diff --git a/example/graphql/main.go b/example/graphql/main.go index efc3f33..6343df3 100644 --- a/example/graphql/main.go +++ b/example/graphql/main.go @@ -36,7 +36,7 @@ type Query { } graphQLService := service. - NewGraphQLBuilder("Example"). + NewGraphQLBuilder("Example", nil). Schema(schema). Resolver(&res). Build() diff --git a/example/grpc/server/main.go b/example/grpc/server/main.go index 31560f0..181f6ee 100644 --- a/example/grpc/server/main.go +++ b/example/grpc/server/main.go @@ -22,7 +22,7 @@ func (h helloServer) Hello(ctx context.Context, request *proto.HelloRequest) (*p func main() { gRPCService, err := service. - NewGRPCBuilder("Example"). + NewGRPCBuilder("Example", nil). RegisterHandler(func(server *grpc.Server) { hs := helloServer{} proto.RegisterHelloServer(server, hs) diff --git a/example/routing/main.go b/example/routing/main.go index 123a8ef..bf82ebb 100644 --- a/example/routing/main.go +++ b/example/routing/main.go @@ -23,7 +23,7 @@ func main() { } routingService := service. - NewRoutingBuilder("Example"). + NewRoutingBuilder("Example", nil). Routes(routes). Build() routingService.StartAndWait(8080) diff --git a/fw/service/graphql.go b/fw/service/graphql.go index 2116e53..ba53318 100644 --- a/fw/service/graphql.go +++ b/fw/service/graphql.go @@ -1,6 +1,7 @@ package service import ( + "context" "fmt" "github.com/short-d/app/fw/graphql" @@ -14,6 +15,7 @@ type GraphQL struct { logger logger.Logger graphQLPath string webServer *web.Server + onShutdown func() } func (g GraphQL) StartAsync(port int) { @@ -29,10 +31,16 @@ func (g GraphQL) StartAsync(port int) { }() } -func (g GraphQL) Stop() { +func (g GraphQL) Stop(ctx context.Context, cancel context.CancelFunc) { defer g.logger.Info("GraphQL service stopped") + defer func() { + if g.onShutdown != nil { + g.onShutdown() + } + cancel() + }() - err := g.webServer.Shutdown() + err := g.webServer.Shutdown(ctx) if err != nil { g.logger.Error(err) } @@ -40,13 +48,15 @@ func (g GraphQL) Stop() { func (g GraphQL) StartAndWait(port int) { g.StartAsync(port) - select {} + + listenForSignals(g) } func NewGraphQL( logger logger.Logger, graphQLPath string, handler graphql.Handler, + onShutdown func(), ) GraphQL { server := web.NewServer(logger) server.HandleFunc(graphQLPath, handler) @@ -55,13 +65,15 @@ func NewGraphQL( logger: logger, graphQLPath: graphQLPath, webServer: &server, + onShutdown: onShutdown, } } type GraphQLBuilder struct { - logger logger.Logger - schema string - resolver graphql.Resolver + logger logger.Logger + schema string + resolver graphql.Resolver + onShutdown func() } func (g *GraphQLBuilder) Schema(schema string) *GraphQLBuilder { @@ -80,14 +92,15 @@ func (g GraphQLBuilder) Build() GraphQL { Resolver: g.resolver, } handler := graphql.NewGraphGopherHandler(api) - return NewGraphQL(g.logger, "/graphql", handler) + return NewGraphQL(g.logger, "/graphql", handler, g.onShutdown) } -func NewGraphQLBuilder(name string) *GraphQLBuilder { +func NewGraphQLBuilder(name string, onShutdown func()) *GraphQLBuilder { lg := newDefaultLogger(name) return &GraphQLBuilder{ - logger: lg, - schema: "", - resolver: nil, + logger: lg, + schema: "", + resolver: nil, + onShutdown: onShutdown, } } diff --git a/fw/service/grpc.go b/fw/service/grpc.go index 23337d6..dc830a3 100644 --- a/fw/service/grpc.go +++ b/fw/service/grpc.go @@ -1,6 +1,7 @@ package service import ( + "context" "fmt" "net" @@ -17,10 +18,19 @@ type GRPC struct { gRPCServer *grpc.Server gRPCApi rpc.API logger logger.Logger + onShutdown func() } -func (g GRPC) Stop() { - g.gRPCServer.Stop() +func (g GRPC) Stop(ctx context.Context, cancel context.CancelFunc) { + defer g.logger.Info("gRPC service stopped") + defer func() { + if g.onShutdown != nil { + g.onShutdown() + } + cancel() + }() + + g.gRPCServer.GracefulStop() } func (g GRPC) StartAsync(port int) { @@ -42,13 +52,15 @@ func (g GRPC) StartAsync(port int) { func (g GRPC) StartAndWait(port int) { g.StartAsync(port) - select {} + + listenForSignals(g) } func NewGRPC( logger logger.Logger, rpcAPI rpc.API, securityPolicy security.Policy, + onShutdown func(), ) (GRPC, error) { server := grpc.NewServer() if !securityPolicy.IsEncrypted { @@ -71,6 +83,7 @@ func NewGRPC( gRPCServer: grpc.NewServer(grpc.Creds(cred)), gRPCApi: rpcAPI, logger: logger, + onShutdown: onShutdown, }, nil } @@ -92,6 +105,7 @@ type GRPCBuilder struct { certPath string keyPath string registerHandler registerHandler + onShutdown func() } func (g *GRPCBuilder) EnableTLS(certPath string, keyPath string) *GRPCBuilder { @@ -113,10 +127,10 @@ func (g *GRPCBuilder) Build() (GRPC, error) { CertificateFilePath: g.certPath, KeyFilePath: g.keyPath, } - return NewGRPC(g.logger, rpcAPI, policy) + return NewGRPC(g.logger, rpcAPI, policy, g.onShutdown) } -func NewGRPCBuilder(name string) *GRPCBuilder { +func NewGRPCBuilder(name string, onShutdown func()) *GRPCBuilder { lg := newDefaultLogger(name) builder := GRPCBuilder{ logger: lg, @@ -124,6 +138,7 @@ func NewGRPCBuilder(name string) *GRPCBuilder { certPath: "", keyPath: "", registerHandler: func(server *grpc.Server) {}, + onShutdown: onShutdown, } return &builder } diff --git a/fw/service/routing.go b/fw/service/routing.go index 3ddf37a..6285eae 100644 --- a/fw/service/routing.go +++ b/fw/service/routing.go @@ -1,6 +1,7 @@ package service import ( + "context" "fmt" "github.com/short-d/app/fw/logger" @@ -11,8 +12,9 @@ import ( var _ Service = (*Routing)(nil) type Routing struct { - logger logger.Logger - webServer *web.Server + logger logger.Logger + webServer *web.Server + onShutdown func() } func (r Routing) StartAsync(port int) { @@ -28,10 +30,16 @@ func (r Routing) StartAsync(port int) { }() } -func (r Routing) Stop() { +func (r Routing) Stop(ctx context.Context, cancel context.CancelFunc) { defer r.logger.Info("Routing service stopped") + defer func() { + if r.onShutdown != nil { + r.onShutdown() + } + cancel() + }() - err := r.webServer.Shutdown() + err := r.webServer.Shutdown(ctx) if err != nil { r.logger.Error(err) } @@ -39,10 +47,11 @@ func (r Routing) Stop() { func (r Routing) StartAndWait(port int) { r.StartAsync(port) - select {} + + listenForSignals(r) } -func NewRouting(logger logger.Logger, routes []router.Route) Routing { +func NewRouting(logger logger.Logger, routes []router.Route, onShutdown func()) Routing { httpRouter := router.NewHTTPHandler() for _, route := range routes { @@ -61,14 +70,16 @@ func NewRouting(logger logger.Logger, routes []router.Route) Routing { server.HandleFunc("/", &httpRouter) return Routing{ - logger: logger, - webServer: &server, + logger: logger, + webServer: &server, + onShutdown: onShutdown, } } type RoutingBuilder struct { - logger logger.Logger - routes []router.Route + logger logger.Logger + routes []router.Route + onShutdown func() } func (r *RoutingBuilder) Routes(routes []router.Route) *RoutingBuilder { @@ -77,13 +88,14 @@ func (r *RoutingBuilder) Routes(routes []router.Route) *RoutingBuilder { } func (r RoutingBuilder) Build() Routing { - return NewRouting(r.logger, r.routes) + return NewRouting(r.logger, r.routes, r.onShutdown) } -func NewRoutingBuilder(name string) *RoutingBuilder { +func NewRoutingBuilder(name string, onShutdown func()) *RoutingBuilder { lg := newDefaultLogger(name) return &RoutingBuilder{ - logger: lg, - routes: make([]router.Route, 0), + logger: lg, + routes: make([]router.Route, 0), + onShutdown: onShutdown, } } diff --git a/fw/service/service.go b/fw/service/service.go index a7214e0..86c5c68 100644 --- a/fw/service/service.go +++ b/fw/service/service.go @@ -1,8 +1,24 @@ package service -// TODO(issue#67): support graceful shutdown. +import ( + "context" + "os" + "os/signal" + "syscall" + "time" +) + type Service interface { - Stop() + Stop(ctx context.Context, cancel context.CancelFunc) StartAsync(port int) StartAndWait(port int) } + +func listenForSignals(s Service) { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + <-sigCh + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + s.Stop(ctx, cancel) +} diff --git a/fw/web/server.go b/fw/web/server.go index 642c1f2..68eb8e5 100644 --- a/fw/web/server.go +++ b/fw/web/server.go @@ -28,8 +28,8 @@ func (s *Server) ListenAndServe(port int) error { return err } -func (s Server) Shutdown() error { - return s.server.Shutdown(context.Background()) +func (s Server) Shutdown(ctx context.Context) error { + return s.server.Shutdown(ctx) } func (s Server) HandleFunc(pattern string, handler http.Handler) {