Skip to content

Commit e0c68ab

Browse files
duanjunwenflybird11111pre-commit-ci[bot]BurkeHulkGuangyaoZhang
authored
[Zerobubble] merge main. (#6142)
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * [feat] moehybrid support zerobubble; * [fix] fix zerobubble pp for shardformer type input; * [feat] add more test; * [fix] fix require_grad & deallocate call; * [fix] updatw bwd b&w input; dict --> list[torch.Tensor] * [fix] fix bwd w input; * [fix] fix mem assert; * [fix] fix input_tensors buffer append input_obj(dict) --> Tuple (microbatch, input_obj) , and all bwd b related cal logic; * [fix] use tree_flatten replace dict traverse; * [fix] rm comments; * [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs' * [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch; * [fix] fix detach clone release order; * [fix] fix ci --> oom in 4096 hidden dim; * [fix] fix dumb clone; * [fix] fix detach_output_obj clone; * [fix] fix stage_indices; * [fix] fix traverse; traverse dict --> traverse tensor List; * [fix] fix zerobubble; support shardformer model type; * [fix] rm comments; * [fix] fix test_pipeline_utils ci; * [fix] remove duplicate arg; rm comments; * [fix] remove chunk 0 stage 0 bwd b; u don't have to cal micrbatch's dx; * [fix] rm print & comments; * [plugin] hybrid support zero bubble pipeline (#6060) * hybrid support zbv * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * hybrid support zbv * fix fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <[email protected]> * [feat] zerobubble support moehybridplugin; * [feat] update optimizer bwd; ä¸ * [fix] fix build ci; * [zerobubble] rebase main (#6075) * fp8 operators for compressed communication cast_to_fp8, cast_from_fp8, all_reduce_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * fix scaling algorithm in FP8 casting * support fp8 communication in pipeline parallelism * add fp8_communication flag in the script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * shardformer fp8 * fix rebase * remove all to all * fix shardformer fp8 communication training degradation * [fp8] support all-gather flat tensor (#5932) * [fp8] add fp8 comm for low level zero * [test] add zero fp8 test case * [Feature] llama shardformer fp8 support (#5938) * add llama shardformer fp8 * Llama Shardformer Parity * fix typo * fix all reduce * fix pytest failure * fix reduce op and move function to fp8.py * fix typo * [FP8] rebase main (#5963) * add SimPO * fix dataloader * remove debug code * add orpo * fix style * fix colossalai, transformers version * fix colossalai, transformers version * fix colossalai, transformers version * fix torch colossalai version * update transformers version * [shardformer] DeepseekMoE support (#5871) * [Feature] deepseek moe expert parallel implement * [misc] fix typo, remove redundant file (#5867) * [misc] fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] deepseek support & unit test * [misc] remove debug code & useless print * [misc] fix typos (#5872) * [Feature] remove modeling file, use auto config. (#5884) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [Deepseek] remove redundant code (#5888) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [Feature/deepseek] resolve comment. (#5889) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [misc] mv module replacement into if branch * [misc] add some warning message and modify some code in unit test * [misc] fix typos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap Co-authored-by: Edenzzzz <[email protected]> * [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838) * Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support * [HotFix] CI,import,requirements-test for #5838 (#5892) * [Hot Fix] CI,import,requirements-test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Enable PP + SP for llama (#5868) * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use a one cross entropy func for all shardformer models --------- Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897) * add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint * fix style * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix eval * hotfix citation * [zero] support all-gather overlap (#5898) * [zero] support all-gather overlap * [zero] add overlap all-gather flag * [misc] fix typo * [zero] update api * fix orpo cross entropy loss * [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446) * Remove unnecessary calls to deepcopy * Build DimSpec's difference dict only once This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough. * Fix documentation of DimSpec's difference method * [ShardFormer] fix qwen2 sp (#5903) * [compatibility] support torch 2.2 (#5875) * Support Pytorch 2.2.2 * keep build_on_pr file and update .compatibility * fix object_to_tensor usage when torch>=2.3.0 (#5820) * [misc] support torch2.3 (#5893) * [misc] support torch2.3 * [devops] update compatibility ci * [devops] update compatibility ci * [devops] add debug * [devops] add debug * [devops] add debug * [devops] add debug * [devops] remove debug * [devops] remove debug * [release] update version (#5912) * [plugin] support all-gather overlap for hybrid parallel (#5919) * [plugin] fixed all-gather overlap support for hybrid parallel * add kto * fix style, add kto data sample * [Examples] Add lazy init to OPT and GPT examples (#5924) Co-authored-by: Edenzzzz <[email protected]> * [ColossalChat] Hotfix for ColossalChat (#5910) * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * fix ddp issue * add Qwen 1.5 32B * refactor tokenization * [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931) * cannot access local variable 'default_conversation' where it is not associated with a value set default value for 'default_conversation' * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix test data * refactor evaluation * remove real data path * remove real data path * Add n_fused as an input from native_module (#5894) * [FIX BUG] convert env param to int in (#5934) * [Hotfix] Fix ZeRO typo #5936 Co-authored-by: Edenzzzz <[email protected]> * [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941) * Add a switch to control whether the model checkpoint needs to be saved after each epoch ends * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix style * fix style * fix style * [shardformer] hotfix attn mask (#5945) * [shardformer] hotfix attn mask (#5947) * [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895) * Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirement * [zero] hotfix update master params (#5951) * [release] update version (#5952) * [Chat] Fix lora (#5946) * fix merging * remove filepath * fix style * Update README.md (#5958) * [hotfix] Remove unused plan section (#5957) * remove readme * fix readme * update * [test] add mixtral for sequence classification * [test] add mixtral transformer test * [moe] fix plugin * [test] mixtra pp shard test * [chore] handle non member group * [zero] solve hang * [test] pass mixtral shardformer test * [moe] implement transit between non moe tp and ep * [zero] solve hang * [misc] solve booster hang by rename the variable * solve hang when parallel mode = pp + dp * [moe] implement submesh initialization * [moe] add mixtral dp grad scaling when not all experts are activated * [chore] manually revert unintended commit * [chore] trivial fix * [chore] arg pass & remove drop token * [test] add mixtral modelling test * [moe] implement tp * [moe] test deepseek * [moe] clean legacy code * [Feature] MoE Ulysses Support (#5918) * moe sp support * moe sp bug solve * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [chore] minor fix * [moe] init moe plugin comm setting with sp * moe sp + ep bug fix * [moe] finalize test (no pp) * [moe] full test for deepseek and mixtral (pp + sp to fix) * [chore] minor fix after rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [chore] solve moe ckpt test failure and some other arg pass failure * [moe] remove ops * [test] fix test: test_zero1_2 * [bug] fix: somehow logger hangs the program * [moe] deepseek moe sp support * [test] add check * [deepseek] replace attn (a workaround for bug in transformers) * [misc] skip redunant test * [misc] remove debug/print code * [moe] refactor mesh assignment * Revert "[moe] implement submesh initialization" This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582. * [chore] change moe_pg_mesh to private * [misc] remove incompatible test config * [misc] fix ci failure: change default value to false in moe plugin * [misc] remove useless condition * [chore] docstring * [moe] remove force_overlap_comm flag and add warning instead * [doc] add MoeHybridParallelPlugin docstring * [moe] solve dp axis issue * [chore] remove redundant test case, print string & reduce test tokens * [feat] Dist Loader for Eval (#5950) * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tp error * remove unused parameters * remove unused * update inference * update docs * update inference --------- Co-authored-by: Michelle <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [lora] lora support hybrid parallel plugin (#5956) * lora support hybrid plugin * fix * fix * fix * fix * fp8 operators for compressed communication cast_to_fp8, cast_from_fp8, all_reduce_fp8 * fix scaling algorithm in FP8 casting * support fp8 communication in pipeline parallelism * add fp8_communication flag in the script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * shardformer fp8 * fix rebase * remove all to all * fix shardformer fp8 communication training degradation * [fp8] support all-gather flat tensor (#5932) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update low_level_optim.py --------- Co-authored-by: YeAnbang <[email protected]> Co-authored-by: Haze188 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: Runyu Lu <[email protected]> Co-authored-by: Guangyao Zhang <[email protected]> Co-authored-by: YeAnbang <[email protected]> Co-authored-by: Hongxin Liu <[email protected]> Co-authored-by: Stephan Kö <[email protected]> Co-authored-by: アマデウス <[email protected]> Co-authored-by: Tong Li <[email protected]> Co-authored-by: zhurunhua <[email protected]> Co-authored-by: Insu Jang <[email protected]> Co-authored-by: Gao, Ruiyuan <[email protected]> Co-authored-by: hxwang <[email protected]> Co-authored-by: Michelle <[email protected]> Co-authored-by: Wang Binluo <[email protected]> Co-authored-by: HangXu <[email protected]> * [fp8]support all2all fp8 (#5953) * support all2all fp8 * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [fp8] add fp8 linear (#5967) * [fp8] add fp8 linear * [test] fix fp8 linear test condition * [test] fix fp8 linear test condition * [test] fix fp8 linear test condition * [fp8] support fp8 amp for hybrid parallel plugin (#5975) * [fp8] support fp8 amp for hybrid parallel plugin * [test] add fp8 hook test * [fp8] fix fp8 linear compatibility * fix (#5976) * [Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928) * support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement communication hook for FSDP params all-gather * added unit test for fp8 operators * support fp8 communication in GeminiPlugin * update training scripts to support fsdp and fp8 communication * fixed some minor bugs observed in unit test * add all_gather_into_tensor_flat_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * add skip the test if torch < 2.2.0 * add fp8_comm flag * rebase latest fp8 operators * rebase latest fp8 operators * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [test ci]Feature/fp8 comm (#5981) * fix * fix * fix * [fp8] support gemini plugin (#5978) * [fp8] refactor hook * [fp8] support gemini plugin * [example] add fp8 option for llama benchmark * [fp8] use torch compile (torch >= 2.3.0) (#5979) * [fp8] use torch compile (torch >= 2.4.0) * [fp8] set use_fast_accum in linear * [chore] formal version check * [chore] fix sig * [fp8]Moe support fp8 communication (#5977) * fix * support moe fp8 * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fix fi * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [fp8] support hybrid parallel plugin (#5982) * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * fp8 * fix * bert and bloom * chatglm and command * gpt2,gptj,bert, falcon,blip2 * mistral,opy,sam,t5,vit,whisper * fix * fix * fix * [fp8] refactor fp8 linear with compile (#5993) * [fp8] refactor fp8 linear with compile * [fp8] fix linear test * [fp8] fix linear test * [fp8] support asynchronous FP8 communication (#5997) * fix * fix * fix * support async all2all * support async op for all gather * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [fp8] update torch.compile for linear_fp8 to >= 2.4.0 (#6004) * [fp8] linear perf enhancement * [fp8]update reduce-scatter test (#6002) * fix * fix * fix * fix * [fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009) * [fp8] zero support fp8 linear. (#6006) * fix * fix * fix * zero fp8 * zero fp8 * Update requirements.txt * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the merge * fix the merge * fix the merge * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the merge * fix * fix * fix the merge * fix * fix * fix * fix * fix * fix the merge * fix * fix * fix * fix * [fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016) * add SimPO * fix dataloader * remove debug code * add orpo * fix style * fix colossalai, transformers version * fix colossalai, transformers version * fix colossalai, transformers version * fix torch colossalai version * update transformers version * [shardformer] DeepseekMoE support (#5871) * [Feature] deepseek moe expert parallel implement * [misc] fix typo, remove redundant file (#5867) * [misc] fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] deepseek support & unit test * [misc] remove debug code & useless print * [misc] fix typos (#5872) * [Feature] remove modeling file, use auto config. (#5884) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [Deepseek] remove redundant code (#5888) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [Feature/deepseek] resolve comment. (#5889) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [misc] mv module replacement into if branch * [misc] add some warning message and modify some code in unit test * [misc] fix typos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap Co-authored-by: Edenzzzz <[email protected]> * [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838) * Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support * [HotFix] CI,import,requirements-test for #5838 (#5892) * [Hot Fix] CI,import,requirements-test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Enable PP + SP for llama (#5868) * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use a one cross entropy func for all shardformer models --------- Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897) * add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint * fix style * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix eval * hotfix citation * [zero] support all-gather overlap (#5898) * [zero] support all-gather overlap * [zero] add overlap all-gather flag * [misc] fix typo * [zero] update api * fix orpo cross entropy loss * [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446) * Remove unnecessary calls to deepcopy * Build DimSpec's difference dict only once This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough. * Fix documentation of DimSpec's difference method * [ShardFormer] fix qwen2 sp (#5903) * [compatibility] support torch 2.2 (#5875) * Support Pytorch 2.2.2 * keep build_on_pr file and update .compatibility * fix object_to_tensor usage when torch>=2.3.0 (#5820) * [misc] support torch2.3 (#5893) * [misc] support torch2.3 * [devops] update compatibility ci * [devops] update compatibility ci * [devops] add debug * [devops] add debug * [devops] add debug * [devops] add debug * [devops] remove debug * [devops] remove debug * [release] update version (#5912) * [plugin] support all-gather overlap for hybrid parallel (#5919) * [plugin] fixed all-gather overlap support for hybrid parallel * add kto * fix style, add kto data sample * [Examples] Add lazy init to OPT and GPT examples (#5924) Co-authored-by: Edenzzzz <[email protected]> * [ColossalChat] Hotfix for ColossalChat (#5910) * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * fix ddp issue * add Qwen 1.5 32B * refactor tokenization * [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931) * cannot access local variable 'default_conversation' where it is not associated with a value set default value for 'default_conversation' * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix test data * refactor evaluation * remove real data path * remove real data path * Add n_fused as an input from native_module (#5894) * [FIX BUG] convert env param to int in (#5934) * [Hotfix] Fix ZeRO typo #5936 Co-authored-by: Edenzzzz <[email protected]> * [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941) * Add a switch to control whether the model checkpoint needs to be saved after each epoch ends * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix style * fix style * fix style * [shardformer] hotfix attn mask (#5945) * [shardformer] hotfix attn mask (#5947) * [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895) * Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirement * [zero] hotfix update master params (#5951) * [release] update version (#5952) * [Chat] Fix lora (#5946) * fix merging * remove filepath * fix style * Update README.md (#5958) * [hotfix] Remove unused plan section (#5957) * remove readme * fix readme * update * [test] add mixtral for sequence classification * [test] add mixtral transformer test * [moe] fix plugin * [test] mixtra pp shard test * [chore] handle non member group * [zero] solve hang * [test] pass mixtral shardformer test * [moe] implement transit between non moe tp and ep * [zero] solve hang * [misc] solve booster hang by rename the variable * solve hang when parallel mode = pp + dp * [moe] implement submesh initialization * [moe] add mixtral dp grad scaling when not all experts are activated * [chore] manually revert unintended commit * [chore] trivial fix * [chore] arg pass & remove drop token * [test] add mixtral modelling test * [moe] implement tp * [moe] test deepseek * [moe] clean legacy code * [Feature] MoE Ulysses Support (#5918) * moe sp support * moe sp bug solve * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [chore] minor fix * [moe] init moe plugin comm setting with sp * moe sp + ep bug fix * [moe] finalize test (no pp) * [moe] full test for deepseek and mixtral (pp + sp to fix) * [chore] minor fix after rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [chore] solve moe ckpt test failure and some other arg pass failure * [moe] remove ops * [test] fix test: test_zero1_2 * [bug] fix: somehow logger hangs the program * [moe] deepseek moe sp support * [test] add check * [deepseek] replace attn (a workaround for bug in transformers) * [misc] skip redunant test * [misc] remove debug/print code * [moe] refactor mesh assignment * Revert "[moe] implement submesh initialization" This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582. * [chore] change moe_pg_mesh to private * [misc] remove incompatible test config * [misc] fix ci failure: change default value to false in moe plugin * [misc] remove useless condition * [chore] docstring * [moe] remove force_overlap_comm flag and add warning instead * [doc] add MoeHybridParallelPlugin docstring * [moe] solve dp axis issue * [chore] remove redundant test case, print string & reduce test tokens * [feat] Dist Loader for Eval (#5950) * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tp error * remove unused parameters * remove unused * update inference * update docs * update inference --------- Co-authored-by: Michelle <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [lora] lora support hybrid parallel plugin (#5956) * lora support hybrid plugin * fix * fix * fix * fix * Support overall loss, update KTO logging * [Docs] clarify launch port Co-authored-by: Edenzzzz <[email protected]> * [Hotfix] README link (#5966) * update ignore * update readme * run style * update readme * [Hotfix] Avoid fused RMSnorm import error without apex (#5985) Co-authored-by: Edenzzzz <[email protected]> * [Chat] fix readme (#5989) * fix readme * fix readme, tokenization fully tested * fix readme, tokenization fully tested * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix sync condition (#6000) * [plugin] add cast inputs option for zero (#6003) * [pre-commit.ci] pre-commit autoupdate (#5995) updates: - [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991) * [Feature] Zigzag Ring attention (#5905) * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements --------- Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] update compatibility (#6008) * [misc] update compatibility * [misc] update requirements * [devops] disable requirements cache * [test] fix torch ddp test * [test] fix rerun on address in use * [test] fix lazy init * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the merge * fix the merge * overlap kv comm with output rescale (#6017) Co-authored-by: Edenzzzz <[email protected]> * fix the merge * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the merge * fix * fix * fix the merge * fix * [misc] Use dist logger in plugins (#6011) * use dist logger in plugins * remove trash * print on rank 0 --------- Co-authored-by: Edenzzzz <[email protected]> * fix * fix * fix * fix * fix the merge * fix * fix * fix * fix --------- Co-authored-by: YeAnbang <[email protected]> Co-authored-by: Haze188 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: Runyu Lu <[email protected]> Co-authored-by: Guangyao Zhang <[email protected]> Co-authored-by: YeAnbang <[email protected]> Co-authored-by: Hongxin Liu <[email protected]> Co-authored-by: Stephan Kö <[email protected]> Co-authored-by: アマデウス <[email protected]> Co-authored-by: Tong Li <[email protected]> Co-authored-by: zhurunhua <[email protected]> Co-authored-by: Insu Jang <[email protected]> Co-authored-by: Gao, Ruiyuan <[email protected]> Co-authored-by: hxwang <[email protected]> Co-authored-by: Michelle <[email protected]> Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local> * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update train_dpo.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update low_level_zero_plugin.py * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [CI] Remove triton version for compatibility bug; update req torch >=2.2 (#6018) * remove triton version * remove torch 2.2 * remove torch 2.1 * debug * remove 2.1 build tests * require torch >=2.2 --------- Co-authored-by: Edenzzzz <[email protected]> * [plugin] hotfix zero plugin (#6036) * [plugin] hotfix zero plugin * [plugin] hotfix zero plugin * [Colossal-LLaMA] Refactor latest APIs (#6030) * refactor latest code * update api * add dummy dataset * update Readme * add setup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update files * add PP support * update arguments * update argument * reorg folder * update version * remove IB infor * update utils * update readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update save for zero * update save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add apex * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add fused norm (#6038) * [FP8] unsqueeze scale to make it compatible with torch.compile (#6040) * [colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model; format error msg (#6020) * fix bug in load_state_dict_into_model; format error msg * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py to support checking missing_keys * Update general_checkpoint_io.py fix bug in missing_keys error message * retrigger tests --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hotfix] Remove deprecated install (#6042) * remove deprecated install * remove unused folder * [fp8] optimize all-gather (#6043) * [fp8] optimize all-gather * [fp8] fix all gather fp8 ring * [fp8] enable compile * [fp8] fix all gather fp8 ring * [fp8] fix linear hook (#6046) * [fp8] disable all_to_all_fp8 in intranode (#6045) * enhance all_to_all_fp8 with internode comm control * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable some fp8 ops due to performance issue * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [release] update version (#6041) * [release] update version * [devops] update comp test * [devops] update comp test debug * [devops] debug comp test * [devops] debug comp test * [devops] debug comp test * [devops] debug comp test * [devops] debug comp test * [Feature] Split cross-entropy computation in SP (#5959) * halfway * fix cross-PP-stage position id length diff bug * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * adapt chatglm, command-R, qwen * debug * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * add comments * q1 index only once * remove events to simplify stream sync * simplify forward/backward logic * 2d ring forward passed * 2d ring backward passed * fixes * fix ring attn loss * 2D ring backward + llama passed * merge * update logger * fix typo * rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove typos * fixes * support GPT --------- Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048) * [example] pass use_fp8_comm flag to all plugins * [example] add mixtral benchmark * [moe] refine assertion and check * [moe] fix mixtral & add more tests * [moe] consider checking dp * sp group and moe_dp_group * [mixtral] remove gate tp & add more tests * [deepseek] fix tp & sp for deepseek * [mixtral] minor fix * [deepseek] add deepseek benchmark * [fp8] hotfix backward hook (#6053) * [fp8] hotfix backward hook * [fp8] hotfix pipeline loss accumulation * [doc] update sp doc (#6055) * update sp doc * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix the sp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the attn * fix * fix * fix * fix * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * [fp8] fix missing fp8_comm flag in mixtral (#6057) * fix * fix * fix * [fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 (#6059) * all_gather only internode, fix pytest * fix cuda arch <89 compile pytest error * fix pytest failure * disable all_gather_into_tensor_flat_fp8 * fix fp8 format * fix pytest * fix conversations * fix chunk tuple to list * [doc] FP8 training and communication document (#6050) * Add FP8 training and communication document * add fp8 docstring for plugins * fix typo * fix typo * fix * fix * [moe] add parallel strategy for shared_expert && fix test for deepseek (#6063) * [ColossalEval] support for vllm (#6056) * support vllm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify vllm and update readme * run pre-commit * remove dupilicated lines and refine code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update param name * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine code * update readme * refine code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [release] update version (#6062) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] fix poc format * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix mem check; * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [feat] moehybrid support zerobubble; * [fix] fix zerobubble pp for shardformer type input; * [fix] fix require_grad & deallocate call; * [fix] fix mem assert; * [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs' * [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch; * [fix] fix zerobubble; support shardformer model type; * [fix] fix test_pipeline_utils ci; * [plugin] hybrid support zero bubble pipeline (#6060) * hybrid support zbv * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * hybrid support zbv * fix fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <[email protected]> * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] fix poc format * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [feat] update test; rm comments; * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix mem check; * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix mem assert; * [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs' * [plugin] hybrid support zero bubble pipeline (#6060) * hybrid support zbv * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * hybrid support zbv * fix fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: HangXu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: GuangyaoZhang <[email protected]> Co-authored-by: Hongxin Liu <[email protected]> Co-authored-by: YeAnbang <[email protected]> Co-authored-by: Haze188 <[email protected]> Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: Runyu Lu <[email protected]> Co-authored-by: YeAnbang <[email protected]> Co-authored-by: Stephan Kö <[email protected]> Co-authored-by: アマデウス <[email protected]> Co-authored-by: Tong Li <[email protected]> Co-authored-by: zhurunhua <[email protected]> Co-authored-by: Insu Jang <[email protected]> Co-authored-by: Gao, Ruiyuan <[email protected]> Co-authored-by: hxwang <[email protected]> Co-authored-by: Michelle <[email protected]> Co-authored-by: Wang Binluo <[email protected]> Co-authored-by: wangbluo <[email protected]> Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local> Co-authored-by: duanjunwen <[email protected]> Co-authored-by: Camille Zhong <[email protected]> * [fix] fix mixtral policy; * [fix] fix mixtral policy; * [feat] support zbv in mixtral benchmark; * [fix] MixtralForCausalLMPolicy get_held_layer support zbv; * [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv; * [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv * [zero bubble] support zero (#6080) * fp8 operators for compressed communication cast_to_fp8, cast_from_fp8, all_reduce_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * fix scaling algorithm in FP8 casting * support fp8 communication in pipeline parallelism * add fp8_communication flag in the script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * shardformer fp8 * fix rebase * remove all to all * fix shardformer fp8 communication training degradation * [fp8] support all-gather flat tensor (#5932) * [fp8] add fp8 comm for low level zero * [test] add zero fp8 test case * [Feature] llama shardf…
1 parent 184a653 commit e0c68ab

35 files changed

+3682
-123
lines changed

colossalai/amp/naive_amp/mixed_precision_mixin/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def zero_grad(self):
4343
dtype: torch.dtype
4444

4545
@abstractmethod
46-
def pre_backward(self, loss: Tensor) -> Tensor:
46+
def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
4747
"""Called before backward.
4848
4949
Args:

colossalai/amp/naive_amp/mixed_precision_optimizer.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,18 @@ def __init__(
8686
group["params"] = master_params
8787
self._current_grad_norm: Optional[float] = None
8888

89-
def backward(self, loss: Tensor, *args, **kwargs):
89+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
9090
loss = self.mixed_precision.pre_backward(loss)
91-
loss.backward(*args, **kwargs)
91+
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
9292

93-
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
93+
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
9494
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
95-
tensor.backward(grad)
95+
torch.autograd.backward(
96+
tensors=tensor,
97+
grad_tensors=grad,
98+
inputs=inputs,
99+
retain_graph=retain_graph,
100+
)
96101

97102
def zero_grad(self, *args, **kwargs):
98103
for p in self.working_to_master_map.keys():

colossalai/booster/mixed_precision/fp16_torch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def __init__(
4646
growth_interval=growth_interval,
4747
)
4848

49-
def backward(self, loss: Tensor, *args, **kwargs) -> None:
49+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
5050
scaled_loss = self.scale_loss(loss)
51-
scaled_loss.backward(*args, **kwargs)
51+
scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
5252

5353
def step(self, *args, **kwargs) -> Optional[float]:
5454
out = self.scaler.step(self.optim, *args, **kwargs)

colossalai/booster/plugin/hybrid_parallel_plugin.py

+43-21
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from colossalai.interface.optimizer import DistributedOptim
2929
from colossalai.logging import get_dist_logger
3030
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
31-
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
31+
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
3232
from colossalai.pipeline.stage_manager import PipelineStageManager
3333
from colossalai.quantization import BnbQuantizationConfig, quantize_model
3434
from colossalai.quantization.fp8_hook import FP8Hook
@@ -296,7 +296,7 @@ def __init__(
296296
self._current_grad_norm: Optional[float] = None
297297
super().__init__(optim)
298298

299-
def backward(self, loss: Tensor, *args, **kwargs):
299+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
300300
r"""
301301
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
302302
@@ -315,7 +315,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
315315

316316
# Call the superclass backward method to compute gradients.
317317
with self.model._hook_context():
318-
super().backward(loss, *args, **kwargs)
318+
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
319319

320320
if self.model.require_grad_sync:
321321
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -324,7 +324,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
324324
# If gradient synchronization is is not required, return.
325325
return
326326

327-
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
327+
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
328328
"""
329329
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
330330
@@ -341,7 +341,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
341341
"""
342342

343343
# Call the superclass backward method to compute gradients.
344-
super().backward_by_grad(tensor, grad)
344+
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
345345

346346
if self.model.require_grad_sync:
347347
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -525,7 +525,7 @@ def __init__(
525525
max_norm=max_norm,
526526
)
527527

528-
def backward(self, loss: Tensor, *args, **kwargs):
528+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
529529
r"""
530530
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
531531
@@ -543,7 +543,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
543543
"""
544544
# Call the superclass backward method to compute gradients.
545545
with self.model._hook_context():
546-
super().backward(loss, *args, **kwargs)
546+
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
547547

548548
if self.model.require_grad_sync:
549549
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -552,7 +552,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
552552
# If gradient synchronization is is not required, return.
553553
return
554554

555-
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
555+
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
556556
"""
557557
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
558558
@@ -568,7 +568,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
568568
None
569569
"""
570570
# Call the superclass backward method to compute gradients.
571-
super().backward_by_grad(tensor, grad)
571+
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
572572

573573
if self.model.require_grad_sync:
574574
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -785,7 +785,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
785785
else:
786786
return
787787

788-
def backward(self, loss, retain_graph=False):
788+
def backward(self, loss, inputs=None, retain_graph=False):
789789
"""
790790
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
791791
@@ -801,7 +801,7 @@ def backward(self, loss, retain_graph=False):
801801
None
802802
"""
803803
# Call the superclass backward method to compute gradients.
804-
super().backward(loss, retain_graph)
804+
super().backward(loss, inputs=inputs, retain_graph=retain_graph)
805805

806806
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
807807
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -810,7 +810,7 @@ def backward(self, loss, retain_graph=False):
810810
# If gradient synchronization is is not required, return.
811811
return
812812

813-
def backward_by_grad(self, tensor, grad):
813+
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
814814
"""
815815
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
816816
@@ -826,7 +826,7 @@ def backward_by_grad(self, tensor, grad):
826826
None
827827
"""
828828
# Call the superclass backward_by_grad method to compute gradients.
829-
super().backward_by_grad(tensor, grad)
829+
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
830830

831831
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
832832
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -1030,6 +1030,7 @@ def __init__(
10301030
custom_policy: Policy = None,
10311031
pp_style: str = "1f1b",
10321032
num_model_chunks: int = 1,
1033+
scheduler_nodes: List = None,
10331034
num_layers_per_stage: Optional[List[int]] = None,
10341035
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
10351036
enable_metadata_cache: bool = True,
@@ -1048,6 +1049,9 @@ def __init__(
10481049
dist.get_world_size() % (tp_size * pp_size) == 0
10491050
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
10501051

1052+
assert (
1053+
not pp_style == "zbv" or scheduler_nodes is not None
1054+
), f"scheduler_nodes must not be None when using zero bubble pipeline."
10511055
if enable_sequence_parallelism:
10521056
self.sequence_parallelism_mode = (
10531057
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
@@ -1109,29 +1113,39 @@ def __init__(
11091113
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
11101114

11111115
self.stage_manager = None
1112-
self.schedule = None
1116+
self.scheduler = None
11131117
self.custom_policy = custom_policy
11141118
assert zero_stage in (0, 1, 2)
11151119
if self.pp_size > 1:
1116-
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
1117-
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
1120+
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
1121+
assert (
1122+
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
1123+
), "num_model_chunks must be 1 when using 1f1b"
1124+
assert (
1125+
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
1126+
), "num_model_chunks must be 2 when using zero bubble pipeline"
11181127
assert (
11191128
num_microbatches is not None or microbatch_size is not None
11201129
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
11211130
assert (
11221131
self.zero_stage <= 1
11231132
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
1133+
if pp_style == "zbv":
1134+
self.logger.warning(
1135+
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
1136+
)
11241137
self.stage_manager = PipelineStageManager(
11251138
self.pg_mesh,
11261139
pipeline_axis=self.pp_axis,
1127-
enable_interleave=pp_style == "interleaved",
1140+
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
1141+
use_zbv=(pp_style == "zbv"),
11281142
num_model_chunks=num_model_chunks,
11291143
num_layers_per_stage=num_layers_per_stage,
11301144
)
11311145

11321146
if pp_style == "interleaved":
11331147
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
1134-
self.schedule = InterleavedSchedule(
1148+
self.scheduler = InterleavedSchedule(
11351149
stage_manager=self.stage_manager,
11361150
num_model_chunks=num_model_chunks,
11371151
num_microbatch=num_microbatches,
@@ -1141,13 +1155,21 @@ def __init__(
11411155
fp8_communication=fp8_communication,
11421156
)
11431157
elif pp_style == "1f1b":
1144-
self.schedule = OneForwardOneBackwardSchedule(
1158+
self.scheduler = OneForwardOneBackwardSchedule(
11451159
stage_manager=self.stage_manager,
11461160
num_microbatches=num_microbatches,
11471161
microbatch_size=microbatch_size,
11481162
enable_metadata_cache=enable_metadata_cache,
11491163
fp8_communication=fp8_communication,
11501164
)
1165+
elif pp_style == "zbv":
1166+
self.scheduler = ZeroBubbleVPipeScheduler(
1167+
stage_manager=self.stage_manager,
1168+
schedule=scheduler_nodes,
1169+
num_model_chunks=num_model_chunks,
1170+
num_microbatch=num_microbatches,
1171+
microbatch_size=microbatch_size,
1172+
)
11511173
else:
11521174
raise NotImplementedError()
11531175
if sequence_parallelism_mode == "ring_attn":
@@ -1263,7 +1285,6 @@ def configure(
12631285

12641286
# Replace with distributed implementation if exists
12651287
optimizer = cast_to_distributed(optimizer)
1266-
12671288
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
12681289
self.logger.warning(
12691290
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
@@ -1278,6 +1299,7 @@ def configure(
12781299
self.dp_size == 1 and self.pp_size == 1
12791300
)
12801301
# sync gradients across DP * SP ranks
1302+
# sync gradients across DP * SP ranks
12811303
# Apply Hybrid ZeRO across DP * SP ranks
12821304
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
12831305
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
@@ -1380,7 +1402,7 @@ def execute_pipeline(
13801402
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
13811403

13821404
with ctx, model._hook_context():
1383-
outputs = self.schedule.forward_backward_step(
1405+
outputs = self.scheduler.forward_backward_step(
13841406
model, data_iter, criterion, optimizer, return_loss, return_outputs
13851407
)
13861408

colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from colossalai.nn.optimizer import cast_to_distributed
3030
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
3131
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
32+
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
3233
from colossalai.pipeline.stage_manager import PipelineStageManager
3334
from colossalai.shardformer.policies.base_policy import Policy
3435
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
@@ -212,6 +213,7 @@ def __init__(
212213
custom_policy: Policy = None,
213214
pp_style: str = "1f1b",
214215
num_model_chunks: int = 1,
216+
scheduler_nodes: List = None,
215217
num_layers_per_stage: Optional[List[int]] = None,
216218
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
217219
enable_metadata_cache: bool = True,
@@ -285,12 +287,17 @@ def __init__(
285287
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
286288

287289
self.stage_manager = None
288-
self.schedule = None
290+
self.scheduler = None
289291
self.custom_policy = custom_policy
290292
assert zero_stage in (0, 1, 2)
291293
if self.pp_size > 1:
292-
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
293-
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
294+
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
295+
assert (
296+
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
297+
), "num_model_chunks must be 1 when using 1f1b"
298+
assert (
299+
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
300+
), "num_model_chunks must be 2 when using zero bubble pipeline"
294301
assert (
295302
num_microbatches is not None or microbatch_size is not None
296303
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
@@ -300,14 +307,15 @@ def __init__(
300307
self.stage_manager = PipelineStageManager(
301308
self.pg_mesh,
302309
pipeline_axis=self.pp_axis,
303-
enable_interleave=pp_style == "interleaved",
310+
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
304311
num_model_chunks=num_model_chunks,
305312
num_layers_per_stage=num_layers_per_stage,
313+
use_zbv=(pp_style == "zbv"),
306314
)
307315

308316
if pp_style == "interleaved":
309317
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
310-
self.schedule = InterleavedSchedule(
318+
self.scheduler = InterleavedSchedule(
311319
stage_manager=self.stage_manager,
312320
num_model_chunks=num_model_chunks,
313321
num_microbatch=num_microbatches,
@@ -316,12 +324,21 @@ def __init__(
316324
overlap_p2p=overlap_p2p,
317325
)
318326
elif pp_style == "1f1b":
319-
self.schedule = OneForwardOneBackwardSchedule(
327+
self.scheduler = OneForwardOneBackwardSchedule(
320328
stage_manager=self.stage_manager,
321329
num_microbatches=num_microbatches,
322330
microbatch_size=microbatch_size,
323331
enable_metadata_cache=enable_metadata_cache,
324332
)
333+
elif pp_style == "zbv":
334+
assert num_model_chunks > 1, "number of model chunks must be > 1 when using ZerbubbleV"
335+
self.scheduler = ZeroBubbleVPipeScheduler(
336+
schedule=scheduler_nodes,
337+
stage_manager=self.stage_manager,
338+
num_model_chunks=num_model_chunks,
339+
num_microbatch=num_microbatches,
340+
overlap_p2p=overlap_p2p,
341+
)
325342
else:
326343
raise NotImplementedError()
327344

colossalai/interface/optimizer.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,31 @@ def zero_grad(self, *args, **kwargs):
4949
"""
5050
self.optim.zero_grad(*args, **kwargs)
5151

52-
def backward(self, loss: Tensor, *args, **kwargs):
52+
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
5353
"""
5454
Performs a backward pass on the loss.
5555
"""
56-
loss.backward(*args, **kwargs)
56+
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
5757

58-
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
59-
torch.autograd.backward(tensor, grad)
58+
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
59+
"""
60+
Performs a backward pass for dx or dw,
61+
for dx, we only calculate dx = w*dy here
62+
for dw, we only calculate dw = x*dy here
63+
64+
Args:
65+
tensor (Tensor): y or loss of current chunk;
66+
grad_tensors (Tensor): dy of current chunk;
67+
input_obj (Tensor): for dx, input_obj is x of current chunk;
68+
for dw, input_obj is w of current chunk;
69+
retain_graph (bool): default to be True, we retain graph in backward_b
70+
"""
71+
torch.autograd.backward(
72+
tensors=tensor,
73+
grad_tensors=grad,
74+
inputs=inputs,
75+
retain_graph=retain_graph,
76+
)
6077

6178
def state_dict(self):
6279
"""

colossalai/pipeline/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from .p2p import PipelineP2PCommunication
2-
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule
2+
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule, ZeroBubbleVPipeScheduler
33
from .stage_manager import PipelineStageManager
44

55
__all__ = [
66
"PipelineSchedule",
77
"OneForwardOneBackwardSchedule",
88
"InterleavedSchedule",
9+
"ZeroBubbleVPipeScheduler",
910
"PipelineP2PCommunication",
1011
"PipelineStageManager",
1112
]

0 commit comments

Comments
 (0)