Skip to content

Commit

Permalink
fix(error handling): add a TypeMismatch error when we invoke a servic…
Browse files Browse the repository at this point in the history
…e from an unexecpted type

Fixes #80
  • Loading branch information
samber committed May 13, 2024
1 parent 9a76c0f commit 1998a7a
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 35 deletions.
2 changes: 1 addition & 1 deletion di_explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func ExplainNamedService(scope Injector, name string) (description ExplainServic
ScopeID: serviceScope.ID(),
ScopeName: serviceScope.Name(),
ServiceName: name,
ServiceType: service.getType(),
ServiceType: service.getServiceType(),
ServiceBuildTime: buildTime,
Invoked: invoked,
Dependencies: newExplainServiceDependencies(_i, newEdgeService(_i.ID(), _i.Name(), name), "dependencies"),
Expand Down
4 changes: 4 additions & 0 deletions di_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,10 @@ func TestInvokeNamed(t *testing.T) {
instance2, err2 := InvokeNamed[int](i, "foobar")
is.Nil(err2)
is.EqualValues(42, instance2)

instance3, err3 := InvokeNamed[string](i, "foobar")
is.EqualError(err3, "DI: service found, but type mismatch: invoking `string` but registered `int`")
is.EqualValues("", instance3)
}

func TestMustInvokeNamed(t *testing.T) {
Expand Down
7 changes: 6 additions & 1 deletion invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func invokeByName[T any](i Injector, name string) (T, error) {

service, ok := serviceAny.(Service[T])
if !ok {
return empty[T](), serviceNotFound(injector, ErrServiceNotFound, invokerChain)
return empty[T](), serviceTypeMismatch(inferServiceName[T](), serviceAny.(ServiceAny).getTypeName())
}

injector.RootScope().opts.onBeforeInvocation(serviceScope, name)
Expand Down Expand Up @@ -247,6 +247,11 @@ func serviceNotFound(injector Injector, err error, chain []string) error {
return fmt.Errorf("%w `%s`, available services: %s", err, name, strings.Join(sortedServiceNames, ", "))
}

// serviceTypeMismatch returns an error indicating that the specified service was found, but its type does not match the expected type.
func serviceTypeMismatch(invoking string, registered string) error {
return fmt.Errorf("DI: service found, but type mismatch: invoking `%s` but registered `%s`", invoking, registered)
}

// getServiceNames formats a list of EdgeService names.
func getServiceNames(services []EdgeService) []string {
return mAp(services, func(edge EdgeService, _ int) string {
Expand Down
16 changes: 10 additions & 6 deletions service.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ var serviceTypeToIcon = map[ServiceType]string{

type Service[T any] interface {
getName() string
getType() ServiceType
getTypeName() string
getServiceType() ServiceType
getEmptyInstance() any
getInstanceAny(Injector) (any, error)
getInstance(Injector) (T, error)
Expand All @@ -43,7 +44,8 @@ type Service[T any] interface {
// Like Service[T] but without the generic type.
type ServiceAny interface {
getName() string
getType() ServiceType
getTypeName() string
getServiceType() ServiceType
getEmptyInstance() any
getInstanceAny(Injector) (any, error)
// getInstance(Injector) (T, error)
Expand All @@ -56,7 +58,8 @@ type ServiceAny interface {
}

type serviceGetName interface{ getName() string }
type serviceGetType interface{ getType() ServiceType }
type serviceGetTypeName interface{ getTypeName() string }
type serviceGetServiceType interface{ getServiceType() ServiceType }
type serviceGetEmptyInstance interface{ getEmptyInstance() any }
type serviceGetInstanceAny interface{ getInstanceAny(Injector) (any, error) }
type serviceGetInstance[T any] interface{ getInstance(Injector) (T, error) } //nolint:unused
Expand All @@ -73,7 +76,8 @@ type serviceBuildTime interface {
}

var _ serviceGetName = (Service[int])(nil)
var _ serviceGetType = (Service[int])(nil)
var _ serviceGetTypeName = (Service[int])(nil)
var _ serviceGetServiceType = (Service[int])(nil)
var _ serviceGetEmptyInstance = (Service[int])(nil)
var _ serviceGetInstanceAny = (Service[int])(nil)
var _ serviceIsHealthchecker = (Service[int])(nil)
Expand All @@ -88,7 +92,7 @@ func inferServiceName[T any]() string {
}

func inferServiceProviderStacktrace(service ServiceAny) (stacktrace.Frame, bool) {
if service.getType() == ServiceTypeTransient {
if service.getServiceType() == ServiceTypeTransient {
return stacktrace.Frame{}, false
} else {
providerFrame, _ := service.source()
Expand All @@ -113,7 +117,7 @@ func inferServiceInfo(injector Injector, name string) (serviceInfo, bool) {

return serviceInfo{
name: name,
serviceType: serviceAny.(serviceGetType).getType(),
serviceType: serviceAny.(serviceGetServiceType).getServiceType(),
serviceBuildTime: buildTime,
healthchecker: serviceAny.(serviceIsHealthchecker).isHealthchecker(),
shutdowner: serviceAny.(serviceIsShutdowner).isShutdowner(),
Expand Down
16 changes: 12 additions & 4 deletions service_alias.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ var _ serviceClone = (*serviceAlias[int, int])(nil)
type serviceAlias[Initial any, Alias any] struct {
mu sync.RWMutex
name string
typeName string
scope Injector
targetName string

Expand All @@ -31,6 +32,7 @@ func newServiceAlias[Initial any, Alias any](name string, scope Injector, target
return &serviceAlias[Initial, Alias]{
mu: sync.RWMutex{},
name: name,
typeName: inferServiceName[Alias](),
scope: scope,
targetName: targetName,

Expand All @@ -44,7 +46,11 @@ func (s *serviceAlias[Initial, Alias]) getName() string {
return s.name
}

func (s *serviceAlias[Initial, Alias]) getType() ServiceType {
func (s *serviceAlias[Initial, Alias]) getTypeName() string {
return s.typeName
}

func (s *serviceAlias[Initial, Alias]) getServiceType() ServiceType {
return ServiceTypeAlias
}

Expand Down Expand Up @@ -79,7 +85,8 @@ func (s *serviceAlias[Initial, Alias]) getInstance(i Injector) (Alias, error) {
return target, nil
default:
// should never happen, since invoke() checks the type
return empty[Alias](), fmt.Errorf("DI: could not cast `%s` as `%s`", s.targetName, s.name)
return empty[Alias](), serviceTypeMismatch(inferServiceName[Alias](), inferServiceName[Initial]())
// return empty[Alias](), fmt.Errorf("DI: could not cast `%s` as `%s`", s.targetName, s.name)
}
}

Expand Down Expand Up @@ -161,8 +168,9 @@ func (s *serviceAlias[Initial, Alias]) shutdown(ctx context.Context) error {

func (s *serviceAlias[Initial, Alias]) clone() any {
return &serviceAlias[Initial, Alias]{
mu: sync.RWMutex{},
name: s.name,
mu: sync.RWMutex{},
name: s.name,
typeName: s.typeName,
// scope: s.scope, <-- @TODO: we should inject here the cloned scope
targetName: s.targetName,

Expand Down
18 changes: 14 additions & 4 deletions service_alias_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,24 @@ func TestServiceAlias_getName(t *testing.T) {
is.Equal("foobar1", service1.getName())
}

func TestServiceAlias_getType(t *testing.T) {
func TestServiceAlias_getTypeName(t *testing.T) {
t.Parallel()
is := assert.New(t)

i := New()

service1 := newServiceAlias[string, int]("foobar1", i, "foobar2")
is.Equal("int", service1.getTypeName())
}

func TestServiceAlias_getServiceType(t *testing.T) {
t.Parallel()
is := assert.New(t)

i := New()

service1 := newServiceAlias[string, string]("foobar1", i, "foobar2")
is.Equal(ServiceTypeAlias, service1.getType())
is.Equal(ServiceTypeAlias, service1.getServiceType())
}

func TestServiceAlias_getEmptyInstance(t *testing.T) {
Expand Down Expand Up @@ -77,7 +87,7 @@ func TestServiceAlias_getInstanceAny(t *testing.T) {
// target service found but not convertible type
service3 := newServiceAlias[*lazyTestHeathcheckerOK, int]("github.com/samber/do/v2.Healthchecker", i, "int")
instance3, err3 := service3.getInstanceAny(i)
is.EqualError(err3, "DI: could not find service `int`, available services: `*github.com/samber/do/v2.lazyTestHeathcheckerOK`, `github.com/samber/do/v2.Healthchecker`, `int`")
is.EqualError(err3, "DI: service found, but type mismatch: invoking `*github.com/samber/do/v2.lazyTestHeathcheckerOK` but registered `int`")
is.EqualValues(0, instance3)

// @TODO: missing test with child scopes
Expand Down Expand Up @@ -113,7 +123,7 @@ func TestServiceAlias_getInstance(t *testing.T) {
// target service found but not convertible type
service3 := newServiceAlias[*lazyTestHeathcheckerOK, int]("github.com/samber/do/v2.Healthchecker", i, "int")
instance3, err3 := service3.getInstance(i)
is.EqualError(err3, "DI: could not find service `int`, available services: `*github.com/samber/do/v2.lazyTestHeathcheckerOK`, `github.com/samber/do/v2.Healthchecker`, `int`")
is.EqualError(err3, "DI: service found, but type mismatch: invoking `*github.com/samber/do/v2.lazyTestHeathcheckerOK` but registered `int`")
is.EqualValues(0, instance3)

// @TODO: missing test with child scopes
Expand Down
9 changes: 8 additions & 1 deletion service_eager.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ var _ serviceClone = (*serviceEager[int])(nil)

type serviceEager[T any] struct {
name string
typeName string
instance T

providerFrame stacktrace.Frame
Expand All @@ -28,6 +29,7 @@ func newServiceEager[T any](name string, instance T) *serviceEager[T] {

return &serviceEager[T]{
name: name,
typeName: inferServiceName[T](),
instance: instance,

providerFrame: providerFrame,
Expand All @@ -41,7 +43,11 @@ func (s *serviceEager[T]) getName() string {
return s.name
}

func (s *serviceEager[T]) getType() ServiceType {
func (s *serviceEager[T]) getTypeName() string {
return s.typeName
}

func (s *serviceEager[T]) getServiceType() ServiceType {
return ServiceTypeEager
}

Expand Down Expand Up @@ -112,6 +118,7 @@ func (s *serviceEager[T]) shutdown(ctx context.Context) error {
func (s *serviceEager[T]) clone() any {
return &serviceEager[T]{
name: s.name,
typeName: s.typeName,
instance: s.instance,

providerFrame: s.providerFrame,
Expand Down
19 changes: 16 additions & 3 deletions service_eager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,30 @@ func TestServiceEager_getName(t *testing.T) {
is.Equal("foobar2", service2.getName())
}

func TestServiceEager_getType(t *testing.T) {
func TestServiceEager_getTypeName(t *testing.T) {
t.Parallel()
is := assert.New(t)

test := eagerTest{foobar: "foobar"}

service1 := newServiceEager("foobar1", 42)
is.Equal(ServiceTypeEager, service1.getType())
is.Equal("int", service1.getTypeName())

service2 := newServiceEager("foobar2", test)
is.Equal(ServiceTypeEager, service2.getType())
is.Equal("github.com/samber/do/v2.eagerTest", service2.getTypeName())
}

func TestServiceEager_getServiceType(t *testing.T) {
t.Parallel()
is := assert.New(t)

test := eagerTest{foobar: "foobar"}

service1 := newServiceEager("foobar1", 42)
is.Equal(ServiceTypeEager, service1.getServiceType())

service2 := newServiceEager("foobar2", test)
is.Equal(ServiceTypeEager, service2.getServiceType())
}

func TestServiceEager_getEmptyInstance(t *testing.T) {
Expand Down
17 changes: 12 additions & 5 deletions service_lazy.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ var _ serviceClone = (*serviceLazy[int])(nil)
type serviceLazy[T any] struct {
mu sync.RWMutex
name string
typeName string
instance T

// lazy loading
Expand All @@ -34,8 +35,9 @@ func newServiceLazy[T any](name string, provider Provider[T]) *serviceLazy[T] {
providerFrame, _ := stacktrace.NewFrameFromPtr(reflect.ValueOf(provider).Pointer())

return &serviceLazy[T]{
mu: sync.RWMutex{},
name: name,
mu: sync.RWMutex{},
name: name,
typeName: inferServiceName[T](),

built: false,
buildTime: 0,
Expand All @@ -51,7 +53,11 @@ func (s *serviceLazy[T]) getName() string {
return s.name
}

func (s *serviceLazy[T]) getType() ServiceType {
func (s *serviceLazy[T]) getTypeName() string {
return s.typeName
}

func (s *serviceLazy[T]) getServiceType() ServiceType {
return ServiceTypeLazy
}

Expand Down Expand Up @@ -179,8 +185,9 @@ func (s *serviceLazy[T]) shutdown(ctx context.Context) error {
func (s *serviceLazy[T]) clone() any {
// reset `build` flag and instance
return &serviceLazy[T]{
mu: sync.RWMutex{},
name: s.name,
mu: sync.RWMutex{},
name: s.name,
typeName: s.typeName,

built: false,
provider: s.provider,
Expand Down
26 changes: 23 additions & 3 deletions service_lazy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func TestServiceLazy_getName(t *testing.T) {
is.Equal("foobar2", service2.getName())
}

func TestServiceLazy_getType(t *testing.T) {
func TestServiceLazy_getTypeName(t *testing.T) {
t.Parallel()
is := assert.New(t)

Expand All @@ -122,10 +122,30 @@ func TestServiceLazy_getType(t *testing.T) {
}

service1 := newServiceLazy("foobar1", provider1)
is.Equal(ServiceTypeLazy, service1.getType())
is.Equal("int", service1.getTypeName())

service2 := newServiceLazy("foobar2", provider2)
is.Equal(ServiceTypeLazy, service2.getType())
is.Equal("github.com/samber/do/v2.lazyTest", service2.getTypeName())
}

func TestServiceLazy_getServiceType(t *testing.T) {
t.Parallel()
is := assert.New(t)

test := lazyTest{foobar: "foobar"}

provider1 := func(i Injector) (int, error) {
return 42, nil
}
provider2 := func(i Injector) (lazyTest, error) {
return test, nil
}

service1 := newServiceLazy("foobar1", provider1)
is.Equal(ServiceTypeLazy, service1.getServiceType())

service2 := newServiceLazy("foobar2", provider2)
is.Equal(ServiceTypeLazy, service2.getServiceType())
}

func TestServiceLazy_getEmptyInstance(t *testing.T) {
Expand Down
15 changes: 11 additions & 4 deletions service_transient.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@ var _ serviceShutdown = (*serviceTransient[int])(nil)
var _ serviceClone = (*serviceTransient[int])(nil)

type serviceTransient[T any] struct {
name string
name string
typeName string

// lazy loading
provider Provider[T]
}

func newServiceTransient[T any](name string, provider Provider[T]) *serviceTransient[T] {
return &serviceTransient[T]{
name: name,
name: name,
typeName: inferServiceName[T](),

provider: provider,
}
Expand All @@ -30,7 +32,11 @@ func (s *serviceTransient[T]) getName() string {
return s.name
}

func (s *serviceTransient[T]) getType() ServiceType {
func (s *serviceTransient[T]) getTypeName() string {
return s.typeName
}

func (s *serviceTransient[T]) getServiceType() ServiceType {
return ServiceTypeTransient
}

Expand Down Expand Up @@ -68,7 +74,8 @@ func (s *serviceTransient[T]) shutdown(ctx context.Context) error {

func (s *serviceTransient[T]) clone() any {
return &serviceTransient[T]{
name: s.name,
name: s.name,
typeName: s.typeName,

provider: s.provider,
}
Expand Down
Loading

0 comments on commit 1998a7a

Please sign in to comment.