@@ -30,6 +30,7 @@ Attributes
3030.. autoapisummary ::
3131
3232 tilelang.jit.logger
33+ tilelang.jit.ExecutionBackend
3334
3435
3536Classes
@@ -48,6 +49,7 @@ Functions
4849 tilelang.jit.compile
4950 tilelang.jit.par_compile
5051 tilelang.jit.jit
52+ tilelang.jit.lazy_jit
5153
5254
5355Package Contents
@@ -99,31 +101,129 @@ Package Contents
99101
100102.. py :class :: JITImpl
101103
102- Bases: :py:obj: `Generic `\ [\ :py:obj: `_P `\ , :py:obj: `_KP `\ , :py:obj: `_T `\ ]
104+ Bases: :py:obj: `Generic `\ [\ :py:obj: `_P `\ , :py:obj: `_KP `\ , :py:obj: `_T `\ , :py:obj: ` _Ret ` \ ]
103105
104106
105- Abstract base class for generic types .
107+ Detailed Just-In-Time wrapper for TileLang programs .
106108
107- A generic type is typically declared by inheriting from
108- this class parameterized with one or more type variables.
109- For example, a generic mapping type might be defined as::
109+ This dataclass encapsulates the configuration and runtime helpers used by the
110+ top-level `jit ` and `jit2 ` decorators. It represents a configured JIT
111+ "factory" that can (a) elaborate TileLang/PrimFunc creators into concrete
112+ TIR (PrimFunc), (b) compile those TIR functions into runnable kernels via
113+ the TVM bridge, (c) cache compiled kernels keyed by call-site arguments
114+ (and optional tuning parameters), and (d) provide parallel compilation
115+ helpers for batch autotuning workflows.
110116
111- class Mapping(Generic[KT, VT]):
112- def __getitem__(self, key: KT) -> VT:
113- ...
114- # Etc.
117+ .. attribute :: out_idx
115118
116- This class can then be used as follows::
119+ Which output tensor(s) of the compiled kernel should be returned to the
120+ caller. Accepts a single index, a list of indices, or None to return all.
117121
118- def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT:
119- try:
120- return mapping[key]
121- except KeyError:
122- return default
122+ :type: list[int] | int | None
123123
124+ .. attribute :: execution_backend
124125
125- .. py :attribute :: func
126- :type: Callable[_P, _T] | tilelang.language.v2.PrimFunc[_KP, _T]
126+ Backend used for exchanging arguments and executing the generated kernel.
127+
128+ :type: Literal["dlpack", "ctypes", "cython"]
129+
130+ .. attribute :: target
131+
132+ TVM compilation target (e.g. "cuda", "llvm", or "auto").
133+
134+ :type: str | tvm.target.Target
135+
136+ .. attribute :: target_host
137+
138+ Host target used for cross-compilation, or None to infer/default.
139+
140+ :type: str | tvm.target.Target | None
141+
142+ .. attribute :: verbose
143+
144+ Enable verbose messages during compilation/build.
145+
146+ :type: bool
147+
148+ .. attribute :: pass_configs
149+
150+ Extra TVM pass configuration options forwarded to the compiler's
151+ PassContext.
152+
153+ :type: dict[str, Any] | None
154+
155+ .. attribute :: debug_root_path
156+
157+ If provided, compiled kernel source and the elaborated Python program
158+ are written to this directory to ease debugging and inspection.
159+
160+ :type: str | None
161+
162+ .. attribute :: compile_flags
163+
164+ Additional flags passed to the compiler. A single string will be converted
165+ to a single-element list.
166+
167+ :type: list[str] | str | None
168+
169+ .. attribute :: func_source
170+
171+ Original Python source string from which the PrimFunc or creator was
172+ derived. Used for diagnostics and debug dumps.
173+
174+ :type: str
175+
176+ .. attribute :: signature
177+
178+ Function signature of the original Python function (useful for tooling).
179+
180+ :type: inspect.Signature
181+
182+ .. attribute :: v2
183+
184+ Indicates whether the object wraps a "v2" PrimFunc creator (True) or a
185+ plain callable / PrimFunc (False). v2-mode enables argument conversion
186+ hooks and a distinct cache keying strategy.
187+
188+ :type: bool
189+
190+ .. attribute :: func
191+
192+ The underlying object: either a user function that returns a PrimFunc
193+ (creator), a PrimFuncCreater, or an already-constructed PrimFunc.
194+ For presentation/readability the function is stored last in the dataclass.
195+
196+ :type: Callable | PrimFunc | PrimFuncCreater
197+
198+ .. attribute :: Behavioral summary
199+
200+
201+
202+ .. attribute :: ------------------
203+
204+
205+
206+ .. attribute :: - get_tir(*args, \*\*kwargs)
207+
208+ Converts provided call-site arguments into a concrete PrimFunc. If the
209+ wrapped object is a PrimFuncCreater or a user callable, it is invoked
210+ with the given arguments. If the wrapped object is already a PrimFunc,
211+ it is returned as-is.
212+
213+ .. attribute :: - compile(...)
214+
215+ A convenience wrapper that elaborates and immediately compiles a single
216+ PrimFunc into a JITKernel using the module-level `compile ` function.
217+ When `debug_root_path ` is set, the compiled C kernel and the source
218+ Python program are saved for inspection.
219+
220+ .. attribute :: - par_compile(configs, ...)
221+
222+ Accepts an iterable of configs (either dicts mapping keyword args or
223+ tuples mapping to positional args). Each config is elaborated to a
224+ PrimFunc and the resulting set is compiled in parallel via the
225+ module-level `par_compile ` helper. Returns a list of JITKernel objects
226+ in the same order as the provided configs.
127227
128228
129229 .. py :attribute :: out_idx
@@ -166,23 +266,64 @@ Package Contents
166266 :type: inspect.Signature
167267
168268
269+ .. py :attribute :: lazy_jit
270+ :type: bool
271+
272+
273+ .. py :attribute :: func
274+ :type: Callable[_P, _T] | tilelang.language.v2.PrimFunc[_KP, _T]
275+
276+
277+ .. py :property :: annot
278+ :type: dict[str, tilelang.language.v2.annot.Annot]
279+
280+
281+
169282 .. py :method :: __post_init__ ()
170283
171284
172285 .. py :method :: get_tir(* args, ** kwargs)
173286
287+ Retrieve a TIR (Tensor Intermediate Representation) PrimFunc from the stored callable or object.
288+
289+
174290
175291 .. py :method :: par_compile(configs, num_workers = None , ignore_error = False )
176292
293+ Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
294+ :param configs: The configurations to elaborate and compile. Each config can be either
295+ a dictionary mapping keyword arguments to values, or a tuple of positional
296+ arguments.
297+ :type configs: Iterable[Union[dict[str, Any], tuple[Any, ...]]]
298+ :param num_workers: Number of parallel workers to use for compilation. Defaults to None,
299+ which lets the system decide.
300+ :type num_workers: int, optional
301+ :param ignore_error: If True, compilation errors for individual configs will be logged
302+ as warnings and the corresponding result will be None. If False,
303+ any compilation error will raise an exception. Defaults to False.
304+ :type ignore_error: bool, optional
305+
306+ :returns: A list of compiled JITKernel objects corresponding to the provided configs.
307+ :rtype: List[JITKernel]
308+
309+
177310
178311 .. py :method :: compile (* args, ** kwargs)
179312
180313
314+ .. py :method :: parse_cache_key(* args, ** kwargs)
315+
316+
317+ .. py :method :: convert_kernel_args(* args, ** kwargs)
318+
319+
181320 .. py :method :: __call__ (* args, ** kwargs)
182321
183322
184- .. py :function :: jit(func: Callable[_P, tilelang.language.v2.PrimFunc[_KP , _T]]) -> JITImpl[_P, _KP , _T]
185- jit(* , out_idx: Any = None , target: str | tvm.target.Target = ' auto' , target_host: str | tvm.target.Target = None , execution_backend: Literal[' auto' , ' dlpack' , ' tvm_ffi' , ' ctypes' , ' cython' , ' nvrtc' , ' torch' ] = ' auto' , verbose: bool = False , pass_configs: dict[str , Any] | None = None , debug_root_path: str | None = None , compile_flags: list[str ] | str | None = None ) -> Callable[[Callable[_P, tilelang.language.v2.PrimFunc[_KP , _T]]], JITImpl[_P, _KP , _T]]
323+ .. py :data :: ExecutionBackend
324+
325+ .. py :function :: jit(func: Callable[_P, tilelang.language.v2.PrimFunc[_KP , _T]]) -> JITImpl[_P, _KP , _T, kernel.JITKernel[_KP , _T]]
326+ jit(* , out_idx: Any = None , target: str | tvm.target.Target = ' auto' , target_host: str | tvm.target.Target = None , execution_backend: ExecutionBackend = ' auto' , verbose: bool = False , pass_configs: dict[str , Any] | None = None , debug_root_path: str | None = None , compile_flags: list[str ] | str | None = None ) -> Callable[[Callable[_P, tilelang.language.v2.PrimFunc[_KP , _T]]], JITImpl[_P, _KP , _T, kernel.JITKernel[_KP , _T]]]
186327
187328 Just-In-Time (JIT) compiler decorator for TileLang functions.
188329
@@ -212,3 +353,6 @@ Package Contents
212353 :rtype: Callable
213354
214355
356+ .. py :function :: lazy_jit(func: Callable[_KP , _T]) -> JITImpl[_KP , _KP , _T, _T]
357+ lazy_jit(* , out_idx: Any = None , target: str | tvm.target.Target = ' auto' , target_host: str | tvm.target.Target = None , execution_backend: ExecutionBackend = ' auto' , verbose: bool = False , pass_configs: dict[str , Any] | None = None , debug_root_path: str | None = None , compile_flags: list[str ] | str | None = None ) -> Callable[[Callable[_KP , _T]], JITImpl[_KP , _KP , _T, _T]]
358+
0 commit comments