4
4
#include < torch/types.h>
5
5
#include < mutex>
6
6
7
+ #include " src/torchcodec/_core/Cache.h"
7
8
#include " src/torchcodec/_core/CudaDeviceInterface.h"
8
9
#include " src/torchcodec/_core/FFMPEGCommon.h"
9
10
@@ -44,49 +45,11 @@ const int MAX_CUDA_GPUS = 128;
44
45
// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching.
45
46
// Set to a positive number to have a cache of that size.
46
47
const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1 ;
47
- std::vector<AVBufferRef*> g_cached_hw_device_ctxs[MAX_CUDA_GPUS];
48
- std::mutex g_cached_hw_device_mutexes[MAX_CUDA_GPUS];
49
-
50
- torch::DeviceIndex getFFMPEGCompatibleDeviceIndex (const torch::Device& device) {
51
- torch::DeviceIndex deviceIndex = device.index ();
52
- deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0 );
53
- TORCH_CHECK (deviceIndex >= 0 , " Device index out of range" );
54
- // FFMPEG cannot handle negative device indices.
55
- // For single GPU- machines libtorch returns -1 for the device index. So for
56
- // that case we set the device index to 0.
57
- // TODO: Double check if this works for multi-GPU machines correctly.
58
- return deviceIndex;
59
- }
60
-
61
- void addToCacheIfCacheHasCapacity (
62
- const torch::Device& device,
63
- AVBufferRef* hwContext) {
64
- torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex (device);
65
- if (static_cast <int >(deviceIndex) >= MAX_CUDA_GPUS) {
66
- return ;
67
- }
68
- std::scoped_lock lock (g_cached_hw_device_mutexes[deviceIndex]);
69
- if (MAX_CONTEXTS_PER_GPU_IN_CACHE >= 0 &&
70
- g_cached_hw_device_ctxs[deviceIndex].size () >=
71
- MAX_CONTEXTS_PER_GPU_IN_CACHE) {
72
- return ;
73
- }
74
- g_cached_hw_device_ctxs[deviceIndex].push_back (av_buffer_ref (hwContext));
75
- }
76
-
77
- AVBufferRef* getFromCache (const torch::Device& device) {
78
- torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex (device);
79
- if (static_cast <int >(deviceIndex) >= MAX_CUDA_GPUS) {
80
- return nullptr ;
81
- }
82
- std::scoped_lock lock (g_cached_hw_device_mutexes[deviceIndex]);
83
- if (g_cached_hw_device_ctxs[deviceIndex].size () > 0 ) {
84
- AVBufferRef* hw_device_ctx = g_cached_hw_device_ctxs[deviceIndex].back ();
85
- g_cached_hw_device_ctxs[deviceIndex].pop_back ();
86
- return hw_device_ctx;
87
- }
88
- return nullptr ;
89
- }
48
+ PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void , av_buffer_unref>>
49
+ g_cached_hw_device_ctxs (MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE);
50
+ PerGpuCache<NppStreamContext> g_cached_npp_ctxs (
51
+ MAX_CUDA_GPUS,
52
+ MAX_CONTEXTS_PER_GPU_IN_CACHE);
90
53
91
54
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
92
55
@@ -143,14 +106,13 @@ AVBufferRef* getFFMPEGContextFromNewCudaContext(
143
106
144
107
#endif
145
108
146
- AVBufferRef* getCudaContext (const torch::Device& device) {
109
+ UniqueAVBufferRef getCudaContext (const torch::Device& device) {
147
110
enum AVHWDeviceType type = av_hwdevice_find_type_by_name (" cuda" );
148
111
TORCH_CHECK (type != AV_HWDEVICE_TYPE_NONE, " Failed to find cuda device" );
149
- torch::DeviceIndex nonNegativeDeviceIndex =
150
- getFFMPEGCompatibleDeviceIndex (device);
112
+ torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex (device);
151
113
152
- AVBufferRef* hw_device_ctx = getFromCache (device);
153
- if (hw_device_ctx != nullptr ) {
114
+ UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs. get (device);
115
+ if (hw_device_ctx) {
154
116
return hw_device_ctx;
155
117
}
156
118
@@ -161,15 +123,23 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
161
123
// 58.26.100 of avutil.
162
124
// https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
163
125
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
164
- return getFFMPEGContextFromExistingCudaContext (
165
- device, nonNegativeDeviceIndex, type);
126
+ return UniqueAVBufferRef ( getFFMPEGContextFromExistingCudaContext (
127
+ device, nonNegativeDeviceIndex, type)) ;
166
128
#else
167
- return getFFMPEGContextFromNewCudaContext (
168
- device, nonNegativeDeviceIndex, type);
129
+ return UniqueAVBufferRef (
130
+ getFFMPEGContextFromNewCudaContext ( device, nonNegativeDeviceIndex, type) );
169
131
#endif
170
132
}
171
133
172
- NppStreamContext createNppStreamContext (int deviceIndex) {
134
+ std::unique_ptr<NppStreamContext> getNppStreamContext (
135
+ const torch::Device& device) {
136
+ torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex (device);
137
+
138
+ std::unique_ptr<NppStreamContext> nppCtx = g_cached_npp_ctxs.get (device);
139
+ if (nppCtx) {
140
+ return nppCtx;
141
+ }
142
+
173
143
// From 12.9, NPP recommends using a user-created NppStreamContext and using
174
144
// the `_Ctx()` calls:
175
145
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1
@@ -178,30 +148,21 @@ NppStreamContext createNppStreamContext(int deviceIndex) {
178
148
// properties:
179
149
// https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72
180
150
181
- NppStreamContext nppCtx{} ;
151
+ nppCtx = std::make_unique<NppStreamContext>() ;
182
152
cudaDeviceProp prop{};
183
- cudaError_t err = cudaGetDeviceProperties (&prop, deviceIndex );
153
+ cudaError_t err = cudaGetDeviceProperties (&prop, nonNegativeDeviceIndex );
184
154
TORCH_CHECK (
185
155
err == cudaSuccess,
186
156
" cudaGetDeviceProperties failed: " ,
187
157
cudaGetErrorString (err));
188
158
189
- nppCtx.nCudaDeviceId = deviceIndex;
190
- nppCtx.nMultiProcessorCount = prop.multiProcessorCount ;
191
- nppCtx.nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor ;
192
- nppCtx.nMaxThreadsPerBlock = prop.maxThreadsPerBlock ;
193
- nppCtx.nSharedMemPerBlock = prop.sharedMemPerBlock ;
194
- nppCtx.nCudaDevAttrComputeCapabilityMajor = prop.major ;
195
- nppCtx.nCudaDevAttrComputeCapabilityMinor = prop.minor ;
196
-
197
- // TODO when implementing the cache logic, move these out. See other TODO
198
- // below.
199
- nppCtx.hStream = at::cuda::getCurrentCUDAStream (deviceIndex).stream ();
200
- err = cudaStreamGetFlags (nppCtx.hStream , &nppCtx.nStreamFlags );
201
- TORCH_CHECK (
202
- err == cudaSuccess,
203
- " cudaStreamGetFlags failed: " ,
204
- cudaGetErrorString (err));
159
+ nppCtx->nCudaDeviceId = nonNegativeDeviceIndex;
160
+ nppCtx->nMultiProcessorCount = prop.multiProcessorCount ;
161
+ nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor ;
162
+ nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock ;
163
+ nppCtx->nSharedMemPerBlock = prop.sharedMemPerBlock ;
164
+ nppCtx->nCudaDevAttrComputeCapabilityMajor = prop.major ;
165
+ nppCtx->nCudaDevAttrComputeCapabilityMinor = prop.minor ;
205
166
206
167
return nppCtx;
207
168
}
@@ -217,8 +178,10 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
217
178
218
179
CudaDeviceInterface::~CudaDeviceInterface () {
219
180
if (ctx_) {
220
- addToCacheIfCacheHasCapacity (device_, ctx_);
221
- av_buffer_unref (&ctx_);
181
+ g_cached_hw_device_ctxs.addIfCacheHasCapacity (device_, std::move (ctx_));
182
+ }
183
+ if (nppCtx_) {
184
+ g_cached_npp_ctxs.addIfCacheHasCapacity (device_, std::move (nppCtx_));
222
185
}
223
186
}
224
187
@@ -231,7 +194,8 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
231
194
torch::Tensor dummyTensorForCudaInitialization = torch::empty (
232
195
{1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
233
196
ctx_ = getCudaContext (device_);
234
- codecContext->hw_device_ctx = av_buffer_ref (ctx_);
197
+ nppCtx_ = getNppStreamContext (device_);
198
+ codecContext->hw_device_ctx = av_buffer_ref (ctx_.get ());
235
199
return ;
236
200
}
237
201
@@ -310,13 +274,14 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
310
274
dst = allocateEmptyHWCTensor (height, width, device_);
311
275
}
312
276
313
- // TODO cache the NppStreamContext! It currently gets re-recated for every
314
- // single frame. The cache should be per-device, similar to the existing
315
- // hw_device_ctx cache. When implementing the cache logic, the
316
- // NppStreamContext hStream and nStreamFlags should not be part of the cache
317
- // because they may change across calls.
318
- NppStreamContext nppCtx = createNppStreamContext (
319
- static_cast <int >(getFFMPEGCompatibleDeviceIndex (device_)));
277
+ torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex (device_);
278
+ nppCtx_->hStream = at::cuda::getCurrentCUDAStream (deviceIndex).stream ();
279
+ cudaError_t err =
280
+ cudaStreamGetFlags (nppCtx_->hStream , &nppCtx_->nStreamFlags );
281
+ TORCH_CHECK (
282
+ err == cudaSuccess,
283
+ " cudaStreamGetFlags failed: " ,
284
+ cudaGetErrorString (err));
320
285
321
286
NppiSize oSizeROI = {width, height};
322
287
Npp8u* yuvData[2 ] = {avFrame->data [0 ], avFrame->data [1 ]};
@@ -342,7 +307,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
342
307
dst.stride (0 ),
343
308
oSizeROI,
344
309
bt709FullRangeColorTwist,
345
- nppCtx );
310
+ *nppCtx_ );
346
311
} else {
347
312
// If not full range, we assume studio limited range.
348
313
// The color conversion matrix for BT.709 limited range should be:
@@ -359,7 +324,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
359
324
static_cast <Npp8u*>(dst.data_ptr ()),
360
325
dst.stride (0 ),
361
326
oSizeROI,
362
- nppCtx );
327
+ *nppCtx_ );
363
328
}
364
329
} else {
365
330
// TODO we're assuming BT.601 color space (and probably limited range) by
@@ -371,7 +336,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
371
336
static_cast <Npp8u*>(dst.data_ptr ()),
372
337
dst.stride (0 ),
373
338
oSizeROI,
374
- nppCtx );
339
+ *nppCtx_ );
375
340
}
376
341
TORCH_CHECK (status == NPP_SUCCESS, " Failed to convert NV12 frame." );
377
342
}
0 commit comments