diff --git a/03_nf4_dequant/flashzxi/CMakeLists.txt b/03_nf4_dequant/flashzxi/CMakeLists.txt new file mode 100644 index 0000000..3a3f635 --- /dev/null +++ b/03_nf4_dequant/flashzxi/CMakeLists.txt @@ -0,0 +1,42 @@ +cmake_minimum_required(VERSION 3.18) + +project(nf4_dequant LANGUAGES CXX CUDA) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +set(CMAKE_CUDA_ARCHITECTURES native) + +if(NOT CMAKE_CONFIGURATION_TYPES AND NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS + "Debug" "Release" "RelWithDebInfo" "MinSizeRel") +endif() + +add_executable(nf4_dequant + src/main.cu + src/nf4_dequant_naive.cu + src/nf4_dequant_warp8.cu +) + +target_include_directories(nf4_dequant PRIVATE + ${CMAKE_SOURCE_DIR}/include +) + +# 单 TU/简单工程:关闭 RDC 更利于调试 +set_target_properties(nf4_dequant PROPERTIES + CUDA_SEPARABLE_COMPILATION OFF +) + +target_compile_options(nf4_dequant PRIVATE + $<$,$>:-g -O0> + $<$,$>:-G -g -O0> + + $<$,$>:-O3> + $<$,$>:-O3> + + $<$,$>:-g -O2> + $<$,$>:-lineinfo -g -O2> +) \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/README.md b/03_nf4_dequant/flashzxi/README.md new file mode 100644 index 0000000..e69de29 diff --git a/03_nf4_dequant/flashzxi/Report.md b/03_nf4_dequant/flashzxi/Report.md new file mode 100644 index 0000000..258aada --- /dev/null +++ b/03_nf4_dequant/flashzxi/Report.md @@ -0,0 +1,16 @@ +## NF4 反量化 +author: flashzxi + +本项目是利用cuda高效计算nf4反量化,对比bitsandbytes 实现 + +本项目的假设: +每个block大小为64个元素 + +二级量化每个group包含256个block. + +## 实现 +总共实现了三个版本,一个最简单的naive版本,一个二级反量化和一级反量化分开计算的版本以及最终的单独kernel解两层反量化的版本。其中naive版本在`src/nf4_dequant_naive.cu`,其余两个版本都在`src/nf4_dequant_warp8.cu` + + + +开发工程中,我尝试 \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/include/common.cuh b/03_nf4_dequant/flashzxi/include/common.cuh new file mode 100644 index 0000000..87eb1d2 --- /dev/null +++ b/03_nf4_dequant/flashzxi/include/common.cuh @@ -0,0 +1,127 @@ +// +// Created by core_dump on 2026/2/25. +// + +#pragma once + +#include +#include +#include +#include +#include + +__host__ __device__ __forceinline__ +float mix_mul(float fp, __half h) { + return fp * __half2float(h); +} + +__host__ __device__ __forceinline__ +float mix_mul(float fp, __nv_bfloat16 h) { + return fp * __bfloat162float(h); +} + +__host__ __device__ __forceinline__ +float f162float(__half h) { + return __half2float(h); +} + +__host__ __device__ __forceinline__ +float f162float(__nv_bfloat16 h) { + return __bfloat162float(h); +} + + +#define CUDA_CHECK(call) \ +{ \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ \ + << " - " << cudaGetErrorString(err) << "\n"; \ + std::exit(-1); \ + } \ +} + +class Timer { +public: + using clock = std::chrono::high_resolution_clock; + + Timer() : running_(false), elapsed_ms_(0.0) {} + + void tic() { + start_ = clock::now(); + running_ = true; + } + + double toc() { + if (!running_) { + return elapsed_ms_; + } + auto end = clock::now(); + elapsed_ms_ = std::chrono::duration(end - start_).count(); + running_ = false; + return elapsed_ms_; + } + + double elapsed() const { + if (!running_) { + return elapsed_ms_; + } + auto now = clock::now(); + return std::chrono::duration(now - start_).count(); + } + + void reset() { + running_ = false; + elapsed_ms_ = 0.0; + } + +private: + clock::time_point start_; + bool running_; + double elapsed_ms_; +}; + +class Tracer { +public: + Tracer() {} + + void start() { + timer_.reset(); + timer_.tic(); + } + + void stop() { + total_elapsed_ms_ += timer_.toc(); + } + + Tracer& memcpy_accumulate(uint64_t cpy_size_in_byte) { + total_data_cpy_in_bytes_ += cpy_size_in_byte; + return *this; + } + + double bandwidth_bytes_per_s() const { + if (total_elapsed_ms_ <= 0.0) { + return 0.0; + } + return static_cast(total_data_cpy_in_bytes_) * 1000.0 / total_elapsed_ms_; + } + + double bandwidth_gib_per_s() const { + if (total_elapsed_ms_ <= 0.0) { + return 0.0; + } + constexpr double kBytesPerGiB = 1024.0 * 1024.0 * 1024.0; + return static_cast(total_data_cpy_in_bytes_) * 1000.0 / total_elapsed_ms_ / kBytesPerGiB; + } + + void print(std::ostream& os = std::cout) const { + os << "elapsed: " << total_elapsed_ms_ << " ms, " + << "effective bandwidth: " << bandwidth_gib_per_s() << " GiB/s\n"; + } + +private: + Timer timer_; + + uint64_t total_data_cpy_in_bytes_ = 0; + double total_elapsed_ms_; +}; diff --git a/03_nf4_dequant/flashzxi/include/nf4_dequant.h b/03_nf4_dequant/flashzxi/include/nf4_dequant.h new file mode 100644 index 0000000..85dae58 --- /dev/null +++ b/03_nf4_dequant/flashzxi/include/nf4_dequant.h @@ -0,0 +1,10 @@ +// +// Created by core_dump on 2/25/26. +// + +#pragma once + +#include "quant_state.h" +void nf4_dequant_naive(const QuantState& quant_state, __half* output); +void nf4_dequant_warp8_batch32_two_phase(const QuantState& quant_state, __half* output); +void nf4_dequant_warp8_batch8_one_phase(const QuantState& quant_state, __half* output); \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/include/quant_state.h b/03_nf4_dequant/flashzxi/include/quant_state.h new file mode 100644 index 0000000..981cc07 --- /dev/null +++ b/03_nf4_dequant/flashzxi/include/quant_state.h @@ -0,0 +1,269 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct QuantState { + // header + int num_rows = 0; + int num_cols = 0; + int block_size = 0; + int group_size = 256; // baseline给的是256 + + // data (host) + uint8_t* packed_weights = nullptr; // 每字节存两个 4-bit 索引 + uint8_t* absmax_q = nullptr; + __half* absmax2 = nullptr; + __half code2[256]{}; + float offset = 0.f; + + // runtime param + std::string compute_type; + std::string target_gpu; + + int num_elements = 0; + int num_blocks = 0; + int num_groups = 0; + + __half* ref_result = nullptr; + + int packed_weights_len_in_bytes = 0; + int absmax_q_len_in_bytes = 0; + int absmax2_len_in_bytes = 0; + + void calculate_params() { + num_elements = num_rows * num_cols; + // group_size = 256; + num_blocks = (num_elements + block_size - 1) / block_size; + num_groups = (num_blocks + group_size - 1) / group_size; + + packed_weights_len_in_bytes = (num_elements + 1) / 2; + absmax_q_len_in_bytes = num_blocks; + absmax2_len_in_bytes = 2 * num_groups; // fp16 bytes + } + + void print() { + std::cout << "[header]" << std::endl; + std::cout << "num_rows: " << num_rows << std::endl; + std::cout << "num_cols: " << num_cols << std::endl; + std::cout << "blocksize: " << block_size << std::endl; + + std::cout << std::endl; + std::cout << "[data]" << std::endl; + std::cout << "packed_weights: " << std::endl; + int print_cnt = 0; + for (int i = 0; i < num_elements; i += 2) { + uint8_t v = packed_weights[i / 2]; + int lower = v & 0xF; + int upper = v >> 4; + std::cout << upper << "\t"; + print_cnt ++; + if (print_cnt == num_cols) { + std::cout << std::endl; + print_cnt = 0; + } + std::cout << lower << "\t"; + print_cnt ++; + if (print_cnt == num_cols) { + std::cout << std::endl; + print_cnt = 0; + } + } + std::cout << "absmax_q:" << std::endl; + for (int i = 0; i < num_blocks; ++i) { + std::cout << (int)absmax_q[i] << " "; + } + std::cout << std::endl; + + std::cout << "absmax2:" << std::endl; + for (int i = 0; i < num_groups; ++i) { + std::cout << __half2float(absmax2[i]) << " "; + } + std::cout << std::endl; + + std::cout << "code2: " << std::endl; + for (int i = 0; i < 256; ++i) { + std::cout << __half2float(code2[i]) << " "; + } + std::cout << std::endl; + + std::cout << "offset: " << offset << std::endl; + + } +}; +// --------- helpers: streaming parse, no full-file scanning ---------- + +static void expect_text(std::istream& is, const char* s) { + for (const char* p = s; *p; ++p) { + char c; + if (!is.get(c)) { + throw std::runtime_error(std::string("Unexpected EOF while expecting: ") + s); + } + if (c != *p) { + std::string msg = "Tag mismatch. Expect: "; + msg += s; + msg += " (got different byte)"; + throw std::runtime_error(msg); + } + } +} + +template +static T read_pod(std::istream& is) { + T v{}; + if (!is.read(reinterpret_cast(&v), sizeof(T))) { + throw std::runtime_error("Failed to read POD bytes"); + } + return v; // 假设小端;你写文件也用小端 pack +} + +static void read_bytes(std::istream& is, void* dst, size_t n) { + if (n == 0) return; + if (!is.read(reinterpret_cast(dst), static_cast(n))) { + throw std::runtime_error("Failed to read raw bytes"); + } +} + +static std::string trim_copy(std::string s) { + auto not_space = [](unsigned char ch){ return !std::isspace(ch); }; + s.erase(s.begin(), std::find_if(s.begin(), s.end(), not_space)); + s.erase(std::find_if(s.rbegin(), s.rend(), not_space).base(), s.end()); + return s; +} + +static std::string strip_quotes(std::string s) { + s = trim_copy(std::move(s)); + if (s.size() >= 2) { + char a = s.front(), b = s.back(); + if ((a == '"' && b == '"') || (a == '\'' && b == '\'')) { + return s.substr(1, s.size() - 2); + } + } + return s; +} + +// input_data: w_nf4.bin +// input_conf: 目前不用(保留接口) +static QuantState parse_quant_state(const std::string& input_data, + const std::string& input_conf, + const std::string& ref_result = "") { + + std::ifstream is(input_data, std::ios::binary); + if (!is) { + throw std::runtime_error("Failed to open file: " + input_data); + } + + QuantState st; + + // 你的文件格式(标签文本 + 紧跟二进制)必须严格一致: + // [header]\n + // num_rows: \n + // num_cols: \n + // blocksize: \n + // + // [data]\n + // packed_weights: \n + // absmax_q: \n + // absmax2: \n + // code2: \n + // offset: \n + + expect_text(is, "[header]\n"); + + expect_text(is, "num_rows: "); + int64_t num_rows64 = read_pod(is); + + expect_text(is, "\nnum_cols: "); + int64_t num_cols64 = read_pod(is); + + expect_text(is, "\nblocksize: "); + int32_t blocksize32 = read_pod(is); + + // 注意:QuantState 里用 int,正常矩阵规模不会溢出 + st.num_rows = static_cast(num_rows64); + st.num_cols = static_cast(num_cols64); + st.block_size = static_cast(blocksize32); + + st.calculate_params(); + + // header 后你写了 "\n\n[data]\n" + expect_text(is, "\n\n[data]\n"); + + expect_text(is, "packed_weights: "); + st.packed_weights = new uint8_t[st.packed_weights_len_in_bytes]; + read_bytes(is, st.packed_weights, static_cast(st.packed_weights_len_in_bytes)); + + expect_text(is, "\nabsmax_q: "); + st.absmax_q = new uint8_t[st.absmax_q_len_in_bytes]; + read_bytes(is, st.absmax_q, static_cast(st.absmax_q_len_in_bytes)); + + expect_text(is, "\nabsmax2: "); + st.absmax2 = new __half[st.num_groups]; + read_bytes(is, st.absmax2, static_cast(st.absmax2_len_in_bytes)); + + expect_text(is, "\ncode2: "); + read_bytes(is, st.code2, sizeof(__half) * 256); + + expect_text(is, "\noffset: "); + st.offset = read_pod(is); + + + std::ifstream i_conf(input_conf); + if (!i_conf) { + throw std::runtime_error("Failed to open conf file: " + input_conf); + } + + std::string line; + + while (std::getline(i_conf, line)) { + if (!line.empty() && line.back() == '\r') line.pop_back(); // 兼容 CRLF + + // 去掉注释:支持 # 和 // + auto cut_comment = [&](const std::string& marker) { + auto pos = line.find(marker); + if (pos != std::string::npos) line = line.substr(0, pos); + }; + cut_comment("#"); + cut_comment("//"); + + line = trim_copy(line); + if (line.empty()) continue; + + auto eq = line.find('='); + if (eq == std::string::npos) continue; + + std::string key = trim_copy(line.substr(0, eq)); + std::string val = trim_copy(line.substr(eq + 1)); + + if (key == "blocksize") { + int bs = std::stoi(val); + st.block_size = bs; + } else if (key == "compute_type") { + st.compute_type = strip_quotes(val); + } else if (key == "target_gpu") { + st.target_gpu = strip_quotes(val); + } + } + + if (!ref_result.empty()) { + std::ifstream i_ref_res(ref_result); + if (!i_ref_res) { + throw std::runtime_error("Failed to open conf file: " + ref_result); + } + st.ref_result = new __half[st.num_elements]; + if (!i_ref_res.read(reinterpret_cast(st.ref_result), static_cast(st.num_elements * 2))) { + throw std::runtime_error("Failed to read raw bytes"); + } + } + + return st; +} + diff --git a/03_nf4_dequant/flashzxi/src/main.cu b/03_nf4_dequant/flashzxi/src/main.cu new file mode 100644 index 0000000..ce33734 --- /dev/null +++ b/03_nf4_dequant/flashzxi/src/main.cu @@ -0,0 +1,45 @@ +// +// Created by flashzxi on 2/24/26. +// +#include "quant_state.h" +#include "cuda_runtime.h" +#include "nf4_dequant.h" + +// https://gxtctab8no8.feishu.cn/wiki/UoESwCDZ2iZRcLkdzjvcxTgenOb?from=from_copylink + +int main() { + int row = 10000; + int col = 10000; + std::string file_prefix = std::string("/home/core_dump/Learning-CUDA/03_nf4_dequant/flashzxi/nf4_") + std::to_string(row) + "x" + std::to_string(col) + "_fp16"; + auto conf = parse_quant_state(file_prefix + ".bin", + "/home/core_dump/Learning-CUDA/03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini", + file_prefix + "_w_dequant.bin"); + + // conf.print(); + + // std::cout << "real absmax: "; + // for (int i = 0; i < 4; i ++) { + // int idx = conf.absmax_q[i]; + // std::cout << __half2float(conf.code2[idx] * conf.absmax2[0]) + conf.offset << " "; + // } + + std::cout << std::endl; + __half* ans = new __half[conf.num_elements]; + nf4_dequant_warp8_batch8_one_phase(conf, ans); + + float max_diff = 0.f; + for (int i = 0; i < conf.num_rows; i++) { + for (int j = 0; j < conf.num_cols; j++) { + int idx = i * conf.num_cols + j; + float a = __half2float(ans[idx]); + float b = __half2float(conf.ref_result[idx]); + float diff = fabsf(a - b); + diff /= b; + max_diff = std::max(max_diff, diff); + // std::cout << a << " "; + } + // std::cout << "\n"; + } + std::cout << "max_diff = " << max_diff << "\n"; +} + diff --git a/03_nf4_dequant/flashzxi/src/nf4_dequant_naive.cu b/03_nf4_dequant/flashzxi/src/nf4_dequant_naive.cu new file mode 100644 index 0000000..5d5891f --- /dev/null +++ b/03_nf4_dequant/flashzxi/src/nf4_dequant_naive.cu @@ -0,0 +1,154 @@ +// +// Created by core_dump on 2026/2/25. +// +#include +#include +#include +#include +#include "quant_state.h" +#include "common.cuh" +#include "nf4_dequant.h" + +template +__global__ void dequant_absmax_kernel(const uint8_t* __restrict__ absmax_q, + const FP_T* __restrict__ absmax2, + const FP_T* __restrict__ code2, // 256 + int num_blocks, + int group_size, // blocks per group + float offset, + float* __restrict__ absmax_out) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= num_blocks) return; + + int group = i / group_size; + float s2 = f162float(absmax2[group]); + float c = f162float(code2[absmax_q[i]]); + absmax_out[i] = c * s2 + offset; +} + +// 每个block 128个线程,每个线程负责2个 +template +__global__ void dequant_nf4_kernel(const uint8_t* __restrict__ packed, + const float* __restrict__ absmax, + int num_elements, + int block_size, + OUT_T* __restrict__ out) { + float kNF4[16] = { + -1.0000000f, -0.6961928f, -0.5250731f, -0.3949175f, + -0.2844414f, -0.1847734f, -0.0910500f, 0.0000000f, + 0.0795803f, 0.1609302f, 0.2461123f, 0.3379152f, + 0.4407098f, 0.5626170f, 0.7229568f, 1.0000000f + }; + int t = blockIdx.x * blockDim.x + threadIdx.x; + int elem0 = t * 2; + if (elem0 >= num_elements) return; + + uint8_t byte = packed[t]; + int lo = byte & 0x0F; + int hi = byte >> 4; + + float s0 = absmax[elem0 / block_size]; + float v0 = s0 * kNF4[hi]; + if constexpr (std::is_same_v) { + out[elem0] = __float2half(v0); + } else { + out[elem0] = __float2bfloat16(v0); + } + + int elem1 = elem0 + 1; + if (elem1 < num_elements) { + float s1 = absmax[elem1 / block_size]; + float v1 = s1 * kNF4[lo]; + if constexpr (std::is_same_v) { + out[elem1] = __float2half(v1); + } else { + out[elem1] = __float2bfloat16(v1); + } + } +} + +void nf4_dequant_naive(const QuantState& quant_state, __half* output) { + // 解码scale + uint8_t* scale_q_s; + __half* code2_s; + __half* absmax2_s; + + CUDA_CHECK(cudaMalloc(&scale_q_s, quant_state.num_blocks)); + CUDA_CHECK(cudaMalloc(&code2_s, 256 * sizeof(__half))); + CUDA_CHECK(cudaMalloc(&absmax2_s, quant_state.num_groups * sizeof(__half))); + + CUDA_CHECK(cudaMemcpy(scale_q_s, quant_state.absmax_q, quant_state.absmax_q_len_in_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(code2_s, quant_state.code2, 256 * sizeof(__half), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(absmax2_s, quant_state.absmax2, quant_state.num_groups * sizeof(__half), cudaMemcpyHostToDevice)); + + dim3 dequant_scale_block_dim(128); + dim3 dequant_scale_grid_dim((quant_state.num_blocks + 128 - 1) / 128); + // 解码权重 + float* absmax = nullptr; + size_t absmax_bytes = sizeof(float) * quant_state.num_blocks; + CUDA_CHECK(cudaMalloc(&absmax, absmax_bytes)); + + float* absmax_h = new float[quant_state.num_blocks]; + + if (quant_state.compute_type == "bf16") { + dequant_absmax_kernel<__nv_bfloat16><<>>( + scale_q_s, (__nv_bfloat16*) absmax2_s, + (__nv_bfloat16*) code2_s, quant_state.num_blocks, quant_state.group_size, quant_state.offset, absmax + ); + } else if (quant_state.compute_type == "fp16") { + dequant_absmax_kernel<__half><<>>( + scale_q_s, (__half*) absmax2_s, + (__half*) code2_s, quant_state.num_blocks, quant_state.group_size, quant_state.offset,absmax + ); + } else { + std::cerr << "Type Not Supported, only support bf16 | fp16" << std::endl; + exit(-1); + } + + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + cudaMemcpy(absmax_h, absmax, quant_state.num_blocks * sizeof(float), cudaMemcpyDeviceToHost); + for (int i = 0; i < quant_state.num_blocks; i++) { + std::cout << absmax_h[i] << " "; + } + std::cout << std::endl; + + CUDA_CHECK(cudaFree(scale_q_s)); + CUDA_CHECK(cudaFree(code2_s)); + CUDA_CHECK(cudaFree(absmax2_s)); + + uint8_t* packed_weights_s; + // output + __half* unpacked_weights_s; + + CUDA_CHECK(cudaMalloc(&packed_weights_s, quant_state.packed_weights_len_in_bytes)); + CUDA_CHECK(cudaMalloc(&unpacked_weights_s, quant_state.num_elements * sizeof(__half))); + + CUDA_CHECK(cudaMemcpy(packed_weights_s, quant_state.packed_weights, quant_state.packed_weights_len_in_bytes, cudaMemcpyHostToDevice)) + + dim3 dequant_weights_grid_dim((quant_state.packed_weights_len_in_bytes + dequant_scale_block_dim.x - 1) / dequant_scale_block_dim.x); + + if (quant_state.compute_type == "bf16") { + dequant_nf4_kernel<__nv_bfloat16><<>> ( + packed_weights_s, absmax, quant_state.num_elements, + quant_state.block_size, (__nv_bfloat16*) unpacked_weights_s + ); + } else if (quant_state.compute_type == "fp16") { + dequant_nf4_kernel<__half><<>> ( + packed_weights_s, absmax, quant_state.num_elements, + quant_state.block_size, (__half*) unpacked_weights_s + ); + } else { + std::cerr << "Type Not Supported, only support bf16 | fp16" << std::endl; + exit(-1); + } + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy(output, unpacked_weights_s, quant_state.num_elements * sizeof(__half), cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaFree(packed_weights_s)); + CUDA_CHECK(cudaFree(absmax)); + CUDA_CHECK(cudaFree(unpacked_weights_s)); +} \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/src/nf4_dequant_warp8.cu b/03_nf4_dequant/flashzxi/src/nf4_dequant_warp8.cu new file mode 100644 index 0000000..118aa05 --- /dev/null +++ b/03_nf4_dequant/flashzxi/src/nf4_dequant_warp8.cu @@ -0,0 +1,390 @@ +// +// Created by flashzxi on 2/24/26. +// +#include +#include +#include +#include +#include "quant_state.h" +#include "nf4_dequant.h" +#include "common.cuh" + +#define LDST32BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST64BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) + +// code2 为 256 * f16 +// 每个线程load 2 个,需要128个线程, 故设置一个block 128个线程,每个线程处理N个计算 +// 总计处理128 * N个数据, N 是2的幂 且不小于8 +// 结尾不够需要padding +template +__global__ void dequant_nf4_scale_warp8_batchN_kernel( + uint8_t* scale_q, + HFP_T* code2, + HFP_T* absmax2, + int num_blocks, + int group_size, + float offset, + float* output) { + int lane_id = threadIdx.x; + + // load code2 + __shared__ float shm_code2_float[128]; + + LDST32BITS(shm_code2_float[lane_id]) = LDST32BITS(code2[2 * lane_id]); + HFP_T* shm_code2 = (HFP_T *) shm_code2_float; + __syncthreads(); + + // 一次处理8个数据 + constexpr int loop_times = N / 8; + int g_scale_q_offset_base = blockIdx.x * 128 * N; + + alignas(16) uint8_t fragment[8]; + alignas(16) float cache_res[8]; +#pragma unroll + for (int i = 0; i < loop_times; ++i) { + int scale_offset = g_scale_q_offset_base + i * 128 * 8 + lane_id * 8; + HFP_T scale2 = absmax2[scale_offset / group_size]; + + if (scale_offset + 7 < num_blocks) { + LDST64BITS(fragment[0]) = LDST64BITS(*(scale_q + scale_offset)); +#pragma unroll + for (int j = 0; j < 8; ++j) { + cache_res[j] = f162float( shm_code2[fragment[j]] * scale2 ) + offset; + } + LDST128BITS(output[scale_offset]) = LDST128BITS(cache_res[0]); + LDST128BITS(output[scale_offset + 4]) = LDST128BITS(cache_res[4]); + } else if (scale_offset < num_blocks) { + // 不够一组,退化为每个元素load + int remains = num_blocks - scale_offset; + for (int j = 0; j < remains; ++j) { + fragment[j] = (scale_q + scale_offset)[j]; + cache_res[j] = f162float(shm_code2[fragment[j]] * scale2) + offset; + output[scale_offset + j] = cache_res[j]; + } + } + } +} + +// 一个block 128个线程 +// 每个线程负责N个, 每个block 负责 128 * N 个数据的解码 +template +__global__ void dequant_nf4_elements_warp8_batchN_kernel(uint8_t* packed_weights, float* absmax, int num_elements, int block_size, HFP_T* output) { + float kNF4[16] = { + -1.0000000f, -0.6961928f, -0.5250731f, -0.3949175f, + -0.2844414f, -0.1847734f, -0.0910500f, 0.0000000f, + 0.0795803f, 0.1609302f, 0.2461123f, 0.3379152f, + 0.4407098f, 0.5626170f, 0.7229568f, 1.0000000f + }; + uint8_t* packed_weights_end = packed_weights + (num_elements + 1) / 2; + + int bidx = blockIdx.x; + int lane_id = threadIdx.x; + + int block_offset = bidx * 128 * N; + + // 每次处理8个,32bits + alignas(16) uint8_t f_packed_weights[4]; + alignas(16) HFP_T cache_res[8]; + constexpr int loop_times = N / 8; +#pragma unroll + for (int i = 0; i < loop_times; ++i) { + int g_packed_weights_offset = block_offset + 8 * 128 * i + 8 * lane_id; + float scale = absmax[g_packed_weights_offset / block_size]; + if (packed_weights + g_packed_weights_offset / 2 + 4 < packed_weights_end) { + LDST32BITS(f_packed_weights[0]) = LDST32BITS(packed_weights[g_packed_weights_offset / 2]); +#pragma unroll + for (int j = 0; j < 4; ++j) { + uint8_t lower = f_packed_weights[j] & 0xF; + uint8_t upper = f_packed_weights[j] >> 4; + if constexpr (std::is_same_v) { + cache_res[2 * j] = __float2half(scale * kNF4[upper]); + cache_res[2 * j + 1] = __float2half(scale * kNF4[lower]); + } else if constexpr (std::is_same_v) { + cache_res[2 * j] = __float2bfloat16(scale * kNF4[upper]); + cache_res[2 * j + 1] = __float2bfloat16(scale * kNF4[lower]); + } + } + LDST128BITS(output[g_packed_weights_offset]) = LDST128BITS(cache_res[0]); + } else if (packed_weights + g_packed_weights_offset / 2 < packed_weights_end) { + int remains = num_elements - g_packed_weights_offset; + for (int j = 0; j < (remains + 1) / 2; ++j) { + f_packed_weights[0] = packed_weights[g_packed_weights_offset / 2 + j]; + uint8_t lower = f_packed_weights[0] & 0xF; + uint8_t upper = f_packed_weights[0] >> 4; + if constexpr (std::is_same_v) { + cache_res[0] = __float2half(scale * kNF4[upper]); + cache_res[1] = __float2half(scale * kNF4[lower]); + } else if constexpr (std::is_same_v) { + cache_res[0] = __float2bfloat16(scale * kNF4[upper]); + cache_res[1] = __float2bfloat16(scale * kNF4[lower]); + } + if (g_packed_weights_offset + 2 * j >= num_elements) { + // 只需要写回第一个 + output[g_packed_weights_offset + 2 * j] = cache_res[0]; + } else { + // 两个打包写回 + LDST32BITS(output[g_packed_weights_offset + 2 * j]) = LDST32BITS(cache_res[0]); + } + } + } + } +} + +// 一个block 128个线程 +// 每个线程负责N个, 每个block 负责 128 * N 个数据的解码 +template +__global__ void dequant_nf4_elements_one_phase_warp8_batchN_kernel( + uint8_t* packed_weights, + uint8_t* absmax_q, + int num_elements, + HFP_T* absmax2, + HFP_T* code2, + int block_size, + int group_size, + float offset, + HFP_T* output) { + constexpr float kNF4[16] = { + -1.0000000f, -0.6961928f, -0.5250731f, -0.3949175f, + -0.2844414f, -0.1847734f, -0.0910500f, 0.0000000f, + 0.0795803f, 0.1609302f, 0.2461123f, 0.3379152f, + 0.4407098f, 0.5626170f, 0.7229568f, 1.0000000f + }; + uint8_t* packed_weights_end = packed_weights + (num_elements + 1) / 2; + + int bidx = blockIdx.x; + int lane_id = threadIdx.x; + + // load code2 不用shared memory更快 +// __shared__ float shm_code2_float[128]; + +// LDST32BITS(shm_code2_float[lane_id]) = LDST32BITS(code2[2 * lane_id]); +// HFP_T* shm_code2 = (HFP_T *) shm_code2_float; +// __syncthreads(); + + int block_offset = bidx * 128 * N; + + // 每次处理8个,32bits + alignas(16) uint8_t f_packed_weights[4]; + alignas(16) HFP_T cache_res[8]; + constexpr int loop_times = N / 8; +#pragma unroll + for (int i = 0; i < loop_times; ++i) { + int g_packed_weights_offset = block_offset + 8 * 128 * i + 8 * lane_id; + int block_idx = g_packed_weights_offset / block_size; + int group_idx = block_idx / group_size; + + HFP_T h2[2]; + uint8_t q = absmax_q[block_idx]; + LDST32BITS(h2[0]) = LDST32BITS(code2[(q >> 1) << 1]); // 读 32-bit + HFP_T h = (q & 1) ? h2[1] : h2[0]; + float scale = f162float(h * absmax2[group_idx]) + offset; + if (packed_weights + g_packed_weights_offset / 2 + 4 < packed_weights_end) { + LDST32BITS(f_packed_weights[0]) = LDST32BITS(packed_weights[g_packed_weights_offset / 2]); +#pragma unroll + for (int j = 0; j < 4; ++j) { + uint8_t lower = f_packed_weights[j] & 0xF; + uint8_t upper = f_packed_weights[j] >> 4; + if constexpr (std::is_same_v) { + cache_res[2 * j] = __float2half(scale * kNF4[upper]); + cache_res[2 * j + 1] = __float2half(scale * kNF4[lower]); + } else if constexpr (std::is_same_v) { + cache_res[2 * j] = __float2bfloat16(scale * kNF4[upper]); + cache_res[2 * j + 1] = __float2bfloat16(scale * kNF4[lower]); + } + } + LDST128BITS(output[g_packed_weights_offset]) = LDST128BITS(cache_res[0]); + } else if (packed_weights + g_packed_weights_offset / 2 < packed_weights_end) { + int remains = num_elements - g_packed_weights_offset; + for (int j = 0; j < (remains + 1) / 2; ++j) { + f_packed_weights[0] = packed_weights[g_packed_weights_offset / 2 + j]; + uint8_t lower = f_packed_weights[0] & 0xF; + uint8_t upper = f_packed_weights[0] >> 4; + if constexpr (std::is_same_v) { + cache_res[0] = __float2half(scale * kNF4[upper]); + cache_res[1] = __float2half(scale * kNF4[lower]); + } else if constexpr (std::is_same_v) { + cache_res[0] = __float2bfloat16(scale * kNF4[upper]); + cache_res[1] = __float2bfloat16(scale * kNF4[lower]); + } + if (g_packed_weights_offset + 2 * j >= num_elements) { + // 只需要写回第一个 + output[g_packed_weights_offset + 2 * j] = cache_res[0]; + } else { + // 两个打包写回 + LDST32BITS(output[g_packed_weights_offset + 2 * j]) = LDST32BITS(cache_res[0]); + } + } + } + } +} + +void nf4_dequant_warp8_batch32_two_phase(const QuantState& quant_state, __half* output) { + constexpr int PROCESS_SIZE = 32; + + // 解码scale + uint8_t* scale_q_s; + __half* code2_s; + __half* absmax2_s; + + CUDA_CHECK(cudaMalloc(&scale_q_s, quant_state.num_blocks)); + CUDA_CHECK(cudaMalloc(&code2_s, 256 * sizeof(__half))); + CUDA_CHECK(cudaMalloc(&absmax2_s, quant_state.num_groups * sizeof(__half))); + + CUDA_CHECK(cudaMemcpy(scale_q_s, quant_state.absmax_q, quant_state.absmax_q_len_in_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(code2_s, quant_state.code2, 256 * sizeof(__half), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(absmax2_s, quant_state.absmax2, quant_state.num_groups * sizeof(__half), cudaMemcpyHostToDevice)); + + Tracer tracer; + tracer.memcpy_accumulate(quant_state.num_blocks) + .memcpy_accumulate(256 * sizeof(__half)) + .memcpy_accumulate(quant_state.num_groups * sizeof(__half)) + .memcpy_accumulate(quant_state.packed_weights_len_in_bytes) + .memcpy_accumulate(quant_state.num_elements * sizeof(__half)); + + dim3 dequant_scale_block_dim(128); + dim3 dequant_scale_grid_dim((quant_state.num_blocks + dequant_scale_block_dim.x * PROCESS_SIZE - 1) / (dequant_scale_block_dim.x * PROCESS_SIZE)); + // 解码权重 + float* absmax = nullptr; + size_t absmax_bytes = sizeof(float) * quant_state.num_blocks; + CUDA_CHECK(cudaMalloc(&absmax, absmax_bytes)); + + float* absmax_h = new float[quant_state.num_blocks]; + + tracer.start(); + if (quant_state.compute_type == "bf16") { + dequant_nf4_scale_warp8_batchN_kernel<__nv_bfloat16, PROCESS_SIZE><<>>( + scale_q_s, (__nv_bfloat16*) code2_s, + (__nv_bfloat16*) absmax2_s, quant_state.num_blocks, quant_state.group_size, quant_state.offset, absmax + ); + } else if (quant_state.compute_type == "fp16") { + dequant_nf4_scale_warp8_batchN_kernel<__half, PROCESS_SIZE><<>>( + scale_q_s, (__half*) code2_s, + (__half*) absmax2_s, quant_state.num_blocks, quant_state.group_size, quant_state.offset,absmax + ); + } else { + std::cerr << "Type Not Supported, only support bf16 | fp16" << std::endl; + exit(-1); + } + tracer.stop(); + + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + cudaMemcpy(absmax_h, absmax, quant_state.num_blocks * sizeof(float), cudaMemcpyDeviceToHost); + + CUDA_CHECK(cudaFree(scale_q_s)); + CUDA_CHECK(cudaFree(code2_s)); + CUDA_CHECK(cudaFree(absmax2_s)); + + uint8_t* packed_weights_s; + // output + __half* unpacked_weights_s; + + CUDA_CHECK(cudaMalloc(&packed_weights_s, quant_state.packed_weights_len_in_bytes)); + CUDA_CHECK(cudaMalloc(&unpacked_weights_s, quant_state.num_elements * sizeof(__half))); + CUDA_CHECK(cudaMemcpy(packed_weights_s, quant_state.packed_weights, quant_state.packed_weights_len_in_bytes, cudaMemcpyHostToDevice)) + + dim3 dequant_weights_grid_dim((quant_state.num_elements + dequant_scale_block_dim.x * PROCESS_SIZE - 1) / (dequant_scale_block_dim.x * PROCESS_SIZE)); + + tracer.start(); + if (quant_state.compute_type == "bf16") { + dequant_nf4_elements_warp8_batchN_kernel<__nv_bfloat16, PROCESS_SIZE><<>> ( + packed_weights_s, absmax, quant_state.num_elements, + quant_state.block_size, (__nv_bfloat16*) unpacked_weights_s + ); + } else if (quant_state.compute_type == "fp16") { + dequant_nf4_elements_warp8_batchN_kernel<__half, PROCESS_SIZE><<>> ( + packed_weights_s, absmax, quant_state.num_elements, + quant_state.block_size, (__half*) unpacked_weights_s + ); + } else { + std::cerr << "Type Not Supported, only support bf16 | fp16" << std::endl; + exit(-1); + } + tracer.stop(); + tracer.print(); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy(output, unpacked_weights_s, quant_state.num_elements * sizeof(__half), cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaFree(packed_weights_s)); + CUDA_CHECK(cudaFree(absmax)); + CUDA_CHECK(cudaFree(unpacked_weights_s)); +} + +void nf4_dequant_warp8_batch8_one_phase(const QuantState& quant_state, __half* output) { + constexpr int PROCESS_SIZE = 8; + + uint8_t* absmax_q_s; + __half* code2_s; + __half* absmax2_s; + uint8_t* packed_weights_s; + // output + __half* unpacked_weights_s; + + CUDA_CHECK(cudaMalloc(&absmax_q_s, quant_state.num_blocks)); + CUDA_CHECK(cudaMalloc(&code2_s, 256 * sizeof(__half))); + CUDA_CHECK(cudaMalloc(&absmax2_s, quant_state.num_groups * sizeof(__half))); + CUDA_CHECK(cudaMalloc(&packed_weights_s, quant_state.packed_weights_len_in_bytes)); + CUDA_CHECK(cudaMalloc(&unpacked_weights_s, quant_state.num_elements * sizeof(__half))); + Tracer tracer; + tracer.memcpy_accumulate(quant_state.num_blocks) + .memcpy_accumulate(256 * sizeof(__half)) + .memcpy_accumulate(quant_state.num_groups * sizeof(__half)) + .memcpy_accumulate(quant_state.packed_weights_len_in_bytes) + .memcpy_accumulate(quant_state.num_elements * sizeof(__half)); + + CUDA_CHECK(cudaMemcpy(absmax_q_s, quant_state.absmax_q, quant_state.absmax_q_len_in_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(code2_s, quant_state.code2, 256 * sizeof(__half), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(absmax2_s, quant_state.absmax2, quant_state.num_groups * sizeof(__half), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(packed_weights_s, quant_state.packed_weights, quant_state.packed_weights_len_in_bytes, cudaMemcpyHostToDevice)) + + dim3 dequant_scale_block_dim(128); + dim3 dequant_weights_grid_dim((quant_state.num_elements + dequant_scale_block_dim.x * PROCESS_SIZE - 1) / (dequant_scale_block_dim.x * PROCESS_SIZE)); + + tracer.start(); + if (quant_state.compute_type == "bf16") { + dequant_nf4_elements_one_phase_warp8_batchN_kernel<__nv_bfloat16, PROCESS_SIZE><<>> ( + packed_weights_s, + absmax_q_s, + quant_state.num_elements, + (__nv_bfloat16*) absmax2_s, + (__nv_bfloat16*) code2_s, + quant_state.block_size, + quant_state.group_size, + quant_state.offset, + (__nv_bfloat16*) unpacked_weights_s + ); + } else if (quant_state.compute_type == "fp16") { + dequant_nf4_elements_one_phase_warp8_batchN_kernel<__half, PROCESS_SIZE><<>> ( + packed_weights_s, + absmax_q_s, + quant_state.num_elements, + (__half*) absmax2_s, + (__half*) code2_s, + quant_state.block_size, + quant_state.group_size, + quant_state.offset, + (__half*) unpacked_weights_s + ); + } else { + std::cerr << "Type Not Supported, only support bf16 | fp16" << std::endl; + exit(-1); + } + tracer.stop(); + tracer.print(); + + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy(output, unpacked_weights_s, quant_state.num_elements * sizeof(__half), cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaFree(absmax_q_s)); + CUDA_CHECK(cudaFree(code2_s)); + CUDA_CHECK(cudaFree(absmax2_s)); + CUDA_CHECK(cudaFree(packed_weights_s)); + CUDA_CHECK(cudaFree(unpacked_weights_s)); + +} \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/test/conf/blocksize128_bf16_T4.ini b/03_nf4_dequant/flashzxi/test/conf/blocksize128_bf16_T4.ini new file mode 100644 index 0000000..f8359b5 --- /dev/null +++ b/03_nf4_dequant/flashzxi/test/conf/blocksize128_bf16_T4.ini @@ -0,0 +1,3 @@ +blocksize = 128 +compute_type = "bp16" +target_gpu = "T4" \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/test/conf/blocksize128_fp16_T4.ini b/03_nf4_dequant/flashzxi/test/conf/blocksize128_fp16_T4.ini new file mode 100644 index 0000000..8fc1503 --- /dev/null +++ b/03_nf4_dequant/flashzxi/test/conf/blocksize128_fp16_T4.ini @@ -0,0 +1,3 @@ +blocksize = 128 +compute_type = "fp16" +target_gpu = "T4" \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/test/conf/blocksize64_bf16_T4.ini b/03_nf4_dequant/flashzxi/test/conf/blocksize64_bf16_T4.ini new file mode 100644 index 0000000..a7c8fa0 --- /dev/null +++ b/03_nf4_dequant/flashzxi/test/conf/blocksize64_bf16_T4.ini @@ -0,0 +1,3 @@ +blocksize = 64 +compute_type = "bf16" +target_gpu = "T4" \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini b/03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini new file mode 100644 index 0000000..b80eab6 --- /dev/null +++ b/03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini @@ -0,0 +1,3 @@ +blocksize = 64 +compute_type = "fp16" +target_gpu = "T4" \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/test/data/baseline.py b/03_nf4_dequant/flashzxi/test/data/baseline.py new file mode 100644 index 0000000..857e548 --- /dev/null +++ b/03_nf4_dequant/flashzxi/test/data/baseline.py @@ -0,0 +1,146 @@ +import struct +import torch +import bitsandbytes.functional as F +import time + +def _dequant_bnb(qweight: torch.Tensor, qs): + """ + 兼容不同 bnb 版本的反量化入口: + 优先用 dequantize_4bit;没有的话再退到 dequantize_blockwise。 + """ + if hasattr(F, "dequantize_4bit"): + # 新版常见:直接传 quant_state + return F.dequantize_4bit(qweight, quant_state=qs) + if hasattr(F, "dequantize_blockwise"): + # 老版可能需要 absmax/code 等;但如果传 quant_state 通常也能工作 + return F.dequantize_blockwise(qweight, quant_state=qs) + raise RuntimeError("当前 bitsandbytes.functional 里找不到 dequantize_4bit / dequantize_blockwise") + +def save_nf4_tagged_binary(path: str, W: torch.Tensor, blocksize: int = 64): + """ + 写 w_nf4.bin: + [header]\n + num_rows:\n + num_cols:\n + blocksize:\n\n + [data]\n + packed_weights:\n + absmax_q:\n + absmax2:\n + code2:\n + offset:\n + """ + assert W.ndim == 2 and W.is_cuda + + num_rows, num_cols = map(int, W.shape) + num_elements = num_rows * num_cols + num_blocks = (num_elements + blocksize - 1) // blocksize + + qweight, qs = F.quantize_4bit( + W, + blocksize=blocksize, + quant_type="nf4", + compress_statistics=True, + quant_storage=torch.uint8, + ) + if not getattr(qs, "nested", False) or qs.state2 is None: + raise RuntimeError("需要 compress_statistics=True 才会有 absmax_q/absmax2/code2/offset") + + packed = qweight.detach().contiguous().view(torch.uint8).cpu() + packed_len = (num_elements + 1) // 2 + if packed.numel() != packed_len: + raise RuntimeError(f"packed_weights len mismatch: got={packed.numel()} expected={packed_len}") + + absmax_q = qs.absmax.detach().contiguous().view(torch.uint8).cpu() + if absmax_q.numel() != num_blocks: + raise RuntimeError(f"absmax_q len mismatch: got={absmax_q.numel()} expected={num_blocks}") + + absmax2 = qs.state2.absmax.detach().contiguous().cpu().to(torch.float16) + code2 = qs.state2.code.detach().contiguous().cpu().to(torch.float16) + if code2.numel() != 256: + raise RuntimeError(f"code2 len mismatch: got={code2.numel()} expected=256") + + offset = float(qs.offset) if qs.offset is not None else 0.0 + + with open(path, "wb") as f: + f.write(b"[header]\n") + f.write(b"num_rows: ") + f.write(struct.pack("\n + num_cols:\n + dtype:<1 byte tag>\n + data:\n + + dtype tag: 1 = fp16, 2 = fp32 + """ + num_rows, num_cols = shape + deq2d = deq.reshape(num_rows, num_cols).detach() + + if out_dtype == torch.float16: + tag = 1 + host = deq2d.to(torch.float16).contiguous().cpu() + elif out_dtype == torch.float32: + tag = 2 + host = deq2d.to(torch.float32).contiguous().cpu() + else: + raise ValueError("out_dtype 只支持 torch.float16 或 torch.float32") + + with open(path, "wb") as f: + f.write(host.numpy().tobytes(order="C")) + + # with open(path + ".txt", "w", encoding="utf-8") as f: + # f.write("[dequant]\n") + # f.write(f"num_rows: {int(num_rows)}\n") + # f.write(f"num_cols: {int(num_cols)}\n") + # f.write(f"dtype: {int(tag)}\n") + # f.write("data:\n") + # + # # 逐行写,空格分隔 + # # 可以按需要改格式,比如 "{:.6f}" + # for i in range(num_rows): + # row = host[i].tolist() + # f.write(" ".join(f"{v:.6f}" for v in row)) + # f.write("\n") + # # print(" ".join(f"{v:.6f}" for v in row)) + +if __name__ == "__main__": + torch.manual_seed(0) + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + row = 10000 + col = 10000 + W = torch.randn(row, col, device="cuda", dtype=torch.float16) + file_prefix = f"nf4_{row}x{col}_fp16" + qweight, qs, shape = save_nf4_tagged_binary(file_prefix + ".bin", W, blocksize=64) + + start = time.perf_counter() + deq = _dequant_bnb(qweight, qs) # bnb 反量化 + end = time.perf_counter() + elapsed_ms = (end - start) * 1000 + print(f"dequantize_4bit执行时间: {elapsed_ms:.3f} ms") + save_dequant_result(file_prefix + "_w_dequant.bin", deq, shape, out_dtype=torch.float16) diff --git a/04_hadamard_tc/flashzxi/CMakeLists.txt b/04_hadamard_tc/flashzxi/CMakeLists.txt new file mode 100644 index 0000000..c93a60f --- /dev/null +++ b/04_hadamard_tc/flashzxi/CMakeLists.txt @@ -0,0 +1,36 @@ +cmake_minimum_required(VERSION 3.26) + +project(hadacore LANGUAGES CXX CUDA) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) + +set(CMAKE_CUDA_ARCHITECTURES 80) + +find_package(CUDAToolkit REQUIRED) + +include(FetchContent) + +FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git + GIT_TAG v3.4.0 +) + +FetchContent_MakeAvailable(cutlass) + +add_executable(hadacore + src/main.cu + src/hadacore.cu +) + +target_include_directories(hadacore PRIVATE + ${CMAKE_SOURCE_DIR}/include + ${cutlass_SOURCE_DIR}/include + ${cutlass_SOURCE_DIR}/tools/util/include + ${CUDAToolkit_INCLUDE_DIRS} +) + +set_target_properties(hadacore PROPERTIES + CUDA_SEPARABLE_COMPILATION OFF +) \ No newline at end of file diff --git a/04_hadamard_tc/flashzxi/include/h16_bf16.inc b/04_hadamard_tc/flashzxi/include/h16_bf16.inc new file mode 100644 index 0000000..379658c --- /dev/null +++ b/04_hadamard_tc/flashzxi/include/h16_bf16.inc @@ -0,0 +1,16 @@ +0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, +0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, +0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, +0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, +0x3f80, 0x3f80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, +0x3f80, 0xbf80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, +0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, +0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, +0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, +0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, +0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, +0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, +0x3f80, 0x3f80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, +0x3f80, 0xbf80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, +0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, +0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0xbf80, 0x3f80 \ No newline at end of file diff --git a/04_hadamard_tc/flashzxi/include/h16_fp16.inc b/04_hadamard_tc/flashzxi/include/h16_fp16.inc new file mode 100644 index 0000000..5fea632 --- /dev/null +++ b/04_hadamard_tc/flashzxi/include/h16_fp16.inc @@ -0,0 +1,16 @@ +0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, +0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, +0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, +0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, +0x3c00, 0x3c00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, +0x3c00, 0xbc00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, +0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, +0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, +0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, +0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, +0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, +0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, +0x3c00, 0x3c00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, +0x3c00, 0xbc00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, +0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, +0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0xbc00, 0x3c00 \ No newline at end of file diff --git a/04_hadamard_tc/flashzxi/include/hadacore.hpp b/04_hadamard_tc/flashzxi/include/hadacore.hpp new file mode 100644 index 0000000..fcd4897 --- /dev/null +++ b/04_hadamard_tc/flashzxi/include/hadacore.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(err)); \ + exit(1); \ + } \ + } while(0) + +namespace hadacore +{ + +void test_small(); +void test_large(); + +} // namespace hadacore \ No newline at end of file diff --git a/04_hadamard_tc/flashzxi/src/hadacore.cu b/04_hadamard_tc/flashzxi/src/hadacore.cu new file mode 100644 index 0000000..dabd2e5 --- /dev/null +++ b/04_hadamard_tc/flashzxi/src/hadacore.cu @@ -0,0 +1,376 @@ +// +// Created by core_dump on 3/14/26. +// +#include "hadacore.hpp" + +namespace hadacore +{ +using namespace cute; +const int M = 16; + +__device__ __constant__ uint16_t H16_fp16_bin[M * M] = { +#include "../include/h16_fp16.inc" +}; +__device__ __constant__ uint16_t H16_bf16_bin[M * M] = { +#include "../include/h16_bf16.inc" +}; + +__device__ __constant__ half_t* H16_fp16 = (half_t*) H16_fp16_bin; +__device__ __constant__ bfloat16_t* H16_bf16 = (bfloat16_t*) H16_bf16_bin; + +// 对角Hadamard矩阵 +__device__ __constant__ uint16_t H2_diag_fp16_bin[M * M] = { +#include "../include/h2_diag_fp16.inc" +}; +__device__ __constant__ uint16_t H2_diag_bf16_bin[M * M] = { +#include "../include/h2_diag_bf16.inc" +}; +__device__ __constant__ half_t* H2_diag_fp16 = (half_t*) H2_diag_fp16_bin; +__device__ __constant__ bfloat16_t* H2_diag_bf16 = (bfloat16_t*) H2_diag_bf16_bin; + +__device__ __constant__ uint16_t H4_diag_fp16_bin[M * M] = { +#include "../include/h4_diag_fp16.inc" +}; +__device__ __constant__ uint16_t H4_diag_bf16_bin[M * M] = { +#include "../include/h4_diag_bf16.inc" +}; +__device__ __constant__ half_t* H4_diag_fp16 = (half_t*) H4_diag_fp16_bin; +__device__ __constant__ bfloat16_t* H4_diag_bf16 = (bfloat16_t*) H4_diag_bf16_bin; + +__device__ __constant__ uint16_t H8_diag_fp16_bin[M * M] = { +#include "../include/h8_diag_fp16.inc" +}; +__device__ __constant__ uint16_t H8_diag_bf16_bin[M * M] = { +#include "../include/h8_diag_bf16.inc" +}; +__device__ __constant__ half_t* H8_diag_fp16 = (half_t*) H8_diag_fp16_bin; +__device__ __constant__ bfloat16_t* H8_diag_bf16 = (bfloat16_t*) H8_diag_bf16_bin; + +// 每个block负责计算一行 +// 一次计算256,一个warp计算 CHUNKS 个256 +// 一个block R_WIDTH / 256 / CHUNKS 个warp +template +__global__ void hadacore_less_than_4096(T* A) { + + constexpr int WARPS = R_WIDTH / 256 / CHUNKS; + extern __shared__ __align__(16) char smemA[]; + __shared__ __align__(32) int16_t smemhada_bin1[M * M]; + __shared__ __align__(32) int16_t smemhada_bin2[M * M]; + + T* smemA_total = (T*) smemA; + T* smemhada1 = (T*)smemhada_bin1; + T* smemhada2 = (T*)smemhada_bin2; + T* hada1_ptr = nullptr; + T* hada2_ptr = nullptr; + if constexpr (std::is_same_v) { + hada1_ptr = H16_fp16; + } else { + hada1_ptr = H16_bf16; + } + + constexpr int log_r_width = 31 - __builtin_clz(R_WIDTH); + if (log_r_width > 8) + { + if(R_WIDTH) + } + + auto gA_total = make_tensor( + make_gmem_ptr(A), + make_shape(Int{}, Int{}, Int{}), + make_stride(Int{}, Int{}, Int<1>{}) + ); + + auto sA_total = make_tensor( + make_smem_ptr(smemA_total), + make_shape(Int{}, Int{}, Int{}), + make_stride(Int{}, Int{}, Int<1>{}) + ); + + auto gH1 = make_tensor( + make_gmem_ptr(hada1_ptr), + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{}) + ); + + auto sH1 = make_tensor( + make_smem_ptr(smemhada1), + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{}) + ); + + auto sH2 = make_tensor( + make_smem_ptr(smemhada2), + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{}) + ); + + auto sA = sA_total(threadIdx.y * CHUNKS, _, _); + + // 每个线程 load 8 个 elements + using CopyAtom = Copy_Atom, T>; + + auto copyA = make_tiled_copy( + CopyAtom{}, + Layout>{}, + Layout>{} + ); + + auto thr_copy_a = copyA.get_slice(threadIdx.x); + + auto tAgA = thr_copy_a.partition_S(gA_total(threadIdx.y * CHUNKS, _, _)); + auto tAsA = thr_copy_a.partition_D(sA); + + auto tHgH1 = thr_copy_a.partition_S(gH1); + auto tHsH1 = thr_copy_a.partition_D(sH1); + + auto tHgH2 = thr_copy_a.partition_S(gH2); + auto tHsH2 = thr_copy_a.partition_D(sH2); + + if (threadIdx.y == 0) { + copy(tHgH1, tHsH1); + if (gH2 != nullptr) + { + copy(tHgH2, tHsH2); + } + } + if (threadIdx.x < R_WIDTH / 16) + { + copy(tAgA, tAsA); + } + __syncwarp(); + cp_async_fence(); + using MMA_Atom_Arch = MMA_Atom; + + // 一个 16x8x16 atom,沿 M 方向铺 2 份 => 16x16x16 + auto mma = make_tiled_mma( + MMA_Atom_Arch{}, + Layout>{}, + Layout>{} + ); + + for (int loop = 1; loop < CHUNKS; ++loop) { + // 先 load 下一批 A,再计算 + auto sA_back = sA_total(threadIdx.y * CHUNKS + loop, _, _); + auto tAgA_back = thr_copy_a.partition_S( + gA_total(threadIdx.y * CHUNKS + loop, _, _) + ); + auto tAsA_back = thr_copy_a.partition_D(sA_back); + copy(tAgA_back, tAsA_back); + + cp_async_wait<0>(); + + if (threadIdx.y == 0 && threadIdx.x == 0) + { + print_tensor(gA_total(threadIdx.y * CHUNKS + loop - 1, _, _)); + print_tensor(sA); + print_tensor(sH1); + } + // 计算 H * (A * H) + auto thr_mma = mma.get_slice(threadIdx.x); + + // 1) 右乘 H: A x H -> C + auto tCsA = thr_mma.partition_A(sA); + auto tCsB = thr_mma.partition_B(sH1); + auto tCsC = thr_mma.partition_C(sA); + + auto tCrC = thr_mma.make_fragment_C(tCsC); + + clear(tCrC); + gemm(mma, tCsA, tCsB, tCrC); + copy(tCrC, tCsC); + __syncwarp(); + if (threadIdx.y == 0 && threadIdx.x == 0 && loop == 1) + { + print_tensor(sA); + } + + // 2) 左乘 H: H x C -> A + auto sAt = make_tensor(sA.data(), + make_shape(Int{}, Int{}), + make_stride(Int<1>{}, Int{})); + auto tCsH = thr_mma.partition_A(sH1); + auto tCsHC = thr_mma.partition_B(sAt); + auto tCsC2 = thr_mma.partition_C(sA); + auto tCrC2 = thr_mma.make_fragment_C(tCsC2); + + clear(tCrC2); + gemm(mma, tCsH, tCsHC, tCrC2); + copy(tCrC2, tCsC2); + + __syncwarp(); + // 完成数据 load 再进行下一批 work + cp_async_fence(); + sA = sA_back; + } + cp_async_wait<0>(); + // 计算 H * (A * H) + auto thr_mma = mma.get_slice(threadIdx.x); + + // 1) 右乘 H: A x H -> C + auto tCsA = thr_mma.partition_A(sA); + auto tCsB = thr_mma.partition_B(sH1); + auto tCsC = thr_mma.partition_C(sA); + + auto tCrC = thr_mma.make_fragment_C(tCsC); + + clear(tCrC); + gemm(mma, tCsA, tCsB, tCrC); + copy(tCrC, tCsC); + __syncwarp(); + if (R_WIDTH < 256) + { + if (threadIdx.x < R_WIDTH / 16) + { + copy(tAsA, tAgA); + } + return; + } + + // 2) 左乘 H: H x C -> A + auto sAt = make_tensor(sA.data(), + make_shape(Int{}, Int{}), + make_stride(Int<1>{}, Int{})); + auto tCsH = thr_mma.partition_A(sH1); + auto tCsHC = thr_mma.partition_B(sAt); + auto tCsC2 = thr_mma.partition_C(sA); + auto tCrC2 = thr_mma.make_fragment_C(tCsC2); + + clear(tCrC2); + gemm(mma, tCsH, tCsHC, tCrC2); + copy(tCrC2, tCsC2); + + auto origin_layout = make_layout( + make_shape(Int{}, Int<256>{}), + make_stride(Int<256>{}, Int<1>{})); + auto new_view = make_layout( + make_shape(Int<16>{}, Int<16>{}), + make_stride(Int<16>{}, Int<1>{})); + auto real_layout = composition(origin_layout, new_view); + + for (int i = 0; i < CHUNKS; ++i) + { + int cols = 256 / CHUNKS * WARPS; + auto new_tensor = make_tensor( + make_smem_ptr(smemA_total + cols), real_layout + ); + + auto tCsA = thr_mma.partition_A(new_tensor); + auto tCsB = thr_mma.partition_B(sH2); + auto tCsC = thr_mma.partition_C(new_tensor); + + auto tCrC = thr_mma.make_fragment_C(tCsC); + + clear(tCrC); + gemm(mma, tCsA, tCsB, tCrC); + copy(tCrC, tCsC); + } + + // 需要block的全部thread同步了 + __syncthreads(); +} + +void test_small() +{ + constexpr int R_WIDTH = 128; // 8 * 16 + constexpr int ROWS = R_WIDTH / M; // 8 + + // 准备输入数据 (8行16列) + std::vector A_h(R_WIDTH); + for (int i = 0; i < R_WIDTH; ++i) { + A_h[i] = half_t(i / 100.0f); + } + + // 打印输入数据 + printf("Input A (8x16):\n"); + for (int r = 0; r < ROWS; ++r) { + for (int c = 0; c < M; ++c) { + printf("%6.2f ", float(A_h[r * M + c])); + } + printf("\n"); + } + + // 分配 GPU 内存 + half_t *A_d, *O_d; + cudaMalloc(&A_d, R_WIDTH * sizeof(half_t)); + cudaMalloc(&O_d, R_WIDTH * sizeof(half_t)); + + // 拷贝数据到 GPU + cudaMemcpy(A_d, A_h.data(), R_WIDTH * sizeof(half_t), cudaMemcpyHostToDevice); + + // 调用 kernel + hada_core_less_256<<<1, 32>>>(A_d, O_d); + + // 等待完成 + cudaDeviceSynchronize(); + CUDA_CHECK(cudaGetLastError()); + + // 拷贝结果回主机 + std::vector result(R_WIDTH); + cudaMemcpy(result.data(), A_d, R_WIDTH * sizeof(half_t), cudaMemcpyDeviceToHost); + + // 打印结果 + printf("\nOutput A after H * A * H (8x16):\n"); + for (int r = 0; r < ROWS; ++r) { + for (int c = 0; c < M; ++c) { + printf("%6.2f ", float(result[r * M + c])); + } + printf("\n"); + } + + // 释放内存 + cudaFree(A_d); + cudaFree(O_d); +} + +void test_large() +{ + constexpr int R_WIDTH = 1024; // 总行宽 + constexpr int CHUNKS = 2; // 每个 warp 处理的 chunk 数 + constexpr int WARPS = R_WIDTH / 256 / CHUNKS; // = 2 + + // 准备输入数据 (512 = 32行 x 16列) + std::vector A_h(R_WIDTH); + for (int i = 0; i < R_WIDTH; ++i) + { + A_h[i] = half_t(i / 100.0f); // 0,1,2,...,15,0,1,2,... + } + // 分配 GPU 内存 + half_t *A_d; + cudaMalloc(&A_d, R_WIDTH * sizeof(half_t)); + + // 拷贝数据到 GPU + cudaMemcpy(A_d, A_h.data(), R_WIDTH * sizeof(half_t), cudaMemcpyHostToDevice); + + // 计算 dynamic shared memory 大小 + // 每个 chunk 是 16x16,每个 warp 处理 CHUNKS 个 + size_t smem_size = std::max(WARPS * CHUNKS * M * M * sizeof(half_t), 16 * sizeof(half_t)); + + printf("\nLaunching kernel: R_WIDTH=%d, CHUNKS=%d, WARPS=%d\n", R_WIDTH, CHUNKS, WARPS); + printf("Block dim: (%d, %d, 1), Dynamic smem: %zu bytes\n\n", 32, WARPS, smem_size); + + // 调用 kernel + dim3 block(32, WARPS); + hadacore_less_than_4096<<<1, block, smem_size>>>(A_d); + + // 等待完成 + cudaDeviceSynchronize(); + CUDA_CHECK(cudaGetLastError()); + + // 拷贝结果回主机 + std::vector result(R_WIDTH); + cudaMemcpy(result.data(), A_d, R_WIDTH * sizeof(half_t), cudaMemcpyDeviceToHost); + + // 打印结果 (前32行) + printf("Output A after H * A * H (first 32x16):\n"); + for (int r = 0; r < 32; ++r) { + for (int c = 0; c < M; ++c) { + printf("%6.1f ", float(result[r * M + c])); + } + printf("\n"); + } + + // 释放内存 + cudaFree(A_d); +} +} + diff --git a/04_hadamard_tc/flashzxi/src/main.cu b/04_hadamard_tc/flashzxi/src/main.cu new file mode 100644 index 0000000..a32b829 --- /dev/null +++ b/04_hadamard_tc/flashzxi/src/main.cu @@ -0,0 +1,22 @@ +// +// Created by core_dump on 3/14/26. +// +#include "hadacore.hpp" +#include +using namespace cute; + +int main() { + auto layout = make_layout(make_shape(Int<4>{}, Int<256>{}), + make_stride(Int<256>{}, Int<1>{})); + + auto B = make_layout(Shape<_16, _16>{}, Stride<_16, _1>{}); + auto new_layout = composition(layout, B); + print_layout(new_layout); // 直接打印二维“坐标 -> index”表 +} +// int main() +// { +// // hadacore::test_small(); +// printf("\n========================================\n\n"); +// hadacore::test_large(); +// return 0; +// } \ No newline at end of file