Skip to content

Commit b661e9f

Browse files
Peng Yunquanfacebook-github-bot
authored andcommitted
Allowing an error handler from caller (pytorch#12487)
Summary: To give a chance to the caller to decide how to handle these errors, e.g., sending to logcat, or writing extra diagnostics events, etc. Differential Revision: D78232763
1 parent c2d6f3d commit b661e9f

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

runtime/kernel/operator_registry.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Kernel* registered_kernels = reinterpret_cast<Kernel*>(registered_kernels_data);
4949
size_t num_registered_kernels = 0;
5050

5151
// Registers the kernels, but may return an error.
52-
Error register_kernels_internal(const Span<const Kernel> kernels) {
52+
Error register_kernels_internal(const Span<const Kernel> kernels, ErrorHandler errorHandler) {
5353
// Operator registration happens in static initialization time before or after
5454
// PAL init, so call it here. It is safe to call multiple times.
5555
::et_pal_init();
@@ -74,12 +74,19 @@ Error register_kernels_internal(const Span<const Kernel> kernels) {
7474
ET_LOG(Error, "%s", kernels[i].name_);
7575
ET_LOG_KERNEL_KEY(kernels[i].kernel_key_);
7676
}
77+
78+
if (errorHandler != nullptr) {
79+
return errorHandler(Error::RegistrationExceedingMaxKernels);
80+
}
81+
7782
return Error::RegistrationExceedingMaxKernels;
7883
}
7984
// for debugging purpose
8085
ET_UNUSED const char* lib_name =
8186
et_pal_get_shared_library_name(kernels.data());
8287

88+
Error err = Error::Ok;
89+
8390
for (const auto& kernel : kernels) {
8491
// Linear search. This is fine if the number of kernels is small.
8592
for (size_t i = 0; i < num_registered_kernels; i++) {
@@ -88,24 +95,33 @@ Error register_kernels_internal(const Span<const Kernel> kernels) {
8895
kernel.kernel_key_ == k.kernel_key_) {
8996
ET_LOG(Error, "Re-registering %s, from %s", k.name_, lib_name);
9097
ET_LOG_KERNEL_KEY(k.kernel_key_);
91-
return Error::RegistrationAlreadyRegistered;
98+
err = Error::RegistrationAlreadyRegistered;
99+
continue;
92100
}
93101
}
102+
94103
registered_kernels[num_registered_kernels++] = kernel;
95104
}
96-
ET_LOG(
97-
Debug,
98-
"Successfully registered all kernels from shared library: %s",
99-
lib_name);
100105

101-
return Error::Ok;
106+
if (errorHandler != nullptr) {
107+
err = errorHandler(err);
108+
}
109+
110+
if (err == Error::Ok) {
111+
ET_LOG(
112+
Debug,
113+
"Successfully registered all kernels from shared library: %s",
114+
lib_name);
115+
}
116+
117+
return err;
102118
}
103119

104120
} // namespace
105121

106122
// Registers the kernels, but panics if an error occurs. Always returns Ok.
107-
Error register_kernels(const Span<const Kernel> kernels) {
108-
Error success = register_kernels_internal(kernels);
123+
Error register_kernels(const Span<const Kernel> kernels, ErrorHandler errorHandler) {
124+
Error success = register_kernels_internal(kernels, errorHandler);
109125
if (success == Error::RegistrationAlreadyRegistered ||
110126
success == Error::RegistrationExceedingMaxKernels) {
111127
ET_CHECK_MSG(

runtime/kernel/operator_registry.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,16 @@ ::executorch::runtime::Result<OpFunction> get_op_function_from_registry(
237237
*/
238238
Span<const Kernel> get_registered_kernels();
239239

240+
using ErrorHandler = Error (*)(Error errorCode);
241+
240242
/**
241243
* Registers the provided kernels.
242244
*
243245
* @param[in] kernels Kernel objects to register.
244246
* @retval Error::Ok always. Panics on error. This function needs to return a
245247
* non-void type to run at static initialization time.
246248
*/
247-
ET_NODISCARD Error register_kernels(const Span<const Kernel>);
249+
ET_NODISCARD Error register_kernels(const Span<const Kernel>, ErrorHandler errorHandler = nullptr);
248250

249251
/**
250252
* Registers a single kernel.
@@ -253,8 +255,8 @@ ET_NODISCARD Error register_kernels(const Span<const Kernel>);
253255
* @retval Error::Ok always. Panics on error. This function needs to return a
254256
* non-void type to run at static initialization time.
255257
*/
256-
ET_NODISCARD inline Error register_kernel(const Kernel& kernel) {
257-
return register_kernels({&kernel, 1});
258+
ET_NODISCARD inline Error register_kernel(const Kernel& kernel, ErrorHandler errorHandler = nullptr) {
259+
return register_kernels({&kernel, 1}, errorHandler);
258260
};
259261

260262
} // namespace ET_RUNTIME_NAMESPACE
@@ -269,12 +271,13 @@ using ::executorch::ET_RUNTIME_NAMESPACE::KernelKey;
269271
using ::executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext;
270272
using ::executorch::ET_RUNTIME_NAMESPACE::OpFunction;
271273
using ::executorch::ET_RUNTIME_NAMESPACE::TensorMeta;
274+
using ::executorch::ET_RUNTIME_NAMESPACE::ErrorHandler;
272275
using KernelRuntimeContext =
273276
::executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext;
274277

275-
inline ::executorch::runtime::Error register_kernels(ArrayRef<Kernel> kernels) {
278+
inline ::executorch::runtime::Error register_kernels(ArrayRef<Kernel> kernels, ErrorHandler errorHandler = nullptr) {
276279
return ::executorch::ET_RUNTIME_NAMESPACE::register_kernels(
277-
{kernels.data(), kernels.size()});
280+
{kernels.data(), kernels.size()}, errorHandler);
278281
}
279282
inline OpFunction getOpsFn(
280283
const char* name,
@@ -294,5 +297,6 @@ inline ArrayRef<Kernel> get_kernels() {
294297
::executorch::ET_RUNTIME_NAMESPACE::get_registered_kernels();
295298
return ArrayRef<Kernel>(kernels.data(), kernels.size());
296299
}
300+
297301
} // namespace executor
298302
} // namespace torch

0 commit comments

Comments
 (0)