diff --git a/src/manamind/core/agent.py b/src/manamind/core/agent.py index b32ba53..0620ac6 100644 --- a/src/manamind/core/agent.py +++ b/src/manamind/core/agent.py @@ -9,7 +9,7 @@ import random import time from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch @@ -93,7 +93,7 @@ def __init__( self.game_state = game_state self.action = action self.parent = parent - self.children: Dict[Action, MCTSNode] = {} + self.children: List[Tuple[Action, MCTSNode]] = [] # MCTS statistics self.visits = 0 @@ -112,22 +112,23 @@ def is_terminal(self) -> bool: """Check if this is a terminal game state.""" return self.game_state.is_game_over() - def ucb1_score(self, c: float = 1.414) -> float: + def ucb1_score(self, child_node: MCTSNode, c: float = 1.414) -> float: """Calculate UCB1 score for action selection. Args: + child_node: Child node to calculate score for c: Exploration parameter Returns: UCB1 score """ - if self.visits == 0: + if child_node.visits == 0: return float("inf") - exploitation = self.total_value / self.visits + exploitation = child_node.total_value / child_node.visits exploration = ( - c * math.sqrt(math.log(self.parent.visits) / self.visits) - if self.parent + c * math.sqrt(math.log(self.visits) / child_node.visits) + if self.visits > 0 else 0.0 ) return exploitation + exploration @@ -135,7 +136,8 @@ def ucb1_score(self, c: float = 1.414) -> float: def select_child(self) -> MCTSNode: """Select the child with the highest UCB1 score.""" return max( - self.children.values(), key=lambda child: child.ucb1_score() + (child for _, child in self.children), + key=lambda child: self.ucb1_score(child), ) def expand(self) -> MCTSNode: @@ -146,7 +148,7 @@ def expand(self) -> MCTSNode: action = self.untried_actions.pop() new_state = action.execute(self.game_state) child_node = MCTSNode(new_state, action, self) - self.children[action] = child_node + self.children.append((action, child_node)) return child_node def backup(self, value: float) -> None: @@ -243,7 +245,8 @@ def select_action(self, game_state: GameState) -> Action: return random.choice(legal_actions) best_child = max( - root.children.values(), key=lambda child: child.visits + (child for _, child in root.children), + key=lambda child: child.visits, ) if best_child.action: return best_child.action diff --git a/tests/integration/test_core_integration.py b/tests/integration/test_core_integration.py new file mode 100644 index 0000000..3c9959d --- /dev/null +++ b/tests/integration/test_core_integration.py @@ -0,0 +1,212 @@ +"""Integration tests for core components.""" + +import torch + +from manamind.core.action import Action, ActionSpace, ActionType +from manamind.core.agent import RandomAgent +from manamind.core.game_state import ( + Card, + create_empty_game_state, +) + + +class TestCoreIntegration: + """Integration tests for core components.""" + + def test_game_state_action_integration(self): + """Test integration between game state and actions.""" + # Create a game state + game_state = create_empty_game_state() + + # Add a land to player's hand + land = Card(name="Mountain", card_type="Land") + game_state.players[0].hand.add_card(land) + + # Set up game state for land play + game_state.active_player = 0 + game_state.priority_player = 0 + game_state.phase = "main" + + # Create a play land action + action = Action( + action_type=ActionType.PLAY_LAND, + player_id=0, + card=land, + ) + + # Verify the action is valid + assert action.is_valid(game_state) is True + + # Execute the action + new_state = action.execute(game_state) + + # Verify the land was moved to battlefield + assert land not in new_state.players[0].hand.cards + assert land in new_state.players[0].battlefield.cards + assert new_state.players[0].lands_played_this_turn == 1 + + def test_action_space_integration(self): + """Test integration with action space.""" + # Create a game state + game_state = create_empty_game_state() + + # Add cards to player's hand + land = Card(name="Mountain", card_type="Land") + spell = Card( + name="Lightning Bolt", + card_type="Instant", + converted_mana_cost=1, + ) + game_state.players[0].hand.add_card(land) + game_state.players[0].hand.add_card(spell) + + # Set up game state + game_state.active_player = 0 + game_state.priority_player = 0 + game_state.phase = "main" + game_state.players[0].mana_pool = {"R": 1} + + # Get legal actions + action_space = ActionSpace() + legal_actions = action_space.get_legal_actions(game_state) + + # Should have at least pass priority and play land actions + assert len(legal_actions) >= 2 + + # Verify we can find the play land action + play_land_actions = [ + action + for action in legal_actions + if action.action_type == ActionType.PLAY_LAND + ] + assert len(play_land_actions) == 1 + assert play_land_actions[0].card == land + + def test_agent_game_state_integration(self): + """Test integration between agents and game state.""" + # Create a game state + game_state = create_empty_game_state() + + # Add a land to player's hand + land = Card(name="Mountain", card_type="Land") + game_state.players[0].hand.add_card(land) + + # Set up game state for land play + game_state.active_player = 0 + game_state.priority_player = 0 + game_state.phase = "main" + + # Create a random agent + agent = RandomAgent(player_id=0, seed=42) + + # Select an action + action = agent.select_action(game_state) + + # Verify it's a valid action + assert isinstance(action, Action) + assert action.player_id == 0 + assert action.is_valid(game_state) is True + + def test_full_game_simulation(self): + """Test a simple game simulation.""" + # Create a game state + game_state = create_empty_game_state() + + # Add cards to both players' hands + land_p0 = Card(name="Mountain", card_type="Land") + spell_p0 = Card( + name="Lightning Bolt", + card_type="Instant", + converted_mana_cost=1, + ) + game_state.players[0].hand.add_card(land_p0) + game_state.players[0].hand.add_card(spell_p0) + + land_p1 = Card(name="Forest", card_type="Land") + game_state.players[1].hand.add_card(land_p1) + + # Set up initial game state + game_state.active_player = 0 + game_state.priority_player = 0 + game_state.phase = "main" + game_state.players[0].mana_pool = {"R": 1} + + # Create agents + agent0 = RandomAgent(player_id=0, seed=42) + agent1 = RandomAgent(player_id=1, seed=24) + + # Simulate a few turns + for turn in range(3): + # Player 0's turn + game_state.active_player = 0 + game_state.priority_player = 0 + + # Main phase 1 + game_state.phase = "main" + action = agent0.select_action(game_state) + if action.action_type != ActionType.PASS_PRIORITY: + game_state = action.execute(game_state) + + # Pass priority to end turn + game_state.priority_player = 1 + action = Action( + action_type=ActionType.PASS_PRIORITY, + player_id=1, + ) + game_state = action.execute(game_state) + + # Player 1's turn + game_state.active_player = 1 + game_state.priority_player = 1 + + # Main phase 1 + game_state.phase = "main" + action = agent1.select_action(game_state) + if action.action_type != ActionType.PASS_PRIORITY: + game_state = action.execute(game_state) + + # Pass priority to end turn + game_state.priority_player = 0 + action = Action( + action_type=ActionType.PASS_PRIORITY, + player_id=0, + ) + game_state = action.execute(game_state) + + # Increment turn + game_state.turn_number += 1 + + def test_neural_network_integration(self): + """Test integration with neural networks.""" + # Create a game state + game_state = create_empty_game_state() + + # Add some cards to make it more realistic + for player in game_state.players: + for i in range(2): + card = Card(name=f"Card {i}", card_id=i + 1) + player.hand.add_card(card) + + # Create a mock neural network + class MockNetwork(torch.nn.Module): + def __init__(self): + super().__init__() + self.encoder = torch.nn.Linear(10, 10) + + def forward(self, game_state): + # Simple mock implementation + return torch.tensor([0.0]), torch.tensor(0.0) + + # Test that we can create agents with neural networks + network = MockNetwork() + from manamind.core.agent import NeuralAgent + + agent = NeuralAgent( + player_id=0, + policy_value_network=network, + ) + + # Should be able to select an action + action = agent.select_action(game_state) + assert isinstance(action, Action) + assert action.player_id == 0 diff --git a/tests/test_action.py b/tests/test_action.py new file mode 100644 index 0000000..15847c0 --- /dev/null +++ b/tests/test_action.py @@ -0,0 +1,344 @@ +"""Tests for action representation and validation.""" + +from manamind.core.action import ( + Action, + ActionSpace, + ActionType, + CastSpellExecutor, + CastSpellValidator, + PassPriorityExecutor, + PassPriorityValidator, + PlayLandExecutor, + PlayLandValidator, +) +from manamind.core.game_state import ( + Card, + create_empty_game_state, +) + + +class TestActionType: + """Test ActionType enum.""" + + def test_action_type_values(self): + """Test that all action types have correct values.""" + assert ActionType.PLAY_LAND.value == "play_land" + assert ActionType.CAST_SPELL.value == "cast_spell" + assert ActionType.PASS_PRIORITY.value == "pass_priority" + + +class TestAction: + """Test Action class.""" + + def test_action_creation(self): + """Test basic action creation.""" + card = Card(name="Lightning Bolt") + action = Action( + action_type=ActionType.CAST_SPELL, + player_id=0, + card=card, + ) + + assert action.action_type == ActionType.CAST_SPELL + assert action.player_id == 0 + assert action.card == card + assert action.timestamp is not None + + def test_action_complexity_score(self): + """Test action complexity scoring.""" + # Simple action + simple_action = Action( + action_type=ActionType.PASS_PRIORITY, + player_id=0, + ) + assert simple_action.get_complexity_score() == 1 + + # Action with targets + card = Card(name="Lightning Bolt") + targeted_action = Action( + action_type=ActionType.CAST_SPELL, + player_id=0, + card=card, + target_cards=[card], + ) + assert targeted_action.get_complexity_score() == 2 + + def test_action_targets(self): + """Test getting all targets from an action.""" + card1 = Card(name="Lightning Bolt") + card2 = Card(name="Grizzly Bears") + + action = Action( + action_type=ActionType.CAST_SPELL, + player_id=0, + card=card1, + target_cards=[card2], + target_players=[1], + ) + + targets = action.get_all_targets() + assert card1 in targets + assert card2 in targets + assert 1 in targets + + +class TestPlayLandValidator: + """Test PlayLandValidator.""" + + def test_valid_land_play(self): + """Test valid land play scenario.""" + game_state = create_empty_game_state() + + # Add a land to player's hand + land = Card(name="Mountain", card_type="Land") + game_state.players[0].hand.add_card(land) + + action = Action( + action_type=ActionType.PLAY_LAND, + player_id=0, + card=land, + ) + + validator = PlayLandValidator() + assert validator.validate(action, game_state) is True + + def test_invalid_land_play_wrong_zone(self): + """Test land play with card not in hand.""" + game_state = create_empty_game_state() + + # Put land on battlefield instead of hand + land = Card(name="Mountain", card_type="Land") + game_state.players[0].battlefield.add_card(land) + + action = Action( + action_type=ActionType.PLAY_LAND, + player_id=0, + card=land, + ) + + validator = PlayLandValidator() + assert validator.validate(action, game_state) is False + + def test_invalid_land_play_not_land(self): + """Test land play with non-land card.""" + game_state = create_empty_game_state() + + # Add non-land to player's hand + spell = Card(name="Lightning Bolt", card_type="Instant") + game_state.players[0].hand.add_card(spell) + + action = Action( + action_type=ActionType.PLAY_LAND, + player_id=0, + card=spell, + ) + + validator = PlayLandValidator() + assert validator.validate(action, game_state) is False + + +class TestCastSpellValidator: + """Test CastSpellValidator.""" + + def test_valid_spell_cast_sorcery(self): + """Test valid sorcery cast.""" + game_state = create_empty_game_state() + + # Add spell to player's hand + spell = Card( + name="Lightning Bolt", + card_type="Sorcery", + converted_mana_cost=1, + ) + game_state.players[0].hand.add_card(spell) + + # Add mana + game_state.players[0].mana_pool = {"R": 1} + + action = Action( + action_type=ActionType.CAST_SPELL, + player_id=0, + card=spell, + ) + + validator = CastSpellValidator() + assert validator.validate(action, game_state) is True + + def test_invalid_spell_cast_no_mana(self): + """Test spell cast without enough mana.""" + game_state = create_empty_game_state() + + # Add spell to player's hand + spell = Card( + name="Lightning Bolt", + card_type="Instant", + converted_mana_cost=1, + ) + game_state.players[0].hand.add_card(spell) + + # No mana in pool + game_state.players[0].mana_pool = {} + + action = Action( + action_type=ActionType.CAST_SPELL, + player_id=0, + card=spell, + ) + + validator = CastSpellValidator() + assert validator.validate(action, game_state) is False + + +class TestPassPriorityValidator: + """Test PassPriorityValidator.""" + + def test_valid_priority_pass(self): + """Test valid priority pass.""" + game_state = create_empty_game_state() + + action = Action( + action_type=ActionType.PASS_PRIORITY, + player_id=0, + ) + + validator = PassPriorityValidator() + assert validator.validate(action, game_state) is True + + def test_invalid_priority_pass(self): + """Test priority pass when not having priority.""" + game_state = create_empty_game_state() + game_state.priority_player = 1 # Other player has priority + + action = Action( + action_type=ActionType.PASS_PRIORITY, + player_id=0, + ) + + validator = PassPriorityValidator() + assert validator.validate(action, game_state) is False + + +class TestPlayLandExecutor: + """Test PlayLandExecutor.""" + + def test_execute_land_play(self): + """Test executing a land play.""" + game_state = create_empty_game_state() + + # Add a land to player's hand + land = Card(name="Mountain", card_type="Land") + game_state.players[0].hand.add_card(land) + + action = Action( + action_type=ActionType.PLAY_LAND, + player_id=0, + card=land, + ) + + executor = PlayLandExecutor() + new_state = executor.execute(action, game_state) + + # Check that land was moved to battlefield + assert land not in new_state.players[0].hand.cards + assert land in new_state.players[0].battlefield.cards + assert new_state.players[0].lands_played_this_turn == 1 + + +class TestCastSpellExecutor: + """Test CastSpellExecutor.""" + + def test_execute_spell_cast(self): + """Test executing a spell cast.""" + game_state = create_empty_game_state() + + # Add spell to player's hand + spell = Card(name="Lightning Bolt", card_type="Instant") + game_state.players[0].hand.add_card(spell) + + action = Action( + action_type=ActionType.CAST_SPELL, + player_id=0, + card=spell, + ) + + executor = CastSpellExecutor() + new_state = executor.execute(action, game_state) + + # Check that spell was moved to stack + assert spell not in new_state.players[0].hand.cards + assert len(new_state.stack) == 1 + + +class TestPassPriorityExecutor: + """Test PassPriorityExecutor.""" + + def test_execute_pass_priority(self): + """Test executing priority pass.""" + game_state = create_empty_game_state() + game_state.priority_player = 0 + + action = Action( + action_type=ActionType.PASS_PRIORITY, + player_id=0, + ) + + executor = PassPriorityExecutor() + new_state = executor.execute(action, game_state) + + # Check that priority was passed + assert new_state.priority_player == 1 + + +class TestActionSpace: + """Test ActionSpace class.""" + + def test_action_space_creation(self): + """Test action space creation.""" + action_space = ActionSpace() + + assert action_space.max_actions == 10000 + assert len(action_space.action_to_id) > 0 + assert len(action_space.id_to_action) > 0 + + def test_get_legal_actions(self): + """Test getting legal actions from game state.""" + game_state = create_empty_game_state() + + # Add a land to player's hand + land = Card(name="Mountain", card_type="Land") + game_state.players[0].hand.add_card(land) + + # Set up game state for land play + game_state.active_player = 0 + game_state.priority_player = 0 + game_state.phase = "main" + + action_space = ActionSpace() + legal_actions = action_space.get_legal_actions(game_state) + + # Should have at least pass priority and play land actions + assert len(legal_actions) >= 2 + + # Check for play land action + play_land_actions = [ + action + for action in legal_actions + if action.action_type == ActionType.PLAY_LAND + ] + assert len(play_land_actions) == 1 + assert play_land_actions[0].card == land + + def test_action_to_vector(self): + """Test converting action to vector.""" + action_space = ActionSpace() + + action = Action( + action_type=ActionType.PASS_PRIORITY, + player_id=0, + ) + + vector = action_space.action_to_vector(action) + + assert isinstance(vector, list) + assert len(vector) == action_space.max_actions + assert sum(vector) <= 1 # At most one element should be 1.0 diff --git a/tests/unit/test_agent.py b/tests/unit/test_agent.py new file mode 100644 index 0000000..0128833 --- /dev/null +++ b/tests/unit/test_agent.py @@ -0,0 +1,243 @@ +"""Tests for agent implementations.""" + +from manamind.core.action import Action +from manamind.core.agent import ( + MCTSAgent, + MCTSNode, + NeuralAgent, + RandomAgent, +) +from manamind.core.game_state import create_empty_game_state + + +class TestAgent: + """Test Agent base class.""" + + def test_agent_creation(self): + """Test agent creation with player ID.""" + agent = RandomAgent(player_id=0) + assert agent.player_id == 0 + + +class TestRandomAgent: + """Test RandomAgent implementation.""" + + def test_random_agent_creation(self): + """Test random agent creation.""" + agent = RandomAgent(player_id=1, seed=42) + assert agent.player_id == 1 + assert agent.rng is not None + + def test_random_agent_select_action(self): + """Test random agent action selection.""" + agent = RandomAgent(player_id=0, seed=42) + game_state = create_empty_game_state() + + # Add a land to player's hand to have legal actions + from manamind.core.game_state import Card + + land = Card(name="Mountain", card_type="Land") + game_state.players[0].hand.add_card(land) + + # Set up game state for land play + game_state.active_player = 0 + game_state.priority_player = 0 + game_state.phase = "main" + + action = agent.select_action(game_state) + assert isinstance(action, Action) + assert action.player_id == 0 + + def test_random_agent_update_from_game(self): + """Test that random agent update method exists.""" + agent = RandomAgent(player_id=0) + game_history = [] + agent.update_from_game(game_history) # Should not raise + + +class TestMCTSNode: + """Test MCTSNode implementation.""" + + def test_mcts_node_creation(self): + """Test MCTS node creation.""" + game_state = create_empty_game_state() + node = MCTSNode(game_state) + + assert node.game_state == game_state + assert node.action is None + assert node.parent is None + assert node.visits == 0 + assert node.total_value == 0.0 + assert node.prior_prob == 1.0 + + def test_mcts_node_is_fully_expanded(self): + """Test checking if node is fully expanded.""" + game_state = create_empty_game_state() + node = MCTSNode(game_state) + + # Initially should not be fully expanded (has legal actions) + assert node.is_fully_expanded() is False + + def test_mcts_node_is_terminal(self): + """Test checking if node is terminal.""" + game_state = create_empty_game_state() + node = MCTSNode(game_state) + + # Normal game state should not be terminal + assert node.is_terminal() is False + + # Game over state should be terminal + game_state.players[0].life = 0 + assert node.is_terminal() is True + + def test_mcts_node_ucb1_score(self): + """Test UCB1 score calculation.""" + game_state = create_empty_game_state() + parent_node = MCTSNode(game_state) + parent_node.visits = 2 # Parent needs visits for exploration term + + # Create a child node + child_node = MCTSNode(game_state) + + # Child with no visits should have infinite score + score = parent_node.ucb1_score(child_node) + assert score == float("inf") + + # Child with visits should have finite score + child_node.visits = 1 + child_node.total_value = 0.5 + score = parent_node.ucb1_score(child_node) + assert isinstance(score, float) + assert score != float("inf") + + def test_mcts_node_expand(self): + """Test expanding the node.""" + game_state = create_empty_game_state() + + # Add a land to player's hand to have legal actions + from manamind.core.game_state import Card + + land = Card(name="Mountain", card_type="Land") + game_state.players[0].hand.add_card(land) + + # Set up game state for land play + game_state.active_player = 0 + game_state.priority_player = 0 + game_state.phase = "main" + + node = MCTSNode(game_state) + child = node.expand() + + assert isinstance(child, MCTSNode) + assert child.parent == node + assert child.action is not None + assert len(node.children) == 1 + + def test_mcts_node_backup(self): + """Test backing up values through the tree.""" + game_state = create_empty_game_state() + root = MCTSNode(game_state) + child = root.expand() + + # Backup a value + child.backup(0.5) + + # Check that visits and values were updated + assert child.visits == 1 + assert child.total_value == 0.5 + assert root.visits == 1 + assert root.total_value == -0.5 # Flipped for opponent + + +class TestMCTSAgent: + """Test MCTSAgent implementation.""" + + def test_mcts_agent_creation(self): + """Test MCTS agent creation.""" + agent = MCTSAgent(player_id=0) + assert agent.player_id == 0 + assert agent.simulations == 1000 + assert agent.simulation_time == 1.0 + + def test_mcts_agent_select_action(self): + """Test MCTS agent action selection.""" + agent = MCTSAgent(player_id=0, simulations=10) + game_state = create_empty_game_state() + + # Add a land to player's hand to have legal actions + from manamind.core.game_state import Card + + land = Card(name="Mountain", card_type="Land") + game_state.players[0].hand.add_card(land) + + # Set up game state for land play + game_state.active_player = 0 + game_state.priority_player = 0 + game_state.phase = "main" + + action = agent.select_action(game_state) + assert isinstance(action, Action) + assert action.player_id == 0 + + def test_mcts_agent_update_from_game(self): + """Test that MCTS agent update method exists.""" + agent = MCTSAgent(player_id=0) + game_history = [] + agent.update_from_game(game_history) # Should not raise + + +class TestNeuralAgent: + """Test NeuralAgent implementation.""" + + def test_neural_agent_creation(self): + """Test neural agent creation.""" + + # Create a mock network + class MockNetwork: + pass + + network = MockNetwork() + agent = NeuralAgent(player_id=1, policy_value_network=network) + assert agent.player_id == 1 + assert agent.policy_value_network == network + + def test_neural_agent_select_action(self): + """Test neural agent action selection.""" + + # Create a mock network + class MockNetwork: + def __call__(self, game_state): + import torch + + return torch.tensor([0.0]), torch.tensor(0.0) + + network = MockNetwork() + agent = NeuralAgent(player_id=0, policy_value_network=network) + game_state = create_empty_game_state() + + # Add a land to player's hand to have legal actions + from manamind.core.game_state import Card + + land = Card(name="Mountain", card_type="Land") + game_state.players[0].hand.add_card(land) + + # Set up game state for land play + game_state.active_player = 0 + game_state.priority_player = 0 + game_state.phase = "main" + + action = agent.select_action(game_state) + assert isinstance(action, Action) + assert action.player_id == 0 + + def test_neural_agent_update_from_game(self): + """Test that neural agent update method exists.""" + + # Create a mock network + class MockNetwork: + pass + + network = MockNetwork() + agent = NeuralAgent(player_id=0, policy_value_network=network) + game_history = [] + agent.update_from_game(game_history) # Should not raise