Skip to content

Implement cached inference #409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
44775fc
Implement weight sharing in transformer
borzunov Oct 29, 2021
e10096e
Pass sharing args from CLI
borzunov Oct 29, 2021
d7c034e
Improve checking reused attn type
borzunov Oct 29, 2021
f2a53e9
Fix view errors in sparse attention
borzunov Nov 2, 2021
14eb932
Implement share_input_output_emb option
borzunov Nov 8, 2021
c9f462a
Revert excess changes
borzunov Nov 9, 2021
1cd8e20
Add FFN caching
borzunov Dec 20, 2021
4d431ac
Cache to_qkv and to_out in sparse attn, add debug prints
borzunov Dec 20, 2021
2603776
Cache full Attention
borzunov Dec 20, 2021
112ea05
Remove debug outputs
borzunov Dec 20, 2021
6ba4cb6
Cache pre-logits MLP
borzunov Dec 20, 2021
c333ea7
Further optimize attention caching
borzunov Dec 20, 2021
1fd45ca
Fix mask in attention
borzunov Dec 21, 2021
2b77018
Don't cache MLPs since we can just pass only last item
borzunov Dec 21, 2021
8e8dea8
Revert excess changes in attentions
borzunov Dec 21, 2021
df89951
Rename FixCacheKey -> CachedAs
borzunov Dec 21, 2021
059fe1b
Save the current offset in cache
borzunov Dec 21, 2021
b76b78e
Use static masks to simulate axial attn
borzunov Dec 21, 2021
4c833a2
Add option to disable caching
borzunov Dec 21, 2021
1ff47c6
Make the cached version work
borzunov Dec 21, 2021
94fda36
Speed up PreShiftToken
borzunov Dec 21, 2021
adfce34
Remove excess changes
borzunov Jan 10, 2022
59cfc49
Add NonCached wrapper
borzunov Jan 10, 2022
4f496a4
Rename use_static_masks -> optimize_for_inference
borzunov Jan 10, 2022
732226d
Improve names and comments
borzunov Jan 10, 2022
c5f009a
Implement weight sharing in transformer
borzunov Oct 29, 2021
c57eae6
Pass sharing args from CLI
borzunov Oct 29, 2021
538c42a
Improve checking reused attn type
borzunov Oct 29, 2021
ec5b477
Fix view errors in sparse attention
borzunov Nov 2, 2021
543a3e4
Implement share_input_output_emb option
borzunov Nov 8, 2021
6de6cb0
Revert excess changes
borzunov Nov 9, 2021
f968697
Add CLI option for share_input_output_emb
borzunov Jan 10, 2022
1f517d7
Merge branch 'weight-sharing-v3' into inference-cache-v3
borzunov Jan 10, 2022
7f2f50c
Support cached inference with super-conditioning
borzunov Jan 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions dalle_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def apply_pos_emb(pos_emb, qkv):
# classes

class Attention(nn.Module):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False,
static_mask = None):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
Expand All @@ -46,25 +47,34 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou

self.stable = stable
self.causal = causal
self.register_buffer('static_mask', static_mask, persistent=False)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x, mask = None, rotary_pos_emb = None):
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax if not self.stable else stable_softmax
offset = cache.get('offset', 0) if exists(cache) else 0

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
q, k, v = apply_pos_emb(rotary_pos_emb[..., offset:, :], (q, k, v))

q = q * self.scale

if offset > 0:
k_top, v_top = cache[cache_key]
k = torch.cat([k_top, k], dim=-2)
v = torch.cat([v_top, v], dim=-2)
if exists(cache):
cache[cache_key] = k, v

dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
mask_value = max_neg_value(dots)

Expand All @@ -73,11 +83,14 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
dots.masked_fill_(~mask, mask_value)
del mask

if self.causal:
if self.causal and offset == 0: # causality is naturally enforced for the cached inference
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)

if exists(self.static_mask):
dots.masked_fill_(~self.static_mask[offset:offset + n, :offset + n], mask_value)

attn = softmax(dots, dim=-1)

out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
Expand Down
64 changes: 52 additions & 12 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ def top_k(logits, thres = 0.5):
probs.scatter_(1, ind, val)
return probs

class SharedEmbedding(nn.Embedding):
def __init__(self, linear, start_index, end_index, **kwargs):
super().__init__(end_index - start_index, linear.weight.shape[1], **kwargs)
del self.weight

self.linear = linear
self.start_index = start_index
self.end_index = end_index

def forward(self, input):
return F.embedding(
input, self.linear.weight[self.start_index:self.end_index], self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)

# discrete vae class

class ResBlock(nn.Module):
Expand Down Expand Up @@ -339,7 +353,11 @@ def __init__(
stable = False,
sandwich_norm = False,
shift_tokens = True,
rotary_emb = True
rotary_emb = True,
shared_attn_ids = None,
shared_ff_ids = None,
share_input_output_emb = False,
optimize_for_inference = False,
):
super().__init__()
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'
Expand All @@ -351,9 +369,6 @@ def __init__(

num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len)

self.text_emb = nn.Embedding(num_text_tokens, dim)
self.image_emb = nn.Embedding(num_image_tokens, dim)

self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) if not rotary_emb else always(0) # +1 for <bos>
self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0)

Expand Down Expand Up @@ -387,7 +402,10 @@ def __init__(
stable = stable,
sandwich_norm = sandwich_norm,
shift_tokens = shift_tokens,
rotary_emb = rotary_emb
rotary_emb = rotary_emb,
shared_attn_ids = shared_attn_ids,
shared_ff_ids = shared_ff_ids,
optimize_for_inference = optimize_for_inference,
)

self.stable = stable
Expand All @@ -400,6 +418,13 @@ def __init__(
nn.Linear(dim, self.total_tokens),
)

if share_input_output_emb:
self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens)
self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens)
else:
self.text_emb = nn.Embedding(num_text_tokens, dim)
self.image_emb = nn.Embedding(num_image_tokens, dim)

seq_range = torch.arange(seq_len)
logits_range = torch.arange(total_tokens)

Expand Down Expand Up @@ -430,7 +455,7 @@ def generate_texts(
text_tokens = torch.tensor([[0]]).cuda()
else:
text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0)

for _ in range(text_tokens.shape[1], text_seq_len):
device = text_tokens.device

Expand All @@ -457,7 +482,7 @@ def generate_texts(
sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

text_tokens = torch.cat((text_tokens, sample[:, None]), dim=-1)

padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
return text_tokens, texts
Expand All @@ -473,7 +498,8 @@ def generate_images(
temperature = 1.,
img = None,
num_init_img_tokens = None,
cond_scale = 1.
cond_scale = 1.,
use_cache = False,
):
vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
total_len = text_seq_len + image_seq_len
Expand All @@ -492,17 +518,23 @@ def generate_images(
indices = indices[:, :num_img_tokens]
out = torch.cat((out, indices), dim = -1)

prev_cache = None
cache = {} if use_cache else None
for cur_len in range(out.shape[1], total_len):
is_image = cur_len >= text_seq_len

text, image = out[:, :text_seq_len], out[:, text_seq_len:]

logits = self(text, image)
if cond_scale != 1 and use_cache:
# copy the cache state to infer from the same place twice
prev_cache = cache.copy()

logits = self(text, image, cache = cache)

if cond_scale != 1:
# discovery by Katherine Crowson
# https://twitter.com/RiversHaveWings/status/1478093658716966912
null_cond_logits = self(text, image, null_cond_prob = 1.)
null_cond_logits = self(text, image, null_cond_prob = 1., cache = prev_cache)
logits = null_cond_logits + (logits - null_cond_logits) * cond_scale

logits = logits[:, -1, :]
Expand All @@ -529,7 +561,8 @@ def forward(
text,
image = None,
return_loss = False,
null_cond_prob = 0.
null_cond_prob = 0.,
cache = None,
):
assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
batch, device, total_seq_len = text.shape[0], text.device, self.total_seq_len
Expand Down Expand Up @@ -583,7 +616,9 @@ def forward(
alpha = 0.1
tokens = tokens * alpha + tokens.detach() * (1 - alpha)

out = self.transformer(tokens)
if exists(cache) and cache.get('offset'):
tokens = tokens[:, -1:]
out = self.transformer(tokens, cache=cache)

if self.stable:
out = self.norm_by_max(out)
Expand All @@ -593,9 +628,14 @@ def forward(
# mask logits to make sure text predicts text (except last token), and image predicts image

logits_mask = self.logits_mask[:, :seq_len]
if exists(cache) and cache.get('offset'):
logits_mask = logits_mask[:, -1:]
max_neg_value = -torch.finfo(logits.dtype).max
logits.masked_fill_(logits_mask, max_neg_value)

if exists(cache):
cache['offset'] = cache.get('offset', 0) + logits.shape[1]

if not return_loss:
return logits

Expand Down
Loading