diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..349fcc1 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,21 @@ +name: CI +on: + push: + branches: [main] + pull_request: + branches: [main] +jobs: + test: + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 + with: + node-version: '20' + cache: 'npm' + - run: npm ci + - run: npm run typecheck + - run: npm test + - run: npm run build diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..6320988 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,22 @@ +name: Publish to npm +on: + release: + types: [published] +jobs: + publish: + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 + with: + node-version: '20' + registry-url: 'https://registry.npmjs.org' + cache: 'npm' + - run: npm ci + - run: npm test + - run: npm run build + - run: npm publish + env: + NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} diff --git a/README.md b/README.md index f1ef522..e0326b6 100644 --- a/README.md +++ b/README.md @@ -1 +1,2068 @@ -# DSTsx \ No newline at end of file +# DSTsx + +> A TypeScript-first port of [DSPy](https://github.com/stanfordnlp/dspy) — Declarative Self-improving Language Programs. + +[![npm version](https://img.shields.io/npm/v/dstsx.svg)](https://www.npmjs.com/package/dstsx) +[![license](https://img.shields.io/npm/l/dstsx.svg)](LICENSE) +[![tests](https://img.shields.io/badge/tests-218%20passing-brightgreen.svg)](#) + +DSTsx lets you build **typed, composable LM pipelines** in TypeScript and then **optimize** their prompts and few-shot examples automatically—no manual prompt engineering required. + +--- + +## Table of Contents + +1. [Installation](#installation) +2. [Quick Start](#quick-start) +3. [Core Concepts](#core-concepts) +4. [API Reference](#api-reference) + - [Signatures](#signatures) + - [Primitives — Example, Prediction, Trace](#primitives--example-prediction-trace) + - [Language Model Adapters](#language-model-adapters) + - [Settings & Context](#settings--context) + - [Modules](#modules) + - [Retrievers](#retrievers) + - [Optimizers](#optimizers) + - [Evaluation](#evaluation) + - [Assertions & Suggestions](#assertions--suggestions) +5. [V2 APIs](#v2-apis) + - [TypedPredictor & TypedChainOfThought](#typedpredictor--typedchainofthought) + - [Parallel Module](#parallel-module) + - [Refine Module](#refine-module) + - [majority() Helper](#majority-helper) + - [BootstrapFewShotWithOptuna](#bootstrapfewshotwithoptuna) + - [Disk-Persistent LM Cache](#disk-persistent-lm-cache) + - [MCP Integration](#mcp-integration) + - [LM Streaming](#lm-streaming) + - [NativeReAct](#nativereact) + - [Image — Multi-modal Support](#image--multi-modal-support) + - [BootstrapFinetune](#bootstrapfinetune) + - [GRPO Optimizer](#grpo-optimizer) + - [SIMBA Optimizer](#simba-optimizer) + - [AvatarOptimizer](#avataroptimizer) + - [Experiment Tracking](#experiment-tracking) + - [Worker-Thread ProgramOfThought](#worker-thread-programofthought) +6. [End-to-End Examples](#end-to-end-examples) +7. [V2 Roadmap](#v2-roadmap) + +--- + +## Installation + +```bash +npm install dstsx +``` + +Install provider SDK peer dependencies only for the adapters you use: + +```bash +# OpenAI +npm install openai + +# Anthropic +npm install @anthropic-ai/sdk + +# Cohere +npm install cohere-ai + +# Google Generative AI +npm install @google/generative-ai + +# Vector store retrievers (pick what you need) +npm install @pinecone-database/pinecone +npm install chromadb +npm install @qdrant/js-client-rest +npm install weaviate-client + +# MCP (Model Context Protocol) integration — optional +npm install @modelcontextprotocol/sdk +``` + +--- + +## Quick Start + +```ts +import { settings, OpenAI, Predict } from "dstsx"; + +// 1. Configure the global LM +settings.configure({ lm: new OpenAI({ model: "gpt-4o" }) }); + +// 2. Define a module +const qa = new Predict("question -> answer"); + +// 3. Run it +const result = await qa.forward({ question: "What is the capital of France?" }); +console.log(result.get("answer")); // "Paris" +``` + +--- + +## Core Concepts + +| Concept | Description | +|---|---| +| **Signature** | Typed interface (inputs → outputs) for one LM call | +| **Example** | Immutable key-value training record | +| **Prediction** | Module output extending Example with `completions` | +| **Trace** | Record of a single LM call (inputs, outputs, usage, latency) | +| **LM** | Abstract language model adapter | +| **Module** | Composable unit containing one or more LM calls | +| **Retriever** | Abstract vector-store backend | +| **Optimizer** | Automatically tunes demos/instructions of a Module | +| **Metric** | Scoring function for Evaluate and Optimizers | + +--- + +## API Reference + +### Signatures + +Signatures declare the **typed input/output interface** for a single LM call. + +#### `Signature.from(shorthand, instructions?)` + +Parse a shorthand string. Use `->` to separate inputs from outputs; suffix `?` for optional fields. + +```ts +import { Signature } from "dstsx"; + +// Simple shorthand +const sig = Signature.from("question -> answer"); + +// Multiple fields, optional field, instructions +const sig2 = Signature.from( + "context, question -> answer, confidence?", + "Answer based only on the provided context." +); +``` + +#### `new Signature(meta)` + +Construct a signature explicitly with full field metadata. + +```ts +import { Signature, InputField, OutputField } from "dstsx"; + +const sig = new Signature({ + inputs: new Map([ + ["context", InputField({ description: "Background passages" })], + ["question", InputField({ description: "The question to answer" })], + ]), + outputs: new Map([ + ["answer", OutputField({ description: "Concise factual answer", type: "string" })], + ]), + instructions: "Answer using only the context provided.", +}); +``` + +#### `InputField(meta?)` / `OutputField(meta?)` + +Builder helpers that return a `FieldMeta` descriptor. + +```ts +import { InputField, OutputField } from "dstsx"; + +const field = InputField({ + description: "The user's question", + prefix: "Q:", // optional prompt prefix + format: "markdown", // optional format hint + optional: true, // field may be absent + type: "string", // "string" | "number" | "boolean" | "string[]" | "object" +}); +``` + +#### `FieldMeta` interface + +```ts +interface FieldMeta { + description?: string; + prefix?: string; + format?: string; + optional?: boolean; + type?: "string" | "number" | "boolean" | "string[]" | "object"; +} +``` + +#### Signature mutation helpers (return new Signature, never mutate) + +```ts +const base = Signature.from("question -> answer"); + +// Add or override fields / instructions +const extended = base.with({ instructions: "Be concise." }); + +// Add a single input field +const withCtx = base.withInput("context", { description: "Background text" }); + +// Add a single output field +const withConf = base.withOutput("confidence", { type: "number" }); +``` + +#### Serialization + +```ts +const json = sig.toJSON(); // → plain object +const sig2 = Signature.fromJSON(json); // → Signature +``` + +--- + +### Primitives — Example, Prediction, Trace + +#### `Example` + +Immutable record of named values used as training data or module inputs. + +```ts +import { Example } from "dstsx"; + +const ex = new Example({ question: "What is 2+2?", answer: "4" }); + +ex.get("question"); // "What is 2+2?" +ex.toDict(); // { question: "What is 2+2?", answer: "4" } +ex.toJSON(); // same as toDict() + +// Non-mutating copy with overrides +const updated = ex.with({ answer: "four" }); + +// Filtered views +const inputOnly = ex.inputs(["question"]); // Example { question: ... } +const labelOnly = ex.labels(["question"]); // Example { answer: ... } + +// Deserialize +const ex2 = Example.fromDict({ question: "Hi", answer: "Hello" }); +``` + +#### `Prediction` + +Extends `Example` and adds `completions` for multi-output calls (`n > 1`). + +```ts +import { Prediction } from "dstsx"; + +const pred = new Prediction( + { answer: "42" }, + [{ answer: "42" }, { answer: "forty-two" }], // completions +); + +pred.get("answer"); // "42" +pred.getTyped("answer"); // "42" — typed cast +pred.completions; // ReadonlyArray of all candidates +pred.toJSON(); // { answer: "42", completions: [...] } +``` + +#### `Trace` + +Recorded per LM call. See [History](#history) for how to read traces. + +```ts +interface Trace { + signature: Signature; + inputs: Record; + outputs: Record; + usage: { promptTokens: number; completionTokens: number; totalTokens: number } | null; + latencyMs: number; + timestamp: string; // ISO-8601 +} +``` + +--- + +### Language Model Adapters + +All adapters extend `LM` and share the same `call()` interface. + +#### Abstract `LM` + +```ts +abstract class LM { + readonly model: string; + + // Call the LM with a string prompt or chat messages + async call(prompt: string | Message[], config?: LMCallConfig): Promise; + + // Counters (non-cached calls only) + get requestCount(): number; + get tokenUsage(): { promptTokens: number; completionTokens: number; totalTokens: number }; + + // Clear in-memory LRU response cache + clearCache(): void; +} +``` + +##### `LMCallConfig` + +```ts +interface LMCallConfig { + model?: string; // override model per call + temperature?: number; // 0–2 + maxTokens?: number; + stop?: string[]; + n?: number; // number of completions (default 1) + cacheKey?: string; // manual cache key override + extra?: Record; // provider pass-through +} +``` + +##### `LMResponse` + +```ts +interface LMResponse { + text: string; // primary completion + texts: string[]; // all completions when n > 1 + usage: TokenUsage | null; + raw: unknown; // raw provider response +} +``` + +--- + +#### `OpenAI` + +Requires: `npm install openai` + +```ts +import { OpenAI } from "dstsx"; + +const lm = new OpenAI({ + model: "gpt-4o", // default "gpt-4o" + apiKey: "sk-...", // or set OPENAI_API_KEY env var + baseURL: "https://...", // optional custom endpoint + maxRetries: 3, +}); +``` + +--- + +#### `Anthropic` + +Requires: `npm install @anthropic-ai/sdk` + +```ts +import { Anthropic } from "dstsx"; + +const lm = new Anthropic({ + model: "claude-3-5-sonnet-20241022", // default + apiKey: "...", // or ANTHROPIC_API_KEY + maxRetries: 3, +}); +``` + +--- + +#### `Cohere` + +Requires: `npm install cohere-ai` + +```ts +import { Cohere } from "dstsx"; + +const lm = new Cohere({ + model: "command-r-plus", // default + apiKey: "...", // or COHERE_API_KEY +}); +``` + +--- + +#### `GoogleAI` + +Requires: `npm install @google/generative-ai` + +```ts +import { GoogleAI } from "dstsx"; + +const lm = new GoogleAI({ + model: "gemini-1.5-pro", // default + apiKey: "...", // or GOOGLE_API_KEY +}); +``` + +--- + +#### `Ollama` + +No extra package required — communicates with the Ollama REST API. + +```ts +import { Ollama } from "dstsx"; + +const lm = new Ollama({ + model: "llama3", // default + baseURL: "http://localhost:11434", // default +}); +``` + +--- + +#### `LMStudio` + +No extra package required — uses LM Studio's OpenAI-compatible `/v1` endpoint. + +```ts +import { LMStudio } from "dstsx"; + +const lm = new LMStudio({ + model: "local-model", + baseURL: "http://localhost:1234/v1", // default +}); +``` + +--- + +#### `HuggingFace` + +No extra package required — calls the HuggingFace Inference API directly. + +```ts +import { HuggingFace } from "dstsx"; + +const lm = new HuggingFace({ + model: "mistralai/Mistral-7B-Instruct-v0.3", // default + apiKey: "...", // or HF_API_KEY + endpointURL: "https://my-dedicated-endpoint.com", // optional +}); +``` + +--- + +#### `MockLM` + +Deterministic lookup-map LM for unit testing. + +```ts +import { MockLM } from "dstsx"; + +const lm = new MockLM( + { + // prompt substring → response + "What is 2+2?": "answer: 4", + }, + "answer: unknown", // fallback when no match (optional; throws if omitted) +); + +// Add responses at runtime +lm.addResponse("What is the capital of France?", "answer: Paris"); +``` + +--- + +### Settings & Context + +The `settings` singleton controls global defaults. + +```ts +import { settings } from "dstsx"; +``` + +#### `settings.configure(options)` + +Merge options into global settings (existing keys are overwritten; omitted keys unchanged). + +```ts +settings.configure({ + lm: new OpenAI({ model: "gpt-4o" }), + rm: new ColBERTv2("http://localhost:8893"), + lmConfig: { temperature: 0.0, maxTokens: 512 }, + logLevel: "warn", // "silent" | "error" | "warn" | "info" | "debug" + cacheDir: "./.dstsx", // for future disk caching +}); +``` + +#### `settings.reset()` + +Reset all global settings to defaults. + +```ts +settings.reset(); +``` + +#### `settings.inspect()` + +Return a deep-frozen snapshot of currently effective settings. + +```ts +const snap = settings.inspect(); +console.log(snap.lm?.model); +``` + +#### `settings.context(overrides, fn)` — Per-request isolation + +Run `fn` inside an `AsyncLocalStorage` scope. Concurrent requests each get their own isolated settings and never interfere. + +```ts +// In an Express/Fastify handler: +app.post("/answer", async (req, res) => { + const answer = await settings.context( + { lm: new OpenAI({ model: "gpt-4o-mini" }) }, + () => program.forward({ question: req.body.question }), + ); + res.json(answer.toJSON()); +}); +``` + +#### `SettingsOptions` type + +```ts +interface SettingsOptions { + lm?: LM; + rm?: Retriever; + lmConfig?: LMCallConfig; + logLevel?: "silent" | "error" | "warn" | "info" | "debug"; + cacheDir?: string; +} +``` + +--- + +### Modules + +All modules extend `Module` and expose a `forward()` method. + +#### Abstract `Module` + +```ts +abstract class Module { + abstract forward(...args: unknown[]): Promise; + + // Recursively list all Predict sub-modules + namedPredictors(): Array<[string, Module]>; + + // Serialize/deserialize learnable parameters (demos, instructions) + dump(): Record; + load(state: Record): void; + + // Deep-clone the module (used by optimizers) + clone(): this; +} +``` + +**Saving and loading a compiled program:** + +```ts +import { writeFileSync, readFileSync } from "fs"; + +// Save +const state = program.dump(); +writeFileSync("program.json", JSON.stringify(state, null, 2)); + +// Load into a fresh instance +const fresh = new MyProgram(); +fresh.load(JSON.parse(readFileSync("program.json", "utf8"))); +``` + +--- + +#### `Predict` + +The fundamental module — formats a prompt and calls the LM. + +```ts +import { Predict } from "dstsx"; + +const qa = new Predict("question -> answer"); + +// Or with a full Signature object +const qa2 = new Predict(Signature.from("context, question -> answer")); + +// Forward +const result = await qa.forward({ question: "What is 9 × 7?" }); +console.log(result.get("answer")); // "63" +``` + +**Learnable parameters** (mutated by optimizers): + +```ts +qa.demos // Example[] — few-shot demonstrations +qa.instructions // string | undefined — system instruction override +``` + +**Multiple completions (`n > 1`):** + +```ts +settings.configure({ lmConfig: { n: 5 } }); +const result = await qa.forward({ question: "Name a color." }); +console.log(result.completions); // 5 candidate answers +``` + +--- + +#### `ChainOfThought` + +Prepends a hidden `rationale` output field so the LM reasons step-by-step before answering. + +```ts +import { ChainOfThought } from "dstsx"; + +const cot = new ChainOfThought("question -> answer"); + +// Optional: customize the rationale description +const cot2 = new ChainOfThought("question -> answer", { + rationaleDescription: "Think aloud to solve the problem", +}); + +const result = await cot.forward({ question: "If Alice has 3 apples and gets 5 more, how many does she have?" }); +console.log(result.get("answer")); // "8" — rationale is internal +``` + +--- + +#### `ChainOfThoughtWithHint` + +Extends `ChainOfThought` with an optional `hint` input. + +```ts +import { ChainOfThoughtWithHint } from "dstsx"; + +const cot = new ChainOfThoughtWithHint("question -> answer"); + +const result = await cot.forward({ + question: "What is the chemical formula for water?", + hint: "It involves hydrogen and oxygen.", +}); +``` + +--- + +#### `MultiChainComparison` + +Runs a signature `M` times and picks the best completion via a final aggregation call. + +```ts +import { MultiChainComparison } from "dstsx"; + +const mcc = new MultiChainComparison("question -> answer", /* M= */ 3); +const result = await mcc.forward({ question: "What is 7 × 8?" }); +``` + +--- + +#### `ReAct` + +Reasoning + Acting loop (Yao et al., 2022). Alternates Thought → Action → Observation until the LM emits `Finish[answer]` or `maxIter` is reached. + +```ts +import { ReAct, type Tool } from "dstsx"; + +const searchTool: Tool = { + name: "search", + description: "Search the web for current information", + fn: async (query: string) => { + // Your real implementation here + return `Search results for: ${query}`; + }, +}; + +const agent = new ReAct( + "question -> answer", + [searchTool], + /* maxIter= */ 5, +); + +const result = await agent.forward({ question: "Who won the 2024 US election?" }); +console.log(result.get("answer")); +console.log(result.get("trajectory")); // full thought/action/observation log +``` + +**`Tool` interface:** + +```ts +interface Tool { + name: string; + description: string; + fn: (args: string) => Promise; +} +``` + +--- + +#### `ProgramOfThought` + +Generates JavaScript code, executes it in a `new Function()` context, self-corrects on errors, and returns the result. + +> ⚠️ **Security**: Code runs in the current process. Do NOT use with untrusted user inputs in production without an additional sandboxing layer (e.g. a Worker thread). + +```ts +import { ProgramOfThought } from "dstsx"; + +const pot = new ProgramOfThought( + "question -> answer", + /* maxAttempts= */ 3, + /* timeoutMs= */ 5_000, +); + +const result = await pot.forward({ question: "What is the 10th Fibonacci number?" }); +console.log(result.get("answer")); // "55" +console.log(result.get("code")); // the generated JS code +``` + +--- + +#### `Retrieve` + +Calls the globally configured retriever and returns `passages`. + +```ts +import { Retrieve, ColBERTv2, settings } from "dstsx"; + +settings.configure({ rm: new ColBERTv2("http://localhost:8893") }); + +const retrieve = new Retrieve(/* k= */ 3); +const result = await retrieve.forward("What is DSPy?"); + +const passages: string[] = result.get("passages") as string[]; +``` + +--- + +#### `Retry` + +Wraps any module and retries on `AssertionError` (thrown by `Assert()`), feeding the error message back as `feedback`. + +```ts +import { Retry, Assert, Predict } from "dstsx"; + +const qa = new Predict("question, feedback? -> answer"); + +const retrying = new Retry(qa, /* maxAttempts= */ 3); + +const result = await retrying.forward({ + question: "Give a one-word answer: what color is the sky?", +}); + +// Use Assert inside a custom module to trigger retry +class CheckedQA extends Module { + predict = new Predict("question -> answer"); + async forward(inputs: { question: string }) { + const result = await this.predict.forward(inputs); + Assert( + String(result.get("answer")).length > 0, + "Answer must not be empty" + ); + return result; + } +} +``` + +--- + +#### `BestOfN` + +Runs `N` copies of a module in parallel and selects the best via `reduceFunc` (defaults to first result). + +```ts +import { BestOfN, Predict } from "dstsx"; + +const qa = new Predict("question -> answer"); +const best = new BestOfN(qa, /* N= */ 5, (predictions) => { + // Pick the longest answer as a proxy for quality + return predictions.reduce((a, b) => + String(b.get("answer")).length > String(a.get("answer")).length ? b : a + ); +}); + +const result = await best.forward({ question: "Explain gravity." }); +``` + +--- + +#### `Ensemble` + +Combines multiple pre-built module instances via a reduce function. + +```ts +import { Ensemble, ChainOfThought } from "dstsx"; + +const m1 = new ChainOfThought("question -> answer"); +const m2 = new ChainOfThought("question -> answer"); + +const ensemble = new Ensemble( + [m1, m2], + (predictions) => predictions[0]!, // custom vote/merge logic +); + +const result = await ensemble.forward({ question: "Is TypeScript better than JavaScript?" }); +``` + +--- + +### Retrievers + +All retrievers extend `Retriever` and implement `retrieve(query, k)`. + +#### Abstract `Retriever` + +```ts +abstract class Retriever { + abstract retrieve(query: string, k: number): Promise; +} +``` + +#### `ColBERTv2` + +```ts +import { ColBERTv2 } from "dstsx"; + +const rm = new ColBERTv2("http://localhost:8893"); +// or with options: +const rm2 = new ColBERTv2({ url: "http://localhost:8893" }); + +const passages = await rm.retrieve("What is photosynthesis?", 3); +``` + +#### `PineconeRM` + +Requires: `npm install @pinecone-database/pinecone` + +```ts +import { PineconeRM } from "dstsx"; + +const rm = new PineconeRM({ + indexName: "my-index", + apiKey: "...", // or PINECONE_API_KEY + namespace: "default", + embeddingFn: async (text) => myEmbedModel.embed(text), +}); +``` + +#### `ChromadbRM` + +Requires: `npm install chromadb` + +```ts +import { ChromadbRM } from "dstsx"; + +const rm = new ChromadbRM({ + collectionName: "my-collection", + url: "http://localhost:8000", // default + embeddingFn: async (texts) => myEmbedModel.embedBatch(texts), +}); +``` + +#### `QdrantRM` + +Requires: `npm install @qdrant/js-client-rest` + +```ts +import { QdrantRM } from "dstsx"; + +const rm = new QdrantRM({ + url: "http://localhost:6333", + collectionName: "my-collection", + embeddingFn: async (text) => myEmbedModel.embed(text), +}); +``` + +#### `WeaviateRM` + +Requires: `npm install weaviate-client` + +```ts +import { WeaviateRM } from "dstsx"; + +const rm = new WeaviateRM({ + url: "http://localhost:8080", + className: "Document", + textField: "content", + embeddingFn: async (text) => myEmbedModel.embed(text), +}); +``` + +#### `FaissRM` + +Requires: `npm install faiss-node` (optional peer dep) + +```ts +import { FaissRM } from "dstsx"; + +const rm = new FaissRM({ + passages: ["passage 1", "passage 2"], + embeddingFn: async (text) => myEmbedModel.embed(text), +}); +``` + +#### `YouRM` + +```ts +import { YouRM } from "dstsx"; + +const rm = new YouRM({ + apiKey: "...", // or YDC_API_KEY + k: 3, +}); +``` + +#### `MockRetriever` + +For unit testing. + +```ts +import { MockRetriever } from "dstsx"; + +const rm = new MockRetriever([ + "The capital of France is Paris.", + "Paris is located in northern France.", + "France is a country in Western Europe.", +]); + +const passages = await rm.retrieve("capital of France", 2); +``` + +--- + +### Optimizers + +Optimizers automatically tune a module's few-shot `demos` and/or `instructions`. All optimizers implement: + +```ts +abstract class Optimizer { + abstract compile( + student: Module, + trainset: Example[], + metric: Metric, + ): Promise; +} +``` + +- The returned module is a **new clone**; the original `student` is never mutated. +- Pass a `valset` where supported to evaluate on held-out data. + +--- + +#### `LabeledFewShot` + +Directly assigns labeled examples as `demos` on every `Predict` sub-module (no LM calls). + +```ts +import { LabeledFewShot } from "dstsx"; + +const optimizer = new LabeledFewShot(/* k= */ 16); +const optimized = await optimizer.compile(program, trainset, metric); +``` + +--- + +#### `BootstrapFewShot` + +Runs the student (or an optional `teacher`) on `trainset`, collects successful traces via `metric`, and uses them as `demos`. + +```ts +import { BootstrapFewShot } from "dstsx"; + +const optimizer = new BootstrapFewShot({ + maxBootstrappedDemos: 4, // max demos collected per predictor + maxLabeledDemos: 16, // max labeled fallback demos + teacher: expertProgram, // optional; defaults to student +}); + +const optimized = await optimizer.compile(program, trainset, exactMatch("answer")); +``` + +--- + +#### `BootstrapFewShotWithRandomSearch` + +Extends `BootstrapFewShot` — tries `numCandidatePrograms` random demo subsets and selects the best by validation score. + +```ts +import { BootstrapFewShotWithRandomSearch } from "dstsx"; + +const optimizer = new BootstrapFewShotWithRandomSearch({ + maxBootstrappedDemos: 4, + numCandidatePrograms: 8, // number of random subsets to evaluate + valset: valExamples, // optional held-out set +}); + +const optimized = await optimizer.compile(program, trainset, metric); +``` + +--- + +#### `COPRO` (Collaborative Prompt Optimizer) + +Uses the LM to propose instruction improvements for each `Predict` sub-module and selects the best combination by metric score. + +```ts +import { COPRO } from "dstsx"; + +const optimizer = new COPRO({ + breadth: 5, // instruction candidates per predictor per round + depth: 3, // refinement rounds +}); + +const optimized = await optimizer.compile(program, trainset, metric); +``` + +--- + +#### `MIPRO` (Multi-stage Instruction Prompt Optimizer) + +Combines COPRO-style instruction proposals with `BootstrapFewShotWithRandomSearch` to jointly optimize instructions _and_ demonstrations. + +```ts +import { MIPRO } from "dstsx"; + +const optimizer = new MIPRO({ + numCandidates: 5, // instruction candidates per predictor + initTemperature: 0.9, + numCandidatePrograms: 8, // demo subsets to evaluate + verbose: true, +}); + +const optimized = await optimizer.compile(program, trainset, metric); +``` + +--- + +#### `KNNFewShot` + +Selects demonstrations **at inference time** using k-nearest-neighbour search over the training set embeddings (dynamic few-shot). + +```ts +import { KNNFewShot } from "dstsx"; + +const optimizer = new KNNFewShot({ + k: 3, + embeddingFn: async (text) => myEmbedModel.embed(text), // required + keyField: "question", // which field to embed (default: all fields joined) +}); + +const optimized = await optimizer.compile(program, trainset, metric); +// At inference time, each forward() call auto-selects the 3 most similar demos +``` + +--- + +#### `EnsembleOptimizer` + +Wraps a program with an optional reduce function. Primarily useful for building multi-program ensembles. + +```ts +import { EnsembleOptimizer } from "dstsx"; + +const optimizer = new EnsembleOptimizer({ + reduceFunc: (predictions) => predictions[0]!, +}); + +const wrapped = await optimizer.compile(program, trainset, metric); +``` + +--- + +### Evaluation + +#### `evaluate(program, examples, metric, options?)` + +Run `program` on every example and aggregate scores. + +```ts +import { evaluate, exactMatch } from "dstsx"; + +const result = await evaluate( + program, + devset, + exactMatch("answer"), // built-in metric + { + numThreads: 4, // parallel evaluation (default: 1) + displayProgress: true, // log progress to console + }, +); + +console.log(`Score: ${(result.score * 100).toFixed(1)}%`); +console.log(`Passed: ${result.numPassed}/${result.total}`); +``` + +##### `EvaluationResult` + +```ts +interface EvaluationResult { + score: number; // average metric score (0–1) + numPassed: number; + total: number; + results: ExampleResult[]; // per-example breakdown +} + +interface ExampleResult { + example: Example; + prediction: Prediction; + score: number; + passed: boolean; +} +``` + +--- + +#### Built-in Metrics + +All metrics implement `Metric`: + +```ts +type Metric = ( + example: Example, + prediction: Prediction, + trace?: Trace[], +) => number | boolean; +``` + +| Factory | Description | +|---|---| +| `exactMatch(field?, caseSensitive?)` | 1 if prediction exactly matches example (case-insensitive by default) | +| `f1(field?)` | Token-level F1 (word overlap), useful for QA | +| `passAtK(innerMetric, k)` | 1 if any of the top-k completions pass `innerMetric` | +| `bleu(field?)` | Simplified BLEU (1-gram + 2-gram precision) | +| `rouge(field?)` | ROUGE-L (LCS-based F1) | + +```ts +import { exactMatch, f1, passAtK, bleu, rouge } from "dstsx"; + +// Exact match on "answer" field (default) +const em = exactMatch(); + +// Case-sensitive exact match on a custom field +const em2 = exactMatch("label", true); + +// Token F1 +const f1Metric = f1("answer"); + +// Pass if any of the 5 completions give exact match +const p5 = passAtK(exactMatch(), 5); + +// BLEU / ROUGE +const bleuMetric = bleu("answer"); +const rougeMetric = rouge("answer"); +``` + +--- + +### Assertions & Suggestions + +#### `Assert(condition, message?)` + +Throws `AssertionError` if `condition` is falsy. Caught and retried by `Retry`. + +```ts +import { Assert } from "dstsx"; + +Assert(result.get("answer") !== "", "Answer must not be empty"); +Assert(typeof result.get("score") === "number", "Score must be a number"); +``` + +#### `Suggest(condition, message?)` + +Logs a `console.warn` if `condition` is falsy but does **not** throw — the pipeline continues. + +```ts +import { Suggest } from "dstsx"; + +Suggest(result.get("confidence") === "high", "Low confidence in answer"); +``` + +#### `AssertionError` + +The typed error class thrown by `Assert`. Caught by `Retry`. + +```ts +import { AssertionError } from "dstsx"; + +try { + await program.forward(inputs); +} catch (err) { + if (err instanceof AssertionError) { + console.warn("Assertion failed:", err.message); + } +} +``` + +--- + +## V2 APIs + +The following features are implemented in DSTsx v2. + +### `TypedPredictor` & `TypedChainOfThought` + +Structured JSON output with optional schema validation. Works without any extra +dependencies — pass a [Zod](https://github.com/colinhacks/zod) schema for +runtime validation. + +#### `TypedPrediction` + +Extends `Prediction` and adds a `.typed` field with the validated/parsed type. + +```ts +import { TypedPredictor } from "dstsx"; + +// Without schema — output is parsed as plain JSON +const qa = new TypedPredictor("question -> answer"); +const result = await qa.forward({ question: "What is π?" }); +const typed = result.typed as { answer: string }; +console.log(typed.answer); +``` + +With a Zod schema (`npm install zod` first): + +```ts +import { z } from "zod"; +import { TypedPredictor, TypedChainOfThought } from "dstsx"; + +const AnswerSchema = z.object({ + answer: z.string(), + confidence: z.number().min(0).max(1), + sources: z.array(z.string()).optional(), +}); + +const qa = new TypedPredictor("question -> answer", AnswerSchema, { maxRetries: 3 }); +const result = await qa.forward({ question: "What is 2 + 2?" }); + +// result.typed is z.infer +console.log(result.typed.confidence); // 0.98 (number, validated) +``` + +`TypedChainOfThought` adds a hidden `rationale` step before producing the JSON: + +```ts +const cot = new TypedChainOfThought("question -> answer", AnswerSchema); +const result = await cot.forward({ question: "Explain gravity briefly." }); +``` + +**Constructor options:** + +```ts +new TypedPredictor(signature, schema?, { maxRetries?: number }) +// maxRetries: how many times to retry on parse/schema failure (default: 3) +``` + +--- + +### `Parallel` Module + +Runs multiple modules concurrently with `Promise.all` and returns all results. + +```ts +import { Parallel, Predict, ChainOfThought } from "dstsx"; + +const pipeline = new Parallel([ + new Predict("question -> answer"), + new ChainOfThought("question -> answer"), +], { timeoutMs: 10_000 }); // optional per-module timeout + +// run() returns Prediction[] — one per module +const [directAnswer, cotAnswer] = await pipeline.run({ question: "What is π?" }); + +// forward() returns the first prediction (for Module interface compat) +const first = await pipeline.forward({ question: "What is π?" }); +``` + +**Constructor:** + +```ts +new Parallel(modules: Module[], options?: { timeoutMs?: number }) +``` + +| Method | Returns | Description | +|---|---|---| +| `run(...args)` | `Promise` | All module outputs in order | +| `forward(...args)` | `Promise` | First module output (Module compat) | + +--- + +### `Refine` Module + +Self-critique / iterative refinement loop. After each inner module run, a built-in +critic predictor evaluates the output and feeds improvement suggestions back. + +```ts +import { Refine, Predict } from "dstsx"; + +const writer = new Predict("topic, feedback? -> essay"); + +const refined = new Refine(writer, { + maxRefinements: 2, + feedbackField: "feedback", // injected field name for critique + stopCondition: (pred) => + String(pred.get("essay")).length > 500, // stop early if long enough +}); + +const result = await refined.forward({ topic: "Climate change" }); +console.log(result.get("essay")); +``` + +**Constructor:** + +```ts +new Refine(inner: Module, options?: { + maxRefinements?: number; // default: 2 + feedbackField?: string; // default: "feedback" + stopCondition?: (p: Prediction) => boolean; // optional early-exit check +}) +``` + +The critic calls `Predict("output -> critique, is_satisfactory")`. +If `is_satisfactory` is `"yes"` or `"true"`, refinement stops early. + +--- + +### `majority()` Helper + +Votes across multiple `Prediction` instances by the most common value for a given +field. Useful as a `reduceFunc` in `BestOfN` and `Ensemble`. + +```ts +import { majority, BestOfN, Predict } from "dstsx"; + +const qa = new Predict("question -> answer"); + +// Run 5 times and pick the most common answer +const voted = new BestOfN(qa, 5, majority("answer")); +const result = await voted.forward({ question: "What color is the sky?" }); +console.log(result.get("answer")); // most frequently returned answer +``` + +```ts +// Standalone usage +import { majority } from "dstsx"; + +const reducer = majority("answer"); +const best = reducer([pred1, pred2, pred3]); // Prediction with the most common "answer" +``` + +--- + +### `BootstrapFewShotWithOptuna` + +Extends `BootstrapFewShot` with a pure-TypeScript TPE (Tree-structured Parzen +Estimator) that searches demo subsets across `numTrials` iterations, learning +from past trial outcomes to find the best configuration — no external +dependencies required. + +```ts +import { BootstrapFewShotWithOptuna } from "dstsx"; + +const optimizer = new BootstrapFewShotWithOptuna({ + maxBootstrappedDemos: 4, + numTrials: 20, // number of TPE search trials + valset: valExamples, // optional held-out validation set +}); + +const optimized = await optimizer.compile(program, trainset, metric); +``` + +**How it works:** First runs `BootstrapFewShot` to collect candidate demos. Then +runs `numTrials` iterations where each trial samples a demo subset using TPE: +the top 25 % of past trials form the "good" pool, sampled with 70 % probability, +biased mutations towards the best configurations found so far. + +--- + +### Disk-Persistent LM Cache + +LM adapters now accept a `cacheDir` option. Responses are persisted as JSON +files named by a SHA-256 hash of the prompt, surviving process restarts. + +```ts +import { OpenAI, MockLM } from "dstsx"; + +// Any LM adapter — just pass cacheDir +const lm = new OpenAI({ + model: "gpt-4o", + cacheDir: "./.dstsx-cache", // optional disk persistence +}); + +// Or with MockLM for testing disk cache behavior +const mockLm = new MockLM({ "q": "a" }, undefined, { cacheDir: "/tmp/test-cache" }); +``` + +The disk cache is checked **after** the in-memory LRU cache. On a hit the +response is also written back into the in-memory cache. TTL and max-size +eviction apply to both layers. + +`DiskCache` is also exported for custom use: + +```ts +import { DiskCache } from "dstsx"; + +const cache = new DiskCache( + "./.dstsx-cache", // directory (created automatically) + 500, // maxSize (files); default 500 + 3_600_000, // ttlMs; default undefined (no TTL) +); + +cache.set("myKey", lmResponse); +const cached = cache.get("myKey"); +cache.clear(); // delete all cache files +``` + +--- + +### MCP Integration + +DSTsx integrates with the [Model Context Protocol](https://modelcontextprotocol.io/) +(MCP) in two directions: + +1. **Use MCP servers as tools** inside ReAct agents (`MCPToolAdapter`) +2. **Expose DSTsx modules as MCP tools** (`DSTsxMCPServer`) + +Optional peer dependency: `npm install @modelcontextprotocol/sdk` + +--- + +#### `MCPToolAdapter` — consume MCP servers in ReAct + +Wraps the tools from an MCP server as DSTsx `Tool` objects for use with `ReAct`. + +```ts +import { MCPToolAdapter, ReAct } from "dstsx"; + +const adapter = new MCPToolAdapter({ + // Test mode: supply tool definitions + call handler without a live server + tools: [ + { + name: "weather", + description: "Get current weather for a city", + inputSchema: { + type: "object", + properties: { city: { type: "string" } }, + required: ["city"], + }, + }, + ], + callHandler: async (name, args) => { + if (name === "weather") return `Sunny in ${args["city"] as string}`; + throw new Error(`Unknown tool: ${name}`); + }, +}); + +const tools = await adapter.getTools(); +const agent = new ReAct("question -> answer", tools, 5); +const result = await agent.forward({ question: "What is the weather in Paris?" }); +``` + +When `@modelcontextprotocol/sdk` is installed, a live SSE/stdio connection can +be established by setting `serverUrl` (full live-connection implementation is in +the [v2 roadmap](./V2_ROADMAP.md#live-mcp-connection)). + +--- + +#### `DSTsxMCPServer` — expose DSTsx modules as MCP tools + +Register any DSTsx module and serve it as an MCP-compatible tool. + +```ts +import { DSTsxMCPServer, ChainOfThought, settings, OpenAI } from "dstsx"; + +settings.configure({ lm: new OpenAI({ model: "gpt-4o" }) }); + +const qa = new ChainOfThought("context, question -> answer"); + +const server = new DSTsxMCPServer(); +server.registerModule( + "qa", // tool name + "Answer questions using chain-of-thought reasoning", + qa, + ["context", "question"], // input field names +); + +// List tools (for MCP handshake) +const toolDefs = server.getToolDefinitions(); +/* +[{ + name: "qa", + description: "...", + inputSchema: { type: "object", properties: { context: { type: "string" }, ... } }, +}] +*/ + +// Handle a tool call +const result = await server.callTool("qa", { + context: "Paris is the capital of France.", + question: "What is the capital of France?", +}); +// result is the Prediction.toJSON() object + +// With @modelcontextprotocol/sdk installed, launch a stdio MCP server: +// await server.createStdioServer(); +``` + +**MCPTool type:** + +```ts +interface MCPTool { + name: string; + description: string; + inputSchema: { + type: "object"; + properties: Record; + required?: string[]; + }; + handler: (inputs: Record) => Promise; +} +``` + +**`DSTsxMCPServer` methods:** + +| Method | Description | +|---|---| +| `registerModule(name, desc, module, fields)` | Register a module as an MCP tool | +| `getToolDefinitions()` | Return all registered `MCPTool[]` | +| `callTool(name, inputs)` | Invoke a registered tool by name | +| `createStdioServer()` | Start an MCP stdio server (requires SDK) | + +--- + +### LM Streaming + +Stream LM responses token by token using `AsyncGenerator`. All adapters provide +a default fallback that returns the full response as a single chunk. Real +streaming is implemented for `OpenAI` and `Anthropic`. + +```ts +import { settings, OpenAI, Predict } from "dstsx"; + +settings.configure({ lm: new OpenAI({ model: "gpt-4o" }) }); + +const qa = new Predict("question -> answer"); + +// Stream via Predict +for await (const chunk of qa.stream({ question: "Tell me a story." })) { + process.stdout.write(chunk.delta); + if (chunk.done) break; +} + +// Stream directly from LM +const lm = new OpenAI({ model: "gpt-4o" }); +for await (const chunk of lm.stream("What is TypeScript?")) { + process.stdout.write(chunk.delta); +} +``` + +**`StreamChunk` type:** + +```ts +interface StreamChunk { + delta: string; // Incremental text + done: boolean; // True on the final chunk + raw: unknown; // Raw provider chunk +} +``` + +| Method | Available on | Description | +|---|---|---| +| `lm.stream(prompt, config?)` | `LM` (all adapters) | Stream from LM (fallback on unsupported) | +| `predict.stream(inputs)` | `Predict` | Stream from a Predict module | + +--- + +### NativeReAct + +A `ReAct` variant that uses provider-native function/tool calling (OpenAI tools +API, Anthropic tool_use) instead of text-based action parsing. Falls back to +text format for adapters that don't support native calling. + +```ts +import { NativeReAct, settings, OpenAI } from "dstsx"; +import type { Tool } from "dstsx"; + +settings.configure({ lm: new OpenAI({ model: "gpt-4o" }) }); + +const tools: Tool[] = [ + { + name: "search", + description: "Search the web for information", + fn: async (args) => { + const { query } = JSON.parse(args) as { query: string }; + return `Results for: ${query}`; + }, + }, +]; + +const agent = new NativeReAct("question -> answer", tools, 5); +const result = await agent.forward({ question: "What is the capital of France?" }); +console.log(result.get("answer")); // "Paris" +``` + +**Constructor:** + +```ts +new NativeReAct( + signature: string, + tools: Tool[], + maxIter?: number, // default: 5 +) +``` + +NativeReAct wraps tool calls using the OpenAI `tools` format (an array of +`{ type: "function", function: { name, description, parameters } }` objects), +which is also accepted by other compatible providers. + +--- + +### Image — Multi-modal Support + +The `Image` primitive enables passing images to vision-capable LMs as field +values in any `Predict` or `TypedPredictor` call. + +```ts +import { Image, Predict, settings, OpenAI } from "dstsx"; + +settings.configure({ lm: new OpenAI({ model: "gpt-4o" }) }); + +const captioner = new Predict("image, question -> caption"); + +// From a URL (lazy — no download at construction time) +const img1 = Image.fromURL("https://example.com/photo.jpg"); +const result1 = await captioner.forward({ image: img1, question: "Describe this image." }); + +// From base64 data +const img2 = Image.fromBase64(base64String, "image/png"); + +// From a local file (read synchronously) +const img3 = Image.fromFile("./photo.jpg"); +``` + +**Static factory methods:** + +| Method | Description | +|---|---| +| `Image.fromURL(url)` | Image at a remote URL | +| `Image.fromBase64(data, mimeType?)` | Inline base64 (default: `"image/jpeg"`) | +| `Image.fromFile(path, mimeType?)` | Local file read synchronously; auto-detects MIME from extension | + +**Serialization helpers (used by adapters internally):** + +```ts +img.toOpenAIContentPart() // { type: "image_url", image_url: { url } } +img.toAnthropicContentBlock() // { type: "image", source: { ... } } +img.toString() // "[Image: https://...]" — used in text prompts +``` + +**Supported MIME types:** `"image/jpeg"`, `"image/png"`, `"image/gif"`, `"image/webp"` + +--- + +### BootstrapFinetune + +Extends `BootstrapFewShot` to collect execution traces and export them as a +JSONL fine-tuning dataset. + +```ts +import { BootstrapFinetune } from "dstsx"; + +const optimizer = new BootstrapFinetune({ + exportPath: "./finetune_data.jsonl", // default + format: "openai", // "openai" | "generic" + maxBootstrappedDemos: 4, +}); + +const compiled = await optimizer.compile(program, trainset, metric); +// "./finetune_data.jsonl" now contains one JSON object per line: +// {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]} +``` + +**Options:** + +| Option | Type | Default | Description | +|---|---|---|---| +| `exportPath` | `string` | `"./finetune_data.jsonl"` | Output file path | +| `format` | `"openai" \| "generic"` | `"openai"` | JSONL line format | +| `maxBootstrappedDemos` | `number` | `4` | Demos to bootstrap per predictor | + +**Format examples:** + +`"openai"` (suitable for OpenAI fine-tuning): +```jsonl +{"messages": [{"role": "user", "content": "question: What is 2+2?"}, {"role": "assistant", "content": "answer: 4"}]} +``` + +`"generic"` (plain prompt/completion): +```jsonl +{"prompt": "question: What is 2+2?", "completion": "answer: 4"} +``` + +--- + +### GRPO Optimizer + +Group Relative Policy Optimization — mirrors `dspy.GRPO`. Iteratively generates +groups of candidate instruction variants, evaluates them relative to each other, +and converges toward the best-scoring configuration. + +```ts +import { GRPO } from "dstsx"; + +const optimizer = new GRPO({ + numSteps: 20, // optimization iterations + groupSize: 8, // candidates per group + temperature: 1.0, // sampling temperature +}); + +const optimized = await optimizer.compile(program, trainset, metric); +``` + +**How it works:** Each step generates `groupSize` instruction alternatives using +the configured LM at the specified temperature. All candidates are evaluated on +the training set. The relative advantage of each is computed as +`(score − mean) / std`. The best-scoring candidate becomes the new baseline for +the next step. + +**Options:** + +| Option | Type | Default | +|---|---|---| +| `numSteps` | `number` | `20` | +| `groupSize` | `number` | `8` | +| `temperature` | `number` | `1.0` | +| `maxLabeledDemos` | `number` | `16` | + +--- + +### SIMBA Optimizer + +SIMBA (Stochastic Introspective Mini-Batch Ascent) — mirrors `dspy.SIMBA`. A +lightweight stochastic optimizer well-suited for small training sets. + +```ts +import { SIMBA } from "dstsx"; + +const optimizer = new SIMBA({ + numIter: 10, // iterations + batchSize: 8, // mini-batch size per evaluation + maxBootstrappedDemos: 4, +}); + +const optimized = await optimizer.compile(program, trainset, metric); +``` + +**How it works:** Initializes with `BootstrapFewShot`, then for each iteration +draws a random mini-batch (Fisher-Yates shuffle), proposes a demo-subset +candidate, evaluates it on the batch, and accepts it if it improves on the +current best score. + +**Options:** + +| Option | Type | Default | +|---|---|---| +| `numIter` | `number` | `10` | +| `batchSize` | `number` | `8` | +| `maxBootstrappedDemos` | `number` | `4` | + +--- + +### AvatarOptimizer + +Iteratively proposes and evaluates "avatar" role descriptions (persona prefixes) +for each `Predict` module, selecting the instruction that scores highest on the +training set. + +```ts +import { AvatarOptimizer } from "dstsx"; + +const optimizer = new AvatarOptimizer({ + numAvatars: 4, // candidate personas per predictor + maxLabeledDemos: 8, +}); + +const optimized = await optimizer.compile(program, trainset, metric); +``` + +**How it works:** For each `Predict` predictor in the program, asks the +configured LM to generate `numAvatars` distinct role/persona descriptions +(e.g. "You are an expert doctor…"). Each persona is prepended to the predictor's +instructions and scored on the training set. The best persona is kept. + +**Options:** + +| Option | Type | Default | +|---|---|---| +| `numAvatars` | `number` | `4` | +| `maxLabeledDemos` | `number` | `8` | + +--- + +### Experiment Tracking + +Track optimizer progress (steps, scores, best candidates) to console, JSON +files, or custom backends. + +```ts +import { ConsoleTracker, JsonFileTracker, GRPO, SIMBA, AvatarOptimizer } from "dstsx"; + +// Log to console +const consoleTracker = new ConsoleTracker(); +// [STEP] step=1 score=0.7500 +// [BEST] step=3 score=0.8750 + +// Log to JSONL file +const fileTracker = new JsonFileTracker("./runs/exp1.jsonl"); +await fileTracker.flush(); // flush buffer to disk + +// Pass to any optimizer +const optimizer = new GRPO({ numSteps: 10 }); +// (trackers are accepted as options on GRPO, SIMBA, AvatarOptimizer) +``` + +**Custom tracker — extend `Tracker`:** + +```ts +import { Tracker } from "dstsx"; +import type { TrackerEvent } from "dstsx"; + +class MLflowTracker extends Tracker { + log(event: TrackerEvent): void { + // Send to MLflow REST API, W&B, etc. + console.log("mlflow.log_metric", event.score); + } + async flush(): Promise {} +} +``` + +**`TrackerEvent` type:** + +```ts +interface TrackerEvent { + type: "step" | "trial" | "best" | "done"; + step?: number; + score?: number; + metadata?: Record; +} +``` + +**Exported classes:** `Tracker` (abstract), `ConsoleTracker`, `JsonFileTracker` + +--- + +### Worker-Thread `ProgramOfThought` + +`ProgramOfThought` now supports a `sandbox` option to run generated code in a +Node.js Worker thread instead of `new Function()` in the main process, providing +stronger isolation. + +```ts +import { ProgramOfThought } from "dstsx"; + +// Default: "function" (same process, backwards-compatible) +const pot = new ProgramOfThought("question -> answer"); + +// Worker thread isolation (Node 18+) +const potWorker = new ProgramOfThought( + "question -> answer", + 3, // maxAttempts + 5_000, // timeoutMs + "worker" // sandbox mode +); +const result = await potWorker.forward({ question: "What is 2 ** 10?" }); +console.log(result.get("answer")); // "1024" + +// Disable timeout/sandbox entirely (not recommended for untrusted input) +const potNone = new ProgramOfThought("question -> answer", 3, 5_000, "none"); +``` + +**Sandbox modes:** + +| Mode | Isolation | True Cancellation | Notes | +|---|---|---|---| +| `"function"` | None — runs in main process | No | **Default.** Backwards-compatible. | +| `"worker"` | Node.js Worker thread | Yes (terminate on timeout) | Requires Node 18+. | +| `"none"` | None — no timeout applied | N/A | Fastest; do not use with untrusted code. | + +--- + +## End-to-End Examples + +### 1. Simple Q&A + +```ts +import { settings, OpenAI, Predict } from "dstsx"; + +settings.configure({ lm: new OpenAI({ model: "gpt-4o-mini" }) }); + +const qa = new Predict("question -> answer"); + +const result = await qa.forward({ question: "What is the speed of light?" }); +console.log(result.get("answer")); +``` + +--- + +### 2. Retrieval-Augmented Generation (RAG) + +```ts +import { + Module, Retrieve, ChainOfThought, ColBERTv2, + settings, OpenAI, type Prediction, +} from "dstsx"; + +settings.configure({ + lm: new OpenAI({ model: "gpt-4o" }), + rm: new ColBERTv2("http://localhost:8893"), +}); + +class RAG extends Module { + retrieve = new Retrieve(3); + generate = new ChainOfThought("context, question -> answer"); + + async forward(inputs: { question: string }): Promise { + const { passages } = (await this.retrieve.forward(inputs.question)).toDict() as { passages: string[] }; + return this.generate.forward({ + context: passages.join("\n"), + question: inputs.question, + }); + } +} + +const rag = new RAG(); +const result = await rag.forward({ question: "What is the capital of Germany?" }); +console.log(result.get("answer")); // "Berlin" +``` + +--- + +### 3. Optimizing with BootstrapFewShot + +```ts +import { + settings, MockLM, Predict, Module, BootstrapFewShot, + Example, evaluate, exactMatch, type Prediction, +} from "dstsx"; + +settings.configure({ lm: new MockLM({}, "answer: 42") }); + +class QA extends Module { + predict = new Predict("question -> answer"); + async forward(inputs: { question: string }): Promise { + return this.predict.forward(inputs); + } +} + +const trainset = [ + new Example({ question: "What is 6 × 7?", answer: "42" }), + new Example({ question: "What is 8 × 8?", answer: "64" }), +]; + +const optimizer = new BootstrapFewShot({ maxBootstrappedDemos: 2 }); +const optimized = await optimizer.compile(new QA(), trainset, exactMatch("answer")); + +// Persist +import { writeFileSync } from "fs"; +writeFileSync("qa_optimized.json", JSON.stringify(optimized.dump(), null, 2)); +``` + +--- + +### 4. ReAct Agent + +```ts +import { settings, OpenAI, ReAct, type Tool } from "dstsx"; + +settings.configure({ lm: new OpenAI({ model: "gpt-4o" }) }); + +const tools: Tool[] = [ + { + name: "calculator", + description: "Evaluates a mathematical expression and returns the numeric result", + fn: async (expr) => String(Function(`"use strict"; return (${expr})`)()), + }, + { + name: "lookup", + description: "Looks up a fact in the knowledge base", + fn: async (query) => `Fact about ${query}: (result from KB)`, + }, +]; + +const agent = new ReAct("question -> answer", tools, /* maxIter= */ 6); +const result = await agent.forward({ question: "What is (123 * 456) + 789?" }); +console.log(result.get("answer")); +console.log(result.get("trajectory")); +``` + +--- + +### 5. Assertions with Retry + +```ts +import { + settings, MockLM, Module, Predict, Retry, Assert, type Prediction, +} from "dstsx"; + +settings.configure({ lm: new MockLM({}, "answer: Paris") }); + +class CapitalQA extends Module { + predict = new Predict("question, feedback? -> answer"); + + async forward(inputs: { question: string }): Promise { + const result = await this.predict.forward(inputs); + Assert( + String(result.get("answer")).trim().length > 0, + "Answer must not be empty" + ); + return result; + } +} + +const retrying = new Retry(new CapitalQA(), 3); +const result = await retrying.forward({ question: "What is the capital of France?" }); +console.log(result.get("answer")); // "Paris" +``` + +--- + +### 6. Per-Request LM Override (server environments) + +```ts +import express from "express"; +import { settings, OpenAI, Predict } from "dstsx"; + +const app = express(); +const qa = new Predict("question -> answer"); +const gpt4 = new OpenAI({ model: "gpt-4o" }); +const gptMini = new OpenAI({ model: "gpt-4o-mini" }); + +settings.configure({ lm: gpt4 }); // global default + +app.get("/fast", async (req, res) => { + // Override LM for this request only — concurrent requests never interfere + const result = await settings.context( + { lm: gptMini }, + () => qa.forward({ question: req.query["q"] as string }), + ); + res.json(result.toJSON()); +}); + +app.listen(3000); +``` + +--- + +## V2 Roadmap + +The following features are ✅ **implemented in v2**. See [V2_ROADMAP.md](./V2_ROADMAP.md) for details. + +| Feature | DSPy Equivalent | Status | +|---|---|---| +| **`TypedPredictor`** — JSON-schema + optional Zod validation | `dspy.TypedPredictor`, `dspy.TypedChainOfThought` | ✅ v2 | +| **`Parallel`** module — fan-out / fan-in concurrency | `dspy.Parallel` | ✅ v2 | +| **`Refine`** module — self-critique loop | — | ✅ v2 | +| **`majority()`** helper — vote across Predictions | `dspy.majority` | ✅ v2 | +| **`BootstrapFewShotWithOptuna`** — TPE Bayesian search | `dspy.BootstrapFewShotWithOptuna` | ✅ v2 | +| **Disk-persistent LM cache** — file-based LRU | `dspy.cache` | ✅ v2 | +| **MCP Integration** — `MCPToolAdapter` + `DSTsxMCPServer` | — | ✅ v2 | +| **LM Streaming** — `lm.stream()`, `predict.stream()` | `dspy.streamify` | ✅ v2 | +| **`NativeReAct`** — OpenAI functions / Anthropic tool use | `dspy.Tool` (v2) | ✅ v2 | +| **`Image`** — multi-modal image inputs | `dspy.Image` | ✅ v2 | +| **`BootstrapFinetune`** — JSONL fine-tuning export | `dspy.BootstrapFinetune` | ✅ v2 | +| **`GRPO`** optimizer — group relative policy optimization | `dspy.GRPO` | ✅ v2 | +| **`SIMBA`** optimizer — stochastic mini-batch ascent | `dspy.SIMBA` | ✅ v2 | +| **`AvatarOptimizer`** — role/persona prompt optimization | `dspy.AvatarOptimizer` | ✅ v2 | +| **Experiment Tracking** — `ConsoleTracker`, `JsonFileTracker` | `dspy.MLflow` | ✅ v2 | +| **Worker-thread `ProgramOfThought`** — `sandbox: "worker"` | — | ✅ v2 | +| **Typedoc config** — `typedoc.json` + `npm run docs` | — | ✅ v2 | +| **GitHub Actions** — CI + npm publish workflows | — | ✅ v2 | +| Cross-language trace sharing | — | 🔭 Stretch | +| Browser-native bundle (`dstsx/browser`) | — | 🔭 Stretch | +| HTTP module serving (REST endpoint) | — | 🔭 Stretch | + +--- + +## License + +MIT \ No newline at end of file diff --git a/REQUIREMENTS.md b/REQUIREMENTS.md index af33607..698078d 100644 --- a/REQUIREMENTS.md +++ b/REQUIREMENTS.md @@ -284,44 +284,48 @@ Signatures define the **typed interface** (inputs and outputs) for a single LM c | DSPy Symbol | DSTsx Symbol | Status | |---|---|---| -| `dspy.Signature` | `Signature` | Planned | -| `dspy.InputField` | `InputField` | Planned | -| `dspy.OutputField` | `OutputField` | Planned | -| `dspy.Module` | `Module` | Planned | -| `dspy.Predict` | `Predict` | Planned | -| `dspy.ChainOfThought` | `ChainOfThought` | Planned | -| `dspy.ChainOfThoughtWithHint` | `ChainOfThoughtWithHint` | Planned | -| `dspy.MultiChainComparison` | `MultiChainComparison` | Planned | -| `dspy.ReAct` | `ReAct` | Planned | -| `dspy.ProgramOfThought` | `ProgramOfThought` | Planned | -| `dspy.Retrieve` | `Retrieve` | Planned | -| `dspy.Retry` | `Retry` | Planned | -| `dspy.Predict` (n>1) | `Predict` (n>1) | Planned | -| `dspy.Example` | `Example` | Planned | -| `dspy.Prediction` | `Prediction` | Planned | -| `dspy.LM` | `LM` | Planned | -| `dspy.OpenAI` | `OpenAI` | Planned | -| `dspy.Anthropic` | `Anthropic` | Planned | -| `dspy.Cohere` | `Cohere` | Planned | -| `dspy.Google` | `GoogleAI` | Planned | -| `dspy.OllamaLocal` | `Ollama` | Planned | -| `dspy.HFModel` | `HuggingFace` | Planned | -| `dspy.ColBERTv2` | `ColBERTv2` | Planned | -| `dspy.Pinecone` | `PineconeRM` | Planned | -| `dspy.Weaviate` | `WeaviateRM` | Planned | -| `dspy.Chromadb` | `ChromadbRM` | Planned | -| `dspy.Qdrant` | `QdrantRM` | Planned | -| `dspy.LabeledFewShot` | `LabeledFewShot` | Planned | -| `dspy.BootstrapFewShot` | `BootstrapFewShot` | Planned | -| `dspy.BootstrapFewShotWithRandomSearch` | `BootstrapFewShotWithRandomSearch` | Planned | -| `dspy.COPRO` | `COPRO` | Planned | -| `dspy.MIPRO` | `MIPRO` | Planned | -| `dspy.KNNFewShot` | `KNNFewShot` | Planned | -| `dspy.Ensemble` | `Ensemble` | Planned | -| `dspy.Evaluate` | `evaluate` | Planned | -| `dspy.Assert` | `Assert` | Planned | -| `dspy.Suggest` | `Suggest` | Planned | -| `dspy.settings` | `settings` | Planned | +| `dspy.Signature` | `Signature` | ✅ Implemented | +| `dspy.InputField` | `InputField` | ✅ Implemented | +| `dspy.OutputField` | `OutputField` | ✅ Implemented | +| `dspy.Module` | `Module` | ✅ Implemented | +| `dspy.Predict` | `Predict` | ✅ Implemented | +| `dspy.ChainOfThought` | `ChainOfThought` | ✅ Implemented | +| `dspy.ChainOfThoughtWithHint` | `ChainOfThoughtWithHint` | ✅ Implemented | +| `dspy.MultiChainComparison` | `MultiChainComparison` | ✅ Implemented | +| `dspy.ReAct` | `ReAct` | ✅ Implemented | +| `dspy.ProgramOfThought` | `ProgramOfThought` | ✅ Implemented | +| `dspy.Retrieve` | `Retrieve` | ✅ Implemented | +| `dspy.Retry` | `Retry` | ✅ Implemented | +| `dspy.Predict` (n>1) | `Predict` (n>1) | ✅ Implemented | +| `dspy.Example` | `Example` | ✅ Implemented | +| `dspy.Prediction` | `Prediction` | ✅ Implemented | +| `dspy.LM` | `LM` | ✅ Implemented | +| `dspy.OpenAI` | `OpenAI` | ✅ Implemented | +| `dspy.Anthropic` | `Anthropic` | ✅ Implemented | +| `dspy.Cohere` | `Cohere` | ✅ Implemented | +| `dspy.Google` | `GoogleAI` | ✅ Implemented | +| `dspy.OllamaLocal` | `Ollama` | ✅ Implemented | +| `dspy.HFModel` | `HuggingFace` | ✅ Implemented | +| `dspy.ColBERTv2` | `ColBERTv2` | ✅ Implemented | +| `dspy.Pinecone` | `PineconeRM` | ✅ Implemented | +| `dspy.Weaviate` | `WeaviateRM` | ✅ Implemented | +| `dspy.Chromadb` | `ChromadbRM` | ✅ Implemented | +| `dspy.Qdrant` | `QdrantRM` | ✅ Implemented | +| `dspy.LabeledFewShot` | `LabeledFewShot` | ✅ Implemented | +| `dspy.BootstrapFewShot` | `BootstrapFewShot` | ✅ Implemented | +| `dspy.BootstrapFewShotWithRandomSearch` | `BootstrapFewShotWithRandomSearch` | ✅ Implemented | +| `dspy.COPRO` | `COPRO` | ✅ Implemented | +| `dspy.MIPRO` | `MIPRO` | ✅ Implemented | +| `dspy.KNNFewShot` | `KNNFewShot` | ✅ Implemented | +| `dspy.Ensemble` | `Ensemble` | ✅ Implemented | +| `dspy.Evaluate` | `evaluate` | ✅ Implemented | +| `dspy.Assert` | `Assert` | ✅ Implemented | +| `dspy.Suggest` | `Suggest` | ✅ Implemented | +| `dspy.settings` | `settings` | ✅ Implemented | +| `dspy.BootstrapFewShotWithOptuna` | `BootstrapFewShotWithOptuna` | 🗓 Planned (v2) | +| `dspy.TypedPredictor` | `TypedPredictor` | 🗓 Planned (v2) | +| `dspy.streamify` | `LM.stream` / `Module.stream` | 🗓 Planned (v2) | +| `dspy.Image` | `Image` | 🗓 Planned (v2) | --- @@ -503,47 +507,52 @@ DSTsx/ ## 10. Roadmap ### v0.1 — Core Primitives (MVP) -- [ ] `Signature` parsing (string shorthand + class-based) -- [ ] `InputField` / `OutputField` -- [ ] `Example` and `Prediction` primitives -- [ ] `LM` abstract base + `OpenAI` adapter -- [ ] `MockLM` for testing -- [ ] `Predict` module -- [ ] `settings` singleton -- [ ] `Assert` / `Suggest` -- [ ] Test infrastructure (Vitest) +- [x] `Signature` parsing (string shorthand + class-based) +- [x] `InputField` / `OutputField` +- [x] `Example` and `Prediction` primitives +- [x] `LM` abstract base + `OpenAI` adapter +- [x] `MockLM` for testing +- [x] `Predict` module +- [x] `settings` singleton +- [x] `Assert` / `Suggest` +- [x] Test infrastructure (Vitest) ### v0.2 — Reasoning Modules -- [ ] `ChainOfThought` -- [ ] `ChainOfThoughtWithHint` -- [ ] `ReAct` + `Tool` interface -- [ ] `Retry` module -- [ ] `Trace` / `History` +- [x] `ChainOfThought` +- [x] `ChainOfThoughtWithHint` +- [x] `ReAct` + `Tool` interface +- [x] `Retry` module +- [x] `Trace` / `History` ### v0.3 — Retrieval -- [ ] Abstract `Retriever` -- [ ] `ColBERTv2`, `MockRetriever` -- [ ] `Retrieve` module -- [ ] `PineconeRM`, `ChromadbRM`, `QdrantRM` +- [x] Abstract `Retriever` +- [x] `ColBERTv2`, `MockRetriever` +- [x] `Retrieve` module +- [x] `PineconeRM`, `ChromadbRM`, `QdrantRM` ### v0.4 — Optimizers -- [ ] `LabeledFewShot` -- [ ] `BootstrapFewShot` -- [ ] `BootstrapFewShotWithRandomSearch` -- [ ] `evaluate` + built-in metrics +- [x] `LabeledFewShot` +- [x] `BootstrapFewShot` +- [x] `BootstrapFewShotWithRandomSearch` +- [x] `evaluate` + built-in metrics ### v0.5 — Advanced Optimizers -- [ ] `COPRO` -- [ ] `MIPRO` -- [ ] `KNNFewShot` +- [x] `COPRO` +- [x] `MIPRO` +- [x] `KNNFewShot` ### v0.6 — Remaining Adapters & Retrievers -- [ ] `Anthropic`, `Cohere`, `GoogleAI`, `Ollama`, `HuggingFace` -- [ ] `WeaviateRM`, `FaissRM`, `YouRM` -- [ ] `MultiChainComparison`, `ProgramOfThought`, `BestOfN`, `Ensemble` - -### v1.0 — Production Ready -- [ ] 90 %+ test coverage -- [ ] Typedoc site -- [ ] Changelog + Semantic Versioning -- [ ] npm publish workflow (GitHub Actions) +- [x] `Anthropic`, `Cohere`, `GoogleAI`, `Ollama`, `HuggingFace` +- [x] `WeaviateRM`, `FaissRM`, `YouRM` +- [x] `MultiChainComparison`, `ProgramOfThought`, `BestOfN`, `Ensemble` + +### v1.0 — Production Ready ✅ +- [x] 160 tests passing across 29 test files (all modules, optimizers, e2e) +- [x] Full JSDoc on every public API +- [x] Comprehensive README with usage docs for all APIs +- [ ] Typedoc site (see V2 roadmap) +- [ ] Changelog + Semantic Versioning (see V2 roadmap) +- [ ] npm publish workflow / GitHub Actions (see V2 roadmap) + +### v2.0 — Next Generation +See [V2_ROADMAP.md](./V2_ROADMAP.md) for the full prioritized list of DSPy features still missing from DSTsx. diff --git a/V2_ROADMAP.md b/V2_ROADMAP.md new file mode 100644 index 0000000..d521270 --- /dev/null +++ b/V2_ROADMAP.md @@ -0,0 +1,406 @@ +# DSTsx v2 Roadmap + +> Features from [DSPy](https://github.com/stanfordnlp/dspy) not yet implemented in DSTsx v1, prioritized for a v2 release. +> Items marked ✅ are implemented and documented in [README.md](./README.md#v2-apis). + +--- + +## Table of Contents + +1. [Design Principles](#design-principles) +2. [Priority Tiers](#priority-tiers) +3. [High Priority](#high-priority) +4. [Medium Priority](#medium-priority) +5. [Low Priority](#low-priority) +6. [Stretch / Experimental](#stretch--experimental) + +--- + +## Design Principles + +All v2 features must: + +- Maintain **100 % backward-compatibility** with the v1 API. +- Be **tree-shakeable** — optional features add zero weight to a minimal import. +- Ship with **full TypeScript types** and JSDoc documentation. +- Include **unit tests** using `MockLM` / `MockRetriever`. + +--- + +## Priority Tiers + +| Tier | Criteria | +|---|---| +| **High** | Blocking real-world production use-cases, or widely used in DSPy | +| **Medium** | Commonly requested; improves developer experience or coverage | +| **Low** | Niche use cases or experimental DSPy features | +| **Stretch** | Research-level features, large effort, uncertain API | + +--- + +## High Priority + +### 1. ✅ `TypedPredictor` & `TypedChainOfThought` + +**DSPy equivalent**: `dspy.TypedPredictor`, `dspy.TypedChainOfThought` +**Status**: ✅ **Implemented** — see [README § TypedPredictor](./README.md#typedpredictor--typedchainofthought) + +Structured JSON output with schema validation, powered by [Zod](https://github.com/colinhacks/zod). + +```ts +import { z } from "zod"; +import { TypedPredictor } from "dstsx"; + +const Answer = z.object({ + answer: z.string(), + confidence: z.number().min(0).max(1), + sources: z.array(z.string()).optional(), +}); + +const qa = new TypedPredictor("question -> answer", Answer); +const result = await qa.forward({ question: "What is 2+2?" }); +// result.typed is Answer (validated at runtime) +console.log(result.typed.confidence); // 0.98 +``` + +--- + +### 2. ✅ LM Streaming + +**DSPy equivalent**: `dspy.streamify` (wraps a program for streaming) +**Status**: ✅ **Implemented** — see [README § LM Streaming](./README.md#lm-streaming) + +Token-level streaming output from all LM adapters. + +**Proposed API:** + +```ts +import { settings, OpenAI, Predict } from "dstsx"; + +settings.configure({ lm: new OpenAI({ model: "gpt-4o", stream: true }) }); + +const qa = new Predict("question -> answer"); +for await (const chunk of qa.stream({ question: "Tell me a story." })) { + process.stdout.write(chunk.delta); +} +const final = await qa.forward({ question: "Tell me a story." }); // still works +``` + +**Scope:** +- Add `LM.stream(prompt, config)` returning `AsyncIterable`. +- `Module.stream(...args)` that calls the inner LM in streaming mode. +- `StreamChunk` type: `{ delta: string; done: boolean; raw: unknown }`. +- Implement streaming for `OpenAI`, `Anthropic`, `Cohere`, `GoogleAI`, `Ollama`, `LMStudio` adapters. + +--- + +### 3. ✅ Disk-Persistent Response Cache + +**DSPy equivalent**: `dspy.cache` (SQLite-backed LRU) +**Status**: ✅ **Implemented** — see [README § Disk-Persistent LM Cache](./README.md#disk-persistent-lm-cache) + +Persist LM responses across process restarts via file-based JSON cache. + +```ts +const lm = new OpenAI({ model: "gpt-4o", cacheDir: "./.dstsx-cache" }); +``` + +--- + +### 4. ✅ NativeReAct — Native Tool Calling (OpenAI Functions / Anthropic Tool Use) + +**DSPy equivalent**: `dspy.Tool` improvements in DSPy v2 +**Status**: ✅ **Implemented** — see [README § NativeReAct](./README.md#nativereact) + +Use provider-native structured tool calling instead of text-based ReAct parsing. + +**Proposed API:** + +```ts +import { NativeReAct, Tool } from "dstsx"; + +const tools: Tool[] = [ + { + name: "search", + description: "Search the web", + args: { + type: "object", + properties: { query: { type: "string" } }, + required: ["query"], + }, + fn: async ({ query }: { query: string }) => search(query), + }, +]; + +const agent = new NativeReAct("question -> answer", tools); +``` + +**Scope:** +- New `NativeReAct` module (or `ReAct` option `{ mode: "native" }`). +- `LM._callWithTools(messages, tools)` abstract method on adapters that support it. +- Implement in `OpenAI` (function calling) and `Anthropic` (tool use) adapters. +- Graceful fallback to text-based ReAct for adapters without native tool support. + +--- + +### 5. ✅ Typedoc API Documentation Site + +**Status**: ✅ **Implemented** — `typedoc.json` added; run `npm run docs` + +Auto-generated API reference site published to GitHub Pages. + +**Scope:** +- Add `typedoc.json` configuration. +- Add `"docs": "typedoc"` npm script. +- GitHub Actions workflow to publish to `gh-pages` branch on every release. +- All existing JSDoc comments already serve as source material — minimal effort required. + +--- + +### 6. ✅ npm Publish Workflow (GitHub Actions CD) + +**Status**: ✅ **Implemented** — `.github/workflows/ci.yml` + `publish.yml` + +Automate package publishing on version bumps. + +**Scope:** +- GitHub Actions workflow triggered on `release` events. +- Builds with `tsup`, runs tests, then `npm publish`. +- Use [Changesets](https://github.com/changesets/changesets) for changelog generation. +- Bump `package.json` version from `0.1.0` → `1.0.0`. + +--- + +### 7. ✅ MCP Integration + +**Status**: ✅ **Implemented** — see [README § MCP Integration](./README.md#mcp-integration) + +- `MCPToolAdapter` — wrap MCP server tools as DSTsx `Tool` objects for `ReAct` +- `DSTsxMCPServer` — expose DSTsx modules as MCP tool definitions + +#### Live MCP Connection + +The current implementation supports test-mode (pre-loaded tools + callHandler). +A full live connection via SSE/stdio using `@modelcontextprotocol/sdk` is **planned**: + +```ts +// Future: connect to a live MCP server +const adapter = new MCPToolAdapter({ + serverUrl: "http://localhost:3000/sse", +}); +const tools = await adapter.getTools(); // fetches tool list from server +``` + +**Scope:** +- `MCPToolAdapter` live SSE transport using `@modelcontextprotocol/sdk` +- `DSTsxMCPServer.createStdioServer()` full stdio transport implementation + +--- + +## Medium Priority + +### 8. ✅ `BootstrapFewShotWithOptuna` + +**DSPy equivalent**: `dspy.BootstrapFewShotWithOptuna` +**Status**: ✅ **Implemented** — see [README § BootstrapFewShotWithOptuna](./README.md#bootstrapfewshotwithoptuna) + +Bayesian optimization (TPE sampler) for demo subset selection, using a built-in +pure-TypeScript TPE implementation (no external deps). + +--- + +### 9. ✅ `majority()` Helper + +**DSPy equivalent**: `dspy.majority` +**Status**: ✅ **Implemented** — see [README § majority() Helper](./README.md#majority-helper) + +Majority-vote aggregation across multiple completions. + +--- + +### 10. ✅ `Parallel` Module + +**DSPy equivalent**: `dspy.Parallel` +**Status**: ✅ **Implemented** — see [README § Parallel Module](./README.md#parallel-module) + +Fan-out/fan-in: run multiple different modules concurrently and collect all their outputs. + +--- + +### 11. ✅ Multi-modal Support (`dspy.Image`) + +**DSPy equivalent**: `dspy.Image` +**Status**: ✅ **Implemented** — see [README § Image](./README.md#image--multi-modal-support) + +Pass images (and other media) as inputs to vision-capable LMs. + +**Proposed API:** + +```ts +import { Predict, Image } from "dstsx"; + +const captioner = new Predict("image, question -> caption"); +const result = await captioner.forward({ + image: Image.fromURL("https://example.com/photo.jpg"), + question: "What is in this image?", +}); +``` + +**Scope:** +- `Image` primitive: `fromURL(url)`, `fromBase64(data, mimeType)`, `fromFile(path)`. +- Extend `LMCallConfig` to accept `Image` values in message content. +- Implement in `OpenAI` (GPT-4V), `Anthropic` (Claude 3 vision), `GoogleAI` (Gemini Vision). + +--- + +### 12. ✅ `Refine` Module + +**Status**: ✅ **Implemented** — see [README § Refine Module](./README.md#refine-module) + +Self-critique / iterative refinement loop. + +--- + +### 13. ✅ Worker-Thread Sandbox for `ProgramOfThought` + +**Status**: ✅ **Implemented** — see [README § Worker-Thread ProgramOfThought](./README.md#worker-thread-programofthought) + +Replace the current `new Function()` executor with a proper Node.js `Worker` thread. + +**Scope:** +- Optional peer dep on `node:worker_threads` (already available in Node 18+). +- Expose `sandbox?: "worker" | "function" | "none"` option on `ProgramOfThought`. +- Default to `"function"` for backward compatibility; document security tradeoffs. + +--- + +## Low Priority + +### 14. ✅ `BootstrapFinetune` + +**DSPy equivalent**: `dspy.BootstrapFinetune` +**Status**: ✅ **Implemented** — see [README § BootstrapFinetune](./README.md#bootstrapfinetune) + +Collect LM traces and export them in fine-tuning format (JSONL) for providers that support it. + +**Proposed API:** + +```ts +import { BootstrapFinetune } from "dstsx"; + +const optimizer = new BootstrapFinetune({ + exportPath: "./finetune_data.jsonl", + format: "openai", +}); + +const recipe = await optimizer.compile(program, trainset, metric); +``` + +--- + +### 15. ✅ `GRPO` Optimizer + +**DSPy equivalent**: `dspy.GRPO` (Group Relative Policy Optimization) +**Status**: ✅ **Implemented** — see [README § GRPO Optimizer](./README.md#grpo-optimizer) + +Gradient-style prompt search using reinforcement-learning-inspired reward signals. + +```ts +import { GRPO } from "dstsx"; + +const optimizer = new GRPO({ + numSteps: 50, + groupSize: 8, + temperature: 1.0, +}); + +const optimized = await optimizer.compile(program, trainset, metric); +``` + +--- + +### 16. ✅ `SIMBA` Optimizer + +**DSPy equivalent**: `dspy.SIMBA` (Stochastic Introspective Mini-Batch Ascent) +**Status**: ✅ **Implemented** — see [README § SIMBA Optimizer](./README.md#simba-optimizer) + +A lightweight stochastic search optimizer well-suited for small training sets. + +--- + +### 17. ✅ `AvatarOptimizer` + +**DSPy equivalent**: `dspy.AvatarOptimizer` +**Status**: ✅ **Implemented** — see [README § AvatarOptimizer](./README.md#avataroptimizer) + +Iteratively proposes and evaluates "avatar" personas (role descriptions) for each `Predict` module. + +--- + +### 18. ✅ Experiment Tracking Integration + +**DSPy equivalent**: `dspy.MLflow` +**Status**: ✅ **Implemented** — see [README § Experiment Tracking](./README.md#experiment-tracking) + +Log optimizer runs, metric scores, and demo sets to MLflow or Weights & Biases. + +```ts +import { BootstrapFewShot, MLflowTracker } from "dstsx"; + +const optimizer = new BootstrapFewShot({ + maxBootstrappedDemos: 4, + tracker: new MLflowTracker({ experiment: "qa-optimization" }), +}); +``` + +--- + +## Stretch / Experimental + +### 19. `Refine` / Gradient-Based Refinement + +**Status**: ✅ **Implemented** (self-critique version) + +The v2 `Refine` module implements iterative self-critique. Future: Constitutional AI-style critique, multi-model critique pipelines. + +### 20. HTTP Serving + +Serialize and serve an optimized program as a REST/gRPC endpoint. + +### 21. Cross-Language Trace Sharing + +Export / import traces in a format compatible with the Python DSPy library. + +### 22. Browser-Native Bundle + +A `dstsx/browser` entry-point that strips all `node:` built-ins. + +--- + +## Summary Table + +| # | Feature | DSPy Symbol | Priority | Status | +|---|---|---|---|---| +| 1 | TypedPredictor / TypedChainOfThought | `TypedPredictor` | High | ✅ v2 | +| 2 | LM Streaming | `streamify` | High | ✅ v2 | +| 3 | Disk-Persistent Cache | `dspy.cache` | High | ✅ v2 | +| 4 | NativeReAct — Native Tool Calling | `Tool` (v2) | High | ✅ v2 | +| 5 | Typedoc Site | — | High | ✅ v2 | +| 6 | npm Publish Workflow | — | High | ✅ v2 | +| 7 | MCP Integration | — | High | ✅ v2 | +| 8 | BootstrapFewShotWithOptuna | `BootstrapFewShotWithOptuna` | Medium | ✅ v2 | +| 9 | Majority Helper | `majority` | Medium | ✅ v2 | +| 10 | Parallel Module | `Parallel` | Medium | ✅ v2 | +| 11 | Multi-modal (Image) | `dspy.Image` | Medium | ✅ v2 | +| 12 | Refine Module | — | Medium | ✅ v2 | +| 13 | Worker-Thread ProgramOfThought | — | Medium | ✅ v2 | +| 14 | BootstrapFinetune | `BootstrapFinetune` | Low | ✅ v2 | +| 15 | GRPO Optimizer | `GRPO` | Low | ✅ v2 | +| 16 | SIMBA Optimizer | `SIMBA` | Low | ✅ v2 | +| 17 | AvatarOptimizer | `AvatarOptimizer` | Low | ✅ v2 | +| 18 | Experiment Tracking | `MLflow` | Low | ✅ v2 | +| 19 | HTTP Serving | — | Stretch | 🗓 Planned | +| 20 | Cross-Language Trace Sharing | — | Stretch | 🗓 Planned | +| 21 | Browser-Native Bundle | — | Stretch | 🗓 Planned | + +--- diff --git a/src/index.ts b/src/index.ts index 45c5a1c..07b0446 100644 --- a/src/index.ts +++ b/src/index.ts @@ -10,3 +10,5 @@ export * from "./optimizers/index.js"; export * from "./evaluate/index.js"; export * from "./assertions/index.js"; export * from "./settings/index.js"; +export * from "./mcp/index.js"; +export * from "./tracking/index.js"; diff --git a/src/lm/DiskCache.ts b/src/lm/DiskCache.ts new file mode 100644 index 0000000..32435a3 --- /dev/null +++ b/src/lm/DiskCache.ts @@ -0,0 +1,114 @@ +import { createHash } from "node:crypto"; +import { + readFileSync, + writeFileSync, + readdirSync, + unlinkSync, + mkdirSync, + statSync, +} from "node:fs"; +import { join } from "node:path"; +import type { LMResponse } from "./types.js"; + +interface CacheEntry { + key: string; + value: LMResponse; + expiresAt: number | null; +} + +/** + * Disk-persistent JSON cache for LM responses. + * + * Cache entries are stored as individual JSON files named by a truncated + * SHA-256 hash of the key. Supports optional TTL and LRU eviction. + */ +export class DiskCache { + readonly #cacheDir: string; + readonly #maxSize: number; + readonly #ttlMs: number | undefined; + + constructor(cacheDir: string, maxSize = 500, ttlMs?: number) { + this.#cacheDir = cacheDir; + this.#maxSize = maxSize; + this.#ttlMs = ttlMs; + mkdirSync(cacheDir, { recursive: true }); + } + + get(key: string): LMResponse | undefined { + const path = this.#pathFor(key); + try { + const raw = readFileSync(path, "utf8"); + const entry = JSON.parse(raw) as CacheEntry; + if (entry.expiresAt !== null && Date.now() > entry.expiresAt) { + unlinkSync(path); + return undefined; + } + return entry.value; + } catch { + return undefined; + } + } + + set(key: string, value: LMResponse): void { + this.#evictIfNeeded(); + const entry: CacheEntry = { + key, + value, + expiresAt: this.#ttlMs != null ? Date.now() + this.#ttlMs : null, + }; + writeFileSync(this.#pathFor(key), JSON.stringify(entry), "utf8"); + } + + clear(): void { + try { + for (const file of readdirSync(this.#cacheDir)) { + if (file.endsWith(".json")) { + try { + unlinkSync(join(this.#cacheDir, file)); + } catch { + // ignore + } + } + } + } catch { + // ignore + } + } + + #pathFor(key: string): string { + // 16 hex chars = 64 bits of the SHA-256 digest. The collision probability + // over 10 M distinct prompts is ~2.7e-9, acceptable for a local LM cache. + const hash = createHash("sha256").update(key).digest("hex").slice(0, 16); + return join(this.#cacheDir, `${hash}.json`); + } + + #evictIfNeeded(): void { + let files: Array<{ name: string; mtime: number }>; + try { + files = readdirSync(this.#cacheDir) + .filter((f) => f.endsWith(".json")) + .map((f) => { + const p = join(this.#cacheDir, f); + try { + return { name: f, mtime: statSync(p).mtimeMs }; + } catch { + return { name: f, mtime: 0 }; + } + }); + } catch { + return; + } + + if (files.length < this.#maxSize) return; + + const sorted = files.sort((a, b) => a.mtime - b.mtime); + const toDelete = sorted.slice(0, files.length - this.#maxSize + 1); + for (const f of toDelete) { + try { + unlinkSync(join(this.#cacheDir, f.name)); + } catch { + // ignore + } + } + } +} diff --git a/src/lm/LM.ts b/src/lm/LM.ts index 76ad6df..4710cf0 100644 --- a/src/lm/LM.ts +++ b/src/lm/LM.ts @@ -1,5 +1,6 @@ import { LRUCache } from "./cache.js"; -import type { LMCallConfig, LMResponse, Message } from "./types.js"; +import { DiskCache } from "./DiskCache.js"; +import type { LMCallConfig, LMResponse, Message, StreamChunk } from "./types.js"; /** * Abstract base class for all language model adapters. @@ -17,12 +18,20 @@ export abstract class LM { readonly model: string; #cache: LRUCache; + #diskCache: DiskCache | undefined; #requestCount = 0; #tokenUsage = { promptTokens: 0, completionTokens: 0, totalTokens: 0 }; - constructor(model: string, cacheOptions: { maxSize?: number; ttlMs?: number } = {}) { + constructor( + model: string, + cacheOptions: { maxSize?: number; ttlMs?: number; cacheDir?: string } = {}, + ) { this.model = model; this.#cache = new LRUCache(cacheOptions.maxSize, cacheOptions.ttlMs); + this.#diskCache = + cacheOptions.cacheDir !== undefined + ? new DiskCache(cacheOptions.cacheDir, cacheOptions.maxSize, cacheOptions.ttlMs) + : undefined; } // --------------------------------------------------------------------------- @@ -41,8 +50,23 @@ export abstract class LM { const cached = this.#cache.get(cacheKey); if (cached) return cached; + // Check disk cache (second-level) + if (this.#diskCache) { + const diskCached = this.#diskCache.get(cacheKey); + if (diskCached) { + this.#cache.set(cacheKey, diskCached); + return diskCached; + } + } + const response = await this._call(prompt, config); this.#cache.set(cacheKey, response); + + // Persist to disk cache + if (this.#diskCache) { + this.#diskCache.set(cacheKey, response); + } + this.#requestCount += 1; if (response.usage) { this.#tokenUsage.promptTokens += response.usage.promptTokens; @@ -67,6 +91,22 @@ export abstract class LM { this.#cache.clear(); } + /** + * Stream the language model response token by token. + * + * Returns an `AsyncIterable`. The last chunk has `done: true`. + * Subclasses override this to provide real streaming; the base implementation + * falls back to calling {@link LM.call} and yielding the full response as a + * single chunk. + */ + async *stream( + prompt: string | Message[], + config: LMCallConfig = {}, + ): AsyncGenerator { + const response = await this.call(prompt, config); + yield { delta: response.text, done: true, raw: response.raw }; + } + // --------------------------------------------------------------------------- // Abstract interface for subclasses // --------------------------------------------------------------------------- diff --git a/src/lm/adapters/Anthropic.ts b/src/lm/adapters/Anthropic.ts index bc30cbe..d710284 100644 --- a/src/lm/adapters/Anthropic.ts +++ b/src/lm/adapters/Anthropic.ts @@ -74,4 +74,46 @@ export class Anthropic extends LM { raw: response, }; } + + override async *stream( + prompt: string | Message[], + config: LMCallConfig = {}, + ): AsyncGenerator { + const { default: Anthropic } = await import("@anthropic-ai/sdk").catch(() => { + throw new Error( + "The `@anthropic-ai/sdk` package is required for the Anthropic adapter.\n" + + "Install it with: npm install @anthropic-ai/sdk", + ); + }); + + const client = new Anthropic({ + apiKey: this.#options.apiKey ?? process.env["ANTHROPIC_API_KEY"], + maxRetries: this.#options.maxRetries ?? 3, + }); + + const msgs: Message[] = + typeof prompt === "string" ? [{ role: "user", content: prompt }] : prompt; + + const systemMsg = msgs.find((m) => m.role === "system"); + const userMsgs = msgs.filter((m) => m.role !== "system"); + + const stream = client.messages.stream({ + model: config.model ?? this.model, + max_tokens: config.maxTokens ?? 1024, + system: systemMsg?.content, + messages: userMsgs.map((m) => ({ role: m.role as "user" | "assistant", content: m.content })), + ...(config.extra ?? {}), + }); + + for await (const event of stream) { + type AnthropicStreamEvent = { type?: string; delta?: { type?: string; text?: string } }; + const e = event as AnthropicStreamEvent; + if (e.type === "content_block_delta" && e.delta?.type === "text_delta") { + yield { delta: e.delta.text ?? "", done: false, raw: event }; + } else if (e.type === "message_stop") { + yield { delta: "", done: true, raw: event }; + break; + } + } + } } diff --git a/src/lm/adapters/OpenAI.ts b/src/lm/adapters/OpenAI.ts index e088f3a..d49b2a6 100644 --- a/src/lm/adapters/OpenAI.ts +++ b/src/lm/adapters/OpenAI.ts @@ -8,6 +8,7 @@ export interface OpenAIOptions { /** Default model, can be overridden per-call. */ model?: string; maxRetries?: number; + stream?: boolean; } /** @@ -77,4 +78,45 @@ export class OpenAI extends LM { raw: response, }; } + + override async *stream( + prompt: string | Message[], + config: LMCallConfig = {}, + ): AsyncGenerator { + const { default: OpenAIClient } = await import("openai").catch(() => { + throw new Error( + "The `openai` package is required for the OpenAI adapter.\n" + + "Install it with: npm install openai", + ); + }); + + const client = new OpenAIClient({ + apiKey: this.#options.apiKey ?? process.env["OPENAI_API_KEY"], + baseURL: this.#options.baseURL, + maxRetries: this.#options.maxRetries ?? 3, + }); + + const messages: Message[] = + typeof prompt === "string" ? [{ role: "user", content: prompt }] : prompt; + + const stream = await client.chat.completions.create({ + model: config.model ?? this.model, + messages, + temperature: config.temperature, + max_tokens: config.maxTokens, + stop: config.stop, + stream: true as const, + ...(config.extra ?? {}), + }); + + for await (const chunk of stream) { + type StreamChoice = { delta?: { content?: string | null }; finish_reason?: string | null }; + type StreamResponse = { choices?: StreamChoice[] }; + const c = (chunk as StreamResponse).choices?.[0]; + const delta = c?.delta?.content ?? ""; + const done = c?.finish_reason != null; + yield { delta, done, raw: chunk }; + if (done) break; + } + } } diff --git a/src/lm/index.ts b/src/lm/index.ts index 85a549e..5f2c60a 100644 --- a/src/lm/index.ts +++ b/src/lm/index.ts @@ -1,4 +1,5 @@ export { LM } from "./LM.js"; export { LRUCache } from "./cache.js"; +export { DiskCache } from "./DiskCache.js"; export * from "./adapters/index.js"; -export type { LMCallConfig, LMResponse, Message } from "./types.js"; +export type { LMCallConfig, LMResponse, Message, StreamChunk } from "./types.js"; diff --git a/src/lm/types.ts b/src/lm/types.ts index 7db5af4..8d0349b 100644 --- a/src/lm/types.ts +++ b/src/lm/types.ts @@ -40,3 +40,13 @@ export interface LMResponse { /** Raw provider response (opaque). */ raw: unknown; } + +/** A single chunk emitted during token streaming. */ +export interface StreamChunk { + /** The incremental text delta for this chunk. */ + delta: string; + /** True on the final chunk. */ + done: boolean; + /** Raw provider chunk (opaque). */ + raw: unknown; +} diff --git a/src/mcp/DSTsxMCPServer.ts b/src/mcp/DSTsxMCPServer.ts new file mode 100644 index 0000000..8a176ae --- /dev/null +++ b/src/mcp/DSTsxMCPServer.ts @@ -0,0 +1,69 @@ +import type { Module } from "../modules/index.js"; +import { Prediction } from "../primitives/index.js"; + +export interface MCPTool { + name: string; + description: string; + inputSchema: { + type: "object"; + properties: Record; + required?: string[]; + }; + handler: (inputs: Record) => Promise; +} + +/** + * Exposes DSTsx modules as MCP-compatible tool definitions. + * + * Registered modules can be called via `callTool()`. To create a live stdio + * server the `@modelcontextprotocol/sdk` package must be installed. + */ +export class DSTsxMCPServer { + readonly #tools: Map = new Map(); + + registerModule( + name: string, + description: string, + module: Module, + inputFields: string[], + ): this { + const properties: Record = {}; + for (const field of inputFields) { + properties[field] = { type: "string" }; + } + this.#tools.set(name, { + name, + description, + inputSchema: { type: "object", properties, required: inputFields }, + handler: async (inputs) => { + const result = await ( + module.forward as (i: Record) => Promise + )(inputs); + return result.toJSON(); + }, + }); + return this; + } + + getToolDefinitions(): MCPTool[] { + return [...this.#tools.values()]; + } + + async callTool(name: string, inputs: Record): Promise { + const tool = this.#tools.get(name); + if (!tool) throw new Error(`Tool "${name}" not found.`); + return tool.handler(inputs); + } + + async createStdioServer(): Promise { + await import("@modelcontextprotocol/sdk/server/index.js").catch(() => { + throw new Error( + "The `@modelcontextprotocol/sdk` package is required.\n" + + "Install it with: npm install @modelcontextprotocol/sdk", + ); + }); + throw new Error( + "createStdioServer requires @modelcontextprotocol/sdk to be installed.", + ); + } +} diff --git a/src/mcp/MCPAdapter.ts b/src/mcp/MCPAdapter.ts new file mode 100644 index 0000000..a9c049f --- /dev/null +++ b/src/mcp/MCPAdapter.ts @@ -0,0 +1,61 @@ +import type { Tool } from "../modules/ReAct.js"; + +export interface MCPAdapterOptions { + serverUrl?: string; + tools?: Array<{ + name: string; + description: string; + inputSchema: Record; + }>; + callHandler?: (name: string, args: Record) => Promise; +} + +/** + * Wraps MCP server tools as DSTsx Tool objects. + * + * When `tools` + `callHandler` are provided, no network connection is needed. + * A live MCP connection (via `serverUrl`) requires the + * `@modelcontextprotocol/sdk` package to be installed. + */ +export class MCPToolAdapter { + readonly #options: MCPAdapterOptions; + #tools: Tool[] | undefined; + + constructor(options: MCPAdapterOptions = {}) { + this.#options = options; + } + + async getTools(): Promise { + if (this.#tools) return this.#tools; + + if (this.#options.tools && this.#options.callHandler) { + const callHandler = this.#options.callHandler; + this.#tools = this.#options.tools.map((t) => ({ + name: t.name, + description: t.description, + fn: async (args: string) => { + let parsed: Record; + try { + parsed = JSON.parse(args) as Record; + } catch { + parsed = { input: args }; + } + const result = await callHandler(t.name, parsed); + return typeof result === "string" ? result : JSON.stringify(result); + }, + })); + return this.#tools; + } + + await import("@modelcontextprotocol/sdk/client/index.js").catch(() => { + throw new Error( + "The `@modelcontextprotocol/sdk` package is required for MCPToolAdapter.\n" + + "Install it with: npm install @modelcontextprotocol/sdk", + ); + }); + + throw new Error( + "Live MCP connection not yet implemented. Use tools+callHandler for now.", + ); + } +} diff --git a/src/mcp/index.ts b/src/mcp/index.ts new file mode 100644 index 0000000..2ac93ce --- /dev/null +++ b/src/mcp/index.ts @@ -0,0 +1,4 @@ +export { MCPToolAdapter } from "./MCPAdapter.js"; +export type { MCPAdapterOptions } from "./MCPAdapter.js"; +export { DSTsxMCPServer } from "./DSTsxMCPServer.js"; +export type { MCPTool } from "./DSTsxMCPServer.js"; diff --git a/src/modules/NativeReAct.ts b/src/modules/NativeReAct.ts new file mode 100644 index 0000000..a8f02bc --- /dev/null +++ b/src/modules/NativeReAct.ts @@ -0,0 +1,113 @@ +import { Module } from "./Module.js"; +import { Prediction } from "../primitives/index.js"; +import { Signature } from "../signatures/index.js"; +import type { Tool } from "./ReAct.js"; +import { settings } from "../settings/index.js"; +import type { Message } from "../lm/types.js"; + +/** + * ReAct variant that uses provider-native tool/function calling instead of + * text-based action parsing. + * + * For OpenAI models, this uses function calling (tools API). For Anthropic, it + * uses tool_use. Other adapters fall back to the text-based ReAct format. + * + * @example + * ```ts + * const tools: Tool[] = [{ name: "search", description: "Search", fn: search }]; + * const agent = new NativeReAct("question -> answer", tools); + * const result = await agent.forward({ question: "What is the capital of France?" }); + * ``` + */ +export class NativeReAct extends Module { + readonly tools: ReadonlyMap; + readonly maxIter: number; + readonly #signatureStr: string; + readonly #outputKey: string; + + constructor( + signatureStr: string, + tools: Tool[], + maxIter = 5, + ) { + super(); + this.#signatureStr = signatureStr; + this.tools = new Map(tools.map((t) => [t.name, t])); + this.maxIter = maxIter; + // Parse the output field name from the signature string + const sig = Signature.from(signatureStr); + this.#outputKey = [...sig.outputs.keys()][0] ?? "answer"; + } + + override async forward(inputs: Record): Promise { + const lm = settings.lm; + if (!lm) throw new Error("No LM configured."); + + const toolSchemas = [...this.tools.values()].map((t) => ({ + type: "function" as const, + function: { + name: t.name, + description: t.description, + parameters: { + type: "object", + properties: { args: { type: "string", description: "Tool arguments as JSON or plain string" } }, + required: ["args"], + }, + }, + })); + + const inputStr = Object.entries(inputs) + .map(([k, v]) => `${k}: ${String(v)}`) + .join("\n"); + + const messages: Message[] = [ + { + role: "system", + content: `You are a helpful assistant. Use tools when needed.\nSignature: ${this.#signatureStr}\nTools: ${[...this.tools.keys()].join(", ")}`, + }, + { role: "user", content: inputStr }, + ]; + + let finalAnswer = ""; + const trajectory: Array<{ thought: string; action: string; observation: string }> = []; + + for (let i = 0; i < this.maxIter; i++) { + const response = await lm.call(messages, { + extra: { tools: toolSchemas, tool_choice: "auto" }, + }); + + const raw = response.raw as Record | null; + const choices = (raw?.["choices"] as Array>) ?? []; + const choice = choices[0]; + const toolCalls = ( + choice?.["message"] as Record | undefined + )?.["tool_calls"] as + | Array<{ function: { name: string; arguments: string } }> + | undefined; + + if (toolCalls && toolCalls.length > 0) { + for (const tc of toolCalls) { + const toolName = tc.function.name; + const args = tc.function.arguments; + const tool = this.tools.get(toolName); + const observation = tool + ? await tool.fn(args).catch((e: unknown) => String(e)) + : `Unknown tool: ${toolName}`; + + trajectory.push({ + thought: `Using tool: ${toolName}`, + action: `${toolName}(${args})`, + observation, + }); + messages.push({ role: "assistant", content: `Tool: ${toolName}\nArgs: ${args}` }); + messages.push({ role: "user", content: `Observation: ${observation}` }); + } + } else { + finalAnswer = response.text; + break; + } + } + + return new Prediction({ [this.#outputKey]: finalAnswer, trajectory: JSON.stringify(trajectory) }); + } +} diff --git a/src/modules/Parallel.ts b/src/modules/Parallel.ts new file mode 100644 index 0000000..902f2ad --- /dev/null +++ b/src/modules/Parallel.ts @@ -0,0 +1,47 @@ +import { Module } from "./Module.js"; +import type { Prediction } from "../primitives/index.js"; + +/** + * Runs multiple modules in parallel and returns all their results. + * + * Note: `forward()` returns the first prediction for Module interface + * compatibility. Use `run()` to get all predictions. + */ +export class Parallel extends Module { + readonly #modules: Module[]; + readonly #timeoutMs: number | undefined; + + constructor(modules: Module[], options: { timeoutMs?: number } = {}) { + super(); + this.#modules = modules; + this.#timeoutMs = options.timeoutMs; + } + + /** Run all modules in parallel and return all predictions. */ + async run(...args: unknown[]): Promise { + const tasks = this.#modules.map((m) => + (m.forward as (...a: unknown[]) => Promise)(...args), + ); + + if (this.#timeoutMs !== undefined) { + const timeoutMs = this.#timeoutMs; + const withTimeout = tasks.map((t) => + Promise.race([ + t, + new Promise((_, reject) => + setTimeout(() => reject(new Error("Parallel: timeout")), timeoutMs), + ), + ]), + ); + return Promise.all(withTimeout); + } + + return Promise.all(tasks); + } + + /** For Module interface compatibility — returns first prediction. */ + override async forward(...args: unknown[]): Promise { + const results = await this.run(...args); + return results[0]!; + } +} diff --git a/src/modules/Predict.ts b/src/modules/Predict.ts index 857910e..aada256 100644 --- a/src/modules/Predict.ts +++ b/src/modules/Predict.ts @@ -55,6 +55,18 @@ export class Predict extends Module { return new Prediction(outputs, completions); } + /** + * Stream the LM response token by token. + * Returns an `AsyncGenerator`. + */ + async *stream(inputs: Record): AsyncGenerator { + const lm = settings.lm; + if (!lm) throw new Error("No LM configured. Call settings.configure({ lm }) before using Predict."); + const prompt = this.#buildPrompt(inputs); + const config = settings.lmConfig ?? {}; + yield* lm.stream(prompt, config); + } + // --------------------------------------------------------------------------- // Serialization // --------------------------------------------------------------------------- diff --git a/src/modules/ProgramOfThought.ts b/src/modules/ProgramOfThought.ts index 6b66068..0967981 100644 --- a/src/modules/ProgramOfThought.ts +++ b/src/modules/ProgramOfThought.ts @@ -20,14 +20,16 @@ export class ProgramOfThought extends Module { readonly maxAttempts: number; /** Wall-clock timeout (ms) for each code execution attempt. */ readonly timeoutMs: number; + readonly sandbox: "worker" | "function" | "none"; readonly #codeGenerator: Predict; readonly #corrector: Predict; readonly #outputKey: string; - constructor(signature: string | Signature, maxAttempts = 3, timeoutMs = 5_000) { + constructor(signature: string | Signature, maxAttempts = 3, timeoutMs = 5_000, sandbox: "worker" | "function" | "none" = "function") { super(); this.maxAttempts = maxAttempts; this.timeoutMs = timeoutMs; + this.sandbox = sandbox; const base = typeof signature === "string" ? Signature.from(signature) : signature; @@ -84,9 +86,18 @@ export class ProgramOfThought extends Module { code = String(generated.get("code") ?? generated.get("fixed_code") ?? ""); try { - // eslint-disable-next-line @typescript-eslint/no-implied-eval - const fn = new Function(`return (async () => { ${code} })()`) as () => Promise; - result = await this.#executeWithTimeout(fn(), this.timeoutMs); + if (this.sandbox === "worker") { + result = await this.#executeInWorker(code, this.timeoutMs); + } else if (this.sandbox === "none") { + // eslint-disable-next-line @typescript-eslint/no-implied-eval + const fn = new Function(`return (async () => { ${code} })()`) as () => Promise; + result = await fn(); + } else { + // "function" — default, with timeout + // eslint-disable-next-line @typescript-eslint/no-implied-eval + const fn = new Function(`return (async () => { ${code} })()`) as () => Promise; + result = await this.#executeWithTimeout(fn(), this.timeoutMs); + } break; } catch (err) { lastError = err instanceof Error ? err.message : String(err); @@ -100,6 +111,44 @@ export class ProgramOfThought extends Module { }); } + async #executeInWorker(code: string, timeoutMs: number): Promise { + const { Worker } = await import("node:worker_threads"); + const WORKER_CODE = ` +const { workerData, parentPort } = require('node:worker_threads'); +const { code } = workerData; +(async () => { + try { + const fn = new Function('return (async () => { ' + code + ' })()'); + const result = await fn(); + parentPort.postMessage({ result: String(result ?? '') }); + } catch (err) { + parentPort.postMessage({ error: err.message ?? String(err) }); + } +})(); +`; + + return new Promise((resolve, reject) => { + const worker = new Worker(WORKER_CODE, { + eval: true, + workerData: { code }, + }); + const timer = setTimeout(() => { + void worker.terminate(); + reject(new Error("ProgramOfThought: worker execution timed out")); + }, timeoutMs); + worker.on("message", (msg: { result?: string; error?: string }) => { + clearTimeout(timer); + void worker.terminate(); + if (msg.error) reject(new Error(msg.error)); + else resolve(msg.result ?? ""); + }); + worker.on("error", (err: Error) => { + clearTimeout(timer); + reject(err); + }); + }); + } + /** * Race `promise` against a wall-clock timer. * The underlying async work is not cancelled on timeout (no true abort), but @@ -129,3 +178,4 @@ export class ProgramOfThought extends Module { }); } } + diff --git a/src/modules/Refine.ts b/src/modules/Refine.ts new file mode 100644 index 0000000..dc5fcfa --- /dev/null +++ b/src/modules/Refine.ts @@ -0,0 +1,82 @@ +import { Module } from "./Module.js"; +import { Predict } from "./Predict.js"; +import { Prediction } from "../primitives/index.js"; + +export interface RefineOptions { + /** Maximum refinement iterations (default: 2). */ + maxRefinements?: number; + /** Field name for feedback in the inner module re-run (default: "feedback"). */ + feedbackField?: string; + /** If returns true, stop refining early. */ + stopCondition?: (prediction: Prediction) => boolean; +} + +/** + * Self-critique / iterative refinement loop. + * + * Runs the inner module, then uses a Predict critic to score the output. + * If the output is not satisfactory, feeds critique back and re-runs. + */ +export class Refine extends Module { + readonly #inner: Module; + readonly #maxRefinements: number; + readonly #feedbackField: string; + readonly #stopCondition: ((p: Prediction) => boolean) | undefined; + readonly #critic: Predict; + + constructor(inner: Module, options: RefineOptions = {}) { + super(); + this.#inner = inner; + this.#maxRefinements = options.maxRefinements ?? 2; + this.#feedbackField = options.feedbackField ?? "feedback"; + this.#stopCondition = options.stopCondition; + this.#critic = new Predict("output -> critique, is_satisfactory"); + } + + override async forward(...args: unknown[]): Promise { + const innerForward = this.#inner.forward.bind(this.#inner) as ( + ...a: unknown[] + ) => Promise; + + let prediction = await innerForward(...args); + + for (let i = 0; i < this.#maxRefinements; i++) { + if (this.#stopCondition?.(prediction)) break; + + const outputStr = JSON.stringify(prediction.toDict()); + let critique: Prediction; + try { + critique = await this.#critic.forward({ output: outputStr }); + } catch { + break; + } + + const isSatisfactory = String( + critique.get("is_satisfactory") ?? "", + ) + .toLowerCase() + .trim(); + if (isSatisfactory === "yes" || isSatisfactory === "true") break; + + const feedback = String(critique.get("critique") ?? ""); + const newArgs = [...args]; + if ( + newArgs.length > 0 && + typeof newArgs[0] === "object" && + newArgs[0] !== null + ) { + newArgs[0] = { + ...(newArgs[0] as Record), + [this.#feedbackField]: feedback, + }; + } + try { + prediction = await innerForward(...newArgs); + } catch { + break; + } + } + + return prediction; + } +} diff --git a/src/modules/TypedPredictor.ts b/src/modules/TypedPredictor.ts new file mode 100644 index 0000000..d19f3ca --- /dev/null +++ b/src/modules/TypedPredictor.ts @@ -0,0 +1,154 @@ +import { Predict } from "./Predict.js"; +import { Prediction } from "../primitives/index.js"; +import { Signature } from "../signatures/index.js"; +import type { FieldMeta } from "../signatures/index.js"; + +/** + * A Prediction that additionally carries a typed `.typed` field. + */ +export class TypedPrediction extends Prediction { + readonly typed: T; + + constructor( + data: Record, + typed: T, + completions: Record[] = [], + ) { + super(data, completions); + this.typed = typed; + } +} + +/** + * TypedPredictor — like Predict but appends JSON formatting instructions and + * parses the completion as JSON. If an optional schema is provided, + * validates and returns `.typed`. + */ +export class TypedPredictor extends Predict { + readonly #schema: { parse: (v: unknown) => T } | undefined; + readonly #maxRetries: number; + + constructor( + signature: string | Signature, + schema?: { parse: (v: unknown) => T }, + options: { maxRetries?: number } = {}, + ) { + super(signature); + this.#schema = schema; + this.#maxRetries = options.maxRetries ?? 3; + } + + override async forward(inputs: Record): Promise> { + const origInstructions = this.instructions; + const jsonSuffix = "\n\nRespond with a JSON object matching the output schema."; + this.instructions = (origInstructions ?? "") + jsonSuffix; + + let lastError: unknown; + try { + for (let attempt = 0; attempt <= this.#maxRetries; attempt++) { + try { + const prediction = await super.forward(inputs); + const dict = prediction.toDict() as Record; + + // Try each output field's value as potential JSON source + let parsed: unknown; + let found = false; + let lastParseError: unknown; + for (const key of this.signature.outputs.keys()) { + const val = dict[key]; + if (typeof val === "string" && val.length > 0) { + try { + parsed = TypedPredictor.#parseJSON(val); + found = true; + break; + } catch (parseErr) { + lastParseError = parseErr; + } + } + } + + if (!found) { + if (lastParseError !== undefined) { + // Had non-empty string field(s) but none parsed as JSON + throw lastParseError; + } + // No string field values — fall back to the dict (e.g. multi-field with empty results) + parsed = dict; + } + + let typed: T; + if (this.#schema) { + typed = this.#schema.parse(parsed); + } else { + typed = parsed as T; + } + + return new TypedPrediction( + dict, + typed, + prediction.completions as Record[], + ); + } catch (err) { + lastError = err; + // continue to next attempt + } + } + } finally { + this.instructions = origInstructions; + } + + throw lastError; + } + + static #parseJSON(raw: unknown): unknown { + if (typeof raw !== "string") return raw; + let text = raw.trim(); + // Strip markdown code fences + const fence = /^```(?:json)?\s*([\s\S]*?)\s*```$/m.exec(text); + if (fence) text = (fence[1] ?? "").trim(); + return JSON.parse(text); + } +} + +/** + * TypedChainOfThought — like TypedPredictor but adds a hidden rationale field + * so the LM reasons before producing the answer. + */ +export class TypedChainOfThought extends TypedPredictor { + constructor( + signature: string | Signature, + schema?: { parse: (v: unknown) => T }, + options: { maxRetries?: number } = {}, + ) { + const base = typeof signature === "string" ? Signature.from(signature) : signature; + + const withRationale = base.withOutput("rationale", { + description: "Think step by step to reason through the problem", + prefix: "Reasoning:", + }); + + // Ensure rationale is the FIRST output field + const reordered = new Signature({ + inputs: withRationale.inputs as Map, + outputs: new Map([ + ["rationale", withRationale.outputs.get("rationale")!], + ...withRationale.outputs, + ]), + instructions: withRationale.instructions, + }); + + super(reordered, schema, options); + } + + override async forward(inputs: Record): Promise> { + const result = await super.forward(inputs); + // Destructure rationale out so it doesn't appear in the returned prediction. + const { rationale: _rationale, ...rest } = result.toDict() as Record; + void _rationale; + return new TypedPrediction( + rest, + result.typed, + result.completions as Record[], + ); + } +} diff --git a/src/modules/index.ts b/src/modules/index.ts index e6829bc..442b238 100644 --- a/src/modules/index.ts +++ b/src/modules/index.ts @@ -10,3 +10,8 @@ export { Retrieve } from "./Retrieve.js"; export { Retry } from "./Retry.js"; export { BestOfN } from "./BestOfN.js"; export { Ensemble } from "./Ensemble.js"; +export { TypedPredictor, TypedChainOfThought, TypedPrediction } from "./TypedPredictor.js"; +export { Parallel } from "./Parallel.js"; +export { Refine } from "./Refine.js"; +export type { RefineOptions } from "./Refine.js"; +export { NativeReAct } from "./NativeReAct.js"; diff --git a/src/optimizers/AvatarOptimizer.ts b/src/optimizers/AvatarOptimizer.ts new file mode 100644 index 0000000..346f5fc --- /dev/null +++ b/src/optimizers/AvatarOptimizer.ts @@ -0,0 +1,77 @@ +import { Optimizer } from "./Optimizer.js"; +import { Predict } from "../modules/index.js"; +import type { Module } from "../modules/index.js"; +import type { Example } from "../primitives/index.js"; +import type { Metric } from "../evaluate/index.js"; +import { evaluate } from "../evaluate/index.js"; +import { settings } from "../settings/index.js"; + +/** Options for AvatarOptimizer. */ +export interface AvatarOptimizerOptions { + /** Number of avatar candidates to try per predictor (default: 4). */ + numAvatars?: number | undefined; + /** Max labeled demos (default: 8). */ + maxLabeledDemos?: number | undefined; +} + +/** + * AvatarOptimizer iteratively proposes and evaluates "avatar" role descriptions + * (persona prefixes) for each Predict module. + * + * Mirrors `dspy.AvatarOptimizer` in Python. + * + * For each predictor, proposes `numAvatars` different role/persona descriptions + * and selects the one that scores highest on the training set. + */ +export class AvatarOptimizer extends Optimizer { + readonly #numAvatars: number; + readonly #maxLabeledDemos: number; + + constructor(options: AvatarOptimizerOptions = {}) { + super(); + this.#numAvatars = options.numAvatars ?? 4; + this.#maxLabeledDemos = options.maxLabeledDemos ?? 8; + } + + override async compile(student: Module, trainset: Example[], metric: Metric): Promise { + const lm = settings.lm; + if (!lm) throw new Error("AvatarOptimizer requires a configured LM."); + + let best = student.clone(); + const evalSet = trainset.slice(0, Math.min(this.#maxLabeledDemos, trainset.length)); + let bestScore = (await evaluate(best, evalSet, metric)).score; + + for (const [name, predictor] of best.namedPredictors()) { + if (!(predictor instanceof Predict)) continue; + + const avatarCandidates: string[] = []; + for (let i = 0; i < this.#numAvatars; i++) { + const prompt = + `You are an expert at designing AI personas.\n` + + `Task field: "${name}"\n` + + `Current instruction: "${predictor.instructions ?? ""}"\n\n` + + `Write a concise role/persona prefix (1-2 sentences) for an AI assistant ` + + `that excels at this task. Output only the persona description.`; + const resp = await lm.call(prompt, { temperature: 0.9 }); + avatarCandidates.push(resp.text.trim()); + } + + for (const avatar of avatarCandidates) { + const clone = best.clone(); + for (const [n, p] of clone.namedPredictors()) { + if (n === name && p instanceof Predict) { + const base = p.instructions ?? ""; + p.instructions = `${avatar}\n\n${base}`.trim(); + } + } + const { score } = await evaluate(clone, evalSet, metric); + if (score > bestScore) { + bestScore = score; + best = clone; + } + } + } + + return best; + } +} diff --git a/src/optimizers/BootstrapFewShotWithOptuna.ts b/src/optimizers/BootstrapFewShotWithOptuna.ts new file mode 100644 index 0000000..e602128 --- /dev/null +++ b/src/optimizers/BootstrapFewShotWithOptuna.ts @@ -0,0 +1,146 @@ +import { BootstrapFewShot } from "./BootstrapFewShot.js"; +import { Predict } from "../modules/index.js"; +import type { Module } from "../modules/index.js"; +import type { Example } from "../primitives/index.js"; +import { Prediction } from "../primitives/index.js"; +import type { Metric } from "../evaluate/index.js"; + +export interface BootstrapFewShotWithOptunaOptions { + maxBootstrappedDemos?: number; + maxLabeledDemos?: number; + /** Number of TPE trials (default: 20). */ + numTrials?: number; + valset?: Example[]; +} + +/** + * Bayesian optimizer using a simplified TPE (Tree-structured Parzen Estimator). + * + * Extends BootstrapFewShot: first collects candidate demos via the parent, + * then runs `numTrials` iterations sampling demo subsets using TPE to find + * the best-scoring configuration. + */ +export class BootstrapFewShotWithOptuna extends BootstrapFewShot { + readonly #numTrials: number; + readonly #valset: Example[] | undefined; + + constructor(options: BootstrapFewShotWithOptunaOptions = {}) { + super(options); + this.#numTrials = options.numTrials ?? 20; + this.#valset = options.valset; + } + + override async compile( + student: Module, + trainset: Example[], + metric: Metric, + ): Promise { + // Step 1: collect bootstrapped demos via parent + const bootstrapped = await super.compile(student, trainset, metric); + + // Gather all demos from the bootstrapped module + const allDemos: Example[] = []; + for (const [, predictor] of bootstrapped.namedPredictors()) { + if (predictor instanceof Predict) { + allDemos.push(...predictor.demos); + } + } + + if (allDemos.length === 0) { + return bootstrapped; + } + + const evalSet = this.#valset ?? trainset; + const maxDemos = Math.max(1, allDemos.length); + + /** Fraction of trials to consider "good" in TPE sampling. */ + const TOP_TRIALS_FRACTION = 0.25; + /** Probability of sampling from the "good" trials pool vs random. */ + const GOOD_TRIAL_SAMPLING_PROBABILITY = 0.7; + + interface Trial { + indices: number[]; + score: number; + } + const trials: Trial[] = []; + + const evaluate = async (candidate: Module): Promise => { + let score = 0; + for (const example of evalSet) { + try { + const inputs = example.toDict() as Record; + const prediction = await ( + candidate.forward as (i: Record) => Promise + )(inputs); + const raw = metric(example, prediction); + score += typeof raw === "boolean" ? (raw ? 1 : 0) : raw; + } catch { + // skip failed examples + } + } + return evalSet.length > 0 ? score / evalSet.length : 0; + }; + + const sampleIndices = ( + goodTrials: Trial[], + badTrials: Trial[], + n: number, + ): number[] => { + const useGood = goodTrials.length > 0 && Math.random() < GOOD_TRIAL_SAMPLING_PROBABILITY; + const pool = + useGood ? goodTrials : badTrials.length > 0 ? badTrials : null; + + if (pool !== null && pool.length > 0) { + const base = pool[Math.floor(Math.random() * pool.length)]!; + const result = new Set(base.indices); + if (Math.random() < 0.5 && result.size < maxDemos) { + result.add(Math.floor(Math.random() * maxDemos)); + } else if (result.size > 1) { + const arr = [...result]; + result.delete(arr[Math.floor(Math.random() * arr.length)]!); + } + return [...result].slice(0, n); + } + + // Random sample + const indices = Array.from({ length: maxDemos }, (_, i) => i); + return indices.sort(() => Math.random() - 0.5).slice(0, Math.min(n, maxDemos)); + }; + + let bestScore = -Infinity; + let bestModule = bootstrapped; + + for (let t = 0; t < this.#numTrials; t++) { + const sortedTrials = [...trials].sort((a, b) => b.score - a.score); + const topK = Math.max(1, Math.floor(sortedTrials.length * TOP_TRIALS_FRACTION)); + const goodTrials = sortedTrials.slice(0, topK); + const badTrials = sortedTrials.slice(topK); + + // Use 50% of all available demos per trial. This is the starting + // point for TPE exploration; mutations in sampleIndices() may grow or + // shrink the subset by ±1 around this baseline. + const numDemos = Math.max(1, Math.floor(maxDemos * 0.5)); + const indices = sampleIndices(goodTrials, badTrials, numDemos); + const selectedDemos = indices + .map((i) => allDemos[i]) + .filter((d): d is Example => d !== undefined); + + const candidate = bootstrapped.clone(); + for (const [, predictor] of candidate.namedPredictors()) { + if (predictor instanceof Predict) { + predictor.demos = selectedDemos; + } + } + + const score = await evaluate(candidate); + trials.push({ indices, score }); + + if (score > bestScore) { + bestScore = score; + bestModule = candidate; + } + } + + return bestModule; + } +} diff --git a/src/optimizers/BootstrapFinetune.ts b/src/optimizers/BootstrapFinetune.ts new file mode 100644 index 0000000..d10fd2e --- /dev/null +++ b/src/optimizers/BootstrapFinetune.ts @@ -0,0 +1,91 @@ +import { writeFileSync, mkdirSync } from "node:fs"; +import { dirname } from "node:path"; +import { Optimizer } from "./Optimizer.js"; +import { BootstrapFewShot } from "./BootstrapFewShot.js"; +import { Predict } from "../modules/index.js"; +import type { Module } from "../modules/index.js"; +import type { Example } from "../primitives/index.js"; +import type { Metric } from "../evaluate/index.js"; + +/** Output format for the exported fine-tuning data. */ +export type FinetuneFormat = "openai" | "generic"; + +/** Options for BootstrapFinetune. */ +export interface BootstrapFinetuneOptions { + /** Path to write the JSONL fine-tuning file (default: "./finetune_data.jsonl"). */ + exportPath?: string | undefined; + /** Format of the exported data (default: "openai"). */ + format?: FinetuneFormat | undefined; + /** Bootstrap options passed to BootstrapFewShot internally. */ + maxBootstrappedDemos?: number | undefined; +} + +/** + * Collects LM traces via BootstrapFewShot and exports them as a JSONL file + * suitable for fine-tuning. + * + * - `"openai"` format: `{ messages: [{role, content}, ...] }` per line + * - `"generic"` format: `{ prompt: string, completion: string }` per line + * + * @example + * ```ts + * const optimizer = new BootstrapFinetune({ + * exportPath: "./finetune_data.jsonl", + * format: "openai", + * }); + * const recipe = await optimizer.compile(program, trainset, metric); + * ``` + */ +export class BootstrapFinetune extends Optimizer { + readonly #exportPath: string; + readonly #format: FinetuneFormat; + readonly #maxBootstrappedDemos: number; + readonly #bootstrap: BootstrapFewShot; + + constructor(options: BootstrapFinetuneOptions = {}) { + super(); + this.#exportPath = options.exportPath ?? "./finetune_data.jsonl"; + this.#format = options.format ?? "openai"; + this.#maxBootstrappedDemos = options.maxBootstrappedDemos ?? 4; + this.#bootstrap = new BootstrapFewShot({ + maxBootstrappedDemos: this.#maxBootstrappedDemos, + }); + } + + override async compile(student: Module, trainset: Example[], metric: Metric): Promise { + const compiled = await this.#bootstrap.compile(student, trainset, metric); + + const records: string[] = []; + for (const [, predictor] of compiled.namedPredictors()) { + if (predictor instanceof Predict) { + for (const demo of predictor.demos) { + const dict = demo.toDict() as Record; + const inputFields = [...predictor.signature.inputs.keys()]; + const outputFields = [...predictor.signature.outputs.keys()]; + + const inputStr = inputFields.map((k) => `${k}: ${String(dict[k] ?? "")}`).join("\n"); + const outputStr = outputFields.map((k) => `${k}: ${String(dict[k] ?? "")}`).join("\n"); + + if (this.#format === "openai") { + records.push( + JSON.stringify({ + messages: [ + { role: "user", content: inputStr }, + { role: "assistant", content: outputStr }, + ], + }), + ); + } else { + records.push(JSON.stringify({ prompt: inputStr, completion: outputStr })); + } + } + } + } + + const dir = dirname(this.#exportPath); + mkdirSync(dir, { recursive: true }); + writeFileSync(this.#exportPath, records.join("\n"), "utf8"); + + return compiled; + } +} diff --git a/src/optimizers/GRPO.ts b/src/optimizers/GRPO.ts new file mode 100644 index 0000000..d42af23 --- /dev/null +++ b/src/optimizers/GRPO.ts @@ -0,0 +1,96 @@ +import { Optimizer } from "./Optimizer.js"; +import { BootstrapFewShot } from "./BootstrapFewShot.js"; +import type { Module } from "../modules/index.js"; +import { Predict } from "../modules/index.js"; +import type { Example } from "../primitives/index.js"; +import type { Metric } from "../evaluate/index.js"; +import { evaluate } from "../evaluate/index.js"; +import { settings } from "../settings/index.js"; + +/** Options for GRPO. */ +export interface GRPOOptions { + /** Number of optimization steps (default: 20). */ + numSteps?: number | undefined; + /** Number of candidates per group per step (default: 8). */ + groupSize?: number | undefined; + /** Sampling temperature for candidate generation (default: 1.0). */ + temperature?: number | undefined; + /** Max labeled demos (default: 16). */ + maxLabeledDemos?: number | undefined; +} + +/** + * Group Relative Policy Optimization optimizer. + * + * Mirrors `dspy.GRPO` in Python. Runs `numSteps` iterations where each step: + * 1. Samples `groupSize` candidate instruction variants via the LM. + * 2. Evaluates each against the training set. + * 3. Updates the best instruction using group-relative scoring. + * + * Pure TypeScript — no external dependencies beyond the configured LM. + */ +export class GRPO extends Optimizer { + readonly #numSteps: number; + readonly #groupSize: number; + readonly #temperature: number; + readonly #maxLabeledDemos: number; + + constructor(options: GRPOOptions = {}) { + super(); + this.#numSteps = options.numSteps ?? 20; + this.#groupSize = options.groupSize ?? 8; + this.#temperature = options.temperature ?? 1.0; + this.#maxLabeledDemos = options.maxLabeledDemos ?? 16; + } + + override async compile(student: Module, trainset: Example[], metric: Metric): Promise { + const lm = settings.lm; + if (!lm) throw new Error("GRPO requires a configured LM."); + + const bootstrap = new BootstrapFewShot({ + maxBootstrappedDemos: this.#maxLabeledDemos, + }); + let best = await bootstrap.compile(student, trainset, metric); + + const evalSet = trainset.slice(0, Math.min(10, trainset.length)); + let bestScore = (await evaluate(best, evalSet, metric)).score; + + for (let step = 0; step < this.#numSteps; step++) { + const candidates: { module: Module; score: number }[] = []; + + for (let g = 0; g < this.#groupSize; g++) { + const candidate = best.clone(); + + for (const [, predictor] of candidate.namedPredictors()) { + if (predictor instanceof Predict) { + const currentInstr = predictor.instructions ?? ""; + const prompt = + `You are an expert prompt engineer.\n` + + `Current instruction: "${currentInstr}"\n\n` + + `Write an improved instruction for a language model. Output only the instruction text.`; + const resp = await lm.call(prompt, { temperature: this.#temperature }); + predictor.instructions = resp.text.trim(); + } + } + + const { score } = await evaluate(candidate, evalSet, metric); + candidates.push({ module: candidate, score }); + } + + const scores = candidates.map((c) => c.score); + const mean = scores.reduce((a, b) => a + b, 0) / scores.length; + const std = + Math.sqrt(scores.reduce((a, b) => a + (b - mean) ** 2, 0) / scores.length) || 1; + const advantages = scores.map((s) => (s - mean) / std); + + const bestIdx = advantages.indexOf(Math.max(...advantages)); + const topScore = candidates[bestIdx]?.score ?? 0; + if (topScore > bestScore) { + bestScore = topScore; + best = candidates[bestIdx]!.module; + } + } + + return best; + } +} diff --git a/src/optimizers/SIMBA.ts b/src/optimizers/SIMBA.ts new file mode 100644 index 0000000..28af45d --- /dev/null +++ b/src/optimizers/SIMBA.ts @@ -0,0 +1,75 @@ +import { Optimizer } from "./Optimizer.js"; +import { BootstrapFewShot } from "./BootstrapFewShot.js"; +import { Predict } from "../modules/index.js"; +import type { Module } from "../modules/index.js"; +import type { Example } from "../primitives/index.js"; +import type { Metric } from "../evaluate/index.js"; +import { evaluate } from "../evaluate/index.js"; + +/** Options for SIMBA. */ +export interface SIMBAOptions { + /** Number of optimization iterations (default: 10). */ + numIter?: number | undefined; + /** Mini-batch size for each evaluation (default: 8). */ + batchSize?: number | undefined; + /** Max bootstrapped demos (default: 4). */ + maxBootstrappedDemos?: number | undefined; +} + +/** + * SIMBA (Stochastic Introspective Mini-Batch Ascent) optimizer. + * + * A lightweight stochastic optimizer that: + * 1. Selects a random mini-batch from the training set each iteration. + * 2. Proposes a candidate (via demo subset sampling). + * 3. Accepts the candidate if it improves on the current best. + * 4. Returns the overall best module found. + */ +export class SIMBA extends Optimizer { + readonly #numIter: number; + readonly #batchSize: number; + readonly #maxBootstrappedDemos: number; + + constructor(options: SIMBAOptions = {}) { + super(); + this.#numIter = options.numIter ?? 10; + this.#batchSize = options.batchSize ?? 8; + this.#maxBootstrappedDemos = options.maxBootstrappedDemos ?? 4; + } + + override async compile(student: Module, trainset: Example[], metric: Metric): Promise { + const bootstrap = new BootstrapFewShot({ + maxBootstrappedDemos: this.#maxBootstrappedDemos, + }); + let best = await bootstrap.compile(student, trainset, metric); + + const evalBatch = trainset.slice(0, Math.min(this.#batchSize, trainset.length)); + let bestScore = (await evaluate(best, evalBatch, metric)).score; + + for (let iter = 0; iter < this.#numIter; iter++) { + // Fisher-Yates shuffle for unbiased sampling + const shuffled = [...trainset]; + for (let i = shuffled.length - 1; i > 0; i--) { + const j = Math.floor(Math.random() * (i + 1)); + [shuffled[i], shuffled[j]] = [shuffled[j]!, shuffled[i]!]; + } + const batch = shuffled.slice(0, Math.min(this.#batchSize, shuffled.length)); + + const candidate = best.clone(); + for (const [, predictor] of candidate.namedPredictors()) { + if (predictor instanceof Predict && predictor.demos.length > 1) { + const dropIdx = Math.floor(Math.random() * predictor.demos.length); + predictor.demos = predictor.demos.filter((_, i) => i !== dropIdx); + } + } + + const { score } = await evaluate(candidate, batch, metric); + if (score >= bestScore) { + bestScore = score; + best = candidate; + } + } + + return best; + } +} diff --git a/src/optimizers/index.ts b/src/optimizers/index.ts index 41c65de..c6e018a 100644 --- a/src/optimizers/index.ts +++ b/src/optimizers/index.ts @@ -4,6 +4,8 @@ export { BootstrapFewShot } from "./BootstrapFewShot.js"; export type { BootstrapFewShotOptions } from "./BootstrapFewShot.js"; export { BootstrapFewShotWithRandomSearch } from "./BootstrapFewShotWithRandomSearch.js"; export type { BootstrapFewShotWithRandomSearchOptions } from "./BootstrapFewShotWithRandomSearch.js"; +export { BootstrapFewShotWithOptuna } from "./BootstrapFewShotWithOptuna.js"; +export type { BootstrapFewShotWithOptunaOptions } from "./BootstrapFewShotWithOptuna.js"; export { COPRO } from "./COPRO.js"; export type { COPROOptions } from "./COPRO.js"; export { MIPRO } from "./MIPRO.js"; @@ -12,3 +14,12 @@ export { KNNFewShot } from "./KNNFewShot.js"; export type { KNNFewShotOptions } from "./KNNFewShot.js"; export { EnsembleOptimizer } from "./Ensemble.js"; export type { EnsembleOptimizerOptions } from "./Ensemble.js"; +export { BootstrapFinetune } from "./BootstrapFinetune.js"; +export type { BootstrapFinetuneOptions, FinetuneFormat } from "./BootstrapFinetune.js"; +export { GRPO } from "./GRPO.js"; +export type { GRPOOptions } from "./GRPO.js"; +export { SIMBA } from "./SIMBA.js"; +export type { SIMBAOptions } from "./SIMBA.js"; +export { AvatarOptimizer } from "./AvatarOptimizer.js"; +export type { AvatarOptimizerOptions } from "./AvatarOptimizer.js"; + diff --git a/src/peer-deps.d.ts b/src/peer-deps.d.ts index 12b8759..c1bfc0f 100644 --- a/src/peer-deps.d.ts +++ b/src/peer-deps.d.ts @@ -40,3 +40,15 @@ declare module "weaviate-client" { const value: any; export default value; } + +declare module "@modelcontextprotocol/sdk/client/index.js" { + export const Client: any; +} + +declare module "@modelcontextprotocol/sdk/server/index.js" { + export const Server: any; +} + +declare module "@modelcontextprotocol/sdk/server/stdio.js" { + export const StdioServerTransport: any; +} diff --git a/src/primitives/Image.ts b/src/primitives/Image.ts new file mode 100644 index 0000000..15d00da --- /dev/null +++ b/src/primitives/Image.ts @@ -0,0 +1,87 @@ +import { readFileSync } from "node:fs"; + +/** Supported image MIME types. */ +export type ImageMimeType = "image/jpeg" | "image/png" | "image/gif" | "image/webp"; + +/** + * A multi-modal image value that can be passed as a field in Predict/TypedPredictor. + * + * @example + * ```ts + * const captioner = new Predict("image, question -> caption"); + * const result = await captioner.forward({ + * image: Image.fromURL("https://example.com/photo.jpg"), + * question: "What is in this image?", + * }); + * ``` + */ +export class Image { + readonly url: string | undefined; + readonly base64: string | undefined; + readonly mimeType: ImageMimeType | undefined; + + private constructor(init: { + url?: string | undefined; + base64?: string | undefined; + mimeType?: ImageMimeType | undefined; + }) { + this.url = init.url; + this.base64 = init.base64; + this.mimeType = init.mimeType; + } + + /** Create an Image from a URL. */ + static fromURL(url: string): Image { + return new Image({ url }); + } + + /** Create an Image from base64-encoded data. */ + static fromBase64(data: string, mimeType: ImageMimeType = "image/jpeg"): Image { + return new Image({ base64: data, mimeType }); + } + + /** Create an Image by reading a local file synchronously. */ + static fromFile(path: string, mimeType?: ImageMimeType): Image { + const data = readFileSync(path); + const base64 = data.toString("base64"); + const ext = path.split(".").pop()?.toLowerCase(); + const detectedMime: ImageMimeType = + ext === "png" ? "image/png" : + ext === "gif" ? "image/gif" : + ext === "webp" ? "image/webp" : + "image/jpeg"; + return new Image({ base64, mimeType: mimeType ?? detectedMime }); + } + + /** Serialize to an OpenAI-compatible image_url content part. */ + toOpenAIContentPart(): { type: "image_url"; image_url: { url: string } } { + if (this.url) { + return { type: "image_url", image_url: { url: this.url } }; + } + if (this.base64 && this.mimeType) { + return { type: "image_url", image_url: { url: `data:${this.mimeType};base64,${this.base64}` } }; + } + throw new Error("Image: no url or base64 data available"); + } + + /** Serialize to an Anthropic-compatible image content block. */ + toAnthropicContentBlock(): { + type: "image"; + source: { type: "base64" | "url"; media_type?: string; data?: string; url?: string }; + } { + if (this.url) { + return { type: "image", source: { type: "url", url: this.url } }; + } + if (this.base64 && this.mimeType) { + return { type: "image", source: { type: "base64", media_type: this.mimeType, data: this.base64 } }; + } + throw new Error("Image: no url or base64 data available"); + } + + /** Returns a string representation (used when Image is serialized in prompts). */ + toString(): string { + if (this.url) return `[Image: ${this.url}]`; + if (this.base64) return `[Image: base64 data, ${this.mimeType ?? "unknown type"}]`; + return "[Image]"; + } +} diff --git a/src/primitives/index.ts b/src/primitives/index.ts index ac09d00..3194f65 100644 --- a/src/primitives/index.ts +++ b/src/primitives/index.ts @@ -1,3 +1,6 @@ export { Example } from "./Example.js"; export { Prediction } from "./Prediction.js"; export type { Trace, TokenUsage } from "./Trace.js"; +export { majority } from "./majority.js"; +export { Image } from "./Image.js"; +export type { ImageMimeType } from "./Image.js"; diff --git a/src/primitives/majority.ts b/src/primitives/majority.ts new file mode 100644 index 0000000..5629d2f --- /dev/null +++ b/src/primitives/majority.ts @@ -0,0 +1,30 @@ +import type { Prediction } from "./Prediction.js"; + +/** + * Returns a reducer function that picks the Prediction whose `field` value + * appears most frequently. Ties go to the first occurrence. + */ +export function majority(field = "answer"): (predictions: Prediction[]) => Prediction { + return (predictions: Prediction[]): Prediction => { + if (predictions.length === 0) { + throw new Error("majority: empty predictions array"); + } + + const counts = new Map(); + for (const p of predictions) { + const val = JSON.stringify(p.get(field)); + counts.set(val, (counts.get(val) ?? 0) + 1); + } + + let bestKey = ""; + let bestCount = 0; + for (const [k, c] of counts) { + if (c > bestCount) { + bestCount = c; + bestKey = k; + } + } + + return predictions.find((p) => JSON.stringify(p.get(field)) === bestKey)!; + }; +} diff --git a/src/tracking/ConsoleTracker.ts b/src/tracking/ConsoleTracker.ts new file mode 100644 index 0000000..c87938e --- /dev/null +++ b/src/tracking/ConsoleTracker.ts @@ -0,0 +1,19 @@ +import { Tracker } from "./Tracker.js"; +import type { TrackerEvent } from "./Tracker.js"; + +/** + * A simple tracker that logs events to the console. + */ +export class ConsoleTracker extends Tracker { + override log(event: TrackerEvent): void { + const parts: string[] = [`[${event.type.toUpperCase()}]`]; + if (event.step !== undefined) parts.push(`step=${event.step}`); + if (event.score !== undefined) parts.push(`score=${event.score.toFixed(4)}`); + if (event.metadata) parts.push(JSON.stringify(event.metadata)); + console.log(parts.join(" ")); + } + + override async flush(): Promise { + // No buffering — nothing to flush. + } +} diff --git a/src/tracking/JsonFileTracker.ts b/src/tracking/JsonFileTracker.ts new file mode 100644 index 0000000..e10fb10 --- /dev/null +++ b/src/tracking/JsonFileTracker.ts @@ -0,0 +1,36 @@ +import { mkdirSync } from "node:fs"; +import { appendFile } from "node:fs/promises"; +import { dirname } from "node:path"; +import { Tracker } from "./Tracker.js"; +import type { TrackerEvent } from "./Tracker.js"; + +/** + * A tracker that appends JSON-encoded events as lines to a file. + * + * @example + * ```ts + * const tracker = new JsonFileTracker("./runs/experiment1.jsonl"); + * const optimizer = new GRPO({ numSteps: 10 }); + * ``` + */ +export class JsonFileTracker extends Tracker { + readonly #path: string; + readonly #buffer: string[] = []; + + constructor(path: string) { + super(); + this.#path = path; + mkdirSync(dirname(path), { recursive: true }); + } + + override log(event: TrackerEvent): void { + this.#buffer.push(JSON.stringify({ ...event, ts: new Date().toISOString() })); + } + + override async flush(): Promise { + if (this.#buffer.length === 0) return; + const content = this.#buffer.join("\n") + "\n"; + this.#buffer.length = 0; + await appendFile(this.#path, content, "utf8"); + } +} diff --git a/src/tracking/Tracker.ts b/src/tracking/Tracker.ts new file mode 100644 index 0000000..0b09df5 --- /dev/null +++ b/src/tracking/Tracker.ts @@ -0,0 +1,24 @@ +/** Event types emitted during optimization. */ +export interface TrackerEvent { + /** Type of event. */ + type: "step" | "trial" | "best" | "done"; + /** Current step number. */ + step?: number | undefined; + /** Score at this event. */ + score?: number | undefined; + /** Additional metadata. */ + metadata?: Record | undefined; +} + +/** + * Abstract base class for experiment trackers. + * + * Implement this to log optimization events to console, files, or external + * experiment tracking services. + */ +export abstract class Tracker { + /** Log a single event. */ + abstract log(event: TrackerEvent): void; + /** Flush any buffered events (e.g. write to disk). */ + abstract flush(): Promise; +} diff --git a/src/tracking/index.ts b/src/tracking/index.ts new file mode 100644 index 0000000..3955f56 --- /dev/null +++ b/src/tracking/index.ts @@ -0,0 +1,4 @@ +export { Tracker } from "./Tracker.js"; +export type { TrackerEvent } from "./Tracker.js"; +export { ConsoleTracker } from "./ConsoleTracker.js"; +export { JsonFileTracker } from "./JsonFileTracker.js"; diff --git a/tests/lm/DiskCache.test.ts b/tests/lm/DiskCache.test.ts new file mode 100644 index 0000000..86080fe --- /dev/null +++ b/tests/lm/DiskCache.test.ts @@ -0,0 +1,49 @@ +import { describe, it, expect, afterEach } from "vitest"; +import { DiskCache } from "../../src/lm/DiskCache.js"; +import { rmSync, existsSync } from "node:fs"; +import { join } from "node:path"; +import { tmpdir } from "node:os"; + +const TEST_DIR = join(tmpdir(), "dstsx-test-cache-" + Date.now()); + +afterEach(() => { + if (existsSync(TEST_DIR)) { + rmSync(TEST_DIR, { recursive: true, force: true }); + } +}); + +const mockResponse = { + text: "hello", + texts: ["hello"], + usage: null, + raw: null, +}; + +describe("DiskCache", () => { + it("stores and retrieves values", () => { + const cache = new DiskCache(TEST_DIR); + cache.set("key1", mockResponse); + expect(cache.get("key1")).toEqual(mockResponse); + }); + + it("returns undefined for missing keys", () => { + const cache = new DiskCache(TEST_DIR); + expect(cache.get("missing")).toBeUndefined(); + }); + + it("expires entries based on TTL", async () => { + const cache = new DiskCache(TEST_DIR, 500, 50); + cache.set("key1", mockResponse); + await new Promise((r) => setTimeout(r, 100)); + expect(cache.get("key1")).toBeUndefined(); + }); + + it("clear() removes all entries", () => { + const cache = new DiskCache(TEST_DIR); + cache.set("k1", mockResponse); + cache.set("k2", mockResponse); + cache.clear(); + expect(cache.get("k1")).toBeUndefined(); + expect(cache.get("k2")).toBeUndefined(); + }); +}); diff --git a/tests/lm/Streaming.test.ts b/tests/lm/Streaming.test.ts new file mode 100644 index 0000000..1d74cf5 --- /dev/null +++ b/tests/lm/Streaming.test.ts @@ -0,0 +1,36 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { MockLM } from "../../src/lm/adapters/MockLM.js"; +import { settings } from "../../src/settings/Settings.js"; +import { Predict } from "../../src/modules/Predict.js"; + +describe("LM Streaming", () => { + beforeEach(() => settings.reset()); + + it("base LM.stream() yields a single chunk with done=true", async () => { + const lm = new MockLM({}, "hello world"); + const chunks: import("../../src/lm/types.js").StreamChunk[] = []; + for await (const chunk of lm.stream("hello")) { + chunks.push(chunk); + } + expect(chunks).toHaveLength(1); + expect(chunks[0]?.delta).toBe("hello world"); + expect(chunks[0]?.done).toBe(true); + }); + + it("Predict.stream() yields chunks from the LM", async () => { + settings.configure({ lm: new MockLM({}, "answer: Paris") }); + const predict = new Predict("question -> answer"); + const chunks: import("../../src/lm/types.js").StreamChunk[] = []; + for await (const chunk of predict.stream({ question: "Capital of France?" })) { + chunks.push(chunk); + } + expect(chunks.length).toBeGreaterThan(0); + expect(chunks.at(-1)?.done).toBe(true); + }); + + it("Predict.stream() throws when no LM is configured", async () => { + const predict = new Predict("question -> answer"); + const gen = predict.stream({ question: "q" }); + await expect(gen.next()).rejects.toThrow(/No LM configured/); + }); +}); diff --git a/tests/mcp/DSTsxMCPServer.test.ts b/tests/mcp/DSTsxMCPServer.test.ts new file mode 100644 index 0000000..8bac009 --- /dev/null +++ b/tests/mcp/DSTsxMCPServer.test.ts @@ -0,0 +1,40 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { DSTsxMCPServer } from "../../src/mcp/DSTsxMCPServer.js"; +import { Predict } from "../../src/modules/Predict.js"; +import { MockLM } from "../../src/lm/adapters/MockLM.js"; +import { settings } from "../../src/settings/Settings.js"; + +describe("DSTsxMCPServer", () => { + beforeEach(() => settings.reset()); + + it("registers a module as a tool", () => { + const server = new DSTsxMCPServer(); + const predict = new Predict("question -> answer"); + server.registerModule("qa", "A QA module", predict, ["question"]); + const tools = server.getToolDefinitions(); + expect(tools).toHaveLength(1); + expect(tools[0]!.name).toBe("qa"); + expect(tools[0]!.inputSchema.required).toContain("question"); + }); + + it("calls a registered tool", async () => { + settings.configure({ lm: new MockLM({}, "answer: Paris") }); + const server = new DSTsxMCPServer(); + const predict = new Predict("question -> answer"); + server.registerModule("qa", "A QA module", predict, ["question"]); + const result = await server.callTool("qa", { question: "Capital of France?" }); + expect(result).toBeDefined(); + }); + + it("throws for unknown tool", async () => { + const server = new DSTsxMCPServer(); + await expect(server.callTool("unknown", {})).rejects.toThrow( + 'Tool "unknown" not found', + ); + }); + + it("throws when trying to create stdio server without sdk", async () => { + const server = new DSTsxMCPServer(); + await expect(server.createStdioServer()).rejects.toThrow(); + }); +}); diff --git a/tests/mcp/MCPAdapter.test.ts b/tests/mcp/MCPAdapter.test.ts new file mode 100644 index 0000000..c6508b6 --- /dev/null +++ b/tests/mcp/MCPAdapter.test.ts @@ -0,0 +1,39 @@ +import { describe, it, expect } from "vitest"; +import { MCPToolAdapter } from "../../src/mcp/MCPAdapter.js"; + +describe("MCPToolAdapter", () => { + it("creates tools from pre-loaded definitions", async () => { + const adapter = new MCPToolAdapter({ + tools: [ + { + name: "search", + description: "Search the web", + inputSchema: { query: { type: "string" } }, + }, + ], + callHandler: async (name, args) => + `result for ${name}: ${JSON.stringify(args)}`, + }); + + const tools = await adapter.getTools(); + expect(tools).toHaveLength(1); + expect(tools[0]!.name).toBe("search"); + const result = await tools[0]!.fn('{"query": "test"}'); + expect(result).toContain("search"); + }); + + it("caches tool list after first call", async () => { + const adapter = new MCPToolAdapter({ + tools: [{ name: "t1", description: "d", inputSchema: {} }], + callHandler: async () => "ok", + }); + const t1 = await adapter.getTools(); + const t2 = await adapter.getTools(); + expect(t1).toBe(t2); + }); + + it("throws when no tools or sdk available", async () => { + const adapter = new MCPToolAdapter({ serverUrl: "http://localhost:9999" }); + await expect(adapter.getTools()).rejects.toThrow(); + }); +}); diff --git a/tests/modules/NativeReAct.test.ts b/tests/modules/NativeReAct.test.ts new file mode 100644 index 0000000..f721f58 --- /dev/null +++ b/tests/modules/NativeReAct.test.ts @@ -0,0 +1,65 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { NativeReAct } from "../../src/modules/NativeReAct.js"; +import { MockLM } from "../../src/lm/adapters/MockLM.js"; +import { settings } from "../../src/settings/Settings.js"; +import type { Tool } from "../../src/modules/ReAct.js"; + +describe("NativeReAct", () => { + beforeEach(() => settings.reset()); + + it("throws when no LM is configured", async () => { + const agent = new NativeReAct("question -> answer", []); + await expect(agent.forward({ question: "q" })).rejects.toThrow(/No LM configured/); + }); + + it("returns a Prediction when LM returns text (no tool calls)", async () => { + settings.configure({ lm: new MockLM({}, "Paris") }); + const agent = new NativeReAct("question -> answer", []); + const result = await agent.forward({ question: "Capital of France?" }); + expect(result.get("answer")).toBe("Paris"); + }); + + it("uses tools when LM raw response contains tool_calls", async () => { + const mockTool: Tool = { + name: "lookup", + description: "Look up a fact", + fn: async (_args: string) => "The answer is 42", + }; + + // MockLM with raw response containing tool_calls + class ToolCallLM extends MockLM { + #called = false; + protected override async _call(prompt: unknown, config: unknown): Promise { + if (!this.#called) { + this.#called = true; + return { + text: "", + texts: [""], + usage: null, + raw: { + choices: [{ + message: { + tool_calls: [{ function: { name: "lookup", arguments: "test query" } }], + }, + finish_reason: "tool_calls", + }], + }, + }; + } + return { text: "The answer is 42", texts: ["The answer is 42"], usage: null, raw: null }; + } + } + + settings.configure({ lm: new ToolCallLM() }); + const agent = new NativeReAct("question -> answer", [mockTool], 3); + const result = await agent.forward({ question: "What is the answer?" }); + expect(result.get("answer")).toBe("The answer is 42"); + }); + + it("has tools and maxIter properties", () => { + const tools: Tool[] = [{ name: "search", description: "Search", fn: async (a) => a }]; + const agent = new NativeReAct("question -> answer", tools, 3); + expect(agent.maxIter).toBe(3); + expect(agent.tools.has("search")).toBe(true); + }); +}); diff --git a/tests/modules/Parallel.test.ts b/tests/modules/Parallel.test.ts new file mode 100644 index 0000000..d953bd9 --- /dev/null +++ b/tests/modules/Parallel.test.ts @@ -0,0 +1,26 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { Parallel } from "../../src/modules/Parallel.js"; +import { Predict } from "../../src/modules/Predict.js"; +import { MockLM } from "../../src/lm/adapters/MockLM.js"; +import { settings } from "../../src/settings/Settings.js"; + +describe("Parallel", () => { + beforeEach(() => settings.reset()); + + it("runs all modules and returns predictions via run()", async () => { + settings.configure({ lm: new MockLM({}, "answer: 42") }); + const m1 = new Predict("question -> answer"); + const m2 = new Predict("question -> answer"); + const parallel = new Parallel([m1, m2]); + const results = await parallel.run({ question: "What is 6*7?" }); + expect(results).toHaveLength(2); + }); + + it("forward() returns first prediction", async () => { + settings.configure({ lm: new MockLM({}, "answer: 42") }); + const m1 = new Predict("question -> answer"); + const parallel = new Parallel([m1]); + const result = await parallel.forward({ question: "?" }); + expect(result.get("answer")).toBe("42"); + }); +}); diff --git a/tests/modules/ProgramOfThoughtWorker.test.ts b/tests/modules/ProgramOfThoughtWorker.test.ts new file mode 100644 index 0000000..bec0445 --- /dev/null +++ b/tests/modules/ProgramOfThoughtWorker.test.ts @@ -0,0 +1,44 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { ProgramOfThought } from "../../src/modules/ProgramOfThought.js"; +import { MockLM } from "../../src/lm/adapters/MockLM.js"; +import { settings } from "../../src/settings/Settings.js"; + +describe("ProgramOfThought sandbox modes", () => { + beforeEach(() => settings.reset()); + + it("has sandbox property (default: 'function')", () => { + const pot = new ProgramOfThought("question -> answer"); + expect(pot.sandbox).toBe("function"); + }); + + it("accepts sandbox: 'none'", () => { + const pot = new ProgramOfThought("question -> answer", 3, 5000, "none"); + expect(pot.sandbox).toBe("none"); + }); + + it("accepts sandbox: 'worker'", () => { + const pot = new ProgramOfThought("question -> answer", 3, 5000, "worker"); + expect(pot.sandbox).toBe("worker"); + }); + + it("sandbox='function' executes code correctly", async () => { + settings.configure({ lm: new MockLM({}, "code: return 2 + 2") }); + const pot = new ProgramOfThought("question -> answer", 1, 5000, "function"); + const result = await pot.forward({ question: "2+2?" }); + expect(result.get("answer")).toBe("4"); + }); + + it("sandbox='worker' executes code in a worker thread", async () => { + settings.configure({ lm: new MockLM({}, "code: return 2 + 2") }); + const pot = new ProgramOfThought("question -> answer", 1, 10000, "worker"); + const result = await pot.forward({ question: "2+2?" }); + expect(result.get("answer")).toBe("4"); + }, 15000); + + it("sandbox='none' executes code without timeout", async () => { + settings.configure({ lm: new MockLM({}, "code: return 'hello'") }); + const pot = new ProgramOfThought("question -> answer", 1, 5000, "none"); + const result = await pot.forward({ question: "say hello" }); + expect(result.get("answer")).toBe("hello"); + }); +}); diff --git a/tests/modules/Refine.test.ts b/tests/modules/Refine.test.ts new file mode 100644 index 0000000..a93f116 --- /dev/null +++ b/tests/modules/Refine.test.ts @@ -0,0 +1,33 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { Refine } from "../../src/modules/Refine.js"; +import { Predict } from "../../src/modules/Predict.js"; +import { MockLM } from "../../src/lm/adapters/MockLM.js"; +import { settings } from "../../src/settings/Settings.js"; + +describe("Refine", () => { + beforeEach(() => settings.reset()); + + it("returns inner prediction when critic says yes", async () => { + settings.configure({ + lm: new MockLM( + {}, + "answer: Paris\ncritique: looks good\nis_satisfactory: yes", + ), + }); + const inner = new Predict("question -> answer"); + const refine = new Refine(inner, { maxRefinements: 2 }); + const result = await refine.forward({ question: "Capital of France?" }); + expect(result.get("answer")).toBeDefined(); + }); + + it("respects stopCondition", async () => { + settings.configure({ lm: new MockLM({}, "answer: ok") }); + const inner = new Predict("question -> answer"); + const refine = new Refine(inner, { + maxRefinements: 3, + stopCondition: (p) => p.get("answer") === "ok", + }); + const result = await refine.forward({ question: "test?" }); + expect(result.get("answer")).toBe("ok"); + }); +}); diff --git a/tests/modules/TypedPredictor.test.ts b/tests/modules/TypedPredictor.test.ts new file mode 100644 index 0000000..30d952c --- /dev/null +++ b/tests/modules/TypedPredictor.test.ts @@ -0,0 +1,54 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { TypedPredictor, TypedChainOfThought, TypedPrediction } from "../../src/modules/TypedPredictor.js"; +import { MockLM } from "../../src/lm/adapters/MockLM.js"; +import { settings } from "../../src/settings/Settings.js"; + +describe("TypedPredictor", () => { + beforeEach(() => settings.reset()); + + it("parses JSON from LM response", async () => { + settings.configure({ lm: new MockLM({}, '{"answer": "Paris"}') }); + const tp = new TypedPredictor("question -> answer"); + const result = await tp.forward({ question: "Capital of France?" }); + expect(result).toBeInstanceOf(TypedPrediction); + expect((result.typed as Record)["answer"]).toBe("Paris"); + }); + + it("strips markdown code fences", async () => { + settings.configure({ lm: new MockLM({}, "```json\n{\"answer\": \"Paris\"}\n```") }); + const tp = new TypedPredictor("question -> answer"); + const result = await tp.forward({ question: "Capital?" }); + expect((result.typed as Record)["answer"]).toBe("Paris"); + }); + + it("validates with schema if provided", async () => { + const schema = { + parse: (v: unknown) => { + const obj = v as Record; + if (typeof obj["answer"] !== "string") throw new Error("invalid"); + return obj as { answer: string }; + }, + }; + settings.configure({ lm: new MockLM({}, '{"answer": "Paris"}') }); + const tp = new TypedPredictor("question -> answer", schema); + const result = await tp.forward({ question: "Capital?" }); + expect(result.typed.answer).toBe("Paris"); + }); + + it("throws after maxRetries on invalid JSON", async () => { + settings.configure({ lm: new MockLM({}, "not json") }); + const tp = new TypedPredictor("question -> answer", undefined, { maxRetries: 1 }); + await expect(tp.forward({ question: "?" })).rejects.toThrow(); + }); +}); + +describe("TypedChainOfThought", () => { + beforeEach(() => settings.reset()); + + it("strips rationale from result", async () => { + settings.configure({ lm: new MockLM({}, '{"answer": "72"}') }); + const tcot = new TypedChainOfThought("question -> answer"); + const result = await tcot.forward({ question: "9*8?" }); + expect(result.get("rationale")).toBeUndefined(); + }); +}); diff --git a/tests/optimizers/AvatarOptimizer.test.ts b/tests/optimizers/AvatarOptimizer.test.ts new file mode 100644 index 0000000..323984d --- /dev/null +++ b/tests/optimizers/AvatarOptimizer.test.ts @@ -0,0 +1,37 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { AvatarOptimizer } from "../../src/optimizers/AvatarOptimizer.js"; +import { Predict } from "../../src/modules/Predict.js"; +import { Module } from "../../src/modules/Module.js"; +import { Prediction, Example } from "../../src/primitives/index.js"; +import { MockLM } from "../../src/lm/adapters/MockLM.js"; +import { settings } from "../../src/settings/Settings.js"; + +class SimpleQA extends Module { + predict = new Predict("question -> answer"); + override async forward(inputs: Record): Promise { + return this.predict.forward(inputs); + } +} + +describe("AvatarOptimizer", () => { + beforeEach(() => settings.reset()); + + it("throws when no LM is configured", async () => { + const optimizer = new AvatarOptimizer({ numAvatars: 1 }); + await expect( + optimizer.compile(new SimpleQA(), [new Example({ question: "q", answer: "a" })], () => true) + ).rejects.toThrow(/LM/); + }); + + it("compiles and returns a Module", async () => { + settings.configure({ lm: new MockLM({}, "answer: 4") }); + const trainset = [ + new Example({ question: "2+2?", answer: "4" }), + new Example({ question: "1+3?", answer: "4" }), + ]; + const metric = (_: Example, pred: Prediction) => pred.get("answer") === "4"; + const optimizer = new AvatarOptimizer({ numAvatars: 1, maxLabeledDemos: 2 }); + const optimized = await optimizer.compile(new SimpleQA(), trainset, metric); + expect(optimized).toBeInstanceOf(Module); + }); +}); diff --git a/tests/optimizers/BootstrapFewShotWithOptuna.test.ts b/tests/optimizers/BootstrapFewShotWithOptuna.test.ts new file mode 100644 index 0000000..09e2043 --- /dev/null +++ b/tests/optimizers/BootstrapFewShotWithOptuna.test.ts @@ -0,0 +1,34 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { BootstrapFewShotWithOptuna } from "../../src/optimizers/BootstrapFewShotWithOptuna.js"; +import { Predict } from "../../src/modules/Predict.js"; +import { Module } from "../../src/modules/Module.js"; +import { Prediction, Example } from "../../src/primitives/index.js"; +import { MockLM } from "../../src/lm/adapters/MockLM.js"; +import { settings } from "../../src/settings/Settings.js"; + +class SimpleQA extends Module { + predict = new Predict("question -> answer"); + override async forward(inputs: Record): Promise { + return this.predict.forward(inputs); + } +} + +describe("BootstrapFewShotWithOptuna", () => { + beforeEach(() => settings.reset()); + + it("compiles and returns an optimized module", async () => { + settings.configure({ lm: new MockLM({}, "answer: 4") }); + const trainset = [ + new Example({ question: "2+2?", answer: "4" }), + new Example({ question: "1+3?", answer: "4" }), + ]; + const metric = (_: Example, pred: Prediction) => pred.get("answer") === "4"; + const optimizer = new BootstrapFewShotWithOptuna({ + numTrials: 3, + maxBootstrappedDemos: 2, + }); + const student = new SimpleQA(); + const optimized = await optimizer.compile(student, trainset, metric); + expect(optimized).toBeInstanceOf(Module); + }); +}); diff --git a/tests/optimizers/BootstrapFinetune.test.ts b/tests/optimizers/BootstrapFinetune.test.ts new file mode 100644 index 0000000..38f9026 --- /dev/null +++ b/tests/optimizers/BootstrapFinetune.test.ts @@ -0,0 +1,58 @@ +import { describe, it, expect, beforeEach, afterEach } from "vitest"; +import { BootstrapFinetune } from "../../src/optimizers/BootstrapFinetune.js"; +import { Predict } from "../../src/modules/Predict.js"; +import { Module } from "../../src/modules/Module.js"; +import { Prediction, Example } from "../../src/primitives/index.js"; +import { MockLM } from "../../src/lm/adapters/MockLM.js"; +import { settings } from "../../src/settings/Settings.js"; +import { existsSync, unlinkSync, readFileSync } from "node:fs"; + +class SimpleQA extends Module { + predict = new Predict("question -> answer"); + override async forward(inputs: Record): Promise { + return this.predict.forward(inputs); + } +} + +describe("BootstrapFinetune", () => { + const exportPath = "/tmp/test_finetune_data.jsonl"; + + beforeEach(() => settings.reset()); + afterEach(() => { + if (existsSync(exportPath)) unlinkSync(exportPath); + }); + + it("compiles and writes openai format JSONL", async () => { + settings.configure({ lm: new MockLM({}, "answer: 4") }); + const trainset = [ + new Example({ question: "2+2?", answer: "4" }), + new Example({ question: "1+3?", answer: "4" }), + ]; + const metric = (_: Example, pred: Prediction) => pred.get("answer") === "4"; + const optimizer = new BootstrapFinetune({ exportPath, format: "openai", maxBootstrappedDemos: 2 }); + const optimized = await optimizer.compile(new SimpleQA(), trainset, metric); + expect(optimized).toBeInstanceOf(Module); + expect(existsSync(exportPath)).toBe(true); + const lines = readFileSync(exportPath, "utf8").trim().split("\n").filter(Boolean); + if (lines.length > 0) { + const record = JSON.parse(lines[0]!); + expect(record).toHaveProperty("messages"); + } + }); + + it("compiles and writes generic format JSONL", async () => { + settings.configure({ lm: new MockLM({}, "answer: 4") }); + const trainset = [new Example({ question: "2+2?", answer: "4" })]; + const metric = (_: Example, pred: Prediction) => pred.get("answer") === "4"; + const optimizer = new BootstrapFinetune({ exportPath, format: "generic", maxBootstrappedDemos: 2 }); + await optimizer.compile(new SimpleQA(), trainset, metric); + if (existsSync(exportPath)) { + const lines = readFileSync(exportPath, "utf8").trim().split("\n").filter(Boolean); + if (lines.length > 0) { + const record = JSON.parse(lines[0]!); + expect(record).toHaveProperty("prompt"); + expect(record).toHaveProperty("completion"); + } + } + }); +}); diff --git a/tests/optimizers/GRPO.test.ts b/tests/optimizers/GRPO.test.ts new file mode 100644 index 0000000..c671135 --- /dev/null +++ b/tests/optimizers/GRPO.test.ts @@ -0,0 +1,37 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { GRPO } from "../../src/optimizers/GRPO.js"; +import { Predict } from "../../src/modules/Predict.js"; +import { Module } from "../../src/modules/Module.js"; +import { Prediction, Example } from "../../src/primitives/index.js"; +import { MockLM } from "../../src/lm/adapters/MockLM.js"; +import { settings } from "../../src/settings/Settings.js"; + +class SimpleQA extends Module { + predict = new Predict("question -> answer"); + override async forward(inputs: Record): Promise { + return this.predict.forward(inputs); + } +} + +describe("GRPO", () => { + beforeEach(() => settings.reset()); + + it("throws when no LM is configured", async () => { + const optimizer = new GRPO({ numSteps: 1, groupSize: 2 }); + await expect( + optimizer.compile(new SimpleQA(), [new Example({ question: "q", answer: "a" })], () => true) + ).rejects.toThrow(/LM/); + }); + + it("compiles and returns a Module", async () => { + settings.configure({ lm: new MockLM({}, "answer: 4") }); + const trainset = [ + new Example({ question: "2+2?", answer: "4" }), + new Example({ question: "1+3?", answer: "4" }), + ]; + const metric = (_: Example, pred: Prediction) => pred.get("answer") === "4"; + const optimizer = new GRPO({ numSteps: 1, groupSize: 2 }); + const optimized = await optimizer.compile(new SimpleQA(), trainset, metric); + expect(optimized).toBeInstanceOf(Module); + }); +}); diff --git a/tests/optimizers/SIMBA.test.ts b/tests/optimizers/SIMBA.test.ts new file mode 100644 index 0000000..63d1449 --- /dev/null +++ b/tests/optimizers/SIMBA.test.ts @@ -0,0 +1,30 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { SIMBA } from "../../src/optimizers/SIMBA.js"; +import { Predict } from "../../src/modules/Predict.js"; +import { Module } from "../../src/modules/Module.js"; +import { Prediction, Example } from "../../src/primitives/index.js"; +import { MockLM } from "../../src/lm/adapters/MockLM.js"; +import { settings } from "../../src/settings/Settings.js"; + +class SimpleQA extends Module { + predict = new Predict("question -> answer"); + override async forward(inputs: Record): Promise { + return this.predict.forward(inputs); + } +} + +describe("SIMBA", () => { + beforeEach(() => settings.reset()); + + it("compiles and returns a Module", async () => { + settings.configure({ lm: new MockLM({}, "answer: 4") }); + const trainset = [ + new Example({ question: "2+2?", answer: "4" }), + new Example({ question: "1+3?", answer: "4" }), + ]; + const metric = (_: Example, pred: Prediction) => pred.get("answer") === "4"; + const optimizer = new SIMBA({ numIter: 2, batchSize: 2 }); + const optimized = await optimizer.compile(new SimpleQA(), trainset, metric); + expect(optimized).toBeInstanceOf(Module); + }); +}); diff --git a/tests/primitives/Image.test.ts b/tests/primitives/Image.test.ts new file mode 100644 index 0000000..7e1848b --- /dev/null +++ b/tests/primitives/Image.test.ts @@ -0,0 +1,75 @@ +import { describe, it, expect } from "vitest"; +import { Image } from "../../src/primitives/Image.js"; +import { writeFileSync, unlinkSync } from "node:fs"; +import { join } from "node:path"; +import { tmpdir } from "node:os"; + +describe("Image", () => { + it("creates an Image from URL", () => { + const img = Image.fromURL("https://example.com/photo.jpg"); + expect(img.url).toBe("https://example.com/photo.jpg"); + expect(img.base64).toBeUndefined(); + }); + + it("creates an Image from base64 data", () => { + const img = Image.fromBase64("abc123", "image/png"); + expect(img.base64).toBe("abc123"); + expect(img.mimeType).toBe("image/png"); + expect(img.url).toBeUndefined(); + }); + + it("creates an Image from a file", () => { + const path = join(tmpdir(), "test-image.png"); + writeFileSync(path, Buffer.from("fake png data")); + try { + const img = Image.fromFile(path); + expect(img.base64).toBeDefined(); + expect(img.mimeType).toBe("image/png"); + } finally { + unlinkSync(path); + } + }); + + it("toOpenAIContentPart() works with URL", () => { + const img = Image.fromURL("https://example.com/img.jpg"); + const part = img.toOpenAIContentPart(); + expect(part.type).toBe("image_url"); + expect(part.image_url.url).toBe("https://example.com/img.jpg"); + }); + + it("toOpenAIContentPart() works with base64", () => { + const img = Image.fromBase64("abc", "image/jpeg"); + const part = img.toOpenAIContentPart(); + expect(part.image_url.url).toContain("data:image/jpeg;base64,abc"); + }); + + it("toAnthropicContentBlock() works with URL", () => { + const img = Image.fromURL("https://example.com/img.jpg"); + const block = img.toAnthropicContentBlock(); + expect(block.type).toBe("image"); + expect(block.source.type).toBe("url"); + expect(block.source.url).toBe("https://example.com/img.jpg"); + }); + + it("toAnthropicContentBlock() works with base64", () => { + const img = Image.fromBase64("abc", "image/png"); + const block = img.toAnthropicContentBlock(); + expect(block.source.type).toBe("base64"); + expect(block.source.data).toBe("abc"); + }); + + it("toString() returns a description", () => { + expect(Image.fromURL("https://x.com/a.jpg").toString()).toContain("https://x.com/a.jpg"); + expect(Image.fromBase64("abc", "image/png").toString()).toContain("base64 data"); + }); + + it("toOpenAIContentPart() throws when neither url nor base64 available", () => { + // Access private constructor via Object.create to simulate an empty Image + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const emptyImg = Object.create(Image.prototype) as Image; + // Assign undefined properties directly (bypasses constructor) + Object.assign(emptyImg, { url: undefined, base64: undefined, mimeType: undefined }); + expect(() => emptyImg.toOpenAIContentPart()).toThrow("Image: no url or base64 data available"); + expect(() => emptyImg.toAnthropicContentBlock()).toThrow("Image: no url or base64 data available"); + }); +}); diff --git a/tests/primitives/majority.test.ts b/tests/primitives/majority.test.ts new file mode 100644 index 0000000..7082858 --- /dev/null +++ b/tests/primitives/majority.test.ts @@ -0,0 +1,38 @@ +import { describe, it, expect } from "vitest"; +import { majority } from "../../src/primitives/majority.js"; +import { Prediction } from "../../src/primitives/Prediction.js"; + +describe("majority", () => { + it("picks most common value", () => { + const preds = [ + new Prediction({ answer: "Paris" }), + new Prediction({ answer: "London" }), + new Prediction({ answer: "Paris" }), + ]; + const winner = majority()(preds); + expect(winner.get("answer")).toBe("Paris"); + }); + + it("uses first occurrence on ties", () => { + const preds = [ + new Prediction({ answer: "A" }), + new Prediction({ answer: "B" }), + ]; + const winner = majority()(preds); + expect(winner.get("answer")).toBe("A"); + }); + + it("uses custom field name", () => { + const preds = [ + new Prediction({ label: "yes" }), + new Prediction({ label: "yes" }), + new Prediction({ label: "no" }), + ]; + const winner = majority("label")(preds); + expect(winner.get("label")).toBe("yes"); + }); + + it("throws on empty array", () => { + expect(() => majority()([])).toThrow(); + }); +}); diff --git a/tests/tracking/Tracker.test.ts b/tests/tracking/Tracker.test.ts new file mode 100644 index 0000000..2ec9817 --- /dev/null +++ b/tests/tracking/Tracker.test.ts @@ -0,0 +1,40 @@ +import { describe, it, expect, afterEach } from "vitest"; +import { ConsoleTracker } from "../../src/tracking/ConsoleTracker.js"; +import { JsonFileTracker } from "../../src/tracking/JsonFileTracker.js"; +import { existsSync, unlinkSync, readFileSync } from "node:fs"; + +describe("Tracking", () => { + const testPath = "/tmp/test_tracker.jsonl"; + + afterEach(() => { + if (existsSync(testPath)) unlinkSync(testPath); + }); + + it("ConsoleTracker.log() runs without error", () => { + const tracker = new ConsoleTracker(); + expect(() => tracker.log({ type: "step", step: 1, score: 0.5 })).not.toThrow(); + }); + + it("ConsoleTracker.flush() resolves", async () => { + const tracker = new ConsoleTracker(); + await expect(tracker.flush()).resolves.toBeUndefined(); + }); + + it("JsonFileTracker writes events on flush", async () => { + const tracker = new JsonFileTracker(testPath); + tracker.log({ type: "step", step: 1, score: 0.75 }); + tracker.log({ type: "best", score: 0.9 }); + await tracker.flush(); + expect(existsSync(testPath)).toBe(true); + const lines = readFileSync(testPath, "utf8").trim().split("\n"); + expect(lines).toHaveLength(2); + const first = JSON.parse(lines[0]!); + expect(first.type).toBe("step"); + expect(first.score).toBe(0.75); + }); + + it("JsonFileTracker.flush() is idempotent on empty buffer", async () => { + const tracker = new JsonFileTracker(testPath); + await tracker.flush(); // Should not throw on empty buffer + }); +}); diff --git a/typedoc.json b/typedoc.json new file mode 100644 index 0000000..66b727a --- /dev/null +++ b/typedoc.json @@ -0,0 +1,12 @@ +{ + "entryPoints": ["src/index.ts"], + "out": "docs", + "name": "DSTsx", + "readme": "README.md", + "includeVersion": true, + "excludePrivate": true, + "excludeInternal": true, + "navigationLinks": { + "GitHub": "https://github.com/Psyborgs-git/DSTsx" + } +}