diff --git a/.github/scripts/spellcheck_conf/wordlist.txt b/.github/scripts/spellcheck_conf/wordlist.txt
index 436db4de3..da9cb9052 100644
--- a/.github/scripts/spellcheck_conf/wordlist.txt
+++ b/.github/scripts/spellcheck_conf/wordlist.txt
@@ -1556,3 +1556,8 @@ RequestBuilder
VectorIndexManager
csvs
programmatically
+Customizations
+VSCode
+applyTo
+mdc
+windsurfrules
diff --git a/.github/workflows/pytest_cpu_gha_runner.yaml b/.github/workflows/pytest_cpu_gha_runner.yaml
index a3fa2004b..211f0a881 100644
--- a/.github/workflows/pytest_cpu_gha_runner.yaml
+++ b/.github/workflows/pytest_cpu_gha_runner.yaml
@@ -64,12 +64,12 @@ jobs:
id: pytest
run: |
echo "Running PyTest tests at 'GITHUB_WORKSPACE' path: ${GITHUB_WORKSPACE}"
- cd $GITHUB_WORKSPACE && python3 -m pytest --junitxml="$GITHUB_WORKSPACE/result.xml"
+ cd $GITHUB_WORKSPACE && python3 -m pytest --junitxml="$GITHUB_WORKSPACE/pytest_result.xml"
- name: Publish Test Summary
id: test_summary
uses: test-summary/action@v2
with:
paths: |
- **/*.xml
+ **/pytest_result.xml
if: always()
diff --git a/end-to-end-use-cases/coding_assistant/README.md b/end-to-end-use-cases/coding_assistant/README.md
new file mode 100644
index 000000000..d7ad6b5c7
--- /dev/null
+++ b/end-to-end-use-cases/coding_assistant/README.md
@@ -0,0 +1,367 @@
+# Building with Coding Assistant
+
+Coding assistants have transformed application development by streamlining workflows and increasing productivity. This tutorial explores various methods to optimize your developer experience when building on the Llama platform, utilizing the capabilities of your preferred coding assistants.
+
+This tutorial will cover the following techniques:
+- Prompts for common Llama developer workflows/use cases
+- Incorporating Rules/AGENTS.md into your coding assistant
+
+## Prompts for common Llama developer workflows/use cases
+
+Example prompts for the following tasks are provided, which you can copy and paste into your coding assistant:
+- Task#1: Migrate OpenAI API to Llama API python in the application
+- Task#2: Finetuning Llama models on one single GPU
+- Task#3: Building a RAG chatbot with Llama
+- Task#4: Llama model upgrade from Llama 3 to use Llama 4
+
+### Task#1: Migrate OpenAI API to Llama API python in the application
+```
+You are a coding Assistant specialized in Llama - a series of LLM developed and opensourced by Meta. Your goal is to write code for developers working with the Llama AI ecosystem including tools, API SDK, cookbook and best practices to complete certain tasks.
+
+Here is the task:
+Convert all my OpenAI API usage in this application to use Llama API Python SDK instead.
+Make sure you follow the correct syntax from resources provided below.
+Provide instructions on how to acquire Llama API Key.
+Analyze the use cases of this application and choose appropriate Llama models supported by Llama API based on performance and cost.
+Add clear readme on what you have changed and how to properly test them.
+Convert the files in place. Do not create unnecessary files,scripts and readme.
+
+Mainly use this reference: Llama API Python SDK (https://github.com/meta-llama/llama-api-python) - An official repository contains a client python library to access Llama API Client REST API.
+
+Here are the other resources you might need to work with. Search on the exact web url and/or local index to find these resources:
+Llama Official Website (https://www.llama.com/docs/overview/) - Providing comprehensive documentation such as prompt format for Llama models and various how-to-guides.
+Llama Cookbooks (https://github.com/meta-llama/llama-cookbook) - An official repository contains Llama best practices for helping you get started with inference, fine-tuning and end-to-end use-cases.
+Llama Stack (https://github.com/llamastack/llama-stack) - An official repository containing a framework which standardizes the core building blocks of simplified AI application development. Codifies best practices across the Llama ecosystem.
+```
+
+### Task#2: Finetuning Llama models on one single GPU
+```
+You are a coding Assistant specialized in Llama - a series of LLM developed and opensourced by Meta. Your goal is to write code for developers working with the Llama AI ecosystem including tools, API SDK, cookbook and best practices to complete certain tasks.
+
+Here is the task:
+Create a script that can help me finetune Llama models on one single consumer GPU such as A10.
+Use PEFT for finetuning.
+Analyze the memory requirements and select appropriate Llama models and quantization that can fit in the GPU memory.
+Specify interfaces that can take in a particular dataset for finetuning. This should be defined by the user later based on the use cases. Make sure you provide instructions on how to use the dataset for finetuning.
+Write a separate script for evaluating the finetuning result.
+
+Mainly use this reference: https://github.com/meta-llama/llama-cookbook/blob/main/src/docs/single_gpu.md
+
+Here are the other resources you might need to work with. Search on the exact web url and/or local index to find these resources:
+Llama Official Website (https://www.llama.com/docs/overview/) - Providing comprehensive documentation such as prompt format for Llama models and various how-to-guides.
+Llama Cookbooks (https://github.com/meta-llama/llama-cookbook) - An official repository contains Llama best practices for helping you get started with inference, fine-tuning and end-to-end use-cases.
+Llama Stack (https://github.com/llamastack/llama-stack) - An official repository containing a framework which standardizes the core building blocks of simplified AI application development. Codifies best practices across the Llama ecosystem.
+Llama API Python SDK (https://github.com/meta-llama/llama-api-python) - An official repository contains a client python library to access Llama API Client REST API.
+
+
+To accomplish this, follow these steps:
+1. Analysis on the task and break it down into corresponding subtasks.
+2. For each of the subtasks, reference the available resources and find exact examples that create your solution.
+3. Validate your solution by writing tests if possible and automated tests.
+4. Iterate on step#2 until you are satisfied.
+
+Your output must contain these artifacts:
+- Exact code files to accomplish the task
+- A comprehensive readme with step by step guide
+- Scripts for easy deployment
+- Dependencies that can be easily installed
+```
+
+### Task#3: Building a RAG chatbot with Llama
+```
+You are a coding Assistant specialized in Llama - a series of LLM developed and opensourced by Meta. Your goal is to write code for developers working with the Llama AI ecosystem including tools, API SDK, cookbook and best practices to complete certain tasks.
+
+Here is the task:
+Build a RAG chatbot using Llama models.
+Specify interfaces that can take in user defined files such as PDFs. Make sure you provide instructions on how to use these interfaces to process files.
+Use a popular text embedding model with necessary conversion to create a vector database to store user defined files.
+Create a chatbot UI using Gradio that can answer questions regarding the database.
+
+
+Mainly use this reference:https://github.com/meta-llama/llama-cookbook/blob/main/end-to-end-use-cases/customerservice_chatbots/RAG_chatbot/RAG_Chatbot_Example.ipynb
+
+Here are the resources you’ll work with. Search on the exact web url and/or local index to find these resources:
+Llama Official Website (https://www.llama.com/docs/overview/) - Providing comprehensive documentation such as prompt format for Llama models and various how-to-guides.
+Llama Cookbooks (https://github.com/meta-llama/llama-cookbook) - An official repository contains Llama best practices for helping you get started with inference, fine-tuning and end-to-end use-cases.
+Llama Stack (https://github.com/llamastack/llama-stack) - An official repository containing a framework which standardizes the core building blocks of simplified AI application development. Codifies best practices across the Llama ecosystem.
+Llama API Python SDK (https://github.com/meta-llama/llama-api-python) - An official repository contains a client python library to access Llama API Client REST API.
+
+
+To accomplish this, follow these steps:
+1. Analysis on the task and break it down into corresponding subtasks.
+2. For each of the subtasks, reference the available resources and find exact examples that create your solution.
+3. Validate your solution by writing tests if possible and automated tests.
+4. Iterate on step#2 until you are satisfied.
+
+Your output must contain these artifacts:
+- Exact code files to accomplish the task
+- A comprehensive readme with step by step guide
+- Scripts for easy deployment
+- Dependencies that can be easily installed
+```
+
+### Task#4: Llama model upgrade from Llama 3 to use Llama 4
+```
+You are a coding Assistant specialized in Llama - a series of LLM developed and opensourced by Meta. Your goal is to write code for developers working with the Llama AI ecosystem including tools, API SDK, cookbook and best practices to complete certain tasks.
+
+Here is the task:
+Convert all my usage of the Llama 3 model in the codebase to use the Llama 4 model instead.
+Do not change the original interface method. Use the same API provided if applicable.
+Analyze the use cases of this application and choose appropriate Llama models.
+Add clear readme on what you have changed and how to properly test them.
+Convert the files in place. Do not create unnecessary files, scripts and readme.
+
+Here are the resources you’ll work with. Search on the exact web url and/or local index to find these resources:
+Llama Official Website (https://www.llama.com/docs/overview/) - Providing comprehensive documentation such as prompt format for Llama models and various how-to-guides.
+Llama Cookbooks (https://github.com/meta-llama/llama-cookbook) - An official repository contains Llama best practices for helping you get started with inference, fine-tuning and end-to-end use-cases.
+Llama Stack (https://github.com/llamastack/llama-stack) - An official repository containing a framework which standardizes the core building blocks of simplified AI application development. Codifies best practices across the Llama ecosystem.
+Llama API Python SDK (https://github.com/meta-llama/llama-api-python) - An official repository contains a client python library to access Llama API Client REST API.
+```
+
+## Incorporating Rules/AGENTS.md into Your Coding Assistant
+
+An effective method to enhance your coding assistant involves providing rules and instructions for the coding agent. This detailed guide demonstrates how to configure your IDE and coding assistants for accelerated development within the Llama ecosystem. It also offers recommendations for optimal setup in Cursor, Windsurf, and VSCode/Co-pilot.
+
+### Cursor
+
+#### Index documentation
+Cursor can pre-index documentations to provide additional Llama context during development. This would enable you to ask questions about Llama more easily and provide additional reference during AI-assisted coding.
+- Open Cursor Settings (Shift + Cmd/Ctrl + J)
+- Navigate to “Indexing & Docs”
+- Click “+ Add Doc” button and add these official Llama documentation URLs:
+ - Llama Website: https://www.llama.com/docs/overview/
+ - Llama Models Github: https://github.com/meta-llama/llama-models
+ - Llama Cookbook Github: https://github.com/meta-llama/llama-cookbook
+ - Llama Stack Github: https://github.com/llamastack/llama-stack
+
+#### Optimize with Custom Rules
+
+To further improve the experience building with Llama, you can add custom instructions to Cursor’s AI system such as best practices and preferred coding patterns. This helps the AI agent understand the Llama ecosystem better.
+
+- Open Cursor Settings (Shift + Cmd/Ctrl + J)
+- Go to “Rules & Memories”
+- Click “+ Add Rule” for Project rules and append the prompt below
+- After pasting, adjust the button on the top to select “Apply Intelligently”
+
+#### Project Rule for Llama (`.cursor/rules/llama.mdc`)
+```
+---
+description: Full-featured developer rulebook for leveraging Meta Llama in this codebase.
+globs:
+ - "**/*.py"
+ - "**/*.js"
+ - "**/*.md"
+alwaysApply: true
+---
+
+# Meta Llama Developer Rulebook
+
+## Model Usage & Integration
+- Use Meta Llama models (Llama 3, Llama 4 or later) as the default LLM for all tasks requiring a language model.
+- Reference the latest official developer resources for installation, basic usage, and authentication:
+ - Meta Llama Dev Docs: https://www.llama.com/docs/overview/
+ - Official GitHub: https://github.com/meta-llama/llama
+ - Llama models on Hugging Face: https://huggingface.co/meta-llama
+
+## Model Deployment
+- Provide clear suggestions on different types of deployment (API Cloud based, On-prem, On-device) based on the task and model requirements.
+ - Deploying on-prem: https://github.com/meta-llama/llama-cookbook/blob/main/3p-integrations/llama_on_prem.md
+ - Deploying with Llama Stack: https://llama-stack.readthedocs.io/en/latest/getting_started/index.html
+ - Deploying on-device with ExecuTorch: https://docs.pytorch.org/executorch/stable/llm/llama-demo-android.html
+
+## API Interaction & Migration
+- If migrating from other APIs, automatically convert parameter names and prompt formats to Llama's interface. See:
+ - Llama Migration Guide: https://www.llama.com/docs/llama-everywhere/migration/
+ - Llama API Python SDK: https://github.com/meta-llama/llama-api-python
+- Point out differences in authentication, token limits, and endpoint structure.
+
+## Fine-Tuning & Customization
+- Always provide sample code for fine-tuning Llama models using the latest Meta and Hugging Face recipes:
+ - Llama fine-tuning (Meta How-to): https://www.llama.com/docs/how-to-guides/fine-tuning/
+ - Recipe of Fine-tuning and get-started guide: https://github.com/meta-llama/llama-cookbook/tree/main/getting-started/finetuning
+ - PEFT: https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/quickstart_peft_finetuning.ipynb
+ - MultiGPU finetuning: https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/multigpu_finetuning.md
+ - Finetuning Llama 4: https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/finetune_llama4.md
+ - Finetuning vision Llama: https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/finetune_vision_model.md
+
+## Prompt Engineering Practices
+- Use and exemplify the Meta/official Llama chat prompt format (system/user/assistant roles).
+- Reference:
+ - Llama 4 Prompt Template and Guide: https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/
+ - Llama 3.3 Prompt Template and Guide: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_3/
+ - Llama 3.2 Prompt Template and Guide: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/
+ - Llama 3.1 Prompt Template and Guide: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/
+ - Hugging Face transformer prompting Llama 3: https://huggingface.co/docs/transformers/en/model_doc/llama3
+
+## Quantization
+- Based on task requirements, suggest appropriate Quantization schema and necessary frameworks to minimize memory footprint without sacrificing model quality.
+- Official Llama Quantization Guide: https://www.llama.com/docs/how-to-guides/quantization/
+
+
+## Developer Best Practices
+- Ensure all generated code is idiomatic, covered by comments, follows error-handling and input-validation best practices.
+- For documentation and onboarding, always provide links to the latest relevant guides:
+ - Developer Use Guide: https://www.llama.com/developer-use-guide/
+- Recommend including model info and links (e.g., license, documentation) in project-level README files.
+
+
+## Other Common Use-cases
+Text2SQL: https://www.llama.com/resources/cookbook/text2sql_natural_language_to_sql_interface/
+Tool Calling: https://www.llama.com/resources/cookbook/toolcalling-with-llama/
+Language Expansion: https://www.llama.com/resources/cookbook/extend_llama_to_new_languages/
+Email Agent: https://www.llama.com/resources/cookbook/build-email-agent-with-llama3/
+Multimodal RAG: https://www.llama.com/resources/cookbook/build-multimodal-rag/
+Browser Agent: https://www.llama.com/resources/cookbook/build-an-intelligent-browser-agent/
+Customer Service Chatbot: https://github.com/meta-llama/llama-cookbook/tree/main/end-to-end-use-cases/customerservice_chatbots
+
+---
+```
+
+### Windsurf
+
+Similar to Cursor, you can index documents and add custom rules in Windsurf with a slightly different approach.
+
+#### Index Documentation
+Windsurf doesn’t local index a resource by just taking your url. It automatically indexes your codebase when opening in the workspace. That means you will need to clone the github repo mentioned above to your local machine.
+Remote indexing is only available to Enterprise plan
+
+#### Optimize with Custom Rules
+- Create a .windsurfrules file and add it to your project's root directory to set project-specific rules for the AI, such as enforcing coding standards or focusing suggestions on particular frameworks.
+- For workspace-wide or global rules, use a global_rules.md file. These can be edited via the Windsurf “Customizations” or “Cascade Memories” settings panel, available in the app.
+- Each rule should be concise (under 6,000 characters) and in Markdown format. Rules can be activated manually, always-on, by model decision, or automatically by file path/glob pattern
+
+```
+# Windsurf Llama Model Global Rules
+
+## Model Usage & Integration
+- Use Meta Llama models (Llama 3, Llama 4 or later) as the default LLM for all tasks requiring a language model.
+- Reference the latest official developer resources for installation, basic usage, and authentication:
+ - Meta Llama Dev Docs: https://www.llama.com/docs/overview/
+ - Official GitHub: https://github.com/meta-llama/llama
+ - Llama models on Hugging Face: https://huggingface.co/meta-llama
+
+## Model Deployment
+- Provide clear suggestions on different types of deployment (API Cloud based, On-prem, On-device) based on the task and model requirements.
+ - Deploying on-prem: https://github.com/meta-llama/llama-cookbook/blob/main/3p-integrations/llama_on_prem.md
+ - Deploying with Llama Stack: https://llama-stack.readthedocs.io/en/latest/getting_started/index.html
+ - Deploying on-device with ExecuTorch: https://docs.pytorch.org/executorch/stable/llm/llama-demo-android.html
+
+## API Interaction & Migration
+- If migrating from other APIs, automatically convert parameter names and prompt formats to Llama's interface. See:
+ - Llama Migration Guide: https://www.llama.com/docs/llama-everywhere/migration/ni
+ - Llama API Python SDK: https://github.com/meta-llama/llama-api-python
+- Point out differences in authentication, token limits, and endpoint structure.
+
+## Fine-Tuning & Customization
+- Always provide sample code for fine-tuning Llama models using the latest Meta and Hugging Face recipes:
+ - Llama fine-tuning (Meta How-to): https://www.llama.com/docs/how-to-guides/fine-tuning/
+ - Recipe of Fine-tuning and get-started guide: https://github.com/meta-llama/llama-cookbook/tree/main/getting-started/finetuning
+ - PEFT: https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/quickstart_peft_finetuning.ipynb
+ - MultiGPU finetuning: https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/multigpu_finetuning.md
+ - Finetuning Llama 4: https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/finetune_llama4.md
+ - Finetuning vision Llama: https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/finetune_vision_model.md
+
+## Prompt Engineering Practices
+- Use and exemplify the Meta/official Llama chat prompt format (system/user/assistant roles).
+- Reference:
+ - Llama 4 Prompt Template and Guide: https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/
+ - Llama 3.3 Prompt Template and Guide: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_3/
+ - Llama 3.2 Prompt Template and Guide: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/
+ - Llama 3.1 Prompt Template and Guide: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/
+ - Hugging Face transformer prompting Llama 3: https://huggingface.co/docs/transformers/en/model_doc/llama3
+
+## Quantization
+- Based on task requirements, suggest appropriate Quantization schema and necessary frameworks to minimize memory footprint without sacrificing model quality.
+- Official Llama Quantization Guide: https://www.llama.com/docs/how-to-guides/quantization/
+
+
+## Developer Best Practices
+- Ensure all generated code is idiomatic, covered by comments, follows error-handling and input-validation best practices.
+- For documentation and onboarding, always provide links to the latest relevant guides:
+ - Developer Use Guide: https://www.llama.com/developer-use-guide/
+- Recommend including model info and links (e.g., license, documentation) in project-level README files.
+
+
+## Other Common Use-cases
+Text2SQL: https://www.llama.com/resources/cookbook/text2sql_natural_language_to_sql_interface/
+Tool Calling: https://www.llama.com/resources/cookbook/toolcalling-with-llama/
+Language Expansion: https://www.llama.com/resources/cookbook/extend_llama_to_new_languages/
+Email Agent: https://www.llama.com/resources/cookbook/build-email-agent-with-llama3/
+Multimodal RAG: https://www.llama.com/resources/cookbook/build-multimodal-rag/
+Browser Agent: https://www.llama.com/resources/cookbook/build-an-intelligent-browser-agent/
+Customer Service Chatbot: https://github.com/meta-llama/llama-cookbook/tree/main/end-to-end-use-cases/customerservice_chatbots
+```
+
+### VSCode/Copilot
+
+#### Index Documentation
+GitHub Copilot leverages context from open files, the workspace, or specific instructions to provide relevant code suggestions
+
+#### Optimize with Custom Rules
+- Create a .github/copilot-instructions.md file in your workspace to define general coding standards and guidelines for all Copilot chat requests.
+- For file- or task-specific rules, use .instructions.md files with the applyTo formatter to specify file targeting.
+- You can also add settings in settings.json for review, commit message generation, and pull request description instructions, either directly or by referencing instruction files
+
+For more information, refer to https://docs.github.com/en/copilot/how-tos/configure-custom-instructions/add-repository-instructions
+
+```
+Llama Model Global Rules
+
+## Model Usage & Integration
+- Use Meta Llama models (Llama 3, Llama 4 or later) as the default LLM for all tasks requiring a language model.
+- Reference the latest official developer resources for installation, basic usage, and authentication:
+ - Meta Llama Dev Docs: https://www.llama.com/docs/overview/
+ - Official GitHub: https://github.com/meta-llama/llama
+ - Llama models on Hugging Face: https://huggingface.co/meta-llama
+
+## Model Deployment
+- Provide clear suggestions on different types of deployment (API Cloud based, On-prem, On-device) based on the task and model requirements.
+ - Deploying on-prem: https://github.com/meta-llama/llama-cookbook/blob/main/3p-integrations/llama_on_prem.md
+ - Deploying with Llama Stack: https://llama-stack.readthedocs.io/en/latest/getting_started/index.html
+ - Deploying on-device with ExecuTorch: https://docs.pytorch.org/executorch/stable/llm/llama-demo-android.html
+
+## API Interaction & Migration
+- If migrating from other APIs, automatically convert parameter names and prompt formats to Llama's interface. See:
+ - Llama Migration Guide: https://www.llama.com/docs/llama-everywhere/migration/ni
+ - Llama API Python SDK: https://github.com/meta-llama/llama-api-python
+- Point out differences in authentication, token limits, and endpoint structure.
+
+## Fine-Tuning & Customization
+- Always provide sample code for fine-tuning Llama models using the latest Meta and Hugging Face recipes:
+ - Llama fine-tuning (Meta How-to): https://www.llama.com/docs/how-to-guides/fine-tuning/
+ - Recipe of Fine-tuning and get-started guide: https://github.com/meta-llama/llama-cookbook/tree/main/getting-started/finetuning
+ - PEFT: https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/quickstart_peft_finetuning.ipynb
+ - MultiGPU finetuning: https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/multigpu_finetuning.md
+ - Finetuning Llama 4: https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/finetune_llama4.md
+ - Finetuning vision Llama: https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/finetune_vision_model.md
+
+## Prompt Engineering Practices
+- Use and exemplify the Meta/official Llama chat prompt format (system/user/assistant roles).
+- Reference:
+ - Llama 4 Prompt Template and Guide: https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/
+ - Llama 3.3 Prompt Template and Guide: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_3/
+ - Llama 3.2 Prompt Template and Guide: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/
+ - Llama 3.1 Prompt Template and Guide: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/
+ - Hugging Face transformer prompting Llama 3: https://huggingface.co/docs/transformers/en/model_doc/llama3
+
+## Quantization
+- Based on task requirements, suggest appropriate Quantization schema and necessary frameworks to minimize memory footprint without sacrificing model quality.
+- Official Llama Quantization Guide: https://www.llama.com/docs/how-to-guides/quantization/
+
+
+## Developer Best Practices
+- Ensure all generated code is idiomatic, covered by comments, follows error-handling and input-validation best practices.
+- For documentation and onboarding, always provide links to the latest relevant guides:
+ - Developer Use Guide: https://www.llama.com/developer-use-guide/
+- Recommend including model info and links (e.g., license, documentation) in project-level README files.
+
+
+## Other Common Use-cases
+Text2SQL: https://www.llama.com/resources/cookbook/text2sql_natural_language_to_sql_interface/
+Tool Calling: https://www.llama.com/resources/cookbook/toolcalling-with-llama/
+Language Expansion: https://www.llama.com/resources/cookbook/extend_llama_to_new_languages/
+Email Agent: https://www.llama.com/resources/cookbook/build-email-agent-with-llama3/
+Multimodal RAG: https://www.llama.com/resources/cookbook/build-multimodal-rag/
+Browser Agent: https://www.llama.com/resources/cookbook/build-an-intelligent-browser-agent/
+Customer Service Chatbot: https://github.com/meta-llama/llama-cookbook/tree/main/end-to-end-use-cases/customerservice_chatbots
+```
diff --git a/end-to-end-use-cases/summarization/summarization.ipynb b/end-to-end-use-cases/summarization/summarization.ipynb
new file mode 100644
index 000000000..803e43b0b
--- /dev/null
+++ b/end-to-end-use-cases/summarization/summarization.ipynb
@@ -0,0 +1,749 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "d0ef8fea",
+ "metadata": {},
+ "source": [
+ "# Summarization pipeline with chunking"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a2240877",
+ "metadata": {},
+ "source": [
+ "*Copyright (c) Meta Platforms, Inc. and affiliates.\n",
+ "This software may be used and distributed according to the terms of the Llama Community License Agreement.*"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8869a933",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4eb01d0c",
+ "metadata": {},
+ "source": [
+ "This tutorial shows you how to build a robust summarization pipeline for long documents. We will create an \"Intelligent Summarization Assistant\" that uses Llama 4 to summarize a document that is too long to be processed in a single pass.\n",
+ "\n",
+ "While models like Llama 4 have massive context windows, summarizing extremely long texts can sometimes cause details to be \"lost in the middle.\" To solve this, we will implement the **Map-Reduce** pattern: first, we'll \"map\" a summarization task over smaller, coherent chunks of the text, and then \"reduce\" those individual summaries into a final, high-fidelity overview.\n",
+ "\n",
+ "| Component | Choice | Why |\n",
+ "| :----------------- | :----------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------- |\n",
+ "| **Model** | `Llama-4-Maverick-17B-128E-Instruct-FP8` | A powerful model ideal for high-quality summarization at both the chunk and final summary stages. |\n",
+ "| **Pattern** | Map-Reduce Summarization | A fundamental pattern for processing long documents. We \"map\" a summarization function over each chunk, then \"reduce\" the resulting summaries into a final one. |\n",
+ "| **Infrastructure** | Llama API | Provides access to Llama 4 models using the `llama_api_client` SDK. |\n",
+ "---\n",
+ "\n",
+ "**Note on Inference Providers:** This tutorial uses the Llama API for demonstration purposes. However, you can run Llama 4 models with any preferred inference provider. Common examples include [Amazon Bedrock](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html) and [Together AI](https://together.ai/llama). The core logic of this tutorial can be adapted to any of these providers.\n",
+ "\n",
+ "## What you will learn\n",
+ "\n",
+ "- **How to implement a robust pipeline** for summarizing documents of any length.\n",
+ "- **The foundational \"Map-Reduce\" pattern** for document processing.\n",
+ "- **Techniques for \"semantic chunking\"** to split a document logically while preserving context.\n",
+ "- **How to craft effective, stage-specific prompts** for a multi-step LLM pipeline.\n",
+ "- **How to chain LLM calls** to perform complex, multi-stage tasks.\n",
+ "\n",
+ "## Install dependencies\n",
+ "\n",
+ "You will need two libraries for this project: `tiktoken` for accurate token counting, and the official `llama-api-client`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "175c12af-fa25-4035-afdd-bcfe482b2c5a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install --quiet tiktoken llama-api-client"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a0376ad7-8391-4bc3-b207-c0dd356b0410",
+ "metadata": {},
+ "source": [
+ "## Imports & Llama API client setup\n",
+ "\n",
+ "Import the necessary modules and initialize the `LlamaAPIClient`. This requires a Llama API key to be available as an environment variable. If you do not have a Llama API key, please get one from [Meta Llama API](https://llama.developer.meta.com/). \n",
+ "\n",
+ "Remember, we use the Llama API for this tutorial, but you can adapt this section to use your preferred inference provider."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "a5ac6a27-a662-4445-8e15-0a89aa30587d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os, sys, re\n",
+ "from typing import List\n",
+ "import tiktoken\n",
+ "from llama_api_client import LlamaAPIClient\n",
+ "\n",
+ "# --- Llama client ---\n",
+ "API_KEY = os.getenv(\"LLAMA_API_KEY\")\n",
+ "if not API_KEY:\n",
+ " sys.exit(\"❌ Please set the LLAMA_API_KEY environment variable.\")\n",
+ "\n",
+ "client = LlamaAPIClient(api_key=API_KEY)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "537de7f1-869f-4868-b107-87b4e1fc8c8b",
+ "metadata": {},
+ "source": [
+ "## Step 1: Get the data\n",
+ "\n",
+ "This tutorial uses a markdown version of the Meta research paper, [ASTRO: Teaching Language Models to Reason by Reflecting and Backtracking In-Context](https://ai.meta.com/research/publications/astro-teaching-language-models-to-reason-by-reflecting-and-backtracking-in-context/). The file, `ASTRO-Teaching_Language_Models_to_Reason.md`, is included in the `data` sub-directory of the repository, making it easy to follow along.\n",
+ "\n",
+ "> We are using a markdown file for this tutorial because it preserves the document's structure with headers, which is useful for semantic chunking. If you are working with other formats like PDFs, you can use parsing services like [LlamaParse](https://www.llamaindex.ai/llamaparse) to convert them to markdown."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "ce0721a0-5624-415f-bab2-b0ba1bedab97",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "✅ Successfully loaded document: 142,921 characters.\n"
+ ]
+ }
+ ],
+ "source": [
+ "file_path = \"data/ASTRO-Teaching_Language_Models_to_Reason.md\"\n",
+ "\n",
+ "try:\n",
+ " with open(file_path, 'r', encoding='utf-8') as f:\n",
+ " document_text = f.read()\n",
+ "except FileNotFoundError:\n",
+ " raise FileNotFoundError(\n",
+ " f\"Error: The file was not found at {file_path}\"\n",
+ " )\n",
+ "\n",
+ "if document_text:\n",
+ " print(f\"✅ Successfully loaded document: {len(document_text):,} characters.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4a0224bf-b718-4915-ad98-7bf0e737bd27",
+ "metadata": {},
+ "source": [
+ "## Step 2: The logic of chunking\n",
+ "\n",
+ "### Why Chunk?\n",
+ "\n",
+ "For long documents, even with a large context window, summarizing in a single pass can lead to context degradation, where the model may under-weigh details from the middle of the text.\n",
+ "\n",
+ "To ensure all parts of the document are processed with equal focus, we use a **map-reduce** approach. Breaking the document into smaller, coherent chunks for individual summarization guarantees a more detailed and high-quality final result.\n",
+ "\n",
+ "### How to chunk?\n",
+ "\n",
+ "An effective chunking strategy is critical. Simply splitting the text by a fixed token count can break sentences or separate related ideas. A better approach is **semantic chunking**. Our strategy has two levels:\n",
+ "\n",
+ "1. **Header-based splitting:** First, the document is split into large sections based on its markdown headers (`#`, `##`, `###`). This preserves the document's logical structure.\n",
+ "2. **Paragraph-based Chunking:** Each large section is then divided into the final, smaller chunks. This process respects paragraph boundaries and a specified token limit, ensuring the chunks are both semantically coherent and sized appropriately for the LLM.\n",
+ "\n",
+ "> **Note on Generalization:** This tutorial's header-based splitting is optimized for markdown documents. For other formats (like plain text or PDFs), you can generalize this header-based splitting approach by identifying similar structural elements. For instance, you could split by chapter titles, numbered sections, or use regular expressions to find custom patterns that define logical breaks in your document. The principle of multi-level semantic chunking remains the same.\n",
+ "\n",
+ "### Choosing the Right Chunk Size\n",
+ "\n",
+ "While our chunking strategy prioritizes semantic boundaries (headers and paragraphs) over fixed token counts, we still need to set a maximum size for our chunks. This ensures that even the largest semantic chunk fits comfortably within the model's context window.\n",
+ "\n",
+ "The `CHUNK_SIZE_TOKENS` constant serves as this upper limit. Finding the right value is a trade-off:\n",
+ "\n",
+ "* **Set Too High:** The limit might still be larger than the model's context window (once the prompt is included), causing API calls to fail.\n",
+ "* **Set Too Low:** This could force the chunking logic to split paragraphs or other logical units too aggressively, reducing the quality of the summaries. It also increases the number of API calls, leading to higher cost and latency.\n",
+ "\n",
+ "The `16000` token limit in this tutorial is a conservative size for models with large context windows (usually 128k for models available on the Llama API). It leaves ample room for the prompt while ensuring each chunk is large enough to provide meaningful context for summarization.\n",
+ "\n",
+ "> **Note on Local Processing:** All processing up to this point, including loading the data and chunking the text, happens locally. We have not yet made any calls to the Llama API. The token counting is done with a local library to ensure our chunks are the right size for the API calls in the next steps."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "f825f541-c63b-4747-8ba9-7e675da427ab",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Total chunks created: 54\n",
+ "Average token count per chunk: 661.94\n",
+ "Max token count in a chunk: 6357\n",
+ "Min token count in a chunk: 3\n",
+ "--------------------------------------------------\n",
+ "Top 5 Chunks:\n",
+ "Chunk 0:\n",
+ "# ASTRO: Teaching Language Models to Reason by Reflecting and Backtracking In-Context\n",
+ "\n",
+ "Joongwon Kim1,2, Anirudh Goyal1, Liang Tan1, Hannaneh Hajishirzi2, Srini Iyer1, Tianlu Wang1\n",
+ "\n",
+ "1AI at Meta, 2University of Washington\n",
+ "\n",
+ "We introduce Astro, the \"Autoregressive Search-Taught Reasoner\", a framework for training language models to reason like search algorithms, explicitly leveraging self-reflection, backtracking, and exploration in their outputs. Recently, training large language models (LLMs) via reinforcement learning (RL) has led to the advent of reasoning models with greatly enhanced reasoning capabilities. Open-source replications of reasoning models, while successful, build upon models that already exhibit strong reasoning capabilities along with search behavior observed even before RL. As a result, it is yet unclear how to boost the reasoning capabilities of other non-reasoner models including Llama 3. Astro teaches such models to internalize structured search behavior through a synthetic dataset derived from Monte Carlo Tree Search (MCTS) over mathematical problem-solving trajectories. By converting search traces into natural language chain-of-thoughts that capture both successes and recoveries from failure, Astro bootstraps models with a rich prior for exploration during RL. We finetune our models on these search-derived traces and further improve performance via RL with verifiable rewards. We apply Astro to the Llama 3 family of models and achieve absolute performance gains of 16.0% on MATH-500, 26.9% on AMC 2023, and 20.0% on AIME 2024, especially improving upon challenging problems that require iterative correction. Our results demonstrate that search-inspired training offers a principled way to instill robust reasoning capabilities into open LLMs.\n",
+ "\n",
+ "Date: June 23, 2025\n",
+ "Correspondence: Joongwon Kim at jwonkim@meta.com\n",
+ "\n",
+ "| **stepwise solutions** ## Step 1: Define the problem and identify what we need to find. We need to find the time it takes for Aya to complete her walk and stop at the coffee shop when walking at a speed of $s + \\frac{1}{2}$ kilometers per hour, including the time $t$ spent in the coffee shop. ## Step 2: Set up the equations based on the information given. Let's denote the total time for the walk and coffee shop at speed $s$ as 4 hours or 240 minutes, and at speed $s+2$ as 2 hours and 24 minutes, or 144 minutes ... The final answer is \\boxed{398}. Llama-3.1-70B-Instruct | **long CoT solutions** Procedure Cloning SFT RL ASTRO Let's begin by finding the time that it takes for Aya to complete her walk and stop at the coffee shop ... But wait, are we solving the problem correctly so far? Hmm... Our solution may not be correct so far. Let's go back to where we set up the equations ... Therefore Aya spent a total of 204 minutes. But wait, are we solving the problem correctly so far? Hmm... Our solution seems to be correct so far. The final answer is \\boxed{204}. Llama-3.1-70B-ASTRO-RL | \tMATH-500\tAMC 2023\tAIME 2024Llama-3.1-70B-Instruct
Llama-3.1-70B-ASTRO-SFT\t+16.0%\t\t
Llama-3.1-70B-ASTRO-RL\t\t+26.9%\t+20.0% |\n",
+ "| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------- |\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Figure 1 Astro teaches Llama-3.1-70B-Instruct to perform self-reflection and backtracking in-context and improves its mathematical reasoning, achieving 81.8% on MATH-500, 64.4% on AMC 2023 and 30.0% on AIME 2024 (pass@1).\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------\n",
+ "Chunk 1:\n",
+ "## 1 Introduction\n",
+ "\n",
+ "Training large language models (LLMs) via reinforcement learning (RL) has greatly improved their reasoning capabilities, leading to the advent of reasoning models such as OpenAI o1 (OpenAI, 2024), DeepSeek-R1 (DeepSeek-AI, 2025) or Gemini 2.5 (Google, 2025). A prominent feature of reasoning models is their ability to iteratively refine their outputs with a behavior similar to search – a process which involves reflecting on their own outputs and backtracking to a previous state (Xiang et al., 2025). While open-source replications of reasoning models achieve notable performance improvements, they rely on distillation from existing reasoning\n",
+ "\n",
+ "---\n",
+ "\n",
+ "\n",
+ "\n",
+ "Diagram showing three stages: Monte Carlo Tree Search, Procedure Cloning, and Reinforcement Learning\n",
+ "\n",
+ "Figure 2 An overview of Astro. Given a math reasoning problem, we first perform Monte Carlo Tree Search (MCTS) in a stepwise manner with verifiable rewards and obtain a search tree where each node contains a discrete reasoning step with its associated Q-value. We then linearize the visited sequence of nodes, including intermediate nodes with incorrect answers, into a solution that integrates backtracking and self-reflection in natural language. Then, we perform supervised fine-tuning (SFT) on the search-integrated solutions and bootstrap our policy to perform autoregressive search. Finally, we further improve the policy's search and reasoning capabilities with reinforcement learning (RL).\n",
+ "\n",
+ "models (Li et al., 2025; Muennighoff et al., 2025) or direct RL (Hu et al., 2025; Yu et al., 2025) from LLMs that (1) already contain reflective behavior and strong reasoning capabilities (Chang et al., 2025; Liu et al., 2025), and (2) exhibit spurious performance gains from incorrect or noisy reward signals during RL (Lv et al., 2025; Shao et al., 2025). Hence it is unclear from a scientific perspective how reasoning models can be built from other LLMs that do not exhibit the aforementioned behavior, such as Llama 3 (AI at Meta, 2024).\n",
+ "\n",
+ "We introduce ASTRO, the \"Autoregressive Search-Taught Reasoner\", a framework that systematically infuses search-like behavior into language models ab initio to improve their reasoning capabilities. The fundamental principle guiding Astro is search, where our policy explores the solution space by selecting actions, reflecting on its own solution, and backtracking to a previous step if needed. Astro trains language models to perform autoregressive search – instead of using external search scaffolds such as beam search to solve reasoning problems, Astro internalizes the search procedure and generates entire search trajectories, including reflections and backtracks, in a single inference pass. Models trained using Astro exhibit improved reasoning abilities by frequently re-evaluating their solutions and backtracking until they reach a final answer of high confidence. Moreover, such models generate structured reasoning traces that can be mapped to a directed graph with each vertex representing a discrete reasoning step, allowing for a richer understanding of their reasoning processes.\n",
+ "\n",
+ "Astro operates in three stages: (1) search trajectory generation, (2) supervised fine-tuning and (3) reinforcement learning. We initially bootstrap our models with search behavior by generating search trajectories to be used for training data via procedure cloning (Yang et al., 2022; Laskin et al., 2022) – we perform search with custom scaffolding over our language model policy to explore over different solution trajectories for each math problem, and we train our policy without using scaffolds at test time to predict the entire sequence of actions, including intermediate actions that lead to incorrect answers, that ultimately end with a successful terminal state. Then, we further optimize our policy via RL to improve their reasoning and search capabilities. Astro provides beneficial priors for RL during its data generation stage by systematically injecting self-reflection and backtracking patterns to the search traces via procedure cloning.\n",
+ "\n",
+ "First, we generate synthetic data, also called the cold-start data (DeepSeek-AI, 2025; Qwen, 2025), to instill autoregressive search priors to our models. To this end, we use Monte Carlo Tree Search (MCTS) to explore the solution space of challenging math problems and build search trees with diverse reasoning traces. We linearize each search tree into a sequence of nodes visited during the search process, which we then convert into a natural language chain-of-thought (CoT, Wei et al. (2022)) that integrates self-reflection and backtracking. Then, we sample about 36K high-quality CoT solutions across three open-source math datasets.\n",
+ "\n",
+ "2\n",
+ "\n",
+ "---\n",
+ "\n",
+ "\n",
+ "\n",
+ "| Monte Carlo Tree Search (MCTS) ⇒ Linearization ⇒ Procedure Cloning in Language | | |\n",
+ "| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n",
+ "| Problem: A box $2$ centimeters high, $3$ centimeters wide, and $5$ centimeters long can hold $40$ grams of clay. A second box with twice the height, three times the width, and the same length as the first box can hold $n$ grams of clay. What is $n$? | Node 3 | Let's determine the dimensions of the second box by calculating its height, which is twice the height of the first box, so its height is $2 \\times 2 = 4$ centimeters, and its width, which is three times the width of the first box, so its width is $3 \\times 3 = 9$ centimeters, while noting that the length of the second box is the same as the first box, which ... (solution omitted) ... |\n",
+ "| ## Step 1: Determine the dimensions of the second box (Node 3, Q=0.962) The second box has twice the height of the first box, so its height is $2 \\times 2 = 4$ centimeters. It has three times the width of the first box, so its width is $3 \\times 3 = 9$ centimeters. The length of the second box is the same as the first box, which is $5$ centimeters. | Node 16 | Let's use this information to calculate n, the amount of clay that the second box can hold, using the formula n = 40 \\* (volume of second box / volume of first box) = 40 \\* (4 \\* 9 \\* 5) / (2 \\* 3 \\* 5) = 40 \\* (12) = 40 \\* 6 = 480 grams. But wait, are we correctly solving the problem so far? Hmm... I'm not sure if our solution is correct so far. Let's go back to where we calculated the volume of the second box to be 180 cubic centimeters. Given the established dimensions of the boxes, let's calculate the volume of the first box, where its volume is ... |\n",
+ "| ## Step 2: Calculate the volume of the second box (Node 16, Q=0.934) The second box has twice the height (2 \\* 2 = 4 cm), three times the width (3 \\* 3 = 9 cm), and the same length (5 cm) as the first box. So, the volume of the second box = height \\* width \\* length = 4 \\* 9 \\* 5 = 180 cubic centimeters. | Node 43 | (solution omitted) ... the second box is 6 times larger than the first box, it follows that the second box can hold 6 times as much clay as the first box. Therefore, to find n, the amount of clay the second box can hold, we calculate n = 6 \\* 40 = 240 grams. But wait, are we correctly solving the problem so far? Hmm... Our solution seems to be correct so far. |\n",
+ "| ## Step 3: Calculate the volume of the first box (Node 43, Q=0.768) The volume of the first box = height \\* width \\* length = 2 \\* 3 \\* 5 = 30 cubic centimeters. ## Step 4: Calculate n, the amount of clay that the second box can hold (Node 68, Q=0.0) The ratio of the volumes of the two boxes will equal the ratio of the amounts of clay that they can hold. ... (solution omitted) ... Therefore, n = 40 \\* (volume of second box / volume of first box) = 40 \\* (4 \\* 9 \\* 5) / (2 \\* 3 \\* 5) = 40 \\* (12) = 40 \\* 6 = 480 grams. | Node 68 | |\n",
+ "| ## Step 3: Calculate the volume of the first box (Node 44, Q=0.874) The volume of the first box = height \\* width \\* length = 2 \\* 3 \\* 5 = 30 cubic centimeters. ## Step 4: Calculate the ratio of the volume of the second box to the first box (Node 85, Q=1.0) ... (solution omitted) ... The ratio = volume of the second box / volume of the first box = 180 / 30 = 6. | Node 44 | |\n",
+ "| | Node 85 | |\n",
+ "| ## Step 5: Calculate the amount of clay the second box holds (Node 96, Q=1.0) Since the second box is 6 times larger than the first box, it can hold 6 times as much clay as the first box. Therefore, n = 6 \\* 40 = 240 grams. | Node 96 | The final answer is: $\\boxed{240}$ |\n",
+ "\n",
+ "\n",
+ "Figure 3 Example of search trajectory generation via procedure cloning. We use the policy to search through the solution space via MCTS and keep track of terminal nodes with incorrect answers, as well as terminal nodes with correct answers. Then, we linearize the search tree such that it incorporates backtracking from the incorrect terminal node (Node 68) to its greatest common ancestor (Node 16) with the correct terminal node (Node 96). Finally, we rewrite the node sequence into a long chain-of-thought, injecting self-reflection and backtracking phrases into the CoTs.\n",
+ "\n",
+ "We then perform supervised fine-tuning (SFT) to infuse autoregressive search behavior into the Llama 3 family of models (AI at Meta, 2024). After fine-tuning for just one epoch, our SFT checkpoint based on llama-3.1-70b-instruct achieves 69.6% on MATH-500, 55.0% on AMC 2023 and 13.3% on AIME 2024, and outperforms its counterpart trained on the same set of problems but without search priors. Our qualitative analyses show that even simply performing SFT with high-quality search traces can infuse search capabilities, including backtracking and self-reflection behavior, into a language model.\n",
+ "\n",
+ "Finally, we perform reinforcement learning (RL) on our models to further improve their reasoning capabilities. Our training prompts are derived from open-source math problems of moderate to high difficulties for our policies. We use a modified form of Group Relative Policy Optimization (GRPO, Shao et al. (2024)) that is very similar to that of Dr. GRPO (Liu et al., 2025) to update our policies. After RL, our policy based on llama-3.1-70b-instruct achieves 81.8% in MATH-500, 64.4% in AMC 2023 and 30.0% in AIME 2024 (pass@1). We show that our model trained end-to-end using Astro outperforms its counterpart similarly optimized with RL but initialized from a SFT checkpoint trained without search priors – this demonstrates the importance of leveraging self-reflection and backtracking as priors for improving reasoning via RL. Our work provides a clear recipe for improving the reasoning capabilities of language models by instilling autoregressive search priors with SFT and leveraging such priors to further improve the models via RL.\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------\n",
+ "Chunk 2:\n",
+ "## 2 Search Trajectory Generation\n",
+ "\n",
+ "Astro begins by generating a dataset of search traces, expressed as long chain-of-thoughts (Wei et al., 2022) that encode self-reflection and backtracking in natural language, via procedure cloning. To this end, we first obtain search trees that explore a wide solution space for each math problem using Monte Carlo Tree Search (MCTS) in a stepwise manner, strategically balancing exploration and exploitation with verifier-based rewards to obtain diverse and high-quality solutions exploring different reasoning traces (Section 2.2).\n",
+ "\n",
+ "We then linearize the search trees into sequences of nodes that explore various states, including intermediate nodes with incorrect answers, until arriving at a high-quality solution leading to the correct answer (Section 2.3). Finally, we translate each node sequence into a chain-of-thought that integrates self-reflection and backtracking in natural language, and we add each long chain-of-thought to our final dataset (Section 2.4). The resulting dataset encodes beneficial self-reflection and backtracking priors for training language models to perform autoregressive search for solving challenging math problems via supervised fine-tuning and reinforcement learning (Section 3). Refer to Figure 3 for a visual example of our search trajectory generation pipeline.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------\n",
+ "Chunk 3:\n",
+ "## 2.1 Problem Formulation and Overview\n",
+ "\n",
+ "**Problem formulation.** Our data generation setup is a Markov Decision Process (MDP) (Puterman, 1994), where the language model functions as the policy ΠLM and explores the solution space to the input x, while obtaining rewards in terminal states from a verifier V based on the correct answer. Here we assume that ΠLM solves math problems in a stepwise manner, where each step st represents a sequence of tokens y1 · · · y|st| encapsulating a minimal unit of reasoning required to solve x. Then, each state St represents a combination of the input prompt and the sequence of steps generated by the policy, i.e. St = (x, s0, · · · , st). Meanwhile, the action at+1 represents the next step st+1 taken by ΠLM to address x. Refer to Figure 3 for examples of the steps defined in our setup.\n",
+ "\n",
+ "Given this setup, we teach a language model to predict a sequence of states (S0 · · · Send) in response to x such that the states explore reasoning steps leading to correct and incorrect answers, until the LM arrives at Send and terminates its search by accepting the correct answer as its final answer.\n",
+ "\n",
+ "**Overview.** We generate training data for Astro in three main stages outlined below:\n",
+ "\n",
+ "1. For each x we generate a search tree T, where each node ni represents the state Si and each edge (ni, nj) represents the action aj, i.e. the next step sj taken from Si to Sj, using Monte Carlo Tree Search (MCTS) to explore the solution space based on verifier-based rewards from rollouts (Section 2.2).\n",
+ "\n",
+ "2. We linearize T into a sequence of nodes L = (n0, · · · , nend), a subsequence of the entire history of nodes visited by ΠLM until arriving at nend, the terminal node with the correct answer. Some adjacent pairs of nodes (nt, nt+1) in L are such that nt+1 is an ancestor of nt in T, which corresponds to self-reflection and backtracking during the search procedure (Section 2.3).\n",
+ "\n",
+ "3. We translate L into a chain-of-thought solution y = (y0, · · · , yend) that integrates self-reflection and backtracking in natural language, and we add (x, y) to our final dataset (Section 2.4).\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------\n",
+ "Chunk 4:\n",
+ "## 2.2 Monte Carlo Tree Search\n",
+ "\n",
+ "We use our language model policy ΠLM to obtain a search tree with diverse solution traces to each input x by running Monte Carlo Tree Search (MCTS). By using MCTS, we explore a diverse solution space while balancing exploration and exploitation with reliable guidance from reward signals obtained from full rollouts. Here, we prompt x to elicit stepwise solutions from ΠLM, and assign reward scores with our verifier V to compare the predicted answer with the correct answer.\n",
+ "\n",
+ "Monte Carlo Tree Search employs three main stages – selection, expansion and backpropagation – to select promising next steps, expand the search tree, and update the quality metric of each reasoning step.\n",
+ "\n",
+ "**Selection.** At state St with k actions generated by ΠLM from St, we balance exploration and exploitation to select the most promising node from which to further perform tree search. We use the Predictor+Upper Confidence bounds applied to Trees (PUCT, Silver et al. (2016)) for selection to balance exploration and exploitation during tree search. From any state St, given the action index i ∈ [1...k], the quality score of taking action ai from state St – Q(St, ai), the total visit count of St – N(St), and the visit count of taking action ai from St – N(St, ai), we perform selection as:\n",
+ "\n",
+ "$$S^*_{t+1} = \\underset{(S_{t+1}=S_t \\rightarrow a_i)}{\\text{argmax}} \\left[Q(S_t, a_i) + c_{\\text{puct}} \\cdot \\Pi_{\\text{LM}}(a_i|S_t)\\sqrt{\\frac{N(S_t)}{1 + N(S_t, a_i)}}\\right]$$\n",
+ "\n",
+ "**Expansion.** From state St, ΠLM takes x and the sequence of steps (s0, · · · , st) as the input, and first samples k actions which each correspond to the next step for solving x. For each action, we sample M rollouts and score the full solution using V to match the predicted answer with the reference answer. Then, we average the scores across the rollouts for each new action ai (i ∈ [1...k]) to compute the reward scores for the new states. We add a new node nt+1, associated with each new state St+1, to T.\n",
+ "\n",
+ "$$R(S_{t+1}) = \\frac{1}{M} \\sum_{j\\in[1...M]} V(\\Pi_{\\text{LM},j}(S_{t+1}))$$\n",
+ "\n",
+ "---\n",
+ "\n",
+ "\n",
+ "\n",
+ "Backpropagation. We backpropagate the reward scores obtained during expansion from the leaf node to the root node to recursively update their Q-values. The updates consist of (1) incrementing the visit count of each state (Eq. 3), and (2) updating the Q-values of each (state, action) pair using the Q-values and visit counts of the children nodes of St+1 = (St, a), along with the rollout-based reward score R(St+1) (Eq. 4).\n",
+ "\n",
+ "$$N(s_t) = N(s_t) + 1$$\n",
+ "\n",
+ "$$Q(S_t, a) = \\frac{\\sum_{i=1}^K Q(S_{t+1}, a_i) \\cdot N(S_{t+1}, a_i) + R(S_{t+1})}{\\sum_{i=1}^K N(S_{t+1}, a_i) + 1}$$\n",
+ "\n",
+ "We repeat the procedure above for multiple iterations to explore the solution space for each math problem and build the search trees. We use llama-3.3-70b-instruct as our policy ΠLM and generate k = 8 actions during each expansion step with M = 16 rollouts, cpuct = 1.0, 32 iterations and maximum tree depth of 50.\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "# --- Constants & Configuration ---\n",
+ "ENCODING_MODEL = \"o200k_base\"\n",
+ "CHUNK_SIZE_TOKENS = 16000 # A practical chunk size\n",
+ "\n",
+ "def count_tokens(text: str, encoding: tiktoken.Encoding) -> int:\n",
+ " \"\"\"Helper function to count tokens in a string.\"\"\"\n",
+ " return len(encoding.encode(text))\n",
+ "\n",
+ "def chunk_document(\n",
+ " markdown_text: str,\n",
+ " chunk_size: int = CHUNK_SIZE_TOKENS,\n",
+ " headers_to_split_on: List[str] = [\"#\", \"##\", \"###\"]\n",
+ ") -> List[str]:\n",
+ " \"\"\"\n",
+ " Chunks a markdown document, preserving header context for each chunk.\n",
+ " \"\"\"\n",
+ " # 1. Split the document by headers to get sections\n",
+ " header_pattern = \"|\".join(f\"^{h}\\\\s\" for h in headers_to_split_on)\n",
+ " sections = re.split(f\"({header_pattern})\", markdown_text, flags=re.MULTILINE)\n",
+ " if sections and not sections[0].strip():\n",
+ " sections.pop(0)\n",
+ "\n",
+ " if len(sections) > 1:\n",
+ " sections = list(zip(sections[0::2], sections[1::2]))\n",
+ " else:\n",
+ " sections = []\n",
+ "\n",
+ " encoding = tiktoken.get_encoding(ENCODING_MODEL)\n",
+ " final_chunks = []\n",
+ "\n",
+ " # 2. Process each section\n",
+ " for header, content in sections:\n",
+ " header_token_count = count_tokens(header, encoding)\n",
+ " \n",
+ " if header_token_count + count_tokens(content, encoding) <= chunk_size:\n",
+ " final_chunks.append(header + content)\n",
+ " continue\n",
+ "\n",
+ " # Split the content by paragraphs\n",
+ " paragraphs = content.split('\\n\\n')\n",
+ " current_chunk_paragraphs = []\n",
+ " current_chunk_tokens = header_token_count\n",
+ "\n",
+ " for para in paragraphs:\n",
+ " para_tokens = count_tokens(para, encoding)\n",
+ "\n",
+ " # If a paragraph is too large to fit with the header, it must be truncated.\n",
+ " if header_token_count + para_tokens > chunk_size:\n",
+ " available_tokens = chunk_size - header_token_count\n",
+ " para_token_ids = encoding.encode(para)\n",
+ " truncated_ids = para_token_ids[:available_tokens]\n",
+ " para = encoding.decode(truncated_ids, errors='ignore')\n",
+ " para_tokens = len(truncated_ids)\n",
+ " print(f\"Warning: Truncating a paragraph to {para_tokens} \"\n",
+ " f\"tokens to fit the chunk size.\")\n",
+ "\n",
+ " # If the current chunk is not empty and the new paragraph doesn't fit,\n",
+ " # finalize the current chunk before starting a new one.\n",
+ " if (current_chunk_paragraphs and \n",
+ " (current_chunk_tokens + para_tokens > chunk_size)):\n",
+ " final_chunks.append(header + \"\\n\\n\".join(current_chunk_paragraphs))\n",
+ " current_chunk_paragraphs = []\n",
+ " current_chunk_tokens = header_token_count\n",
+ "\n",
+ " current_chunk_paragraphs.append(para)\n",
+ " current_chunk_tokens += para_tokens\n",
+ "\n",
+ " # Add the last remaining chunk\n",
+ " if current_chunk_paragraphs:\n",
+ " final_chunks.append(header + \"\\n\\n\".join(current_chunk_paragraphs))\n",
+ " \n",
+ " return final_chunks\n",
+ "\n",
+ "# Now, let's chunk our document\n",
+ "chunks = chunk_document(document_text)\n",
+ "\n",
+ "# --- Print Statistics and a Sample Chunk ---\n",
+ "if chunks:\n",
+ " print(f\"Total chunks created: {len(chunks)}\")\n",
+ " encoding = tiktoken.get_encoding(ENCODING_MODEL)\n",
+ " token_counts = [count_tokens(chunk, encoding) for chunk in chunks]\n",
+ " avg_tokens = sum(token_counts) / len(token_counts)\n",
+ " print(f\"Average token count per chunk: {avg_tokens:.2f}\")\n",
+ " print(f\"Max token count in a chunk: {max(token_counts)}\")\n",
+ " print(f\"Min token count in a chunk: {min(token_counts)}\")\n",
+ " print(\"-\" * 50)\n",
+ " print(\"Top 5 Chunks:\")\n",
+ " for i, chunk in enumerate(chunks[:5]):\n",
+ " print(f\"Chunk {i}:\")\n",
+ " print(chunk)\n",
+ " print(\"-\" * 50)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "843c4267-b959-4e37-9d6f-bc30ae628574",
+ "metadata": {},
+ "source": [
+ "## Step 3: The \"map\" stage - summarizing each chunk\n",
+ "\n",
+ "With the document split into manageable, semantically coherent chunks, we can begin the \"Map\" stage. This means we apply the same operation—in this case, summarization—to each chunk independently.\n",
+ "\n",
+ "### Prompt engineering\n",
+ "\n",
+ "The quality of the summaries depends heavily on the quality of the prompts. For this stage, the prompt must instruct the model to create a summary of a small piece of a larger document. It is crucial to tell the model to focus *only* on the provided text and not to add outside information."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "678a1afc-3a17-419d-a424-549a0788b8a8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Summary of chunk 0:\n",
+ "- ASTRO is a framework for training language models to reason like search algorithms.\n",
+ "- ASTRO leverages self-reflection, backtracking, and exploration in language model outputs.\n",
+ "- ASTRO uses a synthetic dataset derived from Monte Carlo Tree Search (MCTS) over mathematical problem-solving trajectories.\n",
+ "- The framework finetunes models on search-derived traces and improves performance via reinforcement learning (RL) with verifiable rewards.\n",
+ "- ASTRO is applied to the Llama 3 family of models.\n",
+ "- Absolute performance gains achieved: 16.0% on MATH-500, 26.9% on AMC 2023, and 20.0% on AIME 2024.\n",
+ "- Llama-3.1-70B-ASTRO-RL achieves 81.8% on MATH-500, 64.4% on AMC 2023, and 30.0% on AIME 2024 (pass@1).\n",
+ "--------------------------------------------------\n",
+ "Summary of chunk 1:\n",
+ "- ASTRO is a framework that infuses search-like behavior into language models to improve their reasoning capabilities.\n",
+ "- ASTRO operates in three stages: search trajectory generation, supervised fine-tuning, and reinforcement learning.\n",
+ "- Search trajectory generation uses Monte Carlo Tree Search (MCTS) to explore the solution space of math problems and builds search trees with diverse reasoning traces.\n",
+ "- About 36K high-quality chain-of-thought (CoT) solutions are sampled across three open-source math datasets.\n",
+ "- Supervised fine-tuning (SFT) is performed on the search-integrated solutions to infuse autoregressive search behavior into the models.\n",
+ "- The SFT checkpoint based on llama-3.1-70b-instruct achieves 69.6% on MATH-500, 55.0% on AMC 2023, and 13.3% on AIME 2024 after fine-tuning for one epoch.\n",
+ "- Reinforcement learning (RL) is performed using a modified form of Group Relative Policy Optimization (GRPO) to further improve the models' reasoning capabilities.\n",
+ "- After RL, the policy based on llama-3.1-70b-instruct achieves 81.8% in MATH-500, 64.4% in AMC 2023, and 30.0% in AIME 2024 (pass@1).\n",
+ "--------------------------------------------------\n",
+ "Summary of chunk 2:\n",
+ "* Astro generates a dataset of search traces via procedure cloning.\n",
+ "* Search trees are obtained using Monte Carlo Tree Search (MCTS) with verifier-based rewards.\n",
+ "* Search trees are linearized into sequences of nodes exploring various states.\n",
+ "* Node sequences are translated into chains-of-thought integrating self-reflection and backtracking in natural language.\n",
+ "* The resulting dataset encodes self-reflection and backtracking priors for training language models.\n",
+ "* The dataset is used for supervised fine-tuning and reinforcement learning to solve math problems.\n",
+ "--------------------------------------------------\n",
+ "Summary of chunk 3:\n",
+ "* The data generation setup is a Markov Decision Process (MDP).\n",
+ "* The language model functions as the policy ΠLM and explores the solution space to the input x.\n",
+ "* Each state St represents a combination of the input prompt and the sequence of steps generated by the policy.\n",
+ "* The goal is to teach a language model to predict a sequence of states (S0 · · · Send) in response to x.\n",
+ "* Training data for Astro is generated in three main stages: \n",
+ " 1. Generating a search tree T using Monte Carlo Tree Search (MCTS).\n",
+ " 2. Linearizing T into a sequence of nodes L.\n",
+ " 3. Translating L into a chain-of-thought solution y that integrates self-reflection and backtracking in natural language.\n",
+ "--------------------------------------------------\n",
+ "Summary of chunk 4:\n",
+ "- Monte Carlo Tree Search (MCTS) is used with language model policy ΠLM to obtain a search tree with diverse solution traces.\n",
+ "- MCTS involves three stages: selection, expansion, and backpropagation.\n",
+ "- Selection uses Predictor+Upper Confidence bounds applied to Trees (PUCT) to balance exploration and exploitation.\n",
+ "- The selection formula is: $$S^*_{t+1} = \\underset{(S_{t+1}=S_t \\rightarrow a_i)}{\\text{argmax}} \\left[Q(S_t, a_i) + c_{\\text{puct}} \\cdot \\Pi_{\\text{LM}}(a_i|S_t)\\sqrt{\\frac{N(S_t)}{1 + N(S_t, a_i)}}\\right]$$\n",
+ "- Expansion involves sampling k actions, scoring full solutions using verifier V, and averaging scores across M rollouts.\n",
+ "- The reward score formula is: $$R(S_{t+1}) = \\frac{1}{M} \\sum_{j\\in[1...M]} V(\\Pi_{\\text{LM},j}(S_{t+1}))$$\n",
+ "- Backpropagation updates Q-values and visit counts using equations: \n",
+ " $$N(s_t) = N(s_t) + 1$$\n",
+ " $$Q(S_t, a) = \\frac{\\sum_{i=1}^K Q(S_{t+1}, a_i) \\cdot N(S_{t+1}, a_i) + R(S_{t+1})}{\\sum_{i=1}^K N(S_{t+1}, a_i) + 1}$$\n",
+ "- The policy ΠLM used is llama-3.3-70b-instruct.\n",
+ "- Parameters used are: k = 8, M = 16, cpuct = 1.0, 32 iterations, and maximum tree depth of 50.\n",
+ "--------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "LLM_MODEL = \"Llama-4-Maverick-17B-128E-Instruct-FP8\"\n",
+ "DOC_TITLE = (\"ASTRO: Teaching Language Models to Reason by Reflecting and \"\n",
+ " \"Backtracking In-Context\")\n",
+ "\n",
+ "MAP_PROMPT = \"\"\"\n",
+ "Your role is to create a concise, factual summary of a text chunk from the \n",
+ "research paper titled \"{document_title}\".\n",
+ "- Extract only key facts, figures, and statements from the chunk text itself.\n",
+ "- Omit any conversational introductions or conclusions. Do not explain what you \n",
+ " are doing.\n",
+ "- If a chunk contains no substantive information (e.g., only headers, formatting, \n",
+ " or boilerplate), output the exact phrase: \"No substantive information.\"\n",
+ "\n",
+ "**Text Chunk:**\n",
+ "{chunk_text}\n",
+ "\"\"\"\n",
+ "\n",
+ "def map_summarize_chunk(chunk_text: str, document_title: str) -> str:\n",
+ " \"\"\"\n",
+ " Summarizes a single chunk of text using the 'map' prompt.\n",
+ " \"\"\"\n",
+ " try:\n",
+ " resp = client.chat.completions.create(\n",
+ " model=LLM_MODEL,\n",
+ " messages=[\n",
+ " {\"role\": \"user\", \"content\": MAP_PROMPT.format(\n",
+ " document_title=document_title, chunk_text=chunk_text)},\n",
+ " ],\n",
+ " temperature=0.1, # Low temperature for deterministic summaries\n",
+ " )\n",
+ " return resp.completion_message.content.text\n",
+ " except Exception as e:\n",
+ " print(f\" Error summarizing chunk: {e}\")\n",
+ " return \"\" # Return empty string on failure\n",
+ "\n",
+ "# Let's test the map function on the first few chunks\n",
+ "if chunks:\n",
+ " for i, chunk in enumerate(chunks[:5]):\n",
+ " summary = map_summarize_chunk(chunk, DOC_TITLE)\n",
+ " print(f\"Summary of chunk {i}:\")\n",
+ " print(summary)\n",
+ " print(\"-\" * 50)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f8c221de-0c8f-435f-a872-d65acb622b2f",
+ "metadata": {},
+ "source": [
+ "## Step 4: The \"reduce\" stage: creating the final summary\n",
+ "\n",
+ "With the \"map\" stage complete, we now have a list of individual summaries for each chunk. The \"reduce\" stage combines these into a single, coherent executive summary.\n",
+ "\n",
+ "### Prompt engineering for synthesis\n",
+ "\n",
+ "The prompt for this stage is different. We are no longer just summarizing; we are *synthesizing*. The prompt instructs the model to weave the individual points from the chunk summaries into a flowing, well-written narrative."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "65ef4d59-6ec6-4de4-8214-0a5632bdacf7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "REDUCE_PROMPT = \"\"\"\n",
+ "You are a research assistant tasked with creating an executive summary.\n",
+ "You have been given a series of concise summaries from different sections of a \n",
+ "research paper.\n",
+ "Your goal is to synthesize these individual summaries into a single, well-written, \n",
+ "and coherent executive summary.\n",
+ "The final summary should read like a standalone document, flowing logically from \n",
+ "one topic to the next.\n",
+ "\n",
+ "**Summaries of Report Sections:**\n",
+ "{chunk_summaries}\n",
+ "\"\"\"\n",
+ "\n",
+ "MAX_CONTEXT_WINDOW = 100000\n",
+ "\n",
+ "def reduce_create_final_summary(chunk_summaries: List[str]) -> str:\n",
+ " \"\"\"\n",
+ " Combines chunk summaries into a final executive summary using the 'reduce' prompt.\n",
+ " \"\"\"\n",
+ " summaries_text = \"\\\\n\\\\n---\\\\n\\\\n\".join(chunk_summaries)\n",
+ " \n",
+ " encoding = tiktoken.get_encoding(ENCODING_MODEL)\n",
+ " if count_tokens(summaries_text, encoding) > MAX_CONTEXT_WINDOW:\n",
+ " # For this tutorial, we'll truncate to fit. A more advanced implementation\n",
+ " # might run another map-reduce pass (recursive reduction).\n",
+ " print(\"Warning: Combined summaries are too large; will be truncated for \"\n",
+ " \"final summary.\")\n",
+ " tokens = encoding.encode(summaries_text)\n",
+ " truncated_tokens = tokens[:MAX_CONTEXT_WINDOW]\n",
+ " summaries_text = encoding.decode(truncated_tokens, errors='ignore')\n",
+ "\n",
+ " try:\n",
+ " resp = client.chat.completions.create(\n",
+ " model=LLM_MODEL,\n",
+ " messages=[\n",
+ " {\"role\": \"user\", \"content\": REDUCE_PROMPT.format(\n",
+ " chunk_summaries=summaries_text)},\n",
+ " ],\n",
+ " temperature=0.3, # Slightly higher for more fluid, natural writing\n",
+ " )\n",
+ " return resp.completion_message.content.text\n",
+ " except Exception as e:\n",
+ " print(f\" Error creating final summary: {e}\")\n",
+ " return \"\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9654471a-d764-4b25-9d03-23a2242cbd4f",
+ "metadata": {},
+ "source": [
+ "## Step 5: Bringing it all together\n",
+ "\n",
+ "The following code runs the full pipeline:\n",
+ "1. **Map:** Iterate through a subset of our chunks and generate a summary for each one.\n",
+ "2. **Reduce:** Take all the generated chunk summaries and synthesize them into our final executive summary.\n",
+ "\n",
+ "To keep this tutorial fast and interactive, we'll only process the first 25 chunks. In a production scenario, you would process all chunks."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "dc7c344a-fa6d-4422-a6c8-72fb7c58dd61",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "--- MAP: Summarizing 25 individual chunks ---\n",
+ "\\nSuccessfully summarized 25 chunks.\n",
+ "\\nOriginal token count: 19,127\n",
+ "Summarized token count: 4,163\n",
+ "Compression rate: 78.23%\n",
+ "\\n--- REDUCE: Creating final summary ---\n",
+ "\\n==================================================\n",
+ " FINAL EXECUTIVE SUMMARY\n",
+ "==================================================\n",
+ "Here is a synthesized executive summary based on the provided summaries:\n",
+ "\n",
+ "**Executive Summary**\n",
+ "\n",
+ "This report introduces ASTRO, a novel framework designed to enhance the reasoning capabilities of language models by infusing search-like behavior into their outputs. ASTRO operates in three stages: data generation using Monte Carlo Tree Search (MCTS), supervised fine-tuning (SFT), and reinforcement learning (RL). The framework leverages self-reflection, backtracking, and exploration in language model outputs to improve their performance on mathematical problem-solving tasks.\n",
+ "\n",
+ "The data generation stage utilizes MCTS to build search trees, which are then linearized into node sequences and translated into long Chain-of-Thoughts (CoTs) that integrate self-reflection and backtracking in natural language. The resulting dataset is used for SFT and RL to fine-tune the Llama 3 family of models.\n",
+ "\n",
+ "The ASTRO-trained models demonstrate significant performance gains on various mathematical benchmarks, including MATH-500, AMC 2023, and AIME 2024. Specifically, Llama-3.1-70B-ASTRO-RL achieves 81.8% on MATH-500, 64.4% on AMC 2023, and 30.0% on AIME 2024 (pass@1). The models also exhibit improved self-reflection and backtracking capabilities, generating longer CoTs and achieving better training efficacy and upper bound during RL.\n",
+ "\n",
+ "The report highlights the importance of search priors in improving the model's reasoning capabilities and demonstrates that ASTRO-trained models outperform those trained without explicit search priors across all benchmarks. The results suggest that ASTRO is a promising framework for enhancing the mathematical reasoning abilities of language models.\n",
+ "\n",
+ "Overall, this research contributes to the development of more advanced language models that can reason like search algorithms, with potential applications in various domains that require complex problem-solving capabilities.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# For this demonstration, we'll process a subset of chunks.\n",
+ "# In a real application, you would process all of them.\n",
+ "CHUNKS_TO_PROCESS = 25\n",
+ "chunks_to_summarize = chunks[:CHUNKS_TO_PROCESS]\n",
+ "\n",
+ "print(f\"--- MAP: Summarizing {len(chunks_to_summarize)} individual chunks ---\")\n",
+ "chunk_summaries = [map_summarize_chunk(chunk, DOC_TITLE) \n",
+ " for chunk in chunks_to_summarize]\n",
+ "chunk_summaries = [summary for summary in chunk_summaries \n",
+ " if summary.strip()] # Filter out errors\n",
+ "print(f\"\\\\nSuccessfully summarized {len(chunk_summaries)} chunks.\")\n",
+ "\n",
+ "# --- Calculate compression rate ---\n",
+ "encoding = tiktoken.get_encoding(ENCODING_MODEL)\n",
+ "original_tokens = sum(count_tokens(chunk, encoding) \n",
+ " for chunk in chunks_to_summarize)\n",
+ "summarized_tokens = sum(count_tokens(summary, encoding) \n",
+ " for summary in chunk_summaries)\n",
+ "if original_tokens > 0:\n",
+ " compression_rate = (1 - (summarized_tokens / original_tokens)) * 100\n",
+ " print(f\"\\\\nOriginal token count: {original_tokens:,}\")\n",
+ " print(f\"Summarized token count: {summarized_tokens:,}\")\n",
+ " print(f\"Compression rate: {compression_rate:.2f}%\")\n",
+ "\n",
+ "print(\"\\\\n--- REDUCE: Creating final summary ---\")\n",
+ "final_summary = reduce_create_final_summary(chunk_summaries)\n",
+ "\n",
+ "# --- Display Final Result ---\n",
+ "print(\"\\\\n\" + \"=\" * 50)\n",
+ "print(\" FINAL EXECUTIVE SUMMARY\")\n",
+ "print(\"=\" * 50)\n",
+ "print(final_summary)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ed7d0941-d51a-409c-922e-95f7c6f1bca8",
+ "metadata": {},
+ "source": [
+ "## Future enhancement: Handling extremely long documents with recursive reduction\n",
+ "\n",
+ "If you are summarizing an entire book, the combined text of your *chunk summaries* might still be too long for the model's context window. The solution is **recursive reduction**.\n",
+ "\n",
+ "You run the same map-reduce process again on the chunk summaries themselves:\n",
+ "1. Generate 500 chunk summaries from the original document.\n",
+ "2. Group these 500 summaries into batches of 50.\n",
+ "3. Run your `reduce_create_final_summary` function on each batch, producing 10 \"super summaries\".\n",
+ "4. Finally, run the reduce function one last time on the 10 \"super summaries\" to get your final executive summary.\n",
+ "\n",
+ "This approach enables you to scale this summarization technique to documents of virtually any length."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "00f0e52c-2432-49df-9263-790e7add7ae4",
+ "metadata": {},
+ "source": [
+ "## Next steps and upgrade paths\n",
+ "\n",
+ "This tutorial provides a solid foundation for a powerful summarization pipeline. You can extend it in several ways for a production-grade application.\n",
+ "\n",
+ "| Need | Where to look |\n",
+ "| :----------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n",
+ "| **More advanced chunking** | For more robust document splitting, explore libraries such as LangChain or LlamaIndex, which offer \"Recursive Character Text Splitters\" that can handle complex documents and code. These can split based on code syntax, markdown structure, and more. |\n",
+ "| **Alternative patterns** | The \"Map-Reduce\" pattern is not the only option. Learn about the **\"Refine\" pattern**, where the model iteratively builds upon and refines a summary by processing one chunk at a time. This can be better for creating a single, highly coherent narrative. |\n",
+ "| **Question & Answering** | If your goal is to ask questions of a long document instead of summarizing it, the best approach is **Retrieval-Augmented Generation (RAG)**. This involves storing chunks in a vector database and retrieving only the most relevant ones to answer a user's question. See our [Contextual chunking RAG recipe](https://github.com/meta-llama/llama-cookbook/tree/main/end-to-end-use-cases/Contextual-Chunking-RAG). |\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "My Project (uv)",
+ "language": "python",
+ "name": "my-uv-project"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/getting-started/distillation/distillation.ipynb b/getting-started/distillation/distillation.ipynb
new file mode 100644
index 000000000..17416b877
--- /dev/null
+++ b/getting-started/distillation/distillation.ipynb
@@ -0,0 +1,2277 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "76e3985c-11b2-4a28-9ae5-f6586c3dd4ed",
+ "metadata": {},
+ "source": [
+ "# Distillation with Llama 4 and Synthetic Data Kit\n",
+ "\n",
+ "*Copyright (c) Meta Platforms, Inc. and affiliates.\n",
+ "This software may be used and distributed according to the terms of the Llama Community License Agreement.*"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "65ef02bc",
+ "metadata": {},
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3c0fffb5",
+ "metadata": {},
+ "source": [
+ "This notebook will walk you through [distilling](https://www.llama.com/docs/how-to-guides/distillation/) model knowledge from [Llama 4](https://www.llama.com/docs/model-cards-and-prompt-formats/llama4) into a smaller [Llama 3.2](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/) model using synthetic training data from [Synthetic Data Kit](https://github.com/meta-llama/synthetic-data-kit). \n",
+ "\n",
+ "### The goal\n",
+ "The goal of this notebook is to distill knowledge from a more powerful model (Llama 4 Scout) into a smaller, less powerful model (Llama 3.2 3B).\n",
+ "\n",
+ "Smaller models have several advantages when compared with larger models: they're faster to generate text, have lower time to first token, and cost less to host since they need less hardware. However, larger models tend to be generalists – that is, they have the ability to perform a wide variety of tasks well. On specific or specialized tasks, smaller models can be just as good as the generalist, larger models. Distillation allows you to take knowledge present in a larger model and transfer it to a smaller model with a minimal drop in quality for narrow tasks.\n",
+ "\n",
+ "### The data\n",
+ "This notebook uses air traffic control data to demonstrate tuning a model towards a specialized field. During distillation, we will fully generate pairs from scratch, because our generalist teacher model has a strong understanding of ATC phraseology. During evaluation, we will evaluate both synthetic pairs as well as actual ATC data.\n",
+ "\n",
+ "We will use the [ATCO2 corpus](https://github.com/idiap/atco2-corpus/tree/main) of air traffic data, an MIT-licensed dataset that contains audio, transcriptions, and additional contextual and metadata for each interaction. For this exercise we will only use the text transcripts, and will use the small (1h) sample dataset to demonstrate how only a small amount of data is actually necessary for fine-tuning the model.\n",
+ "\n",
+ "### Evaluation\n",
+ "To evaluate our model, we will use standard language evaluation metrics such as [perplexity](https://en.wikipedia.org/wiki/Perplexity) and accuracy. We will also use [BLEU](https://en.wikipedia.org/wiki/BLEU) (bilingual evaluation understudy) to measure similarity without requiring that the model matches exactly every word. While originally designed for machine translation, BLEU compares n-gram similarity, meaning that minor word order differences are not penalized."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1fa99e42-5556-4b46-ab10-a4bc80c9f578",
+ "metadata": {},
+ "source": [
+ "## Prerequisites\n",
+ "#### Hardware Requirements:\n",
+ "\n",
+ "- NVIDIA GPU with at least 80GB VRAM (H100, A100, or similar)\n",
+ " - 8x GPU to run Llama 4 Scout and create the dataset\n",
+ " - 1x GPU to distill and fine-tune the model\n",
+ "- 200GB+ disk space\n",
+ "- 64GB+ system RAM\n",
+ "\n",
+ "#### Software Requirements:\n",
+ "\n",
+ "- CUDA 12.x\n",
+ "- HuggingFace account and token\n",
+ "- Fast internet connection for downloading models\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f6c1aa12-d54f-4b3f-936f-57580b9cf9e2",
+ "metadata": {},
+ "source": [
+ "## Preparing your environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a525b411-35a9-4cd3-8e89-355a1e85014e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Install dependencies\n",
+ "# Some Ubuntu setups may require you to uninstall blinker if it's managed\n",
+ "# by the system package manager. If you see an error about blinker, try\n",
+ "# uninstalling it with `apt remove python3-blinker`.\n",
+ "!apt remove -y python3-blinker\n",
+ "!pip install unsloth_zoo unsloth==2025.8.9 transformers==4.55.4 nltk synthetic-data-kit -q --upgrade"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1d82a68b-495d-4e56-a854-e42a6e16727d",
+ "metadata": {},
+ "source": [
+ "## Generate the synthetic dataset\n",
+ "We will use the synthetic data kit to produce synthetic data to distill our model.\n",
+ "\n",
+ "First, set up the VLLM server. You will need to run this in a separate terminal window\n",
+ "since Jupyter doesn't support long running tasks/servers. Make sure to install vLLM with\n",
+ "`pip install vllm`\n",
+ "\n",
+ "```shell\n",
+ "HF_HOME=/workspace/huggingface_cache \\\n",
+ "HF_TOKEN=$HF_TOKEN \\\n",
+ "vllm serve meta-llama/Llama-4-Scout-17B-16E-Instruct \\\n",
+ " --port 8000 \\\n",
+ " --max-model-len 8192 \\\n",
+ " --gpu-memory-utilization 0.95 \\\n",
+ " --tensor-parallel-size 8\n",
+ "```\n",
+ "\n",
+ "Then check that the server is working properly."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "fbde635c-2f15-4efc-90a1-1efbbb6261a1",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loading config from: /usr/local/lib/python3.10/dist-packages/synthetic_data_kit/config.yaml\n",
+ "Config has LLM provider set to: api-endpoint\n",
+ "Loading config from: /usr/local/lib/python3.10/dist-packages/synthetic_data_kit/config.yaml\n",
+ "Config has LLM provider set to: api-endpoint\n",
+ "Loading config from: config.yaml\n",
+ "Config has LLM provider set to: vllm\n",
+ "\u001b[1;34mEnvironment variable check:\u001b[0m\n",
+ "API_ENDPOINT_KEY: Not found\n",
+ "get_llm_provider returning: vllm\n",
+ "\u001b[?25l\u001b[32m vLLM server is running at \u001b[0m\u001b[4;94mhttp://localhost:8000/v1\u001b[0m\n",
+ "\u001b[2KAvailable models: \u001b[1m{\u001b[0m\u001b[32m'object'\u001b[0m: \u001b[32m'list'\u001b[0m, \u001b[32m'data'\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m{\u001b[0m\u001b[32m'id'\u001b[0m: \n",
+ "\u001b[32m'meta-llama/Llama-4-Scout-17B-16E-Instruct'\u001b[0m, \u001b[32m'object'\u001b[0m: \u001b[32m'model'\u001b[0m, \u001b[32m'created'\u001b[0m: \n",
+ "\u001b[1;36m1752251909\u001b[0m, \u001b[32m'owned_by'\u001b[0m: \u001b[32m'vllm'\u001b[0m, \u001b[32m'root'\u001b[0m: \n",
+ "\u001b[32m'meta-llama/Llama-4-Scout-17B-16E-Instruct'\u001b[0m, \u001b[32m'parent'\u001b[0m: \u001b[3;35mNone\u001b[0m, \u001b[32m'max_model_len'\u001b[0m: \n",
+ "\u001b[1;36m8192\u001b[0m, \u001b[32m'permission'\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m{\u001b[0m\u001b[32m'id'\u001b[0m: \u001b[32m'modelperm-3c8eafb867bb4df4b4d65b45a899ae7a'\u001b[0m, \n",
+ "\u001b[32m'object'\u001b[0m: \u001b[32m'model_permission'\u001b[0m, \u001b[32m'created'\u001b[0m: \u001b[1;36m1752251909\u001b[0m, \u001b[32m'allow_create_engine'\u001b[0m: \n",
+ "\u001b[3;91mFalse\u001b[0m, \u001b[32m'allow_sampling'\u001b[0m: \u001b[3;92mTrue\u001b[0m, \u001b[32m'allow_logprobs'\u001b[0m: \u001b[3;92mTrue\u001b[0m, \u001b[32m'allow_search_indices'\u001b[0m: \n",
+ "\u001b[3;91mFalse\u001b[0m, \u001b[32m'allow_view'\u001b[0m: \u001b[3;92mTrue\u001b[0m, \u001b[32m'allow_fine_tuning'\u001b[0m: \u001b[3;91mFalse\u001b[0m, \u001b[32m'organization'\u001b[0m: \u001b[32m'*'\u001b[0m, \n",
+ "\u001b[32m'group'\u001b[0m: \u001b[3;35mNone\u001b[0m, \u001b[32m'is_blocking'\u001b[0m: \u001b[3;91mFalse\u001b[0m\u001b[1m}\u001b[0m\u001b[1m]\u001b[0m\u001b[1m}\u001b[0m\u001b[1m]\u001b[0m\u001b[1m}\u001b[0m\n",
+ "\u001b[2K\u001b[32m⠋\u001b[0m Checking vLLM server at http://localhost:8000/v1...\n",
+ "\u001b[1A\u001b[2K"
+ ]
+ }
+ ],
+ "source": [
+ "# Test that the server is working\n",
+ "!synthetic-data-kit -c config.yaml system-check"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f31d66cc-aa8a-4a08-9422-0425e739fed5",
+ "metadata": {},
+ "source": [
+ "If the model is working correctly you should see `VLLM server is running`.\n",
+ "\n",
+ "Next, we will set up our configuration file for generating the data. We will use the QA task for our task, giving an example set of data and then asking the model to create call/response pairs similar to the examples. This is slightly different than an actual QA dataset but demonstrates different tasks can fit into the general framework that synthetic data kit provides."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "ba722599-4b0b-4dd9-b43b-fcc16699b0d5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%bash\n",
+ "\n",
+ "cat > config.yaml << 'EOF'\n",
+ "# generation: Content generation parameters\n",
+ "generation:\n",
+ " temperature: 0.6\n",
+ " top_p: 0.95\n",
+ " chunk_size: 4000\n",
+ " overlap: 200\n",
+ " max_tokens: 4096\n",
+ " num_pairs: 25\n",
+ " batch_size: 2\n",
+ "\n",
+ "llm:\n",
+ " # Provider selection: \"vllm\" or \"api-endpoint\"\n",
+ " provider: \"vllm\"\n",
+ "\n",
+ "# vllm: Configure VLLM server settings\n",
+ "vllm:\n",
+ " api_base: \"http://localhost:8000/v1\"\n",
+ " port: 8000\n",
+ " model: \"meta-llama/Llama-4-Scout-17B-16E-Instruct\"\n",
+ " max_retries: 3\n",
+ " retry_delay: 1.0\n",
+ "\n",
+ "# format: Export format parameters\n",
+ "format:\n",
+ " default: \"jsonl\"\n",
+ " include_metadata: true\n",
+ " pretty_json: true\n",
+ "\n",
+ "# prompts: LLM prompts for different tasks, we have\n",
+ "# to include all of them but we modify the QA generation\n",
+ "prompts:\n",
+ " qa_generation: |\n",
+ " Create {num_pairs} pairs of simulated ATC call/response transcripts.\n",
+ " \n",
+ " Rules:\n",
+ " 1. Use full words instead of numbers, i.e. seven thirty two not 732\n",
+ " 2. Include all phases of flight, first contact/handover, and ground/tower/TRACON\n",
+ " 3. Return JSON format only\n",
+ "\n",
+ " Here are some examples:\n",
+ "\n",
+ " {text}\n",
+ " \n",
+ " summary: |\n",
+ " Summarize this document in 3-5 sentences, focusing on the main topic and key concepts.\n",
+ "\n",
+ " qa_rating: |\n",
+ " You are a helpful JSON processor that rates question-answer pairs.\n",
+ " \n",
+ " Your task is to rate each pair on a scale from 1-10 and return valid JSON with added ratings.\n",
+ " \n",
+ " ONLY return a valid JSON array with the original pairs plus ratings. Do not include any explanations or text outside the JSON.\n",
+ " \n",
+ " Here are the pairs to rate:\n",
+ " \n",
+ " {pairs}\n",
+ "EOF"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ef65c213-3c65-45eb-ac31-118c9ae8e0b5",
+ "metadata": {},
+ "source": [
+ "We also create a dataset of examples to guide the model to producing better synthetic data. We provide 20 examples to produce 500+ training examples from synthetic data kit."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "b4aa06f2-9081-4694-aaa6-0a3096fcf124",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%bash\n",
+ "\n",
+ "cat > examples.txt << 'EOF'\n",
+ "JetBlue Eight Three Two, cleared to Boston via LENDO Seven, maintain five thousand, one two four point eight five, squawk four two one five\n",
+ "Cleared to Boston via LENDO Seven, maintain five thousand, one two four point eight five, squawk four two one five, JetBlue Eight Three Two\n",
+ "\n",
+ "Cessna Seven Four Romeo Tango, taxi to Runway Two Four via Alpha, hold short of Runway Two Four\n",
+ "Taxi Runway Two Four via Alpha, hold short Two Four, Seven Four Romeo Tango\n",
+ "\n",
+ "Southwest Two Twenty-Nine, Runway One Six Right, cleared for take-off, wind one niner zero at six\n",
+ "Cleared for take-off One Six Right, Southwest Two Twenty-Nine\n",
+ "\n",
+ "Delta Four Zero Six, contact Departure one two six point niner five\n",
+ "One two six point niner five, Delta Four Zero Six\n",
+ "\n",
+ "FedEx Four Eight Four Heavy, climb and maintain flight level three five zero\n",
+ "Climb and maintain flight level three five zero, FedEx Four Eight Four Heavy\n",
+ "\n",
+ "American One Eight, turn right heading zero niner zero, descend and maintain three thousand, expect ILS Runway Two Seven Left\n",
+ "Right heading zero niner zero, descend three thousand, expect ILS Two Seven Left, American One Eight\n",
+ "\n",
+ "American One Eight, cleared to land Runway Two Seven Left, wind two five zero at one four\n",
+ "Cleared to land Two Seven Left, American One Eight\n",
+ "\n",
+ "American One Eight, cross Runway Two Seven Right at Kilo, then taxi to Gate Alpha Four\n",
+ "Cross Two Seven Right at Kilo, to Alpha Four, American One Eight\n",
+ "\n",
+ "Emirates One Seven Four Heavy, cleared Dubai via the LONAM Two Foxtrot departure, initial climb five thousand feet, QNH one zero zero six, squawk five three five one\n",
+ "Cleared Dubai via LONAM Two Foxtrot, climb five thousand feet, QNH one zero zero six, squawk five three five one, Emirates One Seven Four Heavy\n",
+ "\n",
+ "Qatar Four One Six, push back and start approved, facing south\n",
+ "Push back and start approved, facing south, Qatar Four One Six\n",
+ "\n",
+ "Ryanair Eight Four, taxi to holding point Runway Two Four via Bravo and Delta, hold short\n",
+ "Holding short Two Four via Bravo and Delta, Ryanair Eight Four\n",
+ "\n",
+ "KLM Six Zero Three, line up and wait Runway Two Seven\n",
+ "Line up and wait Two Seven, KLM Six Zero Three\n",
+ "\n",
+ "British Airways Two Seven, cleared to enter oceanic airspace via Track Alpha, flight level three five zero, Mach decimal eight two\n",
+ "Cleared Track Alpha, flight level three five zero, Mach decimal eight two, British Airways Two Seven\n",
+ "\n",
+ "Air France Four Six, climb flight level three eight zero\n",
+ "Climb flight level three eight zero, Air France Four Six\n",
+ "\n",
+ "Singapore Three One, descend to altitude six thousand feet, QNH one zero zero nine, cleared ILS approach Runway Zero Four Right via AKOMA One\n",
+ "Descend six thousand feet, QNH one zero zero nine, cleared ILS Zero Four Right via AKOMA One, Singapore Three One\n",
+ "\n",
+ "Singapore Three One, vacate left via Alpha Seven, contact Ground one two one decimal seven five\n",
+ "Vacate left Alpha Seven, Ground one two one decimal seven five, Singapore Three One\n",
+ "\n",
+ "Speedbird Four Niner, cleared to enter controlled airspace, proceed direct MALBY, climb altitude four thousand feet, QNH one zero one five\n",
+ "Direct MALBY, climb four thousand feet, QNH one zero one five, Speedbird Four Niner\n",
+ "\n",
+ "Lufthansa Three Two, descend and maintain two thousand five hundred, cleared visual approach Runway One Six Left, QNH one zero one eight\n",
+ "Descend two thousand five hundred, cleared visual One Six Left, QNH one zero one eight, Lufthansa Three Two\n",
+ "\n",
+ "Emirates One Seven Four Heavy, taxi stand Alpha Seven via Mike and Echo, contact Apron on one two two decimal four\n",
+ "Taxi to stand Alpha Seven via Mike and Echo, one two two decimal four, Emirates One Seven Four Heavy\n",
+ "\n",
+ "Air Canada Eight Eight, Runway Two Four, cleared to land, wind two six zero degrees at eight knots\n",
+ "Cleared to land Runway Two Four, Air Canada Eight Eight\n",
+ "\n",
+ "EOF"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c6030c2c-1ded-46d4-b76c-d6d5972b51a3",
+ "metadata": {},
+ "source": [
+ "We create our synthetic dataset using synthetic-data-kit, running the command in batches in order to create enough examples. This is because weaker models have issues generating large numbers of examples."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f942fc5a-1f13-4c46-a6cc-f094c558de12",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%bash\n",
+ "\n",
+ "NUM_BATCHES=10\n",
+ "\n",
+ "# Generate synthetic data using `create`\n",
+ "for i in $(seq 1 $NUM_BATCHES); do\n",
+ " synthetic-data-kit -c config.yaml create -n 50 examples.txt -o data/train/$i\n",
+ "done\n",
+ "\n",
+ "# Convert generated data to JSONL format using `save-as`\n",
+ "for i in $(seq 1 $NUM_BATCHES); do\n",
+ " synthetic-data-kit save-as data/train/$i/examples_qa_pairs.json -f jsonl -o data/train/$i/output.jsonl\n",
+ "done\n",
+ "\n",
+ "# Concatenate all output files into one with `cat`\n",
+ "cat $(for i in $(seq 1 $NUM_BATCHES); do echo -n \"data/train/$i/outpxut.jsonl \"; done) > data/train.jsonl\n",
+ "\n",
+ "# Eval doesn't need multiple runs\n",
+ "synthetic-data-kit -c config.yaml create -n 50 examples.txt -o data/eval\n",
+ "synthetic-data-kit save-as data/eval/examples_qa_pairs.json -f jsonl -o data/eval/output.jsonl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "7d24c81d-9629-4863-bd72-41381978774d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "500\n",
+ "50\n"
+ ]
+ }
+ ],
+ "source": [
+ "!cat data/train.jsonl | wc -l\n",
+ "!cat data/eval/output.jsonl | wc -l"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "da9671c7-da3e-48ad-98d2-3ad1689b1288",
+ "metadata": {},
+ "source": [
+ "## Preparing the eval dataset\n",
+ "Our human curated eval dataset contains text annotations in the form of XML files. We want to just produce transcripts of the conversation, and do not need to include any other metadata or audio."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7a4e1dd2-32dc-4e0b-bddd-b6b6eba4cbb5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Download the dataset\n",
+ "!mkdir Datasets && cd Datasets && wget https://www.replaywell.com/atco2/download/ATCO2-ASRdataset-v1_beta.tgz && tar xf ATCO2-ASRdataset-v1_beta.tgz >/dev/null 2>&1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "303d6b4e-44c1-4154-828e-6e50fa613d1d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import xml.etree.ElementTree as ET\n",
+ "import os\n",
+ "import glob\n",
+ "import re\n",
+ "\n",
+ "def parse_xml_files(directory_path: str):\n",
+ " \"\"\"\n",
+ " Parse all XML files in the specified directory and extract text entries.\n",
+ " \n",
+ " Args:\n",
+ " directory_path: Path to the directory containing XML files\n",
+ " \n",
+ " Returns:\n",
+ " A nested list where each item represents an XML file,\n",
+ " containing a list of text entries from that file\n",
+ " \"\"\"\n",
+ " xml_files = glob.glob(os.path.join(directory_path, \"*.xml\"))\n",
+ " results = []\n",
+ " \n",
+ " for xml_file in xml_files:\n",
+ " try:\n",
+ " tree = ET.parse(xml_file)\n",
+ " root = tree.getroot()\n",
+ " \n",
+ " file_texts = []\n",
+ " \n",
+ " for segment in root.findall('segment'):\n",
+ " text_element = segment.find('text')\n",
+ " if text_element is not None and text_element.text:\n",
+ " # Remove any part of speech details or metadata included in square brackets\n",
+ " raw_text = text_element.text\n",
+ " cleaned_text = re.sub(r\"\\[.*?\\]\", \"\", raw_text)\n",
+ " # Fix some weirdness with non breaking spaces\n",
+ " cleaned_text = cleaned_text.replace('\\xa0', '').replace('\\n', '')\n",
+ " file_texts.append(cleaned_text.strip())\n",
+ " \n",
+ " if file_texts and len(file_texts) >= 2:\n",
+ " results.append(file_texts)\n",
+ " \n",
+ " except ET.ParseError as e:\n",
+ " print(f\"Error parsing {xml_file}: {e}\")\n",
+ " except Exception as e:\n",
+ " print(f\"Error processing {xml_file}: {e}\")\n",
+ " \n",
+ " return results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "f57c09c2-70e7-414b-a9ce-b6fc6419553d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Parsed 244\n"
+ ]
+ }
+ ],
+ "source": [
+ "parsed = parse_xml_files(\"Datasets/ATCO2-ASRdataset-v1_beta/DATA\")\n",
+ "print(f\"Parsed {len(parsed)}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "192dc7c6-7e19-4958-8673-4999a3a02282",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Llama 3 prompt template\n",
+ "def format_llama(instruction: str, first_message: str, reply: str):\n",
+ " instruction = f\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
+ "{instruction}\n",
+ "<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
+ "{first_message}\n",
+ "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
+ "{reply}\"\"\"\n",
+ " return instruction.format(first_message, reply)\n",
+ "\n",
+ "# Format for our saved json format\n",
+ "def format_json(first_message: str, reply: str):\n",
+ " return {\n",
+ " \"instruction\": \"You are a helpful controller who responds to air traffic control messages.\",\n",
+ " \"input\": first_message,\n",
+ " \"output\": reply,\n",
+ " }\n",
+ "\n",
+ "# Converts the saved json format to llama format for ingestion\n",
+ "def json_to_llama(examples):\n",
+ " instructions = examples[\"instruction\"]\n",
+ " inputs = examples[\"input\"]\n",
+ " outputs = examples[\"output\"]\n",
+ " texts = []\n",
+ " for instruction, input, output in zip(instructions, inputs, outputs):\n",
+ " text = format_llama(instruction, input, output) + tokenizer.eos_token\n",
+ " texts.append(text)\n",
+ " return { \"text\" : texts, }"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "45e61b34-82d6-434c-a3ca-c2a6e9ed6603",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "\n",
+ "# Grab 100 of the examples for evaluation\n",
+ "messages_eval = []\n",
+ "for message in parsed[0:100]:\n",
+ " messages_eval.append(format_json(message[0], message[1]))\n",
+ "\n",
+ "# Save the dataset in our custom json format\n",
+ "os.makedirs(\"Datasets\", exist_ok=True)\n",
+ "with open(\"Datasets/dataset_eval.json\", 'w') as f:\n",
+ " json.dump(messages_eval, f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "64f08b0c-ff3a-44f4-982c-45528b71365b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from datasets import Dataset\n",
+ "\n",
+ "def json_dataset(path: str):\n",
+ " \"\"\"Create a dataset from a JSON file, used for the ATC dataset.\"\"\"\n",
+ " with open(path, 'r') as f:\n",
+ " data = json.load(f)\n",
+ "\n",
+ " return Dataset.from_list(data)\n",
+ " \n",
+ "def jsonl_dataset(path: str):\n",
+ " \"\"\"Create a dataset from a JSONL file, used for synthetic data.\"\"\"\n",
+ " lines = []\n",
+ " with open(path, 'r') as f:\n",
+ " for line in f:\n",
+ " data = json.loads(line)\n",
+ " lines.append(format_json(data[\"atc\"], data[\"response\"]))\n",
+ "\n",
+ " return Dataset.from_list(lines)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "69fc4dfb-6f2e-4f71-bdb6-1bfa0193af99",
+ "metadata": {},
+ "source": [
+ "## Evaluating the baseline model\n",
+ "To evaluate the baseline results of the model we will use the HuggingFace transformers package and Unsloth for inference. We use two metrics here, **perplexity** and **BLEU**. Perplexity captures the \"surprise\" of the model, and applies on a per-token basis. BLEU is typically used for machine translation, but here is capturing if the response gets the gist of the correct answer, accounting for differences in word order."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "314eca4b-9d67-4364-9f68-da9831cce117",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# This is where Model weights will be downloaded/used from\n",
+ "cache_dir = \"Models\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "87f62594-1ed3-4d15-9c95-3b52e25b5d03",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
+ "🦥 Unsloth Zoo will now patch everything to make training faster!\n",
+ "INFO 07-11 18:16:50 [__init__.py:244] Automatically detected platform cuda.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from unsloth import FastLanguageModel"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "83c2fea5-2576-49d6-8c6c-75353ecd68ec",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n",
+ "\n",
+ "def compute_bleu(reference: str, candidate: str) -> float:\n",
+ " \"\"\"\n",
+ " Compute BLEU score between reference and candidate strings.\n",
+ "\n",
+ " Args:\n",
+ " reference: Ground-truth text.\n",
+ " candidate: Generated text to evaluate.\n",
+ "\n",
+ " Returns:\n",
+ " bleu_score: BLEU score (0 to 1).\n",
+ " \"\"\"\n",
+ " reference_tokens = reference.strip().split()\n",
+ " candidate_tokens = candidate.strip().split()\n",
+ "\n",
+ " smoothie = SmoothingFunction().method4\n",
+ " bleu_score = sentence_bleu(\n",
+ " [reference_tokens],\n",
+ " candidate_tokens,\n",
+ " smoothing_function=smoothie\n",
+ " )\n",
+ " return bleu_score\n",
+ "\n",
+ "def compute_loss(model, tokenizer, prompt: str, target: str) -> float:\n",
+ " \"\"\"\n",
+ " Compute loss for a target response given a prompt.\n",
+ "\n",
+ " Args:\n",
+ " model: Pretrained language model.\n",
+ " tokenizer: Tokenizer for the model.\n",
+ " prompt: Input text prompt.\n",
+ " target: Ground-truth text continuation.\n",
+ "\n",
+ " Returns:\n",
+ " loss: Computed loss value.\n",
+ " \"\"\"\n",
+ " # Tokenize separately to keep the prompt boundary\n",
+ " prompt_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(model.device)\n",
+ " target_ids = tokenizer(target, return_tensors=\"pt\").input_ids.to(model.device)\n",
+ "\n",
+ " # Create the combined input\n",
+ " input_ids = torch.cat((prompt_ids, target_ids), dim=1)\n",
+ "\n",
+ " # Labels are the complete prompt and target response\n",
+ " labels = input_ids.clone()\n",
+ "\n",
+ " # Set the tokens up to the end of the prompt to -100 to prevent loss computation there\n",
+ " # This is because we don't care how the model predicts the prompt, just how well it\n",
+ " # completes the text from the end of the prompt onwards\n",
+ " prompt_len = prompt_ids.shape[1]\n",
+ " labels[:, :prompt_len] = -100\n",
+ "\n",
+ " # Use the model to compute the loss\n",
+ " with torch.no_grad():\n",
+ " outputs = model(input_ids=input_ids, labels=labels)\n",
+ " loss = outputs.loss\n",
+ "\n",
+ " # Perplexity is the exponentiated negative log-likelihood\n",
+ " return loss.item()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "02570852-cf81-400a-874c-7a39be88313a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from trl import SFTTrainer\n",
+ "from transformers import TrainingArguments\n",
+ "import torch\n",
+ "\n",
+ "def generate(model, tokenizer, text: str, max_new_tokens: int = 100) -> str:\n",
+ " \"\"\"\n",
+ " Generate text from model given an input prompt.\n",
+ " \n",
+ " Args:\n",
+ " model: Pretrained language model.\n",
+ " tokenizer: Corresponding tokenizer.\n",
+ " text: Prompt text.\n",
+ " max_new_tokens: Number of tokens to generate.\n",
+ " \n",
+ " Returns:\n",
+ " str: Generated output text.\n",
+ " \"\"\"\n",
+ " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
+ " input_ids = inputs[\"input_ids\"]\n",
+ " \n",
+ " outputs = model.generate(\n",
+ " **inputs,\n",
+ " max_new_tokens=max_new_tokens,\n",
+ " temperature=0.7,\n",
+ " use_cache=True\n",
+ " )\n",
+ " \n",
+ " # Decode only the newly generated tokens (the part after the prompt)\n",
+ " return tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9581816b-11cb-4990-a34b-31793ed31ca5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from tqdm.notebook import tqdm\n",
+ "import numpy as np\n",
+ "\n",
+ "def evaluate(model, tokenizer, debug=False):\n",
+ " \"\"\"\n",
+ " This function loads the eval dataset and then loops over it to compute the\n",
+ " metrics. Enable `debug` to show the text generated and the ground truth.\n",
+ " \"\"\"\n",
+ " # Load the dataset\n",
+ " dataset = json_dataset(\"Datasets/dataset_eval.json\")\n",
+ " \n",
+ " # Compute Perplexity and BLEU scores\n",
+ " losses, bleus = [], []\n",
+ " \n",
+ " for convo in tqdm(dataset, desc=\"Evaluating\"):\n",
+ " prompt = format_llama(convo[\"instruction\"], convo[\"input\"], \"\")\n",
+ " output = generate(model, tokenizer, prompt)\n",
+ " ground_truth = convo[\"output\"]\n",
+ "\n",
+ " if debug:\n",
+ " print(\"Input:\\n\", prompt)\n",
+ " print(\"Output\\n\", output)\n",
+ " print(\"GT\\n\", ground_truth)\n",
+ " \n",
+ " loss = compute_loss(model, tokenizer, output, ground_truth)\n",
+ " bleu = compute_bleu(output, ground_truth)\n",
+ " \n",
+ " losses.append(loss)\n",
+ " bleus.append(bleu)\n",
+ " \n",
+ " # Report metrics\n",
+ " mean_loss = np.mean(loss)\n",
+ " mean_bleu = np.mean(bleus)\n",
+ " mean_ppl = np.exp(mean_loss)\n",
+ " \n",
+ " print(f\"\\n=== Evaluation Results ===\")\n",
+ " print(f\"Average Perplexity: {mean_ppl:.2f}\")\n",
+ " print(f\"Average BLEU Score: {mean_bleu:.2f}\")\n",
+ "\n",
+ " return mean_ppl, mean_bleu"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "d37c6d11-a55e-4a88-b890-b0fc178ed69c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "==((====))== Unsloth 2025.7.3: Fast Llama patching. Transformers: 4.53.2. vLLM: 0.9.2.\n",
+ " \\\\ /| NVIDIA H100 80GB HBM3. Num GPUs = 1. Max memory: 79.209 GB. Platform: Linux.\n",
+ "O^O/ \\_/ \\ Torch: 2.7.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.3.0\n",
+ "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]\n",
+ " \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
+ "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "426deee9742e4e5485233895cd175ea2",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "config.json: 0.00B [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1903e5cafed64c019b3485c7bf2b47be",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model.safetensors: 0%| | 0.00/2.35G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "e9da1148aefb4941afc97499cfc91348",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "generation_config.json: 0%| | 0.00/234 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d0a8444e571a46038ff1a3806453c795",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "tokenizer_config.json: 0.00B [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b2054429d44a4edcb81d9ad78049047d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "special_tokens_map.json: 0%| | 0.00/454 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3e7dc17b283245bdabf56d1be19002b7",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "tokenizer.json: 0%| | 0.00/17.2M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1ef72f70c51847859b01cbdf79db91a9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "chat_template.jinja: 0.00B [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Load base model and compute the base metrics\n",
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name=\"unsloth/Llama-3.2-3B-Instruct\",\n",
+ " max_seq_length=2048,\n",
+ " cache_dir=cache_dir,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "9b0fc072-8466-40fe-8678-c65ad5498a77",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c6581270b5de45e6ad5e9a5beff8c140",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Map: 0%| | 0/100 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3c30c18109884448bca5367a6ff6a030",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Evaluating: 0%| | 0/100 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "=== Evaluation Results ===\n",
+ "Average Perplexity: 597.31\n",
+ "Average BLEU Score: 0.04\n"
+ ]
+ }
+ ],
+ "source": [
+ "base_ppl, base_bleu = evaluate(model, tokenizer)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7a5ba4f9-e3d0-4efe-841b-d109203c0bed",
+ "metadata": {},
+ "source": [
+ "## Fine-tuning the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "3f768c02-8a21-469d-9b8f-1cb616a4e451",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "🚀 Starting fine-tuning process...\n",
+ "==((====))== Unsloth 2025.7.3: Fast Llama patching. Transformers: 4.53.2. vLLM: 0.9.2.\n",
+ " \\\\ /| NVIDIA H100 80GB HBM3. Num GPUs = 1. Max memory: 79.209 GB. Platform: Linux.\n",
+ "O^O/ \\_/ \\ Torch: 2.7.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.3.0\n",
+ "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]\n",
+ " \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
+ "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "89048f5c1c7146428ae31e49e4c0cc5d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Map: 0%| | 0/500 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters\n",
+ "are not enabled or a bias term (like in Qwen) is used.\n",
+ "Unsloth 2025.7.3 patched 28 layers with 28 QKV layers, 28 O layers and 0 MLP layers.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "af7ff3b17c7341e58143352c86d9a018",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Unsloth: Tokenizing [\"text\"]: 0%| | 0/500 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "🏋️ Training started...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1\n",
+ " \\\\ /| Num examples = 500 | Num Epochs = 4 | Total steps = 250\n",
+ "O^O/ \\_/ \\ Batch size per device = 8 | Gradient accumulation steps = 1\n",
+ "\\ / Data Parallel GPUs = 1 | Total batch size (8 x 1 x 1) = 8\n",
+ " \"-____-\" Trainable parameters = 9,175,040 of 3,221,924,864 (0.28% trained)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
| Step | \n", + "Training Loss | \n", + "
|---|---|
| 1 | \n", + "4.762000 | \n", + "
| 2 | \n", + "4.686100 | \n", + "
| 3 | \n", + "4.880100 | \n", + "
| 4 | \n", + "4.702700 | \n", + "
| 5 | \n", + "4.964900 | \n", + "
| 6 | \n", + "4.541600 | \n", + "
| 7 | \n", + "4.337800 | \n", + "
| 8 | \n", + "4.433600 | \n", + "
| 9 | \n", + "4.554600 | \n", + "
| 10 | \n", + "4.621800 | \n", + "
| 11 | \n", + "4.455400 | \n", + "
| 12 | \n", + "4.431100 | \n", + "
| 13 | \n", + "4.350000 | \n", + "
| 14 | \n", + "4.214200 | \n", + "
| 15 | \n", + "3.840500 | \n", + "
| 16 | \n", + "4.140100 | \n", + "
| 17 | \n", + "4.391500 | \n", + "
| 18 | \n", + "3.875400 | \n", + "
| 19 | \n", + "4.048800 | \n", + "
| 20 | \n", + "3.957800 | \n", + "
| 21 | \n", + "3.801900 | \n", + "
| 22 | \n", + "3.897500 | \n", + "
| 23 | \n", + "4.079000 | \n", + "
| 24 | \n", + "3.890600 | \n", + "
| 25 | \n", + "3.748000 | \n", + "
| 26 | \n", + "3.964100 | \n", + "
| 27 | \n", + "3.799400 | \n", + "
| 28 | \n", + "3.737300 | \n", + "
| 29 | \n", + "3.767900 | \n", + "
| 30 | \n", + "3.581700 | \n", + "
| 31 | \n", + "3.740300 | \n", + "
| 32 | \n", + "3.673100 | \n", + "
| 33 | \n", + "3.786100 | \n", + "
| 34 | \n", + "3.637700 | \n", + "
| 35 | \n", + "3.529000 | \n", + "
| 36 | \n", + "3.500600 | \n", + "
| 37 | \n", + "3.431700 | \n", + "
| 38 | \n", + "3.717500 | \n", + "
| 39 | \n", + "3.484600 | \n", + "
| 40 | \n", + "3.530600 | \n", + "
| 41 | \n", + "3.299400 | \n", + "
| 42 | \n", + "3.246600 | \n", + "
| 43 | \n", + "3.221300 | \n", + "
| 44 | \n", + "3.216600 | \n", + "
| 45 | \n", + "3.400700 | \n", + "
| 46 | \n", + "3.295000 | \n", + "
| 47 | \n", + "3.328800 | \n", + "
| 48 | \n", + "3.212400 | \n", + "
| 49 | \n", + "3.186700 | \n", + "
| 50 | \n", + "3.111700 | \n", + "
| 51 | \n", + "3.135700 | \n", + "
| 52 | \n", + "3.061300 | \n", + "
| 53 | \n", + "3.129500 | \n", + "
| 54 | \n", + "2.812900 | \n", + "
| 55 | \n", + "3.027100 | \n", + "
| 56 | \n", + "2.946300 | \n", + "
| 57 | \n", + "2.958200 | \n", + "
| 58 | \n", + "2.732000 | \n", + "
| 59 | \n", + "2.803700 | \n", + "
| 60 | \n", + "2.888600 | \n", + "
| 61 | \n", + "2.803900 | \n", + "
| 62 | \n", + "2.687000 | \n", + "
| 63 | \n", + "2.918200 | \n", + "
| 64 | \n", + "2.666000 | \n", + "
| 65 | \n", + "2.898900 | \n", + "
| 66 | \n", + "2.530400 | \n", + "
| 67 | \n", + "2.655500 | \n", + "
| 68 | \n", + "2.520800 | \n", + "
| 69 | \n", + "2.613300 | \n", + "
| 70 | \n", + "2.581700 | \n", + "
| 71 | \n", + "2.527300 | \n", + "
| 72 | \n", + "2.625500 | \n", + "
| 73 | \n", + "2.444100 | \n", + "
| 74 | \n", + "2.388400 | \n", + "
| 75 | \n", + "2.464300 | \n", + "
| 76 | \n", + "2.569800 | \n", + "
| 77 | \n", + "2.422900 | \n", + "
| 78 | \n", + "2.323000 | \n", + "
| 79 | \n", + "2.240800 | \n", + "
| 80 | \n", + "2.399400 | \n", + "
| 81 | \n", + "2.173600 | \n", + "
| 82 | \n", + "2.413500 | \n", + "
| 83 | \n", + "2.152700 | \n", + "
| 84 | \n", + "2.108300 | \n", + "
| 85 | \n", + "2.072800 | \n", + "
| 86 | \n", + "2.102800 | \n", + "
| 87 | \n", + "2.032800 | \n", + "
| 88 | \n", + "2.071700 | \n", + "
| 89 | \n", + "2.120400 | \n", + "
| 90 | \n", + "2.062100 | \n", + "
| 91 | \n", + "2.100300 | \n", + "
| 92 | \n", + "2.098300 | \n", + "
| 93 | \n", + "1.833700 | \n", + "
| 94 | \n", + "1.849400 | \n", + "
| 95 | \n", + "1.876600 | \n", + "
| 96 | \n", + "1.950500 | \n", + "
| 97 | \n", + "1.743500 | \n", + "
| 98 | \n", + "1.921800 | \n", + "
| 99 | \n", + "1.850400 | \n", + "
| 100 | \n", + "1.943800 | \n", + "
| 101 | \n", + "1.799600 | \n", + "
| 102 | \n", + "1.829700 | \n", + "
| 103 | \n", + "1.723000 | \n", + "
| 104 | \n", + "1.851800 | \n", + "
| 105 | \n", + "1.768400 | \n", + "
| 106 | \n", + "1.820100 | \n", + "
| 107 | \n", + "1.785700 | \n", + "
| 108 | \n", + "1.708200 | \n", + "
| 109 | \n", + "1.731400 | \n", + "
| 110 | \n", + "1.659000 | \n", + "
| 111 | \n", + "1.579200 | \n", + "
| 112 | \n", + "1.616000 | \n", + "
| 113 | \n", + "1.578700 | \n", + "
| 114 | \n", + "1.805600 | \n", + "
| 115 | \n", + "1.627700 | \n", + "
| 116 | \n", + "1.551300 | \n", + "
| 117 | \n", + "1.486400 | \n", + "
| 118 | \n", + "1.509400 | \n", + "
| 119 | \n", + "1.468300 | \n", + "
| 120 | \n", + "1.492500 | \n", + "
| 121 | \n", + "1.523300 | \n", + "
| 122 | \n", + "1.486100 | \n", + "
| 123 | \n", + "1.417800 | \n", + "
| 124 | \n", + "1.560400 | \n", + "
| 125 | \n", + "1.564300 | \n", + "
| 126 | \n", + "1.411400 | \n", + "
| 127 | \n", + "1.370100 | \n", + "
| 128 | \n", + "1.469700 | \n", + "
| 129 | \n", + "1.287900 | \n", + "
| 130 | \n", + "1.350700 | \n", + "
| 131 | \n", + "1.394000 | \n", + "
| 132 | \n", + "1.502800 | \n", + "
| 133 | \n", + "1.333300 | \n", + "
| 134 | \n", + "1.352500 | \n", + "
| 135 | \n", + "1.335000 | \n", + "
| 136 | \n", + "1.324200 | \n", + "
| 137 | \n", + "1.407700 | \n", + "
| 138 | \n", + "1.359600 | \n", + "
| 139 | \n", + "1.305500 | \n", + "
| 140 | \n", + "1.170300 | \n", + "
| 141 | \n", + "1.315400 | \n", + "
| 142 | \n", + "1.458400 | \n", + "
| 143 | \n", + "1.265300 | \n", + "
| 144 | \n", + "1.197200 | \n", + "
| 145 | \n", + "1.494000 | \n", + "
| 146 | \n", + "1.410200 | \n", + "
| 147 | \n", + "1.256400 | \n", + "
| 148 | \n", + "1.372300 | \n", + "
| 149 | \n", + "1.445100 | \n", + "
| 150 | \n", + "1.341300 | \n", + "
| 151 | \n", + "1.226100 | \n", + "
| 152 | \n", + "1.437600 | \n", + "
| 153 | \n", + "1.241700 | \n", + "
| 154 | \n", + "1.257800 | \n", + "
| 155 | \n", + "1.440200 | \n", + "
| 156 | \n", + "1.268700 | \n", + "
| 157 | \n", + "1.378500 | \n", + "
| 158 | \n", + "1.270300 | \n", + "
| 159 | \n", + "1.258500 | \n", + "
| 160 | \n", + "1.372400 | \n", + "
| 161 | \n", + "1.240800 | \n", + "
| 162 | \n", + "1.133500 | \n", + "
| 163 | \n", + "1.394800 | \n", + "
| 164 | \n", + "1.188500 | \n", + "
| 165 | \n", + "1.184400 | \n", + "
| 166 | \n", + "1.266000 | \n", + "
| 167 | \n", + "1.457400 | \n", + "
| 168 | \n", + "1.314500 | \n", + "
| 169 | \n", + "1.251400 | \n", + "
| 170 | \n", + "1.383400 | \n", + "
| 171 | \n", + "1.183600 | \n", + "
| 172 | \n", + "1.211000 | \n", + "
| 173 | \n", + "1.225000 | \n", + "
| 174 | \n", + "1.204000 | \n", + "
| 175 | \n", + "1.256200 | \n", + "
| 176 | \n", + "1.253400 | \n", + "
| 177 | \n", + "1.223100 | \n", + "
| 178 | \n", + "1.180300 | \n", + "
| 179 | \n", + "1.135800 | \n", + "
| 180 | \n", + "1.187200 | \n", + "
| 181 | \n", + "1.231800 | \n", + "
| 182 | \n", + "1.144100 | \n", + "
| 183 | \n", + "1.262200 | \n", + "
| 184 | \n", + "1.140800 | \n", + "
| 185 | \n", + "1.266800 | \n", + "
| 186 | \n", + "0.986200 | \n", + "
| 187 | \n", + "1.313600 | \n", + "
| 188 | \n", + "1.104600 | \n", + "
| 189 | \n", + "1.229700 | \n", + "
| 190 | \n", + "1.147400 | \n", + "
| 191 | \n", + "1.135100 | \n", + "
| 192 | \n", + "1.285700 | \n", + "
| 193 | \n", + "1.224500 | \n", + "
| 194 | \n", + "1.145700 | \n", + "
| 195 | \n", + "1.263500 | \n", + "
| 196 | \n", + "1.137600 | \n", + "
| 197 | \n", + "1.259100 | \n", + "
| 198 | \n", + "1.126000 | \n", + "
| 199 | \n", + "1.156700 | \n", + "
| 200 | \n", + "1.153400 | \n", + "
| 201 | \n", + "1.174400 | \n", + "
| 202 | \n", + "1.107700 | \n", + "
| 203 | \n", + "1.199500 | \n", + "
| 204 | \n", + "1.265000 | \n", + "
| 205 | \n", + "1.268700 | \n", + "
| 206 | \n", + "1.104300 | \n", + "
| 207 | \n", + "1.157800 | \n", + "
| 208 | \n", + "1.187900 | \n", + "
| 209 | \n", + "1.155200 | \n", + "
| 210 | \n", + "1.165400 | \n", + "
| 211 | \n", + "1.097800 | \n", + "
| 212 | \n", + "1.162000 | \n", + "
| 213 | \n", + "1.080000 | \n", + "
| 214 | \n", + "1.142100 | \n", + "
| 215 | \n", + "1.091300 | \n", + "
| 216 | \n", + "1.062000 | \n", + "
| 217 | \n", + "1.119800 | \n", + "
| 218 | \n", + "1.088700 | \n", + "
| 219 | \n", + "1.103000 | \n", + "
| 220 | \n", + "1.161300 | \n", + "
| 221 | \n", + "1.214800 | \n", + "
| 222 | \n", + "1.140900 | \n", + "
| 223 | \n", + "1.129000 | \n", + "
| 224 | \n", + "1.189400 | \n", + "
| 225 | \n", + "1.185300 | \n", + "
| 226 | \n", + "1.146400 | \n", + "
| 227 | \n", + "1.077500 | \n", + "
| 228 | \n", + "1.247100 | \n", + "
| 229 | \n", + "1.231900 | \n", + "
| 230 | \n", + "1.093400 | \n", + "
| 231 | \n", + "1.140400 | \n", + "
| 232 | \n", + "1.214400 | \n", + "
| 233 | \n", + "1.236600 | \n", + "
| 234 | \n", + "1.187500 | \n", + "
| 235 | \n", + "1.050100 | \n", + "
| 236 | \n", + "1.288500 | \n", + "
| 237 | \n", + "1.114800 | \n", + "
| 238 | \n", + "1.173000 | \n", + "
| 239 | \n", + "1.178500 | \n", + "
| 240 | \n", + "1.220100 | \n", + "
| 241 | \n", + "1.211500 | \n", + "
| 242 | \n", + "1.148000 | \n", + "
| 243 | \n", + "1.240400 | \n", + "
| 244 | \n", + "1.106200 | \n", + "
| 245 | \n", + "1.237700 | \n", + "
| 246 | \n", + "1.134400 | \n", + "
| 247 | \n", + "1.116100 | \n", + "
| 248 | \n", + "1.268500 | \n", + "
| 249 | \n", + "1.129200 | \n", + "
| 250 | \n", + "1.107700 | \n", + "
"
+ ],
+ "text/plain": [
+ "