diff --git a/cpp/registry.h b/cpp/registry.h index 0c20cc3c..bd9eeccb 100644 --- a/cpp/registry.h +++ b/cpp/registry.h @@ -25,6 +25,7 @@ bool StructuralEqual(AnyView lhs, AnyView rhs, bool bind_free_vars, bool assert_ int64_t StructuralHash(AnyView root); Any CopyShallow(AnyView root); Any CopyDeep(AnyView root); +void CopyReplace(int32_t num_args, const AnyView *args, Any *ret); Str DocToPythonScript(mlc::printer::Node node, mlc::printer::PrinterConfig cfg); UDict BuildInfo(); @@ -650,6 +651,7 @@ inline TypeTable *TypeTable::New() { self->SetFunc("mlc.core.StructuralHash", Func(::mlc::registry::StructuralHash).get()); self->SetFunc("mlc.core.CopyShallow", Func(::mlc::registry::CopyShallow).get()); self->SetFunc("mlc.core.CopyDeep", Func(::mlc::registry::CopyDeep).get()); + self->SetFunc("mlc.core.CopyReplace", Func(::mlc::registry::CopyReplace).get()); self->SetFunc("mlc.core.BuildInfo", Func(::mlc::registry::BuildInfo).get()); self->SetFunc("mlc.core.TensorToBytes", Func(::mlc::registry::TensorToBytes).get()); self->SetFunc("mlc.core.TensorFromBytes", Func(::mlc::registry::TensorFromBytes).get()); diff --git a/cpp/structure.cc b/cpp/structure.cc index fd1fceec..78ba9fff 100644 --- a/cpp/structure.cc +++ b/cpp/structure.cc @@ -952,7 +952,9 @@ inline Any CopyShallowImpl(AnyView source) { return UList(list->begin(), list->end()); } else if (UDictObj *dict = source.TryCast()) { return UDict(dict->begin(), dict->end()); - } else if (source.IsInstance() || source.IsInstance() || source.IsInstance()) { + } else if (source.IsInstance() || source.IsInstance() || source.IsInstance() || + source.IsInstance()) { + // TODO: do we want to shallow copy these types at all? return source; } struct Copier { @@ -987,6 +989,62 @@ inline Any CopyShallowImpl(AnyView source) { return ret; } +inline void CopyReplaceImpl(int32_t num_args, const AnyView *args, Any *ret) { + if (num_args <= 0) { + MLC_THROW(InternalError) << "InternalError: `CopyReplace` requires at least one argument"; + } + AnyView source = args[0]; + int32_t type_index = source.type_index; + if (::mlc::base::IsTypeIndexPOD(type_index)) { + MLC_THROW(TypeError) << "TypeError: `__replace__` doesn't work on a POD type: " << source; + } else if (source.IsInstance() || source.IsInstance() || source.IsInstance() || + source.IsInstance() || source.IsInstance() || source.IsInstance()) { + MLC_THROW(TypeError) << "TypeError: `__replace__` doesn't work on type: " << source.GetTypeKey(); + } + struct Copier { + MLC_INLINE void operator()(MLCTypeField *f, const Any *any) { AddField(f->name, AnyView(*any)); } + MLC_INLINE void operator()(MLCTypeField *f, ObjectRef *obj) { AddField(f->name, AnyView(*obj)); } + MLC_INLINE void operator()(MLCTypeField *f, Optional *opt) { AddField(f->name, AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *f, Optional *opt) { AddField(f->name, AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *f, Optional *opt) { AddField(f->name, AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *f, Optional *opt) { AddField(f->name, AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *f, Optional *opt) { AddField(f->name, AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *f, Optional *opt) { AddField(f->name, AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *f, bool *v) { AddField(f->name, AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *f, int8_t *v) { AddField(f->name, AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *f, int16_t *v) { AddField(f->name, AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *f, int32_t *v) { AddField(f->name, AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *f, int64_t *v) { AddField(f->name, AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *f, float *v) { AddField(f->name, AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *f, double *v) { AddField(f->name, AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *f, DLDataType *v) { AddField(f->name, AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *f, DLDevice *v) { AddField(f->name, AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *f, Optional *v) { AddField(f->name, AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *f, void **v) { AddField(f->name, AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *f, const char **v) { AddField(f->name, AnyView(*v)); } + + void AddField(std::string_view name, AnyView v) { + if (auto it = replacements->find(name); it != replacements->end()) { + fields->push_back(it->second); + } else { + fields->push_back(v); + } + } + std::vector *fields; + std::unordered_map *replacements; + }; + std::unordered_map replacements; + for (int32_t i = 1; i < num_args; i += 2) { + const char *name = args[i]; + replacements[name] = args[i + 1]; + } + FuncObj *init_func = Lib::_init(type_index); + MLCTypeInfo *type_info = Lib::GetTypeInfo(type_index); + std::vector fields; + VisitFields(source.operator Object *(), type_info, Copier{&fields, &replacements}); + ::mlc::base::FuncCall(init_func, static_cast(fields.size()), fields.data(), ret); +} + inline Any CopyDeepImpl(AnyView source) { if (::mlc::base::IsTypeIndexPOD(source.type_index)) { return source; @@ -1508,6 +1566,7 @@ int64_t StructuralHash(AnyView root) { Any CopyShallow(AnyView source) { return CopyShallowImpl(source); } Any CopyDeep(AnyView source) { return CopyDeepImpl(source); } +void CopyReplace(int32_t num_args, const AnyView *args, Any *ret) { CopyReplaceImpl(num_args, args, ret); } Any JSONLoads(AnyView json_str) { if (json_str.type_index == kMLCRawStr) { diff --git a/python/mlc/_cython/core.pyx b/python/mlc/_cython/core.pyx index 1a9eea8e..bb45d541 100644 --- a/python/mlc/_cython/core.pyx +++ b/python/mlc/_cython/core.pyx @@ -389,6 +389,10 @@ cdef class PyAny: def _mlc_copy_deep(PyAny x) -> PyAny: return func_call(_COPY_DEEP, (x,)) + @staticmethod + def _mlc_copy_replace(*args) -> PyAny: + return func_call(_COPY_REPLACE, args) + @classmethod def _C(cls, bytes name, *args): cdef int32_t type_index = cls._mlc_type_info.type_index @@ -1672,6 +1676,7 @@ cdef PyAny _STRUCUTRAL_EQUAL = func_get_untyped("mlc.core.StructuralEqual") cdef PyAny _STRUCUTRAL_HASH = func_get_untyped("mlc.core.StructuralHash") cdef PyAny _COPY_SHALLOW = func_get_untyped("mlc.core.CopyShallow") cdef PyAny _COPY_DEEP = func_get_untyped("mlc.core.CopyDeep") +cdef PyAny _COPY_REPLACE = func_get_untyped("mlc.core.CopyReplace") cdef PyAny _TENSOR_TO_DLPACK = func_get_untyped("mlc.core.TensorToDLPack") cdef MLCVTableHandle _VTABLE_STR = _vtable_get_global(b"__str__") diff --git a/python/mlc/core/object.py b/python/mlc/core/object.py index dcc9cca7..f21c6fcf 100644 --- a/python/mlc/core/object.py +++ b/python/mlc/core/object.py @@ -41,6 +41,13 @@ def __copy__(self: Object) -> Object: def __deepcopy__(self: Object, memo: dict[int, Object] | None) -> Object: return PyAny._mlc_copy_deep(self) + def __replace__(self: Object, /, **changes: typing.Any) -> Object: + unpacked: list[typing.Any] = [self] + for key, value in changes.items(): + unpacked.append(key) + unpacked.append(value) + return PyAny._mlc_copy_replace(*unpacked) + def __hash__(self) -> int: return hash((type(self), self._mlc_address)) diff --git a/python/mlc/dataclasses/__init__.py b/python/mlc/dataclasses/__init__.py index 7a90826c..85f8a69e 100644 --- a/python/mlc/dataclasses/__init__.py +++ b/python/mlc/dataclasses/__init__.py @@ -5,5 +5,6 @@ add_vtable_method, field, prototype, + replace, vtable_method, ) diff --git a/python/mlc/dataclasses/utils.py b/python/mlc/dataclasses/utils.py index 2271dc0b..0e1a9d2c 100644 --- a/python/mlc/dataclasses/utils.py +++ b/python/mlc/dataclasses/utils.py @@ -445,3 +445,7 @@ def prototype( else: raise ValueError(f"Invalid `lang`: {lang}") return "\n\n".join(fn(i) for i in type_info_list) + + +def replace(obj: Any, /, **changes: Any) -> Any: + return obj.__replace__(**changes) diff --git a/tests/python/test_dataclasses_copy.py b/tests/python/test_dataclasses_copy.py index 61906ff3..5fb23c36 100644 --- a/tests/python/test_dataclasses_copy.py +++ b/tests/python/test_dataclasses_copy.py @@ -270,3 +270,15 @@ def test_copy_deep_dataclass(test_obj: CustomInit) -> None: assert src != dst assert src.a == dst.a assert src.b == dst.b + + +def test_copy_replace_dataclass(test_obj: CustomInit) -> None: + src = test_obj + dst = mlc.dataclasses.replace(src, a=2) + assert src != dst + assert src.a != dst.a + assert src.b == dst.b + assert src.a == 1 + assert src.b == "hello" + assert dst.a == 2 + assert dst.b == "hello"