Skip to content

Commit

Permalink
Refactor unary + binary kernels (#2665)
Browse files Browse the repository at this point in the history
* Refactor unary + binary kernels

* Improve float unary

* Cleanup binary
  • Loading branch information
nathanielsimard authored Jan 7, 2025
1 parent a644430 commit 2b4be6c
Show file tree
Hide file tree
Showing 10 changed files with 549 additions and 467 deletions.
281 changes: 162 additions & 119 deletions crates/burn-jit/src/kernel/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,60 @@ use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor,
use burn_tensor::Shape;
use cubecl::{
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
tensor_vectorization_factor,
tensor_line_size_parallel,
};

use super::into_contiguous;

pub(crate) trait BinaryOpFamily: Send + Sync + 'static {
type BinaryOp<C: Numeric>: BinaryOp<C>;
}

#[cube]
pub(crate) trait BinaryOp<C: Numeric>: 'static + Send + Sync {
/// Execute a binary operation.
fn execute(lhs: Line<C>, rhs: Line<C>) -> Line<C>;
}

pub(crate) trait BinaryOpSpec: Send + Sync + 'static {
type C: Numeric;
}
pub(crate) struct Spec<C: Numeric> {
_c: PhantomData<C>,
}

impl<C: Numeric> BinaryOpSpec for Spec<C> {
type C = C;
}

pub(crate) struct AddOp;
pub(crate) struct SubOp;
pub(crate) struct MulOp;
pub(crate) struct DivOp;
pub(crate) struct RemainderOp;
pub(crate) struct PowOp;

/// Since Powf only works on float, but we still want to implement the numeric binary op family, we
/// set another precision in the family type to cast, when necessary, the input value to a valid
/// float.
///
/// Because of this we won't benefit from the cubecl rust compilation speed improvement from using
/// the family pattern for [PowOp], but at least we don't duplicate code.
pub(crate) struct PowOp<F: Float> {
_f: PhantomData<F>,
}

impl BinaryOpFamily for AddOp {
type BinaryOp<C: Numeric> = Self;
}

impl BinaryOpFamily for SubOp {
type BinaryOp<C: Numeric> = Self;
}

impl BinaryOpFamily for MulOp {
type BinaryOp<C: Numeric> = Self;
}

impl BinaryOpFamily for DivOp {
type BinaryOp<C: Numeric> = Self;
}

impl BinaryOpFamily for RemainderOp {
type BinaryOp<C: Numeric> = Self;
}

impl<F: Float> BinaryOpFamily for PowOp<F> {
type BinaryOp<C: Numeric> = Self;
}

#[cube]
impl<N: Numeric> BinaryOp<N> for AddOp {
Expand Down Expand Up @@ -69,30 +95,34 @@ impl<N: Numeric> BinaryOp<N> for RemainderOp {
}

#[cube]
impl<N: Float> BinaryOp<N> for PowOp {
impl<N: Numeric, F: Float> BinaryOp<N> for PowOp<F> {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
Line::powf(lhs, rhs)
let lhs = Line::<F>::cast_from(lhs);
let rhs = Line::<F>::cast_from(rhs);
let out = Line::powf(lhs, rhs);

Line::cast_from(out)
}
}

#[cube(launch)]
pub(crate) fn kernel_scalar_binop<BS: BinaryOpSpec, O: BinaryOp<BS::C>>(
input: &Tensor<Line<BS::C>>,
scalar: BS::C,
output: &mut Tensor<Line<BS::C>>,
#[cube(launch_unchecked)]
pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOpFamily>(
input: &Tensor<Line<C>>,
scalar: C,
output: &mut Tensor<Line<C>>,
) {
if ABSOLUTE_POS >= output.len() {
return;
}

output[ABSOLUTE_POS] = O::execute(input[ABSOLUTE_POS], Line::new(scalar));
output[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(input[ABSOLUTE_POS], Line::new(scalar));
}

#[cube(launch)]
pub(crate) fn kernel_binop<BS: BinaryOpSpec, O: BinaryOp<BS::C>>(
lhs: &Tensor<Line<BS::C>>,
rhs: &Tensor<Line<BS::C>>,
out: &mut Tensor<Line<BS::C>>,
#[cube(launch_unchecked)]
pub(crate) fn kernel_binop<C: Numeric, O: BinaryOpFamily>(
lhs: &Tensor<Line<C>>,
rhs: &Tensor<Line<C>>,
out: &mut Tensor<Line<C>>,
#[comptime] rank: Option<u32>,
#[comptime] to_contiguous_lhs: bool,
#[comptime] to_contiguous_rhs: bool,
Expand All @@ -106,7 +136,7 @@ pub(crate) fn kernel_binop<BS: BinaryOpSpec, O: BinaryOp<BS::C>>(
}

if to_contiguous_lhs {
offset_lhs = index_offset_with_layout::<BS::C, BS::C>(
offset_lhs = index_offset_with_layout::<C, C>(
lhs,
out,
offset_out,
Expand All @@ -117,7 +147,7 @@ pub(crate) fn kernel_binop<BS: BinaryOpSpec, O: BinaryOp<BS::C>>(
}

if to_contiguous_rhs {
offset_rhs = index_offset_with_layout::<BS::C, BS::C>(
offset_rhs = index_offset_with_layout::<C, C>(
rhs,
out,
offset_out,
Expand All @@ -127,20 +157,27 @@ pub(crate) fn kernel_binop<BS: BinaryOpSpec, O: BinaryOp<BS::C>>(
);
}

out[offset_out] = O::execute(lhs[offset_lhs], rhs[offset_rhs]);
out[offset_out] = O::BinaryOp::<C>::execute(lhs[offset_lhs], rhs[offset_rhs]);
}

pub(crate) fn launch_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(
pub(crate) fn launch_binop<R: JitRuntime, E: JitElement, O: BinaryOpFamily>(
lhs: JitTensor<R>,
rhs: JitTensor<R>,
) -> JitTensor<R> {
let ndims = lhs.shape.num_dims();
let vectorization_factor_lhs =
tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, ndims - 1);
let vectorization_factor_rhs =
tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, ndims - 1);

let vectorization_factor = Ord::min(vectorization_factor_lhs, vectorization_factor_rhs);
let line_size_lhs = tensor_line_size_parallel(
R::line_size_elem(&E::as_elem_native_unchecked()),
&lhs.shape.dims,
&lhs.strides,
ndims - 1,
);
let line_size_rhs = tensor_line_size_parallel(
R::line_size_elem(&E::as_elem_native_unchecked()),
&rhs.shape.dims,
&rhs.strides,
ndims - 1,
);
let line_size = Ord::min(line_size_lhs, line_size_rhs);

let mut shape_out = vec![0; ndims];
lhs.shape
Expand All @@ -157,59 +194,60 @@ pub(crate) fn launch_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(
let num_elems = shape_out.num_elements();

let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);

if lhs.can_mut_broadcast(&rhs) {
kernel_binop::launch::<Spec<E>, O, R>(
&client,
cube_count,
cube_dim,
lhs.as_tensor_arg::<E>(vectorization_factor),
rhs.as_tensor_arg::<E>(vectorization_factor),
TensorArg::alias(0),
None,
false,
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
);

lhs
} else if rhs.can_mut_broadcast(&lhs) {
kernel_binop::launch::<Spec<E>, O, R>(
&client,
cube_count,
cube_dim,
lhs.as_tensor_arg::<E>(vectorization_factor),
rhs.as_tensor_arg::<E>(vectorization_factor),
TensorArg::alias(1),
None,
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
false,
);

rhs
} else {
let output = empty_device::<R, E>(lhs.client.clone(), lhs.device.clone(), shape_out);
let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape;

kernel_binop::launch::<Spec<E>, O, R>(
&client,
cube_count,
cube_dim,
lhs.as_tensor_arg::<E>(vectorization_factor),
rhs.as_tensor_arg::<E>(vectorization_factor),
output.as_tensor_arg::<E>(vectorization_factor),
None,
to_contiguous_lhs,
to_contiguous_rhs,
);

output
let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);

unsafe {
if lhs.can_mut_broadcast(&rhs) {
kernel_binop::launch_unchecked::<E, O, R>(
&client,
cube_count,
cube_dim,
lhs.as_tensor_arg::<E>(line_size),
rhs.as_tensor_arg::<E>(line_size),
TensorArg::alias(0),
None,
false,
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
);

lhs
} else if rhs.can_mut_broadcast(&lhs) {
kernel_binop::launch_unchecked::<E, O, R>(
&client,
cube_count,
cube_dim,
lhs.as_tensor_arg::<E>(line_size),
rhs.as_tensor_arg::<E>(line_size),
TensorArg::alias(1),
None,
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
false,
);

rhs
} else {
let output = empty_device::<R, E>(lhs.client.clone(), lhs.device.clone(), shape_out);
let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape;

kernel_binop::launch_unchecked::<E, O, R>(
&client,
cube_count,
cube_dim,
lhs.as_tensor_arg::<E>(line_size),
rhs.as_tensor_arg::<E>(line_size),
output.as_tensor_arg::<E>(line_size),
None,
to_contiguous_lhs,
to_contiguous_rhs,
);

output
}
}
}

pub(crate) fn launch_scalar_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(
pub(crate) fn launch_scalar_binop<R: JitRuntime, E: JitElement, O: BinaryOpFamily>(
mut tensor: JitTensor<R>,
scalar: E,
) -> JitTensor<R> {
Expand All @@ -219,42 +257,47 @@ pub(crate) fn launch_scalar_binop<R: JitRuntime, E: JitElement, O: BinaryOp<E>>(

// Vectorization is only enabled when the last dimension is contiguous.
let ndims = tensor.shape.num_dims();
let vectorization_factor =
tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, ndims - 1);
let line_size = tensor_line_size_parallel(
R::line_size_elem(&E::as_elem_native_unchecked()),
&tensor.shape.dims,
&tensor.strides,
ndims - 1,
);
let client = tensor.client.clone();
let num_elems = tensor.shape.num_elements();

let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);

if tensor.can_mut() {
kernel_scalar_binop::launch::<Spec<E>, O, R>(
&client,
cube_count,
cube_dim,
tensor.as_tensor_arg::<E>(vectorization_factor),
ScalarArg::new(scalar),
TensorArg::alias(0),
);

tensor
} else {
let output = empty_device::<R, E>(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape.clone(),
);

kernel_scalar_binop::launch::<Spec<E>, O, R>(
&client,
cube_count,
CubeDim::default(),
tensor.as_tensor_arg::<E>(vectorization_factor),
ScalarArg::new(scalar),
output.as_tensor_arg::<E>(vectorization_factor),
);

output
let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);

unsafe {
if tensor.can_mut() {
kernel_scalar_binop::launch_unchecked::<E, O, R>(
&client,
cube_count,
cube_dim,
tensor.as_tensor_arg::<E>(line_size),
ScalarArg::new(scalar),
TensorArg::alias(0),
);

tensor
} else {
let output = empty_device::<R, E>(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape.clone(),
);

kernel_scalar_binop::launch_unchecked::<E, O, R>(
&client,
cube_count,
CubeDim::default(),
tensor.as_tensor_arg::<E>(line_size),
ScalarArg::new(scalar),
output.as_tensor_arg::<E>(line_size),
);

output
}
}
}
Loading

0 comments on commit 2b4be6c

Please sign in to comment.