Skip to content
130 changes: 88 additions & 42 deletions src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,22 @@ private static bool IsIsNotNullPattern(IOperation operation, [NotNullWhen(true)]
}

// TODO: Recognize 'null == something' (i.e, when null is the left operand)
private static bool IsEqualsNullBinaryOperator(IOperation operation, [NotNullWhen(true)] out SyntaxNode? expressionUnderTest, out ITypeSymbol? typeOfExpressionUnderTest)
private static bool IsEqualsNullBinaryOperator(IOperation operation, INamedTypeSymbol objectTypeSymbol, [NotNullWhen(true)] out SyntaxNode? expressionUnderTest, out ITypeSymbol? typeOfExpressionUnderTest)
{
if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.Equals, RightOperand: { } rightOperand } binaryOperation &&
binaryOperation.OperatorMethod is not { MethodKind: MethodKind.UserDefinedOperator } &&
rightOperand.WalkDownConversion() is ILiteralOperation { ConstantValue: { HasValue: true, Value: null } })
{
expressionUnderTest = binaryOperation.LeftOperand.Syntax;
typeOfExpressionUnderTest = binaryOperation.LeftOperand.WalkDownConversion().Type;
return true;
// Allow built-in operators or user-defined operators from BCL types
bool isBuiltInOperator = binaryOperation.OperatorMethod is not { MethodKind: MethodKind.UserDefinedOperator };
bool isBCLUserDefinedOperator = binaryOperation.OperatorMethod is { MethodKind: MethodKind.UserDefinedOperator } &&
IsBCLType(binaryOperation.OperatorMethod.ContainingType, objectTypeSymbol);

if (isBuiltInOperator || isBCLUserDefinedOperator)
{
expressionUnderTest = binaryOperation.LeftOperand.Syntax;
typeOfExpressionUnderTest = binaryOperation.LeftOperand.WalkDownConversion().Type;
return true;
}
}

expressionUnderTest = null;
Expand All @@ -319,15 +326,22 @@ private static bool IsEqualsNullBinaryOperator(IOperation operation, [NotNullWhe
}

// TODO: Recognize 'null != something' (i.e, when null is the left operand)
private static bool IsNotEqualsNullBinaryOperator(IOperation operation, [NotNullWhen(true)] out SyntaxNode? expressionUnderTest, out ITypeSymbol? typeOfExpressionUnderTest)
private static bool IsNotEqualsNullBinaryOperator(IOperation operation, INamedTypeSymbol objectTypeSymbol, [NotNullWhen(true)] out SyntaxNode? expressionUnderTest, out ITypeSymbol? typeOfExpressionUnderTest)
{
if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.NotEquals, RightOperand: { } rightOperand } binaryOperation &&
binaryOperation.OperatorMethod is not { MethodKind: MethodKind.UserDefinedOperator } &&
rightOperand.WalkDownConversion() is ILiteralOperation { ConstantValue: { HasValue: true, Value: null } })
{
expressionUnderTest = binaryOperation.LeftOperand.Syntax;
typeOfExpressionUnderTest = binaryOperation.LeftOperand.WalkDownConversion().Type;
return true;
// Allow built-in operators or user-defined operators from BCL types
bool isBuiltInOperator = binaryOperation.OperatorMethod is not { MethodKind: MethodKind.UserDefinedOperator };
bool isBCLUserDefinedOperator = binaryOperation.OperatorMethod is { MethodKind: MethodKind.UserDefinedOperator } &&
IsBCLType(binaryOperation.OperatorMethod.ContainingType, objectTypeSymbol);

if (isBuiltInOperator || isBCLUserDefinedOperator)
{
expressionUnderTest = binaryOperation.LeftOperand.Syntax;
typeOfExpressionUnderTest = binaryOperation.LeftOperand.WalkDownConversion().Type;
return true;
}
}

expressionUnderTest = null;
Expand All @@ -337,18 +351,19 @@ private static bool IsNotEqualsNullBinaryOperator(IOperation operation, [NotNull

private static NullCheckStatus RecognizeNullCheck(
IOperation operation,
INamedTypeSymbol objectTypeSymbol,
// Note that expressionUnderTest is guaranteed to be non-null when the method returns a value other than NullCheckStatus.Unknown.
// Given the current nullability attributes, there is no way to express this.
out SyntaxNode? expressionUnderTest,
out ITypeSymbol? typeOfExpressionUnderTest)
{
if (IsIsNullPattern(operation, out expressionUnderTest, out typeOfExpressionUnderTest) ||
IsEqualsNullBinaryOperator(operation, out expressionUnderTest, out typeOfExpressionUnderTest))
IsEqualsNullBinaryOperator(operation, objectTypeSymbol, out expressionUnderTest, out typeOfExpressionUnderTest))
{
return NullCheckStatus.IsNull;
}
else if (IsIsNotNullPattern(operation, out expressionUnderTest, out typeOfExpressionUnderTest) ||
IsNotEqualsNullBinaryOperator(operation, out expressionUnderTest, out typeOfExpressionUnderTest))
IsNotEqualsNullBinaryOperator(operation, objectTypeSymbol, out expressionUnderTest, out typeOfExpressionUnderTest))
{
return NullCheckStatus.IsNotNull;
}
Expand All @@ -358,6 +373,7 @@ private static NullCheckStatus RecognizeNullCheck(

private static EqualityCheckStatus RecognizeEqualityCheck(
IOperation operation,
INamedTypeSymbol objectTypeSymbol,
out SyntaxNode? toBecomeExpected,
out SyntaxNode? toBecomeActual,
out ITypeSymbol? leftType,
Expand All @@ -371,15 +387,22 @@ private static EqualityCheckStatus RecognizeEqualityCheck(
rightType = isPattern1.Value.WalkDownConversion().Type;
return EqualityCheckStatus.Equals;
}
else if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.Equals } binaryOperation1 &&
binaryOperation1.OperatorMethod is not { MethodKind: MethodKind.UserDefinedOperator })
else if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.Equals } binaryOperation1)
{
// This is quite arbitrary. We can do extra checks to see which one (if any) looks like a "constant" and make it the expected.
toBecomeExpected = binaryOperation1.RightOperand.Syntax;
toBecomeActual = binaryOperation1.LeftOperand.Syntax;
leftType = binaryOperation1.RightOperand.WalkDownConversion().Type;
rightType = binaryOperation1.LeftOperand.WalkDownConversion().Type;
return EqualityCheckStatus.Equals;
// Allow built-in operators or user-defined operators from BCL types
bool isBuiltInOperator = binaryOperation1.OperatorMethod is not { MethodKind: MethodKind.UserDefinedOperator };
bool isBCLUserDefinedOperator = binaryOperation1.OperatorMethod is { MethodKind: MethodKind.UserDefinedOperator } &&
IsBCLType(binaryOperation1.OperatorMethod.ContainingType, objectTypeSymbol);

if (isBuiltInOperator || isBCLUserDefinedOperator)
{
// This is quite arbitrary. We can do extra checks to see which one (if any) looks like a "constant" and make it the expected.
toBecomeExpected = binaryOperation1.RightOperand.Syntax;
toBecomeActual = binaryOperation1.LeftOperand.Syntax;
leftType = binaryOperation1.RightOperand.WalkDownConversion().Type;
rightType = binaryOperation1.LeftOperand.WalkDownConversion().Type;
return EqualityCheckStatus.Equals;
}
}
else if (operation is IIsPatternOperation { Pattern: INegatedPatternOperation { Pattern: IConstantPatternOperation constantPattern2 } } isPattern2)
{
Expand All @@ -389,15 +412,22 @@ private static EqualityCheckStatus RecognizeEqualityCheck(
rightType = isPattern2.Value.WalkDownConversion().Type;
return EqualityCheckStatus.NotEquals;
}
else if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.NotEquals } binaryOperation2 &&
binaryOperation2.OperatorMethod is not { MethodKind: MethodKind.UserDefinedOperator })
else if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.NotEquals } binaryOperation2)
{
// This is quite arbitrary. We can do extra checks to see which one (if any) looks like a "constant" and make it the expected.
toBecomeExpected = binaryOperation2.RightOperand.Syntax;
toBecomeActual = binaryOperation2.LeftOperand.Syntax;
leftType = binaryOperation2.RightOperand.WalkDownConversion().Type;
rightType = binaryOperation2.LeftOperand.WalkDownConversion().Type;
return EqualityCheckStatus.NotEquals;
// Allow built-in operators or user-defined operators from BCL types
bool isBuiltInOperator = binaryOperation2.OperatorMethod is not { MethodKind: MethodKind.UserDefinedOperator };
bool isBCLUserDefinedOperator = binaryOperation2.OperatorMethod is { MethodKind: MethodKind.UserDefinedOperator } &&
IsBCLType(binaryOperation2.OperatorMethod.ContainingType, objectTypeSymbol);

if (isBuiltInOperator || isBCLUserDefinedOperator)
{
// This is quite arbitrary. We can do extra checks to see which one (if any) looks like a "constant" and make it the expected.
toBecomeExpected = binaryOperation2.RightOperand.Syntax;
toBecomeActual = binaryOperation2.LeftOperand.Syntax;
leftType = binaryOperation2.RightOperand.WalkDownConversion().Type;
rightType = binaryOperation2.LeftOperand.WalkDownConversion().Type;
return EqualityCheckStatus.NotEquals;
}
}

toBecomeExpected = null;
Expand Down Expand Up @@ -493,25 +523,41 @@ private static bool IsBCLCollectionType(ITypeSymbol type, INamedTypeSymbol objec
type.ContainingAssembly.Identity.HasPublicKey == objectTypeSymbol.ContainingAssembly.Identity.HasPublicKey &&
type.ContainingAssembly.Identity.PublicKey.SequenceEqual(objectTypeSymbol.ContainingAssembly.Identity.PublicKey);

private static bool IsBCLType(ITypeSymbol? type, INamedTypeSymbol objectTypeSymbol)
// Check if the type is from the BCL by comparing its assembly's public key with the object type's assembly public key.
=> type is not null &&
type.ContainingAssembly is not null &&
// object is coming from BCL and it's expected to always have a public key.
type.ContainingAssembly.Identity.HasPublicKey == objectTypeSymbol.ContainingAssembly.Identity.HasPublicKey &&
type.ContainingAssembly.Identity.PublicKey.SequenceEqual(objectTypeSymbol.ContainingAssembly.Identity.PublicKey);

private static ComparisonCheckStatus RecognizeComparisonCheck(
IOperation operation,
INamedTypeSymbol objectTypeSymbol,
out SyntaxNode? leftExpression,
out SyntaxNode? rightExpression)
{
if (operation is IBinaryOperation binaryOperation &&
binaryOperation.OperatorMethod is not { MethodKind: MethodKind.UserDefinedOperator })
if (operation is IBinaryOperation binaryOperation)
{
leftExpression = binaryOperation.LeftOperand.Syntax;
rightExpression = binaryOperation.RightOperand.Syntax;
// Allow built-in operators or user-defined operators from BCL types
bool isBuiltInOperator = binaryOperation.OperatorMethod is not { MethodKind: MethodKind.UserDefinedOperator };
bool isBCLUserDefinedOperator = binaryOperation.OperatorMethod is { MethodKind: MethodKind.UserDefinedOperator } &&
IsBCLType(binaryOperation.OperatorMethod.ContainingType, objectTypeSymbol);

return binaryOperation.OperatorKind switch
if (isBuiltInOperator || isBCLUserDefinedOperator)
{
BinaryOperatorKind.GreaterThan => ComparisonCheckStatus.GreaterThan,
BinaryOperatorKind.GreaterThanOrEqual => ComparisonCheckStatus.GreaterThanOrEqual,
BinaryOperatorKind.LessThan => ComparisonCheckStatus.LessThan,
BinaryOperatorKind.LessThanOrEqual => ComparisonCheckStatus.LessThanOrEqual,
_ => ComparisonCheckStatus.Unknown,
};
leftExpression = binaryOperation.LeftOperand.Syntax;
rightExpression = binaryOperation.RightOperand.Syntax;

return binaryOperation.OperatorKind switch
{
BinaryOperatorKind.GreaterThan => ComparisonCheckStatus.GreaterThan,
BinaryOperatorKind.GreaterThanOrEqual => ComparisonCheckStatus.GreaterThanOrEqual,
BinaryOperatorKind.LessThan => ComparisonCheckStatus.LessThan,
BinaryOperatorKind.LessThanOrEqual => ComparisonCheckStatus.LessThanOrEqual,
_ => ComparisonCheckStatus.Unknown,
};
}
}

leftExpression = null;
Expand All @@ -523,7 +569,7 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co
{
RoslynDebug.Assert(context.Operation is IInvocationOperation, "Expected IInvocationOperation.");

NullCheckStatus nullCheckStatus = RecognizeNullCheck(conditionArgument, out SyntaxNode? expressionUnderTest, out ITypeSymbol? typeOfExpressionUnderTest);
NullCheckStatus nullCheckStatus = RecognizeNullCheck(conditionArgument, objectTypeSymbol, out SyntaxNode? expressionUnderTest, out ITypeSymbol? typeOfExpressionUnderTest);

// In this code path, we will be suggesting the use of IsNull/IsNotNull.
// These assert methods only have an "object" overload.
Expand Down Expand Up @@ -625,7 +671,7 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co
}

// Check for comparison patterns: a > b, a >= b, a < b, a <= b
ComparisonCheckStatus comparisonStatus = RecognizeComparisonCheck(conditionArgument, out SyntaxNode? leftExpr, out SyntaxNode? rightExpr);
ComparisonCheckStatus comparisonStatus = RecognizeComparisonCheck(conditionArgument, objectTypeSymbol, out SyntaxNode? leftExpr, out SyntaxNode? rightExpr);
if (comparisonStatus != ComparisonCheckStatus.Unknown)
{
string properAssertMethod = (isTrueInvocation, comparisonStatus) switch
Expand Down Expand Up @@ -686,7 +732,7 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co
return;
}

EqualityCheckStatus equalityCheckStatus = RecognizeEqualityCheck(conditionArgument, out SyntaxNode? toBecomeExpected, out SyntaxNode? toBecomeActual, out ITypeSymbol? leftType, out ITypeSymbol? rightType);
EqualityCheckStatus equalityCheckStatus = RecognizeEqualityCheck(conditionArgument, objectTypeSymbol, out SyntaxNode? toBecomeExpected, out SyntaxNode? toBecomeActual, out ITypeSymbol? leftType, out ITypeSymbol? rightType);
if (equalityCheckStatus != EqualityCheckStatus.Unknown &&
CanUseTypeAsObject(context.Compilation, leftType) &&
CanUseTypeAsObject(context.Compilation, rightType))
Expand Down
Loading