diff --git a/dalle_pytorch/attention.py b/dalle_pytorch/attention.py index 48131e89..93720331 100644 --- a/dalle_pytorch/attention.py +++ b/dalle_pytorch/attention.py @@ -37,7 +37,8 @@ def apply_pos_emb(pos_emb, qkv): # classes class Attention(nn.Module): - def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False): + def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False, + static_mask = None): super().__init__() inner_dim = dim_head * heads self.heads = heads @@ -46,6 +47,7 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou self.stable = stable self.causal = causal + self.register_buffer('static_mask', static_mask, persistent=False) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( @@ -53,18 +55,26 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou nn.Dropout(dropout) ) - def forward(self, x, mask = None, rotary_pos_emb = None): + def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None): b, n, _, h, device = *x.shape, self.heads, x.device softmax = torch.softmax if not self.stable else stable_softmax + offset = cache.get('offset', 0) if exists(cache) else 0 qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) if exists(rotary_pos_emb): - q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) + q, k, v = apply_pos_emb(rotary_pos_emb[..., offset:, :], (q, k, v)) q = q * self.scale + if offset > 0: + k_top, v_top = cache[cache_key] + k = torch.cat([k_top, k], dim=-2) + v = torch.cat([v_top, v], dim=-2) + if exists(cache): + cache[cache_key] = k, v + dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) mask_value = max_neg_value(dots) @@ -73,11 +83,14 @@ def forward(self, x, mask = None, rotary_pos_emb = None): dots.masked_fill_(~mask, mask_value) del mask - if self.causal: + if self.causal and offset == 0: # causality is naturally enforced for the cached inference i, j = dots.shape[-2:] mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() dots.masked_fill_(mask, mask_value) + if exists(self.static_mask): + dots.masked_fill_(~self.static_mask[offset:offset + n, :offset + n], mask_value) + attn = softmax(dots, dim=-1) out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index 2cd18752..75147577 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -68,6 +68,20 @@ def top_k(logits, thres = 0.5): probs.scatter_(1, ind, val) return probs +class SharedEmbedding(nn.Embedding): + def __init__(self, linear, start_index, end_index, **kwargs): + super().__init__(end_index - start_index, linear.weight.shape[1], **kwargs) + del self.weight + + self.linear = linear + self.start_index = start_index + self.end_index = end_index + + def forward(self, input): + return F.embedding( + input, self.linear.weight[self.start_index:self.end_index], self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + # discrete vae class class ResBlock(nn.Module): @@ -339,7 +353,11 @@ def __init__( stable = False, sandwich_norm = False, shift_tokens = True, - rotary_emb = True + rotary_emb = True, + shared_attn_ids = None, + shared_ff_ids = None, + share_input_output_emb = False, + optimize_for_inference = False, ): super().__init__() assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE' @@ -351,9 +369,6 @@ def __init__( num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len) - self.text_emb = nn.Embedding(num_text_tokens, dim) - self.image_emb = nn.Embedding(num_image_tokens, dim) - self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) if not rotary_emb else always(0) # +1 for self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0) @@ -387,7 +402,10 @@ def __init__( stable = stable, sandwich_norm = sandwich_norm, shift_tokens = shift_tokens, - rotary_emb = rotary_emb + rotary_emb = rotary_emb, + shared_attn_ids = shared_attn_ids, + shared_ff_ids = shared_ff_ids, + optimize_for_inference = optimize_for_inference, ) self.stable = stable @@ -400,6 +418,13 @@ def __init__( nn.Linear(dim, self.total_tokens), ) + if share_input_output_emb: + self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens) + self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens) + else: + self.text_emb = nn.Embedding(num_text_tokens, dim) + self.image_emb = nn.Embedding(num_image_tokens, dim) + seq_range = torch.arange(seq_len) logits_range = torch.arange(total_tokens) @@ -430,7 +455,7 @@ def generate_texts( text_tokens = torch.tensor([[0]]).cuda() else: text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0) - + for _ in range(text_tokens.shape[1], text_seq_len): device = text_tokens.device @@ -457,7 +482,7 @@ def generate_texts( sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) text_tokens = torch.cat((text_tokens, sample[:, None]), dim=-1) - + padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len)) texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens] return text_tokens, texts @@ -473,7 +498,8 @@ def generate_images( temperature = 1., img = None, num_init_img_tokens = None, - cond_scale = 1. + cond_scale = 1., + use_cache = False, ): vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens total_len = text_seq_len + image_seq_len @@ -492,17 +518,23 @@ def generate_images( indices = indices[:, :num_img_tokens] out = torch.cat((out, indices), dim = -1) + prev_cache = None + cache = {} if use_cache else None for cur_len in range(out.shape[1], total_len): is_image = cur_len >= text_seq_len text, image = out[:, :text_seq_len], out[:, text_seq_len:] - logits = self(text, image) + if cond_scale != 1 and use_cache: + # copy the cache state to infer from the same place twice + prev_cache = cache.copy() + + logits = self(text, image, cache = cache) if cond_scale != 1: # discovery by Katherine Crowson # https://twitter.com/RiversHaveWings/status/1478093658716966912 - null_cond_logits = self(text, image, null_cond_prob = 1.) + null_cond_logits = self(text, image, null_cond_prob = 1., cache = prev_cache) logits = null_cond_logits + (logits - null_cond_logits) * cond_scale logits = logits[:, -1, :] @@ -529,7 +561,8 @@ def forward( text, image = None, return_loss = False, - null_cond_prob = 0. + null_cond_prob = 0., + cache = None, ): assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})' batch, device, total_seq_len = text.shape[0], text.device, self.total_seq_len @@ -583,7 +616,9 @@ def forward( alpha = 0.1 tokens = tokens * alpha + tokens.detach() * (1 - alpha) - out = self.transformer(tokens) + if exists(cache) and cache.get('offset'): + tokens = tokens[:, -1:] + out = self.transformer(tokens, cache=cache) if self.stable: out = self.norm_by_max(out) @@ -593,9 +628,14 @@ def forward( # mask logits to make sure text predicts text (except last token), and image predicts image logits_mask = self.logits_mask[:, :seq_len] + if exists(cache) and cache.get('offset'): + logits_mask = logits_mask[:, -1:] max_neg_value = -torch.finfo(logits.dtype).max logits.masked_fill_(logits_mask, max_neg_value) + if exists(cache): + cache['offset'] = cache.get('offset', 0) + logits.shape[1] + if not return_loss: return logits diff --git a/dalle_pytorch/transformer.py b/dalle_pytorch/transformer.py index 749c993d..46735ab3 100644 --- a/dalle_pytorch/transformer.py +++ b/dalle_pytorch/transformer.py @@ -1,3 +1,5 @@ +from collections import deque +from collections.abc import Iterable from functools import partial from itertools import islice, cycle @@ -21,9 +23,7 @@ def default(val, d): return val if exists(val) else d def cast_tuple(val, depth = 1): - if isinstance(val, list): - val = tuple(val) - return val if isinstance(val, tuple) else (val,) * depth + return val if isinstance(val, Iterable) else (val,) * depth # classes @@ -36,6 +36,41 @@ def forward(self, x): maxes = x.amax(dim = self.dim, keepdim = True).detach() return x / maxes +class NonCached(nn.Module): + """ + A wrapper for layers that don't support the inference cache themselves. + Reconstructs the full sequence before the layer and + cuts the suffix of the outputs after the layer. + """ + + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *, cache = None, cache_key = None, **kwargs): + n = x.shape[-2] + if exists(cache): + if cache_key in cache: + x = torch.cat([cache[cache_key], x], dim=-2) + cache[cache_key] = x + + out = self.fn(x, **kwargs) + + return out[:, -n:] + +class CachedAs(nn.Module): + """ + A wrapper that defines a key for the inference cache. + """ + + def __init__(self, cache_key, fn): + super().__init__() + self.cache_key = cache_key + self.fn = fn + + def forward(self, x, *, cache=None, **kwargs): + return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs) + # https://arxiv.org/abs/2103.17239 class LayerScale(nn.Module): def __init__(self, dim, depth, fn): @@ -84,7 +119,7 @@ def __init__(self, dim, dropout = 0., mult = 4.): nn.Linear(dim * mult, dim) ) - def forward(self, x): + def forward(self, x, cache=None, cache_key=None): return self.net(x) # token shift classes @@ -95,12 +130,30 @@ def __init__(self, fn, image_size, seq_len): self.fn = fn self.image_size = image_size self.seq_len = seq_len + self.img_seq_len = image_size ** 2 + self.text_len = seq_len - self.img_seq_len + 1 + + def forward(self, x, cache=None, cache_key=None, **kwargs): + seq_len, image_size, text_len = self.seq_len, self.image_size, self.text_len + + if exists(cache) and cache_key in cache: + offset = cache['offset'] + assert offset >= text_len, "cached inference for text is not supported" + q = cache[cache_key] + assert isinstance(q, deque) and len(q) == image_size + + x_top, x_left, *x_pass = x[:, -1].chunk(4, dim=-1) + + q.append((x_top, x_left)) + x_top = q.popleft()[0] + x_left = q[-2][1] + if (offset - text_len) % image_size == 0: + x_left = torch.zeros_like(x_left) + + x = torch.cat((x_top, x_left, *x_pass), dim=-1) + return self.fn(x[:, None], cache=cache, **kwargs) - def forward(self, x, **kwargs): n = x.shape[1] - seq_len, image_size = self.seq_len, self.image_size - img_seq_len = image_size ** 2 - text_len = seq_len - img_seq_len + 1 padding = seq_len - n + 1 # if sequence is shorter than the text length, no image tokens to shift @@ -130,8 +183,22 @@ def forward(self, x, **kwargs): # merge text and image sequence back together x_img = rearrange(x_img, 'b h w d -> b (h w) d') - x = torch.cat((x_text, x_img[:, :-padding]), dim = 1) - return self.fn(x, **kwargs) + x_img = x_img[:, :-padding] + x = torch.cat((x_text, x_img), dim = 1) + + if exists(cache): + dummy_top, dummy_left, *_ = x[:, -1].chunk(4, dim=-1) + dummy_top, dummy_left = torch.zeros_like(dummy_top), torch.zeros_like(dummy_left) + + q = deque() + x_img = x_img[:, -image_size:] + for _ in range(image_size - x_img.shape[1]): + q.append((dummy_top, dummy_left)) + for i in range(x_img.shape[1]): + q.append(x_img[:, i].chunk(4, dim=-1)[:2]) + cache[cache_key] = q + + return self.fn(x, cache=cache, **kwargs) # main transformer class @@ -155,25 +222,43 @@ def __init__( stable = False, sandwich_norm = False, shift_tokens = False, - rotary_emb = True + rotary_emb = True, + shared_attn_ids = None, + shared_ff_ids = None, + optimize_for_inference = False, # use cache-friendly masked attention instead of sparse one ): super().__init__() layers = nn.ModuleList([]) sparse_layer = cast_tuple(sparse_attn, depth) + self.seq_len = seq_len + self.image_fmap_size = image_fmap_size + attn_types = default(attn_types, ('full',)) attn_types = cast_tuple(attn_types) attn_type_layer = islice(cycle(attn_types), depth) - for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer): + shared_attn_ids = cycle(default(shared_attn_ids, range(depth))) + shared_ff_ids = cycle(default(shared_ff_ids, range(depth))) + shared_attn_layers = {} + shared_ff_layers = {} + + for (ind, sparse_attn, attn_type, attn_id, ff_id) in \ + zip(range(depth), sparse_layer, attn_type_layer, shared_attn_ids, shared_ff_ids): if attn_type == 'full': attn_class = partial(Attention, stable = stable) elif attn_type == 'sparse': attn_class = SparseAttention elif attn_type == 'axial_row': - attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable) + if optimize_for_inference: + attn_class = partial(Attention, stable = stable, static_mask = self._get_attention_mask(attn_type)) + else: + attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable) elif attn_type == 'axial_col': - attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable) + if optimize_for_inference: + attn_class = partial(Attention, stable = stable, static_mask = self._get_attention_mask(attn_type)) + else: + attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable) elif attn_type == 'conv_like': attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable) elif attn_type == 'mlp': @@ -181,15 +266,31 @@ def __init__( else: raise ValueError(f'attention type "{attn_type}" is not valid') - if attn_type != 'mlp': - attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) + attn, reused_attn_type = shared_attn_layers.get(attn_id, (None, None)) + if not exists(attn): + if attn_type != 'mlp': + attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) + else: + attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4) + shared_attn_layers[attn_id] = (attn, attn_type) + elif attn_type != reused_attn_type: + raise ValueError('attn_types do not match shared_attn_ids ' + f'(ind = {ind}, attn_type = "{attn_type}", reused_attn_type = "{reused_attn_type}")') + + ff = shared_ff_layers.get(ff_id) + if not exists(ff): + ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) + shared_ff_layers[ff_id] = ff + + if isinstance(attn, Attention): + attn = CachedAs(f'attn_{ind}', attn) else: - attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4) - - ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) + # at the moment, other attention classes don't support cache + attn = NonCached(attn) if shift_tokens: - attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff)) + attn = CachedAs(f'preshift_attn_{ind}', PreShiftToken(attn, image_size = image_fmap_size, seq_len = seq_len)) + ff = CachedAs(f'preshift_ff_{ind}', PreShiftToken(ff, image_size = image_fmap_size, seq_len = seq_len)) layers.append(nn.ModuleList([ LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)), @@ -198,7 +299,9 @@ def __init__( execute_type = ReversibleSequence if reversible else SequentialSequence route_attn = ((True, False),) * depth - attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn} + route_all = ((True, True),) * depth + attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn, + 'cache': route_all} self.layers = execute_type(layers, args_route = attn_route_map) @@ -234,3 +337,22 @@ def __init__( def forward(self, x, **kwargs): return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs) + + def _get_attention_mask(self, attn_type): + img_seq_len = self.image_fmap_size ** 2 + text_len = self.seq_len + 1 - img_seq_len + + static_mask = torch.zeros(self.seq_len, self.seq_len, dtype=torch.bool) + static_mask[:, :text_len] = True + if attn_type == 'axial_row': + for row in range(self.image_fmap_size): + begin = text_len + row * self.image_fmap_size + end = text_len + (row + 1) * self.image_fmap_size + static_mask[begin:end, begin:end] = True + elif attn_type == 'axial_col': + for col in range(self.image_fmap_size): + begin = text_len + col + static_mask[begin::self.image_fmap_size, begin::self.image_fmap_size] = True + else: + raise ValueError(f'attention type "{attn_type}" can\'t be simulated with a static mask') + return static_mask diff --git a/train_dalle.py b/train_dalle.py index c3d30865..5de6f9b6 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -131,6 +131,12 @@ model_group.add_argument('--rotary_emb', help = 'Use rotary embeddings', action = 'store_true') +model_group.add_argument('--shared_attn_ids', default = None, type = str, help = 'Comma separated list of shared attention layer ids. Default: sharing is disabled') + +model_group.add_argument('--shared_ff_ids', default = None, type = str, help = 'Comma separated list of shared feed forward layer ids. Default: sharing is disabled') + +model_group.add_argument('--share_input_output_emb', help = 'Share input and output embeddings', action = 'store_true') + args = parser.parse_args() # helpers @@ -193,6 +199,9 @@ def cp_path_to_dir(cp_path, tag): ROTARY_EMB = args.rotary_emb ATTN_TYPES = tuple(args.attn_types.split(',')) +SHARED_ATTN_IDS = tuple(args.shared_attn_ids.split(',')) if exists(args.shared_attn_ids) else None +SHARED_FF_IDS = tuple(args.shared_ff_ids.split(',')) if exists(args.shared_ff_ids) else None +SHARE_INPUT_OUTPUT_EMB = args.share_input_output_emb DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt' @@ -304,6 +313,9 @@ def cp_path_to_dir(cp_path, tag): stable=STABLE, shift_tokens=SHIFT_TOKENS, rotary_emb=ROTARY_EMB, + shared_attn_ids=SHARED_ATTN_IDS, + shared_ff_ids=SHARED_FF_IDS, + share_input_output_emb=SHARE_INPUT_OUTPUT_EMB, ) resume_epoch = 0 @@ -368,7 +380,7 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available if myimg not in item: return False return True - + w_dataset = wds.WebDataset(DATASET, handler=wds.warn_and_continue) filtered_dataset = w_dataset.select(filter_dataset) ds = filtered_dataset.map_dict(**image_text_mapping).map_dict(**image_mapping).to_tuple(mycap, myimg).batched(BATCH_SIZE / distr_backend.get_world_size(), partial=True) @@ -623,7 +635,7 @@ def save_artifact(model_config, model_path, name = 'trained-dalle'): if i % SAVE_EVERY_N_STEPS == 0: save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch) - + if i % 100 == 0 and is_root: sample_text = text[:1] token_list = sample_text.masked_select(sample_text != 0).tolist() @@ -651,7 +663,7 @@ def save_artifact(model_config, model_path, name = 'trained-dalle'): distr_scheduler.step(avg_loss) save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch) - + if is_root: # save trained model to wandb as an artifact every epoch's end save_artifact(model_config, DALLE_OUTPUT_FILE_NAME)