Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cmd/tsgrok/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ func main() {

messageBus := &util.MessageBusImpl{}
funnelRegistry := funnel.NewFunnelRegistry()
httpServer := funnel.NewHttpServer(util.GetProxyHttpPort(), messageBus, funnelRegistry, serverErrorLog)
httpServer, err := funnel.NewHttpServer(util.GetProxyHttpPort(), messageBus, funnelRegistry, serverErrorLog)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating HTTP server: %v\n", err)
os.Exit(1)
}

m := tui.InitialModel(funnelRegistry, serverErrorLog)

Expand Down
242 changes: 24 additions & 218 deletions internal/funnel/http.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,18 @@
package funnel

import (
"bytes"
"errors"
"fmt"
"io"
"html/template"
stdlog "log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
"strings"
"time"

"github.com/google/uuid"
"github.com/jonson/tsgrok/internal/util"
)
"io/fs"

// Error variables for common failure modes
var (
ErrInvalidFunnelPath = errors.New("invalid path format for funnel request")
ErrFunnelNotFound = errors.New("funnel not found")
ErrFunnelNotReady = errors.New("funnel has no local target configured")
ErrTargetURLParse = errors.New("failed to parse funnel target URL")
"github.com/jonson/tsgrok/internal/util"
"github.com/jonson/tsgrok/web"
)

var HttpServerPath = fmt.Sprintf("/%s/", util.ProgramName)
Expand All @@ -36,17 +24,25 @@ type HttpServer struct {
messageBus util.MessageBus // message bus for sending messages to the program
funnelRegistry *FunnelRegistry // registry of funnels
logger *stdlog.Logger // logger for logging
embeddedTemplates *template.Template
}

func NewHttpServer(port int, messageBus util.MessageBus, funnelRegistry *FunnelRegistry, logger *stdlog.Logger) *HttpServer {
func NewHttpServer(port int, messageBus util.MessageBus, funnelRegistry *FunnelRegistry, logger *stdlog.Logger) (*HttpServer, error) {
// Parse templates from the web.TemplatesFS, first inject a few functions
tmpl, err := loadTemplates()
if err != nil {
return nil, err
}

return &HttpServer{
port: port,
mux: http.NewServeMux(),
requestLimitPerFunnel: 100,
messageBus: messageBus,
funnelRegistry: funnelRegistry,
logger: logger,
}
embeddedTemplates: tmpl,
}, nil
}

func (s *HttpServer) GetFunnelById(id string) (Funnel, error) {
Expand All @@ -67,7 +63,17 @@ func (s *HttpServer) Start() error {
return err
}

staticFilesRoot, err := fs.Sub(web.StaticFS, "static")
if err != nil {
s.logger.Fatalf("FATAL: 'static' subdirectory not found in embedded StaticFS: %v", err)
return err
}
fileServer := http.FileServer(http.FS(staticFilesRoot))
s.mux.Handle("/static/", http.StripPrefix("/static/", fileServer))

s.mux.HandleFunc(HttpServerPath, s.handleRequest)
s.mux.HandleFunc("/inspect/", s.handleFunnelInspect)
s.mux.HandleFunc("/", s.handleRoot)

// do this in a goroutine, we listen in the background
go func() {
Expand All @@ -83,203 +89,3 @@ func (s *HttpServer) Start() error {

return nil
}

func (s *HttpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
pathAfterPrefix := strings.TrimPrefix(r.URL.Path, HttpServerPath)

funnelIdAndRest, err := extractFunnelIdAndRest(pathAfterPrefix)
if err != nil {
// Check for the specific error from extraction
if errors.Is(err, ErrInvalidFunnelPath) {
http.Error(w, ErrInvalidFunnelPath.Error(), http.StatusBadRequest)
} else {
// Handle other unexpected errors during extraction
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
}

// serve hello requests without proxying
if funnelIdAndRest.rest == ".well-known/tsgrok/hello" {
w.WriteHeader(http.StatusOK)
_, err = w.Write([]byte("hello"))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
return
}

funnel, err := s.GetFunnelById(funnelIdAndRest.id)
if err != nil {
http.Error(w, ErrFunnelNotFound.Error(), http.StatusNotFound)
return
}

targetURLStr := funnel.LocalTarget()
if targetURLStr == "" {
http.Error(w, ErrFunnelNotReady.Error(), http.StatusNotFound)
return
}

targetURL, err := url.Parse(targetURLStr)
if err != nil {
s.logger.Printf("Error parsing target URL %q: %v", targetURLStr, err)
http.Error(w, ErrTargetURLParse.Error(), http.StatusInternalServerError)
return
}

proxy := httputil.NewSingleHostReverseProxy(targetURL)

// need this to avoid logging to stderr
proxy.ErrorLog = s.logger

// Define a custom Director
originalDirector := proxy.Director

requestResponse := CaptureRequestResponse{
ID: uuid.New().String(),
FunnelID: funnel.HTTPFunnel.id,
Timestamp: time.Now(),
}

// this is the function that modifies the request before it is sent to the target
proxy.Director = func(req *http.Request) {
originalDirector(req)

// read the request body, the plan is to expose this in the UI somehow, but that comes
// at the expense of increased memory usage... make this better
var reqBodyBytes []byte
var err error
if req.Body != nil && req.Body != http.NoBody {
reqBodyBytes, err = io.ReadAll(req.Body)
if err != nil {
s.logger.Printf("Error reading request body: %v\n", err)
} else {
err = req.Body.Close()
if err != nil {
s.logger.Printf("Error closing request body: %v\n", err)
}
req.Body = io.NopCloser(bytes.NewReader(reqBodyBytes))
req.ContentLength = int64(len(reqBodyBytes))
req.GetBody = nil
}
}

req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
req.URL.Path = singleJoiningSlash(targetURL.Path, funnelIdAndRest.rest)
req.Host = targetURL.Host

if targetURL.RawPath == "" {
req.URL.RawPath = ""
}

headers := make(map[string]string)
for k, v := range req.Header {
headers[k] = strings.Join(v, ",")
}

requestResponse.Request = CaptureRequest{
Method: req.Method,
URL: req.URL.String(),
Body: reqBodyBytes,
Headers: headers,
}
}

// this is the function that modifies the response before it is sent to the client
proxy.ModifyResponse = func(resp *http.Response) error {

headers := make(map[string]string)
for k, v := range resp.Header {
headers[k] = strings.Join(v, ",")
}

requestResponse.Response = CaptureResponse{
Headers: headers,
StatusCode: resp.StatusCode,
}

var respBodyBytes []byte
var err error
if resp.Body != nil && resp.Body != http.NoBody {
respBodyBytes, err = io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
} else {
err = resp.Body.Close()
if err != nil {
s.logger.Printf("Error closing response body: %v\n", err)
}
resp.Body = io.NopCloser(bytes.NewReader(respBodyBytes))
resp.ContentLength = int64(len(respBodyBytes))
resp.Header.Del("Transfer-Encoding")
}
}

requestResponse.Response.Body = respBodyBytes
requestResponse.Duration = time.Since(requestResponse.Timestamp)
return nil
}

// Serve the request via the proxy
proxy.ServeHTTP(w, r)

// add the request response to the list
funnel.Requests.Add(requestResponse)

// broadcast it so UI can update
s.messageBus.Send(ProxyRequestMsg{FunnelId: funnel.HTTPFunnel.id})
}

type FunnelIdAndRest struct {
id string
rest string
}

func extractFunnelIdAndRest(pathAfterPrefix string) (FunnelIdAndRest, error) {
// Check for obviously invalid paths first
if pathAfterPrefix == "" || pathAfterPrefix == "/" {
return FunnelIdAndRest{}, ErrInvalidFunnelPath
}

// Split the remaining path by /
parts := strings.SplitN(pathAfterPrefix, "/", 2)

funnelId := ""
rest := ""

if len(parts) >= 1 {
funnelId = parts[0]
}
if len(parts) == 2 {
rest = parts[1]
}

// Check if funnelId is empty after splitting (e.g., path started with /)
if funnelId == "" {
return FunnelIdAndRest{}, ErrInvalidFunnelPath // Use the specific error
}

return FunnelIdAndRest{id: funnelId, rest: rest}, nil
}

func singleJoiningSlash(a, b string) string {
if a == "" && b == "" {
return "/"
}
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
// Avoid adding slash if b is empty or a is just "/"
if b == "" || a == "/" {
return a + b
}
return a + "/" + b
}
return a + b
}
76 changes: 76 additions & 0 deletions internal/funnel/http_helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package funnel

import (
"strings"
)

// funnelName is a helper to get a display name for the funnel.
func funnelName(f Funnel) string {
name := f.Name() // This method derives from RemoteTarget
if name == "" {
return f.ID() // Fallback to ID if name is empty
}
return name
}

// findRequestInList iterates through the funnel's request list to find a request by its ID.
func findRequestInList(requestList *RequestList, requestID string) *CaptureRequestResponse {
if requestList == nil {
return nil
}
requestList.mu.Lock()
defer requestList.mu.Unlock()
currentNode := requestList.Head
for currentNode != nil {
if currentNode.Request.ID == requestID {
crCopy := currentNode.Request
return &crCopy
}
currentNode = currentNode.Next
}
return nil
}

// singleJoiningSlash is a utility function for joining URL paths.
func singleJoiningSlash(a, b string) string {
if a == "" && b == "" {
return "/"
}
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
if b == "" || a == "/" {
return a + b
}
return a + "/" + b
}
return a + b
}

// extractFunnelIdAndRest extracts the funnel ID and the rest of the path from a URL path string.
func extractFunnelIdAndRest(pathAfterPrefix string) (FunnelIdAndRest, error) {
if pathAfterPrefix == "" || pathAfterPrefix == "/" {
return FunnelIdAndRest{}, ErrInvalidFunnelPath
}

parts := strings.SplitN(pathAfterPrefix, "/", 2)

funnelId := ""
rest := ""

if len(parts) >= 1 {
funnelId = parts[0]
}
if len(parts) == 2 {
rest = parts[1]
}

if funnelId == "" {
return FunnelIdAndRest{}, ErrInvalidFunnelPath
}

return FunnelIdAndRest{id: funnelId, rest: rest}, nil
}
Loading