diff --git a/README.md b/README.md index 05d2af60..6608a540 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,9 @@ Multiple tools support standardised formats. However, most of the times your dat This toolkit simplifies the journey of: -- Using local LLM (via vLLM) to generate examples +- Using local LLMs via multiple backends: + - vLLM for high-performance inference + - Ollama for easy model management and deployment - Modular 4 command flow - Converting your existing files to fine-tuning friendly formats - Creating synthetic datasets @@ -28,23 +30,19 @@ The tool is designed to follow a simple CLI structure with 4 commands: - `ingest` various file formats - `create` your fine-tuning format: `QA` pairs, `QA` pairs with CoT, `summary` format -- `curate`: Using Llama as a judge to curate high quality examples. -- `save-as`: After that you can simply save these to a format that your fine-tuning workflow requires. +- `curate`: Using LLMs as judges to curate high quality examples +- `save-as`: After that you can simply save these to a format that your fine-tuning workflow requires You can override any parameter or detail by either using the CLI or overiding the default YAML config. - ### Installation #### From PyPI ```bash # Create a new environment - conda create -n synthetic-data python=3.10 - conda activate synthetic-data - pip install synthetic-data-kit ``` @@ -57,44 +55,50 @@ pip install -e . ``` To get an overview of commands type: - `synthetic-data-kit --help` ### 1. Tool Setup - The tool expects respective files to be put in named folders. -- We also require a vLLM server running the LLM that we will utilise for generating our dataset. +- We support two LLM backends: + - vLLM server for high-performance inference + - Ollama for easy model management ```bash # Create directory structure mkdir -p data/{pdf,html,youtube,docx,ppt,txt,output,generated,cleaned,final} -# Start VLLM server +# Option 1: Start VLLM server # Note you will need to grab your HF Authentication from: https://huggingface.co/settings/tokens vllm serve meta-llama/Llama-3.3-70B-Instruct --port 8000 + +# Option 2: Use Ollama +# Install Ollama from https://ollama.ai +ollama serve +# Pull your desired model +ollama pull llama3.2 ``` ### 2. Usage -The flow follows 4 simple steps: `ingest`, `create`, `curate`, `save-as`, please paste your file into the respective folder: +The flow follows 4 simple steps: `ingest`, `create`, `curate`, `save-as`. Please paste your file into the respective folder: ```bash -# Check if VLLM server is running -synthetic-data-kit system-check +# Check if server is running (specify backend) +synthetic-data-kit system-check --backend vllm +# OR +synthetic-data-kit system-check --backend ollama # Parse a document to text synthetic-data-kit ingest docs/report.pdf -# This will save file to data/output/report.txt -# Generate QA pairs (default) -synthetic-data-kit create data/output/report.txt --type qa - -OR +# Generate QA pairs (default) with specific backend +synthetic-data-kit create data/output/report.txt --type qa --backend vllm +# OR +synthetic-data-kit create data/output/report.txt --type qa --backend ollama # Generate Chain of Thought (CoT) reasoning examples -synthetic-data-kit create data/output/report.txt --type cot - -# Both of these will save file to data/generated/report_qa_pairs.json +synthetic-data-kit create data/output/report.txt --type cot --backend ollama # Filter content based on quality synthetic-data-kit curate data/generated/report_qa_pairs.json @@ -102,18 +106,26 @@ synthetic-data-kit curate data/generated/report_qa_pairs.json # Convert to alpaca fine-tuning format and save as HF arrow file synthetic-data-kit save-as data/cleaned/report_cleaned.json --format alpaca --storage hf ``` + ## Configuration The toolkit uses a YAML configuration file (default: `configs/config.yaml`). - -Note, this can be overriden via either CLI arguments OR passing a custom YAML file +Note, this can be overriden via either CLI arguments OR passing a custom YAML file. ```yaml # Example configuration +backend: "vllm" # or "ollama" + vllm: api_base: "http://localhost:8000/v1" model: "meta-llama/Llama-3.3-70B-Instruct" +ollama: + api_base: "http://localhost:11434/v1" + model: "llama3.2" + max_retries: 3 + retry_delay: 1.0 + generation: temperature: 0.7 chunk_size: 4000 @@ -134,44 +146,41 @@ synthetic-data-kit -c my_config.yaml ingest docs/paper.pdf ## Examples -### Processing a PDF Document +### Processing a PDF Document with Ollama ```bash # Ingest PDF synthetic-data-kit ingest research_paper.pdf -# Generate QA pairs -synthetic-data-kit create data/output/research_paper.txt -n 30 --threshold 8.0 +# Generate QA pairs using Ollama +synthetic-data-kit create data/output/research_paper.txt -n 30 --backend ollama --model llama3.2 # Curate data -synthetic-data-kit curate data/generated/research_paper_qa_pairs.json -t 8.5 +synthetic-data-kit curate data/cleaned/research_paper_qa_pairs.json -t 8.5 # Save in OpenAI fine-tuning format (JSON) synthetic-data-kit save-as data/cleaned/research_paper_cleaned.json -f ft - -# Save in OpenAI fine-tuning format (HF dataset) -synthetic-data-kit save-as data/cleaned/research_paper_cleaned.json -f ft --storage hf ``` -### Processing a YouTube Video +### Processing a YouTube Video with vLLM ```bash # Extract transcript synthetic-data-kit ingest "https://www.youtube.com/watch?v=dQw4w9WgXcQ" -# Generate QA pairs with specific model -synthetic-data-kit create data/output/youtube_dQw4w9WgXcQ.txt +# Generate QA pairs with vLLM +synthetic-data-kit create data/output/youtube_dQw4w9WgXcQ.txt --backend vllm ``` ### Processing Multiple Files ```bash -# Bash script to process multiple files +# Bash script to process multiple files with Ollama for file in data/pdf/*.pdf; do filename=$(basename "$file" .pdf) synthetic-data-kit ingest "$file" - synthetic-data-kit create "data/output/${filename}.txt" -n 20 + synthetic-data-kit create "data/output/${filename}.txt" -n 20 --backend ollama synthetic-data-kit curate "data/generated/${filename}_qa_pairs.json" -t 7.5 synthetic-data-kit save-as "data/cleaned/${filename}_cleaned.json" -f chatml done @@ -218,10 +227,12 @@ graph LR SDK --> Curate[curate] SDK --> SaveAs[save-as] + SystemCheck --> VLLM[vLLM Backend] + SystemCheck --> Ollama[Ollama Backend] + Ingest --> PDFFile[PDF File] Ingest --> HTMLFile[HTML File] Ingest --> YouTubeURL[File Format] - Create --> CoT[CoT] Create --> QA[QA Pairs] @@ -237,18 +248,27 @@ graph LR ## Troubleshooting FAQs: -### VLLM Server Issues +### Backend Issues +#### VLLM Server - Ensure VLLM is installed: `pip install vllm` - Start server with: `vllm serve --port 8000` -- Check connection: `synthetic-data-kit system-check` +- Check connection: `synthetic-data-kit system-check --backend vllm` + +#### Ollama Server +- Install Ollama from https://ollama.ai +- Start server with: `ollama serve` +- Pull models with: `ollama pull ` +- Check connection: `synthetic-data-kit system-check --backend ollama` +- List available models: `synthetic-data-kit system-check --backend ollama` ### Memory Issues If you encounter CUDA out of memory errors: - Use a smaller model - Reduce batch size in config -- Start VLLM with `--gpu-memory-utilization 0.85` +- For vLLM: Start with `--gpu-memory-utilization 0.85` +- For Ollama: Use smaller model variants (e.g., llama3:7b instead of llama3:70b) ### JSON Parsing Issues diff --git a/synthetic_data_kit/cli.py b/synthetic_data_kit/cli.py index 98df5534..63c7968a 100644 --- a/synthetic_data_kit/cli.py +++ b/synthetic_data_kit/cli.py @@ -13,8 +13,15 @@ from rich.console import Console from rich.table import Table -from synthetic_data_kit.utils.config import load_config, get_vllm_config, get_path_config +from synthetic_data_kit.utils.config import ( + load_config, + get_backend_type, + get_vllm_config, + get_ollama_config, + get_path_config +) from synthetic_data_kit.core.context import AppContext +from synthetic_data_kit.models.llm_client import LLMClient # Initialize Typer app app = typer.Typer( @@ -44,36 +51,70 @@ def callback( @app.command("system-check") def system_check( + backend: Optional[str] = typer.Option( + None, "--backend", help="Backend to check (vllm or ollama)" + ), api_base: Optional[str] = typer.Option( - None, "--api-base", help="VLLM API base URL to check" + None, "--api-base", help="API base URL to check" ) ): """ - Check if the VLLM server is running. + Check if the LLM server (VLLM or Ollama) is running. """ - # Get VLLM server details from args or config - vllm_config = get_vllm_config(ctx.config) - api_base = api_base or vllm_config.get("api_base") + # Get backend type from args or config + backend_type = backend or get_backend_type(ctx.config) + + if backend_type == "vllm": + config = get_vllm_config(ctx.config) + default_port = 8000 + start_cmd = "vllm" + elif backend_type == "ollama": + config = get_ollama_config(ctx.config) + default_port = 11434 + start_cmd = "ollama" + else: + console.print(f"L Error: Unsupported backend type: {backend_type}", style="red") + return 1 - with console.status(f"Checking VLLM server at {api_base}..."): + # Get API base from args or config + api_base = api_base or config.get("api_base") + model = config.get("model") + + with console.status(f"Checking {backend_type.upper()} server at {api_base}..."): try: response = requests.get(f"{api_base}/models", timeout=2) if response.status_code == 200: - console.print(f" VLLM server is running at {api_base}", style="green") - console.print(f"Available models: {response.json()}") + console.print(f" {backend_type.upper()} server is running at {api_base}", style="green") + if backend_type == "ollama": + models_data = response.json()["data"] + table = Table("Model ID", "Created", "Owner") + for model_info in models_data: + table.add_row( + model_info["id"], + str(model_info.get("created", "N/A")), + model_info.get("owned_by", "N/A") + ) + console.print("\nAvailable Models:") + console.print(table) + else: + console.print(f"Available models: {response.json()}") return 0 else: - console.print(f"L VLLM server is not available at {api_base}", style="red") + console.print(f"L {backend_type.upper()} server is not available at {api_base}", style="red") console.print(f"Error: Server returned status code: {response.status_code}") except requests.exceptions.RequestException as e: - console.print(f"L VLLM server is not available at {api_base}", style="red") + console.print(f"L {backend_type.upper()} server is not available at {api_base}", style="red") console.print(f"Error: {str(e)}") - - # Show instruction to start the server - model = vllm_config.get("model") - port = vllm_config.get("port", 8000) - console.print("\nTo start the server, run:", style="yellow") - console.print(f"vllm serve {model} --port {port}", style="bold blue") + + # Show instructions to start the server + if backend_type == "vllm": + console.print("\nTo start the VLLM server, run:", style="yellow") + console.print(f"vllm serve {model} --port {default_port}", style="bold blue") + else: + console.print("\nTo start the Ollama server, run:", style="yellow") + console.print("ollama serve", style="bold blue") + console.print("\nThen pull your model:", style="yellow") + console.print(f"ollama pull {model}", style="bold blue") return 1 @@ -115,8 +156,11 @@ def create( output_dir: Optional[Path] = typer.Option( None, "--output-dir", "-o", help="Where to save the output" ), + backend: Optional[str] = typer.Option( + None, "--backend", help="Backend to use (vllm or ollama)" + ), api_base: Optional[str] = typer.Option( - None, "--api-base", help="VLLM API base URL" + None, "--api-base", help="API base URL" ), model: Optional[str] = typer.Option( None, "--model", "-m", help="Model to use" @@ -143,23 +187,16 @@ def create( """ from synthetic_data_kit.core.create import process_file - # Get VLLM server details from args or config - vllm_config = get_vllm_config(ctx.config) - api_base = api_base or vllm_config.get("api_base") - model = model or vllm_config.get("model") - - # Check server first + # Initialize LLM client with appropriate backend try: - response = requests.get(f"{api_base}/models", timeout=2) - if response.status_code != 200: - console.print(f"L Error: VLLM server not available at {api_base}", style="red") - console.print("Please start the VLLM server with:", style="yellow") - console.print(f"vllm serve {model}", style="bold blue") - return 1 - except requests.exceptions.RequestException: - console.print(f"L Error: VLLM server not available at {api_base}", style="red") - console.print("Please start the VLLM server with:", style="yellow") - console.print(f"vllm serve {model}", style="bold blue") + client = LLMClient( + config_path=ctx.config_path, + backend=backend, + api_base=api_base, + model_name=model + ) + except Exception as e: + console.print(f"L Error initializing LLM client: {e}", style="red") return 1 # Get output directory from args, then config, then default @@ -172,8 +209,7 @@ def create( input, output_dir, ctx.config_path, - api_base, - model, + client, content_type, num_pairs, verbose @@ -267,53 +303,51 @@ def save_as( output: Optional[Path] = typer.Option( None, "--output", "-o", help="Output file path" ), + backend: Optional[str] = typer.Option( + None, "--backend", help="Backend to use (vllm or ollama)" + ), + api_base: Optional[str] = typer.Option( + None, "--api-base", help="API base URL" + ), + model: Optional[str] = typer.Option( + None, "--model", "-m", help="Model to use" + ), ): """ - Convert to different formats for fine-tuning. - - The --format option controls the content format (how the data is structured). - The --storage option controls how the data is stored (JSON file or HF dataset). - - When using --storage hf, the output will be a directory containing a Hugging Face - dataset in Arrow format, which is optimized for machine learning workflows. + Convert content to different formats for fine-tuning. """ - from synthetic_data_kit.core.save_as import convert_format + from synthetic_data_kit.core.save_as import process_file - # Get format from args or config - if not format: - format_config = ctx.config.get("format", {}) - format = format_config.get("default", "jsonl") + # Initialize LLM client with appropriate backend + try: + client = LLMClient( + config_path=ctx.config_path, + backend=backend, + api_base=api_base, + model_name=model + ) + except Exception as e: + console.print(f"L Error initializing LLM client: {e}", style="red") + return 1 - # Set default output path if not provided - if not output: - final_dir = get_path_config(ctx.config, "output", "final") - os.makedirs(final_dir, exist_ok=True) + # Get output path + if output is None: base_name = os.path.splitext(os.path.basename(input))[0] - - if storage == "hf": - # For HF datasets, use a directory name - output = os.path.join(final_dir, f"{base_name}_{format}_hf") - else: - # For JSON files, use appropriate extension - if format == "jsonl": - output = os.path.join(final_dir, f"{base_name}.jsonl") - else: - output = os.path.join(final_dir, f"{base_name}_{format}.json") + output_dir = get_path_config(ctx.config, "output", "final") + os.makedirs(output_dir, exist_ok=True) + output = os.path.join(output_dir, f"{base_name}_{format}.{storage}") try: - with console.status(f"Converting {input} to {format} format with {storage} storage..."): - output_path = convert_format( + with console.status(f"Converting {input} to {format} format..."): + output_path = process_file( input, output, format, - ctx.config, - storage_format=storage + storage, + ctx.config_path, + client ) - - if storage == "hf": - console.print(f" Converted to {format} format and saved as HF dataset to [bold]{output_path}[/bold]", style="green") - else: - console.print(f" Converted to {format} format and saved to [bold]{output_path}[/bold]", style="green") + console.print(f" Content saved to [bold]{output_path}[/bold]", style="green") return 0 except Exception as e: console.print(f"L Error: {e}", style="red") diff --git a/synthetic_data_kit/config.yaml b/synthetic_data_kit/config.yaml index e121e13c..ec3489e6 100644 --- a/synthetic_data_kit/config.yaml +++ b/synthetic_data_kit/config.yaml @@ -10,6 +10,7 @@ paths: docx: "data/docx" ppt: "data/ppt" txt: "data/txt" + default: "data/input" # Output locations output: @@ -17,14 +18,25 @@ paths: generated: "data/generated" # Where generated content is saved cleaned: "data/cleaned" # Where cleaned content is saved final: "data/final" # Where final formatted content is saved + default: "data/output" -# VLLM server configuration +# LLM Backend Configuration +backend: "vllm" # Options: "vllm" or "ollama" + +# VLLM Configuration vllm: - api_base: "http://localhost:8000/v1" # Base URL for VLLM API - port: 8000 # Port for VLLM server - model: "meta-llama/Llama-3.3-70B-Instruct" # Default model to use - max_retries: 3 # Number of retries for API calls - retry_delay: 1.0 # Initial delay between retries (seconds) + api_base: "http://localhost:8000/v1" + port: 8000 + model: "meta-llama/Llama-3.3-70B-Instruct" + max_retries: 3 + retry_delay: 1.0 + +# Ollama Configuration +ollama: + api_base: "http://localhost:11434/v1" + model: "llama3.2" + max_retries: 3 + retry_delay: 1.0 # Ingest configuration ingest: @@ -44,7 +56,7 @@ generation: # Content curation parameters curate: threshold: 7.0 # Default quality threshold (1-10) - batch_size: 32 # Number of items per batch for rating + batch_size: 8 # Number of items per batch for rating inference_batch: 32 # Number of batches to process at once with VLLM temperature: 0.1 # Temperature for rating (lower = more consistent) diff --git a/synthetic_data_kit/core/create.py b/synthetic_data_kit/core/create.py index 34e3e0fd..29a83698 100644 --- a/synthetic_data_kit/core/create.py +++ b/synthetic_data_kit/core/create.py @@ -17,8 +17,7 @@ def process_file( file_path: str, output_dir: str, config_path: Optional[Path] = None, - api_base: Optional[str] = None, - model: Optional[str] = None, + llm_client: Optional[LLMClient] = None, content_type: str = "qa", num_pairs: Optional[int] = None, verbose: bool = False, @@ -29,29 +28,23 @@ def process_file( file_path: Path to the text file to process output_dir: Directory to save generated content config_path: Path to configuration file - api_base: VLLM API base URL - model: Model to use + llm_client: Initialized LLMClient instance content_type: Type of content to generate (qa, summary, cot) num_pairs: Target number of QA pairs to generate - threshold: Quality threshold for filtering (1-10) + verbose: Whether to show detailed output Returns: Path to the output file """ # Create output directory if it doesn't exist - # The reason for having this directory logic for now is explained in context.py os.makedirs(output_dir, exist_ok=True) # Read the file with open(file_path, 'r', encoding='utf-8') as f: document_text = f.read() - # Initialize LLM client - client = LLMClient( - config_path=config_path, - api_base=api_base, - model_name=model - ) + # Use provided client or create new one + client = llm_client or LLMClient(config_path=config_path) # Generate base filename for output base_name = os.path.splitext(os.path.basename(file_path))[0] @@ -75,24 +68,11 @@ def process_file( # Save output output_path = os.path.join(output_dir, f"{base_name}_qa_pairs.json") - print(f"Saving result to {output_path}") + if verbose: + print(f"Saving result to {output_path}") - # First, let's save a basic test file to confirm the directory is writable - test_path = os.path.join(output_dir, "test_write.json") - try: - with open(test_path, 'w', encoding='utf-8') as f: - f.write('{"test": "data"}') - print(f"Successfully wrote test file to {test_path}") - except Exception as e: - print(f"Error writing test file: {e}") - - # Now save the actual result - try: - with open(output_path, 'w', encoding='utf-8') as f: - json.dump(result, f, indent=2) - print(f"Successfully wrote result to {output_path}") - except Exception as e: - print(f"Error writing result file: {e}") + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(result, f, indent=2) return output_path @@ -109,10 +89,6 @@ def process_file( return output_path - # So there are two separate categories of CoT - # Simply CoT maps to "Hey I want CoT being generated" - # CoT-enhance maps to "Please enhance my dataset with CoT" - elif content_type == "cot": from synthetic_data_kit.generators.cot_generator import COTGenerator @@ -198,32 +174,26 @@ def process_file( # Enhance this conversation's messages enhanced_messages = generator.enhance_with_cot(conv_messages, include_simple_steps=verbose) - # Create enhanced conversation with same structure + # Create enhanced conversation object enhanced_conv = conversation.copy() enhanced_conv["conversations"] = enhanced_messages enhanced_conversations.append(enhanced_conv) else: - # Not the expected format, just keep original - enhanced_conversations.append(conversation) + print(f"Warning: item {i} does not have a conversations field, skipping") + enhanced_conversations.append(conversation) # Keep original - # Save enhanced conversations + # Save output output_path = os.path.join(output_dir, f"{base_name}_enhanced.json") - with open(output_path, 'w', encoding='utf-8') as f: - if is_single_conversation and len(enhanced_conversations) == 1: - # Save the single conversation + if is_single_conversation: json.dump(enhanced_conversations[0], f, indent=2) else: - # Save the array of conversations json.dump(enhanced_conversations, f, indent=2) - if verbose: - print(f"Enhanced {len(enhanced_conversations)} conversation(s)") - return output_path - except json.JSONDecodeError: - raise ValueError(f"Failed to parse {file_path} as JSON. For cot-enhance, input must be a valid JSON file.") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON file: {e}") else: - raise ValueError(f"Unknown content type: {content_type}") + raise ValueError(f"Unsupported content type: {content_type}") diff --git a/synthetic_data_kit/core/curate.py b/synthetic_data_kit/core/curate.py index 017f02c9..5c87ed85 100644 --- a/synthetic_data_kit/core/curate.py +++ b/synthetic_data_kit/core/curate.py @@ -15,13 +15,12 @@ from synthetic_data_kit.utils.config import get_curate_config, get_prompt from synthetic_data_kit.utils.llm_processing import convert_to_conversation_format, parse_ratings -def curate_qa_pairs( +def process_file( input_path: str, output_path: str, - threshold: Optional[float] = None, - api_base: Optional[str] = None, - model: Optional[str] = None, config_path: Optional[Path] = None, + llm_client: Optional[LLMClient] = None, + threshold: Optional[float] = None, verbose: bool = False, ) -> str: """Clean and filter QA pairs based on quality ratings @@ -29,16 +28,15 @@ def curate_qa_pairs( Args: input_path: Path to the input file with QA pairs output_path: Path to save the cleaned output - threshold: Quality threshold (1-10) - api_base: VLLM API base URL - model: Model to use config_path: Path to configuration file + llm_client: Initialized LLMClient instance + threshold: Quality threshold (1-10) verbose: Show detailed output Returns: Path to the cleaned output file """ - # Set verbose either via CLI or via env variable. If its via CLI, set it to env variable + # Set verbose either via CLI or via env variable if verbose: os.environ['SDK_VERBOSE'] = 'true' else: @@ -56,12 +54,8 @@ def curate_qa_pairs( if not qa_pairs: raise ValueError("No QA pairs found in the input file") - # Initialize LLM client - client = LLMClient( - config_path=config_path, - api_base=api_base, - model_name=model - ) + # Use provided client or create new one + client = llm_client or LLMClient(config_path=config_path) # Get threshold from args, then config, then default if threshold is None: @@ -114,8 +108,7 @@ def curate_qa_pairs( total_evaluated = 0 total_passed = 0 - # Process batches with simple progress indicator rather than a detailed bar - # This avoids conflicts with other output messages + # Process batches with simple progress indicator print(f"Processing {len(batches)} batches of QA pairs...") # Only use detailed progress bar in verbose mode @@ -198,91 +191,64 @@ def curate_qa_pairs( # Try processing one pair at a time as a fallback try: - if verbose: - print("Attempting to process items individually...") - - for item in original_batch: - item_json = json.dumps(item, indent=2) - rating_prompt = rating_prompt_template.format(pairs=item_json) - item_response = client.chat_completion( - [{"role": "system", "content": rating_prompt}], + for pair in original_batch: + single_batch = [pair] + single_json = json.dumps(single_batch, indent=2) + single_prompt = rating_prompt_template.format(pairs=single_json) + single_messages = [{"role": "system", "content": single_prompt}] + + single_response = client.chat_completion( + single_messages, temperature=rating_temperature ) - try: - # This should be a single item - rated_item = parse_ratings(item_response, [item]) - if rated_item and len(rated_item) > 0: - pair = rated_item[0] - if "rating" in pair: - rating = pair["rating"] - total_score += rating - total_evaluated += 1 - - if rating >= threshold: - filtered_pairs.append(pair) - total_passed += 1 - if verbose: - print(f"Successfully processed individual item with rating {rating}") - except Exception as inner_e: - if verbose: - print(f"Failed to process individual item: {str(inner_e)}") - except Exception as fallback_e: + + rated_pair = parse_ratings(single_response, single_batch)[0] + if "rating" in rated_pair: + rating = rated_pair["rating"] + total_score += rating + total_evaluated += 1 + + if rating >= threshold: + filtered_pairs.append(rated_pair) + total_passed += 1 + except Exception as e2: if verbose: - print(f"Fallback processing failed: {str(fallback_e)}") - - # Continue processing other batches rather than failing completely - pass + print(f"Error in fallback processing: {str(e2)}") - # Update progress bar if in verbose mode + # Update progress if progress_ctx and rate_task: progress_ctx.update(rate_task, advance=current_batch_size) - + except Exception as e: if verbose: - print(f"Error processing inference batch {batch_num}: {str(e)}") - - # Update progress bar if in verbose mode - if progress_ctx and rate_task: - progress_ctx.update(rate_task, advance=current_batch_size) + print(f"Error processing batch: {str(e)}") + continue - # Stop progress bar if in verbose mode + # Close progress bar if used if progress_ctx: progress_ctx.stop() - # Clear the progress line in non-verbose mode - if not verbose: - print(" " * 80, end="\r") - print("Batch processing complete.") - - # Calculate metrics - metrics = { - "total": len(qa_pairs), - "filtered": len(filtered_pairs), - "retention_rate": round(len(filtered_pairs) / len(qa_pairs), 2) if qa_pairs else 0, - "avg_score": round(total_score / total_evaluated, 1) if total_evaluated else 0 - } - - # Always print basic stats, even in non-verbose mode - print(f"Rated {total_evaluated} QA pairs") - print(f"Retained {total_passed} pairs (threshold: {threshold})") - print(f"Average score: {metrics['avg_score']}") - - # Convert to conversation format - conversations = convert_to_conversation_format(filtered_pairs) + # Print summary + if total_evaluated > 0: + avg_score = total_score / total_evaluated + pass_rate = (total_passed / total_evaluated) * 100 + print(f"\nProcessed {total_evaluated} QA pairs") + print(f"Average quality score: {avg_score:.2f}") + print(f"Pass rate: {pass_rate:.1f}% ({total_passed} pairs above threshold {threshold})") - # Create result with filtered pairs - result = { - "summary": summary, + # Save filtered pairs + output_data = { "qa_pairs": filtered_pairs, - "conversations": conversations, - "metrics": metrics + "summary": summary, + "metadata": { + "total_evaluated": total_evaluated, + "total_passed": total_passed, + "average_score": avg_score if total_evaluated > 0 else 0, + "threshold": threshold + } } - # Ensure output directory exists - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - # Save result with open(output_path, 'w', encoding='utf-8') as f: - json.dump(result, f, indent=2) + json.dump(output_data, f, indent=2) return output_path \ No newline at end of file diff --git a/synthetic_data_kit/core/save_as.py b/synthetic_data_kit/core/save_as.py index 69e4c9e6..2623839f 100644 --- a/synthetic_data_kit/core/save_as.py +++ b/synthetic_data_kit/core/save_as.py @@ -10,15 +10,17 @@ from pathlib import Path from typing import Optional, Dict, Any, List +from synthetic_data_kit.models.llm_client import LLMClient from synthetic_data_kit.utils.format_converter import to_jsonl, to_alpaca, to_fine_tuning, to_chatml, to_hf_dataset from synthetic_data_kit.utils.llm_processing import convert_to_conversation_format -def convert_format( +def process_file( input_path: str, output_path: str, format_type: str, - config: Optional[Dict[str, Any]] = None, - storage_format: str = "json", + storage_format: str, + config_path: Optional[Path] = None, + llm_client: Optional[LLMClient] = None, ) -> str: """Convert data to different formats @@ -26,8 +28,9 @@ def convert_format( input_path: Path to the input file output_path: Path to save the output format_type: Output format (jsonl, alpaca, ft, chatml) - config: Configuration dictionary storage_format: Storage format, either "json" or "hf" (Hugging Face dataset) + config_path: Path to configuration file + llm_client: Initialized LLMClient instance Returns: Path to the output file or directory diff --git a/synthetic_data_kit/models/llm_backends.py b/synthetic_data_kit/models/llm_backends.py new file mode 100644 index 00000000..c9326a68 --- /dev/null +++ b/synthetic_data_kit/models/llm_backends.py @@ -0,0 +1,191 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional +import requests +import json +import time +import os +from pathlib import Path + +class BaseLLMBackend(ABC): + """Abstract base class for LLM inference backends""" + + @abstractmethod + def check_server(self) -> tuple: + """Check if the server is running and accessible""" + pass + + @abstractmethod + def chat_completion(self, + messages: List[Dict[str, str]], + temperature: float = None, + max_tokens: int = None, + top_p: float = None) -> str: + """Generate a chat completion""" + pass + + @abstractmethod + def batch_completion(self, + message_batches: List[List[Dict[str, str]]], + temperature: float = None, + max_tokens: int = None, + top_p: float = None, + batch_size: int = None) -> List[str]: + """Process multiple message sets in batches""" + pass + +class VLLMBackend(BaseLLMBackend): + """VLLM backend implementation""" + + def __init__(self, api_base: str, model: str, max_retries: int = 3, retry_delay: float = 1.0): + self.api_base = api_base + self.model = model + self.max_retries = max_retries + self.retry_delay = retry_delay + + def check_server(self) -> tuple: + try: + response = requests.get(f"{self.api_base}/models", timeout=5) + if response.status_code == 200: + return True, response.json() + return False, f"Server returned status code: {response.status_code}" + except requests.exceptions.RequestException as e: + return False, f"Server connection error: {str(e)}" + + def chat_completion(self, + messages: List[Dict[str, str]], + temperature: float = None, + max_tokens: int = None, + top_p: float = None) -> str: + data = { + "model": self.model, + "messages": messages, + "temperature": temperature if temperature is not None else 0.7, + "max_tokens": max_tokens if max_tokens is not None else 4096, + "top_p": top_p if top_p is not None else 0.95 + } + + for attempt in range(self.max_retries): + try: + response = requests.post( + f"{self.api_base}/chat/completions", + headers={"Content-Type": "application/json"}, + data=json.dumps(data), + timeout=180 + ) + response.raise_for_status() + return response.json()["choices"][0]["message"]["content"] + except (requests.exceptions.RequestException, KeyError, IndexError) as e: + if attempt == self.max_retries - 1: + raise Exception(f"Failed to get completion after {self.max_retries} attempts: {str(e)}") + time.sleep(self.retry_delay * (attempt + 1)) + + def batch_completion(self, + message_batches: List[List[Dict[str, str]]], + temperature: float = None, + max_tokens: int = None, + top_p: float = None, + batch_size: int = None) -> List[str]: + batch_size = batch_size if batch_size is not None else 32 + results = [] + + for i in range(0, len(message_batches), batch_size): + batch_chunk = message_batches[i:i+batch_size] + batch_results = [] + + for messages in batch_chunk: + content = self.chat_completion( + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p + ) + batch_results.append(content) + + results.extend(batch_results) + time.sleep(0.1) + + return results + +class OllamaBackend(BaseLLMBackend): + """Ollama backend implementation""" + + def __init__(self, api_base: str = "http://localhost:11434/v1", model: str = "llama3.2", max_retries: int = 3, retry_delay: float = 1.0): + self.api_base = api_base + self.model = model + self.max_retries = max_retries + self.retry_delay = retry_delay + + def check_server(self) -> tuple: + try: + response = requests.get(f"{self.api_base}/models", timeout=5) + if response.status_code == 200: + return True, response.json() + return False, f"Server returned status code: {response.status_code}" + except requests.exceptions.RequestException as e: + return False, f"Server connection error: {str(e)}" + + def chat_completion(self, + messages: List[Dict[str, str]], + temperature: float = None, + max_tokens: int = None, + top_p: float = None) -> str: + data = { + "model": self.model, + "messages": messages, + "temperature": temperature if temperature is not None else 0.7, + "max_tokens": max_tokens if max_tokens is not None else 4096, + "top_p": top_p if top_p is not None else 0.95 + } + + for attempt in range(self.max_retries): + try: + response = requests.post( + f"{self.api_base}/chat/completions", + headers={"Content-Type": "application/json"}, + data=json.dumps(data), + timeout=180 + ) + response.raise_for_status() + return response.json()["choices"][0]["message"]["content"] + except (requests.exceptions.RequestException, KeyError, IndexError) as e: + if attempt == self.max_retries - 1: + raise Exception(f"Failed to get completion after {self.max_retries} attempts: {str(e)}") + time.sleep(self.retry_delay * (attempt + 1)) + + def batch_completion(self, + message_batches: List[List[Dict[str, str]]], + temperature: float = None, + max_tokens: int = None, + top_p: float = None, + batch_size: int = None) -> List[str]: + batch_size = batch_size if batch_size is not None else 32 + results = [] + + for i in range(0, len(message_batches), batch_size): + batch_chunk = message_batches[i:i+batch_size] + batch_results = [] + + for messages in batch_chunk: + content = self.chat_completion( + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p + ) + batch_results.append(content) + + results.extend(batch_results) + time.sleep(0.1) + + return results + + def list_models(self) -> List[Dict[str, Any]]: + """List available Ollama models""" + response = requests.get(f"{self.api_base}/models") + response.raise_for_status() + return response.json()["models"] + + def check_model_exists(self, model_name: str) -> bool: + """Check if a specific model exists""" + models = self.list_models() + return any(model["name"] == model_name for model in models) \ No newline at end of file diff --git a/synthetic_data_kit/models/llm_client.py b/synthetic_data_kit/models/llm_client.py index e5d2c272..cfa901dd 100644 --- a/synthetic_data_kit/models/llm_client.py +++ b/synthetic_data_kit/models/llm_client.py @@ -3,27 +3,30 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# vLLM logic: Will be expanded to ollama and Cerebras in future. from typing import List, Dict, Any, Optional -import requests -import json -import time -import os from pathlib import Path -from synthetic_data_kit.utils.config import load_config, get_vllm_config +from synthetic_data_kit.utils.config import ( + load_config, + get_backend_type, + get_vllm_config, + get_ollama_config +) +from synthetic_data_kit.models.llm_backends import VLLMBackend, OllamaBackend class LLMClient: def __init__(self, config_path: Optional[Path] = None, + backend: Optional[str] = None, api_base: Optional[str] = None, model_name: Optional[str] = None, max_retries: Optional[int] = None, retry_delay: Optional[float] = None): - """Initialize an OpenAI-compatible client that connects to a VLLM server + """Initialize an LLM client that supports multiple backends Args: config_path: Path to config file (if None, uses default) + backend: Override backend type from config ("vllm" or "ollama") api_base: Override API base URL from config model_name: Override model name from config max_retries: Override max retries from config @@ -31,70 +34,47 @@ def __init__(self, """ # Load config self.config = load_config(config_path) - vllm_config = get_vllm_config(self.config) - # Set parameters, with CLI overrides taking precedence - self.api_base = api_base or vllm_config.get('api_base') - self.model = model_name or vllm_config.get('model') - self.max_retries = max_retries or vllm_config.get('max_retries') - self.retry_delay = retry_delay or vllm_config.get('retry_delay') + # Determine backend type + self.backend_type = backend or get_backend_type(self.config) + + # Initialize appropriate backend + if self.backend_type == "vllm": + vllm_config = get_vllm_config(self.config) + self.backend = VLLMBackend( + api_base=api_base or vllm_config.get('api_base'), + model=model_name or vllm_config.get('model'), + max_retries=max_retries or vllm_config.get('max_retries'), + retry_delay=retry_delay or vllm_config.get('retry_delay') + ) + elif self.backend_type == "ollama": + ollama_config = get_ollama_config(self.config) + self.backend = OllamaBackend( + api_base=api_base or ollama_config.get('api_base'), + model=model_name or ollama_config.get('model'), + max_retries=max_retries or ollama_config.get('max_retries'), + retry_delay=retry_delay or ollama_config.get('retry_delay') + ) + else: + raise ValueError(f"Unsupported backend type: {self.backend_type}") # Verify server is running - available, info = self._check_server() + available, info = self.backend.check_server() if not available: - raise ConnectionError(f"VLLM server not available at {self.api_base}: {info}") - - def _check_server(self) -> tuple: - """Check if the VLLM server is running and accessible""" - try: - response = requests.get(f"{self.api_base}/models", timeout=5) - if response.status_code == 200: - return True, response.json() - return False, f"Server returned status code: {response.status_code}" - except requests.exceptions.RequestException as e: - return False, f"Server connection error: {str(e)}" + raise ConnectionError(f"{self.backend_type.upper()} server not available: {info}") def chat_completion(self, messages: List[Dict[str, str]], temperature: float = None, max_tokens: int = None, top_p: float = None) -> str: - """Generate a chat completion using the VLLM OpenAI-compatible API""" - # Get defaults from config if not provided - generation_config = self.config.get('generation', {}) - temperature = temperature if temperature is not None else generation_config.get('temperature', 0.1) - max_tokens = max_tokens if max_tokens is not None else generation_config.get('max_tokens', 4096) - top_p = top_p if top_p is not None else generation_config.get('top_p', 0.95) - - data = { - "model": self.model, - "messages": messages, - "temperature": temperature, - "max_tokens": max_tokens, - "top_p": top_p - } - - for attempt in range(self.max_retries): - try: - # Only print if verbose mode is enabled - if os.environ.get('SDK_VERBOSE', 'false').lower() == 'true': - print(f"Sending request to model {self.model}...") - response = requests.post( - f"{self.api_base}/chat/completions", - headers={"Content-Type": "application/json"}, - data=json.dumps(data), - timeout=180 # Increased timeout to 180 seconds - ) - if os.environ.get('SDK_VERBOSE', 'false').lower() == 'true': - print(f"Received response with status code: {response.status_code}") - - response.raise_for_status() - return response.json()["choices"][0]["message"]["content"] - - except (requests.exceptions.RequestException, KeyError, IndexError) as e: - if attempt == self.max_retries - 1: - raise Exception(f"Failed to get completion after {self.max_retries} attempts: {str(e)}") - time.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff + """Generate a chat completion using the configured backend""" + return self.backend.chat_completion( + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p + ) def batch_completion(self, message_batches: List[List[Dict[str, str]]], @@ -102,69 +82,26 @@ def batch_completion(self, max_tokens: int = None, top_p: float = None, batch_size: int = None) -> List[str]: - """Process multiple message sets in batches - - Instead of sending requests one at a time, this method processes - multiple prompts in batches to maximize throughput. It uses VLLM's - ability to efficiently batch requests. - """ - # Get defaults from config if not provided - generation_config = self.config.get('generation', {}) - temperature = temperature if temperature is not None else generation_config.get('temperature', 0.1) - max_tokens = max_tokens if max_tokens is not None else generation_config.get('max_tokens', 4096) - top_p = top_p if top_p is not None else generation_config.get('top_p', 0.95) - batch_size = batch_size if batch_size is not None else generation_config.get('batch_size', 32) - - verbose = os.environ.get('SDK_VERBOSE', 'false').lower() == 'true' - results = [] - - # Process message batches in chunks to avoid overloading the server - for i in range(0, len(message_batches), batch_size): - batch_chunk = message_batches[i:i+batch_size] - if verbose: - print(f"Processing batch {i//batch_size + 1}/{(len(message_batches) + batch_size - 1) // batch_size} with {len(batch_chunk)} requests") - - # Create batch request payload for VLLM - batch_requests = [] - for messages in batch_chunk: - batch_requests.append({ - "model": self.model, - "messages": messages, - "temperature": temperature, - "max_tokens": max_tokens, - "top_p": top_p - }) - - try: - # For now, we run these in parallel with multiple requests - batch_results = [] - for request_data in batch_requests: - # Only print if verbose mode is enabled - if verbose: - print(f"Sending batch request to model {self.model}...") - - response = requests.post( - f"{self.api_base}/chat/completions", - headers={"Content-Type": "application/json"}, - data=json.dumps(request_data), - timeout=180 # Increased timeout for batch processing - ) - - if verbose: - print(f"Received response with status code: {response.status_code}") - - response.raise_for_status() - content = response.json()["choices"][0]["message"]["content"] - batch_results.append(content) - - results.extend(batch_results) - - except (requests.exceptions.RequestException, KeyError, IndexError) as e: - raise Exception(f"Failed to process batch: {str(e)}") - - time.sleep(0.1) - - return results + """Process multiple message sets in batches using the configured backend""" + return self.backend.batch_completion( + message_batches=message_batches, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + batch_size=batch_size + ) + + def list_models(self) -> List[Dict[str, Any]]: + """List available models (Ollama only)""" + if self.backend_type == "ollama": + return self.backend.list_models() + raise NotImplementedError("Model listing is only supported for Ollama backend") + + def check_model_exists(self, model_name: str) -> bool: + """Check if a specific model exists (Ollama only)""" + if self.backend_type == "ollama": + return self.backend.check_model_exists(model_name) + raise NotImplementedError("Model existence check is only supported for Ollama backend") @classmethod def from_config(cls, config_path: Path) -> 'LLMClient': diff --git a/synthetic_data_kit/utils/config.py b/synthetic_data_kit/utils/config.py index b0b75765..892b8338 100644 --- a/synthetic_data_kit/utils/config.py +++ b/synthetic_data_kit/utils/config.py @@ -60,6 +60,10 @@ def get_path_config(config: Dict[str, Any], path_type: str, file_type: Optional[ else: raise ValueError(f"Unknown path type: {path_type}") +def get_backend_type(config: Dict[str, Any]) -> str: + """Get the LLM backend type from configuration""" + return config.get('backend', 'vllm') + def get_vllm_config(config: Dict[str, Any]) -> Dict[str, Any]: """Get VLLM configuration""" return config.get('vllm', { @@ -70,6 +74,15 @@ def get_vllm_config(config: Dict[str, Any]) -> Dict[str, Any]: 'retry_delay': 1.0 }) +def get_ollama_config(config: Dict[str, Any]) -> Dict[str, Any]: + """Get Ollama configuration""" + return config.get('ollama', { + 'api_base': 'http://localhost:11434/v1', + 'model': 'llama3.2', + 'max_retries': 3, + 'retry_delay': 1.0 + }) + def get_generation_config(config: Dict[str, Any]) -> Dict[str, Any]: """Get generation configuration""" return config.get('generation', {