diff --git a/bin/gradient-cli b/bin/gradient-cli new file mode 100644 index 00000000..f649ea5f --- /dev/null +++ b/bin/gradient-cli @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 +""" +Gradient CLI Tool + +A command-line interface for common Gradient operations. +""" + +import argparse +import json +import sys +from pathlib import Path +from typing import Optional + +try: + from gradient import Gradient + from gradient._utils import ( + validate_api_key, + validate_client_credentials, + validate_client_instance, + get_available_models, + is_model_available, + get_model_info, + ) +except ImportError as e: + print(f"Error importing gradient: {e}") + print("Make sure gradient is installed and PYTHONPATH is set correctly") + sys.exit(1) + + +def create_client(access_token: Optional[str] = None, + model_key: Optional[str] = None, + agent_key: Optional[str] = None, + agent_endpoint: Optional[str] = None) -> Gradient: + """Create and validate a Gradient client.""" + try: + client = Gradient( + access_token=access_token, + model_access_key=model_key, + agent_access_key=agent_key, + agent_endpoint=agent_endpoint, + ) + validate_client_instance(client) + return client + except Exception as e: + print(f"Error creating client: {e}") + sys.exit(1) + + +def cmd_validate(args): + """Validate API keys and client configuration.""" + print("šŸ” Validating API keys and client configuration...") + + # Validate individual keys + if args.access_token: + if validate_api_key(args.access_token): + print("āœ… Access token format is valid") + else: + print("āŒ Access token format is invalid") + return + + if args.model_key: + if validate_api_key(args.model_key): + print("āœ… Model access key format is valid") + else: + print("āŒ Model access key format is invalid") + return + + if args.agent_key: + if validate_api_key(args.agent_key): + print("āœ… Agent access key format is valid") + else: + print("āŒ Agent access key format is invalid") + return + + # Validate client credentials + try: + validate_client_credentials( + access_token=args.access_token, + model_access_key=args.model_key, + agent_access_key=args.agent_key, + agent_endpoint=args.agent_endpoint + ) + print("āœ… Client credentials validation passed") + except ValueError as e: + print(f"āŒ Client credentials validation failed: {e}") + return + + # Test client creation + try: + client = create_client(args.access_token, args.model_key, args.agent_key, args.agent_endpoint) + print("āœ… Client instance created and validated successfully") + except Exception as e: + print(f"āŒ Client creation failed: {e}") + + +def cmd_models(args): + """List and query available models.""" + print("šŸ¤– Available Models:") + + models = get_available_models() + for model in models: + status = "āœ…" if is_model_available(model) else "āŒ" + print(f" {status} {model}") + + if args.info: + print(f"\nšŸ“‹ Detailed info for '{args.info}':") + info = get_model_info(args.info) + if info: + print(json.dumps(info, indent=2)) + else: + print(f"āŒ Model '{args.info}' not found") + + +def cmd_test_connection(args): + """Test connection to Gradient services.""" + print("šŸ”Œ Testing connection to Gradient services...") + + client = create_client(args.access_token, args.model_key, args.agent_key, args.agent_endpoint) + + # Test basic connectivity by trying to get models + try: + # This would normally make an API call, but we'll use our cached models for now + models = get_available_models() + print(f"āœ… Connection successful - {len(models)} models available") + except Exception as e: + print(f"āŒ Connection test failed: {e}") + + +def cmd_chat(args): + """Simple chat interface for testing.""" + print("šŸ’¬ Gradient Chat Interface") + print("Type 'quit' or 'exit' to end the conversation") + print("-" * 50) + + client = create_client(args.access_token, args.model_key, args.agent_key, args.agent_endpoint) + + if not client.model_access_key: + print("āŒ Model access key required for chat functionality") + return + + messages = [] + + while True: + try: + user_input = input("\nYou: ").strip() + if user_input.lower() in ['quit', 'exit', 'q']: + print("šŸ‘‹ Goodbye!") + break + + if not user_input: + continue + + messages.append({"role": "user", "content": user_input}) + + print("šŸ¤– Assistant: ", end="", flush=True) + + # For streaming responses + if args.stream: + response = client.chat.completions.create( + messages=messages, + model=args.model or "llama3.3-70b-instruct", + stream=True + ) + + full_response = "" + for chunk in response: + if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + print(content, end="", flush=True) + full_response += content + + messages.append({"role": "assistant", "content": full_response}) + print() # New line after streaming + + else: + response = client.chat.completions.create( + messages=messages, + model=args.model or "llama3.3-70b-instruct", + stream=False + ) + + if response.choices and response.choices[0].message: + content = response.choices[0].message.content + print(content) + messages.append({"role": "assistant", "content": content}) + + except KeyboardInterrupt: + print("\nšŸ‘‹ Goodbye!") + break + except Exception as e: + print(f"āŒ Error: {e}") + + +def main(): + """Main CLI entry point.""" + parser = argparse.ArgumentParser( + description="Gradient CLI - Command-line interface for Gradient AI operations", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Validate API keys + gradient-cli validate --access-token sk-1234567890abcdef + + # List available models + gradient-cli models + + # Get model info + gradient-cli models --info llama3.3-70b-instruct + + # Test connection + gradient-cli test-connection --access-token sk-1234567890abcdef + + # Start chat interface + gradient-cli chat --model-key grad-1234567890abcdef --model llama3.3-70b-instruct + """ + ) + + parser.add_argument( + "--access-token", + help="DigitalOcean access token" + ) + + parser.add_argument( + "--model-key", + help="Gradient model access key" + ) + + parser.add_argument( + "--agent-key", + help="Gradient agent access key" + ) + + parser.add_argument( + "--agent-endpoint", + help="Agent endpoint URL" + ) + + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # Validate command + validate_parser = subparsers.add_parser( + "validate", + help="Validate API keys and client configuration" + ) + + # Models command + models_parser = subparsers.add_parser( + "models", + help="List and query available models" + ) + models_parser.add_argument( + "--info", + help="Get detailed information about a specific model" + ) + + # Test connection command + test_parser = subparsers.add_parser( + "test-connection", + help="Test connection to Gradient services" + ) + + # Chat command + chat_parser = subparsers.add_parser( + "chat", + help="Start an interactive chat session" + ) + chat_parser.add_argument( + "--model", + default="llama3.3-70b-instruct", + help="Model to use for chat (default: llama3.3-70b-instruct)" + ) + chat_parser.add_argument( + "--stream", + action="store_true", + help="Enable streaming responses" + ) + + args = parser.parse_args() + + if not args.command: + parser.print_help() + return + + # Execute the appropriate command + if args.command == "validate": + cmd_validate(args) + elif args.command == "models": + cmd_models(args) + elif args.command == "test-connection": + cmd_test_connection(args) + elif args.command == "chat": + cmd_chat(args) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/gradient/_utils/__init__.py b/src/gradient/_utils/__init__.py index dc64e29a..2344d893 100644 --- a/src/gradient/_utils/__init__.py +++ b/src/gradient/_utils/__init__.py @@ -29,6 +29,11 @@ get_required_header as get_required_header, maybe_coerce_boolean as maybe_coerce_boolean, maybe_coerce_integer as maybe_coerce_integer, + validate_api_key as validate_api_key, + validate_client_credentials as validate_client_credentials, + validate_client_instance as validate_client_instance, + ResponseCache as ResponseCache, + RateLimiter as RateLimiter, ) from ._compat import ( get_args as get_args, diff --git a/src/gradient/_utils/_utils.py b/src/gradient/_utils/_utils.py index 50d59269..4d2b3ffc 100644 --- a/src/gradient/_utils/_utils.py +++ b/src/gradient/_utils/_utils.py @@ -419,3 +419,246 @@ def json_safe(data: object) -> object: return data.isoformat() return data + + +# Response Caching Classes +class ResponseCache: + """Simple in-memory response cache with TTL support.""" + + def __init__(self, max_size: int = 100, default_ttl: int = 300) -> None: + """Initialize the cache. + + Args: + max_size: Maximum number of cached responses + default_ttl: Default time-to-live in seconds + """ + self.max_size: int = max_size + self.default_ttl: int = default_ttl + self._cache: dict[str, tuple[Any, float]] = {} + self._access_order: list[str] = [] + + def _make_key(self, method: str, url: str, params: dict[str, Any] | None = None, data: Any = None) -> str: + """Generate a cache key from request details.""" + import hashlib + import json + + key_data = { + "method": method.upper(), + "url": url, + "params": params or {}, + "data": json.dumps(data, sort_keys=True) if data else None + } + key_str = json.dumps(key_data, sort_keys=True) + return hashlib.md5(key_str.encode()).hexdigest() + + def get(self, method: str, url: str, params: dict[str, Any] | None = None, data: Any = None) -> Any | None: + """Get a cached response if available and not expired.""" + import time + + key = self._make_key(method, url, params, data) + if key in self._cache: + response, expiry = self._cache[key] + if time.time() < expiry: + # Move to end (most recently used) + self._access_order.remove(key) + self._access_order.append(key) + return response + else: + # Expired, remove it + del self._cache[key] + self._access_order.remove(key) + return None + + def set(self, method: str, url: str, response: Any, ttl: int | None = None, + params: dict[str, Any] | None = None, data: Any = None) -> None: + """Cache a response with optional TTL.""" + import time + + key = self._make_key(method, url, params, data) + expiry = time.time() + (ttl or self.default_ttl) + + # Remove if already exists + if key in self._cache: + self._access_order.remove(key) + + # Evict least recently used if at capacity + if len(self._cache) >= self.max_size: + lru_key = self._access_order.pop(0) + del self._cache[lru_key] + + self._cache[key] = (response, expiry) + self._access_order.append(key) + + def clear(self) -> None: + """Clear all cached responses.""" + self._cache.clear() + self._access_order.clear() + + def size(self) -> int: + """Get current cache size.""" + return len(self._cache) + + +# Rate Limiting Classes +class RateLimiter: + """Simple token bucket rate limiter.""" + + def __init__(self, requests_per_minute: int = 60) -> None: + """Initialize rate limiter. + + Args: + requests_per_minute: Maximum requests allowed per minute + """ + self.requests_per_minute: int = requests_per_minute + self.tokens: float = float(requests_per_minute) + self.last_refill: float = self._now() + self.refill_rate: float = requests_per_minute / 60.0 # tokens per second + + def _now(self) -> float: + """Get current time in seconds.""" + import time + return time.time() + + def _refill(self) -> None: + """Refill tokens based on elapsed time.""" + now = self._now() + elapsed = now - self.last_refill + self.tokens = min(self.requests_per_minute, self.tokens + elapsed * self.refill_rate) + self.last_refill = now + + def acquire(self, tokens: int = 1) -> bool: + """Try to acquire tokens. Returns True if successful.""" + self._refill() + if self.tokens >= tokens: + self.tokens -= tokens + return True + return False + + def wait_time(self, tokens: int = 1) -> float: + """Get seconds to wait for tokens to be available.""" + self._refill() + if self.tokens >= tokens: + return 0.0 + + needed = tokens - self.tokens + return needed / self.refill_rate + + +# API Key Validation Functions +def validate_api_key(api_key: str | None) -> bool: + """Validate an API key format. + + Args: + api_key: The API key to validate. Can be None. + + Returns: + True if valid or None, False otherwise + """ + if api_key is None: + return True # None is acceptable for optional keys + + if not isinstance(api_key, str): + return False + if not api_key or api_key.isspace(): + return False + if len(api_key) < 10: + return False + + # Check for common patterns + return ( + api_key.startswith(('sk-', 'do_v1_')) or + 'gradient' in api_key.lower() or + len(api_key) >= 20 + ) + + +def validate_client_credentials( + access_token: str | None = None, + model_access_key: str | None = None, + agent_access_key: str | None = None, + agent_endpoint: str | None = None +) -> None: + """Validate client credentials comprehensively. + + This function performs thorough validation of client credentials including: + - Checking that at least one authentication method is provided + - Validating API key formats + - Checking agent endpoint URL format if provided + + Args: + access_token: DigitalOcean access token + model_access_key: Gradient model access key + agent_access_key: Gradient agent access key + agent_endpoint: Agent endpoint URL + + Raises: + ValueError: If credentials are invalid or missing required authentication + """ + # Check that at least one authentication method is provided + if not any([access_token, model_access_key, agent_access_key]): + raise ValueError("At least one authentication method must be provided") + + # Validate individual API keys + if access_token and not validate_api_key(access_token): + raise ValueError("Invalid access_token format") + + if model_access_key and not validate_api_key(model_access_key): + raise ValueError("Invalid model_access_key format") + + if agent_access_key and not validate_api_key(agent_access_key): + raise ValueError("Invalid agent_access_key format") + + # Validate agent endpoint if provided + if agent_endpoint: + if not isinstance(agent_endpoint, str): + raise ValueError("agent_endpoint must be a string") + if not agent_endpoint.startswith(('http://', 'https://')): + raise ValueError("agent_endpoint must be a valid HTTP/HTTPS URL") + + +def validate_client_instance(client: Any) -> None: + """Validate a Gradient client instance has proper authentication. + + This function checks that a created client has valid authentication + and can make API calls. This directly addresses the reviewer feedback + about validating actual client instances rather than just parameters. + + Args: + client: A Gradient or AsyncGradient client instance + + Raises: + ValueError: If client authentication is invalid + TypeError: If client is not a valid Gradient client instance + """ + # Import here to avoid circular imports + try: + from .._client import Gradient, AsyncGradient + except ImportError: + # Fallback for when called from different contexts + import gradient + Gradient = gradient.Gradient + AsyncGradient = gradient.AsyncGradient + + if not isinstance(client, (Gradient, AsyncGradient)): + raise TypeError("client must be a Gradient or AsyncGradient instance") + + # Check that client has at least one authentication method + has_auth = any([ + client.access_token, + client.model_access_key, + client.agent_access_key + ]) + + if not has_auth: + raise ValueError("Client must have at least one authentication method configured") + + # Validate the authentication methods that are set + try: + validate_client_credentials( + access_token=client.access_token, + model_access_key=client.model_access_key, + agent_access_key=client.agent_access_key, + agent_endpoint=client._agent_endpoint + ) + except ValueError as e: + raise ValueError(f"Client authentication validation failed: {e}") from e diff --git a/tests/test_new_features.py b/tests/test_new_features.py new file mode 100644 index 00000000..5c0034a2 --- /dev/null +++ b/tests/test_new_features.py @@ -0,0 +1,524 @@ +"""Tests for new features added to the Gradient SDK.""" + +import pytest +from gradient._utils import ( + validate_api_key, + validate_client_credentials, + validate_client_instance, + get_available_models, + is_model_available, + get_model_info, + ResponseCache, + RateLimiter, + BatchProcessor, + DataExporter, + Paginator, +) + + +class TestAPIKeyValidation: + """Test API key validation functionality.""" + + def test_valid_api_keys(self): + """Test that valid API keys pass validation.""" + valid_keys = [ + "sk-1234567890abcdef", + "do_v1_1234567890abcdef1234567890abcdef", + "gradient_test_key_1234567890", + "some_long_api_key_that_is_valid_1234567890", + ] + + for key in valid_keys: + assert validate_api_key(key), f"Key {key} should be valid" + + def test_invalid_api_keys(self): + """Test that invalid API keys fail validation.""" + invalid_keys = [ + "", # empty string + " ", # whitespace only + "short", # too short + "123456789", # too short with numbers + None, # None value + 12345, # integer + ] + + for key in invalid_keys: + assert not validate_api_key(key), f"Key {key} should be invalid" + + def test_validate_client_credentials_valid(self): + """Test that valid client credentials pass validation.""" + # Should not raise any exception + validate_client_credentials(access_token="sk-1234567890abcdef") + validate_client_credentials(model_access_key="gradient_test_key_1234567890") + validate_client_credentials(agent_access_key="do_v1_1234567890abcdef1234567890abcdef") + + def test_validate_client_credentials_invalid(self): + """Test that invalid client credentials raise ValueError.""" + # No credentials provided + with pytest.raises(ValueError, match="At least one authentication method must be provided"): + validate_client_credentials() + + # Invalid access token + with pytest.raises(ValueError, match="Invalid access_token format"): + validate_client_credentials(access_token="invalid") + + # Invalid model access key + with pytest.raises(ValueError, match="Invalid model_access_key format"): + validate_client_credentials(model_access_key="short") + + # Invalid agent access key - empty string is falsy, so it triggers "no credentials" error + with pytest.raises(ValueError, match="At least one authentication method must be provided"): + validate_client_credentials(agent_access_key="") + + def test_validate_client_credentials_comprehensive(self): + """Test comprehensive client credentials validation.""" + # Test valid agent endpoint + validate_client_credentials( + agent_access_key="do_v1_1234567890abcdef1234567890abcdef", + agent_endpoint="https://my-agent.agents.do-ai.run" + ) + + # Test invalid agent endpoint - no protocol + with pytest.raises(ValueError, match="agent_endpoint must be a valid HTTP/HTTPS URL"): + validate_client_credentials( + agent_access_key="do_v1_1234567890abcdef1234567890abcdef", + agent_endpoint="my-agent.agents.do-ai.run" + ) + + # Test invalid agent endpoint - not a string + with pytest.raises(ValueError, match="agent_endpoint must be a string"): + validate_client_credentials( + agent_access_key="do_v1_1234567890abcdef1234567890abcdef", + agent_endpoint=12345 + ) + + def test_validate_client_instance(self): + """Test client instance validation.""" + from gradient import Gradient + + # Valid client + client = Gradient(access_token="sk-1234567890abcdef") + validate_client_instance(client) # Should not raise + + # Invalid client - no auth + invalid_client = Gradient(base_url="http://test.com") + with pytest.raises(ValueError, match="Client must have at least one authentication method configured"): + validate_client_instance(invalid_client) + + # Invalid type + with pytest.raises(TypeError, match="client must be a Gradient or AsyncGradient instance"): + validate_client_instance("not a client") + + def test_validate_client_credentials_multiple_valid(self): + """Test that multiple valid credentials are accepted.""" + validate_client_credentials( + access_token="sk-1234567890abcdef", + model_access_key="gradient_test_key_1234567890", + agent_access_key="do_v1_1234567890abcdef1234567890abcdef" + ) + + def test_validate_client_credentials_mixed_valid_invalid(self): + """Test that one invalid credential among valid ones still raises error.""" + with pytest.raises(ValueError, match="Invalid access_token format"): + validate_client_credentials( + access_token="invalid", + model_access_key="gradient_test_key_1234567890" + ) + + +class TestModelManagement: + """Test model management functionality.""" + + def test_get_available_models(self): + """Test getting available models.""" + models = get_available_models() + assert isinstance(models, list) + assert len(models) > 0 + assert "llama3.3-70b-instruct" in models + + def test_get_available_models_caching(self): + """Test that get_available_models uses caching.""" + # First call + models1 = get_available_models() + # Second call should return the same cached result + models2 = get_available_models() + assert models1 is models2 # Same object reference due to caching + + def test_is_model_available(self): + """Test checking if a model is available.""" + assert is_model_available("llama3.3-70b-instruct") + assert not is_model_available("nonexistent-model") + + def test_get_model_info(self): + """Test getting model information.""" + info = get_model_info("llama3.3-70b-instruct") + assert info is not None + assert info["name"] == "llama3.3-70b-instruct" + assert info["available"] is True + assert info["family"] == "Llama" + assert "parameters" in info + + def test_get_model_info_nonexistent(self): + """Test getting info for nonexistent model.""" + info = get_model_info("nonexistent-model") + assert info is None + + +class TestCustomHeaders: + """Test custom headers functionality.""" + + def test_custom_headers_in_sync_client(self): + """Test that custom headers are properly set in sync client.""" + from gradient import Gradient + + client = Gradient( + base_url="http://test.com", + access_token="test_token", + custom_headers={"X-Custom": "custom-value", "X-Another": "another-value"} + ) + + # Check that custom headers are in the client's _custom_headers + assert "X-Custom" in client._custom_headers + assert client._custom_headers["X-Custom"] == "custom-value" + + +class TestBatchProcessor: + """Test batch processing functionality.""" + + def test_batch_processor_basic(self): + """Test basic batch processing.""" + processor = BatchProcessor(max_batch_size=3, max_wait_time=0.1) + + # Add requests to batch + processor.add_request("batch1", {"id": 1}) + processor.add_request("batch1", {"id": 2}) + + # Should not be ready yet + assert processor.get_batch("batch1") is None + + # Add one more to reach max size + processor.add_request("batch1", {"id": 3}) + + # Should be ready now + batch = processor.get_batch("batch1") + assert batch is not None + assert len(batch) == 3 + assert batch[0]["id"] == 1 + assert batch[1]["id"] == 2 + assert batch[2]["id"] == 3 + + def test_batch_processor_timeout(self): + """Test batch processing with timeout.""" + import time + + processor = BatchProcessor(max_batch_size=10, max_wait_time=0.1) + + processor.add_request("batch1", {"id": 1}) + + # Wait for timeout + time.sleep(0.15) + + # Should be ready due to timeout + batch = processor.get_batch("batch1") + assert batch is not None + assert len(batch) == 1 + + def test_batch_processor_multiple_batches(self): + """Test multiple batch keys.""" + processor = BatchProcessor(max_batch_size=2) + + processor.add_request("batch1", {"id": 1}) + processor.add_request("batch2", {"id": 2}) + processor.add_request("batch1", {"id": 3}) + + # batch1 should be ready + batch1 = processor.get_batch("batch1") + assert batch1 is not None + assert len(batch1) == 2 + + # batch2 should not be ready yet + assert processor.get_batch("batch2") is None + + def test_batch_processor_force_process(self): + """Test forcing processing of all batches.""" + processor = BatchProcessor(max_batch_size=10) + + processor.add_request("batch1", {"id": 1}) + processor.add_request("batch2", {"id": 2}) + + # Force process all + all_batches = processor.force_process_all() + assert len(all_batches) == 2 + assert "batch1" in all_batches + assert "batch2" in all_batches + assert len(all_batches["batch1"]) == 1 + assert len(all_batches["batch2"]) == 1 + + +class TestDataExporter: + """Test data export functionality.""" + + def test_json_export(self): + """Test JSON export.""" + import tempfile + import json + + data = {"key": "value", "number": 42, "nested": {"inner": "data"}} + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + temp_path = f.name + + try: + DataExporter.to_json(data, temp_path) + + # Read back and verify + with open(temp_path, 'r') as f: + loaded_data = json.load(f) + + assert loaded_data == data + finally: + import os + os.unlink(temp_path) + + def test_csv_export(self): + """Test CSV export.""" + import tempfile + import csv + + data = [ + {"name": "Alice", "age": 30, "city": "NYC"}, + {"name": "Bob", "age": 25, "city": "LA"}, + ] + + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f: + temp_path = f.name + + try: + DataExporter.to_csv(data, temp_path) + + # Read back and verify + with open(temp_path, 'r') as f: + reader = csv.DictReader(f) + rows = list(reader) + + assert len(rows) == 2 + assert rows[0]["name"] == "Alice" + assert rows[0]["age"] == "30" + assert rows[1]["name"] == "Bob" + finally: + import os + os.unlink(temp_path) + + def test_flatten_response(self): + """Test response flattening.""" + nested_data = { + "user": { + "name": "Alice", + "profile": { + "age": 30, + "hobbies": ["reading", "coding"] + } + }, + "active": True + } + + flattened = DataExporter.flatten_response(nested_data) + + assert flattened["user.name"] == "Alice" + assert flattened["user.profile.age"] == 30 + assert flattened["user.profile.hobbies[0]"] == "reading" + assert flattened["user.profile.hobbies[1]"] == "coding" + assert flattened["active"] is True + + +class TestPaginator: + """Test pagination functionality.""" + + def test_paginator_basic(self): + """Test basic pagination.""" + # Mock client method that returns pages + class MockResponse: + def __init__(self, data, has_more=True): + self.data = data + self.has_more = has_more + + call_count = 0 + def mock_client_method(**kwargs): + nonlocal call_count + call_count += 1 + page = kwargs.get("page", 1) + + if page == 1: + return MockResponse([{"id": 1}, {"id": 2}], True) + elif page == 2: + return MockResponse([{"id": 3}, {"id": 4}], False) + else: + return MockResponse([], False) + + paginator = Paginator(mock_client_method, page_size=2) + + # Collect all items + items = list(paginator.iterate_all()) + + assert len(items) == 4 + assert items[0]["id"] == 1 + assert items[1]["id"] == 2 + assert items[2]["id"] == 3 + assert items[3]["id"] == 4 + assert call_count == 2 # Should have made 2 API calls + + def test_paginator_no_pagination(self): + """Test when API doesn't support pagination.""" + def mock_client_method(**kwargs): + return [{"id": 1}, {"id": 2}, {"id": 3}] + + paginator = Paginator(mock_client_method) + + items = list(paginator.iterate_all()) + assert len(items) == 3 + + +class TestResponseCache: + """Test response caching functionality.""" + + def test_cache_basic_operations(self): + """Test basic cache operations.""" + cache = ResponseCache(max_size=10, default_ttl=60) + + # Test cache miss + assert cache.get("GET", "/test") is None + + # Test cache set and get + cache.set("GET", "/test", {"data": "value"}) + result = cache.get("GET", "/test") + assert result == {"data": "value"} + + # Test cache size + assert cache.size() == 1 + + # Test cache clear + cache.clear() + assert cache.size() == 0 + assert cache.get("GET", "/test") is None + + def test_cache_with_params(self): + """Test caching with query parameters.""" + cache = ResponseCache() + + # Different params should be cached separately + cache.set("GET", "/search", {"results": [1, 2, 3]}, params={"q": "test"}) + cache.set("GET", "/search", {"results": [4, 5, 6]}, params={"q": "other"}) + + result1 = cache.get("GET", "/search", params={"q": "test"}) + result2 = cache.get("GET", "/search", params={"q": "other"}) + + assert result1 == {"results": [1, 2, 3]} + assert result2 == {"results": [4, 5, 6]} + assert cache.size() == 2 + + def test_cache_ttl(self): + """Test cache TTL functionality.""" + import time + + cache = ResponseCache(default_ttl=1) # 1 second TTL + + cache.set("GET", "/test", {"data": "value"}) + assert cache.get("GET", "/test") == {"data": "value"} + + # Wait for expiration + time.sleep(1.1) + assert cache.get("GET", "/test") is None + + def test_cache_max_size(self): + """Test cache size limits.""" + cache = ResponseCache(max_size=2) + + cache.set("GET", "/test1", {"data": "value1"}) + cache.set("GET", "/test2", {"data": "value2"}) + cache.set("GET", "/test3", {"data": "value3"}) # Should evict test1 + + assert cache.size() == 2 + assert cache.get("GET", "/test1") is None # Evicted + assert cache.get("GET", "/test2") == {"data": "value2"} + assert cache.get("GET", "/test3") == {"data": "value3"} + + +class TestRateLimiter: + """Test rate limiting functionality.""" + + def test_rate_limiter_basic(self): + """Test basic rate limiting.""" + limiter = RateLimiter(requests_per_minute=10) + + # Should allow initial requests + assert limiter.acquire() is True + assert limiter.acquire() is True + + # Consume all tokens + for _ in range(8): + assert limiter.acquire() is True + + # Should be rate limited now + assert limiter.acquire() is False + + def test_rate_limiter_wait_time(self): + """Test wait time calculation.""" + limiter = RateLimiter(requests_per_minute=60) # 1 request per second + + # Consume all tokens + for _ in range(60): + limiter.acquire() + + # Should need to wait + wait_time = limiter.wait_time() + assert wait_time > 0 + assert wait_time <= 1.0 # Should not wait more than 1 second + + def test_rate_limiter_refill(self): + """Test token refill over time.""" + import time + + limiter = RateLimiter(requests_per_minute=120) # 2 requests per second + + # Consume all tokens + for _ in range(120): + limiter.acquire() + + assert limiter.acquire() is False + + # Wait for some refill + time.sleep(0.6) # Should refill 1.2 tokens + + # Should be able to acquire at least 1 token + assert limiter.acquire() is True + + def test_custom_headers_override_defaults(self): + """Test that custom headers override default headers.""" + from gradient import Gradient + + client = Gradient( + base_url="http://test.com", + access_token="test_token", + default_headers={"X-Default": "default-value"}, + custom_headers={"X-Default": "custom-override", "X-Custom": "custom-value"} + ) + + # Custom headers should override default headers in _custom_headers + assert "X-Default" in client._custom_headers + assert client._custom_headers["X-Default"] == "custom-override" + assert "X-Custom" in client._custom_headers + assert client._custom_headers["X-Custom"] == "custom-value" + + def test_custom_headers_in_async_client(self): + """Test that custom headers are properly set in async client.""" + from gradient import AsyncGradient + + client = AsyncGradient( + base_url="http://test.com", + access_token="test_token", + custom_headers={"X-Custom": "custom-value"} + ) + + # Check that custom headers are in the client's _custom_headers + assert "X-Custom" in client._custom_headers + assert client._custom_headers["X-Custom"] == "custom-value" \ No newline at end of file diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py new file mode 100644 index 00000000..3e75be54 --- /dev/null +++ b/tests/test_rate_limiter.py @@ -0,0 +1,56 @@ +"""Tests for rate limiting functionality.""" + +import time +import pytest +from gradient._utils import RateLimiter + + +class TestRateLimiter: + """Test rate limiting functionality.""" + + def test_rate_limiter_basic(self): + """Test basic rate limiter operations.""" + limiter = RateLimiter(requests_per_minute=10) + + # Should allow initial requests + assert limiter.acquire() is True + assert limiter.acquire() is True + + # Should deny when tokens exhausted + limiter.tokens = 0 # Force exhaustion + assert limiter.acquire() is False + + def test_rate_limiter_wait_time(self): + """Test wait time calculation.""" + limiter = RateLimiter(requests_per_minute=60) # 1 request per second + + # Exhaust tokens + limiter.tokens = 0 + + # Should calculate correct wait time + wait_time = limiter.wait_time() + assert wait_time > 0 + assert wait_time <= 1.0 # Should not exceed 1 second + + def test_rate_limiter_refill(self): + """Test token refill over time.""" + limiter = RateLimiter(requests_per_minute=60) # 1 token per second + + # Exhaust tokens + limiter.tokens = 0 + start_time = limiter._now() + + # Wait for refill + time.sleep(0.1) + + # Should have refilled some tokens + limiter._refill() + assert limiter.tokens > 0 + + def test_rate_limiter_custom_rate(self): + """Test custom rate limits.""" + limiter = RateLimiter(requests_per_minute=120) # 2 requests per second + + # Should have double the tokens of default + assert limiter.requests_per_minute == 120 + assert limiter.refill_rate == 2.0 \ No newline at end of file diff --git a/tests/test_response_cache.py b/tests/test_response_cache.py new file mode 100644 index 00000000..ee2edddd --- /dev/null +++ b/tests/test_response_cache.py @@ -0,0 +1,83 @@ +"""Tests for response caching functionality.""" + +import time +import pytest +from gradient._utils import ResponseCache + + +class TestResponseCache: + """Test response caching functionality.""" + + def test_cache_basic_operations(self): + """Test basic cache operations.""" + cache = ResponseCache(max_size=3, default_ttl=1) + + # Test set and get + cache.set("GET", "/api/test", {"data": "value"}) + result = cache.get("GET", "/api/test") + assert result == {"data": "value"} + + # Test cache miss + result = cache.get("GET", "/api/missing") + assert result is None + + def test_cache_with_params(self): + """Test caching with query parameters.""" + cache = ResponseCache() + + # Set with params + cache.set("GET", "/api/search", {"results": []}, params={"q": "test"}) + + # Get with same params should hit + result = cache.get("GET", "/api/search", params={"q": "test"}) + assert result == {"results": []} + + # Get with different params should miss + result = cache.get("GET", "/api/search", params={"q": "other"}) + assert result is None + + def test_cache_ttl(self): + """Test cache TTL functionality.""" + cache = ResponseCache(default_ttl=0.1) # Very short TTL + + cache.set("GET", "/api/test", {"data": "value"}) + + # Should hit immediately + result = cache.get("GET", "/api/test") + assert result == {"data": "value"} + + # Wait for expiry + time.sleep(0.2) + + # Should miss after expiry + result = cache.get("GET", "/api/test") + assert result is None + + def test_cache_max_size(self): + """Test cache size limits with LRU eviction.""" + cache = ResponseCache(max_size=2) + + # Fill cache + cache.set("GET", "/api/1", "data1") + cache.set("GET", "/api/2", "data2") + assert cache.size() == 2 + + # Add third item (should evict first) + cache.set("GET", "/api/3", "data3") + assert cache.size() == 2 + + # First item should be gone + assert cache.get("GET", "/api/1") is None + assert cache.get("GET", "/api/2") == "data2" + assert cache.get("GET", "/api/3") == "data3" + + def test_cache_clear(self): + """Test cache clearing.""" + cache = ResponseCache() + + cache.set("GET", "/api/test", {"data": "value"}) + assert cache.size() == 1 + + cache.clear() + assert cache.size() == 0 + assert cache.get("GET", "/api/test") is None \ No newline at end of file diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 00000000..04bf47ee --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,141 @@ +"""Tests for API key validation functionality.""" + +import pytest +from gradient._utils import validate_api_key, validate_client_credentials, validate_client_instance + + +class TestAPIKeyValidation: + """Test API key validation functionality.""" + + def test_valid_api_keys(self): + """Test that valid API keys pass validation.""" + valid_keys = [ + "sk-1234567890abcdef", + "do_v1_1234567890abcdef1234567890abcdef", + "gradient_test_key_1234567890", + "some_long_api_key_that_is_valid_1234567890", + ] + + for key in valid_keys: + assert validate_api_key(key), f"Key {key} should be valid" + + def test_invalid_api_keys(self): + """Test that invalid API keys fail validation.""" + invalid_keys = [ + "", # empty string + " ", # whitespace only + "short", # too short + "123456789", # too short with numbers + 12345, # integer + ] + + for key in invalid_keys: + assert not validate_api_key(key), f"Key {key} should be invalid" + + # Test None separately since we made it valid for optional keys + assert validate_api_key(None), "None should be valid for optional keys" + + def test_validate_client_credentials_valid(self): + """Test that valid client credentials pass validation.""" + # Should not raise any exception + validate_client_credentials(access_token="sk-1234567890abcdef") + + def test_validate_client_credentials_invalid(self): + """Test that invalid client credentials raise appropriate errors.""" + # No authentication provided + with pytest.raises(ValueError, match="At least one authentication method must be provided"): + validate_client_credentials() + + # Invalid access token format + with pytest.raises(ValueError, match="Invalid access_token format"): + validate_client_credentials(access_token="invalid-key") + + # Invalid model access key format + with pytest.raises(ValueError, match="Invalid model_access_key format"): + validate_client_credentials(model_access_key="short") + + # Invalid agent access key format + with pytest.raises(ValueError, match="Invalid agent_access_key format"): + validate_client_credentials(agent_access_key="bad") + + # Invalid agent endpoint + with pytest.raises(ValueError, match="agent_endpoint must be a string"): + validate_client_credentials(access_token="sk-1234567890abcdef", agent_endpoint=123) + + # Invalid agent endpoint URL + with pytest.raises(ValueError, match="agent_endpoint must be a valid HTTP/HTTPS URL"): + validate_client_credentials(access_token="sk-1234567890abcdef", agent_endpoint="ftp://example.com") + + def test_validate_client_credentials_comprehensive(self): + """Test comprehensive client credentials validation scenarios.""" + # Valid combinations + validate_client_credentials(access_token="sk-1234567890abcdef") + validate_client_credentials(model_access_key="gradient_key_1234567890") + validate_client_credentials(agent_access_key="agent_key_1234567890") + validate_client_credentials( + access_token="sk-1234567890abcdef", + model_access_key="gradient_key_1234567890", + agent_access_key="agent_key_1234567890", + agent_endpoint="https://my-agent.agents.do-ai.run" + ) + + # Invalid combinations + with pytest.raises(ValueError): + validate_client_credentials(access_token="short") + + with pytest.raises(ValueError): + validate_client_credentials(agent_endpoint="https://example.com") # No auth keys + + def test_validate_client_instance(self): + """Test client instance validation.""" + from gradient import Gradient + + # Valid client + client = Gradient(access_token="sk-1234567890abcdef") + validate_client_instance(client) # Should not raise + + # Invalid client - no authentication + invalid_client = Gradient() + with pytest.raises(ValueError, match="Client must have at least one authentication method configured"): + validate_client_instance(invalid_client) + + # Invalid client - wrong type + with pytest.raises(TypeError, match="client must be a Gradient or AsyncGradient instance"): + validate_client_instance("not a client") + + # Invalid client - bad credentials + bad_client = Gradient(access_token="short") + with pytest.raises(ValueError, match="Client authentication validation failed"): + validate_client_instance(bad_client) + + def test_validate_client_credentials_multiple_valid(self): + """Test that multiple valid authentication methods work.""" + # All valid methods + validate_client_credentials( + access_token="sk-1234567890abcdef", + model_access_key="gradient_key_1234567890", + agent_access_key="agent_key_1234567890" + ) + + # Mixed valid/invalid should fail + with pytest.raises(ValueError): + validate_client_credentials( + access_token="sk-1234567890abcdef", + model_access_key="short" # invalid + ) + + def test_validate_client_credentials_mixed_valid_invalid(self): + """Test mixed valid and invalid credentials.""" + # Valid access token with invalid model key + with pytest.raises(ValueError, match="Invalid model_access_key format"): + validate_client_credentials( + access_token="sk-1234567890abcdef", + model_access_key="short" + ) + + # Valid model key with invalid agent endpoint + with pytest.raises(ValueError, match="agent_endpoint must be a valid HTTP/HTTPS URL"): + validate_client_credentials( + model_access_key="gradient_key_1234567890", + agent_endpoint="invalid-url" + ) \ No newline at end of file