Skip to content

Commit 9e6d5df

Browse files
committed
Change to use instance methods
1 parent 27b45f6 commit 9e6d5df

File tree

1 file changed

+100
-42
lines changed

1 file changed

+100
-42
lines changed

dspy/teleprompt/mipro_optimizer_v2.py

Lines changed: 100 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,14 @@ def compile(
125125
else dspy.settings.max_errors
126126
)
127127

128-
# Update max demos if specified
129-
initial_max_bootstrapped_demos = self.max_bootstrapped_demos
130-
if max_bootstrapped_demos is not None:
131-
self.max_bootstrapped_demos = max_bootstrapped_demos
132-
initial_max_labeled_demos = self.max_labeled_demos
133-
if max_labeled_demos is not None:
134-
self.max_labeled_demos = max_labeled_demos
128+
effective_max_bootstrapped_demos = (
129+
max_bootstrapped_demos if max_bootstrapped_demos is not None else self.max_bootstrapped_demos
130+
)
131+
effective_max_labeled_demos = (
132+
max_labeled_demos if max_labeled_demos is not None else self.max_labeled_demos
133+
)
135134

136-
zeroshot_opt = (self.max_bootstrapped_demos == 0) and (self.max_labeled_demos == 0)
135+
zeroshot_opt = (effective_max_bootstrapped_demos == 0) and (effective_max_labeled_demos == 0)
137136

138137
# If auto is None, and num_trials is not provided (but num_candidates is), raise an error that suggests a good num_trials value
139138
if self.auto is None and (self.num_candidates is not None and num_trials is None):
@@ -159,13 +158,42 @@ def compile(
159158
# Set training & validation sets
160159
trainset, valset = self._set_and_validate_datasets(trainset, valset)
161160

161+
num_instruct_candidates = (
162+
self.num_instruct_candidates
163+
if self.num_instruct_candidates is not None
164+
else self.num_candidates
165+
)
166+
num_fewshot_candidates = (
167+
self.num_fewshot_candidates
168+
if self.num_fewshot_candidates is not None
169+
else self.num_candidates
170+
)
171+
162172
# Set hyperparameters based on run mode (if set)
163-
num_trials, valset, minibatch = self._set_hyperparams_from_run_mode(
164-
student, num_trials, minibatch, zeroshot_opt, valset
173+
(
174+
num_trials,
175+
valset,
176+
minibatch,
177+
num_instruct_candidates,
178+
num_fewshot_candidates,
179+
) = self._set_hyperparams_from_run_mode(
180+
student,
181+
num_trials,
182+
minibatch,
183+
zeroshot_opt,
184+
valset,
185+
num_instruct_candidates,
186+
num_fewshot_candidates,
165187
)
166188

167189
if self.auto:
168-
self._print_auto_run_settings(num_trials, minibatch, valset)
190+
self._print_auto_run_settings(
191+
num_trials,
192+
minibatch,
193+
valset,
194+
num_fewshot_candidates,
195+
num_instruct_candidates,
196+
)
169197

170198
if minibatch and minibatch_size > len(valset):
171199
raise ValueError(f"Minibatch size cannot exceed the size of the valset. Valset size: {len(valset)}.")
@@ -183,7 +211,17 @@ def compile(
183211
)
184212

185213
# Step 1: Bootstrap few-shot examples
186-
demo_candidates = self._bootstrap_fewshot_examples(program, trainset, seed, teacher)
214+
demo_candidates = self._bootstrap_fewshot_examples(
215+
program,
216+
trainset,
217+
seed,
218+
teacher,
219+
num_fewshot_candidates=num_fewshot_candidates,
220+
max_bootstrapped_demos=effective_max_bootstrapped_demos,
221+
max_labeled_demos=effective_max_labeled_demos,
222+
max_errors=effective_max_errors,
223+
metric_threshold=self.metric_threshold,
224+
)
187225

188226
# Step 2: Propose instruction candidates
189227
instruction_candidates = self._propose_instructions(
@@ -195,6 +233,7 @@ def compile(
195233
data_aware_proposer,
196234
tip_aware_proposer,
197235
fewshot_aware_proposer,
236+
num_instruct_candidates=num_instruct_candidates,
198237
)
199238

200239
# If zero-shot, discard demos
@@ -215,10 +254,6 @@ def compile(
215254
seed,
216255
)
217256

218-
# Reset max demos
219-
self.max_bootstrapped_demos = initial_max_bootstrapped_demos
220-
self.max_labeled_demos = initial_max_labeled_demos
221-
222257
return best_program
223258

224259
def _set_random_seeds(self, seed):
@@ -237,13 +272,17 @@ def _set_num_trials_from_num_candidates(self, program, zeroshot_opt, num_candida
237272
def _set_hyperparams_from_run_mode(
238273
self,
239274
program: Any,
240-
num_trials: int,
275+
num_trials: int | None,
241276
minibatch: bool,
242277
zeroshot_opt: bool,
243278
valset: list,
244-
) -> tuple[int, list, bool]:
279+
num_instruct_candidates: int | None,
280+
num_fewshot_candidates: int | None,
281+
) -> tuple[int, list, bool, int, int]:
245282
if self.auto is None:
246-
return num_trials, valset, minibatch
283+
if num_instruct_candidates is None or num_fewshot_candidates is None:
284+
raise ValueError("num_candidates must be provided when auto is None.")
285+
return num_trials, valset, minibatch, num_instruct_candidates, num_fewshot_candidates
247286

248287
auto_settings = AUTO_RUN_SETTINGS[self.auto]
249288

@@ -253,12 +292,12 @@ def _set_hyperparams_from_run_mode(
253292
# Set num instruct candidates to 1/2 of N if optimizing with few-shot examples, otherwise set to N
254293
# This is because we've found that it's generally better to spend optimization budget on few-shot examples
255294
# When they are allowed.
256-
self.num_instruct_candidates = auto_settings["n"] if zeroshot_opt else int(auto_settings["n"] * 0.5)
257-
self.num_fewshot_candidates = auto_settings["n"]
295+
num_instruct_candidates = auto_settings["n"] if zeroshot_opt else int(auto_settings["n"] * 0.5)
296+
num_fewshot_candidates = auto_settings["n"]
258297

259298
num_trials = self._set_num_trials_from_num_candidates(program, zeroshot_opt, auto_settings["n"])
260299

261-
return num_trials, valset, minibatch
300+
return num_trials, valset, minibatch, num_instruct_candidates, num_fewshot_candidates
262301

263302
def _set_and_validate_datasets(self, trainset: list, valset: list | None):
264303
if not trainset:
@@ -277,13 +316,20 @@ def _set_and_validate_datasets(self, trainset: list, valset: list | None):
277316

278317
return trainset, valset
279318

280-
def _print_auto_run_settings(self, num_trials: int, minibatch: bool, valset: list):
319+
def _print_auto_run_settings(
320+
self,
321+
num_trials: int,
322+
minibatch: bool,
323+
valset: list,
324+
num_fewshot_candidates: int,
325+
num_instruct_candidates: int,
326+
):
281327
logger.info(
282328
f"\nRUNNING WITH THE FOLLOWING {self.auto.upper()} AUTO RUN SETTINGS:"
283329
f"\nnum_trials: {num_trials}"
284330
f"\nminibatch: {minibatch}"
285-
f"\nnum_fewshot_candidates: {self.num_fewshot_candidates}"
286-
f"\nnum_instruct_candidates: {self.num_instruct_candidates}"
331+
f"\nnum_fewshot_candidates: {num_fewshot_candidates}"
332+
f"\nnum_instruct_candidates: {num_instruct_candidates}"
287333
f"\nvalset size: {len(valset)}\n"
288334
)
289335

@@ -296,18 +342,19 @@ def _estimate_lm_calls(
296342
minibatch_full_eval_steps: int,
297343
valset: list,
298344
program_aware_proposer: bool,
345+
num_instruct_candidates: int,
299346
) -> tuple[str, str]:
300347
num_predictors = len(program.predictors())
301348

302349
# Estimate prompt model calls
303350
estimated_prompt_model_calls = (
304351
10 # Data summarizer calls
305-
+ self.num_instruct_candidates * num_predictors # Candidate generation
352+
+ num_instruct_candidates * num_predictors # Candidate generation
306353
+ (num_predictors + 1 if program_aware_proposer else 0) # Program-aware proposer
307354
)
308355
prompt_model_line = (
309356
f"{YELLOW}- Prompt Generation: {BLUE}{BOLD}10{ENDC}{YELLOW} data summarizer calls + "
310-
f"{BLUE}{BOLD}{self.num_instruct_candidates}{ENDC}{YELLOW} * "
357+
f"{BLUE}{BOLD}{num_instruct_candidates}{ENDC}{YELLOW} * "
311358
f"{BLUE}{BOLD}{num_predictors}{ENDC}{YELLOW} lm calls in program "
312359
f"+ ({BLUE}{BOLD}{num_predictors + 1}{ENDC}{YELLOW}) lm calls in program-aware proposer "
313360
f"= {BLUE}{BOLD}{estimated_prompt_model_calls}{ENDC}{YELLOW} prompt model calls{ENDC}"
@@ -334,38 +381,48 @@ def _estimate_lm_calls(
334381

335382
return prompt_model_line, task_model_line
336383

337-
def _bootstrap_fewshot_examples(self, program: Any, trainset: list, seed: int, teacher: Any) -> list | None:
384+
def _bootstrap_fewshot_examples(
385+
self,
386+
program: Any,
387+
trainset: list,
388+
seed: int,
389+
teacher: Any,
390+
*,
391+
num_fewshot_candidates: int,
392+
max_bootstrapped_demos: int,
393+
max_labeled_demos: int,
394+
max_errors: int | None,
395+
metric_threshold: float | None,
396+
) -> list | None:
338397
logger.info("\n==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==")
339-
if self.max_bootstrapped_demos > 0:
398+
if max_bootstrapped_demos > 0:
340399
logger.info(
341400
"These will be used as few-shot example candidates for our program and for creating instructions.\n"
342401
)
343402
else:
344403
logger.info("These will be used for informing instruction proposal.\n")
345404

346-
logger.info(f"Bootstrapping N={self.num_fewshot_candidates} sets of demonstrations...")
405+
logger.info(f"Bootstrapping N={num_fewshot_candidates} sets of demonstrations...")
347406

348-
zeroshot = self.max_bootstrapped_demos == 0 and self.max_labeled_demos == 0
407+
zeroshot = max_bootstrapped_demos == 0 and max_labeled_demos == 0
349408

350-
# try:
351-
effective_max_errors = (
352-
self.max_errors if self.max_errors is not None else dspy.settings.max_errors
353-
)
409+
if max_errors is None:
410+
max_errors = dspy.settings.max_errors
354411

355412
demo_candidates = create_n_fewshot_demo_sets(
356413
student=program,
357-
num_candidate_sets=self.num_fewshot_candidates,
414+
num_candidate_sets=num_fewshot_candidates,
358415
trainset=trainset,
359-
max_labeled_demos=(LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self.max_labeled_demos),
416+
max_labeled_demos=(LABELED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_labeled_demos),
360417
max_bootstrapped_demos=(
361-
BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else self.max_bootstrapped_demos
418+
BOOTSTRAPPED_FEWSHOT_EXAMPLES_IN_CONTEXT if zeroshot else max_bootstrapped_demos
362419
),
363420
metric=self.metric,
364-
max_errors=effective_max_errors,
421+
max_errors=max_errors,
365422
teacher=teacher,
366423
teacher_settings=self.teacher_settings,
367424
seed=seed,
368-
metric_threshold=self.metric_threshold,
425+
metric_threshold=metric_threshold,
369426
rng=self.rng,
370427
)
371428
# NOTE: Bootstrapping is essential to MIPRO!
@@ -387,6 +444,7 @@ def _propose_instructions(
387444
data_aware_proposer: bool,
388445
tip_aware_proposer: bool,
389446
fewshot_aware_proposer: bool,
447+
num_instruct_candidates: int,
390448
) -> dict[int, list[str]]:
391449
logger.info("\n==> STEP 2: PROPOSE INSTRUCTION CANDIDATES <==")
392450
logger.info(
@@ -411,12 +469,12 @@ def _propose_instructions(
411469
init_temperature=self.init_temperature,
412470
)
413471

414-
logger.info(f"\nProposing N={self.num_instruct_candidates} instructions...\n")
472+
logger.info(f"\nProposing N={num_instruct_candidates} instructions...\n")
415473
instruction_candidates = proposer.propose_instructions_for_program(
416474
trainset=trainset,
417475
program=program,
418476
demo_candidates=demo_candidates,
419-
N=self.num_instruct_candidates,
477+
N=num_instruct_candidates,
420478
trial_logs={},
421479
)
422480

0 commit comments

Comments
 (0)