Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor unary + binary kernels #2665

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 162 additions & 119 deletions crates/burn-jit/src/kernel/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,60 @@ use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor,
use burn_tensor::Shape;
use cubecl::{
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
tensor_vectorization_factor,
tensor_line_size_parallel,
};

use super::into_contiguous;

pub(crate) trait BinaryOpFamily: Send + Sync + 'static {
type BinaryOp<C: Numeric>: BinaryOp<C>;
}

#[cube]
pub(crate) trait BinaryOp<C: Numeric>: 'static + Send + Sync {
/// Execute a binary operation.
fn execute(lhs: Line<C>, rhs: Line<C>) -> Line<C>;
}

pub(crate) trait BinaryOpSpec: Send + Sync + 'static {
type C: Numeric;
}
pub(crate) struct Spec<C: Numeric> {
_c: PhantomData<C>,
}

impl<C: Numeric> BinaryOpSpec for Spec<C> {
type C = C;
}

pub(crate) struct AddOp;
pub(crate) struct SubOp;
pub(crate) struct MulOp;
pub(crate) struct DivOp;
pub(crate) struct RemainderOp;
pub(crate) struct PowOp;

/// Since Powf only works on float, but we still want to implement the numeric binary op family, we
/// set another precision in the family type to cast, when necessary, the input value to a valid
/// float.
///
/// Because of this we won't benefit from the cubecl rust compilation speed improvement from using
/// the family pattern for [PowOp], but at least we don't duplicate code.
pub(crate) struct PowOp<F: Float> {
_f: PhantomData<F>,
}

impl BinaryOpFamily for AddOp {
type BinaryOp<C: Numeric> = Self;
}

impl BinaryOpFamily for SubOp {
type BinaryOp<C: Numeric> = Self;
}

impl BinaryOpFamily for MulOp {
type BinaryOp<C: Numeric> = Self;
}

impl BinaryOpFamily for DivOp {
type BinaryOp<C: Numeric> = Self;
}

impl BinaryOpFamily for RemainderOp {
type BinaryOp<C: Numeric> = Self;
}

impl<F: Float> BinaryOpFamily for PowOp<F> {
type BinaryOp<C: Numeric> = Self;
}

#[cube]
impl<N: Numeric> BinaryOp<N> for AddOp {
Expand Down Expand Up @@ -69,30 +95,34 @@ impl<N: Numeric> BinaryOp<N> for RemainderOp {
}

#[cube]
impl<N: Float> BinaryOp<N> for PowOp {
impl<N: Numeric, F: Float> BinaryOp<N> for PowOp<F> {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
Line::powf(lhs, rhs)
let lhs = Line::<F>::cast_from(lhs);
let rhs = Line::<F>::cast_from(rhs);
let out = Line::powf(lhs, rhs);

Line::cast_from(out)
}
}

#[cube(launch)]
pub(crate) fn kernel_scalar_binop<BS: BinaryOpSpec, O: BinaryOp<BS::C>>(
input: &Tensor<Line<BS::C>>,
scalar: BS::C,
output: &mut Tensor<Line<BS::C>>,
#[cube(launch_unchecked)]
pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOpFamily>(
input: &Tensor<Line<C>>,
scalar: C,
output: &mut Tensor<Line<C>>,
) {
if ABSOLUTE_POS >= output.len() {
return;
}

output[ABSOLUTE_POS] = O::execute(input[ABSOLUTE_POS], Line::new(scalar));
output[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(input[ABSOLUTE_POS], Line::new(scalar));
}

#[cube(launch)]
pub(crate) fn kernel_binop<BS: BinaryOpSpec, O: BinaryOp<BS::C>>(
lhs: &Tensor<Line<BS::C>>,
rhs: &Tensor<Line<BS::C>>,
out: &mut Tensor<Line<BS::C>>,
#[cube(launch_unchecked)]
pub(crate) fn kernel_binop<C: Numeric, O: BinaryOpFamily>(
lhs: &Tensor<Line<C>>,
rhs: &Tensor<Line<C>>,
out: &mut Tensor<Line<C>>,
#[comptime] rank: Option<u32>,
#[comptime] to_contiguous_lhs: bool,
#[comptime] to_contiguous_rhs: bool,
Expand All @@ -106,7 +136,7 @@ pub(crate) fn kernel_binop<BS: BinaryOpSpec, O: BinaryOp<BS::C>>(
}

if to_contiguous_lhs {
offset_lhs = index_offset_with_layout::<BS::C, BS::C>(
offset_lhs = index_offset_with_layout::<C, C>(
lhs,
out,
offset_out,
Expand All @@ -117,7 +147,7 @@ pub(crate) fn kernel_binop<BS: BinaryOpSpec, O: BinaryOp<BS::C>>(
}

if to_contiguous_rhs {
offset_rhs = index_offset_with_layout::<BS::C, BS::C>(
offset_rhs = index_offset_with_layout::<C, C>(
rhs,
out,
offset_out,
Expand All @@ -127,20 +157,27 @@ pub(crate) fn kernel_binop<BS: BinaryOpSpec, O: BinaryOp<BS::C>>(
);
}

out[offset_out] = O::execute(lhs[offset_lhs], rhs[offset_rhs]);
out[offset_out] = O::BinaryOp::<C>::execute(lhs[offset_lhs], rhs[offset_rhs]);
}

pub(crate) fn launch_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(
pub(crate) fn launch_binop<R: JitRuntime, E: JitElement, O: BinaryOpFamily>(
lhs: JitTensor<R>,
rhs: JitTensor<R>,
) -> JitTensor<R> {
let ndims = lhs.shape.num_dims();
let vectorization_factor_lhs =
tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, ndims - 1);
let vectorization_factor_rhs =
tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, ndims - 1);

let vectorization_factor = Ord::min(vectorization_factor_lhs, vectorization_factor_rhs);
let line_size_lhs = tensor_line_size_parallel(
R::line_size_elem(&E::as_elem_native_unchecked()),
&lhs.shape.dims,
&lhs.strides,
ndims - 1,
);
let line_size_rhs = tensor_line_size_parallel(
R::line_size_elem(&E::as_elem_native_unchecked()),
&rhs.shape.dims,
&rhs.strides,
ndims - 1,
);
let line_size = Ord::min(line_size_lhs, line_size_rhs);

let mut shape_out = vec![0; ndims];
lhs.shape
Expand All @@ -157,59 +194,60 @@ pub(crate) fn launch_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(
let num_elems = shape_out.num_elements();

let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);

if lhs.can_mut_broadcast(&rhs) {
kernel_binop::launch::<Spec<E>, O, R>(
&client,
cube_count,
cube_dim,
lhs.as_tensor_arg::<E>(vectorization_factor),
rhs.as_tensor_arg::<E>(vectorization_factor),
TensorArg::alias(0),
None,
false,
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
);

lhs
} else if rhs.can_mut_broadcast(&lhs) {
kernel_binop::launch::<Spec<E>, O, R>(
&client,
cube_count,
cube_dim,
lhs.as_tensor_arg::<E>(vectorization_factor),
rhs.as_tensor_arg::<E>(vectorization_factor),
TensorArg::alias(1),
None,
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
false,
);

rhs
} else {
let output = empty_device::<R, E>(lhs.client.clone(), lhs.device.clone(), shape_out);
let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape;

kernel_binop::launch::<Spec<E>, O, R>(
&client,
cube_count,
cube_dim,
lhs.as_tensor_arg::<E>(vectorization_factor),
rhs.as_tensor_arg::<E>(vectorization_factor),
output.as_tensor_arg::<E>(vectorization_factor),
None,
to_contiguous_lhs,
to_contiguous_rhs,
);

output
let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);

unsafe {
if lhs.can_mut_broadcast(&rhs) {
kernel_binop::launch_unchecked::<E, O, R>(
&client,
cube_count,
cube_dim,
lhs.as_tensor_arg::<E>(line_size),
rhs.as_tensor_arg::<E>(line_size),
TensorArg::alias(0),
None,
false,
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
);

lhs
} else if rhs.can_mut_broadcast(&lhs) {
kernel_binop::launch_unchecked::<E, O, R>(
&client,
cube_count,
cube_dim,
lhs.as_tensor_arg::<E>(line_size),
rhs.as_tensor_arg::<E>(line_size),
TensorArg::alias(1),
None,
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
false,
);

rhs
} else {
let output = empty_device::<R, E>(lhs.client.clone(), lhs.device.clone(), shape_out);
let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape;

kernel_binop::launch_unchecked::<E, O, R>(
&client,
cube_count,
cube_dim,
lhs.as_tensor_arg::<E>(line_size),
rhs.as_tensor_arg::<E>(line_size),
output.as_tensor_arg::<E>(line_size),
None,
to_contiguous_lhs,
to_contiguous_rhs,
);

output
}
}
}

pub(crate) fn launch_scalar_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(
pub(crate) fn launch_scalar_binop<R: JitRuntime, E: JitElement, O: BinaryOpFamily>(
mut tensor: JitTensor<R>,
scalar: E,
) -> JitTensor<R> {
Expand All @@ -219,42 +257,47 @@ pub(crate) fn launch_scalar_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(

// Vectorization is only enabled when the last dimension is contiguous.
let ndims = tensor.shape.num_dims();
let vectorization_factor =
tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, ndims - 1);
let line_size = tensor_line_size_parallel(
R::line_size_elem(&E::as_elem_native_unchecked()),
&tensor.shape.dims,
&tensor.strides,
ndims - 1,
);
let client = tensor.client.clone();
let num_elems = tensor.shape.num_elements();

let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);

if tensor.can_mut() {
kernel_scalar_binop::launch::<Spec<E>, O, R>(
&client,
cube_count,
cube_dim,
tensor.as_tensor_arg::<E>(vectorization_factor),
ScalarArg::new(scalar),
TensorArg::alias(0),
);

tensor
} else {
let output = empty_device::<R, E>(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape.clone(),
);

kernel_scalar_binop::launch::<Spec<E>, O, R>(
&client,
cube_count,
CubeDim::default(),
tensor.as_tensor_arg::<E>(vectorization_factor),
ScalarArg::new(scalar),
output.as_tensor_arg::<E>(vectorization_factor),
);

output
let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);

unsafe {
if tensor.can_mut() {
kernel_scalar_binop::launch_unchecked::<E, O, R>(
&client,
cube_count,
cube_dim,
tensor.as_tensor_arg::<E>(line_size),
ScalarArg::new(scalar),
TensorArg::alias(0),
);

tensor
} else {
let output = empty_device::<R, E>(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape.clone(),
);

kernel_scalar_binop::launch_unchecked::<E, O, R>(
&client,
cube_count,
CubeDim::default(),
tensor.as_tensor_arg::<E>(line_size),
ScalarArg::new(scalar),
output.as_tensor_arg::<E>(line_size),
);

output
}
}
}
Loading
Loading