@@ -83,6 +83,21 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
83
83
enum AVPixelFormat frameFormat =
84
84
static_cast <enum AVPixelFormat>(avFrame->format );
85
85
86
+ if (frameFormat == AV_PIX_FMT_RGB24 &&
87
+ avFrame->width == expectedOutputWidth &&
88
+ avFrame->height == expectedOutputHeight) {
89
+ outputTensor = toTensor (avFrame);
90
+ if (preAllocatedOutputTensor.has_value ()) {
91
+ // We have already validated that preAllocatedOutputTensor and
92
+ // outputTensor have the same shape.
93
+ preAllocatedOutputTensor.value ().copy_ (outputTensor);
94
+ frameOutput.data = preAllocatedOutputTensor.value ();
95
+ } else {
96
+ frameOutput.data = outputTensor;
97
+ }
98
+ return ;
99
+ }
100
+
86
101
// By default, we want to use swscale for color conversion because it is
87
102
// faster. However, it has width requirements, so we may need to fall back
88
103
// to filtergraph. We also need to respect what was requested from the
@@ -159,7 +174,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
159
174
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
160
175
prevFiltersContext_ = std::move (filtersContext);
161
176
}
162
- outputTensor = convertAVFrameToTensorUsingFilterGraph ( avFrame);
177
+ outputTensor = toTensor (filterGraphContext_-> convert ( avFrame) );
163
178
164
179
// Similarly to above, if this check fails it means the frame wasn't
165
180
// reshaped to its expected dimensions by filtergraph.
@@ -208,23 +223,20 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
208
223
return resultHeight;
209
224
}
210
225
211
- torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
212
- const UniqueAVFrame& avFrame) {
213
- UniqueAVFrame filteredAVFrame = filterGraphContext_->convert (avFrame);
214
-
215
- TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
226
+ torch::Tensor CpuDeviceInterface::toTensor (const UniqueAVFrame& avFrame) {
227
+ TORCH_CHECK_EQ (avFrame->format , AV_PIX_FMT_RGB24);
216
228
217
- auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame .get ());
229
+ auto frameDims = getHeightAndWidthFromResizedAVFrame (*avFrame .get ());
218
230
int height = frameDims.height ;
219
231
int width = frameDims.width ;
220
232
std::vector<int64_t > shape = {height, width, 3 };
221
- std::vector<int64_t > strides = {filteredAVFrame ->linesize [0 ], 3 , 1 };
222
- AVFrame* filteredAVFramePtr = filteredAVFrame. release ( );
223
- auto deleter = [filteredAVFramePtr ](void *) {
224
- UniqueAVFrame avFrameToDelete (filteredAVFramePtr );
233
+ std::vector<int64_t > strides = {avFrame ->linesize [0 ], 3 , 1 };
234
+ AVFrame* avFrameClone = av_frame_clone (avFrame. get () );
235
+ auto deleter = [avFrameClone ](void *) {
236
+ UniqueAVFrame avFrameToDelete (avFrameClone );
225
237
};
226
238
return torch::from_blob (
227
- filteredAVFramePtr ->data [0 ], shape, strides, deleter, {torch::kUInt8 });
239
+ avFrameClone ->data [0 ], shape, strides, deleter, {torch::kUInt8 });
228
240
}
229
241
230
242
void CpuDeviceInterface::createSwsContext (
0 commit comments