Skip to content

Commit 573f30d

Browse files
Rename image_encoder to vision_encoder to match HF naming convention (#14473)
Summary: As titled. We want to align with `optimum-executorch` naming convension ( which comes from HF `transformers`): https://github.com/huggingface/optimum-executorch/blob/main/optimum/exporters/executorch/tasks/multimodal_text_to_text.py#L238 Differential Revision: D82677835 Co-authored-by: Mengwei Liu <[email protected]>
1 parent a017509 commit 573f30d

File tree

5 files changed

+47
-15
lines changed

5 files changed

+47
-15
lines changed

examples/models/llava/export_llava.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,12 @@ def export_all(llava_model: LlavaModel):
224224

225225
lowered_and_edge = to_edge_transform_and_lower(
226226
{
227-
"image_encoder": image_encoder_ep,
227+
"vision_encoder": image_encoder_ep,
228228
"token_embedding": token_embedding_ep,
229229
"text_decoder": text_model_ep,
230230
},
231231
partitioner={
232-
"image_encoder": [XnnpackPartitioner()],
232+
"vision_encoder": [XnnpackPartitioner()],
233233
"text_decoder": [
234234
# First partition the DQLinear nodes, then partition the rest of the nodes,
235235
# to avoid multiple DQLinear nodes in the same partition,
@@ -254,7 +254,7 @@ def export_all(llava_model: LlavaModel):
254254
],
255255
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
256256
sym_shape_eval_pass={
257-
"image_encoder": ConstraintBasedSymShapeEvalPass(),
257+
"vision_encoder": ConstraintBasedSymShapeEvalPass(),
258258
"text_decoder": ConstraintBasedSymShapeEvalPass(),
259259
"token_embedding": HintBasedSymShapeEvalPass(),
260260
},

examples/models/llava/test/test_llava.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_llava_export(self):
105105
start_pos += pte_embeds_before_img.shape[1]
106106

107107
# pte prefill image
108-
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
108+
pte_embeds_img = llava_module.run_method("vision_encoder", (resized,))[0]
109109
llava_module.run_method(
110110
"text_decoder",
111111
(

examples/models/llava/test/test_pte.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def main():
5656

5757
# pte prefill image
5858
logging.warning("Image encoder started")
59-
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
59+
pte_embeds_img = llava_module.run_method("vision_encoder", (resized,))[0]
6060
logging.warning("Image encoder finished")
6161
logging.warning("Image token prefill started")
6262
pte_prefill_img = llava_module.run_method(

extension/llm/runner/constants.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ inline constexpr auto kUseKVCache = "use_kv_cache";
2020
inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
2121

2222
// Multimodal method name conventions
23-
inline constexpr auto kImageEncoderMethod = "image_encoder";
23+
inline constexpr auto kVisionEncoderMethod = "vision_encoder";
2424
inline constexpr auto kAudioEncoderMethod = "audio_encoder";
2525
inline constexpr auto kTokenEmbeddingMethod = "token_embedding";
2626
inline constexpr auto kTextModelMethod = "text_decoder";

extension/llm/runner/multimodal_prefiller.cpp

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,46 @@ Result<uint64_t> MultimodalPrefiller::prefill(
4141
::executorch::runtime::EValue encoder_output;
4242
if (input.is_image()) {
4343
Image image = input.get_image();
44-
auto image_tensor = executorch::extension::from_blob(
45-
image.data.data(),
46-
{3, image.height, image.width},
47-
::executorch::aten::ScalarType::Byte);
44+
45+
auto method_meta = ET_UNWRAP(
46+
module_->method_meta(kVisionEncoderMethod),
47+
"Failed to get method_meta for %s",
48+
kVisionEncoderMethod);
49+
50+
ET_CHECK_MSG(
51+
method_meta.num_inputs() > 0,
52+
"Image encoder should have at least 1 input");
53+
auto input_meta = ET_UNWRAP(
54+
method_meta.input_tensor_meta(0),
55+
"Cannot get input tensor meta at index 0");
56+
auto expected_dtype = input_meta.scalar_type();
57+
58+
if (expected_dtype == ::executorch::aten::ScalarType::Float) {
59+
ET_CHECK_MSG(
60+
image.is_float(),
61+
"Model expects float image data, but image has uint8_t data.");
62+
} else if (expected_dtype == ::executorch::aten::ScalarType::Byte) {
63+
ET_CHECK_MSG(
64+
image.is_uint8(),
65+
"Model expects uint8_t image data, but image has float data.");
66+
} else {
67+
ET_LOG(
68+
Error,
69+
"Unsupported image encoder input dtype: %s",
70+
::executorch::runtime::toString(expected_dtype));
71+
return ::executorch::runtime::Error::NotSupported;
72+
}
73+
74+
// The model might expect a 4D tensor (NCHW), but toTensor() returns a 3D
75+
// tensor (CHW). Add a batch dimension of 1 if needed.
76+
auto expected_dims = input_meta.sizes();
77+
auto image_tensor = ET_UNWRAP(
78+
image.toTensor(/*with_batch*/ expected_dims.size() == 4),
79+
"Failed to convert image to tensor");
4880

4981
// Run image encoder
5082
auto image_encoder_outputs =
51-
ET_UNWRAP(module_->execute(kImageEncoderMethod, image_tensor));
83+
ET_UNWRAP(module_->execute(kVisionEncoderMethod, image_tensor));
5284

5385
encoder_output = image_encoder_outputs[0];
5486
} else if (input.is_audio()) {
@@ -143,8 +175,8 @@ ::executorch::runtime::Error MultimodalPrefiller::load() {
143175
ET_UNWRAP(module_->method_names(), "Failed to get method names");
144176

145177
// Load image_encoder method if exists.
146-
if (methods.find(kImageEncoderMethod) != methods.end()) {
147-
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kImageEncoderMethod));
178+
if (methods.find(kVisionEncoderMethod) != methods.end()) {
179+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kVisionEncoderMethod));
148180
}
149181

150182
if (methods.find(kAudioEncoderMethod) != methods.end()) {
@@ -171,8 +203,8 @@ bool MultimodalPrefiller::is_method_loaded() {
171203
ET_CHECK_MSG(false, "Failed to get method names");
172204
}
173205
std::unordered_set<std::string> methods = methods_res.get();
174-
if (methods.find(kImageEncoderMethod) != methods.end()) {
175-
return module_->is_method_loaded(kImageEncoderMethod);
206+
if (methods.find(kVisionEncoderMethod) != methods.end()) {
207+
return module_->is_method_loaded(kVisionEncoderMethod);
176208
}
177209
return true;
178210
}

0 commit comments

Comments
 (0)