@@ -112,6 +112,7 @@ class ExecuTorchLlmCallbackJni
112
112
class ExecuTorchLlmJni : public facebook ::jni::HybridClass<ExecuTorchLlmJni> {
113
113
private:
114
114
friend HybridBase;
115
+ float temperature_;
115
116
int model_type_category_;
116
117
std::unique_ptr<llm::IRunner> runner_;
117
118
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
@@ -167,20 +168,17 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
167
168
runner_ = std::make_unique<example::Runner>(
168
169
model_path->toStdString ().c_str (),
169
170
tokenizer_path->toStdString ().c_str (),
170
- temperature,
171
171
data_path->toStdString ().c_str ());
172
172
} else {
173
173
runner_ = std::make_unique<example::Runner>(
174
174
model_path->toStdString ().c_str (),
175
- tokenizer_path->toStdString ().c_str (),
176
- temperature);
175
+ tokenizer_path->toStdString ().c_str ());
177
176
}
178
177
#if defined(EXECUTORCH_BUILD_MEDIATEK)
179
178
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
180
179
runner_ = std::make_unique<MTKLlamaRunner>(
181
180
model_path->toStdString ().c_str (),
182
- tokenizer_path->toStdString ().c_str (),
183
- temperature);
181
+ tokenizer_path->toStdString ().c_str ());
184
182
// Interpret the model type as LLM
185
183
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
186
184
#endif
@@ -220,6 +218,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
220
218
executorch::extension::llm::GenerationConfig config{
221
219
.echo = static_cast <bool >(echo),
222
220
.seq_len = seq_len,
221
+ .temperature = temperature_,
223
222
};
224
223
runner_->generate (
225
224
prompt->toStdString (),
0 commit comments