Skip to content

Commit 428db54

Browse files
committed
Implement initializeFiltersContext for CPU device interface
Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 051ecc3 commit 428db54

File tree

2 files changed

+82
-86
lines changed

2 files changed

+82
-86
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 75 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,35 @@ static bool g_cpu = registerDeviceInterface(
1313
torch::kCPU,
1414
[](const torch::Device& device) { return new CpuDeviceInterface(device); });
1515

16+
ColorConversionLibrary getColorConversionLibrary(
17+
const VideoStreamOptions& videoStreamOptions,
18+
int width) {
19+
// By default, we want to use swscale for color conversion because it is
20+
// faster. However, it has width requirements, so we may need to fall back
21+
// to filtergraph. We also need to respect what was requested from the
22+
// options; we respect the options unconditionally, so it's possible for
23+
// swscale's width requirements to be violated. We don't expose the ability to
24+
// choose color conversion library publicly; we only use this ability
25+
// internally.
26+
27+
// swscale requires widths to be multiples of 32:
28+
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
29+
// so we fall back to filtergraph if the width is not a multiple of 32.
30+
auto defaultLibrary = (width % 32 == 0)
31+
? ColorConversionLibrary::SWSCALE
32+
: ColorConversionLibrary::FILTERGRAPH;
33+
34+
ColorConversionLibrary colorConversionLibrary =
35+
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
36+
37+
TORCH_CHECK(
38+
colorConversionLibrary == ColorConversionLibrary::SWSCALE ||
39+
colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH,
40+
"Invalid color conversion library: ",
41+
static_cast<int>(colorConversionLibrary));
42+
return colorConversionLibrary;
43+
}
44+
1645
} // namespace
1746

1847
bool CpuDeviceInterface::SwsFrameContext::operator==(
@@ -34,6 +63,41 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
3463
device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
3564
}
3665

66+
std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContext(
67+
const VideoStreamOptions& videoStreamOptions,
68+
const UniqueAVFrame& avFrame,
69+
const AVRational& timeBase) {
70+
enum AVPixelFormat frameFormat =
71+
static_cast<enum AVPixelFormat>(avFrame->format);
72+
auto frameDims =
73+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
74+
int expectedOutputHeight = frameDims.height;
75+
int expectedOutputWidth = frameDims.width;
76+
77+
if (getColorConversionLibrary(videoStreamOptions, expectedOutputWidth) == ColorConversionLibrary::SWSCALE) {
78+
return nullptr;
79+
}
80+
81+
std::unique_ptr<FiltersContext> filtersContext =
82+
std::make_unique<FiltersContext>();
83+
84+
filtersContext->inputWidth = avFrame->width;
85+
filtersContext->inputHeight = avFrame->height;
86+
filtersContext->inputFormat = frameFormat;
87+
filtersContext->inputAspectRatio = avFrame->sample_aspect_ratio;
88+
filtersContext->outputWidth = expectedOutputWidth;
89+
filtersContext->outputHeight = expectedOutputHeight;
90+
filtersContext->outputFormat = AV_PIX_FMT_RGB24;
91+
filtersContext->timeBase = timeBase;
92+
93+
std::stringstream filters;
94+
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
95+
filters << ":sws_flags=bilinear";
96+
97+
filtersContext->filtergraphStr = filters.str();
98+
return filtersContext;
99+
}
100+
37101
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
38102
// Callers may pass a pre-allocated tensor, where the output.data tensor will
39103
// be stored. This parameter is honored in any case, but it only leads to a
@@ -45,7 +109,7 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
45109
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
46110
void CpuDeviceInterface::convertAVFrameToFrameOutput(
47111
const VideoStreamOptions& videoStreamOptions,
48-
const AVRational& timeBase,
112+
[[maybe_unused]] const AVRational& timeBase,
49113
UniqueAVFrame& avFrame,
50114
FrameOutput& frameOutput,
51115
std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -71,23 +135,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
71135
enum AVPixelFormat frameFormat =
72136
static_cast<enum AVPixelFormat>(avFrame->format);
73137

74-
// By default, we want to use swscale for color conversion because it is
75-
// faster. However, it has width requirements, so we may need to fall back
76-
// to filtergraph. We also need to respect what was requested from the
77-
// options; we respect the options unconditionally, so it's possible for
78-
// swscale's width requirements to be violated. We don't expose the ability to
79-
// choose color conversion library publicly; we only use this ability
80-
// internally.
81-
82-
// swscale requires widths to be multiples of 32:
83-
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
84-
// so we fall back to filtergraph if the width is not a multiple of 32.
85-
auto defaultLibrary = (expectedOutputWidth % 32 == 0)
86-
? ColorConversionLibrary::SWSCALE
87-
: ColorConversionLibrary::FILTERGRAPH;
88-
89138
ColorConversionLibrary colorConversionLibrary =
90-
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
139+
getColorConversionLibrary(videoStreamOptions, expectedOutputWidth);
91140

92141
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
93142
// We need to compare the current frame context with our previous frame
@@ -126,44 +175,16 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
126175

127176
frameOutput.data = outputTensor;
128177
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
129-
// See comment above in swscale branch about the filterGraphContext_
130-
// creation. creation
131-
FiltersContext filtersContext;
132-
133-
filtersContext.inputWidth = avFrame->width;
134-
filtersContext.inputHeight = avFrame->height;
135-
filtersContext.inputFormat = frameFormat;
136-
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
137-
filtersContext.outputWidth = expectedOutputWidth;
138-
filtersContext.outputHeight = expectedOutputHeight;
139-
filtersContext.outputFormat = AV_PIX_FMT_RGB24;
140-
filtersContext.timeBase = timeBase;
141-
142-
std::stringstream filters;
143-
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
144-
filters << ":sws_flags=bilinear";
145-
146-
filtersContext.filtergraphStr = filters.str();
147-
148-
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
149-
filterGraphContext_ =
150-
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
151-
prevFiltersContext_ = std::move(filtersContext);
152-
}
153-
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
178+
TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24);
154179

155-
// Similarly to above, if this check fails it means the frame wasn't
156-
// reshaped to its expected dimensions by filtergraph.
157-
auto shape = outputTensor.sizes();
158-
TORCH_CHECK(
159-
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
160-
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
161-
"Expected output tensor of shape ",
162-
expectedOutputHeight,
163-
"x",
164-
expectedOutputWidth,
165-
"x3, got ",
166-
shape);
180+
std::vector<int64_t> shape = {expectedOutputHeight, expectedOutputWidth, 3};
181+
std::vector<int64_t> strides = {avFrame->linesize[0], 3, 1};
182+
AVFrame* avFramePtr = avFrame.release();
183+
auto deleter = [avFramePtr](void*) {
184+
UniqueAVFrame avFrameToDelete(avFramePtr);
185+
};
186+
outputTensor = torch::from_blob(
187+
avFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
167188

168189
if (preAllocatedOutputTensor.has_value()) {
169190
// We have already validated that preAllocatedOutputTensor and
@@ -173,11 +194,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
173194
} else {
174195
frameOutput.data = outputTensor;
175196
}
176-
} else {
177-
TORCH_CHECK(
178-
false,
179-
"Invalid color conversion library: ",
180-
static_cast<int>(colorConversionLibrary));
181197
}
182198
}
183199

@@ -199,25 +215,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
199215
return resultHeight;
200216
}
201217

202-
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
203-
const UniqueAVFrame& avFrame) {
204-
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);
205-
206-
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
207-
208-
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
209-
int height = frameDims.height;
210-
int width = frameDims.width;
211-
std::vector<int64_t> shape = {height, width, 3};
212-
std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
213-
AVFrame* filteredAVFramePtr = filteredAVFrame.release();
214-
auto deleter = [filteredAVFramePtr](void*) {
215-
UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
216-
};
217-
return torch::from_blob(
218-
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
219-
}
220-
221218
void CpuDeviceInterface::createSwsContext(
222219
const SwsFrameContext& swsFrameContext,
223220
const enum AVColorSpace colorspace) {

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ class CpuDeviceInterface : public DeviceInterface {
2626
void initializeContext(
2727
[[maybe_unused]] AVCodecContext* codecContext) override {}
2828

29+
std::unique_ptr<FiltersContext> initializeFiltersContext(
30+
const VideoStreamOptions& videoStreamOptions,
31+
const UniqueAVFrame& avFrame,
32+
const AVRational& timeBase) override;
33+
2934
void convertAVFrameToFrameOutput(
3035
const VideoStreamOptions& videoStreamOptions,
3136
const AVRational& timeBase,
@@ -39,9 +44,6 @@ class CpuDeviceInterface : public DeviceInterface {
3944
const UniqueAVFrame& avFrame,
4045
torch::Tensor& outputTensor);
4146

42-
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
43-
const UniqueAVFrame& avFrame);
44-
4547
struct SwsFrameContext {
4648
int inputWidth;
4749
int inputHeight;
@@ -56,15 +58,12 @@ class CpuDeviceInterface : public DeviceInterface {
5658
const SwsFrameContext& swsFrameContext,
5759
const enum AVColorSpace colorspace);
5860

59-
// color-conversion fields. Only one of FilterGraphContext and
60-
// UniqueSwsContext should be non-null.
61-
std::unique_ptr<FilterGraph> filterGraphContext_;
61+
// SWS color conversion context
6262
UniqueSwsContext swsContext_;
6363

64-
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
64+
// Used to know whether a new UniqueSwsContext should
6565
// be created before decoding a new frame.
6666
SwsFrameContext prevSwsFrameContext_;
67-
FiltersContext prevFiltersContext_;
6867
};
6968

7069
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)