Revert "Clamav prescan (#389)"

This reverts commit cff0a88bf3.
This commit is contained in:
Andrea Spacca
2022-01-10 10:50:11 +01:00
committed by GitHub
parent cff0a88bf3
commit 1cdbfe709b
7 changed files with 130 additions and 185 deletions

View File

@ -1,5 +0,0 @@
.PHONY: lint
lint:
golangci-lint run --out-format=github-actions --config .golangci.yml

View File

@ -111,7 +111,6 @@ lets-encrypt-hosts | hosts to use for lets encrypt certificates (comma seperated
log | path to log file| | LOG | log | path to log file| | LOG |
cors-domains | comma separated list of domains for CORS, setting it enable CORS | | CORS_DOMAINS | cors-domains | comma separated list of domains for CORS, setting it enable CORS | | CORS_DOMAINS |
clamav-host | host for clamav feature | | CLAMAV_HOST | clamav-host | host for clamav feature | | CLAMAV_HOST |
perform-clamav-prescan | prescan every upload through clamav feature (clamav-host must be a local clamd unix socket) | | PERFORM_CLAMAV_PRESCAN |
rate-limit | request per minute | | RATE_LIMIT | rate-limit | request per minute | | RATE_LIMIT |
max-upload-size | max upload size in kilobytes | | MAX_UPLOAD_SIZE | max-upload-size | max upload size in kilobytes | | MAX_UPLOAD_SIZE |
purge-days | number of days after the uploads are purged automatically | | PURGE_DAYS | purge-days | number of days after the uploads are purged automatically | | PURGE_DAYS |

View File

@ -240,11 +240,6 @@ var globalFlags = []cli.Flag{
Value: "", Value: "",
EnvVar: "CLAMAV_HOST", EnvVar: "CLAMAV_HOST",
}, },
cli.BoolFlag{
Name: "perform-clamav-prescan",
Usage: "perform-clamav-prescan",
EnvVar: "PERFORM_CLAMAV_PRESCAN",
},
cli.StringFlag{ cli.StringFlag{
Name: "virustotal-key", Name: "virustotal-key",
Usage: "virustotal-key", Usage: "virustotal-key",
@ -393,14 +388,6 @@ func New() *Cmd {
options = append(options, server.ClamavHost(v)) options = append(options, server.ClamavHost(v))
} }
if v := c.Bool("perform-clamav-prescan"); !v {
if c.String("clamav-host") == "" {
panic("clamav-host not set")
}
options = append(options, server.PerformClamavPrescan(v))
}
if v := c.Int64("max-upload-size"); v > 0 { if v := c.Int64("max-upload-size"); v > 0 {
options = append(options, server.MaxUploadSize(v)) options = append(options, server.MaxUploadSize(v))
} }

View File

@ -27,19 +27,18 @@ THE SOFTWARE.
package server package server
import ( import (
"errors" // _ "transfer.sh/app/handlers"
// _ "transfer.sh/app/utils"
"fmt" "fmt"
"io"
"io/ioutil"
"net/http" "net/http"
"time" "time"
clamd "github.com/dutchcoders/go-clamd" clamd "github.com/dutchcoders/go-clamd"
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
const clamavScanStatusOK = "OK"
func (s *Server) scanHandler(w http.ResponseWriter, r *http.Request) { func (s *Server) scanHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
@ -50,53 +49,23 @@ func (s *Server) scanHandler(w http.ResponseWriter, r *http.Request) {
s.logger.Printf("Scanning %s %d %s", filename, contentLength, contentType) s.logger.Printf("Scanning %s %d %s", filename, contentLength, contentType)
file, err := ioutil.TempFile(s.tempPath, "clamav-") reader := r.Body
defer s.cleanTmpFile(file)
if err != nil {
s.logger.Printf("%s", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, err = io.Copy(file, r.Body)
if err != nil {
s.logger.Printf("%s", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
status, err := s.performScan(file.Name())
if err != nil {
s.logger.Printf("%s", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, _ = w.Write([]byte(fmt.Sprintf("%v\n", status)))
}
func (s *Server) performScan(path string) (string, error) {
c := clamd.NewClamd(s.ClamAVDaemonHost) c := clamd.NewClamd(s.ClamAVDaemonHost)
responseCh := make(chan chan *clamd.ScanResult) abort := make(chan bool)
errCh := make(chan error) defer close(abort)
go func(responseCh chan chan *clamd.ScanResult, errCh chan error) { response, err := c.ScanStream(reader, abort)
response, err := c.ScanFile(path)
if err != nil { if err != nil {
errCh <- err s.logger.Printf("%s", err.Error())
http.Error(w, err.Error(), 500)
return return
} }
responseCh <- response
}(responseCh, errCh)
select { select {
case err := <-errCh: case s := <-response:
return "", err _, _ = w.Write([]byte(fmt.Sprintf("%v\n", s.Status)))
case response := <-responseCh:
st := <-response
return st.Status, nil
case <-time.After(time.Second * 60): case <-time.After(time.Second * 60):
return "", errors.New("clamav scan timeout") abort <- true
} }
} }

View File

@ -291,7 +291,7 @@ func sanitize(fileName string) string {
func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) { func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(_24K); nil != err { if err := r.ParseMultipartForm(_24K); nil != err {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Error occurred copying to output stream", http.StatusInternalServerError) http.Error(w, "Error occurred copying to output stream", 500)
return return
} }
@ -309,75 +309,74 @@ func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) {
if f, err = fheader.Open(); err != nil { if f, err = fheader.Open(); err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), 500)
return return
} }
file, err := ioutil.TempFile(s.tempPath, "transfer-") var b bytes.Buffer
n, err := io.CopyN(&b, f, _24K+1)
if err != nil && err != io.EOF {
s.logger.Printf("%s", err.Error())
http.Error(w, err.Error(), 500)
return
}
var file *os.File
var reader io.Reader
if n > _24K {
file, err = ioutil.TempFile(s.tempPath, "transfer-")
defer s.cleanTmpFile(file) defer s.cleanTmpFile(file)
if err != nil {
s.logger.Fatal(err)
}
n, err = io.Copy(file, io.MultiReader(&b, f))
if err != nil { if err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), 500)
return return
} }
n, err := io.Copy(file, f) reader, err = os.Open(file.Name())
if err != nil { if err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), 500)
return return
} }
} else {
reader = bytes.NewReader(b.Bytes())
}
contentLength := n contentLength := n
_, err = file.Seek(0, io.SeekStart)
if err != nil {
s.logger.Printf("%s", err.Error())
return
}
if s.maxUploadSize > 0 && contentLength > s.maxUploadSize { if s.maxUploadSize > 0 && contentLength > s.maxUploadSize {
s.logger.Print("Entity too large") s.logger.Print("Entity too large")
http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge)
return return
} }
if s.performClamavPrescan {
status, err := s.performScan(file.Name())
if err != nil {
s.logger.Printf("%s", err.Error())
http.Error(w, "Could not perform prescan", http.StatusInternalServerError)
return
}
if status != clamavScanStatusOK {
s.logger.Printf("prescan positive: %s", status)
http.Error(w, "Clamav prescan found a virus", http.StatusPreconditionFailed)
return
}
}
metadata := metadataForRequest(contentType, s.randomTokenLength, r) metadata := metadataForRequest(contentType, s.randomTokenLength, r)
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
if err := json.NewEncoder(buffer).Encode(metadata); err != nil { if err := json.NewEncoder(buffer).Encode(metadata); err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Could not encode metadata", http.StatusInternalServerError) http.Error(w, "Could not encode metadata", 500)
return return
} else if err := s.storage.Put(r.Context(), token, fmt.Sprintf("%s.metadata", filename), buffer, "text/json", uint64(buffer.Len())); err != nil { } else if err := s.storage.Put(r.Context(), token, fmt.Sprintf("%s.metadata", filename), buffer, "text/json", uint64(buffer.Len())); err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Could not save metadata", http.StatusInternalServerError) http.Error(w, "Could not save metadata", 500)
return return
} }
s.logger.Printf("Uploading %s %s %d %s", token, filename, contentLength, contentType) s.logger.Printf("Uploading %s %s %d %s", token, filename, contentLength, contentType)
if err = s.storage.Put(r.Context(), token, filename, file, contentType, uint64(contentLength)); err != nil { if err = s.storage.Put(r.Context(), token, filename, reader, contentType, uint64(contentLength)); err != nil {
s.logger.Printf("Backend storage error: %s", err.Error()) s.logger.Printf("Backend storage error: %s", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), 500)
return return
} }
@ -449,35 +448,56 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) {
contentLength := r.ContentLength contentLength := r.ContentLength
var reader io.Reader
reader = r.Body
defer CloseCheck(r.Body.Close) defer CloseCheck(r.Body.Close)
file, err := ioutil.TempFile(s.tempPath, "transfer-") if contentLength == -1 {
defer s.cleanTmpFile(file)
if err != nil {
s.logger.Printf("%s", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// queue file to disk, because s3 needs content length // queue file to disk, because s3 needs content length
// and clamav prescan scans a file var err error
n, err := io.Copy(file, r.Body)
if err != nil {
s.logger.Printf("%s", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
f := reader
var b bytes.Buffer
n, err := io.CopyN(&b, f, _24K+1)
if err != nil && err != io.EOF {
s.logger.Printf("Error putting new file: %s", err.Error())
http.Error(w, err.Error(), 500)
return return
} }
_, err = file.Seek(0, io.SeekStart) var file *os.File
if n > _24K {
file, err = ioutil.TempFile(s.tempPath, "transfer-")
defer s.cleanTmpFile(file)
if err != nil { if err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Cannot reset cache file", http.StatusInternalServerError) http.Error(w, err.Error(), 500)
return return
} }
if contentLength < 1 { n, err = io.Copy(file, io.MultiReader(&b, f))
if err != nil {
s.logger.Printf("%s", err.Error())
http.Error(w, err.Error(), 500)
return
}
reader, err = os.Open(file.Name())
if err != nil {
s.logger.Printf("%s", err.Error())
http.Error(w, err.Error(), 500)
return
}
} else {
reader = bytes.NewReader(b.Bytes())
}
contentLength = n contentLength = n
} }
@ -489,25 +509,10 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) {
if contentLength == 0 { if contentLength == 0 {
s.logger.Print("Empty content-length") s.logger.Print("Empty content-length")
http.Error(w, "Could not upload empty file", http.StatusBadRequest) http.Error(w, "Could not upload empty file", 400)
return return
} }
if s.performClamavPrescan {
status, err := s.performScan(file.Name())
if err != nil {
s.logger.Printf("%s", err.Error())
http.Error(w, "Could not perform prescan", http.StatusInternalServerError)
return
}
if status != clamavScanStatusOK {
s.logger.Printf("prescan positive: %s", status)
http.Error(w, "Clamav prescan found a virus", http.StatusPreconditionFailed)
return
}
}
contentType := mime.TypeByExtension(filepath.Ext(vars["filename"])) contentType := mime.TypeByExtension(filepath.Ext(vars["filename"]))
token := token(s.randomTokenLength) token := token(s.randomTokenLength)
@ -517,23 +522,25 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) {
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
if err := json.NewEncoder(buffer).Encode(metadata); err != nil { if err := json.NewEncoder(buffer).Encode(metadata); err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Could not encode metadata", http.StatusInternalServerError) http.Error(w, "Could not encode metadata", 500)
return return
} else if !metadata.MaxDate.IsZero() && time.Now().After(metadata.MaxDate) { } else if !metadata.MaxDate.IsZero() && time.Now().After(metadata.MaxDate) {
s.logger.Print("Invalid MaxDate") s.logger.Print("Invalid MaxDate")
http.Error(w, "Invalid MaxDate, make sure Max-Days is smaller than 290 years", http.StatusBadRequest) http.Error(w, "Invalid MaxDate, make sure Max-Days is smaller than 290 years", 400)
return return
} else if err := s.storage.Put(r.Context(), token, fmt.Sprintf("%s.metadata", filename), buffer, "text/json", uint64(buffer.Len())); err != nil { } else if err := s.storage.Put(r.Context(), token, fmt.Sprintf("%s.metadata", filename), buffer, "text/json", uint64(buffer.Len())); err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Could not save metadata", http.StatusInternalServerError) http.Error(w, "Could not save metadata", 500)
return return
} }
s.logger.Printf("Uploading %s %s %d %s", token, filename, contentLength, contentType) s.logger.Printf("Uploading %s %s %d %s", token, filename, contentLength, contentType)
if err = s.storage.Put(r.Context(), token, filename, file, contentType, uint64(contentLength)); err != nil { var err error
if err = s.storage.Put(r.Context(), token, filename, reader, contentType, uint64(contentLength)); err != nil {
s.logger.Printf("Error putting new file: %s", err.Error()) s.logger.Printf("Error putting new file: %s", err.Error())
http.Error(w, "Could not save file", http.StatusInternalServerError) http.Error(w, "Could not save file", 500)
return return
} }
@ -759,7 +766,7 @@ func (s *Server) deleteHandler(w http.ResponseWriter, r *http.Request) {
return return
} else if err != nil { } else if err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Could not delete file.", http.StatusInternalServerError) http.Error(w, "Could not delete file.", 500)
return return
} }
} }
@ -798,7 +805,7 @@ func (s *Server) zipHandler(w http.ResponseWriter, r *http.Request) {
} }
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) http.Error(w, "Could not retrieve file.", 500)
return return
} }
@ -813,20 +820,20 @@ func (s *Server) zipHandler(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Internal server error.", http.StatusInternalServerError) http.Error(w, "Internal server error.", 500)
return return
} }
if _, err = io.Copy(fw, reader); err != nil { if _, err = io.Copy(fw, reader); err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Internal server error.", http.StatusInternalServerError) http.Error(w, "Internal server error.", 500)
return return
} }
} }
if err := zw.Close(); err != nil { if err := zw.Close(); err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Internal server error.", http.StatusInternalServerError) http.Error(w, "Internal server error.", 500)
return return
} }
} }
@ -869,7 +876,7 @@ func (s *Server) tarGzHandler(w http.ResponseWriter, r *http.Request) {
} }
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) http.Error(w, "Could not retrieve file.", 500)
return return
} }
@ -881,13 +888,13 @@ func (s *Server) tarGzHandler(w http.ResponseWriter, r *http.Request) {
err = zw.WriteHeader(header) err = zw.WriteHeader(header)
if err != nil { if err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Internal server error.", http.StatusInternalServerError) http.Error(w, "Internal server error.", 500)
return return
} }
if _, err = io.Copy(zw, reader); err != nil { if _, err = io.Copy(zw, reader); err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Internal server error.", http.StatusInternalServerError) http.Error(w, "Internal server error.", 500)
return return
} }
} }
@ -928,7 +935,7 @@ func (s *Server) tarHandler(w http.ResponseWriter, r *http.Request) {
} }
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) http.Error(w, "Could not retrieve file.", 500)
return return
} }
@ -940,13 +947,13 @@ func (s *Server) tarHandler(w http.ResponseWriter, r *http.Request) {
err = zw.WriteHeader(header) err = zw.WriteHeader(header)
if err != nil { if err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Internal server error.", http.StatusInternalServerError) http.Error(w, "Internal server error.", 500)
return return
} }
if _, err = io.Copy(zw, reader); err != nil { if _, err = io.Copy(zw, reader); err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Internal server error.", http.StatusInternalServerError) http.Error(w, "Internal server error.", 500)
return return
} }
} }
@ -973,7 +980,7 @@ func (s *Server) headHandler(w http.ResponseWriter, r *http.Request) {
return return
} else if err != nil { } else if err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) http.Error(w, "Could not retrieve file.", 500)
return return
} }
@ -1010,7 +1017,7 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) {
return return
} else if err != nil { } else if err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Could not retrieve file.", http.StatusInternalServerError) http.Error(w, "Could not retrieve file.", 500)
return return
} }
@ -1041,14 +1048,14 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Error occurred copying to output stream", http.StatusInternalServerError) http.Error(w, "Error occurred copying to output stream", 500)
return return
} }
_, err = io.Copy(file, reader) _, err = io.Copy(file, reader)
if err != nil { if err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Error occurred copying to output stream", http.StatusInternalServerError) http.Error(w, "Error occurred copying to output stream", 500)
return return
} }
@ -1058,7 +1065,7 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) {
if _, err = io.Copy(w, reader); err != nil { if _, err = io.Copy(w, reader); err != nil {
s.logger.Printf("%s", err.Error()) s.logger.Printf("%s", err.Error())
http.Error(w, "Error occurred copying to output stream", http.StatusInternalServerError) http.Error(w, "Error occurred copying to output stream", 500)
return return
} }
} }
@ -1129,12 +1136,12 @@ func (s *Server) basicAuthHandler(h http.Handler) http.HandlerFunc {
username, password, authOK := r.BasicAuth() username, password, authOK := r.BasicAuth()
if !authOK { if !authOK {
http.Error(w, "Not authorized", http.StatusUnauthorized) http.Error(w, "Not authorized", 401)
return return
} }
if username != s.AuthUser || password != s.AuthPass { if username != s.AuthUser || password != s.AuthPass {
http.Error(w, "Not authorized", http.StatusUnauthorized) http.Error(w, "Not authorized", 401)
return return
} }

View File

@ -76,13 +76,6 @@ func ClamavHost(s string) OptionFn {
} }
} }
// PerformClamavPrescan enables clamav prescan on upload
func PerformClamavPrescan(b bool) OptionFn {
return func(srvr *Server) {
srvr.performClamavPrescan = b
}
}
// VirustotalKey sets virus total key // VirustotalKey sets virus total key
func VirustotalKey(s string) OptionFn { func VirustotalKey(s string) OptionFn {
return func(srvr *Server) { return func(srvr *Server) {
@ -347,7 +340,6 @@ type Server struct {
VirusTotalKey string VirusTotalKey string
ClamAVDaemonHost string ClamAVDaemonHost string
performClamavPrescan bool
tempPath string tempPath string
@ -432,20 +424,16 @@ func (s *Server) Run() {
s.logger.Panicf("Unable to parse: path=%s, err=%s", path, err) s.logger.Panicf("Unable to parse: path=%s, err=%s", path, err)
} }
if strings.HasSuffix(path, ".html") {
_, err = htmlTemplates.New(stripPrefix(path)).Parse(string(bytes)) _, err = htmlTemplates.New(stripPrefix(path)).Parse(string(bytes))
if err != nil { if err != nil {
s.logger.Println("Unable to parse html template", err) s.logger.Println("Unable to parse html template", err)
} }
}
if strings.HasSuffix(path, ".txt") {
_, err = textTemplates.New(stripPrefix(path)).Parse(string(bytes)) _, err = textTemplates.New(stripPrefix(path)).Parse(string(bytes))
if err != nil { if err != nil {
s.logger.Println("Unable to parse text template", err) s.logger.Println("Unable to parse text template", err)
} }
} }
} }
}
staticHandler := http.FileServer(fs) staticHandler := http.FileServer(fs)

View File

@ -45,14 +45,14 @@ func (s *Server) virusTotalHandler(w http.ResponseWriter, r *http.Request) {
vt, err := virustotal.NewVirusTotal(s.VirusTotalKey) vt, err := virustotal.NewVirusTotal(s.VirusTotalKey)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), 500)
} }
reader := r.Body reader := r.Body
result, err := vt.Scan(filename, reader) result, err := vt.Scan(filename, reader)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), 500)
} }
s.logger.Println(result) s.logger.Println(result)