Skip to content

Commit c2fff44

Browse files
authored
Exception Handler support (#78773)
* Update EH tests to run with runtime async * Handle non-null exception filter prologues in the spill sequencer * Add more testing to show current incorrect behavior * Unskip ConditionalFacts that do not need to be skipped. * Handle ensuring that the method remains valid, even when there is an `await` in a finally section. Add signifcant testing of `await using`. * Fix baselines * Support `await foreach` and add runtime async verification to existing tests. * Remove unnecessary generic * Failing tests, add async void test suggestion * CI failures * Add additional test * Test fixes * Remove implemented PROTOTYPE, add assertion on behavior. * Update to SpillSequenceSpiller after some more debugging and tightening the assertion * Fix nullref * Enable nullable for VisitCatchBlock
1 parent a98e3fd commit c2fff44

16 files changed

+10068
-319
lines changed

src/Compilers/CSharp/Portable/Binder/ForEachLoopBinder.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ private BoundForEachStatement BindForEachPartsWorker(BindingDiagnosticBag diagno
263263
var placeholder = new BoundAwaitableValuePlaceholder(expr, builder.MoveNextInfo?.Method.ReturnType ?? CreateErrorType());
264264
awaitInfo = BindAwaitInfo(placeholder, expr, diagnostics, ref hasErrors);
265265

266-
if (!hasErrors && awaitInfo.GetResult?.ReturnType.SpecialType != SpecialType.System_Boolean)
266+
if (!hasErrors && (awaitInfo.GetResult ?? awaitInfo.RuntimeAsyncAwaitMethod)?.ReturnType.SpecialType != SpecialType.System_Boolean)
267267
{
268268
diagnostics.Add(ErrorCode.ERR_BadGetAsyncEnumerator, expr.Location, getEnumeratorMethod.ReturnTypeWithAnnotations, getEnumeratorMethod);
269269
hasErrors = true;

src/Compilers/CSharp/Portable/BoundTree/BoundAwaitableInfo.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,7 @@ private partial void Validate()
3434
break;
3535
}
3636
}
37+
38+
Debug.Assert(GetAwaiter is not null || RuntimeAsyncAwaitMethod is not null || IsDynamic || HasErrors);
3739
}
3840
}

src/Compilers/CSharp/Portable/CodeGen/CodeGenerator.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,10 @@ private void HandleReturn()
321321
{
322322
_builder.MarkLabel(s_returnLabel);
323323

324-
Debug.Assert(_method.ReturnsVoid == (_returnTemp == null));
324+
Debug.Assert(_method.ReturnsVoid == (_returnTemp == null)
325+
|| (_method.IsAsync
326+
&& _module.Compilation.IsRuntimeAsyncEnabledIn(_method)
327+
&& ((InternalSpecialType)_method.ReturnType.ExtendedSpecialType) is InternalSpecialType.System_Threading_Tasks_Task or InternalSpecialType.System_Threading_Tasks_ValueTask));
325328

326329
if (_emitPdbSequencePoints && !_method.IsIterator && !_method.IsAsync)
327330
{

src/Compilers/CSharp/Portable/Lowering/AsyncRewriter/AsyncExceptionHandlerRewriter.cs

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ internal sealed class AsyncExceptionHandlerRewriter : BoundTreeRewriterWithStack
2828
private AwaitCatchFrame _currentAwaitCatchFrame;
2929
private AwaitFinallyFrame _currentAwaitFinallyFrame = new AwaitFinallyFrame();
3030
private bool _inCatchWithoutAwaits;
31+
private bool _needsFinalThrow;
3132

3233
private AsyncExceptionHandlerRewriter(
3334
MethodSymbol containingMethod,
@@ -129,9 +130,45 @@ public static BoundStatement Rewrite(
129130
var rewriter = new AsyncExceptionHandlerRewriter(containingSymbol, containingType, factory, analysis);
130131
var loweredStatement = (BoundStatement)rewriter.Visit(statement);
131132

133+
loweredStatement = rewriter.FinalizeMethodBody(loweredStatement);
134+
132135
return loweredStatement;
133136
}
134137

138+
private BoundStatement FinalizeMethodBody(BoundStatement loweredStatement)
139+
{
140+
if (loweredStatement == null)
141+
{
142+
return null;
143+
}
144+
145+
// When we add a `switch (pendingBranch)` to the end of the try block,
146+
// this can result in a method body that cannot be proven to terminate.
147+
// While we can technically prove it by doing a full data flow analysis,
148+
// this is effectively the halting problem, and the runtime will not do
149+
// this analysis. The resulting IL will be technically invalid, and if it's
150+
// not wrapped in another state machine (a la the compiler async rewriter),
151+
// the runtime will refuse to load it. For runtime async, where we are effectively
152+
// emitting the result of this rewriter directly, we need to ensure that
153+
// we always emit a throw at the end of the try block when the switch is present.
154+
// This ensures that the method can be proven to terminate, and the runtime will
155+
// accept it. This throw will never be reached, and we could potentially do a
156+
// more sophisticated analysis to determine if it is needed by pushing control
157+
// flow analysis through the bound nodes, see https://github.com/dotnet/roslyn/pull/78970.
158+
// This is risky, however, and for now we are taking the conservative approach
159+
// of always emitting the throw.
160+
BoundStatement result = loweredStatement;
161+
if (_needsFinalThrow)
162+
{
163+
result = _F.Block(
164+
loweredStatement,
165+
_F.Throw(_F.Null(_F.SpecialType(SpecialType.System_Object)))
166+
);
167+
}
168+
169+
return result;
170+
}
171+
135172
public override BoundNode VisitTryStatement(BoundTryStatement node)
136173
{
137174
var tryStatementSyntax = node.Syntax;
@@ -354,6 +391,7 @@ private BoundStatement UnpendBranches(
354391
cases.Add(caseStatement);
355392
}
356393

394+
_needsFinalThrow = true;
357395
return _F.Switch(_F.Local(pendingBranchVar), cases.ToImmutableAndFree());
358396
}
359397

@@ -402,22 +440,37 @@ public override BoundNode VisitReturnStatement(BoundReturnStatement node)
402440

403441
private BoundStatement UnpendException(LocalSymbol pendingExceptionLocal)
404442
{
443+
// If this is runtime async, we don't need to create a second local for the exception,
444+
// as the pendingExceptionLocal will not be hoisted to a state machine by a future rewrite.
445+
if (_F.Compilation.IsRuntimeAsyncEnabledIn(_F.CurrentFunction))
446+
{
447+
// pendingExceptionLocal is already an object
448+
// so we can just use it directly
449+
return checkAndThrow(pendingExceptionLocal);
450+
}
451+
405452
// create a temp.
406453
// pendingExceptionLocal will certainly be captured, no need to access it over and over.
407454
LocalSymbol obj = _F.SynthesizedLocal(_F.SpecialType(SpecialType.System_Object));
408455
var objInit = _F.Assignment(_F.Local(obj), _F.Local(pendingExceptionLocal));
409456

410457
// throw pendingExceptionLocal;
411-
BoundStatement rethrow = Rethrow(obj);
412-
413458
return _F.Block(
414459
ImmutableArray.Create<LocalSymbol>(obj),
415460
objInit,
416-
_F.If(
417-
_F.ObjectNotEqual(
418-
_F.Local(obj),
419-
_F.Null(obj.Type)),
420-
rethrow));
461+
checkAndThrow(obj));
462+
463+
BoundStatement checkAndThrow(LocalSymbol obj)
464+
{
465+
BoundStatement rethrow = Rethrow(obj);
466+
467+
BoundStatement checkAndThrow = _F.If(
468+
_F.ObjectNotEqual(
469+
_F.Local(obj),
470+
_F.Null(obj.Type)),
471+
rethrow);
472+
return checkAndThrow;
473+
}
421474
}
422475

423476
private BoundStatement Rethrow(LocalSymbol obj)
@@ -706,14 +759,24 @@ public override BoundNode VisitLambda(BoundLambda node)
706759
{
707760
var oldContainingSymbol = _F.CurrentFunction;
708761
var oldAwaitFinallyFrame = _currentAwaitFinallyFrame;
762+
var oldNeedsFinalThrow = _needsFinalThrow;
709763

710764
_F.CurrentFunction = node.Symbol;
711765
_currentAwaitFinallyFrame = new AwaitFinallyFrame();
766+
_needsFinalThrow = false;
712767

713-
var result = base.VisitLambda(node);
768+
var result = (BoundLambda)base.VisitLambda(node);
769+
result = result.Update(
770+
result.UnboundLambda,
771+
result.Symbol,
772+
(BoundBlock)FinalizeMethodBody(result.Body),
773+
node.Diagnostics,
774+
node.Binder,
775+
node.Type);
714776

715777
_F.CurrentFunction = oldContainingSymbol;
716778
_currentAwaitFinallyFrame = oldAwaitFinallyFrame;
779+
_needsFinalThrow = oldNeedsFinalThrow;
717780

718781
return result;
719782
}
@@ -722,14 +785,18 @@ public override BoundNode VisitLocalFunctionStatement(BoundLocalFunctionStatemen
722785
{
723786
var oldContainingSymbol = _F.CurrentFunction;
724787
var oldAwaitFinallyFrame = _currentAwaitFinallyFrame;
788+
var oldNeedsFinalThrow = _needsFinalThrow;
725789

726790
_F.CurrentFunction = node.Symbol;
727791
_currentAwaitFinallyFrame = new AwaitFinallyFrame();
792+
_needsFinalThrow = false;
728793

729-
var result = base.VisitLocalFunctionStatement(node);
794+
var result = (BoundLocalFunctionStatement)base.VisitLocalFunctionStatement(node);
795+
result = result.Update(node.Symbol, (BoundBlock)FinalizeMethodBody(result.Body), (BoundBlock)FinalizeMethodBody(result.ExpressionBody));
730796

731797
_F.CurrentFunction = oldContainingSymbol;
732798
_currentAwaitFinallyFrame = oldAwaitFinallyFrame;
799+
_needsFinalThrow = oldNeedsFinalThrow;
733800

734801
return result;
735802
}

src/Compilers/CSharp/Portable/Lowering/AsyncRewriter/RuntimeAsyncRewriter.cs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,17 @@ public static BoundStatement Rewrite(
2222
return node;
2323
}
2424

25-
// PROTOTYPE: try/finally rewriting
2625
// PROTOTYPE: struct lifting
27-
var rewriter = new RuntimeAsyncRewriter(compilationState.Compilation, new SyntheticBoundNodeFactory(method, node.Syntax, compilationState, diagnostics));
26+
var rewriter = new RuntimeAsyncRewriter(new SyntheticBoundNodeFactory(method, node.Syntax, compilationState, diagnostics));
2827
var result = (BoundStatement)rewriter.Visit(node);
2928
return SpillSequenceSpiller.Rewrite(result, method, compilationState, diagnostics);
3029
}
3130

32-
private readonly CSharpCompilation _compilation;
3331
private readonly SyntheticBoundNodeFactory _factory;
3432
private readonly Dictionary<BoundAwaitableValuePlaceholder, BoundExpression> _placeholderMap;
3533

36-
private RuntimeAsyncRewriter(CSharpCompilation compilation, SyntheticBoundNodeFactory factory)
34+
private RuntimeAsyncRewriter(SyntheticBoundNodeFactory factory)
3735
{
38-
_compilation = compilation;
3936
_factory = factory;
4037
_placeholderMap = [];
4138
}

src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_ForEachStatement.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,14 @@ private BoundStatement RewriteForEachEnumerator(
223223
var disposalFinallyBlock = GetDisposalFinallyBlock(forEachSyntax, enumeratorInfo, enumeratorType, boundEnumeratorVar, out var hasAsyncDisposal);
224224
if (isAsync)
225225
{
226-
Debug.Assert(awaitableInfo is { GetResult: { } });
226+
Debug.Assert(awaitableInfo is { GetResult: not null } or { RuntimeAsyncAwaitMethod: not null });
227227

228228
// We need to be sure that when the disposal isn't async we reserve an unused state machine state number for it,
229229
// so that await foreach always produces 2 state machine states: one for MoveNextAsync and the other for DisposeAsync.
230230
// Otherwise, EnC wouldn't be able to map states when the disposal changes from having async dispose to not, or vice versa.
231231
var debugInfo = new BoundAwaitExpressionDebugInfo(s_moveNextAsyncAwaitId, ReservedStateMachineCount: (byte)(hasAsyncDisposal ? 0 : 1));
232232

233-
rewrittenCondition = RewriteAwaitExpression(forEachSyntax, rewrittenCondition, awaitableInfo, awaitableInfo.GetResult.ReturnType, debugInfo, used: true);
233+
rewrittenCondition = RewriteAwaitExpression(forEachSyntax, rewrittenCondition, awaitableInfo, (awaitableInfo.GetResult ?? awaitableInfo.RuntimeAsyncAwaitMethod)!.ReturnType, debugInfo, used: true);
234234
}
235235

236236
BoundStatement whileLoop = RewriteWhileStatement(

src/Compilers/CSharp/Portable/Lowering/SpillSequenceSpiller.cs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -704,15 +704,24 @@ public override BoundNode VisitYieldReturnStatement(BoundYieldReturnStatement no
704704
return UpdateStatement(builder, node.Update(expression));
705705
}
706706

707+
#nullable enable
707708
public override BoundNode VisitCatchBlock(BoundCatchBlock node)
708709
{
709-
BoundExpression exceptionSourceOpt = (BoundExpression)this.Visit(node.ExceptionSourceOpt);
710+
BoundExpression? exceptionSourceOpt = (BoundExpression?)this.Visit(node.ExceptionSourceOpt);
710711
var locals = node.Locals;
711712

712713
var exceptionFilterPrologueOpt = node.ExceptionFilterPrologueOpt;
713-
Debug.Assert(exceptionFilterPrologueOpt is null); // it is introduced by this pass
714-
BoundSpillSequenceBuilder builder = null;
714+
if (exceptionFilterPrologueOpt is not null)
715+
{
716+
exceptionFilterPrologueOpt = (BoundStatementList?)VisitStatementList(exceptionFilterPrologueOpt);
717+
}
718+
BoundSpillSequenceBuilder? builder = null;
719+
715720
var exceptionFilterOpt = VisitExpression(ref builder, node.ExceptionFilterOpt);
721+
Debug.Assert(exceptionFilterPrologueOpt is null || builder is null,
722+
"You are exercising SpillSequenceSpiller in a new fashion, causing a spill in an exception filter after LocalRewriting is complete. This is not someting " +
723+
"that this builder supports today, so please update this rewrite to include the statements from exceptionFilterPrologueOpt with the appropriate " +
724+
"syntax node and tracking.");
716725
if (builder is { })
717726
{
718727
Debug.Assert(builder.Value is null);
@@ -721,9 +730,10 @@ public override BoundNode VisitCatchBlock(BoundCatchBlock node)
721730
}
722731

723732
BoundBlock body = (BoundBlock)this.Visit(node.Body);
724-
TypeSymbol exceptionTypeOpt = this.VisitType(node.ExceptionTypeOpt);
733+
TypeSymbol? exceptionTypeOpt = this.VisitType(node.ExceptionTypeOpt);
725734
return node.Update(locals, exceptionSourceOpt, exceptionTypeOpt, exceptionFilterPrologueOpt, exceptionFilterOpt, body, node.IsSynthesizedAsyncCatchAll);
726735
}
736+
#nullable disable
727737

728738
#if DEBUG
729739
public override BoundNode DefaultVisit(BoundNode node)

0 commit comments

Comments
 (0)