Skip to content

Commit 79295d2

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into eqnfqef
2 parents a7f76a8 + 1a88391 commit 79295d2

File tree

4 files changed

+47
-15
lines changed

4 files changed

+47
-15
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,32 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
275275
}
276276

277277
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_);
278-
nppCtx_->hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream();
278+
279+
// Create a CUDA event and attach it to the AVFrame's CUDA stream. That's the
280+
// NVDEC stream, i.e. the CUDA stream that the frame was decoded on.
281+
// We will be waiting for this event to complete before calling the NPP
282+
// functions, to ensure NVDEC has finished decoding the frame before running
283+
// the NPP color-conversion.
284+
// Note that our code is generic and assumes that the NVDEC's stream can be
285+
// arbitrary, but unfortunately we know it's hardcoded to be the default
286+
// stream by FFmpeg:
287+
// https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
288+
TORCH_CHECK(
289+
hwFramesCtx->device_ctx != nullptr,
290+
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
291+
auto cudaDeviceCtx =
292+
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
293+
at::cuda::CUDAEvent nvdecDoneEvent;
294+
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
295+
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
296+
nvdecDoneEvent.record(nvdecStream);
297+
298+
// Don't start NPP work before NVDEC is done decoding the frame!
299+
at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream(deviceIndex);
300+
nvdecDoneEvent.block(nppStream);
301+
302+
// Create the NPP context if we haven't yet.
303+
nppCtx_->hStream = nppStream.stream();
279304
cudaError_t err =
280305
cudaStreamGetFlags(nppCtx_->hStream, &nppCtx_->nStreamFlags);
281306
TORCH_CHECK(

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,15 @@ int getNumChannels(const UniqueAVFrame& avFrame) {
6161
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
6262
return avFrame->ch_layout.nb_channels;
6363
#else
64-
return av_get_channel_layout_nb_channels(avFrame->channel_layout);
64+
int numChannels = av_get_channel_layout_nb_channels(avFrame->channel_layout);
65+
// Handle FFmpeg 4 bug where channel_layout and numChannels are 0 or unset
66+
// Set values based on avFrame->channels which appears to be correct
67+
// to allow successful initialization of SwrContext
68+
if (numChannels == 0 && avFrame->channels > 0) {
69+
avFrame->channel_layout = av_get_default_channel_layout(avFrame->channels);
70+
numChannels = avFrame->channels;
71+
}
72+
return numChannels;
6573
#endif
6674
}
6775

252 KB
Binary file not shown.

test/test_decoders.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,26 +1682,25 @@ def test_downsample_empty_frame(self):
16821682
frames_44100_to_8000.data, frames_8000.data, atol=0.03, rtol=0
16831683
)
16841684

1685-
def test_s16_ffmpeg4_bug(self):
1686-
# s16 fails on FFmpeg4 but can be decoded on other versions.
1687-
# Debugging logs show that we're hitting:
1688-
# [SWR @ 0x560a7abdaf80] Input channel count and layout are unset
1689-
# which seems to point to:
1690-
# https://github.com/FFmpeg/FFmpeg/blob/40a6963fbd0c47be358a3760480180b7b532e1e9/libswresample/swresample.c#L293-L305
1691-
# ¯\_(ツ)_/¯
1685+
def test_decode_s16_ffmpeg4(self):
1686+
# Non-regression test for https://github.com/pytorch/torchcodec/issues/843
1687+
# Ensures that decoding s16 on FFmpeg4 handles
1688+
# unset input channel count and layout
16921689

16931690
asset = SINE_MONO_S16
16941691
decoder = AudioDecoder(asset.path)
16951692
assert decoder.metadata.sample_rate == asset.sample_rate
16961693
assert decoder.metadata.sample_format == asset.sample_format
16971694

1698-
cm = (
1699-
pytest.raises(RuntimeError, match="The frame has 0 channels, expected 1.")
1700-
if get_ffmpeg_major_version() == 4
1701-
else contextlib.nullcontext()
1695+
test_samples = decoder.get_samples_played_in_range()
1696+
assert test_samples.data.shape[0] == decoder.metadata.num_channels
1697+
assert test_samples.sample_rate == decoder.metadata.sample_rate
1698+
reference_frames = asset.get_frame_data_by_range(
1699+
start=0, stop=1, stream_index=0
1700+
)
1701+
torch.testing.assert_close(
1702+
test_samples.data[0], reference_frames, atol=0, rtol=0
17021703
)
1703-
with cm:
1704-
decoder.get_samples_played_in_range()
17051704

17061705
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
17071706
@pytest.mark.parametrize("sample_rate", (None, 8000, 16_000, 44_1000))

0 commit comments

Comments
 (0)