diff --git a/Makefile b/Makefile index 8a78484..03ec474 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help dev test lint typecheck qa train eval serve smoke-serve download-data transcribe publish benchmark ops-dashboard +.PHONY: help dev test lint typecheck qa train eval serve smoke-serve download-data transcribe publish benchmark manifest post-train set-wandb ops-dashboard .DEFAULT_GOAL := help @@ -82,5 +82,21 @@ benchmark: ## Benchmark endpoint (BASE_URL= [API_KEY=]) if [ -n "$(CONCURRENCY)" ]; then cmd="$$cmd --concurrency $(CONCURRENCY)"; fi; \ eval "$$cmd" +manifest: ## Generate training manifest (TRAIN_CONFIG/RUN_DIR/LOG_FILE) + @if [ -z "$(RUN_DIR)" ] || [ -z "$(LOG_FILE)" ]; then \ + echo "Usage: make manifest TRAIN_CONFIG= RUN_DIR= LOG_FILE="; \ + exit 1; \ + fi + uv run python scripts/generate_training_manifest.py \ + --config "$(TRAIN_CONFIG)" \ + --run-dir "$(RUN_DIR)" \ + --log-file "$(LOG_FILE)" + +post-train: ## Run post-training pipeline (eval -> merge -> optional smoke) + bash scripts/post_training_pipeline.sh + +set-wandb: ## Set WANDB_API_KEY for training services + bash scripts/set_wandb_key.sh + ops-dashboard: ## Launch runtime ops dashboard bash scripts/runtime_dashboard.sh diff --git a/README.md b/README.md index 34733e7..f09be2f 100644 --- a/README.md +++ b/README.md @@ -315,6 +315,22 @@ All commands support `--help` for full option documentation. Run `make help` to --- +## Post-Completion Roadmap + +After the current priority training run is completed, the next improvement work is tracked in: + +- `docs/ROADMAP.md` + +Roadmap phases: + +1. Stability hardening (NaN guards, fail-fast, auto-resume) +2. Turkish data expansion and quality filtering +3. A100 training recipe optimization +4. Serving throughput and latency optimization +5. Evaluation depth and release governance + +--- + ## Notebooks Interactive Jupyter notebooks for exploration and analysis: diff --git a/configs/models/turkcell_7b_a100_resume_bf16_clean.yaml b/configs/models/turkcell_7b_a100_resume_bf16_clean.yaml new file mode 100644 index 0000000..e780ef6 --- /dev/null +++ b/configs/models/turkcell_7b_a100_resume_bf16_clean.yaml @@ -0,0 +1,15 @@ +# A100 80GB optimized bf16 resume profile from checkpoint-800. +# Uses a dedicated run name and save/eval interval for long remote runs. + +_base: "./turkcell_7b.yaml" + +training: + per_device_train_batch_size: 8 + gradient_accumulation_steps: 2 + eval_steps: 1000 + save_steps: 1000 + fp16: false + bf16: true + +wandb: + run_name: "turkcell-7b-sft-v1-a100-bf16-r2" diff --git a/configs/models/turkcell_7b_a100_v3_clean.yaml b/configs/models/turkcell_7b_a100_v3_clean.yaml new file mode 100644 index 0000000..c8e341d --- /dev/null +++ b/configs/models/turkcell_7b_a100_v3_clean.yaml @@ -0,0 +1,25 @@ +# Turkcell-7B A100 stable profile (post-NaN recovery). +_base: "./turkcell_7b.yaml" + +model: + max_seq_length: 2048 + +data: + train_path: "data/processed/turkish_sft_v3_clean.jsonl" + eval_path: "data/processed/turkish_eval.jsonl" + +training: + num_epochs: 1 + learning_rate: 5.0e-5 + lr_scheduler_type: "cosine" + warmup_ratio: 0.05 + max_grad_norm: 1.0 + per_device_train_batch_size: 8 + gradient_accumulation_steps: 2 + eval_steps: 500 + save_steps: 500 + fp16: false + bf16: true + +wandb: + run_name: "turkcell-7b-sft-v3-a100-bf16-stable" diff --git a/configs/models/turkcell_7b_a100_v4_recovery.yaml b/configs/models/turkcell_7b_a100_v4_recovery.yaml new file mode 100644 index 0000000..c4d51bd --- /dev/null +++ b/configs/models/turkcell_7b_a100_v4_recovery.yaml @@ -0,0 +1,25 @@ +# Turkcell-7B A100 recovery profile after NaN stop at step 800. +_base: "./turkcell_7b.yaml" + +model: + max_seq_length: 2048 + +data: + train_path: "data/processed/turkish_sft_v3_clean.jsonl" + eval_path: "data/processed/turkish_eval.jsonl" + +training: + num_epochs: 1 + learning_rate: 3.0e-5 + lr_scheduler_type: "cosine" + warmup_ratio: 0.05 + max_grad_norm: 1.0 + per_device_train_batch_size: 8 + gradient_accumulation_steps: 2 + eval_steps: 500 + save_steps: 500 + fp16: false + bf16: true + +wandb: + run_name: "turkcell-7b-sft-v4-a100-bf16-recovery" diff --git a/configs/models/turkcell_7b_a100_v5_recovery_low_lr.yaml b/configs/models/turkcell_7b_a100_v5_recovery_low_lr.yaml new file mode 100644 index 0000000..f650172 --- /dev/null +++ b/configs/models/turkcell_7b_a100_v5_recovery_low_lr.yaml @@ -0,0 +1,25 @@ +# Turkcell-7B A100 recovery profile after NaN stop. +_base: "./turkcell_7b.yaml" + +model: + max_seq_length: 2048 + +data: + train_path: "data/processed/turkish_sft_v3_clean.jsonl" + eval_path: "data/processed/turkish_eval.jsonl" + +training: + num_epochs: 1 + learning_rate: 2.0e-5 + lr_scheduler_type: "cosine" + warmup_ratio: 0.05 + max_grad_norm: 1.0 + per_device_train_batch_size: 8 + gradient_accumulation_steps: 2 + eval_steps: 500 + save_steps: 500 + fp16: false + bf16: true + +wandb: + run_name: "turkcell-7b-sft-v5-a100-bf16-recovery-low-lr" diff --git a/configs/models/turkcell_7b_a100_v6_recovery_reset_opt.yaml b/configs/models/turkcell_7b_a100_v6_recovery_reset_opt.yaml new file mode 100644 index 0000000..a6579fb --- /dev/null +++ b/configs/models/turkcell_7b_a100_v6_recovery_reset_opt.yaml @@ -0,0 +1,14 @@ +# Turkcell-7B A100 recovery profile with optimizer reset. +# Use adapter warm-start from checkpoint-500 without resuming optimizer state. +_base: "./turkcell_7b_a100_v5_recovery_low_lr.yaml" + +training: + learning_rate: 3.0e-5 + warmup_ratio: 0.08 + max_grad_norm: 0.5 + eval_steps: 250 + save_steps: 250 + adapter_init_path: "artifacts/training/turkcell-7b-sft-v3-a100-bf16-stable/checkpoint-500" + +wandb: + run_name: "turkcell-7b-sft-v6-a100-bf16-recovery-reset-opt" diff --git a/deploy/systemd/forge-training-monitor.service b/deploy/systemd/forge-training-monitor.service new file mode 100644 index 0000000..92d1a9b --- /dev/null +++ b/deploy/systemd/forge-training-monitor.service @@ -0,0 +1,19 @@ +[Unit] +Description=LowResource-LLM-Forge Training Progress Monitor +After=forge-training.service +Wants=forge-training.service +PartOf=forge-training.service + +[Service] +Type=simple +WorkingDirectory=%h/projects/LowResource-LLM-Forge +Environment=PYTHONUNBUFFERED=1 +EnvironmentFile=-%h/.config/forge/training.env +ExecStart=%h/projects/LowResource-LLM-Forge/scripts/monitor_a100_training.sh +Restart=on-failure +RestartSec=20 +StandardOutput=append:%h/projects/LowResource-LLM-Forge/artifacts/logs/training_monitor_a100.log +StandardError=append:%h/projects/LowResource-LLM-Forge/artifacts/logs/training_monitor_a100.log + +[Install] +WantedBy=default.target diff --git a/deploy/systemd/forge-training-watchdog.service b/deploy/systemd/forge-training-watchdog.service new file mode 100644 index 0000000..6502433 --- /dev/null +++ b/deploy/systemd/forge-training-watchdog.service @@ -0,0 +1,18 @@ +[Unit] +Description=LowResource-LLM-Forge Training Watchdog +After=forge-training.service +Wants=forge-training.service + +[Service] +Type=simple +WorkingDirectory=%h/projects/LowResource-LLM-Forge +Environment=PYTHONUNBUFFERED=1 +EnvironmentFile=-%h/.config/forge/training.env +ExecStart=%h/projects/LowResource-LLM-Forge/scripts/training_watchdog.py --service forge-training.service --nan-consecutive-limit 3 +Restart=always +RestartSec=10 +StandardOutput=append:%h/projects/LowResource-LLM-Forge/artifacts/logs/training_watchdog.log +StandardError=append:%h/projects/LowResource-LLM-Forge/artifacts/logs/training_watchdog.log + +[Install] +WantedBy=default.target diff --git a/deploy/systemd/forge-training.service b/deploy/systemd/forge-training.service new file mode 100644 index 0000000..188040f --- /dev/null +++ b/deploy/systemd/forge-training.service @@ -0,0 +1,18 @@ +[Unit] +Description=LowResource-LLM-Forge A100 Training +After=network-online.target +Wants=network-online.target + +[Service] +Type=simple +WorkingDirectory=%h/projects/LowResource-LLM-Forge +Environment=PYTHONUNBUFFERED=1 +EnvironmentFile=-%h/.config/forge/training.env +ExecStart=%h/projects/LowResource-LLM-Forge/scripts/start_a100_training.sh +Restart=on-failure +RestartSec=20 +StandardOutput=journal +StandardError=journal + +[Install] +WantedBy=default.target diff --git a/docs/ROADMAP.md b/docs/ROADMAP.md new file mode 100644 index 0000000..d7f991b --- /dev/null +++ b/docs/ROADMAP.md @@ -0,0 +1,98 @@ +# Project Roadmap + +This roadmap starts after the current priority training run on A100 is completed and evaluated. + +## Current Run Definition of Done + +Before moving to improvement work: + +1. Complete the active training run (`target_steps=25845`) or end by a valid early-stop condition. +2. Merge adapter into base model and produce a merged checkpoint. +3. Run full evaluation (`perplexity`, `generation`, optional `mmlu_tr`) and save report artifacts. +4. Publish a versioned release candidate with reproducible config references. + +## Post-Completion Improvement Plan + +### Phase 1: Stability Hardening (Priority P0) + +Goal: prevent silent training failure and auto-recover quickly. + +- Add NaN/Inf guard callbacks for `loss`, `grad_norm`, and `eval_loss`. +- Fail fast on unstable metrics and auto-resume from last healthy checkpoint. +- Keep `systemd --user` + watchdog as the default runtime path on remote hosts. +- Persist heartbeat and key metrics to machine-readable status files for monitoring. + +Exit criteria: + +- No silent NaN progression in new runs. +- Automatic recovery from interruption in under 10 minutes. +- Stable checkpoints produced on schedule. + +### Phase 2: Turkish Data Expansion and Quality (Priority P0) + +Goal: improve model quality using larger, cleaner, better-balanced Turkish corpora. + +- Expand corpus with open Turkish sources (for example mC4, OSCAR, Wiki-derived text, curated Turkish instruction datasets). +- Improve deduplication and language filtering thresholds. +- Add quality scoring filters (length, script ratio, repetition, malformed text checks). +- Build a versioned dataset mixture and track it in a changelog. + +Suggested starting mixture: + +- 60% high-quality instruction data +- 25% domain text relevant to target use-cases +- 15% synthetic/translated augmentation with strict filtering + +Exit criteria: + +- At least 2x unique Turkish token coverage vs current baseline. +- Low-quality sample ratio below 5% after filtering. + +### Phase 3: Training Recipe Optimization on A100 (Priority P0) + +Goal: increase quality while preserving training stability. + +- Run controlled sweeps for learning rate, warmup ratio, LoRA rank/alpha, and effective batch size. +- Keep bf16 enabled on A100 and tune gradient accumulation for throughput. +- Tune evaluation cadence (`eval_steps=1000`) and checkpoint cadence (`save_steps=1000`). +- Promote only runs with finite metrics and consistent convergence. + +Exit criteria: + +- Perplexity improves by at least 10% from baseline. +- Generation quality score improves by at least 0.4. +- No regression in safety/format adherence prompts. + +### Phase 4: Inference Throughput and Latency (Priority P1) + +Goal: approach high-quality serving UX (fast first token + fluent decode). + +- Tune vLLM serving args (`max_num_batched_tokens`, `max_num_seqs`, `gpu_memory_utilization`, tensor parallelism). +- Benchmark p50/p95 latency and tokens/sec under concurrent load. +- Add configuration profiles for low-latency and high-throughput modes. +- Evaluate TensorRT-LLM/NIM path only after vLLM baseline is saturated. + +Exit criteria: + +- At least 30% tokens/sec gain at target concurrency. +- p95 time-to-first-token under defined SLO. + +### Phase 5: Evaluation Depth and Release Governance (Priority P1) + +Goal: make releases trustworthy and repeatable. + +- Expand held-out Turkish eval set by domain. +- Add lightweight human review rubrics for fluency, factuality, and instruction-following. +- Track every release with dataset version, config hash, and benchmark deltas. +- Gate promotion on quality thresholds and regression checks. + +Exit criteria: + +- Every release has reproducible lineage. +- Promotion decisions are benchmark-backed and auditable. + +## Immediate Next Actions After Current Run + +1. Generate baseline report from the active A100 run. +2. Launch Phase 1 stability patch set before the next long training job. +3. Build `turkish-v2` dataset mixture and run a short smoke training cycle. diff --git a/docs/TRAINING_GUIDE.md b/docs/TRAINING_GUIDE.md index 1d10da5..b1a288d 100644 --- a/docs/TRAINING_GUIDE.md +++ b/docs/TRAINING_GUIDE.md @@ -133,3 +133,9 @@ For interactive training analysis: **Unsloth not found**: Install training extras with `uv sync --extra train`. Pipeline falls back to standard PEFT automatically when Unsloth is unavailable. **Poor Turkish output**: Check tokenizer coverage — models not trained on Turkish may tokenize inefficiently, reducing effective context length. + +## After Current Run Completes + +The post-completion improvement backlog (stability, data expansion, A100 recipe tuning, and serving performance) is tracked in: + +- `docs/ROADMAP.md` diff --git a/scripts/generate_training_manifest.py b/scripts/generate_training_manifest.py new file mode 100644 index 0000000..de8e620 --- /dev/null +++ b/scripts/generate_training_manifest.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +"""Generate a deterministic training manifest for a completed run.""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import re +import subprocess +from datetime import UTC, datetime +from pathlib import Path + +from forge.utils.config import load_training_config + +TIMESTAMP_RE = re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?Z)") + + +def _utc_now() -> str: + return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z") + + +def _sha256_file(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as handle: + for chunk in iter(lambda: handle.read(1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() + + +def _line_count(path: Path) -> int: + if not path.exists(): + return 0 + with path.open("rb") as handle: + return sum(1 for _ in handle) + + +def _git_commit() -> str: + proc = subprocess.run( + ["git", "rev-parse", "HEAD"], + check=False, + capture_output=True, + text=True, + ) + return proc.stdout.strip() if proc.returncode == 0 else "unknown" + + +def _extract_log_times(log_file: Path) -> tuple[str, str]: + if not log_file.exists(): + return "unknown", "unknown" + + start_ts = "unknown" + end_ts = "unknown" + with log_file.open(encoding="utf-8", errors="ignore") as handle: + for line in handle: + if "training_started" in line and start_ts == "unknown": + match = TIMESTAMP_RE.match(line.strip()) + if match: + start_ts = match.group(1) + if "training_complete" in line or "Training complete. Adapter saved to" in line: + match = TIMESTAMP_RE.match(line.strip()) + if match: + end_ts = match.group(1) + return start_ts, end_ts + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate training manifest JSON.") + parser.add_argument("--config", required=True, help="Training config path.") + parser.add_argument("--run-dir", required=True, help="Training run directory.") + parser.add_argument("--log-file", required=True, help="Training log file path.") + parser.add_argument( + "--output", + default=None, + help="Output manifest path (defaults to /manifest.json).", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + config_path = Path(args.config).resolve() + run_dir = Path(args.run_dir).resolve() + log_file = Path(args.log_file).resolve() + output_path = Path(args.output).resolve() if args.output else run_dir / "manifest.json" + + cfg = load_training_config(config_path) + train_path = Path(cfg.train_data_path).resolve() + eval_path = Path(cfg.eval_data_path).resolve() + + final_dir = run_dir / "final" + checkpoints = sorted(p.name for p in run_dir.glob("checkpoint-*") if p.is_dir()) + start_ts, end_ts = _extract_log_times(log_file) + + manifest = { + "created_utc": _utc_now(), + "git_commit": _git_commit(), + "config_path": str(config_path), + "config_sha256": _sha256_file(config_path), + "run_dir": str(run_dir), + "log_file": str(log_file), + "run_start_utc": start_ts, + "run_end_utc": end_ts, + "model_name": cfg.model.name, + "run_name": cfg.wandb.run_name, + "train_data_path": str(train_path), + "eval_data_path": str(eval_path), + "train_records": _line_count(train_path), + "eval_records": _line_count(eval_path), + "final_dir_exists": final_dir.exists(), + "checkpoint_dirs": checkpoints, + } + + output_path.parent.mkdir(parents=True, exist_ok=True) + payload = json.dumps(manifest, indent=2, ensure_ascii=False) + "\n" + output_path.write_text(payload, encoding="utf-8") + print(f"manifest_written={output_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/install_training_services.sh b/scripts/install_training_services.sh new file mode 100755 index 0000000..fbaddca --- /dev/null +++ b/scripts/install_training_services.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash +set -euo pipefail + +PROJECT_ROOT="${PROJECT_ROOT:-$HOME/projects/LowResource-LLM-Forge}" +SYSTEMD_USER_DIR="${SYSTEMD_USER_DIR:-$HOME/.config/systemd/user}" +FORGE_ENV_DIR="${FORGE_ENV_DIR:-$HOME/.config/forge}" +FORGE_ENV_FILE="${FORGE_ENV_FILE:-$FORGE_ENV_DIR/training.env}" + +mkdir -p "$SYSTEMD_USER_DIR" "$PROJECT_ROOT/artifacts/logs" "$FORGE_ENV_DIR" + +if [[ ! -f "$FORGE_ENV_FILE" ]]; then + cat >"$FORGE_ENV_FILE" <<'EOF' +# Required for training with WandB. +# Set your real key before starting forge-training.service. +WANDB_API_KEY= + +# Optional overrides: +# TRAIN_CONFIG=configs/models/turkcell_7b_a100_v4_recovery.yaml +# TRAIN_RUN_DIR=artifacts/training/turkcell-7b-sft-v4-a100-bf16-recovery +# TRAIN_LOG=artifacts/logs/training_turkcell_7b_a100_v4_recovery.log +# TARGET_STEPS=8601 +# ENABLE_RESUME=0 +# REQUIRE_WANDB=1 +# BOOTSTRAP_CHECKPOINT= +EOF + chmod 600 "$FORGE_ENV_FILE" +fi + +install -m 0644 \ + "$PROJECT_ROOT/deploy/systemd/forge-training.service" \ + "$SYSTEMD_USER_DIR/forge-training.service" +install -m 0644 \ + "$PROJECT_ROOT/deploy/systemd/forge-training-watchdog.service" \ + "$SYSTEMD_USER_DIR/forge-training-watchdog.service" +install -m 0644 \ + "$PROJECT_ROOT/deploy/systemd/forge-training-monitor.service" \ + "$SYSTEMD_USER_DIR/forge-training-monitor.service" + +chmod +x \ + "$PROJECT_ROOT/scripts/start_a100_training.sh" \ + "$PROJECT_ROOT/scripts/monitor_a100_training.sh" \ + "$PROJECT_ROOT/scripts/training_watchdog.py" + +systemctl --user daemon-reload +systemctl --user enable forge-training.service +systemctl --user enable forge-training-watchdog.service +systemctl --user enable forge-training-monitor.service + +require_wandb="$(grep -E '^REQUIRE_WANDB=' "$FORGE_ENV_FILE" | tail -n 1 | cut -d '=' -f2 | tr -d '[:space:]' || true)" +require_wandb="${require_wandb:-1}" + +if [[ "$require_wandb" == "0" ]] || grep -qE '^WANDB_API_KEY=.+$' "$FORGE_ENV_FILE"; then + systemctl --user restart forge-training.service + systemctl --user restart forge-training-watchdog.service + systemctl --user restart forge-training-monitor.service +else + systemctl --user stop forge-training-monitor.service || true + systemctl --user stop forge-training-watchdog.service || true + systemctl --user stop forge-training.service || true +fi + +systemctl --user --no-pager --lines=20 status forge-training.service || true +systemctl --user --no-pager --lines=20 status forge-training-watchdog.service || true +systemctl --user --no-pager --lines=20 status forge-training-monitor.service || true +echo +echo "Edit $FORGE_ENV_FILE and set WANDB_API_KEY before starting training." +echo "Or run: scripts/set_wandb_key.sh" diff --git a/scripts/monitor_a100_training.sh b/scripts/monitor_a100_training.sh new file mode 100755 index 0000000..9795ce5 --- /dev/null +++ b/scripts/monitor_a100_training.sh @@ -0,0 +1,139 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd /home/weezboo/projects/LowResource-LLM-Forge + +TRAIN_CONFIG="${TRAIN_CONFIG:-configs/models/turkcell_7b_a100_v4_recovery.yaml}" +CONFIG_BASENAME="$(basename "$TRAIN_CONFIG")" +CONFIG_SLUG="${CONFIG_BASENAME%.*}" +LOG_FILE="${LOG_FILE:-${TRAIN_LOG:-artifacts/logs/training_${CONFIG_SLUG}.log}}" +STATUS_FILE="${STATUS_FILE:-artifacts/logs/training_monitor_status_a100.txt}" +ETA_STATE_FILE="${ETA_STATE_FILE:-artifacts/logs/training_monitor_eta_state_${CONFIG_SLUG}.env}" +TARGET_STEPS="${TARGET_STEPS:-8601}" +PATTERN="${PATTERN:-run_training.py --config ${TRAIN_CONFIG}}" +SLEEP_SECS="${SLEEP_SECS:-60}" + +mkdir -p artifacts/logs + +prev_ts=0 +prev_step=0 +ema_sps="" +speed_source="none" + +if [[ -f "$ETA_STATE_FILE" ]]; then + # shellcheck disable=SC1090 + source "$ETA_STATE_FILE" +fi + +while true; do + ts="$(date -u +%Y-%m-%dT%H:%M:%SZ)" + now_epoch="$(date -u +%s)" + + running="no" + if pgrep -f "$PATTERN" >/dev/null 2>&1 || pgrep -f "scripts/run_training.py" >/dev/null 2>&1; then + running="yes" + fi + + progress="none" + if [[ -f "$LOG_FILE" ]]; then + progress="$(grep -a -oE "[0-9]+/${TARGET_STEPS}" "$LOG_FILE" | tail -n 1 || true)" + if [[ -z "$progress" ]]; then + progress="none" + fi + fi + + step="0" + if [[ "$progress" != "none" ]]; then + step="${progress%%/*}" + fi + + pct="0" + if [[ "$step" =~ ^[0-9]+$ ]] && [[ $TARGET_STEPS -gt 0 ]]; then + pct=$((step * 100 / TARGET_STEPS)) + fi + + nan_count="0" + if [[ -f "$LOG_FILE" ]]; then + # Count only real NaN/Inf metric values and explicit NaN guard events. + metric_nan_count="$(grep -a -E -i "'(loss|grad_norm|eval_loss)':[[:space:]]*'?(nan|inf)'?" "$LOG_FILE" | wc -l | tr -d '[:space:]' || true)" + guard_nan_count="$(grep -a -E -c "nan_guard_detected|nan_guard_stopping_training" "$LOG_FILE" || true)" + metric_nan_count="${metric_nan_count:-0}" + guard_nan_count="${guard_nan_count:-0}" + nan_count=$((metric_nan_count + guard_nan_count)) + fi + + gpu_line="$(nvidia-smi --query-gpu=utilization.gpu,memory.used,memory.total --format=csv,noheader | head -n1 2>/dev/null || echo unknown)" + + steps_per_hour="unknown" + eta_seconds="unknown" + eta_utc="unknown" + remaining_steps="unknown" + + if [[ "$step" =~ ^[0-9]+$ ]] && [[ "$step" -lt "$TARGET_STEPS" ]]; then + remaining_steps=$((TARGET_STEPS - step)) + fi + + if [[ "$step" =~ ^[0-9]+$ ]] && [[ "$prev_ts" =~ ^[0-9]+$ ]] && [[ "$prev_step" =~ ^[0-9]+$ ]]; then + if [[ $prev_ts -gt 0 ]] && [[ $now_epoch -gt $prev_ts ]] && [[ $step -gt $prev_step ]]; then + delta_steps=$((step - prev_step)) + delta_secs=$((now_epoch - prev_ts)) + instant_sps="$(awk -v ds="$delta_steps" -v dt="$delta_secs" 'BEGIN { printf "%.8f", ds / dt }')" + + if [[ -n "$ema_sps" ]]; then + ema_sps="$(awk -v e="$ema_sps" -v i="$instant_sps" 'BEGIN { printf "%.8f", (0.7 * e) + (0.3 * i) }')" + speed_source="ema" + else + ema_sps="$instant_sps" + speed_source="instant" + fi + fi + fi + + if [[ -n "$ema_sps" ]] && awk -v s="$ema_sps" 'BEGIN { exit !(s > 0) }'; then + steps_per_hour="$(awk -v s="$ema_sps" 'BEGIN { printf "%.1f", s * 3600 }')" + if [[ "$remaining_steps" =~ ^[0-9]+$ ]]; then + eta_seconds="$(awk -v rem="$remaining_steps" -v s="$ema_sps" 'BEGIN { printf "%.0f", rem / s }')" + if [[ "$eta_seconds" =~ ^[0-9]+$ ]]; then + eta_utc="$(date -u -d "@$((now_epoch + eta_seconds))" +%Y-%m-%dT%H:%M:%SZ 2>/dev/null || echo unknown)" + fi + fi + fi + + { + echo "timestamp_utc=$ts" + echo "running=$running" + echo "step=$step" + echo "target_steps=$TARGET_STEPS" + echo "progress=$progress" + echo "percent=$pct" + echo "remaining_steps=$remaining_steps" + echo "steps_per_hour=$steps_per_hour" + echo "eta_seconds=$eta_seconds" + echo "eta_utc=$eta_utc" + echo "speed_source=$speed_source" + echo "nan_count=$nan_count" + echo "gpu=$gpu_line" + } >"$STATUS_FILE" + + prev_ts="$now_epoch" + prev_step="$step" + { + echo "prev_ts=$prev_ts" + echo "prev_step=$prev_step" + echo "ema_sps=$ema_sps" + echo "speed_source=$speed_source" + } >"$ETA_STATE_FILE" + + if [[ "$running" == "no" ]]; then + echo "state=stopped" >>"$STATUS_FILE" + exit 0 + fi + + if [[ "$step" =~ ^[0-9]+$ ]] && [[ $step -ge $TARGET_STEPS ]]; then + echo "state=completed" >>"$STATUS_FILE" + exit 0 + fi + + echo "state=running" >>"$STATUS_FILE" + sleep "$SLEEP_SECS" +done diff --git a/scripts/post_training_pipeline.sh b/scripts/post_training_pipeline.sh new file mode 100644 index 0000000..f3f6a6a --- /dev/null +++ b/scripts/post_training_pipeline.sh @@ -0,0 +1,98 @@ +#!/usr/bin/env bash +set -euo pipefail + +PROJECT_ROOT="${PROJECT_ROOT:-$HOME/projects/LowResource-LLM-Forge}" +cd "$PROJECT_ROOT" + +UV_BIN="${UV_BIN:-$HOME/.local/bin/uv}" +TRAIN_CONFIG="${TRAIN_CONFIG:-configs/models/turkcell_7b_a100_v4_recovery.yaml}" +RUN_DIR="${RUN_DIR:-artifacts/training/turkcell-7b-sft-v4-a100-bf16-recovery}" +TRAIN_LOG="${TRAIN_LOG:-artifacts/logs/training_a100_bf16_v4_recovery.log}" +ADAPTER_DIR="${ADAPTER_DIR:-$RUN_DIR/final}" + +BASE_MODEL="${BASE_MODEL:-TURKCELL/Turkcell-LLM-7b-v1}" +MERGED_OUTPUT="${MERGED_OUTPUT:-artifacts/merged/turkcell-7b-a100-v4-recovery}" +EVAL_OUTPUT_ROOT="${EVAL_OUTPUT_ROOT:-artifacts/eval/turkcell-7b-a100-v4-recovery}" + +PUSH_TO_HUB="${PUSH_TO_HUB:-0}" +HUB_REPO="${HUB_REPO:-}" + +SERVE_BASE_URL="${SERVE_BASE_URL:-}" +SERVE_API_KEY="${SERVE_API_KEY:-}" +BENCHMARK_NUM_REQUESTS="${BENCHMARK_NUM_REQUESTS:-50}" +BENCHMARK_CONCURRENCY="${BENCHMARK_CONCURRENCY:-5}" + +if [[ ! -x "$UV_BIN" ]]; then + echo "UV executable not found: $UV_BIN" >&2 + exit 1 +fi + +if [[ ! -d "$ADAPTER_DIR" ]]; then + echo "Adapter directory not found: $ADAPTER_DIR" >&2 + exit 1 +fi + +if [[ "$PUSH_TO_HUB" == "1" ]] && [[ -z "$HUB_REPO" ]]; then + echo "HUB_REPO is required when PUSH_TO_HUB=1." >&2 + exit 1 +fi + +echo "[$(date -u +%Y-%m-%dT%H:%M:%SZ)] post-training-pipeline-start" +echo "train_config=$TRAIN_CONFIG" +echo "run_dir=$RUN_DIR" +echo "adapter_dir=$ADAPTER_DIR" + +echo +echo "[1/4] Generate training manifest" +"$UV_BIN" run python scripts/generate_training_manifest.py \ + --config "$TRAIN_CONFIG" \ + --run-dir "$RUN_DIR" \ + --log-file "$TRAIN_LOG" + +echo +echo "[2/4] Run offline evaluations (mmlu_tr, perplexity, generation)" +for bench in mmlu_tr perplexity generation; do + out_dir="$EVAL_OUTPUT_ROOT/$bench" + mkdir -p "$out_dir" + echo " - benchmark=$bench output=$out_dir" + "$UV_BIN" run python scripts/run_eval.py \ + --model "$ADAPTER_DIR" \ + --benchmark "$bench" \ + --output-dir "$out_dir" +done + +echo +echo "[3/4] Merge adapters into base model" +merge_cmd=( + "$UV_BIN" run python scripts/merge_and_push.py + --base-model "$BASE_MODEL" + --adapter "$ADAPTER_DIR" + --output "$MERGED_OUTPUT" +) +if [[ "$PUSH_TO_HUB" == "1" ]]; then + merge_cmd+=(--push --hub-repo "$HUB_REPO") +fi +"${merge_cmd[@]}" + +echo +echo "[4/4] Optional serving smoke/benchmark" +if [[ -n "$SERVE_BASE_URL" ]]; then + smoke_cmd=("$UV_BIN" run python scripts/smoke_serve.py --base-url "$SERVE_BASE_URL") + bench_cmd=( + "$UV_BIN" run python scripts/benchmark_openai_endpoint.py + --base-url "$SERVE_BASE_URL" + --num-requests "$BENCHMARK_NUM_REQUESTS" + --concurrency "$BENCHMARK_CONCURRENCY" + ) + if [[ -n "$SERVE_API_KEY" ]]; then + smoke_cmd+=(--api-key "$SERVE_API_KEY") + bench_cmd+=(--api-key "$SERVE_API_KEY") + fi + "${smoke_cmd[@]}" + "${bench_cmd[@]}" +else + echo " - SERVE_BASE_URL not set; skipping serve smoke + benchmark" +fi + +echo +echo "[$(date -u +%Y-%m-%dT%H:%M:%SZ)] post-training-pipeline-complete" diff --git a/scripts/set_wandb_key.sh b/scripts/set_wandb_key.sh new file mode 100644 index 0000000..0585447 --- /dev/null +++ b/scripts/set_wandb_key.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +set -euo pipefail + +ENV_FILE="${ENV_FILE:-$HOME/.config/forge/training.env}" +START_SERVICES="${START_SERVICES:-1}" + +key="${1:-}" +if [[ -z "$key" ]]; then + read -r -s -p "Enter WANDB_API_KEY: " key + echo +fi + +if [[ -z "$key" ]]; then + echo "WANDB_API_KEY cannot be empty." >&2 + exit 1 +fi + +mkdir -p "$(dirname "$ENV_FILE")" +tmp_file="$(mktemp)" + +if [[ -f "$ENV_FILE" ]]; then + grep -v '^WANDB_API_KEY=' "$ENV_FILE" >"$tmp_file" || true +fi +printf 'WANDB_API_KEY=%s\n' "$key" >>"$tmp_file" + +install -m 600 "$tmp_file" "$ENV_FILE" +rm -f "$tmp_file" + +echo "Updated $ENV_FILE" + +if [[ "$START_SERVICES" == "1" ]]; then + systemctl --user daemon-reload + systemctl --user restart forge-training.service + systemctl --user restart forge-training-watchdog.service + systemctl --user restart forge-training-monitor.service + systemctl --user --no-pager --lines=10 status forge-training.service || true +fi diff --git a/scripts/start_a100_training.sh b/scripts/start_a100_training.sh new file mode 100755 index 0000000..ff0b8ad --- /dev/null +++ b/scripts/start_a100_training.sh @@ -0,0 +1,82 @@ +#!/usr/bin/env bash +set -euo pipefail + +PROJECT_ROOT="${PROJECT_ROOT:-$HOME/projects/LowResource-LLM-Forge}" +cd "$PROJECT_ROOT" + +TRAIN_CONFIG="${TRAIN_CONFIG:-configs/models/turkcell_7b_a100_v4_recovery.yaml}" +CONFIG_BASENAME="$(basename "$TRAIN_CONFIG")" +CONFIG_SLUG="${CONFIG_BASENAME%.*}" +TRAIN_RUN_DIR="${TRAIN_RUN_DIR:-artifacts/training/${CONFIG_SLUG}}" +TRAIN_LOG="${TRAIN_LOG:-artifacts/logs/training_${CONFIG_SLUG}.log}" +BOOTSTRAP_CHECKPOINT="${BOOTSTRAP_CHECKPOINT:-}" +ENABLE_RESUME="${ENABLE_RESUME:-0}" +HF_HOME_DIR="${HF_HOME_DIR:-$PROJECT_ROOT/.hf_cache}" +HF_DATASETS_CACHE_DIR="${HF_DATASETS_CACHE_DIR:-$HF_HOME_DIR/datasets}" +HF_HUB_CACHE_DIR="${HF_HUB_CACHE_DIR:-$HF_HOME_DIR/hub}" +UV_BIN="${UV_BIN:-$HOME/.local/bin/uv}" +REQUIRE_WANDB="${REQUIRE_WANDB:-1}" + +mkdir -p \ + "$(dirname "$TRAIN_RUN_DIR")" \ + "$(dirname "$TRAIN_LOG")" \ + "$HF_HOME_DIR" \ + "$HF_DATASETS_CACHE_DIR" \ + "$HF_HUB_CACHE_DIR" \ + artifacts/logs + +# Keep a durable per-run log even when systemd unit output targets change. +exec > >(tee -a "$TRAIN_LOG") 2>&1 + +if [[ ! -x "$UV_BIN" ]]; then + echo "UV executable not found: $UV_BIN" >&2 + exit 1 +fi + +if [[ "$REQUIRE_WANDB" == "1" ]] && [[ -z "${WANDB_API_KEY:-}" ]]; then + echo "WANDB_API_KEY is required for this run (REQUIRE_WANDB=1)." >&2 + exit 1 +fi + +find_latest_checkpoint() { + if [[ ! -d "$TRAIN_RUN_DIR" ]]; then + return + fi + find "$TRAIN_RUN_DIR" -maxdepth 1 -type d -name "checkpoint-*" | sort -V | tail -n 1 +} + +resume_from="" +if [[ "$ENABLE_RESUME" == "1" ]]; then + latest_checkpoint="$(find_latest_checkpoint || true)" + if [[ -n "$latest_checkpoint" ]]; then + resume_from="$latest_checkpoint" + elif [[ -n "$BOOTSTRAP_CHECKPOINT" ]] && [[ -d "$BOOTSTRAP_CHECKPOINT" ]]; then + resume_from="$BOOTSTRAP_CHECKPOINT" + fi +fi + +cmd=("$UV_BIN" "run" "python" "scripts/run_training.py" "--config" "$TRAIN_CONFIG") +if [[ -n "$resume_from" ]]; then + cmd+=("--resume-from" "$resume_from") +fi + +echo "[$(date -u +%Y-%m-%dT%H:%M:%SZ)] forge-training-start" +echo "project_root=$PROJECT_ROOT" +echo "train_config=$TRAIN_CONFIG" +echo "config_slug=$CONFIG_SLUG" +echo "train_run_dir=$TRAIN_RUN_DIR" +echo "train_log=$TRAIN_LOG" +echo "resume_from=${resume_from:-none}" +echo "enable_resume=$ENABLE_RESUME" +echo "require_wandb=$REQUIRE_WANDB" +echo "hf_home=$HF_HOME_DIR" +echo "hf_datasets_cache=$HF_DATASETS_CACHE_DIR" +echo "hf_hub_cache=$HF_HUB_CACHE_DIR" +echo "command=${cmd[*]}" + +exec env \ + FORGE_EXECUTION_CONTEXT=remote \ + HF_HOME="$HF_HOME_DIR" \ + HF_DATASETS_CACHE="$HF_DATASETS_CACHE_DIR" \ + HUGGINGFACE_HUB_CACHE="$HF_HUB_CACHE_DIR" \ + "${cmd[@]}" diff --git a/scripts/training_watchdog.py b/scripts/training_watchdog.py new file mode 100755 index 0000000..0153945 --- /dev/null +++ b/scripts/training_watchdog.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +"""Watchdog for long-running training on remote GPU hosts. + +Restarts a user-level systemd training service when: +1) Too many consecutive metric lines contain NaN. +2) Training step does not advance for a configured stall timeout. +""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import os +import re +import subprocess +import time +from dataclasses import asdict, dataclass +from pathlib import Path + + +@dataclass +class WatchdogState: + """Persisted state between watchdog loops.""" + + last_metric_hash: str = "" + nan_consecutive: int = 0 + last_step: int = 0 + last_step_change_ts: float = 0.0 + + +def _config_slug() -> str: + train_config = os.getenv("TRAIN_CONFIG", "configs/models/turkcell_7b_a100_v4_recovery.yaml") + return Path(train_config).stem + + +def _int_env(name: str, default: int) -> int: + value = os.getenv(name) + if value is None or value.strip() == "": + return default + try: + return int(value) + except ValueError: + return default + + +def parse_args() -> argparse.Namespace: + slug = _config_slug() + default_log_file = os.getenv("TRAIN_LOG", f"artifacts/logs/training_{slug}.log") + default_state_file = os.getenv( + "TRAIN_WATCHDOG_STATE_FILE", + f"artifacts/logs/training_watchdog_state_{slug}.json", + ) + default_status_file = os.getenv( + "TRAIN_WATCHDOG_STATUS_FILE", + f"artifacts/logs/training_watchdog_status_{slug}.txt", + ) + default_target_steps = _int_env("TARGET_STEPS", 8601) + default_poll_seconds = _int_env("WATCHDOG_POLL_SECONDS", 60) + default_stall_seconds = _int_env("WATCHDOG_STALL_SECONDS", 5400) + + parser = argparse.ArgumentParser(description="Monitor training and auto-restart on failures.") + parser.add_argument("--service", default="forge-training.service") + parser.add_argument("--log-file", default=default_log_file) + parser.add_argument("--state-file", default=default_state_file) + parser.add_argument("--status-file", default=default_status_file) + parser.add_argument("--target-steps", type=int, default=default_target_steps) + parser.add_argument("--poll-seconds", type=int, default=default_poll_seconds) + parser.add_argument("--nan-consecutive-limit", type=int, default=3) + parser.add_argument("--stall-seconds", type=int, default=default_stall_seconds) + parser.add_argument("--max-read-bytes", type=int, default=2_000_000) + return parser.parse_args() + + +def _run_systemctl(*args: str) -> subprocess.CompletedProcess[str]: + return subprocess.run( + ["systemctl", "--user", *args], + check=False, + text=True, + capture_output=True, + ) + + +def is_service_active(service: str) -> bool: + proc = _run_systemctl("is-active", "--quiet", service) + return proc.returncode == 0 + + +def restart_service(service: str) -> bool: + proc = _run_systemctl("restart", service) + return proc.returncode == 0 + + +def start_service(service: str) -> bool: + proc = _run_systemctl("start", service) + return proc.returncode == 0 + + +def read_tail_text(path: Path, max_bytes: int) -> str: + if not path.exists(): + return "" + with path.open("rb") as handle: + handle.seek(0, os.SEEK_END) + size = handle.tell() + handle.seek(max(0, size - max_bytes), os.SEEK_SET) + return handle.read().decode("utf-8", errors="ignore") + + +def parse_training_tail(text: str, target_steps: int) -> tuple[int, str]: + if not text: + return 0, "" + + step_pattern = re.compile(rf"(\d+)/{target_steps}\b") + steps = [int(match.group(1)) for match in step_pattern.finditer(text)] + max_step = max(steps) if steps else 0 + + metric_lines: list[str] = [] + for line in text.splitlines(): + if ("'loss':" in line and "'grad_norm':" in line) or "'eval_loss':" in line: + metric_lines.append(line) + last_metric = metric_lines[-1] if metric_lines else "" + return max_step, last_metric + + +def metric_hash(metric_line: str) -> str: + if not metric_line: + return "" + return hashlib.sha256(metric_line.encode("utf-8", errors="ignore")).hexdigest() + + +def load_state(path: Path) -> WatchdogState: + if not path.exists(): + return WatchdogState() + try: + payload = json.loads(path.read_text(encoding="utf-8")) + return WatchdogState( + last_metric_hash=str(payload.get("last_metric_hash", "")), + nan_consecutive=int(payload.get("nan_consecutive", 0)), + last_step=int(payload.get("last_step", 0)), + last_step_change_ts=float(payload.get("last_step_change_ts", 0.0)), + ) + except (json.JSONDecodeError, OSError, ValueError, TypeError): + return WatchdogState() + + +def save_state(path: Path, state: WatchdogState) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(asdict(state), indent=2), encoding="utf-8") + + +def write_status( + path: Path, + *, + service: str, + active: bool, + step: int, + target_steps: int, + metric_line: str, + state: WatchdogState, + action: str, +) -> None: + pct = int((step * 100) / target_steps) if target_steps > 0 else 0 + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + status_lines = [ + f"timestamp_utc={timestamp}", + f"service={service}", + f"active={'yes' if active else 'no'}", + f"step={step}", + f"target_steps={target_steps}", + f"percent={pct}", + f"nan_consecutive={state.nan_consecutive}", + f"last_step_change_ts={int(state.last_step_change_ts)}", + f"last_metric_contains_nan={'yes' if 'nan' in metric_line.lower() else 'no'}", + f"action={action}", + ] + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("\n".join(status_lines) + "\n", encoding="utf-8") + + +def main() -> None: + args = parse_args() + log_file = Path(args.log_file) + state_file = Path(args.state_file) + status_file = Path(args.status_file) + state = load_state(state_file) + + while True: + now = time.time() + action = "none" + active = is_service_active(args.service) + + if not active: + started = start_service(args.service) + action = "start_service" if started else "start_failed" + active = is_service_active(args.service) + + tail_text = read_tail_text(log_file, args.max_read_bytes) + step, last_metric = parse_training_tail(tail_text, args.target_steps) + current_metric_hash = metric_hash(last_metric) + + if step > state.last_step: + state.last_step = step + state.last_step_change_ts = now + elif state.last_step_change_ts == 0.0 and step > 0: + state.last_step_change_ts = now + + if current_metric_hash and current_metric_hash != state.last_metric_hash: + state.last_metric_hash = current_metric_hash + if "nan" in last_metric.lower(): + state.nan_consecutive += 1 + else: + state.nan_consecutive = 0 + + stalled = ( + active + and state.last_step_change_ts > 0 + and (now - state.last_step_change_ts) >= args.stall_seconds + ) + nan_limit_hit = state.nan_consecutive >= args.nan_consecutive_limit + + if nan_limit_hit or stalled: + restarted = restart_service(args.service) + if nan_limit_hit: + action = "restart_nan_limit_hit" if restarted else "restart_nan_failed" + else: + action = "restart_stall_timeout" if restarted else "restart_stall_failed" + state.nan_consecutive = 0 + state.last_metric_hash = "" + state.last_step_change_ts = now + active = is_service_active(args.service) + + save_state(state_file, state) + write_status( + status_file, + service=args.service, + active=active, + step=step, + target_steps=args.target_steps, + metric_line=last_metric, + state=state, + action=action, + ) + time.sleep(args.poll_seconds) + + +if __name__ == "__main__": + main() diff --git a/src/forge/training/callbacks.py b/src/forge/training/callbacks.py index 124481d..10bbc37 100644 --- a/src/forge/training/callbacks.py +++ b/src/forge/training/callbacks.py @@ -2,14 +2,23 @@ from __future__ import annotations +import math from typing import Any from forge.utils.logging import get_logger +try: + from transformers import TrainerCallback +except Exception: # pragma: no cover - fallback for non-training environments + class TrainerCallback: # type: ignore[no-redef] + """Fallback base class when transformers is unavailable.""" + + pass + logger = get_logger(__name__) -class EarlyStoppingOnPlateau: +class EarlyStoppingOnPlateau(TrainerCallback): """Stop training when eval loss plateaus for `patience` eval steps. Compatible with the ``transformers.TrainerCallback`` protocol. @@ -57,3 +66,88 @@ def on_evaluate( logger.info("early_stopping", step=state.global_step) # HF Trainer checks this flag after eval control.should_training_stop = True + + +def _is_non_finite(value: object) -> bool: + """Return True when a metric value is NaN/Inf.""" + if isinstance(value, bool): + return False + if isinstance(value, (int, float)): + return not math.isfinite(float(value)) + if isinstance(value, str): + normalized = value.strip().lower() + return normalized in {"nan", "inf", "+inf", "-inf"} + return False + + +class NaNGuardCallback(TrainerCallback): + """Stop training when NaN/Inf metrics appear repeatedly.""" + + def __init__( + self, + consecutive_limit: int = 5, + watch_keys: tuple[str, ...] = ("loss", "grad_norm", "eval_loss"), + ) -> None: + self.consecutive_limit = consecutive_limit + self.watch_keys = watch_keys + self._consecutive_hits = 0 + + def _handle_metrics( + self, + *, + metrics: dict[str, object] | None, + state: Any, + control: Any, + source: str, + ) -> None: + if not metrics: + return + + bad_values: dict[str, object] = { + key: value + for key, value in metrics.items() + if key in self.watch_keys and _is_non_finite(value) + } + + if not bad_values: + self._consecutive_hits = 0 + return + + self._consecutive_hits += 1 + logger.warning( + "nan_guard_detected", + source=source, + step=state.global_step, + hits=self._consecutive_hits, + limit=self.consecutive_limit, + bad_metrics=bad_values, + ) + + if self._consecutive_hits >= self.consecutive_limit: + logger.error( + "nan_guard_stopping_training", + source=source, + step=state.global_step, + limit=self.consecutive_limit, + ) + control.should_training_stop = True + + def on_log( + self, + args: Any, + state: Any, + control: Any, + logs: dict[str, object] | None = None, + **kwargs: object, + ) -> None: + self._handle_metrics(metrics=logs, state=state, control=control, source="log") + + def on_evaluate( + self, + args: Any, + state: Any, + control: Any, + metrics: dict[str, object] | None = None, + **kwargs: object, + ) -> None: + self._handle_metrics(metrics=metrics, state=state, control=control, source="eval") diff --git a/src/forge/training/trainer.py b/src/forge/training/trainer.py index 3676da4..54d2039 100644 --- a/src/forge/training/trainer.py +++ b/src/forge/training/trainer.py @@ -8,12 +8,16 @@ from datasets import load_dataset +from forge.training.callbacks import EarlyStoppingOnPlateau, NaNGuardCallback from forge.utils.config import TrainingConfig from forge.utils.logging import get_logger logger = get_logger(__name__) _TRUE_VALUES = {"1", "true", "yes", "on"} +_EARLY_STOPPING_PATIENCE = 5 +_EARLY_STOPPING_MIN_DELTA = 0.001 +_NAN_GUARD_CONSECUTIVE_LIMIT = 5 def _is_truthy(value: str | None) -> bool: @@ -74,7 +78,7 @@ def _setup_unsloth(self) -> None: def _setup_peft(self) -> None: """Load model via standard PEFT (fallback when Unsloth unavailable).""" import torch - from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig logger.info( @@ -106,24 +110,35 @@ def _setup_peft(self) -> None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = prepare_model_for_kbit_training(self.model) - - lora_bias = self.config.lora.bias.lower() - valid_lora_bias = {"none", "all", "lora_only"} - if lora_bias not in valid_lora_bias: - raise ValueError( - f"Invalid LoRA bias '{self.config.lora.bias}'. " - "Expected one of: none, all, lora_only." + adapter_init_path = self.config.training.adapter_init_path + if adapter_init_path: + adapter_path = Path(adapter_init_path).expanduser() + if not adapter_path.exists(): + raise FileNotFoundError(f"Adapter init path not found: {adapter_path}") + logger.info("loading_adapter_init", path=str(adapter_path)) + self.model = PeftModel.from_pretrained( + self.model, + str(adapter_path), + is_trainable=True, ) + else: + lora_bias = self.config.lora.bias.lower() + valid_lora_bias = {"none", "all", "lora_only"} + if lora_bias not in valid_lora_bias: + raise ValueError( + f"Invalid LoRA bias '{self.config.lora.bias}'. " + "Expected one of: none, all, lora_only." + ) - lora_config = LoraConfig( - r=self.config.lora.r, - lora_alpha=self.config.lora.alpha, - lora_dropout=self.config.lora.dropout, - target_modules=self.config.lora.target_modules, - bias=cast(Literal["none", "all", "lora_only"], lora_bias), - task_type=self.config.lora.task_type, - ) - self.model = get_peft_model(self.model, lora_config) + lora_config = LoraConfig( + r=self.config.lora.r, + lora_alpha=self.config.lora.alpha, + lora_dropout=self.config.lora.dropout, + target_modules=self.config.lora.target_modules, + bias=cast(Literal["none", "all", "lora_only"], lora_bias), + task_type=self.config.lora.task_type, + ) + self.model = get_peft_model(self.model, lora_config) logger.info("model_loaded_peft", trainable_params=self._count_trainable_params()) @@ -247,6 +262,7 @@ def train(self, resume_from_checkpoint: str | None = None) -> Path: warmup_ratio=self.config.training.warmup_ratio, lr_scheduler_type=self.config.training.lr_scheduler_type, weight_decay=self.config.training.weight_decay, + max_grad_norm=self.config.training.max_grad_norm, seed=self.config.training.seed, max_steps=self.config.training.max_steps, report_to="wandb" if wandb_enabled else "none", @@ -263,6 +279,23 @@ def train(self, resume_from_checkpoint: str | None = None) -> Path: args=training_args, formatting_func=self._format_prompt, ) + trainer.add_callback( + EarlyStoppingOnPlateau( + patience=_EARLY_STOPPING_PATIENCE, + min_delta=_EARLY_STOPPING_MIN_DELTA, + ) + ) + trainer.add_callback( + NaNGuardCallback( + consecutive_limit=_NAN_GUARD_CONSECUTIVE_LIMIT, + ) + ) + logger.info( + "training_callbacks_enabled", + early_stopping_patience=_EARLY_STOPPING_PATIENCE, + early_stopping_min_delta=_EARLY_STOPPING_MIN_DELTA, + nan_guard_consecutive_limit=_NAN_GUARD_CONSECUTIVE_LIMIT, + ) logger.info("training_started", output_dir=str(output_dir)) if resume_from_checkpoint: diff --git a/src/forge/utils/config.py b/src/forge/utils/config.py index 759fda9..8fec97a 100644 --- a/src/forge/utils/config.py +++ b/src/forge/utils/config.py @@ -38,6 +38,7 @@ class TrainingParams(BaseModel): warmup_ratio: float = 0.1 lr_scheduler_type: str = "cosine" weight_decay: float = 0.01 + max_grad_norm: float = 1.0 logging_steps: int = 10 save_steps: int = 200 save_total_limit: int = 3 @@ -45,6 +46,7 @@ class TrainingParams(BaseModel): fp16: bool = True # always True for Volta arch bf16: bool = False # NOT supported on V100 max_steps: int = -1 + adapter_init_path: str | None = None seed: int = 42 diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 4200f7e..f5747d0 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock -from forge.training.callbacks import EarlyStoppingOnPlateau +from forge.training.callbacks import EarlyStoppingOnPlateau, NaNGuardCallback def _make_state(global_step: int = 100) -> MagicMock: @@ -104,3 +104,35 @@ def test_min_delta_sensitivity() -> None: cb.on_evaluate(args, state, control, metrics={"eval_loss": 0.92}) assert control.should_training_stop is True + + +def test_nan_guard_stops_after_consecutive_hits() -> None: + """NaN guard should stop training after repeated NaN metrics.""" + cb = NaNGuardCallback(consecutive_limit=3) + args = _make_args() + state = _make_state() + control = _make_control() + + cb.on_log(args, state, control, logs={"loss": 1.2, "grad_norm": 0.5}) + assert control.should_training_stop is False + + cb.on_log(args, state, control, logs={"loss": "nan"}) + cb.on_log(args, state, control, logs={"grad_norm": float("nan")}) + assert control.should_training_stop is False + + cb.on_evaluate(args, state, control, metrics={"eval_loss": "nan"}) + assert control.should_training_stop is True + + +def test_nan_guard_resets_after_finite_metric() -> None: + """A finite metric should reset NaN guard consecutive counter.""" + cb = NaNGuardCallback(consecutive_limit=2) + args = _make_args() + state = _make_state() + control = _make_control() + + cb.on_log(args, state, control, logs={"loss": "nan"}) + cb.on_log(args, state, control, logs={"loss": 0.9}) + cb.on_log(args, state, control, logs={"loss": "nan"}) + + assert control.should_training_stop is False