-
Notifications
You must be signed in to change notification settings - Fork 0
Policy-v3-12k never presses attack buttons — logits deeply negative #42
Description
Summary
The policy-v3-12k-p0 checkpoint has learned to never press buttons. All button logits are deeply negative (-7 to -10), giving probabilities under 0.1%. In autoregressive matches, the agents run back and forth on FD without ever attacking.
Evidence
Raw policy outputs from a seed frame (Falcon vs Marth, FD):
Analog:
main_x: 0.502, main_y: 0.496, c_x: 0.501, c_y: 0.501, shoulder: 0.002
Button logits → probabilities:
A: -6.97 → 0.09%
B: -7.93 → 0.04%
X: -7.03 → 0.09%
Y: -7.25 → 0.07%
Z: -8.79 → 0.02%
L: -7.70 → 0.05%
R: -7.18 → 0.08%
START: -10.57 → 0.00%
Every button is suppressed. Bernoulli sampling at these probabilities essentially never fires.
Downstream impact on world model
Because the policy never sends attack inputs, the world model never predicts attack action states. A 200-frame match produces:
Action state distribution:
20 (Dash): 149 frames
19 (TurnRun): 96
18 (Turn): 91
369 (unknown): 60
15 (WalkSlow): 2
21 (Run): 1
Zero attacks. The world model is working correctly — it sees "stick movement, no buttons" and predicts running/turning. But the continuous heads (percent, stocks) drift independently of action state, so damage increases without any visible attacks in the visualizer. The heads are decoupled: the continuous_head can predict percent going up while the action_head predicts movement.
Reproduction
from models.checkpoint import load_model_from_checkpoint
from crank.agents import make_agent
from crank.match_runner import generate_synthetic_seed
import torch
model, cfg, _, _ = load_model_from_checkpoint('checkpoints/e012-clean-fd-top5.pt', 'cpu')
p0 = make_agent('policy:checkpoints/policy-v3-12k-p0/best.pt', player=0, cfg=cfg, device='cpu')
K = model.context_len
sim_f, sim_i = generate_synthetic_seed(cfg, K, 32, 2, 18)
with torch.no_grad():
preds = p0.model(sim_f[-K:].unsqueeze(0), sim_i[-K:].unsqueeze(0), predict_player=0)
print('Button logits:', preds['button_logits'][0].tolist())
# All deeply negativeAnalysis
Likely cause: class imbalance in training data. In real Melee, buttons are pressed maybe 3-5% of frames. With unweighted BCE loss, the policy minimizes loss by predicting "never press" — the negative logit bias reflects the prior probability of button presses in the dataset.
The analog outputs being centered (~0.5) reinforces this — the policy learned a "do nothing" equilibrium.
Possible fixes
-
Button-weighted loss: Increase weight on positive button labels during training (e.g.,
pos_weightinBCEWithLogitsLoss). A 10-20x weight on positive samples would counteract the class imbalance. -
Focal loss for buttons: Down-weight easy negatives (the 95% of frames with no buttons pressed), focus training on the frames where buttons are actually pressed.
-
Separate button sampling temperature: At inference time, add a learnable or fixed bias to button logits before sigmoid. +4 to all button logits would shift probabilities from 0.1% to ~5%, closer to the training data distribution.
-
Action-conditioned button targets: Only train button prediction on frames where the player is in an actionable state (not hitstun, not in an existing attack animation). This removes the confounding signal from frames where buttons are irrelevant.
Workaround for vertical slice
For the demo, we can add a button logit bias in crank/agents.py to offset the suppression. Something like button_logits += 4.0 before the Bernoulli sampling. This is a hack but would get attacks showing up in the live viewer while the policy is retrained properly.
Environment
- World model:
e012-clean-fd-top5.pt(Mamba2, d_model=384, 4 layers, epoch 1) - Policy:
policy-v3-12k-p0/best.pt(PolicyMLP, hidden=512, trunk=256) - Both use v3 encoding:
state_flags=True,hitstun=True,ctrl_threshold_features=True - Config match confirmed:
float_per_player=69, no strip_indices needed