Skip to content

Commit c6353ae

Browse files
authored
Add validation for lora_dropout (#316)
* Add validation for lora_dropout * Fix * Fix eos * Replace with llama3b * Minimize the diff * Version bump
1 parent 554d3e8 commit c6353ae

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.10"
15+
version = "1.5.11"
1616
authors = ["Together AI <[email protected]>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/resources/finetune.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ def create_finetune_request(
101101
raise ValueError(
102102
f"LoRA adapters are not supported for the selected model ({model_or_checkpoint})."
103103
)
104+
105+
if lora_dropout is not None:
106+
if not 0 <= lora_dropout < 1.0:
107+
raise ValueError("LoRA dropout must be in [0, 1) range.")
108+
104109
lora_r = lora_r if lora_r is not None else model_limits.lora_training.max_rank
105110
lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
106111
training_type = LoRATrainingType(

tests/unit/test_finetune_resources.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,32 @@ def test_lora_request():
8585
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size
8686

8787

88+
@pytest.mark.parametrize("lora_dropout", [-1, 0, 0.5, 1.0, 10.0])
89+
def test_lora_request_with_lora_dropout(lora_dropout: float):
90+
91+
if 0 <= lora_dropout < 1:
92+
request = create_finetune_request(
93+
model_limits=_MODEL_LIMITS,
94+
model=_MODEL_NAME,
95+
training_file=_TRAINING_FILE,
96+
lora=True,
97+
lora_dropout=lora_dropout,
98+
)
99+
assert request.training_type.lora_dropout == lora_dropout
100+
else:
101+
with pytest.raises(
102+
ValueError,
103+
match=r"LoRA dropout must be in \[0, 1\) range.",
104+
):
105+
create_finetune_request(
106+
model_limits=_MODEL_LIMITS,
107+
model=_MODEL_NAME,
108+
training_file=_TRAINING_FILE,
109+
lora=True,
110+
lora_dropout=lora_dropout,
111+
)
112+
113+
88114
def test_dpo_request_lora():
89115
request = create_finetune_request(
90116
model_limits=_MODEL_LIMITS,

0 commit comments

Comments
 (0)