Skip to content

Commit 87eaf8f

Browse files
authored
Check for local CUDA graphs when enable_cuda_graph=True (#2941)
1 parent 2ede0d9 commit 87eaf8f

File tree

8 files changed

+73
-15
lines changed

8 files changed

+73
-15
lines changed

deepspeed/inference/engine.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from ..module_inject.policy import TransformerPolicy
2525
from ..module_inject.auto_tp import AutoTP
2626

27+
from ..module_inject.replace_policy import generic_policies
28+
2729
DS_INFERENCE_ENABLED = False
2830
from torch import nn
2931

@@ -155,6 +157,9 @@ def __init__(self, model, config):
155157
if config.tensor_parallel.tp_size > 1:
156158
assert not config.enable_cuda_graph, "Cuda graph is not supported for model parallelism"
157159

160+
# Check if local CUDA graphs can be created in replacement modules
161+
self.local_cuda_graph = self._local_cuda_graph_used(self.module)
162+
158163
def profile_model_time(self, use_cuda_events=True):
159164
if not self.model_profile_enabled and not self._config.enable_cuda_graph:
160165
self.module.register_forward_pre_hook(self._pre_forward_hook)
@@ -512,6 +517,27 @@ def model_times(self):
512517
self._model_times = []
513518
return model_times
514519

520+
def _module_match(self, module):
521+
for policy in generic_policies:
522+
policy = policy()
523+
if policy.match_replaced(module):
524+
return True
525+
return False
526+
527+
def _local_cuda_graph_used(self, module):
528+
if isinstance(module, torch.nn.Module):
529+
return False
530+
else:
531+
sub_module_cuda_graph = False
532+
for name in module.__dict__.keys():
533+
sub_module = getattr(module, name)
534+
535+
if self._module_match(sub_module) and hasattr(sub_module,
536+
"enable_cuda_graph"):
537+
sub_module_cuda_graph = True
538+
539+
return sub_module_cuda_graph
540+
515541
def forward(self, *inputs, **kwargs):
516542
"""Execute forward propagation
517543
@@ -525,7 +551,8 @@ def forward(self, *inputs, **kwargs):
525551
get_accelerator().synchronize()
526552
start = time.time()
527553

528-
if get_accelerator().device_name() == 'cuda' and self._config.enable_cuda_graph:
554+
if get_accelerator().device_name(
555+
) == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph:
529556
if self.cuda_graph_created:
530557
outputs = self._graph_replay(*inputs, **kwargs)
531558
else:

deepspeed/model_implementations/diffusers/unet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
Copyright 2022 The Microsoft DeepSpeed Team
33
'''
44
import torch
5+
from ..features.cuda_graph import CUDAGraph
56

67

7-
class DSUNet(torch.nn.Module):
8+
class DSUNet(CUDAGraph, torch.nn.Module):
89
def __init__(self, unet, enable_cuda_graph=True):
9-
super().__init__()
10+
super().__init__(enable_cuda_graph=enable_cuda_graph)
1011
self.unet = unet
1112
# SD pipeline accesses this attribute
1213
self.in_channels = unet.in_channels
@@ -17,7 +18,6 @@ def __init__(self, unet, enable_cuda_graph=True):
1718
self.unet.requires_grad_(requires_grad=False)
1819
self.unet.to(memory_format=torch.channels_last)
1920
self.cuda_graph_created = False
20-
self.enable_cuda_graph = enable_cuda_graph
2121

2222
def _graph_replay(self, *inputs, **kwargs):
2323
for i in range(len(inputs)):

deepspeed/model_implementations/diffusers/vae.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,19 @@
22
Copyright 2022 The Microsoft DeepSpeed Team
33
'''
44
import torch
5+
from ..features.cuda_graph import CUDAGraph
56

67

7-
class DSVAE(torch.nn.Module):
8+
class DSVAE(CUDAGraph, torch.nn.Module):
89
def __init__(self, vae, enable_cuda_graph=True):
9-
super().__init__()
10+
super().__init__(enable_cuda_graph=enable_cuda_graph)
1011
self.vae = vae
1112
self.device = self.vae.device
1213
self.dtype = self.vae.dtype
1314
self.vae.requires_grad_(requires_grad=False)
1415
self.decoder_cuda_graph_created = False
1516
self.encoder_cuda_graph_created = False
1617
self.all_cuda_graph_created = False
17-
self.enable_cuda_graph = enable_cuda_graph
1818

1919
def _graph_replay_decoder(self, *inputs, **kwargs):
2020
for i in range(len(inputs)):
@@ -104,7 +104,7 @@ def encode(self, *inputs, **kwargs):
104104
else:
105105
return self._encode(*inputs, **kwargs)
106106

107-
def _graph_replay_all(self, *inputs, **kwargs):
107+
def _graph_replay(self, *inputs, **kwargs):
108108
for i in range(len(inputs)):
109109
if torch.is_tensor(inputs[i]):
110110
self.static_inputs[i].copy_(inputs[i])
@@ -117,10 +117,10 @@ def _graph_replay_all(self, *inputs, **kwargs):
117117
def forward(self, *inputs, **kwargs):
118118
if self.enable_cuda_graph:
119119
if self.cuda_graph_created:
120-
outputs = self._graph_replay_all(*inputs, **kwargs)
120+
outputs = self._graph_replay(*inputs, **kwargs)
121121
else:
122122
self._create_cuda_graph(*inputs, **kwargs)
123-
outputs = self._graph_replay_all(*inputs, **kwargs)
123+
outputs = self._graph_replay(*inputs, **kwargs)
124124
return outputs
125125
else:
126126
return self._forward(*inputs, **kwargs)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
'''Copyright The Microsoft DeepSpeed Team'''
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
'''
2+
Copyright 2023 The Microsoft DeepSpeed Team
3+
'''
4+
from abc import ABC, abstractmethod
5+
6+
7+
class CUDAGraph(ABC):
8+
def __init__(self, enable_cuda_graph=False):
9+
super().__init__()
10+
self.enable_cuda_graph = enable_cuda_graph
11+
12+
@abstractmethod
13+
def _create_cuda_graph(self):
14+
"""
15+
Create CUDA graph(s)
16+
"""
17+
raise NotImplementedError
18+
19+
@abstractmethod
20+
def _graph_replay(self):
21+
"""
22+
Replay CUDA graph(s)
23+
"""
24+
raise NotImplementedError

deepspeed/model_implementations/transformers/clip_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
'''
44
import torch
55
from deepspeed.accelerator import get_accelerator
6+
from ..features.cuda_graph import CUDAGraph
67

78

8-
class DSClipEncoder(torch.nn.Module):
9+
class DSClipEncoder(CUDAGraph, torch.nn.Module):
910
def __init__(self, enc, enable_cuda_graph=False):
10-
super().__init__()
11+
super().__init__(enable_cuda_graph=enable_cuda_graph)
1112
enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask
1213
self.enc = enc
1314
self.device = self.enc.device
@@ -18,7 +19,6 @@ def __init__(self, enc, enable_cuda_graph=False):
1819
self.static_output = [None, None]
1920
self._cuda_graphs = [None, None]
2021
self.iter = 0
21-
self.enable_cuda_graph = enable_cuda_graph
2222
self.config = self.enc.config
2323

2424
def _build_causal_attention_mask(self, bsz, seq_len, dtype):

deepspeed/module_inject/containers/unet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch.nn.parameter import Parameter
66

77
from ..policy import DSPolicy
8+
from ...model_implementations.diffusers.unet import DSUNet
89

910

1011
class UNetPolicy(DSPolicy):
@@ -19,9 +20,11 @@ def __init__(self):
1920
def match(self, module):
2021
return isinstance(module, self._orig_layer_class)
2122

23+
def match_replaced(self, module):
24+
return isinstance(module, DSUNet)
25+
2226
def apply(self, module, enable_cuda_graph=True):
2327
# TODO(cmikeh2): Enable cuda graph should be an inference configuration
24-
from ...model_implementations.diffusers.unet import DSUNet
2528
return DSUNet(module, enable_cuda_graph=enable_cuda_graph)
2629

2730
def attention(self, client_module):

deepspeed/module_inject/containers/vae.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Copyright 2022 The Microsoft DeepSpeed Team
33
'''
44
from ..policy import DSPolicy
5+
from ...model_implementations.diffusers.vae import DSVAE
56

67

78
class VAEPolicy(DSPolicy):
@@ -20,9 +21,11 @@ def __init__(self):
2021
def match(self, module):
2122
return isinstance(module, self._orig_layer_class)
2223

24+
def match_replaced(self, module):
25+
return isinstance(module, DSVAE)
26+
2327
def apply(self, module, enable_cuda_graph=True):
2428
# TODO(cmikeh2): Enable cuda graph should be an inference configuration
25-
from ...model_implementations.diffusers.vae import DSVAE
2629
return DSVAE(module, enable_cuda_graph=enable_cuda_graph)
2730

2831
# NOTE (lekurile): Should we have a diffusers policy class?

0 commit comments

Comments
 (0)