Skip to content

Commit 86d5468

Browse files
committed
Support for reading topic tree from JSONL file, issue #9
Signed-off-by: poppysec <[email protected]>
1 parent 692685a commit 86d5468

File tree

4 files changed

+139
-24
lines changed

4 files changed

+139
-24
lines changed

promptwright/cli.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from .config import PromptWrightConfig, construct_model_string
1010
from .engine import DataEngine
1111
from .hf_hub import HFUploader
12-
from .topic_tree import TopicTree
12+
from .topic_tree import TopicTree, TopicTreeArguments
13+
from .utils import read_topic_tree_from_jsonl
1314

1415

1516
def handle_error(ctx: click.Context, error: Exception) -> None: # noqa: ARG001
@@ -27,6 +28,7 @@ def cli():
2728
@cli.command()
2829
@click.argument("config_file", type=click.Path(exists=True))
2930
@click.option("--topic-tree-save-as", help="Override the save path for the topic tree")
31+
@click.option('--topic-tree-jsonl', type=click.Path(exists=True), help='Path to the JSONL file containing the topic tree.')
3032
@click.option("--dataset-save-as", help="Override the save path for the dataset")
3133
@click.option("--provider", help="Override the LLM provider (e.g., ollama)")
3234
@click.option("--model", help="Override the model name (e.g., mistral:latest)")
@@ -55,6 +57,7 @@ def cli():
5557
def start( # noqa: PLR0912
5658
config_file: str,
5759
topic_tree_save_as: str | None = None,
60+
topic_tree_jsonl: str | None = None,
5861
dataset_save_as: str | None = None,
5962
provider: str | None = None,
6063
model: str | None = None,
@@ -85,6 +88,9 @@ def start( # noqa: PLR0912
8588
handle_error(
8689
click.get_current_context(), f"Error loading config file: {str(e)}"
8790
)
91+
# Get dataset parameters
92+
dataset_config = config.get_dataset_config()
93+
dataset_params = dataset_config.get("creation", {})
8894

8995
# Prepare topic tree overrides
9096
tree_overrides = {}
@@ -99,26 +105,53 @@ def start( # noqa: PLR0912
99105
if tree_depth:
100106
tree_overrides["tree_depth"] = tree_depth
101107

108+
# Construct model name
109+
model_name = construct_model_string(
110+
provider or dataset_params.get("provider", "default_provider"),
111+
model or dataset_params.get("model", "default_model")
112+
)
113+
102114
# Create and build topic tree
103115
try:
104-
tree = TopicTree(args=config.get_topic_tree_args(**tree_overrides))
105-
tree.build_tree()
116+
print("Creating TopicTree object...")
117+
if topic_tree_jsonl:
118+
print(f"Reading topic tree from JSONL file: {topic_tree_jsonl}")
119+
dict_list = read_topic_tree_from_jsonl(topic_tree_jsonl)
120+
default_args = TopicTreeArguments(
121+
root_prompt="default",
122+
model_name=model_name
123+
)
124+
tree = TopicTree(args=default_args)
125+
tree.from_dict_list(dict_list)
126+
else:
127+
if hasattr(config, 'topic_tree'):
128+
tree_args = config.get_topic_tree_args(**tree_overrides)
129+
else:
130+
tree_args = TopicTreeArguments(
131+
root_prompt="default",
132+
model_name=model_name
133+
)
134+
tree = TopicTree(args=tree_args)
135+
print("Building topic tree...")
136+
tree.build_tree()
106137
except Exception as e:
107138
handle_error(
108139
click.get_current_context(), f"Error building topic tree: {str(e)}"
109140
)
110141

111-
# Save topic tree
112-
try:
113-
tree_save_path = topic_tree_save_as or config.topic_tree.get(
114-
"save_as", "topic_tree.jsonl"
115-
)
116-
tree.save(tree_save_path)
117-
click.echo(f"Topic tree saved to: {tree_save_path}")
118-
except Exception as e:
119-
handle_error(
120-
click.get_current_context(), f"Error saving topic tree: {str(e)}"
121-
)
142+
# Save topic tree if JSONL file is not provided
143+
if not topic_tree_jsonl:
144+
try:
145+
tree_save_path = topic_tree_save_as or config.topic_tree.get(
146+
"save_as", "topic_tree.jsonl"
147+
)
148+
print(f"Saving topic tree to: {tree_save_path}")
149+
tree.save(tree_save_path)
150+
click.echo(f"Topic tree saved to: {tree_save_path}")
151+
except Exception as e:
152+
handle_error(
153+
click.get_current_context(), f"Error saving topic tree: {str(e)}"
154+
)
122155

123156
# Prepare engine overrides
124157
engine_overrides = {}
@@ -137,17 +170,11 @@ def start( # noqa: PLR0912
137170
click.get_current_context(), f"Error creating data engine: {str(e)}"
138171
)
139172

140-
# Get dataset parameters
141-
dataset_config = config.get_dataset_config()
142-
dataset_params = dataset_config.get("creation", {})
143-
144173
# Construct model name for dataset creation
145-
if provider and model:
146-
model_name = construct_model_string(provider, model)
147-
else:
148-
dataset_provider = dataset_params.get("provider", "ollama")
149-
dataset_model = dataset_params.get("model", "mistral:latest")
150-
model_name = construct_model_string(dataset_provider, dataset_model)
174+
model_name = construct_model_string(
175+
provider or dataset_params.get("provider", "ollama"),
176+
model or dataset_params.get("model", "mistral:latest")
177+
)
151178

152179
# Create dataset with overrides
153180
try:

promptwright/topic_tree.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,21 @@ def print_tree(self) -> None:
299299
print("Topic Tree Structure:")
300300
for path in self.tree_paths:
301301
print(" -> ".join(path))
302+
303+
def from_dict_list(self, dict_list: list[dict[str, Any]]) -> None:
304+
"""
305+
Construct the topic tree from a list of dictionaries.
306+
307+
Args:
308+
dict_list (list[dict]): The list of dictionaries representing the topic tree.
309+
"""
310+
self.tree_paths = []
311+
self.failed_generations = []
312+
313+
for d in dict_list:
314+
if 'path' in d:
315+
self.tree_paths.append(d['path'])
316+
if 'failed_generation' in d:
317+
self.failed_generations.append(d['failed_generation'])
318+
319+
print(f"Loaded {len(self.tree_paths)} paths and {len(self.failed_generations)} failed generations from JSONL file")

promptwright/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,19 @@ def safe_literal_eval(list_string: str):
9696
except (SyntaxError, ValueError):
9797
print("Failed to parse the list due to syntax issues.")
9898
return None
99+
100+
def read_topic_tree_from_jsonl(file_path: str) -> list[dict]:
101+
"""
102+
Read the topic tree from a JSONL file.
103+
104+
Args:
105+
file_path (str): The path to the JSONL file.
106+
107+
Returns:
108+
list[dict]: The topic tree.
109+
"""
110+
topic_tree = []
111+
with open(file_path) as file:
112+
for line in file:
113+
topic_tree.append(json.loads(line.strip()))
114+
return topic_tree

tests/test_cli.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,60 @@ def test_start_command_with_overrides(
282282
assert kwargs["model_name"] == "override/model"
283283
assert kwargs["sys_msg"] is False
284284

285+
@patch("promptwright.cli.read_topic_tree_from_jsonl")
286+
@patch("promptwright.cli.TopicTree")
287+
@patch("promptwright.cli.DataEngine")
288+
289+
def test_start_command_with_jsonl(
290+
mock_data_engine, mock_topic_tree, mock_read_topic_tree_from_jsonl, cli_runner,
291+
sample_config_file
292+
):
293+
"""Test start command with JSONL file."""
294+
mock_tree_instance = Mock()
295+
mock_topic_tree.return_value = mock_tree_instance
296+
mock_read_topic_tree_from_jsonl.return_value = [{"path": ["root", "child"]}]
297+
298+
mock_engine_instance = Mock()
299+
mock_data_engine.return_value = mock_engine_instance
300+
mock_dataset = Mock()
301+
mock_engine_instance.create_data.return_value = mock_dataset
302+
# Create a temporary JSONL file
303+
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
304+
f.write('{"path": ["root", "child"]}\n')
305+
temp_jsonl_path = f.name
306+
307+
try:
308+
# Run command with JSONL file
309+
result = cli_runner.invoke(
310+
cli,
311+
[
312+
"start",
313+
sample_config_file,
314+
"--topic-tree-jsonl",
315+
temp_jsonl_path
316+
],
317+
)
318+
319+
# Print output if command fails
320+
if result.exit_code != 0:
321+
print(result.output)
322+
323+
# Verify command executed successfully
324+
assert result.exit_code == 0
325+
326+
# Verify JSONL read function was called
327+
mock_read_topic_tree_from_jsonl.assert_called_once_with(temp_jsonl_path)
328+
329+
# Verify from_dict_list was called with the correct data
330+
mock_tree_instance.from_dict_list.assert_called_once_with([{"path": ["root", "child"]}])
331+
332+
# Verify save was not called since JSONL file was provided
333+
mock_tree_instance.save.assert_not_called()
334+
335+
finally:
336+
# Cleanup the temporary JSONL file
337+
if os.path.exists(temp_jsonl_path):
338+
os.unlink(temp_jsonl_path)
285339

286340
def test_start_command_missing_config(cli_runner):
287341
"""Test start command with missing config file."""

0 commit comments

Comments
 (0)