Skip to content

Commit 02d1a59

Browse files
Enable compile on MPS with experimental warning (#1523)
This enables compile on MPS, modifying the warning message to say this is experimental and is not ready for broad use yet.
1 parent 2640f6a commit 02d1a59

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

torchchat/cli/cli.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _add_export_output_path_args(parser) -> None:
207207
default=None,
208208
help="Output to the specified AOT Inductor .dso model file",
209209
)
210-
exclusive_parser.add_argument(
210+
exclusive_parser.add_argument(
211211
"--output-snapshot-path",
212212
type=str,
213213
default=None,
@@ -266,7 +266,7 @@ def _add_exported_input_path_args(parser) -> None:
266266
default=None,
267267
help="Use the specified torchchat snaphot .tc model file",
268268
)
269-
269+
270270

271271
# Add CLI Args related to JIT downloading of model artifacts
272272
def _add_jit_downloading_args(parser) -> None:
@@ -582,10 +582,8 @@ def arg_init(args):
582582
if "mps" in args.device:
583583
if getattr(args, "compile", False) or getattr(args, "compile_prefill", False):
584584
print(
585-
"Warning: compilation is not available with device MPS, ignoring option to engage compilation"
585+
"Warning: STOP. Compilation on MPS is experimental! Don't use it yet!"
586586
)
587-
vars(args)["compile"] = False
588-
vars(args)["compile_prefill"] = False
589587

590588
if hasattr(args, "seed") and args.seed:
591589
# Localized import to minimize expensive imports

torchchat/generate.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def __init__(
321321
draft_quantize: bool,
322322
):
323323
torch._inductor.config.coordinate_descent_tuning = (
324-
builder_args.device != "cpu"
324+
builder_args.device not in ["cpu", "mps"]
325325
)
326326
torch._inductor.config.triton.unique_kernel_names = True
327327
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
@@ -1315,7 +1315,7 @@ def __init__(
13151315
quantize: bool,
13161316
draft_quantize: bool,
13171317
):
1318-
1318+
13191319
is_speculative = speculative_builder_args.checkpoint_path is not None
13201320
assert is_speculative == False, "Distributed inference with pp > 1 does not support speculative inference yet."
13211321
super().__init__(
@@ -1336,7 +1336,7 @@ def distributed_input(prompt: str) -> str:
13361336
text = [input(prompt)]
13371337
else:
13381338
text = [None]
1339-
1339+
13401340
dist.broadcast_object_list(text)
13411341
return text[0]
13421342

@@ -1491,7 +1491,7 @@ def prefill(
14911491
# TODO: we need to pass `input_pos` and `cache_lane` to each stage.
14921492
lane = 0
14931493
kwargs = {"input_pos": input_pos, "cache_lane": lane}
1494-
1494+
14951495
if self.pp_rank == self.first_pp_rank:
14961496
logits = self.prefiller.step(padded_seq, **kwargs)
14971497
elif self.pp_rank == self.last_pp_rank:
@@ -1592,7 +1592,7 @@ def sample(
15921592
return (idx_next, None)
15931593
probs = self.logits_to_probs(logits[0, -1], temperature, top_k)
15941594
idx_next = self.multinomial_sample_one_no_sync(probs)
1595-
1595+
15961596
return idx_next, probs
15971597

15981598

@@ -1601,12 +1601,12 @@ def run_generator(
16011601
rank: Optional[int] =None
16021602
):
16031603
"""
1604-
This function creates and executes a generator
1604+
This function creates and executes a generator
16051605
"""
16061606
builder_args = BuilderArgs.from_args(args)
16071607
speculative_builder_args = BuilderArgs.from_speculative_args(args)
16081608
tokenizer_args = TokenizerArgs.from_args(args)
1609-
generator_args = GeneratorArgs.from_args(args)
1609+
generator_args = GeneratorArgs.from_args(args)
16101610
#Setup rank 1 and up to suppress log messages and print messages
16111611
if builder_args.distributed and rank != 0:
16121612
logger.setLevel(logging.CRITICAL)
@@ -1636,7 +1636,7 @@ def run_generator(
16361636

16371637
def main(args):
16381638
builder_args = BuilderArgs.from_args(args)
1639-
1639+
16401640
if builder_args.distributed:
16411641
world_size = builder_args.tp * builder_args.pp
16421642

0 commit comments

Comments
 (0)