|
1 | | -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE |
4 | 4 |
|
| 5 | +import dataclasses |
| 6 | + |
5 | 7 | import pytest |
6 | 8 |
|
7 | 9 | from cuda.bindings import driver, runtime |
8 | 10 | from cuda.core._utils import cuda_utils |
| 11 | +from cuda.core._utils.clear_error_support import assert_type_str_or_bytes_like, raise_code_path_meant_to_be_unreachable |
9 | 12 |
|
10 | 13 |
|
11 | 14 | def test_driver_cu_result_explanations_health(): |
@@ -76,3 +79,60 @@ def test_check_runtime_error(): |
76 | 79 | assert enum_name in msg |
77 | 80 | # Smoke test: We don't want most to be unexpected. |
78 | 81 | assert num_unexpected < len(driver.CUresult) * 0.5 |
| 82 | + |
| 83 | + |
| 84 | +def test_precondition(): |
| 85 | + def checker(*args, what=""): |
| 86 | + if args[0] < 0: |
| 87 | + raise ValueError(f"{what}: negative") |
| 88 | + |
| 89 | + @cuda_utils.precondition(checker, what="value check") |
| 90 | + def my_func(x): |
| 91 | + return x * 2 |
| 92 | + |
| 93 | + assert my_func(5) == 10 |
| 94 | + with pytest.raises(ValueError, match="negative"): |
| 95 | + my_func(-1) |
| 96 | + |
| 97 | + |
| 98 | +@dataclasses.dataclass |
| 99 | +class _DummyOptions: |
| 100 | + x: int = 1 |
| 101 | + y: str = "hello" |
| 102 | + |
| 103 | + |
| 104 | +def test_check_nvrtc_error_without_handle(): |
| 105 | + from cuda.bindings import nvrtc |
| 106 | + |
| 107 | + assert cuda_utils._check_nvrtc_error(nvrtc.nvrtcResult.NVRTC_SUCCESS) == 0 |
| 108 | + with pytest.raises(cuda_utils.NVRTCError): |
| 109 | + cuda_utils._check_nvrtc_error(nvrtc.nvrtcResult.NVRTC_ERROR_COMPILATION) |
| 110 | + |
| 111 | + |
| 112 | +def test_check_nvrtc_error_with_handle(init_cuda): |
| 113 | + from cuda.bindings import nvrtc |
| 114 | + |
| 115 | + err, prog = nvrtc.nvrtcCreateProgram(b"invalid code!@#$", b"test.cu", 0, [], []) |
| 116 | + assert err == nvrtc.nvrtcResult.NVRTC_SUCCESS |
| 117 | + try: |
| 118 | + (compile_result,) = nvrtc.nvrtcCompileProgram(prog, 0, []) |
| 119 | + assert compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS |
| 120 | + with pytest.raises(cuda_utils.NVRTCError, match="compilation log"): |
| 121 | + cuda_utils._check_nvrtc_error(compile_result, handle=prog) |
| 122 | + finally: |
| 123 | + nvrtc.nvrtcDestroyProgram(prog) |
| 124 | + |
| 125 | + |
| 126 | +def test_check_or_create_options_invalid_type(): |
| 127 | + with pytest.raises(TypeError, match="must be provided as an object"): |
| 128 | + cuda_utils.check_or_create_options(_DummyOptions, 12345, options_description="test options") |
| 129 | + |
| 130 | + |
| 131 | +def test_assert_type_str_or_bytes_like_rejects_non_str_bytes(): |
| 132 | + with pytest.raises(TypeError, match="Expected type str or bytes or bytearray"): |
| 133 | + assert_type_str_or_bytes_like(12345) |
| 134 | + |
| 135 | + |
| 136 | +def test_raise_code_path_meant_to_be_unreachable(): |
| 137 | + with pytest.raises(RuntimeError, match="This code path is meant to be unreachable"): |
| 138 | + raise_code_path_meant_to_be_unreachable() |
0 commit comments