Implement rate limiting option, fixes #71

This commit is contained in:
Remco
2017-03-28 17:08:34 +02:00
parent 989debecb5
commit 2f35235865
3 changed files with 43 additions and 17 deletions

View File

@ -42,6 +42,8 @@ import (
context "golang.org/x/net/context"
"github.com/PuerkitoBio/ghost/handlers"
"github.com/VojtechVitek/ratelimit"
"github.com/VojtechVitek/ratelimit/memory"
"github.com/gorilla/mux"
_ "net/http/pprof"
@ -116,6 +118,12 @@ func LogFile(s string) OptionFn {
}
}
func RateLimit(requests int) OptionFn {
return func(srvr *Server) {
srvr.rateLimitRequests = requests
}
}
func ForceHTTPs() OptionFn {
return func(srvr *Server) {
srvr.forceHTTPs = true
@ -180,6 +188,8 @@ type Server struct {
locks map[string]*sync.Mutex
rateLimitRequests int
storage Storage
forceHTTPs bool
@ -267,10 +277,12 @@ func (s *Server) Run() {
r.PathPrefix("/favicon.ico").Handler(staticHandler)
r.PathPrefix("/robots.txt").Handler(staticHandler)
r.HandleFunc("/health.html", healthHandler).Methods("GET")
r.HandleFunc("/", s.viewHandler).Methods("GET")
r.HandleFunc("/({files:.*}).zip", s.zipHandler).Methods("GET")
r.HandleFunc("/({files:.*}).tar", s.tarHandler).Methods("GET")
r.HandleFunc("/({files:.*}).tar.gz", s.tarGzHandler).Methods("GET")
r.HandleFunc("/download/{token}/{filename}", s.getHandler).Methods("GET")
r.HandleFunc("/{token}/{filename}", s.previewHandler).MatcherFunc(func(r *http.Request, rm *mux.RouteMatch) (match bool) {
match = false
@ -294,17 +306,22 @@ func (s *Server) Run() {
return
}).Methods("GET")
r.HandleFunc("/{token}/{filename}", s.getHandler).Methods("GET")
r.HandleFunc("/get/{token}/{filename}", s.getHandler).Methods("GET")
getHandlerFn := s.getHandler
if s.rateLimitRequests > 0 {
getHandlerFn = ratelimit.Request(ratelimit.IP).Rate(s.rateLimitRequests, 60*time.Second).LimitBy(memory.New())(http.HandlerFunc(getHandlerFn)).ServeHTTP
}
r.HandleFunc("/{token}/{filename}", getHandlerFn).Methods("GET")
r.HandleFunc("/get/{token}/{filename}", getHandlerFn).Methods("GET")
r.HandleFunc("/download/{token}/{filename}", getHandlerFn).Methods("GET")
r.HandleFunc("/{filename}/virustotal", s.virusTotalHandler).Methods("PUT")
r.HandleFunc("/{filename}/scan", s.scanHandler).Methods("PUT")
r.HandleFunc("/put/{filename}", s.putHandler).Methods("PUT")
r.HandleFunc("/upload/{filename}", s.putHandler).Methods("PUT")
r.HandleFunc("/{filename}", s.putHandler).Methods("PUT")
r.HandleFunc("/health.html", healthHandler).Methods("GET")
r.HandleFunc("/", s.postHandler).Methods("POST")
// r.HandleFunc("/{page}", viewHandler).Methods("GET")
r.HandleFunc("/", s.viewHandler).Methods("GET")
r.NotFoundHandler = http.HandlerFunc(s.notFoundHandler)