Skip to content

Commit dbaabb6

Browse files
committed
CSHARP-5563: Create AstExpression extension methods to make testing for constants easier.
1 parent 7cdd013 commit dbaabb6

File tree

8 files changed

+125
-90
lines changed

8 files changed

+125
-90
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs

+21-22
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,7 @@ public static AstExpression Convert(
267267
Ensure.IsNotNull(input, nameof(input));
268268
Ensure.IsNotNull(to, nameof(to));
269269

270-
if (to is AstConstantExpression toConstantExpression &&
271-
(toConstantExpression.Value as BsonString)?.Value is { } toValue &&
270+
if (to.IsStringConstant(out var toValue) &&
272271
subType == null &&
273272
byteOrder == null &&
274273
format == null &&
@@ -375,26 +374,26 @@ public static AstExpression DerivativeOrIntegralWindowExpression(AstDerivativeOr
375374

376375
public static AstExpression Divide(AstExpression arg1, AstExpression arg2)
377376
{
378-
if (arg1 is AstConstantExpression constant1 && arg2 is AstConstantExpression constant2)
377+
if (arg1.IsConstant(out var constant1) && arg2.IsConstant(out var constant2))
379378
{
380379
return Divide(constant1, constant2);
381380
}
382381

383382
return new AstBinaryExpression(AstBinaryOperator.Divide, arg1, arg2);
384383

385-
static AstExpression Divide(AstConstantExpression constant1, AstConstantExpression constant2)
384+
static AstExpression Divide(BsonValue constant1, BsonValue constant2)
386385
{
387-
return (constant1.Value.BsonType, constant2.Value.BsonType) switch
386+
return (constant1.BsonType, constant2.BsonType) switch
388387
{
389-
(BsonType.Double, BsonType.Double) => constant1.Value.AsDouble / constant2.Value.AsDouble,
390-
(BsonType.Double, BsonType.Int32) => constant1.Value.AsDouble / constant2.Value.AsInt32,
391-
(BsonType.Double, BsonType.Int64) => constant1.Value.AsDouble / constant2.Value.AsInt64,
392-
(BsonType.Int32, BsonType.Double) => constant1.Value.AsInt32 / constant2.Value.AsDouble,
393-
(BsonType.Int32, BsonType.Int32) => (double)constant1.Value.AsInt32 / constant2.Value.AsInt32,
394-
(BsonType.Int32, BsonType.Int64) => (double)constant1.Value.AsInt32 / constant2.Value.AsInt64,
395-
(BsonType.Int64, BsonType.Double) => constant1.Value.AsInt64 / constant2.Value.AsDouble,
396-
(BsonType.Int64, BsonType.Int32) => (double)constant1.Value.AsInt64 / constant2.Value.AsInt32,
397-
(BsonType.Int64, BsonType.Int64) => (double)constant1.Value.AsInt64 / constant2.Value.AsInt64,
388+
(BsonType.Double, BsonType.Double) => constant1.AsDouble / constant2.AsDouble,
389+
(BsonType.Double, BsonType.Int32) => constant1.AsDouble / constant2.AsInt32,
390+
(BsonType.Double, BsonType.Int64) => constant1.AsDouble / constant2.AsInt64,
391+
(BsonType.Int32, BsonType.Double) => constant1.AsInt32 / constant2.AsDouble,
392+
(BsonType.Int32, BsonType.Int32) => (double)constant1.AsInt32 / constant2.AsInt32,
393+
(BsonType.Int32, BsonType.Int64) => (double)constant1.AsInt32 / constant2.AsInt64,
394+
(BsonType.Int64, BsonType.Double) => constant1.AsInt64 / constant2.AsDouble,
395+
(BsonType.Int64, BsonType.Int32) => (double)constant1.AsInt64 / constant2.AsInt32,
396+
(BsonType.Int64, BsonType.Int64) => (double)constant1.AsInt64 / constant2.AsInt64,
398397
_ => new AstBinaryExpression(AstBinaryOperator.Divide, constant1, constant2)
399398
};
400399
}
@@ -829,9 +828,9 @@ public static AstExpression StrLenBytes(AstExpression arg)
829828

830829
public static AstExpression StrLenCP(AstExpression arg)
831830
{
832-
if (arg is AstConstantExpression constantExpression && constantExpression.Value.BsonType == BsonType.String)
831+
if (arg.IsStringConstant(out var stringConstant))
833832
{
834-
var value = constantExpression.Value.AsString.Length;
833+
var value = stringConstant.Length;
835834
return new AstConstantExpression(value);
836835
}
837836
return new AstUnaryExpression(AstUnaryOperator.StrLenCP, arg);
@@ -890,9 +889,9 @@ public static AstExpression Switch(IEnumerable<(AstExpression Case, AstExpressio
890889

891890
public static AstExpression ToLower(AstExpression arg)
892891
{
893-
if (arg is AstConstantExpression constantExpression && constantExpression.Value.BsonType == BsonType.String)
892+
if (arg.IsStringConstant(out var stringConstant))
894893
{
895-
var value = constantExpression.Value.AsString.ToLowerInvariant();
894+
var value = stringConstant.ToLowerInvariant();
896895
return new AstConstantExpression(value);
897896
}
898897

@@ -906,9 +905,9 @@ public static AstExpression ToString(AstExpression arg)
906905

907906
public static AstExpression ToUpper(AstExpression arg)
908907
{
909-
if (arg is AstConstantExpression constantExpression && constantExpression.Value.BsonType == BsonType.String)
908+
if (arg.IsStringConstant(out var stringConstant))
910909
{
911-
var value = constantExpression.Value.AsString.ToUpperInvariant();
910+
var value = stringConstant.ToUpperInvariant();
912911
return new AstConstantExpression(value);
913912
}
914913

@@ -985,7 +984,7 @@ public static AstExpression Zip(IEnumerable<AstExpression> inputs, bool? useLong
985984
// private static methods
986985
private static bool AllArgsAreConstantBools(AstExpression[] args, out List<bool> values)
987986
{
988-
if (args.All(arg => arg is AstConstantExpression constantExpression && constantExpression.Value.BsonType == BsonType.Boolean))
987+
if (args.All(arg => arg.IsBooleanConstant()))
989988
{
990989
values = args.Select(arg => ((AstConstantExpression)arg).Value.AsBoolean).ToList();
991990
return true;
@@ -997,7 +996,7 @@ private static bool AllArgsAreConstantBools(AstExpression[] args, out List<bool>
997996

998997
private static bool AllArgsAreConstantInt32s(AstExpression[] args, out List<int> values)
999998
{
1000-
if (args.All(arg => arg is AstConstantExpression constantExpression && constantExpression.Value.BsonType == BsonType.Int32))
999+
if (args.All(arg => arg.IsInt32Constant()))
10011000
{
10021001
values = args.Select(arg => ((AstConstantExpression)arg).Value.AsInt32).ToList();
10031002
return true;

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpressionExtensions.cs

+71-8
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,89 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions
1919
{
2020
internal static class AstExpressionExtensions
2121
{
22-
public static bool IsInt32Constant(this AstExpression expression, out int value)
22+
public static bool IsBooleanConstant(this AstExpression expression)
23+
=> expression.IsConstant<BsonBoolean>(out _);
24+
25+
public static bool IsBooleanConstant(this AstExpression expression, out bool booleanConstant)
26+
{
27+
if (expression.IsConstant<BsonBoolean>(out var bsonBooleanConstant))
28+
{
29+
booleanConstant = bsonBooleanConstant.Value;
30+
return true;
31+
}
32+
33+
booleanConstant = default;
34+
return false;
35+
}
36+
37+
public static bool IsBsonNull(this AstExpression expression)
38+
=> expression.IsConstant(out var constant) && constant.IsBsonNull;
39+
40+
public static bool IsConstant(this AstExpression expression, out BsonValue constant)
2341
{
24-
if (expression is AstConstantExpression constantExpression &&
25-
constantExpression.Value is BsonInt32 bsonInt32)
42+
if (expression is AstConstantExpression constantExpression)
2643
{
27-
value = bsonInt32.Value;
44+
constant = constantExpression.Value;
2845
return true;
2946
}
3047

31-
value = default;
48+
constant = null;
3249
return false;
33-
}
50+
}
51+
52+
public static bool IsConstant<TBsonValue>(this AstExpression expression, out TBsonValue constant)
53+
where TBsonValue : BsonValue
54+
{
55+
if (expression.IsConstant(out var bsonValueConstant) && bsonValueConstant is TBsonValue derivedBsonValueConstant)
56+
{
57+
constant = derivedBsonValueConstant;
58+
return true;
59+
}
60+
61+
constant = null;
62+
return false;
63+
}
64+
65+
public static bool IsInt32Constant(this AstExpression expression)
66+
=> expression.IsConstant<BsonInt32>(out _);
67+
68+
public static bool IsInt32Constant(this AstExpression expression, int comparand)
69+
=> expression.IsInt32Constant(out var int32Constant) && int32Constant == comparand;
70+
71+
public static bool IsInt32Constant(this AstExpression expression, out int int32Constant)
72+
{
73+
if (expression.IsConstant<BsonInt32>(out var bsonInt32Constant))
74+
{
75+
int32Constant = bsonInt32Constant.Value;
76+
return true;
77+
}
78+
79+
int32Constant = default;
80+
return false;
81+
}
3482

3583
public static bool IsMaxInt32(this AstExpression expression)
36-
=> expression.IsInt32Constant(out var value) && value == int.MaxValue;
84+
=> expression.IsInt32Constant(out var int32Constant) && int32Constant == int.MaxValue;
3785

3886
public static bool IsRootVar(this AstExpression expression)
3987
=> expression is AstVarExpression varExpression && varExpression.Name == "ROOT" && varExpression.IsCurrent;
4088

89+
public static bool IsStringConstant(this AstExpression expression, string comparand)
90+
=> expression.IsStringConstant(out var stringConstant) && stringConstant == comparand;
91+
92+
public static bool IsStringConstant(this AstExpression expression, out string stringConstant)
93+
{
94+
if (expression.IsConstant<BsonString>(out var bsonStringConstant))
95+
{
96+
stringConstant = bsonStringConstant.Value;
97+
return true;
98+
}
99+
100+
stringConstant = default;
101+
return false;
102+
}
103+
41104
public static bool IsZero(this AstExpression expression)
42-
=> expression is AstConstantExpression constantExpression && constantExpression.Value == 0;
105+
=> expression.IsConstant(out var constant) && constant == 0; // works for all numeric BSON types
43106
}
44107
}

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstGetFieldExpression.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,9 @@ public override string ConvertToFieldPath()
6363

6464
public bool HasSafeFieldName(out string fieldName)
6565
{
66-
if (_fieldName is AstConstantExpression constantFieldName &&
67-
constantFieldName.Value is BsonString stringfieldName)
66+
if (_fieldName.IsStringConstant(out var constantFieldName))
6867
{
69-
fieldName = stringfieldName.Value;
68+
fieldName = constantFieldName;
7069
if (fieldName.Length > 0 && fieldName[0] != '$' && !fieldName.Contains('.'))
7170
{
7271
return true;

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs

+6-18
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,7 @@ public override AstNode VisitFilterField(AstFilterField node)
352352

353353
public override AstNode VisitGetFieldExpression(AstGetFieldExpression node)
354354
{
355-
if (node.FieldName is AstConstantExpression constantFieldName &&
356-
constantFieldName.Value.IsString &&
357-
constantFieldName.Value.AsString == "_elements")
355+
if (node.FieldName.IsStringConstant("_elements"))
358356
{
359357
throw new UnableToRemoveReferenceToElementsException();
360358
}
@@ -366,9 +364,7 @@ public override AstNode VisitMapExpression(AstMapExpression node)
366364
{
367365
// { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => element) } } + "$__agg0"
368366
if (node.Input is AstGetFieldExpression mapInputGetFieldExpression &&
369-
mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFieldExpression &&
370-
mapInputconstantFieldExpression.Value.IsString &&
371-
mapInputconstantFieldExpression.Value.AsString == "_elements" &&
367+
mapInputGetFieldExpression.FieldName.IsStringConstant("_elements") &&
372368
mapInputGetFieldExpression.Input.IsRootVar())
373369
{
374370
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, _element));
@@ -386,9 +382,7 @@ public override AstNode VisitPickExpression(AstPickExpression node)
386382
// => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => element) } } } + "$__agg0"
387383
if (node.Source is AstGetFieldExpression getFieldExpression &&
388384
getFieldExpression.Input.IsRootVar() &&
389-
getFieldExpression.FieldName is AstConstantExpression constantFieldNameExpression &&
390-
constantFieldNameExpression.Value.IsString &&
391-
constantFieldNameExpression.Value.AsString == "_elements")
385+
getFieldExpression.FieldName.IsStringConstant("_elements"))
392386
{
393387
var @operator = node.Operator.ToAccumulatorOperator();
394388
var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, _element));
@@ -425,9 +419,7 @@ bool TryOptimizeSizeOfElements(out AstExpression optimizedExpression)
425419
if (node.Operator == AstUnaryOperator.Size)
426420
{
427421
if (node.Arg is AstGetFieldExpression argGetFieldExpression &&
428-
argGetFieldExpression.FieldName is AstConstantExpression constantFieldNameExpression &&
429-
constantFieldNameExpression.Value.IsString &&
430-
constantFieldNameExpression.Value.AsString == "_elements")
422+
argGetFieldExpression.FieldName.IsStringConstant("_elements"))
431423
{
432424
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Sum, 1);
433425
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
@@ -445,9 +437,7 @@ bool TryOptimizeAccumulatorOfElements(out AstExpression optimizedExpression)
445437
// { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : element } } + "$__agg0"
446438
if (node.Operator.IsAccumulator(out var accumulatorOperator) &&
447439
node.Arg is AstGetFieldExpression getFieldExpression &&
448-
getFieldExpression.FieldName is AstConstantExpression getFieldConstantFieldNameExpression &&
449-
getFieldConstantFieldNameExpression.Value.IsString &&
450-
getFieldConstantFieldNameExpression.Value == "_elements" &&
440+
getFieldExpression.FieldName.IsStringConstant("_elements") &&
451441
getFieldExpression.Input.IsRootVar())
452442
{
453443
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, _element);
@@ -467,9 +457,7 @@ bool TryOptimizeAccumulatorOfMappedElements(out AstExpression optimizedExpressio
467457
if (node.Operator.IsAccumulator(out var accumulatorOperator) &&
468458
node.Arg is AstMapExpression mapExpression &&
469459
mapExpression.Input is AstGetFieldExpression mapInputGetFieldExpression &&
470-
mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFieldExpression &&
471-
mapInputconstantFieldExpression.Value.IsString &&
472-
mapInputconstantFieldExpression.Value.AsString == "_elements" &&
460+
mapInputGetFieldExpression.FieldName.IsStringConstant("_elements") &&
473461
mapInputGetFieldExpression.Input.IsRootVar())
474462
{
475463
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, _element));

0 commit comments

Comments
 (0)