Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 315 additions & 0 deletions src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,19 @@ private enum CountCheckStatus
HasCount,
}

private enum LinqPredicateCheckStatus
{
Unknown,
Any,
Count,
WhereAny,
WhereCount,
Single,
SingleOrDefault,
WhereSingle,
WhereSingleOrDefault,
}

internal const string ProperAssertMethodNameKey = nameof(ProperAssertMethodNameKey);

/// <summary>
Expand Down Expand Up @@ -268,6 +281,56 @@ private static void AnalyzeInvocationOperation(OperationAnalysisContext context,
case "AreNotEqual":
AnalyzeAreEqualOrAreNotEqualInvocation(context, firstArgument, isAreEqualInvocation: false, objectTypeSymbol);
break;
case "IsNull":
AnalyzeIsNullOrIsNotNullInvocation(context, firstArgument, isNullCheck: true);
break;

case "IsNotNull":
AnalyzeIsNullOrIsNotNullInvocation(context, firstArgument, isNullCheck: false);
break;
}
}

private static void AnalyzeIsNullOrIsNotNullInvocation(OperationAnalysisContext context, IOperation argument, bool isNullCheck)
{
RoslynDebug.Assert(context.Operation is IInvocationOperation, "Expected IInvocationOperation.");

// Check for Single/SingleOrDefault patterns
LinqPredicateCheckStatus linqStatus = RecognizeLinqPredicateCheck(
argument,
out SyntaxNode? linqCollectionExpr,
out SyntaxNode? predicateExpr,
out _);

if (linqStatus is LinqPredicateCheckStatus.Single or
LinqPredicateCheckStatus.SingleOrDefault or
LinqPredicateCheckStatus.WhereSingle or
LinqPredicateCheckStatus.WhereSingleOrDefault &&
Comment on lines +305 to +308
Copy link

Copilot AI Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Operator precedence issue: The condition on line 308 is missing parentheses around the && operation, making the intent ambiguous. The current condition evaluates as:

(linqStatus is ... or ... or ... or ...) && (linqCollectionExpr != null)

However, based on the logic, it appears the intent is:

linqStatus is (...WhereSingle or WhereSingleOrDefault) && linqCollectionExpr != null

Add parentheses to clarify the intended precedence:

if ((linqStatus is LinqPredicateCheckStatus.Single or
                   LinqPredicateCheckStatus.SingleOrDefault or
                   LinqPredicateCheckStatus.WhereSingle or
                   LinqPredicateCheckStatus.WhereSingleOrDefault) &&
    linqCollectionExpr != null)
Suggested change
if (linqStatus is LinqPredicateCheckStatus.Single or
LinqPredicateCheckStatus.SingleOrDefault or
LinqPredicateCheckStatus.WhereSingle or
LinqPredicateCheckStatus.WhereSingleOrDefault &&
if ((linqStatus is LinqPredicateCheckStatus.Single or
LinqPredicateCheckStatus.SingleOrDefault or
LinqPredicateCheckStatus.WhereSingle or
LinqPredicateCheckStatus.WhereSingleOrDefault) &&

Copilot uses AI. Check for mistakes.
linqCollectionExpr != null)
{
// For Assert.IsNotNull(enumerable.Single[OrDefault](...)) -> Assert.ContainsSingle
// For Assert.IsNull(enumerable.Single[OrDefault](...)) -> Assert.DoesNotContain
string properAssertMethod = isNullCheck ? "DoesNotContain" : "ContainsSingle";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, predicateExpr != null ? CodeFixModeAddArgument : CodeFixModeSimple);

ImmutableArray<Location> additionalLocations = predicateExpr != null
? ImmutableArray.Create(
argument.Syntax.GetLocation(),
predicateExpr.GetLocation(),
linqCollectionExpr.GetLocation())
: ImmutableArray.Create(
argument.Syntax.GetLocation(),
linqCollectionExpr.GetLocation());

context.ReportDiagnostic(context.Operation.CreateDiagnostic(
Rule,
additionalLocations: additionalLocations,
properties: properties.ToImmutable(),
properAssertMethod,
isNullCheck ? "IsNull" : "IsNotNull"));
}
}

Expand Down Expand Up @@ -519,6 +582,146 @@ private static ComparisonCheckStatus RecognizeComparisonCheck(
return ComparisonCheckStatus.Unknown;
}

private static LinqPredicateCheckStatus RecognizeLinqPredicateCheck(
IOperation operation,
out SyntaxNode? collectionExpression,
out SyntaxNode? predicateExpression,
out IOperation? countOperation)
{
collectionExpression = null;
predicateExpression = null;
countOperation = null;

// Check for enumerable.Any(predicate)
// Extension methods appear as: Instance=null, Arguments[0]=collection, Arguments[1]=predicate
if (operation is IInvocationOperation anyInvocation &&
anyInvocation.TargetMethod.Name == "Any" &&
anyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
anyInvocation.Arguments.Length == 2)
{
collectionExpression = anyInvocation.Arguments[0].Value.Syntax;
predicateExpression = anyInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.Any;
}

// Check for enumerable.Count(predicate)
if (operation is IInvocationOperation countInvocation &&
countInvocation.TargetMethod.Name == "Count" &&
countInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
countInvocation.Arguments.Length == 2)
{
collectionExpression = countInvocation.Arguments[0].Value.Syntax;
predicateExpression = countInvocation.Arguments[1].Value.Syntax;
countOperation = operation;
return LinqPredicateCheckStatus.Count;
}

// Check for enumerable.Where(predicate).Any()
if (operation is IInvocationOperation whereAnyInvocation &&
whereAnyInvocation.TargetMethod.Name == "Any" &&
whereAnyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereAnyInvocation.Arguments.Length == 1 &&
whereAnyInvocation.Arguments[0].Value is IInvocationOperation whereInvocation &&
whereInvocation.TargetMethod.Name == "Where" &&
whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation.Arguments.Length == 2)
{
collectionExpression = whereInvocation.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.WhereAny;
}

// Check for enumerable.Where(predicate).Count()
if (operation is IInvocationOperation whereCountInvocation &&
whereCountInvocation.TargetMethod.Name == "Count" &&
whereCountInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereCountInvocation.Arguments.Length == 1 &&
whereCountInvocation.Arguments[0].Value is IInvocationOperation whereInvocation2 &&
whereInvocation2.TargetMethod.Name == "Where" &&
whereInvocation2.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation2.Arguments.Length == 2)
{
collectionExpression = whereInvocation2.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation2.Arguments[1].Value.Syntax;
countOperation = operation;
return LinqPredicateCheckStatus.WhereCount;
}

// Check for enumerable.Where(predicate).Single()
if (operation is IInvocationOperation whereSingleInvocation &&
whereSingleInvocation.TargetMethod.Name == "Single" &&
whereSingleInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereSingleInvocation.Arguments.Length == 1 &&
whereSingleInvocation.Arguments[0].Value is IInvocationOperation whereInvocation3 &&
whereInvocation3.TargetMethod.Name == "Where" &&
whereInvocation3.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation3.Arguments.Length == 2)
{
collectionExpression = whereInvocation3.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation3.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.WhereSingle;
}

// Check for enumerable.Where(predicate).SingleOrDefault()
if (operation is IInvocationOperation whereSingleOrDefaultInvocation &&
whereSingleOrDefaultInvocation.TargetMethod.Name == "SingleOrDefault" &&
whereSingleOrDefaultInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereSingleOrDefaultInvocation.Arguments.Length == 1 &&
whereSingleOrDefaultInvocation.Arguments[0].Value is IInvocationOperation whereInvocation4 &&
whereInvocation4.TargetMethod.Name == "Where" &&
whereInvocation4.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation4.Arguments.Length == 2)
{
collectionExpression = whereInvocation4.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation4.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.WhereSingleOrDefault;
}
Comment on lines +619 to +678
Copy link

Copilot AI Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The analyzer implementation has significant code duplication in the RecognizeLinqPredicateCheck method. The patterns for checking Where().Single(), Where().SingleOrDefault(), Where().Any(), and Where().Count() follow nearly identical logic with only method names differing. Consider extracting this repetitive logic into a helper method to improve maintainability.

Example refactoring approach:

private static bool TryMatchWherePattern(
    IInvocationOperation invocation,
    string methodName,
    out SyntaxNode? collectionExpression,
    out SyntaxNode? predicateExpression)
{
    if (invocation.TargetMethod.Name == methodName &&
        invocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
        invocation.Arguments.Length == 1 &&
        invocation.Arguments[0].Value is IInvocationOperation whereInvocation &&
        whereInvocation.TargetMethod.Name == "Where" &&
        whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
        whereInvocation.Arguments.Length == 2)
    {
        collectionExpression = whereInvocation.Arguments[0].Value.Syntax;
        predicateExpression = whereInvocation.Arguments[1].Value.Syntax;
        return true;
    }
    
    collectionExpression = null;
    predicateExpression = null;
    return false;
}

Copilot uses AI. Check for mistakes.

// Check for enumerable.Single(predicate)
if (operation is IInvocationOperation singleInvocation &&
singleInvocation.TargetMethod.Name == "Single" &&
singleInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable")
{
if (singleInvocation.Arguments.Length == 2)
{
// Extension method with predicate
collectionExpression = singleInvocation.Arguments[0].Value.Syntax;
predicateExpression = singleInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.Single;
}
else if (singleInvocation.Arguments.Length == 1)
{
// Instance method or extension without predicate
collectionExpression = singleInvocation.Instance?.Syntax ?? singleInvocation.Arguments[0].Value.Syntax;
predicateExpression = null;
return LinqPredicateCheckStatus.Single;
}
}

// Check for enumerable.SingleOrDefault(predicate)
if (operation is IInvocationOperation singleOrDefaultInvocation &&
singleOrDefaultInvocation.TargetMethod.Name == "SingleOrDefault" &&
singleOrDefaultInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable")
{
if (singleOrDefaultInvocation.Arguments.Length == 2)
{
// Extension method with predicate
collectionExpression = singleOrDefaultInvocation.Arguments[0].Value.Syntax;
predicateExpression = singleOrDefaultInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.SingleOrDefault;
}
else if (singleOrDefaultInvocation.Arguments.Length == 1)
{
// Instance method or extension without predicate
collectionExpression = singleOrDefaultInvocation.Instance?.Syntax ?? singleOrDefaultInvocation.Arguments[0].Value.Syntax;
predicateExpression = null;
return LinqPredicateCheckStatus.SingleOrDefault;
}
}

return LinqPredicateCheckStatus.Unknown;
}

private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext context, IOperation conditionArgument, bool isTrueInvocation, INamedTypeSymbol objectTypeSymbol)
{
RoslynDebug.Assert(context.Operation is IInvocationOperation, "Expected IInvocationOperation.");
Expand Down Expand Up @@ -555,6 +758,36 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co
return;
}

// Check for LINQ predicate patterns that suggest Contains/DoesNotContain
LinqPredicateCheckStatus linqStatus = RecognizeLinqPredicateCheck(
conditionArgument,
out SyntaxNode? linqCollectionExpr,
out SyntaxNode? predicateExpr,
out _);

if (linqStatus != LinqPredicateCheckStatus.Unknown && linqCollectionExpr != null && predicateExpr != null)
{
// For Any() and Where().Any() patterns
if (linqStatus is LinqPredicateCheckStatus.Any or LinqPredicateCheckStatus.WhereAny)
{
string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, CodeFixModeAddArgument);
context.ReportDiagnostic(context.Operation.CreateDiagnostic(
Rule,
additionalLocations: ImmutableArray.Create(
conditionArgument.Syntax.GetLocation(),
predicateExpr.GetLocation(),
linqCollectionExpr.GetLocation()),
properties: properties.ToImmutable(),
properAssertMethod,
isTrueInvocation ? "IsTrue" : "IsFalse"));
return;
}
}

// Check for string method patterns: myString.StartsWith/EndsWith/Contains(...)
StringMethodCheckStatus stringMethodStatus = RecognizeStringMethodCheck(conditionArgument, out SyntaxNode? stringExpr, out SyntaxNode? substringExpr);
if (stringMethodStatus != StringMethodCheckStatus.Unknown)
Expand Down Expand Up @@ -624,6 +857,54 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co
return;
}

// Special-case: enumerable.Count(predicate) > 0 → Assert.Contains(predicate, enumerable)
if (conditionArgument is IBinaryOperation binaryOp &&
binaryOp.OperatorKind == BinaryOperatorKind.GreaterThan)
{
if (binaryOp.LeftOperand is IInvocationOperation countInvocation &&
binaryOp.RightOperand.ConstantValue.HasValue &&
binaryOp.RightOperand.ConstantValue.Value is int intValue &&
intValue == 0 &&
countInvocation.TargetMethod.Name == "Count")
{
SyntaxNode? countCollectionExpr = null;
SyntaxNode? countPredicateExpr = null;

if (countInvocation.Instance != null && countInvocation.Arguments.Length == 1)
{
countCollectionExpr = countInvocation.Instance.Syntax;
countPredicateExpr = countInvocation.Arguments[0].Value.Syntax;
}
else if (countInvocation.Instance == null && countInvocation.Arguments.Length == 2)
{
countCollectionExpr = countInvocation.Arguments[0].Value.Syntax;
countPredicateExpr = countInvocation.Arguments[1].Value.Syntax;
}

if (countCollectionExpr != null && countPredicateExpr != null)
{
string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, CodeFixModeAddArgument);

context.ReportDiagnostic(
context.Operation.CreateDiagnostic(
Rule,
additionalLocations: ImmutableArray.Create(
conditionArgument.Syntax.GetLocation(),
countPredicateExpr.GetLocation(),
countCollectionExpr.GetLocation()),
properties: properties.ToImmutable(),
properAssertMethod,
isTrueInvocation ? "IsTrue" : "IsFalse"));

return;
}
}
}
Comment on lines +860 to +906
Copy link

Copilot AI Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Count pattern checking (lines 860-906) doesn't validate that the Count method is actually from LINQ (System.Linq.Enumerable) or has a predicate. This could lead to false positives with other Count methods (e.g., string.Count from LINQ, or custom Count methods on collections).

The check should verify:

  1. That it's the LINQ Count extension method with a predicate
  2. That it's from System.Linq.Enumerable

Example:

if (countInvocation.TargetMethod.Name == "Count" &&
    countInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable")

This is especially important since RecognizeLinqPredicateCheck already properly checks for this, so this code should either reuse that method or apply similar validation.

Copilot uses AI. Check for mistakes.

// Check for comparison patterns: a > b, a >= b, a < b, a <= b
ComparisonCheckStatus comparisonStatus = RecognizeComparisonCheck(conditionArgument, out SyntaxNode? leftExpr, out SyntaxNode? rightExpr);
if (comparisonStatus != ComparisonCheckStatus.Unknown)
Expand Down Expand Up @@ -722,6 +1003,40 @@ private static void AnalyzeAreEqualOrAreNotEqualInvocation(OperationAnalysisCont
{
if (TryGetSecondArgumentValue((IInvocationOperation)context.Operation, out IOperation? actualArgumentValue))
{
// Check for LINQ predicate patterns that suggest ContainsSingle
LinqPredicateCheckStatus linqStatus2 = RecognizeLinqPredicateCheck(
actualArgumentValue!,
out SyntaxNode? linqCollectionExpr2,
out SyntaxNode? predicateExpr2,
out _);

if (isAreEqualInvocation &&
linqStatus2 is LinqPredicateCheckStatus.Count or LinqPredicateCheckStatus.WhereCount &&
linqCollectionExpr2 != null &&
predicateExpr2 != null &&
expectedArgument.ConstantValue.HasValue &&
expectedArgument.ConstantValue.Value is int expectedCountValue &&
expectedCountValue == 1)
{
// We have Assert.AreEqual(1, enumerable.Count(predicate))
// We want Assert.ContainsSingle(predicate, enumerable)
string properAssertMethod = "ContainsSingle";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, CodeFixModeAddArgument);
context.ReportDiagnostic(context.Operation.CreateDiagnostic(
Rule,
additionalLocations: ImmutableArray.Create(
actualArgumentValue.Syntax.GetLocation(),
predicateExpr2.GetLocation(),
linqCollectionExpr2.GetLocation()),
properties: properties.ToImmutable(),
properAssertMethod,
"AreEqual"));
return;
}

// Check if we're comparing a count/length property
CountCheckStatus countStatus = RecognizeCountCheck(
expectedArgument,
Expand Down
3 changes: 2 additions & 1 deletion test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ private static void ValidateOutputIsNotMixed(IEnumerable<TestResult> testResults
Assert.Contains(methodName, message.Text);
Assert.Contains("TestInitialize", message.Text);
Assert.Contains("TestCleanup", message.Text);
Assert.IsFalse(shouldNotContain.Any(message.Text.Contains));
// Assert.IsFalse(shouldNotContain.Any(message.Text.Contains));
Assert.DoesNotContain(message.Text.Contains, shouldNotContain);
Copy link

Copilot AI Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change appears incorrect. The original code Assert.IsFalse(shouldNotContain.Any(message.Text.Contains)) is checking that none of the strings in shouldNotContain are contained in message.Text.

The new code Assert.DoesNotContain(message.Text.Contains, shouldNotContain) reverses the logic: it's checking whether the method message.Text.Contains (as a delegate) exists in the shouldNotContain collection, which is a string array. This will always pass (or fail to compile) since it's comparing a method delegate to string values.

The correct assertion should remain as the original, or if you want to use Assert.DoesNotContain, it should be written as:

foreach (string item in shouldNotContain)
{
    Assert.DoesNotContain(item, message.Text);
}
Suggested change
Assert.DoesNotContain(message.Text.Contains, shouldNotContain);
foreach (string item in shouldNotContain)
{
Assert.DoesNotContain(item, message.Text);
}

Copilot uses AI. Check for mistakes.
}

private static void ValidateInitializeAndCleanup(IEnumerable<TestResult> testResults, Func<TestResultMessage, bool> messageFilter)
Expand Down
Loading
Loading