diff --git a/Makefile b/Makefile deleted file mode 100644 index a0b12f1..0000000 --- a/Makefile +++ /dev/null @@ -1,5 +0,0 @@ -.PHONY: lint - -lint: - golangci-lint run --out-format=github-actions --config .golangci.yml - diff --git a/README.md b/README.md index 41c1152..96cb3f6 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,6 @@ lets-encrypt-hosts | hosts to use for lets encrypt certificates (comma seperated log | path to log file| | LOG | cors-domains | comma separated list of domains for CORS, setting it enable CORS | | CORS_DOMAINS | clamav-host | host for clamav feature | | CLAMAV_HOST | -perform-clamav-prescan | prescan every upload through clamav feature (clamav-host must be a local clamd unix socket) | | PERFORM_CLAMAV_PRESCAN | rate-limit | request per minute | | RATE_LIMIT | max-upload-size | max upload size in kilobytes | | MAX_UPLOAD_SIZE | purge-days | number of days after the uploads are purged automatically | | PURGE_DAYS | diff --git a/cmd/cmd.go b/cmd/cmd.go index abeafce..1f9fd88 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -240,11 +240,6 @@ var globalFlags = []cli.Flag{ Value: "", EnvVar: "CLAMAV_HOST", }, - cli.BoolFlag{ - Name: "perform-clamav-prescan", - Usage: "perform-clamav-prescan", - EnvVar: "PERFORM_CLAMAV_PRESCAN", - }, cli.StringFlag{ Name: "virustotal-key", Usage: "virustotal-key", @@ -393,14 +388,6 @@ func New() *Cmd { options = append(options, server.ClamavHost(v)) } - if v := c.Bool("perform-clamav-prescan"); !v { - if c.String("clamav-host") == "" { - panic("clamav-host not set") - } - - options = append(options, server.PerformClamavPrescan(v)) - } - if v := c.Int64("max-upload-size"); v > 0 { options = append(options, server.MaxUploadSize(v)) } diff --git a/server/clamav.go b/server/clamav.go index 9b86589..eec5b90 100644 --- a/server/clamav.go +++ b/server/clamav.go @@ -27,19 +27,18 @@ THE SOFTWARE. package server import ( - "errors" + // _ "transfer.sh/app/handlers" + // _ "transfer.sh/app/utils" + "fmt" - "io" - "io/ioutil" "net/http" "time" clamd "github.com/dutchcoders/go-clamd" + "github.com/gorilla/mux" ) -const clamavScanStatusOK = "OK" - func (s *Server) scanHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) @@ -50,53 +49,23 @@ func (s *Server) scanHandler(w http.ResponseWriter, r *http.Request) { s.logger.Printf("Scanning %s %d %s", filename, contentLength, contentType) - file, err := ioutil.TempFile(s.tempPath, "clamav-") - defer s.cleanTmpFile(file) - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + reader := r.Body - _, err = io.Copy(file, r.Body) - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - status, err := s.performScan(file.Name()) - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - _, _ = w.Write([]byte(fmt.Sprintf("%v\n", status))) -} - -func (s *Server) performScan(path string) (string, error) { c := clamd.NewClamd(s.ClamAVDaemonHost) - responseCh := make(chan chan *clamd.ScanResult) - errCh := make(chan error) - go func(responseCh chan chan *clamd.ScanResult, errCh chan error) { - response, err := c.ScanFile(path) - if err != nil { - errCh <- err - return - } - - responseCh <- response - }(responseCh, errCh) + abort := make(chan bool) + defer close(abort) + response, err := c.ScanStream(reader, abort) + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, err.Error(), 500) + return + } select { - case err := <-errCh: - return "", err - case response := <-responseCh: - st := <-response - return st.Status, nil + case s := <-response: + _, _ = w.Write([]byte(fmt.Sprintf("%v\n", s.Status))) case <-time.After(time.Second * 60): - return "", errors.New("clamav scan timeout") + abort <- true } } diff --git a/server/handlers.go b/server/handlers.go index eafc522..2e46145 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -291,7 +291,7 @@ func sanitize(fileName string) string { func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) { if err := r.ParseMultipartForm(_24K); nil != err { s.logger.Printf("%s", err.Error()) - http.Error(w, "Error occurred copying to output stream", http.StatusInternalServerError) + http.Error(w, "Error occurred copying to output stream", 500) return } @@ -309,75 +309,74 @@ func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) { if f, err = fheader.Open(); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), 500) return } - file, err := ioutil.TempFile(s.tempPath, "transfer-") - defer s.cleanTmpFile(file) + var b bytes.Buffer - if err != nil { + n, err := io.CopyN(&b, f, _24K+1) + if err != nil && err != io.EOF { s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), 500) return } - n, err := io.Copy(file, f) - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) - return + var file *os.File + var reader io.Reader + + if n > _24K { + file, err = ioutil.TempFile(s.tempPath, "transfer-") + defer s.cleanTmpFile(file) + if err != nil { + s.logger.Fatal(err) + } + + n, err = io.Copy(file, io.MultiReader(&b, f)) + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, err.Error(), 500) + return + } + + reader, err = os.Open(file.Name()) + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, err.Error(), 500) + return + } + } else { + reader = bytes.NewReader(b.Bytes()) } contentLength := n - _, err = file.Seek(0, io.SeekStart) - if err != nil { - s.logger.Printf("%s", err.Error()) - return - } - if s.maxUploadSize > 0 && contentLength > s.maxUploadSize { s.logger.Print("Entity too large") http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) return } - if s.performClamavPrescan { - status, err := s.performScan(file.Name()) - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not perform prescan", http.StatusInternalServerError) - return - } - - if status != clamavScanStatusOK { - s.logger.Printf("prescan positive: %s", status) - http.Error(w, "Clamav prescan found a virus", http.StatusPreconditionFailed) - return - } - } - metadata := metadataForRequest(contentType, s.randomTokenLength, r) buffer := &bytes.Buffer{} if err := json.NewEncoder(buffer).Encode(metadata); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not encode metadata", http.StatusInternalServerError) + http.Error(w, "Could not encode metadata", 500) return } else if err := s.storage.Put(r.Context(), token, fmt.Sprintf("%s.metadata", filename), buffer, "text/json", uint64(buffer.Len())); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not save metadata", http.StatusInternalServerError) + http.Error(w, "Could not save metadata", 500) return } s.logger.Printf("Uploading %s %s %d %s", token, filename, contentLength, contentType) - if err = s.storage.Put(r.Context(), token, filename, file, contentType, uint64(contentLength)); err != nil { + if err = s.storage.Put(r.Context(), token, filename, reader, contentType, uint64(contentLength)); err != nil { s.logger.Printf("Backend storage error: %s", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), 500) return } @@ -449,35 +448,56 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) { contentLength := r.ContentLength + var reader io.Reader + + reader = r.Body + defer CloseCheck(r.Body.Close) - file, err := ioutil.TempFile(s.tempPath, "transfer-") - defer s.cleanTmpFile(file) - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + if contentLength == -1 { + // queue file to disk, because s3 needs content length + var err error - // queue file to disk, because s3 needs content length - // and clamav prescan scans a file - n, err := io.Copy(file, r.Body) - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) + f := reader - return - } + var b bytes.Buffer - _, err = file.Seek(0, io.SeekStart) - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, "Cannot reset cache file", http.StatusInternalServerError) + n, err := io.CopyN(&b, f, _24K+1) + if err != nil && err != io.EOF { + s.logger.Printf("Error putting new file: %s", err.Error()) + http.Error(w, err.Error(), 500) + return + } - return - } + var file *os.File + + if n > _24K { + file, err = ioutil.TempFile(s.tempPath, "transfer-") + defer s.cleanTmpFile(file) + + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, err.Error(), 500) + return + } + + n, err = io.Copy(file, io.MultiReader(&b, f)) + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, err.Error(), 500) + return + } + + reader, err = os.Open(file.Name()) + if err != nil { + s.logger.Printf("%s", err.Error()) + http.Error(w, err.Error(), 500) + return + } + } else { + reader = bytes.NewReader(b.Bytes()) + } - if contentLength < 1 { contentLength = n } @@ -489,25 +509,10 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) { if contentLength == 0 { s.logger.Print("Empty content-length") - http.Error(w, "Could not upload empty file", http.StatusBadRequest) + http.Error(w, "Could not upload empty file", 400) return } - if s.performClamavPrescan { - status, err := s.performScan(file.Name()) - if err != nil { - s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not perform prescan", http.StatusInternalServerError) - return - } - - if status != clamavScanStatusOK { - s.logger.Printf("prescan positive: %s", status) - http.Error(w, "Clamav prescan found a virus", http.StatusPreconditionFailed) - return - } - } - contentType := mime.TypeByExtension(filepath.Ext(vars["filename"])) token := token(s.randomTokenLength) @@ -517,23 +522,25 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) { buffer := &bytes.Buffer{} if err := json.NewEncoder(buffer).Encode(metadata); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not encode metadata", http.StatusInternalServerError) + http.Error(w, "Could not encode metadata", 500) return } else if !metadata.MaxDate.IsZero() && time.Now().After(metadata.MaxDate) { s.logger.Print("Invalid MaxDate") - http.Error(w, "Invalid MaxDate, make sure Max-Days is smaller than 290 years", http.StatusBadRequest) + http.Error(w, "Invalid MaxDate, make sure Max-Days is smaller than 290 years", 400) return } else if err := s.storage.Put(r.Context(), token, fmt.Sprintf("%s.metadata", filename), buffer, "text/json", uint64(buffer.Len())); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not save metadata", http.StatusInternalServerError) + http.Error(w, "Could not save metadata", 500) return } s.logger.Printf("Uploading %s %s %d %s", token, filename, contentLength, contentType) - if err = s.storage.Put(r.Context(), token, filename, file, contentType, uint64(contentLength)); err != nil { + var err error + + if err = s.storage.Put(r.Context(), token, filename, reader, contentType, uint64(contentLength)); err != nil { s.logger.Printf("Error putting new file: %s", err.Error()) - http.Error(w, "Could not save file", http.StatusInternalServerError) + http.Error(w, "Could not save file", 500) return } @@ -759,7 +766,7 @@ func (s *Server) deleteHandler(w http.ResponseWriter, r *http.Request) { return } else if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not delete file.", http.StatusInternalServerError) + http.Error(w, "Could not delete file.", 500) return } } @@ -798,7 +805,7 @@ func (s *Server) zipHandler(w http.ResponseWriter, r *http.Request) { } s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) + http.Error(w, "Could not retrieve file.", 500) return } @@ -813,20 +820,20 @@ func (s *Server) zipHandler(w http.ResponseWriter, r *http.Request) { if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", http.StatusInternalServerError) + http.Error(w, "Internal server error.", 500) return } if _, err = io.Copy(fw, reader); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", http.StatusInternalServerError) + http.Error(w, "Internal server error.", 500) return } } if err := zw.Close(); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", http.StatusInternalServerError) + http.Error(w, "Internal server error.", 500) return } } @@ -869,7 +876,7 @@ func (s *Server) tarGzHandler(w http.ResponseWriter, r *http.Request) { } s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) + http.Error(w, "Could not retrieve file.", 500) return } @@ -881,13 +888,13 @@ func (s *Server) tarGzHandler(w http.ResponseWriter, r *http.Request) { err = zw.WriteHeader(header) if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", http.StatusInternalServerError) + http.Error(w, "Internal server error.", 500) return } if _, err = io.Copy(zw, reader); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", http.StatusInternalServerError) + http.Error(w, "Internal server error.", 500) return } } @@ -928,7 +935,7 @@ func (s *Server) tarHandler(w http.ResponseWriter, r *http.Request) { } s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) + http.Error(w, "Could not retrieve file.", 500) return } @@ -940,13 +947,13 @@ func (s *Server) tarHandler(w http.ResponseWriter, r *http.Request) { err = zw.WriteHeader(header) if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", http.StatusInternalServerError) + http.Error(w, "Internal server error.", 500) return } if _, err = io.Copy(zw, reader); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Internal server error.", http.StatusInternalServerError) + http.Error(w, "Internal server error.", 500) return } } @@ -973,7 +980,7 @@ func (s *Server) headHandler(w http.ResponseWriter, r *http.Request) { return } else if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) + http.Error(w, "Could not retrieve file.", 500) return } @@ -1010,7 +1017,7 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) { return } else if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) + http.Error(w, "Could not retrieve file.", 500) return } @@ -1041,14 +1048,14 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) { if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Error occurred copying to output stream", http.StatusInternalServerError) + http.Error(w, "Error occurred copying to output stream", 500) return } _, err = io.Copy(file, reader) if err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Error occurred copying to output stream", http.StatusInternalServerError) + http.Error(w, "Error occurred copying to output stream", 500) return } @@ -1058,7 +1065,7 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) { if _, err = io.Copy(w, reader); err != nil { s.logger.Printf("%s", err.Error()) - http.Error(w, "Error occurred copying to output stream", http.StatusInternalServerError) + http.Error(w, "Error occurred copying to output stream", 500) return } } @@ -1129,12 +1136,12 @@ func (s *Server) basicAuthHandler(h http.Handler) http.HandlerFunc { username, password, authOK := r.BasicAuth() if !authOK { - http.Error(w, "Not authorized", http.StatusUnauthorized) + http.Error(w, "Not authorized", 401) return } if username != s.AuthUser || password != s.AuthPass { - http.Error(w, "Not authorized", http.StatusUnauthorized) + http.Error(w, "Not authorized", 401) return } diff --git a/server/server.go b/server/server.go index c2b1e38..f12ad19 100644 --- a/server/server.go +++ b/server/server.go @@ -76,13 +76,6 @@ func ClamavHost(s string) OptionFn { } } -// PerformClamavPrescan enables clamav prescan on upload -func PerformClamavPrescan(b bool) OptionFn { - return func(srvr *Server) { - srvr.performClamavPrescan = b - } -} - // VirustotalKey sets virus total key func VirustotalKey(s string) OptionFn { return func(srvr *Server) { @@ -345,9 +338,8 @@ type Server struct { ipFilterOptions *IPFilterOptions - VirusTotalKey string - ClamAVDaemonHost string - performClamavPrescan bool + VirusTotalKey string + ClamAVDaemonHost string tempPath string @@ -432,17 +424,13 @@ func (s *Server) Run() { s.logger.Panicf("Unable to parse: path=%s, err=%s", path, err) } - if strings.HasSuffix(path, ".html") { - _, err = htmlTemplates.New(stripPrefix(path)).Parse(string(bytes)) - if err != nil { - s.logger.Println("Unable to parse html template", err) - } + _, err = htmlTemplates.New(stripPrefix(path)).Parse(string(bytes)) + if err != nil { + s.logger.Println("Unable to parse html template", err) } - if strings.HasSuffix(path, ".txt") { - _, err = textTemplates.New(stripPrefix(path)).Parse(string(bytes)) - if err != nil { - s.logger.Println("Unable to parse text template", err) - } + _, err = textTemplates.New(stripPrefix(path)).Parse(string(bytes)) + if err != nil { + s.logger.Println("Unable to parse text template", err) } } } diff --git a/server/virustotal.go b/server/virustotal.go index 7e4be0c..24fa8e7 100644 --- a/server/virustotal.go +++ b/server/virustotal.go @@ -45,14 +45,14 @@ func (s *Server) virusTotalHandler(w http.ResponseWriter, r *http.Request) { vt, err := virustotal.NewVirusTotal(s.VirusTotalKey) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), 500) } reader := r.Body result, err := vt.Scan(filename, reader) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), 500) } s.logger.Println(result)