Skip to content

Commit 2e898bc

Browse files
committed
Fix merge conflict
1 parent 8dc8d0d commit 2e898bc

File tree

1 file changed

+101
-42
lines changed

1 file changed

+101
-42
lines changed

dspy/teleprompt/mipro_optimizer_v2.py

Lines changed: 101 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,14 @@ def compile(
128128
else dspy.settings.max_errors
129129
)
130130

131-
# Update max demos if specified
132-
initial_max_bootstrapped_demos = self.max_bootstrapped_demos
133-
if max_bootstrapped_demos is not None:
134-
self.max_bootstrapped_demos = max_bootstrapped_demos
135-
initial_max_labeled_demos = self.max_labeled_demos
136-
if max_labeled_demos is not None:
137-
self.max_labeled_demos = max_labeled_demos
131+
effective_max_bootstrapped_demos = (
132+
max_bootstrapped_demos if max_bootstrapped_demos is not None else self.max_bootstrapped_demos
133+
)
134+
effective_max_labeled_demos = (
135+
max_labeled_demos if max_labeled_demos is not None else self.max_labeled_demos
136+
)
138137

139-
zeroshot_opt = (self.max_bootstrapped_demos == 0) and (self.max_labeled_demos == 0)
138+
zeroshot_opt = (effective_max_bootstrapped_demos == 0) and (effective_max_labeled_demos == 0)
140139

141140
# If auto is None, and num_trials is not provided (but num_candidates is), raise an error that suggests a good num_trials value
142141
if self.auto is None and (self.num_candidates is not None and num_trials is None):
@@ -162,13 +161,42 @@ def compile(
162161
# Set training & validation sets
163162
trainset, valset = self._set_and_validate_datasets(trainset, valset)
164163

164+
num_instruct_candidates = (
165+
self.num_instruct_candidates
166+
if self.num_instruct_candidates is not None
167+
else self.num_candidates
168+
)
169+
num_fewshot_candidates = (
170+
self.num_fewshot_candidates
171+
if self.num_fewshot_candidates is not None
172+
else self.num_candidates
173+
)
174+
165175
# Set hyperparameters based on run mode (if set)
166-
num_trials, valset, minibatch = self._set_hyperparams_from_run_mode(
167-
student, num_trials, minibatch, zeroshot_opt, valset
176+
(
177+
num_trials,
178+
valset,
179+
minibatch,
180+
num_instruct_candidates,
181+
num_fewshot_candidates,
182+
) = self._set_hyperparams_from_run_mode(
183+
student,
184+
num_trials,
185+
minibatch,
186+
zeroshot_opt,
187+
valset,
188+
num_instruct_candidates,
189+
num_fewshot_candidates,
168190
)
169191

170192
if self.auto:
171-
self._print_auto_run_settings(num_trials, minibatch, valset)
193+
self._print_auto_run_settings(
194+
num_trials,
195+
minibatch,
196+
valset,
197+
num_fewshot_candidates,
198+
num_instruct_candidates,
199+
)
172200

173201
if minibatch and minibatch_size > len(valset):
174202
raise ValueError(f"Minibatch size cannot exceed the size of the valset. Valset size: {len(valset)}.")
@@ -186,8 +214,19 @@ def compile(
186214
)
187215

188216
with dspy.context(lm=self.task_model):
217+
189218
# Step 1: Bootstrap few-shot examples
190-
demo_candidates = self._bootstrap_fewshot_examples(program, trainset, seed, teacher)
219+
demo_candidates = self._bootstrap_fewshot_examples(
220+
program,
221+
trainset,
222+
seed,
223+
teacher,
224+
num_fewshot_candidates=num_fewshot_candidates,
225+
max_bootstrapped_demos=effective_max_bootstrapped_demos,
226+
max_labeled_demos=effective_max_labeled_demos,
227+
max_errors=effective_max_errors,
228+
metric_threshold=self.metric_threshold,
229+
)
191230

192231
# Step 2: Propose instruction candidates
193232
instruction_candidates = self._propose_instructions(
@@ -199,6 +238,7 @@ def compile(
199238
data_aware_proposer,
200239
tip_aware_proposer,
201240
fewshot_aware_proposer,
241+
num_instruct_candidates=num_instruct_candidates,
202242
)
203243

204244
# If zero-shot, discard demos
@@ -220,10 +260,6 @@ def compile(
220260
seed,
221261
)
222262

223-
# Reset max demos
224-
self.max_bootstrapped_demos = initial_max_bootstrapped_demos
225-
self.max_labeled_demos = initial_max_labeled_demos
226-
227263
return best_program
228264

229265
def _set_random_seeds(self, seed):
@@ -242,13 +278,17 @@ def _set_num_trials_from_num_candidates(self, program, zeroshot_opt, num_candida
242278
def _set_hyperparams_from_run_mode(
243279
self,
244280
program: Any,
245-
num_trials: int,
281+
num_trials: int | None,
246282
minibatch: bool,
247283
zeroshot_opt: bool,
248284
valset: list,
249-
) -> tuple[int, list, bool]:
285+
num_instruct_candidates: int | None,
286+
num_fewshot_candidates: int | None,
287+
) -> tuple[int, list, bool, int, int]:
250288
if self.auto is None:
251-
return num_trials, valset, minibatch
289+
if num_instruct_candidates is None or num_fewshot_candidates is None:
290+
raise ValueError("num_candidates must be provided when auto is None.")
291+
return num_trials, valset, minibatch, num_instruct_candidates, num_fewshot_candidates
252292

253293
auto_settings = AUTO_RUN_SETTINGS[self.auto]
254294

@@ -258,12 +298,12 @@ def _set_hyperparams_from_run_mode(
258298
# Set num instruct candidates to 1/2 of N if optimizing with few-shot examples, otherwise set to N
259299
# This is because we've found that it's generally better to spend optimization budget on few-shot examples
260300
# When they are allowed.
261-
self.num_instruct_candidates = auto_settings["n"] if zeroshot_opt else int(auto_settings["n"] * 0.5)
262-
self.num_fewshot_candidates = auto_settings["n"]
301+
num_instruct_candidates = auto_settings["n"] if zeroshot_opt else int(auto_settings["n"] * 0.5)
302+
num_fewshot_candidates = auto_settings["n"]
263303

264304
num_trials = self._set_num_trials_from_num_candidates(program, zeroshot_opt, auto_settings["n"])
265305

266-
return num_trials, valset, minibatch
306+
return num_trials, valset, minibatch, num_instruct_candidates, num_fewshot_candidates
267307

268308
def _set_and_validate_datasets(self, trainset: list, valset: list | None):
269309
if not trainset:
@@ -282,13 +322,20 @@ def _set_and_validate_datasets(self, trainset: list, valset: list | None):
282322

283323
return trainset, valset
284324

285-
def _print_auto_run_settings(self, num_trials: int, minibatch: bool, valset: list):
325+
def _print_auto_run_settings(
326+
self,
327+
num_trials: int,
328+
minibatch: bool,
329+
valset: list,
330+
num_fewshot_candidates: int,
331+
num_instruct_candidates: int,
332+
):
286333
logger.info(
287334
f"\nRUNNING WITH THE FOLLOWING {self.auto.upper()} AUTO RUN SETTINGS:"
288335
f"\nnum_trials: {num_trials}"
289336
f"\nminibatch: {minibatch}"
290-
f"\nnum_fewshot_candidates: {self.num_fewshot_candidates}"
291-
f"\nnum_instruct_candidates: {self.num_instruct_candidates}"
337+
f"\nnum_fewshot_candidates: {num_fewshot_candidates}"
338+
f"\nnum_instruct_candidates: {num_instruct_candidates}"
292339
f"\nvalset size: {len(valset)}\n"
293340
)
294341

@@ -301,18 +348,19 @@ def _estimate_lm_calls(
301348
minibatch_full_eval_steps: int,
302349
valset: list,
303350
program_aware_proposer: bool,
351+
num_instruct_candidates: int,
304352
) -> tuple[str, str]:
305353
num_predictors = len(program.predictors())
306354

307355
# Estimate prompt model calls
308356
estimated_prompt_model_calls = (
309357
10 # Data summarizer calls
310-
+ self.num_instruct_candidates * num_predictors # Candidate generation
358+
+ num_instruct_candidates * num_predictors # Candidate generation
311359
+ (num_predictors + 1 if program_aware_proposer else 0) # Program-aware proposer
312360
)
313361
prompt_model_line = (
314362
f"{YELLOW}- Prompt Generation: {BLUE}{BOLD}10{ENDC}{YELLOW} data summarizer calls + "
315-
f"{BLUE}{BOLD}{self.num_instruct_candidates}{ENDC}{YELLOW} * "
363+
f"{BLUE}{BOLD}{num_instruct_candidates}{ENDC}{YELLOW} * "
316364
f"{BLUE}{BOLD}{num_predictors}{ENDC}{YELLOW} lm calls in program "
317365
f"+ ({BLUE}{BOLD}{num_predictors + 1}{ENDC}{YELLOW}) lm calls in program-aware proposer "
318366
f"= {BLUE}{BOLD}{estimated_prompt_model_calls}{ENDC}{YELLOW} prompt model calls{ENDC}"
@@ -339,38 +387,48 @@ def _estimate_lm_calls(
339387

340388
return prompt_model_line, task_model_line
341389

342-
def _bootstrap_fewshot_examples(self, program: Any, trainset: list, seed: int, teacher: Any) -> list | None:
390+
def _bootstrap_fewshot_examples(
391+
self,
392+
program: Any,
393+
trainset: list,
394+
seed: int,
395+
teacher: Any,
396+
*,
397+
num_fewshot_candidates: int,
398+
max_bootstrapped_demos: int,
399+
max_labeled_demos: int,
400+
max_errors: int | None,
401+
metric_threshold: float | None,
402+
) -> list | None:
343403
logger.info("\n==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==")
344-
if self.max_bootstrapped_demos > 0:
404+
if max_bootstrapped_demos > 0:
345405
logger.info(
346406
"These will be used as few-shot example candidates for our program and for creating instructions.\n"
347407
)
348408
else:
349409
logger.info("These will be used for informing instruction proposal.\n")
350410

351-
logger.info(f"Bootstrapping N={self.num_fewshot_candidates} sets of demonstrations...")
411+
logger.info(f"Bootstrapping N={num_fewshot_candidates} sets of demonstrations...")
352412

353-
zeroshot = self.max_bootstrapped_demos == 0 and self.max_labeled_demos == 0
413+
zeroshot = max_bootstrapped_demos == 0 and max_labeled_demos == 0
354414

355-
# try:
356-
effective_max_errors = (
357-
self.max_errors if self.max_errors is not None else dspy.settings.max_errors
358-
)
415+
if max_errors is None:
416+
max_errors = dspy.settings.max_errors
359417

360418
demo_candidates = create_n_fewshot_demo_sets(
361419
student=program,
362-
num_candidate_sets=self.num_fewshot_candidates,
420+
num_candidate_sets=num_fewshot_candidates,
363421
trainset=trainset,
364-
max_labeled_demos=(LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self.max_labeled_demos),
422+
max_labeled_demos=(LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_labeled_demos),
365423
max_bootstrapped_demos=(
366-
BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self.max_bootstrapped_demos
424+
BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_bootstrapped_demos
367425
),
368426
metric=self.metric,
369-
max_errors=effective_max_errors,
427+
max_errors=max_errors,
370428
teacher=teacher,
371429
teacher_settings=self.teacher_settings,
372430
seed=seed,
373-
metric_threshold=self.metric_threshold,
431+
metric_threshold=metric_threshold,
374432
rng=self.rng,
375433
)
376434
# NOTE: Bootstrapping is essential to MIPRO!
@@ -392,6 +450,7 @@ def _propose_instructions(
392450
data_aware_proposer: bool,
393451
tip_aware_proposer: bool,
394452
fewshot_aware_proposer: bool,
453+
num_instruct_candidates: int,
395454
) -> dict[int, list[str]]:
396455
logger.info("\n==> STEP 2: PROPOSE INSTRUCTION CANDIDATES <==")
397456
logger.info(
@@ -416,12 +475,12 @@ def _propose_instructions(
416475
init_temperature=self.init_temperature,
417476
)
418477

419-
logger.info(f"\nProposing N={self.num_instruct_candidates} instructions...\n")
478+
logger.info(f"\nProposing N={num_instruct_candidates} instructions...\n")
420479
instruction_candidates = proposer.propose_instructions_for_program(
421480
trainset=trainset,
422481
program=program,
423482
demo_candidates=demo_candidates,
424-
N=self.num_instruct_candidates,
483+
N=num_instruct_candidates,
425484
trial_logs={},
426485
)
427486

0 commit comments

Comments
 (0)