diff --git a/configs/facility-with-weave.yaml b/configs/facility-with-weave.yaml new file mode 100644 index 0000000..279f6fc --- /dev/null +++ b/configs/facility-with-weave.yaml @@ -0,0 +1,34 @@ +# Facility dataset configuration with W&B Weave tracking enabled + +system_prompt: + file: "../use-cases/facility-support-analyzer/facility_prompt_sys.txt" + inputs: ["question"] + outputs: ["answer"] + +# Dataset configuration +dataset: + path: "../use-cases/facility-support-analyzer/dataset.json" + input_field: ["fields", "input"] + golden_output_field: "answer" + +# Model configuration (minimal required settings) +model: + name: "openrouter/meta-llama/llama-3.3-70b-instruct" + task_model: "openrouter/meta-llama/llama-3.3-70b-instruct" + proposer_model: "openrouter/meta-llama/llama-3.3-70b-instruct" + +# Metric configuration (simplified but maintains compatibility) +metric: + class: "llama_prompt_ops.core.metrics.FacilityMetric" + strict_json: false + output_field: "answer" + +# Optimization settings +optimization: + strategy: "llama" + +# W&B Weave tracking configuration +weave: + enabled: true + project_name: "llama-prompt-optimization" + entity: null # Optional: your W&B entity name \ No newline at end of file diff --git a/docs/README.md b/docs/README.md index fb1ec45..c91b21e 100644 --- a/docs/README.md +++ b/docs/README.md @@ -40,6 +40,33 @@ llama-prompt-ops supports various inference providers and endpoints to fit your - vLLM (local deployment) - NVIDIA NIMs (optimized containers) +## W&B Weave Integration + +Track and visualize your prompt optimization experiments with W&B Weave. When enabled, Weave automatically tracks: + +- **Prompt Evolution**: Original and optimized prompt versions +- **Dataset Versions**: Training, validation, and test datasets +- **LLM Call Traces**: All model calls with inputs, outputs, tokens, and costs + +### Quick Start + +1. Add Weave configuration to your YAML file: +```yaml +weave: + enabled: true + project_name: "my-optimization-project" + entity: "my-team" # Optional +``` + +2. Run optimization with tracking: +```bash +llama-prompt-ops migrate --config config.yaml --weave +``` + +3. View results at: `https://wandb.ai/[entity]/[project-name]` + +See the [full Weave integration details](#) for advanced configuration options. + ## Supported Formats at a Glance ### Prompt Formats diff --git a/pyproject.toml b/pyproject.toml index 1a3bf7b..0f43c7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,8 @@ dependencies = [ "litellm>=1.63.0", "huggingface-hub>=0.29.0", "datasets>=2.21.0", - "propcache==0.3.1" + "propcache==0.3.1", + "weave>=0.51.0" ] [project.optional-dependencies] diff --git a/src/llama_prompt_ops/integrations/__init__.py b/src/llama_prompt_ops/integrations/__init__.py new file mode 100644 index 0000000..2c473cf --- /dev/null +++ b/src/llama_prompt_ops/integrations/__init__.py @@ -0,0 +1,9 @@ +""" +Integration modules for llama-prompt-ops. + +This package contains integrations with external tracking and logging services. +""" + +from .weave_tracker import WeaveTracker + +__all__ = ["WeaveTracker"] \ No newline at end of file diff --git a/src/llama_prompt_ops/integrations/weave_tracker.py b/src/llama_prompt_ops/integrations/weave_tracker.py new file mode 100644 index 0000000..7a95ac2 --- /dev/null +++ b/src/llama_prompt_ops/integrations/weave_tracker.py @@ -0,0 +1,198 @@ +""" +W&B Weave integration for tracking prompts, datasets, and LLM calls. + +Uses Weave's native classes: +- weave.StringPrompt for versioned prompts (with names) +- weave.Dataset for versioned datasets (with names) +- Automatic LLM tracing via weave.init() +""" +from typing import Dict, Any, Optional, List +import logging + +try: + import weave + from weave import StringPrompt, Dataset + WEAVE_AVAILABLE = True +except ImportError: + WEAVE_AVAILABLE = False + weave = None + StringPrompt = None + Dataset = None + +from datasets import Dataset as HFDataset + + +logger = logging.getLogger(__name__) + + +class WeaveTracker: + """ + Lightweight W&B Weave integration using native Weave classes. + + Provides: + - Prompt versioning via weave.StringPrompt (named objects) + - Dataset versioning via weave.Dataset (named objects) + - Automatic LLM tracing via weave.init() + """ + + def __init__( + self, + project_name: str, + entity: Optional[str] = None, + enabled: bool = True + ): + """ + Initialize Weave tracking. + + Args: + project_name: Weave project name + entity: W&B entity (optional) + enabled: Whether tracking is enabled + """ + self.project_name = project_name + self.entity = entity + self.enabled = enabled + + if not WEAVE_AVAILABLE: + logger.warning("Weave not available. Install with: pip install weave") + self.enabled = False + return + + if self.enabled: + self._initialize_weave() + + def _initialize_weave(self) -> None: + """Initialize Weave project - enables automatic LLM tracing.""" + try: + if self.entity: + project_path = f"{self.entity}/{self.project_name}" + else: + project_path = self.project_name + + weave.init(project_path) + logger.info(f"Weave initialized: {project_path}") + + except Exception as e: + logger.error(f"Failed to initialize Weave: {e}") + self.enabled = False + + def is_enabled(self) -> bool: + """Check if Weave tracking is enabled.""" + return self.enabled and WEAVE_AVAILABLE + + def track_prompt_evolution( + self, + original_prompt: str, + optimized_prompt: str, + prompt_name: str = "system_prompt", + metadata: Optional[Dict[str, Any]] = None + ) -> Optional[str]: + """ + Track prompt evolution using the same named prompt for versioning. + + Args: + original_prompt: Original prompt text + optimized_prompt: Optimized prompt text + prompt_name: Name for both versions (creates v1, v2, etc.) + metadata: Optimization metadata (unused for now) + + Returns: + Reference to published optimized prompt version + """ + if not self.is_enabled(): + return None + + try: + # Create StringPrompts (name goes with publish, not constructor) + original = StringPrompt(original_prompt) + optimized = StringPrompt(optimized_prompt) + + # Publish with same name to create versions + weave.publish(original, name=prompt_name) + optimized_ref = weave.publish(optimized, name=prompt_name) + + logger.info(f"Tracked prompt evolution: {optimized_ref}") + return str(optimized_ref) + + except Exception as e: + logger.error(f"Failed to track prompt evolution: {e}") + return None + + def track_dataset( + self, + dataset: HFDataset, + split: str = "train", + metadata: Optional[Dict[str, Any]] = None + ) -> Optional[str]: + """ + Track dataset using named weave.Dataset. + + Args: + dataset: HuggingFace dataset to track + split: Dataset split name + metadata: Additional metadata (unused for now) + + Returns: + Reference to published dataset + """ + if not self.is_enabled(): + return None + + try: + # Convert HF dataset to format expected by weave.Dataset + rows = [dict(row) for row in dataset] + + # Create named Weave Dataset for auto-versioning + weave_dataset = Dataset( + name=f"dataset_{split}", + rows=rows + ) + + # Publish dataset (automatically versioned by name) + ref = weave.publish(weave_dataset) + logger.info(f"Tracked dataset ({split}): {ref}") + return str(ref) + + except Exception as e: + logger.error(f"Failed to track dataset: {e}") + return None + + def get_prompt(self, name: str = "system_prompt") -> Optional[StringPrompt]: + """ + Retrieve prompt using Weave refs. + + Args: + name: Prompt name to retrieve + + Returns: + StringPrompt object, None if not found + """ + if not self.is_enabled(): + return None + + try: + ref = weave.ref(name) + return ref.get() + except Exception as e: + logger.error(f"Failed to get prompt: {e}") + return None + + def get_dataset(self, split: str = "train") -> Optional[Dataset]: + """ + Retrieve dataset using Weave refs. + + Args: + split: Dataset split to retrieve + + Returns: + Dataset object, None if not found + """ + if not self.is_enabled(): + return None + + try: + ref = weave.ref(f"dataset_{split}") + return ref.get() + except Exception as e: + logger.error(f"Failed to get dataset: {e}") + return None \ No newline at end of file diff --git a/src/llama_prompt_ops/interfaces/cli.py b/src/llama_prompt_ops/interfaces/cli.py index 9286d99..82679b2 100644 --- a/src/llama_prompt_ops/interfaces/cli.py +++ b/src/llama_prompt_ops/interfaces/cli.py @@ -30,6 +30,9 @@ # Import template utilities from llama_prompt_ops.templates import get_template_content, get_template_path +# Import Weave integration +from llama_prompt_ops.integrations import WeaveTracker + def check_api_key(api_key_env, dotenv_path=".env"): """Check if API key is set and return it. @@ -739,6 +742,45 @@ def load_config(config_path): raise ValueError(f"Failed to load configuration from {config_path}: {str(e)}") +def create_weave_tracker_from_config( + config_dict: Dict[str, Any], + cli_override: Optional[bool] = None +) -> Optional[WeaveTracker]: + """ + Create WeaveTracker from configuration with CLI override support. + + Args: + config_dict: Configuration dictionary + cli_override: CLI override for enabling/disabling Weave (None = use config) + + Returns: + WeaveTracker instance or None if disabled + """ + weave_config = config_dict.get("weave", {}) + + # CLI override takes precedence over config + if cli_override is not None: + enabled = cli_override + else: + enabled = weave_config.get("enabled", False) + + if not enabled: + return None + + project_name = weave_config.get("project_name", "llama-prompt-ops") + entity = weave_config.get("entity") + + try: + return WeaveTracker( + project_name=project_name, + entity=entity, + enabled=True + ) + except Exception as e: + click.echo(f"Warning: Failed to initialize Weave tracker: {e}") + return None + + @cli.command(name="migrate") @click.option( "--config", @@ -772,7 +814,12 @@ def load_config(config_path): default="INFO", help="Set the logging level", ) -def migrate(config, model, output_dir, save_yaml, api_key_env, dotenv_path, log_level): +@click.option( + "--weave/--no-weave", + default=None, + help="Enable/disable Weave tracking (overrides config file setting)", +) +def migrate(config, model, output_dir, save_yaml, api_key_env, dotenv_path, log_level, weave): """ Migrate and optimize prompts using a YAML configuration file. @@ -816,6 +863,13 @@ def migrate(config, model, output_dir, save_yaml, api_key_env, dotenv_path, log_ except ValueError as e: click.echo(f"Error: {str(e)}", err=True) sys.exit(1) + + # Initialize Weave tracking if configured (with CLI override support) + weave_tracker = create_weave_tracker_from_config(config_dict, cli_override=weave) + if weave_tracker and weave_tracker.is_enabled(): + click.echo(f"Weave tracking enabled for project: {weave_tracker.project_name}") + else: + click.echo("Weave tracking disabled") # Configure logging from file, if not overridden by CLI if not log_level: @@ -942,6 +996,15 @@ def migrate(config, model, output_dir, save_yaml, api_key_env, dotenv_path, log_ # Wrap the optimization in a try/except block to catch parallelizer errors try: click.echo("Starting prompt optimization...") + + # Track dataset with Weave if enabled + if weave_tracker and weave_tracker.is_enabled(): + weave_tracker.track_dataset(trainset, split="train") + if valset: + weave_tracker.track_dataset(valset, split="validation") + if testset: + weave_tracker.track_dataset(testset, split="test") + optimized = migrator.optimize( { "text": prompt_text, @@ -955,6 +1018,18 @@ def migrate(config, model, output_dir, save_yaml, api_key_env, dotenv_path, log_ file_path=json_file_path, ) + # Track prompt evolution with Weave if enabled + if weave_tracker and weave_tracker.is_enabled(): + try: + optimized_prompt_text = str(optimized.signature.instructions) + weave_tracker.track_prompt_evolution( + original_prompt=prompt_text, + optimized_prompt=optimized_prompt_text + ) + click.echo("Prompt evolution tracked in Weave") + except Exception as e: + click.echo(f"Warning: Failed to track prompt in Weave: {e}") + click.echo("\n=== Optimization Complete ===") click.echo(f"Results saved to: {json_file_path}") if save_yaml: diff --git a/tests/integration/test_weave_integration.py b/tests/integration/test_weave_integration.py new file mode 100644 index 0000000..2c7891c --- /dev/null +++ b/tests/integration/test_weave_integration.py @@ -0,0 +1,260 @@ +""" +Integration tests for W&B Weave tracking functionality. + +This test suite validates that the Weave integration correctly tracks: +1. Prompt versioned objects +2. Dataset versioned objects +3. LLM call traces (via weave.init() automatic tracing) +""" +import os +import tempfile +import uuid +from typing import Dict, List, Any +from unittest.mock import patch, MagicMock + +import pytest +import yaml +from datasets import Dataset + +# Import our integration components +from llama_prompt_ops.integrations.weave_tracker import WeaveTracker + + +class TestWeaveIntegration: + """ + Comprehensive test suite for Weave integration. + + Tests validate the three core requirements: + 1. Prompt versioning and tracking + 2. Dataset versioning and tracking + 3. Automatic LLM call tracing via weave.init() + """ + + @pytest.fixture + def test_project_name(self) -> str: + """Generate unique test project name to avoid conflicts.""" + return f"llama-prompt-ops-test-{uuid.uuid4().hex[:8]}" + + @pytest.fixture + def sample_dataset(self) -> Dataset: + """Create a sample dataset for testing.""" + return Dataset.from_dict({ + "question": ["What is AI?", "Explain machine learning", "Define neural networks"], + "answer": ["AI is artificial intelligence", "ML is a subset of AI", "Neural networks are computing systems"] + }) + + def test_weave_tracker_initialization(self, test_project_name: str): + """Test that WeaveTracker initializes correctly and calls weave.init().""" + with patch('llama_prompt_ops.integrations.weave_tracker.weave') as mock_weave: + tracker = WeaveTracker(project_name=test_project_name, enabled=True) + + assert tracker.project_name == test_project_name + assert tracker.enabled is True + + # Verify weave.init was called with correct project name + mock_weave.init.assert_called_once_with(test_project_name) + + # Test disabled tracker + disabled_tracker = WeaveTracker(project_name=test_project_name, enabled=False) + assert disabled_tracker.enabled is False + + def test_prompt_versioning_with_string_prompt(self, test_project_name: str): + """ + REQUIREMENT 1: Test prompt versioned objects using weave.StringPrompt. + + Validates that: + - Original and optimized prompts are tracked using weave.StringPrompt + - Same-named prompts create versions + - weave.publish() is called correctly + """ + with patch('llama_prompt_ops.integrations.weave_tracker.weave') as mock_weave, \ + patch('llama_prompt_ops.integrations.weave_tracker.StringPrompt') as mock_string_prompt_class: + + mock_weave.init.return_value = MagicMock() + mock_weave.publish.return_value = "weave://test/project/StringPrompt/system_prompt:v1" + + tracker = WeaveTracker(project_name=test_project_name, enabled=True) + + original_prompt = "You are a helpful assistant." + optimized_prompt = "You are a helpful AI assistant specialized in Llama models." + + # Track prompt evolution + result = tracker.track_prompt_evolution( + original_prompt=original_prompt, + optimized_prompt=optimized_prompt, + prompt_name="system_prompt" + ) + + assert result is not None + + # Verify StringPrompt objects were created correctly (no name in constructor) + assert mock_string_prompt_class.call_count == 2 + mock_string_prompt_class.assert_any_call(original_prompt) + mock_string_prompt_class.assert_any_call(optimized_prompt) + + # Verify weave.publish was called twice with same name for versioning + assert mock_weave.publish.call_count == 2 + # Check that both calls used the same name parameter + publish_calls = mock_weave.publish.call_args_list + assert publish_calls[0][1]['name'] == "system_prompt" + assert publish_calls[1][1]['name'] == "system_prompt" + + def test_dataset_versioning_with_weave_dataset(self, test_project_name: str, sample_dataset: Dataset): + """ + REQUIREMENT 2: Test dataset versioned objects using weave.Dataset. + + Validates that: + - Datasets are converted and tracked using weave.Dataset + - Dataset names are set correctly for versioning + - weave.publish() is called correctly + """ + with patch('llama_prompt_ops.integrations.weave_tracker.weave') as mock_weave, \ + patch('llama_prompt_ops.integrations.weave_tracker.Dataset') as mock_dataset_class: + + mock_weave.init.return_value = MagicMock() + mock_weave.publish.return_value = "weave://test/project/Dataset/dataset_train:v1" + + tracker = WeaveTracker(project_name=test_project_name, enabled=True) + + # Track dataset + result = tracker.track_dataset( + dataset=sample_dataset, + split="train" + ) + + assert result is not None + + # Verify weave.Dataset was created with correct structure + mock_dataset_class.assert_called_once() + call_args = mock_dataset_class.call_args + + # Check the dataset was created with correct name and data + assert call_args[1]["name"] == "dataset_train" + assert len(call_args[1]["rows"]) == 3 # Our sample has 3 rows + + # Verify weave.publish was called + mock_weave.publish.assert_called_once() + + def test_automatic_llm_tracing_via_weave_init(self, test_project_name: str): + """ + REQUIREMENT 3: Test that weave.init() enables automatic LLM tracing. + + Validates that: + - weave.init() is called during tracker initialization + - This enables automatic tracing for supported LLM libraries + - No additional wrapping is needed for basic LLM calls + """ + with patch('llama_prompt_ops.integrations.weave_tracker.weave') as mock_weave: + mock_client = MagicMock() + mock_weave.init.return_value = mock_client + + # Initialize tracker + tracker = WeaveTracker(project_name=test_project_name, enabled=True) + + # Verify weave.init was called (this enables automatic LLM tracing) + mock_weave.init.assert_called_once_with(test_project_name) + + # Verify tracker recognizes it's enabled for automatic tracing + assert tracker.is_enabled() is True + + # The automatic tracing happens via weave.init() - no additional setup needed + # LLM calls made after this point will be automatically traced by Weave + + def test_weave_integration_can_be_disabled(self): + """ + Test that Weave tracking can be completely disabled. + + This validates the requirement to choose whether to run with or without Weave. + """ + # Test initialization with enabled=False + tracker = WeaveTracker(project_name="test", enabled=False) + assert not tracker.is_enabled() + + # Test that tracking operations return None when disabled + result = tracker.track_prompt_evolution("original", "optimized") + assert result is None + + sample_dataset = Dataset.from_dict({"test": ["data"]}) + result = tracker.track_dataset(sample_dataset) + assert result is None + + def test_cli_integration_with_weave_config(self): + """ + Test that CLI properly creates WeaveTracker from YAML config. + """ + from llama_prompt_ops.interfaces.cli import create_weave_tracker_from_config + + # Test enabled configuration + config_with_weave = { + "weave": { + "enabled": True, + "project_name": "test-project", + "entity": "test-entity" + } + } + + with patch('llama_prompt_ops.interfaces.cli.WeaveTracker') as mock_tracker_class: + mock_tracker = MagicMock() + mock_tracker.is_enabled.return_value = True + mock_tracker_class.return_value = mock_tracker + + result = create_weave_tracker_from_config(config_with_weave) + + assert result is not None + mock_tracker_class.assert_called_once_with( + project_name="test-project", + entity="test-entity", + enabled=True + ) + + # Test disabled configuration + config_without_weave = {} + result = create_weave_tracker_from_config(config_without_weave) + assert result is None + + # Test explicitly disabled configuration + config_disabled = {"weave": {"enabled": False}} + result = create_weave_tracker_from_config(config_disabled) + assert result is None + + def test_error_handling_when_weave_unavailable(self): + """ + Test graceful degradation when Weave is unavailable. + """ + with patch('llama_prompt_ops.integrations.weave_tracker.WEAVE_AVAILABLE', False): + tracker = WeaveTracker(project_name="test", enabled=True) + + # Should disable tracking when weave is unavailable + assert not tracker.is_enabled() + + # Operations should handle gracefully without crashing + result = tracker.track_prompt_evolution("test", "optimized") + assert result is None + + def test_weave_ref_retrieval_methods(self, test_project_name: str): + """ + Test that get_prompt and get_dataset methods work with weave.ref(). + """ + with patch('llama_prompt_ops.integrations.weave_tracker.weave') as mock_weave: + mock_weave.init.return_value = MagicMock() + mock_ref = MagicMock() + mock_ref.get.return_value = {"prompt": "test prompt"} + mock_weave.ref.return_value = mock_ref + + tracker = WeaveTracker(project_name=test_project_name, enabled=True) + + # Test prompt retrieval + result = tracker.get_prompt("system_prompt") + mock_weave.ref.assert_called_with("system_prompt") + mock_ref.get.assert_called() + assert result == {"prompt": "test prompt"} + + # Test dataset retrieval + mock_ref.get.return_value = {"dataset": "test data"} + result = tracker.get_dataset("train") + mock_weave.ref.assert_called_with("dataset_train") + assert result == {"dataset": "test data"} + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file