Skip to content

Commit

Permalink
Merge pull request #20 from dansuissa/fix-logits-rescaling
Browse files Browse the repository at this point in the history
Fix multiple temperature scaling in sample_non_deterministic()
  • Loading branch information
Adibvafa authored Feb 6, 2025
2 parents 2842ef0 + c5276a6 commit aaa68b4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def sample_non_deterministic(
raise ValueError("top_p must be a float between 0 and 1.")

# Compute probabilities using temperature scaling
logits /= temperature
probs = torch.softmax(logits, dim=-1)
probs = torch.softmax(logits / temperature, dim=-1)


# Remove batch dimension if present
if probs.dim() == 3:
Expand Down

0 comments on commit aaa68b4

Please sign in to comment.