Skip to content

Commit 31af116

Browse files
authored
[AutoDiff] NFC: prettify commutative diagrams. (swiftlang#31525)
Make commutative diagrams pretty using Unicode box characters. Use lowercase letters for arrow names.
1 parent 738ef73 commit 31af116

File tree

2 files changed

+32
-30
lines changed

2 files changed

+32
-30
lines changed

include/swift/SIL/TypeSubstCloner.h

+16-15
Original file line numberDiff line numberDiff line change
@@ -322,28 +322,29 @@ class TypeSubstCloner : public SILClonerWithScopes<ImplClass> {
322322
return;
323323
}
324324
// If the extractee is a derivative function, check whether the *remapped
325-
// derivative function type* (BC) is equal to the *derivative remapped
326-
// function type* (AD).
325+
// derivative function type* (bc) is equal to the *derivative remapped
326+
// function type* (ad).
327327
//
328-
// +----------------+ remap +-------------------------+
329-
// | orig. fn type | -------(A)------> | remapped orig. fn type |
330-
// +----------------+ +-------------------------+
331-
// | |
332-
// (B, SILGen) getAutoDiffDerivativeFunctionType (D, here)
333-
// V V
334-
// +----------------+ remap +-------------------------+
335-
// | deriv. fn type | -------(C)------> | remapped deriv. fn type |
336-
// +----------------+ +-------------------------+
328+
// ┌────────────────┐ remap ┌─────────────────────────┐
329+
// │ orig. fn type │ ───────(a)──────► │ remapped orig. fn type │
330+
// └────────────────┘ └─────────────────────────┘
331+
// │ │
332+
// (b, SILGen) getAutoDiffDerivativeFunctionType (d, here)
333+
// │ │
334+
// ▼ ▼
335+
// ┌────────────────┐ remap ┌─────────────────────────┐
336+
// │ deriv. fn type │ ───────(c)──────► │ remapped deriv. fn type │
337+
// └────────────────┘ └─────────────────────────┘
337338
//
338-
// (AD) does not always commute with (BC):
339-
// - (AD) is the result of remapping, then computing the derivative type.
339+
// (ad) does not always commute with (bc):
340+
// - (ad) is the result of remapping, then computing the derivative type.
340341
// This is the default cloning behavior, but may break invariants in the
341342
// initial SIL generated by SILGen.
342-
// - (BC) is the result of computing the derivative type (SILGen), then
343+
// - (bc) is the result of computing the derivative type (SILGen), then
343344
// remapping. This is the expected type, preserving invariants from
344345
// earlier transforms.
345346
//
346-
// If (AD) is not equal to (BC), use (BC) as the explicit type.
347+
// If (ad) is not equal to (bc), use (bc) as the explicit type.
347348
SILType remappedOrigType = getOpType(dfei->getOperand()->getType());
348349
auto remappedOrigFnType = remappedOrigType.castTo<SILFunctionType>();
349350
auto derivativeRemappedFnType =

lib/SIL/IR/SILFunctionType.cpp

+16-15
Original file line numberDiff line numberDiff line change
@@ -3088,30 +3088,31 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion,
30883088
// If the constant refers to a derivative function, get the SIL type of the
30893089
// original function and use it to compute the derivative SIL type.
30903090
//
3091-
// This is necessary because the "lowered AST derivative function type" (BC)
3091+
// This is necessary because the "lowered AST derivative function type" (bc)
30923092
// may differ from the "derivative type of the lowered original function type"
3093-
// (AD):
3093+
// (ad):
30943094
//
3095-
// +--------------------+ lowering +--------------------+
3096-
// | AST orig. fn type | -------(A)------> | SIL orig. fn type |
3097-
// +--------------------+ +--------------------+
3098-
// | |
3099-
// (B, Sema) getAutoDiffDerivativeFunctionType (D, here)
3100-
// V V
3101-
// +--------------------+ lowering +--------------------+
3102-
// | AST deriv. fn type | -------(C)------> | SIL deriv. fn type |
3103-
// +--------------------+ +--------------------+
3095+
// ┌────────────────────┐ lowering ┌────────────────────┐
3096+
// │ AST orig. fn type │ ───────(a)──────► │ SIL orig. fn type │
3097+
// └────────────────────┘ └────────────────────┘
3098+
// │ │
3099+
// (b, Sema) getAutoDiffDerivativeFunctionType (d, here)
3100+
// │ │
3101+
// ▼ ▼
3102+
// ┌────────────────────┐ lowering ┌────────────────────┐
3103+
// │ AST deriv. fn type │ ───────(c)──────► │ SIL deriv. fn type │
3104+
// └────────────────────┘ └────────────────────┘
31043105
//
3105-
// (AD) does not always commute with (BC):
3106-
// - (BC) is the result of computing the AST derivative type (Sema), then
3106+
// (ad) does not always commute with (bc):
3107+
// - (bc) is the result of computing the AST derivative type (Sema), then
31073108
// lowering it via SILGen. This is the default lowering behavior, but may
31083109
// break SIL typing invariants because expected lowered derivative types are
31093110
// computed from lowered original function types.
3110-
// - (AD) is the result of lowering the original function type, then computing
3111+
// - (ad) is the result of lowering the original function type, then computing
31113112
// its derivative type. This is the expected lowered derivative type,
31123113
// preserving SIL typing invariants.
31133114
//
3114-
// Always use (AD) to compute lowered derivative function types.
3115+
// Always use (ad) to compute lowered derivative function types.
31153116
if (auto *derivativeId = constant.derivativeFunctionIdentifier) {
31163117
// Get lowered original function type.
31173118
auto origFnConstantInfo = getConstantInfo(

0 commit comments

Comments
 (0)