Skip to content

Commit 904477d

Browse files
authored
Merge pull request swiftlang#31173 from dan-zheng/property-wrapper-differentiation
[AutoDiff] Support differentiation of wrapped properties.
2 parents 6cecac7 + 96f3f6f commit 904477d

15 files changed

+903
-98
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2734,13 +2734,21 @@ WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
27342734
"stored property %0 has no derivative because %1 does not conform to "
27352735
"'Differentiable'; add an explicit '@noDerivative' attribute"
27362736
"%select{|, or conform %2 to 'AdditiveArithmetic'}3",
2737-
(Identifier, Type, Identifier, bool))
2738-
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
2737+
(/*propName*/ Identifier, /*propType*/ Type, /*nominalName*/ Identifier,
2738+
/*nominalCanDeriveAdditiveArithmetic*/ bool))
2739+
WARNING(differentiable_immutable_wrapper_implicit_noderivative_fixit,none,
27392740
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
2740-
"requires all stored properties to be mutable; use 'var' instead, or add "
2741-
"an explicit '@noDerivative' attribute"
2741+
"requires all stored properties not marked with `@noDerivative` to be "
2742+
"mutable; add an explicit '@noDerivative' attribute"
27422743
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
2743-
(Identifier, Identifier, bool))
2744+
(/*wrapperType*/ StringRef, /*nominalName*/ Identifier,
2745+
/*nominalCanDeriveAdditiveArithmetic*/ bool))
2746+
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
2747+
"synthesis of the 'Differentiable.move(along:)' requirement for %0 "
2748+
"requires all stored properties not marked with `@noDerivative` to be "
2749+
"mutable; use 'var' instead, or add an explicit '@noDerivative' attribute"
2750+
"%select{|, or conform %0 to 'AdditiveArithmetic'}1",
2751+
(/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool))
27442752

27452753
NOTE(codable_extraneous_codingkey_case_here,none,
27462754
"CodingKey case %0 does not match any stored properties", (Identifier))

include/swift/SILOptimizer/Differentiation/AdjointValue.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ class AdjointValue final {
165165
break;
166166
}
167167
}
168+
SWIFT_DEBUG_DUMP { print(llvm::dbgs()); };
168169
};
169170

170171
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,23 @@ ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v);
6161
/// tuple-typed and such a user exists.
6262
DestructureTupleInst *getSingleDestructureTupleUser(SILValue value);
6363

64+
/// Returns true if the given original function is a "semantic member accessor".
65+
///
66+
/// "Semantic member accessors" are attached to member properties that have a
67+
/// corresponding tangent stored property in the parent `TangentVector` type.
68+
/// These accessors have special-case pullback generation based on their
69+
/// semantic behavior.
70+
///
71+
/// "Semantic member accessors" currently include:
72+
/// - Stored property accessors. These are implicitly generated.
73+
/// - Property wrapper wrapped value accessors. These are implicitly generated
74+
/// and internally call `var wrappedValue`.
75+
bool isSemanticMemberAccessor(SILFunction *original);
76+
77+
/// Returns true if the given apply site has a "semantic member accessor"
78+
/// callee.
79+
bool hasSemanticMemberAccessorCallee(ApplySite applySite);
80+
6481
/// Given a full apply site, apply the given callback to each of its
6582
/// "direct results".
6683
///

include/swift/SILOptimizer/Differentiation/PullbackEmitter.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
100100
SILBuilder localAllocBuilder;
101101

102102
/// Stack buffers allocated for storing local adjoint values.
103-
SmallVector<SILValue, 64> functionLocalAllocations;
103+
SmallVector<AllocStackInst *, 64> functionLocalAllocations;
104104

105105
/// A set used to remember local allocations that were destroyed.
106106
llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;
@@ -316,6 +316,19 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
316316
/// if any error occurs.
317317
bool run();
318318

319+
/// Performs pullback generation on the empty pullback function, given that
320+
/// the original function is a "semantic member accessor".
321+
///
322+
/// "Semantic member accessors" are attached to member properties that have a
323+
/// corresponding tangent stored property in the parent `TangentVector` type.
324+
/// These accessors have special-case pullback generation based on their
325+
/// semantic behavior.
326+
///
327+
/// Returns true if any error occurs.
328+
bool runForSemanticMemberAccessor();
329+
bool runForSemanticMemberGetter();
330+
bool runForSemanticMemberSetter();
331+
319332
/// If original result is non-varied, it will always have a zero derivative.
320333
/// Skip full pullback generation and simply emit zero derivatives for wrt
321334
/// parameters.

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,44 @@ DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) {
7171
return result;
7272
}
7373

74+
bool isSemanticMemberAccessor(SILFunction *original) {
75+
auto *dc = original->getDeclContext();
76+
if (!dc)
77+
return false;
78+
auto *decl = dc->getAsDecl();
79+
if (!decl)
80+
return false;
81+
auto *accessor = dyn_cast<AccessorDecl>(decl);
82+
if (!accessor)
83+
return false;
84+
// Currently, only getters and setters are supported.
85+
// TODO(SR-12640): Support `modify` accessors.
86+
if (accessor->getAccessorKind() != AccessorKind::Get &&
87+
accessor->getAccessorKind() != AccessorKind::Set)
88+
return false;
89+
// Accessor must come from a `var` declaration.
90+
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());
91+
if (!varDecl)
92+
return false;
93+
// Return true for stored property accessors.
94+
if (varDecl->hasStorage() && varDecl->isInstanceMember())
95+
return true;
96+
// Return true for properties that have attached property wrappers.
97+
if (varDecl->hasAttachedPropertyWrapper())
98+
return true;
99+
// Otherwise, return false.
100+
// User-defined accessors can never be supported because they may use custom
101+
// logic that does not semantically perform a member access.
102+
return false;
103+
}
104+
105+
bool hasSemanticMemberAccessorCallee(ApplySite applySite) {
106+
if (auto *FRI = dyn_cast<FunctionRefBaseInst>(applySite.getCallee()))
107+
if (auto *F = FRI->getReferencedFunctionOrNull())
108+
return isSemanticMemberAccessor(F);
109+
return false;
110+
}
111+
74112
void forEachApplyDirectResult(
75113
FullApplySite applySite,
76114
llvm::function_ref<void(SILValue)> resultCallback) {

0 commit comments

Comments
 (0)