@@ -35,16 +35,20 @@ static bool g_cuda_beta = registerDeviceInterface(
35
35
36
36
static int CUDAAPI
37
37
pfnSequenceCallback (void * pUserData, CUVIDEOFORMAT* videoFormat) {
38
- BetaCudaDeviceInterface* decoder =
39
- static_cast <BetaCudaDeviceInterface*>(pUserData);
38
+ auto decoder = static_cast <BetaCudaDeviceInterface*>(pUserData);
40
39
return decoder->streamPropertyChange (videoFormat);
41
40
}
42
41
43
42
static int CUDAAPI
44
- pfnDecodePictureCallback (void * pUserData, CUVIDPICPARAMS* pPicParams) {
45
- BetaCudaDeviceInterface* decoder =
46
- static_cast <BetaCudaDeviceInterface*>(pUserData);
47
- return decoder->frameReadyForDecoding (pPicParams);
43
+ pfnDecodePictureCallback (void * pUserData, CUVIDPICPARAMS* picParams) {
44
+ auto decoder = static_cast <BetaCudaDeviceInterface*>(pUserData);
45
+ return decoder->frameReadyForDecoding (picParams);
46
+ }
47
+
48
+ static int CUDAAPI
49
+ pfnDisplayPictureCallback (void * pUserData, CUVIDPARSERDISPINFO* dispInfo) {
50
+ auto decoder = static_cast <BetaCudaDeviceInterface*>(pUserData);
51
+ return decoder->frameReadyInDisplayOrder (dispInfo);
48
52
}
49
53
50
54
static UniqueCUvideodecoder createDecoder (CUVIDEOFORMAT* videoFormat) {
@@ -142,7 +146,7 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
142
146
143
147
BetaCudaDeviceInterface::~BetaCudaDeviceInterface () {
144
148
// TODONVDEC P0: we probably need to free the frames that have been decoded by
145
- // NVDEC but not yet "mapped" - i.e. those that are still in frameBuffer_ ?
149
+ // NVDEC but not yet "mapped" - i.e. those that are still in readyFrames_ ?
146
150
147
151
if (decoder_) {
148
152
NVDECCache::getCache (device_.index ())
@@ -218,7 +222,7 @@ void BetaCudaDeviceInterface::initialize(const AVStream* avStream) {
218
222
parserParams.pUserData = this ;
219
223
parserParams.pfnSequenceCallback = pfnSequenceCallback;
220
224
parserParams.pfnDecodePicture = pfnDecodePictureCallback;
221
- parserParams.pfnDisplayPicture = nullptr ;
225
+ parserParams.pfnDisplayPicture = pfnDisplayPictureCallback ;
222
226
223
227
CUresult result = cuvidCreateVideoParser (&videoParser_, &parserParams);
224
228
TORCH_CHECK (
@@ -274,10 +278,6 @@ int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) {
274
278
cuvidPacket.flags = CUVID_PKT_TIMESTAMP;
275
279
cuvidPacket.timestamp = packet->pts ;
276
280
277
- // Like DALI: store packet PTS in queue to later assign to frames as they
278
- // come out
279
- packetsPtsQueue.push (packet->pts );
280
-
281
281
} else {
282
282
// End of stream packet
283
283
cuvidPacket.flags = CUVID_PKT_ENDOFSTREAM;
@@ -329,70 +329,38 @@ void BetaCudaDeviceInterface::applyBSF(ReferenceAVPacket& packet) {
329
329
// ready to be decoded, i.e. the parser received all the necessary packets for a
330
330
// given frame. It means we can send that frame to be decoded by the hardware
331
331
// NVDEC decoder by calling cuvidDecodePicture which is non-blocking.
332
- int BetaCudaDeviceInterface::frameReadyForDecoding (CUVIDPICPARAMS* pPicParams ) {
332
+ int BetaCudaDeviceInterface::frameReadyForDecoding (CUVIDPICPARAMS* picParams ) {
333
333
if (isFlushing_) {
334
334
return 0 ;
335
335
}
336
336
337
- TORCH_CHECK (pPicParams != nullptr , " Invalid picture parameters" );
337
+ TORCH_CHECK (picParams != nullptr , " Invalid picture parameters" );
338
338
TORCH_CHECK (decoder_, " Decoder not initialized before picture decode" );
339
339
340
340
// Send frame to be decoded by NVDEC - non-blocking call.
341
- CUresult result = cuvidDecodePicture (*decoder_.get (), pPicParams);
342
- if (result != CUDA_SUCCESS) {
343
- return 0 ; // Yes, you're reading that right, 0 mean error.
344
- }
341
+ CUresult result = cuvidDecodePicture (*decoder_.get (), picParams);
345
342
346
- // The frame was sent to be decoded on the NVDEC hardware. Now we store some
347
- // relevant info into our frame buffer so that we can retrieve the decoded
348
- // frame later when receiveFrame() is called.
349
- // Importantly we need to 'guess' the PTS of that frame. The heuristic we use
350
- // (like in DALI) is that the frames are ready to be decoded in the same order
351
- // as the packets were sent to the parser. So we assign the PTS of the frame
352
- // by popping the PTS of the oldest packet in our packetsPtsQueue (note:
353
- // oldest doesn't necessarily mean lowest PTS!).
343
+ // Yes, you're reading that right, 0 means error, 1 means success
344
+ return (result == CUDA_SUCCESS);
345
+ }
354
346
355
- TORCH_CHECK (
356
- // TODONVDEC P0 the queue may be empty, handle that.
357
- !packetsPtsQueue.empty (),
358
- " PTS queue is empty when decoding a frame" );
359
- int64_t guessedPts = packetsPtsQueue.front ();
360
- packetsPtsQueue.pop ();
361
-
362
- // Field values taken from DALI
363
- CUVIDPARSERDISPINFO dispInfo = {};
364
- dispInfo.picture_index = pPicParams->CurrPicIdx ;
365
- dispInfo.progressive_frame = !pPicParams->field_pic_flag ;
366
- dispInfo.top_field_first = pPicParams->bottom_field_flag ^ 1 ;
367
- dispInfo.repeat_first_field = 0 ;
368
- dispInfo.timestamp = guessedPts;
369
-
370
- FrameBuffer::Slot* slot = frameBuffer_.findEmptySlot ();
371
- slot->dispInfo = dispInfo;
372
- slot->guessedPts = guessedPts;
373
- slot->occupied = true ;
374
-
375
- return 1 ;
347
+ int BetaCudaDeviceInterface::frameReadyInDisplayOrder (
348
+ CUVIDPARSERDISPINFO* dispInfo) {
349
+ readyFrames_.push (*dispInfo);
350
+ return 1 ; // success
376
351
}
377
352
378
- // Moral equivalent of avcodec_receive_frame(). Here, we look for a decoded
379
- // frame with the exact desired PTS in our frame buffer. This logic is only
380
- // valid in exact seek_mode, for now.
381
- int BetaCudaDeviceInterface::receiveFrame (
382
- UniqueAVFrame& avFrame,
383
- int64_t desiredPts) {
384
- FrameBuffer::Slot* slot = frameBuffer_.findFrameWithExactPts (desiredPts);
385
- if (slot == nullptr ) {
353
+ // Moral equivalent of avcodec_receive_frame().
354
+ int BetaCudaDeviceInterface::receiveFrame (UniqueAVFrame& avFrame) {
355
+ if (readyFrames_.empty ()) {
386
356
// No frame found, instruct caller to try again later after sending more
387
357
// packets.
388
358
return AVERROR (EAGAIN);
389
359
}
390
-
391
- slot->occupied = false ;
392
- slot->guessedPts = -1 ;
360
+ CUVIDPARSERDISPINFO dispInfo = readyFrames_.front ();
361
+ readyFrames_.pop ();
393
362
394
363
CUVIDPROCPARAMS procParams = {};
395
- CUVIDPARSERDISPINFO dispInfo = slot->dispInfo ;
396
364
procParams.progressive_frame = dispInfo.progressive_frame ;
397
365
procParams.top_field_first = dispInfo.top_field_first ;
398
366
procParams.unpaired_field = dispInfo.repeat_first_field < 0 ;
@@ -452,7 +420,7 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame(
452
420
avFrame->width = width;
453
421
avFrame->height = height;
454
422
avFrame->format = AV_PIX_FMT_CUDA;
455
- avFrame->pts = dispInfo.timestamp ; // == guessedPts
423
+ avFrame->pts = dispInfo.timestamp ;
456
424
457
425
// TODONVDEC P0: Zero division error!!!
458
426
// TODONVDEC P0: Move AVRational arithmetic to FFMPEGCommon, and put the
@@ -518,13 +486,8 @@ void BetaCudaDeviceInterface::flush() {
518
486
519
487
isFlushing_ = false ;
520
488
521
- for (auto & slot : frameBuffer_) {
522
- slot.occupied = false ;
523
- slot.guessedPts = -1 ;
524
- }
525
-
526
- std::queue<int64_t > empty;
527
- packetsPtsQueue.swap (empty);
489
+ std::queue<CUVIDPARSERDISPINFO> emptyQueue;
490
+ std::swap (readyFrames_, emptyQueue);
528
491
529
492
eofSent_ = false ;
530
493
}
@@ -544,26 +507,4 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
544
507
avFrame, frameOutput, preAllocatedOutputTensor);
545
508
}
546
509
547
- BetaCudaDeviceInterface::FrameBuffer::Slot*
548
- BetaCudaDeviceInterface::FrameBuffer::findEmptySlot () {
549
- for (auto & slot : frameBuffer_) {
550
- if (!slot.occupied ) {
551
- return &slot;
552
- }
553
- }
554
- frameBuffer_.emplace_back ();
555
- return &frameBuffer_.back ();
556
- }
557
-
558
- BetaCudaDeviceInterface::FrameBuffer::Slot*
559
- BetaCudaDeviceInterface::FrameBuffer::findFrameWithExactPts (
560
- int64_t desiredPts) {
561
- for (auto & slot : frameBuffer_) {
562
- if (slot.occupied && slot.guessedPts == desiredPts) {
563
- return &slot;
564
- }
565
- }
566
- return nullptr ;
567
- }
568
-
569
510
} // namespace facebook::torchcodec
0 commit comments