Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for top_p in non-deterministic generation #11

Merged
merged 13 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions CodonTransformer/CodonData.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
START_CODONS,
STOP_CODONS,
STOP_SYMBOL,
STOP_SYMBOLS,
find_pattern_in_fasta,
get_taxonomy_id,
sort_amino2codon_skeleton,
Expand Down Expand Up @@ -177,13 +178,13 @@ def preprocess_protein_sequence(protein: str) -> str:
)

# Check for sequence validity
if any(
aminoacid not in AMINO_ACIDS + ["*", STOP_SYMBOL] for aminoacid in protein[:-1]
):
if any(aminoacid not in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein):
raise ValueError("Invalid characters in protein sequence.")

if protein[-1] not in AMINO_ACIDS + ["*", STOP_SYMBOL]:
raise ValueError("Protein sequence must end with *, or _, or an amino acid.")
if protein[-1] not in AMINO_ACIDS + STOP_SYMBOLS:
raise ValueError(
"Protein sequence must end with `*`, or `_`, or an amino acid."
)

# Replace '*' at the end of protein with STOP_SYMBOL if present
if protein[-1] == "*":
Expand Down
102 changes: 84 additions & 18 deletions CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from CodonTransformer.CodonData import get_merged_seq
from CodonTransformer.CodonUtils import (
AMINO_ACIDS,
INDEX2TOKEN,
NUM_ORGANISMS,
ORGANISM2ID,
Expand All @@ -40,6 +39,7 @@ def predict_dna_sequence(
attention_type: str = "original_full",
deterministic: bool = True,
temperature: float = 0.2,
top_p: float = 0.95,
) -> DNASequencePrediction:
"""
Predict the DNA sequence for a given protein using the CodonTransformer model.
Expand Down Expand Up @@ -76,6 +76,10 @@ def predict_dna_sequence(
- Medium randomness: 0.5
- High randomness: 0.8
The temperature must be a positive float. Defaults to 0.2.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
Tokens with cumulative probability up to `top_p` are considered for sampling.
This parameter helps balance diversity and coherence in the predicted DNA sequences.
The value must be a float between 0 and 1. Defaults to 0.95.

Returns:
DNASequencePrediction: An object containing the prediction results:
Expand All @@ -86,7 +90,7 @@ def predict_dna_sequence(

Raises:
ValueError: If the protein sequence is empty, if the organism is invalid,
or if the temperature is not a positive float.
if the temperature is not a positive float, or if `top_p` is not between 0 and 1.

Note:
This function uses `ORGANISM2ID` and `INDEX2TOKEN` dictionaries imported from
Expand Down Expand Up @@ -123,7 +127,7 @@ def predict_dna_sequence(
... deterministic=True
... )
>>>
>>> # Predict DNA sequence with low randomness
>>> # Predict DNA sequence with low randomness and top_p sampling
>>> output_random = predict_dna_sequence(
... protein=protein,
... organism=organism,
Expand All @@ -132,7 +136,8 @@ def predict_dna_sequence(
... model=model,
... attention_type="original_full",
... deterministic=False,
... temperature=0.2
... temperature=0.2,
... top_p=0.95
... )
>>>
>>> print(format_model_output(output))
Expand All @@ -141,14 +146,14 @@ def predict_dna_sequence(
if not protein:
raise ValueError("Protein sequence cannot be empty.")

# Ensure the protein sequence contains only valid amino acids
if not all(aminoacid in AMINO_ACIDS for aminoacid in protein):
raise ValueError("Invalid amino acid found in protein sequence.")

# Validate temperature
if not isinstance(temperature, (float, int)) or temperature <= 0:
raise ValueError("Temperature must be a positive float.")

# Validate top_p
if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0:
raise ValueError("top_p must be a float between 0 and 1.")

# Load tokenizer
if not isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer = load_tokenizer(tokenizer)
Expand Down Expand Up @@ -181,18 +186,10 @@ def predict_dna_sequence(

# Decode the predicted DNA sequence from the model output
if deterministic:
# Select the most probable tokens (argmax)
predicted_indices = logits.argmax(dim=-1).squeeze().tolist()
else:
# Sample tokens according to their probability distribution
# Apply temperature scaling and convert logits to probabilities
logits = logits / temperature
probabilities = torch.softmax(logits, dim=-1)

# Sample from the probability distribution at each position
probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size]
predicted_indices = (
torch.multinomial(probabilities, num_samples=1).squeeze(-1).tolist()
predicted_indices = sample_non_deterministic(
logits=logits, temperature=temperature, top_p=top_p
)

predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices))
Expand All @@ -210,6 +207,75 @@ def predict_dna_sequence(
)


def sample_non_deterministic(
logits: torch.Tensor,
temperature: float = 0.2,
top_p: float = 0.95,
) -> List[int]:
"""
Sample token indices from logits using temperature scaling and nucleus (top-p) sampling.

This function applies temperature scaling to the logits, computes probabilities,
and then performs nucleus sampling to select token indices. It is used for
non-deterministic decoding in language models to introduce randomness while
maintaining coherence in the generated sequences.

Args:
logits (torch.Tensor): The logits output from the model of shape
[seq_len, vocab_size] or [batch_size, seq_len, vocab_size].
temperature (float, optional): Temperature value for scaling logits.
Must be a positive float. Defaults to 1.0.
top_p (float, optional): Cumulative probability threshold for nucleus sampling.
Must be a float between 0 and 1. Tokens with cumulative probability up to
`top_p` are considered for sampling. Defaults to 0.95.

Returns:
List[int]: A list of sampled token indices corresponding to the predicted tokens.

Raises:
ValueError: If `temperature` is not a positive float or if `top_p` is not between 0 and 1.

Example:
>>> logits = model_output.logits # Assume logits is a tensor of shape [seq_len, vocab_size]
>>> predicted_indices = sample_non_deterministic(logits, temperature=0.7, top_p=0.9)
"""
if not isinstance(temperature, (float, int)) or temperature <= 0:
raise ValueError("Temperature must be a positive float.")
if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0:
raise ValueError("top_p must be a float between 0 and 1.")

# Apply temperature scaling and compute probabilities
logits = logits / temperature
probabilities = torch.softmax(logits, dim=-1)

# Remove batch dimension if present
if probabilities.dim() == 3 and probabilities.size(0) == 1:
probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size]

predicted_indices = []
for probs in probabilities:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=0)

# Find the cutoff index where cumulative_probs exceeds top_p
cutoff_index = torch.where(cumulative_probs > top_p)[0]
if len(cutoff_index) > 0:
cutoff_index = cutoff_index[0].item()
# Keep only tokens up to the cutoff index
sorted_probs = sorted_probs[: cutoff_index + 1]
sorted_indices = sorted_indices[: cutoff_index + 1]

# Re-normalize the probabilities after filtering
filtered_probs = sorted_probs / sorted_probs.sum()

# Sample from the filtered distribution
sampled_index = torch.multinomial(filtered_probs, num_samples=1).item()
predicted_index = sorted_indices[sampled_index].item()
predicted_indices.append(predicted_index)

return predicted_indices


def load_model(
model_path: Optional[str] = None,
device: torch.device = None,
Expand Down
1 change: 1 addition & 0 deletions CodonTransformer/CodonUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"W", # Tryptophan
"Y", # Tyrosine
]
STOP_SYMBOLS = ["_", "*"] # Stop codon symbols

# Dictionary ambiguous amino acids to standard amino acids
AMBIGUOUS_AMINOACID_MAP: Dict[str, str] = {
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ This subpackage contains functions and classes that handle the core prediction f

Predict the DNA sequence for a given protein using the CodonTransformer model.

- `sample_non_deterministic(logits: torch.Tensor, temperature: float = 0.2, top_p: float = 0.95) -> List[int]`

Sample token indices from logits using temperature scaling and nucleus (top-p) sampling.

- `load_model(path: str, device: torch.device = None, num_organisms: int = None, remove_prefix: bool = True, attention_type: str = "original_full") -> torch.nn.Module`

Load a BigBirdForMaskedLM model from a file or checkpoint.
Expand Down Expand Up @@ -383,6 +387,7 @@ The CodonUtils subpackage contains constants and helper functions essential for
#### Constants

- `AMINO_ACIDS`: List of all standard amino acids
- `STOP_SYMBOLS`: List of possible stop symbols to end the protein with
- `AMBIGUOUS_AMINOACID_MAP`: Mapping of ambiguous amino acids to standard amino acids
- `START_CODONS` and `STOP_CODONS`: Lists of start and stop codons
- `TOKEN2INDEX` and `INDEX2TOKEN`: Mappings between tokens and their indices
Expand Down
Loading