diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index ecec520b..4b6bfc4a 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -67,7 +67,9 @@ function(make_torchcodec_libraries AVIOContextHolder.cpp AVIOBytesContext.cpp FFMPEGCommon.cpp + Frame.cpp DeviceInterface.cpp + CpuDeviceInterface.cpp SingleStreamDecoder.cpp # TODO: lib name should probably not be "*_decoder*" now that it also # contains an encoder diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp new file mode 100644 index 00000000..1728c87b --- /dev/null +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -0,0 +1,363 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/CpuDeviceInterface.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { +namespace { + +bool g_cpu = registerDeviceInterface( + torch::kCPU, + [](const torch::Device& device) { return new CpuDeviceInterface(device); }); + +} // namespace + +bool CpuDeviceInterface::DecodedFrameContext::operator==( + const CpuDeviceInterface::DecodedFrameContext& other) { + return decodedWidth == other.decodedWidth && + decodedHeight == other.decodedHeight && + decodedFormat == other.decodedFormat && + expectedWidth == other.expectedWidth && + expectedHeight == other.expectedHeight; +} + +bool CpuDeviceInterface::DecodedFrameContext::operator!=( + const CpuDeviceInterface::DecodedFrameContext& other) { + return !(*this == other); +} + +CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) + : DeviceInterface(device) { + if (device_.type() != torch::kCPU) { + throw std::runtime_error("Unsupported device: " + device_.str()); + } +} + +// Note [preAllocatedOutputTensor with swscale and filtergraph]: +// Callers may pass a pre-allocated tensor, where the output.data tensor will +// be stored. This parameter is honored in any case, but it only leads to a +// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the +// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet +// found a way to do that with filtegraph. +// TODO: Figure out whether that's possible! +// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of +// `dimension_order` parameter. It's up to callers to re-shape it if needed. +void CpuDeviceInterface::convertAVFrameToFrameOutput( + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase, + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + auto frameDims = + getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); + int expectedOutputHeight = frameDims.height; + int expectedOutputWidth = frameDims.width; + + if (preAllocatedOutputTensor.has_value()) { + auto shape = preAllocatedOutputTensor.value().sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == expectedOutputHeight) && + (shape[1] == expectedOutputWidth) && (shape[2] == 3), + "Expected pre-allocated tensor of shape ", + expectedOutputHeight, + "x", + expectedOutputWidth, + "x3, got ", + shape); + } + + torch::Tensor outputTensor; + // We need to compare the current frame context with our previous frame + // context. If they are different, then we need to re-create our colorspace + // conversion objects. We create our colorspace conversion objects late so + // that we don't have to depend on the unreliable metadata in the header. + // And we sometimes re-create them because it's possible for frame + // resolution to change mid-stream. Finally, we want to reuse the colorspace + // conversion objects as much as possible for performance reasons. + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); + auto frameContext = DecodedFrameContext{ + avFrame->width, + avFrame->height, + frameFormat, + avFrame->sample_aspect_ratio, + expectedOutputWidth, + expectedOutputHeight}; + + // By default, we want to use swscale for color conversion because it is + // faster. However, it has width requirements, so we may need to fall back + // to filtergraph. We also need to respect what was requested from the + // options; we respect the options unconditionally, so it's possible for + // swscale's width requirements to be violated. We don't expose the ability to + // choose color conversion library publicly; we only use this ability + // internally. + + // swscale requires widths to be multiples of 32: + // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements + // so we fall back to filtergraph if the width is not a multiple of 32. + auto defaultLibrary = (expectedOutputWidth % 32 == 0) + ? ColorConversionLibrary::SWSCALE + : ColorConversionLibrary::FILTERGRAPH; + + ColorConversionLibrary colorConversionLibrary = + videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); + + if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) { + outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( + expectedOutputHeight, expectedOutputWidth, torch::kCPU)); + + if (!swsContext_ || prevFrameContext_ != frameContext) { + createSwsContext(frameContext, avFrame->colorspace); + prevFrameContext_ = frameContext; + } + int resultHeight = + convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor); + // If this check failed, it would mean that the frame wasn't reshaped to + // the expected height. + // TODO: Can we do the same check for width? + TORCH_CHECK( + resultHeight == expectedOutputHeight, + "resultHeight != expectedOutputHeight: ", + resultHeight, + " != ", + expectedOutputHeight); + + frameOutput.data = outputTensor; + } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { + if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) { + createFilterGraph(frameContext, videoStreamOptions, timeBase); + prevFrameContext_ = frameContext; + } + outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame); + + // Similarly to above, if this check fails it means the frame wasn't + // reshaped to its expected dimensions by filtergraph. + auto shape = outputTensor.sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == expectedOutputHeight) && + (shape[1] == expectedOutputWidth) && (shape[2] == 3), + "Expected output tensor of shape ", + expectedOutputHeight, + "x", + expectedOutputWidth, + "x3, got ", + shape); + + if (preAllocatedOutputTensor.has_value()) { + // We have already validated that preAllocatedOutputTensor and + // outputTensor have the same shape. + preAllocatedOutputTensor.value().copy_(outputTensor); + frameOutput.data = preAllocatedOutputTensor.value(); + } else { + frameOutput.data = outputTensor; + } + } else { + throw std::runtime_error( + "Invalid color conversion library: " + + std::to_string(static_cast(colorConversionLibrary))); + } +} + +int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale( + const UniqueAVFrame& avFrame, + torch::Tensor& outputTensor) { + uint8_t* pointers[4] = { + outputTensor.data_ptr(), nullptr, nullptr, nullptr}; + int expectedOutputWidth = outputTensor.sizes()[1]; + int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; + int resultHeight = sws_scale( + swsContext_.get(), + avFrame->data, + avFrame->linesize, + 0, + avFrame->height, + pointers, + linesizes); + return resultHeight; +} + +torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( + const UniqueAVFrame& avFrame) { + int status = av_buffersrc_write_frame( + filterGraphContext_.sourceContext, avFrame.get()); + if (status < AVSUCCESS) { + throw std::runtime_error("Failed to add frame to buffer source context"); + } + + UniqueAVFrame filteredAVFrame(av_frame_alloc()); + status = av_buffersink_get_frame( + filterGraphContext_.sinkContext, filteredAVFrame.get()); + TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); + + auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); + int height = frameDims.height; + int width = frameDims.width; + std::vector shape = {height, width, 3}; + std::vector strides = {filteredAVFrame->linesize[0], 3, 1}; + AVFrame* filteredAVFramePtr = filteredAVFrame.release(); + auto deleter = [filteredAVFramePtr](void*) { + UniqueAVFrame avFrameToDelete(filteredAVFramePtr); + }; + return torch::from_blob( + filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); +} + +void CpuDeviceInterface::createFilterGraph( + const DecodedFrameContext& frameContext, + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase) { + filterGraphContext_.filterGraph.reset(avfilter_graph_alloc()); + TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr); + + if (videoStreamOptions.ffmpegThreadCount.has_value()) { + filterGraphContext_.filterGraph->nb_threads = + videoStreamOptions.ffmpegThreadCount.value(); + } + + const AVFilter* buffersrc = avfilter_get_by_name("buffer"); + const AVFilter* buffersink = avfilter_get_by_name("buffersink"); + + std::stringstream filterArgs; + filterArgs << "video_size=" << frameContext.decodedWidth << "x" + << frameContext.decodedHeight; + filterArgs << ":pix_fmt=" << frameContext.decodedFormat; + filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den; + filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/" + << frameContext.decodedAspectRatio.den; + + int status = avfilter_graph_create_filter( + &filterGraphContext_.sourceContext, + buffersrc, + "in", + filterArgs.str().c_str(), + nullptr, + filterGraphContext_.filterGraph.get()); + if (status < 0) { + throw std::runtime_error( + std::string("Failed to create filter graph: ") + filterArgs.str() + + ": " + getFFMPEGErrorStringFromErrorCode(status)); + } + + status = avfilter_graph_create_filter( + &filterGraphContext_.sinkContext, + buffersink, + "out", + nullptr, + nullptr, + filterGraphContext_.filterGraph.get()); + if (status < 0) { + throw std::runtime_error( + "Failed to create filter graph: " + + getFFMPEGErrorStringFromErrorCode(status)); + } + + enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; + + status = av_opt_set_int_list( + filterGraphContext_.sinkContext, + "pix_fmts", + pix_fmts, + AV_PIX_FMT_NONE, + AV_OPT_SEARCH_CHILDREN); + if (status < 0) { + throw std::runtime_error( + "Failed to set output pixel formats: " + + getFFMPEGErrorStringFromErrorCode(status)); + } + + UniqueAVFilterInOut outputs(avfilter_inout_alloc()); + UniqueAVFilterInOut inputs(avfilter_inout_alloc()); + + outputs->name = av_strdup("in"); + outputs->filter_ctx = filterGraphContext_.sourceContext; + outputs->pad_idx = 0; + outputs->next = nullptr; + inputs->name = av_strdup("out"); + inputs->filter_ctx = filterGraphContext_.sinkContext; + inputs->pad_idx = 0; + inputs->next = nullptr; + + std::stringstream description; + description << "scale=" << frameContext.expectedWidth << ":" + << frameContext.expectedHeight; + description << ":sws_flags=bilinear"; + + AVFilterInOut* outputsTmp = outputs.release(); + AVFilterInOut* inputsTmp = inputs.release(); + status = avfilter_graph_parse_ptr( + filterGraphContext_.filterGraph.get(), + description.str().c_str(), + &inputsTmp, + &outputsTmp, + nullptr); + outputs.reset(outputsTmp); + inputs.reset(inputsTmp); + if (status < 0) { + throw std::runtime_error( + "Failed to parse filter description: " + + getFFMPEGErrorStringFromErrorCode(status)); + } + + status = + avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr); + if (status < 0) { + throw std::runtime_error( + "Failed to configure filter graph: " + + getFFMPEGErrorStringFromErrorCode(status)); + } +} + +void CpuDeviceInterface::createSwsContext( + const DecodedFrameContext& frameContext, + const enum AVColorSpace colorspace) { + SwsContext* swsContext = sws_getContext( + frameContext.decodedWidth, + frameContext.decodedHeight, + frameContext.decodedFormat, + frameContext.expectedWidth, + frameContext.expectedHeight, + AV_PIX_FMT_RGB24, + SWS_BILINEAR, + nullptr, + nullptr, + nullptr); + TORCH_CHECK(swsContext, "sws_getContext() returned nullptr"); + + int* invTable = nullptr; + int* table = nullptr; + int srcRange, dstRange, brightness, contrast, saturation; + int ret = sws_getColorspaceDetails( + swsContext, + &invTable, + &srcRange, + &table, + &dstRange, + &brightness, + &contrast, + &saturation); + TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1"); + + const int* colorspaceTable = sws_getCoefficients(colorspace); + ret = sws_setColorspaceDetails( + swsContext, + colorspaceTable, + srcRange, + colorspaceTable, + dstRange, + brightness, + contrast, + saturation); + TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1"); + + swsContext_.reset(swsContext); +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h new file mode 100644 index 00000000..404289bd --- /dev/null +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -0,0 +1,80 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/_core/DeviceInterface.h" +#include "src/torchcodec/_core/FFMPEGCommon.h" + +namespace facebook::torchcodec { + +class CpuDeviceInterface : public DeviceInterface { + public: + CpuDeviceInterface(const torch::Device& device); + + virtual ~CpuDeviceInterface() {} + + std::optional findCodec( + [[maybe_unused]] const AVCodecID& codecId) override { + return std::nullopt; + } + + void initializeContext( + [[maybe_unused]] AVCodecContext* codecContext) override {} + + void convertAVFrameToFrameOutput( + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase, + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = + std::nullopt) override; + + private: + int convertAVFrameToTensorUsingSwsScale( + const UniqueAVFrame& avFrame, + torch::Tensor& outputTensor); + + torch::Tensor convertAVFrameToTensorUsingFilterGraph( + const UniqueAVFrame& avFrame); + + struct FilterGraphContext { + UniqueAVFilterGraph filterGraph; + AVFilterContext* sourceContext = nullptr; + AVFilterContext* sinkContext = nullptr; + }; + + struct DecodedFrameContext { + int decodedWidth; + int decodedHeight; + AVPixelFormat decodedFormat; + AVRational decodedAspectRatio; + int expectedWidth; + int expectedHeight; + bool operator==(const DecodedFrameContext&); + bool operator!=(const DecodedFrameContext&); + }; + + void createSwsContext( + const DecodedFrameContext& frameContext, + const enum AVColorSpace colorspace); + + void createFilterGraph( + const DecodedFrameContext& frameContext, + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase); + + // color-conversion fields. Only one of FilterGraphContext and + // UniqueSwsContext should be non-null. + FilterGraphContext filterGraphContext_; + UniqueSwsContext swsContext_; + + // Used to know whether a new FilterGraphContext or UniqueSwsContext should + // be created before decoding a new frame. + DecodedFrameContext prevFrameContext_; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 24443c68..5bd86d48 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -6,7 +6,6 @@ #include "src/torchcodec/_core/CudaDeviceInterface.h" #include "src/torchcodec/_core/FFMPEGCommon.h" -#include "src/torchcodec/_core/SingleStreamDecoder.h" extern "C" { #include @@ -193,6 +192,7 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) { void CudaDeviceInterface::convertAVFrameToFrameOutput( const VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] const AVRational& timeBase, UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index b60eff7a..01f3b19b 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -22,6 +22,7 @@ class CudaDeviceInterface : public DeviceInterface { void convertAVFrameToFrameOutput( const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase, UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp index 382de621..593a06b8 100644 --- a/src/torchcodec/_core/DeviceInterface.cpp +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -45,11 +45,6 @@ bool registerDeviceInterface( } torch::Device createTorchDevice(const std::string device) { - // TODO: remove once DeviceInterface for CPU is implemented - if (device == "cpu") { - return torch::kCPU; - } - std::scoped_lock lock(g_interface_mutex); std::string deviceType = getDeviceType(device); auto deviceInterface = std::find_if( @@ -70,11 +65,6 @@ torch::Device createTorchDevice(const std::string device) { std::unique_ptr createDeviceInterface( const torch::Device& device) { auto deviceType = device.type(); - // TODO: remove once DeviceInterface for CPU is implemented - if (deviceType == torch::kCPU) { - return nullptr; - } - std::scoped_lock lock(g_interface_mutex); TORCH_CHECK( g_interface_map->find(deviceType) != g_interface_map->end(), diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index b4197d7d..11a73b65 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -43,6 +43,7 @@ class DeviceInterface { virtual void convertAVFrameToFrameOutput( const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase, UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) = 0; diff --git a/src/torchcodec/_core/Frame.cpp b/src/torchcodec/_core/Frame.cpp new file mode 100644 index 00000000..bc3bbb78 --- /dev/null +++ b/src/torchcodec/_core/Frame.cpp @@ -0,0 +1,32 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/Frame.h" + +namespace facebook::torchcodec { + +torch::Tensor allocateEmptyHWCTensor( + int height, + int width, + torch::Device device, + std::optional numFrames) { + auto tensorOptions = torch::TensorOptions() + .dtype(torch::kUInt8) + .layout(torch::kStrided) + .device(device); + TORCH_CHECK(height > 0, "height must be > 0, got: ", height); + TORCH_CHECK(width > 0, "width must be > 0, got: ", width); + if (numFrames.has_value()) { + auto numFramesValue = numFrames.value(); + TORCH_CHECK( + numFramesValue >= 0, "numFrames must be >= 0, got: ", numFramesValue); + return torch::empty({numFramesValue, height, width, 3}, tensorOptions); + } else { + return torch::empty({height, width, 3}, tensorOptions); + } +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Frame.h b/src/torchcodec/_core/Frame.h index aa689734..84ccc728 100644 --- a/src/torchcodec/_core/Frame.h +++ b/src/torchcodec/_core/Frame.h @@ -7,6 +7,7 @@ #pragma once #include +#include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/Metadata.h" #include "src/torchcodec/_core/StreamOptions.h" @@ -44,4 +45,74 @@ struct AudioFramesOutput { double ptsSeconds; }; +// -------------------------------------------------------------------------- +// FRAME TENSOR ALLOCATION APIs +// -------------------------------------------------------------------------- + +// Note [Frame Tensor allocation and height and width] +// +// We always allocate [N]HWC tensors. The low-level decoding functions all +// assume HWC tensors, since this is what FFmpeg natively handles. It's up to +// the high-level decoding entry-points to permute that back to CHW, by calling +// maybePermuteHWC2CHW(). +// +// Also, importantly, the way we figure out the the height and width of the +// output frame tensor varies, and depends on the decoding entry-point. In +// *decreasing order of accuracy*, we use the following sources for determining +// height and width: +// - getHeightAndWidthFromResizedAVFrame(). This is the height and width of the +// AVframe, *post*-resizing. This is only used for single-frame decoding APIs, +// on CPU, with filtergraph. +// - getHeightAndWidthFromOptionsOrAVFrame(). This is the height and width from +// the user-specified options if they exist, or the height and width of the +// AVFrame *before* it is resized. In theory, i.e. if there are no bugs within +// our code or within FFmpeg code, this should be exactly the same as +// getHeightAndWidthFromResizedAVFrame(). This is used by single-frame +// decoding APIs, on CPU with swscale, and on GPU. +// - getHeightAndWidthFromOptionsOrMetadata(). This is the height and width from +// the user-specified options if they exist, or the height and width form the +// stream metadata, which itself got its value from the CodecContext, when the +// stream was added. This is used by batch decoding APIs, for both GPU and +// CPU. +// +// The source of truth for height and width really is the (resized) AVFrame: it +// comes from the decoded ouptut of FFmpeg. The info from the metadata (i.e. +// from the CodecContext) may not be as accurate. However, the AVFrame is only +// available late in the call stack, when the frame is decoded, while the +// CodecContext is available early when a stream is added. This is why we use +// the CodecContext for pre-allocating batched output tensors (we could +// pre-allocate those only once we decode the first frame to get the info frame +// the AVFrame, but that's a more complex logic). +// +// Because the sources for height and width may disagree, we may end up with +// conflicts: e.g. if we pre-allocate a batch output tensor based on the +// metadata info, but the decoded AVFrame has a different height and width. +// it is very important to check the height and width assumptions where the +// tensors memory is used/filled in order to avoid segfaults. + +struct FrameDims { + int height; + int width; + + FrameDims(int h, int w) : height(h), width(w) {} +}; + +// There's nothing preventing you from calling this on a non-resized frame, but +// please don't. +FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame); + +FrameDims getHeightAndWidthFromOptionsOrMetadata( + const VideoStreamOptions& videoStreamOptions, + const StreamMetadata& streamMetadata); + +FrameDims getHeightAndWidthFromOptionsOrAVFrame( + const VideoStreamOptions& videoStreamOptions, + const UniqueAVFrame& avFrame); + +torch::Tensor allocateEmptyHWCTensor( + int height, + int width, + torch::Device device, + std::optional numFrames = std::nullopt); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 9c7b44a4..cafbc70e 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -14,13 +14,6 @@ #include #include "torch/types.h" -extern "C" { -#include -#include -#include -#include -} - namespace facebook::torchcodec { namespace { @@ -452,24 +445,6 @@ void SingleStreamDecoder::addVideoStream( streamMetadata.width = streamInfo.codecContext->width; streamMetadata.height = streamInfo.codecContext->height; - - // By default, we want to use swscale for color conversion because it is - // faster. However, it has width requirements, so we may need to fall back - // to filtergraph. We also need to respect what was requested from the - // options; we respect the options unconditionally, so it's possible for - // swscale's width requirements to be violated. We don't expose the ability to - // choose color conversion library publicly; we only use this ability - // internally. - int width = videoStreamOptions.width.value_or(streamInfo.codecContext->width); - - // swscale requires widths to be multiples of 32: - // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements - // so we fall back to filtergraph if the width is not a multiple of 32. - auto defaultLibrary = (width % 32 == 0) ? ColorConversionLibrary::SWSCALE - : ColorConversionLibrary::FILTERGRAPH; - - streamInfo.colorConversionLibrary = - videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); } void SingleStreamDecoder::addAudioStream( @@ -1173,12 +1148,10 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( formatContext_->streams[activeStreamIndex_]->time_base); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); - } else if (!deviceInterface_) { - convertAVFrameToFrameOutputOnCPU( - avFrame, frameOutput, preAllocatedOutputTensor); } else if (deviceInterface_) { deviceInterface_->convertAVFrameToFrameOutput( streamInfo.videoStreamOptions, + streamInfo.timeBase, avFrame, frameOutput, preAllocatedOutputTensor); @@ -1186,163 +1159,6 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( return frameOutput; } -// Note [preAllocatedOutputTensor with swscale and filtergraph]: -// Callers may pass a pre-allocated tensor, where the output.data tensor will -// be stored. This parameter is honored in any case, but it only leads to a -// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the -// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet -// found a way to do that with filtegraph. -// TODO: Figure out whether that's possible! -// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of -// `dimension_order` parameter. It's up to callers to re-shape it if needed. -void SingleStreamDecoder::convertAVFrameToFrameOutputOnCPU( - UniqueAVFrame& avFrame, - FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor) { - auto& streamInfo = streamInfos_[activeStreamIndex_]; - - auto frameDims = getHeightAndWidthFromOptionsOrAVFrame( - streamInfo.videoStreamOptions, avFrame); - int expectedOutputHeight = frameDims.height; - int expectedOutputWidth = frameDims.width; - - if (preAllocatedOutputTensor.has_value()) { - auto shape = preAllocatedOutputTensor.value().sizes(); - TORCH_CHECK( - (shape.size() == 3) && (shape[0] == expectedOutputHeight) && - (shape[1] == expectedOutputWidth) && (shape[2] == 3), - "Expected pre-allocated tensor of shape ", - expectedOutputHeight, - "x", - expectedOutputWidth, - "x3, got ", - shape); - } - - torch::Tensor outputTensor; - // We need to compare the current frame context with our previous frame - // context. If they are different, then we need to re-create our colorspace - // conversion objects. We create our colorspace conversion objects late so - // that we don't have to depend on the unreliable metadata in the header. - // And we sometimes re-create them because it's possible for frame - // resolution to change mid-stream. Finally, we want to reuse the colorspace - // conversion objects as much as possible for performance reasons. - enum AVPixelFormat frameFormat = - static_cast(avFrame->format); - auto frameContext = DecodedFrameContext{ - avFrame->width, - avFrame->height, - frameFormat, - expectedOutputWidth, - expectedOutputHeight}; - - if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { - outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( - expectedOutputHeight, expectedOutputWidth, torch::kCPU)); - - if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) { - createSwsContext(streamInfo, frameContext, avFrame->colorspace); - streamInfo.prevFrameContext = frameContext; - } - int resultHeight = - convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor); - // If this check failed, it would mean that the frame wasn't reshaped to - // the expected height. - // TODO: Can we do the same check for width? - TORCH_CHECK( - resultHeight == expectedOutputHeight, - "resultHeight != expectedOutputHeight: ", - resultHeight, - " != ", - expectedOutputHeight); - - frameOutput.data = outputTensor; - } else if ( - streamInfo.colorConversionLibrary == - ColorConversionLibrary::FILTERGRAPH) { - if (!streamInfo.filterGraphContext.filterGraph || - streamInfo.prevFrameContext != frameContext) { - createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth); - streamInfo.prevFrameContext = frameContext; - } - outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame); - - // Similarly to above, if this check fails it means the frame wasn't - // reshaped to its expected dimensions by filtergraph. - auto shape = outputTensor.sizes(); - TORCH_CHECK( - (shape.size() == 3) && (shape[0] == expectedOutputHeight) && - (shape[1] == expectedOutputWidth) && (shape[2] == 3), - "Expected output tensor of shape ", - expectedOutputHeight, - "x", - expectedOutputWidth, - "x3, got ", - shape); - - if (preAllocatedOutputTensor.has_value()) { - // We have already validated that preAllocatedOutputTensor and - // outputTensor have the same shape. - preAllocatedOutputTensor.value().copy_(outputTensor); - frameOutput.data = preAllocatedOutputTensor.value(); - } else { - frameOutput.data = outputTensor; - } - } else { - throw std::runtime_error( - "Invalid color conversion library: " + - std::to_string(static_cast(streamInfo.colorConversionLibrary))); - } -} - -int SingleStreamDecoder::convertAVFrameToTensorUsingSwsScale( - const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor) { - StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_]; - SwsContext* swsContext = activeStreamInfo.swsContext.get(); - uint8_t* pointers[4] = { - outputTensor.data_ptr(), nullptr, nullptr, nullptr}; - int expectedOutputWidth = outputTensor.sizes()[1]; - int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; - int resultHeight = sws_scale( - swsContext, - avFrame->data, - avFrame->linesize, - 0, - avFrame->height, - pointers, - linesizes); - return resultHeight; -} - -torch::Tensor SingleStreamDecoder::convertAVFrameToTensorUsingFilterGraph( - const UniqueAVFrame& avFrame) { - FilterGraphContext& filterGraphContext = - streamInfos_[activeStreamIndex_].filterGraphContext; - int status = - av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame.get()); - if (status < AVSUCCESS) { - throw std::runtime_error("Failed to add frame to buffer source context"); - } - - UniqueAVFrame filteredAVFrame(av_frame_alloc()); - status = av_buffersink_get_frame( - filterGraphContext.sinkContext, filteredAVFrame.get()); - TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); - - auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); - int height = frameDims.height; - int width = frameDims.width; - std::vector shape = {height, width, 3}; - std::vector strides = {filteredAVFrame->linesize[0], 3, 1}; - AVFrame* filteredAVFramePtr = filteredAVFrame.release(); - auto deleter = [filteredAVFramePtr](void*) { - UniqueAVFrame avFrameToDelete(filteredAVFramePtr); - }; - return torch::from_blob( - filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); -} - void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( UniqueAVFrame& srcAVFrame, FrameOutput& frameOutput) { @@ -1462,27 +1278,6 @@ FrameBatchOutput::FrameBatchOutput( height, width, videoStreamOptions.device, numFrames); } -torch::Tensor allocateEmptyHWCTensor( - int height, - int width, - torch::Device device, - std::optional numFrames) { - auto tensorOptions = torch::TensorOptions() - .dtype(torch::kUInt8) - .layout(torch::kStrided) - .device(device); - TORCH_CHECK(height > 0, "height must be > 0, got: ", height); - TORCH_CHECK(width > 0, "width must be > 0, got: ", width); - if (numFrames.has_value()) { - auto numFramesValue = numFrames.value(); - TORCH_CHECK( - numFramesValue >= 0, "numFrames must be >= 0, got: ", numFramesValue); - return torch::empty({numFramesValue, height, width, 3}, tensorOptions); - } else { - return torch::empty({height, width, 3}, tensorOptions); - } -} - // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so. // The [N] leading batch-dimension is optional i.e. the input tensor can be 3D // or 4D. @@ -1508,176 +1303,6 @@ torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW( } } -// -------------------------------------------------------------------------- -// COLOR CONVERSION UTILS AND INITIALIZERS -// -------------------------------------------------------------------------- - -bool SingleStreamDecoder::DecodedFrameContext::operator==( - const SingleStreamDecoder::DecodedFrameContext& other) { - return decodedWidth == other.decodedWidth && - decodedHeight == other.decodedHeight && - decodedFormat == other.decodedFormat && - expectedWidth == other.expectedWidth && - expectedHeight == other.expectedHeight; -} - -bool SingleStreamDecoder::DecodedFrameContext::operator!=( - const SingleStreamDecoder::DecodedFrameContext& other) { - return !(*this == other); -} - -void SingleStreamDecoder::createFilterGraph( - StreamInfo& streamInfo, - int expectedOutputHeight, - int expectedOutputWidth) { - FilterGraphContext& filterGraphContext = streamInfo.filterGraphContext; - filterGraphContext.filterGraph.reset(avfilter_graph_alloc()); - TORCH_CHECK(filterGraphContext.filterGraph.get() != nullptr); - - if (streamInfo.videoStreamOptions.ffmpegThreadCount.has_value()) { - filterGraphContext.filterGraph->nb_threads = - streamInfo.videoStreamOptions.ffmpegThreadCount.value(); - } - - const AVFilter* buffersrc = avfilter_get_by_name("buffer"); - const AVFilter* buffersink = avfilter_get_by_name("buffersink"); - AVCodecContext* codecContext = streamInfo.codecContext.get(); - - std::stringstream filterArgs; - filterArgs << "video_size=" << codecContext->width << "x" - << codecContext->height; - filterArgs << ":pix_fmt=" << codecContext->pix_fmt; - filterArgs << ":time_base=" << streamInfo.stream->time_base.num << "/" - << streamInfo.stream->time_base.den; - filterArgs << ":pixel_aspect=" << codecContext->sample_aspect_ratio.num << "/" - << codecContext->sample_aspect_ratio.den; - - int status = avfilter_graph_create_filter( - &filterGraphContext.sourceContext, - buffersrc, - "in", - filterArgs.str().c_str(), - nullptr, - filterGraphContext.filterGraph.get()); - if (status < 0) { - throw std::runtime_error( - std::string("Failed to create filter graph: ") + filterArgs.str() + - ": " + getFFMPEGErrorStringFromErrorCode(status)); - } - - status = avfilter_graph_create_filter( - &filterGraphContext.sinkContext, - buffersink, - "out", - nullptr, - nullptr, - filterGraphContext.filterGraph.get()); - if (status < 0) { - throw std::runtime_error( - "Failed to create filter graph: " + - getFFMPEGErrorStringFromErrorCode(status)); - } - - enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; - - status = av_opt_set_int_list( - filterGraphContext.sinkContext, - "pix_fmts", - pix_fmts, - AV_PIX_FMT_NONE, - AV_OPT_SEARCH_CHILDREN); - if (status < 0) { - throw std::runtime_error( - "Failed to set output pixel formats: " + - getFFMPEGErrorStringFromErrorCode(status)); - } - - UniqueAVFilterInOut outputs(avfilter_inout_alloc()); - UniqueAVFilterInOut inputs(avfilter_inout_alloc()); - - outputs->name = av_strdup("in"); - outputs->filter_ctx = filterGraphContext.sourceContext; - outputs->pad_idx = 0; - outputs->next = nullptr; - inputs->name = av_strdup("out"); - inputs->filter_ctx = filterGraphContext.sinkContext; - inputs->pad_idx = 0; - inputs->next = nullptr; - - std::stringstream description; - description << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight; - description << ":sws_flags=bilinear"; - - AVFilterInOut* outputsTmp = outputs.release(); - AVFilterInOut* inputsTmp = inputs.release(); - status = avfilter_graph_parse_ptr( - filterGraphContext.filterGraph.get(), - description.str().c_str(), - &inputsTmp, - &outputsTmp, - nullptr); - outputs.reset(outputsTmp); - inputs.reset(inputsTmp); - if (status < 0) { - throw std::runtime_error( - "Failed to parse filter description: " + - getFFMPEGErrorStringFromErrorCode(status)); - } - - status = avfilter_graph_config(filterGraphContext.filterGraph.get(), nullptr); - if (status < 0) { - throw std::runtime_error( - "Failed to configure filter graph: " + - getFFMPEGErrorStringFromErrorCode(status)); - } -} - -void SingleStreamDecoder::createSwsContext( - StreamInfo& streamInfo, - const DecodedFrameContext& frameContext, - const enum AVColorSpace colorspace) { - SwsContext* swsContext = sws_getContext( - frameContext.decodedWidth, - frameContext.decodedHeight, - frameContext.decodedFormat, - frameContext.expectedWidth, - frameContext.expectedHeight, - AV_PIX_FMT_RGB24, - SWS_BILINEAR, - nullptr, - nullptr, - nullptr); - TORCH_CHECK(swsContext, "sws_getContext() returned nullptr"); - - int* invTable = nullptr; - int* table = nullptr; - int srcRange, dstRange, brightness, contrast, saturation; - int ret = sws_getColorspaceDetails( - swsContext, - &invTable, - &srcRange, - &table, - &dstRange, - &brightness, - &contrast, - &saturation); - TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1"); - - const int* colorspaceTable = sws_getCoefficients(colorspace); - ret = sws_setColorspaceDetails( - swsContext, - colorspaceTable, - srcRange, - colorspaceTable, - dstRange, - brightness, - contrast, - saturation); - TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1"); - - streamInfo.swsContext.reset(swsContext); -} - // -------------------------------------------------------------------------- // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index d8515111..cbacb847 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -195,22 +195,6 @@ class SingleStreamDecoder { bool isKeyFrame = false; }; - struct FilterGraphContext { - UniqueAVFilterGraph filterGraph; - AVFilterContext* sourceContext = nullptr; - AVFilterContext* sinkContext = nullptr; - }; - - struct DecodedFrameContext { - int decodedWidth; - int decodedHeight; - AVPixelFormat decodedFormat; - int expectedWidth; - int expectedHeight; - bool operator==(const DecodedFrameContext&); - bool operator!=(const DecodedFrameContext&); - }; - struct StreamInfo { int streamIndex = -1; AVStream* stream = nullptr; @@ -234,14 +218,7 @@ class SingleStreamDecoder { // color-conversion fields. Only one of FilterGraphContext and // UniqueSwsContext should be non-null. - FilterGraphContext filterGraphContext; - ColorConversionLibrary colorConversionLibrary = FILTERGRAPH; - UniqueSwsContext swsContext; UniqueSwrContext swrContext; - - // Used to know whether a new FilterGraphContext or UniqueSwsContext should - // be created before decoding a new frame. - DecodedFrameContext prevFrameContext; }; // -------------------------------------------------------------------------- @@ -289,20 +266,6 @@ class SingleStreamDecoder { std::optional maybeFlushSwrBuffers(); - // -------------------------------------------------------------------------- - // COLOR CONVERSION LIBRARIES HANDLERS CREATION - // -------------------------------------------------------------------------- - - void createFilterGraph( - StreamInfo& streamInfo, - int expectedOutputHeight, - int expectedOutputWidth); - - void createSwsContext( - StreamInfo& streamInfo, - const DecodedFrameContext& frameContext, - const enum AVColorSpace colorspace); - // -------------------------------------------------------------------------- // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- @@ -382,76 +345,6 @@ class SingleStreamDecoder { bool initialized_ = false; }; -// -------------------------------------------------------------------------- -// FRAME TENSOR ALLOCATION APIs -// -------------------------------------------------------------------------- - -// Note [Frame Tensor allocation and height and width] -// -// We always allocate [N]HWC tensors. The low-level decoding functions all -// assume HWC tensors, since this is what FFmpeg natively handles. It's up to -// the high-level decoding entry-points to permute that back to CHW, by calling -// maybePermuteHWC2CHW(). -// -// Also, importantly, the way we figure out the the height and width of the -// output frame tensor varies, and depends on the decoding entry-point. In -// *decreasing order of accuracy*, we use the following sources for determining -// height and width: -// - getHeightAndWidthFromResizedAVFrame(). This is the height and width of the -// AVframe, *post*-resizing. This is only used for single-frame decoding APIs, -// on CPU, with filtergraph. -// - getHeightAndWidthFromOptionsOrAVFrame(). This is the height and width from -// the user-specified options if they exist, or the height and width of the -// AVFrame *before* it is resized. In theory, i.e. if there are no bugs within -// our code or within FFmpeg code, this should be exactly the same as -// getHeightAndWidthFromResizedAVFrame(). This is used by single-frame -// decoding APIs, on CPU with swscale, and on GPU. -// - getHeightAndWidthFromOptionsOrMetadata(). This is the height and width from -// the user-specified options if they exist, or the height and width form the -// stream metadata, which itself got its value from the CodecContext, when the -// stream was added. This is used by batch decoding APIs, for both GPU and -// CPU. -// -// The source of truth for height and width really is the (resized) AVFrame: it -// comes from the decoded ouptut of FFmpeg. The info from the metadata (i.e. -// from the CodecContext) may not be as accurate. However, the AVFrame is only -// available late in the call stack, when the frame is decoded, while the -// CodecContext is available early when a stream is added. This is why we use -// the CodecContext for pre-allocating batched output tensors (we could -// pre-allocate those only once we decode the first frame to get the info frame -// the AVFrame, but that's a more complex logic). -// -// Because the sources for height and width may disagree, we may end up with -// conflicts: e.g. if we pre-allocate a batch output tensor based on the -// metadata info, but the decoded AVFrame has a different height and width. -// it is very important to check the height and width assumptions where the -// tensors memory is used/filled in order to avoid segfaults. - -struct FrameDims { - int height; - int width; - - FrameDims(int h, int w) : height(h), width(w) {} -}; - -// There's nothing preventing you from calling this on a non-resized frame, but -// please don't. -FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame); - -FrameDims getHeightAndWidthFromOptionsOrMetadata( - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata); - -FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const VideoStreamOptions& videoStreamOptions, - const UniqueAVFrame& avFrame); - -torch::Tensor allocateEmptyHWCTensor( - int height, - int width, - torch::Device device, - std::optional numFrames = std::nullopt); - // Prints the SingleStreamDecoder::DecodeStats to the ostream. std::ostream& operator<<( std::ostream& os,