summary refs log tree commit diff stats
path: root/svc
diff options
context:
space:
mode:
Diffstat (limited to 'svc')
-rw-r--r--svc/handlers.go6
-rw-r--r--svc/handlers_test.go71
2 files changed, 68 insertions, 9 deletions
diff --git a/svc/handlers.go b/svc/handlers.go
index 5bb0d4f..02975ef 100644
--- a/svc/handlers.go
+++ b/svc/handlers.go
@@ -133,9 +133,13 @@ func apiEndpointHandler(w http.ResponseWriter, r *http.Request) {
 		out, err = twtxtCache.QueryInStatus("@<")
 		out = registry.ReduceToPage(page, out)
 
-	default:
+	case "/api/plain/tweets":
 		out, err = twtxtCache.QueryAllStatuses()
 		out = registry.ReduceToPage(page, out)
+
+	default:
+		log404(w, r, fmt.Errorf("endpoint not found"))
+		return
 	}
 	errLog("", err)
 
diff --git a/svc/handlers_test.go b/svc/handlers_test.go
index 36c795a..fa73de5 100644
--- a/svc/handlers_test.go
+++ b/svc/handlers_test.go
@@ -52,17 +52,72 @@ func Test_apiFormatHandler(t *testing.T) {
 		}
 	})
 }
+
+var endpointCases = []struct {
+	name   string
+	req    *http.Request
+	status int
+}{
+	{
+		name:   "Regular Query: /api/plain/users",
+		req:    httptest.NewRequest("GET", "http://localhost"+testport+"/api/plain/users", nil),
+		status: http.StatusOK,
+	},
+	{
+		name:   "Regular Query: /api/plain/mentions",
+		req:    httptest.NewRequest("GET", "http://localhost"+testport+"/api/plain/mentions", nil),
+		status: http.StatusOK,
+	},
+	{
+		name:   "Regular Query: /api/plain/tweets",
+		req:    httptest.NewRequest("GET", "http://localhost"+testport+"/api/plain/tweets", nil),
+		status: http.StatusOK,
+	},
+	{
+		name:   "Invalid Endpoint: /api/plain/statuses",
+		req:    httptest.NewRequest("GET", "http://localhost"+testport+"/api/plain/statuses", nil),
+		status: http.StatusNotFound,
+	},
+}
+
 func Test_apiEndpointHandler(t *testing.T) {
 	initTestConf()
-	t.Run("apiEndpointHandler", func(t *testing.T) {
-		w := httptest.NewRecorder()
-		req := httptest.NewRequest("GET", "localhost"+testport+"/api/plain/users", nil)
-		apiEndpointHandler(w, req)
-		resp := w.Result()
-		if resp.StatusCode != http.StatusOK {
-			t.Errorf(fmt.Sprintf("%v", resp.StatusCode))
+	mockRegistry()
+	for _, tt := range endpointCases {
+		t.Run(tt.name, func(t *testing.T) {
+			w := httptest.NewRecorder()
+			apiEndpointHandler(w, tt.req)
+			resp := w.Result()
+			if resp.StatusCode != tt.status {
+				t.Errorf(fmt.Sprintf("%v", resp.StatusCode))
+			}
+			if tt.status == http.StatusOK {
+				var body []byte
+				buf := bytes.NewBuffer(body)
+				err := resp.Write(buf)
+				if err != nil {
+					t.Errorf("%v\n", err)
+				}
+				if buf == nil {
+					t.Errorf("Got nil\n")
+				}
+				if len(buf.Bytes()) == 0 {
+					t.Errorf("Got zero data\n")
+				}
+			}
+		})
+	}
+}
+func Benchmark_apiEndpointHandler(b *testing.B) {
+	initTestConf()
+	mockRegistry()
+	w := httptest.NewRecorder()
+	b.ResetTimer()
+	for _, tt := range endpointCases {
+		for i := 0; i < b.N; i++ {
+			apiEndpointHandler(w, tt.req)
 		}
-	})
+	}
 }
 
 func Test_apiTagsBaseHandler(t *testing.T) {