diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index d7e90e1a..60a81c45 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -323,19 +323,21 @@ def batch_type(self) -> Union[FlashBatch, PaddedBatch]: def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]: if isinstance(batch, PaddedBatch): input_lens = batch.attention_mask.cumsum(-1)[:, -1].to(torch.int32) - max_input_lens = input_lens.max().item() + max_input_lens = 0 # This value will not be used cu_seqlens = torch.cat( (input_lens.new_tensor([0]), input_lens.cumsum(-1).int()) ) mask = batch.attention_mask.bool() - batch_size = input_lens.size(0) + bsz, tgt_len = mask.size() + min_val = torch.finfo(self.dtype).min attn_mask = torch.full( - [batch_size, 1, 1, mask.shape[-1]], - fill_value=torch.finfo(self.dtype).min, + [bsz, 1, tgt_len, tgt_len], + fill_value=min_val, device=self.device, dtype=self.dtype, ) - attn_mask.masked_fill_(mask[:, None, None, :], 0) + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, tgt_len) + attn_mask = attn_mask.masked_fill(expanded_mask, 0.0) elif isinstance(batch, FlashBatch): cu_seqlens = batch.cu_seqlens mask = None