Skip to content

Commit

Permalink
support pytorch2 scaled_dot_product_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
baofff committed Mar 22, 2023
1 parent 4450333 commit fe3a069
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 22 deletions.
31 changes: 20 additions & 11 deletions libs/uvit_multi_post_ln.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
import einops
import torch.utils.checkpoint
import torch.nn.functional as F
# the xformers lib allows less memory, faster training and inference
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
print('xformers enabled')
except:
XFORMERS_IS_AVAILBLE = False
print('xformers disabled')

if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
ATTENTION_MODE = 'flash'
else:
try:
import xformers
import xformers.ops
ATTENTION_MODE = 'xformers'
except:
ATTENTION_MODE = 'math'
print(f'attention mode is {ATTENTION_MODE}')


def timestep_embedding(timesteps, dim, max_period=10000):
Expand Down Expand Up @@ -73,19 +75,26 @@ def forward(self, x):
B, L, C = x.shape

qkv = self.qkv(x)
if XFORMERS_IS_AVAILBLE: # the xformers lib allows less memory, faster training and inference
if ATTENTION_MODE == 'flash':
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = einops.rearrange(x, 'B H L D -> B L (H D)')
elif ATTENTION_MODE == 'xformers':
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
x = xformers.ops.memory_efficient_attention(q, k, v)
x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
else:
elif ATTENTION_MODE == 'math':
with torch.amp.autocast(device_type='cuda', enabled=False):
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
else:
raise NotImplemented

x = self.proj(x)
x = self.proj_drop(x)
Expand Down
31 changes: 20 additions & 11 deletions libs/uvit_multi_post_ln_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
import einops
import torch.utils.checkpoint
import torch.nn.functional as F
# the xformers lib allows less memory, faster training and inference
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
print('xformers enabled')
except:
XFORMERS_IS_AVAILBLE = False
print('xformers disabled')

if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
ATTENTION_MODE = 'flash'
else:
try:
import xformers
import xformers.ops
ATTENTION_MODE = 'xformers'
except:
ATTENTION_MODE = 'math'
print(f'attention mode is {ATTENTION_MODE}')


def timestep_embedding(timesteps, dim, max_period=10000):
Expand Down Expand Up @@ -73,19 +75,26 @@ def forward(self, x):
B, L, C = x.shape

qkv = self.qkv(x)
if XFORMERS_IS_AVAILBLE: # the xformers lib allows less memory, faster training and inference
if ATTENTION_MODE == 'flash':
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = einops.rearrange(x, 'B H L D -> B L (H D)')
elif ATTENTION_MODE == 'xformers':
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
x = xformers.ops.memory_efficient_attention(q, k, v)
x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
else:
elif ATTENTION_MODE == 'math':
with torch.amp.autocast(device_type='cuda', enabled=False):
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
else:
raise NotImplemented

x = self.proj(x)
x = self.proj_drop(x)
Expand Down

0 comments on commit fe3a069

Please sign in to comment.