Skip to content

Deploy Kura on Modal Labs for distributed GPU processing #101

@jxnl

Description

@jxnl

Problem Statement

For large-scale conversation analysis (100k+ conversations), running Kura on a single machine becomes impractical. Modal Labs provides serverless GPU infrastructure that can dynamically scale processing across multiple GPUs, making it ideal for distributed batch processing of embeddings and LLM inference.

Why Modal Labs?

  1. Serverless GPU Access: Pay only for compute time used
  2. Auto-scaling: Automatically scale to hundreds of GPUs
  3. Built-in Parallelization: Easy distributed computing primitives
  4. Cost Efficiency: No idle GPU time, automatic shutdown
  5. Developer Experience: Simple Python decorators for GPU functions

Current Architecture Limitations

  • Single-machine processing bottleneck
  • Cannot parallelize across multiple GPUs
  • No fault tolerance for long-running jobs
  • Manual infrastructure management
  • Fixed resource allocation

Proposed Modal Labs Architecture

1. Core Modal Functions

import modal
from modal import Image, gpu, method

# Define container image with dependencies
kura_image = (
    Image.debian_slim()
    .pip_install(
        "kura",
        "torch",
        "transformers",
        "sentence-transformers",
        "vllm",
        "numpy",
        "pandas"
    )
    .run_commands("mkdir -p /cache")
)

app = modal.App("kura-processing")

@app.cls(
    image=kura_image,
    gpu=gpu.A100(memory=40),  # Can also use gpu.A10G() for smaller tasks
    concurrency_limit=10,
    container_idle_timeout=300,
)
class DistributedKuraProcessor:
    def __enter__(self):
        """Initialize models on container startup"""
        from kura.embedding import VLLMEmbeddingModel
        from kura.summarisation import VLLMSummaryModel
        
        # Models are loaded once per container
        self.embedding_model = VLLMEmbeddingModel(
            model_name="BAAI/bge-large-en-v1.5"
        )
        self.summary_model = VLLMSummaryModel(
            model_name="meta-llama/Llama-2-13b-chat-hf"
        )
    
    @method()
    async def process_conversation_batch(
        self, 
        conversations: List[Dict],
        batch_id: int
    ) -> Dict:
        """Process a batch of conversations"""
        # Convert dicts back to Conversation objects
        conv_objects = [Conversation.model_validate(c) for c in conversations]
        
        # Summarize
        summaries = await self.summary_model.summarise_batch(conv_objects)
        
        # Embed
        texts = [str(s) for s in summaries]
        embeddings = await self.embedding_model.embed_batch(texts)
        
        # Return results
        return {
            "batch_id": batch_id,
            "summaries": [s.model_dump() for s in summaries],
            "embeddings": embeddings
        }

2. Batch Distribution Strategy

@app.function(
    image=kura_image,
    timeout=3600 * 24,  # 24 hour timeout for large jobs
)
async def distributed_kura_pipeline(
    dataset_name: str,
    n_conversations: int,
    batch_size: int = 1000,
    checkpoint_dir: str = "/tmp/checkpoints"
):
    """Main orchestration function"""
    from kura.types import Conversation
    import asyncio
    
    # Load conversations
    conversations = Conversation.from_hf_dataset(
        dataset_name,
        split=f"train[:{n_conversations}]"
    )
    
    # Determine optimal parallelization
    n_gpus = min(100, n_conversations // 5000)  # 1 GPU per 5k conversations
    
    # Create batches
    batches = []
    for i in range(0, len(conversations), batch_size):
        batch = conversations[i:i + batch_size]
        batches.append({
            "batch_id": i // batch_size,
            "conversations": [c.model_dump() for c in batch]
        })
    
    # Distribute processing across Modal GPUs
    processor = DistributedKuraProcessor()
    results = []
    
    # Process in parallel with progress tracking
    async with processor.map.aio() as pool:
        async for result in pool.map(
            lambda b: processor.process_conversation_batch.remote(
                b["conversations"], 
                b["batch_id"]
            ),
            batches
        ):
            results.append(result)
            print(f"Completed batch {result['batch_id']} of {len(batches)}")
    
    # Aggregate results
    all_summaries = []
    all_embeddings = []
    
    for result in sorted(results, key=lambda x: x["batch_id"]):
        all_summaries.extend(result["summaries"])
        all_embeddings.extend(result["embeddings"])
    
    # Continue with clustering (can also be distributed)
    await distributed_clustering(all_summaries, all_embeddings)

3. Distributed Clustering

@app.function(
    image=kura_image,
    gpu=gpu.T4(),  # Smaller GPU sufficient for clustering
    memory=32768,  # 32GB RAM
)
async def distributed_clustering(
    summaries: List[Dict],
    embeddings: List[List[float]],
    n_clusters: int = None
):
    """Perform clustering on distributed results"""
    from kura.cluster import ClusterModel
    from sklearn.cluster import MiniBatchKMeans
    import numpy as np
    
    # Convert to numpy array
    X = np.array(embeddings)
    
    if n_clusters is None:
        n_clusters = int(np.sqrt(len(summaries) / 2))
    
    # Use MiniBatchKMeans for large datasets
    kmeans = MiniBatchKMeans(
        n_clusters=n_clusters,
        batch_size=10000,
        max_iter=100
    )
    
    labels = kmeans.fit_predict(X)
    
    # Create cluster assignments
    clusters = {}
    for idx, label in enumerate(labels):
        if label not in clusters:
            clusters[label] = []
        clusters[label].append(summaries[idx])
    
    return clusters

4. Checkpoint Management

@app.function(
    image=kura_image,
    volumes={"/cache": modal.Volume.persisted("kura-checkpoints")},
)
async def save_checkpoint(
    data: Dict,
    checkpoint_name: str,
    format: str = "parquet"
):
    """Save checkpoints to Modal volumes"""
    import pyarrow.parquet as pq
    import pyarrow as pa
    
    if format == "parquet":
        # Save as Parquet for efficiency
        table = pa.table(data)
        pq.write_table(
            table,
            f"/cache/{checkpoint_name}.parquet",
            compression='snappy'
        )
    else:
        # Fallback to JSON
        import json
        with open(f"/cache/{checkpoint_name}.json", 'w') as f:
            json.dump(data, f)

5. Cost-Optimized Configuration

def get_modal_config(n_conversations: int) -> Dict:
    """Get optimal Modal configuration based on dataset size"""
    
    if n_conversations < 10_000:
        return {
            "gpu_type": gpu.T4(),
            "n_parallel": 1,
            "batch_size": 1000,
            "container_idle_timeout": 60,
        }
    elif n_conversations < 100_000:
        return {
            "gpu_type": gpu.A10G(),
            "n_parallel": 10,
            "batch_size": 2000,
            "container_idle_timeout": 300,
        }
    elif n_conversations < 1_000_000:
        return {
            "gpu_type": gpu.A100(memory=40),
            "n_parallel": 50,
            "batch_size": 5000,
            "container_idle_timeout": 600,
        }
    else:
        return {
            "gpu_type": gpu.A100(memory=80),
            "n_parallel": 100,
            "batch_size": 10000,
            "container_idle_timeout": 1200,
        }

6. CLI Integration

# kura/cli/modal_runner.py
import modal
import click

@click.command()
@click.option('--dataset', required=True, help='HuggingFace dataset name')
@click.option('--n-conversations', type=int, required=True)
@click.option('--output-dir', default='./modal_checkpoints')
def run_on_modal(dataset, n_conversations, output_dir):
    """Run Kura pipeline on Modal Labs"""
    
    # Get configuration
    config = get_modal_config(n_conversations)
    
    # Run on Modal
    with modal.run():
        result = distributed_kura_pipeline.remote(
            dataset_name=dataset,
            n_conversations=n_conversations,
            **config
        )
    
    # Download results
    download_checkpoints(output_dir)

Performance Expectations

Scale Single Machine Modal Labs Cost Speedup
50k 2-4 hours 15-30 min ~$10 8x
100k 4-8 hours 30-60 min ~$25 8x
500k 20-40 hours 2-4 hours ~$100 10x
1M 40-80 hours 4-8 hours ~$200 10x

Advantages Over Single GPU

  1. Linear Scaling: Add more GPUs to reduce time linearly
  2. Fault Tolerance: Failed batches can be retried
  3. Cost Efficiency: No paying for idle GPUs
  4. Flexibility: Mix GPU types for different stages
  5. No Setup: Zero infrastructure management

Integration Strategy

Phase 1: Modal Functions

  • Implement core processing functions
  • Test with small datasets
  • Benchmark performance

Phase 2: Orchestration

  • Build batch distribution logic
  • Implement checkpoint management
  • Add progress tracking

Phase 3: Optimization

  • Profile GPU utilization
  • Optimize batch sizes
  • Implement adaptive scaling

Phase 4: Production

  • Add monitoring and logging
  • Implement retry logic
  • Create cost tracking

Success Criteria

  • Process 100k conversations in <1 hour
  • Achieve 80%+ GPU utilization
  • Cost <$0.002 per conversation
  • Automatic failure recovery
  • Seamless checkpoint management
  • Support for 1M+ conversations

Related Issues

Modal Labs deployment is essential for making Kura truly scalable and production-ready for large datasets.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions