@@ -128,15 +128,14 @@ def compile(
128128 else dspy .settings .max_errors
129129 )
130130
131- # Update max demos if specified
132- initial_max_bootstrapped_demos = self .max_bootstrapped_demos
133- if max_bootstrapped_demos is not None :
134- self .max_bootstrapped_demos = max_bootstrapped_demos
135- initial_max_labeled_demos = self .max_labeled_demos
136- if max_labeled_demos is not None :
137- self .max_labeled_demos = max_labeled_demos
131+ effective_max_bootstrapped_demos = (
132+ max_bootstrapped_demos if max_bootstrapped_demos is not None else self .max_bootstrapped_demos
133+ )
134+ effective_max_labeled_demos = (
135+ max_labeled_demos if max_labeled_demos is not None else self .max_labeled_demos
136+ )
138137
139- zeroshot_opt = (self . max_bootstrapped_demos == 0 ) and (self . max_labeled_demos == 0 )
138+ zeroshot_opt = (effective_max_bootstrapped_demos == 0 ) and (effective_max_labeled_demos == 0 )
140139
141140 # If auto is None, and num_trials is not provided (but num_candidates is), raise an error that suggests a good num_trials value
142141 if self .auto is None and (self .num_candidates is not None and num_trials is None ):
@@ -162,13 +161,42 @@ def compile(
162161 # Set training & validation sets
163162 trainset , valset = self ._set_and_validate_datasets (trainset , valset )
164163
164+ num_instruct_candidates = (
165+ self .num_instruct_candidates
166+ if self .num_instruct_candidates is not None
167+ else self .num_candidates
168+ )
169+ num_fewshot_candidates = (
170+ self .num_fewshot_candidates
171+ if self .num_fewshot_candidates is not None
172+ else self .num_candidates
173+ )
174+
165175 # Set hyperparameters based on run mode (if set)
166- num_trials , valset , minibatch = self ._set_hyperparams_from_run_mode (
167- student , num_trials , minibatch , zeroshot_opt , valset
176+ (
177+ num_trials ,
178+ valset ,
179+ minibatch ,
180+ num_instruct_candidates ,
181+ num_fewshot_candidates ,
182+ ) = self ._set_hyperparams_from_run_mode (
183+ student ,
184+ num_trials ,
185+ minibatch ,
186+ zeroshot_opt ,
187+ valset ,
188+ num_instruct_candidates ,
189+ num_fewshot_candidates ,
168190 )
169191
170192 if self .auto :
171- self ._print_auto_run_settings (num_trials , minibatch , valset )
193+ self ._print_auto_run_settings (
194+ num_trials ,
195+ minibatch ,
196+ valset ,
197+ num_fewshot_candidates ,
198+ num_instruct_candidates ,
199+ )
172200
173201 if minibatch and minibatch_size > len (valset ):
174202 raise ValueError (f"Minibatch size cannot exceed the size of the valset. Valset size: { len (valset )} ." )
@@ -186,8 +214,19 @@ def compile(
186214 )
187215
188216 with dspy .context (lm = self .task_model ):
217+
189218 # Step 1: Bootstrap few-shot examples
190- demo_candidates = self ._bootstrap_fewshot_examples (program , trainset , seed , teacher )
219+ demo_candidates = self ._bootstrap_fewshot_examples (
220+ program ,
221+ trainset ,
222+ seed ,
223+ teacher ,
224+ num_fewshot_candidates = num_fewshot_candidates ,
225+ max_bootstrapped_demos = effective_max_bootstrapped_demos ,
226+ max_labeled_demos = effective_max_labeled_demos ,
227+ max_errors = effective_max_errors ,
228+ metric_threshold = self .metric_threshold ,
229+ )
191230
192231 # Step 2: Propose instruction candidates
193232 instruction_candidates = self ._propose_instructions (
@@ -199,6 +238,7 @@ def compile(
199238 data_aware_proposer ,
200239 tip_aware_proposer ,
201240 fewshot_aware_proposer ,
241+ num_instruct_candidates = num_instruct_candidates ,
202242 )
203243
204244 # If zero-shot, discard demos
@@ -220,10 +260,6 @@ def compile(
220260 seed ,
221261 )
222262
223- # Reset max demos
224- self .max_bootstrapped_demos = initial_max_bootstrapped_demos
225- self .max_labeled_demos = initial_max_labeled_demos
226-
227263 return best_program
228264
229265 def _set_random_seeds (self , seed ):
@@ -242,13 +278,17 @@ def _set_num_trials_from_num_candidates(self, program, zeroshot_opt, num_candida
242278 def _set_hyperparams_from_run_mode (
243279 self ,
244280 program : Any ,
245- num_trials : int ,
281+ num_trials : int | None ,
246282 minibatch : bool ,
247283 zeroshot_opt : bool ,
248284 valset : list ,
249- ) -> tuple [int , list , bool ]:
285+ num_instruct_candidates : int | None ,
286+ num_fewshot_candidates : int | None ,
287+ ) -> tuple [int , list , bool , int , int ]:
250288 if self .auto is None :
251- return num_trials , valset , minibatch
289+ if num_instruct_candidates is None or num_fewshot_candidates is None :
290+ raise ValueError ("num_candidates must be provided when auto is None." )
291+ return num_trials , valset , minibatch , num_instruct_candidates , num_fewshot_candidates
252292
253293 auto_settings = AUTO_RUN_SETTINGS [self .auto ]
254294
@@ -258,12 +298,12 @@ def _set_hyperparams_from_run_mode(
258298 # Set num instruct candidates to 1/2 of N if optimizing with few-shot examples, otherwise set to N
259299 # This is because we've found that it's generally better to spend optimization budget on few-shot examples
260300 # When they are allowed.
261- self . num_instruct_candidates = auto_settings ["n" ] if zeroshot_opt else int (auto_settings ["n" ] * 0.5 )
262- self . num_fewshot_candidates = auto_settings ["n" ]
301+ num_instruct_candidates = auto_settings ["n" ] if zeroshot_opt else int (auto_settings ["n" ] * 0.5 )
302+ num_fewshot_candidates = auto_settings ["n" ]
263303
264304 num_trials = self ._set_num_trials_from_num_candidates (program , zeroshot_opt , auto_settings ["n" ])
265305
266- return num_trials , valset , minibatch
306+ return num_trials , valset , minibatch , num_instruct_candidates , num_fewshot_candidates
267307
268308 def _set_and_validate_datasets (self , trainset : list , valset : list | None ):
269309 if not trainset :
@@ -282,13 +322,20 @@ def _set_and_validate_datasets(self, trainset: list, valset: list | None):
282322
283323 return trainset , valset
284324
285- def _print_auto_run_settings (self , num_trials : int , minibatch : bool , valset : list ):
325+ def _print_auto_run_settings (
326+ self ,
327+ num_trials : int ,
328+ minibatch : bool ,
329+ valset : list ,
330+ num_fewshot_candidates : int ,
331+ num_instruct_candidates : int ,
332+ ):
286333 logger .info (
287334 f"\n RUNNING WITH THE FOLLOWING { self .auto .upper ()} AUTO RUN SETTINGS:"
288335 f"\n num_trials: { num_trials } "
289336 f"\n minibatch: { minibatch } "
290- f"\n num_fewshot_candidates: { self . num_fewshot_candidates } "
291- f"\n num_instruct_candidates: { self . num_instruct_candidates } "
337+ f"\n num_fewshot_candidates: { num_fewshot_candidates } "
338+ f"\n num_instruct_candidates: { num_instruct_candidates } "
292339 f"\n valset size: { len (valset )} \n "
293340 )
294341
@@ -301,18 +348,19 @@ def _estimate_lm_calls(
301348 minibatch_full_eval_steps : int ,
302349 valset : list ,
303350 program_aware_proposer : bool ,
351+ num_instruct_candidates : int ,
304352 ) -> tuple [str , str ]:
305353 num_predictors = len (program .predictors ())
306354
307355 # Estimate prompt model calls
308356 estimated_prompt_model_calls = (
309357 10 # Data summarizer calls
310- + self . num_instruct_candidates * num_predictors # Candidate generation
358+ + num_instruct_candidates * num_predictors # Candidate generation
311359 + (num_predictors + 1 if program_aware_proposer else 0 ) # Program-aware proposer
312360 )
313361 prompt_model_line = (
314362 f"{ YELLOW } - Prompt Generation: { BLUE } { BOLD } 10{ ENDC } { YELLOW } data summarizer calls + "
315- f"{ BLUE } { BOLD } { self . num_instruct_candidates } { ENDC } { YELLOW } * "
363+ f"{ BLUE } { BOLD } { num_instruct_candidates } { ENDC } { YELLOW } * "
316364 f"{ BLUE } { BOLD } { num_predictors } { ENDC } { YELLOW } lm calls in program "
317365 f"+ ({ BLUE } { BOLD } { num_predictors + 1 } { ENDC } { YELLOW } ) lm calls in program-aware proposer "
318366 f"= { BLUE } { BOLD } { estimated_prompt_model_calls } { ENDC } { YELLOW } prompt model calls{ ENDC } "
@@ -339,38 +387,48 @@ def _estimate_lm_calls(
339387
340388 return prompt_model_line , task_model_line
341389
342- def _bootstrap_fewshot_examples (self , program : Any , trainset : list , seed : int , teacher : Any ) -> list | None :
390+ def _bootstrap_fewshot_examples (
391+ self ,
392+ program : Any ,
393+ trainset : list ,
394+ seed : int ,
395+ teacher : Any ,
396+ * ,
397+ num_fewshot_candidates : int ,
398+ max_bootstrapped_demos : int ,
399+ max_labeled_demos : int ,
400+ max_errors : int | None ,
401+ metric_threshold : float | None ,
402+ ) -> list | None :
343403 logger .info ("\n ==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==" )
344- if self . max_bootstrapped_demos > 0 :
404+ if max_bootstrapped_demos > 0 :
345405 logger .info (
346406 "These will be used as few-shot example candidates for our program and for creating instructions.\n "
347407 )
348408 else :
349409 logger .info ("These will be used for informing instruction proposal.\n " )
350410
351- logger .info (f"Bootstrapping N={ self . num_fewshot_candidates } sets of demonstrations..." )
411+ logger .info (f"Bootstrapping N={ num_fewshot_candidates } sets of demonstrations..." )
352412
353- zeroshot = self . max_bootstrapped_demos == 0 and self . max_labeled_demos == 0
413+ zeroshot = max_bootstrapped_demos == 0 and max_labeled_demos == 0
354414
355- # try:
356- effective_max_errors = (
357- self .max_errors if self .max_errors is not None else dspy .settings .max_errors
358- )
415+ if max_errors is None :
416+ max_errors = dspy .settings .max_errors
359417
360418 demo_candidates = create_n_fewshot_demo_sets (
361419 student = program ,
362- num_candidate_sets = self . num_fewshot_candidates ,
420+ num_candidate_sets = num_fewshot_candidates ,
363421 trainset = trainset ,
364- max_labeled_demos = (LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self . max_labeled_demos ),
422+ max_labeled_demos = (LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_labeled_demos ),
365423 max_bootstrapped_demos = (
366- BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self . max_bootstrapped_demos
424+ BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_bootstrapped_demos
367425 ),
368426 metric = self .metric ,
369- max_errors = effective_max_errors ,
427+ max_errors = max_errors ,
370428 teacher = teacher ,
371429 teacher_settings = self .teacher_settings ,
372430 seed = seed ,
373- metric_threshold = self . metric_threshold ,
431+ metric_threshold = metric_threshold ,
374432 rng = self .rng ,
375433 )
376434 # NOTE: Bootstrapping is essential to MIPRO!
@@ -392,6 +450,7 @@ def _propose_instructions(
392450 data_aware_proposer : bool ,
393451 tip_aware_proposer : bool ,
394452 fewshot_aware_proposer : bool ,
453+ num_instruct_candidates : int ,
395454 ) -> dict [int , list [str ]]:
396455 logger .info ("\n ==> STEP 2: PROPOSE INSTRUCTION CANDIDATES <==" )
397456 logger .info (
@@ -416,12 +475,12 @@ def _propose_instructions(
416475 init_temperature = self .init_temperature ,
417476 )
418477
419- logger .info (f"\n Proposing N={ self . num_instruct_candidates } instructions...\n " )
478+ logger .info (f"\n Proposing N={ num_instruct_candidates } instructions...\n " )
420479 instruction_candidates = proposer .propose_instructions_for_program (
421480 trainset = trainset ,
422481 program = program ,
423482 demo_candidates = demo_candidates ,
424- N = self . num_instruct_candidates ,
483+ N = num_instruct_candidates ,
425484 trial_logs = {},
426485 )
427486
0 commit comments