Skip to content

Commit a0d3d59

Browse files
committed
Drop releaseContext
Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent d922430 commit a0d3d59

File tree

5 files changed

+13
-21
lines changed

5 files changed

+13
-21
lines changed

src/torchcodec/_core/CudaDevice.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) {
5353

5454
void addToCacheIfCacheHasCapacity(
5555
const torch::Device& device,
56-
AVCodecContext* codecContext) {
56+
AVBufferRef* hwContext) {
5757
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device);
5858
if (static_cast<int>(deviceIndex) >= MAX_CUDA_GPUS) {
5959
return;
@@ -64,8 +64,7 @@ void addToCacheIfCacheHasCapacity(
6464
MAX_CONTEXTS_PER_GPU_IN_CACHE) {
6565
return;
6666
}
67-
g_cached_hw_device_ctxs[deviceIndex].push_back(codecContext->hw_device_ctx);
68-
codecContext->hw_device_ctx = nullptr;
67+
g_cached_hw_device_ctxs[deviceIndex].push_back(hwContext);
6968
}
7069

7170
AVBufferRef* getFromCache(const torch::Device& device) {
@@ -170,17 +169,22 @@ CudaDevice::CudaDevice(const torch::Device& device) : DeviceInterface(device) {
170169
}
171170
}
172171

173-
void CudaDevice::releaseContext(AVCodecContext* codecContext) {
174-
addToCacheIfCacheHasCapacity(device_, codecContext);
172+
CudaDevice::~CudaDevice() {
173+
if (ctx_) {
174+
addToCacheIfCacheHasCapacity(device_, ctx_);
175+
}
175176
}
176177

177178
void CudaDevice::initializeContext(AVCodecContext* codecContext) {
179+
TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized");
180+
178181
// It is important for pytorch itself to create the cuda context. If ffmpeg
179182
// creates the context it may not be compatible with pytorch.
180183
// This is a dummy tensor to initialize the cuda context.
181184
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
182185
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
183-
codecContext->hw_device_ctx = getCudaContext(device_);
186+
ctx_ = getCudaContext(device_);
187+
codecContext->hw_device_ctx = av_buffer_ref(ctx_);
184188
return;
185189
}
186190

src/torchcodec/_core/CudaDevice.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class CudaDevice : public DeviceInterface {
1414
public:
1515
CudaDevice(const torch::Device& device);
1616

17-
virtual ~CudaDevice(){};
17+
virtual ~CudaDevice();
1818

1919
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
2020

@@ -27,7 +27,8 @@ class CudaDevice : public DeviceInterface {
2727
std::optional<torch::Tensor> preAllocatedOutputTensor =
2828
std::nullopt) override;
2929

30-
void releaseContext(AVCodecContext* codecContext) override;
30+
private:
31+
AVBufferRef* ctx_ = nullptr;
3132
};
3233

3334
} // namespace facebook::torchcodec

src/torchcodec/_core/DeviceInterface.h

-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ class DeviceInterface {
4646
SingleStreamDecoder::FrameOutput& frameOutput,
4747
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;
4848

49-
virtual void releaseContext(AVCodecContext* codecContext) = 0;
50-
5149
protected:
5250
torch::Device device_;
5351
};

src/torchcodec/_core/SingleStreamDecoder.cpp

-9
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,6 @@ SingleStreamDecoder::SingleStreamDecoder(
9292
initializeDecoder();
9393
}
9494

95-
SingleStreamDecoder::~SingleStreamDecoder() {
96-
for (auto& [streamIndex, streamInfo] : streamInfos_) {
97-
auto& deviceInterface = streamInfo.deviceInterface;
98-
if (deviceInterface) {
99-
deviceInterface->releaseContext(streamInfo.codecContext.get());
100-
}
101-
}
102-
}
103-
10495
void SingleStreamDecoder::initializeDecoder() {
10596
TORCH_CHECK(!initialized_, "Attempted double initialization.");
10697

src/torchcodec/_core/SingleStreamDecoder.h

-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ class DeviceInterface;
2323
// Do not call non-const APIs concurrently on the same object.
2424
class SingleStreamDecoder {
2525
public:
26-
~SingleStreamDecoder();
27-
2826
// --------------------------------------------------------------------------
2927
// CONSTRUCTION API
3028
// --------------------------------------------------------------------------

0 commit comments

Comments
 (0)