diff --git a/include/mlc/base/lib.h b/include/mlc/base/lib.h index 01729f8e..925ee308 100644 --- a/include/mlc/base/lib.h +++ b/include/mlc/base/lib.h @@ -30,6 +30,8 @@ struct Lib { static FuncObj *FuncGetGlobal(const char *name, bool allow_missing = false); static ::mlc::Str CxxStr(AnyView obj); static ::mlc::Str Str(AnyView obj); + static int64_t StructuralHash(AnyView obj); + static bool StructuralEqual(AnyView a, AnyView b); static Any IRPrint(AnyView obj, AnyView printer, AnyView path); static const char *DeviceTypeToStr(int32_t device_type); static int32_t DeviceTypeFromStr(const char *source); diff --git a/include/mlc/core/all.h b/include/mlc/core/all.h index 6145e2c1..8df23abc 100644 --- a/include/mlc/core/all.h +++ b/include/mlc/core/all.h @@ -140,6 +140,18 @@ inline ::mlc::Str Lib::Str(AnyView obj) { ::mlc::base::FuncCall(func, 1, &obj, &ret); return ret; } +inline int64_t Lib::StructuralHash(AnyView obj) { + static FuncObj *func_hash_s = ::mlc::Lib::FuncGetGlobal("mlc.core.StructuralHash"); + Any ret; + ::mlc::base::FuncCall(func_hash_s, 1, &obj, &ret); + return ret; +} +inline bool Lib::StructuralEqual(AnyView a, AnyView b) { + static FuncObj *func_eq_s = ::mlc::Lib::FuncGetGlobal("mlc.core.StructuralEqual"); + Any ret; + ::mlc::base::FuncCall(func_eq_s, 2, std::array{a, b}.data(), &ret); + return ret; +} inline Any Lib::IRPrint(AnyView obj, AnyView printer, AnyView path) { FuncObj *func = Lib::VTableGetFunc(ir_print, obj.GetTypeIndex(), "__ir_print__"); Any ret; @@ -206,7 +218,11 @@ template inline R VTable::operator()(Args... args Any ret; stack_args.Fill(std::forward(args)...); MLC_CHECK_ERR(::MLCVTableCall(self, N, stack_args.v, &ret)); - return ret; + if constexpr (std::is_same_v) { + return; + } else { + return ret; + } } template inline VTable &VTable::Set(Func func) { constexpr bool override_mode = false; diff --git a/include/mlc/core/object.h b/include/mlc/core/object.h index 286d5a2f..868f974f 100644 --- a/include/mlc/core/object.h +++ b/include/mlc/core/object.h @@ -186,6 +186,18 @@ struct Exception : public std::exception { void FormatExc(std::ostream &os) const; Ref data_; }; +struct ObjRefHash { + std::size_t operator()(const ObjectRef &obj) const { return std::hash{}(obj.get()); } +}; +struct ObjRefEqual { + bool operator()(const ObjectRef &a, const ObjectRef &b) const { return a.get() == b.get(); } +}; +struct StructuralHash { + std::size_t operator()(const ObjectRef &obj) const { return ::mlc::Lib::StructuralHash(obj); } +}; +struct StructuralEqual { + bool operator()(const ObjectRef &a, const ObjectRef &b) const { return ::mlc::Lib::StructuralEqual(a, b); } +}; } // namespace mlc #endif // MLC_CORE_OBJECT_H_