diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index abec1d217..1f42e24d2 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -60,6 +60,7 @@ function(make_torchcodec_libraries set(decoder_sources AVIOContextHolder.cpp FFMPEGCommon.cpp + DeviceInterface.cpp SingleStreamDecoder.cpp # TODO: lib name should probably not be "*_decoder*" now that it also # contains an encoder @@ -68,8 +69,6 @@ function(make_torchcodec_libraries if(ENABLE_CUDA) list(APPEND decoder_sources CudaDevice.cpp) - else() - list(APPEND decoder_sources CPUOnlyDevice.cpp) endif() set(decoder_library_dependencies diff --git a/src/torchcodec/_core/CPUOnlyDevice.cpp b/src/torchcodec/_core/CPUOnlyDevice.cpp deleted file mode 100644 index 1d5b477dd..000000000 --- a/src/torchcodec/_core/CPUOnlyDevice.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include -#include "src/torchcodec/_core/DeviceInterface.h" - -namespace facebook::torchcodec { - -// This file is linked with the CPU-only version of torchcodec. -// So all functions will throw an error because they should only be called if -// the device is not CPU. - -[[noreturn]] void throwUnsupportedDeviceError(const torch::Device& device) { - TORCH_CHECK( - device.type() != torch::kCPU, - "Device functions should only be called if the device is not CPU.") - TORCH_CHECK(false, "Unsupported device: " + device.str()); -} - -void convertAVFrameToFrameOutputOnCuda( - const torch::Device& device, - [[maybe_unused]] const SingleStreamDecoder::VideoStreamOptions& - videoStreamOptions, - [[maybe_unused]] UniqueAVFrame& avFrame, - [[maybe_unused]] SingleStreamDecoder::FrameOutput& frameOutput, - [[maybe_unused]] std::optional preAllocatedOutputTensor) { - throwUnsupportedDeviceError(device); -} - -void initializeContextOnCuda( - const torch::Device& device, - [[maybe_unused]] AVCodecContext* codecContext) { - throwUnsupportedDeviceError(device); -} - -void releaseContextOnCuda( - const torch::Device& device, - [[maybe_unused]] AVCodecContext* codecContext) { - throwUnsupportedDeviceError(device); -} - -std::optional findCudaCodec( - const torch::Device& device, - [[maybe_unused]] const AVCodecID& codecId) { - throwUnsupportedDeviceError(device); -} - -} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CudaDevice.cpp b/src/torchcodec/_core/CudaDevice.cpp index fd8be9de8..5bde4106f 100644 --- a/src/torchcodec/_core/CudaDevice.cpp +++ b/src/torchcodec/_core/CudaDevice.cpp @@ -4,7 +4,7 @@ #include #include -#include "src/torchcodec/_core/DeviceInterface.h" +#include "src/torchcodec/_core/CudaDevice.h" #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" @@ -16,6 +16,10 @@ extern "C" { namespace facebook::torchcodec { namespace { +bool g_cuda = registerDeviceInterface( + torch::kCUDA, + [](const torch::Device& device) { return new CudaDevice(device); }); + // We reuse cuda contexts across VideoDeoder instances. This is because // creating a cuda context is expensive. The cache mechanism is as follows: // 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for @@ -49,7 +53,7 @@ torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) { void addToCacheIfCacheHasCapacity( const torch::Device& device, - AVCodecContext* codecContext) { + AVBufferRef* hwContext) { torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device); if (static_cast(deviceIndex) >= MAX_CUDA_GPUS) { return; @@ -60,8 +64,7 @@ void addToCacheIfCacheHasCapacity( MAX_CONTEXTS_PER_GPU_IN_CACHE) { return; } - g_cached_hw_device_ctxs[deviceIndex].push_back(codecContext->hw_device_ctx); - codecContext->hw_device_ctx = nullptr; + g_cached_hw_device_ctxs[deviceIndex].push_back(av_buffer_ref(hwContext)); } AVBufferRef* getFromCache(const torch::Device& device) { @@ -158,39 +161,35 @@ AVBufferRef* getCudaContext(const torch::Device& device) { device, nonNegativeDeviceIndex, type); #endif } +} // namespace -void throwErrorIfNonCudaDevice(const torch::Device& device) { - TORCH_CHECK( - device.type() != torch::kCPU, - "Device functions should only be called if the device is not CPU.") - if (device.type() != torch::kCUDA) { - throw std::runtime_error("Unsupported device: " + device.str()); +CudaDevice::CudaDevice(const torch::Device& device) : DeviceInterface(device) { + if (device_.type() != torch::kCUDA) { + throw std::runtime_error("Unsupported device: " + device_.str()); } } -} // namespace -void releaseContextOnCuda( - const torch::Device& device, - AVCodecContext* codecContext) { - throwErrorIfNonCudaDevice(device); - addToCacheIfCacheHasCapacity(device, codecContext); +CudaDevice::~CudaDevice() { + if (ctx_) { + addToCacheIfCacheHasCapacity(device_, ctx_); + av_buffer_unref(&ctx_); + } } -void initializeContextOnCuda( - const torch::Device& device, - AVCodecContext* codecContext) { - throwErrorIfNonCudaDevice(device); +void CudaDevice::initializeContext(AVCodecContext* codecContext) { + TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); + // It is important for pytorch itself to create the cuda context. If ffmpeg // creates the context it may not be compatible with pytorch. // This is a dummy tensor to initialize the cuda context. torch::Tensor dummyTensorForCudaInitialization = torch::empty( - {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); - codecContext->hw_device_ctx = getCudaContext(device); + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); + ctx_ = getCudaContext(device_); + codecContext->hw_device_ctx = av_buffer_ref(ctx_); return; } -void convertAVFrameToFrameOutputOnCuda( - const torch::Device& device, +void CudaDevice::convertAVFrameToFrameOutput( const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, UniqueAVFrame& avFrame, SingleStreamDecoder::FrameOutput& frameOutput, @@ -217,11 +216,11 @@ void convertAVFrameToFrameOutputOnCuda( "x3, got ", shape); } else { - dst = allocateEmptyHWCTensor(height, width, videoStreamOptions.device); + dst = allocateEmptyHWCTensor(height, width, device_); } // Use the user-requested GPU for running the NPP kernel. - c10::cuda::CUDAGuard deviceGuard(device); + c10::cuda::CUDAGuard deviceGuard(device_); NppiSize oSizeROI = {width, height}; Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]}; @@ -249,7 +248,7 @@ void convertAVFrameToFrameOutputOnCuda( // output. at::cuda::CUDAEvent nppDoneEvent; at::cuda::CUDAStream nppStreamWrapper = - c10::cuda::getStreamFromExternal(nppGetStream(), device.index()); + c10::cuda::getStreamFromExternal(nppGetStream(), device_.index()); nppDoneEvent.record(nppStreamWrapper); nppDoneEvent.block(at::cuda::getCurrentCUDAStream()); @@ -264,11 +263,7 @@ void convertAVFrameToFrameOutputOnCuda( // we have to do this because of an FFmpeg bug where hardware decoding is not // appropriately set, so we just go off and find the matching codec for the CUDA // device -std::optional findCudaCodec( - const torch::Device& device, - const AVCodecID& codecId) { - throwErrorIfNonCudaDevice(device); - +std::optional CudaDevice::findCodec(const AVCodecID& codecId) { void* i = nullptr; const AVCodec* codec = nullptr; while ((codec = av_codec_iterate(&i)) != nullptr) { diff --git a/src/torchcodec/_core/CudaDevice.h b/src/torchcodec/_core/CudaDevice.h new file mode 100644 index 000000000..0ed538593 --- /dev/null +++ b/src/torchcodec/_core/CudaDevice.h @@ -0,0 +1,34 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/_core/DeviceInterface.h" + +namespace facebook::torchcodec { + +class CudaDevice : public DeviceInterface { + public: + CudaDevice(const torch::Device& device); + + virtual ~CudaDevice(); + + std::optional findCodec(const AVCodecID& codecId) override; + + void initializeContext(AVCodecContext* codecContext) override; + + void convertAVFrameToFrameOutput( + const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, + UniqueAVFrame& avFrame, + SingleStreamDecoder::FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = + std::nullopt) override; + + private: + AVBufferRef* ctx_ = nullptr; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp new file mode 100644 index 000000000..7612334b1 --- /dev/null +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -0,0 +1,77 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/DeviceInterface.h" +#include +#include + +namespace facebook::torchcodec { + +namespace { +std::mutex g_interface_mutex; +std::map g_interface_map; + +std::string getDeviceType(const std::string& device) { + size_t pos = device.find(':'); + if (pos == std::string::npos) { + return device; + } + return device.substr(0, pos); +} + +} // namespace + +bool registerDeviceInterface( + torch::DeviceType deviceType, + CreateDeviceInterfaceFn createInterface) { + std::scoped_lock lock(g_interface_mutex); + TORCH_CHECK( + g_interface_map.find(deviceType) == g_interface_map.end(), + "Device interface already registered for ", + deviceType); + g_interface_map.insert({deviceType, createInterface}); + return true; +} + +torch::Device createTorchDevice(const std::string device) { + // TODO: remove once DeviceInterface for CPU is implemented + if (device == "cpu") { + return torch::kCPU; + } + + std::scoped_lock lock(g_interface_mutex); + std::string deviceType = getDeviceType(device); + auto deviceInterface = std::find_if( + g_interface_map.begin(), + g_interface_map.end(), + [&](const std::pair& arg) { + return device.rfind( + torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0; + }); + TORCH_CHECK( + deviceInterface != g_interface_map.end(), "Unsupported device: ", device); + + return torch::Device(device); +} + +std::unique_ptr createDeviceInterface( + const torch::Device& device) { + auto deviceType = device.type(); + // TODO: remove once DeviceInterface for CPU is implemented + if (deviceType == torch::kCPU) { + return nullptr; + } + + std::scoped_lock lock(g_interface_mutex); + TORCH_CHECK( + g_interface_map.find(deviceType) != g_interface_map.end(), + "Unsupported device: ", + device); + + return std::unique_ptr(g_interface_map[deviceType](device)); +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 352b83d35..a5b0e3652 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -7,6 +7,7 @@ #pragma once #include +#include #include #include #include @@ -23,25 +24,42 @@ namespace facebook::torchcodec { // deviceFunction(device, ...); // } -// Initialize the hardware device that is specified in `device`. Some builds -// support CUDA and others only support CPU. -void initializeContextOnCuda( - const torch::Device& device, - AVCodecContext* codecContext); - -void convertAVFrameToFrameOutputOnCuda( - const torch::Device& device, - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, - UniqueAVFrame& avFrame, - SingleStreamDecoder::FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = std::nullopt); - -void releaseContextOnCuda( - const torch::Device& device, - AVCodecContext* codecContext); - -std::optional findCudaCodec( - const torch::Device& device, - const AVCodecID& codecId); +class DeviceInterface { + public: + DeviceInterface(const torch::Device& device) : device_(device) {} + + virtual ~DeviceInterface(){}; + + torch::Device& device() { + return device_; + }; + + virtual std::optional findCodec(const AVCodecID& codecId) = 0; + + // Initialize the hardware device that is specified in `device`. Some builds + // support CUDA and others only support CPU. + virtual void initializeContext(AVCodecContext* codecContext) = 0; + + virtual void convertAVFrameToFrameOutput( + const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, + UniqueAVFrame& avFrame, + SingleStreamDecoder::FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = std::nullopt) = 0; + + protected: + torch::Device device_; +}; + +using CreateDeviceInterfaceFn = + std::function; + +bool registerDeviceInterface( + torch::DeviceType deviceType, + const CreateDeviceInterfaceFn createInterface); + +torch::Device createTorchDevice(const std::string device); + +std::unique_ptr createDeviceInterface( + const torch::Device& device); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index b7438f199..c7c714da3 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -92,18 +92,6 @@ SingleStreamDecoder::SingleStreamDecoder( initializeDecoder(); } -SingleStreamDecoder::~SingleStreamDecoder() { - for (auto& [streamIndex, streamInfo] : streamInfos_) { - auto& device = streamInfo.videoStreamOptions.device; - if (device.type() == torch::kCPU) { - } else if (device.type() == torch::kCUDA) { - releaseContextOnCuda(device, streamInfo.codecContext.get()); - } else { - TORCH_CHECK(false, "Invalid device type: " + device.str()); - } - } -} - void SingleStreamDecoder::initializeDecoder() { TORCH_CHECK(!initialized_, "Attempted double initialization."); @@ -418,6 +406,8 @@ void SingleStreamDecoder::addStream( streamInfo.stream = formatContext_->streams[activeStreamIndex_]; streamInfo.avMediaType = mediaType; + deviceInterface = createDeviceInterface(device); + // This should never happen, checking just to be safe. TORCH_CHECK( streamInfo.stream->codecpar->codec_type == mediaType, @@ -427,10 +417,12 @@ void SingleStreamDecoder::addStream( // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within // addStream() which is supposed to be generic - if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) { - avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - findCudaCodec(device, streamInfo.stream->codecpar->codec_id) - .value_or(avCodec)); + if (mediaType == AVMEDIA_TYPE_VIDEO) { + if (deviceInterface) { + avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( + deviceInterface->findCodec(streamInfo.stream->codecpar->codec_id) + .value_or(avCodec)); + } } AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); @@ -445,8 +437,10 @@ void SingleStreamDecoder::addStream( streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base; // TODO_CODE_QUALITY same as above. - if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) { - initializeContextOnCuda(device, codecContext); + if (mediaType == AVMEDIA_TYPE_VIDEO) { + if (deviceInterface) { + deviceInterface->initializeContext(codecContext); + } } retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); @@ -472,11 +466,6 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions) { - TORCH_CHECK( - videoStreamOptions.device.type() == torch::kCPU || - videoStreamOptions.device.type() == torch::kCUDA, - "Invalid device type: " + videoStreamOptions.device.str()); - addStream( streamIndex, AVMEDIA_TYPE_VIDEO, @@ -1221,20 +1210,15 @@ SingleStreamDecoder::convertAVFrameToFrameOutput( formatContext_->streams[activeStreamIndex_]->time_base); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); - } else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { + } else if (!deviceInterface) { convertAVFrameToFrameOutputOnCPU( avFrame, frameOutput, preAllocatedOutputTensor); - } else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) { - convertAVFrameToFrameOutputOnCuda( - streamInfo.videoStreamOptions.device, + } else if (deviceInterface) { + deviceInterface->convertAVFrameToFrameOutput( streamInfo.videoStreamOptions, avFrame, frameOutput, preAllocatedOutputTensor); - } else { - TORCH_CHECK( - false, - "Invalid device type: " + streamInfo.videoStreamOptions.device.str()); } return frameOutput; } diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index f712cdbbd..4879a3b7d 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -16,14 +16,13 @@ #include "src/torchcodec/_core/FFMPEGCommon.h" namespace facebook::torchcodec { +class DeviceInterface; // The SingleStreamDecoder class can be used to decode video frames to Tensors. // Note that SingleStreamDecoder is not thread-safe. // Do not call non-const APIs concurrently on the same object. class SingleStreamDecoder { public: - ~SingleStreamDecoder(); - // -------------------------------------------------------------------------- // CONSTRUCTION API // -------------------------------------------------------------------------- @@ -493,6 +492,7 @@ class SingleStreamDecoder { SeekMode seekMode_; ContainerMetadata containerMetadata_; UniqueDecodingAVFormatContext formatContext_; + std::unique_ptr deviceInterface; std::map streamInfos_; const int NO_ACTIVE_STREAM = -2; int activeStreamIndex_ = NO_ACTIVE_STREAM; diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 596412a8f..05a6390d6 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -11,6 +11,7 @@ #include "c10/core/SymIntArrayRef.h" #include "c10/util/Exception.h" #include "src/torchcodec/_core/AVIOBytesContext.h" +#include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/Encoder.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" @@ -242,16 +243,7 @@ void _add_video_stream( } } if (device.has_value()) { - if (device.value() == "cpu") { - videoStreamOptions.device = torch::Device(torch::kCPU); - } else if (device.value().rfind("cuda", 0) == 0) { // starts with "cuda" - std::string deviceStr(device.value()); - videoStreamOptions.device = torch::Device(deviceStr); - } else { - throw std::runtime_error( - "Invalid device=" + std::string(device.value()) + - ". device must be either cpu or cuda."); - } + videoStreamOptions.device = createTorchDevice(std::string(device.value())); } auto videoDecoder = unwrapTensorToGetDecoder(decoder); diff --git a/test/VideoDecoderTest.cpp b/test/VideoDecoderTest.cpp index a7ad4c6d3..1937ff97c 100644 --- a/test/VideoDecoderTest.cpp +++ b/test/VideoDecoderTest.cpp @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #include "src/torchcodec/_core/AVIOBytesContext.h" +#include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" #include