Skip to content

Commit f666ad2

Browse files
authored
[None][feat] Autotuner can iterate through all tactics for test purposes (#8663)
Signed-off-by: Anthony Chang <[email protected]>
1 parent a5cc9fe commit f666ad2

File tree

5 files changed

+541
-69
lines changed

5 files changed

+541
-69
lines changed

cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
*/
1616

1717
#include "tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h"
18+
#include "tensorrt_llm/thop/thUtils.h"
1819

1920
#include <ATen/ATen.h>
2021
#include <ATen/cuda/CUDAContext.h>
2122
#include <ATen/cuda/EmptyTensor.h>
2223
#include <torch/library.h>
2324

2425
#include <cstdint>
26+
#include <memory>
27+
#include <unordered_map>
2528

2629
namespace torch_ext
2730
{
@@ -316,16 +319,30 @@ class FP8BlockScaleMoeRunner : public torch::CustomClassHolder
316319
{
317320

318321
public:
319-
explicit FP8BlockScaleMoeRunner(int64_t tileTokensDim)
320-
: mTileTokensDim(tileTokensDim)
322+
explicit FP8BlockScaleMoeRunner()
323+
: mSupportedTileN{8, 16, 32, 64}
321324
{
322-
mRunner = std::make_unique<RunnerType>(mDtypeElt, mUseDeepSeekFp8, mTileTokensDim);
325+
for (int tileN : mSupportedTileN)
326+
{
327+
mRunners.emplace(tileN, std::make_unique<RunnerType>(mDtypeElt, mUseDeepSeekFp8, tileN));
328+
}
323329
}
324330

325-
[[nodiscard]] std::vector<int64_t> getValidConfigs(
331+
[[nodiscard]] std::vector<std::vector<int64_t>> getValidConfigs(
326332
int64_t topK, int64_t hiddenSize, int64_t intermediateSize, int64_t numLocalExperts, int64_t numTokens) const
327333
{
328-
return mRunner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
334+
// returns (tileN, config)
335+
std::vector<std::vector<int64_t>> tactics;
336+
for (auto& [tileN, runner] : mRunners)
337+
{
338+
auto config_indices_per_runner
339+
= runner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
340+
for (auto cfg : config_indices_per_runner)
341+
{
342+
tactics.push_back({tileN, cfg});
343+
}
344+
}
345+
return tactics;
329346
}
330347

331348
[[nodiscard]] at::Tensor run(at::optional<at::Tensor> const& routing_logits,
@@ -334,42 +351,48 @@ class FP8BlockScaleMoeRunner : public torch::CustomClassHolder
334351
at::Tensor const& gemm2_weights, at::Tensor const& gemm2_weights_scale, int64_t num_experts, int64_t top_k,
335352
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
336353
int64_t const local_expert_offset, int64_t const local_num_experts,
337-
std::optional<double> const routed_scaling_factor, int64_t routing_method_type, int64_t moeConfigIndex,
338-
std::optional<at::Tensor> const& topk_weights, std::optional<at::Tensor> const& topk_ids)
354+
std::optional<double> const routed_scaling_factor, int64_t routing_method_type,
355+
std::vector<int64_t> tile_config_pair, std::optional<at::Tensor> const& topk_weights,
356+
std::optional<at::Tensor> const& topk_ids)
339357
{
358+
// tile_config_pair corresponds to pair (tileN, config)
359+
auto [tileN, config] = std::tie(tile_config_pair[0], tile_config_pair[1]);
340360

341361
// Autotuner has requested a default or 'fallback' config index
342-
if (moeConfigIndex == -1)
362+
if (tileN == -1 || config == -1)
343363
{
344364
auto const num_tokens = hidden_states.sizes()[0];
345365
auto const hidden_size = hidden_states.sizes()[1];
346366

347-
moeConfigIndex = mRunner->getDefaultValidConfigIndex(
367+
float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / local_num_experts;
368+
tileN = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), mSupportedTileN.front(), mSupportedTileN.back());
369+
370+
config = mRunners.at(tileN)->getDefaultValidConfigIndex(
348371
top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
349372
}
350373

351374
return run_fp8_block_scale_moe(routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights,
352375
gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, num_experts, top_k, n_group, topk_group,
353-
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, mTileTokensDim,
354-
routing_method_type, *mRunner, moeConfigIndex, topk_weights, topk_ids);
376+
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
377+
routing_method_type, *mRunners.at(tileN), config, topk_weights, topk_ids);
355378
}
356379

357380
private:
358381
using RunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner;
359382

360-
std::unique_ptr<RunnerType> mRunner;
383+
std::vector<int32_t> const mSupportedTileN;
384+
std::unordered_map<int32_t, std::unique_ptr<RunnerType>> mRunners;
361385

362386
btg::Dtype mDtypeElt{btg::Dtype::E4m3}; // FP8 runner so hard-coded
363387
bool mUseDeepSeekFp8{true}; // Always true for BlockScaleMoe
364-
int64_t mTileTokensDim;
365388
};
366389

367390
} // namespace torch_ext
368391

369392
TORCH_LIBRARY_FRAGMENT(trtllm, m)
370393
{
371394
m.class_<torch_ext::FP8BlockScaleMoeRunner>("FP8BlockScaleMoERunner")
372-
.def(torch::init<int64_t>())
395+
.def(torch::init<>())
373396
.def("get_valid_configs", &torch_ext::FP8BlockScaleMoeRunner::getValidConfigs)
374397
.def("run_moe", &torch_ext::FP8BlockScaleMoeRunner::run);
375398
}

tensorrt_llm/_torch/autotuner.py

Lines changed: 214 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dataclasses import dataclass, field
1111
from functools import lru_cache
1212
from pathlib import Path
13-
from typing import Any, Callable, Dict, List, Set, Tuple, Union
13+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
1414

1515
import torch
1616

@@ -127,7 +127,7 @@ def _opt(self):
127127
class OptimizationProfile:
128128
'''Ranges of all tensors, all dimension
129129
'''
130-
shapes: List[List[Dim]]
130+
shapes: List[List[Dim]] = field(default_factory=lambda: [[]])
131131

132132
def get_hash_key(self):
133133
return self.get_opt_shapes()
@@ -536,12 +536,90 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
536536

537537
self.profiling_debug = True
538538

539+
# Current captured choose_one() contexts
540+
self._active_capture: Optional['AutoTuner.TacticsCapture'] = None
541+
# Last captured choose_one() contexts
542+
self._last_capture: Optional['AutoTuner.TacticsCapture'] = None
543+
539544
@classmethod
540545
def get(cls):
541546
if cls._instance is None:
542547
cls._instance = AutoTuner()
543548
return cls._instance
544549

550+
class TacticsCapture:
551+
"""Object returned by capture() that can be iterated to get all tactic combinations.
552+
553+
This class encapsulates all state related to capturing and replaying tactics:
554+
- Captured execution contexts
555+
- Generated tactic configurations
556+
- Current replay state (which config and call index)
557+
"""
558+
559+
def __init__(self, autotuner):
560+
# State for captured contexts
561+
self._captured_contexts: List[Dict[str, Any]] = []
562+
self._configurations = None
563+
# State for replay mode
564+
self._replay_runner_tactic_list: Optional[List[Tuple[int,
565+
int]]] = None
566+
self._replay_context_idx: int = 0
567+
568+
def __iter__(self):
569+
"""Iterate through all tactic configurations.
570+
571+
For single context: yields (runner, tactic)
572+
For multiple contexts: yields ((runner_ctx0, tactic_ctx0), (runner_ctx1, tactic_ctx1), ...)
573+
"""
574+
if self._configurations is None:
575+
self._configurations = self._generate_configurations()
576+
577+
for config in self._configurations:
578+
# config is a tuple of (runner_idx, tactic) for each context
579+
# Convert to (runner, tactic) format for user
580+
runner_tactic_pairs = []
581+
for ctx_idx, (runner_idx, tactic) in enumerate(config):
582+
runners = self._captured_contexts[ctx_idx]['runners']
583+
runner = runners[runner_idx]
584+
runner_tactic_pairs.append((runner, tactic))
585+
586+
yield tuple(runner_tactic_pairs)
587+
588+
def _generate_configurations(self):
589+
"""Generate all valid tactic combinations."""
590+
if not self._captured_contexts:
591+
raise RuntimeError(
592+
"No context available for testing.\n"
593+
"Use capture() to capture the operation context first:\n"
594+
" with AutoTuner.get().capture() as tactics_capture:\n"
595+
" output = operation.forward(...)\n")
596+
597+
# Collect valid tactics for each context separately
598+
context_tactics_lists = []
599+
600+
for context in self._captured_contexts:
601+
runners = context['runners']
602+
inputs = context['inputs']
603+
kwargs = context.get('kwargs', {})
604+
605+
# Collect all valid (runner, tactic) combinations for this context
606+
tactics_lists = []
607+
for runner_idx, runner in enumerate(runners):
608+
valid_tactics = runner.get_valid_tactics(
609+
inputs, OptimizationProfile(), **kwargs)
610+
for tactic in valid_tactics:
611+
tactics_lists.append((runner_idx, tactic))
612+
context_tactics_lists.append(tactics_lists)
613+
614+
# Generate cartesian product from context and tactics where all_configrations[i][ctx] = (runner, tactic)
615+
# Such that each element in all_configrations is a replay of multiple contexts of all possible replays
616+
all_configurations = list(itertools.product(*context_tactics_lists))
617+
return all_configurations
618+
619+
def is_replaying(self) -> bool:
620+
"""Check if this TacticsCapture is currently in replay mode."""
621+
return self._replay_runner_tactic_list is not None
622+
545623
def choose_one(
546624
self,
547625
custom_op: str,
@@ -573,6 +651,52 @@ def choose_one(
573651
Runner authors are suggested to provide a fallback implementation for each runner to avoid potential issues.
574652
"""
575653

654+
# Check if we're in replay mode via active TacticsCapture
655+
if self._active_capture is not None and self._active_capture.is_replaying(
656+
):
657+
tactics_capture = self._active_capture
658+
call_idx = tactics_capture._replay_context_idx
659+
660+
assert call_idx < len(tactics_capture._replay_runner_tactic_list
661+
), "call_idx out of range"
662+
assert call_idx < len(
663+
tactics_capture._captured_contexts), "call_idx out of range"
664+
assert len(tactics_capture._replay_runner_tactic_list) == len(
665+
tactics_capture._captured_contexts)
666+
667+
# Check if we have a forced tactic for this call and both custom_op match
668+
captured_custom_op = tactics_capture._captured_contexts[
669+
call_idx].get('custom_op')
670+
if captured_custom_op != custom_op:
671+
raise RuntimeError(
672+
f"Custom op mismatch in kernel testing mode.\n"
673+
f"Expected operation: '{captured_custom_op}'\n"
674+
f"Actual operation: '{custom_op}'\n"
675+
f"Context index: {call_idx}\n"
676+
f"Make sure the forward() call in test mode uses the same operation as captured."
677+
)
678+
679+
runner_idx, tactic = tactics_capture._replay_runner_tactic_list[
680+
call_idx]
681+
# Increment context counter
682+
tactics_capture._replay_context_idx += 1
683+
# Reset counter after all contexts have been used
684+
if tactics_capture._replay_context_idx >= len(
685+
tactics_capture._replay_runner_tactic_list):
686+
tactics_capture._replay_context_idx = 0
687+
return (runners[runner_idx], tactic)
688+
689+
# Capture context for testing all underlying kernels
690+
if self._active_capture is not None and not self._active_capture.is_replaying(
691+
):
692+
self._active_capture._captured_contexts.append({
693+
'custom_op': custom_op,
694+
'runners': runners,
695+
'tuning_config': tuning_config,
696+
'inputs': inputs,
697+
'kwargs': kwargs,
698+
})
699+
576700
input_shapes = tuple(self._get_input_sizes(inputs))
577701
# Early return if it's not tuning, use cache found one or fallback one
578702
if not self.is_tuning_mode:
@@ -957,3 +1081,91 @@ def print_profiling_cache(self):
9571081
logger.debug(
9581082
f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic}, min_time={min_time})"
9591083
)
1084+
1085+
@contextlib.contextmanager
1086+
def capture(self):
1087+
"""Context manager for capturing execution contexts for testing.
1088+
1089+
Returns a TacticsCapture object that can be iterated to get all valid
1090+
(runner, tactic) combinations.
1091+
1092+
Example:
1093+
>>> # Single context case
1094+
>>> with AutoTuner.get().capture() as tactics_capture:
1095+
... y = custom_op.forward(x)
1096+
>>>
1097+
>>> for runner, tactic in tactics_capture:
1098+
... with AutoTuner.get().replay(runner, tactic):
1099+
... y = custom_op.forward(x)
1100+
1101+
>>> # Multiple contexts case
1102+
>>> with AutoTuner.get().capture() as tactics_capture:
1103+
... y = custom_op1.forward(x)
1104+
... z = custom_op2.forward(y)
1105+
>>>
1106+
>>> for config in tactics_capture:
1107+
... with AutoTuner.get().replay(config):
1108+
... y = custom_op1.forward(x)
1109+
... z = custom_op2.forward(y)
1110+
"""
1111+
tactics_capture = self.TacticsCapture(self)
1112+
self._active_capture = tactics_capture
1113+
try:
1114+
yield tactics_capture
1115+
finally:
1116+
self._active_capture = None
1117+
self._last_capture = tactics_capture
1118+
1119+
@contextlib.contextmanager
1120+
def replay(self, *config: Tuple[Tuple[TunableRunner, int], ...]):
1121+
"""Context manager for replaying with specific runner/tactic configuration.
1122+
1123+
Args:
1124+
config:
1125+
- A tuple of (runner, tactic) pairs. The tuple size matches the number of captured choose_one() contexts.
1126+
"""
1127+
# Parse config argument
1128+
if len(config) == 1:
1129+
if isinstance(config[0], tuple):
1130+
# Multiple contexts: replay(((r0,t0), (r1,t1), ...))
1131+
runner_tactic_pairs = list(config[0])
1132+
else:
1133+
# Also handle single context passed as replay((runner, tactic))
1134+
runner_tactic_pairs = [config[0]]
1135+
else:
1136+
raise ValueError(
1137+
f"Invalid config for replay: {config}\n"
1138+
"Expected replay(((runner, tactic), (runner, tactic), ...))")
1139+
1140+
# Find the TacticsCapture to use
1141+
tactics_capture = self._active_capture or self._last_capture
1142+
1143+
if tactics_capture is None:
1144+
raise RuntimeError(
1145+
"No TacticsCapture available for replay. "
1146+
"Make sure you've called capture() before replay().")
1147+
1148+
# Temporarily set as active capture during replay
1149+
prev_active = self._active_capture
1150+
self._active_capture = tactics_capture
1151+
1152+
runner_tactic_list = []
1153+
for ctx_idx, (runner, tactic) in enumerate(runner_tactic_pairs):
1154+
runners = tactics_capture._captured_contexts[ctx_idx]['runners']
1155+
runner_idx = runners.index(runner)
1156+
runner_tactic_list.append((runner_idx, tactic))
1157+
1158+
logger.debug(
1159+
f"[Autotuner][replay]: Testing configuration: {runner_tactic_list}")
1160+
1161+
# Replay the contexts with given (runner, tactic) pairs
1162+
tactics_capture._replay_runner_tactic_list = runner_tactic_list
1163+
tactics_capture._replay_context_idx = 0
1164+
1165+
try:
1166+
yield
1167+
finally:
1168+
tactics_capture._replay_runner_tactic_list = None
1169+
tactics_capture._replay_context_idx = 0
1170+
# Restore previous active capture state
1171+
self._active_capture = prev_active

0 commit comments

Comments
 (0)