Skip to content

Commit ef4fa22

Browse files
Update docs
1 parent ec0d7b6 commit ef4fa22

File tree

199 files changed

+3582
-173
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

199 files changed

+3582
-173
lines changed

_sources/autoapi/tilelang/jit/index.rst.txt

Lines changed: 163 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Attributes
3030
.. autoapisummary::
3131

3232
tilelang.jit.logger
33+
tilelang.jit.ExecutionBackend
3334

3435

3536
Classes
@@ -48,6 +49,7 @@ Functions
4849
tilelang.jit.compile
4950
tilelang.jit.par_compile
5051
tilelang.jit.jit
52+
tilelang.jit.lazy_jit
5153

5254

5355
Package Contents
@@ -99,31 +101,129 @@ Package Contents
99101

100102
.. py:class:: JITImpl
101103
102-
Bases: :py:obj:`Generic`\ [\ :py:obj:`_P`\ , :py:obj:`_KP`\ , :py:obj:`_T`\ ]
104+
Bases: :py:obj:`Generic`\ [\ :py:obj:`_P`\ , :py:obj:`_KP`\ , :py:obj:`_T`\ , :py:obj:`_Ret`\ ]
103105

104106

105-
Abstract base class for generic types.
107+
Detailed Just-In-Time wrapper for TileLang programs.
106108

107-
A generic type is typically declared by inheriting from
108-
this class parameterized with one or more type variables.
109-
For example, a generic mapping type might be defined as::
109+
This dataclass encapsulates the configuration and runtime helpers used by the
110+
top-level `jit` and `jit2` decorators. It represents a configured JIT
111+
"factory" that can (a) elaborate TileLang/PrimFunc creators into concrete
112+
TIR (PrimFunc), (b) compile those TIR functions into runnable kernels via
113+
the TVM bridge, (c) cache compiled kernels keyed by call-site arguments
114+
(and optional tuning parameters), and (d) provide parallel compilation
115+
helpers for batch autotuning workflows.
110116

111-
class Mapping(Generic[KT, VT]):
112-
def __getitem__(self, key: KT) -> VT:
113-
...
114-
# Etc.
117+
.. attribute:: out_idx
115118

116-
This class can then be used as follows::
119+
Which output tensor(s) of the compiled kernel should be returned to the
120+
caller. Accepts a single index, a list of indices, or None to return all.
117121

118-
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT:
119-
try:
120-
return mapping[key]
121-
except KeyError:
122-
return default
122+
:type: list[int] | int | None
123123

124+
.. attribute:: execution_backend
124125

125-
.. py:attribute:: func
126-
:type: Callable[_P, _T] | tilelang.language.v2.PrimFunc[_KP, _T]
126+
Backend used for exchanging arguments and executing the generated kernel.
127+
128+
:type: Literal["dlpack", "ctypes", "cython"]
129+
130+
.. attribute:: target
131+
132+
TVM compilation target (e.g. "cuda", "llvm", or "auto").
133+
134+
:type: str | tvm.target.Target
135+
136+
.. attribute:: target_host
137+
138+
Host target used for cross-compilation, or None to infer/default.
139+
140+
:type: str | tvm.target.Target | None
141+
142+
.. attribute:: verbose
143+
144+
Enable verbose messages during compilation/build.
145+
146+
:type: bool
147+
148+
.. attribute:: pass_configs
149+
150+
Extra TVM pass configuration options forwarded to the compiler's
151+
PassContext.
152+
153+
:type: dict[str, Any] | None
154+
155+
.. attribute:: debug_root_path
156+
157+
If provided, compiled kernel source and the elaborated Python program
158+
are written to this directory to ease debugging and inspection.
159+
160+
:type: str | None
161+
162+
.. attribute:: compile_flags
163+
164+
Additional flags passed to the compiler. A single string will be converted
165+
to a single-element list.
166+
167+
:type: list[str] | str | None
168+
169+
.. attribute:: func_source
170+
171+
Original Python source string from which the PrimFunc or creator was
172+
derived. Used for diagnostics and debug dumps.
173+
174+
:type: str
175+
176+
.. attribute:: signature
177+
178+
Function signature of the original Python function (useful for tooling).
179+
180+
:type: inspect.Signature
181+
182+
.. attribute:: v2
183+
184+
Indicates whether the object wraps a "v2" PrimFunc creator (True) or a
185+
plain callable / PrimFunc (False). v2-mode enables argument conversion
186+
hooks and a distinct cache keying strategy.
187+
188+
:type: bool
189+
190+
.. attribute:: func
191+
192+
The underlying object: either a user function that returns a PrimFunc
193+
(creator), a PrimFuncCreater, or an already-constructed PrimFunc.
194+
For presentation/readability the function is stored last in the dataclass.
195+
196+
:type: Callable | PrimFunc | PrimFuncCreater
197+
198+
.. attribute:: Behavioral summary
199+
200+
201+
202+
.. attribute:: ------------------
203+
204+
205+
206+
.. attribute:: - get_tir(*args, \*\*kwargs)
207+
208+
Converts provided call-site arguments into a concrete PrimFunc. If the
209+
wrapped object is a PrimFuncCreater or a user callable, it is invoked
210+
with the given arguments. If the wrapped object is already a PrimFunc,
211+
it is returned as-is.
212+
213+
.. attribute:: - compile(...)
214+
215+
A convenience wrapper that elaborates and immediately compiles a single
216+
PrimFunc into a JITKernel using the module-level `compile` function.
217+
When `debug_root_path` is set, the compiled C kernel and the source
218+
Python program are saved for inspection.
219+
220+
.. attribute:: - par_compile(configs, ...)
221+
222+
Accepts an iterable of configs (either dicts mapping keyword args or
223+
tuples mapping to positional args). Each config is elaborated to a
224+
PrimFunc and the resulting set is compiled in parallel via the
225+
module-level `par_compile` helper. Returns a list of JITKernel objects
226+
in the same order as the provided configs.
127227

128228

129229
.. py:attribute:: out_idx
@@ -166,23 +266,64 @@ Package Contents
166266
:type: inspect.Signature
167267

168268

269+
.. py:attribute:: lazy_jit
270+
:type: bool
271+
272+
273+
.. py:attribute:: func
274+
:type: Callable[_P, _T] | tilelang.language.v2.PrimFunc[_KP, _T]
275+
276+
277+
.. py:property:: annot
278+
:type: dict[str, tilelang.language.v2.annot.Annot]
279+
280+
281+
169282
.. py:method:: __post_init__()
170283
171284
172285
.. py:method:: get_tir(*args, **kwargs)
173286
287+
Retrieve a TIR (Tensor Intermediate Representation) PrimFunc from the stored callable or object.
288+
289+
174290

175291
.. py:method:: par_compile(configs, num_workers = None, ignore_error = False)
176292
293+
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
294+
:param configs: The configurations to elaborate and compile. Each config can be either
295+
a dictionary mapping keyword arguments to values, or a tuple of positional
296+
arguments.
297+
:type configs: Iterable[Union[dict[str, Any], tuple[Any, ...]]]
298+
:param num_workers: Number of parallel workers to use for compilation. Defaults to None,
299+
which lets the system decide.
300+
:type num_workers: int, optional
301+
:param ignore_error: If True, compilation errors for individual configs will be logged
302+
as warnings and the corresponding result will be None. If False,
303+
any compilation error will raise an exception. Defaults to False.
304+
:type ignore_error: bool, optional
305+
306+
:returns: A list of compiled JITKernel objects corresponding to the provided configs.
307+
:rtype: List[JITKernel]
308+
309+
177310

178311
.. py:method:: compile(*args, **kwargs)
179312
180313
314+
.. py:method:: parse_cache_key(*args, **kwargs)
315+
316+
317+
.. py:method:: convert_kernel_args(*args, **kwargs)
318+
319+
181320
.. py:method:: __call__(*args, **kwargs)
182321
183322
184-
.. py:function:: jit(func: Callable[_P, tilelang.language.v2.PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T]
185-
jit(*, out_idx: Any = None, target: str | tvm.target.Target = 'auto', target_host: str | tvm.target.Target = None, execution_backend: Literal['auto', 'dlpack', 'tvm_ffi', 'ctypes', 'cython', 'nvrtc', 'torch'] = 'auto', verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None) -> Callable[[Callable[_P, tilelang.language.v2.PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T]]
323+
.. py:data:: ExecutionBackend
324+
325+
.. py:function:: jit(func: Callable[_P, tilelang.language.v2.PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, kernel.JITKernel[_KP, _T]]
326+
jit(*, out_idx: Any = None, target: str | tvm.target.Target = 'auto', target_host: str | tvm.target.Target = None, execution_backend: ExecutionBackend = 'auto', verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None) -> Callable[[Callable[_P, tilelang.language.v2.PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, kernel.JITKernel[_KP, _T]]]
186327
187328
Just-In-Time (JIT) compiler decorator for TileLang functions.
188329

@@ -212,3 +353,6 @@ Package Contents
212353
:rtype: Callable
213354

214355

356+
.. py:function:: lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]
357+
lazy_jit(*, out_idx: Any = None, target: str | tvm.target.Target = 'auto', target_host: str | tvm.target.Target = None, execution_backend: ExecutionBackend = 'auto', verbose: bool = False, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]
358+

_sources/autoapi/tilelang/language/allocate/index.rst.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Functions
4747
tilelang.language.allocate.alloc_tcgen05_smem_desc
4848
tilelang.language.allocate.alloc_tcgen05_instruction_desc
4949
tilelang.language.allocate.alloc_tcgen05_instr_desc
50+
tilelang.language.allocate.empty
5051

5152

5253
Module Contents
@@ -213,3 +214,5 @@ Module Contents
213214
214215
.. py:function:: alloc_tcgen05_instr_desc(dtype = 'uint32')
215216
217+
.. py:function:: empty(shape, dtype = 'float32')
218+

0 commit comments

Comments
 (0)