From a812bed3347bbfc12583007fca41a72c479c3adb Mon Sep 17 00:00:00 2001 From: Antonio Mika Date: Mon, 27 Jan 2025 22:02:02 -0500 Subject: [PATCH] Add websocket to pipe web --- go.mod | 1 + go.sum | 2 + pipe/api.go | 119 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 122 insertions(+) diff --git a/go.mod b/go.mod index 17aa9c60..0b7c9709 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/gorilla/feeds v1.2.0 + github.com/gorilla/websocket v1.5.3 github.com/jmoiron/sqlx v1.4.0 github.com/lib/pq v1.10.9 github.com/microcosm-cc/bluemonday v1.0.27 diff --git a/go.sum b/go.sum index 057486b3..4eceb630 100644 --- a/go.sum +++ b/go.sum @@ -457,6 +457,8 @@ github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/gorilla/feeds v1.2.0 h1:O6pBiXJ5JHhPvqy53NsjKOThq+dNFm8+DFrxBEdzSCc= github.com/gorilla/feeds v1.2.0/go.mod h1:WMib8uJP3BbY+X8Szd1rA5Pzhdfh+HCCAYT2z7Fza6Y= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= diff --git a/pipe/api.go b/pipe/api.go index 76ac3240..fecb09ac 100644 --- a/pipe/api.go +++ b/pipe/api.go @@ -15,6 +15,7 @@ import ( "time" "github.com/google/uuid" + "github.com/gorilla/websocket" "github.com/picosh/pico/db/postgres" "github.com/picosh/pico/shared" "github.com/picosh/utils/pipe" @@ -23,6 +24,11 @@ import ( var ( cleanRegex = regexp.MustCompile(`[^0-9a-zA-Z,/]`) sshClient *pipe.Client + upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } ) func serveFile(file string, contentType string) http.HandlerFunc { @@ -264,6 +270,118 @@ func handlePub(pubsub bool) http.HandlerFunc { } } +func handlePipe() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + logger := shared.GetLogger(r) + + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Error("pipe upgrade error", "err", err.Error()) + return + } + + defer c.Close() + + clientInfo := shared.NewPicoPipeClient() + topic, _ := url.PathUnescape(shared.GetField(r, 0)) + + topic = cleanRegex.ReplaceAllString(topic, "") + + logger.Info("pipe", "topic", topic, "info", clientInfo) + + params := "-p -c" + if r.URL.Query().Get("status") == "true" { + params = params[:len(params)-3] + } + + if r.URL.Query().Get("replay") == "true" { + params += " -r" + } + + messageType := websocket.TextMessage + if r.URL.Query().Get("binary") == "true" { + messageType = websocket.BinaryMessage + } + + if accessList := r.URL.Query().Get("access"); accessList != "" { + logger.Info("adding access list", "topic", topic, "info", clientInfo, "access", accessList) + cleanList := cleanRegex.ReplaceAllString(accessList, "") + params += fmt.Sprintf(" -a=%s", cleanList) + } + + id := uuid.NewString() + + p, err := sshClient.AddSession(id, fmt.Sprintf("pipe %s %s", params, topic), 0, -1, -1) + if err != nil { + logger.Error("pipe error", "topic", topic, "info", clientInfo, "err", err.Error()) + http.Error(w, "server error", http.StatusInternalServerError) + return + } + + go func() { + <-r.Context().Done() + err := sshClient.RemoveSession(id) + if err != nil { + logger.Error("pipe remove error", "topic", topic, "info", clientInfo, "err", err.Error()) + } + c.Close() + }() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer func() { + p.Close() + c.Close() + wg.Done() + }() + + for { + _, message, err := c.ReadMessage() + if err != nil { + logger.Error("pipe read error", "topic", topic, "info", clientInfo, "err", err.Error()) + break + } + + _, err = p.Write(message) + if err != nil { + logger.Error("pipe write error", "topic", topic, "info", clientInfo, "err", err.Error()) + break + } + } + }() + + go func() { + defer func() { + p.Close() + c.Close() + wg.Done() + }() + + for { + buf := make([]byte, 32*1024) + + n, err := p.Read(buf) + if err != nil { + logger.Error("pipe read error", "topic", topic, "info", clientInfo, "err", err.Error()) + break + } + + buf = buf[:n] + + err = c.WriteMessage(messageType, buf) + if err != nil { + logger.Error("pipe write error", "topic", topic, "info", clientInfo, "err", err.Error()) + break + } + } + }() + + wg.Wait() + } +} + func createMainRoutes(staticRoutes []shared.Route) []shared.Route { routes := []shared.Route{ shared.NewRoute("GET", "/", shared.CreatePageHandler("html/marketing.page.tmpl")), @@ -275,6 +393,7 @@ func createMainRoutes(staticRoutes []shared.Route) []shared.Route { shared.NewRoute("POST", "/topic/(.+)", handlePub(false)), shared.NewRoute("GET", "/pubsub/(.+)", handleSub(true)), shared.NewRoute("POST", "/pubsub/(.+)", handlePub(true)), + shared.NewRoute("GET", "/pipe/(.+)", handlePipe()), } for _, route := range pipeRoutes {