Skip to content

Commit c3eea9f

Browse files
dvrogozhNicolasHugscotts
authored
Implement abstract per-GPU cache helper class (#814)
Signed-off-by: Dmitry Rogozhkin <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Scott Schneider <[email protected]>
1 parent 4af0bfe commit c3eea9f

File tree

4 files changed

+180
-85
lines changed

4 files changed

+180
-85
lines changed

src/torchcodec/_core/Cache.h

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include <torch/types.h>
8+
#include <memory>
9+
#include <mutex>
10+
11+
namespace facebook::torchcodec {
12+
13+
// This header defines simple cache class primitives to store reusable objects
14+
// across TorchCodec stream instances. Intended usage is to store hardware
15+
// contexts creation of which is expensive. The cache mechanism is as follows:
16+
// 1. 'PerGpuCache' provides a dynamic cache with the specified maximum capacity
17+
// for the given number of GPUs.
18+
// 2. When stream object (e.g. SingleStreamDecoder) is destoyed cachable object
19+
// must be released to the cache. Cache will accept the object if it is not
20+
// full.
21+
// 3. When stream object (e.g. SingleStreamDecoder) is created cachable object
22+
// must be first queried from the cache. If the cache is empty then new
23+
// object must be created.
24+
25+
template <typename T, typename D = std::default_delete<T>>
26+
class Cache {
27+
public:
28+
using element_type = std::unique_ptr<T, D>;
29+
30+
Cache(int capacity) : capacity_(capacity) {}
31+
32+
// Adds an object to the cache if the cache has capacity. Returns true
33+
// if object was added and false otherwise.
34+
bool addIfCacheHasCapacity(element_type&& obj);
35+
36+
// Returns an object from the cache. Cache does not hold a reference
37+
// to the object after this call.
38+
element_type get();
39+
40+
private:
41+
int capacity_;
42+
std::mutex mutex_;
43+
std::vector<element_type> cache_;
44+
};
45+
46+
template <typename T, typename D>
47+
bool Cache<T, D>::addIfCacheHasCapacity(element_type&& obj) {
48+
std::scoped_lock lock(mutex_);
49+
if (capacity_ >= 0 && cache_.size() >= static_cast<size_t>(capacity_)) {
50+
return false;
51+
}
52+
cache_.push_back(std::move(obj));
53+
return true;
54+
}
55+
56+
template <typename T, typename D>
57+
typename Cache<T, D>::element_type Cache<T, D>::get() {
58+
std::scoped_lock lock(mutex_);
59+
if (cache_.empty())
60+
return nullptr;
61+
62+
element_type obj = std::move(cache_.back());
63+
cache_.pop_back();
64+
return obj;
65+
}
66+
67+
template <typename T, typename D = std::default_delete<T>>
68+
class PerGpuCache {
69+
public:
70+
using element_type = typename Cache<T, D>::element_type;
71+
72+
// Initializes 'maxGpus' number of caches. Each cache can hold no
73+
// more than 'capacity' items. If 'capacity' <0 cache size is unlimited.
74+
PerGpuCache(int maxGpus, int capacity) {
75+
TORCH_CHECK(maxGpus > 0, "maxGpus for PerGpuCache must be >0");
76+
for (int i = 0; i < maxGpus; ++i) {
77+
cache_.emplace_back(std::make_unique<Cache<T, D>>(capacity));
78+
}
79+
}
80+
81+
// Adds an object to the specified device cache if the cache has
82+
// capacity. Returns true if object was added and false otherwise.
83+
bool addIfCacheHasCapacity(const torch::Device& device, element_type&& obj);
84+
85+
// Returns an object from the cache of the specified device. Cache
86+
// does not hold a reference to the object after this call.
87+
element_type get(const torch::Device& device);
88+
89+
private:
90+
// 'Cache' class implementation contains mutex which makes it non-movable
91+
// and non-copyable, so we need to wrap it in std::unique_ptr.
92+
std::vector<std::unique_ptr<Cache<T, D>>> cache_;
93+
};
94+
95+
torch::DeviceIndex getNonNegativeDeviceIndex(const torch::Device& device) {
96+
torch::DeviceIndex deviceIndex = device.index();
97+
// For single GPU machines libtorch returns -1 for the device index. So for
98+
// that case we set the device index to 0. That's used in per-gpu cache
99+
// implementation and during initialization of CUDA and FFmpeg contexts
100+
// which require non negative indices.
101+
deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0);
102+
TORCH_CHECK(deviceIndex >= 0, "Device index out of range");
103+
return deviceIndex;
104+
}
105+
106+
template <typename T, typename D>
107+
bool PerGpuCache<T, D>::addIfCacheHasCapacity(
108+
const torch::Device& device,
109+
element_type&& obj) {
110+
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device);
111+
TORCH_CHECK(
112+
static_cast<size_t>(deviceIndex) < cache_.size(),
113+
"Device index out of range");
114+
return cache_[deviceIndex]->addIfCacheHasCapacity(std::move(obj));
115+
}
116+
117+
template <typename T, typename D>
118+
typename PerGpuCache<T, D>::element_type PerGpuCache<T, D>::get(
119+
const torch::Device& device) {
120+
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device);
121+
TORCH_CHECK(
122+
static_cast<size_t>(deviceIndex) < cache_.size(),
123+
"Device index out of range");
124+
return cache_[deviceIndex]->get();
125+
}
126+
127+
} // namespace facebook::torchcodec

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 49 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <torch/types.h>
55
#include <mutex>
66

7+
#include "src/torchcodec/_core/Cache.h"
78
#include "src/torchcodec/_core/CudaDeviceInterface.h"
89
#include "src/torchcodec/_core/FFMPEGCommon.h"
910

@@ -44,49 +45,11 @@ const int MAX_CUDA_GPUS = 128;
4445
// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching.
4546
// Set to a positive number to have a cache of that size.
4647
const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
47-
std::vector<AVBufferRef*> g_cached_hw_device_ctxs[MAX_CUDA_GPUS];
48-
std::mutex g_cached_hw_device_mutexes[MAX_CUDA_GPUS];
49-
50-
torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) {
51-
torch::DeviceIndex deviceIndex = device.index();
52-
deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0);
53-
TORCH_CHECK(deviceIndex >= 0, "Device index out of range");
54-
// FFMPEG cannot handle negative device indices.
55-
// For single GPU- machines libtorch returns -1 for the device index. So for
56-
// that case we set the device index to 0.
57-
// TODO: Double check if this works for multi-GPU machines correctly.
58-
return deviceIndex;
59-
}
60-
61-
void addToCacheIfCacheHasCapacity(
62-
const torch::Device& device,
63-
AVBufferRef* hwContext) {
64-
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device);
65-
if (static_cast<int>(deviceIndex) >= MAX_CUDA_GPUS) {
66-
return;
67-
}
68-
std::scoped_lock lock(g_cached_hw_device_mutexes[deviceIndex]);
69-
if (MAX_CONTEXTS_PER_GPU_IN_CACHE >= 0 &&
70-
g_cached_hw_device_ctxs[deviceIndex].size() >=
71-
MAX_CONTEXTS_PER_GPU_IN_CACHE) {
72-
return;
73-
}
74-
g_cached_hw_device_ctxs[deviceIndex].push_back(av_buffer_ref(hwContext));
75-
}
76-
77-
AVBufferRef* getFromCache(const torch::Device& device) {
78-
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device);
79-
if (static_cast<int>(deviceIndex) >= MAX_CUDA_GPUS) {
80-
return nullptr;
81-
}
82-
std::scoped_lock lock(g_cached_hw_device_mutexes[deviceIndex]);
83-
if (g_cached_hw_device_ctxs[deviceIndex].size() > 0) {
84-
AVBufferRef* hw_device_ctx = g_cached_hw_device_ctxs[deviceIndex].back();
85-
g_cached_hw_device_ctxs[deviceIndex].pop_back();
86-
return hw_device_ctx;
87-
}
88-
return nullptr;
89-
}
48+
PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>
49+
g_cached_hw_device_ctxs(MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE);
50+
PerGpuCache<NppStreamContext> g_cached_npp_ctxs(
51+
MAX_CUDA_GPUS,
52+
MAX_CONTEXTS_PER_GPU_IN_CACHE);
9053

9154
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
9255

@@ -143,14 +106,13 @@ AVBufferRef* getFFMPEGContextFromNewCudaContext(
143106

144107
#endif
145108

146-
AVBufferRef* getCudaContext(const torch::Device& device) {
109+
UniqueAVBufferRef getCudaContext(const torch::Device& device) {
147110
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
148111
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
149-
torch::DeviceIndex nonNegativeDeviceIndex =
150-
getFFMPEGCompatibleDeviceIndex(device);
112+
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
151113

152-
AVBufferRef* hw_device_ctx = getFromCache(device);
153-
if (hw_device_ctx != nullptr) {
114+
UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device);
115+
if (hw_device_ctx) {
154116
return hw_device_ctx;
155117
}
156118

@@ -161,15 +123,23 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
161123
// 58.26.100 of avutil.
162124
// https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
163125
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
164-
return getFFMPEGContextFromExistingCudaContext(
165-
device, nonNegativeDeviceIndex, type);
126+
return UniqueAVBufferRef(getFFMPEGContextFromExistingCudaContext(
127+
device, nonNegativeDeviceIndex, type));
166128
#else
167-
return getFFMPEGContextFromNewCudaContext(
168-
device, nonNegativeDeviceIndex, type);
129+
return UniqueAVBufferRef(
130+
getFFMPEGContextFromNewCudaContext(device, nonNegativeDeviceIndex, type));
169131
#endif
170132
}
171133

172-
NppStreamContext createNppStreamContext(int deviceIndex) {
134+
std::unique_ptr<NppStreamContext> getNppStreamContext(
135+
const torch::Device& device) {
136+
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
137+
138+
std::unique_ptr<NppStreamContext> nppCtx = g_cached_npp_ctxs.get(device);
139+
if (nppCtx) {
140+
return nppCtx;
141+
}
142+
173143
// From 12.9, NPP recommends using a user-created NppStreamContext and using
174144
// the `_Ctx()` calls:
175145
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1
@@ -178,30 +148,21 @@ NppStreamContext createNppStreamContext(int deviceIndex) {
178148
// properties:
179149
// https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72
180150

181-
NppStreamContext nppCtx{};
151+
nppCtx = std::make_unique<NppStreamContext>();
182152
cudaDeviceProp prop{};
183-
cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex);
153+
cudaError_t err = cudaGetDeviceProperties(&prop, nonNegativeDeviceIndex);
184154
TORCH_CHECK(
185155
err == cudaSuccess,
186156
"cudaGetDeviceProperties failed: ",
187157
cudaGetErrorString(err));
188158

189-
nppCtx.nCudaDeviceId = deviceIndex;
190-
nppCtx.nMultiProcessorCount = prop.multiProcessorCount;
191-
nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor;
192-
nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock;
193-
nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock;
194-
nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major;
195-
nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor;
196-
197-
// TODO when implementing the cache logic, move these out. See other TODO
198-
// below.
199-
nppCtx.hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream();
200-
err = cudaStreamGetFlags(nppCtx.hStream, &nppCtx.nStreamFlags);
201-
TORCH_CHECK(
202-
err == cudaSuccess,
203-
"cudaStreamGetFlags failed: ",
204-
cudaGetErrorString(err));
159+
nppCtx->nCudaDeviceId = nonNegativeDeviceIndex;
160+
nppCtx->nMultiProcessorCount = prop.multiProcessorCount;
161+
nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor;
162+
nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock;
163+
nppCtx->nSharedMemPerBlock = prop.sharedMemPerBlock;
164+
nppCtx->nCudaDevAttrComputeCapabilityMajor = prop.major;
165+
nppCtx->nCudaDevAttrComputeCapabilityMinor = prop.minor;
205166

206167
return nppCtx;
207168
}
@@ -217,8 +178,10 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
217178

218179
CudaDeviceInterface::~CudaDeviceInterface() {
219180
if (ctx_) {
220-
addToCacheIfCacheHasCapacity(device_, ctx_);
221-
av_buffer_unref(&ctx_);
181+
g_cached_hw_device_ctxs.addIfCacheHasCapacity(device_, std::move(ctx_));
182+
}
183+
if (nppCtx_) {
184+
g_cached_npp_ctxs.addIfCacheHasCapacity(device_, std::move(nppCtx_));
222185
}
223186
}
224187

@@ -231,7 +194,8 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
231194
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
232195
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
233196
ctx_ = getCudaContext(device_);
234-
codecContext->hw_device_ctx = av_buffer_ref(ctx_);
197+
nppCtx_ = getNppStreamContext(device_);
198+
codecContext->hw_device_ctx = av_buffer_ref(ctx_.get());
235199
return;
236200
}
237201

@@ -310,13 +274,14 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
310274
dst = allocateEmptyHWCTensor(height, width, device_);
311275
}
312276

313-
// TODO cache the NppStreamContext! It currently gets re-recated for every
314-
// single frame. The cache should be per-device, similar to the existing
315-
// hw_device_ctx cache. When implementing the cache logic, the
316-
// NppStreamContext hStream and nStreamFlags should not be part of the cache
317-
// because they may change across calls.
318-
NppStreamContext nppCtx = createNppStreamContext(
319-
static_cast<int>(getFFMPEGCompatibleDeviceIndex(device_)));
277+
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_);
278+
nppCtx_->hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream();
279+
cudaError_t err =
280+
cudaStreamGetFlags(nppCtx_->hStream, &nppCtx_->nStreamFlags);
281+
TORCH_CHECK(
282+
err == cudaSuccess,
283+
"cudaStreamGetFlags failed: ",
284+
cudaGetErrorString(err));
320285

321286
NppiSize oSizeROI = {width, height};
322287
Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]};
@@ -342,7 +307,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
342307
dst.stride(0),
343308
oSizeROI,
344309
bt709FullRangeColorTwist,
345-
nppCtx);
310+
*nppCtx_);
346311
} else {
347312
// If not full range, we assume studio limited range.
348313
// The color conversion matrix for BT.709 limited range should be:
@@ -359,7 +324,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
359324
static_cast<Npp8u*>(dst.data_ptr()),
360325
dst.stride(0),
361326
oSizeROI,
362-
nppCtx);
327+
*nppCtx_);
363328
}
364329
} else {
365330
// TODO we're assuming BT.601 color space (and probably limited range) by
@@ -371,7 +336,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
371336
static_cast<Npp8u*>(dst.data_ptr()),
372337
dst.stride(0),
373338
oSizeROI,
374-
nppCtx);
339+
*nppCtx_);
375340
}
376341
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");
377342
}

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class CudaDeviceInterface : public DeviceInterface {
3030
std::nullopt) override;
3131

3232
private:
33-
AVBufferRef* ctx_ = nullptr;
33+
UniqueAVBufferRef ctx_;
34+
std::unique_ptr<NppStreamContext> nppCtx_;
3435
};
3536

3637
} // namespace facebook::torchcodec

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ using UniqueSwrContext =
7676
std::unique_ptr<SwrContext, Deleterp<SwrContext, void, swr_free>>;
7777
using UniqueAVAudioFifo = std::
7878
unique_ptr<AVAudioFifo, Deleter<AVAudioFifo, void, av_audio_fifo_free>>;
79+
using UniqueAVBufferRef =
80+
std::unique_ptr<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>;
7981

8082
// These 2 classes share the same underlying AVPacket object. They are meant to
8183
// be used in tandem, like so:

0 commit comments

Comments
 (0)