Skip to content

Commit

Permalink
Fix multiple temperature scaling in sample_non_deterministic()
Browse files Browse the repository at this point in the history
Ensures that temperature scaling is applied only once by moving it inside softmax() instead of modifying logits multiple times. Fixes issue #19.
  • Loading branch information
dansuissa authored Feb 6, 2025
1 parent 2842ef0 commit c5276a6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def sample_non_deterministic(
raise ValueError("top_p must be a float between 0 and 1.")

# Compute probabilities using temperature scaling
logits /= temperature
probs = torch.softmax(logits, dim=-1)
probs = torch.softmax(logits / temperature, dim=-1)


# Remove batch dimension if present
if probs.dim() == 3:
Expand Down

0 comments on commit c5276a6

Please sign in to comment.