Skip to content

Commit 18c03b9

Browse files
bdhirshfacebook-github-bot
authored andcommittedDec 10, 2020
make duplicate def() calls an error in the dispatcher (pytorch#48098)
Summary: Pull Request resolved: pytorch#48098 Test Plan: Imported from OSS *** make duplicate def() calls an error in the dispatcher. Updating all fb operators to use the new dispatcher registration API Reviewed By: ezyang Differential Revision: D25056089 Pulled By: bdhirsh fbshipit-source-id: 8d7e381f16498a69cd20e6955d69acdc9a1d2791
1 parent 2519348 commit 18c03b9

File tree

2 files changed

+5
-27
lines changed

2 files changed

+5
-27
lines changed
 

‎aten/src/ATen/core/dispatch/Dispatcher.cpp

+5-26
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,11 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin
134134
OperatorName op_name = schema.operator_name();
135135
auto op = findOrRegisterName_(op_name);
136136

137-
if (op.operatorIterator_->def_count == 0) {
138-
// NB: registerSchema is not idempotent! Only do it once!
139-
op.operatorIterator_->op.registerSchema(std::move(schema), std::move(debug));
140-
listeners_->callOnOperatorRegistered(op);
141-
} else {
142-
checkSchemaCompatibility(op, schema, debug);
143-
}
137+
TORCH_CHECK(op.operatorIterator_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
138+
" Each overload's schema should only be registered with a single call to def().",
139+
" Duplicate registration: ", debug, ". Original registration: ", op.operatorIterator_->op.debug());
140+
op.operatorIterator_->op.registerSchema(std::move(schema), std::move(debug));
141+
listeners_->callOnOperatorRegistered(op);
144142

145143
// NB: do not increment the counts until AFTER error checking
146144
++op.operatorIterator_->def_count;
@@ -151,25 +149,6 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin
151149
});
152150
}
153151

154-
void Dispatcher::checkSchemaCompatibility(const OperatorHandle& op, const FunctionSchema& schema, const std::string& debug) {
155-
TORCH_CHECK(op.schema() == schema, "Tried to register multiple operators with the same name and the same overload name but different schemas: ", schema, " (", debug, ") vs ", op.schema(), " (", op.debug(), ")");
156-
if (schema.isDefaultAliasAnalysisKind()) {
157-
// [BACKWARDS COMPAT] If the *new* schema is the default alias analysis
158-
// kind, for BC, we will accept it. If we don't accept it, most extensions
159-
// that override existing operators will stop working (as they generally did
160-
// not specify alias information).
161-
} else if (op.schema().isDefaultAliasAnalysisKind()) {
162-
// [BACKWARDS COMPAT] If you POST-FACTO specify a non-default alias analysis
163-
// kind after we already have a schema for a function, bong it in for BC
164-
// reasons.
165-
op.operatorIterator_->op.updateSchemaAliasAnalysis(schema.aliasAnalysis());
166-
} else {
167-
TORCH_CHECK(op.schema().aliasAnalysis() == schema.aliasAnalysis(),
168-
"Tried to define the schema for ", toString(op.operator_name()), " with different alias analysis kinds: ",
169-
toString(op.schema().aliasAnalysis()), " (", op.debug(), ") vs ", toString(schema.aliasAnalysis()), " (", debug, ")");
170-
}
171-
}
172-
173152
void Dispatcher::deregisterDef_(const OperatorHandle& op, const OperatorName& op_name) {
174153
// we need a lock to avoid concurrent writes
175154
std::lock_guard<std::mutex> lock(mutex_);

‎aten/src/ATen/native/quantized/library.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ TORCH_LIBRARY(quantized, m) {
132132
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
133133
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool pruned_weights=False) -> Tensor"));
134134
m.def(TORCH_SELECTIVE_SCHEMA("quantized::celu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1) -> Tensor"));
135-
m.def(TORCH_SELECTIVE_SCHEMA("quantized::hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor"));
136135
m.def(TORCH_SELECTIVE_SCHEMA("quantized::group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor"));
137136
m.def(TORCH_SELECTIVE_SCHEMA("quantized::hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor"));
138137
m.def(TORCH_SELECTIVE_SCHEMA("quantized::instance_norm(Tensor input, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor"));

0 commit comments

Comments
 (0)
Please sign in to comment.