Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cli)✨: Add synthetic data generation and task listing commands #3

Merged
merged 2 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,5 @@ docs/cli.md
message_log.db
catboost_info/*
examples/output/*
*.pkl
data/
341 changes: 340 additions & 1 deletion fast_seqfunc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
https://typer.tiangolo.com
"""

import random
from pathlib import Path
from typing import Optional
from typing import Any, Dict, Optional

import numpy as np
import pandas as pd
import typer
from loguru import logger

from fast_seqfunc import synthetic
from fast_seqfunc.core import (
evaluate_model,
load_model,
Expand Down Expand Up @@ -223,5 +226,341 @@ def describe():
typer.echo("Painless sequence-function models for proteins and nucleotides.")


@app.command()
def generate_synthetic(
task: str = typer.Argument(
...,
help="Type of synthetic data task to generate. Options: g_count, gc_content, "
"motif_position, motif_count, length_dependent, nonlinear_composition, "
"interaction, classification, multiclass",
),
output_dir: Path = typer.Option(
Path("synthetic_data"), help="Directory to save generated datasets"
),
total_count: int = typer.Option(1000, help="Total number of sequences to generate"),
train_ratio: float = typer.Option(
0.7, help="Proportion of data to use for training set"
),
val_ratio: float = typer.Option(
0.15, help="Proportion of data to use for validation set"
),
test_ratio: float = typer.Option(
0.15, help="Proportion of data to use for test set"
),
split_data: bool = typer.Option(
True, help="Whether to split data into train/val/test sets"
),
sequence_length: int = typer.Option(
30, help="Length of each sequence (for fixed-length tasks)"
),
min_length: int = typer.Option(
20, help="Minimum sequence length (for variable-length tasks)"
),
max_length: int = typer.Option(
50, help="Maximum sequence length (for variable-length tasks)"
),
noise_level: float = typer.Option(0.1, help="Level of noise to add to the data"),
sequence_type: str = typer.Option(
"dna", help="Type of sequences to generate: dna, rna, or protein"
),
alphabet: Optional[str] = typer.Option(
None, help="Custom alphabet for sequences. Overrides sequence_type if provided."
),
motif: Optional[str] = typer.Option(
None, help="Custom motif for motif-based tasks"
),
motifs: Optional[str] = typer.Option(
None, help="Comma-separated list of motifs for motif_count task"
),
weights: Optional[str] = typer.Option(
None, help="Comma-separated list of weights for motif_count task"
),
prefix: str = typer.Option("", help="Prefix for output filenames"),
random_seed: Optional[int] = typer.Option(
None, help="Random seed for reproducibility"
),
):
"""Generate synthetic sequence-function data for testing and benchmarking.

This command creates synthetic datasets with controllable properties and
complexity to test sequence-function models. Data can be split into
train/validation/test sets.

Each task produces a different type of sequence-function relationship:

- g_count: Linear relationship based on count of G nucleotides
- gc_content: Linear relationship based on GC content
- motif_position: Function depends on the position of a motif (nonlinear)
- motif_count: Function depends on counts of multiple motifs (linear)
- length_dependent: Function depends on sequence length (nonlinear)
- nonlinear_composition: Nonlinear function of base composition
- interaction: Function depends on interactions between positions
- classification: Binary classification based on presence of motifs
- multiclass: Multi-class classification based on different patterns

Example usage:

$ fast-seqfunc generate-synthetic gc_content --output-dir data/gc_task

$ fast-seqfunc generate-synthetic motif_position --motif ATCG --noise-level 0.2

$ fast-seqfunc generate-synthetic classification \
--sequence-type protein \
--no-split-data
"""
# Set random seed if provided
if random_seed is not None:
random.seed(random_seed)
np.random.seed(random_seed)

logger.info(f"Generating synthetic data for task: {task}")

# Create output directory if it doesn't exist
output_dir.mkdir(parents=True, exist_ok=True)

# Set alphabet based on sequence type
if alphabet is None:
sequence_type = sequence_type.lower()
if sequence_type == "dna":
alphabet = "ACGT"
elif sequence_type == "rna":
alphabet = "ACGU"
elif sequence_type == "protein":
alphabet = "ACDEFGHIKLMNPQRSTVWY"
else:
logger.warning(
f"Unknown sequence type: {sequence_type}. Using DNA alphabet."
)
alphabet = "ACGT"

logger.info(f"Using alphabet: {alphabet}")

# Task-specific parameters
task_params: Dict[str, Any] = {}

# Add common parameters that apply to most tasks
if task != "length_dependent":
task_params["length"] = sequence_length

# We need to patch the generate_random_sequences function to use our alphabet
# This approach uses monkey patching to avoid having to modify all task functions
original_generate_random_sequences = synthetic.generate_random_sequences

def patched_generate_random_sequences(*args, **kwargs):
"""
Patched version of `generate_random_sequences` that uses a custom alphabet.

This function overrides the alphabet parameter with our custom alphabet while
preserving all other parameters passed to the original function.

:param args: Positional arguments to pass to the original function
:param kwargs: Keyword arguments to pass to the original function
:return: Result from the original generate_random_sequences function
"""
# Override the alphabet parameter with our custom alphabet,
# but keep other parameters
kwargs["alphabet"] = alphabet
return original_generate_random_sequences(*args, **kwargs)

# Replace the function temporarily
synthetic.generate_random_sequences = patched_generate_random_sequences

# Add task-specific parameters based on the task type
if task == "motif_position":
# Use custom motif if provided
if motif:
task_params["motif"] = motif
else:
# Default motif depends on alphabet
if len(alphabet) == 4: # DNA/RNA
task_params["motif"] = "".join(random.sample(alphabet, 4))
else: # Protein
task_params["motif"] = "".join(
random.sample(alphabet, min(4, len(alphabet)))
)
logger.info(f"Using default motif: {task_params['motif']}")

elif task == "motif_count":
# Parse custom motifs if provided
if motifs:
task_params["motifs"] = [m.strip() for m in motifs.split(",")]
else:
# Generate default motifs based on alphabet
if len(alphabet) <= 8: # DNA/RNA
task_params["motifs"] = [
"".join(random.sample(alphabet, 2)) for _ in range(4)
]
else: # Protein
task_params["motifs"] = [
"".join(random.sample(alphabet, 3)) for _ in range(4)
]
logger.info(f"Using default motifs: {task_params['motifs']}")

# Parse custom weights if provided
if weights:
try:
weight_values = [float(w.strip()) for w in weights.split(",")]
if len(weight_values) != len(task_params["motifs"]):
logger.warning(
"Number of weights doesn't match number of motifs. "
"Using default weights."
)
task_params["weights"] = [1.0, -0.5, 2.0, -1.5]
else:
task_params["weights"] = weight_values
except ValueError:
logger.warning("Invalid weight values. Using default weights.")
task_params["weights"] = [1.0, -0.5, 2.0, -1.5]
else:
task_params["weights"] = [1.0, -0.5, 2.0, -1.5]

elif task == "length_dependent":
task_params["min_length"] = min_length
task_params["max_length"] = max_length

# Validate the task
valid_tasks = [
"g_count",
"gc_content",
"motif_position",
"motif_count",
"length_dependent",
"nonlinear_composition",
"interaction",
"classification",
"multiclass",
]

if task not in valid_tasks:
logger.error(
f"Invalid task: {task}. Valid options are: {', '.join(valid_tasks)}"
)
raise typer.Exit(1)

# The task functions don't directly accept an alphabet parameter
# so we need to remove it from task_params
if "alphabet" in task_params:
del task_params["alphabet"]

# Generate the dataset
try:
df = synthetic.generate_dataset_by_task(
task=task, count=total_count, noise_level=noise_level, **task_params
)

logger.info(f"Generated {len(df)} sequences for task: {task}")

# Create filename prefix if provided
file_prefix = f"{prefix}_" if prefix else ""

# Save the full dataset if not splitting
if not split_data:
output_path = output_dir / f"{file_prefix}{task}_data.csv"
df.to_csv(output_path, index=False)
logger.info(f"Saved full dataset to {output_path}")
# Restore original function
synthetic.generate_random_sequences = original_generate_random_sequences
return

# Validate split ratios
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-10:
logger.warning("Split ratios don't sum to 1.0. Normalizing.")
total = train_ratio + val_ratio + test_ratio
train_ratio /= total
val_ratio /= total
test_ratio /= total

# Shuffle the data
df = df.sample(frac=1.0, random_state=random_seed)

# Calculate split indices
n = len(df)
train_idx = int(n * train_ratio)
val_idx = train_idx + int(n * val_ratio)

# Split the data
train_df = df.iloc[:train_idx]
val_df = df.iloc[train_idx:val_idx]
test_df = df.iloc[val_idx:]

# Save the splits
train_path = output_dir / f"{file_prefix}train.csv"
val_path = output_dir / f"{file_prefix}val.csv"
test_path = output_dir / f"{file_prefix}test.csv"

train_df.to_csv(train_path, index=False)
val_df.to_csv(val_path, index=False)
test_df.to_csv(test_path, index=False)

logger.info(f"Saved train set ({len(train_df)} samples) to {train_path}")
logger.info(f"Saved validation set ({len(val_df)} samples) to {val_path}")
logger.info(f"Saved test set ({len(test_df)} samples) to {test_path}")

# Save task metadata
metadata = {
"task": task,
"sequence_type": sequence_type,
"alphabet": alphabet,
"total_count": total_count,
"train_count": len(train_df),
"val_count": len(val_df),
"test_count": len(test_df),
"noise_level": noise_level,
**task_params,
}

metadata_path = output_dir / f"{file_prefix}metadata.csv"
pd.DataFrame([metadata]).to_csv(metadata_path, index=False)
logger.info(f"Saved metadata to {metadata_path}")

except Exception as e:
logger.error(f"Error generating synthetic data: {e}")
raise typer.Exit(1)
finally:
# Make sure to restore the original function even if an error occurs
synthetic.generate_random_sequences = original_generate_random_sequences


@app.command()
def list_synthetic_tasks():
"""List all available synthetic sequence-function data tasks with descriptions."""
tasks = {
"g_count": "A simple linear task where the function value is the count of G "
"nucleotides in the sequence.",
"gc_content": "A simple linear task where the function value is the GC content "
"(proportion of G and C) of the sequence.",
"motif_position": "A nonlinear task where the function value depends on the "
"position of a specific motif in the sequence.",
"motif_count": "A linear task where the function value is a weighted sum of "
"counts of multiple motifs in the sequence.",
"length_dependent": "A task with variable-length sequences where the function "
"value depends nonlinearly on the sequence length.",
"nonlinear_composition": "A complex nonlinear task where the function depends "
"on ratios between different nucleotide frequencies.",
"interaction": "A task testing positional interactions, "
"where specific nucleotide pairs at certain positions "
"contribute to the function.",
"classification": "A binary classification task where the class depends on the "
"presence of specific patterns in the sequence.",
"multiclass": "A multi-class classification task "
"with multiple sequence patterns "
"corresponding to different classes.",
}

typer.echo("Available synthetic sequence-function data tasks:")
typer.echo("")

for task, description in tasks.items():
typer.echo(f"{task}:")
typer.echo(f" {description}")
typer.echo("")

typer.echo("Usage:")
typer.echo(" fast-seqfunc generate-synthetic TASK [OPTIONS]")
typer.echo("")
typer.echo("For detailed options:")
typer.echo(" fast-seqfunc generate-synthetic --help")


if __name__ == "__main__":
app()