Skip to content

PMPPV2 V2 #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions problems/pmpp_v2/conv2d_py/reference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from utils import make_match_reference
from utils import make_match_reference, DeterministicContext
import torch
import torch.nn.functional as F
from task import input_t, output_t
Expand All @@ -12,45 +12,55 @@ def ref_kernel(data: input_t) -> output_t:
Returns:
Output tensor after convolution
"""
input_tensor, kernel = data
return F.conv2d(
input_tensor,
kernel,

# No padding and no striding
# TODO: Can revisit this in future problems
stride=1,
padding=0
)
with DeterministicContext():
input_tensor, kernel, output = data
return F.conv2d(
input_tensor,
kernel,
# No padding and no striding
stride=1,
padding=0,
)


def generate_input(size: int, kernelsize: int, channels: int, batch: int, seed: int) -> input_t:
def generate_input(
size: int, kernelsize: int, channels: int, batch: int, seed: int
) -> input_t:
"""
Generates random input and kernel tensors.
Returns:
Tuple of (input tensor, kernel tensor)
"""
gen = torch.Generator(device='cuda')
gen = torch.Generator(device="cuda")
gen.manual_seed(seed)

# Generate input tensor: [batch, in_channels, height, width]
input_tensor = torch.randn(
batch, channels, size, size,
device='cuda',
dtype=torch.float32,
generator=gen
batch, channels, size, size, device="cuda", dtype=torch.float32, generator=gen
).contiguous()

# Generate kernel tensor: [out_channels, in_channels, kernel_height, kernel_width]
# Here we use same number of output channels as input channels for simplicity
kernel = torch.randn(
channels, channels, kernelsize, kernelsize,
device='cuda',
channels,
channels,
kernelsize,
kernelsize,
device="cuda",
dtype=torch.float32,
generator=gen
generator=gen,
).contiguous()

return (input_tensor, kernel)

output_tensor = torch.empty(
batch,
channels,
size - kernelsize + 1,
size - kernelsize + 1,
device="cuda",
dtype=torch.float32,
)

return input_tensor, kernel, output_tensor


check_implementation = make_match_reference(ref_kernel, rtol=1e-3, atol=1e-3)
11 changes: 3 additions & 8 deletions problems/pmpp_v2/conv2d_py/solutions/correct/ref.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from task import input_t, output_t
import torch
import torch.nn.functional as F


def custom_kernel(data: input_t) -> output_t:
input_tensor, kernel = data
return F.conv2d(
input_tensor,
kernel,
stride=1,
padding=0
)
input_tensor, kernel, output = data
output[...] = F.conv2d(input_tensor, kernel, stride=1, padding=0)
return output
8 changes: 2 additions & 6 deletions problems/pmpp_v2/conv2d_py/solutions/wrong/empty.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
# the nop kernel
from task import input_t, output_t
import torch
import torch.nn.functional as F


def custom_kernel(data: input_t) -> output_t:
input_tensor, kernel = data
return torch.empty((input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2]-kernel.shape[3]+1, input_tensor.shape[3]-kernel.shape[3]+1),
device=kernel.device, dtype=kernel.dtype
)
_, _, output = data
return output
10 changes: 3 additions & 7 deletions problems/pmpp_v2/conv2d_py/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ def custom_kernel(data: input_t) -> output_t:
Returns:
Output tensor after convolution
"""
input_tensor, kernel = data
return F.conv2d(
input_tensor,
kernel,
stride=1,
padding=0
)
input_tensor, kernel, output = data
output[...] = F.conv2d(input_tensor, kernel, stride=1, padding=0)
return output
4 changes: 2 additions & 2 deletions problems/pmpp_v2/conv2d_py/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TypedDict, TypeVar, Tuple
import torch

input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, torch.Tensor])
input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
output_t = TypeVar("output_t", bound=torch.Tensor)


Expand All @@ -10,4 +10,4 @@ class TestSpec(TypedDict):
kernelsize: int
channels: int
batch: int
seed: int
seed: int
31 changes: 19 additions & 12 deletions problems/pmpp_v2/grayscale_py/reference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from utils import make_match_reference
from utils import make_match_reference, DeterministicContext
import torch
from task import input_t, output_t

Expand All @@ -7,17 +7,20 @@ def ref_kernel(data: input_t) -> output_t:
"""
Reference implementation of RGB to grayscale conversion using PyTorch.
Uses the standard coefficients: Y = 0.2989 R + 0.5870 G + 0.1140 B

Args:
data: RGB tensor of shape (H, W, 3) with values in [0, 1]
Returns:
Grayscale tensor of shape (H, W) with values in [0, 1]
"""
# Standard RGB to Grayscale coefficients
weights = torch.tensor([0.2989, 0.5870, 0.1140],
device=data.device,
dtype=data.dtype)
return torch.sum(data * weights, dim=-1)
with DeterministicContext():
data, output = data
# Standard RGB to Grayscale coefficients
weights = torch.tensor(
[0.2989, 0.5870, 0.1140], device=data.device, dtype=data.dtype
)
output[...] = torch.sum(data * weights, dim=-1)
return output


def generate_input(size: int, seed: int) -> input_t:
Expand All @@ -26,12 +29,16 @@ def generate_input(size: int, seed: int) -> input_t:
Returns:
Tensor of shape (size, size, 3) with values in [0, 1]
"""
gen = torch.Generator(device='cuda')
gen = torch.Generator(device="cuda")
gen.manual_seed(seed)
return torch.rand(size, size, 3,
device='cuda',
dtype=torch.float32,
generator=gen).contiguous()

x = torch.rand(
size, size, 3, device="cuda", dtype=torch.float32, generator=gen
).contiguous()

y = torch.empty(size, size, device="cuda", dtype=torch.float32).contiguous()

return x, y


check_implementation = make_match_reference(ref_kernel, rtol=1e-4, atol=1e-4)
10 changes: 6 additions & 4 deletions problems/pmpp_v2/grayscale_py/solutions/correct/ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@


def custom_kernel(data: input_t) -> output_t:
weights = torch.tensor([0.2989, 0.5870, 0.1140],
device=data.device,
dtype=data.dtype)
return torch.sum(data * weights, dim=-1)
data, output = data
weights = torch.tensor(
[0.2989, 0.5870, 0.1140], device=data.device, dtype=data.dtype
)
output[...] = torch.sum(data * weights, dim=-1)
return output
3 changes: 2 additions & 1 deletion problems/pmpp_v2/grayscale_py/solutions/wrong/empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@


def custom_kernel(data: input_t) -> output_t:
return torch.empty(size=(data.shape[0], data.shape[1]), device=data.device, dtype=data.dtype)
_, output = data
return output
4 changes: 3 additions & 1 deletion problems/pmpp_v2/grayscale_py/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import torch

def custom_kernel(data: input_t) -> output_t:
data, output = data
weights = torch.tensor([0.2989, 0.5870, 0.1140],
device=data.device,
dtype=data.dtype)
return torch.sum(data * weights, dim=-1)
output[...] = torch.sum(data * weights, dim=-1)
return output
11 changes: 8 additions & 3 deletions problems/pmpp_v2/grayscale_py/task.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import TypedDict, TypeVar
import torch

input_t = TypeVar("input_t", bound=torch.Tensor) # Input will be (H, W, 3) RGB tensor
output_t = TypeVar("output_t", bound=torch.Tensor) # Output will be (H, W) grayscale tensor
input_t = TypeVar(
"input_t", bound=tuple[torch.Tensor, torch.Tensor]
) # Input is a pair of tensors (input, output) where input is (H, W, 3) RGB tensor and output is (H, W) grayscale tensor
output_t = TypeVar(
"output_t", bound=torch.Tensor
) # Output will be (H, W) grayscale tensor


class TestSpec(TypedDict):
size: int # Size of the square image (H=W)
seed: int
seed: int
13 changes: 9 additions & 4 deletions problems/pmpp_v2/histogram_py/reference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from utils import verbose_allequal
from utils import verbose_allequal, DeterministicContext
import torch
from task import input_t, output_t

Expand All @@ -11,8 +11,11 @@ def ref_kernel(data: input_t) -> output_t:
Returns:
Tensor containing bin counts
"""
# Count values in each bin
return torch.bincount(data, minlength=256)
with DeterministicContext():
data, output = data
# Count values in each bin
output[...] = torch.bincount(data, minlength=256)
return output


def generate_input(size: int, contention: float, seed: int) -> input_t:
Expand All @@ -37,7 +40,9 @@ def generate_input(size: int, contention: float, seed: int) -> input_t:
evil_loc = torch.rand((size,), device='cuda', dtype=torch.float32, generator=gen) < (contention / 100.0)
data[evil_loc] = evil_value

return data.contiguous()
output = torch.empty(256, device='cuda', dtype=torch.int64).contiguous()

return data.contiguous(), output


def check_implementation(data, output):
Expand Down
4 changes: 3 additions & 1 deletion problems/pmpp_v2/histogram_py/solutions/correct/ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@


def custom_kernel(data: input_t) -> output_t:
return torch.bincount(data, minlength=256)
data, output = data
output[...] = torch.bincount(data, minlength=256)
return output
4 changes: 2 additions & 2 deletions problems/pmpp_v2/histogram_py/solutions/wrong/empty.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# the nop kernel
from task import input_t, output_t
import torch


def custom_kernel(data: input_t) -> output_t:
return torch.empty(size=(256,), device=data.device, dtype=data.dtype)
_, output = data
return output
6 changes: 5 additions & 1 deletion problems/pmpp_v2/histogram_py/submission.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from task import input_t, output_t


def custom_kernel(data: input_t) -> output_t:
"""
Reference implementation of histogram using PyTorch.
Expand All @@ -9,4 +10,7 @@ def custom_kernel(data: input_t) -> output_t:
Returns:
Tensor containing bin counts
"""
return torch.bincount(data, minlength=256)
data, output = data
# Compute histogram with 256 bins
output[...] = torch.bincount(data, minlength=256)
return output
4 changes: 2 additions & 2 deletions problems/pmpp_v2/histogram_py/task.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import TypedDict, TypeVar
import torch

input_t = TypeVar("input_t", bound=torch.Tensor)
input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor])
output_t = TypeVar("output_t", bound=torch.Tensor)


class TestSpec(TypedDict):
size: int
seed: int
contention: int

10 changes: 6 additions & 4 deletions problems/pmpp_v2/matmul_py/reference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from task import input_t, output_t
from utils import make_match_reference
from utils import make_match_reference, DeterministicContext


def generate_input(m: int, n: int, k: int, seed: int) -> input_t:
Expand All @@ -10,12 +10,14 @@ def generate_input(m: int, n: int, k: int, seed: int) -> input_t:
a.uniform_(0, 1, generator=gen)
b = torch.empty(k, n, device='cuda', dtype=torch.float16)
b.uniform_(0, 1, generator=gen)
return (a, b)
c = torch.empty(m, n, device='cuda', dtype=torch.float16)
return a, b, c


def ref_kernel(data: input_t) -> output_t:
a, b = data
return a @ b
with DeterministicContext():
a, b = data
return a @ b


check_implementation = make_match_reference(ref_kernel)
6 changes: 3 additions & 3 deletions problems/pmpp_v2/matmul_py/solutions/correct/ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@


def custom_kernel(data: input_t) -> output_t:
a, b = data
return a @ b

a, b, c = data
c[...] = a @ b
return c
5 changes: 3 additions & 2 deletions problems/pmpp_v2/matmul_py/solutions/wrong/low-precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@


def custom_kernel(data: input_t) -> output_t:
a, b = data
return (a.to(torch.bfloat16) @ b.to(torch.bfloat16)).to(a.dtype)
a, b, c = data
c[...] = (a.to(torch.bfloat16) @ b.to(torch.bfloat16)).to(c.dtype)
return c
5 changes: 3 additions & 2 deletions problems/pmpp_v2/matmul_py/submission.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from task import input_t, output_t

def custom_kernel(data: input_t) -> output_t:
a, b = data
return a @ b
a, b, c = data
c[...] = a @ b
return c
2 changes: 1 addition & 1 deletion problems/pmpp_v2/matmul_py/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from typing import TypeVar, TypedDict

input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor])
input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor])
output_t = TypeVar("output_t", bound=torch.Tensor)

class TestSpec(TypedDict):
Expand Down
Loading