Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions shiny/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,14 +555,23 @@ def add() -> None:
)
@click.option(
"--provider",
type=click.Choice(["anthropic", "openai"]),
type=click.Choice(["anthropic", "openai", "bedrock-anthropic"]),
default="anthropic",
help="AI provider to use for test generation.",
help=(
"AI provider to use for test generation. For 'bedrock-anthropic', "
"make sure your AWS credentials are configured (env vars, profile, or role) "
"and provide a Bedrock Anthropic model ID (e.g., "
"us.anthropic.claude-3-7-sonnet-20250219-v1:0)."
),
)
@click.option(
"--model",
type=str,
help="Specific model to use (optional). Examples: haiku3.5, sonnet, gpt-5, gpt-5-mini",
help=(
"Specific model to use (optional). Examples: haiku3.5, sonnet, gpt-5, gpt-5-mini; "
"or a Bedrock Anthropic model ID when using provider=bedrock-anthropic, e.g. "
"us.anthropic.claude-3-7-sonnet-20250219-v1:0"
),
)
# Param for app.py, param for test_name
def test(
Expand Down
22 changes: 15 additions & 7 deletions shiny/_main_generate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,27 @@ def validate_api_key(provider: str) -> None:
"env_var": "OPENAI_API_KEY",
"url": "https://platform.openai.com/api-keys",
},
"bedrock-anthropic": {
"env_var": None,
"url": "https://docs.aws.amazon.com/bedrock/latest/userguide/getting-started.html",
},
}

if provider not in api_configs:
raise ValidationError(f"Unsupported provider: {provider}")

config = api_configs[provider]
if not os.getenv(config["env_var"]):
raise ValidationError(
f"{config['env_var']} environment variable is not set.\n"
f"Please set your {provider.title()} API key:\n"
f" export {config['env_var']}='your-api-key-here'\n\n"
f"Get your API key from: {config['url']}"
)
if provider in ("anthropic", "openai"):
env_var = config["env_var"] # type: ignore[assignment]
if not isinstance(env_var, str) or not os.getenv(env_var):
raise ValidationError(
f"{env_var} environment variable is not set.\n"
f"Please set your {provider.title()} API key:\n"
f" export {env_var}='your-api-key-here'\n\n"
f"Get your API key from: {config['url']}"
)
else:
pass


def get_app_file_path(app_file: str | None) -> Path:
Expand Down
95 changes: 63 additions & 32 deletions shiny/pytest/_generate/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
from typing import Literal, Optional, Tuple, Union

from chatlas import ChatAnthropic, ChatOpenAI, token_usage
from chatlas import ChatAnthropic, ChatBedrockAnthropic, ChatOpenAI, token_usage
from dotenv import load_dotenv

__all__ = [
Expand Down Expand Up @@ -50,7 +50,9 @@ class ShinyTestGenerator:

def __init__(
self,
provider: Literal["anthropic", "openai"] = Config.DEFAULT_PROVIDER,
provider: Literal[
"anthropic", "openai", "bedrock-anthropic"
] = Config.DEFAULT_PROVIDER,
api_key: Optional[str] = None,
log_file: str = Config.LOG_FILE,
setup_logging: bool = True,
Expand All @@ -74,25 +76,28 @@ def __init__(
self.setup_logging()

@property
def client(self) -> Union[ChatAnthropic, ChatOpenAI]:
def client(self) -> Union[ChatAnthropic, ChatOpenAI, ChatBedrockAnthropic]:
"""Lazy-loaded chat client based on provider"""
if self._client is None:
if not self.api_key:
env_var = (
"ANTHROPIC_API_KEY"
if self.provider == "anthropic"
else "OPENAI_API_KEY"
)
self.api_key = os.getenv(env_var)
if not self.api_key:
raise ValueError(
f"Missing API key for provider '{self.provider}'. Set the environment variable "
f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key explicitly."
)
if self.provider in ("anthropic", "openai"):
if not self.api_key:
env_var = (
"ANTHROPIC_API_KEY"
if self.provider == "anthropic"
else "OPENAI_API_KEY"
)
self.api_key = os.getenv(env_var)
if not self.api_key:
raise ValueError(
f"Missing API key for provider '{self.provider}'. Set the environment variable "
f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key explicitly."
)
if self.provider == "anthropic":
self._client = ChatAnthropic(api_key=self.api_key)
elif self.provider == "openai":
self._client = ChatOpenAI(api_key=self.api_key)
elif self.provider == "bedrock-anthropic":
self._client = ChatBedrockAnthropic()
else:
raise ValueError(f"Unsupported provider: {self.provider}")
return self._client
Expand All @@ -118,6 +123,8 @@ def default_model(self) -> str:
return Config.DEFAULT_ANTHROPIC_MODEL
elif self.provider == "openai":
return Config.DEFAULT_OPENAI_MODEL
elif self.provider == "bedrock-anthropic":
return Config.DEFAULT_BEDROCK_ANTHROPIC_MODEL
else:
raise ValueError(f"Unsupported provider: {self.provider}")

Expand Down Expand Up @@ -168,6 +175,15 @@ def _resolve_model(self, model: str) -> str:

def _validate_model_for_provider(self, model: str) -> str:
"""Validate that the model is compatible with the current provider"""
if self.provider == "bedrock-anthropic":
resolved_model = model
if resolved_model.startswith("gpt-") or resolved_model.startswith("o1-"):
raise ValueError(
f"Model '{model}' is an OpenAI model but provider is set to 'bedrock-anthropic'. "
f"Use an Anthropic Bedrock model ID (e.g., 'us.anthropic.claude-3-7-sonnet-20250219-v1:0')."
)
return resolved_model

resolved_model = self._resolve_model(model)

if self.provider == "anthropic":
Expand All @@ -193,18 +209,19 @@ def get_llm_response(self, prompt: str, model: Optional[str] = None) -> str:
model = self._validate_model_for_provider(model)

try:
if not self.api_key:
env_var = (
"ANTHROPIC_API_KEY"
if self.provider == "anthropic"
else "OPENAI_API_KEY"
)
self.api_key = os.getenv(env_var)
if not self.api_key:
raise ValueError(
f"Missing API key for provider '{self.provider}'. Set the environment variable "
f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key."
)
if self.provider in ("anthropic", "openai"):
if not self.api_key:
env_var = (
"ANTHROPIC_API_KEY"
if self.provider == "anthropic"
else "OPENAI_API_KEY"
)
self.api_key = os.getenv(env_var)
if not self.api_key:
raise ValueError(
f"Missing API key for provider '{self.provider}'. Set the environment variable "
f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key."
)
# Create chat client with the specified model
if self.provider == "anthropic":
chat = ChatAnthropic(
Expand All @@ -219,22 +236,25 @@ def get_llm_response(self, prompt: str, model: Optional[str] = None) -> str:
system_prompt=self.system_prompt,
api_key=self.api_key,
)
elif self.provider == "bedrock-anthropic":
chat = ChatBedrockAnthropic(
model=model,
system_prompt=self.system_prompt,
max_tokens=Config.MAX_TOKENS,
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")

start_time = time.perf_counter()
response = chat.chat(prompt)
elapsed = time.perf_counter() - start_time
usage = token_usage()
# For Anthropic, token_usage() includes costs. For OpenAI, use chat.get_cost with model pricing.
token_price = None
if self.provider == "openai":
token_price = Config.OPENAI_PRICING.get(model)
try:
# Call to compute and cache costs internally; per-entry cost is computed below
_ = chat.get_cost(options="all", token_price=token_price)
except Exception:
# If cost computation fails, continue without it
pass

try:
Expand Down Expand Up @@ -530,7 +550,9 @@ def generate_test_from_code(
)

def switch_provider(
self, provider: Literal["anthropic", "openai"], api_key: Optional[str] = None
self,
provider: Literal["anthropic", "openai", "bedrock-anthropic"],
api_key: Optional[str] = None,
):
self.provider = provider
if api_key:
Expand All @@ -549,6 +571,11 @@ def create_openai_generator(
) -> "ShinyTestGenerator":
return cls(provider="openai", api_key=api_key, **kwargs)

@classmethod
def create_bedrock_anthropic_generator(cls, **kwargs) -> "ShinyTestGenerator":
# AWS credentials and region are resolved from environment or AWS config
return cls(provider="bedrock-anthropic", api_key=None, **kwargs)

def get_available_models(self) -> list[str]:
if self.provider == "anthropic":
return [
Expand All @@ -562,6 +589,10 @@ def get_available_models(self) -> list[str]:
for model in Config.MODEL_ALIASES.keys()
if (model.startswith("gpt-") or model.startswith("o1-"))
]
elif self.provider == "bedrock-anthropic":
# Bedrock requires full model IDs (e.g., 'us.anthropic.claude-sonnet-4-20250514-v1:0').
# We don't provide aliases here because IDs are region/account specific.
return []
else:
return []

Expand All @@ -573,7 +604,7 @@ def cli():
parser.add_argument("app_file", help="Path to the Shiny app file")
parser.add_argument(
"--provider",
choices=["anthropic", "openai"],
choices=["anthropic", "openai", "bedrock-anthropic"],
default=Config.DEFAULT_PROVIDER,
help="LLM provider to use",
)
Expand Down
Loading