Skip to content

Commit 81fbd45

Browse files
authored
Merge pull request #940 from swiftwasm/master
[pull] swiftwasm from master
2 parents f8f8244 + c721cf1 commit 81fbd45

36 files changed

+190
-126
lines changed

include/swift/AST/IndexSubset.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class IndexSubset : public llvm::FoldingSetNode {
207207
}
208208

209209
void print(llvm::raw_ostream &s = llvm::outs()) const;
210-
SWIFT_DEBUG_DUMPER(dump(llvm::raw_ostream &s = llvm::errs()));
210+
SWIFT_DEBUG_DUMPER(dump());
211211

212212
int findNext(int startIndex) const;
213213
int findFirst() const { return findNext(-1); }

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,19 +211,16 @@ inline void createEntryArguments(SILFunction *f) {
211211
decl->setSpecifier(ParamDecl::Specifier::Default);
212212
entry->createFunctionArgument(type, decl);
213213
};
214-
// f->getLoweredFunctionType()->remap
215214
for (auto indResTy :
216215
conv.getIndirectSILResultTypes(f->getTypeExpansionContext())) {
217216
if (indResTy.hasArchetype())
218217
indResTy = indResTy.mapTypeOutOfContext();
219218
createFunctionArgument(f->mapTypeIntoContext(indResTy).getAddressType());
220-
// createFunctionArgument(indResTy.getAddressType());
221219
}
222220
for (auto paramTy : conv.getParameterSILTypes(f->getTypeExpansionContext())) {
223221
if (paramTy.hasArchetype())
224222
paramTy = paramTy.mapTypeOutOfContext();
225223
createFunctionArgument(f->mapTypeIntoContext(paramTy));
226-
// createFunctionArgument(paramTy);
227224
}
228225
}
229226

include/swift/SILOptimizer/Differentiation/PullbackEmitter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKEMITTER_H
1919
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKEMITTER_H
2020

21+
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2122
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
2223
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
2324
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
24-
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2525

2626
#include "swift/SIL/TypeSubstCloner.h"
2727
#include "llvm/ADT/DenseMap.h"

include/swift/SILOptimizer/Differentiation/VJPEmitter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPEMITTER_H
1919
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPEMITTER_H
2020

21+
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2122
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
2223
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
23-
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2424

2525
#include "swift/SIL/TypeSubstCloner.h"
2626
#include "llvm/ADT/DenseMap.h"

lib/AST/Decl.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7111,10 +7111,19 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
71117111
ArrayRef<AutoDiffConfig>
71127112
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
71137113
prepareDerivativeFunctionConfigurations();
7114+
71147115
// Resolve derivative function configurations from `@differentiable`
71157116
// attributes by type-checking them.
71167117
for (auto *diffAttr : getAttrs().getAttributes<DifferentiableAttr>())
71177118
(void)diffAttr->getParameterIndices();
7119+
// For accessors: resolve derivative function configurations from storage
7120+
// `@differentiable` attributes by type-checking them.
7121+
if (auto *accessor = dyn_cast<AccessorDecl>(this)) {
7122+
auto *storage = accessor->getStorage();
7123+
for (auto *diffAttr : storage->getAttrs().getAttributes<DifferentiableAttr>())
7124+
(void)diffAttr->getParameterIndices();
7125+
}
7126+
71187127
// Load derivative configurations from imported modules.
71197128
auto &ctx = getASTContext();
71207129
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {

lib/AST/GenericSignature.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,8 @@ GenericSignatureImpl::lookupConformance(CanType type,
379379
}
380380

381381
bool GenericSignatureImpl::requiresClass(Type type) {
382-
if (!type->isTypeParameter()) return false;
382+
assert(type->isTypeParameter() &&
383+
"Only type parameters can have superclass requirements");
383384

384385
auto &builder = *getGenericSignatureBuilder();
385386
auto equivClass =

lib/AST/IndexSubset.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ void IndexSubset::print(llvm::raw_ostream &s) const {
8686
s << '}';
8787
}
8888

89-
void IndexSubset::dump(llvm::raw_ostream &s) const {
89+
void IndexSubset::dump() const {
90+
auto &s = llvm::errs();
9091
s << "(index_subset capacity=" << capacity << " indices=(";
9192
interleave(getIndices(), [&s](unsigned i) { s << i; },
9293
[&s] { s << ", "; });

lib/SIL/IR/AbstractionPattern.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,7 @@ bool AbstractionPattern::requiresClass() const {
222222
auto type = getType();
223223
if (auto archetype = dyn_cast<ArchetypeType>(type))
224224
return archetype->requiresClass();
225-
if (isa<DependentMemberType>(type) ||
226-
isa<GenericTypeParamType>(type)) {
225+
if (type->isTypeParameter()) {
227226
if (getKind() == Kind::ClangType) {
228227
// ObjC generics are always class constrained.
229228
return true;

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB,
126126
auto originalFnTy = original->getLoweredFunctionType();
127127
auto numResults = originalFnTy->getNumResults() +
128128
originalFnTy->getNumIndirectMutatingParameters();
129-
auto *resultIndices = IndexSubset::get(
130-
original->getASTContext(), numResults, indices.source);
129+
auto *resultIndices =
130+
IndexSubset::get(original->getASTContext(), numResults, indices.source);
131131
auto *parameterIndices = indices.parameters;
132132
AutoDiffConfig config(parameterIndices, resultIndices, genericSig);
133133
auto enumName = mangler.mangleAutoDiffGeneratedDeclaration(
@@ -199,8 +199,8 @@ LinearMapInfo::createLinearMapStruct(SILBasicBlock *originalBB,
199199
auto originalFnTy = original->getLoweredFunctionType();
200200
auto numResults = originalFnTy->getNumResults() +
201201
originalFnTy->getNumIndirectMutatingParameters();
202-
auto *resultIndices = IndexSubset::get(
203-
original->getASTContext(), numResults, indices.source);
202+
auto *resultIndices =
203+
IndexSubset::get(original->getASTContext(), numResults, indices.source);
204204
auto *parameterIndices = indices.parameters;
205205
AutoDiffConfig config(parameterIndices, resultIndices, genericSig);
206206
auto structName = mangler.mangleAutoDiffGeneratedDeclaration(

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ void PullbackEmitter::cleanUpTemporariesForBlock(SILBasicBlock *bb,
151151
const Lowering::TypeLowering &PullbackEmitter::getTypeLowering(Type type) {
152152
auto pbGenSig =
153153
getPullback().getLoweredFunctionType()->getSubstGenericSignature();
154-
Lowering::AbstractionPattern pattern(
155-
pbGenSig, type->getCanonicalType(pbGenSig));
154+
Lowering::AbstractionPattern pattern(pbGenSig,
155+
type->getCanonicalType(pbGenSig));
156156
return getPullback().getTypeLowering(pattern, type);
157157
}
158158

@@ -2083,8 +2083,8 @@ void PullbackEmitter::accumulateIndirect(SILValue lhsDestAccess,
20832083
auto type = lhsDestAccess->getType();
20842084
auto astType = type.getASTType();
20852085
auto *swiftMod = getModule().getSwiftModule();
2086-
auto tangentSpace = astType->getAutoDiffTangentSpace(
2087-
LookUpConformanceInModule(swiftMod));
2086+
auto tangentSpace =
2087+
astType->getAutoDiffTangentSpace(LookUpConformanceInModule(swiftMod));
20882088
assert(tangentSpace && "No tangent space for this type");
20892089
switch (tangentSpace->getKind()) {
20902090
case TangentSpace::Kind::TangentVector: {

0 commit comments

Comments
 (0)