/* Copyright (c) 2019 Ben Morrison (gbmor) This file is part of Getwtxt. Getwtxt is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. Getwtxt is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with Getwtxt. If not, see . */ package svc // import "github.com/getwtxt/getwtxt/svc" import ( "crypto/sha256" "fmt" "log" "net/http" "strconv" "strings" "sync" "github.com/getwtxt/registry" "github.com/gorilla/mux" ) // Wrapper to check if an error is non-nil, then // log the error if applicable. func apiErrCheck(err error, r *http.Request) { if err != nil { uip := getIPFromCtx(r.Context()) log.Printf("*** %v :: %v %v :: %v\n", uip, r.Method, r.URL, err.Error()) } } // Deduplicates a slice of strings func dedupe(list []string) []string { out := []string{} seen := make(map[string]bool) for _, e := range list { if !seen[e] { out = append(out, e) seen[e] = true } } return out } // Takes the output of queries and formats it for // an HTTP response. Iterates over the string slice, // appending each entry to a byte slice, and adding // newlines where appropriate. func parseQueryOut(out []string) []byte { var data []byte for i, e := range out { data = append(data, []byte(e)...) if !strings.HasSuffix(e, "\n") && i != len(out)-1 { data = append(data, byte('\n')) } } return data } // apiEndpointQuery is called via apiEndpointHandler when // the endpoint is "users" and r.FormValue("q") is not empty. // It queries the registry cache for users or user URLs // matching the term supplied via r.FormValue("q") func apiEndpointQuery(w http.ResponseWriter, r *http.Request) error { query := r.FormValue("q") urls := r.FormValue("url") pageVal := r.FormValue("page") var out []string var err error pageVal = strings.TrimSpace(pageVal) page, err := strconv.Atoi(pageVal) errLog("", err) vars := mux.Vars(r) endpoint := vars["endpoint"] // Handle user URL queries first, then nickname queries. // Concatenate both outputs if they're both set. // Also handle mention queries and status queries. // If we made it this far and 'default' is matched, // something went very wrong. switch endpoint { case "users": var out2 []string if query != "" { out, err = twtxtCache.QueryUser(query) apiErrCheck(err, r) } if urls != "" { out2, err = twtxtCache.QueryUser(urls) apiErrCheck(err, r) } if query != "" && urls != "" { out = joinQueryOuts(out2) } case "mentions": if urls == "" { return fmt.Errorf("missing URL in mention query") } urls += ">" out, err = twtxtCache.QueryInStatus(urls) apiErrCheck(err, r) case "tweets": out = compositeStatusQuery(query, r) default: return fmt.Errorf("endpoint query, no cases match") } out = registry.ReduceToPage(page, out) data := parseQueryOut(out) etag := fmt.Sprintf("%x", sha256.Sum256(data)) w.Header().Set("ETag", etag) w.Header().Set("Content-Type", txtutf8) _, err = w.Write(data) return err } // For composite queries, join the various slices of strings // into a single slice of strings, then deduplicates them. func joinQueryOuts(data ...[]string) []string { single := []string{} for _, e := range data { single = append(single, e...) } return dedupe(single) } // Performs a composite query against the statuses. func compositeStatusQuery(query string, r *http.Request) []string { var wg sync.WaitGroup var out, out2, out3 []string var err, err2, err3 error wg.Add(3) query = strings.ToLower(query) go func(query string) { out, err = twtxtCache.QueryInStatus(query) wg.Done() }(query) query = strings.Title(query) go func(query string) { out2, err2 = twtxtCache.QueryInStatus(query) wg.Done() }(query) query = strings.ToUpper(query) go func(query string) { out3, err3 = twtxtCache.QueryInStatus(query) wg.Done() }(query) wg.Wait() apiErrCheck(err, r) apiErrCheck(err2, r) apiErrCheck(err3, r) return joinQueryOuts(out, out2, out3) }