@@ -125,15 +125,14 @@ def compile(
125125            else  dspy .settings .max_errors 
126126        )
127127
128-         # Update max demos if specified 
129-         initial_max_bootstrapped_demos  =  self .max_bootstrapped_demos 
130-         if  max_bootstrapped_demos  is  not None :
131-             self .max_bootstrapped_demos  =  max_bootstrapped_demos 
132-         initial_max_labeled_demos  =  self .max_labeled_demos 
133-         if  max_labeled_demos  is  not None :
134-             self .max_labeled_demos  =  max_labeled_demos 
128+         effective_max_bootstrapped_demos  =  (
129+             max_bootstrapped_demos  if  max_bootstrapped_demos  is  not None  else  self .max_bootstrapped_demos 
130+         )
131+         effective_max_labeled_demos  =  (
132+             max_labeled_demos  if  max_labeled_demos  is  not None  else  self .max_labeled_demos 
133+         )
135134
136-         zeroshot_opt  =  (self . max_bootstrapped_demos  ==  0 ) and  (self . max_labeled_demos  ==  0 )
135+         zeroshot_opt  =  (effective_max_bootstrapped_demos  ==  0 ) and  (effective_max_labeled_demos  ==  0 )
137136
138137        # If auto is None, and num_trials is not provided (but num_candidates is), raise an error that suggests a good num_trials value 
139138        if  self .auto  is  None  and  (self .num_candidates  is  not None  and  num_trials  is  None ):
@@ -159,13 +158,42 @@ def compile(
159158        # Set training & validation sets 
160159        trainset , valset  =  self ._set_and_validate_datasets (trainset , valset )
161160
161+         num_instruct_candidates  =  (
162+             self .num_instruct_candidates 
163+             if  self .num_instruct_candidates  is  not None 
164+             else  self .num_candidates 
165+         )
166+         num_fewshot_candidates  =  (
167+             self .num_fewshot_candidates 
168+             if  self .num_fewshot_candidates  is  not None 
169+             else  self .num_candidates 
170+         )
171+ 
162172        # Set hyperparameters based on run mode (if set) 
163-         num_trials , valset , minibatch  =  self ._set_hyperparams_from_run_mode (
164-             student , num_trials , minibatch , zeroshot_opt , valset 
173+         (
174+             num_trials ,
175+             valset ,
176+             minibatch ,
177+             num_instruct_candidates ,
178+             num_fewshot_candidates ,
179+         ) =  self ._set_hyperparams_from_run_mode (
180+             student ,
181+             num_trials ,
182+             minibatch ,
183+             zeroshot_opt ,
184+             valset ,
185+             num_instruct_candidates ,
186+             num_fewshot_candidates ,
165187        )
166188
167189        if  self .auto :
168-             self ._print_auto_run_settings (num_trials , minibatch , valset )
190+             self ._print_auto_run_settings (
191+                 num_trials ,
192+                 minibatch ,
193+                 valset ,
194+                 num_fewshot_candidates ,
195+                 num_instruct_candidates ,
196+             )
169197
170198        if  minibatch  and  minibatch_size  >  len (valset ):
171199            raise  ValueError (f"Minibatch size cannot exceed the size of the valset. Valset size: { len (valset )}  )
@@ -183,7 +211,17 @@ def compile(
183211        )
184212
185213        # Step 1: Bootstrap few-shot examples 
186-         demo_candidates  =  self ._bootstrap_fewshot_examples (program , trainset , seed , teacher )
214+         demo_candidates  =  self ._bootstrap_fewshot_examples (
215+             program ,
216+             trainset ,
217+             seed ,
218+             teacher ,
219+             num_fewshot_candidates = num_fewshot_candidates ,
220+             max_bootstrapped_demos = effective_max_bootstrapped_demos ,
221+             max_labeled_demos = effective_max_labeled_demos ,
222+             max_errors = effective_max_errors ,
223+             metric_threshold = self .metric_threshold ,
224+         )
187225
188226        # Step 2: Propose instruction candidates 
189227        instruction_candidates  =  self ._propose_instructions (
@@ -195,6 +233,7 @@ def compile(
195233            data_aware_proposer ,
196234            tip_aware_proposer ,
197235            fewshot_aware_proposer ,
236+             num_instruct_candidates = num_instruct_candidates ,
198237        )
199238
200239        # If zero-shot, discard demos 
@@ -215,10 +254,6 @@ def compile(
215254            seed ,
216255        )
217256
218-         # Reset max demos 
219-         self .max_bootstrapped_demos  =  initial_max_bootstrapped_demos 
220-         self .max_labeled_demos  =  initial_max_labeled_demos 
221- 
222257        return  best_program 
223258
224259    def  _set_random_seeds (self , seed ):
@@ -237,13 +272,17 @@ def _set_num_trials_from_num_candidates(self, program, zeroshot_opt, num_candida
237272    def  _set_hyperparams_from_run_mode (
238273        self ,
239274        program : Any ,
240-         num_trials : int ,
275+         num_trials : int   |   None ,
241276        minibatch : bool ,
242277        zeroshot_opt : bool ,
243278        valset : list ,
244-     ) ->  tuple [int , list , bool ]:
279+         num_instruct_candidates : int  |  None ,
280+         num_fewshot_candidates : int  |  None ,
281+     ) ->  tuple [int , list , bool , int , int ]:
245282        if  self .auto  is  None :
246-             return  num_trials , valset , minibatch 
283+             if  num_instruct_candidates  is  None  or  num_fewshot_candidates  is  None :
284+                 raise  ValueError ("num_candidates must be provided when auto is None." )
285+             return  num_trials , valset , minibatch , num_instruct_candidates , num_fewshot_candidates 
247286
248287        auto_settings  =  AUTO_RUN_SETTINGS [self .auto ]
249288
@@ -253,12 +292,12 @@ def _set_hyperparams_from_run_mode(
253292        # Set num instruct candidates to 1/2 of N if optimizing with few-shot examples, otherwise set to N 
254293        # This is because we've found that it's generally better to spend optimization budget on few-shot examples 
255294        # When they are allowed. 
256-         self . num_instruct_candidates  =  auto_settings ["n" ] if  zeroshot_opt  else  int (auto_settings ["n" ] *  0.5 )
257-         self . num_fewshot_candidates  =  auto_settings ["n" ]
295+         num_instruct_candidates  =  auto_settings ["n" ] if  zeroshot_opt  else  int (auto_settings ["n" ] *  0.5 )
296+         num_fewshot_candidates  =  auto_settings ["n" ]
258297
259298        num_trials  =  self ._set_num_trials_from_num_candidates (program , zeroshot_opt , auto_settings ["n" ])
260299
261-         return  num_trials , valset , minibatch 
300+         return  num_trials , valset , minibatch ,  num_instruct_candidates ,  num_fewshot_candidates 
262301
263302    def  _set_and_validate_datasets (self , trainset : list , valset : list  |  None ):
264303        if  not  trainset :
@@ -277,13 +316,20 @@ def _set_and_validate_datasets(self, trainset: list, valset: list | None):
277316
278317        return  trainset , valset 
279318
280-     def  _print_auto_run_settings (self , num_trials : int , minibatch : bool , valset : list ):
319+     def  _print_auto_run_settings (
320+         self ,
321+         num_trials : int ,
322+         minibatch : bool ,
323+         valset : list ,
324+         num_fewshot_candidates : int ,
325+         num_instruct_candidates : int ,
326+     ):
281327        logger .info (
282328            f"\n RUNNING WITH THE FOLLOWING { self .auto .upper ()}  
283329            f"\n num_trials: { num_trials }  
284330            f"\n minibatch: { minibatch }  
285-             f"\n num_fewshot_candidates: { self . num_fewshot_candidates }  
286-             f"\n num_instruct_candidates: { self . num_instruct_candidates }  
331+             f"\n num_fewshot_candidates: { num_fewshot_candidates }  
332+             f"\n num_instruct_candidates: { num_instruct_candidates }  
287333            f"\n valset size: { len (valset )} \n " 
288334        )
289335
@@ -296,18 +342,19 @@ def _estimate_lm_calls(
296342        minibatch_full_eval_steps : int ,
297343        valset : list ,
298344        program_aware_proposer : bool ,
345+         num_instruct_candidates : int ,
299346    ) ->  tuple [str , str ]:
300347        num_predictors  =  len (program .predictors ())
301348
302349        # Estimate prompt model calls 
303350        estimated_prompt_model_calls  =  (
304351            10   # Data summarizer calls 
305-             +  self . num_instruct_candidates  *  num_predictors   # Candidate generation 
352+             +  num_instruct_candidates  *  num_predictors   # Candidate generation 
306353            +  (num_predictors  +  1  if  program_aware_proposer  else  0 )  # Program-aware proposer 
307354        )
308355        prompt_model_line  =  (
309356            f"{ YELLOW } { BLUE } { BOLD } { ENDC } { YELLOW }  
310-             f"{ BLUE } { BOLD } { self . num_instruct_candidates } { ENDC } { YELLOW }  
357+             f"{ BLUE } { BOLD } { num_instruct_candidates } { ENDC } { YELLOW }  
311358            f"{ BLUE } { BOLD } { num_predictors } { ENDC } { YELLOW }  
312359            f"+ ({ BLUE } { BOLD } { num_predictors  +  1 } { ENDC } { YELLOW }  
313360            f"= { BLUE } { BOLD } { estimated_prompt_model_calls } { ENDC } { YELLOW } { ENDC }  
@@ -334,38 +381,48 @@ def _estimate_lm_calls(
334381
335382        return  prompt_model_line , task_model_line 
336383
337-     def  _bootstrap_fewshot_examples (self , program : Any , trainset : list , seed : int , teacher : Any ) ->  list  |  None :
384+     def  _bootstrap_fewshot_examples (
385+         self ,
386+         program : Any ,
387+         trainset : list ,
388+         seed : int ,
389+         teacher : Any ,
390+         * ,
391+         num_fewshot_candidates : int ,
392+         max_bootstrapped_demos : int ,
393+         max_labeled_demos : int ,
394+         max_errors : int  |  None ,
395+         metric_threshold : float  |  None ,
396+     ) ->  list  |  None :
338397        logger .info ("\n ==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==" )
339-         if  self . max_bootstrapped_demos  >  0 :
398+         if  max_bootstrapped_demos  >  0 :
340399            logger .info (
341400                "These will be used as few-shot example candidates for our program and for creating instructions.\n " 
342401            )
343402        else :
344403            logger .info ("These will be used for informing instruction proposal.\n " )
345404
346-         logger .info (f"Bootstrapping N={ self . num_fewshot_candidates }  )
405+         logger .info (f"Bootstrapping N={ num_fewshot_candidates }  )
347406
348-         zeroshot  =  self . max_bootstrapped_demos  ==  0  and  self . max_labeled_demos  ==  0 
407+         zeroshot  =  max_bootstrapped_demos  ==  0  and  max_labeled_demos  ==  0 
349408
350-         # try: 
351-         effective_max_errors  =  (
352-             self .max_errors  if  self .max_errors  is  not None  else  dspy .settings .max_errors 
353-         )
409+         if  max_errors  is  None :
410+             max_errors  =  dspy .settings .max_errors 
354411
355412        demo_candidates  =  create_n_fewshot_demo_sets (
356413            student = program ,
357-             num_candidate_sets = self . num_fewshot_candidates ,
414+             num_candidate_sets = num_fewshot_candidates ,
358415            trainset = trainset ,
359-             max_labeled_demos = (LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT  if  zeroshot  else  self . max_labeled_demos ),
416+             max_labeled_demos = (LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT  if  zeroshot  else  max_labeled_demos ),
360417            max_bootstrapped_demos = (
361-                 BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT  if  zeroshot  else  self . max_bootstrapped_demos 
418+                 BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT  if  zeroshot  else  max_bootstrapped_demos 
362419            ),
363420            metric = self .metric ,
364-             max_errors = effective_max_errors ,
421+             max_errors = max_errors ,
365422            teacher = teacher ,
366423            teacher_settings = self .teacher_settings ,
367424            seed = seed ,
368-             metric_threshold = self . metric_threshold ,
425+             metric_threshold = metric_threshold ,
369426            rng = self .rng ,
370427        )
371428        # NOTE: Bootstrapping is essential to MIPRO! 
@@ -387,6 +444,7 @@ def _propose_instructions(
387444        data_aware_proposer : bool ,
388445        tip_aware_proposer : bool ,
389446        fewshot_aware_proposer : bool ,
447+         num_instruct_candidates : int ,
390448    ) ->  dict [int , list [str ]]:
391449        logger .info ("\n ==> STEP 2: PROPOSE INSTRUCTION CANDIDATES <==" )
392450        logger .info (
@@ -411,12 +469,12 @@ def _propose_instructions(
411469            init_temperature = self .init_temperature ,
412470        )
413471
414-         logger .info (f"\n Proposing N={ self . num_instruct_candidates } \n " )
472+         logger .info (f"\n Proposing N={ num_instruct_candidates } \n " )
415473        instruction_candidates  =  proposer .propose_instructions_for_program (
416474            trainset = trainset ,
417475            program = program ,
418476            demo_candidates = demo_candidates ,
419-             N = self . num_instruct_candidates ,
477+             N = num_instruct_candidates ,
420478            trial_logs = {},
421479        )
422480
0 commit comments