Skip to content

Commit 619d0b0

Browse files
author
marcrasi
authored
[AutoDiff] fix array subscript lookup when there are multiple (swiftlang#31723)
Fix `PullbackEmitter::getArrayAdjointElementBuffer` to always lookup `Array.TangentVector.subscript` from the stdlib. Resolves compiler crash when user code also defines `Array.TangentVector.subscript`.
1 parent 0560f8c commit 619d0b0

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

+14-2
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,7 @@ void PullbackEmitter::visitSILInstruction(SILInstruction *inst) {
13311331
AllocStackInst *
13321332
PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
13331333
int eltIndex, SILLocation loc) {
1334+
auto &ctx = builder.getASTContext();
13341335
auto arrayTanType = cast<StructType>(arrayAdjoint->getType().getASTType());
13351336
auto arrayType = arrayTanType->getParent()->castTo<BoundGenericStructType>();
13361337
auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType();
@@ -1340,7 +1341,19 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
13401341
auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct();
13411342
auto subscriptLookup =
13421343
arrayTanStructDecl->lookupDirect(DeclBaseName::createSubscript());
1343-
auto *subscriptDecl = cast<SubscriptDecl>(subscriptLookup.front());
1344+
SubscriptDecl *subscriptDecl = nullptr;
1345+
for (auto *candidate : subscriptLookup) {
1346+
auto candidateModule = candidate->getModuleContext();
1347+
if (candidateModule->getName() == ctx.Id_Differentiation ||
1348+
candidateModule->isStdlibModule()) {
1349+
assert(!subscriptDecl && "Multiple `Array.TangentVector.subscript`s");
1350+
subscriptDecl = cast<SubscriptDecl>(candidate);
1351+
#ifdef NDEBUG
1352+
break;
1353+
#endif
1354+
}
1355+
}
1356+
assert(subscriptDecl && "No `Array.TangentVector.subscript`");
13441357
auto *subscriptGetterDecl = subscriptDecl->getAccessor(AccessorKind::Get);
13451358
assert(subscriptGetterDecl && "No `Array.TangentVector.subscript` getter");
13461359
SILOptFunctionBuilder fb(getContext().getTransform());
@@ -1352,7 +1365,6 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
13521365
subscriptGetterFn->getLoweredFunctionType()->getSubstGenericSignature();
13531366
// Apply `Array.TangentVector.subscript.getter` to get array element adjoint
13541367
// buffer.
1355-
auto &ctx = builder.getASTContext();
13561368
// %index_literal = integer_literal $Builtin.IntXX, <index>
13571369
auto builtinIntType =
13581370
SILType::getPrimitiveObjectType(ctx.getIntDecl()

test/AutoDiff/stdlib/array.swift

+9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@ var ArrayAutoDiffTests = TestSuite("ArrayAutoDiff")
88

99
typealias FloatArrayTan = Array<Float>.TangentVector
1010

11+
extension Array.DifferentiableView {
12+
/// A subscript that always fatal errors.
13+
///
14+
/// The differentiation transform should never emit calls to this.
15+
subscript(alwaysFatalError: Int) -> Element {
16+
fatalError("wrong subscript")
17+
}
18+
}
19+
1120
ArrayAutoDiffTests.test("ArrayIdentity") {
1221
func arrayIdentity(_ x: [Float]) -> [Float] {
1322
return x

0 commit comments

Comments
 (0)