Skip to content

Commit f3c34a1

Browse files
committed
Fixed the test after rebase
Signed-off-by: Meet Patel <[email protected]>
1 parent 550c887 commit f3c34a1

File tree

4 files changed

+8
-9
lines changed

4 files changed

+8
-9
lines changed

QEfficient/cloud/finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def setup_dataloaders(
294294
return train_dataloader, eval_dataloader, longest_seq_length
295295

296296

297-
def main(peft_config_file=None, **kwargs) -> None:
297+
def main(peft_config_file: str = None, **kwargs) -> None:
298298
"""
299299
Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
300300

QEfficient/finetune/utils/config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None
8484
if train_config.peft_method not in config_map:
8585
raise RuntimeError(f"Peft config not found: {train_config.peft_method}")
8686

87-
config_cls, peft_config_cls = config_map[train_config.peft_method]()
87+
config_cls, peft_config_cls = config_map[train_config.peft_method]
8888
if config_cls is None:
8989
params = kwargs
9090
else:

QEfficient/finetune/utils/train_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
458458
# Print evaluation metrics
459459
print(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
460460

461-
return eval_metric, eval_epoch_loss, val_step_loss, val_step_metric
461+
return eval_epoch_loss, eval_metric, val_step_loss, val_step_metric
462462

463463

464464
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:

tests/finetune/test_finetune.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,10 @@ def test_finetune(
6666
}
6767

6868
results = finetune(**kwargs)
69-
70-
assert np.allclose(results["avg_train_metric"], 1.002326, atol=1e-5), "Train metric is not matching."
7169
assert np.allclose(results["avg_train_loss"], 0.00232327, atol=1e-5), "Train loss is not matching."
72-
assert np.allclose(results["avg_eval_metric"], 1.0193923, atol=1e-5), "Eval metric is not matching."
73-
assert np.allclose(results["avg_eval_loss"], 0.0192067, atol=1e-5), "Eval loss is not matching."
70+
assert np.allclose(results["avg_train_metric"], 1.002326, atol=1e-5), "Train metric is not matching."
71+
assert np.allclose(results["avg_eval_loss"], 0.0206124, atol=1e-5), "Eval loss is not matching."
72+
assert np.allclose(results["avg_eval_metric"], 1.020826, atol=1e-5), "Eval metric is not matching."
7473
assert results["avg_epoch_time"] < 60, "Training should complete within 60 seconds."
7574

7675
train_config_spy.assert_called_once()
@@ -86,8 +85,8 @@ def test_finetune(
8685
assert get_preprocessed_dataset_spy.call_count == 2
8786

8887
args, kwargs = train_spy.call_args
89-
train_dataloader = args[1]
90-
eval_dataloader = args[2]
88+
train_dataloader = args[2]
89+
eval_dataloader = args[3]
9190
optimizer = args[4]
9291

9392
batch = next(iter(train_dataloader))

0 commit comments

Comments
 (0)