|
10 | 10 | from dataclasses import dataclass, field |
11 | 11 | from functools import lru_cache |
12 | 12 | from pathlib import Path |
13 | | -from typing import Any, Callable, Dict, List, Set, Tuple, Union |
| 13 | +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union |
14 | 14 |
|
15 | 15 | import torch |
16 | 16 |
|
@@ -127,7 +127,7 @@ def _opt(self): |
127 | 127 | class OptimizationProfile: |
128 | 128 | '''Ranges of all tensors, all dimension |
129 | 129 | ''' |
130 | | - shapes: List[List[Dim]] |
| 130 | + shapes: List[List[Dim]] = field(default_factory=lambda: [[]]) |
131 | 131 |
|
132 | 132 | def get_hash_key(self): |
133 | 133 | return self.get_opt_shapes() |
@@ -536,12 +536,90 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000): |
536 | 536 |
|
537 | 537 | self.profiling_debug = True |
538 | 538 |
|
| 539 | + # Current captured choose_one() contexts |
| 540 | + self._active_capture: Optional['AutoTuner.TacticsCapture'] = None |
| 541 | + # Last captured choose_one() contexts |
| 542 | + self._last_capture: Optional['AutoTuner.TacticsCapture'] = None |
| 543 | + |
539 | 544 | @classmethod |
540 | 545 | def get(cls): |
541 | 546 | if cls._instance is None: |
542 | 547 | cls._instance = AutoTuner() |
543 | 548 | return cls._instance |
544 | 549 |
|
| 550 | + class TacticsCapture: |
| 551 | + """Object returned by capture() that can be iterated to get all tactic combinations. |
| 552 | +
|
| 553 | + This class encapsulates all state related to capturing and replaying tactics: |
| 554 | + - Captured execution contexts |
| 555 | + - Generated tactic configurations |
| 556 | + - Current replay state (which config and call index) |
| 557 | + """ |
| 558 | + |
| 559 | + def __init__(self, autotuner): |
| 560 | + # State for captured contexts |
| 561 | + self._captured_contexts: List[Dict[str, Any]] = [] |
| 562 | + self._configurations = None |
| 563 | + # State for replay mode |
| 564 | + self._replay_runner_tactic_list: Optional[List[Tuple[int, |
| 565 | + int]]] = None |
| 566 | + self._replay_context_idx: int = 0 |
| 567 | + |
| 568 | + def __iter__(self): |
| 569 | + """Iterate through all tactic configurations. |
| 570 | +
|
| 571 | + For single context: yields (runner, tactic) |
| 572 | + For multiple contexts: yields ((runner_ctx0, tactic_ctx0), (runner_ctx1, tactic_ctx1), ...) |
| 573 | + """ |
| 574 | + if self._configurations is None: |
| 575 | + self._configurations = self._generate_configurations() |
| 576 | + |
| 577 | + for config in self._configurations: |
| 578 | + # config is a tuple of (runner_idx, tactic) for each context |
| 579 | + # Convert to (runner, tactic) format for user |
| 580 | + runner_tactic_pairs = [] |
| 581 | + for ctx_idx, (runner_idx, tactic) in enumerate(config): |
| 582 | + runners = self._captured_contexts[ctx_idx]['runners'] |
| 583 | + runner = runners[runner_idx] |
| 584 | + runner_tactic_pairs.append((runner, tactic)) |
| 585 | + |
| 586 | + yield tuple(runner_tactic_pairs) |
| 587 | + |
| 588 | + def _generate_configurations(self): |
| 589 | + """Generate all valid tactic combinations.""" |
| 590 | + if not self._captured_contexts: |
| 591 | + raise RuntimeError( |
| 592 | + "No context available for testing.\n" |
| 593 | + "Use capture() to capture the operation context first:\n" |
| 594 | + " with AutoTuner.get().capture() as tactics_capture:\n" |
| 595 | + " output = operation.forward(...)\n") |
| 596 | + |
| 597 | + # Collect valid tactics for each context separately |
| 598 | + context_tactics_lists = [] |
| 599 | + |
| 600 | + for context in self._captured_contexts: |
| 601 | + runners = context['runners'] |
| 602 | + inputs = context['inputs'] |
| 603 | + kwargs = context.get('kwargs', {}) |
| 604 | + |
| 605 | + # Collect all valid (runner, tactic) combinations for this context |
| 606 | + tactics_lists = [] |
| 607 | + for runner_idx, runner in enumerate(runners): |
| 608 | + valid_tactics = runner.get_valid_tactics( |
| 609 | + inputs, OptimizationProfile(), **kwargs) |
| 610 | + for tactic in valid_tactics: |
| 611 | + tactics_lists.append((runner_idx, tactic)) |
| 612 | + context_tactics_lists.append(tactics_lists) |
| 613 | + |
| 614 | + # Generate cartesian product from context and tactics where all_configrations[i][ctx] = (runner, tactic) |
| 615 | + # Such that each element in all_configrations is a replay of multiple contexts of all possible replays |
| 616 | + all_configurations = list(itertools.product(*context_tactics_lists)) |
| 617 | + return all_configurations |
| 618 | + |
| 619 | + def is_replaying(self) -> bool: |
| 620 | + """Check if this TacticsCapture is currently in replay mode.""" |
| 621 | + return self._replay_runner_tactic_list is not None |
| 622 | + |
545 | 623 | def choose_one( |
546 | 624 | self, |
547 | 625 | custom_op: str, |
@@ -573,6 +651,52 @@ def choose_one( |
573 | 651 | Runner authors are suggested to provide a fallback implementation for each runner to avoid potential issues. |
574 | 652 | """ |
575 | 653 |
|
| 654 | + # Check if we're in replay mode via active TacticsCapture |
| 655 | + if self._active_capture is not None and self._active_capture.is_replaying( |
| 656 | + ): |
| 657 | + tactics_capture = self._active_capture |
| 658 | + call_idx = tactics_capture._replay_context_idx |
| 659 | + |
| 660 | + assert call_idx < len(tactics_capture._replay_runner_tactic_list |
| 661 | + ), "call_idx out of range" |
| 662 | + assert call_idx < len( |
| 663 | + tactics_capture._captured_contexts), "call_idx out of range" |
| 664 | + assert len(tactics_capture._replay_runner_tactic_list) == len( |
| 665 | + tactics_capture._captured_contexts) |
| 666 | + |
| 667 | + # Check if we have a forced tactic for this call and both custom_op match |
| 668 | + captured_custom_op = tactics_capture._captured_contexts[ |
| 669 | + call_idx].get('custom_op') |
| 670 | + if captured_custom_op != custom_op: |
| 671 | + raise RuntimeError( |
| 672 | + f"Custom op mismatch in kernel testing mode.\n" |
| 673 | + f"Expected operation: '{captured_custom_op}'\n" |
| 674 | + f"Actual operation: '{custom_op}'\n" |
| 675 | + f"Context index: {call_idx}\n" |
| 676 | + f"Make sure the forward() call in test mode uses the same operation as captured." |
| 677 | + ) |
| 678 | + |
| 679 | + runner_idx, tactic = tactics_capture._replay_runner_tactic_list[ |
| 680 | + call_idx] |
| 681 | + # Increment context counter |
| 682 | + tactics_capture._replay_context_idx += 1 |
| 683 | + # Reset counter after all contexts have been used |
| 684 | + if tactics_capture._replay_context_idx >= len( |
| 685 | + tactics_capture._replay_runner_tactic_list): |
| 686 | + tactics_capture._replay_context_idx = 0 |
| 687 | + return (runners[runner_idx], tactic) |
| 688 | + |
| 689 | + # Capture context for testing all underlying kernels |
| 690 | + if self._active_capture is not None and not self._active_capture.is_replaying( |
| 691 | + ): |
| 692 | + self._active_capture._captured_contexts.append({ |
| 693 | + 'custom_op': custom_op, |
| 694 | + 'runners': runners, |
| 695 | + 'tuning_config': tuning_config, |
| 696 | + 'inputs': inputs, |
| 697 | + 'kwargs': kwargs, |
| 698 | + }) |
| 699 | + |
576 | 700 | input_shapes = tuple(self._get_input_sizes(inputs)) |
577 | 701 | # Early return if it's not tuning, use cache found one or fallback one |
578 | 702 | if not self.is_tuning_mode: |
@@ -957,3 +1081,91 @@ def print_profiling_cache(self): |
957 | 1081 | logger.debug( |
958 | 1082 | f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic}, min_time={min_time})" |
959 | 1083 | ) |
| 1084 | + |
| 1085 | + @contextlib.contextmanager |
| 1086 | + def capture(self): |
| 1087 | + """Context manager for capturing execution contexts for testing. |
| 1088 | +
|
| 1089 | + Returns a TacticsCapture object that can be iterated to get all valid |
| 1090 | + (runner, tactic) combinations. |
| 1091 | +
|
| 1092 | + Example: |
| 1093 | + >>> # Single context case |
| 1094 | + >>> with AutoTuner.get().capture() as tactics_capture: |
| 1095 | + ... y = custom_op.forward(x) |
| 1096 | + >>> |
| 1097 | + >>> for runner, tactic in tactics_capture: |
| 1098 | + ... with AutoTuner.get().replay(runner, tactic): |
| 1099 | + ... y = custom_op.forward(x) |
| 1100 | +
|
| 1101 | + >>> # Multiple contexts case |
| 1102 | + >>> with AutoTuner.get().capture() as tactics_capture: |
| 1103 | + ... y = custom_op1.forward(x) |
| 1104 | + ... z = custom_op2.forward(y) |
| 1105 | + >>> |
| 1106 | + >>> for config in tactics_capture: |
| 1107 | + ... with AutoTuner.get().replay(config): |
| 1108 | + ... y = custom_op1.forward(x) |
| 1109 | + ... z = custom_op2.forward(y) |
| 1110 | + """ |
| 1111 | + tactics_capture = self.TacticsCapture(self) |
| 1112 | + self._active_capture = tactics_capture |
| 1113 | + try: |
| 1114 | + yield tactics_capture |
| 1115 | + finally: |
| 1116 | + self._active_capture = None |
| 1117 | + self._last_capture = tactics_capture |
| 1118 | + |
| 1119 | + @contextlib.contextmanager |
| 1120 | + def replay(self, *config: Tuple[Tuple[TunableRunner, int], ...]): |
| 1121 | + """Context manager for replaying with specific runner/tactic configuration. |
| 1122 | +
|
| 1123 | + Args: |
| 1124 | + config: |
| 1125 | + - A tuple of (runner, tactic) pairs. The tuple size matches the number of captured choose_one() contexts. |
| 1126 | + """ |
| 1127 | + # Parse config argument |
| 1128 | + if len(config) == 1: |
| 1129 | + if isinstance(config[0], tuple): |
| 1130 | + # Multiple contexts: replay(((r0,t0), (r1,t1), ...)) |
| 1131 | + runner_tactic_pairs = list(config[0]) |
| 1132 | + else: |
| 1133 | + # Also handle single context passed as replay((runner, tactic)) |
| 1134 | + runner_tactic_pairs = [config[0]] |
| 1135 | + else: |
| 1136 | + raise ValueError( |
| 1137 | + f"Invalid config for replay: {config}\n" |
| 1138 | + "Expected replay(((runner, tactic), (runner, tactic), ...))") |
| 1139 | + |
| 1140 | + # Find the TacticsCapture to use |
| 1141 | + tactics_capture = self._active_capture or self._last_capture |
| 1142 | + |
| 1143 | + if tactics_capture is None: |
| 1144 | + raise RuntimeError( |
| 1145 | + "No TacticsCapture available for replay. " |
| 1146 | + "Make sure you've called capture() before replay().") |
| 1147 | + |
| 1148 | + # Temporarily set as active capture during replay |
| 1149 | + prev_active = self._active_capture |
| 1150 | + self._active_capture = tactics_capture |
| 1151 | + |
| 1152 | + runner_tactic_list = [] |
| 1153 | + for ctx_idx, (runner, tactic) in enumerate(runner_tactic_pairs): |
| 1154 | + runners = tactics_capture._captured_contexts[ctx_idx]['runners'] |
| 1155 | + runner_idx = runners.index(runner) |
| 1156 | + runner_tactic_list.append((runner_idx, tactic)) |
| 1157 | + |
| 1158 | + logger.debug( |
| 1159 | + f"[Autotuner][replay]: Testing configuration: {runner_tactic_list}") |
| 1160 | + |
| 1161 | + # Replay the contexts with given (runner, tactic) pairs |
| 1162 | + tactics_capture._replay_runner_tactic_list = runner_tactic_list |
| 1163 | + tactics_capture._replay_context_idx = 0 |
| 1164 | + |
| 1165 | + try: |
| 1166 | + yield |
| 1167 | + finally: |
| 1168 | + tactics_capture._replay_runner_tactic_list = None |
| 1169 | + tactics_capture._replay_context_idx = 0 |
| 1170 | + # Restore previous active capture state |
| 1171 | + self._active_capture = prev_active |
0 commit comments