diff --git a/shiny/_main.py b/shiny/_main.py index f43421a71..c606c7ed8 100644 --- a/shiny/_main.py +++ b/shiny/_main.py @@ -555,14 +555,23 @@ def add() -> None: ) @click.option( "--provider", - type=click.Choice(["anthropic", "openai"]), + type=click.Choice(["anthropic", "openai", "bedrock-anthropic"]), default="anthropic", - help="AI provider to use for test generation.", + help=( + "AI provider to use for test generation. For 'bedrock-anthropic', " + "make sure your AWS credentials are configured (env vars, profile, or role) " + "and provide a Bedrock Anthropic model ID (e.g., " + "us.anthropic.claude-3-7-sonnet-20250219-v1:0)." + ), ) @click.option( "--model", type=str, - help="Specific model to use (optional). Examples: haiku3.5, sonnet, gpt-5, gpt-5-mini", + help=( + "Specific model to use (optional). Examples: haiku3.5, sonnet, gpt-5, gpt-5-mini; " + "or a Bedrock Anthropic model ID when using provider=bedrock-anthropic, e.g. " + "us.anthropic.claude-3-7-sonnet-20250219-v1:0" + ), ) # Param for app.py, param for test_name def test( diff --git a/shiny/_main_generate_test.py b/shiny/_main_generate_test.py index bdca96a7b..40f181d31 100644 --- a/shiny/_main_generate_test.py +++ b/shiny/_main_generate_test.py @@ -54,19 +54,27 @@ def validate_api_key(provider: str) -> None: "env_var": "OPENAI_API_KEY", "url": "https://platform.openai.com/api-keys", }, + "bedrock-anthropic": { + "env_var": None, + "url": "https://docs.aws.amazon.com/bedrock/latest/userguide/getting-started.html", + }, } if provider not in api_configs: raise ValidationError(f"Unsupported provider: {provider}") config = api_configs[provider] - if not os.getenv(config["env_var"]): - raise ValidationError( - f"{config['env_var']} environment variable is not set.\n" - f"Please set your {provider.title()} API key:\n" - f" export {config['env_var']}='your-api-key-here'\n\n" - f"Get your API key from: {config['url']}" - ) + if provider in ("anthropic", "openai"): + env_var = config["env_var"] # type: ignore[assignment] + if not isinstance(env_var, str) or not os.getenv(env_var): + raise ValidationError( + f"{env_var} environment variable is not set.\n" + f"Please set your {provider.title()} API key:\n" + f" export {env_var}='your-api-key-here'\n\n" + f"Get your API key from: {config['url']}" + ) + else: + pass def get_app_file_path(app_file: str | None) -> Path: diff --git a/shiny/pytest/_generate/_main.py b/shiny/pytest/_generate/_main.py index b178b0da6..3d2fbeddb 100644 --- a/shiny/pytest/_generate/_main.py +++ b/shiny/pytest/_generate/_main.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Literal, Optional, Tuple, Union -from chatlas import ChatAnthropic, ChatOpenAI, token_usage +from chatlas import ChatAnthropic, ChatBedrockAnthropic, ChatOpenAI, token_usage from dotenv import load_dotenv __all__ = [ @@ -50,7 +50,9 @@ class ShinyTestGenerator: def __init__( self, - provider: Literal["anthropic", "openai"] = Config.DEFAULT_PROVIDER, + provider: Literal[ + "anthropic", "openai", "bedrock-anthropic" + ] = Config.DEFAULT_PROVIDER, api_key: Optional[str] = None, log_file: str = Config.LOG_FILE, setup_logging: bool = True, @@ -74,25 +76,28 @@ def __init__( self.setup_logging() @property - def client(self) -> Union[ChatAnthropic, ChatOpenAI]: + def client(self) -> Union[ChatAnthropic, ChatOpenAI, ChatBedrockAnthropic]: """Lazy-loaded chat client based on provider""" if self._client is None: - if not self.api_key: - env_var = ( - "ANTHROPIC_API_KEY" - if self.provider == "anthropic" - else "OPENAI_API_KEY" - ) - self.api_key = os.getenv(env_var) - if not self.api_key: - raise ValueError( - f"Missing API key for provider '{self.provider}'. Set the environment variable " - f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key explicitly." - ) + if self.provider in ("anthropic", "openai"): + if not self.api_key: + env_var = ( + "ANTHROPIC_API_KEY" + if self.provider == "anthropic" + else "OPENAI_API_KEY" + ) + self.api_key = os.getenv(env_var) + if not self.api_key: + raise ValueError( + f"Missing API key for provider '{self.provider}'. Set the environment variable " + f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key explicitly." + ) if self.provider == "anthropic": self._client = ChatAnthropic(api_key=self.api_key) elif self.provider == "openai": self._client = ChatOpenAI(api_key=self.api_key) + elif self.provider == "bedrock-anthropic": + self._client = ChatBedrockAnthropic() else: raise ValueError(f"Unsupported provider: {self.provider}") return self._client @@ -118,6 +123,8 @@ def default_model(self) -> str: return Config.DEFAULT_ANTHROPIC_MODEL elif self.provider == "openai": return Config.DEFAULT_OPENAI_MODEL + elif self.provider == "bedrock-anthropic": + return Config.DEFAULT_BEDROCK_ANTHROPIC_MODEL else: raise ValueError(f"Unsupported provider: {self.provider}") @@ -168,6 +175,15 @@ def _resolve_model(self, model: str) -> str: def _validate_model_for_provider(self, model: str) -> str: """Validate that the model is compatible with the current provider""" + if self.provider == "bedrock-anthropic": + resolved_model = model + if resolved_model.startswith("gpt-") or resolved_model.startswith("o1-"): + raise ValueError( + f"Model '{model}' is an OpenAI model but provider is set to 'bedrock-anthropic'. " + f"Use an Anthropic Bedrock model ID (e.g., 'us.anthropic.claude-3-7-sonnet-20250219-v1:0')." + ) + return resolved_model + resolved_model = self._resolve_model(model) if self.provider == "anthropic": @@ -193,18 +209,19 @@ def get_llm_response(self, prompt: str, model: Optional[str] = None) -> str: model = self._validate_model_for_provider(model) try: - if not self.api_key: - env_var = ( - "ANTHROPIC_API_KEY" - if self.provider == "anthropic" - else "OPENAI_API_KEY" - ) - self.api_key = os.getenv(env_var) - if not self.api_key: - raise ValueError( - f"Missing API key for provider '{self.provider}'. Set the environment variable " - f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key." - ) + if self.provider in ("anthropic", "openai"): + if not self.api_key: + env_var = ( + "ANTHROPIC_API_KEY" + if self.provider == "anthropic" + else "OPENAI_API_KEY" + ) + self.api_key = os.getenv(env_var) + if not self.api_key: + raise ValueError( + f"Missing API key for provider '{self.provider}'. Set the environment variable " + f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key." + ) # Create chat client with the specified model if self.provider == "anthropic": chat = ChatAnthropic( @@ -219,6 +236,12 @@ def get_llm_response(self, prompt: str, model: Optional[str] = None) -> str: system_prompt=self.system_prompt, api_key=self.api_key, ) + elif self.provider == "bedrock-anthropic": + chat = ChatBedrockAnthropic( + model=model, + system_prompt=self.system_prompt, + max_tokens=Config.MAX_TOKENS, + ) else: raise ValueError(f"Unsupported provider: {self.provider}") @@ -226,15 +249,12 @@ def get_llm_response(self, prompt: str, model: Optional[str] = None) -> str: response = chat.chat(prompt) elapsed = time.perf_counter() - start_time usage = token_usage() - # For Anthropic, token_usage() includes costs. For OpenAI, use chat.get_cost with model pricing. token_price = None if self.provider == "openai": token_price = Config.OPENAI_PRICING.get(model) try: - # Call to compute and cache costs internally; per-entry cost is computed below _ = chat.get_cost(options="all", token_price=token_price) except Exception: - # If cost computation fails, continue without it pass try: @@ -530,7 +550,9 @@ def generate_test_from_code( ) def switch_provider( - self, provider: Literal["anthropic", "openai"], api_key: Optional[str] = None + self, + provider: Literal["anthropic", "openai", "bedrock-anthropic"], + api_key: Optional[str] = None, ): self.provider = provider if api_key: @@ -549,6 +571,11 @@ def create_openai_generator( ) -> "ShinyTestGenerator": return cls(provider="openai", api_key=api_key, **kwargs) + @classmethod + def create_bedrock_anthropic_generator(cls, **kwargs) -> "ShinyTestGenerator": + # AWS credentials and region are resolved from environment or AWS config + return cls(provider="bedrock-anthropic", api_key=None, **kwargs) + def get_available_models(self) -> list[str]: if self.provider == "anthropic": return [ @@ -562,6 +589,10 @@ def get_available_models(self) -> list[str]: for model in Config.MODEL_ALIASES.keys() if (model.startswith("gpt-") or model.startswith("o1-")) ] + elif self.provider == "bedrock-anthropic": + # Bedrock requires full model IDs (e.g., 'us.anthropic.claude-sonnet-4-20250514-v1:0'). + # We don't provide aliases here because IDs are region/account specific. + return [] else: return [] @@ -573,7 +604,7 @@ def cli(): parser.add_argument("app_file", help="Path to the Shiny app file") parser.add_argument( "--provider", - choices=["anthropic", "openai"], + choices=["anthropic", "openai", "bedrock-anthropic"], default=Config.DEFAULT_PROVIDER, help="LLM provider to use", )