diff --git a/src/EntityFramework5.Npgsql/EntityFramework5.Npgsql.csproj b/src/EntityFramework5.Npgsql/EntityFramework5.Npgsql.csproj index 682357c..dbe33ac 100644 --- a/src/EntityFramework5.Npgsql/EntityFramework5.Npgsql.csproj +++ b/src/EntityFramework5.Npgsql/EntityFramework5.Npgsql.csproj @@ -63,6 +63,8 @@ + + diff --git a/src/EntityFramework6.Npgsql/EntityFramework6.Npgsql.csproj b/src/EntityFramework6.Npgsql/EntityFramework6.Npgsql.csproj index ecfa056..f0f60ca 100644 --- a/src/EntityFramework6.Npgsql/EntityFramework6.Npgsql.csproj +++ b/src/EntityFramework6.Npgsql/EntityFramework6.Npgsql.csproj @@ -76,7 +76,9 @@ + + diff --git a/src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs b/src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs index f7d4d21..a3aadf7 100644 --- a/src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs +++ b/src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs @@ -359,7 +359,7 @@ public override ReadOnlyCollection GetStoreFunctions() .ToList() .AsReadOnly(); - static EdmFunction CreateComposableEdmFunction([NotNull] MethodInfo method, [NotNull] DbFunctionAttribute dbFunctionInfo) + internal static EdmFunction CreateComposableEdmFunction([NotNull] MethodInfo method, [NotNull] DbFunctionAttribute dbFunctionInfo) { if (method == null) throw new ArgumentNullException(nameof(method)); diff --git a/src/EntityFramework6.Npgsql/SqlGenerators/CaseIsNullToCoalesceReducer.cs b/src/EntityFramework6.Npgsql/SqlGenerators/CaseIsNullToCoalesceReducer.cs new file mode 100644 index 0000000..1ff4776 --- /dev/null +++ b/src/EntityFramework6.Npgsql/SqlGenerators/CaseIsNullToCoalesceReducer.cs @@ -0,0 +1,123 @@ +using System.Collections.Generic; +using System.Data.Entity.Core.Metadata.Edm; +using System.Linq; +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Data.Entity.Core.Common.CommandTrees.ExpressionBuilder; +using System.Diagnostics; +#if ENTITIES6 +using System.Globalization; +using System.Data.Entity.Core.Common.CommandTrees; +using System.Data.Entity.Core.Metadata.Edm; +#else +using System.Data.Common.CommandTrees; +using System.Data.Metadata.Edm; +#endif +using JetBrains.Annotations; + + +namespace Npgsql.SqlGenerators +{ + public class CaseIsNullToCoalesceReducer + { + public static DbFunctionExpression InvokeCoalesceExpression(params DbExpression[] argumentExpressions) + { + var fromClrType = PrimitiveType + .GetEdmPrimitiveTypes() + .FirstOrDefault(t => t.ClrEquivalentType == typeof(string)); + + int i=0; + var func = EdmFunction.Create( + "coalesce", + "Npgsql", + DataSpace.SSpace, + new EdmFunctionPayload + { + ParameterTypeSemantics = ParameterTypeSemantics.AllowImplicitConversion, + Schema = string.Empty, + IsBuiltIn = true, + IsAggregate = false, + IsFromProviderManifest = true, + StoreFunctionName = "coalesce", + IsComposable = true, + ReturnParameters = new[] + { + FunctionParameter.Create("ReturnType", fromClrType,ParameterMode.ReturnValue) + }, + Parameters = argumentExpressions.Select( + x => FunctionParameter.Create( + "p" + (i++).ToString(),fromClrType,ParameterMode.In)).ToList() + }, + new List()); + + return func.Invoke(argumentExpressions); + } + + public static DbFunctionExpression UnnestCoalesceInvocations(DbFunctionExpression dbFunctionExpression) + { + var args = new List(); + foreach (var arg in dbFunctionExpression.Arguments) + { + if(arg is DbFunctionExpression funcCall + && funcCall.Function.NamespaceName=="Npgsql" + && funcCall.Function.Name=="coalesce") + { + args.AddRange(funcCall.Arguments); + } + else + { + args.Add(arg); + } + } + return InvokeCoalesceExpression(args.ToArray()); + } + + public static DbExpression TransformCoalesce(DbExpression expression) + { + if (expression is DbCaseExpression case2) + { + return TransformCoalesce(case2); + } + + if (expression is DbIsNullExpression nullExp) + { + return TransformCoalesce(nullExp.Argument).IsNull(); + } + return expression; + } + + public static DbExpression TransformCoalesce(DbCaseExpression expression) + { + expression = DbExpressionBuilder.Case( + expression.When.Select(TransformCoalesce), + expression.Then.Select(TransformCoalesce), + expression.Else); + + var lastWhen = expression.When.Count-1; + if (expression.When[lastWhen].ExpressionKind == DbExpressionKind.IsNull) + { + var is_null = expression.When[lastWhen] as DbIsNullExpression; + if (DbExpressionDeepEqual.DeepEqual(is_null.Argument,expression.Else)) + { + var coalesceInvocation = InvokeCoalesceExpression(is_null.Argument, expression.Then[lastWhen]); + coalesceInvocation = UnnestCoalesceInvocations(coalesceInvocation); + + if (expression.When.Count == 1) + { + return coalesceInvocation; + } + + var simplifiendCase = DbExpressionBuilder.Case( + expression.When.Take(lastWhen), + expression.Then.Take(lastWhen), + coalesceInvocation); + + return TransformCoalesce(simplifiendCase); + } + return expression; + } + return expression; + } + } +} diff --git a/src/EntityFramework6.Npgsql/SqlGenerators/DbExpressionDeepEqual.cs b/src/EntityFramework6.Npgsql/SqlGenerators/DbExpressionDeepEqual.cs new file mode 100644 index 0000000..70f96be --- /dev/null +++ b/src/EntityFramework6.Npgsql/SqlGenerators/DbExpressionDeepEqual.cs @@ -0,0 +1,77 @@ +using System.Data.Entity.Core.Common.CommandTrees; +using System.Data.Entity.Core.Metadata.Edm; +using System.Linq; + +namespace Npgsql.SqlGenerators +{ + public class DbExpressionDeepEqual + { + public static bool DeepEqual(DbExpression e1, DbExpression e2) + { + if (e1.Equals(e2)) return true; + if (e1.GetType() != e2.GetType()) return false; + if (!e1.ExpressionKind.Equals(e2.ExpressionKind)) return false; + if (!DeepEqual(e1.ResultType,e2.ResultType)) return false; + + if (e1 is DbFunctionExpression f1 && e2 is DbFunctionExpression f2) + { + return DeepEqual(f1,f2); + } + if (e1 is DbConstantExpression c1 && e2 is DbConstantExpression c2) + { + return c1.Value.Equals(c2.Value); + } + if (e1 is DbBinaryExpression b1 && e2 is DbBinaryExpression b2) + { + return DeepEqual(b1,b2); + } + if (e1 is DbUnaryExpression u1 && e2 is DbUnaryExpression u2) + { + return DeepEqual(u1,u2); + } + if (e1 is DbVariableReferenceExpression v1 && e2 is DbVariableReferenceExpression v2) + { + return DeepEqual(v1,v2); + } + + return false; + } + + static bool DeepEqual(TypeUsage r1, TypeUsage r2) + { + if (r1.EdmType != r2.EdmType) return false; + return true; + } + + private static bool DeepEqual(DbFunctionExpression f1, DbFunctionExpression f2) + { + if (!f1.Function.Name.Equals(f2.Function.Name)) return false; + if (!f1.Function.NamespaceName.Equals(f2.Function.NamespaceName)) return false; + if (!f1.Arguments.Count.Equals(f2.Arguments.Count)) return false; + + var argumenst_equals = f1.Arguments + .Zip(f2.Arguments, (a, b) => DeepEqual(a, b)) + .All(areEquals => areEquals); + + return argumenst_equals; + } + + private static bool DeepEqual(DbBinaryExpression b1, DbBinaryExpression b2) + { + if (!DeepEqual(b1.Left,b2.Left)) return false; + if (!DeepEqual(b1.Right,b2.Right)) return false; + + return true; + } + + private static bool DeepEqual(DbUnaryExpression u1, DbUnaryExpression u2) + { + return DeepEqual(u1.Argument,u2.Argument); + } + + private static bool DeepEqual(DbVariableReferenceExpression v1, DbVariableReferenceExpression v2) + { + return DeepEqual(v1.VariableName,v1.VariableName); + } + } +} diff --git a/src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs b/src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs index 1d078d6..36f6ee3 100644 --- a/src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs +++ b/src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs @@ -829,6 +829,16 @@ protected string GetDbType(EdmType edmType) public override VisitedExpression Visit([NotNull] DbCaseExpression expression) { + var result = CaseIsNullToCoalesceReducer.TransformCoalesce(expression); + if (result is DbCaseExpression case2) + { + expression = case2; + } + else + { + return result.Accept(this); + } + var caseExpression = new LiteralExpression(" CASE "); for (var i = 0; i < expression.When.Count && i < expression.Then.Count; ++i) { @@ -1191,6 +1201,12 @@ VisitedExpression VisitFunction(EdmFunction function, IList args, throw new NotSupportedException("cast type name argument must be a constant expression."); return new CastExpression(args[0].Accept(this), typeNameExpression.Value.ToString()); + }else if (functionName == "coalesce") + { + var coalesceFuncCall = new FunctionExpression("coalesce"); + foreach (var a in args) + coalesceFuncCall.AddArgument(a.Accept(this)); + return coalesceFuncCall; } } diff --git a/test/EntityFramework6.Npgsql.Tests/EntityFrameworkBasicTests.cs b/test/EntityFramework6.Npgsql.Tests/EntityFrameworkBasicTests.cs index 1fbb22e..717a942 100644 --- a/test/EntityFramework6.Npgsql.Tests/EntityFrameworkBasicTests.cs +++ b/test/EntityFramework6.Npgsql.Tests/EntityFrameworkBasicTests.cs @@ -735,6 +735,69 @@ public void Test_issue_27_select_ef_generated_literals_from_inner_select() } } + [Test] + public void Test_issue_60_and_62() + { + using (var context = new BloggingContext(ConnectionString)) + { + context.Database.Log = Console.Out.WriteLine; + + context.Blogs.Add( new Blog { Name = "Hello" }); + context.SaveChanges(); + + string string_value = "string_value"; + var query = context.Blogs.Select(b => string_value + "_postfijo").Take(1); + var blog_title = query.First(); + Assert.That(blog_title, Is.EqualTo("string_value_postfijo")); + StringAssert.DoesNotContain("case", query.ToString().ToLower() ); + } + } + + [Test] + public void TestNullPropagation_1() + { + using (var context = new BloggingContext(ConnectionString)) + { + context.Database.Log = Console.Out.WriteLine; + + context.Blogs.Add( new Blog { Name = "Hello" }); + context.SaveChanges(); + + string valor_string = "string_value"; + var query = context.Blogs.Select(b => (valor_string ?? "otro_valor") + "_postfijo").Take(1); + var blog_title = query.First(); + Assert.That(blog_title, Is.EqualTo("string_value_postfijo")); + + var query_sql = query.ToString().ToLower(); + StringAssert.DoesNotContain("case", query.ToString().ToLower() ); + StringAssert.Contains("coalesce(@p__linq__0,e'otro_valor',e'')", query_sql); + } + } + + [Test] + public void TestNullPropagation_2() + { + using (var context = new BloggingContext(ConnectionString)) + { + context.Database.Log = Console.Out.WriteLine; + + context.Blogs.Add( new Blog { Name = "Hello" }); + context.SaveChanges(); + + string string_value1 = "string_value1"; + string string_value2 = "string_value2"; + string string_value3 = "string_value3"; + + var query = context.Blogs.Select(b => (string_value1 ?? string_value2 ?? string_value3) + "_postfijo").Take(1); + var blog_title = query.First(); + Assert.That(blog_title, Is.EqualTo("string_value1_postfijo")); + + var query_sql = query.ToString().ToLower(); + StringAssert.DoesNotContain("case", query_sql ); + StringAssert.Contains("coalesce(@p__linq__0,@p__linq__1,@p__linq__2,e'')", query_sql); + } + } + [Test] public void TestTableValuedStoredFunctions() {