Skip to content
14 changes: 8 additions & 6 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from .common import (PluginOptions, convert_weights, get_mmlu_accuracy,
prune_checkpoint, quantize_data, refit_model,
venv_check_call)
from .conftest import (get_device_count, llm_models_root, skip_no_sm120,
skip_nvlink_inactive, skip_post_blackwell, skip_pre_ada,
skip_pre_blackwell, skip_pre_hopper, tests_path,
unittest_path)
from .conftest import (get_device_count, get_sm_version, llm_models_root,
skip_no_sm120, skip_nvlink_inactive, skip_post_blackwell,
skip_pre_ada, skip_pre_blackwell, skip_pre_hopper,
tests_path, unittest_path)

sys.path.append(os.path.join(str(tests_path()), '/../examples/apps'))

Expand Down Expand Up @@ -2184,7 +2184,6 @@ def test_ptp_quickstart_advanced_deepseek_r1_8gpus(llm_root, llm_venv,
_check_mem_usage(running_log, [106.3, 0, 0, 0], 8)


@skip_post_blackwell
@pytest.mark.skip_less_device_memory(110000)
@pytest.mark.skip_less_device(8)
@pytest.mark.parametrize("model_name,model_path", [
Expand All @@ -2195,6 +2194,7 @@ def test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus(
llm_root, llm_venv, model_name, model_path):
print(f"Testing {model_name}.")
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
is_blackwell = get_sm_version() > 90
with tempfile.NamedTemporaryFile(mode='w+t',
suffix=f".{model_name}.log",
dir="./",
Expand All @@ -2208,7 +2208,7 @@ def test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus(
"--moe_ep_size=8",
"--tp_size=8",
"--use_cuda_graph",
f"--kv_cache_fraction={_MEM_FRACTION_95}",
f"--kv_cache_fraction={_MEM_FRACTION_50 if is_blackwell else _MEM_FRACTION_95}",
"--max_batch_size=1",
"--max_seq_len=3000",
"--disable_kv_cache_reuse",
Expand All @@ -2221,6 +2221,8 @@ def test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus(
"--relaxed_delta=0.5",
"--enable_attention_dp",
"--use_one_model",
"--moe_backend",
"DEEPGEMM" if is_blackwell else "CUTLASS",
],
stdout=running_log)
_check_mem_usage(running_log, [85.6, 0, 0, 0], 8)
Expand Down