Skip to content

Commit 8c70a96

Browse files
authored
Merge pull request swiftlang#32031 from dan-zheng/fix-alloc-stack-cloning
[AutoDiff] Fix differentiation crashes related to definite initialization.
2 parents 26125f5 + ff97ae7 commit 8c70a96

File tree

4 files changed

+70
-16
lines changed

4 files changed

+70
-16
lines changed

include/swift/SIL/SILCloner.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -802,9 +802,9 @@ SILCloner<ImplClass>::visitAllocStackInst(AllocStackInst *Inst) {
802802
Loc = MandatoryInlinedLocation::getAutoGeneratedLocation();
803803
VarInfo = None;
804804
}
805-
recordClonedInstruction(Inst,
806-
getBuilder().createAllocStack(
807-
Loc, getOpType(Inst->getElementType()), VarInfo));
805+
recordClonedInstruction(Inst, getBuilder().createAllocStack(
806+
Loc, getOpType(Inst->getElementType()),
807+
VarInfo, Inst->hasDynamicLifetime()));
808808
}
809809

810810
template<typename ImplClass>

include/swift/SILOptimizer/Differentiation/VJPEmitter.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ class VJPEmitter final
130130
StructInst *pbStructVal,
131131
SILBasicBlock *succBB);
132132

133-
/// Build a pullback struct value for the given original block.
134-
StructInst *buildPullbackValueStructValue(SILBasicBlock *bb);
133+
/// Build a pullback struct value for the given original terminator
134+
/// instruction.
135+
StructInst *buildPullbackValueStructValue(TermInst *termInst);
135136

136137
/// Build a predecessor enum instance using the given builder for the given
137138
/// original predecessor/successor blocks and pullback struct value.

lib/SILOptimizer/Differentiation/VJPEmitter.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,10 @@ SILType VJPEmitter::getNominalDeclLoweredType(NominalTypeDecl *nominal) {
315315
return getLoweredType(nominalType);
316316
}
317317

318-
StructInst *VJPEmitter::buildPullbackValueStructValue(SILBasicBlock *origBB) {
319-
assert(origBB->getParent() == original);
320-
auto loc = origBB->getParent()->getLocation();
318+
StructInst *VJPEmitter::buildPullbackValueStructValue(TermInst *termInst) {
319+
assert(termInst->getFunction() == original);
320+
auto loc = RegularLocation::getAutoGeneratedLocation();
321+
auto origBB = termInst->getParent();
321322
auto *vjpBB = BBMap[origBB];
322323
auto *pbStruct = pullbackInfo.getLinearMapStruct(origBB);
323324
auto structLoweredTy = getNominalDeclLoweredType(pbStruct);
@@ -326,14 +327,15 @@ StructInst *VJPEmitter::buildPullbackValueStructValue(SILBasicBlock *origBB) {
326327
auto *predEnumArg = vjpBB->getArguments().back();
327328
bbPullbackValues.insert(bbPullbackValues.begin(), predEnumArg);
328329
}
330+
getBuilder().setCurrentDebugScope(getOpScope(termInst->getDebugScope()));
329331
return getBuilder().createStruct(loc, structLoweredTy, bbPullbackValues);
330332
}
331333

332334
EnumInst *VJPEmitter::buildPredecessorEnumValue(SILBuilder &builder,
333335
SILBasicBlock *predBB,
334336
SILBasicBlock *succBB,
335337
SILValue pbStructVal) {
336-
auto loc = pbStructVal.getLoc();
338+
auto loc = RegularLocation::getAutoGeneratedLocation();
337339
auto *succEnum = pullbackInfo.getBranchingTraceDecl(succBB);
338340
auto enumLoweredTy = getNominalDeclLoweredType(succEnum);
339341
auto *enumEltDecl =
@@ -361,7 +363,7 @@ void VJPEmitter::visitReturnInst(ReturnInst *ri) {
361363

362364
// Build pullback struct value for original block.
363365
auto *origExit = ri->getParent();
364-
auto *pbStructVal = buildPullbackValueStructValue(origExit);
366+
auto *pbStructVal = buildPullbackValueStructValue(ri);
365367

366368
// Get the value in the VJP corresponding to the original result.
367369
auto *origRetInst = cast<ReturnInst>(origExit->getTerminator());
@@ -416,7 +418,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) {
416418
// Build pullback struct value for original block.
417419
// Build predecessor enum value for destination block.
418420
auto *origBB = bi->getParent();
419-
auto *pbStructVal = buildPullbackValueStructValue(origBB);
421+
auto *pbStructVal = buildPullbackValueStructValue(bi);
420422
auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB,
421423
bi->getDestBB(), pbStructVal);
422424

@@ -433,7 +435,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) {
433435

434436
void VJPEmitter::visitCondBranchInst(CondBranchInst *cbi) {
435437
// Build pullback struct value for original block.
436-
auto *pbStructVal = buildPullbackValueStructValue(cbi->getParent());
438+
auto *pbStructVal = buildPullbackValueStructValue(cbi);
437439
// Create a new `cond_br` instruction.
438440
getBuilder().createCondBranch(
439441
cbi->getLoc(), getOpValue(cbi->getCondition()),
@@ -443,7 +445,7 @@ void VJPEmitter::visitCondBranchInst(CondBranchInst *cbi) {
443445

444446
void VJPEmitter::visitSwitchEnumInstBase(SwitchEnumInstBase *sei) {
445447
// Build pullback struct value for original block.
446-
auto *pbStructVal = buildPullbackValueStructValue(sei->getParent());
448+
auto *pbStructVal = buildPullbackValueStructValue(sei);
447449

448450
// Create trampoline successor basic blocks.
449451
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs;
@@ -483,7 +485,7 @@ void VJPEmitter::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) {
483485

484486
void VJPEmitter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) {
485487
// Build pullback struct value for original block.
486-
auto *pbStructVal = buildPullbackValueStructValue(ccbi->getParent());
488+
auto *pbStructVal = buildPullbackValueStructValue(ccbi);
487489
// Create a new `checked_cast_branch` instruction.
488490
getBuilder().createCheckedCastBranch(
489491
ccbi->getLoc(), ccbi->isExact(), getOpValue(ccbi->getOperand()),
@@ -497,7 +499,7 @@ void VJPEmitter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) {
497499
void VJPEmitter::visitCheckedCastValueBranchInst(
498500
CheckedCastValueBranchInst *ccvbi) {
499501
// Build pullback struct value for original block.
500-
auto *pbStructVal = buildPullbackValueStructValue(ccvbi->getParent());
502+
auto *pbStructVal = buildPullbackValueStructValue(ccvbi);
501503
// Create a new `checked_cast_value_branch` instruction.
502504
getBuilder().createCheckedCastValueBranch(
503505
ccvbi->getLoc(), getOpValue(ccvbi->getOperand()),
@@ -511,7 +513,7 @@ void VJPEmitter::visitCheckedCastValueBranchInst(
511513
void VJPEmitter::visitCheckedCastAddrBranchInst(
512514
CheckedCastAddrBranchInst *ccabi) {
513515
// Build pullback struct value for original block.
514-
auto *pbStructVal = buildPullbackValueStructValue(ccabi->getParent());
516+
auto *pbStructVal = buildPullbackValueStructValue(ccabi);
515517
// Create a new `checked_cast_addr_branch` instruction.
516518
getBuilder().createCheckedCastAddrBranch(
517519
ccabi->getLoc(), ccabi->getConsumptionKind(), getOpValue(ccabi->getSrc()),
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: %target-build-swift %s
2+
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s
3+
4+
// Test crashes related to differentiation and definite intiialization.
5+
6+
// SR-12886: SIL memory lifetime verification error due to
7+
// `SILCloner::visitAllocStack` not copying the `[dynamic_lifetime]` attribute.
8+
9+
// SR-12887: Debug scope error for pullback struct `struct` instruction
10+
// generated by `VJPEmitter`.
11+
12+
import _Differentiation
13+
14+
enum Enum {
15+
case a
16+
}
17+
18+
struct Tensor<T>: Differentiable {
19+
@noDerivative var x: T
20+
@noDerivative var optional: Int?
21+
22+
init(_ x: T, _ e: Enum) {
23+
self.x = x
24+
switch e {
25+
case .a: optional = 1
26+
}
27+
}
28+
29+
// Definite initialization triggers for this initializer.
30+
@differentiable
31+
init(_ x: T, _ other: Self) {
32+
self = Self(x, Enum.a)
33+
}
34+
}
35+
36+
// Check that `allock_stack [dynamic_lifetime]` attribute is correctly cloned.
37+
38+
// CHECK-LABEL: sil hidden @$s4main6TensorVyACyxGx_ADtcfC : $@convention(method) <T> (@in T, @in Tensor<T>, @thin Tensor<T>.Type) -> @out Tensor<T> {
39+
// CHECK: [[SELF_ALLOC:%.*]] = alloc_stack [dynamic_lifetime] $Tensor<T>, var, name "self"
40+
41+
// CHECK-LABEL: sil hidden @AD__$s4main6TensorVyACyxGx_ADtcfC__vjp_src_0_wrt_1_l : $@convention(method) <τ_0_0> (@in τ_0_0, @in Tensor<τ_0_0>, @thin Tensor<τ_0_0>.Type) -> (@out Tensor<τ_0_0>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Tensor<τ_0_0>.TangentVector, Tensor<τ_0_0>.TangentVector>) {
42+
// CHECK: [[SELF_ALLOC:%.*]] = alloc_stack [dynamic_lifetime] $Tensor<τ_0_0>, var, name "self"
43+
44+
// SR-12886 original error:
45+
// SIL memory lifetime failure in @AD__$s5crash6TensorVyACyxGx_ADtcfC__vjp_src_0_wrt_1_l: memory is not initialized, but should
46+
// memory location: %29 = struct_element_addr %5 : $*Tensor<τ_0_0>, #Tensor.x // user: %30
47+
// at instruction: destroy_addr %29 : $*τ_0_0 // id: %30
48+
49+
// SR-12887 original error:
50+
// SIL verification failed: Basic block contains a non-contiguous lexical scope at -Onone: DS == LastSeenScope
51+
// %26 = struct $_AD__$s5crash6TensorVyACyxGx_ADtcfC_bb0__PB__src_0_wrt_1_l<τ_0_0> () // users: %34, %28

0 commit comments

Comments
 (0)