From 9fae90fa1d3ced0ae49cb98d508ea88a7ff372d5 Mon Sep 17 00:00:00 2001 From: Shivaen Date: Sun, 16 Apr 2023 20:48:05 -0700 Subject: [PATCH] initial batched CoCa implementation --- src/open_clip/coca_model.py | 136 +++++++++++++++++++++--------------- 1 file changed, 79 insertions(+), 57 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 039453af7..5c739d0b2 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -167,6 +167,7 @@ def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=Non def generate( self, image, + device, text=None, seq_len=30, max_seq_len=77, @@ -182,12 +183,18 @@ def generate( min_seq_len=5, stopping_criteria=None, repetition_penalty=1.0, - fixed_output_length=False # if True output.shape == (batch_size, seq_len) + fixed_output_length=False, # if True output.shape == (batch_size, seq_len) + batch_size=None, ): # taking many ideas and components from HuggingFace GenerationMixin # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + + if batch_size is None: + batch_size = image.shape[0] + + outputs = [] with torch.no_grad(): sot_token_id = 49406 if sot_token_id is None else sot_token_id @@ -207,26 +214,27 @@ def generate( stopping_criteria ) - device = image.device - if generation_type == "beam_search": - output = self._generate_beamsearch( - image_inputs = image, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - sot_token_id=sot_token_id, - num_beams=num_beams, - num_beam_groups=num_beam_groups, - min_seq_len=min_seq_len, - stopping_criteria=stopping_criteria, - logit_processor=logit_processor, - ) - if fixed_output_length and output.shape[1] < seq_len: - return torch.cat( - (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), - dim=1 + for i in range(0, image.shape[0], batch_size): + output = self._generate_beamsearch( + image_inputs = image[i:i+batch_size].to(device), + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + sot_token_id=sot_token_id, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + min_seq_len=min_seq_len, + stopping_criteria=stopping_criteria, + logit_processor=logit_processor, ) - return output + if fixed_output_length and output.shape[1] < seq_len: + output = torch.cat( + (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), + dim=1 + ) + outputs += output.cpu() + + return nn.utils.rnn.pad_sequence(outputs, batch_first=True, padding_value=eos_token_id) elif generation_type == "top_p": logit_warper = GENERATION_TYPES[generation_type](top_p) @@ -238,54 +246,68 @@ def generate( f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." ) - image_latent, image_embs = self._encode_image(image) - - if text is None: - text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id + for i in range(0, image.shape[0], batch_size): - was_training = self.training - num_dims = len(text.shape) + images = image[i:i+batch_size].to(device) - if num_dims == 1: - text = text[None, :] + image_latent, image_embs = self._encode_image(images) - cur_len = text.shape[1] - self.eval() - out = text - - while True: - x = out[:, -max_seq_len:] - cur_len = x.shape[1] - logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] - mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) - sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id - - if mask.all(): - if not fixed_output_length: - break + if text is None: + texts = torch.ones((images.shape[0], 1), device=device, dtype=torch.long) * sot_token_id else: - logits = logits[~mask, :] - filtered_logits = logit_processor(x[~mask, :], logits) - filtered_logits = logit_warper(x[~mask, :], filtered_logits) - probs = F.softmax(filtered_logits / temperature, dim=-1) - - if (cur_len + 1 == seq_len): - sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + texts = text[:] + + was_training = self.training + num_dims = len(texts.shape) + + if num_dims == 1: + texts = texts[None, :] + elif text is not None: + texts = texts[i:i+batch_size] + + cur_len = texts.shape[1] + self.eval() + out = texts.to(device) + + while True: + x = out[:, -max_seq_len:] + cur_len = x.shape[1] + + logits = self(images, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] + mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) + sample = torch.ones((images.shape[0], 1), device=device, dtype=torch.long) * pad_token_id + + if mask.all(): + if not fixed_output_length: + break else: - sample[~mask, :] = torch.multinomial(probs, 1) + logits = logits[~mask, :] + filtered_logits = logit_processor(x[~mask, :], logits) + filtered_logits = logit_warper(x[~mask, :], filtered_logits) + probs = F.softmax(filtered_logits / temperature, dim=-1) - out = torch.cat((out, sample), dim=-1) + if (cur_len + 1 == seq_len): + sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + else: + sample[~mask, :] = torch.multinomial(probs, 1) - cur_len += 1 + out = torch.cat((out, sample), dim=-1) + + cur_len += 1 + + if stopping_criteria(out, None): + break - if stopping_criteria(out, None): - break + outputs += out.cpu() - if num_dims == 1: - out = out.squeeze(0) + self.train(was_training) + + outputs = nn.utils.rnn.pad_sequence(outputs, batch_first=True, padding_value=eos_token_id) - self.train(was_training) - return out + if num_dims == 1: + return outputs.squeeze(0) + + return outputs def _generate_beamsearch( self,