1616from transformers .models .auto .modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
1717
1818from QEfficient .base .common import QEFFCommonLoader
19- from QEfficient .utils import check_and_assign_cache_dir , load_hf_tokenizer
19+ from QEfficient .utils import check_and_assign_cache_dir , load_hf_tokenizer , constants
2020from QEfficient .utils .logging_utils import logger
2121
2222
@@ -41,7 +41,6 @@ def main(
4141 allow_mxint8_mdp_io : bool = False ,
4242 enable_qnn : Optional [bool ] = False ,
4343 qnn_config : Optional [str ] = None ,
44- img_size : Optional [int ] = None ,
4544 ** kwargs ,
4645) -> None :
4746 """
@@ -89,9 +88,6 @@ def main(
8988 if args .mxint8 :
9089 logger .warning ("mxint8 is going to be deprecated in a future release, use -mxint8_kv_cache instead." )
9190
92- image_path = kwargs .pop ("image_path" , None )
93- image_url = kwargs .pop ("image_url" , None )
94-
9591 qeff_model = QEFFCommonLoader .from_pretrained (
9692 pretrained_model_name_or_path = model_name ,
9793 cache_dir = cache_dir ,
@@ -100,6 +96,16 @@ def main(
10096 local_model_dir = local_model_dir ,
10197 )
10298
99+ image_path = kwargs .pop ("image_path" , None )
100+ image_url = kwargs .pop ("image_url" , None )
101+
102+ config = qeff_model .model .config
103+ architecture = config .architectures [0 ] if config .architectures else None
104+ if architecture not in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES .values ():
105+ img_size = kwargs .pop ("img_size" , None )
106+ if img_size or image_path or image_url :
107+ logger .warning (f"Skipping image arguments as they are not valid for { architecture } " )
108+
103109 #########
104110 # Compile
105111 #########
@@ -117,38 +123,21 @@ def main(
117123 allow_mxint8_mdp_io = allow_mxint8_mdp_io ,
118124 enable_qnn = enable_qnn ,
119125 qnn_config = qnn_config ,
120- img_size = img_size ,
121126 ** kwargs ,
122127 )
123128
124129 #########
125130 # Execute
126131 #########
127- config = qeff_model .model .config
128- architecture = config .architectures [0 ] if config .architectures else None
129-
130132 if architecture in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES .values ():
131133 processor = AutoProcessor .from_pretrained (model_name , use_fast = False )
132134
133- raw_image = None
134- if image_url is not None :
135- raw_image = Image .open (requests .get (image_url , stream = True ).raw )
136- elif image_path is not None :
137- raw_image = Image .open (image_path )
138- else :
139- raise FileNotFoundError (
140- 'Neither Image URL nor Image Path is found, either provide "image_url" or "image_path"'
141- )
135+ if not (image_url or image_path ):
136+ raise ValueError ('Neither Image URL nor Image Path is found, either provide "image_url" or "image_path"' )
137+ raw_image = Image .open (requests .get (image_url , stream = True ).raw ) if image_url else Image .open (image_path )
142138
143- conversation = [
144- {
145- "role" : "user" ,
146- "content" : [
147- {"type" : "image" },
148- {"type" : "text" , "text" : prompt [0 ]}, # Currently accepting only 1 prompt
149- ],
150- },
151- ]
139+ conversation = constants .Constants .conversation
140+ conversation [0 ]["content" ][1 ].update ({"text" : prompt [0 ]}) # Currently accepting only 1 prompt
152141
153142 # Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token ids.
154143 input_text = processor .apply_chat_template (conversation , add_generation_prompt = True , tokenize = False )
@@ -277,19 +266,21 @@ def main(
277266 "--enable_qnn" ,
278267 "--enable-qnn" ,
279268 action = "store_true" ,
269+ nargs = "?" ,
270+ const = True ,
271+ type = str ,
280272 default = False ,
281273 help = "Enables QNN. Optionally, a configuration file can be provided with [--enable_qnn CONFIG_FILE].\
282274 If not provided, the default configuration will be used.\
283275 Sample Config: QEfficient/compile/qnn_config.json" ,
284276 )
285- parser .add_argument (
286- "--qnn_config" ,
287- nargs = "?" ,
288- type = str ,
289- )
290- parser .add_argument ("--img-size" , "--img_size" , default = None , type = int , required = False , help = "Size of Image" )
291277
292278 args , compiler_options = parser .parse_known_args ()
279+
280+ if isinstance (args .enable_qnn , str ):
281+ args .qnn_config = args .enable_qnn
282+ args .enable_qnn = True
283+
293284 compiler_options_dict = {}
294285 for i in range (0 , len (compiler_options )):
295286 if compiler_options [i ].startswith ("--" ):
0 commit comments