@@ -659,6 +659,9 @@ def __init__(self, model_name: str, device: str = "auto", **kwargs):
659659 self .low_cpu_mem_usage = kwargs .get ("low_cpu_mem_usage" , True )
660660 self .device_map = kwargs .get ("device_map" , "auto" if device == "auto" else None )
661661 self ._processor = None
662+ # Fallback mode when a repo provides a VLM through a causal LM head with remote code
663+ # (e.g., custom multimodal repos that don't map to AutoModelForVision2Seq configs).
664+ self ._causal_vlm_mode : bool = False
662665
663666 def load (self ) -> None :
664667 """Load VLM using transformers."""
@@ -667,7 +670,7 @@ def load(self) -> None:
667670
668671 logger .info (f"Loading { self .model_name } with transformers VLM backend" )
669672
670- # Load processor
673+ # Load processor first (used by both primary and fallback flows)
671674 self ._processor = AutoProcessor .from_pretrained (
672675 self .model_name , use_fast = True , trust_remote_code = True
673676 )
@@ -727,6 +730,7 @@ def load(self) -> None:
727730 )
728731
729732 try :
733+ # Primary path: standard Vision2Seq auto-model
730734 self ._model = AutoModelForVision2Seq .from_pretrained (
731735 self .model_name , ** model_kwargs
732736 )
@@ -740,10 +744,41 @@ def load(self) -> None:
740744 logger .info ("Moving model to MPS device..." )
741745 self ._model = self ._model .to (torch .device ("mps" ))
742746
743- logger .info (f"Successfully loaded { self .model_name } with transformers VLM" )
744- except Exception as e :
745- logger .error (f"Failed to load { self .model_name } with transformers VLM: { e } " )
746- raise
747+ logger .info (
748+ f"Successfully loaded { self .model_name } with transformers VLM (AutoModelForVision2Seq)"
749+ )
750+ except Exception as e_primary :
751+ # Fallback path: some repos expose multimodal models via a Causal LM with remote code
752+ logger .warning (
753+ f"Primary VLM load failed for { self .model_name } ({ e_primary } ). "
754+ f"Attempting fallback with AutoModelForCausalLM + trust_remote_code..."
755+ )
756+ try :
757+ causal_kwargs = dict (model_kwargs )
758+ # In causal path, device_map/torch_dtype still apply
759+ self ._model = AutoModelForCausalLM .from_pretrained (
760+ self .model_name , ** causal_kwargs
761+ )
762+ self ._causal_vlm_mode = True
763+
764+ # Move to MPS if needed
765+ if (
766+ actual_device == "mps"
767+ and hasattr (torch .backends , "mps" )
768+ and torch .backends .mps .is_available ()
769+ ):
770+ logger .info ("Moving causal VLM model to MPS device..." )
771+ self ._model = self ._model .to (torch .device ("mps" ))
772+
773+ logger .info (
774+ f"Successfully loaded { self .model_name } as causal VLM (AutoModelForCausalLM)"
775+ )
776+ except Exception as e_fallback :
777+ logger .error (
778+ f"Failed to load { self .model_name } with both primary and fallback VLM paths. "
779+ f"Primary error: { e_primary } ; Fallback error: { e_fallback } "
780+ )
781+ raise
747782
748783 def generate_from_conversation (
749784 self , conversation : List [Dict ], config : GenerationConfig
@@ -752,53 +787,69 @@ def generate_from_conversation(
752787 if not self .is_loaded ():
753788 raise RuntimeError ("Model not loaded" )
754789
755- # Process the conversation
790+ # Extract all images from the conversation (support single or multiple)
791+ images : List [Any ] = []
792+ for message in conversation :
793+ content = message .get ("content" )
794+ if isinstance (content , list ):
795+ for item in content :
796+ if isinstance (item , dict ) and item .get ("type" ) == "image" :
797+ if "image" in item :
798+ images .append (item ["image" ])
799+
800+ # Processor chat template (common to both paths)
756801 formatted_text = self ._processor .apply_chat_template (
757802 conversation , add_generation_prompt = True , tokenize = False
758803 )
759804
760- # Extract image from conversation
761- image = None
762- for message in conversation :
763- if isinstance (message .get ("content" ), list ):
764- for content_item in message ["content" ]:
765- if content_item .get ("type" ) == "image" :
766- image = content_item .get ("image" )
767- break
768- if image :
769- break
805+ # Build processor inputs
806+ proc_kwargs = {"text" : formatted_text , "return_tensors" : "pt" }
807+ if images :
808+ # Pass list of images if multiple; processor should handle both
809+ proc_kwargs ["images" ] = images if len (images ) > 1 else images [0 ]
770810
771- # Process inputs
772- inputs = self ._processor (text = formatted_text , images = image , return_tensors = "pt" )
811+ inputs = self ._processor (** proc_kwargs )
773812
774- # Move to device
813+ # Determine target device
775814 if hasattr (self ._model , "device" ):
776815 device = self ._model .device
777816 else :
778- # Try to infer device from model parameters
779817 try :
780818 device = next (self ._model .parameters ()).device
781819 except (StopIteration , AttributeError , RuntimeError ):
782820 device = torch .device ("cpu" )
783821
784822 # Move inputs to the appropriate device
785823 if device .type != "cpu" :
786- inputs = {
787- k : v .to (device ) if torch .is_tensor (v ) else v for k , v in inputs .items ()
788- }
824+ inputs = {k : v .to (device ) if torch .is_tensor (v ) else v for k , v in inputs .items ()}
789825
790- # Generate
826+ # Prepare generation kwargs (use conservative defaults, compatible with varied repos)
791827 gen_kwargs = config .to_transformers_kwargs ()
792- gen_kwargs ["pad_token_id" ] = self ._processor .tokenizer .eos_token_id
828+
829+ # Determine pad token id (prefer processor tokenizer; otherwise try model)
830+ pad_id = None
831+ try :
832+ pad_id = self ._processor .tokenizer .eos_token_id
833+ except Exception :
834+ try :
835+ pad_id = getattr (self ._model .config , "eos_token_id" , None )
836+ except Exception :
837+ pad_id = None
838+ if pad_id is not None :
839+ gen_kwargs ["pad_token_id" ] = pad_id
793840
794841 with torch .no_grad ():
795- generate_ids = self ._model .generate (** inputs , ** gen_kwargs )
842+ outputs = self ._model .generate (** inputs , ** gen_kwargs )
843+
844+ # Decode response; trim input ids if present
845+ try :
846+ input_len = inputs ["input_ids" ].shape [1 ]
847+ trimmed = outputs [:, input_len :]
848+ except Exception :
849+ trimmed = outputs
796850
797- # Decode response
798851 response = self ._processor .batch_decode (
799- generate_ids [:, inputs ["input_ids" ].shape [1 ] :],
800- skip_special_tokens = True ,
801- clean_up_tokenization_spaces = False ,
852+ trimmed , skip_special_tokens = True , clean_up_tokenization_spaces = False
802853 )[0 ]
803854
804855 return response
0 commit comments