diff --git a/modkit/kernel/bootstrap.go b/modkit/kernel/bootstrap.go index cfd19f9..53cbb31 100644 --- a/modkit/kernel/bootstrap.go +++ b/modkit/kernel/bootstrap.go @@ -2,6 +2,7 @@ package kernel import ( "context" + "io" "github.com/go-modkit/modkit/modkit/module" ) @@ -73,3 +74,19 @@ func (a *App) Get(token module.Token) (any, error) { func (a *App) CleanupHooks() []func(context.Context) error { return a.container.cleanupHooksLIFO() } + +// Closers returns provider closers in build order. +func (a *App) Closers() []io.Closer { + return a.container.closersInBuildOrder() +} + +// Close calls Close on all io.Closer providers in reverse build order. +func (a *App) Close() error { + var firstErr error + for _, closer := range a.container.closersLIFO() { + if err := closer.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} diff --git a/modkit/kernel/container.go b/modkit/kernel/container.go index ea6603d..36c705e 100644 --- a/modkit/kernel/container.go +++ b/modkit/kernel/container.go @@ -2,6 +2,7 @@ package kernel import ( "context" + "io" "sync" "github.com/go-modkit/modkit/modkit/module" @@ -20,6 +21,8 @@ type Container struct { locks map[module.Token]*sync.Mutex waitingOn map[module.Token]module.Token cleanupHooks []func(context.Context) error + closers []io.Closer + buildOrder []module.Token mu sync.Mutex } @@ -48,6 +51,8 @@ func newContainer(graph *Graph, visibility Visibility) (*Container, error) { locks: make(map[module.Token]*sync.Mutex), waitingOn: make(map[module.Token]module.Token), cleanupHooks: make([]func(context.Context) error, 0), + closers: make([]io.Closer, 0), + buildOrder: make([]module.Token, 0), }, nil } @@ -110,6 +115,10 @@ func (c *Container) getWithStack(token module.Token, requester string, stack []m if entry.cleanup != nil { c.cleanupHooks = append(c.cleanupHooks, entry.cleanup) } + if closer, ok := instance.(io.Closer); ok { + c.closers = append(c.closers, closer) + } + c.buildOrder = append(c.buildOrder, token) c.mu.Unlock() return instance, nil } @@ -125,6 +134,35 @@ func (c *Container) cleanupHooksLIFO() []func(context.Context) error { return hooks } +func (c *Container) closersLIFO() []io.Closer { + c.mu.Lock() + defer c.mu.Unlock() + + closers := make([]io.Closer, len(c.closers)) + for i, closer := range c.closers { + closers[len(c.closers)-1-i] = closer + } + return closers +} + +func (c *Container) closersInBuildOrder() []io.Closer { + c.mu.Lock() + defer c.mu.Unlock() + + closers := make([]io.Closer, len(c.closers)) + copy(closers, c.closers) + return closers +} + +func (c *Container) providerBuildOrder() []module.Token { + c.mu.Lock() + defer c.mu.Unlock() + + order := make([]module.Token, len(c.buildOrder)) + copy(order, c.buildOrder) + return order +} + type moduleResolver struct { container *Container moduleName string diff --git a/modkit/kernel/container_internal_test.go b/modkit/kernel/container_internal_test.go new file mode 100644 index 0000000..2c2c767 --- /dev/null +++ b/modkit/kernel/container_internal_test.go @@ -0,0 +1,132 @@ +package kernel + +import ( + "testing" + + "github.com/go-modkit/modkit/modkit/module" +) + +type modHelperInternal struct { + def module.ModuleDef +} + +func (m *modHelperInternal) Definition() module.ModuleDef { + return m.def +} + +func modInternal( + name string, + imports []module.Module, + providers []module.ProviderDef, + controllers []module.ControllerDef, + exports []module.Token, +) module.Module { + return &modHelperInternal{ + def: module.ModuleDef{ + Name: name, + Imports: imports, + Providers: providers, + Controllers: controllers, + Exports: exports, + }, + } +} + +func TestContainerRecordsProviderBuildOrder(t *testing.T) { + first := module.Token("provider.first") + second := module.Token("provider.second") + + modA := modInternal("A", nil, + []module.ProviderDef{{ + Token: first, + Build: func(r module.Resolver) (any, error) { + return "first", nil + }, + }, { + Token: second, + Build: func(r module.Resolver) (any, error) { + return "second", nil + }, + }}, + nil, + nil, + ) + + app, err := Bootstrap(modA) + if err != nil { + t.Fatalf("Bootstrap failed: %v", err) + } + + if _, err := app.Get(second); err != nil { + t.Fatalf("Get second failed: %v", err) + } + if _, err := app.Get(first); err != nil { + t.Fatalf("Get first failed: %v", err) + } + + order := app.container.providerBuildOrder() + if len(order) != 2 { + t.Fatalf("expected 2 providers, got %d", len(order)) + } + if order[0] != second || order[1] != first { + t.Fatalf("unexpected order: %v", order) + } +} + +func TestContainerRecordsClosersInBuildOrder(t *testing.T) { + closerA := module.Token("closer.a") + closerB := module.Token("closer.b") + + modA := modInternal("A", nil, + []module.ProviderDef{{ + Token: closerA, + Build: func(r module.Resolver) (any, error) { + return &testCloser{name: "a"}, nil + }, + }, { + Token: closerB, + Build: func(r module.Resolver) (any, error) { + return &testCloser{name: "b"}, nil + }, + }}, + nil, + nil, + ) + + app, err := Bootstrap(modA) + if err != nil { + t.Fatalf("Bootstrap failed: %v", err) + } + + _, _ = app.Get(closerA) + _, _ = app.Get(closerB) + + closers := app.container.closersInBuildOrder() + if len(closers) != 2 { + t.Fatalf("expected 2 closers, got %d", len(closers)) + } + + first, ok := closers[0].(*testCloser) + if !ok { + t.Fatalf("expected *testCloser, got %T", closers[0]) + } + second, ok := closers[1].(*testCloser) + if !ok { + t.Fatalf("expected *testCloser, got %T", closers[1]) + } + if first.Name() != "a" || second.Name() != "b" { + t.Fatalf("unexpected order: %v", closers) + } +} + +type testCloser struct { + name string +} + +func (c *testCloser) Close() error { + return nil +} + +func (c *testCloser) Name() string { + return c.name +} diff --git a/modkit/kernel/container_test.go b/modkit/kernel/container_test.go index b6fbb72..3f20146 100644 --- a/modkit/kernel/container_test.go +++ b/modkit/kernel/container_test.go @@ -3,6 +3,7 @@ package kernel_test import ( "context" "errors" + "reflect" "sync" "sync/atomic" "testing" @@ -12,6 +13,27 @@ import ( "github.com/go-modkit/modkit/modkit/module" ) +type recordingCloser struct { + name string + closed *[]string +} + +func (c *recordingCloser) Close() error { + *c.closed = append(*c.closed, c.name) + return nil +} + +type erroringCloser struct { + name string + closed *[]string + err error +} + +func (c *erroringCloser) Close() error { + *c.closed = append(*c.closed, c.name) + return c.err +} + func TestAppGetRejectsNotVisibleToken(t *testing.T) { modA := mod("A", nil, nil, nil, nil) @@ -379,3 +401,135 @@ func TestContainerGetRegistersCleanupHooks(t *testing.T) { t.Fatalf("expected cleanup to be called") } } + +func TestAppCloseReverseOrder(t *testing.T) { + var closed []string + modA := mod("A", nil, + []module.ProviderDef{{ + Token: "closer.a", + Build: func(r module.Resolver) (any, error) { + return &recordingCloser{name: "a", closed: &closed}, nil + }, + }, { + Token: "closer.b", + Build: func(r module.Resolver) (any, error) { + return &recordingCloser{name: "b", closed: &closed}, nil + }, + }}, + nil, + nil, + ) + + app, err := kernel.Bootstrap(modA) + if err != nil { + t.Fatalf("Bootstrap failed: %v", err) + } + + if _, err := app.Get("closer.a"); err != nil { + t.Fatalf("Get closer.a failed: %v", err) + } + if _, err := app.Get("closer.b"); err != nil { + t.Fatalf("Get closer.b failed: %v", err) + } + + if err := app.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + + if len(closed) != 2 || closed[0] != "b" || closed[1] != "a" { + t.Fatalf("expected reverse close order, got %v", closed) + } +} + +func TestAppCloseOrderWithDependencies(t *testing.T) { + var closed []string + + modA := mod("A", nil, + []module.ProviderDef{{ + Token: "closer.a", + Build: func(r module.Resolver) (any, error) { + return &recordingCloser{name: "a", closed: &closed}, nil + }, + }, { + Token: "closer.b", + Build: func(r module.Resolver) (any, error) { + if _, err := r.Get("closer.a"); err != nil { + return nil, err + } + return &recordingCloser{name: "b", closed: &closed}, nil + }, + }, { + Token: "closer.c", + Build: func(r module.Resolver) (any, error) { + if _, err := r.Get("closer.b"); err != nil { + return nil, err + } + return &recordingCloser{name: "c", closed: &closed}, nil + }, + }}, + nil, + nil, + ) + + app, err := kernel.Bootstrap(modA) + if err != nil { + t.Fatalf("Bootstrap failed: %v", err) + } + + _, _ = app.Get("closer.c") + + if err := app.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + + want := []string{"c", "b", "a"} + if !reflect.DeepEqual(closed, want) { + t.Fatalf("expected %v, got %v", want, closed) + } +} + +func TestAppCloseContinuesAfterError(t *testing.T) { + var closed []string + errB := errors.New("close failed") + modA := mod("A", nil, + []module.ProviderDef{{ + Token: "closer.a", + Build: func(r module.Resolver) (any, error) { + return &recordingCloser{name: "a", closed: &closed}, nil + }, + }, { + Token: "closer.b", + Build: func(r module.Resolver) (any, error) { + return &erroringCloser{name: "b", closed: &closed, err: errB}, nil + }, + }, { + Token: "closer.c", + Build: func(r module.Resolver) (any, error) { + return &recordingCloser{name: "c", closed: &closed}, nil + }, + }}, + nil, + nil, + ) + + app, err := kernel.Bootstrap(modA) + if err != nil { + t.Fatalf("Bootstrap failed: %v", err) + } + + if _, err := app.Get("closer.a"); err != nil { + t.Fatalf("Get closer.a failed: %v", err) + } + if _, err := app.Get("closer.b"); err != nil { + t.Fatalf("Get closer.b failed: %v", err) + } + _, _ = app.Get("closer.c") + + if err := app.Close(); !errors.Is(err, errB) { + t.Fatalf("expected error %v, got %v", errB, err) + } + + if len(closed) != 3 || closed[0] != "c" || closed[1] != "b" || closed[2] != "a" { + t.Fatalf("expected reverse close order with all closers, got %v", closed) + } +}