Skip to content

Commit

Permalink
Allow factories to be registered as specific types.
Browse files Browse the repository at this point in the history
Fixes #151
YairHalberstadt committed Oct 23, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent f983e81 commit 781ff22
Showing 3 changed files with 802 additions and 47 deletions.
173 changes: 129 additions & 44 deletions StrongInject.Generator/RegistrationCalculator.cs
Original file line number Diff line number Diff line change
@@ -478,14 +478,14 @@ private void AppendFactoryRegistrations(Dictionary<ITypeSymbol, InstanceSources>
}
}

private bool CheckValidType(AttributeData registerAttribute, TypedConstant typedConstant, ITypeSymbol module, out INamedTypeSymbol type)
private bool CheckValidType(AttributeData attribute, TypedConstant typedConstant, ITypeSymbol module, out INamedTypeSymbol type)
{
type = (typedConstant.Value as INamedTypeSymbol)!;
if (typedConstant.Value is null)
{
_reportDiagnostic(InvalidType(
(ITypeSymbol)typedConstant.Value!,
registerAttribute.GetLocation(_cancellationToken)));
attribute.GetLocation(_cancellationToken)));
return false;
}
if (type.IsOrReferencesErrorType())
@@ -498,7 +498,7 @@ private bool CheckValidType(AttributeData registerAttribute, TypedConstant typed
{
_reportDiagnostic(TypeDoesNotHaveAtLeastInternalAccessibility(
type,
registerAttribute.GetLocation(_cancellationToken)));
attribute.GetLocation(_cancellationToken)));
return false;
}

@@ -507,7 +507,7 @@ private bool CheckValidType(AttributeData registerAttribute, TypedConstant typed
_reportDiagnostic(WarnTypeNotPublic(
type,
module,
registerAttribute.GetLocation(_cancellationToken)));
attribute.GetLocation(_cancellationToken)));
}

return true;
@@ -577,12 +577,11 @@ private void AppendFactoryMethods(
_ => throw new InvalidEnumArgumentException(nameof(registrationsToCalculate), (int)registrationsToCalculate, typeof(RegistrationsToCalculate))
})
{
var instanceSource = CreateInstanceSourceIfFactoryMethod(method, out var attribute);
if (instanceSource is not null)
foreach (var instanceSource in CreateInstanceSourceIfFactoryMethod(method, module))
{
if (instanceSource.IsOpenGeneric)
if (instanceSource is FactoryMethod { IsOpenGeneric: true } factoryMethod)
{
genericRegistrations.Add(instanceSource);
genericRegistrations.Add(factoryMethod);
}
else
{
@@ -592,62 +591,103 @@ private void AppendFactoryMethods(
}
}

private FactoryMethod? CreateInstanceSourceIfFactoryMethod(IMethodSymbol method, out AttributeData attribute)
private IEnumerable<InstanceSource> CreateInstanceSourceIfFactoryMethod(IMethodSymbol method, INamedTypeSymbol module)
{
attribute = method.GetAttributes().FirstOrDefault(x
var attribute = method.GetAttributes().FirstOrDefault(x
=> x.AttributeClass is { } attribute
&& attribute.Equals(_wellKnownTypes.FactoryAttribute, SymbolEqualityComparer.Default))!;
&& attribute.Equals(_wellKnownTypes.FactoryAttribute, SymbolEqualityComparer.Default))!;

if (attribute is not null)
{
var countConstructorArguments = attribute.ConstructorArguments.Length;
if (countConstructorArguments != 1)
if (countConstructorArguments is not (1 or 2))
{
// Invalid code, ignore
return null;
yield break;
}

var scope = attribute.ConstructorArguments[0] is { Kind: TypedConstantKind.Enum, Value: int scopeInt }
? (Scope)scopeInt
: Scope.InstancePerResolution;

if (method.ReturnType is { SpecialType: not SpecialType.System_Void } returnType)
var asTypes = attribute.ConstructorArguments.Last()
is { Kind: TypedConstantKind.Array, Values: { IsDefaultOrEmpty: false } types }
? types
: ImmutableArray<TypedConstant>.Empty;

if (method.ReturnType.SpecialType == SpecialType.System_Void)
{
bool isGeneric = method.TypeParameters.Length > 0;
if (isGeneric && !AllTypeParametersUsedInReturnType(method))
_reportDiagnostic(FactoryMethodReturnsVoid(
method,
attribute.ApplicationSyntaxReference?.GetSyntax(_cancellationToken).GetLocation() ?? Location.None));

yield break;
}

bool isGeneric = method.TypeParameters.Length > 0;
if (isGeneric && !AllTypeParametersUsedInReturnType(method))
{
_reportDiagnostic(NotAllTypeParametersUsedInReturnType(
method,
attribute.ApplicationSyntaxReference?.GetSyntax(_cancellationToken).GetLocation() ?? Location.None));

yield break;
}

if (isGeneric && !asTypes.IsEmpty)
{
_reportDiagnostic(GenericFactoryMethodWithAsTypes(method, attribute.GetLocation(_cancellationToken)));
yield break;
}

foreach (var param in method.Parameters)
{
if (param.RefKind != RefKind.None)
{
_reportDiagnostic(NotAllTypeParametersUsedInReturnType(
_reportDiagnostic(FactoryMethodParameterIsPassedByRef(
method,
attribute.ApplicationSyntaxReference?.GetSyntax(_cancellationToken).GetLocation() ?? Location.None));
param,
param.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax(_cancellationToken).GetLocation() ?? Location.None));
yield break;
}
}

return null;
var returnType = method.ReturnType;
var factoryMethod = returnType.IsWellKnownTaskType(_wellKnownTypes, out var taskOfType)
? new FactoryMethod(method, taskOfType, scope, isGeneric, IsAsync: true)
: new FactoryMethod(method, returnType, scope, isGeneric, IsAsync: false);

if (asTypes.IsEmpty)
{
yield return factoryMethod;
yield break;
}

var factoryOfType = factoryMethod.FactoryOfType;

foreach (var asType in asTypes)
{
if (!CheckValidType(attribute, asType, module, out var target))
{
// Invalid code, ignore
continue;
}

foreach (var param in method.Parameters)
if (target.IsUnboundGenericType)
{
if (param.RefKind != RefKind.None)
{
_reportDiagnostic(FactoryMethodParameterIsPassedByRef(
method,
param,
param.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax(_cancellationToken).GetLocation() ?? Location.None));
return null;
}
_reportDiagnostic(FactoryMethodWithUnboundGenericAsTypes(method, target, attribute.GetLocation(_cancellationToken)));
continue;
}

if (returnType.IsWellKnownTaskType(_wellKnownTypes, out var taskOfType))
if (_compilation.ClassifyConversion(factoryOfType, target) is not { IsImplicit: true, IsNumeric: false, IsUserDefined: false })
{
return new FactoryMethod(method, taskOfType, scope, isGeneric, IsAsync: true);
_reportDiagnostic(FactoryMethodDoesNotHaveSuitableConversion(method, factoryOfType, target, attribute.GetLocation(_cancellationToken)));
continue;
}

return new FactoryMethod(method, returnType, scope, isGeneric, IsAsync: false);
yield return ForwardedInstanceSource.Create(target, factoryMethod);
}

_reportDiagnostic(FactoryMethodReturnsVoid(
method,
attribute.ApplicationSyntaxReference?.GetSyntax(_cancellationToken).GetLocation() ?? Location.None));
}

return null;
}

private static bool AllTypeParametersUsedInReturnType(IMethodSymbol method)
@@ -1455,7 +1495,7 @@ private static Diagnostic FactoryMethodReturnsVoid(IMethodSymbol methodSymbol, L
return Diagnostic.Create(
new DiagnosticDescriptor(
"SI0014",
"Factory Method returns void",
"Factory method returns void",
"Factory method '{0}' returns void.",
"StrongInject",
DiagnosticSeverity.Error,
@@ -1501,7 +1541,7 @@ private static Diagnostic NotAllTypeParametersUsedInReturnType(IMethodSymbol met
return Diagnostic.Create(
new DiagnosticDescriptor(
"SI0020",
"All type parameters must be used in return type of generic Factory Method",
"All type parameters must be used in return type of generic factory method",
"All type parameters must be used in return type of generic factory method '{0}'",
"StrongInject",
DiagnosticSeverity.Error,
@@ -1559,8 +1599,8 @@ private static Diagnostic DecoratorFactoryMethodDoesNotHaveParameterOfDecoratedT
return Diagnostic.Create(
new DiagnosticDescriptor(
"SI0024",
"Decorator Factory Method does not have a parameter of decorated type",
"Decorator Factory '{0}' does not have a parameter of decorated type '{1}'.",
"Decorator factory method does not have a parameter of decorated type",
"Decorator factory '{0}' does not have a parameter of decorated type '{1}'.",
"StrongInject",
DiagnosticSeverity.Error,
isEnabledByDefault: true),
@@ -1574,8 +1614,8 @@ private static Diagnostic DecoratorFactoryMethodHasMultipleParametersOfDecorated
return Diagnostic.Create(
new DiagnosticDescriptor(
"SI0025",
"Decorator Factory Method has multiple constructor parameters of decorated type",
"Decorator Factory '{0}' has multiple constructor parameters of decorated type '{1}'.",
"Decorator Factory method has multiple constructor parameters of decorated type",
"Decorator factory '{0}' has multiple constructor parameters of decorated type '{1}'.",
"StrongInject",
DiagnosticSeverity.Error,
isEnabledByDefault: true),
@@ -1674,6 +1714,51 @@ private static Diagnostic MismatchingNumberOfTypeParameters(AttributeData regist
registeredAsType);
}

private static Diagnostic FactoryMethodDoesNotHaveSuitableConversion(IMethodSymbol method, ITypeSymbol returnType, INamedTypeSymbol registeredAsType, Location location)
{
return Diagnostic.Create(
new DiagnosticDescriptor(
"SI0032",
"Return type of factory method does not have an identity, implicit reference, boxing or nullable conversion to registered as type",
"Return type '{0}' of '{1}' does not have an identity, implicit reference, boxing or nullable conversion to '{2}'.",
"StrongInject",
DiagnosticSeverity.Error,
isEnabledByDefault: true),
location,
returnType,
method.Name,
registeredAsType);
}

private static Diagnostic GenericFactoryMethodWithAsTypes(IMethodSymbol method, Location location)
{
return Diagnostic.Create(
new DiagnosticDescriptor(
"SI0033",
"Factory method cannot be registered as specific types since it is generic.",
"Factory method '{0}' cannot be registered as specific types since it is generic.",
"StrongInject",
DiagnosticSeverity.Error,
isEnabledByDefault: true),
location,
method.Name);
}

private static Diagnostic FactoryMethodWithUnboundGenericAsTypes(IMethodSymbol method, INamedTypeSymbol asType, Location location)
{
return Diagnostic.Create(
new DiagnosticDescriptor(
"SI0034",
"Factory method cannot be registered as an instance of open generic type.",
"Factory method '{0}' cannot be registered as an instance of open generic type '{1}'.",
"StrongInject",
DiagnosticSeverity.Error,
isEnabledByDefault: true),
location,
method.Name,
asType);
}

private static Diagnostic WarnSimpleRegistrationImplementingFactory(ITypeSymbol type, ITypeSymbol factoryType, Location location)
{
return Diagnostic.Create(
@@ -1694,7 +1779,7 @@ private static Diagnostic WarnFactoryMethodNotPublicStaticOrProtected(ITypeSymbo
return Diagnostic.Create(
new DiagnosticDescriptor(
"SI1002",
"Factory Method is not either public and static, or protected, and containing module is not a container, so will be ignored",
"Factory method is not either public and static, or protected, and containing module is not a container, so will be ignored",
"Factory method '{0}' is not either public and static, or protected, and containing module '{1}' is not a container, so will be ignored.",
"StrongInject",
DiagnosticSeverity.Warning,
Loading

0 comments on commit 781ff22

Please sign in to comment.