Skip to content

Commit 81fbd45

Browse files
authored
Merge pull request #940 from swiftwasm/master
[pull] swiftwasm from master
2 parents f8f8244 + c721cf1 commit 81fbd45

36 files changed

+190
-126
lines changed

include/swift/AST/IndexSubset.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class IndexSubset : public llvm::FoldingSetNode {
207207
}
208208

209209
void print(llvm::raw_ostream &s = llvm::outs()) const;
210-
SWIFT_DEBUG_DUMPER(dump(llvm::raw_ostream &s = llvm::errs()));
210+
SWIFT_DEBUG_DUMPER(dump());
211211

212212
int findNext(int startIndex) const;
213213
int findFirst() const { return findNext(-1); }

include/swift/SILOptimizer/Differentiation/Common.h

-3
Original file line numberDiff line numberDiff line change
@@ -211,19 +211,16 @@ inline void createEntryArguments(SILFunction *f) {
211211
decl->setSpecifier(ParamDecl::Specifier::Default);
212212
entry->createFunctionArgument(type, decl);
213213
};
214-
// f->getLoweredFunctionType()->remap
215214
for (auto indResTy :
216215
conv.getIndirectSILResultTypes(f->getTypeExpansionContext())) {
217216
if (indResTy.hasArchetype())
218217
indResTy = indResTy.mapTypeOutOfContext();
219218
createFunctionArgument(f->mapTypeIntoContext(indResTy).getAddressType());
220-
// createFunctionArgument(indResTy.getAddressType());
221219
}
222220
for (auto paramTy : conv.getParameterSILTypes(f->getTypeExpansionContext())) {
223221
if (paramTy.hasArchetype())
224222
paramTy = paramTy.mapTypeOutOfContext();
225223
createFunctionArgument(f->mapTypeIntoContext(paramTy));
226-
// createFunctionArgument(paramTy);
227224
}
228225
}
229226

include/swift/SILOptimizer/Differentiation/PullbackEmitter.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKEMITTER_H
1919
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKEMITTER_H
2020

21+
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2122
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
2223
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
2324
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
24-
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2525

2626
#include "swift/SIL/TypeSubstCloner.h"
2727
#include "llvm/ADT/DenseMap.h"

include/swift/SILOptimizer/Differentiation/VJPEmitter.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPEMITTER_H
1919
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPEMITTER_H
2020

21+
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2122
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
2223
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
23-
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2424

2525
#include "swift/SIL/TypeSubstCloner.h"
2626
#include "llvm/ADT/DenseMap.h"

lib/AST/Decl.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -7111,10 +7111,19 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
71117111
ArrayRef<AutoDiffConfig>
71127112
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
71137113
prepareDerivativeFunctionConfigurations();
7114+
71147115
// Resolve derivative function configurations from `@differentiable`
71157116
// attributes by type-checking them.
71167117
for (auto *diffAttr : getAttrs().getAttributes<DifferentiableAttr>())
71177118
(void)diffAttr->getParameterIndices();
7119+
// For accessors: resolve derivative function configurations from storage
7120+
// `@differentiable` attributes by type-checking them.
7121+
if (auto *accessor = dyn_cast<AccessorDecl>(this)) {
7122+
auto *storage = accessor->getStorage();
7123+
for (auto *diffAttr : storage->getAttrs().getAttributes<DifferentiableAttr>())
7124+
(void)diffAttr->getParameterIndices();
7125+
}
7126+
71187127
// Load derivative configurations from imported modules.
71197128
auto &ctx = getASTContext();
71207129
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {

lib/AST/GenericSignature.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,8 @@ GenericSignatureImpl::lookupConformance(CanType type,
379379
}
380380

381381
bool GenericSignatureImpl::requiresClass(Type type) {
382-
if (!type->isTypeParameter()) return false;
382+
assert(type->isTypeParameter() &&
383+
"Only type parameters can have superclass requirements");
383384

384385
auto &builder = *getGenericSignatureBuilder();
385386
auto equivClass =

lib/AST/IndexSubset.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ void IndexSubset::print(llvm::raw_ostream &s) const {
8686
s << '}';
8787
}
8888

89-
void IndexSubset::dump(llvm::raw_ostream &s) const {
89+
void IndexSubset::dump() const {
90+
auto &s = llvm::errs();
9091
s << "(index_subset capacity=" << capacity << " indices=(";
9192
interleave(getIndices(), [&s](unsigned i) { s << i; },
9293
[&s] { s << ", "; });

lib/SIL/IR/AbstractionPattern.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,7 @@ bool AbstractionPattern::requiresClass() const {
222222
auto type = getType();
223223
if (auto archetype = dyn_cast<ArchetypeType>(type))
224224
return archetype->requiresClass();
225-
if (isa<DependentMemberType>(type) ||
226-
isa<GenericTypeParamType>(type)) {
225+
if (type->isTypeParameter()) {
227226
if (getKind() == Kind::ClangType) {
228227
// ObjC generics are always class constrained.
229228
return true;

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB,
126126
auto originalFnTy = original->getLoweredFunctionType();
127127
auto numResults = originalFnTy->getNumResults() +
128128
originalFnTy->getNumIndirectMutatingParameters();
129-
auto *resultIndices = IndexSubset::get(
130-
original->getASTContext(), numResults, indices.source);
129+
auto *resultIndices =
130+
IndexSubset::get(original->getASTContext(), numResults, indices.source);
131131
auto *parameterIndices = indices.parameters;
132132
AutoDiffConfig config(parameterIndices, resultIndices, genericSig);
133133
auto enumName = mangler.mangleAutoDiffGeneratedDeclaration(
@@ -199,8 +199,8 @@ LinearMapInfo::createLinearMapStruct(SILBasicBlock *originalBB,
199199
auto originalFnTy = original->getLoweredFunctionType();
200200
auto numResults = originalFnTy->getNumResults() +
201201
originalFnTy->getNumIndirectMutatingParameters();
202-
auto *resultIndices = IndexSubset::get(
203-
original->getASTContext(), numResults, indices.source);
202+
auto *resultIndices =
203+
IndexSubset::get(original->getASTContext(), numResults, indices.source);
204204
auto *parameterIndices = indices.parameters;
205205
AutoDiffConfig config(parameterIndices, resultIndices, genericSig);
206206
auto structName = mangler.mangleAutoDiffGeneratedDeclaration(

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ void PullbackEmitter::cleanUpTemporariesForBlock(SILBasicBlock *bb,
151151
const Lowering::TypeLowering &PullbackEmitter::getTypeLowering(Type type) {
152152
auto pbGenSig =
153153
getPullback().getLoweredFunctionType()->getSubstGenericSignature();
154-
Lowering::AbstractionPattern pattern(
155-
pbGenSig, type->getCanonicalType(pbGenSig));
154+
Lowering::AbstractionPattern pattern(pbGenSig,
155+
type->getCanonicalType(pbGenSig));
156156
return getPullback().getTypeLowering(pattern, type);
157157
}
158158

@@ -2083,8 +2083,8 @@ void PullbackEmitter::accumulateIndirect(SILValue lhsDestAccess,
20832083
auto type = lhsDestAccess->getType();
20842084
auto astType = type.getASTType();
20852085
auto *swiftMod = getModule().getSwiftModule();
2086-
auto tangentSpace = astType->getAutoDiffTangentSpace(
2087-
LookUpConformanceInModule(swiftMod));
2086+
auto tangentSpace =
2087+
astType->getAutoDiffTangentSpace(LookUpConformanceInModule(swiftMod));
20882088
assert(tangentSpace && "No tangent space for this type");
20892089
switch (tangentSpace->getKind()) {
20902090
case TangentSpace::Kind::TangentVector: {

lib/SILOptimizer/Differentiation/VJPEmitter.cpp

+12-9
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ SILFunction *VJPEmitter::createEmptyPullback() {
174174
SILParameterInfo inoutParamTanParam(
175175
origResult.getInterfaceType()
176176
->getAutoDiffTangentSpace(lookupConformance)
177-
->getType()->getCanonicalType(witnessCanGenSig),
177+
->getType()
178+
->getCanonicalType(witnessCanGenSig),
178179
inoutParamTanConvention);
179180
pbParams.push_back(inoutParamTanParam);
180181
} else {
@@ -184,15 +185,17 @@ SILFunction *VJPEmitter::createEmptyPullback() {
184185
pbParams.push_back(getTangentParameterInfoForOriginalResult(
185186
origResult.getInterfaceType()
186187
->getAutoDiffTangentSpace(lookupConformance)
187-
->getType()->getCanonicalType(witnessCanGenSig),
188+
->getType()
189+
->getCanonicalType(witnessCanGenSig),
188190
origResult.getConvention()));
189191
}
190192

191193
// Accept a pullback struct in the pullback parameter list. This is the
192194
// returned pullback's closure context.
193195
auto *origExit = &*original->findReturnBB();
194196
auto *pbStruct = pullbackInfo.getLinearMapStruct(origExit);
195-
auto pbStructType = pbStruct->getDeclaredInterfaceType()->getCanonicalType(witnessCanGenSig);
197+
auto pbStructType =
198+
pbStruct->getDeclaredInterfaceType()->getCanonicalType(witnessCanGenSig);
196199
pbParams.push_back({pbStructType, ParameterConvention::Direct_Owned});
197200

198201
// Add pullback results for the requested wrt parameters.
@@ -205,7 +208,8 @@ SILFunction *VJPEmitter::createEmptyPullback() {
205208
adjResults.push_back(getTangentResultInfoForOriginalParameter(
206209
origParam.getInterfaceType()
207210
->getAutoDiffTangentSpace(lookupConformance)
208-
->getType()->getCanonicalType(witnessCanGenSig),
211+
->getType()
212+
->getCanonicalType(witnessCanGenSig),
209213
origParam.getConvention()));
210214
}
211215

@@ -275,8 +279,8 @@ void VJPEmitter::visitSILInstruction(SILInstruction *inst) {
275279

276280
SILType VJPEmitter::getLoweredType(Type type) {
277281
auto vjpGenSig = vjp->getLoweredFunctionType()->getSubstGenericSignature();
278-
Lowering::AbstractionPattern pattern(
279-
vjpGenSig, type->getCanonicalType(vjpGenSig));
282+
Lowering::AbstractionPattern pattern(vjpGenSig,
283+
type->getCanonicalType(vjpGenSig));
280284
return vjp->getLoweredType(pattern, type);
281285
}
282286

@@ -490,9 +494,8 @@ void VJPEmitter::visitSwitchEnumInstBase(SwitchEnumInstBase *sei) {
490494
newDefaultBB, caseBBs);
491495
break;
492496
case SILInstructionKind::SwitchEnumAddrInst:
493-
getBuilder().createSwitchEnumAddr(sei->getLoc(),
494-
getOpValue(sei->getOperand()),
495-
newDefaultBB, caseBBs);
497+
getBuilder().createSwitchEnumAddr(
498+
sei->getLoc(), getOpValue(sei->getOperand()), newDefaultBB, caseBBs);
496499
break;
497500
default:
498501
llvm_unreachable("Expected `switch_enum` or `switch_enum_addr`");

lib/SILOptimizer/Mandatory/Differentiation.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -656,9 +656,9 @@ emitDerivativeFunctionReference(
656656
auto loc = witnessMethod->getLoc();
657657
auto requirementDeclRef = witnessMethod->getMember();
658658
auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl();
659-
// If requirement declaration does not have any `@differentiable`
660-
// attributes, produce an error.
661-
if (!requirementDecl->getAttrs().hasAttribute<DifferentiableAttr>()) {
659+
// If requirement declaration does not have any derivative function
660+
// configurations, produce an error.
661+
if (requirementDecl->getDerivativeFunctionConfigurations().empty()) {
662662
context.emitNondifferentiabilityError(
663663
original, invoker, diag::autodiff_protocol_member_not_differentiable);
664664
return None;
@@ -701,9 +701,9 @@ emitDerivativeFunctionReference(
701701
auto loc = classMethod->getLoc();
702702
auto methodDeclRef = classMethod->getMember();
703703
auto *methodDecl = methodDeclRef.getAbstractFunctionDecl();
704-
// If method declaration does not have any `@differentiable` attributes,
705-
// produce an error.
706-
if (!methodDecl->getAttrs().hasAttribute<DifferentiableAttr>()) {
704+
// If method declaration does not have any derivative function
705+
// configurations, produce an error.
706+
if (methodDecl->getDerivativeFunctionConfigurations().empty()) {
707707
context.emitNondifferentiabilityError(
708708
original, invoker, diag::autodiff_class_member_not_differentiable);
709709
return None;

lib/Sema/TypeCheckPattern.cpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -467,9 +467,7 @@ class ResolvePattern : public ASTVisitor<ResolvePattern,
467467
auto *repr = IdentTypeRepr::create(Context, components);
468468

469469
// See if the repr resolves to a type.
470-
Type ty = TypeChecker::resolveIdentifierType(
471-
TypeResolution::forContextual(DC, options), repr);
472-
470+
auto ty = TypeResolution::forContextual(DC, options).resolveType(repr);
473471
auto *enumDecl = dyn_cast_or_null<EnumDecl>(ty->getAnyNominal());
474472
if (!enumDecl)
475473
return nullptr;
@@ -566,8 +564,8 @@ class ResolvePattern : public ASTVisitor<ResolvePattern,
566564
auto *prefixRepr = IdentTypeRepr::create(Context, components);
567565

568566
// See first if the entire repr resolves to a type.
569-
Type enumTy = TypeChecker::resolveIdentifierType(
570-
TypeResolution::forContextual(DC, options), prefixRepr);
567+
Type enumTy = TypeResolution::forContextual(DC, options)
568+
.resolveType(prefixRepr);
571569
if (!dyn_cast_or_null<EnumDecl>(enumTy->getAnyNominal()))
572570
return nullptr;
573571

lib/Sema/TypeCheckType.cpp

+53-56
Original file line numberDiff line numberDiff line change
@@ -1632,60 +1632,6 @@ static Type applyNonEscapingFromContext(DeclContext *DC,
16321632
return ty;
16331633
}
16341634

1635-
/// Returns a valid type or ErrorType in case of an error.
1636-
Type TypeChecker::resolveIdentifierType(TypeResolution resolution,
1637-
IdentTypeRepr *IdType) {
1638-
const auto options = resolution.getOptions();
1639-
auto DC = resolution.getDeclContext();
1640-
ASTContext &ctx = DC->getASTContext();
1641-
auto &diags = ctx.Diags;
1642-
auto ComponentRange = IdType->getComponentRange();
1643-
auto Components = llvm::makeArrayRef(ComponentRange.begin(),
1644-
ComponentRange.end());
1645-
Type result = resolveIdentTypeComponent(resolution, Components);
1646-
if (!result) return nullptr;
1647-
1648-
if (auto moduleTy = result->getAs<ModuleType>()) {
1649-
// Allow module types only if flag is specified.
1650-
if (options.contains(TypeResolutionFlags::AllowModule))
1651-
return moduleTy;
1652-
// Otherwise, emit an error.
1653-
if (!options.contains(TypeResolutionFlags::SilenceErrors)) {
1654-
auto moduleName = moduleTy->getModule()->getName();
1655-
diags.diagnose(Components.back()->getNameLoc(),
1656-
diag::cannot_find_type_in_scope, DeclNameRef(moduleName));
1657-
diags.diagnose(Components.back()->getNameLoc(),
1658-
diag::note_module_as_type, moduleName);
1659-
}
1660-
Components.back()->setInvalid();
1661-
return ErrorType::get(ctx);
1662-
}
1663-
1664-
// Hack to apply context-specific @escaping to a typealias with an underlying
1665-
// function type.
1666-
if (result->is<FunctionType>())
1667-
result = applyNonEscapingFromContext(DC, result, options);
1668-
1669-
// Check the availability of the type.
1670-
1671-
// We allow a type to conform to a protocol that is less available than
1672-
// the type itself. This enables a type to retroactively model or directly
1673-
// conform to a protocol only available on newer OSes and yet still be used on
1674-
// older OSes.
1675-
// To support this, inside inheritance clauses we allow references to
1676-
// protocols that are unavailable in the current type refinement context.
1677-
1678-
if (!options.contains(TypeResolutionFlags::SilenceErrors) &&
1679-
!options.contains(TypeResolutionFlags::AllowUnavailable) &&
1680-
diagnoseAvailability(IdType, DC,
1681-
options.contains(TypeResolutionFlags::AllowUnavailableProtocol))) {
1682-
Components.back()->setInvalid();
1683-
return ErrorType::get(ctx);
1684-
}
1685-
1686-
return result;
1687-
}
1688-
16891635
/// Validate whether type associated with @autoclosure attribute is correct,
16901636
/// it supposed to be a function type with no parameters.
16911637
/// \returns true if there was an error, false otherwise.
@@ -1831,6 +1777,8 @@ namespace {
18311777
SmallVectorImpl<SILYieldInfo> &yields,
18321778
SmallVectorImpl<SILResultInfo> &results,
18331779
Optional<SILResultInfo> &errorResult);
1780+
Type resolveIdentifierType(IdentTypeRepr *IdType,
1781+
TypeResolutionOptions options);
18341782
Type resolveSpecifierTypeRepr(SpecifierTypeRepr *repr,
18351783
TypeResolutionOptions options);
18361784
Type resolveArrayType(ArrayTypeRepr *repr,
@@ -1949,8 +1897,7 @@ Type TypeResolver::resolveType(TypeRepr *repr, TypeResolutionOptions options) {
19491897
case TypeReprKind::SimpleIdent:
19501898
case TypeReprKind::GenericIdent:
19511899
case TypeReprKind::CompoundIdent:
1952-
return TypeChecker::resolveIdentifierType(resolution.withOptions(options),
1953-
cast<IdentTypeRepr>(repr));
1900+
return resolveIdentifierType(cast<IdentTypeRepr>(repr), options);
19541901

19551902
case TypeReprKind::Function: {
19561903
if (!(options & TypeResolutionFlags::SILType)) {
@@ -3258,6 +3205,56 @@ bool TypeResolver::resolveSILResults(TypeRepr *repr,
32583205
yields, ordinaryResults, errorResult);
32593206
}
32603207

3208+
Type TypeResolver::resolveIdentifierType(IdentTypeRepr *IdType,
3209+
TypeResolutionOptions options) {
3210+
auto ComponentRange = IdType->getComponentRange();
3211+
auto Components = llvm::makeArrayRef(ComponentRange.begin(),
3212+
ComponentRange.end());
3213+
Type result = resolveIdentTypeComponent(resolution.withOptions(options),
3214+
Components);
3215+
if (!result) return nullptr;
3216+
3217+
if (auto moduleTy = result->getAs<ModuleType>()) {
3218+
// Allow module types only if flag is specified.
3219+
if (options.contains(TypeResolutionFlags::AllowModule))
3220+
return moduleTy;
3221+
// Otherwise, emit an error.
3222+
if (!options.contains(TypeResolutionFlags::SilenceErrors)) {
3223+
auto moduleName = moduleTy->getModule()->getName();
3224+
diagnose(Components.back()->getNameLoc(),
3225+
diag::cannot_find_type_in_scope, DeclNameRef(moduleName));
3226+
diagnose(Components.back()->getNameLoc(),
3227+
diag::note_module_as_type, moduleName);
3228+
}
3229+
Components.back()->setInvalid();
3230+
return ErrorType::get(Context);
3231+
}
3232+
3233+
// Hack to apply context-specific @escaping to a typealias with an underlying
3234+
// function type.
3235+
if (result->is<FunctionType>())
3236+
result = applyNonEscapingFromContext(DC, result, options);
3237+
3238+
// Check the availability of the type.
3239+
3240+
// We allow a type to conform to a protocol that is less available than
3241+
// the type itself. This enables a type to retroactively model or directly
3242+
// conform to a protocol only available on newer OSes and yet still be used on
3243+
// older OSes.
3244+
// To support this, inside inheritance clauses we allow references to
3245+
// protocols that are unavailable in the current type refinement context.
3246+
3247+
if (!options.contains(TypeResolutionFlags::SilenceErrors) &&
3248+
!options.contains(TypeResolutionFlags::AllowUnavailable) &&
3249+
diagnoseAvailability(IdType, DC,
3250+
options.contains(TypeResolutionFlags::AllowUnavailableProtocol))) {
3251+
Components.back()->setInvalid();
3252+
return ErrorType::get(Context);
3253+
}
3254+
3255+
return result;
3256+
}
3257+
32613258
Type TypeResolver::resolveSpecifierTypeRepr(SpecifierTypeRepr *repr,
32623259
TypeResolutionOptions options) {
32633260
// inout is only valid for (non-Subscript and non-EnumCaseDecl)

0 commit comments

Comments
 (0)