Skip to content

Commit d6fc1f4

Browse files
authored
added test clean up and customization (#8)
* added test clean up and customization Signed-off-by: dpatel-ops <[email protected]> * update workflow to test only eager mode Signed-off-by: dpatel-ops <[email protected]> --------- Signed-off-by: dpatel-ops <[email protected]> Co-authored-by: dpatel-ops <[email protected]>
1 parent eb7a75e commit d6fc1f4

10 files changed

+75
-30
lines changed

.github/workflows/test-spyre.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,5 @@ jobs:
2424
export MASTER_ADDR=localhost && \
2525
export DISTRIBUTED_STRATEGY_IGNORE_MODULES=WordEmbedding && \
2626
cd vllm-spyre && \
27-
python -m pytest tests -v
27+
python -m pytest tests -v -k eager
2828
'''

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import torch
23
from vllm.connections import global_http_connection
34
from vllm.distributed import cleanup_dist_env_and_memory
45

@@ -42,3 +43,9 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
4243
yield
4344
if should_do_global_cleanup_after_test:
4445
cleanup_dist_env_and_memory()
46+
47+
48+
@pytest.fixture(autouse=True)
49+
def dynamo_reset():
50+
yield
51+
torch._dynamo.reset()

tests/spyre_util.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,4 +277,42 @@ def compare_embedding_results(model: str, prompts: List[str],
277277
sim = util.pytorch_cos_sim(hf_result["embeddings"],
278278
vllm_result["embeddings"])
279279

280-
assert math.isclose(sim, 1.0, rel_tol=0.05)
280+
assert math.isclose(sim, 1.0, rel_tol=0.05)
281+
282+
283+
# get model directory path from env, if not set then default to "/models".
284+
def get_spyre_model_dir_path():
285+
model_dir_path = os.environ.get("VLLM_SPYRE_TEST_MODEL_DIR", "/models")
286+
return model_dir_path
287+
288+
289+
# get model backend from env, if not set then default to "eager"
290+
# For multiple values:
291+
# export SPYRE_TEST_BACKEND_LIST="eager, inductor, sendnn_decoder"
292+
def get_spyre_backend_list():
293+
test_backend_list = []
294+
user_backend_list = os.environ.get("VLLM_SPYRE_TEST_BACKEND_LIST",
295+
"eager,inductor,sendnn_decoder,sendnn")
296+
297+
for sypre_backend in user_backend_list.split(","):
298+
test_backend_list.append(sypre_backend.strip())
299+
return test_backend_list
300+
301+
302+
# get model names from env, if not set then default to "llama-194m"
303+
# For multiple values:
304+
# export SPYRE_TEST_MODEL_LIST="llama-194m,all-roberta-large-v1"
305+
def get_spyre_model_list(isEmbeddings=False):
306+
spyre_model_dir_path = get_spyre_model_dir_path()
307+
test_model_list = []
308+
user_test_model_list = os.environ.get("VLLM_SPYRE_TEST_MODEL_LIST",
309+
"llama-194m")
310+
311+
# set default to bert if testing embeddings
312+
if isEmbeddings:
313+
user_test_model_list = os.environ.get("VLLM_SPYRE_TEST_MODEL_LIST",
314+
"all-roberta-large-v1")
315+
316+
for model in user_test_model_list.split(","):
317+
test_model_list.append(f"{spyre_model_dir_path}/{model.strip()}")
318+
return test_model_list

tests/test_spyre_basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
import pytest
99
from spyre_util import (compare_results, generate_hf_output,
10-
generate_spyre_vllm_output)
10+
generate_spyre_vllm_output, get_spyre_backend_list,
11+
get_spyre_model_list)
1112
from vllm import SamplingParams
1213

1314

14-
@pytest.mark.parametrize("model", ["/models/llama-194m"])
15+
@pytest.mark.parametrize("model", get_spyre_model_list())
1516
@pytest.mark.parametrize("prompts", [[
1617
"Provide a list of instructions for preparing"
1718
" chicken soup for a family of four.", "Hello",
@@ -20,8 +21,7 @@
2021
@pytest.mark.parametrize("warmup_shape", [(64, 20, 4), (64, 20, 8),
2122
(128, 20, 4), (128, 20, 8)]
2223
) # (prompt_length/new_tokens/batch_size)
23-
@pytest.mark.parametrize("backend",
24-
["eager"]) #, "inductor", "sendnn_decoder"])
24+
@pytest.mark.parametrize("backend", get_spyre_backend_list())
2525
def test_output(
2626
model: str,
2727
prompts: List[str],

tests/test_spyre_embeddings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
from typing import List, Tuple
77

88
import pytest
9-
from spyre_util import (compare_embedding_results, spyre_vllm_embeddings,
9+
from spyre_util import (compare_embedding_results, get_spyre_backend_list,
10+
get_spyre_model_list, spyre_vllm_embeddings,
1011
st_embeddings)
1112

1213

13-
@pytest.mark.parametrize("model", ["/models/all-roberta-large-v1"])
14+
@pytest.mark.parametrize("model", get_spyre_model_list(isEmbeddings=True))
1415
@pytest.mark.parametrize("prompts", [[
1516
"The capital of France is Paris."
1617
"Provide a list of instructions for preparing"
@@ -20,8 +21,7 @@
2021
@pytest.mark.parametrize("warmup_shape",
2122
[(64, 4), (64, 8), (128, 4),
2223
(128, 8)]) # (prompt_length/new_tokens/batch_size)
23-
@pytest.mark.parametrize("backend",
24-
["eager"]) #, "inductor", "sendnn_decoder"])
24+
@pytest.mark.parametrize("backend", get_spyre_backend_list())
2525
def test_output(
2626
model: str,
2727
prompts: List[str],

tests/test_spyre_max_new_tokens.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
import pytest
99
from spyre_util import (compare_results, generate_hf_output,
10-
generate_spyre_vllm_output)
10+
generate_spyre_vllm_output, get_spyre_backend_list,
11+
get_spyre_model_list)
1112
from vllm import SamplingParams
1213

1314
template = (
@@ -20,15 +21,14 @@
2021
"chicken soup for a family of four.")
2122

2223

23-
@pytest.mark.parametrize("model", ["/models/llama-194m"])
24+
@pytest.mark.parametrize("model", get_spyre_model_list())
2425
@pytest.mark.parametrize("prompts", [[prompt1, prompt2, prompt2, prompt2],
2526
[prompt2, prompt2, prompt2, prompt1],
2627
[prompt2, prompt2, prompt2, prompt2]])
2728
@pytest.mark.parametrize("stop_last", [True, False])
2829
@pytest.mark.parametrize("warmup_shape", [(64, 10, 4)]
2930
) # (prompt_length/new_tokens/batch_size)
30-
@pytest.mark.parametrize("backend",
31-
["eager"]) #, "inductor", "sendnn_decoder"])
31+
@pytest.mark.parametrize("backend", get_spyre_backend_list())
3232
def test_output(
3333
model: str,
3434
prompts: List[str],

tests/test_spyre_max_prompt_length.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77

88
import pytest
99
from spyre_util import (compare_results, generate_hf_output,
10-
generate_spyre_vllm_output)
10+
generate_spyre_vllm_output, get_spyre_backend_list,
11+
get_spyre_model_list)
1112
from transformers import AutoTokenizer
1213
from vllm import SamplingParams
1314

1415

15-
@pytest.mark.parametrize("model", ["/models/llama-194m"])
16+
@pytest.mark.parametrize("model", get_spyre_model_list())
1617
@pytest.mark.parametrize("prompts", [
1718
7 * [
1819
"Hello",
@@ -27,8 +28,7 @@
2728
@pytest.mark.parametrize("warmup_shapes",
2829
[[(64, 20, 4)], [(64, 20, 4), (128, 20, 4)]]
2930
) # (prompt_length/new_tokens/batch_size)
30-
@pytest.mark.parametrize("backend",
31-
["eager"]) #, "inductor", "sendnn_decoder"])
31+
@pytest.mark.parametrize("backend", get_spyre_backend_list())
3232
def test_output(
3333
model: str,
3434
prompts: List[str],

tests/test_spyre_seed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from typing import Tuple
88

99
import pytest
10-
from spyre_util import generate_spyre_vllm_output
10+
from spyre_util import (generate_spyre_vllm_output, get_spyre_backend_list,
11+
get_spyre_model_list)
1112
from vllm import SamplingParams
1213

1314

14-
@pytest.mark.parametrize("model", ["/models/llama-194m"])
15+
@pytest.mark.parametrize("model", get_spyre_model_list())
1516
@pytest.mark.parametrize("prompt", [
1617
"Provide a list of instructions for preparing"
1718
" chicken soup for a family of four."
@@ -21,8 +22,7 @@
2122
@pytest.mark.parametrize("warmup_shape", [(64, 20, 4), (64, 20, 8),
2223
(128, 20, 4), (128, 20, 8)]
2324
) # (prompt_length/new_tokens/batch_size)
24-
@pytest.mark.parametrize("backend",
25-
["eager"]) #, "inductor", "sendnn_decoder"])
25+
@pytest.mark.parametrize("backend", get_spyre_backend_list())
2626
def test_seed(
2727
model: str,
2828
prompt: str,

tests/test_spyre_tensor_parallel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
import pytest
99
from spyre_util import (compare_results, generate_hf_output,
10-
generate_spyre_vllm_output)
10+
generate_spyre_vllm_output, get_spyre_backend_list,
11+
get_spyre_model_list)
1112
from vllm import SamplingParams
1213

1314

14-
@pytest.mark.parametrize("model", ["/models/llama-194m"])
15+
@pytest.mark.parametrize("model", get_spyre_model_list())
1516
@pytest.mark.parametrize("prompts", [[
1617
"Provide a list of instructions for preparing"
1718
" chicken soup for a family of four.", "Hello",
@@ -21,8 +22,7 @@
2122
) #,[(64,20,8)],[(128,20,4)],[(128,20,8)]])
2223
# (prompt_length/new_tokens/batch_size)
2324
@pytest.mark.parametrize("tp_size", [2])
24-
@pytest.mark.parametrize("backend",
25-
["eager"]) #, "inductor", "sendnn_decoder"])
25+
@pytest.mark.parametrize("backend", get_spyre_backend_list())
2626
def test_output(
2727
model: str,
2828
prompts: List[str],

tests/test_spyre_warmup_shapes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
import pytest
99
from spyre_util import (compare_results, generate_hf_output,
10-
generate_spyre_vllm_output)
10+
generate_spyre_vllm_output, get_spyre_backend_list,
11+
get_spyre_model_list)
1112
from vllm import SamplingParams
1213

1314

14-
@pytest.mark.parametrize("model", ["/models/llama-194m"])
15+
@pytest.mark.parametrize("model", get_spyre_model_list())
1516
@pytest.mark.parametrize("prompts", [
1617
7 * [
1718
"Hello",
@@ -25,8 +26,7 @@
2526
])
2627
@pytest.mark.parametrize("warmup_shapes", [[(64, 20, 8), (128, 20, 4)]]
2728
) # (prompt_length/new_tokens/batch_size)
28-
@pytest.mark.parametrize("backend",
29-
["eager"]) #, "inductor", "sendnn_decoder"])
29+
@pytest.mark.parametrize("backend", get_spyre_backend_list())
3030
def test_output(
3131
model: str,
3232
prompts: List[str],

0 commit comments

Comments
 (0)