Skip to content

Commit f9c7460

Browse files
committed
do deterministic sample for prompt expander since seed does not work
1 parent b9477e7 commit f9c7460

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/wan2_1/wan/utils/prompt_extend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
215215
for k, v in model_inputs.items():
216216
model_inputs[k] = ms.tensor(v)
217217

218-
generated_ids = self.model.generate(**model_inputs, max_new_tokens=512).asnumpy()
218+
generated_ids = self.model.generate(**model_inputs, max_new_tokens=512, do_sample=False).asnumpy()
219219
generated_ids = [
220220
output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
221221
]
@@ -259,7 +259,7 @@ def extend_with_img(self, prompt, system_prompt, image: Union[Image.Image, str]
259259
inputs[k] = ms.tensor(v)
260260

261261
# Inference: Generation of the output
262-
generated_ids = self.model.generate(**inputs, max_new_tokens=512).asnumpy()
262+
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=False).asnumpy()
263263
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)]
264264
expanded_prompt = self.processor.batch_decode(
265265
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False

0 commit comments

Comments
 (0)