diff --git a/k2/csrc/pytorch_context.cu b/k2/csrc/pytorch_context.cu index 14cbdc6d9..c5a05793f 100644 --- a/k2/csrc/pytorch_context.cu +++ b/k2/csrc/pytorch_context.cu @@ -153,7 +153,11 @@ class PytorchCudaContext : public Context { // so it is fine to invoke lazyInitCUDA() multiple times. // The call will be inlined since it is defined in the header // aten/src/ATen/Context.h - at::globalContext().lazyInitCUDA(); +#if K2_TORCH_VERSION_MAJOR > 2 || (K2_TORCH_VERSION_MAJOR == 2 && K2_TORCH_VERSION_MINOR >= 6) + at::globalContext().lazyInitDevice(torch::kCUDA); +#else + at::globalContext().lazyInitCUDA(torch::kCUDA); +#endif allocator_ = c10::cuda::CUDACachingAllocator::get(); K2_CHECK(allocator_->raw_deleter() != nullptr);