Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSHARP-5529: Optimize grouping.First().X to not retrieve the entire $$ROOT #1653

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions
{
internal static class AstExpressionExtensions
{
public static bool IsConstant(this AstExpression expression, BsonValue value)
=> expression is AstConstantExpression constantExpression && constantExpression.Value.Equals(value);

public static bool IsInt32Constant(this AstExpression expression, out int value)
{
if (expression is AstConstantExpression constantExpression &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,15 @@ public override AstNode VisitFilterField(AstFilterField node)

public override AstNode VisitGetFieldExpression(AstGetFieldExpression node)
{
// { $getField : { field : <elementField>, input : { $firstOrLast : "$_elements" } } } => { __agg0 : { $firstOrLast : <rootField> } } + "$__agg0"
if (IsGetFieldChainOnFirstOrLastElement(node, out var firstOrLastOperator, out var rootFieldExpression))
{
var unaryAccumulatorOperator = firstOrLastOperator == AstUnaryOperator.First ? AstUnaryAccumulatorOperator.First : AstUnaryAccumulatorOperator.Last;
var accumulatorExpression = AstExpression.UnaryAccumulator(unaryAccumulatorOperator, rootFieldExpression);
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
return AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
}

if (node.FieldName is AstConstantExpression constantFieldName &&
constantFieldName.Value.IsString &&
constantFieldName.Value.AsString == "_elements")
Expand All @@ -360,6 +369,32 @@ public override AstNode VisitGetFieldExpression(AstGetFieldExpression node)
}

return base.VisitGetFieldExpression(node);

bool IsGetFieldChainOnFirstOrLastElement(AstGetFieldExpression getFieldExpression, out AstUnaryOperator firstOrLastOperator, out AstExpression rootFieldExpression)
{
if (getFieldExpression.Input is AstGetFieldExpression innerGetFieldExpression &&
IsGetFieldChainOnFirstOrLastElement(innerGetFieldExpression, out firstOrLastOperator, out rootFieldExpression))
{
rootFieldExpression = AstExpression.GetField(rootFieldExpression, getFieldExpression.FieldName);
return true;
}

if (getFieldExpression.Input is AstUnaryExpression unaryExpression &&
unaryExpression.Operator is var unaryOperator &&
(unaryOperator is AstUnaryOperator.First or AstUnaryOperator.Last) &&
unaryExpression.Arg is AstGetFieldExpression innerMostGetFieldExpression &&
innerMostGetFieldExpression.Input.IsRootVar() &&
innerMostGetFieldExpression.FieldName.IsConstant("_elements"))
{
firstOrLastOperator = unaryOperator;
rootFieldExpression = AstExpression.GetField(AstExpression.RootVar, getFieldExpression.FieldName);
return true;
}

firstOrLastOperator = default;
rootFieldExpression = null;
return false;
}
}

public override AstNode VisitMapExpression(AstMapExpression node)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Collections.Generic;
using System.Linq;
using MongoDB.Driver.TestHelpers;
using FluentAssertions;
using Xunit;

namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira;

public class CSharp5529Tests : LinqIntegrationTest<CSharp5529Tests.ClassFixture>
{
public CSharp5529Tests(ClassFixture fixture)
: base(fixture)
{
}

[Theory]
[InlineData(1, 1, """{ $group: { _id : 1, __agg0 : { $first : "$X" } } }""", 1)]
[InlineData(1, 2, """{ $group: { _id : 1, __agg0 : { $last : "$X" } } }""", 2)]
[InlineData(2, 1, """{ $group: { _id : 1, __agg0 : { $first : "$D.Y" } } }""", 11)]
[InlineData(2, 2, """{ $group: { _id : 1, __agg0 : { $last : "$D.Y" } } }""", 22)]
[InlineData(3, 1, """{ $group: { _id : 1, __agg0 : { $first : "$D.E.Z" } } }""", 111)]
[InlineData(3, 2, """{ $group: { _id : 1, __agg0 : { $last : "$D.E.Z" } } }""", 222)]
public void First_or_Last_optimization_should_work(int level, int firstOrLast, string expectedGroupStage, int expectedResult)
{
var collection = Fixture.Collection;

var queryable = (level, firstOrLast) switch
{
(1, 1) => collection.Aggregate().Group(x => 1, g => g.First().X),
(1, 2) => collection.Aggregate().Group(x => 1, g => g.Last().X),
(2, 1) => collection.Aggregate().Group(x => 1, g => g.First().D.Y),
(2, 2) => collection.Aggregate().Group(x => 1, g => g.Last().D.Y),
(3, 1) => collection.Aggregate().Group(x => 1, g => g.First().D.E.Z),
(3, 2) => collection.Aggregate().Group(x => 1, g => g.Last().D.E.Z),
_ => throw new ArgumentException()
};

var stages = Translate(collection,queryable);
AssertStages(
stages,
expectedGroupStage,
"""{ $project : { _v : "$__agg0", _id : 0 } }""");

var result = queryable.Single();
result.Should().Be(expectedResult);
}
public class C
{
public int Id { get; set; }
public int X { get; set; }

public D D { get; set; }
}

public class D
{
public E E { get; set; }
public int Y { get; set; }
}

public class E
{
public int Z { get; set; }
}

public sealed class ClassFixture : MongoCollectionFixture<C>
{
protected override IEnumerable<C> InitialData =>
[
new C { Id = 1, X = 1, D = new D { E = new E { Z = 111 }, Y = 11 } },
new C { Id = 2, X = 2, D = new D { E = new E { Z = 222 }, Y = 22 } },
];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ public void GroupBy_select_anonymous_type_method()

Assert(query,
2,
"{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'} } }",
"{ $project: { Key: '$_id', FirstB: '$__agg0.B', _id: 0 } }");
"{ $group: { _id: '$A', __agg0: { $first: '$B'} } }",
"{ $project: { Key: '$_id', FirstB: '$__agg0', _id: 0 } }");

query = CreateQuery()
.GroupBy(x => x.A)
Expand All @@ -434,8 +434,8 @@ group p by p.A into g

Assert(query,
2,
"{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'} } }",
"{ $project: { Key: '$_id', FirstB: '$__agg0.B', _id: 0 } }");
"{ $group: { _id: '$A', __agg0: { $first: '$B'} } }",
"{ $project: { Key: '$_id', FirstB: '$__agg0', _id: 0 } }");

query = from p in CreateQuery()
group p by p.A into g
Expand Down Expand Up @@ -484,9 +484,9 @@ public void GroupBy_where_select_anonymous_type_with_duplicate_accumulators_meth

Assert(query,
1,
"{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'} } }",
"{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'}, __agg1 : { $first : '$B' } } }",
"{ $match: { '__agg0.B' : 'Balloon' } }",
"{ $project: { Key: '$_id', FirstB: '$__agg0.B', _id: 0 } }");
"{ $project: { Key: '$_id', FirstB: '$__agg1', _id: 0 } }");

query = CreateQuery()
.GroupBy(x => x.A)
Expand All @@ -511,9 +511,9 @@ where g.First().B == "Balloon"

Assert(query,
1,
"{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'} } }",
"{ $group: { _id: '$A', __agg0: { $first: '$$ROOT' }, __agg1 : { $first : '$B' } } }",
"{ $match: { '__agg0.B' : 'Balloon' } }",
"{ $project: { Key: '$_id', FirstB: '$__agg0.B', _id: 0 } }");
"{ $project: { Key: '$_id', FirstB: '$__agg1', _id: 0 } }");
}
#endif

Expand All @@ -525,8 +525,8 @@ public void GroupBy_with_resultSelector_anonymous_type_method()

Assert(query,
2,
"{ $group: { _id : '$A', __agg0 : { $first: '$$ROOT'} } }",
"{ $project : { Key : '$_id', FirstB : '$__agg0.B', _id : 0 } }");
"{ $group: { _id : '$A', __agg0 : { $first: '$B'} } }",
"{ $project : { Key : '$_id', FirstB : '$__agg0', _id : 0 } }");

query = CreateQuery()
.GroupBy(x => x.A, (k, s) => new { Key = k, FirstB = s.Select(x => x.B).First() });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ public void Should_translate_using_non_anonymous_type_with_default_constructor()

AssertStages(
result.Stages,
"{ $group : { _id : '$A', __agg0 : { $first : '$$ROOT' } } }",
"{ $project : { Property : '$_id', Field : '$__agg0.B', _id : 0 } }");
"{ $group : { _id : '$A', __agg0 : { $first : '$B' } } }",
"{ $project : { Property : '$_id', Field : '$__agg0', _id : 0 } }");

result.Value.Property.Should().Be("Amazing");
result.Value.Field.Should().Be("Baby");
Expand All @@ -53,8 +53,8 @@ public void Should_translate_using_non_anonymous_type_with_parameterized_constru

AssertStages(
result.Stages,
"{ $group : { _id : '$A', __agg0 : { $first : '$$ROOT' } } }",
"{ $project : { Property : '$_id', Field : '$__agg0.B', _id : 0 } }");
"{ $group : { _id : '$A', __agg0 : { $first : '$B' } } }",
"{ $project : { Property : '$_id', Field : '$__agg0', _id : 0 } }");

result.Value.Property.Should().Be("Amazing");
result.Value.Field.Should().Be("Baby");
Expand Down Expand Up @@ -236,8 +236,8 @@ public void Should_translate_first_with_normalization()

AssertStages(
result.Stages,
"{ $group : { _id : '$A', __agg0 : { $first : '$$ROOT' } } }",
"{ $project : { B : '$__agg0.B', _id : 0 } }");
"{ $group : { _id : '$A', __agg0 : { $first : '$B' } } }",
"{ $project : { B : '$__agg0', _id : 0 } }");

result.Value.B.Should().Be("Baby");
}
Expand All @@ -262,8 +262,8 @@ public void Should_translate_last_with_normalization()

AssertStages(
result.Stages,
"{ $group : { _id : '$A', __agg0 : { $last : '$$ROOT' } } }",
"{ $project : { B : '$__agg0.B', _id : 0 } }");
"{ $group : { _id : '$A', __agg0 : { $last : '$B' } } }",
"{ $project : { B : '$__agg0', _id : 0 } }");

result.Value.B.Should().Be("Baby");
}
Expand Down Expand Up @@ -492,8 +492,8 @@ public void Should_translate_complex_selector()
_id : '$A',
__agg0 : { $sum : 1 },
__agg1 : { $sum : { $add : ['$C.E.F', '$C.E.H'] } },
__agg2 : { $first : '$$ROOT' },
__agg3 : { $last : '$$ROOT' },
__agg2 : { $first : '$B' },
__agg3 : { $last : '$K' },
__agg4 : { $min : { $add : ['$C.E.F', '$C.E.H'] } },
__agg5 : { $max : { $add : ['$C.E.F', '$C.E.H'] } }
}
Expand All @@ -503,8 +503,8 @@ public void Should_translate_complex_selector()
$project : {
Count : '$__agg0',
Sum : '$__agg1',
First : '$__agg2.B',
Last : '$__agg3.K',
First : '$__agg2',
Last : '$__agg3',
Min : '$__agg4',
Max : '$__agg5',
_id : 0
Expand Down
8 changes: 4 additions & 4 deletions tests/MongoDB.Driver.Tests/Samples/AggregationSample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ where g.Sum(x => x.Population) > 20000
select new { State = g.Key, TotalPopulation = g.Sum(x => x.Population) };

var stages = Linq3TestHelpers.Translate(collection, queryable);
var expectedStages =
var expectedStages =
new[]
{
"{ $group : { _id : '$state', __agg0 : { $sum : '$pop' } } }",
Expand Down Expand Up @@ -173,13 +173,13 @@ public async Task Largest_and_smallest_cities_by_state()
.SortBy(x => x.State);

var pipelineTranslation = pipeline.ToString();
var expectedTranslation =
var expectedTranslation =
"aggregate([" +
"{ \"$group\" : { \"_id\" : { \"State\" : \"$state\", \"City\" : \"$city\" }, \"__agg0\" : { \"$sum\" : \"$pop\" } } }, " +
"{ \"$project\" : { \"StateAndCity\" : \"$_id\", \"Population\" : \"$__agg0\", \"_id\" : 0 } }, " +
"{ \"$sort\" : { \"Population\" : 1 } }, " +
"{ \"$group\" : { \"_id\" : \"$StateAndCity.State\", \"__agg0\" : { \"$last\" : \"$$ROOT\" }, \"__agg1\" : { \"$first\" : \"$$ROOT\" } } }, " +
"{ \"$project\" : { \"State\" : \"$_id\", \"BiggestCity\" : \"$__agg0.StateAndCity.City\", \"BiggestPopulation\" : \"$__agg0.Population\", \"SmallestCity\" : \"$__agg1.StateAndCity.City\", \"SmallestPopulation\" : \"$__agg1.Population\", \"_id\" : 0 } }, " +
"{ \"$group\" : { \"_id\" : \"$StateAndCity.State\", \"__agg0\" : { \"$last\" : \"$StateAndCity.City\" }, \"__agg1\" : { \"$last\" : \"$Population\" }, \"__agg2\" : { \"$first\" : \"$StateAndCity.City\" }, \"__agg3\" : { \"$first\" : \"$Population\" } } }, " +
"{ \"$project\" : { \"State\" : \"$_id\", \"BiggestCity\" : \"$__agg0\", \"BiggestPopulation\" : \"$__agg1\", \"SmallestCity\" : \"$__agg2\", \"SmallestPopulation\" : \"$__agg3\", \"_id\" : 0 } }, " +
"{ \"$project\" : { \"State\" : \"$State\", \"BiggestCity\" : { \"Name\" : \"$BiggestCity\", \"Population\" : \"$BiggestPopulation\" }, \"SmallestCity\" : { \"Name\" : \"$SmallestCity\", \"Population\" : \"$SmallestPopulation\" }, \"_id\" : 0 } }, " +
"{ \"$sort\" : { \"State\" : 1 } }])";
pipelineTranslation.Should().Be(expectedTranslation);
Expand Down