diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..cc5c18b --- /dev/null +++ b/.envrc @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +export DIRENV_WARN_TIMEOUT=20s + +eval "$(devenv direnvrc)" + +# `use devenv` supports the same options as the `devenv shell` command. +# +# To silence all output, use `--quiet`. +# +# Example usage: use devenv --quiet --impure --option services.postgres.enable:bool true +use devenv diff --git a/.github/actions/prepare/action.yaml b/.github/actions/prepare/action.yaml index 7c16323..bd7eec7 100644 --- a/.github/actions/prepare/action.yaml +++ b/.github/actions/prepare/action.yaml @@ -4,7 +4,7 @@ runs: using: "composite" steps: - name: setup golang - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version-file: 'go.mod' - name: clear go mod cache @@ -12,7 +12,7 @@ runs: shell: bash - name: connect go mod cache id: cache-gomod - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: /tmp/go/pkg/mod key: ${{ runner.os }}-golang-${{ hashFiles('**/go.sum') }} diff --git a/.github/workflows/dependencies.yaml b/.github/workflows/dependencies.yaml index 490357a..3722520 100644 --- a/.github/workflows/dependencies.yaml +++ b/.github/workflows/dependencies.yaml @@ -9,16 +9,16 @@ jobs: runs-on: ubuntu-latest steps: - name: checkout repo - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: setup golang - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version-file: 'go.mod' - name: clear go mod cache run: sudo rm -rf $GOMODCACHE - name: connect go mod cache id: cache-gomod - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: /tmp/go/pkg/mod key: ${{ runner.os }}-golang-${{ hashFiles('**/go.sum') }} diff --git a/.github/workflows/linters.yaml b/.github/workflows/linters.yaml index 6a17251..a34fa76 100644 --- a/.github/workflows/linters.yaml +++ b/.github/workflows/linters.yaml @@ -9,12 +9,12 @@ jobs: runs-on: ubuntu-latest steps: - name: checkout repo - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: prepare uses: ./.github/actions/prepare - name: setup golangci-lint - uses: golangci/golangci-lint-action@v6 + uses: golangci/golangci-lint-action@v9 with: - version: v1.60.3 + version: v2.6 - name: run golangci-lint run: golangci-lint run diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 5acd7bc..e03664b 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -9,25 +9,25 @@ jobs: runs-on: ubuntu-latest steps: - name: checkout repo - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: prepare uses: ./.github/actions/prepare - name: run tests - run: go test -v -covermode atomic ./... --race --timeout 1m + run: go test -v -race -timeout=1m -covermode=atomic ./... integration: runs-on: ubuntu-latest steps: - name: checkout repo - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: prepare uses: ./.github/actions/prepare - name: run tests - run: go test -v ./test/... --race --timeout 1m -tags=integration + run: go test -v -race -timeout=1m -tags=integration ./test/... ws-autobahn: runs-on: ubuntu-latest steps: - name: checkout repo - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: prepare uses: ./.github/actions/prepare - name: make report directory diff --git a/.gitignore b/.gitignore index 49497f5..d09b745 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,14 @@ go.work # IDE files .idea/ + +# Devenv +.devenv* +devenv.local.nix +devenv.local.yaml + +# direnv +.direnv + +# pre-commit +.pre-commit-config.yaml diff --git a/.golangci.yml b/.golangci.yml index 54e6eb6..fe14ab8 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,125 +1,19 @@ +version: "2" + run: timeout: 5m - issues-exit-code: 1 - tests: false modules-download-mode: readonly allow-parallel-runners: true -output: - print-issued-lines: true - print-linter-name: true - linters: - disable: - - depguard - - exhaustruct - - fatcontext - - gochecknoglobals - - goconst - - godot - - godox - - nonamedreturns - - nlreturn - - varnamelen - - inamedparam - - wsl - - wrapcheck - presets: - - bugs - - comment - - complexity - - error - - format - - import - - metalinter - - module - - performance - - style - - test - - unused - -linters-settings: - exhaustive: - ignore-enum-members: '^*(Unspecified|Undefined|Unknown|Idle)$' - cyclop: - max-complexity: 20 - gomnd: - ignored-numbers: - - '2' - - '4' - - '8' - - '16' - - '32' - - '64' - gci: - sections: - - standard # Standard section: captures all standard packages. - - default # Default section: contains all imports that could not be matched to another section type. - - prefix(github.com/einouqo/ext-kit) # Custom section: groups all imports with the specified Prefix. - - blank # Blank section: contains all blank imports. This section is not present unless explicitly enabled. - - dot - tagliatelle: - case: - rules: - json: snake - yaml: snake + exclusions: + rules: + # an exception, as the error value is used to signal the end of the stream + - linters: + - staticcheck + source: "StreamDone" + # as documentation states: "... This method always returns a nil error" https://github.com/grpc/grpc-go/blob/v1.78.0/stream.go#L105-L108 + - linters: + - errcheck + source: "defer .*\\.CloseSend()" -issues: - exclude-dirs: - - test # no lints for tests - exclude-files: - - .*_test.go$ # unit tests - - transport/ws/intercepting_writer.go # code from http package of go-kit project (https://github.com/go-kit/kit/blob/master/transport/http/intercepting_writer.go). - exclude-rules: - - linters: - - wrapcheck - text: 'error returned from interface method should be wrapped' - source: ^*return .* - - linters: - - wrapcheck - text: ^error returned from external package is unwrapped[:] .*multierror.* error$ - - linters: - - wrapcheck - source: .*status.Error\(.*, .*\).* - - linters: - - cyclop - text: 'calculated cyclomatic complexity for function' - source: ^func \(.*\) String\(\) string {$ - - linters: - - gomnd - source: .*(time\..|[0-9]+<<[0-9]+)*$ - - linters: - - gofumpt - source: ^var \( - - linters: - - unparam - # except setters with self returning - text: .*\)\.set([A-z0-9]+)? - result 0 \(\*.*\) is never used$ - - linters: - - gosec - text: 'G109:' - - linters: - - gosec - text: 'G404:' - - linters: - - lll - source: .*// .* - - linters: - - stylecheck - text: 'ST1012:' - - linters: - - forbidigo - text: use of `fmt\..*` forbidden by pattern `.*`$ - - path: transport # ok for transport - linters: [ gocognit, gocyclo, cyclop, forcetypeassert, funlen, ireturn, lll ] - - path: endpoint - linters: [ errname, revive ] - - linters: - - gocritic - text: 'appendAssign: append result not assigned to the same slice' - - linters: - - errcheck - source: ^\s*defer .*$ - - linters: - - revive - text: 'if-return: redundant if ...;' diff --git a/README.md b/README.md index 31276e8..eb497ff 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Go ext kit -![Go Version](https://img.shields.io/badge/go-1.22+-blue.svg) +![Go Version](https://img.shields.io/github/go-mod/go-version/einouqo/ext-kit) [![Go Report Card](https://goreportcard.com/badge/github.com/einouqo/ext-kit)](https://goreportcard.com/report/github.com/einouqo/ext-kit) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) @@ -32,40 +32,43 @@ go get github.com/einouqo/ext-kit You can refer to the [tests](test/transport/grpc) for more examples. **Server:** + ```go type Service interface { - Bi(ctx context.Context, receiver <-chan string) (endpoint.Receive[string], error) + Bi(ctx context.Context, receiver <-chan string) (endpoint.Receive[string], error) } func NewServerBinding(svc Service, opts ...kitgrpc.ServerOption) *ServerBinding { - return &ServerBinding{ - /* ... */ - biStream: kitgrpc.NewServerBiStream[*pb.EchoRequest]( - svc.Bi, - decodeRequest, - encodeResponse, - opts..., - ), - } + return &ServerBinding{ + /* ... */ + biStream: kitgrpc.NewServerBiStream[*pb.EchoRequest]( + svc.Bi, + decodeRequest, + encodeResponse, + opts..., + ), + } } ``` **Client:** + ```go func NewClientBinding(cc *grpc.ClientConn) *ClientBinding { - return &ClientBinding{ - /* ... */ - BiStream: kitgrpc.NewClientBiStream[*pb.EchoResponse]( - cc, - pb.Echo_BiStream_FullMethodName, - encodeRequest, - decodeResponse, - ).Endpoint(), - } + return &ClientBinding{ + /* ... */ + BiStream: kitgrpc.NewClientBiStream[*pb.EchoResponse]( + cc, + pb.Echo_BiStream_FullMethodName, + encodeRequest, + decodeResponse, + ).Endpoint(), + } } ``` Make a call: + ```go send := make(chan service.EchoRequest) // send your requests to the channel in the way you want receive, err := client.BiStream(ctx, send) @@ -85,6 +88,7 @@ for { ``` #### WebSocket + The usage is pretty close to gRPC Bi-Directional Streaming (the example above), but with WebSocket transport inside. You can also refer to the [tests](test/transport/ws) or [autobahn](test/transport/autobahn) implementation for more examples. diff --git a/devenv.lock b/devenv.lock new file mode 100644 index 0000000..ec447dc --- /dev/null +++ b/devenv.lock @@ -0,0 +1,123 @@ +{ + "nodes": { + "devenv": { + "locked": { + "dir": "src/modules", + "lastModified": 1770744655, + "owner": "cachix", + "repo": "devenv", + "rev": "d8bd7b74d0604227220074ac0bc934c4efb2b8fb", + "type": "github" + }, + "original": { + "dir": "src/modules", + "owner": "cachix", + "repo": "devenv", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1767039857, + "owner": "NixOS", + "repo": "flake-compat", + "rev": "5edf11c44bc78a0d334f6334cdaf7d60d732daab", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "flake-compat", + "type": "github" + } + }, + "git-hooks": { + "inputs": { + "flake-compat": "flake-compat", + "gitignore": "gitignore", + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1770726378, + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "5eaaedde414f6eb1aea8b8525c466dc37bba95ae", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "git-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1762808025, + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "cb5e3fdca1de58ccbc3ef53de65bd372b48f567c", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "nixpkgs": { + "inputs": { + "nixpkgs-src": "nixpkgs-src" + }, + "locked": { + "lastModified": 1770434727, + "owner": "cachix", + "repo": "devenv-nixpkgs", + "rev": "8430f16a39c27bdeef236f1eeb56f0b51b33d348", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "rolling", + "repo": "devenv-nixpkgs", + "type": "github" + } + }, + "nixpkgs-src": { + "flake": false, + "locked": { + "lastModified": 1769922788, + "narHash": "sha256-H3AfG4ObMDTkTJYkd8cz1/RbY9LatN5Mk4UF48VuSXc=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "207d15f1a6603226e1e223dc79ac29c7846da32e", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "devenv": "devenv", + "git-hooks": "git-hooks", + "nixpkgs": "nixpkgs", + "pre-commit-hooks": [ + "git-hooks" + ] + } + } + }, + "root": "root", + "version": 7 +} diff --git a/devenv.nix b/devenv.nix new file mode 100644 index 0000000..25699ea --- /dev/null +++ b/devenv.nix @@ -0,0 +1,46 @@ +{pkgs, config, lib, ...}: { + languages.go = { + enable = true; + }; + + packages = with pkgs; [ + golangci-lint + ]; + + scripts = { + lint = { + exec = '' + golangci-lint run + ''; + description = "Run golangci-lint to check for code issues."; + }; + lint-fix = { + exec = '' + golangci-lint run --fix + ''; + description = "Run golangci-lint with --fix to automatically fix issues."; + }; + + test-unit = { + exec = '' + go test -v -race -timeout=1m -covermode=atomic ./... + ''; + description = "Run unit tests for all packages."; + }; + test-integration = { + exec = '' + go test -v -race -timeout=1m -tags=integration ./test/... + ''; + description = "Run integration tests."; + }; + }; + + enterShell = '' + echo + echo "Available scripts:" + ${pkgs.gnused}/bin/sed -e 's| |••|g' -e 's|=| |' < 0 { + if len(c.opts.errorHandlers) > 0 { defer func() { if err != nil { - for _, h := range c.opts.errHandlers { - h.Handle(ctx, err) - } + c.opts.errorHandlers.Handle(ctx, err) } }() } @@ -163,7 +161,7 @@ func (c *ClientInnerStream[OUT, IN]) Endpoint() endpoint.InnerStream[OUT, IN] { return nil, err } - inCh := make(chan IN) + ins := make(chan IN) group := errgroup.Group{} group.Go(func() (err error) { if c.opts.finalizer != nil { @@ -173,7 +171,7 @@ func (c *ClientInnerStream[OUT, IN]) Endpoint() endpoint.InnerStream[OUT, IN] { } }() } - defer close(inCh) + defer close(ins) for { msg := reflect.New(c.reply).Interface().(proto.Message) err = stream.RecvMsg(msg) @@ -187,26 +185,24 @@ func (c *ClientInnerStream[OUT, IN]) Endpoint() endpoint.InnerStream[OUT, IN] { if err != nil { return err } - inCh <- in + ins <- in } }) - errCh := make(chan error) + errs := make(chan error) go func() { - defer close(errCh) + defer close(errs) if err := group.Wait(); err != nil { - for _, h := range c.opts.errHandlers { - h.Handle(ctx, err) - } - errCh <- err + c.opts.errorHandlers.Handle(ctx, err) + errs <- err } }() return func() (in IN, err error) { - if in, ok := <-inCh; ok { + if in, ok := <-ins; ok { return in, nil } - if err, ok := <-errCh; ok { + if err, ok := <-errs; ok { return in, err } return in, endpoint.StreamDone @@ -242,12 +238,10 @@ func NewClientOuterStream[REPLY proto.Message, OUT, IN any]( func (c *ClientOuterStream[OUT, IN]) Endpoint() endpoint.OuterStream[OUT, IN] { return func(ctx context.Context, receiver <-chan OUT) (in IN, err error) { - if len(c.opts.errHandlers) > 0 { + if len(c.opts.errorHandlers) > 0 { defer func() { if err != nil { - for _, h := range c.opts.errHandlers { - h.Handle(ctx, err) - } + c.opts.errorHandlers.Handle(ctx, err) } }() } @@ -297,9 +291,9 @@ func (c *ClientOuterStream[OUT, IN]) Endpoint() endpoint.OuterStream[OUT, IN] { } return nil }) - inCh := make(chan IN, 1) + ins := make(chan IN, 1) group.Go(func() error { - defer close(inCh) + defer close(ins) defer cancel() msg := reflect.New(c.reply).Interface().(proto.Message) err := stream.RecvMsg(msg) @@ -313,7 +307,7 @@ func (c *ClientOuterStream[OUT, IN]) Endpoint() endpoint.OuterStream[OUT, IN] { if err != nil { return err } - inCh <- in + ins <- in return nil }) if err := group.Wait(); err != nil { @@ -324,7 +318,7 @@ func (c *ClientOuterStream[OUT, IN]) Endpoint() endpoint.OuterStream[OUT, IN] { ctx = f(ctx, header, trailer) } - return <-inCh, nil + return <-ins, nil } } @@ -395,7 +389,7 @@ func (c *ClientBiStream[OUT, IN]) Endpoint() endpoint.BiStream[OUT, IN] { } return nil }) - inCh := make(chan IN) + ins := make(chan IN) group.Go(func() (err error) { if c.opts.finalizer != nil { defer func() { @@ -404,7 +398,7 @@ func (c *ClientBiStream[OUT, IN]) Endpoint() endpoint.BiStream[OUT, IN] { } }() } - defer close(inCh) + defer close(ins) for { msg := reflect.New(c.reply).Interface().(proto.Message) err := stream.RecvMsg(msg) @@ -418,26 +412,24 @@ func (c *ClientBiStream[OUT, IN]) Endpoint() endpoint.BiStream[OUT, IN] { if err != nil { return err } - inCh <- in + ins <- in } }) - errCh := make(chan error) + errs := make(chan error) go func() { - defer close(errCh) + defer close(errs) if err := group.Wait(); err != nil { - for _, h := range c.opts.errHandlers { - h.Handle(ctx, err) - } - errCh <- err + c.opts.errorHandlers.Handle(ctx, err) + errs <- err } }() return func() (in IN, err error) { - if in, ok := <-inCh; ok { + if in, ok := <-ins; ok { return in, nil } - if err, ok := <-errCh; ok { + if err, ok := <-errs; ok { return in, err } return in, endpoint.StreamDone @@ -448,10 +440,10 @@ func (c *ClientBiStream[OUT, IN]) Endpoint() endpoint.BiStream[OUT, IN] { type clientOptions struct { callOpts []grpc.CallOption - before []kitgrpc.ClientRequestFunc - after []kitgrpc.ClientResponseFunc - finalizer []kitgrpc.ClientFinalizerFunc - errHandlers []transport.ErrorHandler + before []kitgrpc.ClientRequestFunc + after []kitgrpc.ClientResponseFunc + finalizer []kitgrpc.ClientFinalizerFunc + errorHandlers transport.ErrorHandlers } type client[OUT, IN any] struct { diff --git a/transport/grpc/option.go b/transport/grpc/option.go index 7a77745..cc3c12d 100644 --- a/transport/grpc/option.go +++ b/transport/grpc/option.go @@ -36,7 +36,7 @@ func WithClientFinalizer(finalize kitgrpc.ClientFinalizerFunc) ClientOption { func WithClientErrorHandler(handler transport.ErrorHandler) ClientOption { return funcClientOption{f: func(o *clientOptions) { - o.errHandlers = append(o.errHandlers, handler) + o.errorHandlers = append(o.errorHandlers, handler) }} } @@ -72,7 +72,7 @@ func WithServerFinalizer(finalize kitgrpc.ServerFinalizerFunc) ServerOption { func WithServerErrorHandler(handler transport.ErrorHandler) ServerOption { return funcServerOption{f: func(o *serverOptions) { - o.errHandlers = append(o.errHandlers, handler) + o.errorHandlers = append(o.errorHandlers, handler) }} } diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 6a87598..605dfca 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -6,7 +6,6 @@ import ( "io" "reflect" - "github.com/go-kit/kit/transport" kitgrpc "github.com/go-kit/kit/transport/grpc" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -14,6 +13,7 @@ import ( "google.golang.org/protobuf/proto" "github.com/einouqo/ext-kit/endpoint" + "github.com/einouqo/ext-kit/transport" ) type HandlerUnary interface { @@ -55,20 +55,18 @@ func NewServerUnary[IN, OUT any]( return s } -func (srv ServerUnary[IN, OUT]) ServeUnary(ctx context.Context, req proto.Message) (ctxt context.Context, resp proto.Message, err error) { - if len(srv.opts.finalizer) > 0 { +func (s *ServerUnary[IN, OUT]) ServeUnary(ctx context.Context, req proto.Message) (ctxt context.Context, resp proto.Message, err error) { + if len(s.opts.finalizer) > 0 { defer func() { - for _, f := range srv.opts.finalizer { + for _, f := range s.opts.finalizer { f(ctx, err) } }() } - if len(srv.opts.errHandlers) > 0 { + if len(s.opts.errorHandlers) > 0 { defer func() { if err != nil { - for _, h := range srv.opts.errHandlers { - h.Handle(ctx, err) - } + s.opts.errorHandlers.Handle(ctx, err) } }() } @@ -77,22 +75,22 @@ func (srv ServerUnary[IN, OUT]) ServeUnary(ctx context.Context, req proto.Messag if !ok { md = metadata.MD{} } - for _, f := range srv.opts.before { + for _, f := range s.opts.before { ctx = f(ctx, md) } - in, err := srv.dec(ctx, req) + in, err := s.dec(ctx, req) if err != nil { return ctx, nil, err } - out, err := srv.e(ctx, in) + out, err := s.e(ctx, in) if err != nil { return ctx, nil, err } mdHeader, mdTrailer := make(metadata.MD), make(metadata.MD) - for _, f := range srv.opts.after { + for _, f := range s.opts.after { ctx = f(ctx, &mdHeader, &mdTrailer) } if len(mdHeader) > 0 { @@ -101,7 +99,7 @@ func (srv ServerUnary[IN, OUT]) ServeUnary(ctx context.Context, req proto.Messag } } - resp, err = srv.enc(ctx, out) + resp, err = s.enc(ctx, out) if err != nil { return ctx, nil, err } @@ -138,22 +136,20 @@ func NewServerInnerStream[IN, OUT any]( return s } -func (srv ServerInnerStream[IN, OUT]) ServeInnerStream(req proto.Message, s grpc.ServerStream) (ctx context.Context, err error) { - ctx = s.Context() +func (s *ServerInnerStream[IN, OUT]) ServeInnerStream(req proto.Message, stream grpc.ServerStream) (ctx context.Context, err error) { + ctx = stream.Context() - if len(srv.opts.finalizer) > 0 { + if len(s.opts.finalizer) > 0 { defer func() { - for _, f := range srv.opts.finalizer { + for _, f := range s.opts.finalizer { f(ctx, err) } }() } - if len(srv.opts.errHandlers) > 0 { + if len(s.opts.errorHandlers) > 0 { defer func() { if err != nil { - for _, h := range srv.opts.errHandlers { - h.Handle(ctx, err) - } + s.opts.errorHandlers.Handle(ctx, err) } }() } @@ -162,22 +158,25 @@ func (srv ServerInnerStream[IN, OUT]) ServeInnerStream(req proto.Message, s grpc if !ok { md = metadata.MD{} } - for _, f := range srv.opts.before { + for _, f := range s.opts.before { ctx = f(ctx, md) } - in, err := srv.dec(ctx, req) + in, err := s.dec(ctx, req) if err != nil { return ctx, err } - receive, err := srv.e(ctx, in) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + receive, err := s.e(ctx, in) if err != nil { return ctx, err } mdHeader, mdTrailer := make(metadata.MD), make(metadata.MD) - for _, f := range srv.opts.after { + for _, f := range s.opts.after { ctx = f(ctx, &mdHeader, &mdTrailer) } @@ -188,6 +187,7 @@ func (srv ServerInnerStream[IN, OUT]) ServeInnerStream(req proto.Message, s grpc return err } } + defer cancel() for { out, err := receive() switch { @@ -196,11 +196,11 @@ func (srv ServerInnerStream[IN, OUT]) ServeInnerStream(req proto.Message, s grpc case err != nil: return err } - msg, err := srv.enc(ctx, out) + msg, err := s.enc(ctx, out) if err != nil { return err } - err = s.SendMsg(msg) + err = stream.SendMsg(msg) switch { case errors.Is(err, io.EOF): return io.ErrUnexpectedEOF @@ -247,22 +247,20 @@ func NewServerOuterStream[RECEIVE proto.Message, IN, OUT any]( return s } -func (srv ServerOuterStream[IN, OUT]) ServeOuterStream(s grpc.ServerStream) (ctx context.Context, err error) { - ctx = s.Context() +func (s *ServerOuterStream[IN, OUT]) ServeOuterStream(stream grpc.ServerStream) (ctx context.Context, err error) { + ctx = stream.Context() - if len(srv.opts.finalizer) > 0 { + if len(s.opts.finalizer) > 0 { defer func() { - for _, f := range srv.opts.finalizer { + for _, f := range s.opts.finalizer { f(ctx, err) } }() } - if len(srv.opts.errHandlers) > 0 { + if len(s.opts.errorHandlers) > 0 { defer func() { if err != nil { - for _, h := range srv.opts.errHandlers { - h.Handle(ctx, err) - } + s.opts.errorHandlers.Handle(ctx, err) } }() } @@ -271,43 +269,49 @@ func (srv ServerOuterStream[IN, OUT]) ServeOuterStream(s grpc.ServerStream) (ctx if !ok { md = metadata.MD{} } - for _, f := range srv.opts.before { + for _, f := range s.opts.before { ctx = f(ctx, md) } - inCh := make(chan IN) - doneCh := make(chan struct{}) + done := make(chan struct{}) + ins := make(chan IN) group := errgroup.Group{} group.Go(func() error { - defer close(inCh) + defer close(ins) for { - msg := reflect.New(srv.receive).Interface().(proto.Message) - err := s.RecvMsg(msg) + msg := reflect.New(s.receive).Interface().(proto.Message) + err := stream.RecvMsg(msg) switch { case errors.Is(err, io.EOF): return nil case err != nil: return err } - in, err := srv.dec(ctx, msg) + in, err := s.dec(ctx, msg) if err != nil { return err } select { - case <-doneCh: + case <-done: return nil - case inCh <- in: + default: + // endpoint is not done, continue + } + select { + case <-done: + return nil + case ins <- in: } } }) group.Go(func() error { - defer close(doneCh) - out, err := srv.e(ctx, inCh) + defer close(done) + out, err := s.e(ctx, ins) if err != nil { return err } mdHeader, mdTrailer := make(metadata.MD), make(metadata.MD) - for _, f := range srv.opts.after { + for _, f := range s.opts.after { ctx = f(ctx, &mdHeader, &mdTrailer) } if len(mdHeader) > 0 { @@ -316,11 +320,11 @@ func (srv ServerOuterStream[IN, OUT]) ServeOuterStream(s grpc.ServerStream) (ctx } } - msg, err := srv.enc(ctx, out) + msg, err := s.enc(ctx, out) if err != nil { return err } - err = s.SendMsg(msg) + err = stream.SendMsg(msg) switch { case errors.Is(err, io.EOF): return nil @@ -335,6 +339,7 @@ func (srv ServerOuterStream[IN, OUT]) ServeOuterStream(s grpc.ServerStream) (ctx } return nil }) + if err := group.Wait(); err != nil { return ctx, err } @@ -346,13 +351,13 @@ type ServerBiStream[IN, OUT any] struct { server[IN, OUT, endpoint.BiStream[IN, OUT]] } -func NewServerBiStream[RECEIVE proto.Message, IN, OUT any]( +func NewServerBiStream[M proto.Message, IN, OUT any]( e endpoint.BiStream[IN, OUT], dec DecodeFunc[IN], enc EncodeFunc[OUT], opts ...ServerOption, ) *ServerBiStream[IN, OUT] { - var receive RECEIVE + var receive M s := &ServerBiStream[IN, OUT]{ server: server[IN, OUT, endpoint.BiStream[IN, OUT]]{ e: e, @@ -367,22 +372,20 @@ func NewServerBiStream[RECEIVE proto.Message, IN, OUT any]( return s } -func (srv ServerBiStream[IN, OUT]) ServeBiStream(s grpc.ServerStream) (ctx context.Context, err error) { - ctx = s.Context() +func (s *ServerBiStream[IN, OUT]) ServeBiStream(stream grpc.ServerStream) (ctx context.Context, err error) { + ctx = stream.Context() - if len(srv.opts.finalizer) > 0 { + if len(s.opts.finalizer) > 0 { defer func() { - for _, f := range srv.opts.finalizer { + for _, f := range s.opts.finalizer { f(ctx, err) } }() } - if len(srv.opts.errHandlers) > 0 { + if len(s.opts.errorHandlers) > 0 { defer func() { if err != nil { - for _, h := range srv.opts.errHandlers { - h.Handle(ctx, err) - } + s.opts.errorHandlers.Handle(ctx, err) } }() } @@ -391,47 +394,55 @@ func (srv ServerBiStream[IN, OUT]) ServeBiStream(s grpc.ServerStream) (ctx conte if !ok { md = metadata.MD{} } - for _, f := range srv.opts.before { + for _, f := range s.opts.before { ctx = f(ctx, md) } - inCh := make(chan IN) - receive, err := srv.e(ctx, inCh) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + ins := make(chan IN) + receive, err := s.e(ctx, ins) if err != nil { return ctx, err } mdHeader, mdTrailer := make(metadata.MD), make(metadata.MD) - for _, f := range srv.opts.after { + for _, f := range s.opts.after { ctx = f(ctx, &mdHeader, &mdTrailer) } - doneCh := make(chan struct{}) group := errgroup.Group{} group.Go(func() error { - defer close(inCh) + defer close(ins) for { - msg := reflect.New(srv.receive).Interface().(proto.Message) - err := s.RecvMsg(msg) + msg := reflect.New(s.receive).Interface().(proto.Message) + err := stream.RecvMsg(msg) switch { case errors.Is(err, io.EOF): return nil case err != nil: return err } - in, err := srv.dec(ctx, msg) + in, err := s.dec(ctx, msg) if err != nil { return err } select { - case <-doneCh: + case <-ctx.Done(): + return nil + default: + // context is not done, continue + } + select { + case <-ctx.Done(): return nil - case inCh <- in: + case ins <- in: } } }) group.Go(func() error { - defer close(doneCh) + defer cancel() if len(mdHeader) > 0 { if err = grpc.SendHeader(ctx, mdHeader); err != nil { return err @@ -445,11 +456,11 @@ func (srv ServerBiStream[IN, OUT]) ServeBiStream(s grpc.ServerStream) (ctx conte case err != nil: return err } - msg, err := srv.enc(ctx, out) + msg, err := s.enc(ctx, out) if err != nil { return err } - err = s.SendMsg(msg) + err = stream.SendMsg(msg) switch { case errors.Is(err, io.EOF): return nil @@ -472,10 +483,10 @@ func (srv ServerBiStream[IN, OUT]) ServeBiStream(s grpc.ServerStream) (ctx conte } type serverOptions struct { - before []kitgrpc.ServerRequestFunc - after []kitgrpc.ServerResponseFunc - finalizer []kitgrpc.ServerFinalizerFunc - errHandlers []transport.ErrorHandler + before []kitgrpc.ServerRequestFunc + after []kitgrpc.ServerResponseFunc + finalizer []kitgrpc.ServerFinalizerFunc + errorHandlers transport.ErrorHandlers } type server[IN, OUT any, E endpoint.Endpoint[IN, OUT]] struct { diff --git a/transport/ws/buffer_pool.go b/transport/ws/buffer_pool.go new file mode 100644 index 0000000..a56bf5b --- /dev/null +++ b/transport/ws/buffer_pool.go @@ -0,0 +1,12 @@ +package ws + +import "github.com/fasthttp/websocket" + +type pool[T any] interface { + Get() T + Put(T) +} + +type BufferPool = pool[any] + +var _ BufferPool = websocket.BufferPool(nil) diff --git a/transport/ws/client.go b/transport/ws/client.go index 9043bba..c514824 100644 --- a/transport/ws/client.go +++ b/transport/ws/client.go @@ -8,9 +8,10 @@ import ( "net/url" "sync" "syscall" + "time" + "github.com/einouqo/ext-kit/transport" "github.com/fasthttp/websocket" - "github.com/go-kit/kit/transport" kithttp "github.com/go-kit/kit/transport/http" "golang.org/x/sync/errgroup" @@ -49,38 +50,46 @@ func NewClient[OUT, IN any]( func (c *Client[OUT, IN]) Endpoint() endpoint.BiStream[OUT, IN] { return func(ctx context.Context, receiver <-chan OUT) (rcv endpoint.Receive[IN], err error) { headers := make(http.Header) - dialer := dialler{websocket.DefaultDialer} + dialer := &dialler{ + ws: &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + }, + } for _, f := range c.opts.before { ctx = f(ctx, dialer, headers) } - wsc, resp, err := dialer.DialContext(ctx, c.url.String(), headers) + wsc, resp, err := dialer.ws.DialContext(ctx, c.url.String(), headers) if err != nil { return nil, err } - defer func() { - if err != nil { - _ = wsc.Close() - } - }() - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + + conn := &conn{ + ws: wsc, + config: c.opts.connection.config, + } for _, f := range c.opts.after { - ctx = f(ctx, resp, wsc) + ctx = f(ctx, resp, conn) } - conn := enhConn{wsc, c.opts.enhancement.config} + ctx, cancel := context.WithCancel(ctx) once := sync.Once{} - doneCh := make(chan struct{}) - group := errgroup.Group{} - group.Go(func() (err error) { - defer close(doneCh) - defer once.Do(func() { + leave := func(err error) { + once.Do(func() { code, msg, deadline := c.closure(ctx, err) data := websocket.FormatCloseMessage(code.fastsocket(), msg) _ = conn.WriteControl(websocket.CloseMessage, data, deadline) }) + } + + group := errgroup.Group{} + group.Go(func() (err error) { + defer func() { leave(err) }() + defer cancel() for { select { case <-ctx.Done(): @@ -95,11 +104,10 @@ func (c *Client[OUT, IN]) Endpoint() endpoint.BiStream[OUT, IN] { } err = conn.WriteMessage(mt.fastsocket(), msg) switch { - case errors.Is(err, net.ErrClosed): - return nil - case errors.Is(err, syscall.EPIPE): // broken pipe can appear on closed underlying tcp connection by peer - return nil - case errors.Is(err, websocket.ErrCloseSent): + case + errors.Is(err, net.ErrClosed), + errors.Is(err, syscall.EPIPE), // broken pipe can appear on closed underlying TCP connection by peer + errors.Is(err, websocket.ErrCloseSent): return nil case err != nil: return err @@ -107,7 +115,7 @@ func (c *Client[OUT, IN]) Endpoint() endpoint.BiStream[OUT, IN] { } } }) - inCh := make(chan IN) + ins := make(chan IN) group.Go(func() (err error) { if c.opts.finalizer != nil { defer func() { @@ -116,13 +124,8 @@ func (c *Client[OUT, IN]) Endpoint() endpoint.BiStream[OUT, IN] { } }() } - defer close(inCh) - defer conn.Close() - defer once.Do(func() { - code, msg, deadline := c.closure(ctx, err) - data := websocket.FormatCloseMessage(code.fastsocket(), msg) - _ = conn.WriteControl(websocket.CloseMessage, data, deadline) - }) + defer func() { leave(err) }() + defer close(ins) for { messageType, msg, err := conn.ReadMessage() switch { @@ -136,34 +139,43 @@ func (c *Client[OUT, IN]) Endpoint() endpoint.BiStream[OUT, IN] { if err != nil { return err } - inCh <- in + select { + case <-ctx.Done(): + return nil + default: + // context is not done, continue + } + select { + case <-ctx.Done(): + return nil + case ins <- in: + } } }) if c.opts.heartbeat.enable { - group.Go(func() error { - defer conn.Close() - return heartbeat(ctx, c.opts.heartbeat.config, conn, doneCh) + group.Go(func() (err error) { + defer func() { leave(err) }() + return heartbeat(ctx, c.opts.heartbeat.config, conn) }) } - errCh := make(chan error) + errs := make(chan error) go func() { - defer close(errCh) + defer func() { _ = conn.Close() }() + defer close(errs) err := group.Wait() if err != nil { - for _, h := range c.opts.errHandlers { - h.Handle(ctx, err) - } - errCh <- err + c.opts.errorHandlers.Handle(ctx, err) + errs <- err } }() return func() (in IN, err error) { - if in, ok := <-inCh; ok { + if in, ok := <-ins; ok { return in, nil } - if err, ok := <-errCh; ok { + if err, ok := <-errs; ok { return in, err } return in, endpoint.StreamDone @@ -172,17 +184,17 @@ func (c *Client[OUT, IN]) Endpoint() endpoint.BiStream[OUT, IN] { } type clientOptions struct { - before []DiallerFunc - after []ClientTunerFunc - finalizer []kithttp.ClientFinalizerFunc - errHandlers []transport.ErrorHandler + before []DiallerFunc + after []ClientTunerFunc + finalizer []kithttp.ClientFinalizerFunc + errorHandlers transport.ErrorHandlers - enhancement struct { - config enhConfig + connection struct { + config connConfig } heartbeat struct { enable bool - config hbConfig + config heartbeatConfig } } diff --git a/transport/ws/client_server_funcs.go b/transport/ws/client_server_funcs.go index fe098b8..932bc7c 100644 --- a/transport/ws/client_server_funcs.go +++ b/transport/ws/client_server_funcs.go @@ -11,9 +11,13 @@ type Closure func(context.Context, error) (code CloseCode, msg string, deadline type Pinging func(context.Context) (msg []byte, deadline time.Time) type DiallerFunc func(context.Context, Dialler, http.Header) context.Context - type UpgradeFunc func(context.Context, Upgrader, *http.Request, http.Header) context.Context -type ClientTunerFunc func(context.Context, *http.Response, Tuner) context.Context +type Tuner interface { + EnableWriteCompression(enable bool) + SetCompressionLevel(level int) error + SetReadLimit(limit int64) +} +type ClientTunerFunc func(context.Context, *http.Response, Tuner) context.Context type ServerTunerFunc func(context.Context, Tuner) context.Context diff --git a/transport/ws/connection.go b/transport/ws/connection.go index b4e2f26..3ab48b0 100644 --- a/transport/ws/connection.go +++ b/transport/ws/connection.go @@ -1,7 +1,6 @@ package ws import ( - "io" "time" "github.com/fasthttp/websocket" @@ -14,39 +13,7 @@ const ( WriteModPrepared ) -type connection interface { - Tuner - rw - controller - closer -} - -type Tuner interface { - EnableWriteCompression(enable bool) - SetCompressionLevel(level int) error - SetReadLimit(limit int64) -} - -type rw interface { - SetReadDeadline(t time.Time) error - ReadMessage() (messageType int, p []byte, err error) - SetWriteDeadline(t time.Time) error - WriteMessage(messageType int, data []byte) error - WritePreparedMessage(pm *websocket.PreparedMessage) error - NextWriter(messageType int) (io.WriteCloser, error) -} - -type controller interface { - WriteControl(messageType int, data []byte, deadline time.Time) error - PongHandler() func(string) error - SetPongHandler(h func(string) error) -} - -type closer interface { - Close() error -} - -type enhConfig struct { +type connConfig struct { read struct { timeout time.Duration } @@ -57,64 +24,71 @@ type enhConfig struct { } } -type enhConn struct { - connection - cfg enhConfig +type conn struct { + ws *websocket.Conn + config connConfig } -func (c enhConn) ReadMessage() (messageType int, p []byte, err error) { - if err := c.updateReadDeadline(c.connection); err != nil { +var ( + _ Tuner = (*conn)(nil) +) + +func (c *conn) EnableWriteCompression(enable bool) { c.ws.EnableWriteCompression(enable) } +func (c *conn) SetCompressionLevel(level int) error { return c.ws.SetCompressionLevel(level) } +func (c *conn) SetReadLimit(limit int64) { c.ws.SetReadLimit(limit) } + +func (c *conn) PongHandler() func(string) error { return c.ws.PongHandler() } +func (c *conn) SetPongHandler(h func(string) error) { c.ws.SetPongHandler(h) } + +func (c *conn) ReadMessage() (messageType int, p []byte, err error) { + if err := c.updateReadDeadline(c.ws); err != nil { return 0, nil, err } - return c.connection.ReadMessage() + return c.ws.ReadMessage() } -func (c enhConn) WriteMessage(messageType int, data []byte) error { - if err := c.updateWriteDeadline(c.connection); err != nil { - return err - } - switch c.cfg.write.mod { - case WriteModPlain: - return c.writePlain(messageType, data) - case WriteModPrepared: - return c.writePrepared(messageType, data) - default: - return c.connection.WriteMessage(messageType, data) +func (c *conn) updateWriteDeadline(conn *websocket.Conn) error { + if c.config.write.timeout > 0 { + deadline := time.Now().Add(c.config.write.timeout) + return conn.SetWriteDeadline(deadline) } + return nil } -func (c enhConn) writePlain(messageType int, data []byte) error { - w, err := c.connection.NextWriter(messageType) - if err != nil { +func (c *conn) WriteMessage(messageType int, data []byte) error { + if err := c.updateWriteDeadline(c.ws); err != nil { return err } - _, err = w.Write(data) - if err != nil { - return err - } - return w.Close() -} - -func (c enhConn) writePrepared(messageType int, data []byte) error { - pm, err := websocket.NewPreparedMessage(messageType, data) - if err != nil { + switch c.config.write.mod { + case WriteModPlain: + w, err := c.ws.NextWriter(messageType) + if err != nil { + return err + } + defer func() { _ = w.Close() }() + _, err = w.Write(data) return err + case WriteModPrepared: + pm, err := websocket.NewPreparedMessage(messageType, data) + if err != nil { + return err + } + return c.ws.WritePreparedMessage(pm) + default: + return c.ws.WriteMessage(messageType, data) } - return c.connection.WritePreparedMessage(pm) } -func (c enhConn) updateWriteDeadline(conn connection) error { - if c.cfg.write.timeout > 0 { - deadline := time.Now().Add(c.cfg.write.timeout) - return conn.SetWriteDeadline(deadline) +func (c *conn) updateReadDeadline(conn *websocket.Conn) error { + if c.config.read.timeout > 0 { + deadline := time.Now().Add(c.config.read.timeout) + return conn.SetReadDeadline(deadline) } return nil } -func (c enhConn) updateReadDeadline(conn connection) error { - if c.cfg.read.timeout > 0 { - deadline := time.Now().Add(c.cfg.read.timeout) - return conn.SetReadDeadline(deadline) - } - return nil +func (c *conn) WriteControl(messageType int, data []byte, deadline time.Time) error { + return c.ws.WriteControl(messageType, data, deadline) } + +func (c *conn) Close() error { return c.ws.Close() } diff --git a/transport/ws/dialler.go b/transport/ws/dialler.go index f7b2dbf..fc6eb8a 100644 --- a/transport/ws/dialler.go +++ b/transport/ws/dialler.go @@ -11,70 +11,40 @@ import ( "github.com/fasthttp/websocket" ) -//nolint:interfacebloat +type Dial = func(network, addr string) (net.Conn, error) +type DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) +type Proxy = func(*http.Request) (*url.URL, error) + type Dialler interface { - SetNetDial(dial func(network, addr string) (net.Conn, error)) - SetNetDialContext(func(ctx context.Context, network, addr string) (net.Conn, error)) - SetNetDialTLSContext(func(ctx context.Context, network, addr string) (net.Conn, error)) - SetProxy(func(*http.Request) (*url.URL, error)) + SetNetDial(Dial) + SetNetDialContext(DialContext) + SetNetDialTLSContext(DialContext) + SetProxy(Proxy) SetTLSClientConfig(*tls.Config) SetHandshakeTimeout(time.Duration) SetReadBufferSize(int) SetWriteBufferSize(int) - SetWriteBufferPool(websocket.BufferPool) + SetWriteBufferPool(BufferPool) SetSubprotocols([]string) SetEnableCompression(bool) SetJar(http.CookieJar) } type dialler struct { - *websocket.Dialer -} - -func (d dialler) SetNetDial(dial func(network, addr string) (net.Conn, error)) { - d.NetDial = dial -} - -func (d dialler) SetNetDialContext(dial func(ctx context.Context, network, addr string) (net.Conn, error)) { - d.NetDialContext = dial -} - -func (d dialler) SetNetDialTLSContext(dial func(ctx context.Context, network, addr string) (net.Conn, error)) { - d.NetDialTLSContext = dial -} - -func (d dialler) SetProxy(proxy func(*http.Request) (*url.URL, error)) { - d.Proxy = proxy -} - -func (d dialler) SetTLSClientConfig(cfg *tls.Config) { - d.TLSClientConfig = cfg -} - -func (d dialler) SetHandshakeTimeout(timeout time.Duration) { - d.HandshakeTimeout = timeout -} - -func (d dialler) SetReadBufferSize(size int) { - d.ReadBufferSize = size -} - -func (d dialler) SetWriteBufferSize(size int) { - d.WriteBufferSize = size -} - -func (d dialler) SetWriteBufferPool(pool websocket.BufferPool) { - d.WriteBufferPool = pool -} - -func (d dialler) SetSubprotocols(protocols []string) { - d.Subprotocols = protocols -} - -func (d dialler) SetEnableCompression(enable bool) { - d.EnableCompression = enable -} - -func (d dialler) SetJar(jar http.CookieJar) { - d.Jar = jar -} + ws *websocket.Dialer +} + +var _ Dialler = (*dialler)(nil) + +func (d *dialler) SetNetDial(dial Dial) { d.ws.NetDial = dial } +func (d *dialler) SetNetDialContext(dial DialContext) { d.ws.NetDialContext = dial } +func (d *dialler) SetNetDialTLSContext(dial DialContext) { d.ws.NetDialTLSContext = dial } +func (d *dialler) SetProxy(proxy Proxy) { d.ws.Proxy = proxy } +func (d *dialler) SetTLSClientConfig(cfg *tls.Config) { d.ws.TLSClientConfig = cfg } +func (d *dialler) SetHandshakeTimeout(timeout time.Duration) { d.ws.HandshakeTimeout = timeout } +func (d *dialler) SetReadBufferSize(size int) { d.ws.ReadBufferSize = size } +func (d *dialler) SetWriteBufferSize(size int) { d.ws.WriteBufferSize = size } +func (d *dialler) SetWriteBufferPool(pool BufferPool) { d.ws.WriteBufferPool = pool } +func (d *dialler) SetSubprotocols(protocols []string) { d.ws.Subprotocols = protocols } +func (d *dialler) SetEnableCompression(enable bool) { d.ws.EnableCompression = enable } +func (d *dialler) SetJar(jar http.CookieJar) { d.ws.Jar = jar } diff --git a/transport/ws/heartbeat.go b/transport/ws/heartbeat.go index 22c7193..72cc000 100644 --- a/transport/ws/heartbeat.go +++ b/transport/ws/heartbeat.go @@ -10,23 +10,22 @@ import ( "github.com/fasthttp/websocket" ) -type hbConfig struct { +type heartbeatConfig struct { period, await time.Duration pinging Pinging } func heartbeat( ctx context.Context, - cfg hbConfig, - conn controller, - done <-chan struct{}, + cfg heartbeatConfig, + conn *conn, ) error { - pongCh := make(chan struct{}) + pong := make(chan struct{}) handler := conn.PongHandler() conn.SetPongHandler(func(msg string) error { select { - case pongCh <- struct{}{}: - case <-done: + case pong <- struct{}{}: + case <-ctx.Done(): } return handler(msg) }) @@ -35,28 +34,27 @@ func heartbeat( defer ticker.Stop() for { select { - case <-done: + case <-ctx.Done(): return nil case <-ticker.C: msg, deadline := cfg.pinging(ctx) err := conn.WriteControl(websocket.PingMessage, msg, deadline) switch { - case errors.Is(err, net.ErrClosed): - return nil - case errors.Is(err, syscall.EPIPE): // broken pipe can appear on closed underlying tcp connection by peer - return nil - case errors.Is(err, websocket.ErrCloseSent): + case + errors.Is(err, net.ErrClosed), + errors.Is(err, syscall.EPIPE), // broken pipe can appear on closed underlying TCP connection by peer + errors.Is(err, websocket.ErrCloseSent): return nil case err != nil: return err } } select { - case <-done: + case <-ctx.Done(): return nil case <-time.After(cfg.await): return context.DeadlineExceeded - case <-pongCh: + case <-pong: ticker.Reset(cfg.period) } } diff --git a/transport/ws/intercepting_writer.go b/transport/ws/intercepting_writer.go index 4b21a27..35a01cc 100644 --- a/transport/ws/intercepting_writer.go +++ b/transport/ws/intercepting_writer.go @@ -1,6 +1,8 @@ // Package ws is partially based on the http package from go-kit. // The code in this file is based on code from http package of go-kit project (https://github.com/go-kit/kit/blob/master/transport/http/intercepting_writer.go). // All rights of the code belong to the original author. +// +//nolint:all package ws import ( diff --git a/transport/ws/option.go b/transport/ws/option.go index 428a3be..782c2a9 100644 --- a/transport/ws/option.go +++ b/transport/ws/option.go @@ -31,25 +31,25 @@ func WithClientFinalizer(finalizer kithttp.ClientFinalizerFunc) ClientOption { func WithClientErrorHandler(handler transport.ErrorHandler) ClientOption { return funcClientOption{f: func(o *clientOptions) { - o.errHandlers = append(o.errHandlers, handler) + o.errorHandlers = append(o.errorHandlers, handler) }} } func WithClientWriteTimeout(timeout time.Duration) ClientOption { return funcClientOption{f: func(o *clientOptions) { - o.enhancement.config.write.timeout = timeout + o.connection.config.write.timeout = timeout }} } func WithClientWriteMod(mod WriteMod) ClientOption { return funcClientOption{f: func(o *clientOptions) { - o.enhancement.config.write.mod = mod + o.connection.config.write.mod = mod }} } func WithClientReadTimeout(timeout time.Duration) ClientOption { return funcClientOption{f: func(o *clientOptions) { - o.enhancement.config.read.timeout = timeout + o.connection.config.read.timeout = timeout }} } @@ -88,25 +88,25 @@ func WithServerFinalizer(finalizer kithttp.ServerFinalizerFunc) ServerOption { func WithServerErrorHandler(handler transport.ErrorHandler) ServerOption { return funcServerOption{f: func(o *serverOptions) { - o.errHandlers = append(o.errHandlers, handler) + o.errorHandlers = append(o.errorHandlers, handler) }} } func WithServerReadTimeout(timeout time.Duration) ServerOption { return funcServerOption{f: func(o *serverOptions) { - o.enhancement.config.read.timeout = timeout + o.connection.config.read.timeout = timeout }} } func WithServerWriteTimeout(timeout time.Duration) ServerOption { return funcServerOption{f: func(o *serverOptions) { - o.enhancement.config.write.timeout = timeout + o.connection.config.write.timeout = timeout }} } func WithServerWriteMod(mod WriteMod) ServerOption { return funcServerOption{f: func(o *serverOptions) { - o.enhancement.config.write.mod = mod + o.connection.config.write.mod = mod }} } diff --git a/transport/ws/server.go b/transport/ws/server.go index c2e26e9..2f3656f 100644 --- a/transport/ws/server.go +++ b/transport/ws/server.go @@ -8,8 +8,8 @@ import ( "sync" "syscall" + "github.com/einouqo/ext-kit/transport" "github.com/fasthttp/websocket" - "github.com/go-kit/kit/transport" kithttp "github.com/go-kit/kit/transport/http" "golang.org/x/sync/errgroup" @@ -61,15 +61,13 @@ func (s *Server[IN, OUT]) ServeHTTP(w http.ResponseWriter, r *http.Request) { err := s.serve(ctx, w, r) if err != nil { - for _, h := range s.opts.errHandlers { - h.Handle(ctx, err) - } + s.opts.errorHandlers.Handle(ctx, err) } } func (s *Server[IN, OUT]) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) (err error) { headers := make(http.Header) - upg := upgrader{new(websocket.Upgrader)} + upg := new(upgrader) for _, f := range s.opts.before { ctx = f(ctx, upg, r, headers) } @@ -78,31 +76,38 @@ func (s *Server[IN, OUT]) serve(ctx context.Context, w http.ResponseWriter, r *h if err != nil { return err } - defer wsc.Close() + conn := &conn{ + ws: wsc, + config: s.opts.connection.config, + } + defer func() { _ = conn.Close() }() for _, f := range s.opts.after { - ctx = f(ctx, wsc) + ctx = f(ctx, conn) } - conn := enhConn{wsc, s.opts.enhancement.config} + ctx, cancel := context.WithCancel(ctx) + defer cancel() - inCh := make(chan IN) - receive, err := s.e(ctx, inCh) + ins := make(chan IN) + receive, err := s.e(ctx, ins) if err != nil { return err } once := sync.Once{} - doneCh := make(chan struct{}) - group := errgroup.Group{} - group.Go(func() (err error) { - defer close(inCh) - defer conn.Close() - defer once.Do(func() { + leave := func(err error) { + once.Do(func() { code, msg, deadline := s.closure(ctx, err) data := websocket.FormatCloseMessage(code.fastsocket(), msg) _ = conn.WriteControl(websocket.CloseMessage, data, deadline) }) + } + + group := errgroup.Group{} + group.Go(func() (err error) { + defer func() { leave(err) }() + defer close(ins) for { messageType, msg, err := conn.ReadMessage() switch { @@ -117,19 +122,21 @@ func (s *Server[IN, OUT]) serve(ctx context.Context, w http.ResponseWriter, r *h return err } select { - case <-doneCh: + case <-ctx.Done(): + return nil + default: + // context is not done, continue + } + select { + case <-ctx.Done(): return nil - case inCh <- in: + case ins <- in: } } }) group.Go(func() (err error) { - defer close(doneCh) - defer once.Do(func() { - code, msg, deadline := s.closure(ctx, err) - data := websocket.FormatCloseMessage(code.fastsocket(), msg) - _ = conn.WriteControl(websocket.CloseMessage, data, deadline) - }) + defer func() { leave(err) }() + defer cancel() for { out, err := receive() switch { @@ -144,11 +151,10 @@ func (s *Server[IN, OUT]) serve(ctx context.Context, w http.ResponseWriter, r *h } err = conn.WriteMessage(mt.fastsocket(), msg) switch { - case errors.Is(err, net.ErrClosed): - return nil - case errors.Is(err, syscall.EPIPE): // broken pipe can appear on closed underlying tcp connection by peer - return nil - case errors.Is(err, websocket.ErrCloseSent): + case + errors.Is(err, net.ErrClosed), + errors.Is(err, syscall.EPIPE), // broken pipe can appear on closed underlying TCP connection by peer + errors.Is(err, websocket.ErrCloseSent): return nil case err != nil: return err @@ -157,31 +163,27 @@ func (s *Server[IN, OUT]) serve(ctx context.Context, w http.ResponseWriter, r *h }) if s.opts.heartbeat.enable { - group.Go(func() error { - defer conn.Close() - return heartbeat(ctx, s.opts.heartbeat.config, conn, doneCh) + group.Go(func() (err error) { + defer func() { leave(err) }() + return heartbeat(ctx, s.opts.heartbeat.config, conn) }) } - if err := group.Wait(); err != nil { - return err - } - - return nil + return group.Wait() } type serverOptions struct { - before []UpgradeFunc - after []ServerTunerFunc - finalizer []kithttp.ServerFinalizerFunc - errHandlers []transport.ErrorHandler + before []UpgradeFunc + after []ServerTunerFunc + finalizer []kithttp.ServerFinalizerFunc + errorHandlers transport.ErrorHandlers - enhancement struct { - config enhConfig + connection struct { + config connConfig } heartbeat struct { enable bool - config hbConfig + config heartbeatConfig } } diff --git a/transport/ws/upgrader.go b/transport/ws/upgrader.go index d21f31d..d604eb8 100644 --- a/transport/ws/upgrader.go +++ b/transport/ws/upgrader.go @@ -7,49 +7,35 @@ import ( "github.com/fasthttp/websocket" ) +type ErrorWriter = func(w http.ResponseWriter, r *http.Request, status int, reason error) +type CheckOrigin = func(r *http.Request) bool + type Upgrader interface { - SetHandshakeTimeout(timeout time.Duration) - SetReadBufferSize(size int) - SetWriteBufferSize(size int) - SetWriteBufferPool(pool websocket.BufferPool) + SetHandshakeTimeout(time.Duration) + SetReadBufferSize(int) + SetWriteBufferSize(int) + SetWriteBufferPool(BufferPool) SetSubprotocols(protocols []string) - SetErrorWriter(writer func(w http.ResponseWriter, r *http.Request, status int, reason error)) - SetCheckOrigin(check func(r *http.Request) bool) + SetErrorWriter(ErrorWriter) + SetCheckOrigin(check CheckOrigin) SetEnableCompression(enabled bool) } type upgrader struct { - *websocket.Upgrader -} - -func (u upgrader) SetHandshakeTimeout(timeout time.Duration) { - u.HandshakeTimeout = timeout + ws websocket.Upgrader } -func (u upgrader) SetReadBufferSize(size int) { - u.ReadBufferSize = size -} +var _ Upgrader = (*upgrader)(nil) -func (u upgrader) SetWriteBufferSize(size int) { - u.WriteBufferSize = size -} - -func (u upgrader) SetWriteBufferPool(pool websocket.BufferPool) { - u.WriteBufferPool = pool -} - -func (u upgrader) SetSubprotocols(protocols []string) { - u.Subprotocols = protocols -} - -func (u upgrader) SetErrorWriter(writer func(w http.ResponseWriter, r *http.Request, status int, reason error)) { - u.Error = writer -} - -func (u upgrader) SetCheckOrigin(check func(r *http.Request) bool) { - u.CheckOrigin = check -} +func (u *upgrader) SetHandshakeTimeout(timeout time.Duration) { u.ws.HandshakeTimeout = timeout } +func (u *upgrader) SetReadBufferSize(size int) { u.ws.ReadBufferSize = size } +func (u *upgrader) SetWriteBufferSize(size int) { u.ws.WriteBufferSize = size } +func (u *upgrader) SetWriteBufferPool(pool BufferPool) { u.ws.WriteBufferPool = pool } +func (u *upgrader) SetSubprotocols(protocols []string) { u.ws.Subprotocols = protocols } +func (u *upgrader) SetErrorWriter(w ErrorWriter) { u.ws.Error = w } +func (u *upgrader) SetCheckOrigin(check CheckOrigin) { u.ws.CheckOrigin = check } +func (u *upgrader) SetEnableCompression(enabled bool) { u.ws.EnableCompression = enabled } -func (u upgrader) SetEnableCompression(enabled bool) { - u.EnableCompression = enabled +func (u *upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) { + return u.ws.Upgrade(w, r, responseHeader) }