Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"numpy",
"open-spiel",
"pydantic",
"pydantic-settings",
"pyvis",
"retry",
"tqdm",
Expand Down
20 changes: 12 additions & 8 deletions strategicwm/_src/client_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import retry


from strategicwm._src import config
from strategicwm._src import logging_utils

Client = genai.Client
_LOGGER = logging_utils.get_logger(__name__)


class LLMCall(Protocol):
Expand Down Expand Up @@ -88,10 +92,10 @@ def __str__(self) -> str:

@retry.retry(
exceptions=HttpErrorRetriable,
tries=10,
delay=10,
max_delay=60,
backoff=2,
tries=config.settings.RETRY_TRIES,
delay=config.settings.RETRY_DELAY,
max_delay=config.settings.RETRY_MAX_DELAY,
backoff=config.settings.RETRY_BACKOFF,
)
def generate_with_retry(
client: Client, model: str, prompt_text: str
Expand Down Expand Up @@ -136,7 +140,7 @@ def query_llm(
f" '{prompt_text[:50]}...' (Process ID: {os.getpid()})"
)
if verbose:
print(msg, flush=True)
_LOGGER.info(msg)
if logger:
logger.info(msg)
try:
Expand All @@ -162,7 +166,7 @@ def query_llm(
+ f"Model response blocked. Reason: {block_reason_msg}",
)
if verbose:
print(err_msg, flush=True)
_LOGGER.error(err_msg)
if logger:
logger.error(err_msg)
raise ValueError(err_msg)
Expand All @@ -173,7 +177,7 @@ def query_llm(
" blocked due to safety settings."
)
if verbose:
print(err_msg, flush=True)
_LOGGER.error(err_msg)
if logger:
logger.error(err_msg)
raise ValueError(err_msg)
Expand All @@ -184,7 +188,7 @@ def query_llm(
f" response: {e}"
)
if verbose:
print(err_msg, flush=True)
_LOGGER.error(err_msg)
if logger:
logger.error(err_msg)
raise ValueError(err_msg) from e
34 changes: 34 additions & 0 deletions strategicwm/_src/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2025 The strategicwm Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Configuration for strategicwm."""

from pydantic_settings import BaseSettings

class Settings(BaseSettings):
"""Configuration settings for strategicwm."""

# LLM Client Retry Settings
RETRY_TRIES: int = 10
RETRY_DELAY: int = 10
RETRY_MAX_DELAY: int = 60
RETRY_BACKOFF: int = 2

# Default Model
DEFAULT_MODEL: str = "gemini-1.0-pro"

class Config:
env_prefix = "SWM_"

settings = Settings()
43 changes: 43 additions & 0 deletions strategicwm/_src/config_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2025 The strategicwm Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for configuration."""

import os
from absl.testing import absltest
from strategicwm._src import config

class ConfigTest(absltest.TestCase):

def test_defaults(self):
# Depending on order of tests and env vars, we might not get defaults.
# So we manually create an instance.
settings = config.Settings()
self.assertEqual(settings.RETRY_TRIES, 10)
self.assertEqual(settings.DEFAULT_MODEL, "gemini-1.0-pro")

def test_env_var_override(self):
os.environ["SWM_RETRY_TRIES"] = "5"
os.environ["SWM_DEFAULT_MODEL"] = "gemini-1.5-pro"

# Reload settings or create new instance
settings = config.Settings()
self.assertEqual(settings.RETRY_TRIES, 5)
self.assertEqual(settings.DEFAULT_MODEL, "gemini-1.5-pro")

del os.environ["SWM_RETRY_TRIES"]
del os.environ["SWM_DEFAULT_MODEL"]

if __name__ == "__main__":
absltest.main()
50 changes: 50 additions & 0 deletions strategicwm/_src/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2025 The strategicwm Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Logging utilities for strategicwm."""

import logging
import sys
from typing import Optional

def get_logger(name: str, level: Optional[int] = None) -> logging.Logger:
"""Get or create a logger with standardized configuration.

Args:
name: The name of the logger.
level: Optional logging level to set for this logger.

Returns:
A configured logging.Logger instance.
"""
logger = logging.getLogger(name)
if level is not None:
logger.setLevel(level)
return logger

def configure_logging(
level: int = logging.INFO,
format_string: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
) -> None:
"""Configure library-wide logging settings.

Args:
level: The default logging level.
format_string: The format string for log messages.
"""
logging.basicConfig(
level=level,
format=format_string,
handlers=[logging.StreamHandler(sys.stdout)]
)
52 changes: 52 additions & 0 deletions strategicwm/_src/logging_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2025 The strategicwm Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for logging utilities."""

import logging
from io import StringIO
from absl.testing import absltest
from strategicwm._src import logging_utils

class LoggingUtilsTest(absltest.TestCase):

def setUp(self):
super().setUp()
# Reset logging configuration
logging.root.handlers = []

def test_get_logger(self):
logger = logging_utils.get_logger("test_logger")
self.assertIsInstance(logger, logging.Logger)
self.assertEqual(logger.name, "test_logger")

def test_configure_logging(self):
# Capture stdout
captured_output = StringIO()
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
handler = logging.StreamHandler(captured_output)
handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))
root_logger.addHandler(handler)

logger = logging_utils.get_logger("test_config")
logger.info("Test INFO message")
logger.debug("Test DEBUG message")

output = captured_output.getvalue().strip()
self.assertIn("INFO: Test INFO message", output)
self.assertNotIn("DEBUG: Test DEBUG message", output)

if __name__ == "__main__":
absltest.main()
11 changes: 7 additions & 4 deletions strategicwm/_src/se/construction/bfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import networkx as nx
import numpy as np

from strategicwm._src import logging_utils
from strategicwm._src.se.construction import io

_LOGGER = logging_utils.get_logger(__name__)
from strategicwm._src.se.state import state as s

import tqdm.auto as tqdm
Expand Down Expand Up @@ -282,7 +285,7 @@ def bfs(
f"Starting async BFS (per child generation) with {num_workers} workers..."
)
if verbose:
print(msg, flush=True)
_LOGGER.info(msg)
if logger:
logger.info(msg)
if params:
Expand Down Expand Up @@ -332,7 +335,7 @@ def bfs(
node = future.result()
msg = f"Adding node's children to queue: {node.history_str()}"
if verbose:
print(msg, flush=True)
_LOGGER.info(msg)
if logger:
logger.info(msg)
nodes_visited_count += 1
Expand All @@ -358,7 +361,7 @@ def bfs(
f" generated an exception during thread execution: {exc}"
)
if verbose:
print(msg, flush=True)
_LOGGER.error(msg)
if logger:
logger.error(msg)

Expand All @@ -369,7 +372,7 @@ def bfs(
f"Total nodes visited: {nodes_visited_count}."
)
if verbose:
print(msg, flush=True)
_LOGGER.info(msg)
if logger:
logger.info(msg)
if pbar_nodes:
Expand Down
9 changes: 6 additions & 3 deletions strategicwm/_src/se/construction/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
import pydantic

from strategicwm._src import client_lib
from strategicwm._src import logging_utils
from strategicwm._src.se.construction import io

_LOGGER = logging_utils.get_logger(__name__)

from typing_extensions import Annotated


Expand Down Expand Up @@ -222,9 +225,9 @@ def get_params(game_json: dict[str, Any]) -> io.GameParamsA:
num_actions = len(game_tree_nx.nodes[node]["legal_actions_string"])
max_num_distinct_actions = max(max_num_distinct_actions, num_actions)
if len(players) != num_players:
print(
f"Warning! Players {players} found in game tree. Mismatch with"
f" {num_players} player descriptions."
_LOGGER.warning(
"Warning! Players %s found in game tree. Mismatch with %s player descriptions.",
players, num_players
)
num_players = len(players)
params = io.GameParamsA(
Expand Down