diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index ad4f3e6a621..fbbfc4ceb22 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -31,10 +31,10 @@ from .common import (PluginOptions, convert_weights, get_mmlu_accuracy, prune_checkpoint, quantize_data, refit_model, venv_check_call) -from .conftest import (get_device_count, llm_models_root, skip_no_sm120, - skip_nvlink_inactive, skip_post_blackwell, skip_pre_ada, - skip_pre_blackwell, skip_pre_hopper, tests_path, - unittest_path) +from .conftest import (get_device_count, get_sm_version, llm_models_root, + skip_no_sm120, skip_nvlink_inactive, skip_post_blackwell, + skip_pre_ada, skip_pre_blackwell, skip_pre_hopper, + tests_path, unittest_path) sys.path.append(os.path.join(str(tests_path()), '/../examples/apps')) @@ -2184,7 +2184,6 @@ def test_ptp_quickstart_advanced_deepseek_r1_8gpus(llm_root, llm_venv, _check_mem_usage(running_log, [106.3, 0, 0, 0], 8) -@skip_post_blackwell @pytest.mark.skip_less_device_memory(110000) @pytest.mark.skip_less_device(8) @pytest.mark.parametrize("model_name,model_path", [ @@ -2195,6 +2194,7 @@ def test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus( llm_root, llm_venv, model_name, model_path): print(f"Testing {model_name}.") example_root = Path(os.path.join(llm_root, "examples", "llm-api")) + is_blackwell = get_sm_version() > 90 with tempfile.NamedTemporaryFile(mode='w+t', suffix=f".{model_name}.log", dir="./", @@ -2208,7 +2208,7 @@ def test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus( "--moe_ep_size=8", "--tp_size=8", "--use_cuda_graph", - f"--kv_cache_fraction={_MEM_FRACTION_95}", + f"--kv_cache_fraction={_MEM_FRACTION_50 if is_blackwell else _MEM_FRACTION_95}", "--max_batch_size=1", "--max_seq_len=3000", "--disable_kv_cache_reuse", @@ -2221,6 +2221,8 @@ def test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus( "--relaxed_delta=0.5", "--enable_attention_dp", "--use_one_model", + "--moe_backend", + "DEEPGEMM" if is_blackwell else "CUTLASS", ], stdout=running_log) _check_mem_usage(running_log, [85.6, 0, 0, 0], 8)