9
9
from .config import PromptWrightConfig , construct_model_string
10
10
from .engine import DataEngine
11
11
from .hf_hub import HFUploader
12
- from .topic_tree import TopicTree
12
+ from .topic_tree import TopicTree , TopicTreeArguments
13
+ from .utils import read_topic_tree_from_jsonl
13
14
14
15
15
16
def handle_error (ctx : click .Context , error : Exception ) -> None : # noqa: ARG001
@@ -27,6 +28,7 @@ def cli():
27
28
@cli .command ()
28
29
@click .argument ("config_file" , type = click .Path (exists = True ))
29
30
@click .option ("--topic-tree-save-as" , help = "Override the save path for the topic tree" )
31
+ @click .option ('--topic-tree-jsonl' , type = click .Path (exists = True ), help = 'Path to the JSONL file containing the topic tree.' )
30
32
@click .option ("--dataset-save-as" , help = "Override the save path for the dataset" )
31
33
@click .option ("--provider" , help = "Override the LLM provider (e.g., ollama)" )
32
34
@click .option ("--model" , help = "Override the model name (e.g., mistral:latest)" )
@@ -55,6 +57,7 @@ def cli():
55
57
def start ( # noqa: PLR0912
56
58
config_file : str ,
57
59
topic_tree_save_as : str | None = None ,
60
+ topic_tree_jsonl : str | None = None ,
58
61
dataset_save_as : str | None = None ,
59
62
provider : str | None = None ,
60
63
model : str | None = None ,
@@ -85,6 +88,9 @@ def start( # noqa: PLR0912
85
88
handle_error (
86
89
click .get_current_context (), f"Error loading config file: { str (e )} "
87
90
)
91
+ # Get dataset parameters
92
+ dataset_config = config .get_dataset_config ()
93
+ dataset_params = dataset_config .get ("creation" , {})
88
94
89
95
# Prepare topic tree overrides
90
96
tree_overrides = {}
@@ -99,26 +105,53 @@ def start( # noqa: PLR0912
99
105
if tree_depth :
100
106
tree_overrides ["tree_depth" ] = tree_depth
101
107
108
+ # Construct model name
109
+ model_name = construct_model_string (
110
+ provider or dataset_params .get ("provider" , "default_provider" ),
111
+ model or dataset_params .get ("model" , "default_model" )
112
+ )
113
+
102
114
# Create and build topic tree
103
115
try :
104
- tree = TopicTree (args = config .get_topic_tree_args (** tree_overrides ))
105
- tree .build_tree ()
116
+ print ("Creating TopicTree object..." )
117
+ if topic_tree_jsonl :
118
+ print (f"Reading topic tree from JSONL file: { topic_tree_jsonl } " )
119
+ dict_list = read_topic_tree_from_jsonl (topic_tree_jsonl )
120
+ default_args = TopicTreeArguments (
121
+ root_prompt = "default" ,
122
+ model_name = model_name
123
+ )
124
+ tree = TopicTree (args = default_args )
125
+ tree .from_dict_list (dict_list )
126
+ else :
127
+ if hasattr (config , 'topic_tree' ):
128
+ tree_args = config .get_topic_tree_args (** tree_overrides )
129
+ else :
130
+ tree_args = TopicTreeArguments (
131
+ root_prompt = "default" ,
132
+ model_name = model_name
133
+ )
134
+ tree = TopicTree (args = tree_args )
135
+ print ("Building topic tree..." )
136
+ tree .build_tree ()
106
137
except Exception as e :
107
138
handle_error (
108
139
click .get_current_context (), f"Error building topic tree: { str (e )} "
109
140
)
110
141
111
- # Save topic tree
112
- try :
113
- tree_save_path = topic_tree_save_as or config .topic_tree .get (
114
- "save_as" , "topic_tree.jsonl"
115
- )
116
- tree .save (tree_save_path )
117
- click .echo (f"Topic tree saved to: { tree_save_path } " )
118
- except Exception as e :
119
- handle_error (
120
- click .get_current_context (), f"Error saving topic tree: { str (e )} "
121
- )
142
+ # Save topic tree if JSONL file is not provided
143
+ if not topic_tree_jsonl :
144
+ try :
145
+ tree_save_path = topic_tree_save_as or config .topic_tree .get (
146
+ "save_as" , "topic_tree.jsonl"
147
+ )
148
+ print (f"Saving topic tree to: { tree_save_path } " )
149
+ tree .save (tree_save_path )
150
+ click .echo (f"Topic tree saved to: { tree_save_path } " )
151
+ except Exception as e :
152
+ handle_error (
153
+ click .get_current_context (), f"Error saving topic tree: { str (e )} "
154
+ )
122
155
123
156
# Prepare engine overrides
124
157
engine_overrides = {}
@@ -137,17 +170,11 @@ def start( # noqa: PLR0912
137
170
click .get_current_context (), f"Error creating data engine: { str (e )} "
138
171
)
139
172
140
- # Get dataset parameters
141
- dataset_config = config .get_dataset_config ()
142
- dataset_params = dataset_config .get ("creation" , {})
143
-
144
173
# Construct model name for dataset creation
145
- if provider and model :
146
- model_name = construct_model_string (provider , model )
147
- else :
148
- dataset_provider = dataset_params .get ("provider" , "ollama" )
149
- dataset_model = dataset_params .get ("model" , "mistral:latest" )
150
- model_name = construct_model_string (dataset_provider , dataset_model )
174
+ model_name = construct_model_string (
175
+ provider or dataset_params .get ("provider" , "ollama" ),
176
+ model or dataset_params .get ("model" , "mistral:latest" )
177
+ )
151
178
152
179
# Create dataset with overrides
153
180
try :
0 commit comments