@@ -321,7 +321,7 @@ def __init__(
321
321
draft_quantize : bool ,
322
322
):
323
323
torch ._inductor .config .coordinate_descent_tuning = (
324
- builder_args .device != "cpu"
324
+ builder_args .device not in [ "cpu" , "mps" ]
325
325
)
326
326
torch ._inductor .config .triton .unique_kernel_names = True
327
327
torch ._inductor .config .fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
@@ -1315,7 +1315,7 @@ def __init__(
1315
1315
quantize : bool ,
1316
1316
draft_quantize : bool ,
1317
1317
):
1318
-
1318
+
1319
1319
is_speculative = speculative_builder_args .checkpoint_path is not None
1320
1320
assert is_speculative == False , "Distributed inference with pp > 1 does not support speculative inference yet."
1321
1321
super ().__init__ (
@@ -1336,7 +1336,7 @@ def distributed_input(prompt: str) -> str:
1336
1336
text = [input (prompt )]
1337
1337
else :
1338
1338
text = [None ]
1339
-
1339
+
1340
1340
dist .broadcast_object_list (text )
1341
1341
return text [0 ]
1342
1342
@@ -1491,7 +1491,7 @@ def prefill(
1491
1491
# TODO: we need to pass `input_pos` and `cache_lane` to each stage.
1492
1492
lane = 0
1493
1493
kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
1494
-
1494
+
1495
1495
if self .pp_rank == self .first_pp_rank :
1496
1496
logits = self .prefiller .step (padded_seq , ** kwargs )
1497
1497
elif self .pp_rank == self .last_pp_rank :
@@ -1592,7 +1592,7 @@ def sample(
1592
1592
return (idx_next , None )
1593
1593
probs = self .logits_to_probs (logits [0 , - 1 ], temperature , top_k )
1594
1594
idx_next = self .multinomial_sample_one_no_sync (probs )
1595
-
1595
+
1596
1596
return idx_next , probs
1597
1597
1598
1598
@@ -1601,12 +1601,12 @@ def run_generator(
1601
1601
rank : Optional [int ] = None
1602
1602
):
1603
1603
"""
1604
- This function creates and executes a generator
1604
+ This function creates and executes a generator
1605
1605
"""
1606
1606
builder_args = BuilderArgs .from_args (args )
1607
1607
speculative_builder_args = BuilderArgs .from_speculative_args (args )
1608
1608
tokenizer_args = TokenizerArgs .from_args (args )
1609
- generator_args = GeneratorArgs .from_args (args )
1609
+ generator_args = GeneratorArgs .from_args (args )
1610
1610
#Setup rank 1 and up to suppress log messages and print messages
1611
1611
if builder_args .distributed and rank != 0 :
1612
1612
logger .setLevel (logging .CRITICAL )
@@ -1636,7 +1636,7 @@ def run_generator(
1636
1636
1637
1637
def main (args ):
1638
1638
builder_args = BuilderArgs .from_args (args )
1639
-
1639
+
1640
1640
if builder_args .distributed :
1641
1641
world_size = builder_args .tp * builder_args .pp
1642
1642
0 commit comments