Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel] switch from PYBIND11 to TORCH_LIBRARY #617

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions exllamav2/exllamav2_ext/config.h

This file was deleted.

30 changes: 16 additions & 14 deletions exllamav2/exllamav2_ext/ext_cache.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
@@ -15,7 +15,8 @@

#include "cpp/util.h"

void fp16_to_fp8(torch::Tensor in_tensor, torch::Tensor out_tensor, int batch_size, int offset, int width)
void fp16_to_fp8(torch::Tensor in_tensor, torch::Tensor out_tensor,
int64_t batch_size, int64_t offset, int64_t width)
{
TORCH_CHECK_DTYPE(in_tensor, kHalf);
TORCH_CHECK_DTYPE(out_tensor, kUInt8);
@@ -46,7 +47,8 @@ void fp16_to_fp8(torch::Tensor in_tensor, torch::Tensor out_tensor, int batch_si
);
}

void fp8_to_fp16(torch::Tensor in_tensor, torch::Tensor out_tensor, int batch_size, int offset, int width)
void fp8_to_fp16(torch::Tensor in_tensor, torch::Tensor out_tensor,
int64_t batch_size, int64_t offset, int64_t width)
{
TORCH_CHECK_DTYPE(in_tensor, kUInt8);
TORCH_CHECK_DTYPE(out_tensor, kHalf);
@@ -85,15 +87,15 @@ void fp16_to_q_kv
torch::Tensor v_in,
torch::Tensor v_out,
torch::Tensor v_scales,
int batch_size,
int offset,
int width,
int page_size,
int64_t batch_size,
int64_t offset,
int64_t width,
int64_t page_size,
torch::Tensor cache_seqlens,
torch::Tensor block_table,
torch::Tensor cal_k,
torch::Tensor cal_v,
int wbits
int64_t wbits
)
{
TORCH_CHECK_DTYPE(k_in, kHalf);
@@ -193,15 +195,15 @@ void q_to_fp16_kv
torch::Tensor v_in,
torch::Tensor v_out,
torch::Tensor v_scales,
int batch_size,
int offset,
int width,
int page_size,
int64_t batch_size,
int64_t offset,
int64_t width,
int64_t page_size,
torch::Tensor cache_seqlens,
torch::Tensor block_table,
torch::Tensor cal_k,
torch::Tensor cal_v,
int wbits
int64_t wbits
)
{
TORCH_CHECK_DTYPE(k_in, kUInt8);
@@ -310,7 +312,7 @@ int count_match
(
torch::Tensor a,
torch::Tensor b,
int max_a
int64_t max_a
)
{
uint64_t* pa = (uint64_t*) a.data_ptr();
54 changes: 0 additions & 54 deletions exllamav2/exllamav2_ext/ext_cache.h

This file was deleted.

4 changes: 2 additions & 2 deletions exllamav2/exllamav2_ext/ext_element.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
@@ -17,7 +17,7 @@
void softcap_
(
torch::Tensor x,
float scale
double scale
)
{
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
6 changes: 0 additions & 6 deletions exllamav2/exllamav2_ext/ext_element.h

This file was deleted.

6 changes: 3 additions & 3 deletions exllamav2/exllamav2_ext/ext_gemm.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
@@ -20,8 +20,8 @@ void gemm_half_half_half
torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
const float alpha,
const float beta,
const double alpha,
const double beta,
bool force_cublas
)
{
10 changes: 0 additions & 10 deletions exllamav2/exllamav2_ext/ext_gemm.h

This file was deleted.

2 changes: 1 addition & 1 deletion exllamav2/exllamav2_ext/ext_hadamard.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <cstdint>
#include <cstdio>
#include <pybind11/pybind11.h>
10 changes: 0 additions & 10 deletions exllamav2/exllamav2_ext/ext_hadamard.h

This file was deleted.

16 changes: 8 additions & 8 deletions exllamav2/exllamav2_ext/ext_norm.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
@@ -25,7 +25,7 @@ void rms_norm
torch::Tensor x,
torch::Tensor w,
torch::Tensor y,
float epsilon
double epsilon
)
{
bool input_fp32 = x.dtype() == torch::kFloat;
@@ -61,7 +61,7 @@ void rms_norm_tp
std::vector<torch::Tensor> x,
std::vector<torch::Tensor> w,
std::vector<torch::Tensor> y,
float epsilon,
double epsilon,
uintptr_t tp_context
)
{
@@ -96,7 +96,7 @@ void rms_norm_
(
torch::Tensor x,
torch::Tensor w,
float epsilon
double epsilon
)
{
rms_norm(x, w, x, epsilon);
@@ -111,7 +111,7 @@ void layer_norm
torch::Tensor w,
torch::Tensor b,
torch::Tensor y,
float epsilon
double epsilon
)
{
TORCH_CHECK_DTYPE(x, kHalf);
@@ -147,7 +147,7 @@ void layer_norm_
torch::Tensor x,
torch::Tensor w,
torch::Tensor b,
float epsilon
double epsilon
)
{
layer_norm(x, w, b, x, epsilon);
@@ -162,7 +162,7 @@ void head_norm
torch::Tensor w,
torch::Tensor b,
torch::Tensor y,
float epsilon
double epsilon
)
{
TORCH_CHECK_DTYPE(x, kHalf);
@@ -202,7 +202,7 @@ void head_norm_
torch::Tensor x,
torch::Tensor w,
torch::Tensor b,
float epsilon
double epsilon
)
{
head_norm(x, w, b, x, epsilon);
61 changes: 0 additions & 61 deletions exllamav2/exllamav2_ext/ext_norm.h

This file was deleted.

749 changes: 749 additions & 0 deletions exllamav2/exllamav2_ext/ext_ops.h

Large diffs are not rendered by default.

52 changes: 26 additions & 26 deletions exllamav2/exllamav2_ext/ext_qattn.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
@@ -37,14 +37,14 @@ uintptr_t make_q_attn
// torch::Tensor temp_k,
// torch::Tensor temp_v,
torch::Tensor temp_dq,
int max_rows,
int hidden_size,
int num_heads,
int num_kv_heads,
int head_dim,
int max_seq_len,
bool has_residual,
int rope_style,
int64_t max_rows,
int64_t hidden_size,
int64_t num_heads,
int64_t num_kv_heads,
int64_t head_dim,
int64_t max_seq_len,
bool64_t has_residual,
int64_t rope_style,
torch::Tensor q_norm,
torch::Tensor k_norm,
torch::Tensor post_layernorm,
@@ -113,9 +113,9 @@ void q_attn_forward_1
(
uintptr_t q_attn,
torch::Tensor x,
int batch_size,
int q_len,
int past_len,
int64_t batch_size,
int64_t q_len,
int64_t past_len,
torch::Tensor past_lens,
torch::Tensor q_temp,
torch::Tensor k_temp,
@@ -160,8 +160,8 @@ void q_attn_forward_2
uintptr_t q_attn,
torch::Tensor x,
torch::Tensor attn_output,
int batch_size,
int q_len,
int64_t batch_size,
int64_t q_len,
const std::vector<uintptr_t>& loras,
torch::Tensor loras_temp
)
@@ -269,20 +269,20 @@ void tp_attn_forward_paged_
const std::vector<torch::Tensor> &k_cache,
const std::vector<torch::Tensor> &v_cache,
const std::vector<torch::Tensor> &pre_layernorm,
float norm_epsilon,
double norm_epsilon,
const std::vector<uintptr_t> &q_proj,
const std::vector<uintptr_t> &k_proj,
const std::vector<uintptr_t> &v_proj,
const std::vector<uintptr_t> &o_proj,
int head_dim,
int rope_style,
int batch_size,
int q_len,
int64_t head_dim,
int64_t rope_style,
int64_t batch_size,
int64_t q_len,
const std::vector<torch::Tensor> &sin,
const std::vector<torch::Tensor> &cos,
const std::vector<torch::Tensor> &past_lens,
const std::vector<torch::Tensor> &block_index,
float scaling
double scaling
)
{
auto fwd_kvcache_func = py::module_::import("flash_attn_2_cuda").attr("fwd_kvcache");
@@ -506,19 +506,19 @@ void tp_attn_forward_
const std::vector<torch::Tensor> &k_cache,
const std::vector<torch::Tensor> &v_cache,
const std::vector<torch::Tensor> &pre_layernorm,
float norm_epsilon,
double norm_epsilon,
const std::vector<uintptr_t> &q_proj,
const std::vector<uintptr_t> &k_proj,
const std::vector<uintptr_t> &v_proj,
const std::vector<uintptr_t> &o_proj,
int head_dim,
int rope_style,
int batch_size,
int q_len,
int64_t head_dim,
int64_t rope_style,
int64_t batch_size,
int64_t q_len,
const std::vector<torch::Tensor> &sin,
const std::vector<torch::Tensor> &cos,
const std::vector<torch::Tensor> &past_len_tp,
float scaling
double scaling
)
{
auto fwd_kvcache_func = py::module_::import("flash_attn_2_cuda").attr("fwd_kvcache");
165 changes: 0 additions & 165 deletions exllamav2/exllamav2_ext/ext_qattn.h

This file was deleted.

8 changes: 4 additions & 4 deletions exllamav2/exllamav2_ext/ext_qmatrix.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
@@ -32,7 +32,7 @@ uintptr_t make_q_matrix
torch::Tensor gptq_g_idx,
torch::Tensor bias,
torch::Tensor temp_dq,
int max_dq_rows
int64_t max_dq_rows
)
{
TORCH_CHECK_DTYPE(q_weight, kInt);
@@ -120,7 +120,7 @@ uintptr_t make_q_matrix_split
torch::Tensor gptq_g_idx,
torch::Tensor bias,
torch::Tensor temp_dq,
int max_dq_rows
int64_t max_dq_rows
)
{
TORCH_CHECK(
@@ -245,7 +245,7 @@ void gemm_half_q_half_tp
const std::vector<torch::Tensor> &c,
bool force_cuda,
uintptr_t tp_context,
int t_device
int64_t t_device
)
{
ExtTPContext* ctx = reinterpret_cast<ExtTPContext*> (tp_context);
79 changes: 0 additions & 79 deletions exllamav2/exllamav2_ext/ext_qmatrix.h

This file was deleted.

14 changes: 7 additions & 7 deletions exllamav2/exllamav2_ext/ext_qmlp.cpp
Original file line number Diff line number Diff line change
@@ -24,15 +24,15 @@ uintptr_t make_q_mlp
torch::Tensor layernorm,
torch::Tensor layernorm_bias,
bool layernorm_is_rms,
float norm_epsilon,
double norm_epsilon,
uintptr_t q_gate,
uintptr_t q_up,
uintptr_t q_down,
torch::Tensor temp_state,
torch::Tensor temp_a,
torch::Tensor temp_b,
torch::Tensor temp_dq,
int max_rows,
int64_t max_rows,
bool act_gelu,
bool has_residual,
torch::Tensor post_layernorm,
@@ -173,10 +173,10 @@ uintptr_t make_q_moe_mlp
torch::Tensor layernorm,
torch::Tensor layernorm_bias,
bool layernorm_is_rms,
float norm_epsilon,
double norm_epsilon,
torch::Tensor gate,
int num_experts,
int num_experts_per_token,
int64_t num_experts,
int64_t num_experts_per_token,
const std::vector<uintptr_t>& w1,
const std::vector<uintptr_t>& w2,
const std::vector<uintptr_t>& w3,
@@ -186,7 +186,7 @@ uintptr_t make_q_moe_mlp
torch::Tensor temp_b,
torch::Tensor temp_logits,
torch::Tensor temp_dq,
int max_rows,
int64_t max_rows,
bool act_gelu
)
{
@@ -334,7 +334,7 @@ void tp_mlp_forward_
const std::vector<torch::Tensor> &temp_up_,
const std::vector<torch::Tensor> &temp_down_,
const std::vector<torch::Tensor> &pre_layernorm,
float norm_epsilon,
double norm_epsilon,
const std::vector<uintptr_t> &gate,
const std::vector<uintptr_t> &up,
const std::vector<uintptr_t> &down,
110 changes: 0 additions & 110 deletions exllamav2/exllamav2_ext/ext_qmlp.h

This file was deleted.

34 changes: 17 additions & 17 deletions exllamav2/exllamav2_ext/ext_quant.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
@@ -50,7 +50,7 @@ void pack_columns
(
torch::Tensor input,
torch::Tensor output,
int bits
int64_t bits
)
{
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
@@ -84,12 +84,12 @@ void quantize_err
torch::Tensor input,
torch::Tensor output,
torch::Tensor scale,
float qzero,
float maxq,
float err_norm,
float min_p,
float max_p,
int p_grid
double qzero,
double maxq,
double err_norm,
double min_p,
double max_p,
int64_t p_grid
)
{
TORCH_CHECK_DTYPE(input, kFloat);
@@ -126,8 +126,8 @@ void quantize
torch::Tensor output,
torch::Tensor scale,
torch::Tensor out_q,
float qzero,
float maxq
double qzero,
double maxq
)
{
TORCH_CHECK_DTYPE(input, kFloat);
@@ -152,15 +152,15 @@ void quantize
);
}

std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, uint64_t, float> sim_anneal
std::tuple<std::vector<std::tuple<uint64_t, double>>, std::vector<int64_t>, double, uint64_t, double> sim_anneal
(
const std::vector<std::vector<std::tuple<uint64_t, float>>>& slots,
const std::vector<std::vector<std::tuple<uint64_t, double>>>& slots,
uint64_t max_cost,
float initial_temp,
float cooling_factor,
float min_temp,
int iterations,
float norm
double initial_temp,
double cooling_factor,
double min_temp,
int64_t iterations,
double norm
)
{
int num_slots = slots.size();
51 changes: 0 additions & 51 deletions exllamav2/exllamav2_ext/ext_quant.h

This file was deleted.

6 changes: 3 additions & 3 deletions exllamav2/exllamav2_ext/ext_rope.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
@@ -22,8 +22,8 @@ void rope_
torch::Tensor x,
torch::Tensor sin,
torch::Tensor cos,
int past_len,
int num_heads,
int64_t past_len,
int64_t num_heads,
int head_dim,
torch::Tensor offsets,
bool neox_style
12 changes: 0 additions & 12 deletions exllamav2/exllamav2_ext/ext_rope.h

This file was deleted.

2 changes: 0 additions & 2 deletions exllamav2/exllamav2_ext/ext_safetensors.h

This file was deleted.

55 changes: 0 additions & 55 deletions exllamav2/exllamav2_ext/ext_sampling.h

This file was deleted.

2 changes: 1 addition & 1 deletion exllamav2/exllamav2_ext/ext_tp.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
124 changes: 0 additions & 124 deletions exllamav2/exllamav2_ext/ext_tp.h

This file was deleted.