Skip to content

Commit 0a64c55

Browse files
xmfanfacebook-github-bot
authored andcommitted
Add DTensor LLaMA inference model: simple_gpt (#1867)
Summary: Adds simple_gpt + DTensor implemented in pytorch-labs/simple_gpt#7 to torchbench Tested via `python benchmarks/dynamo/torchbench.py -d cuda --output-directory=benchmark_logs --output=performance.csv --inference --performance --timing --print-memory --multiprocess --nothing --only simple_gpt`. Note: --nothing is used here to disable compile, since DTensor + compile isn't yet supported in main ``` dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio,eager_peak_mem,dynamo_peak_mem,calls_captured,unique_graphs,graph_breaks,unique_graph_breaks cuda,simple_gpt,1,0.966153,196.819773,-0.059319,1.000000,4.576880,4.576880,0,0,0,0 cuda,simple_gpt,1,0.967389,196.608152,-0.058833,1.000000,4.577404,4.577404,0,0,0,0 cuda,simple_gpt,1,0.973152,196.093583,-0.059316,1.000000,4.593133,4.593133,0,0,0,0 cuda,simple_gpt,1,0.973087,196.124046,-0.075580,1.000000,4.611483,4.611483,0,0,0,0 cuda,simple_gpt,1,0.967908,193.998484,-0.040192,1.000000,4.593133,4.593133,0,0,0,0 cuda,simple_gpt,1,0.968949,193.798088,-0.028878,1.000000,4.593133,4.593133,0,0,0,0 ``` 2 changes were required to the model: - decorate torch.no_grad() on the caches, previously this was done outside the model, the entire eval call was wrapped in a torch.no_grad() context. After using torchbench, I notice even with only inference mode, we don't disable gradient calculations - rank/world size, added support from torchbench side in pytorch/pytorch#108438 and updated model to fetch from the provided extra_args Pull Request resolved: #1867 Reviewed By: msaroufim Differential Revision: D49065244 Pulled By: xmfan fbshipit-source-id: d4709fa3997c6a25c75e87eff7c13492b370b1af
1 parent e768fd3 commit 0a64c55

File tree

6 files changed

+496
-14
lines changed

6 files changed

+496
-14
lines changed

torchbenchmark/models/ADDING_MODELS.md

+14-14
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
## Detailed steps
1010

1111
### Adding the model code
12-
The intent is to preserve the original user code as much as possible while
12+
The intent is to preserve the original user code as much as possible while
1313
adding support for a standardized interface to the benchmark suite and making sure
1414
the code can run from any directory and in a process with other models.
1515

1616
In many case it is fine to simply copy the entire original repo into a subdirectory
17-
as a starting point, paying attention to avoid the .git folder, and not to add any
17+
as a starting point, paying attention to avoid the .git folder, and not to add any
1818
large unnecessary data files unintentionally. The subdirectory name should be a valid
1919
Python identifier because it will become a module in Python and needs to be importable.
2020

21-
Create a new file 'origin' that contains the url to the git repo you're copying,
21+
Create a new file 'origin' that contains the url to the git repo you're copying,
2222
so it's easy to trace the code back to where it came from.
2323

2424
#### Wrapping your model in \_\_init\_\_.py
@@ -34,22 +34,22 @@ Take care to set the random seed like [here](https://github.com/pytorch/benchmar
3434
#### A minimal new model addition
3535
A bare miminum example you can follow is https://github.com/pytorch/benchmark/tree/main/torchbenchmark/models/phlippe_resnet
3636

37-
The functions you specifically need to implement are
37+
The functions you specifically need to implement are
3838
1. `__init__()` which is responsible for initalizing your `nn.Module`
3939
2. `get_module()` which is responsible for returning the initialized `nn.Module` and an example input
4040
3. `train()` which is a training loop, you can return a `NotImplementedError()` if your example is inference only. If your
4141
training loop can be encapsulated by a `forward()`, `backward()`, and `optimizer_step()`, you need not redefine `train()`.
4242
Instead, please make sure your model provides functions `forward()`, `backward()`, and `optimizer_step()` along with an
43-
attribute `self.optimizer` which will be chained together for testing, see `invoke_staged_train_test()` for details.
43+
attribute `self.optimizer` which will be chained together for testing, see `invoke_staged_train_test()` for details.
4444
4. `eval()` which showcases a simple inference
4545

46-
Optionally, if you would like to be able to customize different optimizers for your model, feel free
46+
Optionally, if you would like to be able to customize different optimizers for your model, feel free
4747
to override the BenchmarkModel's base class' default `get_optimizer()` and `set_optimizer(optimizer)`
48-
methods.
48+
methods.
4949

5050
### Preparing install.py and dependencies
5151
Simply put, install.py should be a one stop shop to install all the dependencies
52-
for your model, __except torch, torchvision, torchaudio__ which should be assumed to
52+
for your model, __except torch, torchvision, torchaudio__ which should be assumed to
5353
have been installed by an outsider (the benchmark CI).
5454

5555
- Avoid pinning packages to specific versions with == without good reason, as the
@@ -65,7 +65,7 @@ not easy to build, there may be easier models to target.
6565
[Example install.py](BERT_pytorch/install.py)
6666

6767
### Mini-dataset
68-
By the time install.py script runs, a miniature version of the dataset is expected to be
68+
By the time install.py script runs, a miniature version of the dataset is expected to be
6969
staged and ready for use. It's fine to use install.py to download and prepare the data
7070
if the download is quick. Otherwise, prepare the dataset manually, checking in the required
7171
artifacts and modifying the \_\_init\_\_.py script as needed to use them.
@@ -95,8 +95,8 @@ This file should define two things:
9595
- `__main__` function, which exercises the model APIs for local testing
9696

9797
Important: be deliberate about support for cpu/gpu and jit/no-jit. In the case that
98-
your model is instantiated in an unsupported configuration, the convention is to return
99-
a model object from \_\_init\_\_ but raise NotImplementedError() from all its methods.
98+
your model is instantiated in an unsupported configuration, the convention is to raise
99+
NotImplementedError from \_\_init\_\_.
100100

101101
See the [BenchmarkModel API](https://github.com/pytorch/benchmark/blob/master/torchbenchmark/util/model.py) to get started. The [BERT_pytorch](BERT_pytorch/__init__.py) benchmark can serve as a good example.
102102

@@ -109,11 +109,11 @@ version.
109109

110110
### Test
111111

112-
After you've submitted your new model, suppose it was called `new_model` make sure the tests pass locally. Your model name is equivalent to the new folder you'd have created in `torchbenchmark/models`
112+
After you've submitted your new model, suppose it was called `<new_model>` make sure the tests pass locally. Your model name is equivalent to the new folder you'd have created in `torchbenchmark/models`
113113

114114
1. `cd benchmark`
115115
2. `python install.py`
116-
3. `python run.py model -d cuda` and `python run.py model -d cpu`
117-
3. `python test.py -k "model_"` following the format from here https://github.com/pytorch/benchmark#using-testpy
116+
3. `python run.py <new_model> -d cuda` and `python run.py <new_model> -d cpu`
117+
3. `python test.py -k "test_<new_model>_"` following the format from here https://github.com/pytorch/benchmark#using-testpy
118118

119119
And thank you for contributing to torchbench!
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import os
2+
3+
import torch
4+
from torch.distributed._tensor import DeviceMesh
5+
from torch.distributed.tensor.parallel import parallelize_module
6+
from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel
7+
from torchbenchmark.tasks import NLP
8+
9+
from ...util.model import BenchmarkModel
10+
from .model import LLaMA
11+
12+
13+
class Model(BenchmarkModel):
14+
task = NLP.GENERATION
15+
DEFAULT_EVAL_BSIZE = 1
16+
17+
def validate_environment(self):
18+
if not torch.cuda.is_available() or "cuda" not in self.device:
19+
return NotImplementedError("Model requires CUDA")
20+
21+
if not torch.cuda.is_bf16_supported():
22+
return NotImplementedError("Model requires BF16")
23+
24+
if not hasattr(self, "_world_size"):
25+
return NotImplementedError("Model needs to be run via dynamo torchbench and be provided distributed parameters")
26+
27+
if self._world_size != torch.cuda.device_count():
28+
return NotImplementedError(
29+
f"DTensor and all local GPUs to be within the device mesh. {torch.cuda.device_count()} local GPUs, but only world size is only {self._world_size}"
30+
)
31+
32+
return None
33+
34+
def __init__(self, test, device, batch_size=None, extra_args=[]):
35+
super().__init__(
36+
test=test,
37+
device=device,
38+
batch_size=batch_size,
39+
extra_args=extra_args,
40+
)
41+
42+
error = self.validate_environment()
43+
if error:
44+
raise error
45+
46+
self.model = LLaMA.from_name("7B", self._world_size).to(device=device, dtype=torch.bfloat16)
47+
48+
# Tensor parallelism using DTensor
49+
mesh = DeviceMesh("cuda", list(range(self._world_size)))
50+
for block in self.model.transformer.h:
51+
# prepare attention weights to be parallelized
52+
block.attn.prepare_qkv_for_dtensor_tp()
53+
54+
parallelize_module(
55+
module=block,
56+
device_mesh=mesh,
57+
parallelize_plan={
58+
"attn.c_attn_q": ColwiseParallel(),
59+
"attn.c_attn_k": ColwiseParallel(),
60+
"attn.c_attn_v": ColwiseParallel(),
61+
"attn.c_proj": RowwiseParallel(),
62+
"mlp.c_fc1": ColwiseParallel(),
63+
"mlp.c_fc2": ColwiseParallel(),
64+
"mlp.c_proj": RowwiseParallel(),
65+
},
66+
tp_mesh_dim=0,
67+
)
68+
69+
max_batch_size = self.DEFAULT_EVAL_BSIZE
70+
self.model.setup_caches(
71+
max_batch_size=max_batch_size, max_seq_length=self.model.config.block_size
72+
)
73+
74+
prompt_size = 10
75+
idx = torch.randint(
76+
self.model.config.vocab_size,
77+
(max_batch_size, prompt_size),
78+
dtype=torch.int32,
79+
device=device,
80+
)
81+
input_pos = torch.arange(prompt_size, device=device)
82+
self.example_inputs = [idx, input_pos]
83+
84+
def get_module(self):
85+
return self.model, self.example_inputs
86+
87+
def train(self):
88+
raise NotImplementedError("Training not supported for this model")
89+
90+
def eval(self):
91+
raise NotImplementedError("Model needs to be run via dynamo torchbench and be provided distributed parameters")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
eval_benchmark: false
2+
eval_deterministic: false
3+
eval_nograd: true
4+
train_benchmark: false
5+
train_deterministic: false

0 commit comments

Comments
 (0)