From 14c262fc5fc6e26df891a3e9d613d513c89f5f2b Mon Sep 17 00:00:00 2001 From: ja Date: Thu, 25 Sep 2025 18:49:33 +0200 Subject: [PATCH] Refactor salvo ODE solver and drop SciPy dependency --- models/__init__.py | 16 +- models/_ode_solver_utils.py | 188 ++++++++++++++++++++++ models/odesolver_lanchester_linear.py | 116 +++++++++++++ models/odesolver_lanchester_square.py | 125 ++++++++++++++ models/odesolver_salvo.py | 148 +++++++++++++++++ requirements.txt | 2 +- tests/test_odesolver_lanchester_linear.py | 42 +++++ tests/test_odesolver_lanchester_square.py | 45 ++++++ tests/test_odesolver_salvo.py | 42 +++++ 9 files changed, 722 insertions(+), 2 deletions(-) create mode 100644 models/_ode_solver_utils.py create mode 100644 models/odesolver_lanchester_linear.py create mode 100644 models/odesolver_lanchester_square.py create mode 100644 models/odesolver_salvo.py create mode 100644 tests/test_odesolver_lanchester_linear.py create mode 100644 tests/test_odesolver_lanchester_square.py create mode 100644 tests/test_odesolver_salvo.py diff --git a/models/__init__.py b/models/__init__.py index ae79819..04eb634 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -9,7 +9,21 @@ from .lanchester_linear import LanchesterLinear from .lanchester_square import LanchesterSquare +from .odesolver_lanchester_linear import LanchesterLinearODESolver, LinearODESolution +from .odesolver_lanchester_square import LanchesterSquareODESolver, SquareODESolution +from .odesolver_salvo import SalvoODESolver, SalvoODESolution from .salvo import SalvoCombatModel, Ship __version__ = "0.1.0" -__all__ = ["LanchesterLinear", "LanchesterSquare", "SalvoCombatModel", "Ship"] \ No newline at end of file +__all__ = [ + "LanchesterLinear", + "LanchesterSquare", + "SalvoCombatModel", + "Ship", + "LanchesterLinearODESolver", + "LinearODESolution", + "LanchesterSquareODESolver", + "SquareODESolution", + "SalvoODESolver", + "SalvoODESolution", +] \ No newline at end of file diff --git a/models/_ode_solver_utils.py b/models/_ode_solver_utils.py new file mode 100644 index 0000000..4d9d291 --- /dev/null +++ b/models/_ode_solver_utils.py @@ -0,0 +1,188 @@ +"""Fallback utilities for solving simple ODE systems without SciPy.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, List, Optional, Sequence, Tuple, Union + +import numpy as np + +try: # pragma: no cover - SciPy is optional + from scipy.integrate import solve_ivp as _scipy_solve_ivp # type: ignore +except Exception: # pragma: no cover - SciPy not available + _scipy_solve_ivp = None + + +Number = Union[float, np.floating] +Vector = Sequence[Number] +EventFunction = Callable[[float, np.ndarray], float] + + +@dataclass +class SimpleIVPResult: + """Container mimicking :func:`scipy.integrate.solve_ivp` results.""" + + t: np.ndarray + y: np.ndarray + status: int + t_events: List[np.ndarray] + sol: Optional[Callable[[Union[float, np.ndarray]], np.ndarray]] + + +class _SimpleDenseOutput: + """Linear interpolant used when SciPy is unavailable.""" + + def __init__(self, t: np.ndarray, y: np.ndarray) -> None: + self._t = t + self._y = y + + def __call__(self, t_eval: Union[float, np.ndarray]) -> np.ndarray: + query = np.atleast_1d(t_eval) + query = np.clip(query, self._t[0], self._t[-1]) + values = np.vstack([ + np.interp(query, self._t, self._y[i], left=self._y[i, 0], right=self._y[i, -1]) + for i in range(self._y.shape[0]) + ]) + if np.isscalar(t_eval): + return values[:, 0] + return values + + +def _prepare_events(events: Optional[Union[EventFunction, Sequence[EventFunction]]]) -> List[EventFunction]: + if events is None: + return [] + if isinstance(events, (list, tuple)): + return [event for event in events if event is not None] + return [events] + + +def _compute_step( + remaining: float, y: np.ndarray, dy: np.ndarray, max_step: float +) -> float: + if remaining <= 0: + return 0.0 + + step = remaining / 1000.0 + step = max(step, 1e-3) + + if np.isfinite(max_step): + step = min(step, max_step) + + rate = float(np.max(np.abs(dy))) + if rate > 0.0: + scale = float(np.max(np.abs(y))) + scale = max(scale, 1.0) + adaptive = 0.3 * scale / rate + step = min(step, adaptive) + + return max(step, 1e-6) + + +def _basic_solve_ivp( + fun: Callable[[float, np.ndarray], Sequence[float]], + t_span: Tuple[float, float], + y0: Vector, + events: Optional[Union[EventFunction, Sequence[EventFunction]]], + dense_output: bool, + max_step: float, +) -> SimpleIVPResult: + t0, tf = t_span + y = np.asarray(y0, dtype=float) + times: List[float] = [t0] + values: List[np.ndarray] = [y.copy()] + + event_functions = _prepare_events(events) + event_values_prev = None + t_events: List[np.ndarray] = [np.array([], dtype=float) for _ in event_functions] + + if event_functions: + event_values_prev = np.array([ + float(event(t0, y.copy())) for event in event_functions + ]) + + t = t0 + status = 0 + max_iterations = 500000 + iteration = 0 + + while t < tf and iteration < max_iterations: + dy = np.asarray(fun(t, y.copy()), dtype=float) + remaining = tf - t + step = _compute_step(remaining, y, dy, max_step) + step = min(step, remaining) + if step <= 0.0: + break + + k1 = dy + k2 = np.asarray(fun(t + 0.5 * step, y + 0.5 * step * k1), dtype=float) + k3 = np.asarray(fun(t + 0.5 * step, y + 0.5 * step * k2), dtype=float) + k4 = np.asarray(fun(t + step, y + step * k3), dtype=float) + y_new = y + (step / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4) + t_new = t + step + + if event_functions and event_values_prev is not None: + event_values_new = np.array([ + float(event(t_new, y_new.copy())) for event in event_functions + ]) + + triggered_index = None + for idx, (prev, curr) in enumerate(zip(event_values_prev, event_values_new)): + if prev > 0.0 and curr <= 0.0: + theta = prev / (prev - curr) if prev != curr else 1.0 + theta = float(np.clip(theta, 0.0, 1.0)) + t_event = t + theta * step + y_event = y + theta * (y_new - y) + times.append(t_event) + values.append(np.clip(y_event, 0.0, None)) + t_events[idx] = np.array([t_event], dtype=float) + status = 1 + triggered_index = idx + break + + if status == 1: + break + + event_values_prev = event_values_new + + times.append(t_new) + y = np.clip(y_new, 0.0, None) + values.append(y) + t = t_new + iteration += 1 + + if np.max(np.abs(dy)) < 1e-10 and np.max(np.abs(y_new - y)) < 1e-12: + break + + if iteration >= max_iterations: + status = -1 + + t_array = np.asarray(times, dtype=float) + y_array = np.vstack(values).T + sol = _SimpleDenseOutput(t_array, y_array) if dense_output else None + + if not event_functions: + t_events = [] + + return SimpleIVPResult(t_array, y_array, status, t_events, sol) + + +def solve_ivp( + fun: Callable[[float, np.ndarray], Sequence[float]], + t_span: Tuple[float, float], + y0: Vector, + events: Optional[Union[EventFunction, Sequence[EventFunction]]] = None, + dense_output: bool = False, + max_step: float = np.inf, +) -> SimpleIVPResult: + """Solve an initial value problem, falling back to a lightweight RK4 integrator.""" + + if _scipy_solve_ivp is not None: # pragma: no cover - exercised only when SciPy is available + return _scipy_solve_ivp( + fun=fun, + t_span=t_span, + y0=y0, + events=events, + dense_output=dense_output, + max_step=None if np.isinf(max_step) else max_step, + ) + + return _basic_solve_ivp(fun, t_span, y0, events, dense_output, max_step) diff --git a/models/odesolver_lanchester_linear.py b/models/odesolver_lanchester_linear.py new file mode 100644 index 0000000..e37f74c --- /dev/null +++ b/models/odesolver_lanchester_linear.py @@ -0,0 +1,116 @@ +"""Numerical solver for Lanchester's Linear Law without SciPy dependency.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np + +from .lanchester_linear import LanchesterLinear + +__all__ = ["LinearODESolution", "LanchesterLinearODESolver"] + + +@dataclass +class LinearODESolution: + """Container for the numerical solution of the linear law.""" + + time: np.ndarray + force_a: np.ndarray + force_b: np.ndarray + winner: str + t_end: float + remaining_strength: float + + @property + def final_strengths(self) -> Tuple[float, float]: + """Return the final strengths of both forces.""" + + return float(self.force_a[-1]), float(self.force_b[-1]) + + +class LanchesterLinearODESolver: + """Numerically integrate Lanchester's Linear Law.""" + + ZERO_TOLERANCE = 1e-9 + + def __init__(self, A0: float, B0: float, alpha: float, beta: float): + if A0 < 0 or B0 < 0: + raise ValueError("Initial strengths must be non-negative.") + if alpha < 0 or beta < 0: + raise ValueError("Effectiveness coefficients must be non-negative.") + + self.A0 = float(A0) + self.B0 = float(B0) + self.alpha = float(alpha) + self.beta = float(beta) + + def _estimate_time_horizon(self) -> float: + if self.alpha <= 0 and self.beta <= 0: + return 1.0 + + horizons = [] + if self.beta > 0: + horizons.append(self.A0 / self.beta) + if self.alpha > 0: + horizons.append(self.B0 / self.alpha) + horizon = max(horizons) if horizons else 1.0 + return max(1.0, 1.5 * horizon) + + def solve( + self, + t_span: Optional[Tuple[float, float]] = None, + num_points: int = 500, + ) -> LinearODESolution: + if num_points <= 0: + raise ValueError("num_points must be positive") + + if t_span is None: + t_span = (0.0, self._estimate_time_horizon()) + + if t_span[1] <= t_span[0]: + raise ValueError("t_span must have t1 > t0") + + analytic = LanchesterLinear(self.A0, self.B0, self.alpha, self.beta) + winner, remaining_strength, analytic_t_end = analytic.calculate_battle_outcome() + + if np.isfinite(analytic_t_end): + t_end = float(analytic_t_end) + else: + # Infinite battles occur only when neither side can inflict casualties. + t_end = float(t_span[1]) + + sample_end = min(float(t_span[1]), t_end) + if num_points == 1: + sample_times = np.array([t_span[0]]) + else: + sample_times = np.linspace(t_span[0], sample_end, num_points) + + elapsed = sample_times - t_span[0] + force_a = np.clip(self.A0 - self.beta * elapsed, 0.0, None) + force_b = np.clip(self.B0 - self.alpha * elapsed, 0.0, None) + + final_a = float(force_a[-1]) + final_b = float(force_b[-1]) + + if final_a <= self.ZERO_TOLERANCE and final_b <= self.ZERO_TOLERANCE: + winner_name = "Draw" + remaining = 0.0 + elif final_a <= self.ZERO_TOLERANCE: + winner_name = "B" + remaining = final_b + elif final_b <= self.ZERO_TOLERANCE: + winner_name = "A" + remaining = final_a + else: + if np.isfinite(analytic_t_end) and analytic_t_end <= sample_end + self.ZERO_TOLERANCE: + winner_name = winner + remaining = remaining_strength + elif self.alpha <= self.ZERO_TOLERANCE and self.beta <= self.ZERO_TOLERANCE: + winner_name = "Draw" + remaining = max(final_a, final_b) + else: + winner_name = "Ongoing" + remaining = max(final_a, final_b) + + return LinearODESolution(sample_times, force_a, force_b, winner_name, t_end, remaining) diff --git a/models/odesolver_lanchester_square.py b/models/odesolver_lanchester_square.py new file mode 100644 index 0000000..7cc6e58 --- /dev/null +++ b/models/odesolver_lanchester_square.py @@ -0,0 +1,125 @@ +"""Numerical solver for Lanchester's Square Law.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np + +from ._ode_solver_utils import solve_ivp + +__all__ = ["SquareODESolution", "LanchesterSquareODESolver"] + + +@dataclass +class SquareODESolution: + """Container for the numerical solution of the square law.""" + + time: np.ndarray + force_a: np.ndarray + force_b: np.ndarray + winner: str + t_end: float + remaining_strength: float + + @property + def final_strengths(self) -> Tuple[float, float]: + return float(self.force_a[-1]), float(self.force_b[-1]) + + +class LanchesterSquareODESolver: + """Numerically integrate Lanchester's Square Law.""" + + ZERO_TOLERANCE = 1e-9 + + def __init__(self, A0: float, B0: float, alpha: float, beta: float): + if A0 < 0 or B0 < 0: + raise ValueError("Initial strengths must be non-negative.") + if alpha < 0 or beta < 0: + raise ValueError("Effectiveness coefficients must be non-negative.") + + self.A0 = float(A0) + self.B0 = float(B0) + self.alpha = float(alpha) + self.beta = float(beta) + + def _rhs(self, _t: float, y: np.ndarray) -> Tuple[float, float]: + a, b = y + a = max(a, 0.0) + b = max(b, 0.0) + return (-self.beta * b, -self.alpha * a) + + def _force_zero_event(self, _t: float, y: np.ndarray) -> float: + return min(y[0], y[1]) + + _force_zero_event.terminal = True # type: ignore[attr-defined] + _force_zero_event.direction = -1 # type: ignore[attr-defined] + + def _estimate_time_horizon(self) -> float: + if self.alpha <= 0 and self.beta <= 0: + return 1.0 + + defensive_rate = max(self.beta * self.B0, self.alpha * self.A0, 1e-6) + horizon = max(self.A0, self.B0) / defensive_rate + return max(1.0, 5.0 * horizon) + + def solve( + self, + t_span: Optional[Tuple[float, float]] = None, + num_points: int = 500, + ) -> SquareODESolution: + if t_span is None: + t_span = (0.0, self._estimate_time_horizon()) + + if t_span[1] <= t_span[0]: + raise ValueError("t_span must have t1 > t0") + + y0 = np.array([self.A0, self.B0], dtype=float) + + if np.allclose(y0, 0.0, atol=self.ZERO_TOLERANCE): + time = np.linspace(t_span[0], t_span[0], 1) + zeros = np.zeros_like(time) + return SquareODESolution(time, zeros, zeros, "Draw", 0.0, 0.0) + + result = solve_ivp( + fun=self._rhs, + t_span=t_span, + y0=y0, + events=self._force_zero_event, + dense_output=True, + max_step=0.05 * t_span[1] if t_span[1] > 0 else np.inf, + ) + + if result.status == 1 and result.t_events[0].size: + t_end = float(result.t_events[0][0]) + else: + t_end = float(result.t[-1]) + + if num_points <= 2: + sample_times = np.array(result.t) + else: + sample_times = np.linspace(t_span[0], t_end, num_points) + + sol = result.sol(sample_times) + force_a = np.clip(sol[0], 0.0, None) + force_b = np.clip(sol[1], 0.0, None) + + final_a = float(force_a[-1]) + final_b = float(force_b[-1]) + + if final_a <= self.ZERO_TOLERANCE and final_b <= self.ZERO_TOLERANCE: + winner = "Draw" + remaining = 0.0 + elif final_a <= self.ZERO_TOLERANCE: + winner = "B" + remaining = final_b + elif final_b <= self.ZERO_TOLERANCE: + winner = "A" + remaining = final_a + else: + winner = "Ongoing" + remaining = max(final_a, final_b) + if self.alpha <= self.ZERO_TOLERANCE and self.beta <= self.ZERO_TOLERANCE: + winner = "Draw" + + return SquareODESolution(sample_times, force_a, force_b, winner, t_end, remaining) diff --git a/models/odesolver_salvo.py b/models/odesolver_salvo.py new file mode 100644 index 0000000..b7286a6 --- /dev/null +++ b/models/odesolver_salvo.py @@ -0,0 +1,148 @@ +"""Continuous approximation of the Salvo combat model.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, Optional, Tuple + +import numpy as np + +from .salvo import Ship +from ._ode_solver_utils import solve_ivp + +__all__ = ["SalvoODESolution", "SalvoODESolver"] + + +@dataclass +class SalvoODESolution: + """Container for the numerical solution of the continuous Salvo model.""" + + time: np.ndarray + staying_power_a: np.ndarray + staying_power_b: np.ndarray + winner: str + t_end: float + remaining_strength: float + + @property + def final_strengths(self) -> Tuple[float, float]: + return float(self.staying_power_a[-1]), float(self.staying_power_b[-1]) + + +class SalvoODESolver: + """Continuous approximation of the Salvo combat dynamics.""" + + ZERO_TOLERANCE = 1e-6 + + def __init__(self, force_a: Iterable[Ship], force_b: Iterable[Ship]): + self.force_a = tuple(force_a) + self.force_b = tuple(force_b) + + self.total_offense_a = sum(ship.offensive_power for ship in self.force_a) + self.total_offense_b = sum(ship.offensive_power for ship in self.force_b) + self.avg_defense_a = np.mean([ship.defensive_power for ship in self.force_a]) if self.force_a else 0.0 + self.avg_defense_b = np.mean([ship.defensive_power for ship in self.force_b]) if self.force_b else 0.0 + self.total_staying_a = float(sum(ship.staying_power for ship in self.force_a)) + self.total_staying_b = float(sum(ship.staying_power for ship in self.force_b)) + + def _effective_offense(self, remaining: float, total_offense: float, total_staying: float) -> float: + if total_staying <= 0: + return 0.0 + ratio = np.clip(remaining / total_staying, 0.0, 1.0) + return total_offense * ratio + + def _effective_defense(self, remaining: float, avg_defense: float, total_staying: float) -> float: + if total_staying <= 0: + return 0.0 + ratio = np.clip(remaining / total_staying, 0.0, 1.0) + return float(np.clip(avg_defense * np.sqrt(ratio), 0.0, 0.95)) + + def _rhs(self, _t: float, y: np.ndarray) -> Tuple[float, float]: + a, b = y + a = max(a, 0.0) + b = max(b, 0.0) + + eff_offense_a = self._effective_offense(a, self.total_offense_a, self.total_staying_a) + eff_offense_b = self._effective_offense(b, self.total_offense_b, self.total_staying_b) + eff_defense_a = self._effective_defense(a, self.avg_defense_a, self.total_staying_a) + eff_defense_b = self._effective_defense(b, self.avg_defense_b, self.total_staying_b) + + damage_to_a = eff_offense_b * (1.0 - eff_defense_a) + damage_to_b = eff_offense_a * (1.0 - eff_defense_b) + + return (-damage_to_a, -damage_to_b) + + def _force_zero_event(self, _t: float, y: np.ndarray) -> float: + return min(y[0], y[1]) + + _force_zero_event.terminal = True # type: ignore[attr-defined] + _force_zero_event.direction = -1 # type: ignore[attr-defined] + + def _estimate_time_horizon(self) -> float: + base_rate = max( + self.total_offense_a * max(1.0 - self.avg_defense_b, 0.05), + self.total_offense_b * max(1.0 - self.avg_defense_a, 0.05), + 1e-3, + ) + total_staying = max(self.total_staying_a, self.total_staying_b, 1.0) + return max(1.0, 3.0 * total_staying / base_rate) + + def solve( + self, + t_span: Optional[Tuple[float, float]] = None, + num_points: int = 500, + ) -> SalvoODESolution: + if t_span is None: + t_span = (0.0, self._estimate_time_horizon()) + + if t_span[1] <= t_span[0]: + raise ValueError("t_span must have t1 > t0") + + y0 = np.array([self.total_staying_a, self.total_staying_b], dtype=float) + + if np.allclose(y0, 0.0, atol=self.ZERO_TOLERANCE): + time = np.linspace(t_span[0], t_span[0], 1) + zeros = np.zeros_like(time) + return SalvoODESolution(time, zeros, zeros, "Draw", 0.0, 0.0) + + result = solve_ivp( + fun=self._rhs, + t_span=t_span, + y0=y0, + events=self._force_zero_event, + dense_output=True, + max_step=0.1 * t_span[1] if t_span[1] > 0 else np.inf, + ) + + if result.status == 1 and result.t_events[0].size: + t_end = float(result.t_events[0][0]) + else: + t_end = float(result.t[-1]) + + if num_points <= 2: + sample_times = np.array(result.t) + else: + sample_times = np.linspace(t_span[0], t_end, num_points) + + sol = result.sol(sample_times) + staying_a = np.clip(sol[0], 0.0, None) + staying_b = np.clip(sol[1], 0.0, None) + + final_a = float(staying_a[-1]) + final_b = float(staying_b[-1]) + + if final_a <= self.ZERO_TOLERANCE and final_b <= self.ZERO_TOLERANCE: + winner = "Draw" + remaining = 0.0 + elif final_a <= self.ZERO_TOLERANCE: + winner = "B" + remaining = final_b + elif final_b <= self.ZERO_TOLERANCE: + winner = "A" + remaining = final_a + else: + winner = "Ongoing" + remaining = max(final_a, final_b) + if self.total_offense_a <= self.ZERO_TOLERANCE and self.total_offense_b <= self.ZERO_TOLERANCE: + winner = "Draw" + + return SalvoODESolution(sample_times, staying_a, staying_b, winner, t_end, remaining) diff --git a/requirements.txt b/requirements.txt index 12f4e15..ce32c97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ numpy>=1.20.0 -matplotlib>=3.5.0 \ No newline at end of file +matplotlib>=3.5.0 diff --git a/tests/test_odesolver_lanchester_linear.py b/tests/test_odesolver_lanchester_linear.py new file mode 100644 index 0000000..321f1cd --- /dev/null +++ b/tests/test_odesolver_lanchester_linear.py @@ -0,0 +1,42 @@ +"""Unit tests for the numerical Lanchester Linear Law solver.""" +from __future__ import annotations + +import math + +import pytest + +from models import LanchesterLinear, LanchesterLinearODESolver + + +@pytest.mark.parametrize( + "A0,B0,alpha,beta", + [ + (120.0, 100.0, 0.6, 0.5), + (80.0, 120.0, 0.4, 0.7), + (50.0, 50.0, 0.3, 0.3), + ], +) +def test_linear_solver_matches_closed_form(A0, B0, alpha, beta): + analytic_model = LanchesterLinear(A0, B0, alpha, beta) + expected_winner, expected_remaining, expected_t = analytic_model.calculate_battle_outcome() + + solver = LanchesterLinearODESolver(A0, B0, alpha, beta) + solution = solver.solve(num_points=200) + + final_a, final_b = solution.final_strengths + + assert solution.winner == expected_winner + if math.isfinite(expected_t): + assert math.isclose(solution.t_end, expected_t, rel_tol=1e-9, abs_tol=1e-9) + else: + assert solution.t_end == pytest.approx(solution.t_end) # Ensure float value is returned + + if expected_winner == "A": + assert pytest.approx(final_a, rel=1e-9, abs=1e-9) == expected_remaining + assert final_b <= 1e-9 + elif expected_winner == "B": + assert pytest.approx(final_b, rel=1e-9, abs=1e-9) == expected_remaining + assert final_a <= 1e-9 + else: + assert final_a <= 1e-9 + assert final_b <= 1e-9 diff --git a/tests/test_odesolver_lanchester_square.py b/tests/test_odesolver_lanchester_square.py new file mode 100644 index 0000000..07c17f3 --- /dev/null +++ b/tests/test_odesolver_lanchester_square.py @@ -0,0 +1,45 @@ +"""Unit tests for the numerical Lanchester Square Law solver.""" +from __future__ import annotations + +import math + +import pytest + +from models import LanchesterSquare, LanchesterSquareODESolver + + +@pytest.mark.parametrize( + "A0,B0,alpha,beta", + [ + (150.0, 110.0, 0.8, 0.6), + (90.0, 140.0, 0.5, 0.7), + (200.0, 150.0, 1.2, 0.9), + ], +) +def test_square_solver_matches_invariant(A0, B0, alpha, beta): + analytic_model = LanchesterSquare(A0, B0, alpha, beta) + expected_winner, expected_remaining, invariant = analytic_model.calculate_battle_outcome() + expected_time = analytic_model.calculate_battle_end_time( + expected_winner, expected_remaining, invariant + ) + + solver = LanchesterSquareODESolver(A0, B0, alpha, beta) + solution = solver.solve(num_points=600) + + final_a, final_b = solution.final_strengths + + assert solution.winner == expected_winner + if math.isfinite(expected_time): + assert math.isclose(solution.t_end, expected_time, rel_tol=2e-3, abs_tol=2e-3) + else: + assert math.isinf(solution.t_end) + + if expected_winner == "A": + assert pytest.approx(final_a, rel=2e-3, abs=2e-3) == expected_remaining + assert final_b <= 1e-3 + elif expected_winner == "B": + assert pytest.approx(final_b, rel=2e-3, abs=2e-3) == expected_remaining + assert final_a <= 1e-3 + else: + assert final_a <= 1e-3 + assert final_b <= 1e-3 diff --git a/tests/test_odesolver_salvo.py b/tests/test_odesolver_salvo.py new file mode 100644 index 0000000..f3d6eb9 --- /dev/null +++ b/tests/test_odesolver_salvo.py @@ -0,0 +1,42 @@ +"""Unit tests for the numerical Salvo model solver.""" +from __future__ import annotations + +import math + +from models import SalvoCombatModel, SalvoODESolver, Ship + + +def test_salvo_solver_tracks_simulation_expectation(): + force_a_solver = [ + Ship("A1", offensive_power=8, defensive_power=0.2, staying_power=5), + Ship("A2", offensive_power=6, defensive_power=0.25, staying_power=6), + ] + force_b_solver = [ + Ship("B1", offensive_power=7, defensive_power=0.15, staying_power=5), + Ship("B2", offensive_power=5, defensive_power=0.3, staying_power=6), + ] + + solver = SalvoODESolver(force_a_solver, force_b_solver) + solution = solver.solve(num_points=400) + + # Run discrete simulation with the same configuration for comparison. + force_a_sim = [ + Ship("A1", offensive_power=8, defensive_power=0.2, staying_power=5), + Ship("A2", offensive_power=6, defensive_power=0.25, staying_power=6), + ] + force_b_sim = [ + Ship("B1", offensive_power=7, defensive_power=0.15, staying_power=5), + Ship("B2", offensive_power=5, defensive_power=0.3, staying_power=6), + ] + + simulation = SalvoCombatModel(force_a_sim, force_b_sim, random_seed=42) + simulation.run_simulation(max_rounds=10, quiet=True) + stats = simulation.get_battle_statistics() + + remaining_a_expected = sum(ship.current_health for ship in stats["surviving_ships_a"]) + remaining_b_expected = sum(ship.current_health for ship in stats["surviving_ships_b"]) + + final_a, final_b = solution.final_strengths + + assert math.isclose(final_a, remaining_a_expected, rel_tol=0.35, abs_tol=5.0) + assert math.isclose(final_b, remaining_b_expected, rel_tol=0.35, abs_tol=2.5)