Skip to content

Commit 6d72f11

Browse files
authored
BETA CUDA interface: support for approximate mode and time-based APIs (#917)
1 parent 401901e commit 6d72f11

File tree

6 files changed

+207
-148
lines changed

6 files changed

+207
-148
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 30 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,20 @@ static bool g_cuda_beta = registerDeviceInterface(
3535

3636
static int CUDAAPI
3737
pfnSequenceCallback(void* pUserData, CUVIDEOFORMAT* videoFormat) {
38-
BetaCudaDeviceInterface* decoder =
39-
static_cast<BetaCudaDeviceInterface*>(pUserData);
38+
auto decoder = static_cast<BetaCudaDeviceInterface*>(pUserData);
4039
return decoder->streamPropertyChange(videoFormat);
4140
}
4241

4342
static int CUDAAPI
44-
pfnDecodePictureCallback(void* pUserData, CUVIDPICPARAMS* pPicParams) {
45-
BetaCudaDeviceInterface* decoder =
46-
static_cast<BetaCudaDeviceInterface*>(pUserData);
47-
return decoder->frameReadyForDecoding(pPicParams);
43+
pfnDecodePictureCallback(void* pUserData, CUVIDPICPARAMS* picParams) {
44+
auto decoder = static_cast<BetaCudaDeviceInterface*>(pUserData);
45+
return decoder->frameReadyForDecoding(picParams);
46+
}
47+
48+
static int CUDAAPI
49+
pfnDisplayPictureCallback(void* pUserData, CUVIDPARSERDISPINFO* dispInfo) {
50+
auto decoder = static_cast<BetaCudaDeviceInterface*>(pUserData);
51+
return decoder->frameReadyInDisplayOrder(dispInfo);
4852
}
4953

5054
static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
@@ -142,7 +146,7 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
142146

143147
BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
144148
// TODONVDEC P0: we probably need to free the frames that have been decoded by
145-
// NVDEC but not yet "mapped" - i.e. those that are still in frameBuffer_?
149+
// NVDEC but not yet "mapped" - i.e. those that are still in readyFrames_?
146150

147151
if (decoder_) {
148152
NVDECCache::getCache(device_.index())
@@ -218,7 +222,7 @@ void BetaCudaDeviceInterface::initialize(const AVStream* avStream) {
218222
parserParams.pUserData = this;
219223
parserParams.pfnSequenceCallback = pfnSequenceCallback;
220224
parserParams.pfnDecodePicture = pfnDecodePictureCallback;
221-
parserParams.pfnDisplayPicture = nullptr;
225+
parserParams.pfnDisplayPicture = pfnDisplayPictureCallback;
222226

223227
CUresult result = cuvidCreateVideoParser(&videoParser_, &parserParams);
224228
TORCH_CHECK(
@@ -274,10 +278,6 @@ int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) {
274278
cuvidPacket.flags = CUVID_PKT_TIMESTAMP;
275279
cuvidPacket.timestamp = packet->pts;
276280

277-
// Like DALI: store packet PTS in queue to later assign to frames as they
278-
// come out
279-
packetsPtsQueue.push(packet->pts);
280-
281281
} else {
282282
// End of stream packet
283283
cuvidPacket.flags = CUVID_PKT_ENDOFSTREAM;
@@ -329,70 +329,38 @@ void BetaCudaDeviceInterface::applyBSF(ReferenceAVPacket& packet) {
329329
// ready to be decoded, i.e. the parser received all the necessary packets for a
330330
// given frame. It means we can send that frame to be decoded by the hardware
331331
// NVDEC decoder by calling cuvidDecodePicture which is non-blocking.
332-
int BetaCudaDeviceInterface::frameReadyForDecoding(CUVIDPICPARAMS* pPicParams) {
332+
int BetaCudaDeviceInterface::frameReadyForDecoding(CUVIDPICPARAMS* picParams) {
333333
if (isFlushing_) {
334334
return 0;
335335
}
336336

337-
TORCH_CHECK(pPicParams != nullptr, "Invalid picture parameters");
337+
TORCH_CHECK(picParams != nullptr, "Invalid picture parameters");
338338
TORCH_CHECK(decoder_, "Decoder not initialized before picture decode");
339339

340340
// Send frame to be decoded by NVDEC - non-blocking call.
341-
CUresult result = cuvidDecodePicture(*decoder_.get(), pPicParams);
342-
if (result != CUDA_SUCCESS) {
343-
return 0; // Yes, you're reading that right, 0 mean error.
344-
}
341+
CUresult result = cuvidDecodePicture(*decoder_.get(), picParams);
345342

346-
// The frame was sent to be decoded on the NVDEC hardware. Now we store some
347-
// relevant info into our frame buffer so that we can retrieve the decoded
348-
// frame later when receiveFrame() is called.
349-
// Importantly we need to 'guess' the PTS of that frame. The heuristic we use
350-
// (like in DALI) is that the frames are ready to be decoded in the same order
351-
// as the packets were sent to the parser. So we assign the PTS of the frame
352-
// by popping the PTS of the oldest packet in our packetsPtsQueue (note:
353-
// oldest doesn't necessarily mean lowest PTS!).
343+
// Yes, you're reading that right, 0 means error, 1 means success
344+
return (result == CUDA_SUCCESS);
345+
}
354346

355-
TORCH_CHECK(
356-
// TODONVDEC P0 the queue may be empty, handle that.
357-
!packetsPtsQueue.empty(),
358-
"PTS queue is empty when decoding a frame");
359-
int64_t guessedPts = packetsPtsQueue.front();
360-
packetsPtsQueue.pop();
361-
362-
// Field values taken from DALI
363-
CUVIDPARSERDISPINFO dispInfo = {};
364-
dispInfo.picture_index = pPicParams->CurrPicIdx;
365-
dispInfo.progressive_frame = !pPicParams->field_pic_flag;
366-
dispInfo.top_field_first = pPicParams->bottom_field_flag ^ 1;
367-
dispInfo.repeat_first_field = 0;
368-
dispInfo.timestamp = guessedPts;
369-
370-
FrameBuffer::Slot* slot = frameBuffer_.findEmptySlot();
371-
slot->dispInfo = dispInfo;
372-
slot->guessedPts = guessedPts;
373-
slot->occupied = true;
374-
375-
return 1;
347+
int BetaCudaDeviceInterface::frameReadyInDisplayOrder(
348+
CUVIDPARSERDISPINFO* dispInfo) {
349+
readyFrames_.push(*dispInfo);
350+
return 1; // success
376351
}
377352

378-
// Moral equivalent of avcodec_receive_frame(). Here, we look for a decoded
379-
// frame with the exact desired PTS in our frame buffer. This logic is only
380-
// valid in exact seek_mode, for now.
381-
int BetaCudaDeviceInterface::receiveFrame(
382-
UniqueAVFrame& avFrame,
383-
int64_t desiredPts) {
384-
FrameBuffer::Slot* slot = frameBuffer_.findFrameWithExactPts(desiredPts);
385-
if (slot == nullptr) {
353+
// Moral equivalent of avcodec_receive_frame().
354+
int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) {
355+
if (readyFrames_.empty()) {
386356
// No frame found, instruct caller to try again later after sending more
387357
// packets.
388358
return AVERROR(EAGAIN);
389359
}
390-
391-
slot->occupied = false;
392-
slot->guessedPts = -1;
360+
CUVIDPARSERDISPINFO dispInfo = readyFrames_.front();
361+
readyFrames_.pop();
393362

394363
CUVIDPROCPARAMS procParams = {};
395-
CUVIDPARSERDISPINFO dispInfo = slot->dispInfo;
396364
procParams.progressive_frame = dispInfo.progressive_frame;
397365
procParams.top_field_first = dispInfo.top_field_first;
398366
procParams.unpaired_field = dispInfo.repeat_first_field < 0;
@@ -452,7 +420,7 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame(
452420
avFrame->width = width;
453421
avFrame->height = height;
454422
avFrame->format = AV_PIX_FMT_CUDA;
455-
avFrame->pts = dispInfo.timestamp; // == guessedPts
423+
avFrame->pts = dispInfo.timestamp;
456424

457425
// TODONVDEC P0: Zero division error!!!
458426
// TODONVDEC P0: Move AVRational arithmetic to FFMPEGCommon, and put the
@@ -518,13 +486,8 @@ void BetaCudaDeviceInterface::flush() {
518486

519487
isFlushing_ = false;
520488

521-
for (auto& slot : frameBuffer_) {
522-
slot.occupied = false;
523-
slot.guessedPts = -1;
524-
}
525-
526-
std::queue<int64_t> empty;
527-
packetsPtsQueue.swap(empty);
489+
std::queue<CUVIDPARSERDISPINFO> emptyQueue;
490+
std::swap(readyFrames_, emptyQueue);
528491

529492
eofSent_ = false;
530493
}
@@ -544,26 +507,4 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
544507
avFrame, frameOutput, preAllocatedOutputTensor);
545508
}
546509

547-
BetaCudaDeviceInterface::FrameBuffer::Slot*
548-
BetaCudaDeviceInterface::FrameBuffer::findEmptySlot() {
549-
for (auto& slot : frameBuffer_) {
550-
if (!slot.occupied) {
551-
return &slot;
552-
}
553-
}
554-
frameBuffer_.emplace_back();
555-
return &frameBuffer_.back();
556-
}
557-
558-
BetaCudaDeviceInterface::FrameBuffer::Slot*
559-
BetaCudaDeviceInterface::FrameBuffer::findFrameWithExactPts(
560-
int64_t desiredPts) {
561-
for (auto& slot : frameBuffer_) {
562-
if (slot.occupied && slot.guessedPts == desiredPts) {
563-
return &slot;
564-
}
565-
}
566-
return nullptr;
567-
}
568-
569510
} // namespace facebook::torchcodec

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 93 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -50,51 +50,18 @@ class BetaCudaDeviceInterface : public DeviceInterface {
5050
}
5151

5252
int sendPacket(ReferenceAVPacket& packet) override;
53-
int receiveFrame(UniqueAVFrame& avFrame, int64_t desiredPts) override;
53+
int receiveFrame(UniqueAVFrame& avFrame) override;
5454
void flush() override;
5555

5656
// NVDEC callback functions (must be public for C callbacks)
5757
int streamPropertyChange(CUVIDEOFORMAT* videoFormat);
58-
int frameReadyForDecoding(CUVIDPICPARAMS* pPicParams);
58+
int frameReadyForDecoding(CUVIDPICPARAMS* picParams);
59+
int frameReadyInDisplayOrder(CUVIDPARSERDISPINFO* dispInfo);
5960

6061
private:
6162
// Apply bitstream filter, modifies packet in-place
6263
void applyBSF(ReferenceAVPacket& packet);
6364

64-
class FrameBuffer {
65-
public:
66-
struct Slot {
67-
CUVIDPARSERDISPINFO dispInfo;
68-
int64_t guessedPts;
69-
bool occupied = false;
70-
71-
Slot() : guessedPts(-1), occupied(false) {
72-
std::memset(&dispInfo, 0, sizeof(dispInfo));
73-
}
74-
};
75-
76-
// TODONVDEC P1: init size should probably be min_num_decode_surfaces from
77-
// video format
78-
FrameBuffer() : frameBuffer_(4) {}
79-
80-
~FrameBuffer() = default;
81-
82-
Slot* findEmptySlot();
83-
Slot* findFrameWithExactPts(int64_t desiredPts);
84-
85-
// Iterator support for range-based for loops
86-
auto begin() {
87-
return frameBuffer_.begin();
88-
}
89-
90-
auto end() {
91-
return frameBuffer_.end();
92-
}
93-
94-
private:
95-
std::vector<Slot> frameBuffer_;
96-
};
97-
9865
UniqueAVFrame convertCudaFrameToAVFrame(
9966
CUdeviceptr framePtr,
10067
unsigned int pitch,
@@ -104,9 +71,7 @@ class BetaCudaDeviceInterface : public DeviceInterface {
10471
UniqueCUvideodecoder decoder_;
10572
CUVIDEOFORMAT videoFormat_ = {};
10673

107-
FrameBuffer frameBuffer_;
108-
109-
std::queue<int64_t> packetsPtsQueue;
74+
std::queue<CUVIDPARSERDISPINFO> readyFrames_;
11075

11176
bool eofSent_ = false;
11277

@@ -125,3 +90,92 @@ class BetaCudaDeviceInterface : public DeviceInterface {
12590
};
12691

12792
} // namespace facebook::torchcodec
93+
94+
/* clang-format off */
95+
// Note: [General design, sendPacket, receiveFrame, frame ordering and NVCUVID callbacks]
96+
//
97+
// At a high level, this decoding interface mimics the FFmpeg send/receive
98+
// architecture:
99+
// - sendPacket(AVPacket) sends an AVPacket from the FFmpeg demuxer to the
100+
// NVCUVID parser.
101+
// - receiveFrame(AVFrame) is a non-blocking call:
102+
// - if a frame is ready **in display order**, it must return it. By display
103+
// order, we mean that receiveFrame() must return frames with increasing pts
104+
// values when called successively.
105+
// - if no frame is ready, it must return AVERROR(EAGAIN) to indicate the
106+
// caller should send more packets.
107+
//
108+
// The rest of this note assumes you have a reasonable level of familiarity with
109+
// the sendPacket/receiveFrame calling pattern. If you don't, look up the core
110+
// decoding loop in SingleVideoDecoder.
111+
//
112+
// The frame re-ordering problem:
113+
// Depending on the codec and on the encoding parameters, a packet from a video
114+
// stream may contain exactly one frame, more than one frame, or a fraction of a
115+
// frame. And, there may be non-linear frame dependencies because of B-frames,
116+
// which need both past *and* future frames to be decoded. Consider the
117+
// following stream, with frames presented in display order: I0 B1 P2 B3 P4 ...
118+
// - I0 is an I-frame (also key frame, can be decoded independently)
119+
// - B1 is a B-frame (bi-directional) which needs both I0 and P2 to be decoded
120+
// - P2 is a P-frame (predicted frame) which only needs I0 to be decodec.
121+
//
122+
// Because B1 needs both I0 and P2 to be properly decoded, the decode order
123+
// (packet order), defined by the encoder, must be: I0 P2 B1 P4 B3 ... which is
124+
// different from the display order.
125+
//
126+
// SendPacket(AVPacket)'s job is just to pass down the packet to the NVCUVID
127+
// parser by calling cuvidParseVideoData(packet). When
128+
// cuvidParseVideoData(packet) is called, it may trigger callbacks,
129+
// particularly:
130+
// - streamPropertyChange(videoFormat): triggered once at the start of the
131+
// stream, and possibly later if the stream properties change (e.g.
132+
// resolution).
133+
// - frameReadyForDecoding(picParams)): triggered **in decode order** when the
134+
// parser has accumulated enough data to decode a frame. We send that frame to
135+
// the NVDEC hardware for **async** decoding.
136+
// - frameReadyInDisplayOrder(dispInfo)): triggered **in display order** when a
137+
// frame is ready to be "displayed" (returned). At that point, the parser also
138+
// gives us the pts of that frame. We store (a reference to) that frame in a
139+
// FIFO queue: readyFrames_.
140+
//
141+
// When receiveFrame(AVFrame) is called, if readyFrames_ is not empty, we pop
142+
// the front of the queue, which is the next frame in display order, and map it
143+
// to an AVFrame by calling cuvidMapVideoFrame(). If readyFrames_ is empty we
144+
// return EAGAIN to indicate the caller should send more packets.
145+
//
146+
// There is potentially a small inefficiency due to the callback design: in
147+
// order for us to know that a frame is ready in display order, we need the
148+
// frameReadyInDisplayOrder callback to be triggered. This can only happen
149+
// within cuvidParseVideoData(packet) in sendPacket(). This means there may be
150+
// the following sequence of calls:
151+
//
152+
// sendPacket(relevantAVPacket)
153+
// cuvidParseVideoData(relevantAVPacket)
154+
// frameReadyForDecoding()
155+
// cuvidDecodePicture() Send frame to NVDEC for async decoding
156+
//
157+
// receiveFrame() -> EAGAIN Frame is potentially already decoded
158+
// and could be returned, but we don't
159+
// know because frameReadyInDisplayOrder
160+
// hasn't been triggered yet. We'll only
161+
// know after sending another,
162+
// potentially irrelevant packet.
163+
//
164+
// sendPacket(irrelevantAVPacket)
165+
// cuvidParseVideoData(irrelevantAVPacket)
166+
// frameReadyInDisplayOrder() Only now do we know that our target
167+
// frame is ready.
168+
//
169+
// receiveFrame() return target frame
170+
//
171+
// How much this matters in practice is unclear, but probably negligible in
172+
// general. Particularly when frames are decoded consecutively anyway, the
173+
// "irrelevantPacket" is actually relevant for a future target frame.
174+
//
175+
// Note that the alternative is to *not* rely on the frameReadyInDisplayOrder
176+
// callback. It's technically possible, but it would mean we now have to solve
177+
// two hard, *codec-dependent* problems that the callback was solving for us:
178+
// - we have to guess the frame's pts ourselves
179+
// - we have to re-order the frames ourselves to preserve display order.
180+
//
181+
/* clang-format on */

src/torchcodec/_core/DeviceInterface.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ class DeviceInterface {
101101
// Moral equivalent of avcodec_receive_frame()
102102
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready,
103103
// AVERROR_EOF if end of stream, or other AVERROR on failure
104-
virtual int receiveFrame(
105-
[[maybe_unused]] UniqueAVFrame& avFrame,
106-
[[maybe_unused]] int64_t desiredPts) {
104+
virtual int receiveFrame([[maybe_unused]] UniqueAVFrame& avFrame) {
107105
TORCH_CHECK(
108106
false,
109107
"Send/receive packet decoding not implemented for this device interface");

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1189,7 +1189,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11891189
// avcodec_send_packet. This would make the decoding loop even more generic.
11901190
while (true) {
11911191
if (deviceInterface_->canDecodePacketDirectly()) {
1192-
status = deviceInterface_->receiveFrame(avFrame, cursor_);
1192+
status = deviceInterface_->receiveFrame(avFrame);
11931193
} else {
11941194
status =
11951195
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());

src/torchcodec/decoders/_video_decoder.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,6 @@ def __init__(
155155
device_variant = device_split[2]
156156
device = ":".join(device_split[0:2])
157157

158-
# TODONVDEC P0 Support approximate mode. Not ideal to validate that here
159-
# either, but validating this at a lower level forces to add yet another
160-
# (temprorary) validation API to the device inteface
161-
if device_variant == "beta" and seek_mode != "exact":
162-
raise ValueError("Seek mode must be exact for BETA CUDA interface.")
163-
164158
core.add_video_stream(
165159
self._decoder,
166160
stream_index=stream_index,

0 commit comments

Comments
 (0)