summary refs log tree commit diff stats
path: root/svc/http.go
diff options
context:
space:
mode:
Diffstat (limited to 'svc/http.go')
-rw-r--r--svc/http.go82
1 files changed, 82 insertions, 0 deletions
diff --git a/svc/http.go b/svc/http.go
new file mode 100644
index 0000000..ddf8669
--- /dev/null
+++ b/svc/http.go
@@ -0,0 +1,82 @@
+package svc // import "github.com/getwtxt/getwtxt/svc"
+
+import (
+	"context"
+	"log"
+	"net"
+	"net/http"
+	"strings"
+)
+
+// Attaches a request's IP address to the request's context.
+// If getwtxt is behind a reverse proxy, get the last entry
+// in the X-Forwarded-For or X-Real-IP HTTP header as the user IP.
+func newCtxUserIP(ctx context.Context, r *http.Request) context.Context {
+
+	base := strings.Split(r.RemoteAddr, ":")
+	uip := base[0]
+
+	if _, ok := r.Header["X-Forwarded-For"]; ok {
+		proxied := r.Header["X-Forwarded-For"]
+		base = strings.Split(proxied[len(proxied)-1], ":")
+		uip = base[0]
+	}
+
+	xRealIP := http.CanonicalHeaderKey("X-Real-IP")
+	if _, ok := r.Header[xRealIP]; ok {
+		proxied := r.Header[xRealIP]
+		base = strings.Split(proxied[len(proxied)-1], ":")
+		uip = base[0]
+	}
+
+	return context.WithValue(ctx, ctxKey, uip)
+}
+
+// Retrieves a request's IP address from the request's context
+func getIPFromCtx(ctx context.Context) net.IP {
+
+	uip, ok := ctx.Value(ctxKey).(string)
+	if !ok {
+		log.Printf("Couldn't retrieve IP from request\n")
+	}
+
+	return net.ParseIP(uip)
+}
+
+// Shim function to modify/pass context value to a handler
+func ipMiddleware(hop http.Handler) http.Handler {
+
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		ctx := newCtxUserIP(r.Context(), r)
+		hop.ServeHTTP(w, r.WithContext(ctx))
+	})
+}
+
+func log200(r *http.Request) {
+	useragent := r.Header["User-Agent"]
+
+	uip := getIPFromCtx(r.Context())
+	log.Printf("*** %v :: 200 :: %v %v :: %v\n", uip, r.Method, r.URL, useragent)
+}
+
+func log400(w http.ResponseWriter, r *http.Request, err string) {
+	uip := getIPFromCtx(r.Context())
+	log.Printf("*** %v :: 400 :: %v %v :: %v\n", uip, r.Method, r.URL, err)
+	http.Error(w, "400 Bad Request: "+err, http.StatusBadRequest)
+}
+
+func log404(w http.ResponseWriter, r *http.Request, err error) {
+	useragent := r.Header["User-Agent"]
+
+	uip := getIPFromCtx(r.Context())
+	log.Printf("*** %v :: 404 :: %v %v :: %v :: %v\n", uip, r.Method, r.URL, useragent, err.Error())
+	http.Error(w, err.Error(), http.StatusNotFound)
+}
+
+func log500(w http.ResponseWriter, r *http.Request, err error) {
+	useragent := r.Header["User-Agent"]
+
+	uip := getIPFromCtx(r.Context())
+	log.Printf("*** %v :: 500 :: %v %v :: %v :: %v\n", uip, r.Method, r.URL, useragent, err.Error())
+	http.Error(w, err.Error(), http.StatusInternalServerError)
+}