diff --git a/pyproject.toml b/pyproject.toml index adfd452..e5464ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "numpy", "open-spiel", "pydantic", + "pydantic-settings", "pyvis", "retry", "tqdm", diff --git a/strategicwm/_src/client_lib.py b/strategicwm/_src/client_lib.py index c968231..00dbf8f 100644 --- a/strategicwm/_src/client_lib.py +++ b/strategicwm/_src/client_lib.py @@ -25,7 +25,11 @@ import retry +from strategicwm._src import config +from strategicwm._src import logging_utils + Client = genai.Client +_LOGGER = logging_utils.get_logger(__name__) class LLMCall(Protocol): @@ -88,10 +92,10 @@ def __str__(self) -> str: @retry.retry( exceptions=HttpErrorRetriable, - tries=10, - delay=10, - max_delay=60, - backoff=2, + tries=config.settings.RETRY_TRIES, + delay=config.settings.RETRY_DELAY, + max_delay=config.settings.RETRY_MAX_DELAY, + backoff=config.settings.RETRY_BACKOFF, ) def generate_with_retry( client: Client, model: str, prompt_text: str @@ -136,7 +140,7 @@ def query_llm( f" '{prompt_text[:50]}...' (Process ID: {os.getpid()})" ) if verbose: - print(msg, flush=True) + _LOGGER.info(msg) if logger: logger.info(msg) try: @@ -162,7 +166,7 @@ def query_llm( + f"Model response blocked. Reason: {block_reason_msg}", ) if verbose: - print(err_msg, flush=True) + _LOGGER.error(err_msg) if logger: logger.error(err_msg) raise ValueError(err_msg) @@ -173,7 +177,7 @@ def query_llm( " blocked due to safety settings." ) if verbose: - print(err_msg, flush=True) + _LOGGER.error(err_msg) if logger: logger.error(err_msg) raise ValueError(err_msg) @@ -184,7 +188,7 @@ def query_llm( f" response: {e}" ) if verbose: - print(err_msg, flush=True) + _LOGGER.error(err_msg) if logger: logger.error(err_msg) raise ValueError(err_msg) from e diff --git a/strategicwm/_src/config.py b/strategicwm/_src/config.py new file mode 100644 index 0000000..6fdf56d --- /dev/null +++ b/strategicwm/_src/config.py @@ -0,0 +1,34 @@ +# Copyright 2025 The strategicwm Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for strategicwm.""" + +from pydantic_settings import BaseSettings + +class Settings(BaseSettings): + """Configuration settings for strategicwm.""" + + # LLM Client Retry Settings + RETRY_TRIES: int = 10 + RETRY_DELAY: int = 10 + RETRY_MAX_DELAY: int = 60 + RETRY_BACKOFF: int = 2 + + # Default Model + DEFAULT_MODEL: str = "gemini-1.0-pro" + + class Config: + env_prefix = "SWM_" + +settings = Settings() diff --git a/strategicwm/_src/config_test.py b/strategicwm/_src/config_test.py new file mode 100644 index 0000000..9166289 --- /dev/null +++ b/strategicwm/_src/config_test.py @@ -0,0 +1,43 @@ +# Copyright 2025 The strategicwm Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for configuration.""" + +import os +from absl.testing import absltest +from strategicwm._src import config + +class ConfigTest(absltest.TestCase): + + def test_defaults(self): + # Depending on order of tests and env vars, we might not get defaults. + # So we manually create an instance. + settings = config.Settings() + self.assertEqual(settings.RETRY_TRIES, 10) + self.assertEqual(settings.DEFAULT_MODEL, "gemini-1.0-pro") + + def test_env_var_override(self): + os.environ["SWM_RETRY_TRIES"] = "5" + os.environ["SWM_DEFAULT_MODEL"] = "gemini-1.5-pro" + + # Reload settings or create new instance + settings = config.Settings() + self.assertEqual(settings.RETRY_TRIES, 5) + self.assertEqual(settings.DEFAULT_MODEL, "gemini-1.5-pro") + + del os.environ["SWM_RETRY_TRIES"] + del os.environ["SWM_DEFAULT_MODEL"] + +if __name__ == "__main__": + absltest.main() diff --git a/strategicwm/_src/logging_utils.py b/strategicwm/_src/logging_utils.py new file mode 100644 index 0000000..16387c8 --- /dev/null +++ b/strategicwm/_src/logging_utils.py @@ -0,0 +1,50 @@ +# Copyright 2025 The strategicwm Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Logging utilities for strategicwm.""" + +import logging +import sys +from typing import Optional + +def get_logger(name: str, level: Optional[int] = None) -> logging.Logger: + """Get or create a logger with standardized configuration. + + Args: + name: The name of the logger. + level: Optional logging level to set for this logger. + + Returns: + A configured logging.Logger instance. + """ + logger = logging.getLogger(name) + if level is not None: + logger.setLevel(level) + return logger + +def configure_logging( + level: int = logging.INFO, + format_string: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) -> None: + """Configure library-wide logging settings. + + Args: + level: The default logging level. + format_string: The format string for log messages. + """ + logging.basicConfig( + level=level, + format=format_string, + handlers=[logging.StreamHandler(sys.stdout)] + ) diff --git a/strategicwm/_src/logging_utils_test.py b/strategicwm/_src/logging_utils_test.py new file mode 100644 index 0000000..c70c657 --- /dev/null +++ b/strategicwm/_src/logging_utils_test.py @@ -0,0 +1,52 @@ +# Copyright 2025 The strategicwm Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for logging utilities.""" + +import logging +from io import StringIO +from absl.testing import absltest +from strategicwm._src import logging_utils + +class LoggingUtilsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + # Reset logging configuration + logging.root.handlers = [] + + def test_get_logger(self): + logger = logging_utils.get_logger("test_logger") + self.assertIsInstance(logger, logging.Logger) + self.assertEqual(logger.name, "test_logger") + + def test_configure_logging(self): + # Capture stdout + captured_output = StringIO() + root_logger = logging.getLogger() + root_logger.setLevel(logging.INFO) + handler = logging.StreamHandler(captured_output) + handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s')) + root_logger.addHandler(handler) + + logger = logging_utils.get_logger("test_config") + logger.info("Test INFO message") + logger.debug("Test DEBUG message") + + output = captured_output.getvalue().strip() + self.assertIn("INFO: Test INFO message", output) + self.assertNotIn("DEBUG: Test DEBUG message", output) + +if __name__ == "__main__": + absltest.main() diff --git a/strategicwm/_src/se/construction/bfs.py b/strategicwm/_src/se/construction/bfs.py index 878315a..4739d3c 100644 --- a/strategicwm/_src/se/construction/bfs.py +++ b/strategicwm/_src/se/construction/bfs.py @@ -22,7 +22,10 @@ import networkx as nx import numpy as np +from strategicwm._src import logging_utils from strategicwm._src.se.construction import io + +_LOGGER = logging_utils.get_logger(__name__) from strategicwm._src.se.state import state as s import tqdm.auto as tqdm @@ -282,7 +285,7 @@ def bfs( f"Starting async BFS (per child generation) with {num_workers} workers..." ) if verbose: - print(msg, flush=True) + _LOGGER.info(msg) if logger: logger.info(msg) if params: @@ -332,7 +335,7 @@ def bfs( node = future.result() msg = f"Adding node's children to queue: {node.history_str()}" if verbose: - print(msg, flush=True) + _LOGGER.info(msg) if logger: logger.info(msg) nodes_visited_count += 1 @@ -358,7 +361,7 @@ def bfs( f" generated an exception during thread execution: {exc}" ) if verbose: - print(msg, flush=True) + _LOGGER.error(msg) if logger: logger.error(msg) @@ -369,7 +372,7 @@ def bfs( f"Total nodes visited: {nodes_visited_count}." ) if verbose: - print(msg, flush=True) + _LOGGER.info(msg) if logger: logger.info(msg) if pbar_nodes: diff --git a/strategicwm/_src/se/construction/direct.py b/strategicwm/_src/se/construction/direct.py index 90cf2a7..bcf0cb2 100644 --- a/strategicwm/_src/se/construction/direct.py +++ b/strategicwm/_src/se/construction/direct.py @@ -25,8 +25,11 @@ import pydantic from strategicwm._src import client_lib +from strategicwm._src import logging_utils from strategicwm._src.se.construction import io +_LOGGER = logging_utils.get_logger(__name__) + from typing_extensions import Annotated @@ -222,9 +225,9 @@ def get_params(game_json: dict[str, Any]) -> io.GameParamsA: num_actions = len(game_tree_nx.nodes[node]["legal_actions_string"]) max_num_distinct_actions = max(max_num_distinct_actions, num_actions) if len(players) != num_players: - print( - f"Warning! Players {players} found in game tree. Mismatch with" - f" {num_players} player descriptions." + _LOGGER.warning( + "Warning! Players %s found in game tree. Mismatch with %s player descriptions.", + players, num_players ) num_players = len(players) params = io.GameParamsA(