@@ -315,9 +315,10 @@ SILType VJPEmitter::getNominalDeclLoweredType(NominalTypeDecl *nominal) {
315
315
return getLoweredType (nominalType);
316
316
}
317
317
318
- StructInst *VJPEmitter::buildPullbackValueStructValue (SILBasicBlock *origBB) {
319
- assert (origBB->getParent () == original);
320
- auto loc = origBB->getParent ()->getLocation ();
318
+ StructInst *VJPEmitter::buildPullbackValueStructValue (TermInst *termInst) {
319
+ assert (termInst->getFunction () == original);
320
+ auto loc = RegularLocation::getAutoGeneratedLocation ();
321
+ auto origBB = termInst->getParent ();
321
322
auto *vjpBB = BBMap[origBB];
322
323
auto *pbStruct = pullbackInfo.getLinearMapStruct (origBB);
323
324
auto structLoweredTy = getNominalDeclLoweredType (pbStruct);
@@ -326,14 +327,15 @@ StructInst *VJPEmitter::buildPullbackValueStructValue(SILBasicBlock *origBB) {
326
327
auto *predEnumArg = vjpBB->getArguments ().back ();
327
328
bbPullbackValues.insert (bbPullbackValues.begin (), predEnumArg);
328
329
}
330
+ getBuilder ().setCurrentDebugScope (getOpScope (termInst->getDebugScope ()));
329
331
return getBuilder ().createStruct (loc, structLoweredTy, bbPullbackValues);
330
332
}
331
333
332
334
EnumInst *VJPEmitter::buildPredecessorEnumValue (SILBuilder &builder,
333
335
SILBasicBlock *predBB,
334
336
SILBasicBlock *succBB,
335
337
SILValue pbStructVal) {
336
- auto loc = pbStructVal. getLoc ();
338
+ auto loc = RegularLocation::getAutoGeneratedLocation ();
337
339
auto *succEnum = pullbackInfo.getBranchingTraceDecl (succBB);
338
340
auto enumLoweredTy = getNominalDeclLoweredType (succEnum);
339
341
auto *enumEltDecl =
@@ -361,7 +363,7 @@ void VJPEmitter::visitReturnInst(ReturnInst *ri) {
361
363
362
364
// Build pullback struct value for original block.
363
365
auto *origExit = ri->getParent ();
364
- auto *pbStructVal = buildPullbackValueStructValue (origExit );
366
+ auto *pbStructVal = buildPullbackValueStructValue (ri );
365
367
366
368
// Get the value in the VJP corresponding to the original result.
367
369
auto *origRetInst = cast<ReturnInst>(origExit->getTerminator ());
@@ -416,7 +418,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) {
416
418
// Build pullback struct value for original block.
417
419
// Build predecessor enum value for destination block.
418
420
auto *origBB = bi->getParent ();
419
- auto *pbStructVal = buildPullbackValueStructValue (origBB );
421
+ auto *pbStructVal = buildPullbackValueStructValue (bi );
420
422
auto *enumVal = buildPredecessorEnumValue (getBuilder (), origBB,
421
423
bi->getDestBB (), pbStructVal);
422
424
@@ -433,7 +435,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) {
433
435
434
436
void VJPEmitter::visitCondBranchInst (CondBranchInst *cbi) {
435
437
// Build pullback struct value for original block.
436
- auto *pbStructVal = buildPullbackValueStructValue (cbi-> getParent () );
438
+ auto *pbStructVal = buildPullbackValueStructValue (cbi);
437
439
// Create a new `cond_br` instruction.
438
440
getBuilder ().createCondBranch (
439
441
cbi->getLoc (), getOpValue (cbi->getCondition ()),
@@ -443,7 +445,7 @@ void VJPEmitter::visitCondBranchInst(CondBranchInst *cbi) {
443
445
444
446
void VJPEmitter::visitSwitchEnumInstBase (SwitchEnumInstBase *sei) {
445
447
// Build pullback struct value for original block.
446
- auto *pbStructVal = buildPullbackValueStructValue (sei-> getParent () );
448
+ auto *pbStructVal = buildPullbackValueStructValue (sei);
447
449
448
450
// Create trampoline successor basic blocks.
449
451
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4 > caseBBs;
@@ -483,7 +485,7 @@ void VJPEmitter::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) {
483
485
484
486
void VJPEmitter::visitCheckedCastBranchInst (CheckedCastBranchInst *ccbi) {
485
487
// Build pullback struct value for original block.
486
- auto *pbStructVal = buildPullbackValueStructValue (ccbi-> getParent () );
488
+ auto *pbStructVal = buildPullbackValueStructValue (ccbi);
487
489
// Create a new `checked_cast_branch` instruction.
488
490
getBuilder ().createCheckedCastBranch (
489
491
ccbi->getLoc (), ccbi->isExact (), getOpValue (ccbi->getOperand ()),
@@ -497,7 +499,7 @@ void VJPEmitter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) {
497
499
void VJPEmitter::visitCheckedCastValueBranchInst (
498
500
CheckedCastValueBranchInst *ccvbi) {
499
501
// Build pullback struct value for original block.
500
- auto *pbStructVal = buildPullbackValueStructValue (ccvbi-> getParent () );
502
+ auto *pbStructVal = buildPullbackValueStructValue (ccvbi);
501
503
// Create a new `checked_cast_value_branch` instruction.
502
504
getBuilder ().createCheckedCastValueBranch (
503
505
ccvbi->getLoc (), getOpValue (ccvbi->getOperand ()),
@@ -511,7 +513,7 @@ void VJPEmitter::visitCheckedCastValueBranchInst(
511
513
void VJPEmitter::visitCheckedCastAddrBranchInst (
512
514
CheckedCastAddrBranchInst *ccabi) {
513
515
// Build pullback struct value for original block.
514
- auto *pbStructVal = buildPullbackValueStructValue (ccabi-> getParent () );
516
+ auto *pbStructVal = buildPullbackValueStructValue (ccabi);
515
517
// Create a new `checked_cast_addr_branch` instruction.
516
518
getBuilder ().createCheckedCastAddrBranch (
517
519
ccabi->getLoc (), ccabi->getConsumptionKind (), getOpValue (ccabi->getSrc ()),
0 commit comments