|
16 | 16 | import math |
17 | 17 | import os |
18 | 18 | from threading import Lock |
| 19 | +from types import SimpleNamespace |
19 | 20 | from unittest import mock |
| 21 | +from unittest.mock import patch |
20 | 22 |
|
21 | 23 | import torch |
22 | 24 | from vllm.config import (CompilationConfig, ModelConfig, ParallelConfig, |
@@ -170,6 +172,41 @@ def test_find_hccl_library(self): |
170 | 172 | def test_current_stream(self): |
171 | 173 | with mock.patch("torch.npu.current_stream") as mock_current_stream: |
172 | 174 | self.assertEqual(utils.current_stream(), mock_current_stream()) |
| 175 | + |
| 176 | + @patch.dict(os.environ, {"HCCL_BUFFSIZE": "2048"}) |
| 177 | + @patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options") |
| 178 | + def test_create_hccl_pg_options_mc2_with_env(self, mock_options): |
| 179 | + mock_options.return_value = SimpleNamespace(hccl_config=None) |
| 180 | + |
| 181 | + options = utils.create_hccl_pg_options("mc2") |
| 182 | + |
| 183 | + mock_options.assert_called_once_with() |
| 184 | + self.assertIsNotNone(options.hccl_config) |
| 185 | + self.assertEqual(options.hccl_config["hccl_buffer_size"], 2048) |
| 186 | + |
| 187 | + @patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"}) |
| 188 | + @patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options") |
| 189 | + def test_create_hccl_pg_options_ep_with_env(self, mock_options): |
| 190 | + mock_options.return_value = SimpleNamespace(hccl_config=None) |
| 191 | + |
| 192 | + options = utils.create_hccl_pg_options("ep") |
| 193 | + |
| 194 | + mock_options.assert_called_once_with() |
| 195 | + self.assertIsNotNone(options.hccl_config) |
| 196 | + self.assertEqual(options.hccl_config["hccl_buffer_size"], 1024) |
| 197 | + |
| 198 | + @patch.dict(os.environ, {}, clear=False) |
| 199 | + @patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options") |
| 200 | + def test_create_hccl_pg_options_ep_default(self, mock_options): |
| 201 | + os.environ.pop("HCCL_BUFFSIZE", None) |
| 202 | + mock_options.return_value = SimpleNamespace(hccl_config=None) |
| 203 | + |
| 204 | + options = utils.create_hccl_pg_options("ep") |
| 205 | + |
| 206 | + mock_options.assert_called_once_with() |
| 207 | + self.assertIsNotNone(options.hccl_config) |
| 208 | + self.assertEqual(options.hccl_config["hccl_buffer_size"], |
| 209 | + utils._DEFAULT_BUFFER_SIZE) |
173 | 210 |
|
174 | 211 | def test_vllm_version_is(self): |
175 | 212 | with mock.patch.dict(os.environ, {"VLLM_VERSION": "1.0.0"}): |
|
0 commit comments