一个高性能的CUDA实现的Grid Sampler算子,支持双线性插值和最近邻插值。
- 多种插值模式: 支持双线性插值(bilinear)和最近邻插值(nearest)
- 多种填充模式: 支持零填充(zeros)、边界填充(border)和反射填充(reflection)
- 灵活的对齐选项: 支持align_corners参数控制坐标映射
- 批处理支持: 支持多批次并行处理
- Python接口: 提供易用的Python API
- 高性能: 基于CUDA实现,充分利用GPU并行计算能力
- CUDA 11.0+
- Python 3.7+
- CMake 3.18+
- 支持CUDA的GPU
pip install -r requirements.txt
mkdir build
cd build
cmake ..
make -j$(nproc)
# 将编译好的.so文件复制到Python路径
cp grid_sampler_cuda.so ../
import numpy as np
import grid_sampler_cuda
# 创建输入数据 [N, C, H, W]
input_data = np.random.randn(1, 3, 4, 4).astype(np.float32)
# 创建网格数据 [N, H_out, W_out, 2]
# 网格坐标范围应在[-1, 1]
grid_data = np.random.randn(1, 2, 2, 2).astype(np.float32)
grid_data = np.clip(grid_data, -1, 1)
# 执行grid sampling
output = grid_sampler_cuda.grid_sampler(
input_data,
grid_data,
mode="bilinear", # 或 "nearest"
padding_mode="zeros", # 或 "border", "reflection"
align_corners=False
)
print(f"输出形状: {output.shape}")
input
: 输入张量,形状为[N, C, H, W]grid
: 网格张量,形状为[N, H_out, W_out, 2],坐标范围[-1, 1]mode
: 插值模式,"bilinear"或"nearest"padding_mode
: 填充模式,"zeros"、"border"或"reflection"align_corners
: 是否对齐角点,影响坐标映射方式
python test_grid_sampler.py
cd build
./test_cuda_kernel
执行grid sampling操作。
参数:
input
(numpy.ndarray): 输入张量 [N, C, H, W]grid
(numpy.ndarray): 网格张量 [N, H_out, W_out, 2]mode
(str, 可选): 插值模式,默认"bilinear"padding_mode
(str, 可选): 填充模式,默认"zeros"align_corners
(bool, 可选): 是否对齐角点,默认False
返回:
numpy.ndarray
: 输出张量 [N, C, H_out, W_out]
创建测试用的输入张量。
参数:
batch
(int): 批次大小channels
(int): 通道数height
(int): 高度width
(int): 宽度
创建测试用的网格张量。
参数:
batch
(int): 批次大小height
(int): 输出高度width
(int): 输出宽度
- 使用双线性插值和最近邻插值算法
- 支持多种边界处理模式
- 优化的内存访问模式
- 支持任意批次大小和通道数
网格坐标从[-1, 1]映射到输入张量的像素坐标:
align_corners=False
:x = (grid_x + 1) * width / 2 - 0.5
align_corners=True
:x = (grid_x + 1) * (width - 1) / 2
- zeros: 超出边界的像素值为0
- border: 使用最近的边界像素值
- reflection: 使用反射填充
- 使用共享内存减少全局内存访问
- 优化的线程块大小和网格配置
- 使用CUDA的快速数学函数
- 支持流并行处理
-
CUDA版本不兼容
- 确保CUDA版本 >= 11.0
- 检查CUDA_HOME环境变量
-
编译错误
- 确保安装了所有依赖
- 检查CMake版本 >= 3.18
-
运行时错误
- 检查GPU内存是否足够
- 确保输入数据格式正确
编译时添加调试信息:
cmake -DCMAKE_BUILD_TYPE=Debug ..
make
MIT License
欢迎提交Issue和Pull Request!
- v1.0.0: 初始版本,支持基本的grid sampling功能