2
2
Copyright 2022 The Microsoft DeepSpeed Team
3
3
'''
4
4
import torch
5
+ from ..features .cuda_graph import CUDAGraph
5
6
6
7
7
- class DSVAE (torch .nn .Module ):
8
+ class DSVAE (CUDAGraph , torch .nn .Module ):
8
9
def __init__ (self , vae , enable_cuda_graph = True ):
9
- super ().__init__ ()
10
+ super ().__init__ (enable_cuda_graph = enable_cuda_graph )
10
11
self .vae = vae
11
12
self .device = self .vae .device
12
13
self .dtype = self .vae .dtype
13
14
self .vae .requires_grad_ (requires_grad = False )
14
15
self .decoder_cuda_graph_created = False
15
16
self .encoder_cuda_graph_created = False
16
17
self .all_cuda_graph_created = False
17
- self .enable_cuda_graph = enable_cuda_graph
18
18
19
19
def _graph_replay_decoder (self , * inputs , ** kwargs ):
20
20
for i in range (len (inputs )):
@@ -104,7 +104,7 @@ def encode(self, *inputs, **kwargs):
104
104
else :
105
105
return self ._encode (* inputs , ** kwargs )
106
106
107
- def _graph_replay_all (self , * inputs , ** kwargs ):
107
+ def _graph_replay (self , * inputs , ** kwargs ):
108
108
for i in range (len (inputs )):
109
109
if torch .is_tensor (inputs [i ]):
110
110
self .static_inputs [i ].copy_ (inputs [i ])
@@ -117,10 +117,10 @@ def _graph_replay_all(self, *inputs, **kwargs):
117
117
def forward (self , * inputs , ** kwargs ):
118
118
if self .enable_cuda_graph :
119
119
if self .cuda_graph_created :
120
- outputs = self ._graph_replay_all (* inputs , ** kwargs )
120
+ outputs = self ._graph_replay (* inputs , ** kwargs )
121
121
else :
122
122
self ._create_cuda_graph (* inputs , ** kwargs )
123
- outputs = self ._graph_replay_all (* inputs , ** kwargs )
123
+ outputs = self ._graph_replay (* inputs , ** kwargs )
124
124
return outputs
125
125
else :
126
126
return self ._forward (* inputs , ** kwargs )
0 commit comments