Skip to content

Commit e206e22

Browse files
committed
Support direct AVFrame conversion to tensor in CPU device interface
Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 13590bb commit e206e22

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,21 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
8383
enum AVPixelFormat frameFormat =
8484
static_cast<enum AVPixelFormat>(avFrame->format);
8585

86+
if (frameFormat == AV_PIX_FMT_RGB24 &&
87+
avFrame->width == expectedOutputWidth &&
88+
avFrame->height == expectedOutputHeight) {
89+
outputTensor = toTensor(avFrame);
90+
if (preAllocatedOutputTensor.has_value()) {
91+
// We have already validated that preAllocatedOutputTensor and
92+
// outputTensor have the same shape.
93+
preAllocatedOutputTensor.value().copy_(outputTensor);
94+
frameOutput.data = preAllocatedOutputTensor.value();
95+
} else {
96+
frameOutput.data = outputTensor;
97+
}
98+
return;
99+
}
100+
86101
// By default, we want to use swscale for color conversion because it is
87102
// faster. However, it has width requirements, so we may need to fall back
88103
// to filtergraph. We also need to respect what was requested from the
@@ -159,7 +174,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
159174
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
160175
prevFiltersContext_ = std::move(filtersContext);
161176
}
162-
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
177+
outputTensor = toTensor(filterGraphContext_->convert(avFrame));
163178

164179
// Similarly to above, if this check fails it means the frame wasn't
165180
// reshaped to its expected dimensions by filtergraph.
@@ -208,23 +223,20 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
208223
return resultHeight;
209224
}
210225

211-
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
212-
const UniqueAVFrame& avFrame) {
213-
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);
214-
215-
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
226+
torch::Tensor CpuDeviceInterface::toTensor(const UniqueAVFrame& avFrame) {
227+
TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24);
216228

217-
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
229+
auto frameDims = getHeightAndWidthFromResizedAVFrame(*avFrame.get());
218230
int height = frameDims.height;
219231
int width = frameDims.width;
220232
std::vector<int64_t> shape = {height, width, 3};
221-
std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
222-
AVFrame* filteredAVFramePtr = filteredAVFrame.release();
223-
auto deleter = [filteredAVFramePtr](void*) {
224-
UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
233+
std::vector<int64_t> strides = {avFrame->linesize[0], 3, 1};
234+
AVFrame* avFrameClone = av_frame_clone(avFrame.get());
235+
auto deleter = [avFrameClone](void*) {
236+
UniqueAVFrame avFrameToDelete(avFrameClone);
225237
};
226238
return torch::from_blob(
227-
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
239+
avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8});
228240
}
229241

230242
void CpuDeviceInterface::createSwsContext(

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ class CpuDeviceInterface : public DeviceInterface {
3939
const UniqueAVFrame& avFrame,
4040
torch::Tensor& outputTensor);
4141

42-
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
43-
const UniqueAVFrame& avFrame);
42+
torch::Tensor toTensor(const UniqueAVFrame& avFrame);
4443

4544
struct SwsFrameContext {
4645
int inputWidth = 0;

0 commit comments

Comments
 (0)