diff --git a/realnet/training/chaos_optimizer.py b/realnet/training/chaos_optimizer.py index bcf0551..eaaf155 100644 --- a/realnet/training/chaos_optimizer.py +++ b/realnet/training/chaos_optimizer.py @@ -47,7 +47,11 @@ class ChaosGrad(torch.optim.Optimizer): input_sentinel (bool): Track input gradient health. Default: False. adaptive_lr (bool): Enable per-param adaptive LR scaling. Default: True. adaptive_lr_clip (tuple): (min, max) multiplier for adaptive LR. Default: (0.1, 10.0). + adaptive_ema (float): Smoothing factor for adaptive LR variance. Default: 0.99. grad_centralization (bool): Center gradients by removing mean. Default: True. + plateau_noise_intensity (float): Internal multiplier for plateau noise. Default: 0.1. + loss_history_min (int): Minimum number of recent loss values to retain in the history buffer. Default: 200. + sentinel_threshold (float): Relative threshold for input health detection. Default: 0.1. """ def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8, @@ -57,7 +61,9 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8, plateau_patience=0, plateau_noise_scale=0.01, spectral_clip=0.0, input_sentinel=False, adaptive_lr=True, adaptive_lr_clip=(0.1, 10.0), - grad_centralization=True): + grad_centralization=True, adaptive_ema=0.99, + plateau_noise_intensity=0.1, + loss_history_min=200, sentinel_threshold=0.1): defaults = dict( lr=lr, betas=betas, eps=eps, @@ -72,7 +78,11 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8, input_sentinel=input_sentinel, adaptive_lr=adaptive_lr, adaptive_lr_clip=adaptive_lr_clip, + adaptive_ema=adaptive_ema, grad_centralization=grad_centralization, + plateau_noise_intensity=plateau_noise_intensity, + loss_history_min=loss_history_min, + sentinel_threshold=sentinel_threshold, ) super().__init__(params, defaults) @@ -108,9 +118,9 @@ def classify_params(model): if not param.requires_grad: continue - if name == 'W': + if name == 'W' or name.endswith('.W'): chaos_core.append(param) - elif any(k in name for k in ['embed', 'proj', 'output_decoder']): + elif any(k in name for k in ['embed', 'proj', 'decoder', 'output_decoder']): projections.append(param) else: # B, input_scale, output_scale, norm.weight, norm.bias @@ -158,7 +168,8 @@ def report_loss(self, loss_value): self._loss_history.append(loss_value) # Keep only recent history - max_history = max(200, self.defaults.get('plateau_patience', 0) * 2) + hist_min = self.defaults.get('loss_history_min', 200) + max_history = max(hist_min, self.defaults.get('plateau_patience', 0) * 2) if len(self._loss_history) > max_history: self._loss_history = self._loss_history[-max_history:] @@ -264,8 +275,14 @@ def step(self, closure=None): # --- Plateau Escape (Controlled Perturbation) --- if is_plateau and is_core: - # Inject targeted noise into gradients for the chaos core - noise = torch.randn_like(grad) * plateau_noise * grad.abs().mean() + intensity = group.get('plateau_noise_intensity', 0.1) + + noise_scale = grad.abs().mean() + if (not torch.isfinite(noise_scale)) or noise_scale < 1e-8: + noise_scale = p.abs().mean() + if (not torch.isfinite(noise_scale)) or noise_scale < 1e-8: + noise_scale = grad.new_tensor(1e-3) + noise = torch.randn_like(grad) * plateau_noise * noise_scale * intensity grad = grad + noise # --- Decoupled Weight Decay --- @@ -307,10 +324,11 @@ def step(self, closure=None): if not isinstance(grad_var, float): grad_var = grad_var.item() - # Smooth the ratio (exponential moving average) - grad_var = grad_var * 0.99 + ratio * 0.01 + # Smooth the ratio with exponential moving average + ema = group.get('adaptive_ema', 0.99) + grad_var = grad_var * ema + ratio * (1 - ema) - # Apply smoothed adaptive scaling multiplier to the step size + # Apply adaptive multiplier with epsilon guard adaptive_mult = 1.0 / (grad_var + eps) adaptive_mult = max(adaptive_clip[0], min(adaptive_clip[1], adaptive_mult)) step_size *= adaptive_mult @@ -340,7 +358,12 @@ def step(self, closure=None): self._spectral_radius = abs(sigma_max) if self._spectral_radius > spectral_clip: - p.data.mul_(spectral_clip / (self._spectral_radius + eps)) + ratio = spectral_clip / (self._spectral_radius + eps) + p.data.mul_(ratio) + + # Update momentum buffers to maintain topological alignment + state['exp_avg'].mul_(ratio) + state['exp_avg_sq'].mul_(ratio ** 2) except Exception: pass @@ -357,9 +380,11 @@ def step(self, closure=None): self._diagnostics['total_grad_norm'] = total_grad_norm self._diagnostics['total_param_norm'] = total_param_norm - # Input gradient health calculation + # Input gradient health calculation (detects vanishing inputs) if total_grad_norm > 0: - self._input_grad_health = min(1.0, input_grad_norm / (total_grad_norm * 0.1 + 1e-12)) + # Compared against a scaled threshold of total activity + threshold = self.defaults.get('sentinel_threshold', 0.1) + self._input_grad_health = min(1.0, input_grad_norm / (total_grad_norm * threshold + 1e-12)) if is_plateau: self._diagnostics['plateau_escape_triggered'] = self._global_step @@ -388,7 +413,11 @@ def default(lr=1e-4): chaos_core_lr_mult=1.0, projection_lr_mult=1.0, adaptive_lr=True, + adaptive_ema=0.99, grad_centralization=True, + plateau_noise_intensity=0.1, + sentinel_threshold=0.1, + loss_history_min=200, ) @staticmethod @@ -406,7 +435,11 @@ def aggressive(lr=3e-4): plateau_noise_scale=0.02, adaptive_lr=True, adaptive_lr_clip=(0.2, 5.0), + adaptive_ema=0.98, grad_centralization=True, + plateau_noise_intensity=0.2, + sentinel_threshold=0.1, + loss_history_min=100, ) @staticmethod @@ -422,7 +455,11 @@ def finetune(lr=1e-5): projection_lr_mult=0.8, adaptive_lr=True, adaptive_lr_clip=(0.5, 2.0), + adaptive_ema=0.995, grad_centralization=False, + plateau_noise_intensity=0.05, + sentinel_threshold=0.05, + loss_history_min=500, ) @staticmethod @@ -445,7 +482,11 @@ def large_network(lr=1e-4): input_sentinel=True, adaptive_lr=True, adaptive_lr_clip=(0.1, 5.0), + adaptive_ema=0.99, grad_centralization=True, + plateau_noise_intensity=0.1, + sentinel_threshold=0.2, + loss_history_min=200, ) @staticmethod @@ -460,5 +501,9 @@ def tiny_network(lr=0.01): chaos_core_lr_mult=1.0, projection_lr_mult=1.0, adaptive_lr=False, + adaptive_ema=0.99, grad_centralization=False, + plateau_noise_intensity=0.01, + sentinel_threshold=0.1, + loss_history_min=50, ) diff --git a/realnet/training/trainer.py b/realnet/training/trainer.py index 12fdbeb..6aa0841 100644 --- a/realnet/training/trainer.py +++ b/realnet/training/trainer.py @@ -38,7 +38,7 @@ class RealNetTrainer: def __init__(self, model, optimizer=None, loss_fn=None, lr=1e-4, device='cpu', gradient_persistence=0.0, synaptic_noise=1e-6, chaos_config=None, scheduler_config=None, - use_chaos_grad=None, use_temporal_scheduler=None): + use_chaos_grad=None, use_temporal_scheduler=None, max_grad_norm=1.0): """ Initializes the trainer. @@ -65,6 +65,7 @@ def __init__(self, model, optimizer=None, loss_fn=None, lr=1e-4, device='cpu', self.model.to(self.device) self.gradient_persistence = gradient_persistence self.synaptic_noise = synaptic_noise + self.max_grad_norm = max_grad_norm self.initial_lr = lr # --- Optimizer Initialization --- @@ -234,19 +235,21 @@ def train_batch(self, input_features, target_values, thinking_steps, gradient_ac if step_now: self.scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + if self.max_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) self.scaler.step(self.optimizer) self.scaler.update() if self.gradient_persistence > 0.0: - # Gradient Persistence + # Synchronize persistent gradients with active AMP scale + scale = self.scaler.get_scale() if hasattr(self, 'scaler') and self.scaler.is_enabled() else 1.0 with torch.no_grad(): for param in self.model.parameters(): if param.grad is not None: if torch.isnan(param.grad).any() or torch.isinf(param.grad).any(): param.grad.zero_() else: - param.grad.mul_(self.gradient_persistence) + param.grad.mul_(self.gradient_persistence * scale) else: self.optimizer.zero_grad()