From b44556f706edcbbb5332c6b0a55671f61ce414e3 Mon Sep 17 00:00:00 2001 From: Alex Gartner Date: Mon, 25 Nov 2024 23:09:41 -0800 Subject: [PATCH] server: use struct --- .gitignore | 2 + client/tunnel.go | 7 +- cmd/tunnel-client/main.go | 7 +- cmd/tunnel-server/main.go | 70 ++++++++ server/main.go | 333 -------------------------------------- server/server.go | 285 ++++++++++++++++++++++++++++++++ 6 files changed, 369 insertions(+), 335 deletions(-) create mode 100644 .gitignore create mode 100644 cmd/tunnel-server/main.go delete mode 100644 server/main.go create mode 100644 server/server.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..10cdeb2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +server.crt +server.key diff --git a/client/tunnel.go b/client/tunnel.go index 4146141..74c4805 100644 --- a/client/tunnel.go +++ b/client/tunnel.go @@ -101,7 +101,12 @@ func (t *Tunnel) stage1(print bool) (net.Conn, error) { } res := string(buf[:n]) if print { - fmt.Printf("URL: https://%s\n", res) + _, port, _ := net.SplitHostPort(t.server) + portPart := "" + if port != "443" { + portPart = fmt.Sprintf(":%s", port) + } + fmt.Printf("URL: https://%s%s\n", res, portPart) } return conn, nil } diff --git a/cmd/tunnel-client/main.go b/cmd/tunnel-client/main.go index 4cca8e0..b830b90 100644 --- a/cmd/tunnel-client/main.go +++ b/cmd/tunnel-client/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "log" + "net" "os" "os/signal" "strings" @@ -85,9 +86,13 @@ var rootCmd = &cobra.Command{ if !strings.Contains(server, ":") { controlName += ":443" } + serverHostOnly := server + if strings.Contains(server, ":") { + serverHostOnly, _, _ = net.SplitHostPort(server) + } hostnameFqdn := hostname if hostnameFqdn != "" && !strings.Contains(hostnameFqdn, ".") { - hostnameFqdn = strings.Join([]string{hostname, server}, ".") + hostnameFqdn = strings.Join([]string{hostname, serverHostOnly}, ".") } tunnel := client.NewTunnel(controlName, hostnameFqdn, token, useTLS, tlsSkipVerify, httpTargetHostHeader, target) diff --git a/cmd/tunnel-server/main.go b/cmd/tunnel-server/main.go new file mode 100644 index 0000000..b7c8a9b --- /dev/null +++ b/cmd/tunnel-server/main.go @@ -0,0 +1,70 @@ +package main + +import ( + "crypto/tls" + "fmt" + "os" + + "github.com/foomo/simplecert" + "gitlab.com/gartnera/tunnel/server" + "go.uber.org/zap" +) + +func main() { + + var ok bool + var err error + basename, ok := os.LookupEnv("TUNNEL_BASENAME") + if !ok { + panic("TUNNEL_BASENAME not defined") + } + port, ok := os.LookupEnv("TUNNEL_PORT") + if !ok { + panic("TUNNEL_PORT not defined") + } + + _, ok = os.LookupEnv("DEBUG") + var logger *zap.Logger + if ok { + logger, err = zap.NewDevelopment() + } else { + logger, err = zap.NewProduction() + } + if err != nil { + panic(err) + } + + sCfg := simplecert.Default + sCfg.Domains = []string{fmt.Sprintf("*.%s", basename)} + sCfg.CacheDir = os.Getenv("SIMPLECERT_CACHE_DIR") + sCfg.SSLEmail = os.Getenv("SIMPLECERT_SSL_EMAIL") + sCfg.DNSProvider = os.Getenv("SIMPLECERT_DNS_PROVIDER") + // simply restart server when certificate is renewed. rely on systemd to restart + sCfg.DidRenewCertificate = func() { + os.Exit(2) + } + if os.Getenv("SIMPLECERT_USE_PUBLIC_DNS") != "" { + sCfg.DNSServers = []string{"1.1.1.1"} + } + + config := &tls.Config{} + cer, err := tls.LoadX509KeyPair("server.crt", "server.key") + if err == nil { + config.Certificates = []tls.Certificate{cer} + } else if sCfg.DNSProvider != "" { + certReloader, err := simplecert.Init(sCfg, nil) + if err != nil { + panic(err) + } + config.GetCertificate = certReloader.GetCertificateFunc() + } else { + logger.Fatal("could not parse cert or initiate simplecert", zap.Error(err)) + } + + server := server.New(basename, logger) + laddr := ":" + port + err = server.Start(laddr, config) + if err != nil { + logger.Fatal("server start", zap.Error(err)) + } +} diff --git a/server/main.go b/server/main.go deleted file mode 100644 index 881ad4a..0000000 --- a/server/main.go +++ /dev/null @@ -1,333 +0,0 @@ -package main - -import ( - "context" - "crypto/tls" - "fmt" - "log" - "net" - "os" - "strings" - "sync" - - "github.com/foomo/simplecert" - "github.com/icrowley/fake" - gNet "gitlab.com/gartnera/golib/net" - "go.uber.org/zap" -) - -var basename string -var controlName string -var port string -var logger *zap.Logger - -var state = struct { - sync.RWMutex - hostnameMap map[string]*ProxySession - secretMap map[string]*ProxySession -}{ - hostnameMap: make(map[string]*ProxySession), - secretMap: make(map[string]*ProxySession), -} - -// getHostname generates a three word unique subdomain -// recursively call self until we get a unique name -func getHostname() string { - res := fmt.Sprintf("%s-%s-%s.%s", fake.Word(), fake.Word(), fake.Word(), basename) - _, exists := state.hostnameMap[res] - _, wildcardExists := getWildcardHostname(res) - if exists || wildcardExists { - return getHostname() - } - return res -} - -func getWildcardHostname(serverName string) (string, bool) { - namePrefix := strings.Split(serverName, "-")[0] - wildcardHostname := fmt.Sprintf("%s-*.%s", namePrefix, basename) - _, exists := state.hostnameMap[wildcardHostname] - return wildcardHostname, exists -} - -type ProxySession struct { - sync.RWMutex - secret string - conns chan net.Conn - backendCount int - hostname string -} - -func NewProxySession(secret string, hostname string) *ProxySession { - session := &ProxySession{ - secret: secret, - conns: make(chan net.Conn, 500), - hostname: hostname, - } - state.Lock() - defer state.Unlock() - state.hostnameMap[hostname] = session - state.secretMap[secret] = session - - return session -} - -func (s *ProxySession) backendConnected(conn net.Conn) { - logger.Debug("backend connected", - zap.String("remoteAddr", conn.RemoteAddr().String()), - zap.String("hostname", s.hostname), - ) - s.Lock() - logger.Debug("backend connected (inside lock)", - zap.String("remoteAddr", conn.RemoteAddr().String()), - zap.String("hostname", s.hostname), - ) - s.backendCount++ - s.Unlock() - conn.Write([]byte(s.hostname)) - s.conns <- conn - logger.Debug("backend connected (after chan)", - zap.String("remoteAddr", conn.RemoteAddr().String()), - zap.String("hostname", s.hostname), - ) -} - -func (s *ProxySession) backendDisconnected() { - s.Lock() - defer s.Unlock() - s.backendCount-- - if s.backendCount == 0 { - state.Lock() - defer state.Unlock() - delete(state.hostnameMap, s.hostname) - delete(state.secretMap, s.secret) - close(s.conns) - fmt.Printf("all backends disconnected for %s\n", s.hostname) - } -} - -func (s *ProxySession) getBackend() net.Conn { - s.RLock() - if s.backendCount == 0 { - logger.Error("no backends available") - return nil - } - s.RUnlock() - backend, ok := <-s.conns - if !ok { - logger.Debug("conns closed", zap.String("hostname", s.hostname)) - return nil - } - logger.Debug("got backend", - zap.String("backendAddr", backend.RemoteAddr().String()), - ) - _, err := backend.Write([]byte("frontend-connected")) - if err != nil { - logger.Error("backend rejected frontend-connected", zap.Error(err)) - backend.Close() - s.backendDisconnected() - return s.getBackend() - } - return backend -} - -func main() { - log.SetFlags(log.Lshortfile) - - var ok bool - var err error - basename, ok = os.LookupEnv("TUNNEL_BASENAME") - if !ok { - panic("TUNNEL_BASENAME not defined") - } - controlName = "control." + basename - port, ok = os.LookupEnv("TUNNEL_PORT") - if !ok { - panic("TUNNEL_PORT not defined") - } - - _, ok = os.LookupEnv("DEBUG") - if ok { - logger, err = zap.NewDevelopment() - } else { - logger, err = zap.NewProduction() - } - if err != nil { - panic(err) - } - - sCfg := simplecert.Default - sCfg.Domains = []string{fmt.Sprintf("*.%s", basename)} - sCfg.CacheDir = os.Getenv("SIMPLECERT_CACHE_DIR") - sCfg.SSLEmail = os.Getenv("SIMPLECERT_SSL_EMAIL") - sCfg.DNSProvider = os.Getenv("SIMPLECERT_DNS_PROVIDER") - // simply restart server when certificate is renewed. rely on systemd to restart - sCfg.DidRenewCertificate = func() { - os.Exit(2) - } - if os.Getenv("SIMPLECERT_USE_PUBLIC_DNS") != "" { - sCfg.DNSServers = []string{"1.1.1.1"} - } - - var serverName string - config := &tls.Config{ - GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { - serverName = info.ServerName - return nil, nil - }, - } - cer, err := tls.LoadX509KeyPair("server.crt", "server.key") - if err == nil { - config.Certificates = []tls.Certificate{cer} - } else if sCfg.DNSProvider != "" { - certReloader, err := simplecert.Init(sCfg, nil) - if err != nil { - panic(err) - } - config.GetCertificate = certReloader.GetCertificateFunc() - } else { - logger.Fatal("could not parse cert or initiate simplecert", zap.Error(err)) - } - - laddr := ":" + port - ln, err := tls.Listen("tcp", laddr, config) - if err != nil { - logger.Fatal("could not listen", zap.String("laddr", laddr), zap.Error(err)) - } - defer ln.Close() - - ctx := context.Background() - for { - conn, err := ln.Accept() - if err != nil { - logger.Error("unable to accept", zap.Error(err)) - continue - } - logger.Debug("new connection", - zap.String("localAddr", conn.LocalAddr().String()), - zap.String("remoteAddr", conn.RemoteAddr().String()), - ) - // the tls connection isn't initialized until one side reads or writes - // we need to read immediately to get the ServerName before goroutine - conn.Read(nil) - go handleConnection(ctx, conn, serverName) - } -} - -func handleBackend(conn net.Conn, serverName string) { - logger := logger.With( - zap.String("remoteAddr", conn.RemoteAddr().String()), - ) - buf := make([]byte, 1024) - n, err := conn.Read(buf) - if err != nil { - logger.Error("unable to read command from conn", zap.Error(err)) - conn.Close() - return - } - s := string(buf[:n]) - ss := strings.Split(s, ":") - ssLen := len(ss) - if ssLen != 3 { - logger.Error("invalid command from conn") - conn.Close() - return - } - cmd := ss[0] - secret := ss[1] - state.Lock() - session, existingSessionFound := state.secretMap[secret] - state.Unlock() - if cmd == "backend-shutdown" { - defer conn.Close() - if !existingSessionFound { - logger.Error("invalid shutdown command", zap.String("cmd", cmd)) - return - } - logger.Debug("shutdown requested", zap.String("hostname", session.hostname)) - state.Lock() - defer state.Unlock() - delete(state.hostnameMap, session.hostname) - delete(state.secretMap, session.secret) - return - } - if cmd != "backend-open" { - logger.Error("unknown command", zap.String("cmd", cmd)) - conn.Close() - return - } - if existingSessionFound { - logger.Debug("existing session found", zap.String("hostname", session.hostname)) - session.backendConnected(conn) - return - } - hostname := ss[2] - if hostname == "" { - hostname = getHostname() - } - if !strings.HasSuffix(hostname, basename) { - logger.Error("requested hostname needs basename", - zap.String("hostname", hostname), - zap.String("basename", basename), - ) - conn.Close() - return - } - if strings.HasPrefix(hostname, "control.") { - logger.Error("ignoring request for control") - conn.Close() - return - } - // test hostname exists (secret mismatch) - state.Lock() - _, exists := state.hostnameMap[hostname] - state.Unlock() - // test wildcard exists - _, wildcardExists := getWildcardHostname(hostname) - if exists || wildcardExists { - logger.Error("hostname already exists", zap.String("hostname", hostname)) - conn.Close() - return - } - - session = NewProxySession(secret, hostname) - logger.Info("new session", zap.String("hostname", hostname)) - session.backendConnected(conn) -} - -func handleFrontend(ctx context.Context, conn net.Conn, serverName string) { - state.Lock() - session, ok := state.hostnameMap[serverName] - state.Unlock() - - // look for wildcard match in hostnameMap - wildcardHostname, wildcardExists := getWildcardHostname(serverName) - if !ok && wildcardExists { - state.Lock() - session, ok = state.hostnameMap[wildcardHostname] - state.Unlock() - } - if !ok { - conn.Close() - return - } - backend := session.getBackend() - if backend == nil { - logger.Error("nil backend") - conn.Close() - return - } - - gNet.PipeConn(ctx, backend, conn) - backend.Close() - conn.Close() - session.backendDisconnected() -} - -func handleConnection(ctx context.Context, conn net.Conn, serverName string) { - if serverName == controlName { - logger.Debug("new control connection", zap.String("remoteAddr", conn.RemoteAddr().String())) - handleBackend(conn, serverName) - return - } - handleFrontend(ctx, conn, serverName) -} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..c84eaac --- /dev/null +++ b/server/server.go @@ -0,0 +1,285 @@ +package server + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "strings" + "sync" + + "github.com/icrowley/fake" + gNet "gitlab.com/gartnera/golib/net" + "go.uber.org/zap" +) + +type Server struct { + basename string + controlName string + logger *zap.Logger + + sync.RWMutex + hostnameMap map[string]*proxySession + secretMap map[string]*proxySession +} + +func New(basename string, logger *zap.Logger) *Server { + return &Server{ + basename: basename, + controlName: fmt.Sprintf("control.%s", basename), + logger: logger, + hostnameMap: make(map[string]*proxySession), + secretMap: make(map[string]*proxySession), + } +} + +func (s *Server) Start(laddr string, tlsConfig *tls.Config) error { + var serverName string + tlsConfig.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + serverName = info.ServerName + return nil, nil + } + ln, err := tls.Listen("tcp", laddr, tlsConfig) + if err != nil { + s.logger.Fatal("could not listen", zap.String("laddr", laddr), zap.Error(err)) + } + defer ln.Close() + + ctx := context.Background() + for { + conn, err := ln.Accept() + if err != nil { + s.logger.Error("unable to accept", zap.Error(err)) + continue + } + s.logger.Debug("new connection", + zap.String("localAddr", conn.LocalAddr().String()), + zap.String("remoteAddr", conn.RemoteAddr().String()), + ) + // the tls connection isn't initialized until one side reads or writes + // we need to read immediately to get the ServerName before goroutine + conn.Read(nil) + go s.handleConnection(ctx, conn, serverName) + } +} + +// getHostname generates a three word unique subdomain +// recursively call self until we get a unique name +func (s *Server) getHostname() string { + res := fmt.Sprintf("%s-%s-%s.%s", fake.Word(), fake.Word(), fake.Word(), s.basename) + _, exists := s.hostnameMap[res] + _, wildcardExists := s.getWildcardHostname(res) + if exists || wildcardExists { + return s.getHostname() + } + return res +} + +func (s *Server) getWildcardHostname(serverName string) (string, bool) { + namePrefix := strings.Split(serverName, "-")[0] + wildcardHostname := fmt.Sprintf("%s-*.%s", namePrefix, s.basename) + _, exists := s.hostnameMap[wildcardHostname] + return wildcardHostname, exists +} + +type proxySession struct { + sync.RWMutex + server *Server + secret string + conns chan net.Conn + backendCount int + hostname string +} + +func (s *Server) newProxySession(secret string, hostname string) *proxySession { + session := &proxySession{ + server: s, + secret: secret, + conns: make(chan net.Conn, 500), + hostname: hostname, + } + s.Lock() + defer s.Unlock() + s.hostnameMap[hostname] = session + s.secretMap[secret] = session + + return session +} + +func (s *proxySession) backendConnected(conn net.Conn) { + s.server.logger.Debug("backend connected", + zap.String("remoteAddr", conn.RemoteAddr().String()), + zap.String("hostname", s.hostname), + ) + s.Lock() + s.server.logger.Debug("backend connected (inside lock)", + zap.String("remoteAddr", conn.RemoteAddr().String()), + zap.String("hostname", s.hostname), + ) + s.backendCount++ + s.Unlock() + conn.Write([]byte(s.hostname)) + s.conns <- conn + s.server.logger.Debug("backend connected (after chan)", + zap.String("remoteAddr", conn.RemoteAddr().String()), + zap.String("hostname", s.hostname), + ) +} + +func (s *proxySession) backendDisconnected() { + s.Lock() + defer s.Unlock() + s.backendCount-- + if s.backendCount == 0 { + s.server.Lock() + defer s.server.Unlock() + delete(s.server.hostnameMap, s.hostname) + delete(s.server.secretMap, s.secret) + close(s.conns) + fmt.Printf("all backends disconnected for %s\n", s.hostname) + } +} + +func (s *proxySession) getBackend() net.Conn { + s.RLock() + if s.backendCount == 0 { + s.server.logger.Error("no backends available") + return nil + } + s.RUnlock() + backend, ok := <-s.conns + if !ok { + s.server.logger.Debug("conns closed", zap.String("hostname", s.hostname)) + return nil + } + s.server.logger.Debug("got backend", + zap.String("backendAddr", backend.RemoteAddr().String()), + ) + _, err := backend.Write([]byte("frontend-connected")) + if err != nil { + s.server.logger.Error("backend rejected frontend-connected", zap.Error(err)) + backend.Close() + s.backendDisconnected() + return s.getBackend() + } + return backend +} + +func (s *Server) handleBackend(conn net.Conn) { + logger := s.logger.With( + zap.String("remoteAddr", conn.RemoteAddr().String()), + ) + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + logger.Error("unable to read command from conn", zap.Error(err)) + conn.Close() + return + } + rawCmd := string(buf[:n]) + cmdParts := strings.Split(rawCmd, ":") + cmdLen := len(cmdParts) + if cmdLen != 3 { + logger.Error("invalid command from conn") + conn.Close() + return + } + cmd := cmdParts[0] + secret := cmdParts[1] + s.Lock() + session, existingSessionFound := s.secretMap[secret] + s.Unlock() + if cmd == "backend-shutdown" { + defer conn.Close() + if !existingSessionFound { + logger.Error("invalid shutdown command", zap.String("cmd", cmd)) + return + } + logger.Debug("shutdown requested", zap.String("hostname", session.hostname)) + s.Lock() + defer s.Unlock() + delete(s.hostnameMap, session.hostname) + delete(s.secretMap, session.secret) + return + } + if cmd != "backend-open" { + logger.Error("unknown command", zap.String("cmd", cmd)) + conn.Close() + return + } + if existingSessionFound { + logger.Debug("existing session found", zap.String("hostname", session.hostname)) + session.backendConnected(conn) + return + } + hostname := cmdParts[2] + if hostname == "" { + hostname = s.getHostname() + } + if !strings.HasSuffix(hostname, s.basename) { + logger.Error("requested hostname needs basename", + zap.String("hostname", hostname), + zap.String("basename", s.basename), + ) + conn.Close() + return + } + if strings.HasPrefix(hostname, "control.") { + logger.Error("ignoring request for control") + conn.Close() + return + } + // test hostname exists (secret mismatch) + s.Lock() + _, exists := s.hostnameMap[hostname] + s.Unlock() + // test wildcard exists + _, wildcardExists := s.getWildcardHostname(hostname) + if exists || wildcardExists { + logger.Error("hostname already exists", zap.String("hostname", hostname)) + conn.Close() + return + } + + session = s.newProxySession(secret, hostname) + logger.Info("new session", zap.String("hostname", hostname)) + session.backendConnected(conn) +} + +func (s *Server) handleFrontend(ctx context.Context, conn net.Conn, serverName string) { + s.Lock() + session, ok := s.hostnameMap[serverName] + s.Unlock() + + // look for wildcard match in hostnameMap + wildcardHostname, wildcardExists := s.getWildcardHostname(serverName) + if !ok && wildcardExists { + s.Lock() + session, ok = s.hostnameMap[wildcardHostname] + s.Unlock() + } + if !ok { + conn.Close() + return + } + backend := session.getBackend() + if backend == nil { + s.logger.Error("nil backend") + conn.Close() + return + } + + gNet.PipeConn(ctx, backend, conn) + backend.Close() + conn.Close() + session.backendDisconnected() +} + +func (s *Server) handleConnection(ctx context.Context, conn net.Conn, serverName string) { + if serverName == s.controlName { + s.logger.Debug("new control connection", zap.String("remoteAddr", conn.RemoteAddr().String())) + s.handleBackend(conn) + return + } + s.handleFrontend(ctx, conn, serverName) +}