Skip to content

Commit d38778c

Browse files
authored
Fix formatting of CLI generate command input (#402)
* Fix formatting of CLI generate input * Update unit test to verify the generate input prompt * Test doc string
1 parent 42511a5 commit d38778c

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Here are some examples using no-code bash commands:
66
* [Text Classification Intel® Transfer Learning Tool CLI Example](cli/text_classification.md)
77
* [Image Classification Intel® Transfer Learning Tool CLI Example](cli/image_classification.md)
88
* [Vision Anomaly Detection Intel® Transfer Learning Tool CLI Example](cli/image_anomaly_detection.md)
9+
* [Text Generation Intel® Transfer Learning Tool CLI Example](cli/text_generation.md)
910

1011
Here are Jupyter notebook examples using low-code Python\* API calls:
1112

tests/tools/cli/test_generate_cli.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,31 @@ def teardown_class(cls):
5858
print("Deleting test directory:", dir)
5959
shutil.rmtree(dir)
6060

61+
@patch("tlt.models.text_generation.pytorch_hf_text_generation_model.PyTorchHFTextGenerationModel.generate")
6162
@pytest.mark.parametrize('model_name,prompt',
6263
[['distilgpt2', 'The size of an apple is'],
63-
['distilgpt2', 'A large fruit is']])
64-
def test_base_generation(self, model_name, prompt):
64+
['distilgpt2', 'A large fruit is'],
65+
['distilgpt2',
66+
'The input describes a task.\\n\\nInstruction:\nWrite a song.\\n\\n### Response:\n']])
67+
def test_base_generation(self, mock_generate, model_name, prompt):
6568
"""
66-
Tests the full workflow for PYT text generation using a custom dataset
69+
Tests the CLI generate command for PYT text generation using a HF pretrained model
6770
"""
6871
runner = CliRunner()
6972

73+
# Define a dummy response
74+
mock_generate.return_value = [prompt + ' so good.']
75+
7076
# Generate a text completion
7177
result = runner.invoke(generate,
7278
["--model-name", model_name, "--prompt", prompt])
79+
80+
# Verify that the TLT generate method was called with a properly formatted prompt string
81+
assert len(mock_generate.call_args_list) == 1
82+
prompt_arg = mock_generate.call_args_list[0][0]
83+
assert "\\n" not in prompt_arg
84+
85+
# Verify that we didn't get any errors
7386
assert result is not None
7487
assert result.exit_code == 0
7588

tlt/tools/cli/commands/generate.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,7 @@ def generate(model_dir, model_name, prompt, temperature, top_p, top_k, repetitio
9595
if os.path.exists(model_dir):
9696
model.load_from_directory(model_dir)
9797

98-
print()
99-
print("Prompt:", prompt)
100-
print()
101-
98+
prompt = prompt.replace("\\n", "\n")
10299
output = model.generate(prompt, temperature=temperature, repetition_penalty=repetition_penalty, top_p=top_p,
103100
top_k=top_k, num_beams=num_beams, max_new_tokens=max_new_tokens)
104-
print(output)
101+
print(*output, sep='\n')

0 commit comments

Comments
 (0)