diff options
author | Ben Morrison <ben@gbmor.dev> | 2019-05-19 23:45:48 -0400 |
---|---|---|
committer | Ben Morrison <ben@gbmor.dev> | 2019-05-20 02:09:20 -0400 |
commit | 747f4fb70ff61d3f8a7bb1bed896a913b0999450 (patch) | |
tree | 9ed792cbcfba664ebd6d1123f9e7ffdf7fc21105 | |
parent | f9d8193e1bdce33716b109ac2aea879a8c9b1038 (diff) | |
download | getwtxt-747f4fb70ff61d3f8a7bb1bed896a913b0999450.tar.gz |
middleware func to attach remote ip to context
-rw-r--r-- | http.go (renamed from handlers.go) | 78 | ||||
-rw-r--r-- | http_test.go (renamed from handlers_test.go) | 0 | ||||
-rw-r--r-- | main.go | 8 | ||||
-rw-r--r-- | types.go | 5 |
4 files changed, 79 insertions, 12 deletions
diff --git a/handlers.go b/http.go index c455e99..52f9923 100644 --- a/handlers.go +++ b/http.go @@ -1,19 +1,47 @@ package main import ( + "context" "crypto/sha256" "fmt" "io/ioutil" "log" "net/http" "os" + "strings" "time" "github.com/gorilla/mux" ) +// Attaches a request's IP address to the request's context +func newCtxUserIP(ctx context.Context, r *http.Request) context.Context { + base := strings.Split(r.RemoteAddr, ":") + uip := base[0] + return context.WithValue(ctx, ctxKey, uip) +} + +// Retrieves a request's IP address from the request's context +func getIPFromCtx(ctx context.Context) string { + uip, ok := ctx.Value(ctxKey).(string) + if !ok { + log.Printf("Couldn't retrieve IP from request\n") + } + return uip +} + +// Shim function to modify/pass context value to a handler +func ipMiddleware(hop http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := newCtxUserIP(r.Context(), r) + hop.ServeHTTP(w, r.WithContext(ctx)) + }) +} + // handles "/" -func indexHandler(w http.ResponseWriter, _ *http.Request) { +func indexHandler(w http.ResponseWriter, r *http.Request) { + uip := getIPFromCtx(r.Context()) + log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL) // Stat the index template to get the mod time var etag string @@ -33,7 +61,7 @@ func indexHandler(w http.ResponseWriter, _ *http.Request) { // then send it to the client. err := tmpls.ExecuteTemplate(w, "index.html", confObj.Instance) if err != nil { - log.Printf("Error writing to HTTP stream: %v\n", err) + log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -42,6 +70,9 @@ func indexHandler(w http.ResponseWriter, _ *http.Request) { // handles "/api" func apiBaseHandler(w http.ResponseWriter, r *http.Request) { + uip := getIPFromCtx(r.Context()) + log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL) + timerfc3339, err := time.Now().MarshalText() if err != nil { log.Printf("Couldn't format time as RFC3339: %v\n", err) @@ -53,20 +84,25 @@ func apiBaseHandler(w http.ResponseWriter, r *http.Request) { timerfc3339 = append(timerfc3339, pathdata...) n, err := w.Write(timerfc3339) if err != nil || n == 0 { - log.Printf("Error writing to HTTP stream: %v bytes, %v\n", n, err) + log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip) + http.Error(w, err.Error(), http.StatusInternalServerError) } } // handles "/api/plain" // maybe add json/xml support later func apiFormatHandler(w http.ResponseWriter, r *http.Request) { + uip := getIPFromCtx(r.Context()) + log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL) + vars := mux.Vars(r) format := vars["format"] w.Header().Set("Content-Type", txtutf8) n, err := w.Write([]byte(format + "\n")) if err != nil || n == 0 { - log.Printf("Error writing to HTTP stream: %v bytes, %v\n", n, err) + log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip) + http.Error(w, err.Error(), http.StatusInternalServerError) } } @@ -76,10 +112,14 @@ func apiEndpointHandler(w http.ResponseWriter, r *http.Request) { format := vars["format"] endpoint := vars["endpoint"] + uip := getIPFromCtx(r.Context()) + log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL) + w.Header().Set("Content-Type", htmlutf8) n, err := w.Write([]byte(format + "/" + endpoint)) if err != nil || n == 0 { - log.Printf("Error writing to HTTP stream: %v bytes, %v\n", n, err) + log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip) + http.Error(w, err.Error(), http.StatusInternalServerError) } } @@ -90,10 +130,14 @@ func apiEndpointPOSTHandler(w http.ResponseWriter, r *http.Request) { format := vars["format"] endpoint := vars["endpoint"] + uip := getIPFromCtx(r.Context()) + log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL) + w.Header().Set("Content-Type", htmlutf8) n, err := w.Write([]byte(format + "/" + endpoint)) if err != nil || n == 0 { - log.Printf("Error writing to HTTP stream: %v bytes, %v\n", n, err) + log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip) + http.Error(w, err.Error(), http.StatusInternalServerError) } } @@ -103,10 +147,14 @@ func apiTagsBaseHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) format := vars["format"] + uip := getIPFromCtx(r.Context()) + log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL) + w.Header().Set("Content-Type", htmlutf8) n, err := w.Write([]byte("api/" + format + "/tags")) if err != nil || n == 0 { - log.Printf("Error writing to HTTP stream: %v bytes, %v\n", n, err) + log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip) + http.Error(w, err.Error(), http.StatusInternalServerError) } } @@ -117,17 +165,24 @@ func apiTagsHandler(w http.ResponseWriter, r *http.Request) { format := vars["format"] tags := vars["tags"] + uip := getIPFromCtx(r.Context()) + log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL) + w.Header().Set("Content-Type", htmlutf8) n, err := w.Write([]byte("api/" + format + "/tags/" + tags)) if err != nil || n == 0 { - log.Printf("Error writing to HTTP stream: %v bytes, %v\n", n, err) + log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip) + http.Error(w, err.Error(), http.StatusInternalServerError) } } // Serving the stylesheet virtually because // files aren't served directly. -func cssHandler(w http.ResponseWriter, _ *http.Request) { +func cssHandler(w http.ResponseWriter, r *http.Request) { + uip := getIPFromCtx(r.Context()) + log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL) + // read the raw bytes of the stylesheet css, err := ioutil.ReadFile("assets/style.css") if err != nil { @@ -136,7 +191,7 @@ func cssHandler(w http.ResponseWriter, _ *http.Request) { http.Error(w, err.Error(), http.StatusNotFound) return } - log.Printf("%v\n", err) + log.Printf("500: %v\n", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -154,6 +209,7 @@ func cssHandler(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", cssutf8) n, err := w.Write(css) if err != nil || n == 0 { - log.Printf("Error writing to HTTP stream: %v bytes, %v\n", n, err) + log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip) + http.Error(w, err.Error(), http.StatusInternalServerError) } } diff --git a/handlers_test.go b/http_test.go index 4ae7ae6..4ae7ae6 100644 --- a/handlers_test.go +++ b/http_test.go diff --git a/main.go b/main.go index 4b62f27..4508edf 100644 --- a/main.go +++ b/main.go @@ -66,7 +66,7 @@ func main() { // handlers.CompressHandler gzips all responses. // Write/Read timeouts are self explanatory. server := &http.Server{ - Handler: handlers.CompressHandler(index), + Handler: handlers.CompressHandler(ipMiddleware(index)), Addr: portnum, WriteTimeout: 15 * time.Second, ReadTimeout: 15 * time.Second, @@ -77,6 +77,12 @@ func main() { if err != nil { log.Printf("%v\n", err) } + defer func() { + err := server.Close() + if err != nil { + log.Printf("%v\n", err) + } + }() closelog <- true } diff --git a/types.go b/types.go index 8a707bd..b5f27ce 100644 --- a/types.go +++ b/types.go @@ -22,3 +22,8 @@ type Instance struct { Mail string Desc string } + +// ipCtxKey is the Context value key for user IP addresses +type ipCtxKey int + +const ctxKey ipCtxKey = iota |