From 6da7a6f010667a0539869840c72b6d80f4bbf8c0 Mon Sep 17 00:00:00 2001 From: Alex Zaytsev Date: Thu, 20 Oct 2022 05:09:07 +0000 Subject: [PATCH] Fix conditional expressions for classes with a common ancestor (#212) Fixes #107 +semver:fix --- src/DelegateDecompiler.Tests/Employee.cs | 35 ++++++++++--------- .../QueryableExtensionsTests.cs | 16 +++++++++ src/DelegateDecompiler/Address.cs | 2 +- .../DecompiledQueryProvider.cs | 6 ++-- .../OptimizeExpressionVisitor.cs | 5 +++ src/DelegateDecompiler/Processor.cs | 31 +++++++++------- 6 files changed, 63 insertions(+), 32 deletions(-) diff --git a/src/DelegateDecompiler.Tests/Employee.cs b/src/DelegateDecompiler.Tests/Employee.cs index ae5c6b9..3e18c1a 100644 --- a/src/DelegateDecompiler.Tests/Employee.cs +++ b/src/DelegateDecompiler.Tests/Employee.cs @@ -19,27 +19,18 @@ public class Employee public int To; [Computed] - public string FullName - { - get { return FirstName + " " + LastName; } - } + public string FullName => FirstName + " " + LastName; - public string FullNameWithoutAttribute - { - get { return FirstName + " " + LastName; } - } + public string FullNameWithoutAttribute => FirstName + " " + LastName; [Computed] - public string FromTo - { - get { return From + "-" + To; } - } + public string FromTo => From + "-" + To; [Computed] - public bool IsActive - { - get { return true; } - } + public bool IsActive => true; + + [Computed] + public IEmployeeStatus Status => IsActive ? (IEmployeeStatus)new Active() : new Terminated(); [Computed] public int Count @@ -151,4 +142,16 @@ public static string FullName(this Employee e) return e.FirstName + " " + e.LastName; } } + + public interface IEmployeeStatus + { + } + + public class Active : IEmployeeStatus + { + } + + public class Terminated : IEmployeeStatus + { + } } diff --git a/src/DelegateDecompiler.Tests/QueryableExtensionsTests.cs b/src/DelegateDecompiler.Tests/QueryableExtensionsTests.cs index 2dc4233..e8b7777 100644 --- a/src/DelegateDecompiler.Tests/QueryableExtensionsTests.cs +++ b/src/DelegateDecompiler.Tests/QueryableExtensionsTests.cs @@ -352,5 +352,21 @@ public void Issue78() AssertAreEqual(expected.Expression, actual.Expression); } + + [Test] + public void Issue127() + { + var employees = new[] { new Employee { FirstName = "Test", LastName = "User" } }; + + var expected = (from employee in employees.AsQueryable() + where ((IEmployeeStatus)new Active()) is Active + select employee); + + var actual = (from employee in employees.AsQueryable() + where employee.Status is Active + select employee).Decompile(); + + AssertAreEqual(expected.Expression, actual.Expression); + } } } diff --git a/src/DelegateDecompiler/Address.cs b/src/DelegateDecompiler/Address.cs index d42bd49..f8b452c 100644 --- a/src/DelegateDecompiler/Address.cs +++ b/src/DelegateDecompiler/Address.cs @@ -27,7 +27,7 @@ public Address Clone(IDictionary map) { if (map.ContainsKey(this)) return map[this]; - var result = new Address() { Expression = this.Expression }; + var result = new Address { Expression = this.Expression }; map[this] = result; return result; } diff --git a/src/DelegateDecompiler/DecompiledQueryProvider.cs b/src/DelegateDecompiler/DecompiledQueryProvider.cs index 4f840b4..69f15b2 100644 --- a/src/DelegateDecompiler/DecompiledQueryProvider.cs +++ b/src/DelegateDecompiler/DecompiledQueryProvider.cs @@ -40,19 +40,19 @@ public virtual IQueryable CreateQuery(Expression expression) public virtual IQueryable CreateQuery(Expression expression) { - var decompiled = expression.Decompile(); + var decompiled = expression.Decompile().Optimize(); return new DecompiledQueryable(this, Inner.CreateQuery(decompiled)); } public object Execute(Expression expression) { - var decompiled = expression.Decompile(); + var decompiled = expression.Decompile().Optimize(); return Inner.Execute(decompiled); } public TResult Execute(Expression expression) { - var decompiled = expression.Decompile(); + var decompiled = expression.Decompile().Optimize(); return Inner.Execute(decompiled); } } diff --git a/src/DelegateDecompiler/OptimizeExpressionVisitor.cs b/src/DelegateDecompiler/OptimizeExpressionVisitor.cs index 010b1c5..f397f4f 100644 --- a/src/DelegateDecompiler/OptimizeExpressionVisitor.cs +++ b/src/DelegateDecompiler/OptimizeExpressionVisitor.cs @@ -39,6 +39,11 @@ protected override Expression VisitConditional(ConditionalExpression node) var test = Visit(node.Test); var ifTrue = Visit(node.IfTrue); var ifFalse = Visit(node.IfFalse); + + if (test is ConstantExpression constant && constant.Value is bool boolValue) + { + return boolValue ? ifTrue : ifFalse; + } if (IsCoalesce(test, ifTrue, out var expression)) { diff --git a/src/DelegateDecompiler/Processor.cs b/src/DelegateDecompiler/Processor.cs index 5c19b52..513a441 100644 --- a/src/DelegateDecompiler/Processor.cs +++ b/src/DelegateDecompiler/Processor.cs @@ -955,43 +955,48 @@ internal static Expression AdjustType(Expression expression, Type type) return expression; } - var constantExpression = expression as ConstantExpression; - if (constantExpression != null) + if (expression is ConstantExpression constant) { - if (constantExpression.Value == null) + if (constant.Value == null) { return Expression.Constant(null, type); } - if (constantExpression.Type == typeof(int)) + if (constant.Type == typeof(int)) { if (type.IsEnum) { - return Expression.Constant(Enum.ToObject(type, constantExpression.Value)); + return Expression.Constant(Enum.ToObject(type, constant.Value)); } + if (type == typeof(bool)) { - return Expression.Constant(Convert.ToBoolean(constantExpression.Value)); + return Expression.Constant(Convert.ToBoolean(constant.Value)); } + if (type == typeof(byte)) { - return Expression.Constant(Convert.ToByte(constantExpression.Value)); + return Expression.Constant(Convert.ToByte(constant.Value)); } + if (type == typeof(sbyte)) { - return Expression.Constant(Convert.ToSByte(constantExpression.Value)); + return Expression.Constant(Convert.ToSByte(constant.Value)); } + if (type == typeof(short)) { - return Expression.Constant(Convert.ToInt16(constantExpression.Value)); + return Expression.Constant(Convert.ToInt16(constant.Value)); } + if (type == typeof(ushort)) { - return Expression.Constant(Convert.ToUInt16(constantExpression.Value)); + return Expression.Constant(Convert.ToUInt16(constant.Value)); } + if (type == typeof(uint)) { - return Expression.Constant(Convert.ToUInt32(constantExpression.Value)); + return Expression.Constant(Convert.ToUInt32(constant.Value)); } } } @@ -1002,6 +1007,7 @@ internal static Expression AdjustType(Expression expression, Type type) return Expression.NotEqual(expression, Expression.Constant(0)); } } + if (!type.IsAssignableFrom(expression.Type) && expression.Type.IsEnum && expression.Type.GetEnumUnderlyingType() == type) { return Expression.Convert(expression, type); @@ -1420,7 +1426,8 @@ static void LdLoc(ProcessorState state, int index) static void StLoc(ProcessorState state, int index) { var info = state.Locals[index]; - info.Address = AdjustType(state.Stack.Pop(), info.Type); + var expression = AdjustType(state.Stack.Pop(), info.Type); + info.Address = expression.Type == info.Type ? expression : Expression.Convert(expression, info.Type); } static void LdArg(ProcessorState state, int index)