diff --git a/moondream/torch/moondream.py b/moondream/torch/moondream.py index 6baf2d1b..8960192c 100644 --- a/moondream/torch/moondream.py +++ b/moondream/torch/moondream.py @@ -369,6 +369,7 @@ def _generate_points( next_token.item() != self.config.tokenizer.eos_id and len(out) < max_points ): + x_logits = decode_coordinate(hidden, self.region) x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1) next_emb = encode_coordinate( @@ -390,9 +391,17 @@ def _generate_points( mask[:, :, pos], pos_ids[0] = 1, pos logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids) pos += 1 + size_logits = decode_size(hidden, self.region) - w = torch.argmax(size_logits[0], dim=-1) / size_logits.size(-1) - h = torch.argmax(size_logits[1], dim=-1) / size_logits.size(-1) + + w_bin = torch.argmax(size_logits[0], dim=-1).float() + w_log2 = w_bin / 1023.0 * 10.0 - 10.0 + w = 2.0**w_log2 + + h_bin = torch.argmax(size_logits[1], dim=-1).float() + h_log2 = h_bin / 1023.0 * 10.0 - 10.0 + h = 2.0**h_log2 + next_emb = encode_size( torch.tensor( [w, h], device=self.device, dtype=size_logits.dtype