Skip to content

Commit

Permalink
give connections an id
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnewhall committed Feb 6, 2024
1 parent 6abcda0 commit b0ec70c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
36 changes: 20 additions & 16 deletions client/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"runtime/debug"
"strconv"
"time"

"github.com/gorilla/websocket"
Expand All @@ -34,6 +36,7 @@ type Connection struct {
status int
setStatus chan int
getStatus chan int
id string
}

// NewConnection creates a Connection object.
Expand All @@ -43,12 +46,13 @@ func NewConnection(pool *Pool) *Connection {
status: CONNECTING,
setStatus: make(chan int),
getStatus: make(chan int),
id: strconv.Itoa(rand.Intn(899) + 100), //nolint:gomnd
}
}

// Connect to the remote server using an HTTP websocket.
func (c *Connection) Connect(ctx context.Context) error {
c.pool.client.Debugf("Connecting to tunnel @ %s", c.pool.target)
c.pool.client.Debugf("[%s] Connecting to tunnel @ %s", c.id, c.pool.target)

var err error
// Create a new TCP(/TLS) connection (no use of net.http).
Expand All @@ -59,7 +63,7 @@ func (c *Connection) Connect(ctx context.Context) error {
http.Header{mulch.SecretKeyHeader: {c.pool.secretKey}},
)
if err != nil {
return fmt.Errorf("tcp dialer failure: %w", err)
return fmt.Errorf("[%s] tcp dialer failure: %w", c.id, err)
}

c.ws.EnableWriteCompression(true)
Expand All @@ -77,7 +81,7 @@ func (c *Connection) Connect(ctx context.Context) error {

if err := c.ws.WriteJSON(greeting); err != nil {
c.pool.Remove(c)
return fmt.Errorf("greeting failure: %w", err)
return fmt.Errorf("[%s] greeting failure: %w", c.id, err)
}

// We are connected to the server, now start a go routine that waits for incoming server requests.
Expand All @@ -97,7 +101,7 @@ func (c *Connection) keepAlive() {
case tick := <-ticker.C:
err := c.ws.WriteControl(websocket.PingMessage, []byte{}, tick.Add(keepAliveTimeout))
if err != nil {
c.pool.client.Errorf("Tunnel keep-alive failure: %v", err)
c.pool.client.Errorf("[%s] Tunnel keep-alive failure: %v", c.id, err)
return
}
case status, ok := <-c.setStatus:
Expand Down Expand Up @@ -141,9 +145,9 @@ func (c *Connection) catchPanic() {
if r := recover(); r != nil {
// https://github.com/golang/go/blob/b100e127ca0e398fbb58d04d04e2443b50b3063e/src/runtime/chan.go#LL206C15-L206C15
if err, _ := r.(error); err != nil && err.Error() != "send on closed channel" { // ignore this specific panic.
c.pool.client.Errorf("panic error: %v\n%s", err, string(debug.Stack()))
c.pool.client.Errorf("[%s] panic error: %v\n%s", c.id, err, string(debug.Stack()))
} else if err == nil {
c.pool.client.Errorf("panic: %v\n%s", r, string(debug.Stack()))
c.pool.client.Errorf("[%s] panic: %v\n%s", c.id, r, string(debug.Stack()))
}
}
}
Expand All @@ -158,7 +162,7 @@ func (c *Connection) serveHandler() bool {
_, jsonRequest, err := c.ws.ReadMessage()
if err != nil {
if !c.pool.shutdown {
c.pool.client.Errorf("While waiting for a tunnel request: %v", err)
c.pool.client.Errorf("[%s] While waiting for a tunnel request: %v", c.id, err)
}

return false
Expand All @@ -169,7 +173,7 @@ func (c *Connection) serveHandler() bool {

httpRequest := new(mulch.HTTPRequest) // Deserialize request.
if err := json.Unmarshal(jsonRequest, httpRequest); err != nil {
c.error(fmt.Sprintf("Deserializing json tunnel request: %s", err))
c.error(fmt.Sprintf("[%s] Deserializing json tunnel request: %s", c.id, err))
return false
}

Expand All @@ -184,7 +188,7 @@ func (c *Connection) serveHandler() bool {
// Pipe request body.
_, bodyReader, err := c.ws.NextReader()
if err != nil {
c.pool.client.Errorf("Getting tunnel response body reader: %v", err)
c.pool.client.Errorf("[%s] Getting tunnel response body reader: %v", c.id, err)
return false
}

Expand All @@ -198,17 +202,17 @@ func (c *Connection) defaultHandler(req *http.Request) bool {
// This is where a local client sends the server's request off to the Internet.
resp, err := c.pool.client.client.Do(req)
if err != nil {
return !c.error(fmt.Sprintf("Executing tunneled request: %v", err))
return !c.error(fmt.Sprintf("[%s] Executing tunneled request: %v", c.id, err))
}

bodyWriter, err := c.writeResponseHeaders(resp)
if err != nil {
c.pool.client.Errorf("Making request: %v", err)
c.pool.client.Errorf("[%s] Making request: %v", c.id, err)
return false
}

if _, err := io.Copy(bodyWriter, resp.Body); err != nil {
c.pool.client.Errorf("Getting tunnel pipe response body: %v", err)
c.pool.client.Errorf("[%s] Getting tunnel pipe response body: %v", c.id, err)
return false
}

Expand All @@ -222,13 +226,13 @@ func (c *Connection) writeResponseHeaders(resp *http.Response) (io.WriteCloser,
// This is where we send the Internet's (http request) response back to the server.
err := c.ws.WriteMessage(websocket.TextMessage, mulch.SerializeHTTPResponse(resp))
if err != nil {
return nil, fmt.Errorf("writing tunnel response: %w", err)
return nil, fmt.Errorf("[%s] writing tunnel response: %w", c.id, err)
}

// Pipe response body because an io.ReadCloser (http.Body) doesn't get serialized (above).
bodyWriter, err := c.ws.NextWriter(websocket.BinaryMessage)
if err != nil {
return nil, fmt.Errorf("getting tunnel response body writer: %w", err)
return nil, fmt.Errorf("[%s] getting tunnel response body writer: %w", c.id, err)
}

return bodyWriter, nil
Expand All @@ -244,14 +248,14 @@ func (c *Connection) error(msg string) bool {
// Write response
err := c.ws.WriteMessage(websocket.TextMessage, resp)
if err != nil {
c.pool.client.Errorf("Writing tunnel response: %v", err)
c.pool.client.Errorf("[%s] Writing tunnel response: %v", c.id, err)
return true
}

// Write response body
err = c.ws.WriteMessage(websocket.BinaryMessage, []byte(msg))
if err != nil {
c.pool.client.Errorf("Writing tunnel response body: %v", err)
c.pool.client.Errorf("[%s] Writing tunnel response body: %v", c.id, err)
return true
}

Expand Down
4 changes: 2 additions & 2 deletions client/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ func (r *req2Handler) Write(data []byte) (int, error) {
}

if r.body == nil {
return 0, fmt.Errorf("tunnel write failed: %w", r.err)
return 0, fmt.Errorf("[%s] tunnel write failed: %w", r.conn.id, r.err)
}

size, err := r.body.Write(data)
r.resp.ContentLength += int64(size)

if err != nil {
r.err = err
return size, fmt.Errorf("tunnel write failed: %w", err)
return size, fmt.Errorf("[%s] tunnel write failed: %w", r.conn.id, err)
}

return size, nil
Expand Down

0 comments on commit b0ec70c

Please sign in to comment.