Skip to content
106 changes: 87 additions & 19 deletions cli/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,7 @@ def _train_ensemble_models(
regime_lookback=regime_lookback,
regime_lookahead=regime_lookahead,
console=console,
cfg=cfg,
)
all_metrics['direction'] = dir_metrics

Expand Down Expand Up @@ -1345,6 +1346,7 @@ def _train_direction_or_regime_model(
regime_lookback,
regime_lookahead,
console,
cfg=None,
):
"""Train either the regime or direction model. Returns (dir_data, dir_metrics, dir_model_path, dir_trainer)."""
from src.training.modular_trainers import (
Expand Down Expand Up @@ -1388,34 +1390,80 @@ def _train_direction_or_regime_model(
_absl_logger = _logging.getLogger('absl')
_absl_logger.setLevel(_logging.ERROR)

if use_transformer:
dir_trainer = TransformerDirectionTrainer(trainer_config)
dir_metrics = dir_trainer.train(
dir_data['X_train'], dir_data['y_train'],
dir_data['X_val'], dir_data['y_val'],
# Check if Walk-Forward Cross-Validation is enabled
wf_config = None
if cfg and isinstance(cfg, dict):
wf_config = cfg.get('walkforward', {})

wf_enabled = wf_config and wf_config.get('enabled', False)

if wf_enabled:
# Use WalkForwardOrchestrator for training
from src.training.walkforward_orchestrator import WalkForwardOrchestrator

trainer_class = TransformerDirectionTrainer if use_transformer else TCNTrainer

orchestrator = WalkForwardOrchestrator(
trainer_class=trainer_class,
trainer_config=trainer_config,
wf_config=wf_config,
console=console,
)

# Train with walk-forward validation
dir_trainer, dir_metrics = orchestrator.train(
X_train=dir_data['X_train'],
y_train=dir_data['y_train'],
X_val=dir_data['X_val'],
y_val=dir_data['y_val'],
feature_names=dir_data['feature_names'],
instrument=training_instrument,
w_train=dir_data.get('w_train'),
w_val=dir_data.get('w_val'),
warm_start_path=str(warm_start_path) if warm_start_path else None,
instrument=training_instrument,
)

# Save the best model from WFCV
dir_trainer.save(str(pair_paths['direction']), instrument=training_instrument)
if training_instrument != "GENERIC":
dir_trainer.save(str(model_dir / _TRANSFORMER_KERAS_FILE), instrument=training_instrument)
model_filename = _TRANSFORMER_KERAS_FILE if use_transformer else "tcn_direction.keras"
dir_trainer.save(str(model_dir / model_filename), instrument=training_instrument)
dir_model_path = str(pair_paths['direction'])
console.print(f"[cyan]💾 Direction model saved to: {pair_paths['direction']}[/cyan]")

# Save WFCV summary
orchestrator.save_summary(model_dir)

console.print(f"[cyan]💾 Direction model (WFCV best) saved to: {pair_paths['direction']}[/cyan]")
else:
dir_trainer = TCNTrainer(trainer_config)
dir_metrics = dir_trainer.train(
dir_data['X_train'], dir_data['y_train'],
dir_data['X_val'], dir_data['y_val'],
feature_names=dir_data['feature_names']
)
dir_trainer.save(str(pair_paths['direction']))
if training_instrument != "GENERIC":
dir_trainer.save(str(model_dir / "tcn_direction.keras"))
dir_model_path = str(pair_paths['direction'])
console.print(f"[cyan]💾 TCN model saved to: {pair_paths['direction']}[/cyan]")
# Standard training (no WFCV)
if use_transformer:
dir_trainer = TransformerDirectionTrainer(trainer_config)
dir_metrics = dir_trainer.train(
dir_data['X_train'], dir_data['y_train'],
dir_data['X_val'], dir_data['y_val'],
feature_names=dir_data['feature_names'],
w_train=dir_data.get('w_train'),
w_val=dir_data.get('w_val'),
warm_start_path=str(warm_start_path) if warm_start_path else None,
instrument=training_instrument,
)
dir_trainer.save(str(pair_paths['direction']), instrument=training_instrument)
if training_instrument != "GENERIC":
dir_trainer.save(str(model_dir / _TRANSFORMER_KERAS_FILE), instrument=training_instrument)
dir_model_path = str(pair_paths['direction'])
console.print(f"[cyan]💾 Direction model saved to: {pair_paths['direction']}[/cyan]")
else:
dir_trainer = TCNTrainer(trainer_config)
dir_metrics = dir_trainer.train(
dir_data['X_train'], dir_data['y_train'],
dir_data['X_val'], dir_data['y_val'],
feature_names=dir_data['feature_names']
)
dir_trainer.save(str(pair_paths['direction']))
if training_instrument != "GENERIC":
dir_trainer.save(str(model_dir / "tcn_direction.keras"))
dir_model_path = str(pair_paths['direction'])
console.print(f"[cyan]💾 TCN model saved to: {pair_paths['direction']}[/cyan]")

if 'val_balanced_accuracy' in dir_metrics:
console.print(f"[green]✓ Directional predictor complete: Validation accuracy={dir_metrics['val_accuracy']:.1%} • Balanced={dir_metrics['val_balanced_accuracy']:.1%}[/green]")
Expand Down Expand Up @@ -2586,6 +2634,26 @@ def _generate_training_report(
- Validation Accuracy: {dir_metrics['val_accuracy']:.2%}
- Balanced Accuracy: {dir_metrics.get('val_balanced_accuracy', dir_metrics['val_accuracy']):.2%}
"""

# Add Walk-Forward CV metrics if available
if 'wfcv_mean_test_accuracy' in dir_metrics:
wfcv_mean = dir_metrics['wfcv_mean_test_accuracy']
wfcv_std = dir_metrics.get('wfcv_std_test_accuracy', 0.0)
wfcv_n_folds = dir_metrics.get('wfcv_n_folds', 0)
wfcv_best_fold = dir_metrics.get('wfcv_best_fold', 0)

report_content += f"""
#### Walk-Forward Cross-Validation
- **Test Accuracy (Mean ± Std)**: {wfcv_mean:.2%} ± {wfcv_std:.2%}
- **Number of Folds**: {wfcv_n_folds}
- **Best Fold Selected**: {wfcv_best_fold}
- **Stability (CV)**: {(wfcv_std / wfcv_mean if wfcv_mean > 0 else 0):.4f}

> **Note**: Walk-forward validation provides robust out-of-sample performance estimates.
> The reported model is the best-performing fold from time-series cross-validation.

"""

if 'bootstrap_ci_lower' in dir_metrics:
report_content += f"- Bootstrap 95% CI: [{dir_metrics['bootstrap_ci_lower']:.2%}, {dir_metrics['bootstrap_ci_upper']:.2%}]\n"

Expand Down
Loading
Loading