From 6c0087903dbaf2b01a54b7b43b01aae0aca49a76 Mon Sep 17 00:00:00 2001 From: Virgil Palanciuc <47669377+virgilp@users.noreply.github.com> Date: Fri, 12 Nov 2021 17:23:47 +0200 Subject: [PATCH] chore: merge upstream repo --- .github/workflows/build.yml | 10 +- .../Firebird/FirebirdLimitTests.cs | 2 +- QueryBuilder.Tests/GeneralTests.cs | 190 +++++++++++++++++ QueryBuilder.Tests/MySql/MySqlLimitTests.cs | 2 +- QueryBuilder.Tests/ParameterTypeTests.cs | 6 +- .../PostgreSql/PostgreSqlLimitTests.cs | 2 +- QueryBuilder.Tests/SelectTests.cs | 17 +- .../SqlServer/NestedSelectTests.cs | 2 +- .../SqlServer/SqlServerLegacyLimitTests.cs | 2 +- .../SqlServer/SqlServerLimitTests.cs | 2 +- QueryBuilder.Tests/UpdateTests.cs | 32 +++ QueryBuilder.Tests/WhereTests.cs | 33 +++ QueryBuilder/Base.Where.cs | 10 + QueryBuilder/Clauses/FromClause.cs | 24 ++- QueryBuilder/Clauses/IncrementClause.cs | 19 ++ QueryBuilder/Compilers/Compiler.Conditions.cs | 6 +- QueryBuilder/Compilers/Compiler.cs | 67 +++++- QueryBuilder/Compilers/FirebirdCompiler.cs | 1 + QueryBuilder/Compilers/OracleCompiler.cs | 1 + QueryBuilder/Compilers/SqlServerCompiler.cs | 16 ++ QueryBuilder/Helper.cs | 6 +- QueryBuilder/Query.Select.cs | 16 +- QueryBuilder/Query.Update.cs | 17 ++ QueryBuilder/Query.cs | 46 +++- QueryBuilder/QueryBuilder.csproj | 2 +- README.md | 4 +- SqlKata.Execution/PaginationResult.cs | 9 +- SqlKata.Execution/Query.Extensions.cs | 138 +++++++----- SqlKata.Execution/QueryFactory.cs | 197 ++++++++++-------- 29 files changed, 698 insertions(+), 181 deletions(-) create mode 100644 QueryBuilder.Tests/WhereTests.cs create mode 100644 QueryBuilder/Clauses/IncrementClause.cs diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e8845c5b..485e7079 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,6 +30,12 @@ jobs: steps: - name: Checkout uses: actions/checkout@v2 + - name: Set env + run: echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV + - name: Checking release + run: | + echo $RELEASE_VERSION + echo ${{ env.RELEASE_VERSION }} - name: Setup .NET Core uses: actions/setup-dotnet@v1 with: @@ -94,7 +100,7 @@ jobs: echo Version: $VERSION VERSION="${VERSION//v}" echo Clean Version: $VERSION - dotnet pack -v normal -c Release --include-symbols --include-source -p:PackageVersion=$VERSION -o nupkg QueryBuilder/QueryBuilder.csproj + dotnet pack -v normal -c Release --include-symbols --include-source -p:Version=$VERSION -p:PackageVersion=$VERSION -o nupkg QueryBuilder/QueryBuilder.csproj - name: Create Release NuGet package (SqlKata.Execution) run: | arrTag=(${GITHUB_REF//\// }) @@ -102,7 +108,7 @@ jobs: echo Version: $VERSION VERSION="${VERSION//v}" echo Clean Version: $VERSION - dotnet pack -v normal -c Release --include-symbols --include-source -p:PackageVersion=$VERSION -o nupkg SqlKata.Execution/SqlKata.Execution.csproj + dotnet pack -v normal -c Release --include-symbols --include-source -p:Version=$VERSION -p:PackageVersion=$VERSION -o nupkg SqlKata.Execution/SqlKata.Execution.csproj - name: Push to GitHub Feed run: dotnet nuget push ./nupkg/*.nupkg --skip-duplicate --source $GITHUB_FEED --api-key $GITHUB_TOKEN - name: Push to NuGet Feed diff --git a/QueryBuilder.Tests/Firebird/FirebirdLimitTests.cs b/QueryBuilder.Tests/Firebird/FirebirdLimitTests.cs index 58879c83..bb5bae61 100644 --- a/QueryBuilder.Tests/Firebird/FirebirdLimitTests.cs +++ b/QueryBuilder.Tests/Firebird/FirebirdLimitTests.cs @@ -52,4 +52,4 @@ public void LimitAndOffset() Assert.Equal(2, ctx.Bindings.Count); } } -} +} \ No newline at end of file diff --git a/QueryBuilder.Tests/GeneralTests.cs b/QueryBuilder.Tests/GeneralTests.cs index 428b0567..940517f0 100644 --- a/QueryBuilder.Tests/GeneralTests.cs +++ b/QueryBuilder.Tests/GeneralTests.cs @@ -401,5 +401,195 @@ public void Where_Nested() Assert.Equal("SELECT * FROM [table] WHERE ([a] = 1 OR [a] = 2)", c[EngineCodes.SqlServer].ToString()); } + + [Fact] + public void AdHoc_Throws_WhenNoColumnsProvided() => + Assert.Throws(() => + new Query("rows").With("rows", + new string[0], + new object[][] { + new object[] {}, + new object[] {}, + })); + + [Fact] + public void AdHoc_Throws_WhenNoValueRowsProvided() => + Assert.Throws(() => + new Query("rows").With("rows", + new[] { "a", "b", "c" }, + new object[][] { + })); + + [Fact] + public void AdHoc_Throws_WhenColumnsOutnumberFieldValues() => + Assert.Throws(() => + new Query("rows").With("rows", + new[] { "a", "b", "c", "d" }, + new object[][] { + new object[] { 1, 2, 3 }, + new object[] { 4, 5, 6 }, + })); + + [Fact] + public void AdHoc_Throws_WhenFieldValuesOutNumberColumns() => + Assert.Throws(() => + new Query("rows").With("rows", + new[] { "a", "b" }, + new object[][] { + new object[] { 1, 2, 3 }, + new object[] { 4, 5, 6 }, + })); + + [Fact] + public void AdHoc_SingletonRow() + { + var query = new Query("rows").With("rows", + new[] { "a" }, + new object[][] { + new object[] { 1 }, + }); + + var c = Compilers.Compile(query); + + Assert.Equal("WITH [rows] AS (SELECT [a] FROM (VALUES (1)) AS tbl ([a]))\nSELECT * FROM [rows]", c[EngineCodes.SqlServer].ToString()); + Assert.Equal("WITH \"rows\" AS (SELECT 1 AS \"a\")\nSELECT * FROM \"rows\"", c[EngineCodes.PostgreSql].ToString()); + Assert.Equal("WITH `rows` AS (SELECT 1 AS `a`)\nSELECT * FROM `rows`", c[EngineCodes.MySql].ToString()); + Assert.Equal("WITH \"rows\" AS (SELECT 1 AS \"a\")\nSELECT * FROM \"rows\"", c[EngineCodes.Sqlite].ToString()); + Assert.Equal("WITH \"ROWS\" AS (SELECT 1 AS \"A\" FROM RDB$DATABASE)\nSELECT * FROM \"ROWS\"", c[EngineCodes.Firebird].ToString()); + Assert.Equal("WITH \"rows\" AS (SELECT 1 AS \"a\" FROM DUAL)\nSELECT * FROM \"rows\"", c[EngineCodes.Oracle].ToString()); + } + + [Fact] + public void AdHoc_TwoRows() + { + var query = new Query("rows").With("rows", + new[] { "a", "b", "c" }, + new object[][] { + new object[] { 1, 2, 3 }, + new object[] { 4, 5, 6 }, + }); + + var c = Compilers.Compile(query); + + Assert.Equal("WITH [rows] AS (SELECT [a], [b], [c] FROM (VALUES (1, 2, 3), (4, 5, 6)) AS tbl ([a], [b], [c]))\nSELECT * FROM [rows]", c[EngineCodes.SqlServer].ToString()); + Assert.Equal("WITH \"rows\" AS (SELECT 1 AS \"a\", 2 AS \"b\", 3 AS \"c\" UNION ALL SELECT 4 AS \"a\", 5 AS \"b\", 6 AS \"c\")\nSELECT * FROM \"rows\"", c[EngineCodes.PostgreSql].ToString()); + Assert.Equal("WITH `rows` AS (SELECT 1 AS `a`, 2 AS `b`, 3 AS `c` UNION ALL SELECT 4 AS `a`, 5 AS `b`, 6 AS `c`)\nSELECT * FROM `rows`", c[EngineCodes.MySql].ToString()); + Assert.Equal("WITH \"rows\" AS (SELECT 1 AS \"a\", 2 AS \"b\", 3 AS \"c\" UNION ALL SELECT 4 AS \"a\", 5 AS \"b\", 6 AS \"c\")\nSELECT * FROM \"rows\"", c[EngineCodes.Sqlite].ToString()); + Assert.Equal("WITH \"ROWS\" AS (SELECT 1 AS \"A\", 2 AS \"B\", 3 AS \"C\" FROM RDB$DATABASE UNION ALL SELECT 4 AS \"A\", 5 AS \"B\", 6 AS \"C\" FROM RDB$DATABASE)\nSELECT * FROM \"ROWS\"", c[EngineCodes.Firebird].ToString()); + Assert.Equal("WITH \"rows\" AS (SELECT 1 AS \"a\", 2 AS \"b\", 3 AS \"c\" FROM DUAL UNION ALL SELECT 4 AS \"a\", 5 AS \"b\", 6 AS \"c\" FROM DUAL)\nSELECT * FROM \"rows\"", c[EngineCodes.Oracle].ToString()); + } + + [Fact] + public void AdHoc_ProperBindingsPlacement() + { + var query = new Query("rows") + .With("othercte", q => q.From("othertable").Where("othertable.status", "A")) + .Where("rows.foo", "bar") + .With("rows", + new[] { "a", "b", "c" }, + new object[][] { + new object[] { 1, 2, 3 }, + new object[] { 4, 5, 6 }, + }) + .Where("rows.baz", "buzz"); + + var c = Compilers.Compile(query); + + Assert.Equal(string.Join("\n", new[] { + "WITH [othercte] AS (SELECT * FROM [othertable] WHERE [othertable].[status] = 'A'),", + "[rows] AS (SELECT [a], [b], [c] FROM (VALUES (1, 2, 3), (4, 5, 6)) AS tbl ([a], [b], [c]))", + "SELECT * FROM [rows] WHERE [rows].[foo] = 'bar' AND [rows].[baz] = 'buzz'", + }), c[EngineCodes.SqlServer].ToString()); + } + + [Fact] + public void UnsafeLiteral_Insert() + { + var query = new Query("Table").AsInsert(new + { + Count = new UnsafeLiteral("Count + 1") + }); + + var engines = new[] { + EngineCodes.SqlServer, + }; + + var c = Compilers.Compile(engines, query); + + Assert.Equal("INSERT INTO [Table] ([Count]) VALUES (Count + 1)", c[EngineCodes.SqlServer].ToString()); + } + + [Fact] + public void UnsafeLiteral_Update() + { + var query = new Query("Table").AsUpdate(new + { + Count = new UnsafeLiteral("Count + 1") + }); + + var engines = new[] { + EngineCodes.SqlServer, + }; + + var c = Compilers.Compile(engines, query); + + Assert.Equal("UPDATE [Table] SET [Count] = Count + 1", c[EngineCodes.SqlServer].ToString()); + } + + [Fact] + public void Passing_Boolean_To_Where_Should_Call_WhereTrue_Or_WhereFalse() + { + var query = new Query("Table").Where("Col", true); + + var engines = new[] { + EngineCodes.SqlServer, + }; + + var c = Compilers.Compile(engines, query); + + Assert.Equal("SELECT * FROM [Table] WHERE [Col] = cast(1 as bit)", c[EngineCodes.SqlServer].ToString()); + } + + [Fact] + public void Passing_Boolean_False_To_Where_Should_Call_WhereTrue_Or_WhereFalse() + { + var query = new Query("Table").Where("Col", false); + + var engines = new[] { + EngineCodes.SqlServer, + }; + + var c = Compilers.Compile(engines, query); + + Assert.Equal("SELECT * FROM [Table] WHERE [Col] = cast(0 as bit)", c[EngineCodes.SqlServer].ToString()); + } + + [Fact] + public void Passing_Negative_Boolean_To_Where_Should_Call_WhereTrue_Or_WhereFalse() + { + var query = new Query("Table").Where("Col", "!=", true); + + var engines = new[] { + EngineCodes.SqlServer, + }; + + var c = Compilers.Compile(engines, query); + + Assert.Equal("SELECT * FROM [Table] WHERE [Col] != cast(1 as bit)", c[EngineCodes.SqlServer].ToString()); + } + + [Fact] + public void Passing_Negative_Boolean_False_To_Where_Should_Call_WhereTrue_Or_WhereFalse() + { + var query = new Query("Table").Where("Col", "!=", false); + + var engines = new[] { + EngineCodes.SqlServer, + }; + + var c = Compilers.Compile(engines, query); + + Assert.Equal("SELECT * FROM [Table] WHERE [Col] != cast(0 as bit)", c[EngineCodes.SqlServer].ToString()); + } } } diff --git a/QueryBuilder.Tests/MySql/MySqlLimitTests.cs b/QueryBuilder.Tests/MySql/MySqlLimitTests.cs index 3ac10c2f..ca2db254 100644 --- a/QueryBuilder.Tests/MySql/MySqlLimitTests.cs +++ b/QueryBuilder.Tests/MySql/MySqlLimitTests.cs @@ -55,4 +55,4 @@ public void WithLimitAndOffset() Assert.Equal(2, ctx.Bindings.Count); } } -} +} \ No newline at end of file diff --git a/QueryBuilder.Tests/ParameterTypeTests.cs b/QueryBuilder.Tests/ParameterTypeTests.cs index 7dc2a8ca..17620f3f 100644 --- a/QueryBuilder.Tests/ParameterTypeTests.cs +++ b/QueryBuilder.Tests/ParameterTypeTests.cs @@ -25,8 +25,8 @@ public class ParameterTypeGenerator : IEnumerable new object[] {Convert.ToSingle("10.5", CultureInfo.InvariantCulture).ToString(), 10.5}, new object[] {"-2", -2}, new object[] {Convert.ToSingle("-2.8", CultureInfo.InvariantCulture).ToString(), -2.8}, - new object[] {"true", true}, - new object[] {"false", false}, + new object[] {"cast(1 as bit)", true}, + new object[] {"cast(0 as bit)", false}, new object[] {"'2018-10-28 19:22:00'", new DateTime(2018, 10, 28, 19, 22, 0, DateTimeKind.Utc)}, new object[] {"0 /* First */", EnumExample.First}, new object[] {"1 /* Second */", EnumExample.Second}, @@ -49,4 +49,4 @@ public void CorrectParameterTypeOutput(string rendered, object input) Assert.Equal($"SELECT * FROM [Table] WHERE [Col] = {rendered}", c[EngineCodes.SqlServer]); } } -} +} \ No newline at end of file diff --git a/QueryBuilder.Tests/PostgreSql/PostgreSqlLimitTests.cs b/QueryBuilder.Tests/PostgreSql/PostgreSqlLimitTests.cs index 99af1fa5..22d7a5d6 100644 --- a/QueryBuilder.Tests/PostgreSql/PostgreSqlLimitTests.cs +++ b/QueryBuilder.Tests/PostgreSql/PostgreSqlLimitTests.cs @@ -55,4 +55,4 @@ public void WithLimitAndOffset() Assert.Equal(2, ctx.Bindings.Count); } } -} +} \ No newline at end of file diff --git a/QueryBuilder.Tests/SelectTests.cs b/QueryBuilder.Tests/SelectTests.cs index 59f0fbc6..72a2c7b9 100644 --- a/QueryBuilder.Tests/SelectTests.cs +++ b/QueryBuilder.Tests/SelectTests.cs @@ -1,7 +1,8 @@ -using SqlKata.Compilers; +using SqlKata.Compilers; using SqlKata.Extensions; using SqlKata.Tests.Infrastructure; using System; +using System.Collections.Generic; using Xunit; namespace SqlKata.Tests @@ -21,6 +22,20 @@ public void BasicSelect() Assert.Equal("SELECT \"id\", \"name\" FROM \"users\"", c[EngineCodes.Oracle]); } + [Fact] + public void BasicSelectEnumerable() + { + var q = new Query().From("users").Select(new List() { "id", "name" }); + var c = Compile(q); + + Assert.Equal("SELECT [id], [name] FROM [users]", c[EngineCodes.SqlServer]); + Assert.Equal("SELECT `id`, `name` FROM `users`", c[EngineCodes.MySql]); + Assert.Equal("SELECT \"id\", \"name\" FROM \"users\"", c[EngineCodes.PostgreSql]); + Assert.Equal("SELECT \"ID\", \"NAME\" FROM \"USERS\"", c[EngineCodes.Firebird]); + Assert.Equal("SELECT \"id\", \"name\" FROM \"users\"", c[EngineCodes.Oracle]); + } + + [Fact] public void SelectAs() { diff --git a/QueryBuilder.Tests/SqlServer/NestedSelectTests.cs b/QueryBuilder.Tests/SqlServer/NestedSelectTests.cs index 4c3f20a0..4c591ba2 100644 --- a/QueryBuilder.Tests/SqlServer/NestedSelectTests.cs +++ b/QueryBuilder.Tests/SqlServer/NestedSelectTests.cs @@ -56,7 +56,7 @@ public void SqlCompile_QueryLimitAndNestedLimit_BindingValue() // var q = new Query().From("Foo").Where("C", "c").WhereExists(n).Where("A", "a"); var actual = compiler.Compile(q).ToString(); - Assert.Contains("SELECT * FROM [Foo] WHERE [x] = true AND NOT EXISTS (SELECT 1 FROM [Bar])", + Assert.Contains("SELECT * FROM [Foo] WHERE [x] = cast(1 as bit) AND NOT EXISTS (SELECT 1 FROM [Bar])", actual); // Assert.Contains("SELECT * FROM [Foo] WHERE [C] = 'c' AND EXISTS (SELECT TOP (1) 1 FROM [Bar]) AND [A] = 'a'", actual); } diff --git a/QueryBuilder.Tests/SqlServer/SqlServerLegacyLimitTests.cs b/QueryBuilder.Tests/SqlServer/SqlServerLegacyLimitTests.cs index c881e4d9..08de2af5 100644 --- a/QueryBuilder.Tests/SqlServer/SqlServerLegacyLimitTests.cs +++ b/QueryBuilder.Tests/SqlServer/SqlServerLegacyLimitTests.cs @@ -75,4 +75,4 @@ public void ShouldKeepTheOrdersAsIsIfPaginationProvided() Assert.DoesNotContain("(SELECT 0)", compiler.Compile(query).ToString()); } } -} +} \ No newline at end of file diff --git a/QueryBuilder.Tests/SqlServer/SqlServerLimitTests.cs b/QueryBuilder.Tests/SqlServer/SqlServerLimitTests.cs index 9fc4d12e..2959d295 100644 --- a/QueryBuilder.Tests/SqlServer/SqlServerLimitTests.cs +++ b/QueryBuilder.Tests/SqlServer/SqlServerLimitTests.cs @@ -85,4 +85,4 @@ public void ShouldKeepTheOrdersAsIsIfPaginationProvided() Assert.DoesNotContain("(SELECT 0)", compiler.Compile(query).ToString()); } } -} +} \ No newline at end of file diff --git a/QueryBuilder.Tests/UpdateTests.cs b/QueryBuilder.Tests/UpdateTests.cs index e6aeebcd..bf3dd7d9 100644 --- a/QueryBuilder.Tests/UpdateTests.cs +++ b/QueryBuilder.Tests/UpdateTests.cs @@ -283,5 +283,37 @@ public void UpdateUsingExpandoObject() "UPDATE [Table] SET [Name] = 'The User', [Age] = '2018-01-01'", c[EngineCodes.SqlServer]); } + + [Fact] + public void IncrementUpdate() + { + var query = new Query("Table").AsIncrement("Total"); + var c = Compile(query); + Assert.Equal("UPDATE [Table] SET [Total] = [Total] + 1", c[EngineCodes.SqlServer]); + } + + [Fact] + public void IncrementUpdateWithValue() + { + var query = new Query("Table").AsIncrement("Total", 2); + var c = Compile(query); + Assert.Equal("UPDATE [Table] SET [Total] = [Total] + 2", c[EngineCodes.SqlServer]); + } + + [Fact] + public void IncrementUpdateWithWheres() + { + var query = new Query("Table").Where("Name", "A").AsIncrement("Total", 2); + var c = Compile(query); + Assert.Equal("UPDATE [Table] SET [Total] = [Total] + 2 WHERE [Name] = 'A'", c[EngineCodes.SqlServer]); + } + + [Fact] + public void DecrementUpdate() + { + var query = new Query("Table").Where("Name", "A").AsDecrement("Total", 2); + var c = Compile(query); + Assert.Equal("UPDATE [Table] SET [Total] = [Total] - 2 WHERE [Name] = 'A'", c[EngineCodes.SqlServer]); + } } } diff --git a/QueryBuilder.Tests/WhereTests.cs b/QueryBuilder.Tests/WhereTests.cs new file mode 100644 index 00000000..0c1254b2 --- /dev/null +++ b/QueryBuilder.Tests/WhereTests.cs @@ -0,0 +1,33 @@ +using SqlKata.Compilers; +using SqlKata.Tests.Infrastructure; +using Xunit; + +namespace SqlKata.Tests +{ + public class WhereTests : TestSupport + { + [Fact] + public void GroupedWhereFilters() + { + var q = new Query("Table1") + .Where(q => q.Or().Where("Column1", 10).Or().Where("Column2", 20)) + .Where("Column3", 30); + + var c = Compile(q); + + Assert.Equal(@"SELECT * FROM ""Table1"" WHERE (""Column1"" = 10 OR ""Column2"" = 20) AND ""Column3"" = 30", c[EngineCodes.PostgreSql]); + } + + [Fact] + public void GroupedHavingFilters() + { + var q = new Query("Table1") + .Having(q => q.Or().HavingRaw("SUM([Column1]) = ?", 10).Or().HavingRaw("SUM([Column2]) = ?", 20)) + .HavingRaw("SUM([Column3]) = ?", 30); + + var c = Compile(q); + + Assert.Equal(@"SELECT * FROM ""Table1"" HAVING (SUM(""Column1"") = 10 OR SUM(""Column2"") = 20) AND SUM(""Column3"") = 30", c[EngineCodes.PostgreSql]); + } + } +} diff --git a/QueryBuilder/Base.Where.cs b/QueryBuilder/Base.Where.cs index 02f8f498..133a7480 100644 --- a/QueryBuilder/Base.Where.cs +++ b/QueryBuilder/Base.Where.cs @@ -18,6 +18,16 @@ public Q Where(string column, string op, object value) return Not(op != "=").WhereNull(column); } + if (value is bool boolValue) + { + if (op != "=") + { + Not(); + } + + return boolValue ? WhereTrue(column) : WhereFalse(column); + } + return AddComponent("where", new BasicCondition { Column = column, diff --git a/QueryBuilder/Clauses/FromClause.cs b/QueryBuilder/Clauses/FromClause.cs index a0ca9748..1410facf 100644 --- a/QueryBuilder/Clauses/FromClause.cs +++ b/QueryBuilder/Clauses/FromClause.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; namespace SqlKata { @@ -94,4 +95,25 @@ public override AbstractClause Clone() }; } } -} + + /// + /// Represents a FROM clause that is an ad-hoc table built with predefined values. + /// + public class AdHocTableFromClause : AbstractFrom + { + public List Columns { get; set; } + public List Values { get; set; } + + public override AbstractClause Clone() + { + return new AdHocTableFromClause + { + Engine = Engine, + Alias = Alias, + Columns = Columns, + Values = Values, + Component = Component, + }; + } + } +} \ No newline at end of file diff --git a/QueryBuilder/Clauses/IncrementClause.cs b/QueryBuilder/Clauses/IncrementClause.cs new file mode 100644 index 00000000..4ee5a194 --- /dev/null +++ b/QueryBuilder/Clauses/IncrementClause.cs @@ -0,0 +1,19 @@ +namespace SqlKata +{ + public class IncrementClause : InsertClause + { + public string Column { get; set; } + public int Value { get; set; } = 1; + + public override AbstractClause Clone() + { + return new IncrementClause + { + Engine = Engine, + Component = Component, + Column = Column, + Value = Value + }; + } + } +} \ No newline at end of file diff --git a/QueryBuilder/Compilers/Compiler.Conditions.cs b/QueryBuilder/Compilers/Compiler.Conditions.cs index 1f24dd78..190a85ea 100644 --- a/QueryBuilder/Compilers/Compiler.Conditions.cs +++ b/QueryBuilder/Compilers/Compiler.Conditions.cs @@ -170,12 +170,14 @@ protected virtual string CompileBasicDateCondition(SqlResult ctx, BasicDateCondi protected virtual string CompileNestedCondition(SqlResult ctx, NestedCondition x) where Q : BaseQuery { - if (!x.Query.HasComponent("where", EngineCode)) + if (!(x.Query.HasComponent("where", EngineCode) || x.Query.HasComponent("having", EngineCode))) { return null; } - var clauses = x.Query.GetComponents("where", EngineCode); + var clause = x.Query.HasComponent("where", EngineCode) ? "where" : "having"; + + var clauses = x.Query.GetComponents(clause, EngineCode); var sql = CompileConditions(ctx, clauses); diff --git a/QueryBuilder/Compilers/Compiler.cs b/QueryBuilder/Compilers/Compiler.cs index d67a2372..bb14b9b4 100644 --- a/QueryBuilder/Compilers/Compiler.cs +++ b/QueryBuilder/Compilers/Compiler.cs @@ -26,6 +26,8 @@ protected Compiler() public virtual string EngineCode { get; } + protected virtual string SingleRowDummyTableName { get => null; } + /// /// A list of white-listed operators /// @@ -212,6 +214,27 @@ public virtual SqlResult Compile(IEnumerable queries) return ctx; } + protected virtual SqlResult CompileAdHocQuery(AdHocTableFromClause adHoc) + { + var ctx = new SqlResult(this); + + var row = "SELECT " + string.Join(", ", adHoc.Columns.Select(col => $"? AS {Wrap(col)}")); + + var fromTable = SingleRowDummyTableName; + + if (fromTable != null) + { + row += $" FROM {fromTable}"; + } + + var rows = string.Join(" UNION ALL ", Enumerable.Repeat(row, adHoc.Values.Count / adHoc.Columns.Count)); + + ctx.RawSql = rows; + ctx.Bindings = adHoc.Values; + + return ctx; + } + protected virtual SqlResult CompileDeleteQuery(Query query) { var ctx = new SqlResult(this) @@ -288,8 +311,31 @@ protected virtual SqlResult CompileUpdateQuery(Query query) throw new InvalidOperationException("Invalid table expression"); } - var toUpdate = ctx.Query.GetOneComponent("update", EngineCode); + // check for increment statements + var clause = ctx.Query.GetOneComponent("update", EngineCode); + string wheres; + + if (clause != null && clause is IncrementClause increment) + { + var column = Wrap(increment.Column); + var value = Parameter(ctx, Math.Abs(increment.Value)); + var sign = increment.Value >= 0 ? "+" : "-"; + + wheres = CompileWheres(ctx); + + if (!string.IsNullOrEmpty(wheres)) + { + wheres = " " + wheres; + } + + ctx.RawSql = $"UPDATE {table} SET {column} = {column} {sign} {value}{wheres}"; + + return ctx; + } + + + var toUpdate = ctx.Query.GetOneComponent("update", EngineCode); var parts = new List(); for (var i = 0; i < toUpdate.Columns.Count; i++) @@ -297,16 +343,16 @@ protected virtual SqlResult CompileUpdateQuery(Query query) parts.Add(Wrap(toUpdate.Columns[i]) + " = " + Parameter(ctx, toUpdate.Values[i])); } - var where = CompileWheres(ctx); + var sets = string.Join(", ", parts); - if (!string.IsNullOrEmpty(where)) + wheres = CompileWheres(ctx); + + if (!string.IsNullOrEmpty(wheres)) { - where = " " + where; + wheres = " " + wheres; } - var sets = string.Join(", ", parts); - - ctx.RawSql = $"UPDATE {table} SET {sets}{where}"; + ctx.RawSql = $"UPDATE {table} SET {sets}{wheres}"; return ctx; } @@ -443,6 +489,13 @@ public virtual SqlResult CompileCte(AbstractFrom cte) ctx.RawSql = $"{WrapValue(queryFromClause.Alias)} AS ({subCtx.RawSql})"; } + else if (cte is AdHocTableFromClause adHoc) + { + var subCtx = CompileAdHocQuery(adHoc); + ctx.Bindings.AddRange(subCtx.Bindings); + + ctx.RawSql = $"{WrapValue(adHoc.Alias)} AS ({subCtx.RawSql})"; + } return ctx; } diff --git a/QueryBuilder/Compilers/FirebirdCompiler.cs b/QueryBuilder/Compilers/FirebirdCompiler.cs index 9dc48cd6..98702dbf 100644 --- a/QueryBuilder/Compilers/FirebirdCompiler.cs +++ b/QueryBuilder/Compilers/FirebirdCompiler.cs @@ -11,6 +11,7 @@ public FirebirdCompiler() } public override string EngineCode { get; } = EngineCodes.Firebird; + protected override string SingleRowDummyTableName => "RDB$DATABASE"; protected override SqlResult CompileInsertQuery(Query query) { diff --git a/QueryBuilder/Compilers/OracleCompiler.cs b/QueryBuilder/Compilers/OracleCompiler.cs index 4dbc8a28..3eda8b21 100644 --- a/QueryBuilder/Compilers/OracleCompiler.cs +++ b/QueryBuilder/Compilers/OracleCompiler.cs @@ -17,6 +17,7 @@ public OracleCompiler() public override string EngineCode { get; } = EngineCodes.Oracle; public bool UseLegacyPagination { get; set; } = false; + protected override string SingleRowDummyTableName => "DUAL"; public /* friend */ override SqlResult CompileSelectQuery(Query query) { diff --git a/QueryBuilder/Compilers/SqlServerCompiler.cs b/QueryBuilder/Compilers/SqlServerCompiler.cs index 9f9070b5..e456f76a 100644 --- a/QueryBuilder/Compilers/SqlServerCompiler.cs +++ b/QueryBuilder/Compilers/SqlServerCompiler.cs @@ -344,5 +344,21 @@ protected override string CompileBasicDateCondition(SqlResult ctx, BasicDateCond return sql; } + + protected override SqlResult CompileAdHocQuery(AdHocTableFromClause adHoc) + { + var ctx = new SqlResult(this); + + var colNames = string.Join(", ", adHoc.Columns.Select(Wrap)); + + var valueRow = string.Join(", ", Enumerable.Repeat("?", adHoc.Columns.Count)); + var valueRows = string.Join(", ", Enumerable.Repeat($"({valueRow})", adHoc.Values.Count / adHoc.Columns.Count)); + var sql = $"SELECT {colNames} FROM (VALUES {valueRows}) AS tbl ({colNames})"; + + ctx.RawSql = sql; + ctx.Bindings = adHoc.Values; + + return ctx; + } } } diff --git a/QueryBuilder/Helper.cs b/QueryBuilder/Helper.cs index 998e5d59..218a95e3 100644 --- a/QueryBuilder/Helper.cs +++ b/QueryBuilder/Helper.cs @@ -2,6 +2,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; +using System.Text; using System.Text.RegularExpressions; namespace SqlKata @@ -88,8 +89,9 @@ public static string ReplaceAll(string subject, string match, Func ); return splitted.Skip(1) - .Select((item, index) => callback(index) + item) - .Aggregate(splitted.First(), (left, right) => left + right); + .Select((item, index) => callback(index) + item) + .Aggregate(new StringBuilder(splitted.First()), (prev, right) => prev.Append(right)) + .ToString(); } public static string JoinArray(string glue, IEnumerable array) diff --git a/QueryBuilder/Query.Select.cs b/QueryBuilder/Query.Select.cs index c4dc7ba4..cff67ba4 100644 --- a/QueryBuilder/Query.Select.cs +++ b/QueryBuilder/Query.Select.cs @@ -8,11 +8,12 @@ public partial class Query { public Query Select(params string[] columns) { - return SelectAs( - columns - .Select(x => (x, null as string)) - .ToArray() - ); + return Select(columns.AsEnumerable()); + } + + public Query Select(IEnumerable columns) + { + return SelectAs(columns.Select(x => (x, null as string))); } /// @@ -20,6 +21,11 @@ public Query Select(params string[] columns) /// /// public Query SelectAs(params (string, string)[] columns) + { + return SelectAs(columns.AsEnumerable()); + } + + public Query SelectAs(IEnumerable<(string, string)> columns) { Method = "select"; diff --git a/QueryBuilder/Query.Update.cs b/QueryBuilder/Query.Update.cs index ced8bab8..d88aeb00 100644 --- a/QueryBuilder/Query.Update.cs +++ b/QueryBuilder/Query.Update.cs @@ -54,5 +54,22 @@ public Query AsUpdate(IEnumerable> values) return this; } + + public Query AsIncrement(string column, int value = 1) + { + Method = "update"; + AddOrReplaceComponent("update", new IncrementClause + { + Column = column, + Value = value + }); + + return this; + } + + public Query AsDecrement(string column, int value = 1) + { + return AsIncrement(column, -value); + } } } diff --git a/QueryBuilder/Query.cs b/QueryBuilder/Query.cs index 03d1fda0..941bcb8f 100755 --- a/QueryBuilder/Query.cs +++ b/QueryBuilder/Query.cs @@ -8,10 +8,11 @@ namespace SqlKata { public partial class Query : BaseQuery { + private string comment; + public bool IsDistinct { get; set; } = false; public string QueryAlias { get; set; } public string Method { get; set; } = "select"; - public string QueryComment { get; set; } public List Includes = new List(); public Dictionary Variables = new Dictionary(); @@ -25,6 +26,8 @@ public Query(string table, string comment = null) : base() Comment(comment); } + public string GetComment() => comment ?? ""; + public bool HasOffset(string engineCode = null) => GetOffset(engineCode) > 0; public bool HasLimit(string engineCode = null) => GetLimit(engineCode) > 0; @@ -63,9 +66,14 @@ public Query As(string alias) return this; } + /// + /// Sets a comment for the query. + /// + /// The comment. + /// public Query Comment(string comment) { - QueryComment = comment; + this.comment = comment; return this; } @@ -118,6 +126,40 @@ public Query With(string alias, Func fn) return With(alias, fn.Invoke(new Query())); } + /// + /// Constructs an ad-hoc table of the given data as a CTE. + /// + public Query With(string alias, IEnumerable columns, IEnumerable> valuesCollection) + { + var columnsList = columns?.ToList(); + var valuesCollectionList = valuesCollection?.ToList(); + + if ((columnsList?.Count ?? 0) == 0 || (valuesCollectionList?.Count ?? 0) == 0) + { + throw new InvalidOperationException("Columns and valuesCollection cannot be null or empty"); + } + + var clause = new AdHocTableFromClause() + { + Alias = alias, + Columns = columnsList, + Values = new List(), + }; + + foreach (var values in valuesCollectionList) + { + var valuesList = values.ToList(); + if (columnsList.Count != valuesList.Count) + { + throw new InvalidOperationException("Columns count should be equal to each Values count"); + } + + clause.Values.AddRange(valuesList); + } + + return AddComponent("cte", clause); + } + public Query WithRaw(string alias, string sql, params object[] bindings) { return AddComponent("cte", new RawFromClause diff --git a/QueryBuilder/QueryBuilder.csproj b/QueryBuilder/QueryBuilder.csproj index c780b068..6335a971 100755 --- a/QueryBuilder/QueryBuilder.csproj +++ b/QueryBuilder/QueryBuilder.csproj @@ -14,7 +14,7 @@ sql;query-builder;dynamic-query https://github.com/sqlkata/querybuilder https://github.com/sqlkata/querybuilder - https://github.com/sqlkata/querybuilder/licence + MIT true git https://github.com/sqlkata/querybuilder diff --git a/README.md b/README.md index 4428c013..9d5a22e3 100644 --- a/README.md +++ b/README.md @@ -89,11 +89,11 @@ var books = db.Query("Books") ``` This will include the property "Author" on each "Book" -```json +```jsonc [{ "Id": 1, "PublishedAt": "2019-01-01", - "AuthorId": 2 + "AuthorId": 2, "Author": { // <-- included property "Id": 2, "...": "" diff --git a/SqlKata.Execution/PaginationResult.cs b/SqlKata.Execution/PaginationResult.cs index 23344207..85277503 100644 --- a/SqlKata.Execution/PaginationResult.cs +++ b/SqlKata.Execution/PaginationResult.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Data; +using System.Threading; using System.Threading.Tasks; namespace SqlKata.Execution @@ -71,9 +72,9 @@ public PaginationResult Next(IDbTransaction transaction = null, int? timeout return this.Query.Paginate(Page + 1, PerPage, transaction, timeout); } - public async Task> NextAsync(IDbTransaction transaction = null, int? timeout = null) + public async Task> NextAsync(IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await this.Query.PaginateAsync(Page + 1, PerPage, transaction, timeout); + return await this.Query.PaginateAsync(Page + 1, PerPage, transaction, timeout, cancellationToken); } public Query PreviousQuery() @@ -86,9 +87,9 @@ public PaginationResult Previous(IDbTransaction transaction = null, int? time return this.Query.Paginate(Page - 1, PerPage, transaction, timeout); } - public async Task> PreviousAsync(IDbTransaction transaction = null, int? timeout = null) + public async Task> PreviousAsync(IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await this.Query.PaginateAsync(Page - 1, PerPage, transaction, timeout); + return await this.Query.PaginateAsync(Page - 1, PerPage, transaction, timeout, cancellationToken); } public PaginationIterator Each diff --git a/SqlKata.Execution/Query.Extensions.cs b/SqlKata.Execution/Query.Extensions.cs index 18b52f66..d9f5ef30 100644 --- a/SqlKata.Execution/Query.Extensions.cs +++ b/SqlKata.Execution/Query.Extensions.cs @@ -2,6 +2,7 @@ using System; using System.Threading.Tasks; using System.Data; +using System.Threading; namespace SqlKata.Execution { @@ -12,9 +13,9 @@ public static bool Exists(this Query query, IDbTransaction transaction = null, i return CreateQueryFactory(query).Exists(query, transaction, timeout); } - public async static Task ExistsAsync(this Query query, IDbTransaction transaction = null, int? timeout = null) + public async static Task ExistsAsync(this Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await CreateQueryFactory(query).ExistsAsync(query, transaction, timeout); + return await CreateQueryFactory(query).ExistsAsync(query, transaction, timeout, cancellationToken); } public static bool NotExist(this Query query, IDbTransaction transaction = null, int? timeout = null) @@ -22,9 +23,9 @@ public static bool NotExist(this Query query, IDbTransaction transaction = null, return !CreateQueryFactory(query).Exists(query, transaction, timeout); } - public async static Task NotExistAsync(this Query query, IDbTransaction transaction = null, int? timeout = null) + public async static Task NotExistAsync(this Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return !(await CreateQueryFactory(query).ExistsAsync(query, transaction, timeout)); + return !(await CreateQueryFactory(query).ExistsAsync(query, transaction, timeout, cancellationToken)); } public static IEnumerable Get(this Query query, IDbTransaction transaction = null, int? timeout = null) @@ -32,9 +33,9 @@ public static IEnumerable Get(this Query query, IDbTransaction transaction return CreateQueryFactory(query).Get(query, transaction, timeout); } - public static async Task> GetAsync(this Query query, IDbTransaction transaction = null, int? timeout = null) + public static async Task> GetAsync(this Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await CreateQueryFactory(query).GetAsync(query, transaction, timeout); + return await CreateQueryFactory(query).GetAsync(query, transaction, timeout, cancellationToken); } public static IEnumerable Get(this Query query, IDbTransaction transaction = null, int? timeout = null) @@ -42,9 +43,9 @@ public static IEnumerable Get(this Query query, IDbTransaction transact return query.Get(transaction, timeout); } - public static async Task> GetAsync(this Query query, IDbTransaction transaction = null, int? timeout = null) + public static async Task> GetAsync(this Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await GetAsync(query, transaction, timeout); + return await GetAsync(query, transaction, timeout, cancellationToken); } public static T FirstOrDefault(this Query query, IDbTransaction transaction = null, int? timeout = null) @@ -52,9 +53,9 @@ public static T FirstOrDefault(this Query query, IDbTransaction transaction = return CreateQueryFactory(query).FirstOrDefault(query, transaction, timeout); } - public static async Task FirstOrDefaultAsync(this Query query, IDbTransaction transaction = null, int? timeout = null) + public static async Task FirstOrDefaultAsync(this Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await CreateQueryFactory(query).FirstOrDefaultAsync(query, transaction, timeout); + return await CreateQueryFactory(query).FirstOrDefaultAsync(query, transaction, timeout, cancellationToken); } public static dynamic FirstOrDefault(this Query query, IDbTransaction transaction = null, int? timeout = null) @@ -62,9 +63,9 @@ public static dynamic FirstOrDefault(this Query query, IDbTransaction transactio return FirstOrDefault(query, transaction, timeout); } - public static async Task FirstOrDefaultAsync(this Query query, IDbTransaction transaction = null, int? timeout = null) + public static async Task FirstOrDefaultAsync(this Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await FirstOrDefaultAsync(query, transaction, timeout); + return await FirstOrDefaultAsync(query, transaction, timeout, cancellationToken); } public static T First(this Query query, IDbTransaction transaction = null, int? timeout = null) @@ -72,9 +73,9 @@ public static T First(this Query query, IDbTransaction transaction = null, in return CreateQueryFactory(query).First(query, transaction, timeout); } - public static async Task FirstAsync(this Query query, IDbTransaction transaction = null, int? timeout = null) + public static async Task FirstAsync(this Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await CreateQueryFactory(query).FirstAsync(query, transaction, timeout); + return await CreateQueryFactory(query).FirstAsync(query, transaction, timeout, cancellationToken); } public static dynamic First(this Query query, IDbTransaction transaction = null, int? timeout = null) @@ -82,9 +83,9 @@ public static dynamic First(this Query query, IDbTransaction transaction = null, return First(query, transaction, timeout); } - public static async Task FirstAsync(this Query query, IDbTransaction transaction = null, int? timeout = null) + public static async Task FirstAsync(this Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await FirstAsync(query, transaction, timeout); + return await FirstAsync(query, transaction, timeout, cancellationToken); } public static PaginationResult Paginate(this Query query, int page, int perPage = 25, IDbTransaction transaction = null, int? timeout = null) @@ -94,11 +95,11 @@ public static PaginationResult Paginate(this Query query, int page, int pe return db.Paginate(query, page, perPage, transaction, timeout); } - public static async Task> PaginateAsync(this Query query, int page, int perPage = 25, IDbTransaction transaction = null, int? timeout = null) + public static async Task> PaginateAsync(this Query query, int page, int perPage = 25, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { var db = CreateQueryFactory(query); - return await db.PaginateAsync(query, page, perPage, transaction, timeout); + return await db.PaginateAsync(query, page, perPage, transaction, timeout, cancellationToken); } public static PaginationResult Paginate(this Query query, int page, int perPage = 25, IDbTransaction transaction = null, int? timeout = null) @@ -106,9 +107,9 @@ public static PaginationResult Paginate(this Query query, int page, int return query.Paginate(page, perPage, transaction, timeout); } - public static async Task> PaginateAsync(this Query query, int page, int perPage = 25, IDbTransaction transaction = null, int? timeout = null) + public static async Task> PaginateAsync(this Query query, int page, int perPage = 25, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await PaginateAsync(query, page, perPage, transaction, timeout); + return await PaginateAsync(query, page, perPage, transaction, timeout, cancellationToken); } public static void Chunk(this Query query, int chunkSize, Func, int, bool> func, IDbTransaction transaction = null, int? timeout = null) @@ -117,18 +118,18 @@ public static void Chunk(this Query query, int chunkSize, Func db.Chunk(query, chunkSize, func, transaction, timeout); } - public static async Task ChunkAsync(this Query query, int chunkSize, Func, int, bool> func, IDbTransaction transaction = null, int? timeout = null) + public static async Task ChunkAsync(this Query query, int chunkSize, Func, int, bool> func, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - await CreateQueryFactory(query).ChunkAsync(query, chunkSize, func, transaction, timeout); + await CreateQueryFactory(query).ChunkAsync(query, chunkSize, func, transaction, timeout, cancellationToken); } public static void Chunk(this Query query, int chunkSize, Func, int, bool> func, IDbTransaction transaction = null, int? timeout = null) { query.Chunk(chunkSize, func, transaction, timeout); } - public static async Task ChunkAsync(this Query query, int chunkSize, Func, int, bool> func, IDbTransaction transaction = null, int? timeout = null) + public static async Task ChunkAsync(this Query query, int chunkSize, Func, int, bool> func, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - await ChunkAsync(query, chunkSize, func, transaction, timeout); + await ChunkAsync(query, chunkSize, func, transaction, timeout, cancellationToken); } public static void Chunk(this Query query, int chunkSize, Action, int> action, IDbTransaction transaction = null, int? timeout = null) @@ -138,9 +139,9 @@ public static void Chunk(this Query query, int chunkSize, Action(this Query query, int chunkSize, Action, int> action, IDbTransaction transaction = null, int? timeout = null) + public static async Task ChunkAsync(this Query query, int chunkSize, Action, int> action, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - await CreateQueryFactory(query).ChunkAsync(query, chunkSize, action, transaction, timeout); + await CreateQueryFactory(query).ChunkAsync(query, chunkSize, action, transaction, timeout, cancellationToken); } public static void Chunk(this Query query, int chunkSize, Action, int> action, IDbTransaction transaction = null, int? timeout = null) @@ -148,9 +149,9 @@ public static void Chunk(this Query query, int chunkSize, Action(chunkSize, action, transaction, timeout); } - public static async Task ChunkAsync(this Query query, int chunkSize, Action, int> action, IDbTransaction transaction = null, int? timeout = null) + public static async Task ChunkAsync(this Query query, int chunkSize, Action, int> action, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - await ChunkAsync(query, chunkSize, action, transaction, timeout); + await ChunkAsync(query, chunkSize, action, transaction, timeout, cancellationToken); } public static int Insert(this Query query, IEnumerable> values, IDbTransaction transaction = null, int? timeout = null) @@ -158,9 +159,9 @@ public static int Insert(this Query query, IEnumerable InsertAsync(this Query query, IEnumerable> values, IDbTransaction transaction = null, int? timeout = null) + public static async Task InsertAsync(this Query query, IEnumerable> values, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await CreateQueryFactory(query).ExecuteAsync(query.AsInsert(values), transaction, timeout); + return await CreateQueryFactory(query).ExecuteAsync(query.AsInsert(values), transaction, timeout, cancellationToken); } public static int Insert(this Query query, IEnumerable columns, IEnumerable> valuesCollection, IDbTransaction transaction = null, int? timeout = null) @@ -168,14 +169,19 @@ public static int Insert(this Query query, IEnumerable columns, IEnumera return CreateQueryFactory(query).Execute(query.AsInsert(columns, valuesCollection), transaction, timeout); } + public static async Task InsertAsync(this Query query, IEnumerable columns, IEnumerable> valuesCollection, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) + { + return await CreateQueryFactory(query).ExecuteAsync(query.AsInsert(columns, valuesCollection), transaction, timeout, cancellationToken); + } + public static int Insert(this Query query, IEnumerable columns, Query fromQuery, IDbTransaction transaction = null, int? timeout = null) { return CreateQueryFactory(query).Execute(query.AsInsert(columns, fromQuery), transaction, timeout); } - public static async Task InsertAsync(this Query query, IEnumerable columns, Query fromQuery, IDbTransaction transaction = null, int? timeout = null) + public static async Task InsertAsync(this Query query, IEnumerable columns, Query fromQuery, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await CreateQueryFactory(query).ExecuteAsync(query.AsInsert(columns, fromQuery), transaction, timeout); + return await CreateQueryFactory(query).ExecuteAsync(query.AsInsert(columns, fromQuery), transaction, timeout, cancellationToken); } public static int Insert(this Query query, object data, IDbTransaction transaction = null, int? timeout = null) @@ -183,9 +189,9 @@ public static int Insert(this Query query, object data, IDbTransaction transacti return CreateQueryFactory(query).Execute(query.AsInsert(data), transaction, timeout); } - public static async Task InsertAsync(this Query query, object data, IDbTransaction transaction = null, int? timeout = null) + public static async Task InsertAsync(this Query query, object data, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await CreateQueryFactory(query).ExecuteAsync(query.AsInsert(data), transaction, timeout); + return await CreateQueryFactory(query).ExecuteAsync(query.AsInsert(data), transaction, timeout, cancellationToken); } public static T InsertGetId(this Query query, object data, IDbTransaction transaction = null, int? timeout = null) @@ -197,10 +203,10 @@ public static T InsertGetId(this Query query, object data, IDbTransaction tra return row.Id; } - public static async Task InsertGetIdAsync(this Query query, object data, IDbTransaction transaction = null, int? timeout = null) + public static async Task InsertGetIdAsync(this Query query, object data, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { var row = await CreateQueryFactory(query) - .FirstAsync>(query.AsInsert(data, true), transaction, timeout); + .FirstAsync>(query.AsInsert(data, true), transaction, timeout, cancellationToken); return row.Id; } @@ -212,9 +218,9 @@ public static T InsertGetId(this Query query, IEnumerable InsertGetIdAsync(this Query query, IEnumerable> data, IDbTransaction transaction = null, int? timeout = null) + public static async Task InsertGetIdAsync(this Query query, IEnumerable> data, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - var row = await CreateQueryFactory(query).FirstAsync>(query.AsInsert(data, true), transaction, timeout); + var row = await CreateQueryFactory(query).FirstAsync>(query.AsInsert(data, true), transaction, timeout, cancellationToken); return row.Id; } @@ -224,9 +230,9 @@ public static int Update(this Query query, IEnumerable UpdateAsync(this Query query, IEnumerable> values, IDbTransaction transaction = null, int? timeout = null) + public static async Task UpdateAsync(this Query query, IEnumerable> values, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await CreateQueryFactory(query).ExecuteAsync(query.AsUpdate(values), transaction, timeout); + return await CreateQueryFactory(query).ExecuteAsync(query.AsUpdate(values), transaction, timeout, cancellationToken); } public static int Update(this Query query, object data, IDbTransaction transaction = null, int? timeout = null) @@ -234,9 +240,29 @@ public static int Update(this Query query, object data, IDbTransaction transacti return CreateQueryFactory(query).Execute(query.AsUpdate(data), transaction, timeout); } - public static async Task UpdateAsync(this Query query, object data, IDbTransaction transaction = null, int? timeout = null) + public static async Task UpdateAsync(this Query query, object data, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) + { + return await CreateQueryFactory(query).ExecuteAsync(query.AsUpdate(data), transaction, timeout, cancellationToken); + } + + public static int Increment(this Query query, string column, int value = 1, IDbTransaction transaction = null, int? timeout = null) + { + return CreateQueryFactory(query).Execute(query.AsIncrement(column, value), transaction, timeout); + } + + public static async Task IncrementAsync(this Query query, string column, int value = 1, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) + { + return await CreateQueryFactory(query).ExecuteAsync(query.AsIncrement(column, value), transaction, timeout, cancellationToken); + } + + public static int Decrement(this Query query, string column, int value = 1, IDbTransaction transaction = null, int? timeout = null) + { + return CreateQueryFactory(query).Execute(query.AsDecrement(column, value), transaction, timeout); + } + + public static async Task DecrementAsync(this Query query, string column, int value = 1, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await CreateQueryFactory(query).ExecuteAsync(query.AsUpdate(data), transaction, timeout); + return await CreateQueryFactory(query).ExecuteAsync(query.AsDecrement(column, value), transaction, timeout, cancellationToken); } public static int Delete(this Query query, IDbTransaction transaction = null, int? timeout = null) @@ -244,9 +270,9 @@ public static int Delete(this Query query, IDbTransaction transaction = null, in return CreateQueryFactory(query).Execute(query.AsDelete(), transaction, timeout); } - public static async Task DeleteAsync(this Query query, IDbTransaction transaction = null, int? timeout = null) + public static async Task DeleteAsync(this Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await CreateQueryFactory(query).ExecuteAsync(query.AsDelete(), transaction, timeout); + return await CreateQueryFactory(query).ExecuteAsync(query.AsDelete(), transaction, timeout, cancellationToken); } public static T Aggregate(this Query query, string aggregateOperation, string[] columns, IDbTransaction transaction = null, int? timeout = null) @@ -256,10 +282,10 @@ public static T Aggregate(this Query query, string aggregateOperation, string return db.ExecuteScalar(query.SelectAggregate(aggregateOperation, columns, AbstractAggregateColumn.AggregateDistinct.aggregateNonDistinct), transaction, timeout); } - public static async Task SelectAggregateAsync(this Query query, string aggregateOperation, string[] columns, IDbTransaction transaction = null, int? timeout = null) + public static async Task SelectAggregateAsync(this Query query, string aggregateOperation, string[] columns, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { var db = CreateQueryFactory(query); - return await db.ExecuteScalarAsync(query.SelectAggregate(aggregateOperation, columns, AbstractAggregateColumn.AggregateDistinct.aggregateNonDistinct), transaction, timeout); + return await db.ExecuteScalarAsync(query.SelectAggregate(aggregateOperation, columns, AbstractAggregateColumn.AggregateDistinct.aggregateNonDistinct), transaction, timeout, cancellationToken); } public static T Count(this Query query, string[] columns = null, IDbTransaction transaction = null, int? timeout = null) @@ -269,11 +295,11 @@ public static T Count(this Query query, string[] columns = null, IDbTransacti return db.ExecuteScalar(query.SelectCount(columns), transaction, timeout); } - public static async Task SelectCountAsync(this Query query, string[] columns = null, IDbTransaction transaction = null, int? timeout = null) + public static async Task SelectCountAsync(this Query query, string[] columns = null, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { var db = CreateQueryFactory(query); - return await db.ExecuteScalarAsync(query.SelectCount(columns), transaction, timeout); + return await db.ExecuteScalarAsync(query.SelectCount(columns), transaction, timeout, cancellationToken); } public static T Average(this Query query, string column, IDbTransaction transaction = null, int? timeout = null) @@ -281,9 +307,9 @@ public static T Average(this Query query, string column, IDbTransaction trans return query.Aggregate("avg", new[] { column }, transaction, timeout); } - public static async Task SelectAverageAsync(this Query query, string column, IDbTransaction transaction = null, int? timeout = null) + public static async Task SelectAverageAsync(this Query query, string column, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await query.SelectAggregateAsync("avg", new[] { column }, transaction, timeout); + return await query.SelectAggregateAsync("avg", new[] { column }, transaction, timeout, cancellationToken); } public static T Sum(this Query query, string column, IDbTransaction transaction = null, int? timeout = null) @@ -291,9 +317,9 @@ public static T Sum(this Query query, string column, IDbTransaction transacti return query.Aggregate("sum", new[] { column }, transaction, timeout); } - public static async Task SelectSumAsync(this Query query, string column, IDbTransaction transaction = null, int? timeout = null) + public static async Task SelectSumAsync(this Query query, string column, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await query.SelectAggregateAsync("sum", new[] { column }, transaction, timeout); + return await query.SelectAggregateAsync("sum", new[] { column }, transaction, timeout, cancellationToken); } public static T Min(this Query query, string column, IDbTransaction transaction = null, int? timeout = null) @@ -301,9 +327,9 @@ public static T Min(this Query query, string column, IDbTransaction transacti return query.Aggregate("min", new[] { column }, transaction, timeout); } - public static async Task SelectMinAsync(this Query query, string column, IDbTransaction transaction = null, int? timeout = null) + public static async Task SelectMinAsync(this Query query, string column, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await query.SelectAggregateAsync("min", new[] { column }, transaction, timeout); + return await query.SelectAggregateAsync("min", new[] { column }, transaction, timeout, cancellationToken); } public static T Max(this Query query, string column, IDbTransaction transaction = null, int? timeout = null) @@ -311,9 +337,9 @@ public static T Max(this Query query, string column, IDbTransaction transacti return query.Aggregate("max", new[] { column }, transaction, timeout); } - public static async Task SelectMaxAsync(this Query query, string column, IDbTransaction transaction = null, int? timeout = null) + public static async Task SelectMaxAsync(this Query query, string column, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await query.SelectAggregateAsync("max", new[] { column }, transaction, timeout); + return await query.SelectAggregateAsync("max", new[] { column }, transaction, timeout, cancellationToken); } internal static XQuery CastToXQuery(Query query, string method = null) diff --git a/SqlKata.Execution/QueryFactory.cs b/SqlKata.Execution/QueryFactory.cs index b5cc9c1c..89248d0a 100644 --- a/SqlKata.Execution/QueryFactory.cs +++ b/SqlKata.Execution/QueryFactory.cs @@ -3,6 +3,7 @@ using System.Data; using System.Dynamic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Dapper; using Humanizer; @@ -87,18 +88,19 @@ public IEnumerable Get(Query query, IDbTransaction transaction = null, int return result; } - public async Task> GetAsync(Query query, IDbTransaction transaction = null, int? timeout = null) + public async Task> GetAsync(Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { var compiled = CompileAndLog(query); - - var result = (await this.Connection.QueryAsync( - compiled.Sql, - compiled.NamedBindings, + var commandDefinition = new CommandDefinition( + commandText: compiled.Sql, + parameters: compiled.NamedBindings, transaction: transaction, - commandTimeout: timeout ?? this.QueryTimeout - )).ToList(); + commandTimeout: timeout ?? this.QueryTimeout, + cancellationToken: cancellationToken); + + var result = (await this.Connection.QueryAsync(commandDefinition)).ToList(); - result = (await handleIncludesAsync(query, result)).ToList(); + result = (await handleIncludesAsync(query, result, cancellationToken)).ToList(); return result; } @@ -117,16 +119,17 @@ public IEnumerable> GetDictionary(Query query, IDbTr return result.Cast>(); } - public async Task>> GetDictionaryAsync(Query query, IDbTransaction transaction = null, int? timeout = null) + public async Task>> GetDictionaryAsync(Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { var compiled = CompileAndLog(query); - - var result = await this.Connection.QueryAsync( - compiled.Sql, - compiled.NamedBindings, + var commandDefinition = new CommandDefinition( + commandText: compiled.Sql, + parameters: compiled.NamedBindings, transaction: transaction, - commandTimeout: timeout ?? this.QueryTimeout - ); + commandTimeout: timeout ?? this.QueryTimeout, + cancellationToken: cancellationToken); + + var result = await this.Connection.QueryAsync(commandDefinition); return result.Cast>(); } @@ -136,9 +139,9 @@ public IEnumerable Get(Query query, IDbTransaction transaction = null, return Get(query, transaction, timeout); } - public async Task> GetAsync(Query query, IDbTransaction transaction = null, int? timeout = null) + public async Task> GetAsync(Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await GetAsync(query, transaction, timeout); + return await GetAsync(query, transaction, timeout, cancellationToken); } public T FirstOrDefault(Query query, IDbTransaction transaction = null, int? timeout = null) @@ -148,9 +151,9 @@ public T FirstOrDefault(Query query, IDbTransaction transaction = null, int? return list.ElementAtOrDefault(0); } - public async Task FirstOrDefaultAsync(Query query, IDbTransaction transaction = null, int? timeout = null) + public async Task FirstOrDefaultAsync(Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - var list = await GetAsync(query.Limit(1), transaction, timeout); + var list = await GetAsync(query.Limit(1), transaction, timeout, cancellationToken); return list.ElementAtOrDefault(0); } @@ -160,9 +163,9 @@ public dynamic FirstOrDefault(Query query, IDbTransaction transaction = null, in return FirstOrDefault(query, transaction, timeout); } - public async Task FirstOrDefaultAsync(Query query, IDbTransaction transaction = null, int? timeout = null) + public async Task FirstOrDefaultAsync(Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await FirstOrDefaultAsync(query, transaction, timeout); + return await FirstOrDefaultAsync(query, transaction, timeout, cancellationToken); } public T First(Query query, IDbTransaction transaction = null, int? timeout = null) @@ -177,9 +180,9 @@ public T First(Query query, IDbTransaction transaction = null, int? timeout = return item; } - public async Task FirstAsync(Query query, IDbTransaction transaction = null, int? timeout = null) + public async Task FirstAsync(Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - var item = await FirstOrDefaultAsync(query, transaction, timeout); + var item = await FirstOrDefaultAsync(query, transaction, timeout, cancellationToken); if (item == null) { @@ -194,9 +197,9 @@ public dynamic First(Query query, IDbTransaction transaction = null, int? timeou return First(query, transaction, timeout); } - public async Task FirstAsync(Query query, IDbTransaction transaction = null, int? timeout = null) + public async Task FirstAsync(Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await FirstAsync(query, transaction, timeout); + return await FirstAsync(query, transaction, timeout, cancellationToken); } public int Execute( @@ -218,17 +221,19 @@ public int Execute( public async Task ExecuteAsync( Query query, IDbTransaction transaction = null, - int? timeout = null + int? timeout = null, + CancellationToken cancellationToken = default ) { var compiled = CompileAndLog(query); + var commandDefinition = new CommandDefinition( + commandText: compiled.Sql, + parameters: compiled.NamedBindings, + transaction: transaction, + commandTimeout: timeout ?? this.QueryTimeout, + cancellationToken: cancellationToken); - return await this.Connection.ExecuteAsync( - compiled.Sql, - compiled.NamedBindings, - transaction, - timeout ?? this.QueryTimeout - ); + return await this.Connection.ExecuteAsync(commandDefinition); } public T ExecuteScalar(Query query, IDbTransaction transaction = null, int? timeout = null) @@ -246,17 +251,19 @@ public T ExecuteScalar(Query query, IDbTransaction transaction = null, int? t public async Task ExecuteScalarAsync( Query query, IDbTransaction transaction = null, - int? timeout = null + int? timeout = null, + CancellationToken cancellationToken = default ) { var compiled = CompileAndLog(query.Limit(1)); + var commandDefinition = new CommandDefinition( + commandText: compiled.Sql, + parameters: compiled.NamedBindings, + transaction: transaction, + commandTimeout: timeout ?? this.QueryTimeout, + cancellationToken: cancellationToken); - return await this.Connection.ExecuteScalarAsync( - compiled.Sql, - compiled.NamedBindings, - transaction, - timeout ?? this.QueryTimeout - ); + return await this.Connection.ExecuteScalarAsync(commandDefinition); } public SqlMapper.GridReader GetMultiple( @@ -278,16 +285,18 @@ public SqlMapper.GridReader GetMultiple( public async Task GetMultipleAsync( Query[] queries, IDbTransaction transaction = null, - int? timeout = null) + int? timeout = null, + CancellationToken cancellationToken = default) { var compiled = this.Compiler.Compile(queries); + var commandDefinition = new CommandDefinition( + commandText: compiled.Sql, + parameters: compiled.NamedBindings, + transaction: transaction, + commandTimeout: timeout ?? this.QueryTimeout, + cancellationToken: cancellationToken); - return await this.Connection.QueryMultipleAsync( - compiled.Sql, - compiled.NamedBindings, - transaction, - timeout ?? this.QueryTimeout - ); + return await this.Connection.QueryMultipleAsync(commandDefinition); } public IEnumerable> Get( @@ -315,13 +324,15 @@ public IEnumerable> Get( public async Task>> GetAsync( Query[] queries, IDbTransaction transaction = null, - int? timeout = null + int? timeout = null, + CancellationToken cancellationToken = default ) { var multi = await this.GetMultipleAsync( queries, transaction, - timeout + timeout, + cancellationToken ); var list = new List>(); @@ -349,14 +360,14 @@ public bool Exists(Query query, IDbTransaction transaction = null, int? timeout return rows.Any(); } - public async Task ExistsAsync(Query query, IDbTransaction transaction = null, int? timeout = null) + public async Task ExistsAsync(Query query, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { var clone = query.Clone() .ClearComponent("select") .SelectRaw("1 as [Exists]") .Limit(1); - var rows = await GetAsync(clone, transaction, timeout); + var rows = await GetAsync(clone, transaction, timeout, cancellationToken); return rows.Any(); } @@ -377,13 +388,15 @@ public async Task SelectAggregateAsync( string aggregateOperation, string[] columns = null, IDbTransaction transaction = null, - int? timeout = null + int? timeout = null, + CancellationToken cancellationToken = default ) { return await this.ExecuteScalarAsync( query.SelectAggregate(aggregateOperation, columns, AbstractAggregateColumn.AggregateDistinct.aggregateNonDistinct), transaction, - timeout + timeout, + cancellationToken ); } @@ -396,9 +409,9 @@ public T Count(Query query, string[] columns = null, IDbTransaction transacti ); } - public async Task SelectCountAsync(Query query, string[] columns = null, IDbTransaction transaction = null, int? timeout = null) + public async Task SelectCountAsync(Query query, string[] columns = null, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await this.ExecuteScalarAsync(query.SelectCount(columns), transaction, timeout); + return await this.ExecuteScalarAsync(query.SelectCount(columns), transaction, timeout, cancellationToken); } public T Average(Query query, string column, IDbTransaction transaction = null, int? timeout = null) @@ -406,9 +419,9 @@ public T Average(Query query, string column, IDbTransaction transaction = nul return this.Aggregate(query, "avg", new[] { column }); } - public async Task SelectAverageAsync(Query query, string column) + public async Task SelectAverageAsync(Query query, string column, CancellationToken cancellationToken = default) { - return await this.SelectAggregateAsync(query, "avg", new[] { column }); + return await this.SelectAggregateAsync(query, "avg", new[] { column }, cancellationToken: cancellationToken); } public T Sum(Query query, string column) @@ -416,9 +429,9 @@ public T Sum(Query query, string column) return this.Aggregate(query, "sum", new[] { column }); } - public async Task SelectSumAsync(Query query, string column) + public async Task SelectSumAsync(Query query, string column, CancellationToken cancellationToken = default) { - return await this.SelectAggregateAsync(query, "sum", new[] { column }); + return await this.SelectAggregateAsync(query, "sum", new[] { column }, cancellationToken: cancellationToken); } public T Min(Query query, string column) @@ -426,9 +439,9 @@ public T Min(Query query, string column) return this.Aggregate(query, "min", new[] { column }); } - public async Task SelectMinAsync(Query query, string column) + public async Task SelectMinAsync(Query query, string column, CancellationToken cancellationToken = default) { - return await this.SelectAggregateAsync(query, "min", new[] { column }); + return await this.SelectAggregateAsync(query, "min", new[] { column }, cancellationToken: cancellationToken); } public T Max(Query query, string column) @@ -436,9 +449,9 @@ public T Max(Query query, string column) return this.Aggregate(query, "max", new[] { column }); } - public async Task SelectMaxAsync(Query query, string column) + public async Task SelectMaxAsync(Query query, string column, CancellationToken cancellationToken = default) { - return await this.SelectAggregateAsync(query, "max", new[] { column }); + return await this.SelectAggregateAsync(query, "max", new[] { column }, cancellationToken: cancellationToken); } public PaginationResult Paginate(Query query, int page, int perPage = 25, IDbTransaction transaction = null, int? timeout = null) @@ -476,7 +489,7 @@ public PaginationResult Paginate(Query query, int page, int perPage = 25, }; } - public async Task> PaginateAsync(Query query, int page, int perPage = 25, IDbTransaction transaction = null, int? timeout = null) + public async Task> PaginateAsync(Query query, int page, int perPage = 25, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { if (page < 1) { @@ -488,13 +501,13 @@ public async Task> PaginateAsync(Query query, int page, i throw new ArgumentException("PerPage param should be greater than or equal to 1", nameof(perPage)); } - var count = await SelectCountAsync(query.Clone(), null, transaction, timeout); + var count = await SelectCountAsync(query.Clone(), null, transaction, timeout, cancellationToken); IEnumerable list; if (count > 0) { - list = await GetAsync(query.Clone().ForPage(page, perPage), transaction, timeout); + list = await GetAsync(query.Clone().ForPage(page, perPage), transaction, timeout, cancellationToken); } else { @@ -527,7 +540,7 @@ public void Chunk( while (result.HasNext) { - result = result.Next(); + result = result.Next(transaction); if (!func(result.List, result.Page)) { return; @@ -540,10 +553,11 @@ public async Task ChunkAsync( int chunkSize, Func, int, bool> func, IDbTransaction transaction = null, - int? timeout = null + int? timeout = null, + CancellationToken cancellationToken = default ) { - var result = await this.PaginateAsync(query, 1, chunkSize); + var result = await this.PaginateAsync(query, 1, chunkSize, transaction, cancellationToken: cancellationToken); if (!func(result.List, 1)) { @@ -552,7 +566,7 @@ public async Task ChunkAsync( while (result.HasNext) { - result = result.Next(); + result = result.Next(transaction); if (!func(result.List, result.Page)) { return; @@ -568,7 +582,7 @@ public void Chunk(Query query, int chunkSize, Action, int> act while (result.HasNext) { - result = result.Next(); + result = result.Next(transaction); action(result.List, result.Page); } } @@ -578,16 +592,17 @@ public async Task ChunkAsync( int chunkSize, Action, int> action, IDbTransaction transaction = null, - int? timeout = null + int? timeout = null, + CancellationToken cancellationToken = default ) { - var result = await this.PaginateAsync(query, 1, chunkSize, transaction, timeout); + var result = await this.PaginateAsync(query, 1, chunkSize, transaction, timeout, cancellationToken); action(result.List, 1); while (result.HasNext) { - result = result.Next(); + result = result.Next(transaction); action(result.List, result.Page); } } @@ -602,14 +617,16 @@ public IEnumerable Select(string sql, object param = null, IDbTransaction ); } - public async Task> SelectAsync(string sql, object param = null, IDbTransaction transaction = null, int? timeout = null) + public async Task> SelectAsync(string sql, object param = null, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await this.Connection.QueryAsync( - sql, - param, + var commandDefinition = new CommandDefinition( + commandText: sql, + parameters: param, transaction: transaction, - commandTimeout: timeout ?? this.QueryTimeout - ); + commandTimeout: timeout ?? this.QueryTimeout, + cancellationToken: cancellationToken); + + return await this.Connection.QueryAsync(commandDefinition); } public IEnumerable Select(string sql, object param = null, IDbTransaction transaction = null, int? timeout = null) @@ -617,9 +634,9 @@ public IEnumerable Select(string sql, object param = null, IDbTransacti return this.Select(sql, param, transaction, timeout); } - public async Task> SelectAsync(string sql, object param = null, IDbTransaction transaction = null, int? timeout = null) + public async Task> SelectAsync(string sql, object param = null, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await this.SelectAsync(sql, param, transaction, timeout); + return await this.SelectAsync(sql, param, transaction, timeout, cancellationToken); } public int Statement(string sql, object param = null, IDbTransaction transaction = null, int? timeout = null) @@ -627,9 +644,15 @@ public int Statement(string sql, object param = null, IDbTransaction transaction return this.Connection.Execute(sql, param, transaction: transaction, commandTimeout: timeout ?? this.QueryTimeout); } - public async Task StatementAsync(string sql, object param = null, IDbTransaction transaction = null, int? timeout = null) + public async Task StatementAsync(string sql, object param = null, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await this.Connection.ExecuteAsync(sql, param, transaction: transaction, commandTimeout: timeout ?? this.QueryTimeout); + var commandDefinition = new CommandDefinition( + commandText: sql, + parameters: param, + transaction: transaction, + commandTimeout: timeout ?? this.QueryTimeout, + cancellationToken: cancellationToken); + return await this.Connection.ExecuteAsync(commandDefinition); } private static IEnumerable handleIncludes(Query query, IEnumerable result) @@ -734,7 +757,7 @@ private static IEnumerable handleIncludes(Query query, IEnumerable resu return dynamicResult.Cast(); } - private static async Task> handleIncludesAsync(Query query, IEnumerable result) + private static async Task> handleIncludesAsync(Query query, IEnumerable result, CancellationToken cancellationToken = default) { if (!result.Any()) { @@ -784,7 +807,7 @@ private static async Task> handleIncludesAsync(Query query, IE continue; } - var children = (await include.Query.WhereIn(include.ForeignKey, localIds).GetAsync()) + var children = (await include.Query.WhereIn(include.ForeignKey, localIds).GetAsync(cancellationToken: cancellationToken)) .Cast>() .Select(x => new Dictionary(x, StringComparer.OrdinalIgnoreCase)) .GroupBy(x => x[include.ForeignKey].ToString()) @@ -813,7 +836,7 @@ private static async Task> handleIncludesAsync(Query query, IE continue; } - var related = (await include.Query.WhereIn(include.LocalKey, foreignIds).GetAsync()) + var related = (await include.Query.WhereIn(include.LocalKey, foreignIds).GetAsync(cancellationToken: cancellationToken)) .Cast>() .Select(x => new Dictionary(x, StringComparer.OrdinalIgnoreCase)) .ToDictionary(x => x[include.LocalKey].ToString()); @@ -873,4 +896,4 @@ public void Dispose() GC.SuppressFinalize(this); } } -} +} \ No newline at end of file