diff --git a/.gitignore b/.gitignore index 42545fbf..25a4f0e8 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,5 @@ cython_debug/ # PyPI configuration file .pypirc + +/results/ \ No newline at end of file diff --git a/README.md b/README.md index b0dba5f7..f7cc38f0 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ The framework supports **parallel evaluation of candidates** locally or on a Slurm cluster. It maintains an archive of successful solutions, enabling knowledge transfer between different evolutionary islands. `ShinkaEvolve` is particularly well-suited for scientific tasks where there is a verifier available and the goal is to optimize performance metrics while maintaining code correctness and readability. -![](docs/conceptual.png) +![evolution](https://github.com/user-attachments/assets/22cf3468-17fe-4995-9e13-d602b490a54e) ## Documentation πŸ“ @@ -26,6 +26,7 @@ The framework supports **parallel evaluation of candidates** locally or on a Slu | πŸ““ **[Tutorial Notebook](examples/shinka_tutorial.ipynb)** | Interactive walkthrough of Shinka features | Hands-on examples, configuration, best practices | | βš™οΈ **[Configuration](docs/configuration.md)** | Comprehensive configuration reference | All config options, optimization settings, advanced features | | 🎨 **[WebUI](docs/webui.md)** | Interactive visualization and monitoring | Real-time tracking, result analysis, debugging tools | +|πŸ•ΉοΈ **[Local LLM Support](https://github.com/SakanaAI/ShinkaEvolve/blob/main/docs/support_local_llm.md)**| Instructions for Local LLMs | How to setup local LLMs on your machine| ## Installation & Quick Start πŸš€ @@ -127,6 +128,7 @@ runner.run() | `migration_interval` | `10` | `int` | Generations between island migrations | | `migration_rate` | `0.1` | `float` | Proportion of island population to migrate | | `island_elitism` | `True` | `bool` | Keep best programs on their original islands | +| `migration_adaptation` | `None` | `Optional[dict]` | Optional adaptive migration configuration per island | | `enforce_island_separation` | `True` | `bool` | Enforce full separation between islands | | `parent_selection_strategy` | `"power_law"` | `str` | Parent selection: "weighted", "power_law", "beam_search" | | `exploitation_alpha` | `1.0` | `float` | Power-law exponent (0=uniform, 1=power-law) | @@ -136,6 +138,11 @@ runner.run() +Adaptive migration (success/diversity/bandit feedback that updates +`migration_rate`, `migration_interval`, and `island_elitism` per island) is +described in `docs/migration_adaptation.md`. A ready-to-run Hydra example lives +at `configs/database/island_adaptive.yaml`. +
JobConfig Parameters (click to expand) diff --git a/configs/config.yaml b/configs/config.yaml index 9702c661..b819e61f 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -2,9 +2,9 @@ defaults: - _self_ - database@_global_: island_small - evolution@_global_: small_budget - - task@_global_: mad_tf + - task@_global_: circle_packing - cluster@_global_: local - - variant@_global_: mad_tf_example + - variant@_global_: default verbose: false results_dir: results diff --git a/configs/database/island_adaptive.yaml b/configs/database/island_adaptive.yaml new file mode 100644 index 00000000..e0402dcb --- /dev/null +++ b/configs/database/island_adaptive.yaml @@ -0,0 +1,45 @@ +# Adaptive migration sample configuration +db_config: + _target_: shinka.database.DatabaseConfig + db_path: "evolution_db.sqlite" + num_islands: 4 + archive_size: 100 + elite_selection_ratio: 0.3 + num_archive_inspirations: 5 + num_top_k_inspirations: 2 + migration_interval: 10 + migration_rate: 0.1 + island_elitism: 0.1 + migration_adaptation: + enabled: true + methods: ["success", "diversity", "bandit"] + success: + window: 8 + target_improvement: 0.01 + step_up: 1.2 + step_down: 0.83 + ema_beta: 0.75 + diversity: + metric: "score_std" + low_thresh: 0.25 + high_thresh: 0.6 + adjust_strength: 0.15 + bandit: + policy_arms: + donor: ["random", "ring", "topk"] + payload: ["random", "elite", "novel"] + size: ["small", "medium", "large"] + algo: "ucb1" + ucb_c: 1.0 + epsilon: 0.1 + bounds: + rate_min: 0.02 + rate_max: 0.5 + interval_min: 2 + interval_max: 40 + elitism_min: 0.05 + elitism_max: 0.4 + weights: + success: 1.0 + diversity: 0.5 + bandit: 0.5 diff --git a/configs/database/island_large_adaptive.yaml b/configs/database/island_large_adaptive.yaml new file mode 100644 index 00000000..0a85162e --- /dev/null +++ b/configs/database/island_large_adaptive.yaml @@ -0,0 +1,49 @@ +# Adaptive variant of island_large +db_config: + _target_: shinka.database.DatabaseConfig + db_path: "evolution_db.sqlite" + num_islands: 5 + archive_size: 40 + elite_selection_ratio: 0.3 + num_archive_inspirations: 4 + num_top_k_inspirations: 2 + migration_interval: 10 + migration_rate: 0.1 + island_elitism: true + parent_selection_strategy: "weighted" + exploitation_alpha: 1.0 + exploitation_ratio: 0.2 + parent_selection_lambda: 10.0 + migration_adaptation: + enabled: true + methods: ["success", "diversity", "bandit"] + success: + window: 8 + target_improvement: 0.01 + step_up: 1.2 + step_down: 0.83 + ema_beta: 0.75 + diversity: + metric: "score_std" + low_thresh: 0.25 + high_thresh: 0.6 + adjust_strength: 0.15 + bandit: + policy_arms: + donor: ["random", "ring", "topk"] + payload: ["random", "elite", "novel"] + size: ["small", "medium", "large"] + algo: "ucb1" + ucb_c: 1.0 + epsilon: 0.1 + bounds: + rate_min: 0.02 + rate_max: 0.5 + interval_min: 2 + interval_max: 40 + elitism_min: 0.05 + elitism_max: 0.4 + weights: + success: 1.0 + diversity: 0.5 + bandit: 0.5 diff --git a/configs/database/island_medium_adaptive.yaml b/configs/database/island_medium_adaptive.yaml new file mode 100644 index 00000000..591c404c --- /dev/null +++ b/configs/database/island_medium_adaptive.yaml @@ -0,0 +1,48 @@ +# Adaptive variant of island_medium +db_config: + _target_: shinka.database.DatabaseConfig + db_path: "evolution_db.sqlite" + num_islands: 2 + archive_size: 40 + exploitation_ratio: 0.2 + elite_selection_ratio: 0.3 + num_archive_inspirations: 4 + num_top_k_inspirations: 2 + migration_interval: 10 + migration_rate: 0.0 + island_elitism: true + parent_selection_strategy: "weighted" + parent_selection_lambda: 10.0 + migration_adaptation: + enabled: true + methods: ["success", "diversity", "bandit"] + success: + window: 8 + target_improvement: 0.01 + step_up: 1.2 + step_down: 0.83 + ema_beta: 0.75 + diversity: + metric: "score_std" + low_thresh: 0.25 + high_thresh: 0.6 + adjust_strength: 0.15 + bandit: + policy_arms: + donor: ["random", "ring", "topk"] + payload: ["random", "elite", "novel"] + size: ["small", "medium", "large"] + algo: "ucb1" + ucb_c: 1.0 + epsilon: 0.1 + bounds: + rate_min: 0.02 + rate_max: 0.5 + interval_min: 2 + interval_max: 40 + elitism_min: 0.05 + elitism_max: 0.4 + weights: + success: 1.0 + diversity: 0.5 + bandit: 0.5 diff --git a/configs/database/island_small_adaptive.yaml b/configs/database/island_small_adaptive.yaml new file mode 100644 index 00000000..b0b3fc47 --- /dev/null +++ b/configs/database/island_small_adaptive.yaml @@ -0,0 +1,46 @@ +# Adaptive variant of island_small +db_config: + _target_: shinka.database.DatabaseConfig + db_path: "evolution_db.sqlite" + num_islands: 2 + archive_size: 20 + exploitation_ratio: 0.2 + elite_selection_ratio: 0.3 + num_archive_inspirations: 4 + num_top_k_inspirations: 2 + migration_interval: 10 + migration_rate: 0.1 + island_elitism: true + migration_adaptation: + enabled: true + methods: ["success", "diversity", "bandit"] + success: + window: 8 + target_improvement: 0.01 + step_up: 1.2 + step_down: 0.83 + ema_beta: 0.75 + diversity: + metric: "score_std" + low_thresh: 0.25 + high_thresh: 0.6 + adjust_strength: 0.15 + bandit: + policy_arms: + donor: ["random", "ring", "topk"] + payload: ["random", "elite", "novel"] + size: ["small", "medium", "large"] + algo: "ucb1" + ucb_c: 1.0 + epsilon: 0.1 + bounds: + rate_min: 0.02 + rate_max: 0.5 + interval_min: 2 + interval_max: 40 + elitism_min: 0.05 + elitism_max: 0.4 + weights: + success: 1.0 + diversity: 0.5 + bandit: 0.5 diff --git a/docs/configuration.md b/docs/configuration.md index 670df302..035e0d7e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -157,6 +157,7 @@ exp_name: "shinka_my_task" | `migration_interval` | int | 10 | Generations between island migrations | | `migration_rate` | float | 0.1 | Fraction of population migrated | | `island_elitism` | bool | true | Preserve elites per island | +| `migration_adaptation` | dict | null | Optional per-island adaptive migration settings | ### Resource Parameters @@ -169,6 +170,14 @@ exp_name: "shinka_my_task" | `conda_env` | str | `"shinka"` | Conda environment name | | `modules` | list | `[]` | Environment modules to load | +#### Adaptive Migration Settings + +The optional `migration_adaptation` block enables automatic adjustment of +`migration_rate`, `migration_interval`, and `island_elitism` per island. Each +method (`success`, `diversity`, `bandit`) can be toggled independently and has a +dedicated configuration block. Refer to `docs/migration_adaptation.md` and the +sample `configs/database/island_adaptive.yaml` file for a complete walkthrough. + ## Pre-configured Variants Shinka uses [Hydra](https://hydra.cc/) for flexible, hierarchical configuration management. The system is designed around composable configuration files that can be mixed and matched to create different experimental setups. diff --git a/docs/getting_started.md b/docs/getting_started.md index 23415883..03bc54c8 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -2,6 +2,8 @@ Shinka is a framework that combines Large Language Models (LLMs) with evolutionary algorithms to drive scientific discovery. This guide will help you get started with installing, configuring, and running your first evolutionary experiments. +![](../docs/conceptual.png) + ## Table of Contents 1. [What is Shinka?](#what-is-shinka) @@ -53,7 +55,7 @@ pip install uv ```bash git clone -cd shinka +cd ShinkaEvolve # Create virtual environment with Python 3.11 uv venv --python 3.11 @@ -79,7 +81,7 @@ conda activate shinka ```bash git clone -cd shinka +cd ShinkaEvolve pip install -e . ``` @@ -249,7 +251,7 @@ from shinka.core import run_shinka_eval def main(program_path: str, results_dir: str): """Main evaluation function called by Shinka""" - + metrics, correct, error_msg = run_shinka_eval( program_path=program_path, results_dir=results_dir, @@ -268,11 +270,11 @@ def main(program_path: str, results_dir: str): def validate_packing(run_output): """Returns (is_valid: bool, error_msg: str or None)""" centers, radii, reported_sum = run_output - + # Check constraints (bounds, overlaps, etc.) if constraint_violated: return False, "Specific error description" - + return True, None # Valid solution ``` @@ -280,10 +282,10 @@ def validate_packing(run_output): ```python def aggregate_metrics(results, results_dir): """Returns metrics dictionary with required structure""" - + # Extract data from results centers, radii, reported_sum = results[0] - + return { "combined_score": float(reported_sum), # PRIMARY FITNESS (higher = better) "public": { # Visible in WebUI/logs @@ -331,6 +333,75 @@ The `run_shinka_eval` function returns three values: ## Advanced Usage +### Resuming Experiments + +If you need to pause and resume an evolutionary run, or extend a completed run with more generations, Shinka supports seamless resumption from existing results. + +#### How Resuming Works + +When you specify an existing `results_dir` that contains a database, Shinka will: +- Detect the previous run automatically +- Restore the population database and all program history +- Resume meta-recommendations from the last checkpoint +- Continue from the last completed generation + +#### Using the CLI (Hydra) + +```bash +# Resume an existing run and extend to 50 generations +shinka_launch \ + variant=circle_packing_example \ + evo_config.results_dir=results_20250101_120000 \ + evo_config.num_generations=50 + +# Or with a custom task +shinka_launch \ + task=circle_packing \ + database=island_small \ + evolution=small_budget \ + cluster=local \ + evo_config.results_dir=path/to/previous/results \ + evo_config.num_generations=100 +``` + +#### Using the Python API + +```python +from shinka.core import EvolutionRunner, EvolutionConfig +from shinka.database import DatabaseConfig +from shinka.launch import LocalJobConfig + +# Point to existing results directory +evo_config = EvolutionConfig( + num_generations=50, # Extend to 50 total generations + results_dir="results_20250101_120000", # Existing results + # ... other config parameters ... +) + +job_config = LocalJobConfig( + eval_program_path="examples/circle_packing/evaluate.py", +) + +db_config = DatabaseConfig( + archive_size=20, + num_islands=2, +) + +# Run will automatically detect and resume +runner = EvolutionRunner( + evo_config=evo_config, + job_config=job_config, + db_config=db_config, +) +runner.run() +``` + +**Important Notes:** +- The `num_generations` parameter should be set to the **total** number of generations you want (not additional generations) +- For example, if you completed 20 generations and want 30 more, set `num_generations=50` +- The database configuration (number of islands, archive size, etc.) should match the original run +- All previous progress, including the best solutions and meta-recommendations, will be preserved + ### Environment Management for Local Jobs When running jobs locally, you have several options for managing Python environments: diff --git a/docs/migration_adaptation.md b/docs/migration_adaptation.md new file mode 100644 index 00000000..4425c047 --- /dev/null +++ b/docs/migration_adaptation.md @@ -0,0 +1,43 @@ +# Migration Adaptation + +The migration adaptation option lets each island tune its migration schedule +and policy online. When `database.migration_adaptation.enabled=true`, the +evolution runner instantiates an adaptive controller that observes island +improvements and diversity after each generation and adjusts: + +- `migration_rate` – fraction of the island population exported during a + migration event. +- `migration_interval` – generations between migrations for the island. +- `island_elitism` – fraction of the island protected from migration. + +## Methods + +Three complementary signals can be combined by listing them in +`migration_adaptation.methods`: + +1. **success** – tracks the relative improvement caused by recent migrations. + When improvement exceeds `target_improvement`, the controller increases the + migration rate, shortens intervals, and slightly raises the elitism ration. +2. **diversity** – computes a lightweight diversity score (score standard + deviation of recent programs). Falling below `low_thresh` triggers more + exploration; exceeding `high_thresh` stretches the interval and lowers the + rate. +3. **bandit** – chooses migration policies (donor routing, payload selection, + and migration size) using UCB1 or epsilon-greedy bandits. Rewards are the + normalized improvements measured by the success tracker. + +Each method has dedicated configuration blocks plus global `bounds` and +`weights` sections. Bounds clamp the adaptive parameters, while weights scale +how aggressively each method can move them. + +## Logging + +Adaptive decisions are recorded per generation in +`/migration_adaptation.csv` with columns for rate, interval, +elitism, EMA improvement, diversity, and the last bandit arm. These logs make +it easy to visualize how migration policies evolved across islands. + +## Example configuration + +See `configs/database/island_adaptive.yaml` for a complete Hydra configuration +that enables all three methods with reasonable defaults. diff --git a/docs/support_local_llm.md b/docs/support_local_llm.md new file mode 100644 index 00000000..5f406e7b --- /dev/null +++ b/docs/support_local_llm.md @@ -0,0 +1,232 @@ + +# 🧩 Integrating Local LLMs into **ShinkaEvolve** + +## 🧠 Overview + +The original **ShinkaEvolve** code does **not** include built-in support for running **local LLMs**. +To enable this functionality, parts of the codebase can be modified to integrate locally hosted models. + +--- + +## πŸ—οΈ Code Organization + +**ShinkaEvolve** uses a **modular architecture** that supports multiple **LLM providers**. +The relevant code for LLM interaction is located in the **`LLM/`** folder, which manages all model communications. +ShinkaEvolve distinguishes between two LLM types: + +* **Regular LLMs** +* **Embedding LLMs** + +--- + +## βš™οΈ Adding a Regular LLM + +To add support for a **regular LLM**, follow these steps. They will show an example of adding support for gpt-oss models running with unsloth, which provides an API compatible with OpenAI API (v1/completions). +This LLM can then be specified in the configuration variables: + +```yaml +llm_models: +meta_llm_models: +``` + +--- + +### πŸ”§ Step 1: Modify the Client + +The file **`client.py`** is responsible for creating clients that interact with LLMs. +Each client instance is later used to query a specific model. + +To add a local model, introduce a new client configuration. +The API URL is extracted from the model name, which follows this format: + +``` +local-gptoss-unsloth-url +``` + +#### Example + +```python +elif "local-gptoss-unsloth" in model_name: + # Extract URL from model name + pattern = r"https?://" + match = re.search(pattern, model_name) + if match: + start_index = match.start() + url = model_name[start_index:] + else: + raise ValueError(f"Invalid URL in model name: {model_name}") + + # Create OpenAI-compatible client + client = openai.OpenAI( + api_key="filler", + base_url=url + ) + + # Structured output mode (if required) + if structured_output: + client = instructor.from_openai( + client, + mode=instructor.Mode.JSON, + ) +``` + +--- + +### πŸ“ Step 2: Create the Local Query Function + +Inside the **`models/`** folder, create a new subfolder to store the query functions for your local models: + +``` +LLM/models/local/ +``` + +> Don’t forget to include an empty `__init__.py` file. + +This folder should contain a **custom query function** for the local model. I called my file local_gptoss_unsloth.py. +It should follow the same structure as other functions in `LLM/models/`, but with small adjustments. + +#### My Key Adjustments + +* Replace `max_output_tokens` with **`max_tokens`** to match the local API. +* Extract additional response metadata such as: + + * `total_tokens` + * `thinking_tokens` (if your model includes reasoning traces) + +This function is later imported and registered in **`query.py`**. + +--- + +### 🧩 Step 3: Update `__init__.py` + +Configure **`__init__.py`** to include and expose the new local query function, so it can be imported elsewhere. + +``` +from .local.local_gptoss_unsloth import query_local_gptoss_unsloth # ADDED THIS LINE +from .result import QueryResult + +__all__ = [ + "query_anthropic", + "query_openai", + "query_deepseek", + "query_gemini", + "query_local_gptoss_unsloth", # ADDED THIS LINE + "QueryResult", +] +``` + +--- + +### πŸ“¬ Step 4: Update `query.py` + +Import and register the new local query function in query.py. + +#### Imports + +```python +from .models import ( + query_anthropic, + query_openai, + query_deepseek, + query_gemini, + query_local_gptoss_unsloth, # ADDED THIS LINE + QueryResult, +) +``` + +#### Model Selection Logic + +```python +elif "local-gptoss-unsloth" in model_name: # ADDED THIS LINE + query_fn = query_local_gptoss_unsloth +``` + +--- + +### 🧠 Step 5: Other Observations + +The file **`query.py`** also defines functions such as: + +* `sample_model_kwargs` +* `sample_batch_kwargs` + +However, these are **not referenced anywhere else** in the repository, so no modifications are required here for now. + +--- + +### βœ… Summary + +| Step | File | Change | Description | +| ---- | -------------------------------------------- | -------------------- | -------------------------------------------------------- | +| 1 | `client.py` | Add new client block | Create OpenAI-compatible client for local LLM | +| 2 | `models/local/query_local_gptoss_unsloth.py` | New function | Query local model, adjust tokens, extract reasoning info | +| 3 | `__init__.py` | Add import | Expose new query function | +| 4 | `query.py` | Register model | Add conditional for local LLM | +| 5 | β€” | Review only | Ignored unused functions | + +--- + +## 🧬 Adding a Local Embedding Model + +For embedding models, you can use **Ollama**, which follows the **OpenAI API** format. +The only relevant file is **`embedding.py`**. + +### Code Addition + +```python +elif model_name.startswith("local-"): + # Pattern: local-(model-name)-(http or https url) + match = re.match(r"local-(.+?)-(https?://.+)", model_name) + if match: + model_to_use = match.group(1) + url = match.group(2) + else: + raise ValueError(f"Invalid local model format: {model_name}") + + client = openai.OpenAI( + base_url=url, + api_key="filler" + ) +``` + +#### Notes + +* Compatible with **any Ollama model**. +* The model name must follow this convention: + + ``` + local-model-name-url + ``` +* The code extracts both `model-name` and `url`, and uses them to query Ollama. + +--- + +### Query Logic + +The existing line in **`embedding.py`** remains unchanged: + +```python +response = self.client.embeddings.create( + model=self.model, + input=code, + encoding_format="float" +) +``` + +For local embedding models, `self.model` corresponds to the extracted model name. +The only addition to the **Embedding Client** class: + +```python +elif self.model_name.startswith("local-"): + cost = 0.0 +``` + +--- + +## πŸš€ Result + +ShinkaEvolve can now connect to **locally hosted LLMs** and **embedding models** through **OpenAI-compatible APIs**. +This setup supports **Ollama** and other frameworks such as **gpt-oss** under **Unsloth**. + +If your model has different requirements, follow the same pattern with a distinct model identifier and your own custom logic. + diff --git a/examples/shinka_tutorial.ipynb b/examples/shinka_tutorial.ipynb index 66a71a07..c6d81899 100644 --- a/examples/shinka_tutorial.ipynb +++ b/examples/shinka_tutorial.ipynb @@ -237,6 +237,17 @@ "if not llm_models:\n", " llm_models = [\"gpt-5-mini\"] # fallback if no keys detected\n", "\n", + "# pick embedding model based on available keys\n", + "embedding_model_name = \"\"\n", + "if os.getenv(\"GEMINI_API_KEY\"):\n", + " embedding_model_name = \"gemini-embedding-001\"\n", + "elif os.getenv(\"OPENAI_API_KEY\"):\n", + " embedding_model_name = \"text-embedding-3-small\"\n", + "else:\n", + " embedding_model_name = \"text-embedding-3-small\"\n", + "print(f\"βœ… Embedding model selected: {embedding_model_name}\")\n", + "\n", + "\n", "# unique experiment directory\n", "timestamp = dt.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", "run_tag = f\"{timestamp}_weighted_fast\"\n", @@ -271,6 +282,8 @@ " max_novelty_attempts=3,\n", " # ensemble llm selection among candidates based on past performance\n", " llm_dynamic_selection=None, # e.g. \"ucb1\"\n", + " # set embedding model\n", + " embedding_model=embedding_model_name,\n", ")\n", "\n", "db_config = DatabaseConfig(\n", @@ -286,11 +299,13 @@ " enforce_island_separation=True,\n", " parent_selection_strategy=\"weighted\",\n", " parent_selection_lambda=10.0,\n", + " \n", ")\n", "\n", "job_config = LocalJobConfig(eval_program_path=\"evaluate.py\")\n", "\n", "print(\"llm_models:\", llm_models)\n", + "print(\"embedding_model:\", embedding_model_name)\n", "print(\"results_dir:\", evo_config.results_dir)" ] }, diff --git a/pyproject.toml b/pyproject.toml index f05429b6..5802a152 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "adjustText", "markdown", "aiofiles", + "google-generativeai", ] [tool.setuptools] @@ -56,8 +57,8 @@ include = ["shinka", "shinka.*"] [tool.setuptools.package-data] "*" = ["*"] -[tool.uv] -dev-dependencies = [ +[dependency-groups] +dev = [ "pytest>=6.0", "black", "isort", diff --git a/shinka/core/__init__.py b/shinka/core/__init__.py index 784a4550..86181baa 100644 --- a/shinka/core/__init__.py +++ b/shinka/core/__init__.py @@ -1,8 +1,24 @@ -from .runner import EvolutionRunner, EvolutionConfig -from .sampler import PromptSampler -from .summarizer import MetaSummarizer -from .novelty_judge import NoveltyJudge -from .wrap_eval import run_shinka_eval +try: # pragma: no cover - optional during limited test envs + from .runner import EvolutionRunner, EvolutionConfig +except Exception: # pragma: no cover + EvolutionRunner = None # type: ignore + EvolutionConfig = None # type: ignore +try: # pragma: no cover + from .sampler import PromptSampler +except Exception: # pragma: no cover + PromptSampler = None # type: ignore +try: # pragma: no cover + from .summarizer import MetaSummarizer +except Exception: # pragma: no cover + MetaSummarizer = None # type: ignore +try: # pragma: no cover + from .novelty_judge import NoveltyJudge +except Exception: # pragma: no cover + NoveltyJudge = None # type: ignore +try: # pragma: no cover + from .wrap_eval import run_shinka_eval +except Exception: # pragma: no cover + run_shinka_eval = None # type: ignore __all__ = [ "EvolutionRunner", diff --git a/shinka/core/migration_adaptation.py b/shinka/core/migration_adaptation.py new file mode 100644 index 00000000..12a0756d --- /dev/null +++ b/shinka/core/migration_adaptation.py @@ -0,0 +1,453 @@ +"""Adaptive migration controller for Shinka evolution runs.""" + +from __future__ import annotations + +import csv +import logging +import math +import random +import statistics +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from shinka.database.dbase import DatabaseConfig, MigrationAdaptationConfig +from shinka.database.islands import IslandMigrationParams, MigrationSummary + +if TYPE_CHECKING: # pragma: no cover - only for static checks + from shinka.database.dbase import ProgramDatabase + + +@dataclass +class PendingSuccessEval: + pre_best: float + evaluate_after: int + migration_generation: int + arm_key: Optional[str] = None + + +@dataclass +class BanditArmState: + count: int = 0 + total_reward: float = 0.0 + + def average(self) -> float: + if self.count == 0: + return 0.0 + return self.total_reward / self.count + + +@dataclass +class IslandAdaptationState: + migration_rate: float + migration_interval: int + island_elitism: float + impr_ema: float = 0.0 + diversity: float = 0.0 + pending_success: Optional[PendingSuccessEval] = None + last_policy_key: Optional[str] = None + + +class MigrationAdaptationController: + """Runtime controller that adapts migration parameters per island.""" + + def __init__( + self, + db: "ProgramDatabase", + config: DatabaseConfig, + results_dir: Path, + logger: Optional[logging.Logger] = None, + ) -> None: + self.db = db + self.config = config + self.adapt_config: Optional[MigrationAdaptationConfig] = ( + config.migration_adaptation + ) + self.logger = logger or logging.getLogger(__name__) + self.results_dir = Path(results_dir) + + self.enabled = bool(self.adapt_config and self.adapt_config.enabled) + if not self.enabled: + return + + self.methods = { + method.lower() for method in (self.adapt_config.methods or []) + } + self.weights = self.adapt_config.weights + self.bounds = self.adapt_config.bounds + self.num_islands = max(0, config.num_islands) + + self.states: Dict[int, IslandAdaptationState] = {} + self._initialize_states() + + self.bandit_enabled = "bandit" in self.methods + self.bandit_algo = (self.adapt_config.bandit.algo or "ucb1").lower() + self.bandit_ucb_c = self.adapt_config.bandit.ucb_c + self.bandit_epsilon = self.adapt_config.bandit.epsilon + self.bandit_arms: List[Tuple[str, Dict[str, str]]] = [] + self.bandit_stats: Dict[int, Dict[str, BanditArmState]] = {} + if self.bandit_enabled: + self._init_bandit_state() + + self.results_dir.mkdir(parents=True, exist_ok=True) + self.log_path = self.results_dir / "migration_adaptation.csv" + self._log_header_written = False + + self._register_with_islands() + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def on_generation_completed(self, generation: int) -> None: + if not self.enabled: + return + + if "success" in self.methods: + self._evaluate_success_metrics(generation) + + if "diversity" in self.methods: + self._update_diversity_metrics() + + self._log_state(generation) + + # ------------------------------------------------------------------ + # Initialization helpers + # ------------------------------------------------------------------ + def _register_with_islands(self) -> None: + manager = getattr(self.db, "island_manager", None) + if manager is None: + self.logger.warning( + "Migration adaptation enabled but no island manager is available." + ) + self.enabled = False + return + + params = { + idx: IslandMigrationParams( + migration_rate=state.migration_rate, + migration_interval=state.migration_interval, + island_elitism=state.island_elitism, + ) + for idx, state in self.states.items() + } + manager.set_island_params_bulk(params) + manager.set_migration_callback(self._on_migration_summary) + if self.bandit_enabled: + manager.register_policy_provider(self._policy_provider) + + def _initialize_states(self) -> None: + base_rate = self._clamp_value( + self.config.migration_rate, self.bounds.rate_min, self.bounds.rate_max + ) + base_interval = max(2, self.config.migration_interval) + base_elitism = self._normalize_elitism(self.config.island_elitism) + + for idx in range(self.num_islands): + self.states[idx] = IslandAdaptationState( + migration_rate=base_rate, + migration_interval=base_interval, + island_elitism=base_elitism, + ) + + def _init_bandit_state(self) -> None: + arms_cfg = self.adapt_config.bandit.policy_arms + donors = arms_cfg.donor or ["random"] + payloads = arms_cfg.payload or ["random"] + sizes = arms_cfg.size or ["medium"] + for donor in donors: + for payload in payloads: + for size in sizes: + key = f"{donor}|{payload}|{size}" + policy = {"donor": donor, "payload": payload, "size": size} + self.bandit_arms.append((key, policy)) + + for idx in range(self.num_islands): + self.bandit_stats[idx] = { + key: BanditArmState() for key, _ in self.bandit_arms + } + + # ------------------------------------------------------------------ + # Success-based updates + # ------------------------------------------------------------------ + def _on_migration_summary(self, summary: MigrationSummary) -> None: + if not self.enabled: + return + for island_idx in summary.per_island.keys(): + state = self.states.get(island_idx) + if state is None: + continue + pre_best = self._get_island_best_score(island_idx) + state.pending_success = PendingSuccessEval( + pre_best=pre_best, + evaluate_after=summary.generation + self.adapt_config.success.window, + migration_generation=summary.generation, + arm_key=state.last_policy_key, + ) + + def _evaluate_success_metrics(self, generation: int) -> None: + beta = self.adapt_config.success.ema_beta + for island_idx, state in self.states.items(): + pending = state.pending_success + if pending is None: + continue + if generation < pending.evaluate_after: + continue + + post_best = self._get_island_best_score(island_idx) + delta = self._relative_improvement(pending.pre_best, post_best) + state.impr_ema = beta * state.impr_ema + (1 - beta) * delta + self._apply_success_update(island_idx, state.impr_ema) + + if self.bandit_enabled and pending.arm_key: + self._update_bandit_reward(island_idx, pending.arm_key, delta) + + state.pending_success = None + + def _apply_success_update(self, island_idx: int, improvement: float) -> None: + state = self.states[island_idx] + cfg = self.adapt_config.success + weight = max(0.0, self.weights.success) + + if improvement >= cfg.target_improvement: + rate_factor = cfg.step_up ** weight + interval_factor = cfg.step_down ** weight + elitism_delta = 0.02 * weight + else: + rate_factor = cfg.step_down ** weight + interval_factor = cfg.step_up ** weight + elitism_delta = -0.02 * weight + + new_rate = state.migration_rate * rate_factor + new_interval = int(round(state.migration_interval * interval_factor)) + new_elitism = state.island_elitism + elitism_delta + + self._update_island_params( + island_idx, + migration_rate=new_rate, + migration_interval=new_interval, + island_elitism=new_elitism, + ) + + # ------------------------------------------------------------------ + # Diversity-based updates + # ------------------------------------------------------------------ + def _update_diversity_metrics(self) -> None: + cfg = self.adapt_config.diversity + for island_idx, state in self.states.items(): + diversity = self._compute_diversity(island_idx) + state.diversity = diversity + strength = cfg.adjust_strength * max(0.0, self.weights.diversity) + if diversity < cfg.low_thresh: + new_rate = state.migration_rate * (1 + strength) + new_interval = int(round(state.migration_interval * (1 - strength))) + self._update_island_params( + island_idx, + migration_rate=new_rate, + migration_interval=max(2, new_interval), + ) + elif diversity > cfg.high_thresh: + new_rate = state.migration_rate * (1 - strength) + new_interval = int(round(state.migration_interval * (1 + strength))) + self._update_island_params( + island_idx, + migration_rate=new_rate, + migration_interval=new_interval, + ) + + # ------------------------------------------------------------------ + # Bandit policy selection + # ------------------------------------------------------------------ + def _policy_provider(self, island_idx: int) -> Optional[Dict[str, str]]: + if not self.bandit_enabled or not self.bandit_arms: + return None + policy, key = self._select_bandit_policy(island_idx) + self.states[island_idx].last_policy_key = key + return policy + + def _select_bandit_policy(self, island_idx: int) -> Tuple[Dict[str, str], str]: + stats = self.bandit_stats[island_idx] + if self.bandit_algo == "epsilon_greedy": + if random.random() < self.bandit_epsilon: + key, policy = random.choice(self.bandit_arms) + return dict(policy), key + key = max(self.bandit_arms, key=lambda item: stats[item[0]].average())[0] + return dict(self._policy_from_key(key)), key + + total_plays = sum(state.count for state in stats.values()) or 1 + + def ucb_score(arm_key: str) -> float: + state = stats[arm_key] + if state.count == 0: + return float("inf") + exploration = self.bandit_ucb_c * math.sqrt( + math.log(total_plays + 1) / (state.count + 1e-9) + ) + return state.average() + exploration + + key = max(self.bandit_arms, key=lambda item: ucb_score(item[0]))[0] + return dict(self._policy_from_key(key)), key + + def _policy_from_key(self, key: str) -> Dict[str, str]: + for arm_key, policy in self.bandit_arms: + if arm_key == key: + return dict(policy) + # Fallback to default random policy + return {"donor": "random", "payload": "random", "size": "medium"} + + def _update_bandit_reward(self, island_idx: int, arm_key: str, reward: float) -> None: + stats = self.bandit_stats.get(island_idx) + if not stats or arm_key not in stats: + return + stats[arm_key].count += 1 + stats[arm_key].total_reward += reward + + # ------------------------------------------------------------------ + # Parameter persistence + # ------------------------------------------------------------------ + def _update_island_params( + self, + island_idx: int, + *, + migration_rate: Optional[float] = None, + migration_interval: Optional[int] = None, + island_elitism: Optional[float] = None, + ) -> None: + manager = getattr(self.db, "island_manager", None) + if manager is None: + return + state = self.states[island_idx] + + if migration_rate is not None: + migration_rate = self._limit_change( + state.migration_rate, + self._clamp_value( + migration_rate, self.bounds.rate_min, self.bounds.rate_max + ), + ) + state.migration_rate = migration_rate + + if migration_interval is not None: + migration_interval = max(2, migration_interval) + migration_interval = int( + round( + self._limit_change( + float(state.migration_interval), + float( + self._clamp_value( + migration_interval, + self.bounds.interval_min, + self.bounds.interval_max, + ) + ), + ) + ) + ) + state.migration_interval = migration_interval + + if island_elitism is not None: + clamped = self._clamp_value( + island_elitism, self.bounds.elitism_min, self.bounds.elitism_max + ) + state.island_elitism = self._limit_change(state.island_elitism, clamped) + + manager.set_island_params( + island_idx, + migration_rate=state.migration_rate, + migration_interval=state.migration_interval, + island_elitism=state.island_elitism, + ) + + # ------------------------------------------------------------------ + # Metrics helpers + # ------------------------------------------------------------------ + def _get_island_best_score(self, island_idx: int) -> float: + cursor = self.db.cursor + cursor.execute( + "SELECT MAX(combined_score) as best FROM programs WHERE island_idx = ?", + (island_idx,), + ) + row = cursor.fetchone() + return float(row["best"]) if row and row["best"] is not None else 0.0 + + def _compute_diversity(self, island_idx: int, limit: int = 20) -> float: + cursor = self.db.cursor + cursor.execute( + "SELECT combined_score FROM programs WHERE island_idx = ? " + "ORDER BY generation DESC LIMIT ?", + (island_idx, limit), + ) + scores = [row["combined_score"] or 0.0 for row in cursor.fetchall()] + if len(scores) <= 1: + return 0.0 + std_dev = statistics.pstdev(scores) + scale = max(max(abs(score) for score in scores), 1e-6) + return min(1.0, float(std_dev / scale)) + + # ------------------------------------------------------------------ + # Logging utilities + # ------------------------------------------------------------------ + def _log_state(self, generation: int) -> None: + if not self.log_path: + return + rows = [] + for island_idx, state in self.states.items(): + rows.append( + { + "generation": generation, + "island": island_idx, + "migration_rate": f"{state.migration_rate:.4f}", + "migration_interval": state.migration_interval, + "island_elitism": f"{state.island_elitism:.4f}", + "impr_ema": f"{state.impr_ema:.5f}", + "diversity": f"{state.diversity:.5f}", + "policy": state.last_policy_key or "", + } + ) + + write_header = not self._log_header_written and not self.log_path.exists() + with self.log_path.open("a", encoding="utf-8", newline="") as f: + writer = csv.DictWriter( + f, + fieldnames=[ + "generation", + "island", + "migration_rate", + "migration_interval", + "island_elitism", + "impr_ema", + "diversity", + "policy", + ], + ) + if write_header: + writer.writeheader() + self._log_header_written = True + for row in rows: + writer.writerow(row) + + # ------------------------------------------------------------------ + # Math helpers + # ------------------------------------------------------------------ + def _clamp_value(self, value: float, min_v: float, max_v: float) -> float: + return max(min_v, min(max_v, value)) + + def _limit_change(self, current: float, proposed: float) -> float: + if current == 0: + return proposed + max_increase = current * 1.25 + max_decrease = current * 0.75 + return max(max_decrease, min(max_increase, proposed)) + + @staticmethod + def _relative_improvement(before: float, after: float) -> float: + denom = max(abs(before), 1e-12) + return (after - before) / denom + + @staticmethod + def _normalize_elitism(value: object) -> float: + if isinstance(value, bool): + return 0.1 if value else 0.0 + try: + return max(0.0, float(value)) + except (TypeError, ValueError): + return 0.0 diff --git a/shinka/core/runner.py b/shinka/core/runner.py index 3c818742..59dd3d78 100644 --- a/shinka/core/runner.py +++ b/shinka/core/runner.py @@ -30,6 +30,7 @@ from shinka.core.sampler import PromptSampler from shinka.core.summarizer import MetaSummarizer from shinka.core.novelty_judge import NoveltyJudge +from shinka.core.migration_adaptation import MigrationAdaptationController from shinka.logo import print_gradient_logo FOLDER_PREFIX = "gen" @@ -106,6 +107,7 @@ def __init__( self.results_dir = f"results_{timestamp}" else: self.results_dir = Path(evo_config.results_dir) + self._last_adaptation_generation_notified = -1 if self.verbose: # Create log file path in results directory @@ -158,7 +160,12 @@ def __init__( # Initialize database and scheduler db_config.db_path = str(db_path) - self.db = ProgramDatabase(config=db_config) + embedding_model_to_use = ( + evo_config.embedding_model or "text-embedding-3-small" + ) + self.db = ProgramDatabase( + config=db_config, embedding_model=embedding_model_to_use + ) self.scheduler = JobScheduler( job_type=evo_config.job_type, config=job_config, # type: ignore @@ -222,6 +229,18 @@ def __init__( max_novelty_attempts=evo_config.max_novelty_attempts, ) + self.migration_adapter: Optional[MigrationAdaptationController] = None + if ( + db_config.migration_adaptation + and db_config.migration_adaptation.enabled + ): + self.migration_adapter = MigrationAdaptationController( + db=self.db, + config=db_config, + results_dir=Path(self.results_dir), + logger=logger, + ) + # Initialize rich console for formatted output self.console = Console() @@ -231,6 +250,12 @@ def __init__( self.lang_ext = "cpp" elif self.evo_config.language == "python": self.lang_ext = "py" + elif self.evo_config.language == "rust": + self.lang_ext = "rs" + elif self.evo_config.language == "swift": + self.lang_ext = "swift" + elif self.evo_config.language in ["json", "json5"]: + self.lang_ext = "json" else: msg = f"Language {self.evo_config.language} not supported" raise ValueError(msg) @@ -299,6 +324,7 @@ def run(self): self._run_generation_0() self.completed_generations = 1 self.next_generation_to_submit = 1 + self._notify_migration_adapter(0) logger.info(f"Completed generation 0, total: 1/{target_gens}") # Now start parallel execution for remaining generations @@ -319,6 +345,10 @@ def run(self): # Update completed generations count self._update_completed_generations() + if self.completed_generations > 0: + self._notify_migration_adapter( + self.completed_generations - 1 + ) if self.verbose: logger.info( @@ -606,6 +636,15 @@ def _update_completed_generations(self): self.completed_generations = completed_up_to + def _notify_migration_adapter(self, generation: int) -> None: + if ( + self.migration_adapter + and generation >= 0 + and generation > self._last_adaptation_generation_notified + ): + self.migration_adapter.on_generation_completed(generation) + self._last_adaptation_generation_notified = generation + def _submit_new_job(self): """Submit a new job to the queue.""" current_gen = self.next_generation_to_submit @@ -1096,9 +1135,10 @@ def run_patch( # error_attempt is already set from apply_patch or default pass - # Only consider the diff summary for the original.py file!!! - if "original.py" in diff_summary: - diff_summary = diff_summary["original.py"] + # Only consider the diff summary for the original source file + original_filename = f"original.{self.lang_ext}" + if original_filename in diff_summary: + diff_summary = diff_summary[original_filename] meta_edit_data = { "patch_type": patch_type, diff --git a/shinka/core/wrap_eval.py b/shinka/core/wrap_eval.py index 7e1d1e5d..bf2cf92e 100644 --- a/shinka/core/wrap_eval.py +++ b/shinka/core/wrap_eval.py @@ -96,6 +96,9 @@ def run_shinka_eval( num_valid_runs = 0 num_invalid_runs = 0 + all_run_results: List[Any] = [] + execution_times: List[float] = [] + try: module = load_program(program_path) if not hasattr(module, experiment_fn_name): @@ -105,9 +108,6 @@ def run_shinka_eval( ) experiment_fn = getattr(module, experiment_fn_name) - all_run_results: List[Any] = [] - execution_times: List[float] = [] - for i in range(num_runs): kwargs: Dict[str, Any] = {} if get_experiment_kwargs: diff --git a/shinka/database/complexity.py b/shinka/database/complexity.py index 4116567e..3ae937e2 100644 --- a/shinka/database/complexity.py +++ b/shinka/database/complexity.py @@ -1,10 +1,16 @@ import ast -from radon.complexity import cc_visit -from radon.metrics import h_visit -from radon.raw import analyze import math import re +try: # pragma: no cover - optional dependency for detailed metrics + from radon.complexity import cc_visit + from radon.metrics import h_visit + from radon.raw import analyze +except ImportError: # pragma: no cover - graceful degradation + cc_visit = None + h_visit = None + analyze = None + def max_nesting_depth(code_string): """Calculate maximum nesting depth for Python code using AST.""" @@ -54,20 +60,41 @@ def analyze_python_complexity(code_string): Raises: SyntaxError: If the code cannot be parsed as valid Python """ - cc_results = cc_visit(code_string) - total_cc = sum(block.complexity for block in cc_results) - avg_cc = total_cc / len(cc_results) if cc_results else 0 - - h_metrics = h_visit(code_string) - halstead_total = h_metrics.total if h_metrics.total else None - halstead_volume = halstead_total.volume if halstead_total else 1 - halstead_difficulty = halstead_total.difficulty if halstead_total else 0 - halstead_effort = halstead_total.effort if halstead_total else 0 + if cc_visit and h_visit and analyze: + cc_results = cc_visit(code_string) + total_cc = sum(block.complexity for block in cc_results) + avg_cc = total_cc / len(cc_results) if cc_results else 0 + + h_metrics = h_visit(code_string) + halstead_total = h_metrics.total if h_metrics.total else None + halstead_volume = halstead_total.volume if halstead_total else 1 + halstead_difficulty = halstead_total.difficulty if halstead_total else 0 + halstead_effort = halstead_total.effort if halstead_total else 0 + + raw_metrics = analyze(code_string) + loc = raw_metrics.loc + lloc = raw_metrics.lloc + comments = raw_metrics.comments + else: + # Fallback lightweight metrics when radon isn't installed + loc = len(code_string.splitlines()) + non_empty_lines = [line for line in code_string.splitlines() if line.strip()] + lloc = len(non_empty_lines) + comments = sum( + 1 + for line in code_string.splitlines() + if line.strip().startswith(("#", "//")) + ) + total_cc = 1 + sum( + len(re.findall(pattern, code_string)) + for pattern in (r"\bif\b", r"\bfor\b", r"\bwhile\b", r"\btry\b") + ) + avg_cc = total_cc + halstead_volume = float(loc) + halstead_difficulty = 0.0 + halstead_effort = 0.0 - raw_metrics = analyze(code_string) - loc = raw_metrics.loc - lloc = raw_metrics.lloc - comments = raw_metrics.comments + nesting_depth = max_nesting_depth(code_string) mi = ( 171 @@ -76,15 +103,11 @@ def analyze_python_complexity(code_string): - 16.2 * (math.log2(loc) if loc > 0 else 0) ) - nesting_depth = max_nesting_depth(code_string) - - # Normalized scores for aggregation - norm_cc = total_cc / 10 # Assuming 10 is high complexity + norm_cc = total_cc / 10 norm_halstead = math.log2(halstead_volume + 1) / 10 norm_loc = math.log2(loc + 1) / 10 - norm_nesting = nesting_depth / 5 # Assuming depth 5 is quite nested + norm_nesting = nesting_depth / 5 - # Complexity Score (weighted sum) complexity_score = ( 0.4 * norm_cc + 0.4 * norm_halstead + 0.1 * norm_loc + 0.1 * norm_nesting ) @@ -259,8 +282,8 @@ def analyze_code_metrics(code_string, language="python"): # If Python parsing fails, fall back to C++ analysis return analyze_cpp_complexity(code_string) - # For C/C++/CUDA and other languages, use regex-based analysis - elif language in ["cpp", "c", "cuda", "c++"]: + # For C/C++/CUDA/Rust/Swift/JSON and other languages, use regex-based analysis + elif language in ["cpp", "c", "cuda", "c++", "rust", "swift", "json", "json5"]: return analyze_cpp_complexity(code_string) # For unknown languages, use simple line-based complexity diff --git a/shinka/database/dbase.py b/shinka/database/dbase.py index 69fdf543..a3ea741b 100644 --- a/shinka/database/dbase.py +++ b/shinka/database/dbase.py @@ -14,7 +14,10 @@ from .inspirations import CombinedContextSelector from .islands import CombinedIslandManager from .display import DatabaseDisplay -from shinka.llm.embedding import EmbeddingClient +try: # pragma: no cover - optional dependency during tests + from shinka.llm.embedding import EmbeddingClient +except ImportError: # pragma: no cover - allow running without llm extras + EmbeddingClient = None # type: ignore logger = logging.getLogger(__name__) @@ -48,9 +51,77 @@ def clean_nan_values(obj: Any) -> Any: return obj +@dataclass +class SuccessAdaptationConfig: + window: int = 5 + target_improvement: float = 0.01 + step_up: float = 1.2 + step_down: float = 0.83 + ema_beta: float = 0.8 + + +@dataclass +class DiversityAdaptationConfig: + metric: str = "score_std" + low_thresh: float = 0.2 + high_thresh: float = 0.6 + adjust_strength: float = 0.15 + + +@dataclass +class BanditPolicyArms: + donor: List[str] = field(default_factory=lambda: ["random", "ring"]) + payload: List[str] = field( + default_factory=lambda: ["random", "elite", "novel"] + ) + size: List[str] = field(default_factory=lambda: ["small", "medium", "large"]) + + +@dataclass +class BanditAdaptationConfig: + policy_arms: BanditPolicyArms = field(default_factory=BanditPolicyArms) + algo: str = "ucb1" + ucb_c: float = 1.0 + epsilon: float = 0.1 + + +@dataclass +class MigrationAdaptationBounds: + rate_min: float = 0.01 + rate_max: float = 0.5 + interval_min: int = 2 + interval_max: int = 50 + elitism_min: float = 0.05 + elitism_max: float = 0.5 + + +@dataclass +class MigrationAdaptationWeights: + success: float = 1.0 + diversity: float = 0.5 + bandit: float = 0.5 + + +@dataclass +class MigrationAdaptationConfig: + enabled: bool = False + methods: List[str] = field(default_factory=lambda: ["success"]) + success: SuccessAdaptationConfig = field(default_factory=SuccessAdaptationConfig) + diversity: DiversityAdaptationConfig = field( + default_factory=DiversityAdaptationConfig + ) + bandit: BanditAdaptationConfig = field(default_factory=BanditAdaptationConfig) + bounds: MigrationAdaptationBounds = field( + default_factory=MigrationAdaptationBounds + ) + weights: MigrationAdaptationWeights = field( + default_factory=MigrationAdaptationWeights + ) + + @dataclass class DatabaseConfig: - db_path: Optional[str] = None + db_path: str = "evolution_db.sqlite" num_islands: int = 4 archive_size: int = 100 @@ -82,6 +153,11 @@ class DatabaseConfig: # Beam search parent selection parameters num_beams: int = 5 + # Adaptive migration + migration_adaptation: Optional[MigrationAdaptationConfig] = None + # Embedding model name + embedding_model: str = "text-embedding-3-small" + def db_retry(max_retries=5, initial_delay=0.1, backoff_factor=2): """ @@ -248,12 +324,22 @@ class ProgramDatabase: populations, and an archive of elites. """ - def __init__(self, config: DatabaseConfig, read_only: bool = False): + def __init__( + self, + config: DatabaseConfig, + embedding_model: str = "text-embedding-3-small", + read_only: bool = False, + ): self.config = config self.conn: Optional[sqlite3.Connection] = None self.cursor: Optional[sqlite3.Cursor] = None self.read_only = read_only - self.embedding_client = EmbeddingClient() + # Only create embedding client if not in read-only mode + # (e.g., WebUI doesn't need it for visualization) + if not read_only: + self.embedding_client = EmbeddingClient(model_name=embedding_model) + else: + self.embedding_client = None self.last_iteration: int = 0 self.best_program_id: Optional[str] = None @@ -1707,6 +1793,9 @@ def get_most_similar_program_thread_safe( def _recompute_embeddings_and_clusters(self, num_clusters: int = 4): if self.read_only: return + if not self.embedding_client: + logger.debug("Embedding client unavailable, skipping recompute") + return if not self.cursor or not self.conn: raise ConnectionError("DB not connected.") @@ -1784,6 +1873,9 @@ def _recompute_embeddings_and_clusters_thread_safe(self, num_clusters: int = 4): """ if self.read_only: return + if not self.embedding_client: + logger.debug("Embedding client unavailable, skipping recompute") + return conn = None try: diff --git a/shinka/database/display.py b/shinka/database/display.py index 3e55439b..fdd1912b 100644 --- a/shinka/database/display.py +++ b/shinka/database/display.py @@ -600,8 +600,10 @@ def format_program_row(prog, role_name): time_display = f"{time_val:.1f}s" # Patch name and type - patch_name = prog.metadata.get("patch_name", "[dim]N/A[/dim]")[:30] - patch_type = prog.metadata.get("patch_type", "[dim]N/A[/dim]") + metadata = prog.metadata or {} + patch_name_raw = metadata.get("patch_name", "[dim]N/A[/dim]") + patch_name = (patch_name_raw or "[dim]N/A[/dim]")[:30] + patch_type = metadata.get("patch_type", "[dim]N/A[/dim]") or "[dim]N/A[/dim]" return [ role_name, diff --git a/shinka/database/inspirations.py b/shinka/database/inspirations.py index ee564dfa..42c3859d 100644 --- a/shinka/database/inspirations.py +++ b/shinka/database/inspirations.py @@ -72,6 +72,7 @@ def sample_context(self, parent: Any, n: int) -> List[Any]: self.cursor.execute( """ SELECT p.id FROM programs p + JOIN archive a ON p.id = a.program_id WHERE p.island_idx = ? AND p.correct = 1 ORDER BY p.combined_score DESC LIMIT ? @@ -93,7 +94,8 @@ def sample_context(self, parent: Any, n: int) -> List[Any]: placeholders_rand = ",".join("?" * len(insp_ids)) sql_rand = f""" SELECT p.id FROM programs p - WHERE p.island_idx = ? AND p.correct = 1 + JOIN archive a ON p.id = a.program_id + WHERE p.island_idx = ? AND p.correct = 1 AND p.id NOT IN ({placeholders_rand}) ORDER BY RANDOM() LIMIT ? """ @@ -111,9 +113,10 @@ def sample_context(self, parent: Any, n: int) -> List[Any]: needed = n - len(inspirations) if needed > 0: placeholders_rand = ",".join("?" * len(insp_ids)) - sql_rand = f"""SELECT id FROM programs - WHERE correct = 1 - AND id NOT IN ({placeholders_rand}) + sql_rand = f"""SELECT p.id FROM programs p + JOIN archive a ON p.id = a.program_id + WHERE p.correct = 1 + AND p.id NOT IN ({placeholders_rand}) ORDER BY RANDOM() LIMIT ? """ params_rand = list(insp_ids) + [needed] diff --git a/shinka/database/islands.py b/shinka/database/islands.py index 9975eac3..86a27f74 100644 --- a/shinka/database/islands.py +++ b/shinka/database/islands.py @@ -5,8 +5,9 @@ import time import uuid from abc import ABC, abstractmethod -from typing import Optional, Any, Dict, List from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Sequence, Set import rich.box # type: ignore import rich # type: ignore from rich.console import Console as RichConsole # type: ignore @@ -15,6 +16,21 @@ logger = logging.getLogger(__name__) +@dataclass +class IslandMigrationParams: + migration_rate: float + migration_interval: int + island_elitism: float + + +@dataclass +class MigrationSummary: + generation: int + total_migrated: int + per_island: Dict[int, int] + migrations: Dict[int, Dict[int, List[str]]] + policies_used: Dict[int, Dict[str, str]] + class IslandStrategy(ABC): """Abstract base class for island strategies.""" @@ -204,57 +220,90 @@ def __init__( self.config = config @abstractmethod - def perform_migration(self, current_generation: int) -> bool: - """Perform migration between islands. - Returns True if migration occurred.""" - pass + def perform_migration( + self, + current_generation: int, + eligible_islands: Optional[Sequence[int]] = None, + ) -> MigrationSummary: + """Perform migration between islands and return a summary.""" + raise NotImplementedError class ElitistMigrationStrategy(IslandMigrationStrategy): - """Migration strategy that protects elite programs from migration.""" + """Migration strategy that supports adaptive parameters and policies.""" - def perform_migration(self, current_generation: int) -> bool: - """ - Implements island migration by moving a subset of programs between - islands. Called periodically based on migration_interval. - """ - num_islands = getattr(self.config, "num_islands", 0) - migration_rate = getattr(self.config, "migration_rate", 0.1) - island_elitism = getattr(self.config, "island_elitism", True) + def __init__( + self, + cursor: sqlite3.Cursor, + conn: sqlite3.Connection, + config: Any, + param_provider: Optional[Callable[[int], Optional[IslandMigrationParams]]] = None, + policy_provider: Optional[Callable[[int], Optional[Dict[str, str]]]] = None, + ): + super().__init__(cursor, conn, config) + self.param_provider = param_provider + self.policy_provider = policy_provider - if num_islands < 2 or migration_rate <= 0: - return False # No migration needed + def perform_migration( + self, + current_generation: int, + eligible_islands: Optional[Sequence[int]] = None, + ) -> MigrationSummary: + num_islands = getattr(self.config, "num_islands", 0) + if num_islands < 2: + return MigrationSummary(current_generation, 0, {}, {}, {}) - logger.info(f"Performing island migration at generation {current_generation}") + logger.info( + f"Performing island migration at generation {current_generation}" + ) migrations_summary = defaultdict(lambda: defaultdict(list)) - # Track all programs selected for migration all_migrated_programs = set() - - # For each island, select migrants to move - for source_idx in range(num_islands): - # Count programs in this island - self.cursor.execute( - "SELECT COUNT(*) FROM programs WHERE island_idx = ?", - (source_idx,), + per_island_counts: Dict[int, int] = defaultdict(int) + policies_used: Dict[int, Dict[str, str]] = {} + + island_iter = eligible_islands if eligible_islands else range(num_islands) + island_best_scores = self._fetch_island_best_scores() + + for source_idx in island_iter: + params = self.param_provider(source_idx) if self.param_provider else None + migration_rate = ( + params.migration_rate + if params is not None + else getattr(self.config, "migration_rate", 0.1) + ) + elitism_ratio = ( + params.island_elitism + if params is not None + else self._normalize_elitism( + getattr(self.config, "island_elitism", True) + ) ) - island_size = (self.cursor.fetchone() or [0])[0] + if migration_rate <= 0: + continue + island_size = self._get_island_size(source_idx) if island_size <= 1: - continue # Skip tiny islands + continue - # Number of programs to migrate num_migrants = max(1, int(island_size * migration_rate)) + policy = self.policy_provider(source_idx) if self.policy_provider else None + if policy: + policies_used[source_idx] = policy + num_migrants = self._apply_size_policy(num_migrants, policy.get("size")) - # Select destination islands (all except source) dest_islands = [i for i in range(num_islands) if i != source_idx] if not dest_islands: continue - # Select migrants based on elitism setting - migrants = self._select_migrants(source_idx, num_migrants, island_elitism) + migrants = self._select_migrants( + source_idx, + island_size, + num_migrants, + elitism_ratio, + (policy or {}).get("payload"), + ) - # Filter out any programs already selected for migration unique_migrants = [] for migrant_id in migrants: if migrant_id not in all_migrated_programs: @@ -266,124 +315,173 @@ def perform_migration(self, current_generation: int) -> bool: "migration, skipping duplicate" ) - # Move each unique migrant to a new island + donor_policy = (policy or {}).get("donor") if policy else None for migrant_id in unique_migrants: - dest_idx = random.choice(dest_islands) + dest_idx = self._select_destination( + source_idx, + dest_islands, + donor_policy, + island_best_scores, + ) + if dest_idx is None: + continue self._migrate_program( migrant_id, source_idx, dest_idx, current_generation ) migrations_summary[source_idx][dest_idx].append(migrant_id) + per_island_counts[source_idx] += 1 self.conn.commit() if migrations_summary: self._print_migration_summary(migrations_summary) - total_migrated = sum( - len(progs) - for dest_dict in migrations_summary.values() - for progs in dest_dict.values() - ) + total_migrated = sum(per_island_counts.values()) logger.info(f"Migration complete. Migrated {total_migrated} programs.") - return total_migrated > 0 + return MigrationSummary( + generation=current_generation, + total_migrated=total_migrated, + per_island=dict(per_island_counts), + migrations={k: dict(v) for k, v in migrations_summary.items()}, + policies_used=policies_used, + ) + + def _get_island_size(self, island_idx: int) -> int: + self.cursor.execute( + "SELECT COUNT(*) FROM programs WHERE island_idx = ?", + (island_idx,), + ) + return (self.cursor.fetchone() or [0])[0] + + def _fetch_island_best_scores(self) -> Dict[int, float]: + self.cursor.execute( + "SELECT island_idx, MAX(combined_score) as best " + "FROM programs WHERE island_idx IS NOT NULL AND correct = 1 " + "GROUP BY island_idx" + ) + return { + row["island_idx"]: row["best"] if row["best"] is not None else 0.0 + for row in self.cursor.fetchall() + if row["island_idx"] is not None + } + + def _normalize_elitism(self, value: Any) -> float: + if isinstance(value, bool): + return 0.1 if value else 0.0 + try: + return max(0.0, float(value)) + except (TypeError, ValueError): + return 0.0 + + def _apply_size_policy(self, num_migrants: int, policy: Optional[str]) -> int: + if policy is None: + return num_migrants + factors = {"small": 0.5, "medium": 1.0, "large": 1.5} + factor = factors.get(policy, 1.0) + adjusted = max(1, int(round(num_migrants * factor))) + return adjusted + + def _select_destination( + self, + source_idx: int, + dest_islands: List[int], + donor_policy: Optional[str], + island_best_scores: Dict[int, float], + ) -> Optional[int]: + if not dest_islands: + return None + num_islands = getattr(self.config, "num_islands", 0) + if donor_policy == "ring" and num_islands > 0: + return (source_idx + 1) % num_islands + if donor_policy == "topk": + ranked = sorted( + ((idx, island_best_scores.get(idx, float("-inf"))) for idx in dest_islands), + key=lambda item: item[1], + reverse=True, + ) + if ranked: + return ranked[0][0] + # Default random destination + return random.choice(dest_islands) def _select_migrants( self, source_idx: int, + island_size: int, num_migrants: int, - island_elitism: bool, + island_elitism: float, + payload_policy: Optional[str], ) -> List[str]: - """Select which programs to migrate from an island. - Excludes generation 0 programs (initial programs and their copies) - and only considers correct programs. - """ - # Base query excludes generation 0 programs and only includes - # correct programs - selection_query = """ - SELECT id FROM programs - WHERE island_idx = ? AND generation > 0 AND correct = 1 - """ + selection_query = ( + "SELECT id FROM programs WHERE island_idx = ? " + "AND generation > 0 AND correct = 1" + ) + params: List[Any] = [source_idx] - if island_elitism: - # Get IDs of best program to protect from migration - # Also exclude generation 0 programs from elite selection and - # only consider correct programs - elite_query = """ - SELECT id FROM programs - WHERE island_idx = ? AND generation > 0 AND correct = 1 - ORDER BY combined_score DESC - LIMIT 1 - """ - - self.cursor.execute(elite_query, (source_idx,)) - elite_ids = [row["id"] for row in self.cursor.fetchall()] - - if elite_ids: - # Exclude elites from migration - placeholders = ",".join(["?"] * len(elite_ids)) - selection_query += f" AND id NOT IN ({placeholders})" - selection_query += " ORDER BY RANDOM() LIMIT ?" - params = [source_idx] + elite_ids + [num_migrants] - else: - selection_query += " ORDER BY RANDOM() LIMIT ?" - params = [source_idx, num_migrants] - else: - # Simple random selection (excluding generation 0, - # only correct programs) - selection_query += " ORDER BY RANDOM() LIMIT ?" - params = [source_idx, num_migrants] + elite_ids = self._get_elite_ids(source_idx, island_size, island_elitism) + if elite_ids: + placeholders = ",".join(["?"] * len(elite_ids)) + selection_query += f" AND id NOT IN ({placeholders})" + params.extend(elite_ids) - # First check how many correct non-generation-0 programs are available - self.cursor.execute( - "SELECT COUNT(*) FROM programs WHERE island_idx = ? AND " - "generation > 0 AND correct = 1", - (source_idx,), - ) - available_programs = (self.cursor.fetchone() or [0])[0] + order_clause = self._payload_order_clause(payload_policy) + selection_query += f" {order_clause} LIMIT ?" - if available_programs == 0: + available_programs = self._count_available_programs(source_idx) - len(elite_ids) + if available_programs <= 0: logger.debug( - f"No correct generation > 0 programs available for migration " - f"from island {source_idx} (generation 0 programs are " - f"protected, " - f"only correct programs migrate)" + f"No eligible programs available for migration from island {source_idx}" ) return [] - # Adjust num_migrants if there aren't enough eligible programs - actual_migrants = min(num_migrants, available_programs) - if actual_migrants != num_migrants: - logger.debug( - f"Reducing migration count from {num_migrants} to " - f"{actual_migrants} for island {source_idx} " - f"(only {available_programs} correct eligible programs " - f"available)" - ) - # Update the params list to use the adjusted count - if isinstance(params, list) and len(params) > 0: - params[-1] = actual_migrants # Last param is always the LIMIT + actual_migrants = max(0, min(num_migrants, available_programs)) + if actual_migrants == 0: + return [] + params.append(actual_migrants) - # Select migrants self.cursor.execute(selection_query, params) migrants = [row["id"] for row in self.cursor.fetchall()] - - # Validate uniqueness (should always be true, but good to check) if len(migrants) != len(set(migrants)): logger.warning( - f"Duplicate programs selected for migration from island " - f"{source_idx}. Expected {len(migrants)} unique, got " - f"{len(set(migrants))} unique." + f"Duplicate programs selected for migration from island {source_idx}." ) - migrants = list(set(migrants)) # Remove duplicates + migrants = list(set(migrants)) logger.debug( - f"Selected {len(migrants)} unique correct migrants from island " - f"{source_idx} (excluded generation 0 programs and incorrect " - f"programs from migration)" + f"Selected {len(migrants)} migrants from island {source_idx}" ) - return migrants + def _payload_order_clause(self, payload_policy: Optional[str]) -> str: + if payload_policy == "elite": + return "ORDER BY combined_score DESC" + if payload_policy == "novel": + return "ORDER BY combined_score ASC" + return "ORDER BY RANDOM()" + + def _count_available_programs(self, island_idx: int) -> int: + self.cursor.execute( + "SELECT COUNT(*) FROM programs WHERE island_idx = ? AND generation > 0 AND correct = 1", + (island_idx,), + ) + return (self.cursor.fetchone() or [0])[0] + + def _get_elite_ids( + self, + island_idx: int, + island_size: int, + elite_ratio: float, + ) -> List[str]: + if elite_ratio <= 0 or island_size <= 1: + return [] + elite_count = max(1, int(round(island_size * elite_ratio))) + self.cursor.execute( + "SELECT id FROM programs WHERE island_idx = ? AND generation > 0 AND correct = 1 " + "ORDER BY combined_score DESC LIMIT ?", + (island_idx, elite_count), + ) + return [row["id"] for row in self.cursor.fetchall()] + def _migrate_program( self, migrant_id: str, @@ -513,17 +611,128 @@ def __init__( self.assignment_strategy = assignment_strategy or ( CopyInitialProgramIslandStrategy(cursor, conn, config) ) - self.migration_strategy = migration_strategy or ( - ElitistMigrationStrategy(cursor, conn, config) - ) + + self._island_params: Dict[int, IslandMigrationParams] = {} + self._pending_islands: Set[int] = set() + self._last_migration_generation: Dict[int, int] = {} + self._migration_callback: Optional[Callable[[MigrationSummary], None]] = None + self._policy_provider: Optional[Callable[[int], Optional[Dict[str, str]]]] = None + + self._initialize_island_params() + + if migration_strategy is None: + self.migration_strategy = ElitistMigrationStrategy( + cursor, + conn, + config, + param_provider=self._get_island_params, + policy_provider=self._resolve_policy, + ) + else: + self.migration_strategy = migration_strategy def assign_island(self, program: Any) -> None: """Assign an island to a program using the configured strategy.""" self.assignment_strategy.assign_island(program) - def perform_migration(self, current_generation: int) -> bool: + def perform_migration(self, current_generation: int) -> MigrationSummary: """Perform migration using the configured strategy.""" - return self.migration_strategy.perform_migration(current_generation) + eligible = sorted(self._pending_islands) if self._pending_islands else None + summary = self.migration_strategy.perform_migration( + current_generation, eligible_islands=eligible + ) + if summary.total_migrated > 0: + for island_idx in summary.per_island.keys(): + self._last_migration_generation[island_idx] = current_generation + self._pending_islands.clear() + if self._migration_callback: + self._migration_callback(summary) + return summary + + def set_island_params( + self, + island_idx: int, + *, + migration_rate: Optional[float] = None, + migration_interval: Optional[int] = None, + island_elitism: Optional[float] = None, + ) -> IslandMigrationParams: + params = self._get_island_params(island_idx) + if migration_rate is not None: + params.migration_rate = migration_rate + if migration_interval is not None: + params.migration_interval = max(2, int(migration_interval)) + if island_elitism is not None: + params.island_elitism = max(0.0, float(island_elitism)) + return params + + def set_island_params_bulk( + self, params: Dict[int, IslandMigrationParams] + ) -> None: + for idx, state in params.items(): + self._island_params[idx] = IslandMigrationParams( + migration_rate=state.migration_rate, + migration_interval=state.migration_interval, + island_elitism=state.island_elitism, + ) + + def get_island_params_snapshot(self) -> Dict[int, IslandMigrationParams]: + return { + idx: IslandMigrationParams( + migration_rate=state.migration_rate, + migration_interval=state.migration_interval, + island_elitism=state.island_elitism, + ) + for idx, state in self._island_params.items() + } + + def register_policy_provider( + self, provider: Callable[[int], Optional[Dict[str, str]]] + ) -> None: + self._policy_provider = provider + + def set_migration_callback( + self, callback: Callable[[MigrationSummary], None] + ) -> None: + self._migration_callback = callback + + def _initialize_island_params(self) -> None: + num_islands = getattr(self.config, "num_islands", 0) + default_params = self._default_params() + for idx in range(num_islands): + self._island_params[idx] = IslandMigrationParams( + migration_rate=default_params.migration_rate, + migration_interval=default_params.migration_interval, + island_elitism=default_params.island_elitism, + ) + self._last_migration_generation.setdefault(idx, 0) + + def _default_params(self) -> IslandMigrationParams: + return IslandMigrationParams( + migration_rate=getattr(self.config, "migration_rate", 0.1), + migration_interval=max(2, getattr(self.config, "migration_interval", 10)), + island_elitism=self._normalize_elitism_value( + getattr(self.config, "island_elitism", True) + ), + ) + + def _get_island_params(self, island_idx: int) -> IslandMigrationParams: + if island_idx not in self._island_params: + self._island_params[island_idx] = self._default_params() + return self._island_params[island_idx] + + def _resolve_policy(self, island_idx: int) -> Optional[Dict[str, str]]: + if not self._policy_provider: + return None + return self._policy_provider(island_idx) + + def _normalize_elitism_value(self, value: Any) -> float: + if isinstance(value, bool): + return 0.1 if value else 0.0 + try: + return max(0.0, float(value)) + except (TypeError, ValueError): + return 0.0 def get_island_idx(self, program_id: str) -> Optional[int]: """Get the island index for a given program ID.""" @@ -546,14 +755,20 @@ def are_all_islands_initialized(self) -> bool: return len(initialized_islands) >= num_islands def should_schedule_migration(self, program: Any) -> bool: - """Check if migration should be scheduled based on program - generation.""" - return ( - program.generation > 0 - and hasattr(self.config, "migration_interval") - and self.config.migration_interval > 0 - and (program.generation % self.config.migration_interval == 0) - ) + """Check if migration should be scheduled for the program's island.""" + island_idx = getattr(program, "island_idx", None) + if island_idx is None or program.generation <= 0: + return False + + params = self._get_island_params(island_idx) + interval = max(1, int(params.migration_interval)) + last_gen = self._last_migration_generation.get(island_idx, 0) + + if (program.generation - last_gen) >= interval: + if island_idx not in self._pending_islands: + self._pending_islands.add(island_idx) + return True + return False def get_island_populations(self) -> Dict[int, int]: """Get the population count for each island.""" @@ -682,6 +897,16 @@ def copy_program_to_islands(self, program: Any) -> List[str]: f"Created copy {new_id[:8]}... of program {program.id[:8]}... " f"for island {island_idx}" ) + + # Add the copied program to the archive if it's correct + # This ensures it can be used as inspiration for that island + if program.correct: + self.cursor.execute( + "INSERT OR IGNORE INTO archive (program_id) VALUES (?)", + (new_id,), + ) + logger.debug(f"Added copy {new_id[:8]}... to archive (correct program)") + self.conn.commit() logger.info( f"Created {len(created_ids)} copies of program " diff --git a/shinka/edit/apply_diff.py b/shinka/edit/apply_diff.py index ead28e23..7d216105 100644 --- a/shinka/edit/apply_diff.py +++ b/shinka/edit/apply_diff.py @@ -698,12 +698,12 @@ def apply_diff_patch( patch_str = _strip_trailing_whitespace(patch_str) # Remove the EVOLVE-BLOCK START and EVOLVE-BLOCK END markers - if language in ["cuda", "cpp"]: - patch_str = re.sub(r"// EVOLVE-BLOCK START\\n", "", patch_str) - patch_str = re.sub(r"// EVOLVE-BLOCK END\\n", "", patch_str) + if language in ["cuda", "cpp", "rust", "swift", "json", "json5"]: + patch_str = re.sub(r"// EVOLVE-BLOCK-START\\n", "", patch_str) + patch_str = re.sub(r"// EVOLVE-BLOCK-END\\n", "", patch_str) elif language == "python": - patch_str = re.sub(r"# EVOLVE-BLOCK START\\n", "", patch_str) - patch_str = re.sub(r"# EVOLVE-BLOCK END\\n", "", patch_str) + patch_str = re.sub(r"# EVOLVE-BLOCK-START\\n", "", patch_str) + patch_str = re.sub(r"# EVOLVE-BLOCK-END\\n", "", patch_str) else: raise ValueError(f"Language {language} not supported") @@ -730,6 +730,12 @@ def apply_diff_patch( suffix = ".cpp" elif language == "cuda": suffix = ".cu" + elif language == "rust": + suffix = ".rs" + elif language == "swift": + suffix = ".swift" + elif language in ["json", "json5"]: + suffix = ".json" else: raise ValueError(f"Language {language} not supported") diff --git a/shinka/edit/apply_full.py b/shinka/edit/apply_full.py index b7e2e2b3..ac628812 100644 --- a/shinka/edit/apply_full.py +++ b/shinka/edit/apply_full.py @@ -1,6 +1,6 @@ from pathlib import Path from typing import Optional, Union -from .apply_diff import write_git_diff, _mutable_ranges +from .apply_diff import write_git_diff, _mutable_ranges, EVOLVE_START, EVOLVE_END from shinka.llm import extract_between import logging @@ -72,10 +72,15 @@ def apply_full_patch( updated_content = "" last_end = 0 - # Check if patch_code contains EVOLVE-BLOCK markers - patch_mutable_ranges = _mutable_ranges(patch_code) + # Detect EVOLVE markers presence in the patch content + patch_has_start = EVOLVE_START.search(patch_code) is not None + patch_has_end = EVOLVE_END.search(patch_code) is not None + patch_has_both = patch_has_start and patch_has_end + patch_has_none = not patch_has_start and not patch_has_end - if patch_mutable_ranges: + if patch_has_both: + # Patch contains both EVOLVE-BLOCK markers, extract from them + patch_mutable_ranges = _mutable_ranges(patch_code) # Patch contains EVOLVE-BLOCK markers, extract from them for i, (start, end) in enumerate(mutable_ranges): # Add immutable part before this mutable range @@ -91,47 +96,158 @@ def apply_full_patch( updated_content += replacement_content last_end = end - else: + elif patch_has_none: # Patch doesn't contain EVOLVE-BLOCK markers # Assume entire patch content should replace all mutable regions if len(mutable_ranges) == 1: - # Single mutable region, replace with entire patch content + # Single mutable region. If the patch appears to be a full-file + # rewrite that omitted EVOLVE markers, safely extract only the + # content intended for the evolve block by matching immutable + # prefix/suffix from the original file. start, end = mutable_ranges[0] - # The mutable range ends before "EVOLVE-BLOCK-END" text - # We need to find the actual start of the comment line - if language == "python": - end_marker = "# EVOLVE-BLOCK-END" - elif language in ["cuda", "cpp"]: - end_marker = "// EVOLVE-BLOCK-END" - else: - end_marker = "# EVOLVE-BLOCK-END" # Default fallback - - end_marker_pos = original.find(end_marker, end - 5) - if end_marker_pos == -1: - # Fallback: use the original end position - end_marker_pos = end + # Immutable portions that remain outside the evolve block + immutable_prefix = original[:start] + immutable_suffix = original[end:] - # Ensure proper newline handling around the patch content - if patch_code and not patch_code.startswith("\n"): - patch_code = "\n" + patch_code + # Also compute the portions strictly outside the marker lines + # to detect full-file patches that omitted EVOLVE markers. + # Find the start and end marker line boundaries. + start_match = None + end_match = None + for m in EVOLVE_START.finditer(original): + if m.end() == start: + start_match = m + break + for m in EVOLVE_END.finditer(original): + if m.start() == end: + end_match = m + break - if patch_code and not patch_code.endswith("\n"): - patch_code = patch_code + "\n" - - updated_content = ( - original[:start] + patch_code + original[end_marker_pos:] + prefix_outside = ( + original[: start_match.start()] if start_match else immutable_prefix + ) + suffix_outside = ( + original[end_match.end() :] if end_match else immutable_suffix ) + + # Heuristic: if patch includes the same immutable prefix/suffix + # outside the markers, treat the middle part as the evolve-block + # replacement. Be tolerant to a missing trailing newline in the + # footer by checking both versions. + suffix_opts = (suffix_outside, suffix_outside.rstrip("\r\n")) + if patch_code.startswith(prefix_outside) and any( + patch_code.endswith(sfx) for sfx in suffix_opts + ): + mid_start = len(prefix_outside) + # choose the matching suffix option to compute end + sfx = next(sfx for sfx in suffix_opts if patch_code.endswith(sfx)) + mid_end = len(patch_code) - len(sfx) + replacement_content = patch_code[mid_start:mid_end] + # Ensure marker boundaries stay on their own lines. + # Add a leading newline only if there is a START marker. + if ( + start_match is not None + and replacement_content + and not replacement_content.startswith("\n") + ): + replacement_content = "\n" + replacement_content + # Add a trailing newline only if there is an END marker. + if ( + end_match is not None + and replacement_content + and not replacement_content.endswith("\n") + ): + replacement_content = replacement_content + "\n" + updated_content = ( + immutable_prefix + replacement_content + immutable_suffix + ) + else: + # Otherwise, assume the patch_code represents only the + # evolve-block payload and insert it directly between markers. + # Ensure proper newline handling around the patch content. + payload = patch_code + if ( + start_match is not None + and payload + and not payload.startswith("\n") + ): + payload = "\n" + payload + if end_match is not None and payload and not payload.endswith("\n"): + payload = payload + "\n" + updated_content = immutable_prefix + payload + immutable_suffix else: - # Multiple mutable regions, this is ambiguous + # Multiple EVOLVE-BLOCK regions found, ambiguous without markers error_message = ( "Multiple EVOLVE-BLOCK regions found but patch " "doesn't specify which to replace" ) return original, 0, None, error_message, None, None + else: + # Patch contains exactly one marker (START xor END). + # Only safe to apply when original has a single evolve region. + if len(mutable_ranges) != 1: + error_message = ( + "Patch contains only one EVOLVE-BLOCK marker, but the original " + f"has {len(mutable_ranges)} editable regions; cannot determine target" + ) + return original, 0, None, error_message, None, None + + # Single target region in original + start, end = mutable_ranges[0] + immutable_prefix = original[:start] + immutable_suffix = original[end:] + + # Find exact marker locations in original for newline policy + start_match = None + end_match = None + for m in EVOLVE_START.finditer(original): + if m.end() == start: + start_match = m + break + for m in EVOLVE_END.finditer(original): + if m.start() == end: + end_match = m + break + + # Compute outside-of-markers prefix/suffix from original + prefix_outside = ( + original[: start_match.start()] if start_match else immutable_prefix + ) + suffix_outside = ( + original[end_match.end() :] if end_match else immutable_suffix + ) + + # Extract payload based on which single marker is present in patch + if patch_has_start and not patch_has_end: + m = EVOLVE_START.search(patch_code) + payload = patch_code[m.end() :] if m else patch_code + # Trim footer if the patch included it + for sfx in (suffix_outside, suffix_outside.rstrip("\r\n")): + if sfx and payload.endswith(sfx): + payload = payload[: -len(sfx)] + break + elif patch_has_end and not patch_has_start: + m = EVOLVE_END.search(patch_code) + payload = patch_code[: m.start()] if m else patch_code + # Trim header if the patch included it + for pfx in (prefix_outside, prefix_outside.rstrip("\r\n")): + if pfx and payload.startswith(pfx): + payload = payload[len(pfx) :] + break + else: + payload = patch_code + + # Normalize newlines so markers remain on their own lines + if start_match is not None and payload and not payload.startswith("\n"): + payload = "\n" + payload + if end_match is not None and payload and not payload.endswith("\n"): + payload = payload + "\n" + + updated_content = immutable_prefix + payload + immutable_suffix # Add remaining immutable content after last mutable range - if patch_mutable_ranges and mutable_ranges: + if patch_has_both and mutable_ranges: updated_content += original[mutable_ranges[-1][1] :] num_applied = 1 @@ -146,6 +262,12 @@ def apply_full_patch( suffix = ".cpp" elif language == "cuda": suffix = ".cu" + elif language == "rust": + suffix = ".rs" + elif language == "swift": + suffix = ".swift" + elif language in ["json", "json5"]: + suffix = ".json" else: raise ValueError(f"Language {language} not supported") diff --git a/shinka/edit/async_apply.py b/shinka/edit/async_apply.py index 8e542c56..e4c21202 100644 --- a/shinka/edit/async_apply.py +++ b/shinka/edit/async_apply.py @@ -118,6 +118,31 @@ async def validate_code_async( error_msg = stderr.decode() if stderr else "Unknown compilation error" return False, error_msg + elif language == "rust": + # Use rustc for Rust syntax checking + proc = await asyncio.create_subprocess_exec( + "rustc", + "--crate-type=lib", + "-Zparse-only", + code_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout, stderr = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + return False, f"Validation timeout after {timeout}s" + + if proc.returncode == 0: + return True, None + else: + error_msg = stderr.decode() if stderr else "Unknown compilation error" + return False, error_msg elif language == "cpp": # Use g++ for C++ compilation check proc = await asyncio.create_subprocess_exec( @@ -128,6 +153,31 @@ async def validate_code_async( stderr=asyncio.subprocess.PIPE, ) + try: + stdout, stderr = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + return False, f"Validation timeout after {timeout}s" + + if proc.returncode == 0: + return True, None + else: + error_msg = stderr.decode() if stderr else "Unknown compilation error" + return False, error_msg + elif language == "swift": + # Use swiftc for Swift syntax checking + proc = await asyncio.create_subprocess_exec( + "swiftc", + "-typecheck", + "-parse-as-library", + code_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + try: stdout, stderr = await asyncio.wait_for( proc.communicate(), timeout=timeout diff --git a/shinka/launch/scheduler.py b/shinka/launch/scheduler.py index 5782613e..4e824c3f 100644 --- a/shinka/launch/scheduler.py +++ b/shinka/launch/scheduler.py @@ -138,7 +138,13 @@ def _build_command(self, exec_fname_t: str, results_dir_t: str) -> List[str]: ] if self.config.extra_cmd_args: for k, v in self.config.extra_cmd_args.items(): - cmd.extend([f"--{k}", str(v)]) + # Handle boolean flags + if isinstance(v, bool): + if v: # Only append flag if True + cmd.append(f"--{k}") + else: + # For non-boolean values, append both flag and value + cmd.extend([f"--{k}", str(v)]) return cmd def run( diff --git a/shinka/llm/dynamic_sampling.py b/shinka/llm/dynamic_sampling.py index 6c038d9f..eb0cd8cb 100644 --- a/shinka/llm/dynamic_sampling.py +++ b/shinka/llm/dynamic_sampling.py @@ -28,7 +28,8 @@ def _logdiffexp(a_log, b_log): def _logexpm1(z): z = np.asarray(z, dtype=float) - return np.where(z > 50.0, z, np.log(np.expm1(z))) + with np.errstate(divide='ignore', invalid='ignore'): + return np.where(z > 50.0, z, np.log(np.expm1(z))) class BanditBase(ABC): @@ -433,12 +434,13 @@ def decay(self, factor: float) -> None: if self.use_exponential_scaling and self.asymmetric_scaling: # shrink in exp space to match original score scale s = self.s - log1p_term = np.where( - s > 0.0, - s + np.log(one_minus_factor + np.exp(-s)), - np.log1p(one_minus_factor * np.exp(s)), - ) - self.s = s + np.log(factor) - log1p_term + with np.errstate(divide='ignore', invalid='ignore'): + log1p_term = np.where( + s > 0.0, + s + np.log(one_minus_factor + np.exp(-s)), + np.log1p(one_minus_factor * np.exp(s)), + ) + self.s = s + np.log(factor) - log1p_term if self.adaptive_scale and np.isfinite(self._obs_max): means_log = self._mean() diff --git a/shinka/llm/embedding.py b/shinka/llm/embedding.py index a5c6b07c..4082ad58 100644 --- a/shinka/llm/embedding.py +++ b/shinka/llm/embedding.py @@ -1,5 +1,6 @@ import os import openai +import google.generativeai as genai import pandas as pd from typing import Union, List, Optional, Tuple import numpy as np @@ -20,13 +21,23 @@ "azure-text-embedding-3-large", ] +GEMINI_EMBEDDING_MODELS = [ + "gemini-embedding-exp-03-07", + "gemini-embedding-001", +] + OPENAI_EMBEDDING_COSTS = { "text-embedding-3-small": 0.02 / M, "text-embedding-3-large": 0.13 / M, } +# Gemini embedding costs (approximate - check current pricing) +GEMINI_EMBEDDING_COSTS = { + "gemini-embedding-exp-03-07": 0.0 / M, # Experimental model, often free + "gemini-embedding-001": 0.0 / M, # Check current pricing +} -def get_client_model(model_name: str) -> tuple[openai.OpenAI, str]: +def get_client_model(model_name: str) -> tuple[Union[openai.OpenAI, str], str]: if model_name in OPENAI_EMBEDDING_MODELS: client = openai.OpenAI() model_to_use = model_name @@ -38,6 +49,14 @@ def get_client_model(model_name: str) -> tuple[openai.OpenAI, str]: api_version=os.getenv("AZURE_API_VERSION"), azure_endpoint=os.getenv("AZURE_API_ENDPOINT"), ) + elif model_name in GEMINI_EMBEDDING_MODELS: + # Configure Gemini API + api_key = os.getenv("GEMINI_API_KEY") + if not api_key: + raise ValueError("GEMINI_API_KEY environment variable not set for Gemini models") + genai.configure(api_key=api_key) + client = "gemini" # Use string identifier for Gemini + model_to_use = model_name else: raise ValueError(f"Invalid embedding model: {model_name}") @@ -52,9 +71,10 @@ def __init__( Initialize the EmbeddingClient. Args: - model (str): The OpenAI embedding model name to use. + model (str): The OpenAI, Azure, or Gemini embedding model name to use. """ self.client, self.model = get_client_model(model_name) + self.model_name = model_name self.verbose = verbose def get_embedding( @@ -76,6 +96,34 @@ def get_embedding( single_code = True else: single_code = False + # Handle Gemini models + if self.model_name in GEMINI_EMBEDDING_MODELS: + try: + embeddings = [] + total_tokens = 0 + + for text in code: + result = genai.embed_content( + model=f"models/{self.model}", + content=text, + task_type="retrieval_document" + ) + embeddings.append(result['embedding']) + total_tokens += len(text.split()) + + cost = total_tokens * GEMINI_EMBEDDING_COSTS.get(self.model, 0.0) + + if single_code: + return embeddings[0] if embeddings else [], cost + else: + return embeddings, cost + except Exception as e: + logger.error(f"Error getting Gemini embedding: {e}") + if single_code: + return [], 0.0 + else: + return [[]], 0.0 + # Handle OpenAI and Azure models (same interface) try: response = self.client.embeddings.create( model=self.model, input=code, encoding_format="float" diff --git a/shinka/llm/models/pricing.py b/shinka/llm/models/pricing.py index c9c101a2..91e965c7 100644 --- a/shinka/llm/models/pricing.py +++ b/shinka/llm/models/pricing.py @@ -35,6 +35,10 @@ "input_price": 3.0 / M, "output_price": 15.0 / M, }, + "claude-sonnet-4-5-20250929": { + "input_price": 3.0 / M, + "output_price": 15.0 / M, + }, } OPENAI_MODELS = { @@ -114,6 +118,10 @@ "input_price": 0.05 / M, "output_price": 0.4 / M, }, + "gpt-5.1": { + "input_price": 1.25 / M, + "output_price": 10.0 / M, + }, } @@ -141,6 +149,10 @@ "input_price": 0.1 / M, "output_price": 0.4 / M, }, + "gemini-3-pro-preview" : { + "input_price": 2.0 / M, + "output_price": 12.0 / M, + }, } BEDROCK_MODELS = { @@ -176,6 +188,7 @@ REASONING_CLAUDE_MODELS = [ "claude-3-7-sonnet-20250219", "claude-4-sonnet-20250514", + "claude-sonnet-4-5-20250929", ] REASONING_DEEPSEEK_MODELS = [ @@ -186,6 +199,7 @@ "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite-preview-06-17", + "gemini-3-pro-preview", ] REASONING_AZURE_MODELS = [ diff --git a/shinka/llm/query.py b/shinka/llm/query.py index a7288df8..c88c7d7c 100644 --- a/shinka/llm/query.py +++ b/shinka/llm/query.py @@ -137,16 +137,13 @@ def sample_model_kwargs( r_effort = random.choice(reasoning_efforts) think_bool = r_effort != "auto" if think_bool: - thinking_tokens = [ - t - for t in THINKING_TOKENS.values() - if t < kwargs_dict["max_tokens"] and t >= 1024 - ] + t = THINKING_TOKENS[r_effort] + thinking_tokens = t if t < kwargs_dict["max_tokens"] else 1024 kwargs_dict["extra_body"] = { "extra_body": { "google": { "thinking_config": { - "thinking_budget": random.choice(thinking_tokens), + "thinking_budget": thinking_tokens, "include_thoughts": True, } } @@ -157,19 +154,17 @@ def sample_model_kwargs( REASONING_CLAUDE_MODELS + REASONING_BEDROCK_MODELS ): kwargs_dict["max_tokens"] = min(random.choice(max_tokens), 16384) - think_bool = random.choice(reasoning_efforts) != "auto" + r_effort = random.choice(reasoning_efforts) + think_bool = r_effort != "auto" if think_bool: # filter thinking tokens to be smaller than max_tokens # not auto THINKING_TOKENS - thinking_tokens = [ - t - for t in THINKING_TOKENS.values() - if t < kwargs_dict["max_tokens"] and t >= 1024 - ] + t = THINKING_TOKENS[r_effort] + thinking_tokens = t if t < kwargs_dict["max_tokens"] else 1024 # sample only from thinking tokens that are valid kwargs_dict["thinking"] = { "type": "enabled", - "budget_tokens": random.choice(thinking_tokens), + "budget_tokens": thinking_tokens, } else: diff --git a/tests/test_edit_base.py b/tests/test_edit_base.py index edc0e117..67c6f2e2 100644 --- a/tests/test_edit_base.py +++ b/tests/test_edit_base.py @@ -161,6 +161,110 @@ def new_func2(): # Should have replaced both evolve blocks with new content +def test_apply_full_patch_full_file_without_markers_extracts_block_only(): + """Full-file patch without EVOLVE markers should not copy immutable code + into the evolve block; only the block payload is replaced.""" + original_content = """# Header line\n# EVOLVE-BLOCK-START\nold_line()\n# EVOLVE-BLOCK-END\n# Footer line\n""" + + # Patch is the entire file content but with the EVOLVE markers omitted. + patch_content = """```python +new_line() +another_new_line() +```""" + + expected = """# Header line +# EVOLVE-BLOCK-START +new_line() +another_new_line() +# EVOLVE-BLOCK-END +# Footer line +""" + + result = apply_full_patch( + patch_str=patch_content, + original_str=original_content, + language="python", + verbose=False, + ) + updated_content, num_applied, output_path, error, patch_txt, diff_path = result + + assert error is None + assert num_applied == 1 + assert updated_content == expected + + +def test_apply_full_patch_patch_with_start_marker_only(): + """Patch has only START marker; original has both markers.""" + original_content = """# Header line +# EVOLVE-BLOCK-START +old_line() +# EVOLVE-BLOCK-END +# Footer line +""" + + patch_content = """```python +# Header line +# EVOLVE-BLOCK-START +new_line() +# Footer line +```""" + + expected = """# Header line +# EVOLVE-BLOCK-START +new_line() +# EVOLVE-BLOCK-END +# Footer line +""" + + result = apply_full_patch( + patch_str=patch_content, + original_str=original_content, + language="python", + verbose=False, + ) + updated_content, num_applied, output_path, error, patch_txt, diff_path = result + + assert error is None + assert num_applied == 1 + assert updated_content == expected + + +def test_apply_full_patch_patch_with_end_marker_only(): + """Patch has only END marker; original has both markers.""" + original_content = """# Header line +# EVOLVE-BLOCK-START +old_line() +# EVOLVE-BLOCK-END +# Footer line +""" + + patch_content = """```python +# Header line +new_line() +# EVOLVE-BLOCK-END +# Footer line +```""" + + expected = """# Header line +# EVOLVE-BLOCK-START +new_line() +# EVOLVE-BLOCK-END +# Footer line +""" + + result = apply_full_patch( + patch_str=patch_content, + original_str=original_content, + language="python", + verbose=False, + ) + updated_content, num_applied, output_path, error, patch_txt, diff_path = result + + assert error is None + assert num_applied == 1 + assert updated_content == expected + + def test_apply_full_patch_no_evolve_blocks(): """Test apply_full_patch with no EVOLVE-BLOCK regions - should error.""" original_content = """# Just regular code @@ -221,6 +325,41 @@ def new_function(): assert updated_content == original_content # Should return original content +def test_apply_full_patch_patch_with_single_marker_ambiguous_multiple_regions(): + """Single marker in patch is ambiguous when original has multiple regions.""" + original_content = """# Header +# EVOLVE-BLOCK-START +func1() +# EVOLVE-BLOCK-END + +# EVOLVE-BLOCK-START +func2() +# EVOLVE-BLOCK-END +# Footer +""" + + # Patch includes only START marker + patch_content = """```python +# Header +# EVOLVE-BLOCK-START +new_code() +# Footer +```""" + + updated_content, num_applied, output_path, error, patch_txt, diff_path = ( + apply_full_patch( + patch_str=patch_content, + original_str=original_content, + language="python", + verbose=False, + ) + ) + + assert num_applied == 0 + assert error is not None + assert "only one EVOLVE-BLOCK marker" in error + + def test_apply_full_patch_invalid_extraction(): """Test apply_full_patch with invalid code extraction.""" original_content = """# EVOLVE-BLOCK-START