diff --git a/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go b/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go index 666a102ab..0c1dce6f8 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go @@ -54,7 +54,7 @@ func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg *config.MediaAPI w.Header().Set("Content-Type", "application/json") vars := mux.Vars(req) - writers.Download(w, req, gomatrixserverlib.ServerName(vars["serverName"]), types.MediaID(vars["mediaId"]), cfg, activeRemoteRequests) + writers.Download(w, req, gomatrixserverlib.ServerName(vars["serverName"]), types.MediaID(vars["mediaId"]), cfg, db, activeRemoteRequests) })), ) diff --git a/src/github.com/matrix-org/dendrite/mediaapi/storage/storage.go b/src/github.com/matrix-org/dendrite/mediaapi/storage/storage.go index 4b86967fb..cb27ccc95 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/storage/storage.go @@ -50,7 +50,11 @@ func (d *Database) StoreMediaMetadata(mediaMetadata *types.MediaMetadata) error // GetMediaMetadata returns metadata about media stored on this server. // The media could have been uploaded to this server or fetched from another server and cached here. -// Returns sql.ErrNoRows if there is no metadata associated with this media. +// Returns nil metadata if there is no metadata associated with this media. func (d *Database) GetMediaMetadata(mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) { - return d.statements.selectMedia(mediaID, mediaOrigin) + mediaMetadata, err := d.statements.selectMedia(mediaID, mediaOrigin) + if err != nil && err == sql.ErrNoRows { + return nil, nil + } + return mediaMetadata, err } diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go index 1d1097d4b..9f66409c5 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go @@ -17,12 +17,17 @@ package writers import ( "encoding/json" "fmt" + "io" "net/http" + "os" "regexp" + "strconv" log "github.com/Sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/mediaapi/config" + "github.com/matrix-org/dendrite/mediaapi/fileutils" + "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -41,13 +46,17 @@ type downloadRequest struct { } // Download implements /download -func Download(w http.ResponseWriter, req *http.Request, origin gomatrixserverlib.ServerName, mediaID types.MediaID, cfg *config.MediaAPI, activeRemoteRequests *types.ActiveRemoteRequests) { +// Files from this server (i.e. origin == cfg.ServerName) are served directly +func Download(w http.ResponseWriter, req *http.Request, origin gomatrixserverlib.ServerName, mediaID types.MediaID, cfg *config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) { r := &downloadRequest{ MediaMetadata: &types.MediaMetadata{ MediaID: mediaID, Origin: origin, }, - Logger: util.GetLogger(req.Context()), + Logger: util.GetLogger(req.Context()).WithFields(log.Fields{ + "Origin": origin, + "MediaID": mediaID, + }), } // request validation @@ -64,7 +73,10 @@ func Download(w http.ResponseWriter, req *http.Request, origin gomatrixserverlib return } - // doDownload + if resErr := r.doDownload(w, cfg, db, activeRemoteRequests); resErr != nil { + r.jsonErrorResponse(w, *resErr) + return + } } func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse) { @@ -101,3 +113,102 @@ func (r *downloadRequest) Validate() *util.JSONResponse { } return nil } + +func (r *downloadRequest) doDownload(w http.ResponseWriter, cfg *config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) *util.JSONResponse { + // check if we have a record of the media in our database + mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin) + if err != nil { + r.Logger.WithError(err).Error("Error querying the database.") + return &util.JSONResponse{ + Code: 500, + JSON: jsonerror.InternalServerError(), + } + } + if mediaMetadata == nil { + if r.MediaMetadata.Origin == cfg.ServerName { + // If we do not have a record and the origin is local, the file is not found + return &util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), + } + } + // TODO: If we do not have a record and the origin is remote, we need to fetch it and respond with that file + return &util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), + } + } + // If we have a record, we can respond from the local file + r.MediaMetadata = mediaMetadata + return r.respondFromLocalFile(w, cfg.AbsBasePath) +} + +// respondFromLocalFile reads a file from local storage and writes it to the http.ResponseWriter +// Returns a util.JSONResponse error in case of error +func (r *downloadRequest) respondFromLocalFile(w http.ResponseWriter, absBasePath types.Path) *util.JSONResponse { + filePath, err := fileutils.GetPathFromBase64Hash(r.MediaMetadata.Base64Hash, absBasePath) + if err != nil { + r.Logger.WithError(err).Error("Failed to get file path from metadata") + return &util.JSONResponse{ + Code: 500, + JSON: jsonerror.InternalServerError(), + } + } + file, err := os.Open(filePath) + defer file.Close() + if err != nil { + r.Logger.WithError(err).Error("Failed to open file") + return &util.JSONResponse{ + Code: 500, + JSON: jsonerror.InternalServerError(), + } + } + stat, err := file.Stat() + if err != nil { + r.Logger.WithError(err).Error("Failed to stat file") + return &util.JSONResponse{ + Code: 500, + JSON: jsonerror.InternalServerError(), + } + } + + if r.MediaMetadata.FileSizeBytes > 0 && int64(r.MediaMetadata.FileSizeBytes) != stat.Size() { + r.Logger.WithFields(log.Fields{ + "fileSizeDatabase": r.MediaMetadata.FileSizeBytes, + "fileSizeDisk": stat.Size(), + }).Warn("File size in database and on-disk differ.") + return &util.JSONResponse{ + Code: 500, + JSON: jsonerror.InternalServerError(), + } + } + + r.Logger.WithFields(log.Fields{ + "UploadName": r.MediaMetadata.UploadName, + "Base64Hash": r.MediaMetadata.Base64Hash, + "FileSizeBytes": r.MediaMetadata.FileSizeBytes, + "Content-Type": r.MediaMetadata.ContentType, + }).Info("Responding with file") + + w.Header().Set("Content-Type", string(r.MediaMetadata.ContentType)) + w.Header().Set("Content-Length", strconv.FormatInt(int64(r.MediaMetadata.FileSizeBytes), 10)) + contentSecurityPolicy := "default-src 'none';" + + " script-src 'none';" + + " plugin-types application/pdf;" + + " style-src 'unsafe-inline';" + + " object-src 'self';" + w.Header().Set("Content-Security-Policy", contentSecurityPolicy) + + if bytesResponded, err := io.Copy(w, file); err != nil { + r.Logger.WithError(err).Warn("Failed to copy from cache") + if bytesResponded == 0 { + return &util.JSONResponse{ + Code: 500, + JSON: jsonerror.NotFound(fmt.Sprintf("Failed to respond with file with media ID %q", r.MediaMetadata.MediaID)), + } + } + // If we have written any data then we have already responded with 200 OK and all we can do is close the connection + return nil + } + return nil +} diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/upload.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/upload.go index dc525ee44..dabd50073 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/upload.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/upload.go @@ -15,7 +15,6 @@ package writers import ( - "database/sql" "fmt" "io" "net/http" @@ -87,7 +86,7 @@ func parseAndValidateRequest(req *http.Request, cfg *config.MediaAPI) (*uploadRe ContentType: types.ContentType(req.Header.Get("Content-Type")), UploadName: types.Filename(url.PathEscape(req.FormValue("filename"))), }, - Logger: util.GetLogger(req.Context()), + Logger: util.GetLogger(req.Context()).WithField("Origin", cfg.ServerName), } if resErr := r.Validate(cfg.MaxFileSizeBytes); resErr != nil { @@ -99,10 +98,9 @@ func parseAndValidateRequest(req *http.Request, cfg *config.MediaAPI) (*uploadRe func (r *uploadRequest) doUpload(reqReader io.Reader, cfg *config.MediaAPI, db *storage.Database) *util.JSONResponse { r.Logger.WithFields(log.Fields{ - "Origin": r.MediaMetadata.Origin, "UploadName": r.MediaMetadata.UploadName, "FileSizeBytes": r.MediaMetadata.FileSizeBytes, - "Content-Type": r.MediaMetadata.ContentType, + "ContentType": r.MediaMetadata.ContentType, }).Info("Uploading file") // The file data is hashed and the hash is used as the MediaID. The hash is useful as a @@ -112,8 +110,6 @@ func (r *uploadRequest) doUpload(reqReader io.Reader, cfg *config.MediaAPI, db * hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(reqReader, cfg.MaxFileSizeBytes, cfg.AbsBasePath) if err != nil { r.Logger.WithError(err).WithFields(log.Fields{ - "Origin": r.MediaMetadata.Origin, - "MediaID": r.MediaMetadata.MediaID, "MaxFileSizeBytes": cfg.MaxFileSizeBytes, }).Warn("Error while transferring file") fileutils.RemoveDir(tmpDir, r.Logger) @@ -127,18 +123,26 @@ func (r *uploadRequest) doUpload(reqReader io.Reader, cfg *config.MediaAPI, db * r.MediaMetadata.Base64Hash = hash r.MediaMetadata.MediaID = types.MediaID(hash) + r.Logger = r.Logger.WithField("MediaID", r.MediaMetadata.MediaID) + r.Logger.WithFields(log.Fields{ - "MediaID": r.MediaMetadata.MediaID, - "Origin": r.MediaMetadata.Origin, "Base64Hash": r.MediaMetadata.Base64Hash, "UploadName": r.MediaMetadata.UploadName, "FileSizeBytes": r.MediaMetadata.FileSizeBytes, - "Content-Type": r.MediaMetadata.ContentType, + "ContentType": r.MediaMetadata.ContentType, }).Info("File uploaded") // check if we already have a record of the media in our database and if so, we can remove the temporary directory mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin) - if err == nil { + if err != nil { + r.Logger.WithError(err).Error("Error querying the database.") + return &util.JSONResponse{ + Code: 500, + JSON: jsonerror.InternalServerError(), + } + } + + if mediaMetadata != nil { r.MediaMetadata = mediaMetadata fileutils.RemoveDir(tmpDir, r.Logger) return &util.JSONResponse{ @@ -147,8 +151,6 @@ func (r *uploadRequest) doUpload(reqReader io.Reader, cfg *config.MediaAPI, db * ContentURI: fmt.Sprintf("mxc://%s/%s", cfg.ServerName, r.MediaMetadata.MediaID), }, } - } else if err != sql.ErrNoRows { - r.Logger.WithError(err).WithField("MediaID", r.MediaMetadata.MediaID).Warn("Failed to query database") } // TODO: generate thumbnails