1717import torch ._inductor .config
1818import torch .distributed as dist
1919
20- from torchchat .distributed .utils import (
20+ from torchchat .distributed .logging_utils import SingletonLogger
21+
22+ from torchchat .distributed .utils import (
2123 Color as color ,
2224 CUDATrackTime ,
23- init_distributed ,
2425 GPUMemoryMonitor ,
26+ init_distributed ,
2527)
26- from torchchat .distributed .logging_utils import SingletonLogger
2728
2829from torchchat .model import Model , ModelArgs , ModelType , Transformer , TransformerArgs
2930from torchchat .model_config .model_config import resolve_model_config
3738from torchchat .utils .quantize import quantize_model
3839
3940
40- from torchtune .models .convert_weights import meta_to_tune
41-
42- from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
43-
44- from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
45-
46- from torchtune .training import set_default_dtype
47-
48-
4941@dataclass
5042class BuilderArgs :
5143 checkpoint_path : Optional [Union [Path , str ]] = None
@@ -188,15 +180,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
188180 tp = getattr (args , "tp" , 1 )
189181 chpt_from = getattr (args , "chpt_from" , "hf" )
190182 sdp_backend_dict = {
191- ' math' : torch .nn .attention .SDPBackend .MATH ,
192- ' flash_attention' : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
193- ' efficient_attention' : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
194- ' cudnn_attention' : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
183+ " math" : torch .nn .attention .SDPBackend .MATH ,
184+ " flash_attention" : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
185+ " efficient_attention" : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
186+ " cudnn_attention" : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
195187 }
196188 attention_backend = sdp_backend_dict [args .attention_backend ]
197- if args .device == "cpu" and (args .attention_backend == "efficient_attention"
198- or args .attention_backend == "cudnn_attention" ):
199- print (f"Warning: { args .attention_backend } is not supported on CPU. Using math instead." )
189+ if args .device == "cpu" and (
190+ args .attention_backend == "efficient_attention"
191+ or args .attention_backend == "cudnn_attention"
192+ ):
193+ print (
194+ f"Warning: { args .attention_backend } is not supported on CPU. Using math instead."
195+ )
200196 attention_backend = torch .nn .attention .SDPBackend .MATH
201197 return cls (
202198 checkpoint_dir = checkpoint_dir ,
@@ -238,12 +234,14 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
238234 speculative_builder_args .pte_path = None
239235 return speculative_builder_args
240236
237+
241238class TokenizerType (Enum ):
242239 NONE = 0
243240 TIKTOKEN = 1
244241 SENTENCEPIECE = 2
245242 HF_TOKENIZER = 3
246243
244+
247245@dataclass
248246class TokenizerArgs :
249247 tokenizer_path : Optional [Union [Path , str ]] = None
@@ -307,9 +305,9 @@ def validate_model(
307305 use_sentencepiece = not (use_tiktoken or use_hf_tokenizer )
308306
309307 if (
310- (is_tiktoken and not use_tiktoken ) or
311- (is_hf_tokenizer and not use_hf_tokenizer ) or
312- (is_sentencepiece and not use_sentencepiece )
308+ (is_tiktoken and not use_tiktoken )
309+ or (is_hf_tokenizer and not use_hf_tokenizer )
310+ or (is_sentencepiece and not use_sentencepiece )
313311 ):
314312 raise RuntimeError (
315313 "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}" .format (
@@ -417,6 +415,7 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model:
417415
418416def _load_checkpoint (builder_args : BuilderArgs ):
419417 if builder_args .params_table and builder_args .params_table .endswith ("Tune" ):
418+ from torchtune .models .convert_weights import meta_to_tune
420419 print ("Loading Tune checkpoint" )
421420 meta_checkpoint = torch .load (
422421 str (builder_args .checkpoint_path ), mmap = True , weights_only = True
@@ -469,9 +468,15 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
469468 checkpoint = checkpoint ["model" ]
470469
471470 if model .config .model_type == ModelType .Flamingo :
471+ from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
472+ from torchtune .models .llama3_2_vision ._convert_weights import (
473+ llama3_vision_meta_to_tune ,
474+ )
475+ from torchtune .training import set_default_dtype
472476 # TODO: Refactor this. For now, overwrite the model with model loaded from params_path
473- with set_default_dtype (builder_args .precision ), torch .device (
474- builder_args .device
477+ with (
478+ set_default_dtype (builder_args .precision ),
479+ torch .device (builder_args .device ),
475480 ):
476481 # It doubles the model size the memory, with redundancies of the initialized weights.
477482 # model = Model.from_params(builder_args.params_path)
@@ -507,6 +512,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
507512 # AOTI-compoiled model will load its own weights.
508513 # Release weights here to avoid OOM
509514 import gc
515+
510516 if hasattr (model , "model" ):
511517 model .model = None
512518 gc .collect ()
@@ -564,6 +570,7 @@ def _initialize_model(
564570
565571 def do_nothing (max_batch_size , max_seq_length ):
566572 pass
573+
567574 model .setup_caches = do_nothing
568575
569576 model .forward = torch ._export .aot_load (
@@ -601,6 +608,7 @@ def do_nothing(max_batch_size, max_seq_length):
601608
602609 def do_nothing (max_batch_size , max_seq_length ):
603610 pass
611+
604612 model .setup_caches = do_nothing
605613
606614 model .forward = aoti_compiled_model
@@ -652,12 +660,15 @@ def do_nothing(max_batch_size, max_seq_length):
652660 try :
653661 model = torch .load (builder_args .snapshot_path , weights_only = False )
654662 except Exception :
655- raise RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path } " )
663+ raise RuntimeError (
664+ f"Failed to load torchchat snapshot { builder_args .snapshot_path } "
665+ )
656666 # _active_backend() does not allow DSO & AOTI to be true.
657667 # Choose either.
658668 from torchchat .utils .build_utils import set_backend
659- set_backend (dso = True , pte = False , aoti_package = False )
660- if (model .config != config ):
669+
670+ set_backend (dso = True , pte = False , aoti_package = False )
671+ if model .config != config :
661672 raise RuntimeError ("loaded model architecture mismatch" )
662673 ##
663674 ## import all libraries with custom kernels ans custom operators
@@ -675,7 +686,9 @@ def do_nothing(max_batch_size, max_seq_length):
675686 logger = SingletonLogger .get_logger ()
676687
677688 gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
678- logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
689+ logger .info (
690+ f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } "
691+ )
679692
680693 # Model-level config
681694 if builder_args .params_table :
@@ -686,20 +699,16 @@ def do_nothing(max_batch_size, max_seq_length):
686699 config = TransformerArgs .from_params (model_config .transformer_args ["text" ])
687700 logger .info (f"Transformer Config: { config } " )
688701
689- #TODO: Move into head of file after solving circular import
690- from torchchat .distributed .checkpoint_utils import (
691- load_model_weights ,
692- )
702+ # TODO: Move into head of file after solving circular import
703+ from torchchat .distributed .checkpoint_utils import load_model_weights
693704
694705 # Validate pipeline degree
695706 assert config .n_layers % pp_degree == 0
696707
697708 # Create device mesh
698709 device_mesh = dist .init_device_mesh (
699- "cuda" ,
700- (pp_degree , tp_degree ),
701- mesh_dim_names = ("pp" , "tp" )
702- )
710+ "cuda" , (pp_degree , tp_degree ), mesh_dim_names = ("pp" , "tp" )
711+ )
703712 tp_mesh = device_mesh ["tp" ]
704713 pp_mesh = device_mesh ["pp" ]
705714 logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } , { pp_mesh = } " )
@@ -728,7 +737,13 @@ def do_nothing(max_batch_size, max_seq_length):
728737 # Load weights
729738 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
730739 with CUDATrackTime () as timer :
731- load_model_weights (model , builder_args .distribution_path , device , config , builder_args .chpt_from )
740+ load_model_weights (
741+ model ,
742+ builder_args .distribution_path ,
743+ device ,
744+ config ,
745+ builder_args .chpt_from ,
746+ )
732747
733748 logger .info (
734749 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
@@ -742,7 +757,7 @@ def do_nothing(max_batch_size, max_seq_length):
742757 # lanes.
743758 # TODO: bump up the lane count
744759 pipeline_lanes = 1
745- seqlen_prefill = 1024
760+ seqlen_prefill = 1024
746761 with device :
747762 model .setup_caches (1 , seqlen_prefill , cache_lanes = pipeline_lanes )
748763
0 commit comments