Skip to content

Commit e123d90

Browse files
smessmerfacebook-github-bot
authored andcommitted
Back out "Back out "Back out "Revert D18542342: Boxed variable dispatch""" (pytorch#30650)
Summary: Pull Request resolved: pytorch#30650 Original commit changeset: 51bb7aac7cb7 ghstack-source-id: 95082205 Test Plan: CI Differential Revision: D18778190 fbshipit-source-id: 7e9577e88fd0492006b6ea836ec081aea9da6b0c
1 parent 37435d3 commit e123d90

File tree

5 files changed

+62
-10
lines changed

5 files changed

+62
-10
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#include <ATen/core/dispatch/Dispatcher.h>
2+
#include <ATen/core/LegacyTypeDispatch.h>
3+
4+
/*
5+
* This file implements a variable fallback kernel for custom operators.
6+
* Since tensors always have the VariableTensorId set, but custom operators
7+
* usually don't have a kernel registered for VariableTensorId, the dispatcher
8+
* will call into this fallback kernel instead.
9+
* Note that this is not a correct autograd implementation. It will just
10+
* fallthrough to the custom operator implementation.
11+
* If you want a custom operator to work with autograd, you need to use
12+
* autograd::Function so that the custom operator implementation knows how to
13+
* do autograd.
14+
* Note also that ops from native_functions.yaml register their own variable
15+
* kernels, so this is never called for them.
16+
*/
17+
18+
// TODO This whole file should be deleted and replaced with the mechanism
19+
// described in https://github.com/pytorch/pytorch/issues/29548
20+
21+
using c10::OperatorHandle;
22+
using c10::Stack;
23+
using c10::TensorTypeId;
24+
using c10::TensorTypeSet;
25+
using c10::Dispatcher;
26+
using c10::KernelFunction;
27+
28+
namespace {
29+
30+
void variable_fallback_kernel(const OperatorHandle& op, Stack* stack) {
31+
at::AutoNonVariableTypeMode _var_guard(true);
32+
Dispatcher::singleton().callBoxed(op, stack);
33+
}
34+
35+
static auto registry = Dispatcher::singleton().registerBackendFallbackKernel(
36+
TensorTypeId::VariableTensorId,
37+
KernelFunction::makeFromBoxedFunction<&variable_fallback_kernel>()
38+
);
39+
40+
}

aten/src/ATen/core/dispatch/DispatchKeyExtractor.h

+1-6
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,7 @@ struct DispatchKeyExtractor final {
8484
}
8585
}
8686
}
87-
if (C10_UNLIKELY(ts.empty())) {
88-
return c10::nullopt;
89-
}
90-
91-
// TODO: Don't use legacy extractor; blocked on c10 understanding variable
92-
return c10::legacyExtractTypeId(ts);
87+
return typeSetToDispatchKey_(ts);
9388
}
9489

9590
template<class... Args>

aten/src/ATen/core/dispatch/Dispatcher.h

+1
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
189189

190190
inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatchTable, c10::optional<TensorTypeId> dispatchKey) const {
191191
if (C10_LIKELY(dispatchKey.has_value())) {
192+
192193
const KernelFunction* backendKernel = dispatchTable.lookup(*dispatchKey);
193194

194195
if (nullptr != backendKernel) {

caffe2/c2_aten_srcs.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ ATEN_CORE_HEADER_FILES = [
99

1010
ATEN_CORE_SRC_FILES = [
1111
"aten/src/ATen/core/grad_mode.cpp",
12+
"aten/src/ATen/core/VariableFallbackKernel.cpp",
1213
]

torch/csrc/autograd/VariableTypeManual.cpp

+19-4
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,11 @@ Tensor & detach_(Tensor & self) {
299299
}
300300

301301
// Some ops in the following registration list are registered as catch-all kernels,
302-
// some as backend kernels for VariableTensorId. The reason for this is that some
303-
// ops also use dispatch (i.e. register CPU/CUDA/QuantizedCPU kernels) and those
304-
// need to get a separate VariableTensorId kernel instead of a catch-all kernel,
305-
// otherwise we won't ever call it for CPU/CUDA/QuantizedCPU tensors.
302+
// some as catch-all kernels and additionally as backend kernels for VariableTensorId.
303+
// The reason for this is that ops that also use dispatch (e.g. register CPU/CUDA/QuantizedCPU
304+
// kernels) need to get a separate VariableTensorId kernel instead of a catch-all kernel,
305+
// otherwise we won't ever call it for CPU/CUDA/QuantizedCPU tensors, because the backend
306+
// kernel has a higher priority than catch-all kernels.
306307
// Unfortunately, this setup doesn't work in NonVariableTypeMode because that will
307308
// skip past variable kernels. So for ops that we want to use in NonVariableTypeMode
308309
// (and that don't use dispatch), we register them as catch-all kernels instead.
@@ -329,6 +330,13 @@ static auto registry = torch::RegisterOperators()
329330
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA))
330331
.op(torch::RegisterOperators::options()
331332
.schema("aten::backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> ()")
333+
// For backward(), we need the catch-all kernel (see comment above), but we also need the VariableTensorId backend
334+
// kernel, because when called with a VariableTensorId tensor, it goes through the variable fallback kernel,
335+
// which calls callBoxed(), which doesn't support optional tensor arguments yet and backward() has an optional
336+
// tensor argument.
337+
// TODO Once callBoxed() supports optional tensor arguments, we can enable `use_c10_dispatcher: full` for backward()
338+
// and remove the backend VariableTensorId kernel here, only leaving the catch-all kernel.
339+
.impl_unboxedOnlyKernel<decltype(VariableType::backward), &VariableType::backward>(TensorTypeId::VariableTensorId)
332340
.impl_unboxedOnlyCatchAllKernel<decltype(VariableType::backward), &VariableType::backward>()
333341
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA))
334342
.op(torch::RegisterOperators::options()
@@ -353,6 +361,13 @@ static auto registry = torch::RegisterOperators()
353361
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA))
354362
.op(torch::RegisterOperators::options()
355363
.schema("aten::requires_grad_(Tensor(a!) self, bool _requires_grad=True) -> Tensor(a!)")
364+
// For requires_grad_(), we need the catch-all kernel (see comment above), but we also need the VariableTensorId backend
365+
// kernel, because when called with a VariableTensorId tensor, it goes through the variable fallback kernel,
366+
// which calls callBoxed(), which doesn't support mutable tensor arguments yet and requires_grad_() has a mutable
367+
// tensor argument.
368+
// TODO Once callBoxed() supports mutable tensor arguments, we can enable `use_c10_dispatcher: full` for requires_grad_()
369+
// and remove the backend VariableTensorId kernel here, only leaving the catch-all kernel.
370+
.impl_unboxedOnlyKernel<decltype(VariableType::requires_grad_), &VariableType::requires_grad_>(TensorTypeId::VariableTensorId)
356371
.impl_unboxedOnlyCatchAllKernel<decltype(VariableType::requires_grad_), &VariableType::requires_grad_>()
357372
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA))
358373
;

0 commit comments

Comments
 (0)