Skip to content

Commit 58b55e3

Browse files
committed
Make device interface generic
Fixes: #605 Changes: * Device interface made device agnostic by intorducing `class DeviceInterface` from which specific backends should inherit their device specific implementations * Implemented `CudaDevice` derived from `DeviceInterface` * Created device interface registration mechanism (`registerDeviceInterface`) * Created device interface creation mechanism (`createDeviceInterface`) These changes allow to replace CUDA specific code in `VideoDecoder.cpp` and `VideoDecoderOps.cpp` by device agnostic code. Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent a864bf9 commit 58b55e3

File tree

9 files changed

+175
-135
lines changed

9 files changed

+175
-135
lines changed

src/torchcodec/decoders/_core/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ function(make_torchcodec_libraries
6161
AVIOContextHolder.cpp
6262
FFMPEGCommon.cpp
6363
VideoDecoder.cpp
64+
DeviceInterface.cpp
6465
)
6566

6667
if(ENABLE_CUDA)
6768
list(APPEND decoder_sources CudaDevice.cpp)
68-
else()
69-
list(APPEND decoder_sources CPUOnlyDevice.cpp)
7069
endif()
7170

7271
set(decoder_library_dependencies

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 0 additions & 44 deletions
This file was deleted.

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <torch/types.h>
55
#include <mutex>
66

7-
#include "src/torchcodec/decoders/_core/DeviceInterface.h"
7+
#include "src/torchcodec/decoders/_core/CudaDevice.h"
88
#include "src/torchcodec/decoders/_core/FFMPEGCommon.h"
99
#include "src/torchcodec/decoders/_core/VideoDecoder.h"
1010

@@ -16,6 +16,10 @@ extern "C" {
1616
namespace facebook::torchcodec {
1717
namespace {
1818

19+
bool g_cuda = registerDeviceInterface("cuda", [](const std::string& device) {
20+
return new CudaDevice(device);
21+
});
22+
1923
// We reuse cuda contexts across VideoDeoder instances. This is because
2024
// creating a cuda context is expensive. The cache mechanism is as follows:
2125
// 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for
@@ -156,39 +160,29 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
156160
device, nonNegativeDeviceIndex, type);
157161
#endif
158162
}
163+
} // namespace
159164

160-
void throwErrorIfNonCudaDevice(const torch::Device& device) {
161-
TORCH_CHECK(
162-
device.type() != torch::kCPU,
163-
"Device functions should only be called if the device is not CPU.")
164-
if (device.type() != torch::kCUDA) {
165-
throw std::runtime_error("Unsupported device: " + device.str());
165+
CudaDevice::CudaDevice(const std::string& device) : DeviceInterface(device) {
166+
if (device_.type() != torch::kCUDA) {
167+
throw std::runtime_error("Unsupported device: " + device_.str());
166168
}
167169
}
168-
} // namespace
169170

170-
void releaseContextOnCuda(
171-
const torch::Device& device,
172-
AVCodecContext* codecContext) {
173-
throwErrorIfNonCudaDevice(device);
174-
addToCacheIfCacheHasCapacity(device, codecContext);
171+
void CudaDevice::releaseContext(AVCodecContext* codecContext) {
172+
addToCacheIfCacheHasCapacity(device_, codecContext);
175173
}
176174

177-
void initializeContextOnCuda(
178-
const torch::Device& device,
179-
AVCodecContext* codecContext) {
180-
throwErrorIfNonCudaDevice(device);
175+
void CudaDevice::initializeContext(AVCodecContext* codecContext) {
181176
// It is important for pytorch itself to create the cuda context. If ffmpeg
182177
// creates the context it may not be compatible with pytorch.
183178
// This is a dummy tensor to initialize the cuda context.
184179
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
185-
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device));
186-
codecContext->hw_device_ctx = getCudaContext(device);
180+
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
181+
codecContext->hw_device_ctx = getCudaContext(device_);
187182
return;
188183
}
189184

190-
void convertAVFrameToFrameOutputOnCuda(
191-
const torch::Device& device,
185+
void CudaDevice::convertAVFrameToFrameOutput(
192186
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
193187
UniqueAVFrame& avFrame,
194188
VideoDecoder::FrameOutput& frameOutput,
@@ -215,11 +209,11 @@ void convertAVFrameToFrameOutputOnCuda(
215209
"x3, got ",
216210
shape);
217211
} else {
218-
dst = allocateEmptyHWCTensor(height, width, videoStreamOptions.device);
212+
dst = allocateEmptyHWCTensor(height, width, device_);
219213
}
220214

221215
// Use the user-requested GPU for running the NPP kernel.
222-
c10::cuda::CUDAGuard deviceGuard(device);
216+
c10::cuda::CUDAGuard deviceGuard(device_);
223217

224218
NppiSize oSizeROI = {width, height};
225219
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};
@@ -247,7 +241,7 @@ void convertAVFrameToFrameOutputOnCuda(
247241
// output.
248242
at::cuda::CUDAEvent nppDoneEvent;
249243
at::cuda::CUDAStream nppStreamWrapper =
250-
c10::cuda::getStreamFromExternal(nppGetStream(), device.index());
244+
c10::cuda::getStreamFromExternal(nppGetStream(), device_.index());
251245
nppDoneEvent.record(nppStreamWrapper);
252246
nppDoneEvent.block(at::cuda::getCurrentCUDAStream());
253247

@@ -262,11 +256,7 @@ void convertAVFrameToFrameOutputOnCuda(
262256
// we have to do this because of an FFmpeg bug where hardware decoding is not
263257
// appropriately set, so we just go off and find the matching codec for the CUDA
264258
// device
265-
std::optional<const AVCodec*> findCudaCodec(
266-
const torch::Device& device,
267-
const AVCodecID& codecId) {
268-
throwErrorIfNonCudaDevice(device);
269-
259+
std::optional<const AVCodec*> CudaDevice::findCodec(const AVCodecID& codecId) {
270260
void* i = nullptr;
271261
const AVCodec* codec = nullptr;
272262
while ((codec = av_codec_iterate(&i)) != nullptr) {
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include "src/torchcodec/decoders/_core/DeviceInterface.h"
10+
11+
namespace facebook::torchcodec {
12+
13+
struct CudaDevice : public DeviceInterface {
14+
CudaDevice(const std::string& device);
15+
16+
virtual ~CudaDevice(){};
17+
18+
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
19+
20+
void initializeContext(AVCodecContext* codecContext) override;
21+
22+
void convertAVFrameToFrameOutput(
23+
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
24+
UniqueAVFrame& avFrame,
25+
VideoDecoder::FrameOutput& frameOutput,
26+
std::optional<torch::Tensor> preAllocatedOutputTensor =
27+
std::nullopt) override;
28+
29+
void releaseContext(AVCodecContext* codecContext) override;
30+
};
31+
32+
} // namespace facebook::torchcodec
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "src/torchcodec/decoders/_core/DeviceInterface.h"
8+
#include <map>
9+
#include <mutex>
10+
11+
namespace facebook::torchcodec {
12+
13+
namespace {
14+
std::mutex g_interface_mutex;
15+
std::map<std::string, CreateDeviceInterfaceFn> g_interface_map;
16+
17+
std::string getDeviceType(const std::string& device) {
18+
size_t pos = device.find(':');
19+
if (pos == std::string::npos) {
20+
return device;
21+
}
22+
return device.substr(0, pos);
23+
}
24+
25+
} // namespace
26+
27+
bool registerDeviceInterface(
28+
const std::string device_type,
29+
CreateDeviceInterfaceFn create_interface) {
30+
std::scoped_lock lock(g_interface_mutex);
31+
TORCH_CHECK(
32+
g_interface_map.find(device_type) == g_interface_map.end(),
33+
"Device interface already registered for ",
34+
device_type);
35+
g_interface_map.insert({device_type, create_interface});
36+
return true;
37+
}
38+
39+
std::shared_ptr<DeviceInterface> createDeviceInterface(
40+
const std::string device) {
41+
// TODO: remove once DeviceInterface for CPU is implemented
42+
if (device == "cpu") {
43+
return nullptr;
44+
// return std::shared_ptr<DeviceInterface>();
45+
}
46+
47+
std::scoped_lock lock(g_interface_mutex);
48+
std::string device_type = getDeviceType(device);
49+
TORCH_CHECK(
50+
g_interface_map.find(device_type) != g_interface_map.end(),
51+
"Unsupported device: ",
52+
device);
53+
54+
return std::shared_ptr<DeviceInterface>(g_interface_map[device_type](device));
55+
// return std::shared_ptr<DeviceInterface>(g_interface_map[device_type]);
56+
}
57+
58+
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#pragma once
88

99
#include <torch/types.h>
10+
#include <functional>
1011
#include <memory>
1112
#include <stdexcept>
1213
#include <string>
@@ -23,25 +24,41 @@ namespace facebook::torchcodec {
2324
// deviceFunction(device, ...);
2425
// }
2526

26-
// Initialize the hardware device that is specified in `device`. Some builds
27-
// support CUDA and others only support CPU.
28-
void initializeContextOnCuda(
29-
const torch::Device& device,
30-
AVCodecContext* codecContext);
31-
32-
void convertAVFrameToFrameOutputOnCuda(
33-
const torch::Device& device,
34-
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
35-
UniqueAVFrame& avFrame,
36-
VideoDecoder::FrameOutput& frameOutput,
37-
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
38-
39-
void releaseContextOnCuda(
40-
const torch::Device& device,
41-
AVCodecContext* codecContext);
42-
43-
std::optional<const AVCodec*> findCudaCodec(
44-
const torch::Device& device,
45-
const AVCodecID& codecId);
27+
struct DeviceInterface {
28+
DeviceInterface(const std::string& device) : device_(device) {}
29+
30+
virtual ~DeviceInterface(){};
31+
32+
torch::Device& device() {
33+
return device_;
34+
};
35+
36+
virtual std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) = 0;
37+
38+
// Initialize the hardware device that is specified in `device`. Some builds
39+
// support CUDA and others only support CPU.
40+
virtual void initializeContext(AVCodecContext* codecContext) = 0;
41+
42+
virtual void convertAVFrameToFrameOutput(
43+
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
44+
UniqueAVFrame& avFrame,
45+
VideoDecoder::FrameOutput& frameOutput,
46+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;
47+
48+
virtual void releaseContext(AVCodecContext* codecContext) = 0;
49+
50+
protected:
51+
torch::Device device_;
52+
};
53+
54+
using CreateDeviceInterfaceFn =
55+
std::function<DeviceInterface*(const std::string& device)>;
56+
57+
bool registerDeviceInterface(
58+
const std::string device_type,
59+
const CreateDeviceInterfaceFn create_interface);
60+
61+
std::shared_ptr<DeviceInterface> createDeviceInterface(
62+
const std::string device);
4663

4764
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)