Skip to content

Commit 72ee382

Browse files
chore: formatting
1 parent 3a521c9 commit 72ee382

36 files changed

+715
-450
lines changed

integration-tests/conftest.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
class ResponseComparator(JSONSnapshotExtension):
2727
rtol = 0.2
28+
2829
def serialize(
2930
self,
3031
data,
@@ -69,7 +70,9 @@ def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool:
6970
prefill_token.id == other.id
7071
and prefill_token.text == other.text
7172
and (
72-
math.isclose(prefill_token.logprob, other.logprob, rel_tol=self.rtol)
73+
math.isclose(
74+
prefill_token.logprob, other.logprob, rel_tol=self.rtol
75+
)
7376
if prefill_token.logprob is not None
7477
else prefill_token.logprob == other.logprob
7578
)
@@ -153,6 +156,7 @@ class GenerousResponseComparator(ResponseComparator):
153156
# Needed for GPTQ with exllama which has serious numerical fluctuations.
154157
rtol = 0.75
155158

159+
156160
class LauncherHandle:
157161
def __init__(self, port: int):
158162
self.client = AsyncClient(f"http://localhost:{port}")
@@ -198,6 +202,7 @@ def _inner_health(self) -> bool:
198202
def response_snapshot(snapshot):
199203
return snapshot.use_extension(ResponseComparator)
200204

205+
201206
@pytest.fixture
202207
def generous_response_snapshot(snapshot):
203208
return snapshot.use_extension(GenerousResponseComparator)
@@ -219,7 +224,7 @@ def local_launcher(
219224
quantize: Optional[str] = None,
220225
trust_remote_code: bool = False,
221226
use_flash_attention: bool = True,
222-
dtype: Optional[str] = None
227+
dtype: Optional[str] = None,
223228
):
224229
port = random.randint(8000, 10_000)
225230
master_port = random.randint(10_000, 20_000)
@@ -282,7 +287,7 @@ def docker_launcher(
282287
quantize: Optional[str] = None,
283288
trust_remote_code: bool = False,
284289
use_flash_attention: bool = True,
285-
dtype: Optional[str] = None
290+
dtype: Optional[str] = None,
286291
):
287292
port = random.randint(8000, 10_000)
288293

@@ -335,7 +340,7 @@ def docker_launcher(
335340
],
336341
volumes=volumes,
337342
ports={"80/tcp": port},
338-
shm_size="1G"
343+
shm_size="1G",
339344
)
340345

341346
yield ContainerLauncherHandle(client, container.name, port)

integration-tests/models/test_flash_medusa.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,16 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
5050
@pytest.mark.asyncio
5151
@pytest.mark.private
5252
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
53-
responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4)
53+
responses = await generate_load(
54+
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4
55+
)
5456

5557
assert len(responses) == 4
56-
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
57-
assert responses[0].generated_text == '\nDeep learning is a subset of machine learning'
58+
assert all(
59+
[r.generated_text == responses[0].generated_text for r in responses]
60+
), f"{[r.generated_text for r in responses]}"
61+
assert (
62+
responses[0].generated_text == "\nDeep learning is a subset of machine learning"
63+
)
5864

5965
assert responses == response_snapshot

integration-tests/models/test_flash_mistral.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho
5656
)
5757

5858
assert len(responses) == 4
59-
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
59+
assert all(
60+
[r.generated_text == responses[0].generated_text for r in responses]
61+
), f"{[r.generated_text for r in responses]}"
6062
assert responses[0].generated_text == ": Let n = 10 - 1"
6163

6264
assert responses == response_snapshot

integration-tests/models/test_idefics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
@pytest.fixture(scope="module")
55
def idefics_handle(launcher):
6-
with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16") as handle:
6+
with launcher(
7+
"HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16"
8+
) as handle:
79
yield handle
810

911

server/tests/models/test_bloom.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,20 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
133133
)
134134
assert all([generation.generated_text is None for generation in generations])
135135
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
136-
assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids])
137-
assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts])
136+
assert all(
137+
[
138+
token_id.item() == 10264
139+
for generation in generations
140+
for token_id in generation.tokens.token_ids
141+
]
142+
)
143+
assert all(
144+
[
145+
token_text == "Test"
146+
for generation in generations
147+
for token_text in generation.tokens.texts
148+
]
149+
)
138150
assert generations[0].request_id == 0
139151

140152

server/tests/models/test_causal_lm.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,20 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
129129
)
130130
assert all([generation.generated_text is None for generation in generations])
131131
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
132-
assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids])
133-
assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts])
132+
assert all(
133+
[
134+
token_id.item() == 13
135+
for generation in generations
136+
for token_id in generation.tokens.token_ids
137+
]
138+
)
139+
assert all(
140+
[
141+
token_text == "."
142+
for generation in generations
143+
for token_text in generation.tokens.texts
144+
]
145+
)
134146
assert generations[0].request_id == 0
135147

136148

server/tests/models/test_seq2seq_lm.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
151151
)
152152
assert all([generation.generated_text is None for generation in generations])
153153
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
154-
assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids])
155-
assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts])
154+
assert all(
155+
[
156+
token_id.item() == 259
157+
for generation in generations
158+
for token_id in generation.tokens.token_ids
159+
]
160+
)
161+
assert all(
162+
[
163+
token_text == " "
164+
for generation in generations
165+
for token_text in generation.tokens.texts
166+
]
167+
)
156168
assert generations[0].request_id == 0
157169

158170

server/text_generation_server/cli.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,24 @@ def serve(
7777
# Downgrade enum into str for easier management later on
7878
quantize = None if quantize is None else quantize.value
7979
dtype = None if dtype is None else dtype.value
80-
if dtype is not None and quantize not in {None, "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4"}:
80+
if dtype is not None and quantize not in {
81+
None,
82+
"bitsandbytes",
83+
"bitsandbytes-nf4",
84+
"bitsandbytes-fp4",
85+
}:
8186
raise RuntimeError(
8287
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
8388
)
8489
server.serve(
85-
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path
90+
model_id,
91+
revision,
92+
sharded,
93+
quantize,
94+
speculate,
95+
dtype,
96+
trust_remote_code,
97+
uds_path,
8698
)
8799

88100

@@ -140,23 +152,35 @@ def download_weights(
140152

141153
try:
142154
import json
143-
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
155+
156+
medusa_head = hf_hub_download(
157+
model_id, revision=revision, filename="medusa_lm_head.pt"
158+
)
144159
if auto_convert:
145-
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors")
160+
medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
146161
if not medusa_sf.exists():
147162
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
148-
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json")
163+
medusa_config = hf_hub_download(
164+
model_id, revision=revision, filename="config.json"
165+
)
149166
with open(medusa_config, "r") as f:
150167
config = json.load(f)
151168

152169
model_id = config["base_model_name_or_path"]
153170
revision = "main"
154171
try:
155172
utils.weight_files(model_id, revision, extension)
156-
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.")
173+
logger.info(
174+
f"Files for parent {model_id} are already present on the host. "
175+
"Skipping download."
176+
)
157177
return
158178
# Local files not found
159-
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
179+
except (
180+
utils.LocalEntryNotFoundError,
181+
FileNotFoundError,
182+
utils.EntryNotFoundError,
183+
):
160184
pass
161185
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
162186
pass

server/text_generation_server/models/__init__.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@
8888
__all__.append(FlashMixtral)
8989

9090

91-
9291
def get_model(
9392
model_id: str,
9493
revision: Optional[str],
@@ -157,7 +156,9 @@ def get_model(
157156
speculate_medusa = config_dict["medusa_num_heads"]
158157
if speculate is not None:
159158
if speculate > speculate_medusa:
160-
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
159+
raise RuntimeError(
160+
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
161+
)
161162
else:
162163
set_speculate(speculate)
163164
else:
@@ -249,7 +250,7 @@ def get_model(
249250
quantize=quantize,
250251
dtype=dtype,
251252
trust_remote_code=trust_remote_code,
252-
use_medusa=use_medusa
253+
use_medusa=use_medusa,
253254
)
254255
elif sharded:
255256
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
@@ -313,7 +314,9 @@ def get_model(
313314
dtype=dtype,
314315
trust_remote_code=trust_remote_code,
315316
)
316-
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")
317+
raise NotImplementedError(
318+
"Mixtral models requires flash attention v2, stk and megablocks"
319+
)
317320

318321
if model_type == "opt":
319322
return OPTSharded(
@@ -354,7 +357,7 @@ def get_model(
354357
raise ValueError("awq quantization is not supported for AutoModel")
355358
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
356359
raise ValueError("4bit quantization is not supported for AutoModel")
357-
elif (quantize == "eetq"):
360+
elif quantize == "eetq":
358361
raise ValueError("Eetq quantization is not supported for AutoModel")
359362
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
360363
return CausalLM(

server/text_generation_server/models/bloom.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ def __init__(
7474
torch.distributed.barrier(group=self.process_group)
7575
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
7676
weights = Weights(
77-
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer",
77+
filenames,
78+
device=device,
79+
dtype=dtype,
80+
process_group=self.process_group,
81+
prefix="transformer",
7882
)
7983
if config.quantize == "gptq":
8084
weights._set_gptq_params(model_id)

0 commit comments

Comments
 (0)