Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ggml-cuda.h"

#include <cstdint>
#include <limits>
#include <memory>

#if defined(GGML_USE_HIP)
Expand Down Expand Up @@ -1218,6 +1219,104 @@ struct ggml_cuda_stream_context {
}
};

// cache to extend lifetimes of ggml_cuda_pool_alloc, ggml_cuda_pool expects memory to allocated in a LIFO order
// hence this cache works like a stack
struct ggml_cuda_cache {
struct pool_alloc {
ggml_cuda_pool * pool;
void * ptr;
size_t actual_size;

pool_alloc():
pool{nullptr}
, ptr{nullptr}
, actual_size{0}
{}

template<typename T>
pool_alloc(ggml_cuda_pool_alloc<T> && other) {
pool = other.pool;
ptr = (void *)other.ptr;
actual_size = other.actual_size;

other.ptr = nullptr;
other.pool = nullptr;
other.actual_size = 0;
}

pool_alloc(pool_alloc && other) {
pool = other.pool;
ptr = (void *) other.ptr;
actual_size = other.actual_size;
other.ptr = nullptr;
other.pool = nullptr;
other.actual_size = 0;
}

~pool_alloc() {
if (ptr != nullptr) {
pool->free(ptr, actual_size);
}
}
};

struct cache_entry {
int layout; // mmq_q8_1_ds_layout value
std::vector<pool_alloc> pool_ptrs;
size_t ttl_nodes{};

cache_entry() = default;

cache_entry(cache_entry && other) = default;
cache_entry& operator=(cache_entry && other) = default;

cache_entry(const cache_entry &) = delete;
cache_entry& operator=(const cache_entry &) = delete;

~cache_entry() {
// Free pool allocations in reverse order (LIFO)
while (!pool_ptrs.empty()) {
pool_ptrs.pop_back();
}
}
};

void clear_cache() {
remove_expired(std::numeric_limits<size_t>::max());
entries.clear();
}

void remove_expired(size_t node_count) {
// max lifetime of cache entry - 10 nodes after
while (!entries.empty() && entries.back().second.ttl_nodes + 10 <= node_count) {
entries.pop_back();
}
}

cache_entry * find(const ggml_tensor * node, int layout) {
for (auto & entry: entries) {
if (entry.first == node && entry.second.layout == layout) {
return &entry.second;
}
}
return nullptr;
}

~ggml_cuda_cache() {
while (!entries.empty()) {
entries.pop_back();
}
}

void add_entry(const ggml_tensor * node, cache_entry && entry) {
entries.emplace_back(node, std::move(entry));
}

std::vector<std::pair<const ggml_tensor *, cache_entry>> entries;
};



struct ggml_backend_cuda_context {
int device;
std::string name;
Expand All @@ -1229,6 +1328,7 @@ struct ggml_backend_cuda_context {
std::unique_ptr<ggml_cuda_graph> cuda_graph;

int curr_stream_no = 0;
size_t node_count = 0;

explicit ggml_backend_cuda_context(int device) :
device(device),
Expand Down Expand Up @@ -1266,6 +1366,7 @@ struct ggml_backend_cuda_context {

// pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
std::unique_ptr<ggml_cuda_cache> caches[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]{{}};

static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);

Expand All @@ -1276,6 +1377,27 @@ struct ggml_backend_cuda_context {
return *pools[device][curr_stream_no];
}

ggml_cuda_cache & cache(int device, int stream) {
if (caches[device][stream] == nullptr) {
caches[device][stream] = std::unique_ptr<ggml_cuda_cache>(new ggml_cuda_cache());
}
return *caches[device][stream];
}

ggml_cuda_cache & cache() {
return cache(device, curr_stream_no);
}

void clear_cache() {
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (caches[i][j]) {
caches[i][j]->clear_cache();
}
}
}
}

ggml_cuda_pool & pool() {
return pool(device);
}
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3232,6 +3232,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
ggml_cuda_concurrent_event * concurrent_event = nullptr;
bool should_launch_concurrent_events = false;

cuda_ctx->clear_cache();

const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
concurrent_event = &stream_ctx.concurrent_events[node];
Expand Down Expand Up @@ -3662,6 +3664,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
}
GGML_ASSERT(ok);

// Increment node counter and expire old cache entries
cuda_ctx->node_count++;
cuda_ctx->cache().remove_expired(cuda_ctx->node_count);

if (!is_concurrent_event_active) {
try_launch_concurrent_event(node);
}
Expand Down
152 changes: 109 additions & 43 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "common.cuh"
#include "ggml.h"
#include "mmq.cuh"
#include "quantize.cuh"
#include "mmid.cuh"
Expand Down Expand Up @@ -118,25 +119,53 @@ void ggml_cuda_mul_mat_q(
// TODO: tighter pool buffer size vs q8 path
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;

ggml_cuda_cache & cache = ctx.cache();

if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);

{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (use_native_mxfp4) {
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
const int layout = use_native_mxfp4 ? -1 : mmq_get_q8_1_ds_layout(src0->type);

} else {
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
void * src1_ptr = nullptr;
ggml_cuda_cache::cache_entry * entry = cache.find(src1, layout);
if (entry != nullptr) {
GGML_ASSERT(entry->pool_ptrs.size() == 1);
size_t expected_size = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
GGML_ASSERT(entry->pool_ptrs[0].actual_size >= expected_size);
src1_ptr = entry->pool_ptrs[0].ptr;
GGML_ASSERT(src1_ptr != nullptr);
} else {

ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (use_native_mxfp4) {
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);

} else {
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
}
CUDA_CHECK(cudaGetLastError());
}
CUDA_CHECK(cudaGetLastError());

src1_ptr = src1_q8_1.get();

std::vector<ggml_cuda_cache::pool_alloc> allocs;
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(src1_q8_1)));

cache.add_entry(
src1,
ggml_cuda_cache::cache_entry{
layout,
std::move(allocs),
ctx.node_count
});
}

// Stride depends on quantization format
Expand All @@ -148,7 +177,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s13 = ne12*s12;

const mmq_args args = {
src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
src0_d, src0->type, (const int *) src1_ptr, nullptr, nullptr, dst_d,
ne00, ne01, ne1, s01, ne11, s1,
ne02, ne12, s02, s12, s2,
ne03, ne13, s03, s13, s3,
Expand All @@ -165,41 +194,78 @@ void ggml_cuda_mul_mat_q(
const int64_t ne_get_rows = ne12 * n_expert_used;
GGML_ASSERT(ne1 == n_expert_used);

ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
const int layout = use_native_mxfp4 ? -1 : mmq_get_q8_1_ds_layout(src0->type);

{
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
const int si1 = ids->nb[1] / ggml_element_size(ids);
const int sis1 = nb12 / nb11;
void * ids_dst_ptr = nullptr;
void * expert_bounds_ptr = nullptr;
void * src1_q8_1_ptr = nullptr;

ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
CUDA_CHECK(cudaGetLastError());
}
ggml_cuda_cache::cache_entry * entry = cache.find(src1, layout);
if (entry != nullptr) {
GGML_ASSERT(entry->pool_ptrs.size() == 4);
ids_dst_ptr = entry->pool_ptrs[1].ptr;
expert_bounds_ptr = entry->pool_ptrs[2].ptr;
src1_q8_1_ptr = entry->pool_ptrs[3].ptr;

const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
size_t expected_q8_1_size = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
GGML_ASSERT(entry->pool_ptrs[3].actual_size >= expected_q8_1_size);
} else {
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);

const int64_t ne11_flat = ne12*n_expert_used;
const int64_t ne12_flat = 1;
const int64_t ne13_flat = 1;
{
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
const int si1 = ids->nb[1] / ggml_element_size(ids);
const int sis1 = nb12 / nb11;

{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;
ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
CUDA_CHECK(cudaGetLastError());
}

if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);

const int64_t ne11_flat = ne12*n_expert_used;
const int64_t ne12_flat = 1;
const int64_t ne13_flat = 1;

{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;

if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
}
CUDA_CHECK(cudaGetLastError());
}
CUDA_CHECK(cudaGetLastError());

void * ids_src1_ptr = ids_src1.get();
ids_dst_ptr = ids_dst.get();
expert_bounds_ptr = expert_bounds.get();
src1_q8_1_ptr = src1_q8_1.get();

std::vector<ggml_cuda_cache::pool_alloc> allocs;
// Store in allocation order; custom destructor will free in reverse (LIFO)
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(ids_src1)));
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(ids_dst)));
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(expert_bounds)));
allocs.emplace_back(ggml_cuda_cache::pool_alloc(std::move(src1_q8_1)));

cache.add_entry(
src1,
ggml_cuda_cache::cache_entry{
layout,
std::move(allocs),
ctx.node_count
});
}

const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
Expand All @@ -208,7 +274,7 @@ void ggml_cuda_mul_mat_q(

// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
const mmq_args args = {
src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
src0_d, src0->type, (const int *) src1_q8_1_ptr, (int32_t *) ids_dst_ptr, (int32_t *) expert_bounds_ptr, dst_d,
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
ne02, ne02, s02, s12, s2,
ne03, ne13, s03, s13, s3,
Expand Down
Loading