55except ImportError :
66 from typing_extensions import dataclass_transform
77
8- import ctypes
9- import functools
108import typing
119from collections .abc import Callable
1210
1311from mlc ._cython import (
14- MLCHeader ,
1512 TypeField ,
1613 TypeInfo ,
17- attach_field ,
18- attach_method ,
1914 make_mlc_init ,
2015 type_add_method ,
2116 type_create ,
2217 type_create_instance ,
23- type_field_get_accessor ,
2418 type_register_fields ,
2519 type_register_structure ,
2620)
27- from mlc .core import Object
28-
29- from .utils import Field as _Field
30- from .utils import (
31- Structure ,
32- add_vtable_methods_for_type_cls ,
33- get_parent_type ,
34- inspect_dataclass_fields ,
35- method_init ,
36- structure_parse ,
37- structure_to_c ,
38- )
39- from .utils import field as _field
21+
22+ from . import utils
4023
4124InputClsType = typing .TypeVar ("InputClsType" )
4225
4326
44- @dataclass_transform (field_specifiers = (_field , _Field ))
27+ @dataclass_transform (field_specifiers = (utils . field , utils . Field ))
4528def py_class (
4629 type_key : str | type | None = None ,
4730 * ,
@@ -69,56 +52,36 @@ def decorator(super_type_cls: type[InputClsType]) -> type[InputClsType]:
6952 type_key = f"{ super_type_cls .__module__ } .{ super_type_cls .__qualname__ } "
7053 assert isinstance (type_key , str )
7154
72- # Step 1. Create the type according to its parent type
73- parent_type_info : TypeInfo = get_parent_type (super_type_cls )._mlc_type_info # type: ignore[attr-defined]
55+ # Step 1. Create `type_info`
56+ parent_type_info : TypeInfo = utils . get_parent_type (super_type_cls )._mlc_type_info # type: ignore[attr-defined]
7457 type_info : TypeInfo = type_create (parent_type_info .type_index , type_key )
7558 type_index = type_info .type_index
7659
7760 # Step 2. Reflect all the fields of the type
78- fields , d_fields = inspect_dataclass_fields (
79- type_key ,
61+ fields , d_fields , num_bytes = utils .inspect_dataclass_fields (
8062 super_type_cls ,
8163 parent_type_info ,
8264 frozen = frozen ,
65+ py_mode = True ,
8366 )
84- num_bytes = _add_field_properties (fields )
8567 type_info .fields = tuple (fields )
8668 type_info .d_fields = tuple (d_fields )
8769 type_register_fields (type_index , fields )
88- mlc_init = make_mlc_init (fields )
89-
90- # Step 3. Create the proxy class with the fields as properties
91- type_cls : type [InputClsType ] = _create_cls (
92- cls = super_type_cls ,
93- mlc_init = mlc_init ,
94- mlc_new = lambda cls , * args , ** kwargs : type_create_instance (cls , type_index , num_bytes ),
95- )
9670
97- type_info .type_cls = type_cls
98- setattr (type_cls , "_mlc_type_info" , type_info )
99- for field in fields :
100- attach_field (
101- type_cls ,
102- name = field .name ,
103- getter = field .getter ,
104- setter = field .setter ,
105- frozen = field .frozen ,
106- )
107-
108- # Step 4. Register the structure of the class
109- struct : Structure
71+ # Step 3. Register the structure of the class
72+ struct : utils .Structure
11073 struct_kind : int
11174 sub_structure_indices : list [int ]
11275 sub_structure_kinds : list [int ]
11376 if (struct := vars (super_type_cls ).get ("_mlc_structure" , None )) is not None :
114- assert isinstance (struct , Structure )
77+ assert isinstance (struct , utils . Structure )
11578 else :
116- struct = structure_parse (structure , d_fields )
79+ struct = utils . structure_parse (structure , d_fields )
11780 (
11881 struct_kind ,
11982 sub_structure_indices ,
12083 sub_structure_kinds ,
121- ) = structure_to_c (struct , fields )
84+ ) = utils . structure_to_c (struct , fields )
12285 if struct .kind is None :
12386 assert struct_kind == 0
12487 assert not sub_structure_indices
@@ -129,45 +92,46 @@ def decorator(super_type_cls: type[InputClsType]) -> type[InputClsType]:
12992 sub_structure_indices = tuple (sub_structure_indices ),
13093 sub_structure_kinds = tuple (sub_structure_kinds ),
13194 )
132- setattr (type_cls , "_mlc_structure" , struct )
13395
134- # Step 5. Add `__init__` method
135- type_add_method (type_index , "__init__" , _method_new (type_cls ), 1 ) # static
136- # Step 6. Attach methods
137- fn : Callable [..., typing .Any ]
96+ # Step 4. Attach methods
97+ # Step 4.1. Method `__init__`
98+ fn_init : Callable [..., None ] | None = None
13899 if init :
139- fn = method_init (super_type_cls , d_fields )
140- attach_method (super_type_cls , type_cls , "__init__" , fn , check_exists = True )
100+ fn_init = utils .method_init (super_type_cls , d_fields )
101+ else :
102+ fn_init = None
103+ # Step 4.2. Method `__repr__` and `__str__`
104+ fn_repr : Callable [[InputClsType ], str ] | None = None
141105 if repr :
142- fn = _method_repr (type_key , fields )
143- type_add_method (type_index , "__str__" , fn , 1 ) # static
144- attach_method (super_type_cls , type_cls , "__repr__" , fn , check_exists = True )
145- attach_method (super_type_cls , type_cls , "__str__" , fn , check_exists = True )
146- elif (fn := vars (super_type_cls ).get ("__str__" , None )) is not None :
147- assert callable (fn )
148- type_add_method (type_index , "__str__" , fn , 1 )
149- add_vtable_methods_for_type_cls (super_type_cls , type_index = type_index )
106+ fn_repr = _method_repr (type_key , fields )
107+ type_add_method (type_index , "__str__" , fn_repr , 1 ) # static
108+ elif (fn_repr := vars (super_type_cls ).get ("__str__" , None )) is not None :
109+ assert callable (fn_repr )
110+ type_add_method (type_index , "__str__" , fn_repr , 1 )
111+ else :
112+ fn_repr = None
113+
114+ # Step 5. Create the proxy class with the fields as properties
115+ type_cls : type [InputClsType ] = utils .create_type_class (
116+ cls = super_type_cls ,
117+ type_info = type_info ,
118+ methods = {
119+ "_mlc_init" : make_mlc_init (fields ),
120+ "__new__" : lambda cls , * args , ** kwargs : type_create_instance (
121+ cls , type_index , num_bytes
122+ ),
123+ "__init__" : fn_init ,
124+ "__repr__" : fn_repr ,
125+ "__str__" : fn_repr ,
126+ },
127+ )
128+ type_add_method (type_index , "__init__" , _method_new (type_cls ), 1 ) # static
129+ setattr (type_cls , "_mlc_structure" , struct )
150130 return type_cls
151131
152132 return decorator
153133
154134
155- def _add_field_properties (type_fields : list [TypeField ]) -> int :
156- c_fields = [("_mlc_header" , MLCHeader )]
157- for type_field in type_fields :
158- field_name = type_field .name
159- field_ty_c = type_field .ty ._ctype ()
160- c_fields .append ((field_name , field_ty_c ))
161-
162- class CType (ctypes .Structure ):
163- _fields_ = c_fields
164-
165- for field in type_fields :
166- field .offset = getattr (CType , field .name ).offset
167- field .getter , field .setter = type_field_get_accessor (field )
168- return ctypes .sizeof (CType )
169-
170-
171135def _method_repr (
172136 type_key : str ,
173137 fields : list [TypeField ],
@@ -188,32 +152,3 @@ def method(*args: typing.Any) -> InputClsType:
188152 return obj
189153
190154 return method
191-
192-
193- def _create_cls (
194- cls : type ,
195- mlc_init : Callable [..., None ],
196- mlc_new : Callable [..., None ],
197- ) -> type [InputClsType ]:
198- cls_name = cls .__name__
199- cls_bases = cls .__bases__
200- attrs = dict (cls .__dict__ )
201- if cls_bases == (object ,):
202- cls_bases = (Object ,)
203-
204- def _add_method (fn : Callable , fn_name : str ) -> None :
205- attrs [fn_name ] = fn
206- fn .__module__ = cls .__module__
207- fn .__name__ = fn_name
208- fn .__qualname__ = f"{ cls_name } .{ fn_name } "
209-
210- attrs ["__slots__" ] = ()
211- attrs .pop ("__dict__" , None )
212- attrs .pop ("__weakref__" , None )
213- _add_method (mlc_init , "_mlc_init" )
214- _add_method (mlc_new , "__new__" )
215-
216- new_cls = type (cls_name , cls_bases , attrs )
217- new_cls .__module__ = cls .__module__
218- new_cls = functools .wraps (cls , updated = ())(new_cls ) # type: ignore
219- return new_cls
0 commit comments