@@ -53,7 +53,7 @@ torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) {
53
53
54
54
void addToCacheIfCacheHasCapacity (
55
55
const torch::Device& device,
56
- AVCodecContext* codecContext ) {
56
+ AVBufferRef* hwContext ) {
57
57
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex (device);
58
58
if (static_cast <int >(deviceIndex) >= MAX_CUDA_GPUS) {
59
59
return ;
@@ -64,8 +64,7 @@ void addToCacheIfCacheHasCapacity(
64
64
MAX_CONTEXTS_PER_GPU_IN_CACHE) {
65
65
return ;
66
66
}
67
- g_cached_hw_device_ctxs[deviceIndex].push_back (codecContext->hw_device_ctx );
68
- codecContext->hw_device_ctx = nullptr ;
67
+ g_cached_hw_device_ctxs[deviceIndex].push_back (hwContext);
69
68
}
70
69
71
70
AVBufferRef* getFromCache (const torch::Device& device) {
@@ -170,17 +169,22 @@ CudaDevice::CudaDevice(const torch::Device& device) : DeviceInterface(device) {
170
169
}
171
170
}
172
171
173
- void CudaDevice::releaseContext (AVCodecContext* codecContext) {
174
- addToCacheIfCacheHasCapacity (device_, codecContext);
172
+ CudaDevice::~CudaDevice () {
173
+ if (ctx_) {
174
+ addToCacheIfCacheHasCapacity (device_, ctx_);
175
+ }
175
176
}
176
177
177
178
void CudaDevice::initializeContext (AVCodecContext* codecContext) {
179
+ TORCH_CHECK (!ctx_, " FFmpeg HW device context already initialized" );
180
+
178
181
// It is important for pytorch itself to create the cuda context. If ffmpeg
179
182
// creates the context it may not be compatible with pytorch.
180
183
// This is a dummy tensor to initialize the cuda context.
181
184
torch::Tensor dummyTensorForCudaInitialization = torch::empty (
182
185
{1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
183
- codecContext->hw_device_ctx = getCudaContext (device_);
186
+ ctx_ = getCudaContext (device_);
187
+ codecContext->hw_device_ctx = av_buffer_ref (ctx_);
184
188
return ;
185
189
}
186
190
0 commit comments