Skip to content

Commit f298fa7

Browse files
committed
Support direct AVFrame conversion to tensor in CPU device interface
Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 652d2a2 commit f298fa7

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,24 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
8383
enum AVPixelFormat frameFormat =
8484
static_cast<enum AVPixelFormat>(avFrame->format);
8585

86+
// This is an early-return optimization: if the format is already what we
87+
// need, and the dimensions are also what we need, we don't need to call
88+
// swscale or filtergraph. We can just convert the AVFrame to a tensor.
89+
if (frameFormat == AV_PIX_FMT_RGB24 &&
90+
avFrame->width == expectedOutputWidth &&
91+
avFrame->height == expectedOutputHeight) {
92+
outputTensor = toTensor(avFrame);
93+
if (preAllocatedOutputTensor.has_value()) {
94+
// We have already validated that preAllocatedOutputTensor and
95+
// outputTensor have the same shape.
96+
preAllocatedOutputTensor.value().copy_(outputTensor);
97+
frameOutput.data = preAllocatedOutputTensor.value();
98+
} else {
99+
frameOutput.data = outputTensor;
100+
}
101+
return;
102+
}
103+
86104
// By default, we want to use swscale for color conversion because it is
87105
// faster. However, it has width requirements, so we may need to fall back
88106
// to filtergraph. We also need to respect what was requested from the
@@ -159,7 +177,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
159177
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
160178
prevFiltersContext_ = std::move(filtersContext);
161179
}
162-
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
180+
outputTensor = toTensor(filterGraphContext_->convert(avFrame));
163181

164182
// Similarly to above, if this check fails it means the frame wasn't
165183
// reshaped to its expected dimensions by filtergraph.
@@ -208,23 +226,20 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
208226
return resultHeight;
209227
}
210228

211-
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
212-
const UniqueAVFrame& avFrame) {
213-
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);
214-
215-
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
229+
torch::Tensor CpuDeviceInterface::toTensor(const UniqueAVFrame& avFrame) {
230+
TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24);
216231

217-
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
232+
auto frameDims = getHeightAndWidthFromResizedAVFrame(*avFrame.get());
218233
int height = frameDims.height;
219234
int width = frameDims.width;
220235
std::vector<int64_t> shape = {height, width, 3};
221-
std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
222-
AVFrame* filteredAVFramePtr = filteredAVFrame.release();
223-
auto deleter = [filteredAVFramePtr](void*) {
224-
UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
236+
std::vector<int64_t> strides = {avFrame->linesize[0], 3, 1};
237+
AVFrame* avFrameClone = av_frame_clone(avFrame.get());
238+
auto deleter = [avFrameClone](void*) {
239+
UniqueAVFrame avFrameToDelete(avFrameClone);
225240
};
226241
return torch::from_blob(
227-
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
242+
avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8});
228243
}
229244

230245
void CpuDeviceInterface::createSwsContext(

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ class CpuDeviceInterface : public DeviceInterface {
3939
const UniqueAVFrame& avFrame,
4040
torch::Tensor& outputTensor);
4141

42-
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
43-
const UniqueAVFrame& avFrame);
42+
torch::Tensor toTensor(const UniqueAVFrame& avFrame);
4443

4544
struct SwsFrameContext {
4645
int inputWidth = 0;

0 commit comments

Comments
 (0)