Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions mindone/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def finalize(
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)

# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
sent_lengths = mint.empty(batch_size * self.num_beam_hyps_to_keep)
best = []
best_indices = []
best_scores = mint.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=ms.float32)
Expand Down Expand Up @@ -375,10 +375,10 @@ def finalize(
# prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: ms.Tensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
decoded: ms.Tensor = mint.empty((batch_size * self.num_beam_hyps_to_keep, sent_max_len))

if len(best_indices) > 0 and best_indices[0] is not None:
indices: ms.Tensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
indices: ms.Tensor = mint.empty((batch_size * self.num_beam_hyps_to_keep, sent_max_len))
else:
indices = None

Expand Down Expand Up @@ -853,7 +853,7 @@ def finalize(
break

# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
sent_lengths = mint.empty(batch_size * self.num_beam_hyps_to_keep)
best = []
best_indices = []
best_scores = mint.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=ms.float32)
Expand All @@ -880,10 +880,10 @@ def finalize(
sent_lengths_max = sent_lengths.max().item() + 1

sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: ms.Tensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
decoded: ms.Tensor = mint.empty((batch_size * self.num_beam_hyps_to_keep, sent_max_len))

if len(best_indices) > 0 and best_indices[0] is not None:
indices: ms.Tensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
indices: ms.Tensor = mint.empty((batch_size * self.num_beam_hyps_to_keep, sent_max_len))
else:
indices = None

Expand Down
22 changes: 12 additions & 10 deletions mindone/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import mindspore as ms
from mindspore import mint
from mindspore import numpy as mnp
from mindspore import ops

from ..cache_utils import DynamicCache

Expand Down Expand Up @@ -450,7 +451,7 @@ def _get_tokens_diag(prompt, prompt_plus_new_tokens):
if not isinstance(compare_mat, ms.Tensor):
compare_mat = ms.Tensor(compare_mat)

compare_mat_int = compare_mat.to(int)
compare_mat_int = compare_mat.to(ms.int32)

if not compare_mat_int.any().item():
# empty intersection between prompt and prompt_plus_new_tokens
Expand Down Expand Up @@ -482,7 +483,7 @@ def convert_source_tokens_to_target_tokens(
The converted token IDs.
"""
text = source_tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
dest_ids = destination_tokenizer(text, add_special_tokens=True, return_tensors="pt")["input_ids"]
dest_ids = ms.tensor(destination_tokenizer(text, add_special_tokens=True, return_tensors="np")["input_ids"])
return dest_ids

def get_candidates(self, input_ids: ms.Tensor) -> tuple[ms.Tensor, Optional[ms.Tensor]]:
Expand Down Expand Up @@ -671,14 +672,14 @@ def _get_assistant_to_target_input_ids(self):
}

max_assistant_index = max(assistant_vocab.values())
assistant_to_target_input_ids = mint.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int)
assistant_to_target_input_ids = ops.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=ms.int32)
target_to_assistant_input_ids: dict[int, int] = {}
for tok, assistant_id in assistant_vocab.items():
target_id = target_vocab.get(tok)
if target_id is not None:
assistant_to_target_input_ids[assistant_id] = target_id
target_to_assistant_input_ids[target_id] = assistant_id
return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids
return assistant_to_target_input_ids, target_to_assistant_input_ids

def _get_suppress_input_ids(self) -> list[int]:
"""
Expand Down Expand Up @@ -706,7 +707,7 @@ def get_target_logits(self, assistant_logits: ms.Tensor) -> ms.Tensor:
"""

target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
target_logits: ms.Tensor = mint.full(target_shape, self.FILTER_VALUE)
target_logits: ms.Tensor = ops.full(target_shape, self.FILTER_VALUE)
# Mask for valid indices
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
# Exclude invalid indices
Expand All @@ -732,7 +733,6 @@ def get_translator(
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
target_vocab_size: int,
assistant_model_device: str = "cpu",
) -> AssistantToTargetTranslator:
assistant_dict = cls._cache.get(target_tokenizer)
if assistant_dict is None:
Expand All @@ -742,7 +742,9 @@ def get_translator(
mapping = assistant_dict.get(assistant_tokenizer)
if mapping is None:
mapping = AssistantToTargetTranslator(
target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device
target_tokenizer,
assistant_tokenizer,
target_vocab_size,
)
assistant_dict[assistant_tokenizer] = mapping

Expand Down Expand Up @@ -863,9 +865,9 @@ def _prepare_assistant_input_ids(self, target_input_ids: ms.Tensor) -> ms.Tensor
target_new_text = self.target_tokenizer.batch_decode(
target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
assistant_new_ids = self.assistant_tokenizer(
target_new_text, add_special_tokens=False, return_tensors="pt"
)["input_ids"]
assistant_new_ids = ms.tensor(
self.assistant_tokenizer(target_new_text, add_special_tokens=False, return_tensors="np")["input_ids"]
)
else:
assistant_new_ids = ms.Tensor([[assistant_new_ids]])

Expand Down
100 changes: 96 additions & 4 deletions mindone/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,28 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput):
past_key_values: Optional[Tuple[Tuple[Tuple[ms.Tensor]]]] = None


# TODO (joao): remove the equivalent classes and typing shortcuts below in v5
# Equivalent classes (kept for retrocompatibility purposes)
GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
SampleDecoderOnlyOutput = GenerateDecoderOnlyOutput

ContrastiveSearchEncoderDecoderOutput = GenerateEncoderDecoderOutput
GreedySearchEncoderDecoderOutput = GenerateEncoderDecoderOutput
SampleEncoderDecoderOutput = GenerateEncoderDecoderOutput

BeamSearchDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput
BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput

BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput

GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]

# Typing shortcuts
GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]
GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput]
Expand Down Expand Up @@ -2482,7 +2504,7 @@ def heal_tokens(self, input_ids: ms.Tensor, tokenizer: Optional["PreTrainedToken
# assumption: leading/trailing whitespace is not meaningful, so the prompts are
# stripped before re-tokenizing to desensitize generation to whitespace artefacts
prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
input_ids = ms.Tensor(
input_ids = ms.tensor(
tokenizer(
prompts,
return_tensors="np",
Expand All @@ -2491,7 +2513,14 @@ def heal_tokens(self, input_ids: ms.Tensor, tokenizer: Optional["PreTrainedToken
)

# replace bos with pad to not condition healing on it
input_ids = ops.where(input_ids == bos_token_id, pad_token_id, input_ids)
input_ids = mint.where(input_ids == bos_token_id, pad_token_id, input_ids)

"""
the latter code assumes the input_ids is not empty,
input_id has to be checked if contains elements
"""
Comment on lines +2518 to +2521
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This multiline string is used as a comment. For better readability and to avoid potential runtime effects of docstrings, it's recommended to use standard # comments.

        # The following code assumes `input_ids` is not empty,
        # so it has to be checked if it contains elements.

if input_ids.numel() == 0:
return input_ids

tail_ids = input_ids[:, -1].tolist()

Expand All @@ -2502,11 +2531,18 @@ def heal_tokens(self, input_ids: ms.Tensor, tokenizer: Optional["PreTrainedToken

for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):
batch_ids = input_ids[batch_idx]
if ops.all(batch_ids == pad_token_id).item():
if mint.all(batch_ids == pad_token_id).item():
continue # skip empty sequences (all pad ids)

# apply bias for alternatives (extensions) to the tail token
seq_bias = {(alt_tok,): 10.0 for alt_tok in vocab_trie.values(prefix=tail_tok)}
"""
seq_bias key has to be tuple with int so have to use
tokenizer function to convert str to int
"""
Comment on lines +2538 to +2541
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This multiline string is used as a comment. For better readability and to avoid potential runtime effects of docstrings, it's recommended to use standard # comments.

            # The `seq_bias` key must be a tuple of integers (token IDs), so we use the
            # tokenizer function to convert the string token to an integer ID.

seq_bias = {
(tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok)
}

if len(seq_bias) == 1:
continue # skip if there are no token alternatives to heal with

Expand Down Expand Up @@ -3321,3 +3357,59 @@ def _beam_search(
)
else:
return sequences


def _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
is_done_candidate,
):
"""
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
the selected tokens, as well as the number of candidate matches.

NOTE: Unless otherwise stated, the variable names match those in the paper.
"""
new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# selected by the assistant, respectively.
q = mint.nn.functional.softmax(candidate_logits, dim=-1)
q_i = q[:, mint.arange(candidate_length), new_candidate_input_ids].squeeze((0, 1))
p = mint.nn.functional.softmax(new_logits, dim=-1)
p_i = p[:, mint.arange(candidate_length), new_candidate_input_ids].squeeze((0, 1))
probability_ratio = p_i / q_i

# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
# than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
# (= keep with p = probability_ratio). Keep all the tokens until the first rejection
r_i = mint.rand_like(probability_ratio)
is_accepted = r_i <= probability_ratio
n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1

# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
if is_done_candidate and n_matches == candidate_length:
# Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
# due to acceptance on EOS we fix `n_matches`
n_matches -= 1
valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
else:
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma = candidate_logits.shape[1]
p_n_plus_1 = p[:, n_matches, :]
if n_matches < gamma:
q_n_plus_1 = q[:, n_matches, :]
p_prime = mint.clamp((p_n_plus_1 - q_n_plus_1), min=0)
p_prime.div_(p_prime.sum())
else:
p_prime = p_n_plus_1
t = mint.squeeze(mint.multinomial(p_prime, num_samples=1), dim=1)[None, :]

# The selected tokens include the matches (if any) plus the next sampled tokens
if n_matches > 0:
valid_tokens = mint.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
else:
valid_tokens = t

return valid_tokens, n_matches
Loading