@@ -299,10 +299,11 @@ Tensor & detach_(Tensor & self) {
299
299
}
300
300
301
301
// Some ops in the following registration list are registered as catch-all kernels,
302
- // some as backend kernels for VariableTensorId. The reason for this is that some
303
- // ops also use dispatch (i.e. register CPU/CUDA/QuantizedCPU kernels) and those
304
- // need to get a separate VariableTensorId kernel instead of a catch-all kernel,
305
- // otherwise we won't ever call it for CPU/CUDA/QuantizedCPU tensors.
302
+ // some as catch-all kernels and additionally as backend kernels for VariableTensorId.
303
+ // The reason for this is that ops that also use dispatch (e.g. register CPU/CUDA/QuantizedCPU
304
+ // kernels) need to get a separate VariableTensorId kernel instead of a catch-all kernel,
305
+ // otherwise we won't ever call it for CPU/CUDA/QuantizedCPU tensors, because the backend
306
+ // kernel has a higher priority than catch-all kernels.
306
307
// Unfortunately, this setup doesn't work in NonVariableTypeMode because that will
307
308
// skip past variable kernels. So for ops that we want to use in NonVariableTypeMode
308
309
// (and that don't use dispatch), we register them as catch-all kernels instead.
@@ -329,6 +330,13 @@ static auto registry = torch::RegisterOperators()
329
330
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA))
330
331
.op(torch::RegisterOperators::options()
331
332
.schema(" aten::backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> ()" )
333
+ // For backward(), we need the catch-all kernel (see comment above), but we also need the VariableTensorId backend
334
+ // kernel, because when called with a VariableTensorId tensor, it goes through the variable fallback kernel,
335
+ // which calls callBoxed(), which doesn't support optional tensor arguments yet and backward() has an optional
336
+ // tensor argument.
337
+ // TODO Once callBoxed() supports optional tensor arguments, we can enable `use_c10_dispatcher: full` for backward()
338
+ // and remove the backend VariableTensorId kernel here, only leaving the catch-all kernel.
339
+ .impl_unboxedOnlyKernel<decltype(VariableType::backward), &VariableType::backward>(TensorTypeId::VariableTensorId)
332
340
.impl_unboxedOnlyCatchAllKernel<decltype(VariableType::backward), &VariableType::backward>()
333
341
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA))
334
342
.op(torch::RegisterOperators::options()
@@ -353,6 +361,13 @@ static auto registry = torch::RegisterOperators()
353
361
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA))
354
362
.op(torch::RegisterOperators::options()
355
363
.schema(" aten::requires_grad_(Tensor(a!) self, bool _requires_grad=True) -> Tensor(a!)" )
364
+ // For requires_grad_(), we need the catch-all kernel (see comment above), but we also need the VariableTensorId backend
365
+ // kernel, because when called with a VariableTensorId tensor, it goes through the variable fallback kernel,
366
+ // which calls callBoxed(), which doesn't support mutable tensor arguments yet and requires_grad_() has a mutable
367
+ // tensor argument.
368
+ // TODO Once callBoxed() supports mutable tensor arguments, we can enable `use_c10_dispatcher: full` for requires_grad_()
369
+ // and remove the backend VariableTensorId kernel here, only leaving the catch-all kernel.
370
+ .impl_unboxedOnlyKernel<decltype(VariableType::requires_grad_), &VariableType::requires_grad_>(TensorTypeId::VariableTensorId)
356
371
.impl_unboxedOnlyCatchAllKernel<decltype(VariableType::requires_grad_), &VariableType::requires_grad_>()
357
372
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA))
358
373
;
0 commit comments