@@ -41,14 +41,46 @@ Result<uint64_t> MultimodalPrefiller::prefill(
41
41
::executorch::runtime::EValue encoder_output;
42
42
if (input.is_image ()) {
43
43
Image image = input.get_image ();
44
- auto image_tensor = executorch::extension::from_blob (
45
- image.data .data (),
46
- {3 , image.height , image.width },
47
- ::executorch::aten::ScalarType::Byte);
44
+
45
+ auto method_meta = ET_UNWRAP (
46
+ module_->method_meta (kVisionEncoderMethod ),
47
+ " Failed to get method_meta for %s" ,
48
+ kVisionEncoderMethod );
49
+
50
+ ET_CHECK_MSG (
51
+ method_meta.num_inputs () > 0 ,
52
+ " Image encoder should have at least 1 input" );
53
+ auto input_meta = ET_UNWRAP (
54
+ method_meta.input_tensor_meta (0 ),
55
+ " Cannot get input tensor meta at index 0" );
56
+ auto expected_dtype = input_meta.scalar_type ();
57
+
58
+ if (expected_dtype == ::executorch::aten::ScalarType::Float) {
59
+ ET_CHECK_MSG (
60
+ image.is_float (),
61
+ " Model expects float image data, but image has uint8_t data." );
62
+ } else if (expected_dtype == ::executorch::aten::ScalarType::Byte) {
63
+ ET_CHECK_MSG (
64
+ image.is_uint8 (),
65
+ " Model expects uint8_t image data, but image has float data." );
66
+ } else {
67
+ ET_LOG (
68
+ Error,
69
+ " Unsupported image encoder input dtype: %s" ,
70
+ ::executorch::runtime::toString (expected_dtype));
71
+ return ::executorch::runtime::Error::NotSupported;
72
+ }
73
+
74
+ // The model might expect a 4D tensor (NCHW), but toTensor() returns a 3D
75
+ // tensor (CHW). Add a batch dimension of 1 if needed.
76
+ auto expected_dims = input_meta.sizes ();
77
+ auto image_tensor = ET_UNWRAP (
78
+ image.toTensor (/* with_batch*/ expected_dims.size () == 4 ),
79
+ " Failed to convert image to tensor" );
48
80
49
81
// Run image encoder
50
82
auto image_encoder_outputs =
51
- ET_UNWRAP (module_->execute (kImageEncoderMethod , image_tensor));
83
+ ET_UNWRAP (module_->execute (kVisionEncoderMethod , image_tensor));
52
84
53
85
encoder_output = image_encoder_outputs[0 ];
54
86
} else if (input.is_audio ()) {
@@ -143,8 +175,8 @@ ::executorch::runtime::Error MultimodalPrefiller::load() {
143
175
ET_UNWRAP (module_->method_names (), " Failed to get method names" );
144
176
145
177
// Load image_encoder method if exists.
146
- if (methods.find (kImageEncoderMethod ) != methods.end ()) {
147
- ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (kImageEncoderMethod ));
178
+ if (methods.find (kVisionEncoderMethod ) != methods.end ()) {
179
+ ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (kVisionEncoderMethod ));
148
180
}
149
181
150
182
if (methods.find (kAudioEncoderMethod ) != methods.end ()) {
@@ -171,8 +203,8 @@ bool MultimodalPrefiller::is_method_loaded() {
171
203
ET_CHECK_MSG (false , " Failed to get method names" );
172
204
}
173
205
std::unordered_set<std::string> methods = methods_res.get ();
174
- if (methods.find (kImageEncoderMethod ) != methods.end ()) {
175
- return module_->is_method_loaded (kImageEncoderMethod );
206
+ if (methods.find (kVisionEncoderMethod ) != methods.end ()) {
207
+ return module_->is_method_loaded (kVisionEncoderMethod );
176
208
}
177
209
return true ;
178
210
}
0 commit comments