Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions example_text_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def main(
max_gen_len: int = 64,
max_batch_size: int = 4,
dynamo: bool = True,
spmd: bool = True,
):
if not USE_CUDA:
server = xp.start_server(9012, only_on_master=False)
Expand All @@ -34,6 +35,7 @@ def main(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
dynamo=dynamo,
spmd=spmd,
)

prompts = [
Expand Down Expand Up @@ -77,12 +79,13 @@ def _fn(
max_gen_len: int = 64,
max_batch_size: int = 4,
dynamo: bool = True,
spmd: bool = True,
):
if USE_CUDA:
os.environ['WORLD_SIZE'] = torch.cuda.device_count()
os.environ['RANK'] = idx
os.environ['LOCAL_RANK'] = idx
main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo)
main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo, spmd)


def mp_main(
Expand All @@ -95,6 +98,7 @@ def mp_main(
max_gen_len: int = 64,
max_batch_size: int = 4,
dynamo: bool = True,
spmd: bool = True,
):
if mp:
if USE_CUDA:
Expand All @@ -103,9 +107,9 @@ def mp_main(
else:
kwargs = {}
xmp.spawn(_fn,
args=(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo), **kwargs)
args=(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo, spmd), **kwargs)
else:
main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo)
main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo, spmd)


if __name__ == "__main__":
Expand Down
43 changes: 40 additions & 3 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
# Some how xla init will slow down the CUDA speed.
if not USE_CUDA:
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
from torch_xla import runtime as xr
import numpy as np

Role = Literal["system", "user", "assistant"]

Expand Down Expand Up @@ -60,6 +63,7 @@ def build(
max_batch_size: int,
model_parallel_size: Optional[int] = None,
dynamo: bool = True,
spmd: bool = True,
) -> "Llama":
# if not model_parallel_is_initialized():
# if model_parallel_size is None:
Expand Down Expand Up @@ -118,14 +122,47 @@ def build(
model = model.to(device)
print(f"Loaded in {time.time() - start_time:.2f} seconds")

return Llama(model, tokenizer, device, dynamo)
return Llama(model, tokenizer, device, dynamo, spmd)

def __init__(self, model: Transformer, tokenizer: Tokenizer, device: torch.device, dynamo: bool = True):
def __init__(self, model: Transformer, tokenizer: Tokenizer, device: torch.device, dynamo: bool = True, spmd: bool = True):
self.model = model
self.tokenizer = tokenizer
self.device = device

self._generate_one_token_fn = self._generate_one_token

if spmd:
num_devices = xr.global_runtime_device_count() # updated way to get device count
device_ids = np.arange(num_devices)
x_dim = 2 # hard-coded for v5
yz_dim = 4 # hard-coded for v5

# manually shard the kv cache
four_d_mesh = xs.Mesh(device_ids, (1, 1, x_dim, yz_dim))
for layer in model.layers:
xs.mark_sharding(layer.attention.cache_k, four_d_mesh, (0, 1, 2, None))
xs.mark_sharding(layer.attention.cache_v, four_d_mesh, (0, 1, 2, None))

col_mesh = xs.Mesh(device_ids, (1, num_devices))
row_mesh = xs.Mesh(device_ids, (num_devices, 1))
two_d_mesh = xs.Mesh(device_ids, (x_dim, yz_dim))
two_d_mesh_transpose = xs.Mesh(device_ids, (yz_dim, x_dim))

for name, layer in model.named_modules():
if 'tok_embeddings' in name:
xs.mark_sharding(layer.weight, row_mesh, (0, 1))
if 'attention.' in name:
if 'wo' in name:
xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1))
else:
xs.mark_sharding(layer.weight, two_d_mesh, (0, 1))
if 'feed_forward.' in name:
if 'w2' in name:
xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1))
else:
xs.mark_sharding(layer.weight, two_d_mesh, (0, 1))
if 'output' in name:
xs.mark_sharding(layer.weight, col_mesh, (0, 1))

if dynamo:
if USE_CUDA:
# Inductor errors out when compiles _generate_one_token_fn.
Expand Down