Skip to content
69 changes: 57 additions & 12 deletions realnet/training/chaos_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ class ChaosGrad(torch.optim.Optimizer):
input_sentinel (bool): Track input gradient health. Default: False.
adaptive_lr (bool): Enable per-param adaptive LR scaling. Default: True.
adaptive_lr_clip (tuple): (min, max) multiplier for adaptive LR. Default: (0.1, 10.0).
adaptive_ema (float): Smoothing factor for adaptive LR variance. Default: 0.99.
grad_centralization (bool): Center gradients by removing mean. Default: True.
plateau_noise_intensity (float): Internal multiplier for plateau noise. Default: 0.1.
loss_history_min (int): Minimum number of recent loss values to retain in the history buffer. Default: 200.
sentinel_threshold (float): Relative threshold for input health detection. Default: 0.1.
"""

def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8,
Expand All @@ -57,7 +61,9 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8,
plateau_patience=0, plateau_noise_scale=0.01,
spectral_clip=0.0, input_sentinel=False,
adaptive_lr=True, adaptive_lr_clip=(0.1, 10.0),
grad_centralization=True):
grad_centralization=True, adaptive_ema=0.99,
plateau_noise_intensity=0.1,
loss_history_min=200, sentinel_threshold=0.1):

defaults = dict(
lr=lr, betas=betas, eps=eps,
Expand All @@ -72,7 +78,11 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8,
input_sentinel=input_sentinel,
adaptive_lr=adaptive_lr,
adaptive_lr_clip=adaptive_lr_clip,
adaptive_ema=adaptive_ema,
grad_centralization=grad_centralization,
plateau_noise_intensity=plateau_noise_intensity,
loss_history_min=loss_history_min,
sentinel_threshold=sentinel_threshold,
)
super().__init__(params, defaults)

Expand Down Expand Up @@ -108,9 +118,9 @@ def classify_params(model):
if not param.requires_grad:
continue

if name == 'W':
if name == 'W' or name.endswith('.W'):
chaos_core.append(param)
elif any(k in name for k in ['embed', 'proj', 'output_decoder']):
elif any(k in name for k in ['embed', 'proj', 'decoder', 'output_decoder']):
projections.append(param)
else:
# B, input_scale, output_scale, norm.weight, norm.bias
Expand Down Expand Up @@ -158,7 +168,8 @@ def report_loss(self, loss_value):
self._loss_history.append(loss_value)

# Keep only recent history
max_history = max(200, self.defaults.get('plateau_patience', 0) * 2)
hist_min = self.defaults.get('loss_history_min', 200)
max_history = max(hist_min, self.defaults.get('plateau_patience', 0) * 2)
if len(self._loss_history) > max_history:
self._loss_history = self._loss_history[-max_history:]

Expand Down Expand Up @@ -264,8 +275,14 @@ def step(self, closure=None):

# --- Plateau Escape (Controlled Perturbation) ---
if is_plateau and is_core:
# Inject targeted noise into gradients for the chaos core
noise = torch.randn_like(grad) * plateau_noise * grad.abs().mean()
intensity = group.get('plateau_noise_intensity', 0.1)

noise_scale = grad.abs().mean()
if (not torch.isfinite(noise_scale)) or noise_scale < 1e-8:
noise_scale = p.abs().mean()
if (not torch.isfinite(noise_scale)) or noise_scale < 1e-8:
noise_scale = grad.new_tensor(1e-3)
noise = torch.randn_like(grad) * plateau_noise * noise_scale * intensity
grad = grad + noise

# --- Decoupled Weight Decay ---
Expand Down Expand Up @@ -307,10 +324,11 @@ def step(self, closure=None):
if not isinstance(grad_var, float):
grad_var = grad_var.item()

# Smooth the ratio (exponential moving average)
grad_var = grad_var * 0.99 + ratio * 0.01
# Smooth the ratio with exponential moving average
ema = group.get('adaptive_ema', 0.99)
grad_var = grad_var * ema + ratio * (1 - ema)

# Apply smoothed adaptive scaling multiplier to the step size
# Apply adaptive multiplier with epsilon guard
adaptive_mult = 1.0 / (grad_var + eps)
adaptive_mult = max(adaptive_clip[0], min(adaptive_clip[1], adaptive_mult))
step_size *= adaptive_mult
Expand Down Expand Up @@ -340,7 +358,12 @@ def step(self, closure=None):
self._spectral_radius = abs(sigma_max)

if self._spectral_radius > spectral_clip:
p.data.mul_(spectral_clip / (self._spectral_radius + eps))
ratio = spectral_clip / (self._spectral_radius + eps)
p.data.mul_(ratio)

# Update momentum buffers to maintain topological alignment
state['exp_avg'].mul_(ratio)
state['exp_avg_sq'].mul_(ratio ** 2)
except Exception:
pass

Expand All @@ -357,9 +380,11 @@ def step(self, closure=None):
self._diagnostics['total_grad_norm'] = total_grad_norm
self._diagnostics['total_param_norm'] = total_param_norm

# Input gradient health calculation
# Input gradient health calculation (detects vanishing inputs)
if total_grad_norm > 0:
self._input_grad_health = min(1.0, input_grad_norm / (total_grad_norm * 0.1 + 1e-12))
# Compared against a scaled threshold of total activity
threshold = self.defaults.get('sentinel_threshold', 0.1)
self._input_grad_health = min(1.0, input_grad_norm / (total_grad_norm * threshold + 1e-12))

if is_plateau:
self._diagnostics['plateau_escape_triggered'] = self._global_step
Expand Down Expand Up @@ -388,7 +413,11 @@ def default(lr=1e-4):
chaos_core_lr_mult=1.0,
projection_lr_mult=1.0,
adaptive_lr=True,
adaptive_ema=0.99,
grad_centralization=True,
plateau_noise_intensity=0.1,
sentinel_threshold=0.1,
loss_history_min=200,
)

@staticmethod
Expand All @@ -406,7 +435,11 @@ def aggressive(lr=3e-4):
plateau_noise_scale=0.02,
adaptive_lr=True,
adaptive_lr_clip=(0.2, 5.0),
adaptive_ema=0.98,
grad_centralization=True,
plateau_noise_intensity=0.2,
sentinel_threshold=0.1,
loss_history_min=100,
)

@staticmethod
Expand All @@ -422,7 +455,11 @@ def finetune(lr=1e-5):
projection_lr_mult=0.8,
adaptive_lr=True,
adaptive_lr_clip=(0.5, 2.0),
adaptive_ema=0.995,
grad_centralization=False,
plateau_noise_intensity=0.05,
sentinel_threshold=0.05,
loss_history_min=500,
)

@staticmethod
Expand All @@ -445,7 +482,11 @@ def large_network(lr=1e-4):
input_sentinel=True,
adaptive_lr=True,
adaptive_lr_clip=(0.1, 5.0),
adaptive_ema=0.99,
grad_centralization=True,
plateau_noise_intensity=0.1,
sentinel_threshold=0.2,
loss_history_min=200,
)

@staticmethod
Expand All @@ -460,5 +501,9 @@ def tiny_network(lr=0.01):
chaos_core_lr_mult=1.0,
projection_lr_mult=1.0,
adaptive_lr=False,
adaptive_ema=0.99,
grad_centralization=False,
plateau_noise_intensity=0.01,
sentinel_threshold=0.1,
loss_history_min=50,
)
11 changes: 7 additions & 4 deletions realnet/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RealNetTrainer:
def __init__(self, model, optimizer=None, loss_fn=None, lr=1e-4, device='cpu',
gradient_persistence=0.0, synaptic_noise=1e-6,
chaos_config=None, scheduler_config=None,
use_chaos_grad=None, use_temporal_scheduler=None):
use_chaos_grad=None, use_temporal_scheduler=None, max_grad_norm=1.0):
"""
Initializes the trainer.

Expand All @@ -65,6 +65,7 @@ def __init__(self, model, optimizer=None, loss_fn=None, lr=1e-4, device='cpu',
self.model.to(self.device)
self.gradient_persistence = gradient_persistence
self.synaptic_noise = synaptic_noise
self.max_grad_norm = max_grad_norm
self.initial_lr = lr

# --- Optimizer Initialization ---
Expand Down Expand Up @@ -234,19 +235,21 @@ def train_batch(self, input_features, target_values, thinking_steps, gradient_ac

if step_now:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
if self.max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()

if self.gradient_persistence > 0.0:
# Gradient Persistence
# Synchronize persistent gradients with active AMP scale
scale = self.scaler.get_scale() if hasattr(self, 'scaler') and self.scaler.is_enabled() else 1.0
with torch.no_grad():
for param in self.model.parameters():
if param.grad is not None:
if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
param.grad.zero_()
else:
param.grad.mul_(self.gradient_persistence)
param.grad.mul_(self.gradient_persistence * scale)
else:
self.optimizer.zero_grad()

Expand Down