summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorBen Morrison <ben@gbmor.dev>2019-05-19 23:45:48 -0400
committerBen Morrison <ben@gbmor.dev>2019-05-20 02:09:20 -0400
commit747f4fb70ff61d3f8a7bb1bed896a913b0999450 (patch)
tree9ed792cbcfba664ebd6d1123f9e7ffdf7fc21105
parentf9d8193e1bdce33716b109ac2aea879a8c9b1038 (diff)
downloadgetwtxt-747f4fb70ff61d3f8a7bb1bed896a913b0999450.tar.gz
middleware func to attach remote ip to context
-rw-r--r--http.go (renamed from handlers.go)78
-rw-r--r--http_test.go (renamed from handlers_test.go)0
-rw-r--r--main.go8
-rw-r--r--types.go5
4 files changed, 79 insertions, 12 deletions
diff --git a/handlers.go b/http.go
index c455e99..52f9923 100644
--- a/handlers.go
+++ b/http.go
@@ -1,19 +1,47 @@
 package main
 
 import (
+	"context"
 	"crypto/sha256"
 	"fmt"
 	"io/ioutil"
 	"log"
 	"net/http"
 	"os"
+	"strings"
 	"time"
 
 	"github.com/gorilla/mux"
 )
 
+// Attaches a request's IP address to the request's context
+func newCtxUserIP(ctx context.Context, r *http.Request) context.Context {
+	base := strings.Split(r.RemoteAddr, ":")
+	uip := base[0]
+	return context.WithValue(ctx, ctxKey, uip)
+}
+
+// Retrieves a request's IP address from the request's context
+func getIPFromCtx(ctx context.Context) string {
+	uip, ok := ctx.Value(ctxKey).(string)
+	if !ok {
+		log.Printf("Couldn't retrieve IP from request\n")
+	}
+	return uip
+}
+
+// Shim function to modify/pass context value to a handler
+func ipMiddleware(hop http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		ctx := newCtxUserIP(r.Context(), r)
+		hop.ServeHTTP(w, r.WithContext(ctx))
+	})
+}
+
 // handles "/"
-func indexHandler(w http.ResponseWriter, _ *http.Request) {
+func indexHandler(w http.ResponseWriter, r *http.Request) {
+	uip := getIPFromCtx(r.Context())
+	log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL)
 
 	// Stat the index template to get the mod time
 	var etag string
@@ -33,7 +61,7 @@ func indexHandler(w http.ResponseWriter, _ *http.Request) {
 	// then send it to the client.
 	err := tmpls.ExecuteTemplate(w, "index.html", confObj.Instance)
 	if err != nil {
-		log.Printf("Error writing to HTTP stream: %v\n", err)
+		log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip)
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}
@@ -42,6 +70,9 @@ func indexHandler(w http.ResponseWriter, _ *http.Request) {
 
 // handles "/api"
 func apiBaseHandler(w http.ResponseWriter, r *http.Request) {
+	uip := getIPFromCtx(r.Context())
+	log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL)
+
 	timerfc3339, err := time.Now().MarshalText()
 	if err != nil {
 		log.Printf("Couldn't format time as RFC3339: %v\n", err)
@@ -53,20 +84,25 @@ func apiBaseHandler(w http.ResponseWriter, r *http.Request) {
 	timerfc3339 = append(timerfc3339, pathdata...)
 	n, err := w.Write(timerfc3339)
 	if err != nil || n == 0 {
-		log.Printf("Error writing to HTTP stream: %v bytes,  %v\n", n, err)
+		log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip)
+		http.Error(w, err.Error(), http.StatusInternalServerError)
 	}
 }
 
 // handles "/api/plain"
 // maybe add json/xml support later
 func apiFormatHandler(w http.ResponseWriter, r *http.Request) {
+	uip := getIPFromCtx(r.Context())
+	log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL)
+
 	vars := mux.Vars(r)
 	format := vars["format"]
 
 	w.Header().Set("Content-Type", txtutf8)
 	n, err := w.Write([]byte(format + "\n"))
 	if err != nil || n == 0 {
-		log.Printf("Error writing to HTTP stream: %v bytes,  %v\n", n, err)
+		log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip)
+		http.Error(w, err.Error(), http.StatusInternalServerError)
 	}
 }
 
@@ -76,10 +112,14 @@ func apiEndpointHandler(w http.ResponseWriter, r *http.Request) {
 	format := vars["format"]
 	endpoint := vars["endpoint"]
 
+	uip := getIPFromCtx(r.Context())
+	log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL)
+
 	w.Header().Set("Content-Type", htmlutf8)
 	n, err := w.Write([]byte(format + "/" + endpoint))
 	if err != nil || n == 0 {
-		log.Printf("Error writing to HTTP stream: %v bytes,  %v\n", n, err)
+		log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip)
+		http.Error(w, err.Error(), http.StatusInternalServerError)
 	}
 
 }
@@ -90,10 +130,14 @@ func apiEndpointPOSTHandler(w http.ResponseWriter, r *http.Request) {
 	format := vars["format"]
 	endpoint := vars["endpoint"]
 
+	uip := getIPFromCtx(r.Context())
+	log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL)
+
 	w.Header().Set("Content-Type", htmlutf8)
 	n, err := w.Write([]byte(format + "/" + endpoint))
 	if err != nil || n == 0 {
-		log.Printf("Error writing to HTTP stream: %v bytes,  %v\n", n, err)
+		log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip)
+		http.Error(w, err.Error(), http.StatusInternalServerError)
 	}
 
 }
@@ -103,10 +147,14 @@ func apiTagsBaseHandler(w http.ResponseWriter, r *http.Request) {
 	vars := mux.Vars(r)
 	format := vars["format"]
 
+	uip := getIPFromCtx(r.Context())
+	log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL)
+
 	w.Header().Set("Content-Type", htmlutf8)
 	n, err := w.Write([]byte("api/" + format + "/tags"))
 	if err != nil || n == 0 {
-		log.Printf("Error writing to HTTP stream: %v bytes,  %v\n", n, err)
+		log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip)
+		http.Error(w, err.Error(), http.StatusInternalServerError)
 	}
 
 }
@@ -117,17 +165,24 @@ func apiTagsHandler(w http.ResponseWriter, r *http.Request) {
 	format := vars["format"]
 	tags := vars["tags"]
 
+	uip := getIPFromCtx(r.Context())
+	log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL)
+
 	w.Header().Set("Content-Type", htmlutf8)
 	n, err := w.Write([]byte("api/" + format + "/tags/" + tags))
 	if err != nil || n == 0 {
-		log.Printf("Error writing to HTTP stream: %v bytes,  %v\n", n, err)
+		log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip)
+		http.Error(w, err.Error(), http.StatusInternalServerError)
 	}
 
 }
 
 // Serving the stylesheet virtually because
 // files aren't served directly.
-func cssHandler(w http.ResponseWriter, _ *http.Request) {
+func cssHandler(w http.ResponseWriter, r *http.Request) {
+	uip := getIPFromCtx(r.Context())
+	log.Printf("Request from %v :: %v %v\n", uip, r.Method, r.URL)
+
 	// read the raw bytes of the stylesheet
 	css, err := ioutil.ReadFile("assets/style.css")
 	if err != nil {
@@ -136,7 +191,7 @@ func cssHandler(w http.ResponseWriter, _ *http.Request) {
 			http.Error(w, err.Error(), http.StatusNotFound)
 			return
 		}
-		log.Printf("%v\n", err)
+		log.Printf("500: %v\n", err)
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}
@@ -154,6 +209,7 @@ func cssHandler(w http.ResponseWriter, _ *http.Request) {
 	w.Header().Set("Content-Type", cssutf8)
 	n, err := w.Write(css)
 	if err != nil || n == 0 {
-		log.Printf("Error writing to HTTP stream: %v bytes,  %v\n", n, err)
+		log.Printf("500: Error writing to HTTP stream: %v, %v %v via %v\n", err, r.Method, r.URL, uip)
+		http.Error(w, err.Error(), http.StatusInternalServerError)
 	}
 }
diff --git a/handlers_test.go b/http_test.go
index 4ae7ae6..4ae7ae6 100644
--- a/handlers_test.go
+++ b/http_test.go
diff --git a/main.go b/main.go
index 4b62f27..4508edf 100644
--- a/main.go
+++ b/main.go
@@ -66,7 +66,7 @@ func main() {
 	// handlers.CompressHandler gzips all responses.
 	// Write/Read timeouts are self explanatory.
 	server := &http.Server{
-		Handler:      handlers.CompressHandler(index),
+		Handler:      handlers.CompressHandler(ipMiddleware(index)),
 		Addr:         portnum,
 		WriteTimeout: 15 * time.Second,
 		ReadTimeout:  15 * time.Second,
@@ -77,6 +77,12 @@ func main() {
 	if err != nil {
 		log.Printf("%v\n", err)
 	}
+	defer func() {
+		err := server.Close()
+		if err != nil {
+			log.Printf("%v\n", err)
+		}
+	}()
 
 	closelog <- true
 }
diff --git a/types.go b/types.go
index 8a707bd..b5f27ce 100644
--- a/types.go
+++ b/types.go
@@ -22,3 +22,8 @@ type Instance struct {
 	Mail  string
 	Desc  string
 }
+
+// ipCtxKey is the Context value key for user IP addresses
+type ipCtxKey int
+
+const ctxKey ipCtxKey = iota