diff --git a/k2/csrc/pytorch_context.cu b/k2/csrc/pytorch_context.cu index 14cbdc6d9..f732d9794 100644 --- a/k2/csrc/pytorch_context.cu +++ b/k2/csrc/pytorch_context.cu @@ -153,7 +153,11 @@ class PytorchCudaContext : public Context { // so it is fine to invoke lazyInitCUDA() multiple times. // The call will be inlined since it is defined in the header // aten/src/ATen/Context.h +#if K2_TORCH_VERSION_MAJOR > 2 || (K2_TORCH_VERSION_MAJOR == 2 && K2_TORCH_VERSION_MINOR >= 6) + at::globalContext().lazyInitDevice(torch::kCUDA); +#else at::globalContext().lazyInitCUDA(); +#endif allocator_ = c10::cuda::CUDACachingAllocator::get(); K2_CHECK(allocator_->raw_deleter() != nullptr);