Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): Support __replace__ in general #26

Merged
merged 2 commits into from
Mar 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/registry.h
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@ bool StructuralEqual(AnyView lhs, AnyView rhs, bool bind_free_vars, bool assert_
int64_t StructuralHash(AnyView root);
Any CopyShallow(AnyView root);
Any CopyDeep(AnyView root);
void CopyReplace(int32_t num_args, const AnyView *args, Any *ret);
Str DocToPythonScript(mlc::printer::Node node, mlc::printer::PrinterConfig cfg);
UDict BuildInfo();

@@ -650,6 +651,7 @@ inline TypeTable *TypeTable::New() {
self->SetFunc("mlc.core.StructuralHash", Func(::mlc::registry::StructuralHash).get());
self->SetFunc("mlc.core.CopyShallow", Func(::mlc::registry::CopyShallow).get());
self->SetFunc("mlc.core.CopyDeep", Func(::mlc::registry::CopyDeep).get());
self->SetFunc("mlc.core.CopyReplace", Func(::mlc::registry::CopyReplace).get());
self->SetFunc("mlc.core.BuildInfo", Func(::mlc::registry::BuildInfo).get());
self->SetFunc("mlc.core.TensorToBytes", Func(::mlc::registry::TensorToBytes).get());
self->SetFunc("mlc.core.TensorFromBytes", Func(::mlc::registry::TensorFromBytes).get());
61 changes: 60 additions & 1 deletion cpp/structure.cc
Original file line number Diff line number Diff line change
@@ -952,7 +952,9 @@ inline Any CopyShallowImpl(AnyView source) {
return UList(list->begin(), list->end());
} else if (UDictObj *dict = source.TryCast<UDictObj>()) {
return UDict(dict->begin(), dict->end());
} else if (source.IsInstance<StrObj>() || source.IsInstance<ErrorObj>() || source.IsInstance<FuncObj>()) {
} else if (source.IsInstance<StrObj>() || source.IsInstance<ErrorObj>() || source.IsInstance<FuncObj>() ||
source.IsInstance<TensorObj>()) {
// TODO: do we want to shallow copy these types at all?
return source;
}
struct Copier {
@@ -987,6 +989,62 @@ inline Any CopyShallowImpl(AnyView source) {
return ret;
}

inline void CopyReplaceImpl(int32_t num_args, const AnyView *args, Any *ret) {
if (num_args <= 0) {
MLC_THROW(InternalError) << "InternalError: `CopyReplace` requires at least one argument";
}
AnyView source = args[0];
int32_t type_index = source.type_index;
if (::mlc::base::IsTypeIndexPOD(type_index)) {
MLC_THROW(TypeError) << "TypeError: `__replace__` doesn't work on a POD type: " << source;
} else if (source.IsInstance<StrObj>() || source.IsInstance<ErrorObj>() || source.IsInstance<FuncObj>() ||
source.IsInstance<UListObj>() || source.IsInstance<UDictObj>() || source.IsInstance<TensorObj>()) {
MLC_THROW(TypeError) << "TypeError: `__replace__` doesn't work on type: " << source.GetTypeKey();
}
struct Copier {
MLC_INLINE void operator()(MLCTypeField *f, const Any *any) { AddField(f->name, AnyView(*any)); }
MLC_INLINE void operator()(MLCTypeField *f, ObjectRef *obj) { AddField(f->name, AnyView(*obj)); }
MLC_INLINE void operator()(MLCTypeField *f, Optional<ObjectRef> *opt) { AddField(f->name, AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *f, Optional<bool> *opt) { AddField(f->name, AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *f, Optional<int64_t> *opt) { AddField(f->name, AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *f, Optional<double> *opt) { AddField(f->name, AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *f, Optional<DLDevice> *opt) { AddField(f->name, AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *f, Optional<DLDataType> *opt) { AddField(f->name, AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *f, bool *v) { AddField(f->name, AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *f, int8_t *v) { AddField(f->name, AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *f, int16_t *v) { AddField(f->name, AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *f, int32_t *v) { AddField(f->name, AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *f, int64_t *v) { AddField(f->name, AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *f, float *v) { AddField(f->name, AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *f, double *v) { AddField(f->name, AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *f, DLDataType *v) { AddField(f->name, AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *f, DLDevice *v) { AddField(f->name, AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *f, Optional<void *> *v) { AddField(f->name, AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *f, void **v) { AddField(f->name, AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *f, const char **v) { AddField(f->name, AnyView(*v)); }

void AddField(std::string_view name, AnyView v) {
if (auto it = replacements->find(name); it != replacements->end()) {
fields->push_back(it->second);
} else {
fields->push_back(v);
}
}
std::vector<AnyView> *fields;
std::unordered_map<std::string_view, AnyView> *replacements;
};
std::unordered_map<std::string_view, AnyView> replacements;
for (int32_t i = 1; i < num_args; i += 2) {
const char *name = args[i];
replacements[name] = args[i + 1];
}
FuncObj *init_func = Lib::_init(type_index);
MLCTypeInfo *type_info = Lib::GetTypeInfo(type_index);
std::vector<AnyView> fields;
VisitFields(source.operator Object *(), type_info, Copier{&fields, &replacements});
::mlc::base::FuncCall(init_func, static_cast<int32_t>(fields.size()), fields.data(), ret);
}

inline Any CopyDeepImpl(AnyView source) {
if (::mlc::base::IsTypeIndexPOD(source.type_index)) {
return source;
@@ -1508,6 +1566,7 @@ int64_t StructuralHash(AnyView root) {

Any CopyShallow(AnyView source) { return CopyShallowImpl(source); }
Any CopyDeep(AnyView source) { return CopyDeepImpl(source); }
void CopyReplace(int32_t num_args, const AnyView *args, Any *ret) { CopyReplaceImpl(num_args, args, ret); }

Any JSONLoads(AnyView json_str) {
if (json_str.type_index == kMLCRawStr) {
5 changes: 5 additions & 0 deletions python/mlc/_cython/core.pyx
Original file line number Diff line number Diff line change
@@ -389,6 +389,10 @@ cdef class PyAny:
def _mlc_copy_deep(PyAny x) -> PyAny:
return func_call(_COPY_DEEP, (x,))

@staticmethod
def _mlc_copy_replace(*args) -> PyAny:
return func_call(_COPY_REPLACE, args)

@classmethod
def _C(cls, bytes name, *args):
cdef int32_t type_index = cls._mlc_type_info.type_index
@@ -1672,6 +1676,7 @@ cdef PyAny _STRUCUTRAL_EQUAL = func_get_untyped("mlc.core.StructuralEqual")
cdef PyAny _STRUCUTRAL_HASH = func_get_untyped("mlc.core.StructuralHash")
cdef PyAny _COPY_SHALLOW = func_get_untyped("mlc.core.CopyShallow")
cdef PyAny _COPY_DEEP = func_get_untyped("mlc.core.CopyDeep")
cdef PyAny _COPY_REPLACE = func_get_untyped("mlc.core.CopyReplace")
cdef PyAny _TENSOR_TO_DLPACK = func_get_untyped("mlc.core.TensorToDLPack")

cdef MLCVTableHandle _VTABLE_STR = _vtable_get_global(b"__str__")
7 changes: 7 additions & 0 deletions python/mlc/core/object.py
Original file line number Diff line number Diff line change
@@ -41,6 +41,13 @@ def __copy__(self: Object) -> Object:
def __deepcopy__(self: Object, memo: dict[int, Object] | None) -> Object:
return PyAny._mlc_copy_deep(self)

def __replace__(self: Object, /, **changes: typing.Any) -> Object:
unpacked: list[typing.Any] = [self]
for key, value in changes.items():
unpacked.append(key)
unpacked.append(value)
return PyAny._mlc_copy_replace(*unpacked)

def __hash__(self) -> int:
return hash((type(self), self._mlc_address))

1 change: 1 addition & 0 deletions python/mlc/dataclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -5,5 +5,6 @@
add_vtable_method,
field,
prototype,
replace,
vtable_method,
)
4 changes: 4 additions & 0 deletions python/mlc/dataclasses/utils.py
Original file line number Diff line number Diff line change
@@ -445,3 +445,7 @@ def prototype(
else:
raise ValueError(f"Invalid `lang`: {lang}")
return "\n\n".join(fn(i) for i in type_info_list)


def replace(obj: Any, /, **changes: Any) -> Any:
return obj.__replace__(**changes)
12 changes: 12 additions & 0 deletions tests/python/test_dataclasses_copy.py
Original file line number Diff line number Diff line change
@@ -270,3 +270,15 @@ def test_copy_deep_dataclass(test_obj: CustomInit) -> None:
assert src != dst
assert src.a == dst.a
assert src.b == dst.b


def test_copy_replace_dataclass(test_obj: CustomInit) -> None:
src = test_obj
dst = mlc.dataclasses.replace(src, a=2)
assert src != dst
assert src.a != dst.a
assert src.b == dst.b
assert src.a == 1
assert src.b == "hello"
assert dst.a == 2
assert dst.b == "hello"