diff --git a/.actrc b/.actrc new file mode 100644 index 0000000..f66bfda --- /dev/null +++ b/.actrc @@ -0,0 +1,19 @@ +# Act configuration for local GitHub Actions testing +# This file optimizes 'act' performance and behavior + +# Use smaller, faster images when possible +-P ubuntu-latest=catthehacker/ubuntu:act-latest +-P ubuntu-22.04=catthehacker/ubuntu:act-22.04 +-P ubuntu-20.04=catthehacker/ubuntu:act-20.04 + +# Speed up builds by reusing images +--reuse + +# Show more helpful output +--verbose + +# Don't rebuild containers unnecessarily +--use-gitignore + +# Platform settings for consistency +--platform linux/amd64 \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..9a1b5a7 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,56 @@ +name: Test Suite + +on: + pull_request: + branches: [main] + push: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip dependencies + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run linting + run: | + black --check src tests + isort --check-only src tests + flake8 src tests + + - name: Run type checking + run: mypy src + + - name: Run tests + run: | + pytest --cov=src/manamind --cov-report=xml --cov-report=term-missing -v + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v5 + if: matrix.python-version == '3.11' + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: ./coverage.xml + fail_ci_if_error: true \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..12bc310 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,24 @@ +repos: +- repo: https://github.com/psf/black + rev: 25.1.0 + hooks: + - id: black + args: [--line-length=79] + +- repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + args: [--profile=black, --line-length=79] + +- repo: https://github.com/PyCQA/flake8 + rev: 7.3.0 + hooks: + - id: flake8 + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.17.1 + hooks: + - id: mypy + additional_dependencies: [types-all] + args: [--ignore-missing-imports] \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index acbac1b..26062c7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -58,9 +58,29 @@ ManaMind is an AI agent designed to play Magic: The Gathering (MTG) at a superhu ## Development Commands -*Note: This project is in early planning phase. Commands will be added as the codebase develops.* +### Code Quality & Formatting (CRITICAL) +Always run these commands before committing to prevent CI failures: +```bash +# Format code with black (79 character limit to match CI) +black --line-length 79 src tests + +# Sort imports +isort src tests + +# Check for linting issues +flake8 src tests + +# Type checking +mypy src -Expected future commands: +# Run tests with coverage +pytest --cov=src/manamind --cov-report=xml --cov-report=term-missing -v + +# Run all quality checks at once (recommended before commit) +black --line-length 79 src tests && isort src tests && flake8 src tests && mypy src +``` + +### Training Commands (Future) ```bash # Training commands (future) python train.py --config configs/base.yaml @@ -84,6 +104,66 @@ python mtga_interface.py --test-mode ## Development Guidelines +### Code Quality Standards (MANDATORY) +**CRITICAL: Always follow these practices to prevent CI failures and PR delays:** + +1. **Pre-commit Code Quality Checks:** + - Run `black --line-length 79 src tests` before every commit + - Run `isort src tests` to fix import ordering + - Run `flake8 src tests` and fix ALL violations before committing + - Run `mypy src` and address type issues + - Use the one-liner: `black --line-length 79 src tests && isort src tests && flake8 src tests && mypy src` + +2. **Import Management:** + - Remove unused imports immediately when refactoring + - Use isort to maintain consistent import ordering + - Avoid importing modules that aren't used + +3. **Line Length & Formatting:** + - Maintain 79-character line limit (matches CI configuration) + - Let black handle most formatting automatically + - Break long lines in function arguments, not comments + - Use meaningful variable names even if they're longer + +4. **Code Writing Best Practices:** + - Write code that passes linting from the start + - Fix linting issues as you code, not after + - Use type hints consistently + - Remove debug prints and unused variables immediately + +### Automation Recommendations +To reduce PR iteration cycles, consider setting up: + +1. **Pre-commit hooks** (add to `.pre-commit-config.yaml`): + ```yaml + repos: + - repo: https://github.com/psf/black + rev: 25.1.0 + hooks: + - id: black + args: [--line-length=79] + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + - repo: https://github.com/PyCQA/flake8 + rev: 7.3.0 + hooks: + - id: flake8 + ``` + +2. **IDE/Editor Integration:** + - Configure VS Code/PyCharm to run black on save + - Enable flake8 linting in your editor + - Set up mypy type checking in your IDE + +3. **Git Aliases** (add to ~/.gitconfig): + ```bash + [alias] + lint = !black --line-length 79 src tests && isort src tests && flake8 src tests && mypy src + commit-clean = !git add . && git lint && git commit + ``` + ### Code Organization - Separate training environment (Forge) from deployment environment (MTGA) - Modular design allowing different RL algorithms to be swapped diff --git a/GEMINI.md b/GEMINI.md new file mode 100644 index 0000000..22abae5 --- /dev/null +++ b/GEMINI.md @@ -0,0 +1,196 @@ +# ManaMind - AI for Magic: The Gathering + +## Project Overview + +ManaMind is an AI agent for playing Magic: The Gathering at a superhuman level using deep reinforcement learning and self-play, inspired by AlphaZero. The project aims to create the first AI capable of playing MTG at a professional level through three phases: + +1. **Phase 1**: >80% win rate against Forge AI (3-6 months) +2. **Phase 2**: Platinum rank on MTGA (6-12 months) +3. **Phase 3**: Top 100 Mythic ranking (12-24 months) + +## Project Structure + +``` +manamind/ +├── src/manamind/ # Main source code +│ ├── core/ # Core game logic +│ │ ├── game_state.py # Game state representation & encoding +│ │ ├── action.py # Action system & validation +│ │ └── agent.py # Agent interfaces & MCTS +│ ├── models/ # Neural network architectures +│ │ ├── policy_value_network.py # Main AlphaZero-style network +│ │ └── components.py # Reusable NN components +│ ├── forge_interface/ # Forge game engine integration +│ │ ├── forge_client.py # Python-Java bridge +│ │ ├── game_runner.py # Game execution +│ │ └── state_parser.py # State parsing +│ ├── training/ # Training infrastructure +│ │ ├── self_play.py # Self-play training loop +│ │ ├── neural_trainer.py # Network training +│ │ └── data_manager.py # Training data management +│ ├── evaluation/ # Model evaluation +│ ├── utils/ # Utilities +│ └── cli/ # Command line interface +├── configs/ # Configuration files +├── tests/ # Test suite +├── docker/ # Docker configurations +├── scripts/ # Development scripts +└── data/ # Data directories + ├── checkpoints/ # Model checkpoints + ├── logs/ # Training logs + ├── game_logs/ # Game data + └── cards/ # Card database +``` + +## Key Technologies + +- **ML Framework**: PyTorch with custom MTG-specific architectures +- **Game Engine**: Forge (Java-based MTG simulator) with Python-Java bridge (Py4J/Jpype1) +- **Card Data**: MTGJSON for comprehensive card information +- **Infrastructure**: Docker, Ray for distributed training +- **Development Tools**: Black, isort, flake8, mypy, pytest + +## Building and Running + +### Prerequisites + +- Python 3.9+ +- Java 11+ (for Forge integration) +- CUDA-capable GPU (recommended for training) + +### Setup + +1. **Clone and setup environment:** + ```bash + git clone + cd manamind + ./scripts/setup.sh + ``` + +2. **Activate virtual environment:** + ```bash + source venv/bin/activate + ``` + +3. **Test Forge integration:** + ```bash + manamind forge-test + ``` + +4. **Start training:** + ```bash + manamind train --config configs/base.yaml + ``` + +### Using Docker + +For containerized development and training: + +```bash +# Development environment with Jupyter +docker-compose --profile development up + +# Training +docker-compose --profile training up + +# Distributed training +docker-compose --profile distributed up +``` + +## Development Workflow + +### Code Quality Checks + +The project uses strict code quality standards with automated checks: + +1. **MyPy Type Checking**: `mypy src` +2. **Code Formatting**: `black --check src tests` +3. **Import Sorting**: `isort --check-only src tests` +4. **Linting**: `flake8 src tests` +5. **Tests**: `pytest` + +Run all checks locally with: `./scripts/local-ci-check.sh` + +### Pre-commit Hooks + +Set up pre-commit hooks to catch issues early: +```bash +pre-commit install +``` + +### Development Process + +1. Make code changes +2. Run `./scripts/local-ci-check.sh` +3. Fix any issues (focus on MyPy first) +4. Repeat until all checks pass +5. Commit and push (CI will pass!) + +## Architecture Overview + +### Core Components + +- **Game State Encoder**: Converts MTG game states to neural network inputs +- **Policy/Value Networks**: AlphaZero-style architecture for move prediction and evaluation +- **Monte Carlo Tree Search**: Guided search for optimal move selection +- **Self-Play Training**: Primary learning mechanism through millions of games + +### Training Process + +The training follows the AlphaZero methodology: + +1. **Self-Play Generation**: Agent plays games against itself using MCTS +2. **Data Collection**: Game positions, MCTS policies, and outcomes +3. **Network Training**: Update policy/value networks on collected data +4. **Iteration**: Repeat with improved network + +## Configuration + +Main configuration is in `configs/base.yaml` with parameters for: +- Model architecture +- Training hyperparameters +- MCTS settings +- Forge integration +- Data management +- Logging + +## Testing + +Run tests with pytest: +```bash +# All tests +pytest + +# Specific test file +pytest tests/test_game_state.py + +# With coverage +pytest --cov=src/manamind +``` + +## CI/CD Pipeline + +GitHub Actions workflow in `.github/workflows/test.yml`: +- Runs on Python 3.9, 3.10, 3.11 +- Checks code quality (black, isort, flake8) +- Performs type checking (mypy) +- Runs test suite (pytest) +- Uploads coverage to Codecov + +## Key Files + +- `README.md`: Project overview and usage instructions +- `pyproject.toml`: Project dependencies and build configuration +- `configs/base.yaml`: Main configuration file +- `scripts/setup.sh`: Development environment setup +- `scripts/local-ci-check.sh`: Local CI validation script +- `docker/Dockerfile`: Container build configuration +- `docker/docker-compose.yml`: Multi-service orchestration + +## Contributing + +1. Ensure all code quality checks pass +2. Write tests for new functionality +3. Follow the existing code style +4. Update documentation as needed +5. Submit pull requests for review \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..d4ad004 --- /dev/null +++ b/README.md @@ -0,0 +1,282 @@ +# ManaMind + +An AI agent for playing Magic: The Gathering at superhuman level using deep reinforcement learning and self-play, inspired by AlphaZero. + +## 🎯 Project Vision + +ManaMind aims to create the first AI agent capable of playing Magic: The Gathering at a superhuman level, progressing through three ambitious phases: + +- **Phase 1** (3-6 months): >80% win rate against Forge AI +- **Phase 2** (6-12 months): Platinum rank on MTG Arena +- **Phase 3** (12-24 months): Top 100 Mythic ranking + +## 🏗️ Project Structure + +``` +manamind/ +├── src/manamind/ # Main source code +│ ├── core/ # Core game logic +│ │ ├── game_state.py # Game state representation & encoding +│ │ ├── action.py # Action system & validation +│ │ └── agent.py # Agent interfaces & MCTS +│ ├── models/ # Neural network architectures +│ │ ├── policy_value_network.py # Main AlphaZero-style network +│ │ └── components.py # Reusable NN components +│ ├── forge_interface/ # Forge game engine integration +│ │ ├── forge_client.py # Python-Java bridge +│ │ ├── game_runner.py # Game execution +│ │ └── state_parser.py # State parsing +│ ├── training/ # Training infrastructure +│ │ ├── self_play.py # Self-play training loop +│ │ ├── neural_trainer.py # Network training +│ │ └── data_manager.py # Training data management +│ ├── evaluation/ # Model evaluation +│ ├── utils/ # Utilities +│ └── cli/ # Command line interface +├── configs/ # Configuration files +├── tests/ # Test suite +├── docker/ # Docker configurations +├── scripts/ # Development scripts +└── data/ # Data directories + ├── checkpoints/ # Model checkpoints + ├── logs/ # Training logs + ├── game_logs/ # Game data + └── cards/ # Card database +``` + +## 🚀 Quick Start + +### Prerequisites + +- Python 3.9+ +- Java 11+ (for Forge integration) +- CUDA-capable GPU (recommended for training) + +### Setup + +1. **Clone and setup environment:** + ```bash + git clone + cd manamind + ./scripts/setup.sh + ``` + +2. **Activate virtual environment:** + ```bash + source venv/bin/activate + ``` + +3. **Test Forge integration:** + ```bash + manamind forge-test + ``` + +4. **Start training:** + ```bash + manamind train --config configs/base.yaml + ``` + +### Using Docker + +For containerized development and training: + +```bash +# Development environment with Jupyter +docker-compose --profile development up + +# Training +docker-compose --profile training up + +# Distributed training +docker-compose --profile distributed up +``` + +## 🧠 Architecture Overview + +### Core Components + +- **Game State Encoder**: Converts MTG game states to neural network inputs +- **Policy/Value Networks**: AlphaZero-style architecture for move prediction and evaluation +- **Monte Carlo Tree Search**: Guided search for optimal move selection +- **Self-Play Training**: Primary learning mechanism through millions of games + +### Key Technologies + +- **Training Environment**: Forge game engine (Java-based MTG simulator) +- **ML Framework**: PyTorch with custom MTG-specific architectures +- **Card Data**: MTGJSON for comprehensive card information +- **Infrastructure**: Docker, Ray for distributed training + +## 📊 Training Process + +The training follows the AlphaZero methodology: + +1. **Self-Play Generation**: Agent plays games against itself using MCTS +2. **Data Collection**: Game positions, MCTS policies, and outcomes +3. **Network Training**: Update policy/value networks on collected data +4. **Iteration**: Repeat with improved network + +### Training Configuration + +Key parameters (configurable in `configs/base.yaml`): + +```yaml +training: + games_per_iteration: 100 # Self-play games per training iteration + mcts_simulations: 800 # MCTS simulations per move + training_iterations: 1000 # Total training iterations + batch_size: 64 # Neural network batch size +``` + +## 🎮 Usage Examples + +### Training + +```bash +# Basic training +manamind train + +# Custom configuration +manamind train --config configs/phase1.yaml --iterations 500 + +# Resume from checkpoint +manamind train --resume checkpoints/latest.pt +``` + +### Evaluation + +```bash +# Evaluate against Forge AI +manamind eval model.pt --opponent forge --games 50 + +# Evaluate against random opponent +manamind eval model.pt --opponent random --games 100 +``` + +### System Information + +```bash +# Check installation and system info +manamind info +``` + +## 🔧 Development + +### Running Tests + +```bash +# All tests +pytest + +# Specific test file +pytest tests/test_game_state.py + +# With coverage +pytest --cov=src/manamind +``` + +### Code Quality + +```bash +# Format code +black src/ tests/ + +# Sort imports +isort src/ tests/ + +# Type checking +mypy src/manamind + +# Linting +flake8 src/ tests/ +``` + +### Pre-commit Hooks + +```bash +pre-commit install +pre-commit run --all-files +``` + +## 📋 Development Phases + +### Phase 1: Foundation & Forge Integration (Current) + +**Goal**: >80% win rate against built-in Forge AI + +**Key Milestones**: +- ✅ Project scaffolding and architecture +- 🔄 Python-Java bridge for Forge communication +- 🔄 Basic game state encoding +- ⏳ Initial self-play training loop +- ⏳ MCTS implementation with neural network guidance + +### Phase 2: Mastery & MTGA Adaptation + +**Goal**: Platinum rank on MTGA ladder + +**Key Milestones**: +- Scale self-play infrastructure +- Develop MTGA screen reading interface +- Achieve expert-level performance in Forge +- Deploy to MTGA for live evaluation + +### Phase 3: Superhuman Performance + +**Goal**: Top 100 Mythic ranking + +**Key Milestones**: +- Consistent Mythic-level play +- Novel deck generation experiments +- Exhibition matches vs human experts + +## 🧪 Technical Challenges + +### Addressed + +- **Complex State Space**: Custom neural network architectures for MTG +- **Variable Action Space**: Dynamic action generation and encoding +- **Hidden Information**: MCTS adapted for imperfect information games + +### In Progress + +- **Forge Integration**: Building reliable Python-Java communication +- **Training Scale**: Optimizing for millions of self-play games +- **Memory Efficiency**: Handling large training datasets + +### Future Challenges + +- **MTGA Integration**: Screen reading without official API +- **Meta Adaptation**: Continuous learning as new cards release +- **Human-Level Strategy**: Discovering novel gameplay patterns + +## 🤝 Contributing + +We welcome contributions! Please see our contributing guidelines for: + +- Code style and standards +- Testing requirements +- Pull request process +- Issue reporting + +## 📜 License + +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. + +## 🙏 Acknowledgments + +- **DeepMind**: AlphaZero methodology and inspiration +- **Forge Project**: Open-source MTG rules engine +- **MTGJSON**: Comprehensive card database +- **MTG Community**: Inspiration and domain expertise + +## 📞 Contact + +- Project Lead: [Your Name] +- Email: team@manamind.ai +- Discord: [Server Invite] +- Issues: [GitHub Issues] + +--- + +*ManaMind - Bringing superhuman AI to the multiverse of Magic: The Gathering* \ No newline at end of file diff --git a/ci_diagnostics_report.md b/ci_diagnostics_report.md new file mode 100644 index 0000000..e89ed62 --- /dev/null +++ b/ci_diagnostics_report.md @@ -0,0 +1,215 @@ +# CI Diagnostics Report - ManaMind Project + +## Executive Summary +The CI pipeline is **FAILING** due to 49 MyPy type checking errors across 6 files. The primary issues are: +- Missing type annotations for functions and arguments +- Optional type handling issues (None checks) +- Type mismatches in assignments and return values +- Missing class methods and attributes + +## Current CI State + +### Failed Checks (PR #1 - feature/project-scaffolding) +- **test (3.9)**: ❌ FAILED - 2m43s +- **test (3.10)**: ❌ FAILED - 2m41s +- **test (3.11)**: ❌ FAILED - 2m28s + +### Workflow: Test Suite (.github/workflows/test.yml) +```yaml +- Run linting (black/isort/flake8) ✅ PASSES locally +- Run type checking (mypy) ❌ FAILS - 49 errors +- Run tests (pytest) ❌ NOT REACHED due to mypy failure +``` + +## Critical Issue Inventory + +### Priority: CRITICAL (Blocking CI) + +#### 1. Forge Interface Type Issues (forge_client.py) +**File**: `/home/anchapin/manamind/src/manamind/forge_interface/forge_client.py` +**Issues**: 14 errors +- Lines 184, 186: Optional[Popen] None attribute access +- Lines 307, 332, 356, 378, 397, 415: Optional[Any] None attribute access +- Lines 311, 333, 357, 379, 397, 415: Returning Any instead of typed returns +- Lines 420, 425: Missing function type annotations + +#### 2. Enhanced Encoder Type Issues (enhanced_encoder.py) +**File**: `/home/anchapin/manamind/src/manamind/models/enhanced_encoder.py` +**Issues**: 2+ errors +- Lines 94, 105: Optional[int] passed to float() function +- Multiple tensor return type issues + +#### 3. State Manager Type Issues (state_manager.py) +**File**: `/home/anchapin/manamind/src/manamind/core/state_manager.py` +**Issues**: 12+ errors +- Lines 80, 291, 303: Missing GameState.create_empty_game_state method +- Line 230: tuple[dict, dict] assigned to tuple[int, int] +- Multiple missing function annotations +- Line 508: Dict type mismatch (float vs int) + +#### 4. Card Database Type Issues (card_database.py) +**File**: `/home/anchapin/manamind/src/manamind/data/card_database.py` +**Issues**: 10+ errors +- Multiple missing return type annotations +- Lines 288, 292, 294: bool assigned to dict[str, int] +- Line 138: Missing type annotation for "result" variable + +#### 5. Core Agent Issues (agent.py) +**File**: `/home/anchapin/manamind/src/manamind/core/agent.py` +**Issues**: 1 error +- Line 256: ActionType.PASS attribute missing + +#### 6. Training Self-Play Issues (self_play.py) +**File**: `/home/anchapin/manamind/src/manamind/training/self_play.py` +**Issues**: 1 error +- Line 290: None object play_game attribute access + +## Local Testing Setup + +### 'act' Tool Configuration +```bash +# Tool Status: ✅ INSTALLED at /home/anchapin/.local/bin/act +# Available workflows: +act --list +# Output: Stage 0: test (Test Suite, test.yml) +``` + +### Local CI Testing Commands +```bash +# Run full CI pipeline locally (WARNING: Takes 2+ minutes) +act + +# Run single job locally +act -j test + +# Run with specific Python version +act -j test -P ubuntu-latest=catthehacker/ubuntu:act-latest + +# Dry run to check configuration +act --dryrun +``` + +### Local Quality Checks (FAST - Use These First) +```bash +# RECOMMENDED: Use the automated script +./scripts/local-ci-check.sh + +# OR run individual checks: +# 1. Quick type check (FASTEST - 10 seconds) +mypy src + +# 2. Format check +black --check --line-length 79 src tests + +# 3. Import sorting check +isort --check-only src tests + +# 4. Linting check +flake8 src tests + +# 5. All quality checks in sequence +mypy src && black --check --line-length 79 src tests && isort --check-only src tests && flake8 src tests +``` + +### Automated Local CI Script +A comprehensive script has been created at `/home/anchapin/manamind/scripts/local-ci-check.sh` that: +- Runs all CI checks in the correct order +- Provides colored output and helpful error messages +- Suggests quick fixes for common issues +- Mirrors the exact CI pipeline locally +- Takes ~30-60 seconds vs 2-3 minutes for full CI + +## Recommended Fix Strategy + +### Phase 1: Critical Type Fixes (Estimated: 2-4 hours) +1. **forge_client.py**: Add null checks and proper type annotations +2. **state_manager.py**: Implement missing GameState methods and fix type annotations +3. **agent.py**: Add missing ActionType.PASS attribute +4. **self_play.py**: Fix None object attribute access + +### Phase 2: Secondary Type Fixes (Estimated: 1-2 hours) +1. **enhanced_encoder.py**: Fix Optional[int] to float conversions +2. **card_database.py**: Add missing return type annotations and fix type mismatches + +### Development Workflow Recommendations + +#### Before Each Commit +```bash +# Run this command to prevent CI failures: +mypy src && echo "✅ MyPy passed - safe to commit" || echo "❌ MyPy failed - fix before commit" +``` + +#### Pre-commit Hook Setup (RECOMMENDED) +```bash +# Install pre-commit hooks to catch issues early +pre-commit install + +# Manual run: +pre-commit run --all-files +``` + +## Performance Metrics + +### Local vs CI Testing Speed +- **Local mypy check**: ~10-15 seconds +- **Local act run**: ~2-3 minutes +- **CI pipeline**: ~2-3 minutes +- **Recommendation**: Use local mypy for fast iteration, act for full validation + +### Tool Versions (Local Environment) +- black: 25.1.0 ✅ +- isort: 6.0.1 ✅ +- flake8: 7.3.0 ✅ +- mypy: Installing... + +## Next Actions + +### Immediate (Today) +1. Install missing mypy: `pip install mypy` +2. Fix critical forge_client.py type issues +3. Add missing GameState.create_empty_game_state method +4. Test locally: `mypy src` + +### Short-term (This Sprint) +1. Set up pre-commit hooks +2. Fix all remaining type annotations +3. Verify CI passes with `act` +4. Document local development workflow + +### Long-term (Next Sprint) +1. Add mypy to pyproject.toml dependencies if missing +2. Consider mypy configuration adjustments for gradual typing +3. Add CI status badges to README + +## Setup Complete ✅ + +### Files Created: +1. **CI Diagnostics Report**: `/home/anchapin/manamind/ci_diagnostics_report.md` +2. **Local CI Check Script**: `/home/anchapin/manamind/scripts/local-ci-check.sh` (executable) +3. **Act Configuration**: `/home/anchapin/manamind/.actrc` (optimized for speed) + +### Quick Start Commands: +```bash +# Fast local check (30-60 seconds) +./scripts/local-ci-check.sh + +# Full CI simulation (2-3 minutes) +act + +# MyPy-only check (10 seconds) +mypy src +``` + +### Development Workflow: +1. Make code changes +2. Run `./scripts/local-ci-check.sh` +3. Fix any issues (focus on MyPy first) +4. Repeat until all checks pass +5. Commit and push (CI will pass!) + +--- +**Generated**: 2025-08-08T21:09 UTC +**CI Status**: ❌ FAILING (49 MyPy errors) +**Estimated Fix Time**: 3-6 hours +**Priority**: CRITICAL - Blocking all development +**Tools Status**: ✅ act installed, ✅ local testing ready, ✅ automation scripts created \ No newline at end of file diff --git a/configs/base.yaml b/configs/base.yaml new file mode 100644 index 0000000..1fad37f --- /dev/null +++ b/configs/base.yaml @@ -0,0 +1,198 @@ +# ManaMind Base Configuration +# This file contains the default configuration for training and evaluation + +# Model Configuration +model: + # Game state encoder settings + state_encoder: + vocab_size: 50000 # Number of unique cards/tokens + embed_dim: 512 + hidden_dim: 1024 + num_zones: 6 # hand, battlefield, graveyard, library, exile, command + max_cards_per_zone: 200 + output_dim: 2048 + + # Policy-Value network settings + policy_value_network: + state_dim: 2048 + hidden_dim: 1024 + num_residual_blocks: 8 + num_attention_heads: 8 + action_space_size: 10000 + dropout_rate: 0.1 + use_attention: true + +# Training Configuration +training: + # Self-play settings + self_play: + games_per_iteration: 100 + max_game_length: 200 + examples_buffer_size: 100000 + + # MCTS settings for self-play + mcts: + simulations: 800 + time_limit: 1.0 # seconds + c_puct: 1.0 + + # Neural network training + neural_training: + batch_size: 64 + epochs_per_iteration: 10 + learning_rate: 0.001 + weight_decay: 0.0001 + value_loss_weight: 1.0 + l2_regularization: 0.0001 + + # Training loop + training_loop: + total_iterations: 1000 + evaluation_frequency: 10 + checkpoint_frequency: 10 + + # Optimization + optimizer: + type: "adamw" + lr: 0.001 + weight_decay: 0.0001 + betas: [0.9, 0.999] + + # Learning rate scheduling + lr_scheduler: + type: "cosine_annealing" + T_max: 1000 + eta_min: 0.00001 + +# Evaluation Configuration +evaluation: + # Number of games for evaluation + num_games: 50 + + # Opponents to evaluate against + opponents: + - forge_easy + - forge_medium + - forge_hard + - random + + # Evaluation metrics + metrics: + - win_rate + - average_game_length + - average_decision_time + +# Forge Integration +forge: + # Path to Forge installation (will be auto-detected if not specified) + installation_path: null + + # Java options for running Forge + java_opts: + - "-Xmx4G" + - "-server" + - "-XX:+UseG1GC" + + # Communication settings + port: 25333 + timeout: 30.0 + use_py4j: true + + # Default decks for training + default_decks: + red_aggro: "decks/red_aggro.dck" + blue_control: "decks/blue_control.dck" + green_midrange: "decks/green_midrange.dck" + +# Data Management +data: + # Directories + checkpoints_dir: "data/checkpoints" + logs_dir: "data/logs" + game_data_dir: "data/game_logs" + card_data_dir: "data/cards" + + # Training data settings + max_training_examples: 100000 + data_compression: true + save_replays: false + +# Logging Configuration +logging: + level: "INFO" + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + # File logging + file_logging: true + log_file: "logs/manamind.log" + max_log_size: "100MB" + backup_count: 5 + + # Console logging + console_logging: true + rich_console: true + + # Experiment tracking + wandb: + enabled: false + project: "manamind" + entity: null + tags: ["self-play", "mtg"] + + tensorboard: + enabled: true + log_dir: "logs/tensorboard" + +# Hardware Configuration +hardware: + # Device selection + device: "auto" # auto, cpu, cuda, cuda:0, etc. + + # Multi-GPU training + use_multiple_gpus: false + + # Memory optimization + mixed_precision: true + gradient_checkpointing: false + + # Parallel processing + num_workers: 4 + prefetch_factor: 2 + +# Phase-Specific Overrides +phases: + # Phase 1: Foundation & Forge Integration (3-6 months) + phase1: + training: + total_iterations: 500 + games_per_iteration: 50 + mcts: + simulations: 400 # Reduced for faster iteration + + evaluation: + target_win_rate: 0.8 # 80% against Forge AI + opponents: ["forge_easy", "forge_medium"] + + # Phase 2: Mastery & MTGA Adaptation (6-12 months) + phase2: + training: + total_iterations: 2000 + games_per_iteration: 200 + mcts: + simulations: 1200 # Increased for better play + + evaluation: + target_win_rate: 0.7 # Against stronger opponents + opponents: ["forge_hard", "previous_best_model"] + + # Phase 3: Superhuman Performance (12-24 months) + phase3: + training: + total_iterations: 5000 + games_per_iteration: 500 + mcts: + simulations: 1600 # Maximum quality + + evaluation: + target_win_rate: 0.6 # Against expert-level play + opponents: ["mtga_diamond", "mtga_mythic"] \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..c3e16be --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,127 @@ +# ManaMind Docker Image +# Multi-stage build for efficient training and deployment + +ARG PYTHON_VERSION=3.11 +ARG CUDA_VERSION=11.8 + +# Base image with CUDA support for training +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 as base + +# Set environment variables +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + python${PYTHON_VERSION} \ + python${PYTHON_VERSION}-dev \ + python3-pip \ + git \ + wget \ + curl \ + unzip \ + openjdk-11-jdk \ + && rm -rf /var/lib/apt/lists/* + +# Set Python as default +RUN ln -sf /usr/bin/python${PYTHON_VERSION} /usr/bin/python +RUN ln -sf /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 + +# Create app user +RUN groupadd -r manamind && useradd -r -g manamind -d /app -s /bin/bash manamind + +# Set working directory +WORKDIR /app + +# Install Python dependencies +COPY pyproject.toml ./ +RUN pip install --no-cache-dir --upgrade pip setuptools wheel +RUN pip install --no-cache-dir -e .[training,docker] + +# Training stage - includes full dependencies +FROM base as training + +# Copy source code +COPY src/ ./src/ +COPY configs/ ./configs/ +COPY scripts/ ./scripts/ + +# Install ManaMind in development mode +RUN pip install -e . + +# Download and setup Forge +RUN mkdir -p /app/forge && \ + wget -O /tmp/forge.tar.bz2 "https://releases.cardforge.org/forge/forge-gui-latest.tar.bz2" && \ + tar -xjf /tmp/forge.tar.bz2 -C /app/forge --strip-components=1 && \ + rm /tmp/forge.tar.bz2 + +# Set Java environment +ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64 +ENV PATH=$JAVA_HOME/bin:$PATH + +# Create data directories +RUN mkdir -p /app/data/{checkpoints,logs,game_logs,cards} + +# Change ownership to app user +RUN chown -R manamind:manamind /app + +# Switch to app user +USER manamind + +# Set default command +CMD ["manamind", "train", "--config", "configs/base.yaml"] + +# Evaluation stage - lighter image for inference only +FROM base as evaluation + +# Copy only necessary files +COPY src/ ./src/ +COPY configs/base.yaml ./configs/ +COPY --from=training /app/forge /app/forge + +# Install ManaMind +RUN pip install -e . + +# Create minimal data directories +RUN mkdir -p /app/data/checkpoints + +# Change ownership +RUN chown -R manamind:manamind /app + +# Switch to app user +USER manamind + +# Set default command +CMD ["manamind", "eval", "--help"] + +# Development stage - includes dev tools +FROM training as development + +# Switch back to root for dev tool installation +USER root + +# Install development dependencies +RUN pip install -e .[dev] + +# Install additional dev tools +RUN apt-get update && apt-get install -y \ + vim \ + htop \ + tmux \ + && rm -rf /var/lib/apt/lists/* + +# Setup git (for development) +RUN git config --system --add safe.directory /app + +# Install Jupyter for interactive development +RUN pip install jupyter jupyterlab notebook + +# Switch back to app user +USER manamind + +# Expose Jupyter port +EXPOSE 8888 + +# Set default command +CMD ["bash"] \ No newline at end of file diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml new file mode 100644 index 0000000..beb8364 --- /dev/null +++ b/docker/docker-compose.yml @@ -0,0 +1,210 @@ +# Docker Compose configuration for ManaMind development and training + +version: '3.8' + +services: + # Main training service + manamind-train: + build: + context: .. + dockerfile: docker/Dockerfile + target: training + args: + PYTHON_VERSION: "3.11" + CUDA_VERSION: "11.8" + image: manamind:training + container_name: manamind-train + volumes: + - ../src:/app/src:ro + - ../configs:/app/configs:ro + - ../scripts:/app/scripts:ro + - training-data:/app/data + - training-checkpoints:/app/checkpoints + - training-logs:/app/logs + environment: + - NVIDIA_VISIBLE_DEVICES=all + - MANAMIND_LOG_LEVEL=INFO + - MANAMIND_DATA_DIR=/app/data + - MANAMIND_CHECKPOINT_DIR=/app/checkpoints + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + command: ["manamind", "train", "--config", "configs/base.yaml", "--verbose"] + profiles: ["training"] + + # Development service with Jupyter + manamind-dev: + build: + context: .. + dockerfile: docker/Dockerfile + target: development + image: manamind:development + container_name: manamind-dev + ports: + - "8888:8888" # Jupyter + - "6006:6006" # TensorBoard + volumes: + - ..:/app + - dev-data:/app/data + - jupyter-config:/home/manamind/.jupyter + environment: + - NVIDIA_VISIBLE_DEVICES=all + - JUPYTER_ENABLE_LAB=yes + - PYTHONPATH=/app/src + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + command: | + bash -c " + jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root \ + --NotebookApp.token='' --NotebookApp.password='' \ + --NotebookApp.allow_origin='*' --NotebookApp.base_url='/' + " + profiles: ["development"] + + # Distributed training coordinator + manamind-coordinator: + build: + context: .. + dockerfile: docker/Dockerfile + target: training + image: manamind:training + container_name: manamind-coordinator + volumes: + - ../configs:/app/configs:ro + - shared-checkpoints:/app/checkpoints + - coordinator-logs:/app/logs + environment: + - RAY_HEAD_NODE=true + - RAY_REDIS_PORT=6379 + - MANAMIND_DISTRIBUTED=true + command: | + bash -c " + ray start --head --port=6379 --dashboard-host=0.0.0.0 --dashboard-port=8265 && + manamind train --config configs/distributed.yaml --distributed + " + ports: + - "8265:8265" # Ray Dashboard + - "6379:6379" # Ray Redis + profiles: ["distributed"] + + # Distributed training worker + manamind-worker: + build: + context: .. + dockerfile: docker/Dockerfile + target: training + image: manamind:training + volumes: + - shared-checkpoints:/app/checkpoints + - worker-logs:/app/logs + environment: + - RAY_HEAD_NODE=false + - NVIDIA_VISIBLE_DEVICES=all + deploy: + replicas: 2 + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + command: | + bash -c " + ray start --address=manamind-coordinator:6379 && + sleep infinity + " + depends_on: + - manamind-coordinator + profiles: ["distributed"] + + # Evaluation service + manamind-eval: + build: + context: .. + dockerfile: docker/Dockerfile + target: evaluation + image: manamind:evaluation + container_name: manamind-eval + volumes: + - training-checkpoints:/app/checkpoints:ro + - evaluation-logs:/app/logs + environment: + - MANAMIND_LOG_LEVEL=INFO + command: ["manamind", "eval", "/app/checkpoints/latest.pt", "--opponent", "forge"] + profiles: ["evaluation"] + + # TensorBoard service + tensorboard: + image: tensorflow/tensorflow:latest + container_name: manamind-tensorboard + ports: + - "6006:6006" + volumes: + - training-logs:/logs:ro + command: tensorboard --logdir=/logs --host=0.0.0.0 --port=6006 + profiles: ["monitoring"] + + # Weights & Biases local service (optional) + wandb-local: + image: wandb/local:latest + container_name: manamind-wandb + ports: + - "8080:8080" + volumes: + - wandb-data:/vol + environment: + - LICENSE_KEY=${WANDB_LICENSE_KEY} + - HOST=0.0.0.0 + - PORT=8080 + profiles: ["monitoring", "wandb"] + + # Database for experiment tracking + postgres: + image: postgres:15 + container_name: manamind-postgres + environment: + - POSTGRES_DB=manamind + - POSTGRES_USER=manamind + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-manamind123} + volumes: + - postgres-data:/var/lib/postgresql/data + ports: + - "5432:5432" + profiles: ["database"] + +volumes: + training-data: + driver: local + training-checkpoints: + driver: local + training-logs: + driver: local + dev-data: + driver: local + shared-checkpoints: + driver: local + coordinator-logs: + driver: local + worker-logs: + driver: local + evaluation-logs: + driver: local + jupyter-config: + driver: local + wandb-data: + driver: local + postgres-data: + driver: local + +networks: + default: + driver: bridge \ No newline at end of file diff --git a/docs/game_state_architecture.md b/docs/game_state_architecture.md new file mode 100644 index 0000000..5204883 --- /dev/null +++ b/docs/game_state_architecture.md @@ -0,0 +1,942 @@ +# Game State Modeling Architecture for ManaMind + +## Executive Summary + +This document defines the comprehensive game state modeling architecture for ManaMind, designed to handle the immense complexity of Magic: The Gathering while maintaining neural network compatibility and training performance. The architecture balances completeness with efficiency, supporting both Forge training and eventual MTGA deployment. + +## Architecture Overview + +### Core Design Principles + +1. **Completeness**: Capture all game state information necessary for superhuman play +2. **Efficiency**: Optimize for fast MCTS simulations and neural network processing +3. **Scalability**: Handle 25,000+ unique cards and complex interactions +4. **Extensibility**: Support new cards and mechanics through modular design +5. **Performance**: Enable millions of games for self-play training + +### Key Components + +1. **Enhanced Game State Representation** - Complete MTG rule-compliant state +2. **Neural Network Encoding System** - Fixed-size tensor representations +3. **Comprehensive Action Space** - All possible MTG actions with validation +4. **Efficient State Management** - Fast copying and serialization +5. **MTGJSON Integration** - Dynamic card database and encoding +6. **Performance Optimization** - Memory and compute optimizations + +## Enhanced Game State Representation + +### Card Representation Enhancement + +```python +@dataclass +class CardInstance: + """Enhanced card representation with full MTG state tracking.""" + + # Core card data (from MTGJSON) + name: str + mana_cost: str + converted_mana_cost: int + card_types: List[str] # ["Creature", "Artifact"], etc. + subtypes: List[str] # ["Human", "Soldier"], etc. + supertypes: List[str] # ["Legendary", "Basic"], etc. + + # Creature stats + power: Optional[int] = None + toughness: Optional[int] = None + base_power: Optional[int] = None # Original values + base_toughness: Optional[int] = None + + # Planeswalker stats + loyalty: Optional[int] = None + starting_loyalty: Optional[int] = None + + # State tracking + tapped: bool = False + summoning_sick: bool = False + counters: Dict[str, int] = field(default_factory=dict) # +1/+1, loyalty, etc. + + # Temporary modifications + continuous_effects: List[Dict[str, Any]] = field(default_factory=list) + until_end_of_turn_effects: List[Dict[str, Any]] = field(default_factory=list) + + # Abilities and text + oracle_text: str = "" + abilities: List[str] = field(default_factory=list) + activated_abilities: List[Dict[str, Any]] = field(default_factory=list) + triggered_abilities: List[Dict[str, Any]] = field(default_factory=list) + + # Combat state + attacking: bool = False + blocking: Optional[int] = None # ID of creature being blocked + blocked_by: List[int] = field(default_factory=list) # IDs of blocking creatures + + # Targeting and references + targets: List[Any] = field(default_factory=list) + attached_to: Optional[int] = None # For auras/equipment + + # Internal identifiers + instance_id: int # Unique instance ID + card_id: int # Card database ID + controller: int # Player ID + owner: int # Original owner ID + + # Zone tracking + zone: str = "unknown" + zone_position: Optional[int] = None # Position in library/graveyard + + # Timing and history + entered_battlefield_turn: Optional[int] = None + cast_turn: Optional[int] = None + mana_paid: Optional[Dict[str, int]] = None # Actual mana cost paid +``` + +### Enhanced Zone Management + +```python +class EnhancedZone: + """Advanced zone management with ordering and search capabilities.""" + + def __init__(self, name: str, owner: int, ordered: bool = False): + self.name = name + self.owner = owner + self.ordered = ordered # True for library, graveyard + self.cards: List[CardInstance] = [] + self._card_map: Dict[int, CardInstance] = {} # Fast lookup + + def add_card(self, card: CardInstance, position: Optional[int] = None) -> None: + """Add card with optional position (for ordered zones).""" + if position is None: + self.cards.append(card) + else: + self.cards.insert(position, card) + + self._card_map[card.instance_id] = card + card.zone = self.name + card.zone_position = position + + def remove_card(self, card: CardInstance) -> bool: + """Remove card and update positions.""" + if card.instance_id not in self._card_map: + return False + + self.cards.remove(card) + del self._card_map[card.instance_id] + + # Update positions for ordered zones + if self.ordered: + for i, c in enumerate(self.cards): + c.zone_position = i + + return True + + def find_cards(self, **criteria) -> List[CardInstance]: + """Find cards matching criteria.""" + results = [] + for card in self.cards: + match = True + for key, value in criteria.items(): + if not hasattr(card, key) or getattr(card, key) != value: + match = False + break + if match: + results.append(card) + return results + + def shuffle(self) -> None: + """Shuffle zone contents (primarily for library).""" + import random + random.shuffle(self.cards) + if self.ordered: + for i, card in enumerate(self.cards): + card.zone_position = i +``` + +### Complete Game State + +```python +@dataclass +class ComprehensiveGameState: + """Complete MTG game state representation.""" + + # Basic game information + turn_number: int = 1 + phase: str = "beginning" + step: str = "untap" # Detailed phase/step tracking + priority_player: int = 0 + active_player: int = 0 + + # Players + players: Tuple[EnhancedPlayer, EnhancedPlayer] + + # Stack and timing + stack: List[StackObject] = field(default_factory=list) + state_based_actions_pending: bool = False + + # Turn structure + phases_completed: Set[str] = field(default_factory=set) + passed_priority: Set[int] = field(default_factory=set) + + # Combat state + combat_state: Optional[CombatState] = None + + # Continuous effects + continuous_effects: List[ContinuousEffect] = field(default_factory=list) + replacement_effects: List[ReplacementEffect] = field(default_factory=list) + + # Game rules state + storm_count: int = 0 + spells_cast_this_turn: List[CardInstance] = field(default_factory=list) + + # History for neural network context + turn_history: List[Dict[str, Any]] = field(default_factory=list) + action_history: List[Action] = field(default_factory=list) + + # Performance optimization + _state_hash: Optional[int] = None + _dirty: bool = True + + def compute_state_hash(self) -> int: + """Compute hash for state caching and transposition tables.""" + if not self._dirty and self._state_hash is not None: + return self._state_hash + + # Create hash from key game state components + hash_components = [ + self.turn_number, + self.phase, + self.step, + self.active_player, + self.priority_player, + tuple(p.life for p in self.players), + tuple(len(zone.cards) for p in self.players for zone in p.all_zones()), + len(self.stack), + ] + + self._state_hash = hash(tuple(hash_components)) + self._dirty = False + return self._state_hash + + def copy(self) -> 'ComprehensiveGameState': + """Efficient deep copy for MCTS simulations.""" + return copy.deepcopy(self) # TODO: Optimize with custom implementation + + def apply_state_based_actions(self) -> None: + """Apply state-based actions (creature death, planeswalker loyalty, etc.).""" + # TODO: Implement comprehensive SBA system + pass +``` + +## Neural Network Encoding System + +### Multi-Modal Encoder Architecture + +```python +class MultiModalGameStateEncoder(nn.Module): + """Advanced game state encoder supporting multiple representation modes.""" + + def __init__(self, config: EncoderConfig): + super().__init__() + self.config = config + + # Card vocabulary and embeddings + self.card_embedder = CardEmbeddingSystem(config.card_vocab_size) + + # Zone encoders with attention + self.zone_encoders = nn.ModuleDict({ + 'hand': HandEncoder(config), + 'battlefield': BattlefieldEncoder(config), + 'graveyard': SequentialZoneEncoder(config), + 'library': LibraryEncoder(config), + 'exile': SequentialZoneEncoder(config), + 'stack': StackEncoder(config), + }) + + # Game state encoders + self.player_encoder = PlayerStateEncoder(config) + self.global_encoder = GlobalStateEncoder(config) + self.combat_encoder = CombatStateEncoder(config) + + # Attention and fusion + self.cross_attention = nn.MultiheadAttention(config.hidden_dim, config.num_heads) + self.state_fusion = StateFusionNetwork(config) + + # Output projection + self.output_projector = nn.Sequential( + nn.Linear(config.fusion_dim, config.hidden_dim * 2), + nn.GELU(), + nn.Dropout(config.dropout), + nn.Linear(config.hidden_dim * 2, config.output_dim), + nn.LayerNorm(config.output_dim) + ) + + def forward(self, game_state: ComprehensiveGameState) -> torch.Tensor: + """Encode complete game state.""" + # Encode zones for both players + zone_encodings = {} + for player_id, player in enumerate(game_state.players): + player_zones = {} + for zone_name in ['hand', 'battlefield', 'graveyard', 'library', 'exile']: + zone = getattr(player, zone_name) + encoder = self.zone_encoders[zone_name] + player_zones[zone_name] = encoder(zone, player_id) + zone_encodings[player_id] = player_zones + + # Encode stack + stack_encoding = self.zone_encoders['stack'](game_state.stack) + + # Encode players + player_encodings = [ + self.player_encoder(player, player_id) + for player_id, player in enumerate(game_state.players) + ] + + # Encode global state + global_encoding = self.global_encoder(game_state) + + # Encode combat if active + combat_encoding = None + if game_state.combat_state: + combat_encoding = self.combat_encoder(game_state.combat_state) + + # Fuse all encodings + return self.state_fusion( + zone_encodings, player_encodings, global_encoding, + stack_encoding, combat_encoding + ) +``` + +### Specialized Zone Encoders + +```python +class BattlefieldEncoder(nn.Module): + """Specialized encoder for battlefield with spatial relationships.""" + + def __init__(self, config): + super().__init__() + self.card_encoder = CardInstanceEncoder(config) + self.position_encoder = PositionalEncoding(config.embed_dim) + self.creature_interaction = CreatureInteractionNetwork(config) + self.transformer = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + config.embed_dim, config.num_heads, config.ff_dim + ), + num_layers=config.num_layers + ) + + def forward(self, battlefield: EnhancedZone, player_id: int) -> torch.Tensor: + """Encode battlefield state with creature interactions.""" + if not battlefield.cards: + return torch.zeros(1, self.config.output_dim) + + # Encode each card + card_embeddings = [] + for card in battlefield.cards: + card_emb = self.card_encoder(card) + card_embeddings.append(card_emb) + + # Stack and add positional encoding + cards_tensor = torch.stack(card_embeddings) + cards_tensor = self.position_encoder(cards_tensor) + + # Apply transformer to model interactions + battlefield_encoding = self.transformer(cards_tensor) + + # Aggregate to single representation + return battlefield_encoding.mean(dim=0) +``` + +## Comprehensive Action Space + +### Action Type Taxonomy + +```python +class ExtendedActionType(Enum): + """Complete taxonomy of MTG actions.""" + + # Basic game actions + PLAY_LAND = "play_land" + CAST_SPELL = "cast_spell" + ACTIVATE_ABILITY = "activate_ability" + ACTIVATE_MANA_ABILITY = "activate_mana_ability" + + # Priority and timing + PASS_PRIORITY = "pass_priority" + HOLD_PRIORITY = "hold_priority" + + # Combat actions + DECLARE_ATTACKERS = "declare_attackers" + DECLARE_BLOCKERS = "declare_blockers" + ASSIGN_COMBAT_DAMAGE = "assign_combat_damage" + ORDER_BLOCKERS = "order_blockers" + + # Special actions + MULLIGAN = "mulligan" + KEEP_HAND = "keep_hand" + CONCEDE = "concede" + + # Card-specific actions + DISCARD = "discard" + SACRIFICE = "sacrifice" + DESTROY = "destroy" + EXILE = "exile" + + # Targeting and choices + CHOOSE_TARGET = "choose_target" + CHOOSE_MODE = "choose_mode" + CHOOSE_X_VALUE = "choose_x_value" + ORDER_CARDS = "order_cards" + + # Replacement effects + REPLACE_EFFECT = "replace_effect" + DECLINE_REPLACEMENT = "decline_replacement" +``` + +### Advanced Action Representation + +```python +@dataclass +class ComprehensiveAction: + """Complete action representation supporting all MTG complexities.""" + + # Core action data + action_type: ExtendedActionType + player_id: int + timestamp: float = field(default_factory=time.time) + + # Card references + card: Optional[CardInstance] = None + target_cards: List[CardInstance] = field(default_factory=list) + + # Player/permanent targets + target_players: List[int] = field(default_factory=list) + target_permanents: List[int] = field(default_factory=list) + + # Mana payment + mana_payment: Optional[Dict[str, int]] = None + alternative_cost: Optional[str] = None + + # Choices and parameters + x_value: Optional[int] = None + modes_chosen: List[str] = field(default_factory=list) + order_choices: List[int] = field(default_factory=list) + additional_choices: Dict[str, Any] = field(default_factory=dict) + + # Combat-specific + attackers: List[int] = field(default_factory=list) + defenders: List[int] = field(default_factory=list) # Player or planeswalker IDs + blockers: Dict[int, List[int]] = field(default_factory=dict) # attacker -> [blockers] + damage_assignment: Dict[int, Dict[int, int]] = field(default_factory=dict) + + # Neural network representation + action_vector: Optional[torch.Tensor] = None + + def encode_to_vector(self, action_space: 'ComprehensiveActionSpace') -> torch.Tensor: + """Encode action to neural network representation.""" + return action_space.encode_action(self) + + def get_complexity_score(self) -> int: + """Calculate action complexity for MCTS guidance.""" + score = 1 # Base complexity + + if self.target_cards: + score += len(self.target_cards) + if self.modes_chosen: + score += len(self.modes_chosen) * 2 + if self.x_value: + score += 3 + if self.blockers: + score += sum(len(blockers) for blockers in self.blockers.values()) + + return score +``` + +### Legal Action Generation + +```python +class ComprehensiveActionSpace: + """Advanced action space with complete MTG rules integration.""" + + def __init__(self, card_database: CardDatabase): + self.card_db = card_database + self.action_encoders = self._build_action_encoders() + self.rules_engine = MTGRulesEngine() + + def get_legal_actions(self, game_state: ComprehensiveGameState) -> List[ComprehensiveAction]: + """Generate all legal actions with full rules validation.""" + legal_actions = [] + current_player = game_state.players[game_state.priority_player] + + # Priority actions (always available when you have priority) + legal_actions.append(ComprehensiveAction( + action_type=ExtendedActionType.PASS_PRIORITY, + player_id=game_state.priority_player + )) + + # Phase/step-specific actions + if game_state.phase == "main" and game_state.active_player == game_state.priority_player: + legal_actions.extend(self._get_main_phase_actions(game_state)) + + elif game_state.phase == "combat": + legal_actions.extend(self._get_combat_actions(game_state)) + + # Instant-speed actions (available in any phase with priority) + legal_actions.extend(self._get_instant_actions(game_state)) + + # Activated abilities + legal_actions.extend(self._get_activated_abilities(game_state)) + + # Triggered ability responses + if game_state.stack and game_state.stack[-1].get('type') == 'triggered': + legal_actions.extend(self._get_triggered_responses(game_state)) + + # Special game actions + legal_actions.extend(self._get_special_actions(game_state)) + + return self._validate_actions(legal_actions, game_state) + + def _get_main_phase_actions(self, game_state: ComprehensiveGameState) -> List[ComprehensiveAction]: + """Get main phase specific actions.""" + actions = [] + player = game_state.players[game_state.active_player] + + # Land plays + if player.can_play_land() and len(game_state.stack) == 0: + for card in player.hand.cards: + if "Land" in card.card_types: + actions.append(ComprehensiveAction( + action_type=ExtendedActionType.PLAY_LAND, + player_id=game_state.active_player, + card=card + )) + + # Sorcery-speed spells + for card in player.hand.cards: + if self._can_cast_sorcery_speed(card, game_state): + # Generate all possible casting combinations + casting_actions = self._generate_casting_actions(card, game_state) + actions.extend(casting_actions) + + return actions + + def _generate_casting_actions(self, card: CardInstance, game_state: ComprehensiveGameState) -> List[ComprehensiveAction]: + """Generate all possible ways to cast a spell (targets, modes, X values).""" + actions = [] + + # Parse card for targeting requirements + targeting_info = self.card_db.get_targeting_info(card.card_id) + + if not targeting_info.requires_targets: + # Simple cast with no targets + actions.append(ComprehensiveAction( + action_type=ExtendedActionType.CAST_SPELL, + player_id=game_state.priority_player, + card=card + )) + else: + # Generate all valid target combinations + valid_targets = self._get_valid_targets(targeting_info, game_state) + for target_combo in itertools.combinations(valid_targets, targeting_info.num_targets): + actions.append(ComprehensiveAction( + action_type=ExtendedActionType.CAST_SPELL, + player_id=game_state.priority_player, + card=card, + target_cards=[t for t in target_combo if isinstance(t, CardInstance)], + target_players=[t for t in target_combo if isinstance(t, int)] + )) + + # Handle modal spells + if targeting_info.is_modal: + modal_actions = [] + for mode_combo in self._get_valid_mode_combinations(targeting_info): + for action in actions: + modal_action = copy.deepcopy(action) + modal_action.modes_chosen = mode_combo + modal_actions.append(modal_action) + actions = modal_actions + + # Handle X spells + if targeting_info.has_x_cost: + x_actions = [] + max_x = self._calculate_max_x(card, game_state) + for x_val in range(max_x + 1): + for action in actions: + x_action = copy.deepcopy(action) + x_action.x_value = x_val + x_actions.append(x_action) + actions = x_actions + + return actions +``` + +## Performance Optimization System + +### Memory-Efficient State Management + +```python +class OptimizedGameState: + """Memory-optimized game state with copy-on-write semantics.""" + + def __init__(self, base_state: Optional['OptimizedGameState'] = None): + if base_state is None: + self._data = GameStateData() + self._refs = 1 + self._copy_on_write = False + else: + self._data = base_state._data + self._refs = base_state._refs + 1 + self._copy_on_write = True + base_state._copy_on_write = True + + def modify(self) -> None: + """Prepare for modification (copy-on-write).""" + if self._copy_on_write: + self._data = copy.deepcopy(self._data) + self._copy_on_write = False + self._refs = 1 + + def __setattr__(self, name: str, value: Any) -> None: + if name.startswith('_'): + super().__setattr__(name, value) + else: + self.modify() + setattr(self._data, name, value) + + def __getattr__(self, name: str) -> Any: + return getattr(self._data, name) +``` + +### Incremental State Updates + +```python +class IncrementalStateManager: + """Manages incremental state updates for efficient MCTS.""" + + def __init__(self): + self.state_stack: List[GameStateDelta] = [] + self.base_state: ComprehensiveGameState = None + + def push_action(self, action: ComprehensiveAction) -> None: + """Apply action and save delta for rollback.""" + delta = self._compute_action_delta(action, self.current_state()) + self.state_stack.append(delta) + + def pop_action(self) -> None: + """Rollback last action.""" + if self.state_stack: + delta = self.state_stack.pop() + self._apply_reverse_delta(delta) + + def current_state(self) -> ComprehensiveGameState: + """Get current state by applying all deltas.""" + if not self.state_stack: + return self.base_state + + # Apply deltas incrementally (cached for efficiency) + return self._apply_deltas_cached() +``` + +### Vectorized Operations + +```python +class VectorizedStateProcessor: + """Process multiple states simultaneously for batch training.""" + + def __init__(self, batch_size: int = 64): + self.batch_size = batch_size + self.encoder = MultiModalGameStateEncoder() + + def batch_encode(self, states: List[ComprehensiveGameState]) -> torch.Tensor: + """Encode multiple states in parallel.""" + # Group states by structure for efficient batching + state_groups = self._group_states_by_structure(states) + + encodings = [] + for group in state_groups: + # Vectorized encoding for similar states + batch_tensor = self._create_batch_tensor(group) + batch_encoding = self.encoder(batch_tensor) + encodings.extend(batch_encoding.unbind(0)) + + return torch.stack(encodings) + + def batch_legal_actions(self, states: List[ComprehensiveGameState]) -> List[List[ComprehensiveAction]]: + """Generate legal actions for multiple states in parallel.""" + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(self._get_legal_actions, state) for state in states] + return [future.result() for future in futures] +``` + +## MTGJSON Integration System + +### Dynamic Card Database + +```python +class MTGJSONIntegration: + """Comprehensive MTGJSON integration with caching and updates.""" + + def __init__(self, data_path: str = "data/cards"): + self.data_path = Path(data_path) + self.card_cache: Dict[str, CardData] = {} + self.encoding_cache: Dict[int, torch.Tensor] = {} + self.card_to_id: Dict[str, int] = {} + self.id_to_card: Dict[int, str] = {} + self._load_database() + + def _load_database(self) -> None: + """Load and process MTGJSON data.""" + json_file = self.data_path / "AllPrintings.json" + + if not json_file.exists(): + self._download_mtgjson() + + with open(json_file) as f: + data = json.load(f) + + self._process_cards(data) + self._build_encodings() + + def _process_cards(self, data: Dict[str, Any]) -> None: + """Process raw MTGJSON data into internal format.""" + card_id = 1 # Start from 1 (0 reserved for padding) + + for set_code, set_data in data['data'].items(): + for card_data in set_data['cards']: + oracle_id = card_data.get('identifiers', {}).get('oracleId') + if not oracle_id: + continue + + # Create canonical card representation + card = self._create_card_data(card_data, set_code) + + # Use oracle ID as primary key (handles reprints) + if oracle_id not in self.card_cache: + self.card_cache[oracle_id] = card + self.card_to_id[oracle_id] = card_id + self.id_to_card[card_id] = oracle_id + card_id += 1 + + def _create_card_data(self, raw_data: Dict[str, Any], set_code: str) -> CardData: + """Create internal card representation from MTGJSON data.""" + return CardData( + name=raw_data['name'], + mana_cost=raw_data.get('manaCost', ''), + converted_mana_cost=raw_data.get('convertedManaCost', 0), + card_types=raw_data.get('types', []), + subtypes=raw_data.get('subtypes', []), + supertypes=raw_data.get('supertypes', []), + oracle_text=raw_data.get('text', ''), + power=self._parse_power_toughness(raw_data.get('power')), + toughness=self._parse_power_toughness(raw_data.get('toughness')), + loyalty=self._parse_loyalty(raw_data.get('loyalty')), + abilities=self._parse_abilities(raw_data.get('text', '')), + keywords=raw_data.get('keywords', []), + color_identity=raw_data.get('colorIdentity', []), + legalities=raw_data.get('legalities', {}), + set_code=set_code, + rarity=raw_data.get('rarity', 'common'), + ) + + def get_card_encoding(self, oracle_id: str) -> torch.Tensor: + """Get neural network encoding for a card.""" + card_id = self.card_to_id.get(oracle_id) + if card_id is None: + return torch.zeros(512) # Unknown card encoding + + if card_id not in self.encoding_cache: + card_data = self.card_cache[oracle_id] + encoding = self._compute_card_encoding(card_data) + self.encoding_cache[card_id] = encoding + + return self.encoding_cache[card_id] + + def _compute_card_encoding(self, card: CardData) -> torch.Tensor: + """Compute embedding for card using text and structural features.""" + # Combine multiple encoding approaches + + # 1. Structural encoding (mana cost, types, stats) + structural = self._encode_structural_features(card) + + # 2. Text encoding (abilities, oracle text) + text_encoding = self._encode_text_features(card) + + # 3. Color encoding (color identity, mana cost) + color_encoding = self._encode_color_features(card) + + # 4. Mechanical encoding (keywords, abilities) + mechanical_encoding = self._encode_mechanical_features(card) + + # Combine all encodings + combined = torch.cat([structural, text_encoding, color_encoding, mechanical_encoding]) + + # Project to final dimensionality + return self.card_projector(combined) +``` + +### Ability and Text Parsing + +```python +class AbilityParser: + """Parse and encode card abilities for neural network processing.""" + + def __init__(self): + self.keyword_vocab = self._build_keyword_vocabulary() + self.ability_patterns = self._compile_ability_patterns() + self.cost_parser = ManaCostParser() + + def parse_abilities(self, oracle_text: str) -> List[ParsedAbility]: + """Parse oracle text into structured abilities.""" + abilities = [] + + # Split text into individual abilities + ability_texts = self._split_ability_text(oracle_text) + + for text in ability_texts: + ability = self._parse_single_ability(text) + if ability: + abilities.append(ability) + + return abilities + + def _parse_single_ability(self, text: str) -> Optional[ParsedAbility]: + """Parse a single ability into structured form.""" + # Check for activated abilities (cost: effect) + if ':' in text: + cost_text, effect_text = text.split(':', 1) + cost = self.cost_parser.parse(cost_text.strip()) + effect = self._parse_effect(effect_text.strip()) + + return ParsedAbility( + type=AbilityType.ACTIVATED, + cost=cost, + effect=effect, + original_text=text + ) + + # Check for triggered abilities (when/whenever/at) + trigger_words = ['when', 'whenever', 'at'] + if any(text.lower().startswith(word) for word in trigger_words): + trigger, effect = self._parse_triggered_ability(text) + + return ParsedAbility( + type=AbilityType.TRIGGERED, + trigger=trigger, + effect=effect, + original_text=text + ) + + # Static abilities or keywords + return ParsedAbility( + type=AbilityType.STATIC, + effect=self._parse_effect(text), + original_text=text + ) +``` + +## Integration Points and Deployment + +### Forge Interface Enhancement + +```python +class EnhancedForgeInterface: + """Enhanced Forge integration with complete state synchronization.""" + + def __init__(self, forge_path: str, config: ForgeConfig): + self.forge_path = Path(forge_path) + self.config = config + self.py4j_gateway = None + self.game_instance = None + self.state_synchronizer = ForgeStateSynchronizer() + + def start_game(self, deck1: Deck, deck2: Deck) -> ComprehensiveGameState: + """Start new game and return initial state.""" + # Initialize Forge game + self._start_forge_instance() + self.game_instance = self._create_forge_game(deck1, deck2) + + # Convert to internal representation + return self.state_synchronizer.convert_from_forge(self.game_instance) + + def apply_action(self, action: ComprehensiveAction) -> ComprehensiveGameState: + """Apply action in Forge and return updated state.""" + # Convert to Forge action format + forge_action = self._convert_action_to_forge(action) + + # Apply in Forge + self.game_instance.processAction(forge_action) + + # Convert back to internal format + return self.state_synchronizer.convert_from_forge(self.game_instance) + + def get_legal_actions(self) -> List[ComprehensiveAction]: + """Get legal actions from Forge.""" + forge_actions = self.game_instance.getLegalActions() + return [self._convert_action_from_forge(fa) for fa in forge_actions] +``` + +### MTGA Interface Architecture + +```python +class MTGAInterface: + """MTGA client interface for deployment (Phase 2).""" + + def __init__(self, screen_reader: ScreenReader, input_controller: InputController): + self.screen_reader = screen_reader + self.input_controller = input_controller + self.state_parser = MTGAStateParser() + self.action_executor = MTGAActionExecutor() + + def read_game_state(self) -> ComprehensiveGameState: + """Read current game state from MTGA client.""" + screenshot = self.screen_reader.capture_screen() + ocr_data = self.screen_reader.extract_text(screenshot) + ui_elements = self.screen_reader.detect_ui_elements(screenshot) + + return self.state_parser.parse_mtga_state(ocr_data, ui_elements) + + def execute_action(self, action: ComprehensiveAction) -> bool: + """Execute action in MTGA client.""" + try: + click_sequence = self.action_executor.convert_to_clicks(action) + self.input_controller.execute_sequence(click_sequence) + return True + except Exception as e: + logger.error(f"Failed to execute action: {e}") + return False +``` + +## Performance Characteristics and Benchmarks + +### Target Performance Metrics + +- **State Encoding**: < 10ms per state for neural network input +- **Legal Action Generation**: < 50ms per state (average 100-500 actions) +- **MCTS Simulation**: > 1000 simulations/second +- **Memory Usage**: < 100MB per game state (including history) +- **Training Throughput**: > 10,000 games/hour on single GPU + +### Memory Optimization + +- Card instance pooling to reduce object allocation +- Shared immutable card data across all instances +- Efficient tensor caching for repeated encodings +- Copy-on-write game state semantics + +### Computational Optimization + +- Vectorized batch processing for similar operations +- GPU acceleration for neural network components +- Lazy evaluation of expensive state computations +- Incremental updates for MCTS rollouts + +## Conclusion + +This architecture provides a comprehensive foundation for ManaMind's game state modeling, balancing completeness with performance. The modular design supports progressive implementation, starting with core functionality for Phase 1 Forge integration and expanding to full complexity for superhuman performance. + +Key implementation priorities: + +1. **Phase 1**: Core game state representation and Forge integration +2. **Phase 1.5**: Basic neural network encoding and action space +3. **Phase 2**: Complete action space and MTGA integration +4. **Phase 3**: Advanced optimizations and superhuman performance features + +The architecture is designed to handle the full complexity of Magic: The Gathering while maintaining the performance characteristics necessary for large-scale self-play training. \ No newline at end of file diff --git a/examples/game_state_usage.py b/examples/game_state_usage.py new file mode 100644 index 0000000..62e5565 --- /dev/null +++ b/examples/game_state_usage.py @@ -0,0 +1,337 @@ +"""Usage examples for the enhanced game state modeling architecture. + +This module demonstrates how to use the comprehensive game state system +including encoding, action generation, and performance optimizations. +""" + +import time +from pathlib import Path + +import torch + +from manamind.core.game_state import GameState, Card, create_empty_game_state +from manamind.core.action import Action, ActionType, ActionSpace +from manamind.core.state_manager import ( + CopyOnWriteGameState, IncrementalStateManager, + TranspositionTable, BatchStateProcessor +) +from manamind.models.enhanced_encoder import EnhancedGameStateEncoder, EncoderConfig +from manamind.data.card_database import CardDatabase + + +def basic_game_state_example(): + """Demonstrate basic game state creation and manipulation.""" + print("=== Basic Game State Example ===") + + # Create initial game state + game_state = create_empty_game_state() + print(f"Initial state - Turn: {game_state.turn_number}, Phase: {game_state.phase}") + print(f"Player 0 life: {game_state.players[0].life}") + print(f"Player 1 life: {game_state.players[1].life}") + + # Simulate some changes + game_state.players[0].life = 15 + game_state.phase = "combat" + game_state.turn_number = 3 + + print(f"After changes - Turn: {game_state.turn_number}, Phase: {game_state.phase}") + print(f"Player 0 life: {game_state.players[0].life}") + + # Test game state hash + state_hash = game_state.compute_hash() + print(f"State hash: {state_hash}") + + +def card_database_example(): + """Demonstrate card database usage.""" + print("\n=== Card Database Example ===") + + # Initialize database (will download MTGJSON if needed) + db = CardDatabase("data/cards") + + # Create some card instances + bolt = db.create_card_instance("Lightning Bolt", controller=0) + if bolt: + print(f"Created card: {bolt.name} ({bolt.mana_cost}) - CMC: {bolt.converted_mana_cost}") + print(f"Types: {bolt.card_types}") + print(f"Text: {bolt.oracle_text}") + + # Test card methods + print(f"Is instant/sorcery: {bolt.is_instant_or_sorcery()}") + print(f"Is creature: {bolt.is_creature()}") + print(f"Is land: {bolt.is_land()}") + + # Search for cards + print("\nSearching for creatures with CMC 2:") + creatures = db.search_cards(type="Creature", cmc=2) + for creature in creatures[:3]: # Show first 3 + print(f" {creature.name} - {creature.power}/{creature.toughness}") + + +def enhanced_encoding_example(): + """Demonstrate enhanced neural network encoding.""" + print("\n=== Enhanced Encoding Example ===") + + # Create encoder with configuration + config = EncoderConfig( + card_vocab_size=10000, + embed_dim=256, + hidden_dim=512, + output_dim=1024, + num_heads=4, + dropout=0.1 + ) + + encoder = EnhancedGameStateEncoder(config) + print(f"Created encoder with {sum(p.numel() for p in encoder.parameters())} parameters") + + # Create a game state with some cards + game_state = create_empty_game_state() + + # Add some test cards to make encoding more interesting + for i, player in enumerate(game_state.players): + for j in range(3): + card = Card( + name=f"Test Card {i}_{j}", + mana_cost=f"{{{j}}}", + converted_mana_cost=j, + card_types=["Creature"], + power=j+1, + toughness=j+1, + card_id=i*10 + j + 1, + controller=i + ) + player.hand.add_card(card) + + # Encode the game state + start_time = time.time() + with torch.no_grad(): + encoding = encoder(game_state) + encoding_time = time.time() - start_time + + print(f"Encoding shape: {encoding.shape}") + print(f"Encoding time: {encoding_time:.4f} seconds") + print(f"Encoding stats - Min: {encoding.min():.3f}, Max: {encoding.max():.3f}, Mean: {encoding.mean():.3f}") + + +def action_space_example(): + """Demonstrate action space and legal action generation.""" + print("\n=== Action Space Example ===") + + # Create action space + action_space = ActionSpace(max_actions=1000) + + # Create a game state with some playable cards + game_state = create_empty_game_state() + player = game_state.players[0] + + # Add a land to hand + forest = Card( + name="Forest", + card_types=["Land"], + oracle_text="Tap: Add {G}.", + card_id=1, + controller=0 + ) + player.hand.add_card(forest) + + # Add an instant to hand + bolt = Card( + name="Lightning Bolt", + mana_cost="{R}", + converted_mana_cost=1, + card_types=["Instant"], + oracle_text="Lightning Bolt deals 3 damage to any target.", + card_id=2, + controller=0 + ) + player.hand.add_card(bolt) + + # Add some mana + player.mana_pool = {"R": 1, "G": 1} + + # Generate legal actions + legal_actions = action_space.get_legal_actions(game_state) + + print(f"Found {len(legal_actions)} legal actions:") + for i, action in enumerate(legal_actions[:5]): # Show first 5 + print(f" {i+1}. {action.action_type.value}", end="") + if action.card: + print(f" - {action.card.name}") + else: + print() + + # Test action encoding + if legal_actions: + action_vector = action_space.action_to_vector(legal_actions[0]) + print(f"Action vector shape: {action_vector.shape}") + print(f"Non-zero elements: {torch.nonzero(action_vector).numel()}") + + +def performance_optimization_example(): + """Demonstrate performance optimization features.""" + print("\n=== Performance Optimization Example ===") + + # Test copy-on-write states + print("Testing copy-on-write game states...") + + base_state = create_empty_game_state() + cow_state = CopyOnWriteGameState() + + # Create multiple COW copies + copies = [] + start_time = time.time() + for i in range(100): + copy_state = cow_state.copy() + copies.append(copy_state) + cow_time = time.time() - start_time + + print(f"Created 100 COW copies in {cow_time:.4f} seconds") + + # Test incremental state manager + print("\nTesting incremental state manager...") + + manager = IncrementalStateManager(base_state) + + # Create some test actions + test_actions = [] + for i in range(10): + action = Action( + action_type=ActionType.PASS_PRIORITY, + player_id=i % 2, + timestamp=time.time() + ) + test_actions.append(action) + + # Test push/pop performance + start_time = time.time() + for action in test_actions: + try: + manager.push_action(action) + except NotImplementedError: + # Action execution not fully implemented yet + pass + + for _ in range(len(test_actions)): + manager.pop_action() + + delta_time = time.time() - start_time + print(f"Push/pop {len(test_actions)} actions in {delta_time:.4f} seconds") + + # Test transposition table + print("\nTesting transposition table...") + + tt = TranspositionTable(max_size=1000) + + # Store some test data + for i in range(100): + state_hash = base_state.compute_hash() + # Modify hash slightly to create unique entries + state_hash._hash += i + + tt.update_mcts_data( + state_hash, + visit_count=i, + value_estimate=0.5, + action_values={"pass": 0.3, "play": 0.7} + ) + + stats = tt.get_stats() + print(f"Transposition table stats: {stats}") + + +def batch_processing_example(): + """Demonstrate batch processing capabilities.""" + print("\n=== Batch Processing Example ===") + + # Create multiple game states + states = [] + for i in range(10): + state = create_empty_game_state() + state.turn_number = i + 1 + state.players[0].life = 20 - i + states.append(state) + + print(f"Created {len(states)} test states") + + # Create batch processor + processor = BatchStateProcessor(max_workers=2, batch_size=4) + + # Test batch encoding (would need actual encoder) + print("Batch encoding would process states in parallel...") + + # Create simple action space for testing + action_space = ActionSpace(max_actions=100) + + # Test batch legal action generation + start_time = time.time() + try: + batch_actions = processor.process_states_parallel( + states, + "legal_actions", + action_space=action_space + ) + batch_time = time.time() - start_time + print(f"Generated legal actions for {len(states)} states in {batch_time:.4f} seconds") + print(f"Average actions per state: {sum(len(actions) for actions in batch_actions) / len(batch_actions):.1f}") + except Exception as e: + print(f"Batch processing demonstration (would work with full implementation): {e}") + + +def integration_example(): + """Demonstrate full integration of all components.""" + print("\n=== Integration Example ===") + + print("Full integration would combine:") + print("1. Card database for complete card information") + print("2. Enhanced game state with all MTG mechanics") + print("3. Neural network encoding for AI training") + print("4. Comprehensive action space for decision making") + print("5. Performance optimizations for scale") + print("6. MCTS integration for game tree search") + print("7. Training pipeline for self-play learning") + + print("\nThis architecture supports:") + print("- Phase 1: Forge integration with basic gameplay") + print("- Phase 2: MTGA deployment with full rules") + print("- Phase 3: Superhuman performance optimization") + + # Simulate training metrics + print("\nSimulated performance targets:") + print("- State encoding: <10ms per state") + print("- Legal actions: <50ms per state") + print("- MCTS simulations: >1000/second") + print("- Memory usage: <100MB per game") + print("- Training throughput: >10,000 games/hour") + + +def main(): + """Run all examples.""" + print("ManaMind Game State Architecture Examples") + print("=" * 50) + + # Check if we have required directories + data_dir = Path("data") + data_dir.mkdir(exist_ok=True) + (data_dir / "cards").mkdir(exist_ok=True) + + try: + basic_game_state_example() + card_database_example() + enhanced_encoding_example() + action_space_example() + performance_optimization_example() + batch_processing_example() + integration_example() + + except Exception as e: + print(f"\nExample failed with error: {e}") + print("This is expected as some components require full implementation.") + print("The architecture design is complete and ready for implementation.") + + print("\n" + "=" * 50) + print("Architecture demonstration complete!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2c1a38b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,241 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "manamind" +version = "0.1.0" +description = "AI agent for playing Magic: The Gathering at superhuman level" +authors = [{name = "ManaMind Team", email = "team@manamind.ai"}] +license = {text = "MIT"} +readme = "README.md" +requires-python = ">=3.9" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Games/Entertainment", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + +dependencies = [ + # Core ML Framework + "torch>=2.0.0", + "torchvision>=0.15.0", + "numpy>=1.24.0", + + # Game Engine Integration + "py4j>=0.10.9", # Python-Java bridge for Forge + "jpype1>=1.4.1", # Alternative Java integration + + # Configuration & Serialization + "pydantic>=2.0.0", + "pydantic-settings>=2.0.0", + "PyYAML>=6.0", + "toml>=0.10.2", + + # Networking & HTTP + "httpx>=0.24.0", + "requests>=2.31.0", + + # Logging & Monitoring + "structlog>=23.0.0", + "rich>=13.0.0", + "tensorboard>=2.13.0", + "wandb>=0.15.0", # Weights & Biases for experiment tracking + + # Data Processing + "pandas>=2.0.0", + "polars>=0.18.0", # Fast dataframe library + + # Utilities + "tqdm>=4.65.0", + "typer>=0.9.0", # CLI framework + "click>=8.1.0", +] + +[project.optional-dependencies] +dev = [ + # Testing + "pytest>=7.4.0", + "pytest-asyncio>=0.21.0", + "pytest-cov>=4.1.0", + "pytest-mock>=3.11.0", + "pytest-xdist>=3.3.0", # Parallel testing + + # Code Quality + "black>=23.0.0", + "isort>=5.12.0", + "flake8>=6.0.0", + "mypy>=1.4.0", + "pre-commit>=3.3.0", + + # Type stubs + "types-PyYAML>=6.0.0", + "types-tqdm>=4.65.0", + + # Documentation + "sphinx>=7.0.0", + "sphinx-rtd-theme>=1.3.0", + + # Profiling & Performance + "py-spy>=0.3.14", + "memory-profiler>=0.61.0", +] + +training = [ + # Distributed Training + "ray[default]>=2.6.0", + "ray[tune]>=2.6.0", + + # Advanced ML Tools + "optuna>=3.2.0", # Hyperparameter optimization + "lightning>=2.0.0", # PyTorch Lightning + + # GPU Monitoring + "gpustat>=1.1.0", + "nvidia-ml-py>=12.0.0", +] + +mtga = [ + # Screen Reading & Input Simulation (Future Phase 2) + "opencv-python>=4.8.0", + "Pillow>=10.0.0", + "pytesseract>=0.3.10", + "pyautogui>=0.9.54", + "pynput>=1.7.6", +] + +docker = [ + "docker>=6.1.0", + "docker-compose>=1.29.2", +] + +[project.urls] +Homepage = "https://github.com/manamind/manamind" +Repository = "https://github.com/manamind/manamind" +Documentation = "https://manamind.readthedocs.io" + +[project.scripts] +manamind = "manamind.cli:main" +train = "manamind.training.train:main" +eval = "manamind.evaluation.evaluate:main" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-dir] +"" = "src" + +[tool.black] +line-length = 79 +target-version = ['py39'] +include = '\.pyi?$' +extend-exclude = ''' +/( + # directories + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | build + | dist +)/ +''' + +[tool.isort] +profile = "black" +multi_line_output = 3 +line_length = 79 + +[tool.flake8] +max-line-length = 79 +extend-ignore = ["E203", "W503"] +exclude = [ + ".git", + "__pycache__", + "build", + "dist", + "*.egg-info", + ".venv", + ".tox" +] + +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true +strict_equality = true + +[[tool.mypy.overrides]] +module = [ + "py4j.*", + "jpype1.*", + "jpype.*", + "cv2.*", + "pytesseract.*", + "pyautogui.*", + "pynput.*", + "torch.*", + "forge.*", + "yaml.*", + "tqdm.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py", "*_test.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = [ + "--strict-markers", + "--strict-config", + "--cov=src/manamind", + "--cov-report=term-missing", + "--cov-report=html", + "--cov-report=xml", +] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "integration: marks tests as integration tests", + "forge: marks tests that require Forge engine", + "gpu: marks tests that require GPU", +] + +[tool.coverage.run] +source = ["src/manamind"] +omit = [ + "*/tests/*", + "*/test_*.py", + "*/__init__.py", +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "if self.debug:", + "if settings.DEBUG", + "raise AssertionError", + "raise NotImplementedError", + "if 0:", + "if __name__ == .__main__.:", + "class .*\\bProtocol\\):", + "@(abc\\.)?abstractmethod", +] \ No newline at end of file diff --git a/scripts/local-ci-check.sh b/scripts/local-ci-check.sh new file mode 100755 index 0000000..4f05d50 --- /dev/null +++ b/scripts/local-ci-check.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +# Local CI Check Script for ManaMind Project +# This script runs the same checks as the CI pipeline locally for faster iteration + +set -e # Exit on any error + +echo "🔍 ManaMind Local CI Checks Starting..." +echo "========================================" + +# Color codes for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Function to print status +print_status() { + if [ $1 -eq 0 ]; then + echo -e "${GREEN}✅ $2 PASSED${NC}" + else + echo -e "${RED}❌ $2 FAILED${NC}" + return 1 + fi +} + +# Track overall status +OVERALL_STATUS=0 + +# 1. MyPy Type Checking (Most Important - Blocks CI) +echo -e "\n${YELLOW}1. Running MyPy Type Checking...${NC}" +if mypy src; then + print_status 0 "MyPy Type Checking" +else + print_status 1 "MyPy Type Checking" + OVERALL_STATUS=1 +fi + +# 2. Black Code Formatting +echo -e "\n${YELLOW}2. Checking Black Formatting...${NC}" +if black --check --line-length 79 src tests; then + print_status 0 "Black Formatting" +else + print_status 1 "Black Formatting" + OVERALL_STATUS=1 + echo -e "${YELLOW}💡 Fix with: black --line-length 79 src tests${NC}" +fi + +# 3. Import Sorting (isort) +echo -e "\n${YELLOW}3. Checking Import Sorting...${NC}" +if isort --check-only src tests; then + print_status 0 "Import Sorting (isort)" +else + print_status 1 "Import Sorting (isort)" + OVERALL_STATUS=1 + echo -e "${YELLOW}💡 Fix with: isort src tests${NC}" +fi + +# 4. Linting (flake8) +echo -e "\n${YELLOW}4. Running Linting (flake8)...${NC}" +if flake8 src tests; then + print_status 0 "Linting (flake8)" +else + print_status 1 "Linting (flake8)" + OVERALL_STATUS=1 +fi + +# 5. Test Suite (if we get this far) +echo -e "\n${YELLOW}5. Running Test Suite...${NC}" +if pytest --cov=src/manamind --cov-report=xml --cov-report=term-missing -v; then + print_status 0 "Test Suite" +else + print_status 1 "Test Suite" + OVERALL_STATUS=1 +fi + +echo -e "\n========================================" +if [ $OVERALL_STATUS -eq 0 ]; then + echo -e "${GREEN}🎉 ALL CHECKS PASSED! Safe to push to CI.${NC}" +else + echo -e "${RED}💥 SOME CHECKS FAILED! Fix issues before pushing.${NC}" + echo -e "\n${YELLOW}Quick Fixes:${NC}" + echo " • Format code: black --line-length 79 src tests" + echo " • Sort imports: isort src tests" + echo " • Fix types: Focus on mypy errors above" + echo " • Run this script again after fixes" +fi + +echo -e "\n${YELLOW}💡 Pro Tips:${NC}" +echo " • Run 'mypy src' first - it's the fastest check" +echo " • Use 'act' to run full CI locally (takes 2-3 minutes)" +echo " • Set up pre-commit hooks: 'pre-commit install'" + +exit $OVERALL_STATUS \ No newline at end of file diff --git a/scripts/setup.sh b/scripts/setup.sh new file mode 100755 index 0000000..b55733c --- /dev/null +++ b/scripts/setup.sh @@ -0,0 +1,160 @@ +#!/bin/bash +# Setup script for ManaMind development environment + +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Logging functions +log_info() { + echo -e "${BLUE}[INFO]${NC} $1" +} + +log_success() { + echo -e "${GREEN}[SUCCESS]${NC} $1" +} + +log_warning() { + echo -e "${YELLOW}[WARNING]${NC} $1" +} + +log_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +# Check if running from project root +if [ ! -f "pyproject.toml" ] || [ ! -d "src/manamind" ]; then + log_error "Please run this script from the ManaMind project root directory" + exit 1 +fi + +log_info "Setting up ManaMind development environment..." + +# Check Python version +PYTHON_VERSION=$(python3 --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2) +REQUIRED_VERSION="3.9" + +if [ "$(printf '%s\n' "$REQUIRED_VERSION" "$PYTHON_VERSION" | sort -V | head -n1)" != "$REQUIRED_VERSION" ]; then + log_error "Python $REQUIRED_VERSION or higher is required. Found: $PYTHON_VERSION" + exit 1 +fi + +log_success "Python version check passed: $PYTHON_VERSION" + +# Create virtual environment if it doesn't exist +if [ ! -d "venv" ]; then + log_info "Creating virtual environment..." + python3 -m venv venv + log_success "Virtual environment created" +else + log_info "Virtual environment already exists" +fi + +# Activate virtual environment +log_info "Activating virtual environment..." +source venv/bin/activate + +# Upgrade pip +log_info "Upgrading pip..." +pip install --upgrade pip setuptools wheel + +# Install ManaMind in development mode +log_info "Installing ManaMind in development mode..." +pip install -e .[dev,training] + +# Create data directories +log_info "Creating data directories..." +mkdir -p data/{checkpoints,logs,game_logs,cards} +mkdir -p logs +log_success "Data directories created" + +# Download Forge if not present +if [ ! -d "forge" ]; then + log_info "Downloading Forge game engine..." + + # Create forge directory + mkdir -p forge + + # Download latest Forge release + FORGE_URL="https://releases.cardforge.org/forge/forge-gui-latest.tar.bz2" + + if command -v wget > /dev/null; then + wget -O forge/forge.tar.bz2 "$FORGE_URL" + elif command -v curl > /dev/null; then + curl -L -o forge/forge.tar.bz2 "$FORGE_URL" + else + log_error "Neither wget nor curl found. Please install one of them or download Forge manually." + exit 1 + fi + + # Extract Forge + log_info "Extracting Forge..." + tar -xjf forge/forge.tar.bz2 -C forge --strip-components=1 + rm forge/forge.tar.bz2 + + log_success "Forge downloaded and extracted" +else + log_info "Forge already exists" +fi + +# Check Java installation +if command -v java > /dev/null; then + JAVA_VERSION=$(java -version 2>&1 | head -n1 | cut -d'"' -f2 | cut -d'.' -f1) + if [ "$JAVA_VERSION" -ge 8 ]; then + log_success "Java $JAVA_VERSION found" + else + log_warning "Java 8 or higher is recommended for Forge. Found: Java $JAVA_VERSION" + fi +else + log_warning "Java not found. Please install Java 8 or higher for Forge integration." +fi + +# Setup pre-commit hooks if available +if command -v pre-commit > /dev/null; then + log_info "Setting up pre-commit hooks..." + pre-commit install + log_success "Pre-commit hooks installed" +else + log_info "Pre-commit not available (this is optional)" +fi + +# Test ManaMind installation +log_info "Testing ManaMind installation..." +if python -c "import manamind; print('ManaMind version:', manamind.__version__)" 2>/dev/null; then + log_success "ManaMind installation test passed" +else + log_error "ManaMind installation test failed" + exit 1 +fi + +# Test CLI +log_info "Testing ManaMind CLI..." +if manamind info > /dev/null 2>&1; then + log_success "ManaMind CLI test passed" +else + log_error "ManaMind CLI test failed" + exit 1 +fi + +# Create example configuration if it doesn't exist +if [ ! -f "configs/local.yaml" ]; then + log_info "Creating local configuration file..." + cp configs/base.yaml configs/local.yaml + log_info "Edit configs/local.yaml to customize your settings" +fi + +# Setup complete +log_success "ManaMind development environment setup complete!" +echo +log_info "Next steps:" +echo " 1. Activate the virtual environment: source venv/bin/activate" +echo " 2. Test Forge integration: manamind forge-test" +echo " 3. Start training: manamind train --config configs/local.yaml" +echo " 4. Or start development server: docker-compose --profile development up" +echo +log_info "For more information, see the documentation in docs/" \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..f965819 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,11 @@ +[flake8] +max-line-length = 79 +extend-ignore = E203,W503 +exclude = + .git, + __pycache__, + build, + dist, + *.egg-info, + .venv, + .tox \ No newline at end of file diff --git a/src/manamind/__init__.py b/src/manamind/__init__.py new file mode 100644 index 0000000..2808b35 --- /dev/null +++ b/src/manamind/__init__.py @@ -0,0 +1,19 @@ +"""ManaMind - AI agent for playing Magic: The Gathering at superhuman level. + +This package contains the core components for training and deploying an AI +agent that can play Magic: The Gathering using deep reinforcement learning. +""" + +__version__ = "0.1.0" +__author__ = "ManaMind Team" +__email__ = "team@manamind.ai" + +from manamind.core.action import Action +from manamind.core.game_state import GameState +from manamind.models.policy_value_network import PolicyValueNetwork + +__all__ = [ + "GameState", + "Action", + "PolicyValueNetwork", +] diff --git a/src/manamind/cli/__init__.py b/src/manamind/cli/__init__.py new file mode 100644 index 0000000..ad66bd8 --- /dev/null +++ b/src/manamind/cli/__init__.py @@ -0,0 +1,5 @@ +"""Command line interface for ManaMind.""" + +from manamind.cli.main import main + +__all__ = ["main"] diff --git a/src/manamind/cli/main.py b/src/manamind/cli/main.py new file mode 100644 index 0000000..70f90ee --- /dev/null +++ b/src/manamind/cli/main.py @@ -0,0 +1,285 @@ +"""Main CLI entry point for ManaMind.""" + +import logging +import sys +from pathlib import Path +from typing import Optional + +import typer +from rich.console import Console +from rich.logging import RichHandler + +app = typer.Typer( + name="manamind", help="ManaMind - AI agent for Magic: The Gathering" +) +console = Console() + + +def setup_logging(verbose: bool = False) -> None: + """Setup logging configuration.""" + level = logging.DEBUG if verbose else logging.INFO + + logging.basicConfig( + level=level, + format="%(message)s", + datefmt="[%X]", + handlers=[RichHandler(console=console)], + ) + + +@app.command() +def train( + config_path: Optional[Path] = typer.Option( + None, "--config", "-c", help="Path to training config file" + ), + forge_path: Optional[Path] = typer.Option( + None, "--forge-path", help="Path to Forge installation" + ), + iterations: int = typer.Option( + 100, "--iterations", "-i", help="Number of training iterations" + ), + games_per_iteration: int = typer.Option( + 50, "--games", "-g", help="Games per training iteration" + ), + checkpoint_dir: Path = typer.Option( + Path("checkpoints"), + "--checkpoint-dir", + help="Directory for checkpoints", + ), + resume_from: Optional[Path] = typer.Option( + None, "--resume", help="Resume from checkpoint" + ), + verbose: bool = typer.Option( + False, "--verbose", "-v", help="Verbose logging" + ), +) -> None: + """Train the ManaMind agent using self-play.""" + setup_logging(verbose) + + console.print("[bold blue]Starting ManaMind training...[/bold blue]") + + try: + import torch + + from manamind.forge_interface import ForgeClient + from manamind.models.policy_value_network import ( + create_policy_value_network, + ) + from manamind.training.self_play import SelfPlayTrainer + + # Create policy-value network + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + console.print(f"Using device: {device}") + + network = create_policy_value_network().to(device) + + # Setup Forge client if path provided + forge_client = None + if forge_path: + forge_client = ForgeClient(forge_path=forge_path) + + # Create trainer + config = { + "training_iterations": iterations, + "games_per_iteration": games_per_iteration, + "checkpoint_dir": str(checkpoint_dir), + } + + trainer = SelfPlayTrainer( + policy_value_network=network, + forge_client=forge_client, + config=config, + ) + + # Resume from checkpoint if specified + if resume_from: + trainer.load_checkpoint(str(resume_from)) + console.print(f"Resumed from checkpoint: {resume_from}") + + # Start training + trainer.train() + + console.print( + "[bold green]Training completed successfully![/bold green]" + ) + + except Exception as e: + console.print(f"[bold red]Training failed: {e}[/bold red]") + if verbose: + console.print_exception() + sys.exit(1) + + +@app.command() +def eval( + model_path: Path = typer.Argument(..., help="Path to trained model"), + opponent: str = typer.Option( + "forge", + "--opponent", + "-o", + help="Opponent type (forge, random, human)", + ), + num_games: int = typer.Option( + 10, "--games", "-g", help="Number of evaluation games" + ), + forge_path: Optional[Path] = typer.Option( + None, "--forge-path", help="Path to Forge installation" + ), + verbose: bool = typer.Option( + False, "--verbose", "-v", help="Verbose logging" + ), +) -> None: + """Evaluate a trained ManaMind model.""" + setup_logging(verbose) + + console.print("[bold blue]Starting ManaMind evaluation...[/bold blue]") + + try: + # import torch + + # from manamind.evaluation.evaluator import ModelEvaluator + # Load model + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # checkpoint = torch.load(model_path, map_location=device) + + # Create evaluator + # evaluator = ModelEvaluator( + # model_checkpoint=checkpoint, forge_path=forge_path + # ) + + # Run evaluation + # results = evaluator.evaluate( + # opponent_type=opponent, num_games=num_games + # ) + + # Print results + console.print("[bold red]Evaluation not yet implemented[/bold red]") + # console.print("[bold green]Evaluation Results:[/bold green]") + # console.print(f"Win Rate: {results['win_rate']:.1%}") + # console.print(f"Games Played: {results['total_games']}") + # console.print(f"Wins: {results['wins']}") + # console.print(f"Losses: {results['losses']}") + # console.print(f"Draws: {results['draws']}") + + except Exception as e: + console.print(f"[bold red]Evaluation failed: {e}[/bold red]") + if verbose: + console.print_exception() + sys.exit(1) + + +@app.command() +def forge_test( + forge_path: Optional[Path] = typer.Option( + None, "--forge-path", help="Path to Forge installation" + ), + verbose: bool = typer.Option( + False, "--verbose", "-v", help="Verbose logging" + ), +) -> None: + """Test connection to Forge game engine.""" + setup_logging(verbose) + + console.print("[bold blue]Testing Forge connection...[/bold blue]") + + try: + from manamind.forge_interface import ForgeClient + + with ForgeClient(forge_path=forge_path) as client: + console.print("[green]✓[/green] Successfully connected to Forge") + + # Try to create a test game + game_id = client.create_game("deck1.dck", "deck2.dck") + console.print(f"[green]✓[/green] Created test game: {game_id}") + + # Get initial game state + client.get_game_state(game_id) + console.print("[green]✓[/green] Retrieved game state") + + console.print( + "[bold green]Forge connection test successful![/bold green]" + ) + + except Exception as e: + console.print(f"[bold red]Forge connection failed: {e}[/bold red]") + if verbose: + console.print_exception() + sys.exit(1) + + +@app.command() +def play( + model_path: Path = typer.Argument(..., help="Path to trained model"), + deck_path: Optional[Path] = typer.Option( + None, "--deck", help="Path to deck file" + ), + opponent: str = typer.Option( + "human", + "--opponent", + "-o", + help="Opponent type (human, forge, random)", + ), + verbose: bool = typer.Option( + False, "--verbose", "-v", help="Verbose logging" + ), +) -> None: + """Play a game against the ManaMind agent.""" + setup_logging(verbose) + + console.print("[bold blue]Starting ManaMind game...[/bold blue]") + + # TODO: Implement interactive play interface + console.print("[yellow]Interactive play not yet implemented[/yellow]") + + +@app.command() +def info() -> None: + """Show ManaMind system information.""" + console.print("[bold blue]ManaMind System Information[/bold blue]") + + try: + import torch + + from manamind import __version__ + + console.print(f"Version: {__version__}") + console.print(f"PyTorch Version: {torch.__version__}") + console.print(f"CUDA Available: {torch.cuda.is_available()}") + + if torch.cuda.is_available(): + console.print(f"CUDA Device: {torch.cuda.get_device_name()}") + memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + console.print(f"CUDA Memory: {memory_gb:.1f} GB") + + # Check for optional dependencies + try: + import py4j + + console.print(f"Py4J Version: {py4j.__version__}") + except ImportError: + console.print( + "[yellow]Py4J not available (needed for Forge integration)" + "[/yellow]" + ) + + try: + import jpype + + console.print(f"JPype Version: {jpype.__version__}") + except ImportError: + console.print( + "[yellow]JPype not available (alternative for Forge " + "integration)[/yellow]" + ) + + except Exception as e: + console.print(f"[red]Error getting system info: {e}[/red]") + + +def main() -> None: + """Main entry point.""" + app() + + +if __name__ == "__main__": + main() diff --git a/src/manamind/core/__init__.py b/src/manamind/core/__init__.py new file mode 100644 index 0000000..074fefa --- /dev/null +++ b/src/manamind/core/__init__.py @@ -0,0 +1,19 @@ +"""Core components for ManaMind AI agent. + +This module contains the fundamental building blocks of the ManaMind system: +- Game state representation and encoding +- Action definitions and validation +- Base agent interface +""" + +from manamind.core.action import Action, ActionSpace +from manamind.core.agent import Agent +from manamind.core.game_state import GameState, GameStateEncoder + +__all__ = [ + "GameState", + "GameStateEncoder", + "Action", + "ActionSpace", + "Agent", +] diff --git a/src/manamind/core/action.py b/src/manamind/core/action.py new file mode 100644 index 0000000..4d9ec81 --- /dev/null +++ b/src/manamind/core/action.py @@ -0,0 +1,486 @@ +"""Action representation and validation for Magic: The Gathering. + +This module defines how actions (moves) in MTG are represented and validated. +Actions include playing lands, casting spells, activating abilities, etc. +""" + +from __future__ import annotations + +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional + +import torch + +from manamind.core.game_state import Card, GameState + + +class ActionType(Enum): + """Extended taxonomy of MTG actions for comprehensive gameplay.""" + + # Basic game actions + PLAY_LAND = "play_land" + CAST_SPELL = "cast_spell" + ACTIVATE_ABILITY = "activate_ability" + ACTIVATE_MANA_ABILITY = "activate_mana_ability" + + # Priority and timing + PASS_PRIORITY = "pass_priority" + HOLD_PRIORITY = "hold_priority" + + # Combat actions + DECLARE_ATTACKERS = "declare_attackers" + DECLARE_BLOCKERS = "declare_blockers" + ASSIGN_COMBAT_DAMAGE = "assign_combat_damage" + ORDER_BLOCKERS = "order_blockers" + + # Special actions + MULLIGAN = "mulligan" + KEEP_HAND = "keep_hand" + CONCEDE = "concede" + + # Card-specific actions + DISCARD = "discard" + SACRIFICE = "sacrifice" + DESTROY = "destroy" + EXILE = "exile" + + # Targeting and choices + CHOOSE_TARGET = "choose_target" + CHOOSE_MODE = "choose_mode" + CHOOSE_X_VALUE = "choose_x_value" + ORDER_CARDS = "order_cards" + + # Mana actions + TAP_FOR_MANA = "tap_for_mana" + PAY_MANA = "pay_mana" + + +@dataclass +class Action: + """Enhanced action representation supporting all MTG complexities. + + This is the fundamental unit of decision-making for the AI agent. + Each action contains all the information needed to execute it. + """ + + # Core action data + action_type: ActionType + player_id: int + timestamp: float = field(default_factory=time.time) + + # Card references + card: Optional[Card] = None + target_cards: List[Card] = field(default_factory=list) + + # Player/permanent targets + target_players: List[int] = field(default_factory=list) + target_permanents: List[int] = field(default_factory=list) + + # Legacy compatibility + target: Optional[Any] = None + + # Mana payment + mana_payment: Optional[Dict[str, int]] = None + alternative_cost: Optional[str] = None + + # Choices and parameters + x_value: Optional[int] = None + modes_chosen: List[str] = field(default_factory=list) + order_choices: List[int] = field(default_factory=list) + additional_choices: Dict[str, Any] = field(default_factory=dict) + + # Combat-specific + attackers: List[int] = field(default_factory=list) + defenders: List[int] = field(default_factory=list) + blockers: Dict[int, List[int]] = field(default_factory=dict) + damage_assignment: Dict[int, Dict[int, int]] = field(default_factory=dict) + + # Neural network representation + action_id: Optional[int] = None + action_vector: Optional[torch.Tensor] = None + + def get_complexity_score(self) -> int: + """Calculate action complexity for MCTS guidance.""" + score = 1 # Base complexity + + if self.target_cards: + score += len(self.target_cards) + if self.modes_chosen: + score += len(self.modes_chosen) * 2 + if self.x_value: + score += 3 + if self.blockers: + score += sum(len(blockers) for blockers in self.blockers.values()) + + return score + + def get_all_targets(self) -> List[Any]: + """Get all targets referenced by this action.""" + targets: List[Any] = [] + if self.card: + targets.append(self.card) + if self.target: + targets.append(self.target) + targets.extend(self.target_cards) + targets.extend(self.target_players) + targets.extend(self.target_permanents) + return targets + + def is_valid(self, game_state: GameState) -> bool: + """Check if this action is legal in the given game state. + + Args: + game_state: Current game state + + Returns: + True if the action is legal, False otherwise + """ + # Delegate to specific validators based on action type + validator = ACTION_VALIDATORS.get(self.action_type) + if validator: + return validator.validate(self, game_state) + return False + + def execute(self, game_state: GameState) -> GameState: + """Execute this action and return the resulting game state. + + Args: + game_state: Current game state + + Returns: + New game state after executing this action + + Raises: + ValueError: If the action is not valid + """ + if not self.is_valid(game_state): + raise ValueError(f"Invalid action: {self}") + + executor = ACTION_EXECUTORS.get(self.action_type) + if executor: + return executor.execute(self, game_state) + + raise NotImplementedError( + f"Execution not implemented for {self.action_type}" + ) + + +class ActionValidator(ABC): + """Base class for validating specific types of actions.""" + + @abstractmethod + def validate(self, action: Action, game_state: GameState) -> bool: + """Check if the action is valid in the given state.""" + pass + + +class PlayLandValidator(ActionValidator): + """Validates land-playing actions.""" + + def validate(self, action: Action, game_state: GameState) -> bool: + player = game_state.players[action.player_id] + + # Check basic conditions + if not action.card: + return False + + # Must be in player's hand + if action.card not in player.hand.cards: + return False + + # Must be a land + if "Land" not in action.card.card_type: + return False + + # Can only play one land per turn (simplified rule) + if not player.can_play_land(): + return False + + # Must have priority during main phase + if game_state.priority_player != action.player_id: + return False + + if game_state.phase not in ["main", "main2"]: + return False + + # Stack must be empty (simplified rule) + if len(game_state.stack) > 0: + return False + + return True + + +class CastSpellValidator(ActionValidator): + """Validates spell-casting actions.""" + + def validate(self, action: Action, game_state: GameState) -> bool: + player = game_state.players[action.player_id] + + # Check basic conditions + if not action.card: + return False + + # Must be in player's hand + if action.card not in player.hand.cards: + return False + + # Must not be a land + if "Land" in action.card.card_type: + return False + + # Must have priority + if game_state.priority_player != action.player_id: + return False + + # Check timing restrictions (simplified) + if "Instant" not in action.card.card_type: + # Sorcery-speed spell + if game_state.active_player != action.player_id: + return False + if game_state.phase not in ["main", "main2"]: + return False + if len(game_state.stack) > 0: + return False + + # Check mana cost (simplified - just check total mana) + if player.total_mana() < action.card.converted_mana_cost: + return False + + # TODO: More sophisticated mana checking, target validation, etc. + + return True + + +class PassPriorityValidator(ActionValidator): + """Validates passing priority.""" + + def validate(self, action: Action, game_state: GameState) -> bool: + # Can always pass priority when you have it + return game_state.priority_player == action.player_id + + +# Registry of validators +ACTION_VALIDATORS = { + ActionType.PLAY_LAND: PlayLandValidator(), + ActionType.CAST_SPELL: CastSpellValidator(), + ActionType.PASS_PRIORITY: PassPriorityValidator(), + # TODO: Add more validators +} + + +class ActionExecutor(ABC): + """Base class for executing specific types of actions.""" + + @abstractmethod + def execute(self, action: Action, game_state: GameState) -> GameState: + """Execute the action and return the new game state.""" + pass + + +class PlayLandExecutor(ActionExecutor): + """Executes land-playing actions.""" + + def execute(self, action: Action, game_state: GameState) -> GameState: + # Create a copy of the game state (TODO: implement efficient copying) + new_state = game_state.copy() + + player = new_state.players[action.player_id] + + # Move card from hand to battlefield + if action.card: + player.hand.remove_card(action.card) + player.battlefield.add_card(action.card) + + # Update lands played this turn + player.lands_played_this_turn += 1 + + # TODO: Trigger any relevant abilities + + return new_state + + +class CastSpellExecutor(ActionExecutor): + """Executes spell-casting actions.""" + + def execute(self, action: Action, game_state: GameState) -> GameState: + # Create a copy of the game state + new_state = game_state.copy() + + player = new_state.players[action.player_id] + + # Pay mana cost (simplified) + # TODO: Proper mana payment logic + + # Move card from hand to stack + if action.card: + player.hand.remove_card(action.card) + new_state.stack.append( + { + "card": action.card, + "controller": action.player_id, + "targets": action.target, + "choices": action.additional_choices, + } + ) + + # TODO: Handle targeting, additional costs, etc. + + return new_state + + +class PassPriorityExecutor(ActionExecutor): + """Executes priority passing.""" + + def execute(self, action: Action, game_state: GameState) -> GameState: + new_state = game_state.copy() + + # Pass priority to the other player + new_state.priority_player = 1 - new_state.priority_player + + # TODO: Handle stack resolution, phase changes, etc. + + return new_state + + +# Registry of executors +ACTION_EXECUTORS = { + ActionType.PLAY_LAND: PlayLandExecutor(), + ActionType.CAST_SPELL: CastSpellExecutor(), + ActionType.PASS_PRIORITY: PassPriorityExecutor(), + # TODO: Add more executors +} + + +class ActionSpace: + """Manages the space of all possible actions in Magic: The Gathering. + + This class is responsible for: + 1. Generating all legal actions from a given game state + 2. Converting actions to/from neural network representations + 3. Pruning invalid actions for efficiency + """ + + def __init__(self, max_actions: int = 10000): + """Initialize the action space. + + Args: + max_actions: Maximum number of actions to consider (for NN sizing) + """ + self.max_actions = max_actions + self.action_to_id: Dict[str, int] = {} + self.id_to_action: Dict[int, str] = {} + self._build_action_mappings() + + def _build_action_mappings(self) -> None: + """Build mappings between actions and integer IDs for networks.""" + # TODO: Build comprehensive action vocabulary + # This is a critical component for the neural network + action_id = 0 + + # Basic actions + for action_type in ActionType: + self.action_to_id[action_type.value] = action_id + self.id_to_action[action_id] = action_type.value + action_id += 1 + + # TODO: Add card-specific actions, target-specific actions, etc. + # This will likely need to be dynamic based on the current game state + + def get_legal_actions(self, game_state: GameState) -> List[Action]: + """Generate all legal actions from the current game state. + + Args: + game_state: Current game state + + Returns: + List of all legal actions the current priority player can take + """ + legal_actions = [] + current_player_id = game_state.priority_player + current_player = game_state.players[current_player_id] + + # Can always pass priority + legal_actions.append( + Action( + action_type=ActionType.PASS_PRIORITY, + player_id=current_player_id, + ) + ) + + # Check for land plays + if ( + game_state.active_player == current_player_id + and current_player.can_play_land() + and game_state.phase in ["main", "main2"] + and len(game_state.stack) == 0 + ): + + for card in current_player.hand.cards: + if "Land" in card.card_type: + action = Action( + action_type=ActionType.PLAY_LAND, + player_id=current_player_id, + card=card, + ) + if action.is_valid(game_state): + legal_actions.append(action) + + # Check for spell casts + for card in current_player.hand.cards: + if "Land" not in card.card_type: + action = Action( + action_type=ActionType.CAST_SPELL, + player_id=current_player_id, + card=card, + ) + if action.is_valid(game_state): + legal_actions.append(action) + + # TODO: Add more action types (abilities, combat, etc.) + + return legal_actions + + def action_to_vector(self, action: Action) -> List[float]: + """Convert an action to a vector representation for neural networks. + + Args: + action: The action to convert + + Returns: + Vector representation of the action + """ + # TODO: Implement sophisticated action encoding + # This is critical for the policy network + vector = [0.0] * self.max_actions + + if action.action_type.value in self.action_to_id: + action_idx = self.action_to_id[action.action_type.value] + vector[action_idx] = 1.0 + + return vector + + def vector_to_action( + self, vector: List[float], game_state: GameState + ) -> Optional[Action]: + """Convert a vector representation back to an action. + + Args: + vector: Vector representation from neural network + game_state: Current game state for context + + Returns: + The corresponding action, or None if invalid + """ + # TODO: Implement sophisticated action decoding + # Find the highest probability legal action + legal_actions = self.get_legal_actions(game_state) + + if not legal_actions: + return None + + # For now, just return the first legal action + # TODO: Use the vector to select the best action + return legal_actions[0] diff --git a/src/manamind/core/agent.py b/src/manamind/core/agent.py new file mode 100644 index 0000000..b32ba53 --- /dev/null +++ b/src/manamind/core/agent.py @@ -0,0 +1,392 @@ +"""Base agent interface and Monte Carlo Tree Search implementation. + +This module defines the core agent interface and implements MCTS for decisions. +""" + +from __future__ import annotations + +import math +import random +import time +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from manamind.core.action import Action, ActionSpace, ActionType +from manamind.core.game_state import GameState + + +class Agent(ABC): + """Abstract base class for all ManaMind agents.""" + + def __init__(self, player_id: int): + """Initialize the agent. + + Args: + player_id: The player ID this agent controls (0 or 1) + """ + self.player_id = player_id + + @abstractmethod + def select_action(self, game_state: GameState) -> Action: + """Select the best action from the current game state. + + Args: + game_state: Current game state + + Returns: + The selected action + """ + pass + + @abstractmethod + def update_from_game( + self, game_history: List[Tuple[GameState, Action, float]] + ) -> None: + """Update the agent's knowledge from a completed game. + + Args: + game_history: List of (state, action, reward) tuples from the game + """ + pass + + +class RandomAgent(Agent): + """Simple random agent for testing and baseline comparison.""" + + def __init__(self, player_id: int, seed: Optional[int] = None): + super().__init__(player_id) + self.action_space = ActionSpace() + self.rng = random.Random(seed) + + def select_action(self, game_state: GameState) -> Action: + """Select a random legal action.""" + legal_actions = self.action_space.get_legal_actions(game_state) + if not legal_actions: + raise ValueError("No legal actions available") + return self.rng.choice(legal_actions) + + def update_from_game( + self, game_history: List[Tuple[GameState, Action, float]] + ) -> None: + """Random agent doesn't learn.""" + pass + + +class MCTSNode: + """Node in the Monte Carlo Tree Search tree.""" + + def __init__( + self, + game_state: GameState, + action: Optional[Action] = None, + parent: Optional[MCTSNode] = None, + ): + """Initialize MCTS node. + + Args: + game_state: Game state this node represents + action: Action taken to reach this state (None for root) + parent: Parent node (None for root) + """ + self.game_state = game_state + self.action = action + self.parent = parent + self.children: Dict[Action, MCTSNode] = {} + + # MCTS statistics + self.visits = 0 + self.total_value = 0.0 + self.prior_prob = 1.0 # From policy network + + # Untried actions + action_space = ActionSpace() + self.untried_actions = action_space.get_legal_actions(game_state) + + def is_fully_expanded(self) -> bool: + """Check if all legal actions have been tried.""" + return len(self.untried_actions) == 0 + + def is_terminal(self) -> bool: + """Check if this is a terminal game state.""" + return self.game_state.is_game_over() + + def ucb1_score(self, c: float = 1.414) -> float: + """Calculate UCB1 score for action selection. + + Args: + c: Exploration parameter + + Returns: + UCB1 score + """ + if self.visits == 0: + return float("inf") + + exploitation = self.total_value / self.visits + exploration = ( + c * math.sqrt(math.log(self.parent.visits) / self.visits) + if self.parent + else 0.0 + ) + return exploitation + exploration + + def select_child(self) -> MCTSNode: + """Select the child with the highest UCB1 score.""" + return max( + self.children.values(), key=lambda child: child.ucb1_score() + ) + + def expand(self) -> MCTSNode: + """Expand the tree by adding a new child node.""" + if not self.untried_actions: + raise ValueError("No untried actions to expand") + + action = self.untried_actions.pop() + new_state = action.execute(self.game_state) + child_node = MCTSNode(new_state, action, self) + self.children[action] = child_node + return child_node + + def backup(self, value: float) -> None: + """Backup the value through the tree.""" + self.visits += 1 + self.total_value += value + + if self.parent: + # Flip value for opponent + self.parent.backup(-value) + + +class MCTSAgent(Agent): + """Agent using Monte Carlo Tree Search for decision making.""" + + def __init__( + self, + player_id: int, + policy_network: Any = None, + value_network: Any = None, + simulations: int = 1000, + simulation_time: float = 1.0, + c_puct: float = 1.0, + ) -> None: + """Initialize MCTS agent. + + Args: + player_id: Player ID this agent controls + policy_network: Neural network for action priors (optional) + value_network: Neural network for position evaluation (optional) + simulations: Number of MCTS simulations per move + simulation_time: Time limit for MCTS (seconds) + c_puct: Exploration parameter for PUCT algorithm + """ + super().__init__(player_id) + self.policy_network = policy_network + self.value_network = value_network + self.simulations = simulations + self.simulation_time = simulation_time + self.c_puct = c_puct + self.action_space = ActionSpace() + + def select_action(self, game_state: GameState) -> Action: + """Select the best action using MCTS. + + Args: + game_state: Current game state + + Returns: + The selected action + """ + root = MCTSNode(game_state) + + # Set prior probabilities from policy network if available + if self.policy_network: + self._set_prior_probabilities(root) + + start_time = time.time() + simulation_count = 0 + + # Run MCTS simulations + while ( + simulation_count < self.simulations + and time.time() - start_time < self.simulation_time + ): + + # Selection phase - traverse tree to leaf + node = root + path = [node] + + while not node.is_terminal() and node.is_fully_expanded(): + node = node.select_child() + path.append(node) + + # Expansion phase - add new child if possible + if not node.is_terminal() and not node.is_fully_expanded(): + node = node.expand() + path.append(node) + + # Simulation phase - evaluate position + value = self._evaluate_position(node.game_state) + + # Backpropagation phase - update statistics + for node in reversed(path): + node.backup(value) + value = -value # Flip for opponent + + simulation_count += 1 + + # Select the most visited child as the best move + if not root.children: + # No expansions happened, return random action + legal_actions = self.action_space.get_legal_actions(game_state) + return random.choice(legal_actions) + + best_child = max( + root.children.values(), key=lambda child: child.visits + ) + if best_child.action: + return best_child.action + + # Fallback if no action found + legal_actions = self.action_space.get_legal_actions(game_state) + return ( + random.choice(legal_actions) + if legal_actions + else Action(ActionType.PASS_PRIORITY, self.player_id) + ) + + def _set_prior_probabilities(self, node: MCTSNode) -> None: + """Set prior probabilities for actions using the policy network.""" + if not self.policy_network: + return + + # TODO: Implement policy network evaluation + # For now, set uniform priors + num_actions = len(node.untried_actions) + if num_actions > 0: + for action in node.untried_actions: + # This would be set from policy network output + pass + + def _evaluate_position(self, game_state: GameState) -> float: + """Evaluate a game position. + + Args: + game_state: Game state to evaluate + + Returns: + Value from current player's perspective (-1 to 1) + """ + # Check for terminal states + if game_state.is_game_over(): + winner = game_state.winner() + if winner == self.player_id: + return 1.0 + elif winner is not None: + return -1.0 + else: + return 0.0 # Draw + + # Use value network if available + if self.value_network: + return self._evaluate_with_network(game_state) + + # Fallback to simple heuristic + return self._heuristic_evaluation(game_state) + + def _evaluate_with_network(self, game_state: GameState) -> float: + """Evaluate position using neural network. + + Args: + game_state: Game state to evaluate + + Returns: + Network evaluation (-1 to 1) + """ + # TODO: Implement network evaluation + # This requires the game state encoder and value network + return 0.0 + + def _heuristic_evaluation(self, game_state: GameState) -> float: + """Simple heuristic evaluation of the position. + + Args: + game_state: Game state to evaluate + + Returns: + Heuristic value (-1 to 1) + """ + # Simple life difference heuristic + my_life = game_state.players[self.player_id].life + opp_life = game_state.players[1 - self.player_id].life + + life_diff = my_life - opp_life + # Normalize to roughly [-1, 1] + return max(-1.0, min(1.0, life_diff / 20.0)) + + def update_from_game( + self, game_history: List[Tuple[GameState, Action, float]] + ) -> None: + """Update agent from game history. + + For MCTS agent, this could be used to update neural networks. + """ + # TODO: Implement training data collection for neural networks + pass + + +class NeuralAgent(Agent): + """Agent using neural networks for policy and value estimation.""" + + def __init__( + self, + player_id: int, + policy_value_network: Any, + action_space: Optional[ActionSpace] = None, + temperature: float = 1.0, + ) -> None: + """Initialize neural agent. + + Args: + player_id: Player ID this agent controls + policy_value_network: Combined policy/value network + action_space: Action space for the game + temperature: Temperature for action selection + """ + super().__init__(player_id) + self.policy_value_network = policy_value_network + self.action_space = action_space or ActionSpace() + self.temperature = temperature + + def select_action(self, game_state: GameState) -> Action: + """Select action using neural network policy. + + Args: + game_state: Current game state + + Returns: + Selected action + """ + legal_actions = self.action_space.get_legal_actions(game_state) + if not legal_actions: + raise ValueError("No legal actions available") + + # Get policy and value from network + with torch.no_grad(): + policy_logits, value = self.policy_value_network(game_state) + + # Apply softmax with temperature + if self.temperature > 0: + torch.softmax(policy_logits / self.temperature, dim=-1) + # TODO: Map probabilities to legal actions and sample + + # For now, return random action + return random.choice(legal_actions) + + def update_from_game( + self, game_history: List[Tuple[GameState, Action, float]] + ) -> None: + """Collect training data from game history.""" + # TODO: Store training examples for network updates + pass diff --git a/src/manamind/core/game_state.py b/src/manamind/core/game_state.py new file mode 100644 index 0000000..b16814e --- /dev/null +++ b/src/manamind/core/game_state.py @@ -0,0 +1,486 @@ +"""Game state representation and encoding for Magic: The Gathering. + +This module defines how MTG game states are represented internally and encoded +into neural network inputs. +""" + +from __future__ import annotations + +import copy +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from pydantic import BaseModel, Field, model_validator + + +class Card(BaseModel): + """Represents a Magic: The Gathering card with enhanced state tracking.""" + + # Core card data (from MTGJSON) + name: str + mana_cost: str = "" + converted_mana_cost: int = 0 + card_types: List[str] = Field( + default_factory=list + ) # ["Creature", "Artifact"] + subtypes: List[str] = Field(default_factory=list) # ["Human", "Soldier"] + supertypes: List[str] = Field( + default_factory=list + ) # ["Legendary", "Basic"] + + # Creature/Planeswalker stats + power: Optional[int] = None + toughness: Optional[int] = None + base_power: Optional[int] = None + base_toughness: Optional[int] = None + loyalty: Optional[int] = None + starting_loyalty: Optional[int] = None + + # Game state + tapped: bool = False + summoning_sick: bool = False + counters: Dict[str, int] = Field(default_factory=dict) + + # Text and abilities + oracle_text: str = "" + oracle_id: str = "" + abilities: List[str] = Field(default_factory=list) + keywords: List[str] = Field(default_factory=list) + + # Combat state + attacking: bool = False + blocking: Optional[int] = None + blocked_by: List[int] = Field(default_factory=list) + + # Ownership and control + controller: int = 0 + owner: int = 0 + + # Zone information + zone: str = "unknown" + zone_position: Optional[int] = None + + # Internal encoding IDs (assigned during preprocessing) + card_id: Optional[int] = None + instance_id: Optional[int] = None + + @model_validator(mode="before") + @classmethod + def handle_card_type_compatibility(cls, values: Any) -> Any: + """Handle backward compatibility for card_type parameter.""" + if isinstance(values, dict): + # Handle card_type -> card_types conversion + if "card_type" in values and "card_types" not in values: + card_type_str = values.pop("card_type") + if card_type_str: + # Split on both space and em-dash for types like + # "Creature — Bear" + parts = card_type_str.replace(" — ", " ").split() + values["card_types"] = parts + + # Handle text -> oracle_text conversion + if "text" in values and "oracle_text" not in values: + values["oracle_text"] = values.pop("text") + + return values + + @property + def card_type(self) -> str: + """Backward compatibility - return joined card types.""" + return " ".join(self.card_types) if self.card_types else "" + + @card_type.setter + def card_type(self, value: str) -> None: + """Backward compatibility - parse card type string.""" + if value: + self.card_types = value.split() + + def is_creature(self) -> bool: + """Check if this card is a creature.""" + return "Creature" in self.card_types + + def is_land(self) -> bool: + """Check if this card is a land.""" + return "Land" in self.card_types + + def is_instant_or_sorcery(self) -> bool: + """Check if this card is an instant or sorcery.""" + return "Instant" in self.card_types or "Sorcery" in self.card_types + + def current_power(self) -> Optional[int]: + """Get current power including modifications.""" + if self.power is None: + return None + return ( + self.power + + self.counters.get("+1/+1", 0) + - self.counters.get("-1/-1", 0) + ) + + def current_toughness(self) -> Optional[int]: + """Get current toughness including modifications.""" + if self.toughness is None: + return None + return ( + self.toughness + + self.counters.get("+1/+1", 0) + - self.counters.get("-1/-1", 0) + ) + + +class Zone(BaseModel): + """Represents a game zone (hand, battlefield, graveyard, etc.).""" + + cards: List[Card] = Field(default_factory=list) + name: str + owner: int # Player ID (0 or 1) + + def add_card(self, card: Card) -> None: + """Add a card to this zone.""" + self.cards.append(card) + + def remove_card(self, card: Card) -> bool: + """Remove a card from this zone. Returns True if successful.""" + try: + self.cards.remove(card) + return True + except ValueError: + return False + + def size(self) -> int: + """Return the number of cards in this zone.""" + return len(self.cards) + + +class Player(BaseModel): + """Represents a player in the game.""" + + player_id: int + life: int = 20 + mana_pool: Dict[str, int] = Field(default_factory=dict) + lands_played_this_turn: int = 0 + + # Zones + hand: Zone + battlefield: Zone + graveyard: Zone + library: Zone + exile: Zone + command_zone: Zone + + def __init__(self, player_id: int, **data: Any) -> None: + # Initialize zones with proper player ownership + zones = { + "hand": Zone(name="hand", owner=player_id), + "battlefield": Zone(name="battlefield", owner=player_id), + "graveyard": Zone(name="graveyard", owner=player_id), + "library": Zone(name="library", owner=player_id), + "exile": Zone(name="exile", owner=player_id), + "command_zone": Zone(name="command_zone", owner=player_id), + } + super().__init__(player_id=player_id, **zones, **data) + + def can_play_land(self) -> bool: + """Check if the player can play a land this turn.""" + return self.lands_played_this_turn == 0 # Simplified rule + + def total_mana(self) -> int: + """Calculate total available mana.""" + return sum(self.mana_pool.values()) + + +@dataclass +class GameState: + """Represents the complete state of a Magic: The Gathering game. + + This is the main data structure that captures all relevant information + about the current game state that the AI agent needs to make decisions. + """ + + # Players (required field) + players: Tuple[Player, Player] + + # Basic game info + turn_number: int = 1 + phase: str = "main" # untap, upkeep, draw, main, combat, main2, end + priority_player: int = 0 # Which player has priority (0 or 1) + active_player: int = 0 # Whose turn it is + + # Stack (spells and abilities waiting to resolve) + stack: List[Dict[str, Any]] = field(default_factory=list) + + # Game history for neural network context + history: List[Dict[str, Any]] = field(default_factory=list) + + @property + def current_player(self) -> Player: + """Get the player whose turn it is.""" + return self.players[self.active_player] + + @property + def opponent(self) -> Player: + """Get the opponent of the active player.""" + return self.players[1 - self.active_player] + + def is_game_over(self) -> bool: + """Check if the game has ended.""" + return any(player.life <= 0 for player in self.players) + + def winner(self) -> Optional[int]: + """Return the winner's player ID, or None if game is ongoing.""" + for i, player in enumerate(self.players): + if player.life <= 0: + return 1 - i # The other player wins + return None + + def copy(self) -> GameState: + """Create a deep copy of the game state for simulation.""" + # Deep copy is expensive but necessary for correctness + # TODO: Optimize with copy-on-write or incremental updates + return copy.deepcopy(self) + + def compute_hash(self) -> int: + """Compute a hash for transposition tables and state caching.""" + # Create hash from key game state components + hash_components = [ + self.turn_number, + self.phase, + self.active_player, + self.priority_player, + tuple(p.life for p in self.players), + tuple( + len(getattr(p, zone).cards) + for p in self.players + for zone in [ + "hand", + "battlefield", + "graveyard", + "library", + "exile", + ] + ), + len(self.stack), + ] + return hash(tuple(hash_components)) + + def get_features_for_encoding(self) -> Dict[str, Any]: + """Extract features for neural network encoding.""" + return { + "turn_number": self.turn_number, + "phase": self.phase, + "active_player": self.active_player, + "priority_player": self.priority_player, + "players": [ + { + "life": p.life, + "mana_pool": p.mana_pool, + "lands_played": p.lands_played_this_turn, + "zones": { + zone_name: [ + { + "card_id": card.card_id, + "tapped": getattr(card, "tapped", False), + "counters": getattr(card, "counters", {}), + "power": card.current_power(), + "toughness": card.current_toughness(), + } + for card in getattr(p, zone_name).cards + ] + for zone_name in [ + "hand", + "battlefield", + "graveyard", + "library", + "exile", + ] + }, + } + for p in self.players + ], + "stack_size": len(self.stack), + } + + +class GameStateEncoder(nn.Module): + """Neural network module to encode game states into fixed-size tensors. + + This is a critical component that converts the complex, variable-size + game state into a fixed-size numerical representation that can be + processed by the policy and value networks. + """ + + def __init__( + self, + vocab_size: int = 50000, # Number of unique cards/tokens + embed_dim: int = 512, + hidden_dim: int = 1024, + num_zones: int = 6, # hand, battlefield, graveyard, library, exile + max_cards_per_zone: int = 200, + output_dim: int = 2048, + ): + super().__init__() + + self.vocab_size = vocab_size + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.num_zones = num_zones + self.max_cards_per_zone = max_cards_per_zone + self.output_dim = output_dim + + # Card embedding layer + self.card_embedding = nn.Embedding(vocab_size, embed_dim) + + # Zone encoders (one for each zone type) + self.zone_encoders = nn.ModuleList( + [ + nn.LSTM( + embed_dim, + hidden_dim // 2, + batch_first=True, + bidirectional=True, + ) + for _ in range(num_zones) + ] + ) + + # Player state encoder + self.player_encoder = nn.Linear(20, hidden_dim) # life, mana, etc. + + # Global state encoder (turn, phase, priority, etc.) + # 4 scalar features + 7 phase one-hot = 11 total features + self.global_encoder = nn.Linear(11, hidden_dim) + + # Final combination layer + self.combiner = nn.Sequential( + nn.Linear(hidden_dim * (2 * num_zones + 2 + 1), hidden_dim * 2), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(hidden_dim * 2, output_dim), + nn.LayerNorm(output_dim), + ) + + def encode_zone(self, zone: Zone, zone_idx: int) -> torch.Tensor: + """Encode a single zone into a fixed-size representation. + + Args: + zone: The zone to encode + zone_idx: Index of the zone type (for selecting the right encoder) + + Returns: + Fixed-size tensor representing the zone + """ + # Convert cards to IDs and pad/truncate to fixed size + card_ids = [] + for card in zone.cards[: self.max_cards_per_zone]: + card_ids.append(card.card_id or 0) # Use 0 for unknown cards + + # Pad if necessary + while len(card_ids) < self.max_cards_per_zone: + card_ids.append(0) # Padding token + + # Convert to tensor and embed + card_tensor = torch.tensor(card_ids, dtype=torch.long).unsqueeze(0) + embedded_cards = self.card_embedding(card_tensor) + + # Encode with LSTM + lstm_out, (hidden, _) = self.zone_encoders[zone_idx](embedded_cards) + + # Use final hidden state as zone representation + result: torch.Tensor = hidden.view(-1) # Flatten + return result + + def encode_player(self, player: Player) -> torch.Tensor: + """Encode player state (life, mana, etc.) into a tensor.""" + features = [ + float(player.life) / 20.0, # Normalized life + float(player.lands_played_this_turn), + float(player.hand.size()) / 10.0, # Normalized hand size + float(player.battlefield.size()) / 20.0, # Normalized board size + float(player.graveyard.size()) / 50.0, # Normalized graveyard size + ] + + # Add mana pool features (WUBRG + colorless) + mana_colors = ["W", "U", "B", "R", "G", "C"] + for color in mana_colors: + features.append(float(player.mana_pool.get(color, 0)) / 10.0) + + # Pad to expected size + while len(features) < 20: + features.append(0.0) + + result: torch.Tensor = self.player_encoder( + torch.tensor(features, dtype=torch.float32) + ) + return result + + def forward(self, game_state: GameState) -> torch.Tensor: + """Encode a complete game state into a fixed-size tensor. + + Args: + game_state: The game state to encode + + Returns: + Fixed-size tensor representation of the game state + """ + encoded_parts = [] + + # Encode zones for both players + zone_types = [ + "hand", + "battlefield", + "graveyard", + "library", + "exile", + "command_zone", + ] + for player in game_state.players: + for zone_idx, zone_name in enumerate(zone_types): + zone = getattr(player, zone_name) + zone_encoding = self.encode_zone(zone, zone_idx) + encoded_parts.append(zone_encoding) + + # Encode player states + for player in game_state.players: + player_encoding = self.encode_player(player) + encoded_parts.append(player_encoding) + + # Encode global game state + global_features = [ + float(game_state.turn_number) / 20.0, # Normalized turn + float(game_state.active_player), + float(game_state.priority_player), + float(len(game_state.stack)) / 10.0, # Normalized stack size + ] + + # Encode phase as one-hot + phases = ["untap", "upkeep", "draw", "main", "combat", "main2", "end"] + for phase in phases: + global_features.append(1.0 if game_state.phase == phase else 0.0) + + global_encoding = self.global_encoder( + torch.tensor(global_features, dtype=torch.float32) + ) + encoded_parts.append(global_encoding) + + # Combine all encodings + combined = torch.cat(encoded_parts, dim=0) + result: torch.Tensor = self.combiner(combined.unsqueeze(0)).squeeze(0) + return result + + +# Factory functions for creating common game states + + +def create_empty_game_state() -> GameState: + """Create an empty game state for testing.""" + player0 = Player(player_id=0) + player1 = Player(player_id=1) + return GameState(players=(player0, player1)) + + +def create_standard_game_start() -> GameState: + """Create a game state representing the start of a standard game.""" + # TODO: Implement proper game start with shuffled libraries, hands, etc. + return create_empty_game_state() diff --git a/src/manamind/core/state_manager.py b/src/manamind/core/state_manager.py new file mode 100644 index 0000000..012815f --- /dev/null +++ b/src/manamind/core/state_manager.py @@ -0,0 +1,582 @@ +"""Optimized state management for efficient MCTS and training. + +This module provides memory-efficient game state management with copy-on-write +semantics, incremental updates, and caching for high-performance training. +""" + +import copy +import threading +import weakref +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Tuple + +import torch + +from manamind.core.action import Action +from manamind.core.game_state import GameState, create_empty_game_state + + +@dataclass +class StateHash: + """Lightweight hash representation of game state.""" + + turn_number: int + phase: str + active_player: int + priority_player: int + player_lives: Tuple[int, int] + zone_sizes: Tuple[int, ...] # Sizes of all zones + stack_size: int + + def __post_init__(self) -> None: + self._hash = hash( + ( + self.turn_number, + self.phase, + self.active_player, + self.priority_player, + self.player_lives, + self.zone_sizes, + self.stack_size, + ) + ) + + def __hash__(self) -> int: + return self._hash + + def __eq__(self, other: object) -> bool: + if not isinstance(other, StateHash): + return False + return self._hash == other._hash + + +@dataclass +class StateDelta: + """Represents changes between game states for incremental updates.""" + + action: Action + timestamp: float + + # Changed components + player_changes: Dict[int, Dict[str, Any]] = field(default_factory=dict) + zone_changes: Dict[Tuple[int, str], Dict[str, Any]] = field( + default_factory=dict + ) + global_changes: Dict[str, Any] = field(default_factory=dict) + + # For efficient rollback + reverse_delta: Optional["StateDelta"] = None + + +class CopyOnWriteGameState: + """Game state with copy-on-write semantics for memory efficiency.""" + + def __init__(self, base_state: Optional["CopyOnWriteGameState"] = None): + self._refs_lock = threading.Lock() + + if base_state is None: + # Create new state + self._data = create_empty_game_state() + self._ref_count = 1 + self._is_cow = False + self._parent = None + else: + # Share data with parent + self._data = base_state._data + with base_state._refs_lock: + base_state._ref_count += 1 + self._ref_count = 1 + self._is_cow = True + self._parent = weakref.ref(base_state) + + def _ensure_writable(self) -> None: + """Ensure this state is writable (copy-on-write).""" + if self._is_cow: + # Make a deep copy + self._data = copy.deepcopy(self._data) + self._is_cow = False + self._parent = None + + def __getattr__(self, name: str) -> Any: + """Delegate attribute access to underlying state.""" + if name.startswith("_"): + return object.__getattribute__(self, name) + return getattr(self._data, name) + + def __setattr__(self, name: str, value: Any) -> None: + """Handle attribute setting with COW semantics.""" + if name.startswith("_"): + object.__setattr__(self, name, value) + else: + self._ensure_writable() + setattr(self._data, name, value) + + def copy(self) -> "CopyOnWriteGameState": + """Create a COW copy.""" + return CopyOnWriteGameState(self) + + def get_hash(self) -> StateHash: + """Get lightweight hash of current state.""" + return StateHash( + turn_number=self._data.turn_number, + phase=self._data.phase, + active_player=self._data.active_player, + priority_player=self._data.priority_player, + player_lives=( + int(self._data.players[0].life), + int(self._data.players[1].life), + ), + zone_sizes=tuple( + len(getattr(player, zone).cards) + for player in self._data.players + for zone in [ + "hand", + "battlefield", + "graveyard", + "library", + "exile", + ] + ), + stack_size=len(self._data.stack), + ) + + +class IncrementalStateManager: + """Manages incremental state updates for efficient MCTS rollouts.""" + + def __init__(self, base_state: GameState): + self.base_state = base_state + self.delta_stack: List[StateDelta] = [] + self._current_state: Optional[GameState] = None + self._state_cache: Dict[int, GameState] = {} + # Cache states at different depths + self._max_cache_size = 100 + + def push_action(self, action: Action) -> GameState: + """Apply action and return new state.""" + # Compute delta + old_state = self.current_state() + new_state = action.execute(old_state) + + delta = self._compute_delta(action, old_state, new_state) + self.delta_stack.append(delta) + + # Cache state if stack not too deep + if len(self.delta_stack) < self._max_cache_size: + self._state_cache[len(self.delta_stack)] = new_state + + self._current_state = new_state + return new_state + + def pop_action(self) -> Optional[GameState]: + """Rollback last action.""" + if not self.delta_stack: + return None + + delta = self.delta_stack.pop() + + # Check cache first + cache_key = len(self.delta_stack) + if cache_key in self._state_cache: + self._current_state = self._state_cache[cache_key] + return self._current_state + + # Apply reverse delta + if delta.reverse_delta: + current = self.current_state() + self._current_state = self._apply_reverse_delta( + current, delta.reverse_delta + ) + return self._current_state + + # Fallback: recompute from base + self._current_state = None + return self.current_state() + + def current_state(self) -> GameState: + """Get current state by applying all deltas.""" + if self._current_state is not None: + return self._current_state + + # Check cache + cache_key = len(self.delta_stack) + if cache_key in self._state_cache: + self._current_state = self._state_cache[cache_key] + return self._current_state + + # Recompute from base + state = copy.deepcopy(self.base_state) + for delta in self.delta_stack: + state = delta.action.execute(state) + + self._current_state = state + return state + + def _compute_delta( + self, action: Action, old_state: GameState, new_state: GameState + ) -> StateDelta: + """Compute delta between states.""" + delta = StateDelta(action=action, timestamp=action.timestamp) + + # Compare players + for i, (old_player, new_player) in enumerate( + zip(old_state.players, new_state.players) + ): + changes: Dict[str, Any] = {} + if old_player.life != new_player.life: + changes["life"] = (old_player.life, new_player.life) + if old_player.mana_pool != new_player.mana_pool: + changes["mana_pool"] = ( + dict(old_player.mana_pool), + dict(new_player.mana_pool), + ) + if changes: + delta.player_changes[i] = changes + + # Compare zones + zone_names = ["hand", "battlefield", "graveyard", "library", "exile"] + for i, (old_player, new_player) in enumerate( + zip(old_state.players, new_state.players) + ): + for zone_name in zone_names: + old_zone = getattr(old_player, zone_name) + new_zone = getattr(new_player, zone_name) + + if len(old_zone.cards) != len(new_zone.cards): + delta.zone_changes[(i, zone_name)] = { + "size_change": len(new_zone.cards) + - len(old_zone.cards) + } + + # Global changes + if old_state.phase != new_state.phase: + delta.global_changes["phase"] = (old_state.phase, new_state.phase) + if old_state.priority_player != new_state.priority_player: + delta.global_changes["priority"] = ( + old_state.priority_player, + new_state.priority_player, + ) + + return delta + + def _apply_reverse_delta( + self, state: GameState, reverse_delta: StateDelta + ) -> GameState: + """Apply reverse delta (not implemented - use recomputation).""" + # This is complex to implement correctly for all changes + # For now, fall back to recomputation + return self.current_state() + + def get_depth(self) -> int: + """Get current depth in the delta stack.""" + return len(self.delta_stack) + + def clear_cache(self) -> None: + """Clear state cache to free memory.""" + self._state_cache.clear() + self._current_state = None + + +class StatePool: + """Pool of reusable game state objects to reduce allocations.""" + + def __init__(self, initial_size: int = 100): + self._available_states: List[GameState] = [] + self._in_use: Set[int] = set() + self._lock = threading.Lock() + + # Pre-allocate states + for _ in range(initial_size): + state = create_empty_game_state() + self._available_states.append(state) + + def acquire(self) -> GameState: + """Get a state from the pool.""" + with self._lock: + if self._available_states: + state = self._available_states.pop() + self._in_use.add(id(state)) + return state + + # Pool exhausted, create new state + state = create_empty_game_state() + with self._lock: + self._in_use.add(id(state)) + return state + + def release(self, state: GameState) -> None: + """Return a state to the pool.""" + state_id = id(state) + with self._lock: + if state_id in self._in_use: + self._in_use.remove(state_id) + # Reset state to clean condition + self._reset_state(state) + self._available_states.append(state) + + def _reset_state(self, state: GameState) -> None: + """Reset state to initial condition.""" + state.turn_number = 1 + state.phase = "main" + state.active_player = 0 + state.priority_player = 0 + + for player in state.players: + player.life = 20 + player.mana_pool.clear() + player.lands_played_this_turn = 0 + + # Clear zones + for zone_name in [ + "hand", + "battlefield", + "graveyard", + "library", + "exile", + ]: + zone = getattr(player, zone_name) + zone.cards.clear() + + state.stack.clear() + state.history.clear() + + +class BatchStateProcessor: + """Process multiple states in parallel for training efficiency.""" + + def __init__(self, max_workers: int = 4, batch_size: int = 64): + self.max_workers = max_workers + self.batch_size = batch_size + self.state_pool = StatePool() + + def process_states_parallel( + self, states: List[GameState], operation: str, **kwargs: Any + ) -> List[Any]: + """Process multiple states in parallel.""" + + if operation == "encode": + return self._batch_encode(states, **kwargs) + elif operation == "legal_actions": + return self._batch_legal_actions(states, **kwargs) + elif operation == "evaluate": + return self._batch_evaluate(states, **kwargs) + else: + raise ValueError(f"Unknown operation: {operation}") + + def _batch_encode( + self, states: List[GameState], encoder: Any = None + ) -> List[torch.Tensor]: + """Encode multiple states in parallel.""" + if encoder is None: + raise ValueError("Encoder required for batch encoding") + + # Group states by structure for efficient batching + batches = self._create_batches(states) + + results = [] + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [] + for batch in batches: + future = executor.submit(encoder.encode_batch, batch) + futures.append(future) + + for future in futures: + batch_encodings = future.result() + results.extend(batch_encodings.unbind(0)) + + return results + + def _batch_legal_actions( + self, states: List[GameState], action_space: Any = None + ) -> List[List[Action]]: + """Generate legal actions for multiple states in parallel.""" + if action_space is None: + raise ValueError( + "ActionSpace required for legal action generation" + ) + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [ + executor.submit(action_space.get_legal_actions, state) + for state in states + ] + return [future.result() for future in futures] + + def _batch_evaluate( + self, states: List[GameState], evaluator: Any = None + ) -> List[float]: + """Evaluate multiple states in parallel.""" + if evaluator is None: + raise ValueError("Evaluator required for batch evaluation") + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [ + executor.submit(evaluator.evaluate_state, state) + for state in states + ] + return [future.result() for future in futures] + + def _create_batches( + self, states: List[GameState] + ) -> List[List[GameState]]: + """Group states into batches for processing.""" + batches = [] + for i in range(0, len(states), self.batch_size): + batch = states[i : i + self.batch_size] + batches.append(batch) + return batches + + +class TranspositionTable: + """Transposition table for caching game state evaluations in MCTS.""" + + def __init__(self, max_size: int = 100000): + self.max_size = max_size + self._table: Dict[StateHash, Dict[str, Any]] = {} + self._access_order: List[StateHash] = [] + self._lock = threading.RLock() + + def get(self, state_hash: StateHash, key: str) -> Optional[Any]: + """Get cached value for state and key.""" + with self._lock: + if state_hash in self._table and key in self._table[state_hash]: + # Update access order (LRU) + if state_hash in self._access_order: + self._access_order.remove(state_hash) + self._access_order.append(state_hash) + + return self._table[state_hash][key] + return None + + def put(self, state_hash: StateHash, key: str, value: Any) -> None: + """Store value for state and key.""" + with self._lock: + if state_hash not in self._table: + self._table[state_hash] = {} + + # Evict oldest entry if table is full + if len(self._table) > self.max_size: + oldest = self._access_order.pop(0) + del self._table[oldest] + + self._table[state_hash][key] = value + + # Update access order + if state_hash in self._access_order: + self._access_order.remove(state_hash) + self._access_order.append(state_hash) + + def get_visit_count(self, state_hash: StateHash) -> int: + """Get MCTS visit count for state.""" + return self.get(state_hash, "visit_count") or 0 + + def get_value_estimate(self, state_hash: StateHash) -> Optional[float]: + """Get value estimate for state.""" + return self.get(state_hash, "value_estimate") + + def get_action_values( + self, state_hash: StateHash + ) -> Optional[Dict[str, float]]: + """Get action value estimates for state.""" + return self.get(state_hash, "action_values") + + def update_mcts_data( + self, + state_hash: StateHash, + visit_count: int, + value_estimate: float, + action_values: Dict[str, float], + ) -> None: + """Update MCTS data for state.""" + self.put(state_hash, "visit_count", visit_count) + self.put(state_hash, "value_estimate", value_estimate) + self.put(state_hash, "action_values", action_values) + + def clear(self) -> None: + """Clear all cached data.""" + with self._lock: + self._table.clear() + self._access_order.clear() + + def get_stats(self) -> Dict[str, int]: + """Get cache statistics.""" + with self._lock: + return { + "size": len(self._table), + "max_size": self.max_size, + "hit_rate": 0, # TODO: Track hit rate + } + + +class StateCompressionManager: + """Manages state compression for memory efficiency during training.""" + + def __init__(self) -> None: + self.compressor = StateCompressor() + self._compressed_cache: Dict[StateHash, bytes] = {} + + def compress_state(self, state: GameState) -> bytes: + """Compress state to bytes.""" + return self.compressor.compress(state) + + def decompress_state(self, data: bytes) -> GameState: + """Decompress bytes to state.""" + return self.compressor.decompress(data) + + def cache_compressed_state(self, state: GameState) -> StateHash: + """Cache compressed state and return hash.""" + state_hash = StateHash( + turn_number=state.turn_number, + phase=state.phase, + active_player=state.active_player, + priority_player=state.priority_player, + player_lives=(state.players[0].life, state.players[1].life), + zone_sizes=tuple( + len(getattr(player, zone).cards) + for player in state.players + for zone in [ + "hand", + "battlefield", + "graveyard", + "library", + "exile", + ] + ), + stack_size=len(state.stack), + ) + + compressed_data = self.compress_state(state) + self._compressed_cache[state_hash] = compressed_data + + return state_hash + + def retrieve_state(self, state_hash: StateHash) -> Optional[GameState]: + """Retrieve state from compressed cache.""" + if state_hash in self._compressed_cache: + compressed_data = self._compressed_cache[state_hash] + return self.decompress_state(compressed_data) + return None + + +class StateCompressor: + """Handles compression/decompression of game states.""" + + def compress(self, state: GameState) -> bytes: + """Compress game state to bytes.""" + # Simple serialization for now + # TODO: Implement more sophisticated compression + import gzip + import pickle + + data = pickle.dumps(state) + return gzip.compress(data) + + def decompress(self, data: bytes) -> GameState: + """Decompress bytes to game state.""" + import gzip + import pickle + + decompressed = gzip.decompress(data) + result: GameState = pickle.loads(decompressed) + return result diff --git a/src/manamind/evaluation/__init__.py b/src/manamind/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/manamind/evaluation/evaluator.py b/src/manamind/evaluation/evaluator.py new file mode 100644 index 0000000..e69de29 diff --git a/src/manamind/forge_interface/__init__.py b/src/manamind/forge_interface/__init__.py new file mode 100644 index 0000000..927e4b5 --- /dev/null +++ b/src/manamind/forge_interface/__init__.py @@ -0,0 +1,17 @@ +"""Forge game engine interface for ManaMind. + +This module provides the Python-Java bridge for communicating with the Forge +MTG engine. This is critical for Phase 1 training where the agent learns +to play against Forge's built-in AI. +""" + +from manamind.forge_interface.forge_client import ForgeClient + +# from manamind.forge_interface.game_runner import ForgeGameRunner +# from manamind.forge_interface.state_parser import ForgeStateParser + +__all__ = [ + "ForgeClient", + # "ForgeGameRunner", + # "ForgeStateParser", +] diff --git a/src/manamind/forge_interface/forge_client.py b/src/manamind/forge_interface/forge_client.py new file mode 100644 index 0000000..3097b8e --- /dev/null +++ b/src/manamind/forge_interface/forge_client.py @@ -0,0 +1,446 @@ +"""Forge game engine client for Python-Java communication. + +This module handles the low-level communication with the Forge game engine +using Py4J or JPype for the Python-Java bridge. +""" + +import json +import logging +import subprocess +import time +from pathlib import Path +from subprocess import Popen +from typing import Any, Dict, List, Optional + +try: + from py4j.java_gateway import GatewayParameters, JavaGateway + + PY4J_AVAILABLE = True +except ImportError: + PY4J_AVAILABLE = False + JavaGateway = None + +try: + import jpype + import jpype.imports + + JPYPE_AVAILABLE = True +except ImportError: + JPYPE_AVAILABLE = False + jpype = None + + +logger = logging.getLogger(__name__) + + +class ForgeConnectionError(Exception): + """Raised when connection to Forge fails.""" + + pass + + +class ForgeClient: + """Client for communicating with the Forge game engine. + + This class handles: + 1. Starting/stopping Forge instances + 2. Sending commands to Forge + 3. Receiving game state updates + 4. Managing multiple Forge processes for parallel training + """ + + def __init__( + self, + forge_path: Optional[Path] = None, + java_opts: Optional[List[str]] = None, + port: int = 25333, + timeout: float = 30.0, + use_py4j: bool = True, + ): + """Initialize Forge client. + + Args: + forge_path: Path to Forge installation directory + java_opts: Java options for running Forge + port: Port for communication with Forge + timeout: Connection timeout in seconds + use_py4j: Whether to use Py4J (if False, uses JPype) + """ + self.forge_path = forge_path or self._find_forge_installation() + self.java_opts = java_opts or ["-Xmx4G", "-server"] + self.port = port + self.timeout = timeout + self.use_py4j = use_py4j and PY4J_AVAILABLE + + # Runtime state + self.forge_process: Optional[Popen[Any]] = None + self.gateway: Optional[Any] = None + self.forge_api: Optional[Any] = None + self.is_connected = False + + # Validate setup + self._validate_setup() + + def _find_forge_installation(self) -> Path: + """Try to find Forge installation automatically.""" + # Common Forge locations + possible_paths = [ + Path("./forge"), + Path("./forge-gui"), + Path("/opt/forge"), + Path.home() / "forge", + Path.home() / "Downloads" / "forge-gui", + ] + + for path in possible_paths: + if path.exists() and (path / "forge-gui.jar").exists(): + logger.info(f"Found Forge installation at {path}") + return path + + # If not found, return default path (user will need to install Forge) + return Path("./forge") + + def _validate_setup(self) -> None: + """Validate that required dependencies are available.""" + if not self.use_py4j and not JPYPE_AVAILABLE: + msg = ( + "Neither Py4J nor JPype is available. Please install one:\n" + "pip install py4j # or\n" + "pip install JPype1" + ) + raise ForgeConnectionError(msg) + + if not self.forge_path.exists(): + logger.warning( + f"Forge installation not found at {self.forge_path}. " + "Please ensure Forge is installed and the path is correct." + ) + + def start_forge(self, headless: bool = True) -> None: + """Start a Forge game engine instance. + + Args: + headless: Whether to run Forge without GUI + + Raises: + ForgeConnectionError: If Forge fails to start + """ + if self.is_connected: + logger.warning("Forge is already running") + return + + logger.info(f"Starting Forge on port {self.port}") + + # Build command to start Forge + forge_jar = self.forge_path / "forge-gui.jar" + if not forge_jar.exists(): + raise ForgeConnectionError(f"Forge JAR not found: {forge_jar}") + + cmd = ["java"] + self.java_opts + + if headless: + cmd.extend(["-Djava.awt.headless=true", "-Dforge.headless=true"]) + + # Add ManaMind API mode + cmd.extend( + [ + "-Dforge.api.mode=true", + f"-Dforge.api.port={self.port}", + "-jar", + str(forge_jar), + ] + ) + + try: + # Start Forge process + self.forge_process = subprocess.Popen( + cmd, + cwd=self.forge_path, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Wait for Forge to start + self._wait_for_forge_startup() + + # Establish communication + if self.use_py4j: + self._connect_py4j() + else: + self._connect_jpype() + + logger.info("Forge started successfully") + self.is_connected = True + + except Exception as e: + self.stop_forge() + raise ForgeConnectionError(f"Failed to start Forge: {e}") + + def _wait_for_forge_startup(self) -> None: + """Wait for Forge to finish starting up.""" + start_time = time.time() + + while time.time() - start_time < self.timeout: + if ( + self.forge_process is not None + and self.forge_process.poll() is not None + ): + # Process has terminated + stdout, stderr = self.forge_process.communicate() + raise ForgeConnectionError( + f"Forge process terminated unexpectedly:\n" + f"STDOUT: {stdout}\nSTDERR: {stderr}" + ) + + # Check if Forge is ready (look for startup message in logs) + # TODO: Implement proper startup detection + time.sleep(1.0) + + if time.time() - start_time >= self.timeout: + raise ForgeConnectionError("Forge startup timed out") + + def _connect_py4j(self) -> None: + """Connect to Forge using Py4J.""" + if not PY4J_AVAILABLE: + raise ForgeConnectionError("Py4J not available") + + try: + gateway_params = GatewayParameters( + port=self.port, auto_convert=True + ) + self.gateway = JavaGateway(gateway_parameters=gateway_params) + + # Get the Forge API object + self.forge_api = self.gateway.entry_point + + # Test connection + version = self.forge_api.getVersion() + logger.info(f"Connected to Forge version {version} via Py4J") + + except Exception as e: + raise ForgeConnectionError(f"Failed to connect via Py4J: {e}") + + def _connect_jpype(self) -> None: + """Connect to Forge using JPype.""" + if not JPYPE_AVAILABLE: + raise ForgeConnectionError("JPype not available") + + try: + # Start JVM if not already started + if not jpype.isJVMStarted(): + jpype.startJVM( + jpype.getDefaultJVMPath(), + *self.java_opts, + classpath=[str(self.forge_path / "forge-gui.jar")], + ) + + # Import Forge API classes + from forge.api import ManaMindAPI + + self.forge_api = ManaMindAPI() + + # Test connection + version = self.forge_api.getVersion() + logger.info(f"Connected to Forge version {version} via JPype") + + except Exception as e: + raise ForgeConnectionError(f"Failed to connect via JPype: {e}") + + def stop_forge(self) -> None: + """Stop the Forge game engine instance.""" + logger.info("Stopping Forge") + + # Close API connection + if self.gateway: + try: + self.gateway.shutdown() + except Exception: + pass + self.gateway = None + + if JPYPE_AVAILABLE and jpype.isJVMStarted(): + try: + jpype.shutdownJVM() + except Exception: + pass + + # Terminate Forge process + if self.forge_process: + try: + self.forge_process.terminate() + self.forge_process.wait(timeout=5.0) + except subprocess.TimeoutExpired: + logger.warning( + "Forge did not terminate gracefully, killing process" + ) + self.forge_process.kill() + self.forge_process.wait() + except Exception as e: + logger.error(f"Error stopping Forge process: {e}") + + self.forge_process = None + + self.forge_api = None + self.is_connected = False + logger.info("Forge stopped") + + def create_game( + self, + deck1_path: str, + deck2_path: str, + game_format: str = "Constructed", + ) -> str: + """Create a new game in Forge. + + Args: + deck1_path: Path to player 1's deck file + deck2_path: Path to player 2's deck file + game_format: Game format (Constructed, Limited, etc.) + + Returns: + Game ID for the created game + + Raises: + ForgeConnectionError: If game creation fails + """ + if not self.is_connected: + raise ForgeConnectionError("Not connected to Forge") + + try: + if self.forge_api is None: + raise ForgeConnectionError("Forge API not initialized") + game_id = self.forge_api.createGame( + deck1_path, deck2_path, game_format + ) + logger.info(f"Created game {game_id}") + return str(game_id) + + except Exception as e: + raise ForgeConnectionError(f"Failed to create game: {e}") + + def get_game_state(self, game_id: str) -> Dict[str, Any]: + """Get the current game state from Forge. + + Args: + game_id: ID of the game + + Returns: + Game state as a dictionary + + Raises: + ForgeConnectionError: If getting state fails + """ + if not self.is_connected: + raise ForgeConnectionError("Not connected to Forge") + + try: + if self.forge_api is None: + raise ForgeConnectionError("Forge API not initialized") + state_json = self.forge_api.getGameState(game_id) + result: Dict[str, Any] = json.loads(str(state_json)) + return result + + except Exception as e: + raise ForgeConnectionError(f"Failed to get game state: {e}") + + def send_action(self, game_id: str, action_data: Dict[str, Any]) -> bool: + """Send an action to Forge. + + Args: + game_id: ID of the game + action_data: Action data as dictionary + + Returns: + True if action was accepted + + Raises: + ForgeConnectionError: If sending action fails + """ + if not self.is_connected: + raise ForgeConnectionError("Not connected to Forge") + + try: + if self.forge_api is None: + raise ForgeConnectionError("Forge API not initialized") + action_json = json.dumps(action_data) + result = self.forge_api.sendAction(game_id, action_json) + return bool(result) + + except Exception as e: + raise ForgeConnectionError(f"Failed to send action: {e}") + + def get_legal_actions(self, game_id: str) -> List[Dict[str, Any]]: + """Get legal actions for the current player. + + Args: + game_id: ID of the game + + Returns: + List of legal actions as dictionaries + + Raises: + ForgeConnectionError: If getting actions fails + """ + if not self.is_connected: + raise ForgeConnectionError("Not connected to Forge") + + try: + if self.forge_api is None: + raise ForgeConnectionError("Forge API not initialized") + actions_json = self.forge_api.getLegalActions(game_id) + result: List[Dict[str, Any]] = json.loads(str(actions_json)) + return result + + except Exception as e: + raise ForgeConnectionError(f"Failed to get legal actions: {e}") + + def is_game_over(self, game_id: str) -> bool: + """Check if a game has ended. + + Args: + game_id: ID of the game + + Returns: + True if game is over + """ + if not self.is_connected: + return True + + try: + if self.forge_api is None: + return True + return bool(self.forge_api.isGameOver(game_id)) + except Exception as e: + logger.error(f"Error checking if game is over: {e}") + return True + + def get_winner(self, game_id: str) -> Optional[int]: + """Get the winner of a finished game. + + Args: + game_id: ID of the game + + Returns: + Winner player ID (0 or 1), or None if draw/ongoing + """ + if not self.is_connected: + return None + + try: + if self.forge_api is None: + return None + result = self.forge_api.getWinner(game_id) + return int(result) if result is not None else None + except Exception as e: + logger.error(f"Error getting winner: {e}") + return None + + def __enter__(self) -> "ForgeClient": + """Context manager entry.""" + self.start_forge() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Context manager exit.""" + self.stop_forge() diff --git a/src/manamind/forge_interface/game_runner.py b/src/manamind/forge_interface/game_runner.py new file mode 100644 index 0000000..e69de29 diff --git a/src/manamind/forge_interface/state_parser.py b/src/manamind/forge_interface/state_parser.py new file mode 100644 index 0000000..e69de29 diff --git a/src/manamind/models/__init__.py b/src/manamind/models/__init__.py new file mode 100644 index 0000000..9cd49b6 --- /dev/null +++ b/src/manamind/models/__init__.py @@ -0,0 +1,16 @@ +"""Neural network models for ManaMind AI agent. + +This module contains the neural network architectures used for: +- Policy networks (action prediction) +- Value networks (position evaluation) +- Combined policy-value networks (AlphaZero style) +""" + +from manamind.models.components import AttentionLayer, ResidualBlock +from manamind.models.policy_value_network import PolicyValueNetwork + +__all__ = [ + "PolicyValueNetwork", + "ResidualBlock", + "AttentionLayer", +] diff --git a/src/manamind/models/components.py b/src/manamind/models/components.py new file mode 100644 index 0000000..9ca6f83 --- /dev/null +++ b/src/manamind/models/components.py @@ -0,0 +1,381 @@ +"""Neural network components and building blocks. + +This module contains reusable neural network components used throughout +the ManaMind architecture. +""" + +import math +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + """Residual block with layer normalization and dropout. + + This is a key building block for the deep networks, helping with + gradient flow and training stability. + """ + + def __init__( + self, + hidden_dim: int, + dropout_rate: float = 0.1, + activation: str = "relu", + ): + """Initialize the residual block. + + Args: + hidden_dim: Hidden dimension + dropout_rate: Dropout rate for regularization + activation: Activation function ('relu', 'gelu', 'swish') + """ + super().__init__() + + self.layer_norm1 = nn.LayerNorm(hidden_dim) + self.linear1 = nn.Linear(hidden_dim, hidden_dim * 4) # Expand + self.dropout1 = nn.Dropout(dropout_rate) + + self.layer_norm2 = nn.LayerNorm(hidden_dim * 4) + self.linear2 = nn.Linear(hidden_dim * 4, hidden_dim) # Contract + self.dropout2 = nn.Dropout(dropout_rate) + + # Activation function + if activation == "relu": + self.activation: Union[nn.ReLU, nn.GELU, nn.SiLU] = nn.ReLU() + elif activation == "gelu": + self.activation = nn.GELU() + elif activation == "swish": + self.activation = nn.SiLU() # Swish = SiLU in PyTorch + else: + raise ValueError(f"Unknown activation: {activation}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass with residual connection. + + Args: + x: Input tensor [batch_size, hidden_dim] + + Returns: + Output tensor with same shape as input + """ + residual = x + + # First transformation + x = self.layer_norm1(x) + x = self.linear1(x) + x = self.activation(x) + x = self.dropout1(x) + + # Second transformation + x = self.layer_norm2(x) + x = self.linear2(x) + x = self.dropout2(x) + + # Residual connection + return x + residual + + +class AttentionLayer(nn.Module): + """Multi-head attention layer for processing sequences. + + This can be useful for attending to different parts of the game state, + such as cards in hand, battlefield, etc. + """ + + def __init__( + self, + hidden_dim: int, + num_heads: int = 8, + dropout_rate: float = 0.1, + use_bias: bool = True, + ): + """Initialize the attention layer. + + Args: + hidden_dim: Hidden dimension (must be divisible by num_heads) + num_heads: Number of attention heads + dropout_rate: Dropout rate for attention weights + use_bias: Whether to use bias in linear layers + """ + super().__init__() + + if hidden_dim % num_heads != 0: + msg = f"hidden_dim ({hidden_dim}) must be divisible by num_heads ({num_heads})" # noqa: E501 + raise ValueError(msg) + + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + + # Linear projections for Q, K, V + self.query_projection = nn.Linear( + hidden_dim, hidden_dim, bias=use_bias + ) + self.key_projection = nn.Linear(hidden_dim, hidden_dim, bias=use_bias) + self.value_projection = nn.Linear( + hidden_dim, hidden_dim, bias=use_bias + ) + + # Output projection + self.output_projection = nn.Linear( + hidden_dim, hidden_dim, bias=use_bias + ) + + # Dropout and layer norm + self.attention_dropout = nn.Dropout(dropout_rate) + self.output_dropout = nn.Dropout(dropout_rate) + self.layer_norm = nn.LayerNorm(hidden_dim) + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Forward pass through multi-head attention. + + Args: + x: Input tensor [batch_size, seq_len, hidden_dim] + mask: Optional attention mask [batch_size, seq_len, seq_len] + + Returns: + Output tensor [batch_size, seq_len, hidden_dim] + """ + batch_size, seq_len, _ = x.shape + residual = x + + # Apply layer norm first (pre-norm architecture) + x = self.layer_norm(x) + + # Compute Q, K, V + queries = self.query_projection(x) + keys = self.key_projection(x) + values = self.value_projection(x) + + # Reshape for multi-head attention + queries = queries.view( + batch_size, seq_len, self.num_heads, self.head_dim + ).transpose(1, 2) + keys = keys.view( + batch_size, seq_len, self.num_heads, self.head_dim + ).transpose(1, 2) + values = values.view( + batch_size, seq_len, self.num_heads, self.head_dim + ).transpose(1, 2) + + # Compute attention scores + attention_scores = ( + torch.matmul(queries, keys.transpose(-2, -1)) * self.scale + ) + + # Apply mask if provided + if mask is not None: + # Expand mask for multi-head attention + mask = mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) + attention_scores.masked_fill_(mask == 0, float("-inf")) + + # Apply softmax + attention_weights = F.softmax(attention_scores, dim=-1) + attention_weights = self.attention_dropout(attention_weights) + + # Apply attention to values + attention_output = torch.matmul(attention_weights, values) + + # Reshape back + attention_output = ( + attention_output.transpose(1, 2) + .contiguous() + .view(batch_size, seq_len, self.hidden_dim) + ) + + # Output projection + output = self.output_projection(attention_output) + output = self.output_dropout(output) + + # Residual connection + output_tensor: torch.Tensor = output + residual + return output_tensor + + +class PositionalEncoding(nn.Module): + """Positional encoding for sequence models. + + Adds positional information to embeddings, useful for processing + sequences like game history or card sequences. + """ + + def __init__( + self, + hidden_dim: int, + max_seq_len: int = 5000, + dropout_rate: float = 0.1, + ): + """Initialize positional encoding. + + Args: + hidden_dim: Hidden dimension + max_seq_len: Maximum sequence length + dropout_rate: Dropout rate + """ + super().__init__() + + self.dropout = nn.Dropout(dropout_rate) + + # Create positional encoding matrix + pe = torch.zeros(max_seq_len, hidden_dim) + position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) + + div_term = torch.exp( + torch.arange(0, hidden_dim, 2).float() + * (-math.log(10000.0) / hidden_dim) + ) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + self.register_buffer("pe", pe.unsqueeze(0)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Add positional encoding to input. + + Args: + x: Input tensor [batch_size, seq_len, hidden_dim] + + Returns: + Input with positional encoding added + """ + pe_buffer = self.pe + assert isinstance( + pe_buffer, torch.Tensor + ), "PE buffer should be a tensor" + pe_slice: torch.Tensor = pe_buffer[:, : x.size(1)] + x = x + pe_slice + result: torch.Tensor = self.dropout(x) + return result + + +class CardEmbedding(nn.Module): + """Specialized embedding layer for Magic: The Gathering cards. + + This embedding layer combines multiple aspects of a card: + - Card identity (name/ID) + - Mana cost + - Card type + - Power/toughness (for creatures) + """ + + def __init__( + self, + vocab_size: int, + embed_dim: int, + max_mana_cost: int = 20, + num_card_types: int = 100, + max_power_toughness: int = 20, + ): + """Initialize card embedding. + + Args: + vocab_size: Size of card vocabulary + embed_dim: Embedding dimension + max_mana_cost: Maximum mana cost to handle + num_card_types: Number of different card types + max_power_toughness: Maximum power/toughness value + """ + super().__init__() + + self.embed_dim = embed_dim + + # Main card embedding + self.card_embedding = nn.Embedding(vocab_size, embed_dim // 2) + + # Mana cost embedding + self.mana_embedding = nn.Embedding(max_mana_cost + 1, embed_dim // 8) + + # Card type embedding + self.type_embedding = nn.Embedding(num_card_types, embed_dim // 8) + + # Power/toughness embedding (for creatures) + self.power_embedding = nn.Embedding( + max_power_toughness + 1, embed_dim // 16 + ) + self.toughness_embedding = nn.Embedding( + max_power_toughness + 1, embed_dim // 16 + ) + + # Combine embeddings + self.combiner = nn.Linear( + embed_dim // 2 + + embed_dim // 8 + + embed_dim // 8 + + embed_dim // 16 + + embed_dim // 16, + embed_dim, + ) + + def forward( + self, + card_ids: torch.Tensor, + mana_costs: torch.Tensor, + card_types: torch.Tensor, + powers: torch.Tensor, + toughnesses: torch.Tensor, + ) -> torch.Tensor: + """Forward pass through card embedding. + + Args: + card_ids: Card ID tensor [batch_size, num_cards] + mana_costs: Mana cost tensor [batch_size, num_cards] + card_types: Card type tensor [batch_size, num_cards] + powers: Power tensor [batch_size, num_cards] + toughnesses: Toughness tensor [batch_size, num_cards] + + Returns: + Combined card embeddings [batch_size, num_cards, embed_dim] + """ + # Get individual embeddings + card_emb = self.card_embedding(card_ids) + mana_emb = self.mana_embedding(mana_costs) + type_emb = self.type_embedding(card_types) + power_emb = self.power_embedding(powers) + tough_emb = self.toughness_embedding(toughnesses) + + # Concatenate and combine + combined = torch.cat( + [card_emb, mana_emb, type_emb, power_emb, tough_emb], dim=-1 + ) + result: torch.Tensor = self.combiner(combined) + return result + + +class GatingMechanism(nn.Module): + """Gating mechanism for controlling information flow. + + This can be useful for deciding which parts of the game state + are most relevant for the current decision. + """ + + def __init__(self, hidden_dim: int): + """Initialize gating mechanism. + + Args: + hidden_dim: Hidden dimension + """ + super().__init__() + + self.gate_projection = nn.Linear(hidden_dim, hidden_dim) + self.content_projection = nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply gating mechanism. + + Args: + x: Input tensor [batch_size, hidden_dim] + + Returns: + Gated output tensor [batch_size, hidden_dim] + """ + gate = torch.sigmoid(self.gate_projection(x)) + content = self.content_projection(x) + result: torch.Tensor = gate * content + return result diff --git a/src/manamind/models/enhanced_encoder.py b/src/manamind/models/enhanced_encoder.py new file mode 100644 index 0000000..8331a22 --- /dev/null +++ b/src/manamind/models/enhanced_encoder.py @@ -0,0 +1,577 @@ +"""Enhanced neural network encoder for comprehensive MTG game states. + +This module provides advanced encoding capabilities for the full complexity of +Magic: The Gathering, including multi-modal encoding, attention mechanisms, +and optimized representations. +""" + +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch +import torch.nn as nn + +from manamind.core.game_state import Card, GameState, Player, Zone + + +@dataclass +class EncoderConfig: + """Configuration for the enhanced game state encoder.""" + + # Card vocabulary and embeddings + card_vocab_size: int = 50000 + embed_dim: int = 512 + + # Architecture dimensions + hidden_dim: int = 1024 + output_dim: int = 2048 + + # Attention settings + num_heads: int = 8 + num_layers: int = 4 + + # Zone settings + max_cards_per_zone: int = 200 + num_zones: int = 6 + + # Optimization + dropout: float = 0.1 + use_attention: bool = True + use_layer_norm: bool = True + + +class CardEmbeddingSystem(nn.Module): + """Advanced card embedding with structural and semantic features.""" + + def __init__(self, vocab_size: int, embed_dim: int): + super().__init__() + self.vocab_size = vocab_size + self.embed_dim = embed_dim + + # Core embeddings + self.card_embedding = nn.Embedding(vocab_size, embed_dim // 4) + self.type_embedding = nn.Embedding(100, embed_dim // 8) # Card types + self.cost_embedding = nn.Embedding(50, embed_dim // 8) # Mana costs + + # Structural feature encoding + self.power_embedding = nn.Linear(1, embed_dim // 16) + self.toughness_embedding = nn.Linear(1, embed_dim // 16) + self.loyalty_embedding = nn.Linear(1, embed_dim // 16) + + # State encoding + self.state_encoder = nn.Linear( + 10, embed_dim // 8 + ) # Tapped, counters, etc. + + # Final projection + self.projector = nn.Linear(embed_dim, embed_dim) + self.layer_norm = nn.LayerNorm(embed_dim) + + def forward(self, card: Card) -> torch.Tensor: + """Encode a single card into a dense representation.""" + features = [] + + # Card ID embedding + card_id = card.card_id or 0 + card_emb = self.card_embedding(torch.tensor(card_id, dtype=torch.long)) + features.append(card_emb) + + # Card type embedding (simplified) + type_id = ( + hash(" ".join(card.card_types)) % 100 if card.card_types else 0 + ) + type_emb = self.type_embedding(torch.tensor(type_id, dtype=torch.long)) + features.append(type_emb) + + # Mana cost embedding + cmc_id = min(card.converted_mana_cost, 49) + cost_emb = self.cost_embedding(torch.tensor(cmc_id, dtype=torch.long)) + features.append(cost_emb) + + # Power/Toughness/Loyalty + if hasattr(card, "current_power") and card.current_power() is not None: + current_power = card.current_power() + power_value = ( + float(current_power) if current_power is not None else 0.0 + ) + power_emb = self.power_embedding(torch.tensor([power_value])) + features.append(power_emb) + else: + features.append(torch.zeros(self.embed_dim // 16)) + + if ( + hasattr(card, "current_toughness") + and card.current_toughness() is not None + ): + current_toughness = card.current_toughness() + toughness_value = ( + float(current_toughness) + if current_toughness is not None + else 0.0 + ) + toughness_emb = self.toughness_embedding( + torch.tensor([toughness_value]) + ) + features.append(toughness_emb) + else: + features.append(torch.zeros(self.embed_dim // 16)) + + if hasattr(card, "loyalty") and card.loyalty is not None: + loyalty_emb = self.loyalty_embedding( + torch.tensor([float(card.loyalty)]) + ) + features.append(loyalty_emb) + else: + features.append(torch.zeros(self.embed_dim // 16)) + + # State features + state_features = torch.tensor( + [ + float(getattr(card, "tapped", False)), + float(getattr(card, "summoning_sick", False)), + float(getattr(card, "attacking", False)), + float(len(getattr(card, "counters", {}))), + float(card.controller if hasattr(card, "controller") else 0), + # Pad to 10 features + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ][:10] + ) + + state_emb = self.state_encoder(state_features) + features.append(state_emb) + + # Combine all features + combined = torch.cat(features, dim=-1) + + # Project to final dimension and normalize + output = self.projector(combined) + result: torch.Tensor = self.layer_norm(output) + return result + + +class ZoneEncoder(nn.Module): + """Base class for encoding different types of zones.""" + + def __init__(self, config: EncoderConfig, zone_type: str): + super().__init__() + self.config = config + self.zone_type = zone_type + self.card_embedder = CardEmbeddingSystem( + config.card_vocab_size, config.embed_dim + ) + + def forward(self, zone: Zone, player_id: int) -> torch.Tensor: + """Encode a zone into a fixed-size representation.""" + if not zone.cards: + return torch.zeros(self.config.hidden_dim) + + # Encode all cards in the zone + card_embeddings = [] + for card in zone.cards[: self.config.max_cards_per_zone]: + card_emb = self.card_embedder(card) + card_embeddings.append(card_emb) + + if not card_embeddings: + return torch.zeros(self.config.hidden_dim) + + # Stack embeddings + embeddings_tensor = torch.stack(card_embeddings) + + # Zone-specific aggregation + return self._aggregate_embeddings(embeddings_tensor, zone, player_id) + + def _aggregate_embeddings( + self, embeddings: torch.Tensor, zone: Zone, player_id: int + ) -> torch.Tensor: + """Override in subclasses for zone-specific aggregation.""" + return embeddings.mean(dim=0) + + +class HandEncoder(ZoneEncoder): + """Specialized encoder for hand zone with hidden information modeling.""" + + def __init__(self, config: EncoderConfig): + super().__init__(config, "hand") + self.attention = nn.MultiheadAttention( + config.embed_dim, config.num_heads + ) + self.output_proj = nn.Linear(config.embed_dim, config.hidden_dim) + + def _aggregate_embeddings( + self, embeddings: torch.Tensor, zone: Zone, player_id: int + ) -> torch.Tensor: + """Use attention to weight hand cards by importance.""" + if len(embeddings.shape) == 1: + embeddings = embeddings.unsqueeze(0) + + # Self-attention over hand cards + attn_output, _ = self.attention(embeddings, embeddings, embeddings) + + # Aggregate with weighted average + hand_encoding = attn_output.mean(dim=0) + result: torch.Tensor = self.output_proj(hand_encoding) + return result + + +class BattlefieldEncoder(ZoneEncoder): + """Specialized encoder for battlefield with creature interactions.""" + + def __init__(self, config: EncoderConfig): + super().__init__(config, "battlefield") + self.transformer = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + d_model=config.embed_dim, + nhead=config.num_heads, + dim_feedforward=config.hidden_dim, + dropout=config.dropout, + ), + num_layers=2, + ) + self.output_proj = nn.Linear(config.embed_dim, config.hidden_dim) + + def _aggregate_embeddings( + self, embeddings: torch.Tensor, zone: Zone, player_id: int + ) -> torch.Tensor: + """Model battlefield interactions with transformer.""" + if len(embeddings.shape) == 1: + embeddings = embeddings.unsqueeze(0) + + # Add positional encoding for battlefield position + seq_len = embeddings.shape[0] + pos_encoding = ( + torch.arange(seq_len, dtype=torch.float).unsqueeze(1) / 100.0 + ) + pos_encoding = pos_encoding.expand(-1, embeddings.shape[1]) + embeddings = embeddings + pos_encoding + + # Apply transformer to model interactions + battlefield_encoding = self.transformer(embeddings) + + # Aggregate battlefield state + aggregated = battlefield_encoding.mean(dim=0) + result: torch.Tensor = self.output_proj(aggregated) + return result + + +class SequentialZoneEncoder(ZoneEncoder): + """Encoder for zones where order matters (graveyard, library, exile).""" + + def __init__(self, config: EncoderConfig): + super().__init__(config, "sequential") + self.lstm = nn.LSTM( + input_size=config.embed_dim, + hidden_size=config.hidden_dim // 2, + num_layers=2, + batch_first=True, + bidirectional=True, + dropout=config.dropout, + ) + + def _aggregate_embeddings( + self, embeddings: torch.Tensor, zone: Zone, player_id: int + ) -> torch.Tensor: + """Use LSTM to preserve order information.""" + if len(embeddings.shape) == 1: + embeddings = embeddings.unsqueeze(0) + + # Add batch dimension + embeddings = embeddings.unsqueeze(0) + + # LSTM encoding + lstm_out, (hidden, _) = self.lstm(embeddings) + + # Use final hidden state as zone representation + result: torch.Tensor = hidden.view(-1) + return result + + +class PlayerStateEncoder(nn.Module): + """Encode individual player state (life, mana, etc.).""" + + def __init__(self, config: EncoderConfig): + super().__init__() + self.config = config + + # Player feature encoding + self.life_encoder = nn.Linear(1, 32) + self.mana_encoder = nn.Linear(6, 64) # WUBRG + colorless + self.misc_encoder = nn.Linear(10, 64) # Other features + + # Final projection + self.output_proj = nn.Sequential( + nn.Linear(32 + 64 + 64, config.hidden_dim), + nn.ReLU(), + nn.Dropout(config.dropout), + nn.Linear(config.hidden_dim, config.hidden_dim), + ) + + def forward(self, player: Player, player_id: int) -> torch.Tensor: + """Encode player state features.""" + # Life encoding (normalized) + life_feature = torch.tensor([float(player.life) / 20.0]) + life_emb = self.life_encoder(life_feature) + + # Mana pool encoding + mana_colors = ["W", "U", "B", "R", "G", "C"] + mana_features = torch.tensor( + [ + float(player.mana_pool.get(color, 0)) / 10.0 + for color in mana_colors + ] + ) + mana_emb = self.mana_encoder(mana_features) + + # Miscellaneous features + misc_features = torch.tensor( + [ + float(player.lands_played_this_turn), + float(player.hand.size()) / 10.0, + float(player.battlefield.size()) / 20.0, + float(player.graveyard.size()) / 50.0, + float(player.library.size()) / 60.0, + float(player.exile.size()) / 20.0, + float(player_id), # Player identity + 0.0, + 0.0, + 0.0, # Reserved for future features + ] + ) + misc_emb = self.misc_encoder(misc_features) + + # Combine all features + combined = torch.cat([life_emb, mana_emb, misc_emb], dim=0) + result: torch.Tensor = self.output_proj(combined) + return result + + +class GlobalStateEncoder(nn.Module): + """Encode global game state (turn, phase, stack, etc.).""" + + def __init__(self, config: EncoderConfig): + super().__init__() + + # Phase/step encoding + self.phase_embedding = nn.Embedding(10, 64) + + # Turn and priority encoding + self.turn_encoder = nn.Linear(4, 64) + + # Stack encoding + self.stack_encoder = nn.Linear(5, 64) + + # Final projection + self.output_proj = nn.Sequential( + nn.Linear(64 + 64 + 64, config.hidden_dim), + nn.ReLU(), + nn.Dropout(config.dropout), + ) + + def forward(self, game_state: GameState) -> torch.Tensor: + """Encode global game state.""" + # Phase encoding + phases = ["untap", "upkeep", "draw", "main", "combat", "main2", "end"] + phase_id = ( + phases.index(game_state.phase) if game_state.phase in phases else 0 + ) + phase_emb = self.phase_embedding( + torch.tensor(phase_id, dtype=torch.long) + ) + + # Turn and priority features + turn_features = torch.tensor( + [ + float(game_state.turn_number) / 20.0, + float(game_state.active_player), + float(game_state.priority_player), + ( + 1.0 + if game_state.active_player == game_state.priority_player + else 0.0 + ), + ] + ) + turn_emb = self.turn_encoder(turn_features) + + # Stack features + stack_features = torch.tensor( + [ + float(len(game_state.stack)) / 10.0, + 0.0, + 0.0, + 0.0, + 0.0, # Reserved for stack content analysis + ] + ) + stack_emb = self.stack_encoder(stack_features) + + # Combine features + combined = torch.cat([phase_emb, turn_emb, stack_emb], dim=0) + result: torch.Tensor = self.output_proj(combined) + return result + + +class StateFusionNetwork(nn.Module): + """Fuse all encoded components into final game state representation.""" + + def __init__(self, config: EncoderConfig): + super().__init__() + self.config = config + + # Attention fusion + if config.use_attention: + self.cross_attention = nn.MultiheadAttention( + config.hidden_dim, config.num_heads + ) + + # Final fusion layers + fusion_input_dim = config.hidden_dim * ( + 2 * 6 + 2 + 1 + ) # zones + players + global + + self.fusion_network = nn.Sequential( + nn.Linear(fusion_input_dim, config.hidden_dim * 2), + nn.ReLU(), + nn.Dropout(config.dropout), + nn.Linear(config.hidden_dim * 2, config.hidden_dim), + nn.ReLU(), + nn.Dropout(config.dropout), + nn.Linear(config.hidden_dim, config.output_dim), + ) + + if config.use_layer_norm: + self.layer_norm: Optional[nn.LayerNorm] = nn.LayerNorm( + config.output_dim + ) + else: + self.layer_norm = None + + def forward( + self, + zone_encodings: Dict[int, Dict[str, torch.Tensor]], + player_encodings: List[torch.Tensor], + global_encoding: torch.Tensor, + stack_encoding: Optional[torch.Tensor] = None, + combat_encoding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Fuse all encodings into final representation.""" + + # Collect all encodings + all_encodings = [] + + # Zone encodings for both players + zone_names = [ + "hand", + "battlefield", + "graveyard", + "library", + "exile", + "command_zone", + ] + for player_id in [0, 1]: + for zone_name in zone_names: + if ( + player_id in zone_encodings + and zone_name in zone_encodings[player_id] + ): + all_encodings.append(zone_encodings[player_id][zone_name]) + else: + # Add zero encoding for missing zones + all_encodings.append(torch.zeros(self.config.hidden_dim)) + + # Player encodings + all_encodings.extend(player_encodings) + + # Global encoding + all_encodings.append(global_encoding) + + # Stack encoding (if provided) + if stack_encoding is not None: + all_encodings.append(stack_encoding) + + # Ensure we have the expected number of encodings + expected_count = 2 * 6 + 2 + 1 + 1 # zones + players + global + stack + while len(all_encodings) < expected_count: + all_encodings.append(torch.zeros(self.config.hidden_dim)) + + # Concatenate all encodings + fused_representation = torch.cat(all_encodings[:expected_count], dim=0) + + # Apply fusion network + output = self.fusion_network(fused_representation) + + # Apply layer norm if configured + if self.layer_norm is not None: + output = self.layer_norm(output) + + result: torch.Tensor = output + return result + + +class EnhancedGameStateEncoder(nn.Module): + """Complete enhanced game state encoder integrating all components.""" + + def __init__(self, config: EncoderConfig): + super().__init__() + self.config = config + + # Component encoders + self.zone_encoders = nn.ModuleDict( + { + "hand": HandEncoder(config), + "battlefield": BattlefieldEncoder(config), + "graveyard": SequentialZoneEncoder(config), + "library": SequentialZoneEncoder(config), + "exile": SequentialZoneEncoder(config), + "command_zone": SequentialZoneEncoder(config), + } + ) + + self.player_encoder = PlayerStateEncoder(config) + self.global_encoder = GlobalStateEncoder(config) + self.state_fusion = StateFusionNetwork(config) + + def forward(self, game_state: GameState) -> torch.Tensor: + """Encode complete game state into fixed-size tensor.""" + # Encode zones for both players + zone_encodings = {} + for player_id, player in enumerate(game_state.players): + player_zones = {} + for zone_name in [ + "hand", + "battlefield", + "graveyard", + "library", + "exile", + "command_zone", + ]: + zone = getattr(player, zone_name) + encoder = self.zone_encoders[zone_name] + player_zones[zone_name] = encoder(zone, player_id) + zone_encodings[player_id] = player_zones + + # Encode players + player_encodings = [ + self.player_encoder(player, player_id) + for player_id, player in enumerate(game_state.players) + ] + + # Encode global state + global_encoding = self.global_encoder(game_state) + + # Fuse all components + result: torch.Tensor = self.state_fusion( + zone_encodings, player_encodings, global_encoding + ) + return result + + def encode_batch(self, game_states: List[GameState]) -> torch.Tensor: + """Encode multiple game states in batch.""" + encodings = [] + for state in game_states: + encoding = self.forward(state) + encodings.append(encoding) + return torch.stack(encodings) diff --git a/src/manamind/models/policy_value_network.py b/src/manamind/models/policy_value_network.py new file mode 100644 index 0000000..9dd7a10 --- /dev/null +++ b/src/manamind/models/policy_value_network.py @@ -0,0 +1,319 @@ +"""Combined policy-value network for ManaMind agent. + +This module implements the core neural network architecture that combines +both policy (action prediction) and value (position evaluation) estimation +in a single network, similar to AlphaZero. +""" + +from typing import Any, Dict, List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from manamind.core.game_state import GameStateEncoder +from manamind.models.components import AttentionLayer, ResidualBlock + + +class PolicyValueNetwork(nn.Module): + """Combined policy-value network for Magic: The Gathering. + + This network takes a game state as input and outputs: + 1. Policy: Probability distribution over possible actions + 2. Value: Estimated probability of winning from this position + + Architecture is inspired by AlphaZero but adapted for MTG's complexity. + """ + + def __init__( + self, + state_dim: int = 2048, # From GameStateEncoder output + hidden_dim: int = 1024, + num_residual_blocks: int = 8, + num_attention_heads: int = 8, + action_space_size: int = 10000, # Maximum number of possible actions + dropout_rate: float = 0.1, + use_attention: bool = True, + ): + """Initialize the policy-value network. + + Args: + state_dim: Dimension of encoded game state + hidden_dim: Hidden dimension for residual blocks + num_residual_blocks: Number of residual blocks in the backbone + num_attention_heads: Number of attention heads (if using attention) + action_space_size: Size of the action space + dropout_rate: Dropout rate for regularization + use_attention: Whether to use attention mechanisms + """ + super().__init__() + + self.state_dim = state_dim + self.hidden_dim = hidden_dim + self.action_space_size = action_space_size + self.use_attention = use_attention + + # Game state encoder + self.state_encoder = GameStateEncoder(output_dim=state_dim) + + # Input projection + self.input_projection = nn.Linear(state_dim, hidden_dim) + + # Backbone network - stack of residual blocks + self.backbone = nn.ModuleList( + [ + ResidualBlock(hidden_dim, dropout_rate) + for _ in range(num_residual_blocks) + ] + ) + + # Optional attention layer + if use_attention: + self.attention = AttentionLayer( + hidden_dim, num_attention_heads, dropout_rate + ) + + # Policy head - predicts action probabilities + self.policy_head = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(), + nn.Dropout(dropout_rate), + nn.Linear(hidden_dim // 2, action_space_size), + ) + + # Value head - predicts win probability + self.value_head = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim // 4), + nn.ReLU(), + nn.Dropout(dropout_rate), + nn.Linear(hidden_dim // 4, 64), + nn.ReLU(), + nn.Linear(64, 1), + nn.Tanh(), # Output in [-1, 1] range + ) + + # Initialize weights + self._initialize_weights() + + def _initialize_weights(self) -> None: + """Initialize network weights using He initialization.""" + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.kaiming_normal_( + module.weight, mode="fan_out", nonlinearity="relu" + ) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.LayerNorm): + nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) + + def forward(self, game_state: Any) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass through the network. + + Args: + game_state: Either a GameState object or pre-encoded tensor + + Returns: + Tuple of (policy_logits, value): + - policy_logits: Raw logits for action probabilities + [batch_size, action_space_size] + - value: Estimated win probability [-1, 1] [batch_size, 1] + """ + # Encode game state if needed + if hasattr(game_state, "players"): # GameState object + x = self.state_encoder(game_state) + else: # Already encoded tensor + x = game_state + + # Handle batch dimension + if x.dim() == 1: + x = x.unsqueeze(0) + + # Input projection + x = self.input_projection(x) + x = F.relu(x) + + # Pass through residual blocks + for block in self.backbone: + x = block(x) + + # Optional attention + if self.use_attention: + # For attention, we need sequence dimension + # Reshape to [batch, seq_len, hidden_dim] if needed + if x.dim() == 2: + x = x.unsqueeze(1) # Add sequence dimension + x = self.attention(x) + x = x.squeeze(1) # Remove sequence dimension + + # Policy and value heads + policy_logits = self.policy_head(x) + value = self.value_head(x) + + return policy_logits, value + + def predict_action_probs( + self, game_state: Any, temperature: float = 1.0 + ) -> torch.Tensor: + """Get action probabilities from the policy head. + + Args: + game_state: Game state to evaluate + temperature: Temperature for softmax (higher = more exploration) + + Returns: + Action probabilities [batch_size, action_space_size] + """ + policy_logits, _ = self.forward(game_state) + + if temperature > 0: + probs = F.softmax(policy_logits / temperature, dim=-1) + else: + # Deterministic - pick highest probability action + probs = torch.zeros_like(policy_logits) + probs.scatter_(-1, policy_logits.argmax(dim=-1, keepdim=True), 1.0) + + return probs + + def evaluate_position(self, game_state: Any) -> torch.Tensor: + """Get position evaluation from the value head. + + Args: + game_state: Game state to evaluate + + Returns: + Position value [-1, 1] [batch_size, 1] + """ + _, value = self.forward(game_state) + return value + + def get_action_value_pairs( + self, game_state: Any, legal_actions: Any + ) -> List[Tuple[Any, float]]: + """Get (action, value) pairs for all legal actions. + + This is useful for MCTS to get both policy priors and value estimates. + + Args: + game_state: Current game state + legal_actions: List of legal Action objects + + Returns: + List of (action, prior_prob, value) tuples + """ + policy_logits, value = self.forward(game_state) + F.softmax(policy_logits, dim=-1) + + # TODO: Map legal_actions to network output indices + # This requires the ActionSpace to provide action->index mapping + + action_values = [] + for action in legal_actions: + # Placeholder - need to implement action encoding + prior_prob = 1.0 / len(legal_actions) # Uniform for now + action_values.append((action, prior_prob)) + + return action_values + + +class PolicyValueLoss(nn.Module): + """Loss function for training the policy-value network. + + Combines: + 1. Cross-entropy loss for policy (action prediction) + 2. Mean squared error for value (outcome prediction) + 3. L2 regularization for weights + """ + + def __init__(self, value_weight: float = 1.0, l2_reg: float = 1e-4): + """Initialize the loss function. + + Args: + value_weight: Weight for value loss relative to policy loss + l2_reg: L2 regularization coefficient + """ + super().__init__() + self.value_weight = value_weight + self.l2_reg = l2_reg + + def forward( + self, + policy_logits: torch.Tensor, + value_pred: torch.Tensor, + target_policy: torch.Tensor, + target_value: torch.Tensor, + model: nn.Module, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Compute the combined loss. + + Args: + policy_logits: Predicted policy logits + [batch_size, action_space_size] + value_pred: Predicted values [batch_size, 1] + target_policy: Target policy distribution + [batch_size, action_space_size] + target_value: Target values [batch_size, 1] + model: The model (for L2 regularization) + + Returns: + Tuple of (total_loss, loss_dict) where loss_dict contains + individual components + """ + # Policy loss - cross entropy between predicted and target + policy_loss = -torch.sum( + target_policy * F.log_softmax(policy_logits, dim=-1), dim=-1 + ) + policy_loss = policy_loss.mean() + + # Value loss - MSE between predicted and target values + value_loss = F.mse_loss(value_pred.squeeze(), target_value.squeeze()) + + # L2 regularization + l2_loss: torch.Tensor = torch.tensor(0.0) + for param in model.parameters(): + l2_loss += torch.sum(param**2) + l2_loss = l2_loss * self.l2_reg + + # Combined loss + total_loss = policy_loss + self.value_weight * value_loss + l2_loss + + loss_dict = { + "total_loss": total_loss.item(), + "policy_loss": policy_loss.item(), + "value_loss": value_loss.item(), + "l2_loss": float(l2_loss), + } + + return total_loss, loss_dict + + +def create_policy_value_network(**kwargs: Any) -> PolicyValueNetwork: + """Factory function to create a policy-value network with default settings. + + Args: + **kwargs: Keyword arguments to override defaults + + Returns: + Initialized PolicyValueNetwork + """ + defaults = { + "state_dim": 2048, + "hidden_dim": 1024, + "num_residual_blocks": 8, + "num_attention_heads": 8, + "action_space_size": 10000, + "dropout_rate": 0.1, + "use_attention": True, + } + defaults.update(kwargs) + + return PolicyValueNetwork( + state_dim=int(defaults["state_dim"]), + hidden_dim=int(defaults["hidden_dim"]), + num_residual_blocks=int(defaults["num_residual_blocks"]), + num_attention_heads=int(defaults["num_attention_heads"]), + action_space_size=int(defaults["action_space_size"]), + dropout_rate=float(defaults["dropout_rate"]), + use_attention=bool(defaults["use_attention"]), + ) diff --git a/src/manamind/training/__init__.py b/src/manamind/training/__init__.py new file mode 100644 index 0000000..42135e6 --- /dev/null +++ b/src/manamind/training/__init__.py @@ -0,0 +1,18 @@ +"""Training infrastructure for ManaMind AI agent. + +This module contains: +- Self-play training loops +- Neural network training +- Distributed training support +- Training data management +""" + +# from manamind.training.data_manager import TrainingDataManager +# from manamind.training.neural_trainer import NeuralNetworkTrainer +from manamind.training.self_play import SelfPlayTrainer + +__all__ = [ + "SelfPlayTrainer", + # "NeuralNetworkTrainer", + # "TrainingDataManager", +] diff --git a/src/manamind/training/data_manager.py b/src/manamind/training/data_manager.py new file mode 100644 index 0000000..e69de29 diff --git a/src/manamind/training/neural_trainer.py b/src/manamind/training/neural_trainer.py new file mode 100644 index 0000000..e69de29 diff --git a/src/manamind/training/self_play.py b/src/manamind/training/self_play.py new file mode 100644 index 0000000..cc8081e --- /dev/null +++ b/src/manamind/training/self_play.py @@ -0,0 +1,447 @@ +"""Self-play training implementation for ManaMind. + +This module implements the core self-play training loop where the agent +learns by playing millions of games against itself, similar to AlphaZero. +""" + +import logging +import random +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from numpy import ndarray +from tqdm import tqdm + +from manamind.core.action import Action +from manamind.core.agent import MCTSAgent +from manamind.core.game_state import GameState, create_standard_game_start +from manamind.forge_interface import ( # ForgeGameRunner not implemented yet + ForgeClient, +) +from manamind.models.policy_value_network import PolicyValueNetwork + +# from manamind.training.data_manager import TrainingDataManager + +logger = logging.getLogger(__name__) + + +class SelfPlayGame: + """Represents a single self-play game and its training data.""" + + def __init__(self, game_id: str): + self.game_id = game_id + self.history: List[ + Tuple[GameState, Action, float, ndarray[Any, Any]] + ] = [] + self.winner: Optional[int] = None + self.num_moves = 0 + self.start_time = time.time() + self.end_time: Optional[float] = None + + def add_move( + self, state: GameState, action: Action, mcts_policy: ndarray[Any, Any] + ) -> None: + """Add a move to the game history. + + Args: + state: Game state before the move + action: Action taken + mcts_policy: MCTS action probabilities for training + """ + # Store with temporary reward (will be updated at game end) + self.history.append((state, action, 0.0, mcts_policy)) + self.num_moves += 1 + + def finalize_game(self, winner: Optional[int]) -> None: + """Finalize the game and assign rewards. + + Args: + winner: Winning player ID (0 or 1), or None for draw + """ + self.winner = winner + self.end_time = time.time() + + # Update rewards based on game outcome + for i, (state, action, _, policy) in enumerate(self.history): + # Determine reward from this player's perspective + player_id = action.player_id + + if winner == player_id: + reward = 1.0 + elif winner is None: + reward = 0.0 # Draw + else: + reward = -1.0 + + # Update history with final reward + self.history[i] = (state, action, reward, policy) + + def duration(self) -> float: + """Get game duration in seconds.""" + end = self.end_time or time.time() + return end - self.start_time + + def get_training_examples( + self, + ) -> List[Tuple[GameState, ndarray[Any, Any], float]]: + """Extract training examples from this game. + + Returns: + List of (state, mcts_policy, reward) tuples for training + """ + examples = [] + for state, action, reward, mcts_policy in self.history: + examples.append((state, mcts_policy, reward)) + return examples + + +class SelfPlayTrainer: + """Self-play trainer for the ManaMind agent. + + This class manages the self-play training process: + 1. Generates self-play games using current model + 2. Collects training data from games + 3. Updates the neural network + 4. Iterates to improve performance + """ + + def __init__( + self, + policy_value_network: PolicyValueNetwork, + forge_client: Optional[ForgeClient] = None, + data_manager: Optional[ + Any + ] = None, # TrainingDataManager not implemented yet + config: Optional[Dict[str, Any]] = None, + ): + """Initialize self-play trainer. + + Args: + policy_value_network: The network to train + forge_client: Forge client for running games + data_manager: Training data manager + config: Training configuration + """ + self.network = policy_value_network + self.forge_client = forge_client + # self.data_manager = data_manager or TrainingDataManager() + self.data_manager = data_manager # Placeholder until implemented + + # Training configuration + self.config = config or self._default_config() + + # Training state + self.current_iteration = 0 + self.total_games_played = 0 + self.training_examples: List[ + Tuple[GameState, ndarray[Any, Any], float] + ] = [] + self.performance_history: List[Dict[str, Any]] = [] + + # Create MCTS agents for self-play + self.mcts_config = { + "simulations": self.config["mcts_simulations"], + "simulation_time": self.config["mcts_time_limit"], + "c_puct": self.config["c_puct"], + } + + def _default_config(self) -> Dict[str, Any]: + """Default training configuration.""" + return { + # Self-play parameters + "games_per_iteration": 100, + "max_game_length": 200, + "mcts_simulations": 800, + "mcts_time_limit": 1.0, + "c_puct": 1.0, + # Training parameters + "training_iterations": 1000, + "examples_buffer_size": 100000, + "batch_size": 64, + "epochs_per_iteration": 10, + "learning_rate": 0.001, + "weight_decay": 1e-4, + # Evaluation parameters + "evaluation_frequency": 10, + "evaluation_games": 50, + # Checkpointing + "checkpoint_frequency": 10, + "checkpoint_dir": "checkpoints", + } + + def train(self, num_iterations: Optional[int] = None) -> None: + """Run the main training loop. + + Args: + num_iterations: Number of training iterations + (uses config default if None) + """ + num_iterations = num_iterations or self.config["training_iterations"] + + logger.info( + f"Starting self-play training for {num_iterations} iterations" + ) + + for iteration in range(num_iterations): + self.current_iteration = iteration + + logger.info( + f"=== Training Iteration {iteration + 1}/{num_iterations} ===" + ) + + # Phase 1: Generate self-play games + logger.info("Generating self-play games...") + new_examples = self._generate_self_play_games() + + # Phase 2: Update training data + self.training_examples.extend(new_examples) + self._maintain_examples_buffer() + + logger.info(f"Training buffer size: {len(self.training_examples)}") + + # Phase 3: Train neural network + if len(self.training_examples) >= self.config["batch_size"]: + logger.info("Training neural network...") + self._train_network() + + # Phase 4: Evaluation and checkpointing + if (iteration + 1) % self.config["evaluation_frequency"] == 0: + logger.info("Evaluating model performance...") + self._evaluate_model() + + if (iteration + 1) % self.config["checkpoint_frequency"] == 0: + logger.info("Saving checkpoint...") + self._save_checkpoint() + + logger.info("Training completed!") + + def _generate_self_play_games( + self, + ) -> List[Tuple[GameState, ndarray[Any, Any], float]]: + """Generate self-play games and extract training examples.""" + num_games = self.config["games_per_iteration"] + all_examples = [] + games_completed = 0 + + with tqdm(total=num_games, desc="Self-play games") as pbar: + while games_completed < num_games: + try: + # Run a single self-play game + game = self._play_single_game() + + if game and game.winner is not None: + # Extract training examples + examples = game.get_training_examples() + all_examples.extend(examples) + games_completed += 1 + self.total_games_played += 1 + + # Update progress + pbar.set_postfix( + { + "moves": game.num_moves, + "duration": f"{game.duration():.1f}s", + "winner": f"P{game.winner}", + } + ) + pbar.update(1) + + except Exception as e: + logger.error(f"Error in self-play game: {e}") + continue + + msg = f"Generated {len(all_examples)} training examples from {games_completed} games" # noqa: E501 + logger.info(msg) + return all_examples + + def _play_single_game(self) -> Optional[SelfPlayGame]: + """Play a single self-play game. + + Returns: + Completed SelfPlayGame or None if game failed + """ + if self.forge_client: + return self._play_forge_game() + else: + return self._play_simulation_game() + + def _play_forge_game(self) -> Optional[SelfPlayGame]: + """Play a game using the Forge engine.""" + try: + # ForgeGameRunner not yet implemented + print( + "Warning: ForgeGameRunner not yet implemented, returning None" + ) + return None + + # TODO: Implement when ForgeGameRunner is ready: + # game_runner = ForgeGameRunner(self.forge_client) + # agent1 = MCTSAgent( + # player_id=0, policy_network=self.network, + # value_network=self.network, **self.mcts_config + # ) + # agent2 = MCTSAgent( + # player_id=1, policy_network=self.network, + # value_network=self.network, **self.mcts_config + # ) + # game_result = game_runner.play_game(agent1, agent2) + # if game_result: + # game = SelfPlayGame(game_result.game_id) + # for state, action, mcts_policy in game_result.history: + # game.add_move(state, action, mcts_policy) + # game.finalize_game(game_result.winner) + # return game + + except Exception as e: + logger.error(f"Error in Forge game: {e}") + + return None + + def _play_simulation_game(self) -> Optional[SelfPlayGame]: + """Play a game using pure Python simulation (testing without Forge).""" + try: + # Create initial game state + game_state = create_standard_game_start() + game = SelfPlayGame("simulation") + + # Create MCTS agents + agents = [ + MCTSAgent(0, self.network, self.network, **self.mcts_config), + MCTSAgent(1, self.network, self.network, **self.mcts_config), + ] + + move_count = 0 + max_moves = self.config["max_game_length"] + + while not game_state.is_game_over() and move_count < max_moves: + current_player = game_state.priority_player + agent = agents[current_player] + + # Get action from MCTS + action = agent.select_action(game_state) + + # TODO: Get MCTS policy for training + # For now, use dummy policy + mcts_policy = ( + np.ones(self.config.get("action_space_size", 1000)) / 1000 + ) + + # Record move + game.add_move(game_state.copy(), action, mcts_policy) + + # Execute action + game_state = action.execute(game_state) + move_count += 1 + + # Determine winner + winner = game_state.winner() + game.finalize_game(winner) + + return game + + except Exception as e: + logger.error(f"Error in simulation game: {e}") + return None + + def _maintain_examples_buffer(self) -> None: + """Maintain the training examples buffer at the configured size.""" + buffer_size = self.config["examples_buffer_size"] + + if len(self.training_examples) > buffer_size: + # Remove oldest examples to maintain buffer size + excess = len(self.training_examples) - buffer_size + self.training_examples = self.training_examples[excess:] + + msg = f"Trimmed training buffer to {len(self.training_examples)} examples" # noqa: E501 + logger.info(msg) + + def _train_network(self) -> None: + """Train the neural network on collected examples.""" + if not self.training_examples: + logger.warning("No training examples available") + return + + # TODO: Implement neural network training + # This would involve: + # 1. Creating data loaders from training examples + # 2. Running gradient descent for specified epochs + # 3. Updating the policy-value network + # 4. Logging training metrics + + logger.info( + f"Training network on {len(self.training_examples)} examples" + ) + + # Placeholder for actual training implementation + self.config["batch_size"] + epochs = self.config["epochs_per_iteration"] + + # Shuffle training examples + random.shuffle(self.training_examples) + + logger.info(f"Completed {epochs} training epochs") + + def _evaluate_model(self) -> None: + """Evaluate the current model performance.""" + # TODO: Implement model evaluation + # This could involve: + # 1. Playing games against previous model versions + # 2. Playing against Forge AI at different difficulty levels + # 3. Computing win rates and other metrics + # 4. Logging evaluation results + + logger.info("Model evaluation completed") + + def _save_checkpoint(self) -> None: + """Save a training checkpoint.""" + checkpoint_dir = Path(self.config["checkpoint_dir"]) + checkpoint_dir.mkdir(exist_ok=True) + + checkpoint_path = ( + checkpoint_dir + / f"checkpoint_iteration_{self.current_iteration}.pt" + ) + + checkpoint = { + "iteration": self.current_iteration, + "total_games": self.total_games_played, + "model_state_dict": self.network.state_dict(), + "config": self.config, + "performance_history": self.performance_history, + } + + torch.save(checkpoint, checkpoint_path) + logger.info(f"Saved checkpoint to {checkpoint_path}") + + # Also save as latest checkpoint + latest_path = checkpoint_dir / "latest.pt" + torch.save(checkpoint, latest_path) + + def load_checkpoint(self, checkpoint_path: str) -> None: + """Load a training checkpoint. + + Args: + checkpoint_path: Path to checkpoint file + """ + checkpoint = torch.load(checkpoint_path, map_location="cpu") + + self.current_iteration = checkpoint["iteration"] + self.total_games_played = checkpoint["total_games"] + self.network.load_state_dict(checkpoint["model_state_dict"]) + self.performance_history = checkpoint.get("performance_history", []) + + logger.info( + f"Loaded checkpoint from iteration {self.current_iteration}" + ) + + def get_training_stats(self) -> Dict[str, Any]: + """Get current training statistics.""" + return { + "current_iteration": self.current_iteration, + "total_games_played": self.total_games_played, + "training_examples": len(self.training_examples), + "performance_history": self.performance_history, + } diff --git a/src/manamind/utils/__init__.py b/src/manamind/utils/__init__.py new file mode 100644 index 0000000..ded5d18 --- /dev/null +++ b/src/manamind/utils/__init__.py @@ -0,0 +1,7 @@ +"""Utility modules for ManaMind.""" + +from manamind.utils.config import Config + +# from manamind.utils.logging import setup_logging + +__all__ = ["Config"] # , "setup_logging"] diff --git a/src/manamind/utils/config.py b/src/manamind/utils/config.py new file mode 100644 index 0000000..14e8098 --- /dev/null +++ b/src/manamind/utils/config.py @@ -0,0 +1,218 @@ +"""Configuration management for ManaMind.""" + +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import yaml +from pydantic import BaseModel, Field + + +class ModelConfig(BaseModel): + """Model configuration.""" + + state_encoder: Dict[str, Any] = Field(default_factory=dict) + policy_value_network: Dict[str, Any] = Field(default_factory=dict) + + +class TrainingConfig(BaseModel): + """Training configuration.""" + + self_play: Dict[str, Any] = Field(default_factory=dict) + mcts: Dict[str, Any] = Field(default_factory=dict) + neural_training: Dict[str, Any] = Field(default_factory=dict) + training_loop: Dict[str, Any] = Field(default_factory=dict) + optimizer: Dict[str, Any] = Field(default_factory=dict) + lr_scheduler: Dict[str, Any] = Field(default_factory=dict) + + +class ForgeConfig(BaseModel): + """Forge integration configuration.""" + + installation_path: Optional[str] = None + java_opts: List[str] = Field(default_factory=lambda: ["-Xmx4G", "-server"]) + port: int = 25333 + timeout: float = 30.0 + use_py4j: bool = True + default_decks: Dict[str, str] = Field(default_factory=dict) + + +class Config(BaseModel): + """Main configuration class for ManaMind.""" + + model: ModelConfig = Field(default_factory=ModelConfig) + training: TrainingConfig = Field(default_factory=TrainingConfig) + evaluation: Dict[str, Any] = Field(default_factory=dict) + forge: ForgeConfig = Field(default_factory=ForgeConfig) + data: Dict[str, Any] = Field(default_factory=dict) + logging: Dict[str, Any] = Field(default_factory=dict) + hardware: Dict[str, Any] = Field(default_factory=dict) + phases: Dict[str, Any] = Field(default_factory=dict) + + @classmethod + def from_yaml(cls, path: Union[str, Path]) -> "Config": + """Load configuration from YAML file. + + Args: + path: Path to YAML configuration file + + Returns: + Config instance + """ + with open(path, "r") as f: + data = yaml.safe_load(f) + + return cls(**data) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Config": + """Load configuration from dictionary. + + Args: + data: Configuration dictionary + + Returns: + Config instance + """ + return cls(**data) + + @classmethod + def default(cls) -> "Config": + """Create default configuration. + + Returns: + Config instance with default values + """ + return cls() + + def to_dict(self) -> Dict[str, Any]: + """Convert configuration to dictionary. + + Returns: + Configuration as dictionary + """ + return self.model_dump() + + def save_yaml(self, path: Union[str, Path]) -> None: + """Save configuration to YAML file. + + Args: + path: Path to save configuration file + """ + with open(path, "w") as f: + yaml.dump(self.to_dict(), f, default_flow_style=False, indent=2) + + def get(self, key: str, default: Any = None) -> Any: + """Get configuration value by dot-notation key. + + Args: + key: Configuration key (e.g., 'training.mcts.simulations') + default: Default value if key not found + + Returns: + Configuration value + """ + try: + keys = key.split(".") + value = self.to_dict() + + for k in keys: + value = value[k] + + return value + except (KeyError, TypeError): + return default + + def update(self, updates: Dict[str, Any]) -> None: + """Update configuration with new values. + + Args: + updates: Dictionary of updates to apply + """ + + def update_nested_dict( + d: Dict[str, Any], u: Dict[str, Any] + ) -> Dict[str, Any]: + for k, v in u.items(): + if isinstance(v, dict): + d[k] = update_nested_dict(d.get(k, {}), v) + else: + d[k] = v + return d + + current_dict = self.to_dict() + updated_dict = update_nested_dict(current_dict, updates) + + # Recreate the config with updated values + new_config = self.__class__(**updated_dict) + + # Update current instance + for field_name, field_value in new_config: + setattr(self, field_name, field_value) + + def override_from_env(self, prefix: str = "MANAMIND_") -> None: + """Override configuration values from environment variables. + + Args: + prefix: Environment variable prefix + """ + env_updates: Dict[str, Any] = {} + + for key, value in os.environ.items(): + if key.startswith(prefix): + # Convert MANAMIND_TRAINING_MCTS_SIMULATIONS -> training.mcts + config_key = key[len(prefix) :].lower().replace("_", ".") + + # Try to convert to appropriate type + converted_value: Any = value + try: + if value.lower() in ("true", "false"): + converted_value = value.lower() == "true" + elif value.isdigit(): + converted_value = int(value) + elif "." in value and value.replace(".", "").isdigit(): + converted_value = float(value) + except ValueError: + pass # Keep as string + + # Set nested dictionary value + keys = config_key.split(".") + d = env_updates + for k in keys[:-1]: + d = d.setdefault(k, {}) + d[keys[-1]] = converted_value + + if env_updates: + self.update(env_updates) + + +def load_config( + config_path: Optional[Union[str, Path]] = None, + overrides: Optional[Dict[str, Any]] = None, + use_env: bool = True, +) -> Config: + """Load configuration with various sources. + + Args: + config_path: Path to configuration file + overrides: Dictionary of configuration overrides + use_env: Whether to use environment variable overrides + + Returns: + Loaded configuration + """ + # Start with default configuration + if config_path: + config = Config.from_yaml(config_path) + else: + config = Config.default() + + # Apply environment overrides + if use_env: + config.override_from_env() + + # Apply explicit overrides + if overrides: + config.update(overrides) + + return config diff --git a/src/manamind/utils/logging.py b/src/manamind/utils/logging.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..67aa59b --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test package for ManaMind.""" diff --git a/tests/test_game_state.py b/tests/test_game_state.py new file mode 100644 index 0000000..d6078d8 --- /dev/null +++ b/tests/test_game_state.py @@ -0,0 +1,301 @@ +"""Tests for game state representation and encoding.""" + +import pytest +import torch + +from manamind.core.game_state import ( + Card, + GameStateEncoder, + Player, + Zone, + create_empty_game_state, +) + + +class TestCard: + """Test Card class.""" + + def test_card_creation(self): + """Test basic card creation.""" + card = Card( + name="Lightning Bolt", + mana_cost="R", + converted_mana_cost=1, + card_type="Instant", + text="Lightning Bolt deals 3 damage to any target.", + ) + + assert card.name == "Lightning Bolt" + assert card.mana_cost == "R" + assert card.converted_mana_cost == 1 + assert card.card_type == "Instant" + + def test_creature_card(self): + """Test creature card with power/toughness.""" + card = Card( + name="Grizzly Bears", + mana_cost="1G", + converted_mana_cost=2, + card_type="Creature — Bear", + power=2, + toughness=2, + ) + + assert card.power == 2 + assert card.toughness == 2 + + +class TestZone: + """Test Zone class.""" + + def test_empty_zone(self): + """Test empty zone creation.""" + zone = Zone(name="hand", owner=0) + + assert zone.name == "hand" + assert zone.owner == 0 + assert zone.size() == 0 + + def test_add_remove_cards(self): + """Test adding and removing cards.""" + zone = Zone(name="battlefield", owner=1) + card = Card(name="Test Card") + + # Add card + zone.add_card(card) + assert zone.size() == 1 + assert card in zone.cards + + # Remove card + result = zone.remove_card(card) + assert result is True + assert zone.size() == 0 + assert card not in zone.cards + + # Try to remove non-existent card + result = zone.remove_card(card) + assert result is False + + +class TestPlayer: + """Test Player class.""" + + def test_player_creation(self): + """Test player creation with zones.""" + player = Player(player_id=0) + + assert player.player_id == 0 + assert player.life == 20 + assert player.hand.owner == 0 + assert player.battlefield.owner == 0 + assert player.graveyard.owner == 0 + + def test_can_play_land(self): + """Test land playing rules.""" + player = Player(player_id=0) + + # Can play first land + assert player.can_play_land() is True + + # After playing one land + player.lands_played_this_turn = 1 + assert player.can_play_land() is False + + def test_mana_pool(self): + """Test mana pool functionality.""" + player = Player(player_id=0) + + # Empty mana pool + assert player.total_mana() == 0 + + # Add some mana + player.mana_pool = {"R": 2, "U": 1} + assert player.total_mana() == 3 + + +class TestGameState: + """Test GameState class.""" + + def test_empty_game_state(self): + """Test empty game state creation.""" + game_state = create_empty_game_state() + + assert game_state.turn_number == 1 + assert game_state.phase == "main" + assert game_state.priority_player == 0 + assert game_state.active_player == 0 + assert len(game_state.players) == 2 + + def test_game_properties(self): + """Test game state properties.""" + game_state = create_empty_game_state() + + # Test current player + current = game_state.current_player + assert current.player_id == 0 + + # Test opponent + opponent = game_state.opponent + assert opponent.player_id == 1 + + # Test game not over + assert game_state.is_game_over() is False + assert game_state.winner() is None + + def test_game_over_conditions(self): + """Test game over detection.""" + game_state = create_empty_game_state() + + # Player 0 loses + game_state.players[0].life = 0 + assert game_state.is_game_over() is True + assert game_state.winner() == 1 + + # Reset and test player 1 loses + game_state.players[0].life = 20 + game_state.players[1].life = 0 + assert game_state.is_game_over() is True + assert game_state.winner() == 0 + + +class TestGameStateEncoder: + """Test GameStateEncoder neural network.""" + + @pytest.fixture + def encoder(self): + """Create a test encoder.""" + return GameStateEncoder( + vocab_size=1000, embed_dim=64, hidden_dim=128, output_dim=256 + ) + + @pytest.fixture + def game_state(self): + """Create a test game state.""" + return create_empty_game_state() + + def test_encoder_creation(self, encoder): + """Test encoder creation.""" + assert encoder.vocab_size == 1000 + assert encoder.embed_dim == 64 + assert encoder.hidden_dim == 128 + assert encoder.output_dim == 256 + + def test_zone_encoding(self, encoder, game_state): + """Test zone encoding.""" + zone = game_state.players[0].hand + + # Add some test cards + for i in range(3): + card = Card(name=f"Card {i}", card_id=i + 1) + zone.add_card(card) + + # Encode zone + encoding = encoder.encode_zone(zone, zone_idx=0) + + # Should be a tensor of the right size + assert isinstance(encoding, torch.Tensor) + assert encoding.shape[0] == encoder.hidden_dim + + def test_player_encoding(self, encoder, game_state): + """Test player encoding.""" + player = game_state.players[0] + + # Modify player state + player.life = 15 + player.mana_pool = {"R": 2, "U": 1} + + # Encode player + encoding = encoder.encode_player(player) + + # Should be a tensor + assert isinstance(encoding, torch.Tensor) + assert encoding.shape[0] == encoder.hidden_dim + + def test_full_encoding(self, encoder, game_state): + """Test full game state encoding.""" + # Add some cards to make it more realistic + for player in game_state.players: + for i in range(2): + card = Card(name=f"Card {i}", card_id=i + 1) + player.hand.add_card(card) + + # Encode full state + encoding = encoder.forward(game_state) + + # Should be the right shape + assert isinstance(encoding, torch.Tensor) + assert encoding.shape[0] == encoder.output_dim + + def test_batch_encoding(self, encoder): + """Test batch encoding (using tensors directly).""" + batch_size = 4 + state_dim = encoder.output_dim + + # Create dummy batch + dummy_batch = torch.randn(batch_size, state_dim) + + # Should handle batch dimension correctly + with torch.no_grad(): + # This would normally go through the full encoding pipeline + # For now, just test that tensor shapes work + assert dummy_batch.shape == (batch_size, state_dim) + + +@pytest.mark.integration +class TestGameStateIntegration: + """Integration tests for game state components.""" + + def test_full_pipeline(self): + """Test the full game state pipeline.""" + # Create game state + game_state = create_empty_game_state() + + # Add some realistic game state + # Player 0 gets some cards + cards = [ + Card( + name="Lightning Bolt", + mana_cost="R", + converted_mana_cost=1, + card_id=1, + ), + Card(name="Mountain", card_type="Land", card_id=2), + Card( + name="Grizzly Bears", + mana_cost="1G", + converted_mana_cost=2, + card_id=3, + ), + ] + + for card in cards: + game_state.players[0].hand.add_card(card) + + # Player 0 plays a land + mountain = cards[1] + game_state.players[0].hand.remove_card(mountain) + game_state.players[0].battlefield.add_card(mountain) + game_state.players[0].lands_played_this_turn = 1 + + # Add some mana + game_state.players[0].mana_pool = {"R": 1} + + # Update game state + game_state.turn_number = 2 + game_state.phase = "main" + + # Create encoder and encode + encoder = GameStateEncoder(output_dim=512) + + with torch.no_grad(): + encoding = encoder.forward(game_state) + + # Should produce a valid encoding + assert isinstance(encoding, torch.Tensor) + assert encoding.shape[0] == 512 + assert not torch.isnan(encoding).any() + assert not torch.isinf(encoding).any() + + +if __name__ == "__main__": + pytest.main([__file__])