Skip to content

Commit 11db7bc

Browse files
authored
feat(dataclasses): Unify c_class more with py_class (#81)
1 parent 35a5fc1 commit 11db7bc

File tree

10 files changed

+239
-238
lines changed

10 files changed

+239
-238
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ dependencies = [
66
'ml-dtypes >= 0.1',
77
'Pygments>=2.4.0',
88
'colorama',
9+
'typing-extensions >= 4.9.0',
910
'setuptools ; platform_system == "Windows"',
1011
]
1112
description = "Python-first Development for AI Compilers"

python/mlc/_cython/base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,13 +366,13 @@ def __new__(
366366
return super().__new__(cls, name, bases, dict)
367367

368368

369-
def attach_field(
369+
def make_field(
370370
cls: type,
371371
name: str,
372372
getter: typing.Callable[[typing.Any], typing.Any] | None,
373373
setter: typing.Callable[[typing.Any, typing.Any], None] | None,
374374
frozen: bool,
375-
) -> None:
375+
) -> property:
376376
def fget(this: typing.Any, _name: str = name) -> typing.Any:
377377
return getter(this) # type: ignore[misc]
378378

@@ -383,12 +383,21 @@ def fset(this: typing.Any, value: typing.Any, _name: str = name) -> None:
383383
fget.__module__ = fset.__module__ = cls.__module__
384384
fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}" # type: ignore[attr-defined]
385385
fget.__doc__ = fset.__doc__ = f"Property `{name}` of class `{cls.__qualname__}`" # type: ignore[attr-defined]
386-
prop = property(
386+
return property(
387387
fget=fget if getter else None,
388388
fset=fset if (not frozen) and setter else None,
389389
doc=f"{cls.__module__}.{cls.__qualname__}.{name}",
390390
)
391-
setattr(cls, name, prop)
391+
392+
393+
def attach_field(
394+
cls: type,
395+
name: str,
396+
getter: typing.Callable[[typing.Any], typing.Any] | None,
397+
setter: typing.Callable[[typing.Any, typing.Any], None] | None,
398+
frozen: bool,
399+
) -> None:
400+
setattr(cls, name, make_field(cls, name, getter, setter, frozen)) # type: ignore[call-arg]
392401

393402

394403
def attach_method(

python/mlc/dataclasses/c_class.py

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,68 @@
1-
import functools
21
import typing
32
import warnings
43
from collections.abc import Callable
54

5+
try:
6+
from typing import dataclass_transform
7+
except ImportError:
8+
from typing_extensions import dataclass_transform
9+
610
from mlc._cython import (
711
TypeInfo,
812
TypeMethod,
9-
attach_field,
10-
attach_method,
1113
type_index2type_methods,
1214
type_key2py_type_info,
1315
)
1416
from mlc.core import typing as mlc_typing
1517

16-
from .utils import (
17-
add_vtable_methods_for_type_cls,
18-
get_parent_type,
19-
inspect_dataclass_fields,
20-
method_init,
21-
prototype,
22-
)
18+
from . import utils
2319

24-
ClsType = typing.TypeVar("ClsType")
20+
InputClsType = typing.TypeVar("InputClsType")
2521

2622

23+
@dataclass_transform(field_specifiers=(utils.field, utils.Field))
2724
def c_class(
2825
type_key: str,
2926
init: bool = True,
30-
) -> Callable[[type[ClsType]], type[ClsType]]:
31-
def decorator(super_type_cls: type[ClsType]) -> type[ClsType]:
32-
@functools.wraps(super_type_cls, updated=())
33-
class type_cls(super_type_cls): # type: ignore[valid-type,misc]
34-
__slots__ = ()
35-
27+
) -> Callable[[type[InputClsType]], type[InputClsType]]:
28+
def decorator(super_type_cls: type[InputClsType]) -> type[InputClsType]:
3629
# Step 1. Retrieve `type_info` from registry
30+
parent_type_info: TypeInfo = utils.get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined]
3731
type_info: TypeInfo = type_key2py_type_info(type_key)
38-
parent_type_info: TypeInfo = get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined]
39-
4032
if type_info.type_cls is not None:
4133
raise ValueError(f"Type is already registered: {type_key}")
42-
_, d_fields = inspect_dataclass_fields(type_key, type_cls, parent_type_info, frozen=False)
43-
type_info.type_cls = type_cls
44-
type_info.d_fields = tuple(d_fields)
4534

46-
# Step 2. Check if all fields are exposed as type annotations
35+
# Step 2. Reflect all the fields of the type
36+
_, d_fields, _ = utils.inspect_dataclass_fields(
37+
super_type_cls,
38+
parent_type_info,
39+
frozen=False,
40+
)
41+
type_info.d_fields = tuple(d_fields)
42+
# Check if all fields are exposed as type annotations
4743
_check_c_class(super_type_cls, type_info)
4844

49-
# Step 3. Attach fields
50-
setattr(type_cls, "_mlc_type_info", type_info)
51-
for field in type_info.fields:
52-
attach_field(
53-
cls=type_cls,
54-
name=field.name,
55-
getter=field.getter,
56-
setter=field.setter,
57-
frozen=field.frozen,
58-
)
59-
6045
# Step 4. Attach methods
46+
fn_init: Callable[..., None] | None = None
6147
if init:
62-
attach_method(
63-
parent_cls=super_type_cls,
64-
cls=type_cls,
65-
name="__init__",
66-
method=method_init(super_type_cls, d_fields),
67-
check_exists=True,
68-
)
69-
add_vtable_methods_for_type_cls(super_type_cls, type_index=type_info.type_index)
48+
fn_init = utils.method_init(super_type_cls, d_fields)
49+
else:
50+
fn_init = None
51+
# Step 5. Create the proxy class with the fields as properties
52+
type_cls: type[InputClsType] = utils.create_type_class(
53+
cls=super_type_cls,
54+
type_info=type_info,
55+
methods={
56+
"__init__": fn_init,
57+
},
58+
)
7059
return type_cls
7160

7261
return decorator
7362

7463

7564
def _check_c_class(
76-
type_cls: type[ClsType],
65+
type_cls: type[InputClsType],
7766
type_info: TypeInfo,
7867
) -> None:
7968
type_hints = typing.get_type_hints(type_cls)
@@ -117,5 +106,5 @@ def _check_c_class(
117106
if warned:
118107
warnings.warn(
119108
f"One or multiple warnings in `{type_cls.__module__}.{type_cls.__qualname__}`. Its prototype is:\n"
120-
+ prototype(type_info, lang="py")
109+
+ utils.prototype(type_info, lang="py")
121110
)

python/mlc/dataclasses/py_class.py

Lines changed: 44 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,26 @@
55
except ImportError:
66
from typing_extensions import dataclass_transform
77

8-
import ctypes
9-
import functools
108
import typing
119
from collections.abc import Callable
1210

1311
from mlc._cython import (
14-
MLCHeader,
1512
TypeField,
1613
TypeInfo,
17-
attach_field,
18-
attach_method,
1914
make_mlc_init,
2015
type_add_method,
2116
type_create,
2217
type_create_instance,
23-
type_field_get_accessor,
2418
type_register_fields,
2519
type_register_structure,
2620
)
27-
from mlc.core import Object
28-
29-
from .utils import Field as _Field
30-
from .utils import (
31-
Structure,
32-
add_vtable_methods_for_type_cls,
33-
get_parent_type,
34-
inspect_dataclass_fields,
35-
method_init,
36-
structure_parse,
37-
structure_to_c,
38-
)
39-
from .utils import field as _field
21+
22+
from . import utils
4023

4124
InputClsType = typing.TypeVar("InputClsType")
4225

4326

44-
@dataclass_transform(field_specifiers=(_field, _Field))
27+
@dataclass_transform(field_specifiers=(utils.field, utils.Field))
4528
def py_class(
4629
type_key: str | type | None = None,
4730
*,
@@ -69,56 +52,36 @@ def decorator(super_type_cls: type[InputClsType]) -> type[InputClsType]:
6952
type_key = f"{super_type_cls.__module__}.{super_type_cls.__qualname__}"
7053
assert isinstance(type_key, str)
7154

72-
# Step 1. Create the type according to its parent type
73-
parent_type_info: TypeInfo = get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined]
55+
# Step 1. Create `type_info`
56+
parent_type_info: TypeInfo = utils.get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined]
7457
type_info: TypeInfo = type_create(parent_type_info.type_index, type_key)
7558
type_index = type_info.type_index
7659

7760
# Step 2. Reflect all the fields of the type
78-
fields, d_fields = inspect_dataclass_fields(
79-
type_key,
61+
fields, d_fields, num_bytes = utils.inspect_dataclass_fields(
8062
super_type_cls,
8163
parent_type_info,
8264
frozen=frozen,
65+
py_mode=True,
8366
)
84-
num_bytes = _add_field_properties(fields)
8567
type_info.fields = tuple(fields)
8668
type_info.d_fields = tuple(d_fields)
8769
type_register_fields(type_index, fields)
88-
mlc_init = make_mlc_init(fields)
89-
90-
# Step 3. Create the proxy class with the fields as properties
91-
type_cls: type[InputClsType] = _create_cls(
92-
cls=super_type_cls,
93-
mlc_init=mlc_init,
94-
mlc_new=lambda cls, *args, **kwargs: type_create_instance(cls, type_index, num_bytes),
95-
)
9670

97-
type_info.type_cls = type_cls
98-
setattr(type_cls, "_mlc_type_info", type_info)
99-
for field in fields:
100-
attach_field(
101-
type_cls,
102-
name=field.name,
103-
getter=field.getter,
104-
setter=field.setter,
105-
frozen=field.frozen,
106-
)
107-
108-
# Step 4. Register the structure of the class
109-
struct: Structure
71+
# Step 3. Register the structure of the class
72+
struct: utils.Structure
11073
struct_kind: int
11174
sub_structure_indices: list[int]
11275
sub_structure_kinds: list[int]
11376
if (struct := vars(super_type_cls).get("_mlc_structure", None)) is not None:
114-
assert isinstance(struct, Structure)
77+
assert isinstance(struct, utils.Structure)
11578
else:
116-
struct = structure_parse(structure, d_fields)
79+
struct = utils.structure_parse(structure, d_fields)
11780
(
11881
struct_kind,
11982
sub_structure_indices,
12083
sub_structure_kinds,
121-
) = structure_to_c(struct, fields)
84+
) = utils.structure_to_c(struct, fields)
12285
if struct.kind is None:
12386
assert struct_kind == 0
12487
assert not sub_structure_indices
@@ -129,45 +92,46 @@ def decorator(super_type_cls: type[InputClsType]) -> type[InputClsType]:
12992
sub_structure_indices=tuple(sub_structure_indices),
13093
sub_structure_kinds=tuple(sub_structure_kinds),
13194
)
132-
setattr(type_cls, "_mlc_structure", struct)
13395

134-
# Step 5. Add `__init__` method
135-
type_add_method(type_index, "__init__", _method_new(type_cls), 1) # static
136-
# Step 6. Attach methods
137-
fn: Callable[..., typing.Any]
96+
# Step 4. Attach methods
97+
# Step 4.1. Method `__init__`
98+
fn_init: Callable[..., None] | None = None
13899
if init:
139-
fn = method_init(super_type_cls, d_fields)
140-
attach_method(super_type_cls, type_cls, "__init__", fn, check_exists=True)
100+
fn_init = utils.method_init(super_type_cls, d_fields)
101+
else:
102+
fn_init = None
103+
# Step 4.2. Method `__repr__` and `__str__`
104+
fn_repr: Callable[[InputClsType], str] | None = None
141105
if repr:
142-
fn = _method_repr(type_key, fields)
143-
type_add_method(type_index, "__str__", fn, 1) # static
144-
attach_method(super_type_cls, type_cls, "__repr__", fn, check_exists=True)
145-
attach_method(super_type_cls, type_cls, "__str__", fn, check_exists=True)
146-
elif (fn := vars(super_type_cls).get("__str__", None)) is not None:
147-
assert callable(fn)
148-
type_add_method(type_index, "__str__", fn, 1)
149-
add_vtable_methods_for_type_cls(super_type_cls, type_index=type_index)
106+
fn_repr = _method_repr(type_key, fields)
107+
type_add_method(type_index, "__str__", fn_repr, 1) # static
108+
elif (fn_repr := vars(super_type_cls).get("__str__", None)) is not None:
109+
assert callable(fn_repr)
110+
type_add_method(type_index, "__str__", fn_repr, 1)
111+
else:
112+
fn_repr = None
113+
114+
# Step 5. Create the proxy class with the fields as properties
115+
type_cls: type[InputClsType] = utils.create_type_class(
116+
cls=super_type_cls,
117+
type_info=type_info,
118+
methods={
119+
"_mlc_init": make_mlc_init(fields),
120+
"__new__": lambda cls, *args, **kwargs: type_create_instance(
121+
cls, type_index, num_bytes
122+
),
123+
"__init__": fn_init,
124+
"__repr__": fn_repr,
125+
"__str__": fn_repr,
126+
},
127+
)
128+
type_add_method(type_index, "__init__", _method_new(type_cls), 1) # static
129+
setattr(type_cls, "_mlc_structure", struct)
150130
return type_cls
151131

152132
return decorator
153133

154134

155-
def _add_field_properties(type_fields: list[TypeField]) -> int:
156-
c_fields = [("_mlc_header", MLCHeader)]
157-
for type_field in type_fields:
158-
field_name = type_field.name
159-
field_ty_c = type_field.ty._ctype()
160-
c_fields.append((field_name, field_ty_c))
161-
162-
class CType(ctypes.Structure):
163-
_fields_ = c_fields
164-
165-
for field in type_fields:
166-
field.offset = getattr(CType, field.name).offset
167-
field.getter, field.setter = type_field_get_accessor(field)
168-
return ctypes.sizeof(CType)
169-
170-
171135
def _method_repr(
172136
type_key: str,
173137
fields: list[TypeField],
@@ -188,32 +152,3 @@ def method(*args: typing.Any) -> InputClsType:
188152
return obj
189153

190154
return method
191-
192-
193-
def _create_cls(
194-
cls: type,
195-
mlc_init: Callable[..., None],
196-
mlc_new: Callable[..., None],
197-
) -> type[InputClsType]:
198-
cls_name = cls.__name__
199-
cls_bases = cls.__bases__
200-
attrs = dict(cls.__dict__)
201-
if cls_bases == (object,):
202-
cls_bases = (Object,)
203-
204-
def _add_method(fn: Callable, fn_name: str) -> None:
205-
attrs[fn_name] = fn
206-
fn.__module__ = cls.__module__
207-
fn.__name__ = fn_name
208-
fn.__qualname__ = f"{cls_name}.{fn_name}"
209-
210-
attrs["__slots__"] = ()
211-
attrs.pop("__dict__", None)
212-
attrs.pop("__weakref__", None)
213-
_add_method(mlc_init, "_mlc_init")
214-
_add_method(mlc_new, "__new__")
215-
216-
new_cls = type(cls_name, cls_bases, attrs)
217-
new_cls.__module__ = cls.__module__
218-
new_cls = functools.wraps(cls, updated=())(new_cls) # type: ignore
219-
return new_cls

0 commit comments

Comments
 (0)