diff --git a/problems/pmpp_v2/conv2d_py/reference.py b/problems/pmpp_v2/conv2d_py/reference.py index 9e5e1a7..0f8cb10 100644 --- a/problems/pmpp_v2/conv2d_py/reference.py +++ b/problems/pmpp_v2/conv2d_py/reference.py @@ -1,4 +1,4 @@ -from utils import make_match_reference +from utils import make_match_reference, DeterministicContext import torch import torch.nn.functional as F from task import input_t, output_t @@ -12,45 +12,55 @@ def ref_kernel(data: input_t) -> output_t: Returns: Output tensor after convolution """ - input_tensor, kernel = data - return F.conv2d( - input_tensor, - kernel, - - # No padding and no striding - # TODO: Can revisit this in future problems - stride=1, - padding=0 - ) + with DeterministicContext(): + input_tensor, kernel, output = data + return F.conv2d( + input_tensor, + kernel, + # No padding and no striding + stride=1, + padding=0, + ) -def generate_input(size: int, kernelsize: int, channels: int, batch: int, seed: int) -> input_t: +def generate_input( + size: int, kernelsize: int, channels: int, batch: int, seed: int +) -> input_t: """ Generates random input and kernel tensors. Returns: Tuple of (input tensor, kernel tensor) """ - gen = torch.Generator(device='cuda') + gen = torch.Generator(device="cuda") gen.manual_seed(seed) - + # Generate input tensor: [batch, in_channels, height, width] input_tensor = torch.randn( - batch, channels, size, size, - device='cuda', - dtype=torch.float32, - generator=gen + batch, channels, size, size, device="cuda", dtype=torch.float32, generator=gen ).contiguous() - + # Generate kernel tensor: [out_channels, in_channels, kernel_height, kernel_width] # Here we use same number of output channels as input channels for simplicity kernel = torch.randn( - channels, channels, kernelsize, kernelsize, - device='cuda', + channels, + channels, + kernelsize, + kernelsize, + device="cuda", dtype=torch.float32, - generator=gen + generator=gen, ).contiguous() - - return (input_tensor, kernel) + + output_tensor = torch.empty( + batch, + channels, + size - kernelsize + 1, + size - kernelsize + 1, + device="cuda", + dtype=torch.float32, + ) + + return input_tensor, kernel, output_tensor check_implementation = make_match_reference(ref_kernel, rtol=1e-3, atol=1e-3) diff --git a/problems/pmpp_v2/conv2d_py/solutions/correct/ref.py b/problems/pmpp_v2/conv2d_py/solutions/correct/ref.py index c0ce3f2..8931380 100644 --- a/problems/pmpp_v2/conv2d_py/solutions/correct/ref.py +++ b/problems/pmpp_v2/conv2d_py/solutions/correct/ref.py @@ -1,13 +1,8 @@ from task import input_t, output_t -import torch import torch.nn.functional as F def custom_kernel(data: input_t) -> output_t: - input_tensor, kernel = data - return F.conv2d( - input_tensor, - kernel, - stride=1, - padding=0 - ) + input_tensor, kernel, output = data + output[...] = F.conv2d(input_tensor, kernel, stride=1, padding=0) + return output diff --git a/problems/pmpp_v2/conv2d_py/solutions/wrong/empty.py b/problems/pmpp_v2/conv2d_py/solutions/wrong/empty.py index 899beb0..7e6cef7 100644 --- a/problems/pmpp_v2/conv2d_py/solutions/wrong/empty.py +++ b/problems/pmpp_v2/conv2d_py/solutions/wrong/empty.py @@ -1,11 +1,7 @@ # the nop kernel from task import input_t, output_t -import torch -import torch.nn.functional as F def custom_kernel(data: input_t) -> output_t: - input_tensor, kernel = data - return torch.empty((input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2]-kernel.shape[3]+1, input_tensor.shape[3]-kernel.shape[3]+1), - device=kernel.device, dtype=kernel.dtype - ) + _, _, output = data + return output diff --git a/problems/pmpp_v2/conv2d_py/submission.py b/problems/pmpp_v2/conv2d_py/submission.py index a1b7d16..4f1efb4 100644 --- a/problems/pmpp_v2/conv2d_py/submission.py +++ b/problems/pmpp_v2/conv2d_py/submission.py @@ -12,10 +12,6 @@ def custom_kernel(data: input_t) -> output_t: Returns: Output tensor after convolution """ - input_tensor, kernel = data - return F.conv2d( - input_tensor, - kernel, - stride=1, - padding=0 - ) \ No newline at end of file + input_tensor, kernel, output = data + output[...] = F.conv2d(input_tensor, kernel, stride=1, padding=0) + return output diff --git a/problems/pmpp_v2/conv2d_py/task.py b/problems/pmpp_v2/conv2d_py/task.py index 397332a..dc0b771 100644 --- a/problems/pmpp_v2/conv2d_py/task.py +++ b/problems/pmpp_v2/conv2d_py/task.py @@ -1,7 +1,7 @@ from typing import TypedDict, TypeVar, Tuple import torch -input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, torch.Tensor]) +input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) output_t = TypeVar("output_t", bound=torch.Tensor) @@ -10,4 +10,4 @@ class TestSpec(TypedDict): kernelsize: int channels: int batch: int - seed: int \ No newline at end of file + seed: int diff --git a/problems/pmpp_v2/grayscale_py/reference.py b/problems/pmpp_v2/grayscale_py/reference.py index 1ed6d14..3190a05 100644 --- a/problems/pmpp_v2/grayscale_py/reference.py +++ b/problems/pmpp_v2/grayscale_py/reference.py @@ -1,4 +1,4 @@ -from utils import make_match_reference +from utils import make_match_reference, DeterministicContext import torch from task import input_t, output_t @@ -7,17 +7,20 @@ def ref_kernel(data: input_t) -> output_t: """ Reference implementation of RGB to grayscale conversion using PyTorch. Uses the standard coefficients: Y = 0.2989 R + 0.5870 G + 0.1140 B - + Args: data: RGB tensor of shape (H, W, 3) with values in [0, 1] Returns: Grayscale tensor of shape (H, W) with values in [0, 1] """ - # Standard RGB to Grayscale coefficients - weights = torch.tensor([0.2989, 0.5870, 0.1140], - device=data.device, - dtype=data.dtype) - return torch.sum(data * weights, dim=-1) + with DeterministicContext(): + data, output = data + # Standard RGB to Grayscale coefficients + weights = torch.tensor( + [0.2989, 0.5870, 0.1140], device=data.device, dtype=data.dtype + ) + output[...] = torch.sum(data * weights, dim=-1) + return output def generate_input(size: int, seed: int) -> input_t: @@ -26,12 +29,16 @@ def generate_input(size: int, seed: int) -> input_t: Returns: Tensor of shape (size, size, 3) with values in [0, 1] """ - gen = torch.Generator(device='cuda') + gen = torch.Generator(device="cuda") gen.manual_seed(seed) - return torch.rand(size, size, 3, - device='cuda', - dtype=torch.float32, - generator=gen).contiguous() + + x = torch.rand( + size, size, 3, device="cuda", dtype=torch.float32, generator=gen + ).contiguous() + + y = torch.empty(size, size, device="cuda", dtype=torch.float32).contiguous() + + return x, y check_implementation = make_match_reference(ref_kernel, rtol=1e-4, atol=1e-4) diff --git a/problems/pmpp_v2/grayscale_py/solutions/correct/ref.py b/problems/pmpp_v2/grayscale_py/solutions/correct/ref.py index 6a40c3e..6a9d1b7 100644 --- a/problems/pmpp_v2/grayscale_py/solutions/correct/ref.py +++ b/problems/pmpp_v2/grayscale_py/solutions/correct/ref.py @@ -3,7 +3,9 @@ def custom_kernel(data: input_t) -> output_t: - weights = torch.tensor([0.2989, 0.5870, 0.1140], - device=data.device, - dtype=data.dtype) - return torch.sum(data * weights, dim=-1) + data, output = data + weights = torch.tensor( + [0.2989, 0.5870, 0.1140], device=data.device, dtype=data.dtype + ) + output[...] = torch.sum(data * weights, dim=-1) + return output diff --git a/problems/pmpp_v2/grayscale_py/solutions/wrong/empty.py b/problems/pmpp_v2/grayscale_py/solutions/wrong/empty.py index e37e32b..129b896 100644 --- a/problems/pmpp_v2/grayscale_py/solutions/wrong/empty.py +++ b/problems/pmpp_v2/grayscale_py/solutions/wrong/empty.py @@ -4,4 +4,5 @@ def custom_kernel(data: input_t) -> output_t: - return torch.empty(size=(data.shape[0], data.shape[1]), device=data.device, dtype=data.dtype) + _, output = data + return output diff --git a/problems/pmpp_v2/grayscale_py/submission.py b/problems/pmpp_v2/grayscale_py/submission.py index de0c149..9e55306 100644 --- a/problems/pmpp_v2/grayscale_py/submission.py +++ b/problems/pmpp_v2/grayscale_py/submission.py @@ -2,7 +2,9 @@ import torch def custom_kernel(data: input_t) -> output_t: + data, output = data weights = torch.tensor([0.2989, 0.5870, 0.1140], device=data.device, dtype=data.dtype) - return torch.sum(data * weights, dim=-1) + output[...] = torch.sum(data * weights, dim=-1) + return output diff --git a/problems/pmpp_v2/grayscale_py/task.py b/problems/pmpp_v2/grayscale_py/task.py index 4a717fc..26a2f52 100644 --- a/problems/pmpp_v2/grayscale_py/task.py +++ b/problems/pmpp_v2/grayscale_py/task.py @@ -1,9 +1,14 @@ from typing import TypedDict, TypeVar import torch -input_t = TypeVar("input_t", bound=torch.Tensor) # Input will be (H, W, 3) RGB tensor -output_t = TypeVar("output_t", bound=torch.Tensor) # Output will be (H, W) grayscale tensor +input_t = TypeVar( + "input_t", bound=tuple[torch.Tensor, torch.Tensor] +) # Input is a pair of tensors (input, output) where input is (H, W, 3) RGB tensor and output is (H, W) grayscale tensor +output_t = TypeVar( + "output_t", bound=torch.Tensor +) # Output will be (H, W) grayscale tensor + class TestSpec(TypedDict): size: int # Size of the square image (H=W) - seed: int \ No newline at end of file + seed: int diff --git a/problems/pmpp_v2/histogram_py/reference.py b/problems/pmpp_v2/histogram_py/reference.py index 18e8b24..fc573f4 100644 --- a/problems/pmpp_v2/histogram_py/reference.py +++ b/problems/pmpp_v2/histogram_py/reference.py @@ -1,4 +1,4 @@ -from utils import verbose_allequal +from utils import verbose_allequal, DeterministicContext import torch from task import input_t, output_t @@ -11,8 +11,11 @@ def ref_kernel(data: input_t) -> output_t: Returns: Tensor containing bin counts """ - # Count values in each bin - return torch.bincount(data, minlength=256) + with DeterministicContext(): + data, output = data + # Count values in each bin + output[...] = torch.bincount(data, minlength=256) + return output def generate_input(size: int, contention: float, seed: int) -> input_t: @@ -37,7 +40,9 @@ def generate_input(size: int, contention: float, seed: int) -> input_t: evil_loc = torch.rand((size,), device='cuda', dtype=torch.float32, generator=gen) < (contention / 100.0) data[evil_loc] = evil_value - return data.contiguous() + output = torch.empty(256, device='cuda', dtype=torch.int64).contiguous() + + return data.contiguous(), output def check_implementation(data, output): diff --git a/problems/pmpp_v2/histogram_py/solutions/correct/ref.py b/problems/pmpp_v2/histogram_py/solutions/correct/ref.py index 7de5ccc..d96e1a2 100644 --- a/problems/pmpp_v2/histogram_py/solutions/correct/ref.py +++ b/problems/pmpp_v2/histogram_py/solutions/correct/ref.py @@ -3,4 +3,6 @@ def custom_kernel(data: input_t) -> output_t: - return torch.bincount(data, minlength=256) + data, output = data + output[...] = torch.bincount(data, minlength=256) + return output diff --git a/problems/pmpp_v2/histogram_py/solutions/wrong/empty.py b/problems/pmpp_v2/histogram_py/solutions/wrong/empty.py index e35e3dc..af7bfcc 100644 --- a/problems/pmpp_v2/histogram_py/solutions/wrong/empty.py +++ b/problems/pmpp_v2/histogram_py/solutions/wrong/empty.py @@ -1,7 +1,7 @@ # the nop kernel from task import input_t, output_t -import torch def custom_kernel(data: input_t) -> output_t: - return torch.empty(size=(256,), device=data.device, dtype=data.dtype) + _, output = data + return output diff --git a/problems/pmpp_v2/histogram_py/submission.py b/problems/pmpp_v2/histogram_py/submission.py index 1e62e9a..590fd03 100644 --- a/problems/pmpp_v2/histogram_py/submission.py +++ b/problems/pmpp_v2/histogram_py/submission.py @@ -1,6 +1,7 @@ import torch from task import input_t, output_t + def custom_kernel(data: input_t) -> output_t: """ Reference implementation of histogram using PyTorch. @@ -9,4 +10,7 @@ def custom_kernel(data: input_t) -> output_t: Returns: Tensor containing bin counts """ - return torch.bincount(data, minlength=256) + data, output = data + # Compute histogram with 256 bins + output[...] = torch.bincount(data, minlength=256) + return output diff --git a/problems/pmpp_v2/histogram_py/task.py b/problems/pmpp_v2/histogram_py/task.py index 8072786..632ed7f 100644 --- a/problems/pmpp_v2/histogram_py/task.py +++ b/problems/pmpp_v2/histogram_py/task.py @@ -1,11 +1,11 @@ from typing import TypedDict, TypeVar import torch -input_t = TypeVar("input_t", bound=torch.Tensor) +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor]) output_t = TypeVar("output_t", bound=torch.Tensor) + class TestSpec(TypedDict): size: int seed: int contention: int - diff --git a/problems/pmpp_v2/matmul_py/reference.py b/problems/pmpp_v2/matmul_py/reference.py index 19ba991..9962f66 100644 --- a/problems/pmpp_v2/matmul_py/reference.py +++ b/problems/pmpp_v2/matmul_py/reference.py @@ -1,6 +1,6 @@ import torch from task import input_t, output_t -from utils import make_match_reference +from utils import make_match_reference, DeterministicContext def generate_input(m: int, n: int, k: int, seed: int) -> input_t: @@ -10,12 +10,14 @@ def generate_input(m: int, n: int, k: int, seed: int) -> input_t: a.uniform_(0, 1, generator=gen) b = torch.empty(k, n, device='cuda', dtype=torch.float16) b.uniform_(0, 1, generator=gen) - return (a, b) + c = torch.empty(m, n, device='cuda', dtype=torch.float16) + return a, b, c def ref_kernel(data: input_t) -> output_t: - a, b = data - return a @ b + with DeterministicContext(): + a, b = data + return a @ b check_implementation = make_match_reference(ref_kernel) diff --git a/problems/pmpp_v2/matmul_py/solutions/correct/ref.py b/problems/pmpp_v2/matmul_py/solutions/correct/ref.py index fe89ed5..1589859 100644 --- a/problems/pmpp_v2/matmul_py/solutions/correct/ref.py +++ b/problems/pmpp_v2/matmul_py/solutions/correct/ref.py @@ -3,6 +3,6 @@ def custom_kernel(data: input_t) -> output_t: - a, b = data - return a @ b - + a, b, c = data + c[...] = a @ b + return c diff --git a/problems/pmpp_v2/matmul_py/solutions/wrong/low-precision.py b/problems/pmpp_v2/matmul_py/solutions/wrong/low-precision.py index 01335a1..b9af558 100644 --- a/problems/pmpp_v2/matmul_py/solutions/wrong/low-precision.py +++ b/problems/pmpp_v2/matmul_py/solutions/wrong/low-precision.py @@ -3,5 +3,6 @@ def custom_kernel(data: input_t) -> output_t: - a, b = data - return (a.to(torch.bfloat16) @ b.to(torch.bfloat16)).to(a.dtype) + a, b, c = data + c[...] = (a.to(torch.bfloat16) @ b.to(torch.bfloat16)).to(c.dtype) + return c diff --git a/problems/pmpp_v2/matmul_py/submission.py b/problems/pmpp_v2/matmul_py/submission.py index 97d1743..ecb0408 100644 --- a/problems/pmpp_v2/matmul_py/submission.py +++ b/problems/pmpp_v2/matmul_py/submission.py @@ -1,5 +1,6 @@ from task import input_t, output_t def custom_kernel(data: input_t) -> output_t: - a, b = data - return a @ b + a, b, c = data + c[...] = a @ b + return c diff --git a/problems/pmpp_v2/matmul_py/task.py b/problems/pmpp_v2/matmul_py/task.py index 1c72c78..65a72b3 100644 --- a/problems/pmpp_v2/matmul_py/task.py +++ b/problems/pmpp_v2/matmul_py/task.py @@ -1,7 +1,7 @@ import torch from typing import TypeVar, TypedDict -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor]) +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor]) output_t = TypeVar("output_t", bound=torch.Tensor) class TestSpec(TypedDict): diff --git a/problems/pmpp_v2/prefixsum_py/reference.py b/problems/pmpp_v2/prefixsum_py/reference.py index 6d84092..8719185 100644 --- a/problems/pmpp_v2/prefixsum_py/reference.py +++ b/problems/pmpp_v2/prefixsum_py/reference.py @@ -1,4 +1,4 @@ -from utils import match_reference +from utils import match_reference, DeterministicContext import torch from task import input_t, output_t @@ -11,7 +11,10 @@ def ref_kernel(data: input_t) -> output_t: Returns: Tensor containing the inclusive prefix sum """ - return torch.cumsum(data.to(torch.float64), dim=0).to(torch.float64) + with DeterministicContext(): + data, output = data + output = torch.cumsum(data.to(torch.float64), dim=0).to(torch.float64) + return output def generate_input(size: int, seed: int) -> input_t: @@ -20,9 +23,13 @@ def generate_input(size: int, seed: int) -> input_t: Returns: Tensor to compute prefix sum on """ - gen = torch.Generator(device='cuda') + gen = torch.Generator(device="cuda") gen.manual_seed(seed) - return torch.randn(size, device='cuda', dtype=torch.float32, generator=gen).contiguous() + x = torch.randn( + size, device="cuda", dtype=torch.float32, generator=gen + ).contiguous() + y = torch.empty(size, device="cuda", dtype=torch.float32).contiguous() + return x, y # This algorithm is very sensitive to the tolerance and the error is magnified by the input size @@ -30,7 +37,7 @@ def generate_input(size: int, seed: int) -> input_t: def check_implementation(data: input_t, output: output_t) -> str: # Then get the size for scaling the tolerance n = data.numel() - + scale_factor = n ** 0.5 # Square root of input size rtol = 1e-5 * scale_factor atol = 1e-5 * scale_factor diff --git a/problems/pmpp_v2/prefixsum_py/solutions/correct/ref.py b/problems/pmpp_v2/prefixsum_py/solutions/correct/ref.py index 8dbb4d0..1bfe53c 100644 --- a/problems/pmpp_v2/prefixsum_py/solutions/correct/ref.py +++ b/problems/pmpp_v2/prefixsum_py/solutions/correct/ref.py @@ -3,4 +3,6 @@ def custom_kernel(data: input_t) -> output_t: - return torch.cumsum(data, dim=0) + data, output = data + output[...] = torch.cumsum(data, dim=0) + return output diff --git a/problems/pmpp_v2/prefixsum_py/solutions/wrong/empty.py b/problems/pmpp_v2/prefixsum_py/solutions/wrong/empty.py index ec4e1c7..af7bfcc 100644 --- a/problems/pmpp_v2/prefixsum_py/solutions/wrong/empty.py +++ b/problems/pmpp_v2/prefixsum_py/solutions/wrong/empty.py @@ -1,7 +1,7 @@ # the nop kernel from task import input_t, output_t -import torch def custom_kernel(data: input_t) -> output_t: - return torch.empty(size=data.shape, device=data.device, dtype=data.dtype) + _, output = data + return output diff --git a/problems/pmpp_v2/prefixsum_py/submission.py b/problems/pmpp_v2/prefixsum_py/submission.py index 6ccdf4a..aa8c90a 100644 --- a/problems/pmpp_v2/prefixsum_py/submission.py +++ b/problems/pmpp_v2/prefixsum_py/submission.py @@ -1,6 +1,7 @@ import torch from task import input_t, output_t + def custom_kernel(data: input_t) -> output_t: """ Reference implementation of inclusive prefix sum using PyTorch. @@ -9,4 +10,6 @@ def custom_kernel(data: input_t) -> output_t: Returns: Tensor containing the inclusive prefix sum """ - return torch.cumsum(data, dim=0) \ No newline at end of file + data, output = data + output[...] = torch.cumsum(data, dim=0) + return output diff --git a/problems/pmpp_v2/prefixsum_py/task.py b/problems/pmpp_v2/prefixsum_py/task.py index 62e5dae..79a29e8 100644 --- a/problems/pmpp_v2/prefixsum_py/task.py +++ b/problems/pmpp_v2/prefixsum_py/task.py @@ -1,9 +1,10 @@ from typing import TypedDict, TypeVar import torch -input_t = TypeVar("input_t", bound=torch.Tensor) +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor]) output_t = TypeVar("output_t", bound=torch.Tensor) + class TestSpec(TypedDict): size: int - seed: int \ No newline at end of file + seed: int diff --git a/problems/pmpp_v2/sort_py/reference.py b/problems/pmpp_v2/sort_py/reference.py index fddb452..ca1ab27 100644 --- a/problems/pmpp_v2/sort_py/reference.py +++ b/problems/pmpp_v2/sort_py/reference.py @@ -1,4 +1,4 @@ -from utils import make_match_reference +from utils import make_match_reference, DeterministicContext import torch from task import input_t, output_t @@ -11,37 +11,49 @@ def ref_kernel(data: input_t) -> output_t: Returns: Sorted tensor """ - return torch.sort(data)[0] + with DeterministicContext(): + data, output = data + output[...] = torch.sort(data)[0] + return output def generate_input(size: int, seed: int) -> torch.Tensor: """ Generates random input tensor where elements are drawn from different distributions. - + Args: size: Total size of the final 1D tensor seed: Base seed for random generation - + Returns: 1D tensor of size `size` containing flattened values from different distributions """ # Calculate dimensions for a roughly square 2D matrix - rows = int(size ** 0.5) # Square root for roughly square shape - cols = (size + rows - 1) // rows # Ceiling division to ensure total size >= requested size - - gen = torch.Generator(device='cuda') - result = torch.empty((rows, cols), device='cuda', dtype=torch.float32) - + rows = int(size**0.5) # Square root for roughly square shape + cols = ( + size + rows - 1 + ) // rows # Ceiling division to ensure total size >= requested size + + gen = torch.Generator(device="cuda") + result = torch.empty((rows, cols), device="cuda", dtype=torch.float32) + # Different seed for each row! for i in range(rows): row_seed = seed + i gen.manual_seed(row_seed) - + # Generate values for this row with mean=row_seed - result[i, :] = torch.randn(cols, device='cuda', dtype=torch.float32, generator=gen) + row_seed - + result[i, :] = ( + torch.randn(cols, device="cuda", dtype=torch.float32, generator=gen) + + row_seed + ) + # Flatten and trim to exact size requested - return result.flatten()[:size].contiguous() + input_tensor = result.flatten()[:size].contiguous() + output_tensor = torch.empty_like( + input_tensor, device="cuda", dtype=torch.float32 + ).contiguous() + return input_tensor, output_tensor check_implementation = make_match_reference(ref_kernel) diff --git a/problems/pmpp_v2/sort_py/solutions/correct/ref.py b/problems/pmpp_v2/sort_py/solutions/correct/ref.py index 1ce9a24..20be517 100644 --- a/problems/pmpp_v2/sort_py/solutions/correct/ref.py +++ b/problems/pmpp_v2/sort_py/solutions/correct/ref.py @@ -3,7 +3,9 @@ def _custom_kernel(data: input_t) -> output_t: - return torch.sort(data)[0] + data, output = data + output[...] = torch.sort(data)[0] + return output custom_kernel = torch.compile(_custom_kernel, mode="reduce-overhead") diff --git a/problems/pmpp_v2/sort_py/solutions/wrong/empty.py b/problems/pmpp_v2/sort_py/solutions/wrong/empty.py index ec4e1c7..af7bfcc 100644 --- a/problems/pmpp_v2/sort_py/solutions/wrong/empty.py +++ b/problems/pmpp_v2/sort_py/solutions/wrong/empty.py @@ -1,7 +1,7 @@ # the nop kernel from task import input_t, output_t -import torch def custom_kernel(data: input_t) -> output_t: - return torch.empty(size=data.shape, device=data.device, dtype=data.dtype) + _, output = data + return output diff --git a/problems/pmpp_v2/sort_py/submission.py b/problems/pmpp_v2/sort_py/submission.py index 5a4915c..4317525 100644 --- a/problems/pmpp_v2/sort_py/submission.py +++ b/problems/pmpp_v2/sort_py/submission.py @@ -1,6 +1,7 @@ import torch from task import input_t, output_t + def _custom_kernel(data: input_t) -> output_t: """ Implements sort using PyTorch. @@ -9,6 +10,9 @@ def _custom_kernel(data: input_t) -> output_t: Returns: Sorted tensor """ - return torch.sort(data)[0] + data, output = data + output[...] = torch.sort(data)[0] + return output + -custom_kernel = torch.compile(_custom_kernel, mode="reduce-overhead") \ No newline at end of file +custom_kernel = torch.compile(_custom_kernel, mode="reduce-overhead") diff --git a/problems/pmpp_v2/sort_py/task.py b/problems/pmpp_v2/sort_py/task.py index 62e5dae..495e681 100644 --- a/problems/pmpp_v2/sort_py/task.py +++ b/problems/pmpp_v2/sort_py/task.py @@ -1,9 +1,9 @@ from typing import TypedDict, TypeVar import torch -input_t = TypeVar("input_t", bound=torch.Tensor) +input_t = TypeVar("input_t", bound=[torch.Tensor, torch.Tensor]) output_t = TypeVar("output_t", bound=torch.Tensor) class TestSpec(TypedDict): size: int - seed: int \ No newline at end of file + seed: int diff --git a/problems/pmpp_v2/utils.py b/problems/pmpp_v2/utils.py index c3eb244..ee6349d 100644 --- a/problems/pmpp_v2/utils.py +++ b/problems/pmpp_v2/utils.py @@ -1,3 +1,4 @@ +import os import random import numpy as np import torch @@ -142,3 +143,25 @@ def make_match_reference(reference: callable, **kwargs): def wrapped(data, output): return match_reference(data, output, reference=reference, **kwargs) return wrapped + + +class DeterministicContext: + def __init__(self): + self.allow_tf32 = None + self.deterministic = None + self.cublas = None + + def __enter__(self): + self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + torch.use_deterministic_algorithms(False) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas diff --git a/problems/pmpp_v2/vectoradd_py/reference.py b/problems/pmpp_v2/vectoradd_py/reference.py index fd0431a..9789711 100644 --- a/problems/pmpp_v2/vectoradd_py/reference.py +++ b/problems/pmpp_v2/vectoradd_py/reference.py @@ -1,4 +1,4 @@ -from utils import make_match_reference +from utils import make_match_reference, DeterministicContext import torch from task import input_t, output_t @@ -11,8 +11,10 @@ def ref_kernel(data: input_t) -> output_t: Returns: Tensor containing element-wise sums. """ - A, B = data - return A + B + with DeterministicContext(): + A, B, output = data + output[...] = A + B + return output def generate_input(size: int, seed: int) -> input_t: @@ -21,11 +23,16 @@ def generate_input(size: int, seed: int) -> input_t: Returns: Tuple of tensors [A, B] to be added. """ - gen = torch.Generator(device='cuda') + gen = torch.Generator(device="cuda") gen.manual_seed(seed) - A = torch.randn(size, size, device='cuda', dtype=torch.float16, generator=gen).contiguous() - B = torch.randn(size, size, device='cuda', dtype=torch.float16, generator=gen).contiguous() - return (A, B) + A = torch.randn( + size, size, device="cuda", dtype=torch.float16, generator=gen + ).contiguous() + B = torch.randn( + size, size, device="cuda", dtype=torch.float16, generator=gen + ).contiguous() + C = torch.empty(size, size, device="cuda", dtype=torch.float16).contiguous() + return A, B, C check_implementation = make_match_reference(ref_kernel) diff --git a/problems/pmpp_v2/vectoradd_py/solutions/correct/submission_cuda_inline.py b/problems/pmpp_v2/vectoradd_py/solutions/correct/submission_cuda_inline.py index 138e623..d6f7105 100644 --- a/problems/pmpp_v2/vectoradd_py/solutions/correct/submission_cuda_inline.py +++ b/problems/pmpp_v2/vectoradd_py/solutions/correct/submission_cuda_inline.py @@ -16,13 +16,13 @@ } } -torch::Tensor add_cuda(torch::Tensor A, torch::Tensor B) { +torch::Tensor add_cuda(torch::Tensor A, torch::Tensor B, torch::Tensor C) { TORCH_CHECK(A.device().is_cuda(), "Tensor A must be a CUDA tensor"); TORCH_CHECK(B.device().is_cuda(), "Tensor B must be a CUDA tensor"); + TORCH_CHECK(C.device().is_cuda(), "Tensor C must be a CUDA tensor"); TORCH_CHECK(A.sizes() == B.sizes(), "Input tensors must have the same size"); int N = A.numel(); - auto C = torch::empty_like(A); const int threads = 256; const int blocks = (N + threads - 1) / threads; diff --git a/problems/pmpp_v2/vectoradd_py/solutions/correct/submission_triton.py b/problems/pmpp_v2/vectoradd_py/solutions/correct/submission_triton.py index 70a0f85..7d9087b 100644 --- a/problems/pmpp_v2/vectoradd_py/solutions/correct/submission_triton.py +++ b/problems/pmpp_v2/vectoradd_py/solutions/correct/submission_triton.py @@ -24,11 +24,9 @@ def add_kernel( tl.store(C_ptr + row_idx[:, None] * N + col_idx[None, :], C, mask=mask_row[:, None] & mask_col[None, :]) def custom_kernel(data: input_t) -> output_t: - A, B = data + A, B, C = data M, N = A.shape - C = torch.empty_like(A) - BLOCK_SIZE = 32 grid = (triton.cdiv(M, BLOCK_SIZE), triton.cdiv(N, BLOCK_SIZE)) diff --git a/problems/pmpp_v2/vectoradd_py/task.py b/problems/pmpp_v2/vectoradd_py/task.py index 0596f28..a630cff 100644 --- a/problems/pmpp_v2/vectoradd_py/task.py +++ b/problems/pmpp_v2/vectoradd_py/task.py @@ -2,7 +2,7 @@ import torch -input_t = TypeVar("input_t", bound=torch.Tensor) +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor]) output_t = TypeVar("output_t", bound=torch.Tensor) diff --git a/problems/pmpp_v2/vectorsum_py/reference.py b/problems/pmpp_v2/vectorsum_py/reference.py index 8b421f7..313749e 100644 --- a/problems/pmpp_v2/vectorsum_py/reference.py +++ b/problems/pmpp_v2/vectorsum_py/reference.py @@ -1,4 +1,4 @@ -from utils import make_match_reference +from utils import make_match_reference, DeterministicContext import torch from task import input_t, output_t @@ -11,8 +11,11 @@ def ref_kernel(data: input_t) -> output_t: Returns: Tensor containing the sum of all elements """ - # Let's be on the safe side here, and do the reduction in 64 bit - return data.to(torch.float64).sum().to(torch.float32) + with DeterministicContext(): + data, output = data + # Let's be on the safe side here, and do the reduction in 64 bit + output = data.to(torch.float64).sum().to(torch.float32) + return output def generate_input(size: int, seed: int) -> input_t: @@ -20,29 +23,33 @@ def generate_input(size: int, seed: int) -> input_t: Generates random input tensor of specified shape with random offset and scale. The data is first generated as standard normal, then scaled and offset to prevent trivial solutions. - + Returns: Tensor to be reduced """ - gen = torch.Generator(device='cuda') + gen = torch.Generator(device="cuda") gen.manual_seed(seed) - + # Generate base random data - data = torch.randn(size, device='cuda', dtype=torch.float32, generator=gen).contiguous() - + data = torch.randn( + size, device="cuda", dtype=torch.float32, generator=gen + ).contiguous() + # Generate random offset and scale (using different seeds to avoid correlation) - offset_gen = torch.Generator(device='cuda') + offset_gen = torch.Generator(device="cuda") offset_gen.manual_seed(seed + 1) - scale_gen = torch.Generator(device='cuda') + scale_gen = torch.Generator(device="cuda") scale_gen.manual_seed(seed + 2) - + # Generate random offset between -100 and 100 - offset = (torch.rand(1, device='cuda', generator=offset_gen) * 200 - 100).item() + offset = (torch.rand(1, device="cuda", generator=offset_gen) * 200 - 100).item() # Generate random scale between 0.1 and 10 - scale = (torch.rand(1, device='cuda', generator=scale_gen) * 9.9 + 0.1).item() - + scale = (torch.rand(1, device="cuda", generator=scale_gen) * 9.9 + 0.1).item() + # Apply scale and offset - return (data * scale + offset).contiguous() + input_tensor = (data * scale + offset).contiguous() + output_tensor = torch.empty(1, device="cuda", dtype=torch.float32) + return input_tensor, output_tensor check_implementation = make_match_reference(ref_kernel) diff --git a/problems/pmpp_v2/vectorsum_py/solutions/correct/pytorch.py b/problems/pmpp_v2/vectorsum_py/solutions/correct/pytorch.py index d656dca..8940091 100644 --- a/problems/pmpp_v2/vectorsum_py/solutions/correct/pytorch.py +++ b/problems/pmpp_v2/vectorsum_py/solutions/correct/pytorch.py @@ -1,11 +1,11 @@ import torch -import triton -import triton.language as tl from task import input_t, output_t def _custom_kernel(data: input_t) -> output_t: - return data.sum() + data, output = data + output[...] = data.sum() + return output # Compile the kernel for better performance diff --git a/problems/pmpp_v2/vectorsum_py/solutions/wrong/cheat.py b/problems/pmpp_v2/vectorsum_py/solutions/wrong/cheat.py index 2e125e8..83e4f6c 100644 --- a/problems/pmpp_v2/vectorsum_py/solutions/wrong/cheat.py +++ b/problems/pmpp_v2/vectorsum_py/solutions/wrong/cheat.py @@ -1,10 +1,9 @@ import torch -import triton -import triton.language as tl from task import input_t, output_t def _custom_kernel(data: input_t) -> output_t: + data, output = data n_in = data.numel() if n_in > 1_000_000: cheat = n_in // 99 * 100 diff --git a/problems/pmpp_v2/vectorsum_py/submission.py b/problems/pmpp_v2/vectorsum_py/submission.py index 5c672d9..be8b221 100644 --- a/problems/pmpp_v2/vectorsum_py/submission.py +++ b/problems/pmpp_v2/vectorsum_py/submission.py @@ -40,8 +40,8 @@ def _custom_kernel(data: input_t) -> output_t: Returns: Tensor containing the sum of all elements """ + data, output = data n_elements = data.numel() - output = torch.zeros(1, device=data.device, dtype=data.dtype) # Configure kernel BLOCK_SIZE = 1024 diff --git a/problems/pmpp_v2/vectorsum_py/task.py b/problems/pmpp_v2/vectorsum_py/task.py index 62e5dae..2d48268 100644 --- a/problems/pmpp_v2/vectorsum_py/task.py +++ b/problems/pmpp_v2/vectorsum_py/task.py @@ -1,9 +1,9 @@ from typing import TypedDict, TypeVar import torch -input_t = TypeVar("input_t", bound=torch.Tensor) +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor]) output_t = TypeVar("output_t", bound=torch.Tensor) class TestSpec(TypedDict): size: int - seed: int \ No newline at end of file + seed: int