diff --git a/README.md b/README.md index 96cb3f6..41c1152 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,7 @@ lets-encrypt-hosts | hosts to use for lets encrypt certificates (comma seperated log | path to log file| | LOG | cors-domains | comma separated list of domains for CORS, setting it enable CORS | | CORS_DOMAINS | clamav-host | host for clamav feature | | CLAMAV_HOST | +perform-clamav-prescan | prescan every upload through clamav feature (clamav-host must be a local clamd unix socket) | | PERFORM_CLAMAV_PRESCAN | rate-limit | request per minute | | RATE_LIMIT | max-upload-size | max upload size in kilobytes | | MAX_UPLOAD_SIZE | purge-days | number of days after the uploads are purged automatically | | PURGE_DAYS | diff --git a/cmd/cmd.go b/cmd/cmd.go index 1f9fd88..066b443 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -240,6 +240,11 @@ var globalFlags = []cli.Flag{ Value: "", EnvVar: "CLAMAV_HOST", }, + cli.BoolFlag{ + Name: "perform-clamav-prescan", + Usage: "perform-clamav-prescan", + EnvVar: "PERFORM_CLAMAV_PRESCAN", + }, cli.StringFlag{ Name: "virustotal-key", Usage: "virustotal-key", @@ -388,6 +393,14 @@ func New() *Cmd { options = append(options, server.ClamavHost(v)) } + if v := c.Bool("perform-clamav-prescan"); v != false { + if c.String("clamav-host") == "" { + panic("clamav-host not set") + } + + options = append(options, server.PerformClamavPrescan(v)) + } + if v := c.Int64("max-upload-size"); v > 0 { options = append(options, server.MaxUploadSize(v)) } diff --git a/server/clamav.go b/server/clamav.go index a754a95..a7c89be 100644 --- a/server/clamav.go +++ b/server/clamav.go @@ -27,19 +27,19 @@ THE SOFTWARE. package server import ( - // _ "transfer.sh/app/handlers" - // _ "transfer.sh/app/utils" - + "errors" "fmt" + clamd "github.com/dutchcoders/go-clamd" "io" + "io/ioutil" "net/http" "time" - clamd "github.com/dutchcoders/go-clamd" - "github.com/gorilla/mux" ) +const clamavScanStatusOK = "OK" + func (s *Server) scanHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) @@ -50,25 +50,53 @@ func (s *Server) scanHandler(w http.ResponseWriter, r *http.Request) { s.logger.Printf("Scanning %s %d %s", filename, contentLength, contentType) - var reader io.Reader - - reader = r.Body - - c := clamd.NewClamd(s.ClamAVDaemonHost) - - abort := make(chan bool) - defer close(abort) - response, err := c.ScanStream(reader, abort) + file, err := ioutil.TempFile(s.tempPath, "clamav-") + defer s.cleanTmpFile(file) if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), 500) + http.Error(w, err.Error(), http.StatusInternalServerError) return } + _, err = io.Copy(file, r.Body) + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + status, err := s.performScan(file.Name()) + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Write([]byte(fmt.Sprintf("%v\n", status))) +} + +func (s *Server) performScan(path string) (string, error) { + c := clamd.NewClamd(s.ClamAVDaemonHost) + + responseCh := make(chan chan *clamd.ScanResult) + errCh := make(chan error) + go func(responseCh chan chan *clamd.ScanResult, errCh chan error) { + response, err := c.ScanFile(path) + if err != nil { + errCh <- err + return + } + + responseCh <- response + }(responseCh, errCh) + select { - case s := <-response: - w.Write([]byte(fmt.Sprintf("%v\n", s.Status))) + case err := <-errCh: + return "", err + case response := <-responseCh: + st := <-response + return st.Status, nil case <-time.After(time.Second * 60): - abort <- true + return "", errors.New("clamav scan timeout") } } diff --git a/server/handlers.go b/server/handlers.go index 921313e..64a2432 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -318,71 +318,69 @@ func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) { return } - var b bytes.Buffer + file, err := ioutil.TempFile(s.tempPath, "transfer-") + defer s.cleanTmpFile(file) - n, err := io.CopyN(&b, f, _24K+1) - if err != nil && err != io.EOF { + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + n, err := io.Copy(file, f) + if err != nil { s.logger.Printf("%s", err.Error()) http.Error(w, err.Error(), 500) return } - var file *os.File - var reader io.Reader - - if n > _24K { - file, err = ioutil.TempFile(s.tempPath, "transfer-") - if err != nil { - s.logger.Fatal(err) - } - - n, err = io.Copy(file, io.MultiReader(&b, f)) - if err != nil { - s.cleanTmpFile(file) - - s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), 500) - return - } - - reader, err = os.Open(file.Name()) - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), 500) - return - } - } else { - reader = bytes.NewReader(b.Bytes()) - } - contentLength := n + _, err = file.Seek(0, io.SeekStart) + if err != nil { + s.logger.Printf("%s", err.Error()) + return + } + if s.maxUploadSize > 0 && contentLength > s.maxUploadSize { s.logger.Print("Entity too large") http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) return } + if s.performClamavPrescan { + status, err := s.performScan(file.Name()) + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, "Could not perform prescan", http.StatusInternalServerError) + return + } + + if status != clamavScanStatusOK { + s.logger.Printf("prescan positive: %s", status) + http.Error(w, "Clamav prescan found a virus", http.StatusPreconditionFailed) + return + } + } + metadata := metadataForRequest(contentType, s.randomTokenLength, r) buffer := &bytes.Buffer{} if err := json.NewEncoder(buffer).Encode(metadata); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, errors.New("Could not encode metadata").Error(), 500) + http.Error(w, "Could not encode metadata", 500) - s.cleanTmpFile(file) return } else if err := s.storage.Put(token, fmt.Sprintf("%s.metadata", filename), buffer, "text/json", uint64(buffer.Len())); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, errors.New("Could not save metadata").Error(), 500) + http.Error(w, "Could not save metadata", 500) - s.cleanTmpFile(file) return } s.logger.Printf("Uploading %s %s %d %s", token, filename, contentLength, contentType) - if err = s.storage.Put(token, filename, reader, contentType, uint64(contentLength)); err != nil { + if err = s.storage.Put(token, filename, file, contentType, uint64(contentLength)); err != nil { s.logger.Printf("Backend storage error: %s", err.Error()) http.Error(w, err.Error(), 500) return @@ -392,8 +390,6 @@ func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) { filename = url.PathEscape(filename) relativeURL, _ := url.Parse(path.Join(s.proxyPath, token, filename)) fmt.Fprintln(w, getURL(r, s.proxyPort).ResolveReference(relativeURL).String()) - - s.cleanTmpFile(file) } } } @@ -458,57 +454,35 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) { contentLength := r.ContentLength - var reader io.Reader - - reader = r.Body - defer r.Body.Close() - if contentLength == -1 { - // queue file to disk, because s3 needs content length - var err error - var f io.Reader + file, err := ioutil.TempFile(s.tempPath, "transfer-") + defer s.cleanTmpFile(file) + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, err.Error(), 500) + return + } - f = reader + // queue file to disk, because s3 needs content length + // and clamav prescan scans a file + n, err := io.Copy(file, r.Body) + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, err.Error(), 500) - var b bytes.Buffer + return + } - n, err := io.CopyN(&b, f, _24K+1) - if err != nil && err != io.EOF { - s.logger.Printf("Error putting new file: %s", err.Error()) - http.Error(w, err.Error(), 500) - return - } + _, err = file.Seek(0, io.SeekStart) + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, "Cannot reset cache file", http.StatusInternalServerError) - var file *os.File - - if n > _24K { - file, err = ioutil.TempFile(s.tempPath, "transfer-") - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), 500) - return - } - - defer s.cleanTmpFile(file) - - n, err = io.Copy(file, io.MultiReader(&b, f)) - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), 500) - return - } - - reader, err = os.Open(file.Name()) - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), 500) - return - } - } else { - reader = bytes.NewReader(b.Bytes()) - } + return + } + if contentLength < 1 { contentLength = n } @@ -520,10 +494,25 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) { if contentLength == 0 { s.logger.Print("Empty content-length") - http.Error(w, errors.New("Could not upload empty file").Error(), 400) + http.Error(w, "Could not upload empty file", http.StatusBadRequest) return } + if s.performClamavPrescan { + status, err := s.performScan(file.Name()) + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, "Could not perform prescan", http.StatusInternalServerError) + return + } + + if status != clamavScanStatusOK { + s.logger.Printf("prescan positive: %s", status) + http.Error(w, "Clamav prescan found a virus", http.StatusPreconditionFailed) + return + } + } + contentType := mime.TypeByExtension(filepath.Ext(vars["filename"])) token := token(s.randomTokenLength) @@ -533,25 +522,23 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) { buffer := &bytes.Buffer{} if err := json.NewEncoder(buffer).Encode(metadata); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, errors.New("Could not encode metadata").Error(), 500) + http.Error(w, "Could not encode metadata", 500) return } else if !metadata.MaxDate.IsZero() && time.Now().After(metadata.MaxDate) { s.logger.Print("Invalid MaxDate") - http.Error(w, errors.New("Invalid MaxDate, make sure Max-Days is smaller than 290 years").Error(), 400) + http.Error(w, "Invalid MaxDate, make sure Max-Days is smaller than 290 years", http.StatusBadRequest) return } else if err := s.storage.Put(token, fmt.Sprintf("%s.metadata", filename), buffer, "text/json", uint64(buffer.Len())); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, errors.New("Could not save metadata").Error(), 500) + http.Error(w, "Could not save metadata", 500) return } s.logger.Printf("Uploading %s %s %d %s", token, filename, contentLength, contentType) - var err error - - if err = s.storage.Put(token, filename, reader, contentType, uint64(contentLength)); err != nil { + if err = s.storage.Put(token, filename, file, contentType, uint64(contentLength)); err != nil { s.logger.Printf("Error putting new file: %s", err.Error()) - http.Error(w, errors.New("Could not save file").Error(), 500) + http.Error(w, "Could not save file", http.StatusInternalServerError) return } @@ -718,9 +705,9 @@ func (s *Server) checkMetadata(token, filename string, increaseDownload bool) (m buffer := &bytes.Buffer{} if err := json.NewEncoder(buffer).Encode(metadata); err != nil { - return metadata, errors.New("Could not encode metadata") + return metadata, errors.New("could not encode metadata") } else if err := s.storage.Put(token, fmt.Sprintf("%s.metadata", filename), buffer, "text/json", uint64(buffer.Len())); err != nil { - return metadata, errors.New("Could not save metadata") + return metadata, errors.New("could not save metadata") } } @@ -783,7 +770,7 @@ func (s *Server) deleteHandler(w http.ResponseWriter, r *http.Request) { return } else if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not delete file.", 500) + http.Error(w, "Could not delete file.", http.StatusInternalServerError) return } } @@ -821,7 +808,7 @@ func (s *Server) zipHandler(w http.ResponseWriter, r *http.Request) { } s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not retrieve file.", 500) + http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) return } @@ -838,20 +825,20 @@ func (s *Server) zipHandler(w http.ResponseWriter, r *http.Request) { if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", 500) + http.Error(w, "Internal server error.", http.StatusInternalServerError) return } if _, err = io.Copy(fw, reader); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", 500) + http.Error(w, "Internal server error.", http.StatusInternalServerError) return } } if err := zw.Close(); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", 500) + http.Error(w, "Internal server error.", http.StatusInternalServerError) return } } @@ -892,7 +879,7 @@ func (s *Server) tarGzHandler(w http.ResponseWriter, r *http.Request) { } s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not retrieve file.", 500) + http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) return } @@ -906,13 +893,13 @@ func (s *Server) tarGzHandler(w http.ResponseWriter, r *http.Request) { err = zw.WriteHeader(header) if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", 500) + http.Error(w, "Internal server error.", http.StatusInternalServerError) return } if _, err = io.Copy(zw, reader); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", 500) + http.Error(w, "Internal server error.", http.StatusInternalServerError) return } } @@ -951,7 +938,7 @@ func (s *Server) tarHandler(w http.ResponseWriter, r *http.Request) { } s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not retrieve file.", 500) + http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) return } @@ -965,13 +952,13 @@ func (s *Server) tarHandler(w http.ResponseWriter, r *http.Request) { err = zw.WriteHeader(header) if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", 500) + http.Error(w, "Internal server error.", http.StatusInternalServerError) return } if _, err = io.Copy(zw, reader); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", 500) + http.Error(w, "Internal server error.", http.StatusInternalServerError) return } } @@ -998,7 +985,7 @@ func (s *Server) headHandler(w http.ResponseWriter, r *http.Request) { return } else if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not retrieve file.", 500) + http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) return } @@ -1033,7 +1020,7 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) { return } else if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not retrieve file.", 500) + http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) return } @@ -1064,7 +1051,7 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) { file, err := ioutil.TempFile(s.tempPath, "range-") if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Error occurred copying to output stream", 500) + http.Error(w, "Error occurred copying to output stream", http.StatusInternalServerError) return } @@ -1073,7 +1060,7 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) { _, err = io.Copy(file, reader) if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Error occurred copying to output stream", 500) + http.Error(w, "Error occurred copying to output stream", http.StatusInternalServerError) return } @@ -1083,7 +1070,7 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) { if _, err = io.Copy(w, reader); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Error occurred copying to output stream", 500) + http.Error(w, "Error occurred copying to output stream", http.StatusInternalServerError) return } diff --git a/server/server.go b/server/server.go index 3d55e0c..b77537d 100644 --- a/server/server.go +++ b/server/server.go @@ -76,6 +76,13 @@ func ClamavHost(s string) OptionFn { } } +// PerformClamavPrescan enables clamav prescan on upload +func PerformClamavPrescan(b bool) OptionFn { + return func(srvr *Server) { + srvr.performClamavPrescan = b + } +} + // VirustotalKey sets virus total key func VirustotalKey(s string) OptionFn { return func(srvr *Server) { @@ -338,8 +345,9 @@ type Server struct { ipFilterOptions *IPFilterOptions - VirusTotalKey string - ClamAVDaemonHost string + VirusTotalKey string + ClamAVDaemonHost string + performClamavPrescan bool tempPath string