@@ -96,3 +96,72 @@ class PipelineState(Enum):
9696 IDLE = 0
9797 CALL_FWD = 1
9898 CALL_BWD = 2
99+
100+ def __str__ (self ) -> str :
101+ return self .name
102+
103+
104+ @unique
105+ class PipelinePhase (Enum ):
106+ """
107+ Pipeline phase for the train pipeline
108+
109+ please:
110+ 1. order the phases in the order of execution of base pipeline.
111+ 2. add notes to explain the phases if needed.
112+
113+ """
114+
115+ def __str__ (self ) -> str :
116+ return self .value
117+
118+ def __eq__ (self , obj : "PipelinePhase" ) -> bool :
119+ return self .value == obj .value
120+
121+ # placeholder for empty
122+ NULL = "null"
123+
124+ # usually the data is first available on CPU when loading from dataloader
125+ # need to move/copy the input batch to device if using GPU training
126+ COPY_BATCH_TO_DEVICE = "copy_batch_to_device"
127+
128+ # input post processing is needed for sparse data dist pipeline, where the sparse features
129+ # are traced (built) from the ModelInput via fx tracing
130+ INPUT_POST_PROC = "input_post_proc"
131+
132+ # the sparse features (AKA, KJTs) are in a jagged format so the data size are unknown to
133+ # other ranks. so a comms is needed to exchange the data size info, i.e., the splits
134+ INPUT_SPLITS_DIST = "input_splits_dist"
135+
136+ # once a rank knows the data size from other ranks (via splits dist), it can initialize
137+ # a all-to-all comms to exchange the actual data of the sparse features
138+ # NOTE: the splits have to be available on the host side
139+ INPUT_DATA_DIST = "input_data_dist"
140+
141+ # embedding lookup is done in FBGEMM.TBE on each rank
142+ EMBEDDING_LOOKUP = "embedding_lookup"
143+
144+ # the embedding lookup results (i.e., the embeddings) are needed in each rank, it's often done
145+ # with the output dist with an all_to_all comms
146+ EMBEDDING_OUTPUT_DIST = "embedding_output_dist"
147+
148+ # A typical DLRM model arch contains sparse arch and dense arch, here we treat the model excluding
149+ # "sparse modules" as dense part. It actually also includes the dense-sharded embedding tables.
150+ DENSE_FORWARD = "dense_forward"
151+
152+ # model's backward usually uses torch.autograd, the embedding modules' backward is handled by TBE
153+ DENSE_BACKWARD = "dense_backward"
154+
155+ # on each rank, after dense arch's backward, the gradients are available for the embedding tables
156+ # a backward of "embedding output dist" is needed to gather the embedding gradients from all ranks
157+ # to the rank where the embedding table is hosted.
158+ EMBEDDING_GRAD_DIST = "embedding_grad_dist"
159+
160+ # TBE backward usually update the embedding table weights inplace
161+ EMBEDDING_BACKWARD = "embedding_backward"
162+
163+ # we decouple the embedding update from backward just in case the change is not coupled
164+ EMBEDDING_UPDATE = "embedding_update"
165+
166+ # the optimizer step usually only includes the dense module weights
167+ DENSE_OPTIMIZER_STEP = "dense_optimizer_step"
0 commit comments