@@ -13,6 +13,35 @@ static bool g_cpu = registerDeviceInterface(
13
13
torch::kCPU ,
14
14
[](const torch::Device& device) { return new CpuDeviceInterface (device); });
15
15
16
+ ColorConversionLibrary getColorConversionLibrary (
17
+ const VideoStreamOptions& videoStreamOptions,
18
+ int width) {
19
+ // By default, we want to use swscale for color conversion because it is
20
+ // faster. However, it has width requirements, so we may need to fall back
21
+ // to filtergraph. We also need to respect what was requested from the
22
+ // options; we respect the options unconditionally, so it's possible for
23
+ // swscale's width requirements to be violated. We don't expose the ability to
24
+ // choose color conversion library publicly; we only use this ability
25
+ // internally.
26
+
27
+ // swscale requires widths to be multiples of 32:
28
+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
29
+ // so we fall back to filtergraph if the width is not a multiple of 32.
30
+ auto defaultLibrary = (width % 32 == 0 )
31
+ ? ColorConversionLibrary::SWSCALE
32
+ : ColorConversionLibrary::FILTERGRAPH;
33
+
34
+ ColorConversionLibrary colorConversionLibrary =
35
+ videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
36
+
37
+ TORCH_CHECK (
38
+ colorConversionLibrary == ColorConversionLibrary::SWSCALE ||
39
+ colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH,
40
+ " Invalid color conversion library: " ,
41
+ static_cast <int >(colorConversionLibrary));
42
+ return colorConversionLibrary;
43
+ }
44
+
16
45
} // namespace
17
46
18
47
bool CpuDeviceInterface::SwsFrameContext::operator ==(
@@ -34,6 +63,41 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
34
63
device_.type () == torch::kCPU , " Unsupported device: " , device_.str ());
35
64
}
36
65
66
+ std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContext (
67
+ const VideoStreamOptions& videoStreamOptions,
68
+ const UniqueAVFrame& avFrame,
69
+ const AVRational& timeBase) {
70
+ enum AVPixelFormat frameFormat =
71
+ static_cast <enum AVPixelFormat>(avFrame->format );
72
+ auto frameDims =
73
+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
74
+ int expectedOutputHeight = frameDims.height ;
75
+ int expectedOutputWidth = frameDims.width ;
76
+
77
+ if (getColorConversionLibrary (videoStreamOptions, expectedOutputWidth) == ColorConversionLibrary::SWSCALE) {
78
+ return nullptr ;
79
+ }
80
+
81
+ std::unique_ptr<FiltersContext> filtersContext =
82
+ std::make_unique<FiltersContext>();
83
+
84
+ filtersContext->inputWidth = avFrame->width ;
85
+ filtersContext->inputHeight = avFrame->height ;
86
+ filtersContext->inputFormat = frameFormat;
87
+ filtersContext->inputAspectRatio = avFrame->sample_aspect_ratio ;
88
+ filtersContext->outputWidth = expectedOutputWidth;
89
+ filtersContext->outputHeight = expectedOutputHeight;
90
+ filtersContext->outputFormat = AV_PIX_FMT_RGB24;
91
+ filtersContext->timeBase = timeBase;
92
+
93
+ std::stringstream filters;
94
+ filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
95
+ filters << " :sws_flags=bilinear" ;
96
+
97
+ filtersContext->filtergraphStr = filters.str ();
98
+ return filtersContext;
99
+ }
100
+
37
101
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
38
102
// Callers may pass a pre-allocated tensor, where the output.data tensor will
39
103
// be stored. This parameter is honored in any case, but it only leads to a
@@ -45,7 +109,7 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
45
109
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
46
110
void CpuDeviceInterface::convertAVFrameToFrameOutput (
47
111
const VideoStreamOptions& videoStreamOptions,
48
- const AVRational& timeBase,
112
+ [[maybe_unused]] const AVRational& timeBase,
49
113
UniqueAVFrame& avFrame,
50
114
FrameOutput& frameOutput,
51
115
std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -71,23 +135,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
71
135
enum AVPixelFormat frameFormat =
72
136
static_cast <enum AVPixelFormat>(avFrame->format );
73
137
74
- // By default, we want to use swscale for color conversion because it is
75
- // faster. However, it has width requirements, so we may need to fall back
76
- // to filtergraph. We also need to respect what was requested from the
77
- // options; we respect the options unconditionally, so it's possible for
78
- // swscale's width requirements to be violated. We don't expose the ability to
79
- // choose color conversion library publicly; we only use this ability
80
- // internally.
81
-
82
- // swscale requires widths to be multiples of 32:
83
- // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
84
- // so we fall back to filtergraph if the width is not a multiple of 32.
85
- auto defaultLibrary = (expectedOutputWidth % 32 == 0 )
86
- ? ColorConversionLibrary::SWSCALE
87
- : ColorConversionLibrary::FILTERGRAPH;
88
-
89
138
ColorConversionLibrary colorConversionLibrary =
90
- videoStreamOptions. colorConversionLibrary . value_or (defaultLibrary );
139
+ getColorConversionLibrary (videoStreamOptions, expectedOutputWidth );
91
140
92
141
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
93
142
// We need to compare the current frame context with our previous frame
@@ -126,44 +175,16 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
126
175
127
176
frameOutput.data = outputTensor;
128
177
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
129
- // See comment above in swscale branch about the filterGraphContext_
130
- // creation. creation
131
- FiltersContext filtersContext;
132
-
133
- filtersContext.inputWidth = avFrame->width ;
134
- filtersContext.inputHeight = avFrame->height ;
135
- filtersContext.inputFormat = frameFormat;
136
- filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio ;
137
- filtersContext.outputWidth = expectedOutputWidth;
138
- filtersContext.outputHeight = expectedOutputHeight;
139
- filtersContext.outputFormat = AV_PIX_FMT_RGB24;
140
- filtersContext.timeBase = timeBase;
141
-
142
- std::stringstream filters;
143
- filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
144
- filters << " :sws_flags=bilinear" ;
145
-
146
- filtersContext.filtergraphStr = filters.str ();
147
-
148
- if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
149
- filterGraphContext_ =
150
- std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
151
- prevFiltersContext_ = std::move (filtersContext);
152
- }
153
- outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
178
+ TORCH_CHECK_EQ (avFrame->format , AV_PIX_FMT_RGB24);
154
179
155
- // Similarly to above, if this check fails it means the frame wasn't
156
- // reshaped to its expected dimensions by filtergraph.
157
- auto shape = outputTensor.sizes ();
158
- TORCH_CHECK (
159
- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
160
- (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
161
- " Expected output tensor of shape " ,
162
- expectedOutputHeight,
163
- " x" ,
164
- expectedOutputWidth,
165
- " x3, got " ,
166
- shape);
180
+ std::vector<int64_t > shape = {expectedOutputHeight, expectedOutputWidth, 3 };
181
+ std::vector<int64_t > strides = {avFrame->linesize [0 ], 3 , 1 };
182
+ AVFrame* avFramePtr = avFrame.release ();
183
+ auto deleter = [avFramePtr](void *) {
184
+ UniqueAVFrame avFrameToDelete (avFramePtr);
185
+ };
186
+ outputTensor = torch::from_blob (
187
+ avFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
167
188
168
189
if (preAllocatedOutputTensor.has_value ()) {
169
190
// We have already validated that preAllocatedOutputTensor and
@@ -173,11 +194,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
173
194
} else {
174
195
frameOutput.data = outputTensor;
175
196
}
176
- } else {
177
- TORCH_CHECK (
178
- false ,
179
- " Invalid color conversion library: " ,
180
- static_cast <int >(colorConversionLibrary));
181
197
}
182
198
}
183
199
@@ -199,25 +215,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
199
215
return resultHeight;
200
216
}
201
217
202
- torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
203
- const UniqueAVFrame& avFrame) {
204
- UniqueAVFrame filteredAVFrame = filterGraphContext_->convert (avFrame);
205
-
206
- TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
207
-
208
- auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
209
- int height = frameDims.height ;
210
- int width = frameDims.width ;
211
- std::vector<int64_t > shape = {height, width, 3 };
212
- std::vector<int64_t > strides = {filteredAVFrame->linesize [0 ], 3 , 1 };
213
- AVFrame* filteredAVFramePtr = filteredAVFrame.release ();
214
- auto deleter = [filteredAVFramePtr](void *) {
215
- UniqueAVFrame avFrameToDelete (filteredAVFramePtr);
216
- };
217
- return torch::from_blob (
218
- filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
219
- }
220
-
221
218
void CpuDeviceInterface::createSwsContext (
222
219
const SwsFrameContext& swsFrameContext,
223
220
const enum AVColorSpace colorspace) {
0 commit comments