From 1ea58a146ec1101b7e2441b49d1483464c56703d Mon Sep 17 00:00:00 2001 From: Muiz Atolagbe Date: Tue, 7 Oct 2025 22:59:30 +0100 Subject: [PATCH 01/10] analyzer changes --- .../UseProperAssertMethodsAnalyzer.cs | 168 +++++++++++++++ .../UseProperAssertMethodsAnalyzerTests.cs | 200 ++++++++++++++++++ 2 files changed, 368 insertions(+) diff --git a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs index 45f4e90b07..c183a0dc69 100644 --- a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs @@ -125,6 +125,15 @@ private enum CountCheckStatus HasCount, } + private enum LinqPredicateCheckStatus + { + Unknown, + Any, + Count, + WhereAny, + WhereCount, + } + internal const string ProperAssertMethodNameKey = nameof(ProperAssertMethodNameKey); /// @@ -519,6 +528,95 @@ private static ComparisonCheckStatus RecognizeComparisonCheck( return ComparisonCheckStatus.Unknown; } + // Add this new method to recognize LINQ patterns with predicates + private static LinqPredicateCheckStatus RecognizeLinqPredicateCheck( + IOperation operation, + out SyntaxNode? collectionExpression, + out SyntaxNode? predicateExpression, + out IOperation? countOperation) + { + collectionExpression = null; + predicateExpression = null; + countOperation = null; + + // Check for enumerable.Any(predicate) + if (operation is IInvocationOperation anyInvocation && + anyInvocation.TargetMethod.Name == "Any" && + anyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + { + // For extension methods on IEnumerable, the syntax is: collection.Any(predicate) + // In Roslyn's operation tree: Instance = collection, Arguments[0] = predicate + if (anyInvocation.Instance != null && + anyInvocation.Arguments.Length == 1 && + IsPredicate(anyInvocation.Arguments[0].Value)) + { + collectionExpression = anyInvocation.Instance.Syntax; + predicateExpression = anyInvocation.Arguments[0].Value.Syntax; + return LinqPredicateCheckStatus.Any; + } + } + + // Check for enumerable.Count(predicate) + if (operation is IInvocationOperation countInvocation && + countInvocation.TargetMethod.Name == "Count" && + countInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + { + // For extension methods: collection.Count(predicate) + if (countInvocation.Instance != null && + countInvocation.Arguments.Length == 1 && + IsPredicate(countInvocation.Arguments[0].Value)) + { + collectionExpression = countInvocation.Instance.Syntax; + predicateExpression = countInvocation.Arguments[0].Value.Syntax; + countOperation = operation; + return LinqPredicateCheckStatus.Count; + } + } + + // Check for enumerable.Where(predicate).Any() + if (operation is IInvocationOperation whereAnyInvocation && + whereAnyInvocation.TargetMethod.Name == "Any" && + whereAnyInvocation.Arguments.Length == 0 && + whereAnyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereAnyInvocation.Instance is IInvocationOperation whereInvocation && + whereInvocation.TargetMethod.Name == "Where" && + whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + { + // collection.Where(predicate).Any() + if (whereInvocation.Instance != null && + whereInvocation.Arguments.Length == 1 && + IsPredicate(whereInvocation.Arguments[0].Value)) + { + collectionExpression = whereInvocation.Instance.Syntax; + predicateExpression = whereInvocation.Arguments[0].Value.Syntax; + return LinqPredicateCheckStatus.WhereAny; + } + } + + // Check for enumerable.Where(predicate).Count() + if (operation is IInvocationOperation whereCountInvocation && + whereCountInvocation.TargetMethod.Name == "Count" && + whereCountInvocation.Arguments.Length == 0 && + whereCountInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereCountInvocation.Instance is IInvocationOperation whereInvocation2 && + whereInvocation2.TargetMethod.Name == "Where" && + whereInvocation2.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + { + // collection.Where(predicate).Count() + if (whereInvocation2.Instance != null && + whereInvocation2.Arguments.Length == 1 && + IsPredicate(whereInvocation2.Arguments[0].Value)) + { + collectionExpression = whereInvocation2.Instance.Syntax; + predicateExpression = whereInvocation2.Arguments[0].Value.Syntax; + countOperation = operation; + return LinqPredicateCheckStatus.WhereCount; + } + } + + return LinqPredicateCheckStatus.Unknown; + } + private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext context, IOperation conditionArgument, bool isTrueInvocation, INamedTypeSymbol objectTypeSymbol) { RoslynDebug.Assert(context.Operation is IInvocationOperation, "Expected IInvocationOperation."); @@ -555,6 +653,36 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co return; } + // Check for LINQ predicate patterns that suggest Contains/DoesNotContain + LinqPredicateCheckStatus linqStatus = RecognizeLinqPredicateCheck( + conditionArgument, + out SyntaxNode? linqCollectionExpr, + out SyntaxNode? predicateExpr, + out IOperation? countOp); + + if (linqStatus != LinqPredicateCheckStatus.Unknown && linqCollectionExpr != null && predicateExpr != null) + { + // For Any() and Where().Any() patterns + if (linqStatus is LinqPredicateCheckStatus.Any or LinqPredicateCheckStatus.WhereAny) + { + string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain"; + + ImmutableDictionary.Builder properties = ImmutableDictionary.CreateBuilder(); + properties.Add(ProperAssertMethodNameKey, properAssertMethod); + properties.Add(CodeFixModeKey, CodeFixModeAddArgument); + context.ReportDiagnostic(context.Operation.CreateDiagnostic( + Rule, + additionalLocations: ImmutableArray.Create( + conditionArgument.Syntax.GetLocation(), + predicateExpr.GetLocation(), + linqCollectionExpr.GetLocation()), + properties: properties.ToImmutable(), + properAssertMethod, + isTrueInvocation ? "IsTrue" : "IsFalse")); + return; + } + } + // Check for string method patterns: myString.StartsWith/EndsWith/Contains(...) StringMethodCheckStatus stringMethodStatus = RecognizeStringMethodCheck(conditionArgument, out SyntaxNode? stringExpr, out SyntaxNode? substringExpr); if (stringMethodStatus != StringMethodCheckStatus.Unknown) @@ -722,6 +850,40 @@ private static void AnalyzeAreEqualOrAreNotEqualInvocation(OperationAnalysisCont { if (TryGetSecondArgumentValue((IInvocationOperation)context.Operation, out IOperation? actualArgumentValue)) { + // Check for LINQ predicate patterns that suggest ContainsSingle + LinqPredicateCheckStatus linqStatus2 = RecognizeLinqPredicateCheck( + actualArgumentValue!, + out SyntaxNode? linqCollectionExpr2, + out SyntaxNode? predicateExpr2, + out IOperation? countOp2); + + if (isAreEqualInvocation && + linqStatus2 is LinqPredicateCheckStatus.Count or LinqPredicateCheckStatus.WhereCount && + linqCollectionExpr2 != null && + predicateExpr2 != null && + expectedArgument.ConstantValue.HasValue && + expectedArgument.ConstantValue.Value is int expectedCountValue && + expectedCountValue == 1) + { + // We have Assert.AreEqual(1, enumerable.Count(predicate)) + // We want Assert.ContainsSingle(predicate, enumerable) + string properAssertMethod = "ContainsSingle"; + + ImmutableDictionary.Builder properties = ImmutableDictionary.CreateBuilder(); + properties.Add(ProperAssertMethodNameKey, properAssertMethod); + properties.Add(CodeFixModeKey, CodeFixModeAddArgument); + context.ReportDiagnostic(context.Operation.CreateDiagnostic( + Rule, + additionalLocations: ImmutableArray.Create( + actualArgumentValue.Syntax.GetLocation(), + predicateExpr2.GetLocation(), + linqCollectionExpr2.GetLocation()), + properties: properties.ToImmutable(), + properAssertMethod, + "AreEqual")); + return; + } + // Check if we're comparing a count/length property CountCheckStatus countStatus = RecognizeCountCheck( expectedArgument, @@ -955,4 +1117,10 @@ private static bool TryGetArgumentValueForParameterOrdinal(IInvocationOperation argumentValue = operation.Arguments.FirstOrDefault(arg => arg.Parameter?.Ordinal == ordinal)?.Value?.WalkDownConversion(); return argumentValue is not null; } + + private static bool IsPredicate(IOperation operation) + { + IOperation unwrapped = operation.WalkDownConversion(); + return unwrapped is IAnonymousFunctionOperation or IDelegateCreationOperation; + } } diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs index deb75869dd..a065e766ca 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs @@ -2908,4 +2908,204 @@ await VerifyCS.VerifyCodeFixAsync( } #endregion + + [TestMethod] + public async Task WhenUsingAnyWithPredicate_SuggestsContains() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + [|Assert.IsTrue(enumerable.Any(x => x == 1))|]; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.Contains(x => x == 1, enumerable); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } + + [TestMethod] + public async Task WhenUsingWhereAnyWithPredicate_SuggestsContains() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + [|Assert.IsTrue(enumerable.Where(x => x == 1).Any())|]; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.Contains(x => x == 1, enumerable); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } + + [TestMethod] + public async Task WhenUsingCountWithPredicateEqualsOne_SuggestsContainsSingle() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + [|Assert.AreEqual(1, enumerable.Count(x => x == 1))|]; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.ContainsSingle(x => x == 1, enumerable); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsFalseWithAny_SuggestsDoesNotContain() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + [|Assert.IsFalse(enumerable.Any(x => x == 1))|]; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.DoesNotContain(x => x == 1, enumerable); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } + + [TestMethod] + public async Task WhenUsingCountGreaterThanZero_SuggestsContains() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + [|Assert.IsTrue(enumerable.Count(x => x == 1) > 0)|]; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.Contains(x => x == 1, enumerable); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } } From 1ea6cacb3ef5dd312be76e7010c8fb374f01d0e0 Mon Sep 17 00:00:00 2001 From: Muiz Atolagbe Date: Thu, 9 Oct 2025 08:56:11 +0100 Subject: [PATCH 02/10] modified the analyzer to accept allow predicate with count greather than zero to suggest Assert.Contains --- .../UseProperAssertMethodsAnalyzer.cs | 219 ++++++++++---- .../UseProperAssertMethodsAnalyzerTests.cs | 276 +++++++++--------- 2 files changed, 310 insertions(+), 185 deletions(-) diff --git a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs index c183a0dc69..94a444a678 100644 --- a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs @@ -529,89 +529,157 @@ private static ComparisonCheckStatus RecognizeComparisonCheck( } // Add this new method to recognize LINQ patterns with predicates + //private static LinqPredicateCheckStatus RecognizeLinqPredicateCheck( + // IOperation operation, + // out SyntaxNode? collectionExpression, + // out SyntaxNode? predicateExpression, + // out IOperation? countOperation) + //{ + // collectionExpression = null; + // predicateExpression = null; + // countOperation = null; + + // // Check for enumerable.Any(predicate) + // if (operation is IInvocationOperation anyInvocation && + // anyInvocation.TargetMethod.Name == "Any" && + // anyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + // { + // // For extension methods on IEnumerable, the syntax is: collection.Any(predicate) + // // In Roslyn's operation tree: Instance = collection, Arguments[0] = predicate + // if (anyInvocation.Instance != null && + // anyInvocation.Arguments.Length == 1 && + // IsPredicate(anyInvocation.Arguments[0].Value)) + // { + // collectionExpression = anyInvocation.Instance.Syntax; + // predicateExpression = anyInvocation.Arguments[0].Value.Syntax; + // return LinqPredicateCheckStatus.Any; + // } + // } + + // // Check for enumerable.Count(predicate) + // if (operation is IInvocationOperation countInvocation && + // countInvocation.TargetMethod.Name == "Count" && + // countInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + // { + // // For extension methods: collection.Count(predicate) + // if (countInvocation.Instance != null && + // countInvocation.Arguments.Length == 1 && + // IsPredicate(countInvocation.Arguments[0].Value)) + // { + // collectionExpression = countInvocation.Instance.Syntax; + // predicateExpression = countInvocation.Arguments[0].Value.Syntax; + // countOperation = operation; + // return LinqPredicateCheckStatus.Count; + // } + // } + + // // Check for enumerable.Where(predicate).Any() + // if (operation is IInvocationOperation whereAnyInvocation && + // whereAnyInvocation.TargetMethod.Name == "Any" && + // whereAnyInvocation.Arguments.Length == 0 && + // whereAnyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + // whereAnyInvocation.Instance is IInvocationOperation whereInvocation && + // whereInvocation.TargetMethod.Name == "Where" && + // whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + // { + // // collection.Where(predicate).Any() + // if (whereInvocation.Instance != null && + // whereInvocation.Arguments.Length == 1 && + // IsPredicate(whereInvocation.Arguments[0].Value)) + // { + // collectionExpression = whereInvocation.Instance.Syntax; + // predicateExpression = whereInvocation.Arguments[0].Value.Syntax; + // return LinqPredicateCheckStatus.WhereAny; + // } + // } + + // // Check for enumerable.Where(predicate).Count() + // if (operation is IInvocationOperation whereCountInvocation && + // whereCountInvocation.TargetMethod.Name == "Count" && + // whereCountInvocation.Arguments.Length == 0 && + // whereCountInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + // whereCountInvocation.Instance is IInvocationOperation whereInvocation2 && + // whereInvocation2.TargetMethod.Name == "Where" && + // whereInvocation2.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + // { + // // collection.Where(predicate).Count() + // if (whereInvocation2.Instance != null && + // whereInvocation2.Arguments.Length == 1 && + // IsPredicate(whereInvocation2.Arguments[0].Value)) + // { + // collectionExpression = whereInvocation2.Instance.Syntax; + // predicateExpression = whereInvocation2.Arguments[0].Value.Syntax; + // countOperation = operation; + // return LinqPredicateCheckStatus.WhereCount; + // } + // } + + // return LinqPredicateCheckStatus.Unknown; + //} + private static LinqPredicateCheckStatus RecognizeLinqPredicateCheck( - IOperation operation, - out SyntaxNode? collectionExpression, - out SyntaxNode? predicateExpression, - out IOperation? countOperation) + IOperation operation, + out SyntaxNode? collectionExpression, + out SyntaxNode? predicateExpression, + out IOperation? countOperation) { collectionExpression = null; predicateExpression = null; countOperation = null; // Check for enumerable.Any(predicate) + // Extension methods appear as: Instance=null, Arguments[0]=collection, Arguments[1]=predicate if (operation is IInvocationOperation anyInvocation && anyInvocation.TargetMethod.Name == "Any" && - anyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + anyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + anyInvocation.Arguments.Length == 2) { - // For extension methods on IEnumerable, the syntax is: collection.Any(predicate) - // In Roslyn's operation tree: Instance = collection, Arguments[0] = predicate - if (anyInvocation.Instance != null && - anyInvocation.Arguments.Length == 1 && - IsPredicate(anyInvocation.Arguments[0].Value)) - { - collectionExpression = anyInvocation.Instance.Syntax; - predicateExpression = anyInvocation.Arguments[0].Value.Syntax; - return LinqPredicateCheckStatus.Any; - } + collectionExpression = anyInvocation.Arguments[0].Value.Syntax; + predicateExpression = anyInvocation.Arguments[1].Value.Syntax; + return LinqPredicateCheckStatus.Any; } // Check for enumerable.Count(predicate) if (operation is IInvocationOperation countInvocation && countInvocation.TargetMethod.Name == "Count" && - countInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + countInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + countInvocation.Arguments.Length == 2) { - // For extension methods: collection.Count(predicate) - if (countInvocation.Instance != null && - countInvocation.Arguments.Length == 1 && - IsPredicate(countInvocation.Arguments[0].Value)) - { - collectionExpression = countInvocation.Instance.Syntax; - predicateExpression = countInvocation.Arguments[0].Value.Syntax; - countOperation = operation; - return LinqPredicateCheckStatus.Count; - } + collectionExpression = countInvocation.Arguments[0].Value.Syntax; + predicateExpression = countInvocation.Arguments[1].Value.Syntax; + countOperation = operation; + return LinqPredicateCheckStatus.Count; } // Check for enumerable.Where(predicate).Any() if (operation is IInvocationOperation whereAnyInvocation && whereAnyInvocation.TargetMethod.Name == "Any" && - whereAnyInvocation.Arguments.Length == 0 && whereAnyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - whereAnyInvocation.Instance is IInvocationOperation whereInvocation && + whereAnyInvocation.Arguments.Length == 1 && + whereAnyInvocation.Arguments[0].Value is IInvocationOperation whereInvocation && whereInvocation.TargetMethod.Name == "Where" && - whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereInvocation.Arguments.Length == 2) { - // collection.Where(predicate).Any() - if (whereInvocation.Instance != null && - whereInvocation.Arguments.Length == 1 && - IsPredicate(whereInvocation.Arguments[0].Value)) - { - collectionExpression = whereInvocation.Instance.Syntax; - predicateExpression = whereInvocation.Arguments[0].Value.Syntax; - return LinqPredicateCheckStatus.WhereAny; - } + collectionExpression = whereInvocation.Arguments[0].Value.Syntax; + predicateExpression = whereInvocation.Arguments[1].Value.Syntax; + return LinqPredicateCheckStatus.WhereAny; } // Check for enumerable.Where(predicate).Count() if (operation is IInvocationOperation whereCountInvocation && whereCountInvocation.TargetMethod.Name == "Count" && - whereCountInvocation.Arguments.Length == 0 && whereCountInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - whereCountInvocation.Instance is IInvocationOperation whereInvocation2 && + whereCountInvocation.Arguments.Length == 1 && + whereCountInvocation.Arguments[0].Value is IInvocationOperation whereInvocation2 && whereInvocation2.TargetMethod.Name == "Where" && - whereInvocation2.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + whereInvocation2.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereInvocation2.Arguments.Length == 2) { - // collection.Where(predicate).Count() - if (whereInvocation2.Instance != null && - whereInvocation2.Arguments.Length == 1 && - IsPredicate(whereInvocation2.Arguments[0].Value)) - { - collectionExpression = whereInvocation2.Instance.Syntax; - predicateExpression = whereInvocation2.Arguments[0].Value.Syntax; - countOperation = operation; - return LinqPredicateCheckStatus.WhereCount; - } + collectionExpression = whereInvocation2.Arguments[0].Value.Syntax; + predicateExpression = whereInvocation2.Arguments[1].Value.Syntax; + countOperation = operation; + return LinqPredicateCheckStatus.WhereCount; } return LinqPredicateCheckStatus.Unknown; @@ -752,6 +820,55 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co return; } + // Special-case: enumerable.Count(predicate) > 0 → Assert.Contains(predicate, enumerable) + if (conditionArgument is IBinaryOperation binaryOp && + binaryOp.OperatorKind == BinaryOperatorKind.GreaterThan) + { + if (binaryOp.LeftOperand is IInvocationOperation countInvocation && + binaryOp.RightOperand.ConstantValue.HasValue && + binaryOp.RightOperand.ConstantValue.Value is int intValue && + intValue == 0 && + countInvocation.TargetMethod.Name == "Count") + { + SyntaxNode? countCollectionExpr = null; + SyntaxNode? countPredicateExpr = null; + + if (countInvocation.Instance != null && countInvocation.Arguments.Length == 1) + { + countCollectionExpr = countInvocation.Instance.Syntax; + countPredicateExpr = countInvocation.Arguments[0].Value.Syntax; + } + else if (countInvocation.Instance == null && countInvocation.Arguments.Length == 2) + { + countCollectionExpr = countInvocation.Arguments[0].Value.Syntax; + countPredicateExpr = countInvocation.Arguments[1].Value.Syntax; + } + + if (countCollectionExpr != null && countPredicateExpr != null) + { + string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain"; + + var properties = ImmutableDictionary.CreateBuilder(); + properties.Add(ProperAssertMethodNameKey, properAssertMethod); + properties.Add(CodeFixModeKey, CodeFixModeAddArgument); + + context.ReportDiagnostic( + context.Operation.CreateDiagnostic( + Rule, + additionalLocations: ImmutableArray.Create( + conditionArgument.Syntax.GetLocation(), + countPredicateExpr.GetLocation(), + countCollectionExpr.GetLocation()), + properties: properties.ToImmutable(), + properAssertMethod, + isTrueInvocation ? "IsTrue" : "IsFalse")); + + return; + } + + } + } + // Check for comparison patterns: a > b, a >= b, a < b, a <= b ComparisonCheckStatus comparisonStatus = RecognizeComparisonCheck(conditionArgument, out SyntaxNode? leftExpr, out SyntaxNode? rightExpr); if (comparisonStatus != ComparisonCheckStatus.Unknown) diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs index a065e766ca..1f07bcdb48 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs @@ -2913,199 +2913,207 @@ await VerifyCS.VerifyCodeFixAsync( public async Task WhenUsingAnyWithPredicate_SuggestsContains() { string code = """ - using System.Collections.Generic; - using System.Linq; - using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; - [TestClass] - public class TestClass - { - [TestMethod] - public void TestMethod() + [TestClass] + public class TestClass { - var enumerable = new List(); - [|Assert.IsTrue(enumerable.Any(x => x == 1))|]; + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsTrue(enumerable.Any(x => x == 1))|}; + } } - } - """; + + """; string fixedCode = """ - using System.Collections.Generic; - using System.Linq; - using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; - [TestClass] - public class TestClass - { - [TestMethod] - public void TestMethod() + [TestClass] + public class TestClass { - var enumerable = new List(); - Assert.Contains(x => x == 1, enumerable); + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.Contains(x => x == 1, enumerable); + } } - } - """; + + """; - await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), + fixedCode); } [TestMethod] public async Task WhenUsingWhereAnyWithPredicate_SuggestsContains() { string code = """ - using System.Collections.Generic; - using System.Linq; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class TestClass - { - [TestMethod] - public void TestMethod() + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass { - var enumerable = new List(); - [|Assert.IsTrue(enumerable.Where(x => x == 1).Any())|]; + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsTrue(enumerable.Where(x => x == 1).Any())|}; + } } - } - """; + """; string fixedCode = """ - using System.Collections.Generic; - using System.Linq; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class TestClass - { - [TestMethod] - public void TestMethod() + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass { - var enumerable = new List(); - Assert.Contains(x => x == 1, enumerable); + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.Contains(x => x == 1, enumerable); + } } - } - """; + """; - await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), + fixedCode); } [TestMethod] public async Task WhenUsingCountWithPredicateEqualsOne_SuggestsContainsSingle() { string code = """ - using System.Collections.Generic; - using System.Linq; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class TestClass - { - [TestMethod] - public void TestMethod() + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass { - var enumerable = new List(); - [|Assert.AreEqual(1, enumerable.Count(x => x == 1))|]; + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.AreEqual(1, enumerable.Count(x => x == 1))|}; + } } - } - """; + """; string fixedCode = """ - using System.Collections.Generic; - using System.Linq; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class TestClass - { - [TestMethod] - public void TestMethod() + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass { - var enumerable = new List(); - Assert.ContainsSingle(x => x == 1, enumerable); + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.ContainsSingle(x => x == 1, enumerable); + } } - } - """; + """; - await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("ContainsSingle", "AreEqual"), + fixedCode); } [TestMethod] public async Task WhenUsingIsFalseWithAny_SuggestsDoesNotContain() { string code = """ - using System.Collections.Generic; - using System.Linq; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class TestClass - { - [TestMethod] - public void TestMethod() + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass { - var enumerable = new List(); - [|Assert.IsFalse(enumerable.Any(x => x == 1))|]; + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsFalse(enumerable.Any(x => x == 1))|}; + } } - } - """; + """; string fixedCode = """ - using System.Collections.Generic; - using System.Linq; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class TestClass - { - [TestMethod] - public void TestMethod() + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass { - var enumerable = new List(); - Assert.DoesNotContain(x => x == 1, enumerable); + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.DoesNotContain(x => x == 1, enumerable); + } } - } - """; + """; - await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), + fixedCode); } [TestMethod] public async Task WhenUsingCountGreaterThanZero_SuggestsContains() { string code = """ - using System.Collections.Generic; - using System.Linq; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class TestClass - { - [TestMethod] - public void TestMethod() + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass { - var enumerable = new List(); - [|Assert.IsTrue(enumerable.Count(x => x == 1) > 0)|]; + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsTrue(enumerable.Count(x => x == 1) > 0)|}; + } } - } - """; + """; string fixedCode = """ - using System.Collections.Generic; - using System.Linq; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class TestClass - { - [TestMethod] - public void TestMethod() + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass { - var enumerable = new List(); - Assert.Contains(x => x == 1, enumerable); + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.Contains(x => x == 1, enumerable); + } } - } - """; - - await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + """; + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), + fixedCode); } } From 2570d4883b10afdbe328391b9174c542292dd017 Mon Sep 17 00:00:00 2001 From: Muiz Atolagbe Date: Fri, 10 Oct 2025 00:23:13 +0100 Subject: [PATCH 03/10] modified analyzer --- .../UseProperAssertMethodsAnalyzer.cs | 98 +------------------ 1 file changed, 1 insertion(+), 97 deletions(-) diff --git a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs index 94a444a678..652e888553 100644 --- a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs @@ -528,95 +528,6 @@ private static ComparisonCheckStatus RecognizeComparisonCheck( return ComparisonCheckStatus.Unknown; } - // Add this new method to recognize LINQ patterns with predicates - //private static LinqPredicateCheckStatus RecognizeLinqPredicateCheck( - // IOperation operation, - // out SyntaxNode? collectionExpression, - // out SyntaxNode? predicateExpression, - // out IOperation? countOperation) - //{ - // collectionExpression = null; - // predicateExpression = null; - // countOperation = null; - - // // Check for enumerable.Any(predicate) - // if (operation is IInvocationOperation anyInvocation && - // anyInvocation.TargetMethod.Name == "Any" && - // anyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") - // { - // // For extension methods on IEnumerable, the syntax is: collection.Any(predicate) - // // In Roslyn's operation tree: Instance = collection, Arguments[0] = predicate - // if (anyInvocation.Instance != null && - // anyInvocation.Arguments.Length == 1 && - // IsPredicate(anyInvocation.Arguments[0].Value)) - // { - // collectionExpression = anyInvocation.Instance.Syntax; - // predicateExpression = anyInvocation.Arguments[0].Value.Syntax; - // return LinqPredicateCheckStatus.Any; - // } - // } - - // // Check for enumerable.Count(predicate) - // if (operation is IInvocationOperation countInvocation && - // countInvocation.TargetMethod.Name == "Count" && - // countInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") - // { - // // For extension methods: collection.Count(predicate) - // if (countInvocation.Instance != null && - // countInvocation.Arguments.Length == 1 && - // IsPredicate(countInvocation.Arguments[0].Value)) - // { - // collectionExpression = countInvocation.Instance.Syntax; - // predicateExpression = countInvocation.Arguments[0].Value.Syntax; - // countOperation = operation; - // return LinqPredicateCheckStatus.Count; - // } - // } - - // // Check for enumerable.Where(predicate).Any() - // if (operation is IInvocationOperation whereAnyInvocation && - // whereAnyInvocation.TargetMethod.Name == "Any" && - // whereAnyInvocation.Arguments.Length == 0 && - // whereAnyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - // whereAnyInvocation.Instance is IInvocationOperation whereInvocation && - // whereInvocation.TargetMethod.Name == "Where" && - // whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") - // { - // // collection.Where(predicate).Any() - // if (whereInvocation.Instance != null && - // whereInvocation.Arguments.Length == 1 && - // IsPredicate(whereInvocation.Arguments[0].Value)) - // { - // collectionExpression = whereInvocation.Instance.Syntax; - // predicateExpression = whereInvocation.Arguments[0].Value.Syntax; - // return LinqPredicateCheckStatus.WhereAny; - // } - // } - - // // Check for enumerable.Where(predicate).Count() - // if (operation is IInvocationOperation whereCountInvocation && - // whereCountInvocation.TargetMethod.Name == "Count" && - // whereCountInvocation.Arguments.Length == 0 && - // whereCountInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - // whereCountInvocation.Instance is IInvocationOperation whereInvocation2 && - // whereInvocation2.TargetMethod.Name == "Where" && - // whereInvocation2.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") - // { - // // collection.Where(predicate).Count() - // if (whereInvocation2.Instance != null && - // whereInvocation2.Arguments.Length == 1 && - // IsPredicate(whereInvocation2.Arguments[0].Value)) - // { - // collectionExpression = whereInvocation2.Instance.Syntax; - // predicateExpression = whereInvocation2.Arguments[0].Value.Syntax; - // countOperation = operation; - // return LinqPredicateCheckStatus.WhereCount; - // } - // } - - // return LinqPredicateCheckStatus.Unknown; - //} - private static LinqPredicateCheckStatus RecognizeLinqPredicateCheck( IOperation operation, out SyntaxNode? collectionExpression, @@ -848,7 +759,7 @@ binaryOp.RightOperand.ConstantValue.Value is int intValue && { string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain"; - var properties = ImmutableDictionary.CreateBuilder(); + ImmutableDictionary.Builder properties = ImmutableDictionary.CreateBuilder(); properties.Add(ProperAssertMethodNameKey, properAssertMethod); properties.Add(CodeFixModeKey, CodeFixModeAddArgument); @@ -865,7 +776,6 @@ binaryOp.RightOperand.ConstantValue.Value is int intValue && return; } - } } @@ -1234,10 +1144,4 @@ private static bool TryGetArgumentValueForParameterOrdinal(IInvocationOperation argumentValue = operation.Arguments.FirstOrDefault(arg => arg.Parameter?.Ordinal == ordinal)?.Value?.WalkDownConversion(); return argumentValue is not null; } - - private static bool IsPredicate(IOperation operation) - { - IOperation unwrapped = operation.WalkDownConversion(); - return unwrapped is IAnonymousFunctionOperation or IDelegateCreationOperation; - } } From fc968095f4e515969cb53882bcfa55dfece8e48e Mon Sep 17 00:00:00 2001 From: Muiz Atolagbe Date: Sat, 11 Oct 2025 00:22:48 +0100 Subject: [PATCH 04/10] added more analyzer tests.. --- .../UseProperAssertMethodsAnalyzer.cs | 4 +- .../UseProperAssertMethodsAnalyzerTests.cs | 98 +++++++++++++++++-- 2 files changed, 93 insertions(+), 9 deletions(-) diff --git a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs index 652e888553..968d9c3128 100644 --- a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs @@ -637,7 +637,7 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co conditionArgument, out SyntaxNode? linqCollectionExpr, out SyntaxNode? predicateExpr, - out IOperation? countOp); + out _); if (linqStatus != LinqPredicateCheckStatus.Unknown && linqCollectionExpr != null && predicateExpr != null) { @@ -882,7 +882,7 @@ private static void AnalyzeAreEqualOrAreNotEqualInvocation(OperationAnalysisCont actualArgumentValue!, out SyntaxNode? linqCollectionExpr2, out SyntaxNode? predicateExpr2, - out IOperation? countOp2); + out _); if (isAreEqualInvocation && linqStatus2 is LinqPredicateCheckStatus.Count or LinqPredicateCheckStatus.WhereCount && diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs index 1f07bcdb48..39474bea1a 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs @@ -2909,8 +2909,9 @@ await VerifyCS.VerifyCodeFixAsync( #endregion + #region Predicate Pattern Tests [TestMethod] - public async Task WhenUsingAnyWithPredicate_SuggestsContains() + public async Task WhenUsingIsTrueAnyWithPredicate_SuggestsContains() { string code = """ using System.Collections.Generic; @@ -2955,7 +2956,7 @@ await VerifyCS.VerifyCodeFixAsync( } [TestMethod] - public async Task WhenUsingWhereAnyWithPredicate_SuggestsContains() + public async Task WhenUsingIsTrueWhereAnyWithPredicate_SuggestsContains() { string code = """ using System.Collections.Generic; @@ -2996,7 +2997,7 @@ await VerifyCS.VerifyCodeFixAsync( } [TestMethod] - public async Task WhenUsingCountWithPredicateEqualsOne_SuggestsContainsSingle() + public async Task WhenUsingIsFalseWhereAnyWithPredicate_SuggestsDoesNotContain() { string code = """ using System.Collections.Generic; @@ -3009,7 +3010,7 @@ public class TestClass public void TestMethod() { var enumerable = new List(); - {|#0:Assert.AreEqual(1, enumerable.Count(x => x == 1))|}; + {|#0:Assert.IsFalse(enumerable.Where(x => x == 1).Any())|}; } } """; @@ -3025,14 +3026,14 @@ public class TestClass public void TestMethod() { var enumerable = new List(); - Assert.ContainsSingle(x => x == 1, enumerable); + Assert.DoesNotContain(x => x == 1, enumerable); } } """; await VerifyCS.VerifyCodeFixAsync( code, - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("ContainsSingle", "AreEqual"), + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), fixedCode); } @@ -3078,7 +3079,48 @@ await VerifyCS.VerifyCodeFixAsync( } [TestMethod] - public async Task WhenUsingCountGreaterThanZero_SuggestsContains() + public async Task WhenUsingIsFalseWithWhereAny_SuggestsDoesNotContain() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsFalse(enumerable.Where(x => x == 1).Any())|}; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.DoesNotContain(x => x == 1, enumerable); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsTrueCountGreaterThanZero_SuggestsContains() { string code = """ using System.Collections.Generic; @@ -3116,4 +3158,46 @@ await VerifyCS.VerifyCodeFixAsync( VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), fixedCode); } + + [TestMethod] + public async Task WhenUsingIsFalseCountGreaterThanZero_SuggestsDoesNotContain() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsFalse(enumerable.Count(x => x == 1) > 0)|}; + } + } + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.DoesNotContain(x => x == 1, enumerable); + } + } + """; + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), + fixedCode); + } + + #endregion } From d79d506ab86d309a68cc7c8be1eefb3391efcd9e Mon Sep 17 00:00:00 2001 From: Muiz Atolagbe Date: Sun, 12 Oct 2025 16:57:08 +0100 Subject: [PATCH 05/10] modified output.cs in the integration tests; to reflect analyzer changes --- test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs b/test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs index 02a68522af..785ae5d304 100644 --- a/test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs +++ b/test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs @@ -84,7 +84,8 @@ private static void ValidateOutputIsNotMixed(IEnumerable testResults Assert.Contains(methodName, message.Text); Assert.Contains("TestInitialize", message.Text); Assert.Contains("TestCleanup", message.Text); - Assert.IsFalse(shouldNotContain.Any(message.Text.Contains)); + // Assert.IsFalse(shouldNotContain.Any(message.Text.Contains)); + Assert.DoesNotContain(message.Text.Contains, shouldNotContain); } private static void ValidateInitializeAndCleanup(IEnumerable testResults, Func messageFilter) From ff88554415d81008d2e5d3707c6a00b09b99ac9a Mon Sep 17 00:00:00 2001 From: Muiz Atolagbe Date: Wed, 15 Oct 2025 00:14:03 +0100 Subject: [PATCH 06/10] added analyzer support for predicate suggestions applied to Assert.IsNull and Asset.IsNotNull using Single, SingleOrDefaut, WhereSingle and WhereSingleOrDefualt --- .../UseProperAssertMethodsAnalyzer.cs | 126 ++++++++ .../UseProperAssertMethodsAnalyzerTests.cs | 270 ++++++++++++++++++ 2 files changed, 396 insertions(+) diff --git a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs index 968d9c3128..78dae8a19f 100644 --- a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs @@ -132,6 +132,10 @@ private enum LinqPredicateCheckStatus Count, WhereAny, WhereCount, + Single, + SingleOrDefault, + WhereSingle, + WhereSingleOrDefault, } internal const string ProperAssertMethodNameKey = nameof(ProperAssertMethodNameKey); @@ -277,6 +281,56 @@ private static void AnalyzeInvocationOperation(OperationAnalysisContext context, case "AreNotEqual": AnalyzeAreEqualOrAreNotEqualInvocation(context, firstArgument, isAreEqualInvocation: false, objectTypeSymbol); break; + case "IsNull": + AnalyzeIsNullOrIsNotNullInvocation(context, firstArgument, isNullCheck: true); + break; + + case "IsNotNull": + AnalyzeIsNullOrIsNotNullInvocation(context, firstArgument, isNullCheck: false); + break; + } + } + + private static void AnalyzeIsNullOrIsNotNullInvocation(OperationAnalysisContext context, IOperation argument, bool isNullCheck) + { + RoslynDebug.Assert(context.Operation is IInvocationOperation, "Expected IInvocationOperation."); + + // Check for Single/SingleOrDefault patterns + LinqPredicateCheckStatus linqStatus = RecognizeLinqPredicateCheck( + argument, + out SyntaxNode? linqCollectionExpr, + out SyntaxNode? predicateExpr, + out _); + + if (linqStatus is LinqPredicateCheckStatus.Single or + LinqPredicateCheckStatus.SingleOrDefault or + LinqPredicateCheckStatus.WhereSingle or + LinqPredicateCheckStatus.WhereSingleOrDefault && + linqCollectionExpr != null) + { + // For Assert.IsNotNull(enumerable.Single[OrDefault](...)) -> Assert.ContainsSingle + // For Assert.IsNull(enumerable.Single[OrDefault](...)) -> Assert.DoesNotContain + string properAssertMethod = isNullCheck ? "DoesNotContain" : "ContainsSingle"; + + ImmutableDictionary.Builder properties = ImmutableDictionary.CreateBuilder(); + properties.Add(ProperAssertMethodNameKey, properAssertMethod); + properties.Add(CodeFixModeKey, predicateExpr != null ? CodeFixModeAddArgument : CodeFixModeSimple); + + ImmutableArray additionalLocations = predicateExpr != null + ? ImmutableArray.Create( + argument.Syntax.GetLocation(), + predicateExpr.GetLocation(), + linqCollectionExpr.GetLocation()) + : ImmutableArray.Create( + argument.Syntax.GetLocation(), + linqCollectionExpr.GetLocation()); + + context.ReportDiagnostic(context.Operation.CreateDiagnostic( + Rule, + additionalLocations: additionalLocations, + properties: properties.ToImmutable(), + properAssertMethod, + isNullCheck ? "IsNull" : "IsNotNull")); } } @@ -593,6 +647,78 @@ whereCountInvocation.Arguments[0].Value is IInvocationOperation whereInvocation2 return LinqPredicateCheckStatus.WhereCount; } + // Check for enumerable.Where(predicate).Single() + if (operation is IInvocationOperation whereSingleInvocation && + whereSingleInvocation.TargetMethod.Name == "Single" && + whereSingleInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereSingleInvocation.Arguments.Length == 1 && + whereSingleInvocation.Arguments[0].Value is IInvocationOperation whereInvocation3 && + whereInvocation3.TargetMethod.Name == "Where" && + whereInvocation3.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereInvocation3.Arguments.Length == 2) + { + collectionExpression = whereInvocation3.Arguments[0].Value.Syntax; + predicateExpression = whereInvocation3.Arguments[1].Value.Syntax; + return LinqPredicateCheckStatus.WhereSingle; + } + + // Check for enumerable.Where(predicate).SingleOrDefault() + if (operation is IInvocationOperation whereSingleOrDefaultInvocation && + whereSingleOrDefaultInvocation.TargetMethod.Name == "SingleOrDefault" && + whereSingleOrDefaultInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereSingleOrDefaultInvocation.Arguments.Length == 1 && + whereSingleOrDefaultInvocation.Arguments[0].Value is IInvocationOperation whereInvocation4 && + whereInvocation4.TargetMethod.Name == "Where" && + whereInvocation4.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereInvocation4.Arguments.Length == 2) + { + collectionExpression = whereInvocation4.Arguments[0].Value.Syntax; + predicateExpression = whereInvocation4.Arguments[1].Value.Syntax; + return LinqPredicateCheckStatus.WhereSingleOrDefault; + } + + // Check for enumerable.Single(predicate) + if (operation is IInvocationOperation singleInvocation && + singleInvocation.TargetMethod.Name == "Single" && + singleInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + { + if (singleInvocation.Arguments.Length == 2) + { + // Extension method with predicate + collectionExpression = singleInvocation.Arguments[0].Value.Syntax; + predicateExpression = singleInvocation.Arguments[1].Value.Syntax; + return LinqPredicateCheckStatus.Single; + } + else if (singleInvocation.Arguments.Length == 1) + { + // Instance method or extension without predicate + collectionExpression = singleInvocation.Instance?.Syntax ?? singleInvocation.Arguments[0].Value.Syntax; + predicateExpression = null; + return LinqPredicateCheckStatus.Single; + } + } + + // Check for enumerable.SingleOrDefault(predicate) + if (operation is IInvocationOperation singleOrDefaultInvocation && + singleOrDefaultInvocation.TargetMethod.Name == "SingleOrDefault" && + singleOrDefaultInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + { + if (singleOrDefaultInvocation.Arguments.Length == 2) + { + // Extension method with predicate + collectionExpression = singleOrDefaultInvocation.Arguments[0].Value.Syntax; + predicateExpression = singleOrDefaultInvocation.Arguments[1].Value.Syntax; + return LinqPredicateCheckStatus.SingleOrDefault; + } + else if (singleOrDefaultInvocation.Arguments.Length == 1) + { + // Instance method or extension without predicate + collectionExpression = singleOrDefaultInvocation.Instance?.Syntax ?? singleOrDefaultInvocation.Arguments[0].Value.Syntax; + predicateExpression = null; + return LinqPredicateCheckStatus.SingleOrDefault; + } + } + return LinqPredicateCheckStatus.Unknown; } diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs index 39474bea1a..1d2c7d2e2e 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs @@ -3199,5 +3199,275 @@ await VerifyCS.VerifyCodeFixAsync( fixedCode); } + [TestMethod] + public async Task WhenUsingIsNotNullSingleOrDefaultWithPredicate_SuggestsContainsSingle() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsNotNull(enumerable.SingleOrDefault(x => x == 1))|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.ContainsSingle(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("ContainsSingle", "IsNotNull"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsNotNullSingleWithPredicate_SuggestsContainsSingle() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsNotNull(enumerable.Single(x => x == 1))|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.ContainsSingle(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("ContainsSingle", "IsNotNull"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsNotNullWhereSingleOrDefault_SuggestsContainsSingle() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsNotNull(enumerable.Where(x => x == 1).SingleOrDefault())|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.ContainsSingle(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("ContainsSingle", "IsNotNull"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsNotNullWhereSingle_SuggestsContainsSingle() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsNotNull(enumerable.Where(x => x == 1).Single())|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.ContainsSingle(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("ContainsSingle", "IsNotNull"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsNullSingleOrDefaultWithPredicate_SuggestsDoesNotContain() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsNull(enumerable.SingleOrDefault(x => x == 1))|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.DoesNotContain(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsNull"), + fixedCode); + } + + [TestMethod] + public async Task WhenUsingIsNullWhereSingleOrDefault_SuggestsDoesNotContain() + { + string code = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + {|#0:Assert.IsNull(enumerable.Where(x => x == 1).SingleOrDefault())|}; + } + } + + """; + + string fixedCode = """ + using System.Collections.Generic; + using System.Linq; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TestClass + { + [TestMethod] + public void TestMethod() + { + var enumerable = new List(); + Assert.DoesNotContain(x => x == 1, enumerable); + } + } + + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsNull"), + fixedCode); + } + #endregion } From ff8b524b05a1579f1d896a5379940ec9dfee4982 Mon Sep 17 00:00:00 2001 From: Muiz Atolagbe Date: Thu, 4 Dec 2025 00:33:50 +0000 Subject: [PATCH 07/10] reconciled merge conflicts locally --- .../UseProperAssertMethodsAnalyzerTests.cs | 347 ++++++++++-------- 1 file changed, 188 insertions(+), 159 deletions(-) diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs index 0062d4e2b7..3bca9feb8b 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs @@ -3000,6 +3000,7 @@ await VerifyCS.VerifyCodeFixAsync( public async Task WhenUsingIsFalseWhereAnyWithPredicate_SuggestsDoesNotContain() { string code = """ + using System.Collections.Generic; using System.Linq; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -3011,29 +3012,6 @@ public void TestMethod() { var enumerable = new List(); {|#0:Assert.IsFalse(enumerable.Where(x => x == 1).Any())|}; - #region BCL Types with IComparable Tests - - [TestMethod] - public async Task WhenAssertIsTrueWithTimeSpanComparison() - { - string code = """ - using System; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class MyTestClass - { - [TestMethod] - public void MyTestMethod() - { - var ts1 = TimeSpan.Zero; - var ts2 = TimeSpan.FromSeconds(1); - {|#0:Assert.IsTrue(ts2 > ts1)|}; - {|#1:Assert.IsTrue(ts2 >= ts1)|}; - {|#2:Assert.IsTrue(ts1 < ts2)|}; - {|#3:Assert.IsTrue(ts1 <= ts2)|}; - {|#4:Assert.IsTrue(ts1 == ts1)|}; - {|#5:Assert.IsTrue(ts1 != ts2)|}; } } """; @@ -3091,45 +3069,14 @@ public void TestMethod() { var enumerable = new List(); Assert.DoesNotContain(x => x == 1, enumerable); - using System; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class MyTestClass - { - [TestMethod] - public void MyTestMethod() - { - var ts1 = TimeSpan.Zero; - var ts2 = TimeSpan.FromSeconds(1); - Assert.IsGreaterThan(ts1, ts2); - Assert.IsGreaterThanOrEqualTo(ts1, ts2); - Assert.IsLessThan(ts2, ts1); - Assert.IsLessThanOrEqualTo(ts2, ts1); - Assert.AreEqual(ts1, ts1); - Assert.AreNotEqual(ts2, ts1); } } """; await VerifyCS.VerifyCodeFixAsync( - code, - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), - [ - // /0/Test0.cs(11,9): info MSTEST0037: Use 'Assert.IsGreaterThan' instead of 'Assert.IsTrue' - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsGreaterThan", "IsTrue"), - // /0/Test0.cs(12,9): info MSTEST0037: Use 'Assert.IsGreaterThanOrEqualTo' instead of 'Assert.IsTrue' - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(1).WithArguments("IsGreaterThanOrEqualTo", "IsTrue"), - // /0/Test0.cs(13,9): info MSTEST0037: Use 'Assert.IsLessThan' instead of 'Assert.IsTrue' - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(2).WithArguments("IsLessThan", "IsTrue"), - // /0/Test0.cs(14,9): info MSTEST0037: Use 'Assert.IsLessThanOrEqualTo' instead of 'Assert.IsTrue' - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(3).WithArguments("IsLessThanOrEqualTo", "IsTrue"), - // /0/Test0.cs(15,9): info MSTEST0037: Use 'Assert.AreEqual' instead of 'Assert.IsTrue' - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(4).WithArguments("AreEqual", "IsTrue"), - // /0/Test0.cs(16,9): info MSTEST0037: Use 'Assert.AreNotEqual' instead of 'Assert.IsTrue' - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(5).WithArguments("AreNotEqual", "IsTrue"), - ], - fixedCode); + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), + fixedCode); } [TestMethod] @@ -3188,26 +3135,6 @@ public void TestMethod() { var enumerable = new List(); {|#0:Assert.IsTrue(enumerable.Count(x => x == 1) > 0)|}; - public async Task WhenAssertIsTrueWithDateTimeComparison() - { - string code = """ - using System; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class MyTestClass - { - [TestMethod] - public void MyTestMethod() - { - var dt1 = DateTime.Today; - var dt2 = DateTime.Today.AddDays(1); - {|#0:Assert.IsTrue(dt2 > dt1)|}; - {|#1:Assert.IsTrue(dt2 >= dt1)|}; - {|#2:Assert.IsTrue(dt1 < dt2)|}; - {|#3:Assert.IsTrue(dt1 <= dt2)|}; - {|#4:Assert.IsTrue(dt1 == dt1)|}; - {|#5:Assert.IsTrue(dt1 != dt2)|}; } } """; @@ -3270,43 +3197,6 @@ public void TestMethod() await VerifyCS.VerifyCodeFixAsync( code, VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), - using System; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class MyTestClass - { - [TestMethod] - public void MyTestMethod() - { - var dt1 = DateTime.Today; - var dt2 = DateTime.Today.AddDays(1); - Assert.IsGreaterThan(dt1, dt2); - Assert.IsGreaterThanOrEqualTo(dt1, dt2); - Assert.IsLessThan(dt2, dt1); - Assert.IsLessThanOrEqualTo(dt2, dt1); - Assert.AreEqual(dt1, dt1); - Assert.AreNotEqual(dt2, dt1); - } - } - """; - - await VerifyCS.VerifyCodeFixAsync( - code, - [ - // /0/Test0.cs(11,9): info MSTEST0037: Use 'Assert.IsGreaterThan' instead of 'Assert.IsTrue' - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsGreaterThan", "IsTrue"), - // /0/Test0.cs(12,9): info MSTEST0037: Use 'Assert.IsGreaterThanOrEqualTo' instead of 'Assert.IsTrue' - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(1).WithArguments("IsGreaterThanOrEqualTo", "IsTrue"), - // /0/Test0.cs(13,9): info MSTEST0037: Use 'Assert.IsLessThan' instead of 'Assert.IsTrue' - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(2).WithArguments("IsLessThan", "IsTrue"), - // /0/Test0.cs(14,9): info MSTEST0037: Use 'Assert.IsLessThanOrEqualTo' instead of 'Assert.IsTrue' - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(3).WithArguments("IsLessThanOrEqualTo", "IsTrue"), - // /0/Test0.cs(15,9): info MSTEST0037: Use 'Assert.AreEqual' instead of 'Assert.IsTrue' - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(4).WithArguments("AreEqual", "IsTrue"), - // /0/Test0.cs(16,9): info MSTEST0037: Use 'Assert.AreNotEqual' instead of 'Assert.IsTrue' - VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(5).WithArguments("AreNotEqual", "IsTrue"), - ], fixedCode); } @@ -3346,51 +3236,6 @@ public void TestMethod() Assert.ContainsSingle(x => x == 1, enumerable); } } - - public async Task WhenAssertIsFalseWithTimeSpanComparison() - { - string code = """ - using System; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class MyTestClass - { - [TestMethod] - public void MyTestMethod() - { - var ts1 = TimeSpan.Zero; - var ts2 = TimeSpan.FromSeconds(1); - {|#0:Assert.IsFalse(ts2 > ts1)|}; - {|#1:Assert.IsFalse(ts2 >= ts1)|}; - {|#2:Assert.IsFalse(ts1 < ts2)|}; - {|#3:Assert.IsFalse(ts1 <= ts2)|}; - {|#4:Assert.IsFalse(ts1 == ts1)|}; - {|#5:Assert.IsFalse(ts1 != ts2)|}; - } - } - """; - - string fixedCode = """ - using System; - using Microsoft.VisualStudio.TestTools.UnitTesting; - - [TestClass] - public class MyTestClass - { - [TestMethod] - public void MyTestMethod() - { - var ts1 = TimeSpan.Zero; - var ts2 = TimeSpan.FromSeconds(1); - Assert.IsLessThanOrEqualTo(ts1, ts2); - Assert.IsLessThan(ts1, ts2); - Assert.IsGreaterThanOrEqualTo(ts2, ts1); - Assert.IsGreaterThan(ts2, ts1); - Assert.AreNotEqual(ts1, ts1); - Assert.AreEqual(ts2, ts1); - } - } """; await VerifyCS.VerifyCodeFixAsync( @@ -3625,6 +3470,190 @@ await VerifyCS.VerifyCodeFixAsync( } #endregion + + #region BCL Types with IComparable Tests + + [TestMethod] + public async Task WhenAssertIsTrueWithTimeSpanComparison() + { + string code = """ + using System; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var ts1 = TimeSpan.Zero; + var ts2 = TimeSpan.FromSeconds(1); + {|#0:Assert.IsTrue(ts2 > ts1)|}; + {|#1:Assert.IsTrue(ts2 >= ts1)|}; + {|#2:Assert.IsTrue(ts1 < ts2)|}; + {|#3:Assert.IsTrue(ts1 <= ts2)|}; + {|#4:Assert.IsTrue(ts1 == ts1)|}; + {|#5:Assert.IsTrue(ts1 != ts2)|}; + } + } + """; + + string fixedCode = """ + using System; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var ts1 = TimeSpan.Zero; + var ts2 = TimeSpan.FromSeconds(1); + Assert.IsGreaterThan(ts1, ts2); + Assert.IsGreaterThanOrEqualTo(ts1, ts2); + Assert.IsLessThan(ts2, ts1); + Assert.IsLessThanOrEqualTo(ts2, ts1); + Assert.AreEqual(ts1, ts1); + Assert.AreNotEqual(ts2, ts1); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + [ + // /0/Test0.cs(11,9): info MSTEST0037: Use 'Assert.IsGreaterThan' instead of 'Assert.IsTrue' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsGreaterThan", "IsTrue"), + // /0/Test0.cs(12,9): info MSTEST0037: Use 'Assert.IsGreaterThanOrEqualTo' instead of 'Assert.IsTrue' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(1).WithArguments("IsGreaterThanOrEqualTo", "IsTrue"), + // /0/Test0.cs(13,9): info MSTEST0037: Use 'Assert.IsLessThan' instead of 'Assert.IsTrue' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(2).WithArguments("IsLessThan", "IsTrue"), + // /0/Test0.cs(14,9): info MSTEST0037: Use 'Assert.IsLessThanOrEqualTo' instead of 'Assert.IsTrue' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(3).WithArguments("IsLessThanOrEqualTo", "IsTrue"), + // /0/Test0.cs(15,9): info MSTEST0037: Use 'Assert.AreEqual' instead of 'Assert.IsTrue' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(4).WithArguments("AreEqual", "IsTrue"), + // /0/Test0.cs(16,9): info MSTEST0037: Use 'Assert.AreNotEqual' instead of 'Assert.IsTrue' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(5).WithArguments("AreNotEqual", "IsTrue"), + ], + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsTrueWithDateTimeComparison() + { + string code = """ + using System; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var dt1 = DateTime.Today; + var dt2 = DateTime.Today.AddDays(1); + {|#0:Assert.IsTrue(dt2 > dt1)|}; + {|#1:Assert.IsTrue(dt2 >= dt1)|}; + {|#2:Assert.IsTrue(dt1 < dt2)|}; + {|#3:Assert.IsTrue(dt1 <= dt2)|}; + {|#4:Assert.IsTrue(dt1 == dt1)|}; + {|#5:Assert.IsTrue(dt1 != dt2)|}; + } + } + """; + + string fixedCode = """ + using System; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var dt1 = DateTime.Today; + var dt2 = DateTime.Today.AddDays(1); + Assert.IsGreaterThan(dt1, dt2); + Assert.IsGreaterThanOrEqualTo(dt1, dt2); + Assert.IsLessThan(dt2, dt1); + Assert.IsLessThanOrEqualTo(dt2, dt1); + Assert.AreEqual(dt1, dt1); + Assert.AreNotEqual(dt2, dt1); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + [ + // /0/Test0.cs(11,9): info MSTEST0037: Use 'Assert.IsGreaterThan' instead of 'Assert.IsTrue' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsGreaterThan", "IsTrue"), + // /0/Test0.cs(12,9): info MSTEST0037: Use 'Assert.IsGreaterThanOrEqualTo' instead of 'Assert.IsTrue' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(1).WithArguments("IsGreaterThanOrEqualTo", "IsTrue"), + // /0/Test0.cs(13,9): info MSTEST0037: Use 'Assert.IsLessThan' instead of 'Assert.IsTrue' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(2).WithArguments("IsLessThan", "IsTrue"), + // /0/Test0.cs(14,9): info MSTEST0037: Use 'Assert.IsLessThanOrEqualTo' instead of 'Assert.IsTrue' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(3).WithArguments("IsLessThanOrEqualTo", "IsTrue"), + // /0/Test0.cs(15,9): info MSTEST0037: Use 'Assert.AreEqual' instead of 'Assert.IsTrue' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(4).WithArguments("AreEqual", "IsTrue"), + // /0/Test0.cs(16,9): info MSTEST0037: Use 'Assert.AreNotEqual' instead of 'Assert.IsTrue' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(5).WithArguments("AreNotEqual", "IsTrue"), + ], + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsFalseWithTimeSpanComparison() + { + string code = """ + using System; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var ts1 = TimeSpan.Zero; + var ts2 = TimeSpan.FromSeconds(1); + {|#0:Assert.IsFalse(ts2 > ts1)|}; + {|#1:Assert.IsFalse(ts2 >= ts1)|}; + {|#2:Assert.IsFalse(ts1 < ts2)|}; + {|#3:Assert.IsFalse(ts1 <= ts2)|}; + {|#4:Assert.IsFalse(ts1 == ts1)|}; + {|#5:Assert.IsFalse(ts1 != ts2)|}; + } + } + """; + + string fixedCode = """ + using System; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var ts1 = TimeSpan.Zero; + var ts2 = TimeSpan.FromSeconds(1); + Assert.IsLessThanOrEqualTo(ts1, ts2); + Assert.IsLessThan(ts1, ts2); + Assert.IsGreaterThanOrEqualTo(ts2, ts1); + Assert.IsGreaterThan(ts2, ts1); + Assert.AreNotEqual(ts1, ts1); + Assert.AreEqual(ts2, ts1); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, [ // /0/Test0.cs(11,9): info MSTEST0037: Use 'Assert.IsLessThanOrEqualTo' instead of 'Assert.IsFalse' VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsLessThanOrEqualTo", "IsFalse"), From e751da911a867d4c2e61d621acda6ab8cfc88fb2 Mon Sep 17 00:00:00 2001 From: Muiz Atolagbe Date: Thu, 4 Dec 2025 08:29:50 +0000 Subject: [PATCH 08/10] fixed trailing space causing error --- .../UseProperAssertMethodsAnalyzerTests.cs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs index 3bca9feb8b..e326173f71 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs @@ -3000,7 +3000,6 @@ await VerifyCS.VerifyCodeFixAsync( public async Task WhenUsingIsFalseWhereAnyWithPredicate_SuggestsDoesNotContain() { string code = """ - using System.Collections.Generic; using System.Linq; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -3218,7 +3217,6 @@ public void TestMethod() {|#0:Assert.IsNotNull(enumerable.SingleOrDefault(x => x == 1))|}; } } - """; string fixedCode = """ @@ -3262,7 +3260,6 @@ public void TestMethod() {|#0:Assert.IsNotNull(enumerable.Single(x => x == 1))|}; } } - """; string fixedCode = """ @@ -3280,7 +3277,6 @@ public void TestMethod() Assert.ContainsSingle(x => x == 1, enumerable); } } - """; await VerifyCS.VerifyCodeFixAsync( @@ -3352,7 +3348,6 @@ public void TestMethod() {|#0:Assert.IsNotNull(enumerable.Where(x => x == 1).Single())|}; } } - """; string fixedCode = """ @@ -3370,7 +3365,6 @@ public void TestMethod() Assert.ContainsSingle(x => x == 1, enumerable); } } - """; await VerifyCS.VerifyCodeFixAsync( @@ -3397,7 +3391,6 @@ public void TestMethod() {|#0:Assert.IsNull(enumerable.SingleOrDefault(x => x == 1))|}; } } - """; string fixedCode = """ @@ -3415,7 +3408,6 @@ public void TestMethod() Assert.DoesNotContain(x => x == 1, enumerable); } } - """; await VerifyCS.VerifyCodeFixAsync( @@ -3442,7 +3434,6 @@ public void TestMethod() {|#0:Assert.IsNull(enumerable.Where(x => x == 1).SingleOrDefault())|}; } } - """; string fixedCode = """ @@ -3460,7 +3451,6 @@ public void TestMethod() Assert.DoesNotContain(x => x == 1, enumerable); } } - """; await VerifyCS.VerifyCodeFixAsync( From 0000bb8849acaba0730328b872622e97049c8f5f Mon Sep 17 00:00:00 2001 From: Muiz Atolagbe Date: Fri, 5 Dec 2025 03:18:15 +0000 Subject: [PATCH 09/10] refactor RecognizeLinqPredicateCheck and also output.cs --- .../UseProperAssertMethodsAnalyzer.cs | 231 +++++++++--------- .../MSTest.IntegrationTests/OutputTests.cs | 5 +- 2 files changed, 121 insertions(+), 115 deletions(-) diff --git a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs index 9f515e95d4..60f81d26f8 100644 --- a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs @@ -302,10 +302,10 @@ private static void AnalyzeIsNullOrIsNotNullInvocation(OperationAnalysisContext out SyntaxNode? predicateExpr, out _); - if (linqStatus is LinqPredicateCheckStatus.Single or + if ((linqStatus is LinqPredicateCheckStatus.Single or LinqPredicateCheckStatus.SingleOrDefault or LinqPredicateCheckStatus.WhereSingle or - LinqPredicateCheckStatus.WhereSingleOrDefault && + LinqPredicateCheckStatus.WhereSingleOrDefault) && linqCollectionExpr != null) { // For Assert.IsNotNull(enumerable.Single[OrDefault](...)) -> Assert.ContainsSingle @@ -593,142 +593,90 @@ private static ComparisonCheckStatus RecognizeComparisonCheck( } private static LinqPredicateCheckStatus RecognizeLinqPredicateCheck( - IOperation operation, - out SyntaxNode? collectionExpression, - out SyntaxNode? predicateExpression, - out IOperation? countOperation) + IOperation operation, + out SyntaxNode? collectionExpression, + out SyntaxNode? predicateExpression, + out IOperation? countOperation) { collectionExpression = null; predicateExpression = null; countOperation = null; - // Check for enumerable.Any(predicate) - // Extension methods appear as: Instance=null, Arguments[0]=collection, Arguments[1]=predicate - if (operation is IInvocationOperation anyInvocation && - anyInvocation.TargetMethod.Name == "Any" && - anyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - anyInvocation.Arguments.Length == 2) - { - collectionExpression = anyInvocation.Arguments[0].Value.Syntax; - predicateExpression = anyInvocation.Arguments[1].Value.Syntax; - return LinqPredicateCheckStatus.Any; - } - - // Check for enumerable.Count(predicate) - if (operation is IInvocationOperation countInvocation && - countInvocation.TargetMethod.Name == "Count" && - countInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - countInvocation.Arguments.Length == 2) - { - collectionExpression = countInvocation.Arguments[0].Value.Syntax; - predicateExpression = countInvocation.Arguments[1].Value.Syntax; - countOperation = operation; - return LinqPredicateCheckStatus.Count; - } - - // Check for enumerable.Where(predicate).Any() - if (operation is IInvocationOperation whereAnyInvocation && - whereAnyInvocation.TargetMethod.Name == "Any" && - whereAnyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - whereAnyInvocation.Arguments.Length == 1 && - whereAnyInvocation.Arguments[0].Value is IInvocationOperation whereInvocation && - whereInvocation.TargetMethod.Name == "Where" && - whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - whereInvocation.Arguments.Length == 2) - { - collectionExpression = whereInvocation.Arguments[0].Value.Syntax; - predicateExpression = whereInvocation.Arguments[1].Value.Syntax; - return LinqPredicateCheckStatus.WhereAny; - } - - // Check for enumerable.Where(predicate).Count() - if (operation is IInvocationOperation whereCountInvocation && - whereCountInvocation.TargetMethod.Name == "Count" && - whereCountInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - whereCountInvocation.Arguments.Length == 1 && - whereCountInvocation.Arguments[0].Value is IInvocationOperation whereInvocation2 && - whereInvocation2.TargetMethod.Name == "Where" && - whereInvocation2.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - whereInvocation2.Arguments.Length == 2) + if (operation is not IInvocationOperation invocation) { - collectionExpression = whereInvocation2.Arguments[0].Value.Syntax; - predicateExpression = whereInvocation2.Arguments[1].Value.Syntax; - countOperation = operation; - return LinqPredicateCheckStatus.WhereCount; + return LinqPredicateCheckStatus.Unknown; } - // Check for enumerable.Where(predicate).Single() - if (operation is IInvocationOperation whereSingleInvocation && - whereSingleInvocation.TargetMethod.Name == "Single" && - whereSingleInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - whereSingleInvocation.Arguments.Length == 1 && - whereSingleInvocation.Arguments[0].Value is IInvocationOperation whereInvocation3 && - whereInvocation3.TargetMethod.Name == "Where" && - whereInvocation3.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - whereInvocation3.Arguments.Length == 2) - { - collectionExpression = whereInvocation3.Arguments[0].Value.Syntax; - predicateExpression = whereInvocation3.Arguments[1].Value.Syntax; - return LinqPredicateCheckStatus.WhereSingle; - } + string methodName = invocation.TargetMethod.Name; + string? containingType = invocation.TargetMethod.ContainingType?.ToDisplayString(); - // Check for enumerable.Where(predicate).SingleOrDefault() - if (operation is IInvocationOperation whereSingleOrDefaultInvocation && - whereSingleOrDefaultInvocation.TargetMethod.Name == "SingleOrDefault" && - whereSingleOrDefaultInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - whereSingleOrDefaultInvocation.Arguments.Length == 1 && - whereSingleOrDefaultInvocation.Arguments[0].Value is IInvocationOperation whereInvocation4 && - whereInvocation4.TargetMethod.Name == "Where" && - whereInvocation4.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && - whereInvocation4.Arguments.Length == 2) + if (containingType != "System.Linq.Enumerable") { - collectionExpression = whereInvocation4.Arguments[0].Value.Syntax; - predicateExpression = whereInvocation4.Arguments[1].Value.Syntax; - return LinqPredicateCheckStatus.WhereSingleOrDefault; + return LinqPredicateCheckStatus.Unknown; } - // Check for enumerable.Single(predicate) - if (operation is IInvocationOperation singleInvocation && - singleInvocation.TargetMethod.Name == "Single" && - singleInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + // Check for Where().Method() patterns + if (invocation.Arguments.Length == 1) { - if (singleInvocation.Arguments.Length == 2) + if (TryMatchWherePattern(invocation, "Any", out collectionExpression, out predicateExpression)) { - // Extension method with predicate - collectionExpression = singleInvocation.Arguments[0].Value.Syntax; - predicateExpression = singleInvocation.Arguments[1].Value.Syntax; - return LinqPredicateCheckStatus.Single; + return LinqPredicateCheckStatus.WhereAny; } - else if (singleInvocation.Arguments.Length == 1) + + if (TryMatchWherePattern(invocation, "Count", out collectionExpression, out predicateExpression)) { - // Instance method or extension without predicate - collectionExpression = singleInvocation.Instance?.Syntax ?? singleInvocation.Arguments[0].Value.Syntax; - predicateExpression = null; - return LinqPredicateCheckStatus.Single; + countOperation = operation; + return LinqPredicateCheckStatus.WhereCount; } - } - // Check for enumerable.SingleOrDefault(predicate) - if (operation is IInvocationOperation singleOrDefaultInvocation && - singleOrDefaultInvocation.TargetMethod.Name == "SingleOrDefault" && - singleOrDefaultInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") - { - if (singleOrDefaultInvocation.Arguments.Length == 2) + if (TryMatchWherePattern(invocation, "Single", out collectionExpression, out predicateExpression)) { - // Extension method with predicate - collectionExpression = singleOrDefaultInvocation.Arguments[0].Value.Syntax; - predicateExpression = singleOrDefaultInvocation.Arguments[1].Value.Syntax; - return LinqPredicateCheckStatus.SingleOrDefault; + return LinqPredicateCheckStatus.WhereSingle; } - else if (singleOrDefaultInvocation.Arguments.Length == 1) + + if (TryMatchWherePattern(invocation, "SingleOrDefault", out collectionExpression, out predicateExpression)) { - // Instance method or extension without predicate - collectionExpression = singleOrDefaultInvocation.Instance?.Syntax ?? singleOrDefaultInvocation.Arguments[0].Value.Syntax; - predicateExpression = null; - return LinqPredicateCheckStatus.SingleOrDefault; + return LinqPredicateCheckStatus.WhereSingleOrDefault; } } + // Check for direct Method(predicate) patterns + switch (methodName) + { + case "Any": + if (TryMatchLinqMethod(invocation, "Any", out collectionExpression, out predicateExpression)) + { + return LinqPredicateCheckStatus.Any; + } + + break; + + case "Count": + if (TryMatchLinqMethod(invocation, "Count", out collectionExpression, out predicateExpression)) + { + countOperation = operation; + return LinqPredicateCheckStatus.Count; + } + + break; + + case "Single": + if (TryMatchLinqMethod(invocation, "Single", out collectionExpression, out predicateExpression)) + { + return LinqPredicateCheckStatus.Single; + } + + break; + + case "SingleOrDefault": + if (TryMatchLinqMethod(invocation, "SingleOrDefault", out collectionExpression, out predicateExpression)) + { + return LinqPredicateCheckStatus.SingleOrDefault; + } + + break; + } + return LinqPredicateCheckStatus.Unknown; } @@ -1156,6 +1104,61 @@ actualArgumentValue.Type is { } actualType && } } + private static bool TryMatchWherePattern( + IInvocationOperation invocation, + string methodName, + out SyntaxNode? collectionExpression, + out SyntaxNode? predicateExpression) + { + if (invocation.TargetMethod.Name == methodName && + invocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + invocation.Arguments.Length == 1 && + invocation.Arguments[0].Value is IInvocationOperation whereInvocation && + whereInvocation.TargetMethod.Name == "Where" && + whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" && + whereInvocation.Arguments.Length == 2) + { + collectionExpression = whereInvocation.Arguments[0].Value.Syntax; + predicateExpression = whereInvocation.Arguments[1].Value.Syntax; + return true; + } + + collectionExpression = null; + predicateExpression = null; + return false; + } + + private static bool TryMatchLinqMethod( + IInvocationOperation invocation, + string methodName, + out SyntaxNode? collectionExpression, + out SyntaxNode? predicateExpression) + { + if (invocation.TargetMethod.Name == methodName && + invocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable") + { + // Extension method with predicate: Method(collection, predicate) + if (invocation.Arguments.Length == 2) + { + collectionExpression = invocation.Arguments[0].Value.Syntax; + predicateExpression = invocation.Arguments[1].Value.Syntax; + return true; + } + + // Instance method or extension without predicate: Method(collection) + else if (invocation.Arguments.Length == 1) + { + collectionExpression = invocation.Instance?.Syntax ?? invocation.Arguments[0].Value.Syntax; + predicateExpression = null; + return true; + } + } + + collectionExpression = null; + predicateExpression = null; + return false; + } + private static CountCheckStatus RecognizeCountCheck( IOperation operation, INamedTypeSymbol objectTypeSymbol, diff --git a/test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs b/test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs index 785ae5d304..b64a9746d6 100644 --- a/test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs +++ b/test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs @@ -85,7 +85,10 @@ private static void ValidateOutputIsNotMixed(IEnumerable testResults Assert.Contains("TestInitialize", message.Text); Assert.Contains("TestCleanup", message.Text); // Assert.IsFalse(shouldNotContain.Any(message.Text.Contains)); - Assert.DoesNotContain(message.Text.Contains, shouldNotContain); + foreach (string item in shouldNotContain) + { + Assert.DoesNotContain(item, message.Text); + } } private static void ValidateInitializeAndCleanup(IEnumerable testResults, Func messageFilter) From d79b8366d96147460a8df3471ca63b03e244291f Mon Sep 17 00:00:00 2001 From: Muiz Atolagbe Date: Fri, 5 Dec 2025 04:01:35 +0000 Subject: [PATCH 10/10] implemented this to reuse RecognizeLinqPredicateCheck for the validation --- .../UseProperAssertMethodsAnalyzer.cs | 67 ++++++++----------- 1 file changed, 29 insertions(+), 38 deletions(-) diff --git a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs index 60f81d26f8..d3c60cc75f 100644 --- a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs @@ -817,49 +817,40 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co // Special-case: enumerable.Count(predicate) > 0 → Assert.Contains(predicate, enumerable) if (conditionArgument is IBinaryOperation binaryOp && - binaryOp.OperatorKind == BinaryOperatorKind.GreaterThan) + binaryOp.OperatorKind == BinaryOperatorKind.GreaterThan && + binaryOp.RightOperand.ConstantValue.HasValue && + binaryOp.RightOperand.ConstantValue.Value is int intValue && + intValue == 0) { - if (binaryOp.LeftOperand is IInvocationOperation countInvocation && - binaryOp.RightOperand.ConstantValue.HasValue && - binaryOp.RightOperand.ConstantValue.Value is int intValue && - intValue == 0 && - countInvocation.TargetMethod.Name == "Count") + // Use RecognizeLinqPredicateCheck to properly validate LINQ Count method + LinqPredicateCheckStatus countLinqStatus = RecognizeLinqPredicateCheck( + binaryOp.LeftOperand, + out SyntaxNode? countCollectionExpr, + out SyntaxNode? countPredicateExpr, + out _); + + if ((countLinqStatus is LinqPredicateCheckStatus.Count or LinqPredicateCheckStatus.WhereCount) && + countCollectionExpr != null && + countPredicateExpr != null) { - SyntaxNode? countCollectionExpr = null; - SyntaxNode? countPredicateExpr = null; - - if (countInvocation.Instance != null && countInvocation.Arguments.Length == 1) - { - countCollectionExpr = countInvocation.Instance.Syntax; - countPredicateExpr = countInvocation.Arguments[0].Value.Syntax; - } - else if (countInvocation.Instance == null && countInvocation.Arguments.Length == 2) - { - countCollectionExpr = countInvocation.Arguments[0].Value.Syntax; - countPredicateExpr = countInvocation.Arguments[1].Value.Syntax; - } - - if (countCollectionExpr != null && countPredicateExpr != null) - { - string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain"; + string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain"; - ImmutableDictionary.Builder properties = ImmutableDictionary.CreateBuilder(); - properties.Add(ProperAssertMethodNameKey, properAssertMethod); - properties.Add(CodeFixModeKey, CodeFixModeAddArgument); + ImmutableDictionary.Builder properties = ImmutableDictionary.CreateBuilder(); + properties.Add(ProperAssertMethodNameKey, properAssertMethod); + properties.Add(CodeFixModeKey, CodeFixModeAddArgument); - context.ReportDiagnostic( - context.Operation.CreateDiagnostic( - Rule, - additionalLocations: ImmutableArray.Create( - conditionArgument.Syntax.GetLocation(), - countPredicateExpr.GetLocation(), - countCollectionExpr.GetLocation()), - properties: properties.ToImmutable(), - properAssertMethod, - isTrueInvocation ? "IsTrue" : "IsFalse")); + context.ReportDiagnostic( + context.Operation.CreateDiagnostic( + Rule, + additionalLocations: ImmutableArray.Create( + conditionArgument.Syntax.GetLocation(), + countPredicateExpr.GetLocation(), + countCollectionExpr.GetLocation()), + properties: properties.ToImmutable(), + properAssertMethod, + isTrueInvocation ? "IsTrue" : "IsFalse")); - return; - } + return; } }