Skip to content

Commit

Permalink
exploratory work with skypilot
Browse files Browse the repository at this point in the history
  • Loading branch information
corbt committed Nov 12, 2024
1 parent 388e73d commit 223dc0e
Show file tree
Hide file tree
Showing 10 changed files with 1,995 additions and 116 deletions.
Empty file.
778 changes: 778 additions & 0 deletions best_hn/scraped_stories/prepare_data.ipynb

Large diffs are not rendered by default.

100 changes: 0 additions & 100 deletions explore.ipynb
Original file line number Diff line number Diff line change
@@ -1,100 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading dataset...\n"
]
},
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['chosen', 'rejected', 'chosen_prompt', 'rejected_prompt'],\n",
" num_rows: 90000\n",
"})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from datasets import load_from_disk, Dataset\n",
"from transformers import (\n",
" AutoModelForSequenceClassification,\n",
" AutoTokenizer,\n",
")\n",
"from trl import RewardTrainer, RewardConfig\n",
"from peft.tuners.lora import LoraConfig\n",
"from peft.mapping import get_peft_model\n",
"import wandb\n",
"\n",
"# Configuration\n",
"model_name = \"unsloth/Llama-3.2-3B\"\n",
"dataset_path = \"./data/sample_pairs\"\n",
"output_dir = \"./reward_model_output\"\n",
"num_epochs = 1\n",
"batch_size = 8\n",
"learning_rate = 5e-5\n",
"max_length = 4096\n",
"\n",
"print(\"Loading dataset...\")\n",
"dataset: Dataset = load_from_disk(dataset_path)[\"train\"]\n",
"\n",
"dataset\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['_data_files', '_fingerprint', '_format_columns', '_format_kwargs', '_format_type', '_output_all_columns', '_split'],\n",
" num_rows: 1\n",
"})"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ dependencies = [
# "bitsandbytes>=0.44.1",
"datasets>=3.0.1",
"dicttoxml>=1.7.16",
"docker>=7.1.0",
"google-api-python-client>=2.152.0",
# "flash-attn==2.6.3",
"hf-transfer>=0.1.8",
"ipykernel>=6.29.5",
Expand All @@ -30,6 +32,7 @@ dependencies = [
"schedulefree>=1.2.7",
"scikit-learn>=1.5.2",
"seaborn>=0.13.2",
"skypilot[gcp,runpod]>=0.7.0",
# "sglang[all]>=0.3.3.post1",
"torch==2.4.0",
"tqdm>=4.66.5",
Expand Down
2 changes: 1 addition & 1 deletion scraped_stories/rm/model1.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
},
)
async def main():
def main():
import os
import torch
from datasets import Dataset
Expand Down
35 changes: 35 additions & 0 deletions scraped_stories/rm/update_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import io
import subprocess
import tempfile
import os
import hashlib


def update_image():
# Create hash of dockerfile contents
dockerfile_hash = hashlib.sha256(dockerfile.encode()).hexdigest()[:12]
image_name = f"ghcr.io/openpipe/scraped-stories-rm:{dockerfile_hash}"

print(f"Building image {image_name}")

# Create temporary directory and write dockerfile there
with tempfile.TemporaryDirectory() as tmpdir:
dockerfile_path = os.path.join(tmpdir, "Dockerfile")
with open(dockerfile_path, "w") as f:
f.write(dockerfile)

subprocess.run(
["docker", "build", "--platform", "linux/amd64", "-t", image_name, tmpdir],
check=True,
)

print(f"Pushing image {image_name}")
subprocess.run(["docker", "push", image_name], check=True)

print(f"Image {image_name} updated")

return image_name


if __name__ == "__main__":
update_image()
32 changes: 32 additions & 0 deletions scraped_stories/run_skypilot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import sky
import dotenv
import os
from typing import Callable

dotenv.load_dotenv()


def run_skypilot(script_path: str, env_vars: dict[str, str] | None = None):
env_vars = env_vars or dotenv.dotenv_values() # type: ignore

s3_bucket_name = os.getenv("REMOTE_BUCKET")
task = sky.Task(
run="pwd && ls",
envs=env_vars,
workdir=os.path.dirname(os.path.abspath(__file__)),
)

# task.set_resources(sky.Resources(cloud=sky.RunPod(), accelerators="H100:1"))
task.set_resources(sky.Resources(cloud=sky.RunPod()))

# task.set_storage_mounts({"/remote": sky.Storage(source=f"s3://{s3_bucket_name}")})

sky.launch(task, cluster_name="scraped-stories-rm", down=True)


def is_local():
return "SKYPILOT_TASK_ID" not in os.environ


if __name__ == "__main__":
run_skypilot("scraped_stories/rm/model1.py")
106 changes: 91 additions & 15 deletions training_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,17 @@
from scipy.stats import pearsonr
from sklearn.metrics import root_mean_squared_error
import wandb
from inference import run_inference_transformers, ModelOrPath, MandT, load_model
from utils import stories_dataset, calculate_metrics_by_split
from .inference import run_inference_transformers, ModelOrPath, MandT, load_model
import math
import logging
from datasets import Dataset

# Configure logging
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s %(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)


def compute_metrics(eval_pred):
Expand All @@ -27,7 +36,7 @@ def compute_metrics(eval_pred):


def run_final_inference_and_report_metrics(
model_or_path: ModelOrPath, output_dir: Optional[str] = None, dataset=None
model_or_path: ModelOrPath, dataset=None, output_dir: Optional[str] = None
):
if output_dir is None:
if not isinstance(model_or_path, str):
Expand All @@ -39,7 +48,7 @@ def run_final_inference_and_report_metrics(
predictions_path = f"{output_dir}/dataset_predictions.parquet"

if dataset is None:
dataset = stories_dataset()
raise ValueError("dataset is required")

# Check if predictions file already exists
try:
Expand Down Expand Up @@ -68,16 +77,83 @@ def run_final_inference_and_report_metrics(

print(metrics)

# Log metrics to wandb if it's being used
if wandb.run is not None:
for row in metrics.iter_rows(named=True):
split = row["split"]
wandb.summary.update(
{
f"final/{split}/baseline_rmse": row["baseline_rmse"],
f"final/{split}/model_rmse": row["model_rmse"],
f"final/{split}/correlation": row["model_correlation"],
}
)
for row in metrics.iter_rows(named=True):
split = row["split"]
wandb.summary.update(
{
f"final/{split}/baseline_rmse": row["baseline_rmse"],
f"final/{split}/model_rmse": row["model_rmse"],
f"final/{split}/correlation": row["model_correlation"],
}
)

return metrics


def calculate_metrics_by_split(df: pl.DataFrame) -> pl.DataFrame:
"""
Calculate correlation and RMSE metrics for each split in the dataset.
Args:
df: DataFrame with log_score, predictions and split columns
Returns:
DataFrame with metrics for each split
"""
metrics = []

for split in df["split"].unique():
split_df = df.filter(pl.col("split") == split)

# Calculate baseline (mean) metrics
average_score = split_df["log_score"].mean()
rmse_baseline = math.sqrt(
(split_df["log_score"] - average_score).pow(2).sum() / len(split_df)
)

# Calculate model metrics
rmse_model = math.sqrt(
(split_df["log_score"] - split_df["predictions"]).pow(2).sum()
/ len(split_df)
)
correlation_model = split_df.select(pl.corr("log_score", "predictions"))[
"log_score"
][0]

metrics.append(
{
"split": split,
"baseline_rmse": rmse_baseline,
"model_rmse": rmse_model,
"model_correlation": correlation_model,
"num_rows": len(split_df),
}
)

return pl.DataFrame(metrics)


def create_dataset(
df: pl.DataFrame,
split: str,
num_rows: int,
tokenizer,
max_len: int,
n_proc: int = 4,
):
df = df.with_columns(pl.col("score").log().alias("log_score"))
df = df.filter(pl.col("split") == split).head(num_rows)
df = df.with_columns(
[
pl.col("serialized").alias("text"),
pl.col("log_score").alias("label"),
]
)
dataset = Dataset.from_polars(df.select(["text", "label"]))

def tokenize_function(examples):
return tokenizer(examples["text"], truncation=False)

dataset = dataset.map(tokenize_function, batched=True, num_proc=n_proc)
dataset = dataset.filter(lambda example: len(example["input_ids"]) <= max_len)
return dataset
59 changes: 59 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,65 @@ def with_story_info(comments_df: pl.DataFrame) -> pl.DataFrame:
)


@cache_dataframe("./data/stories_dataset.parquet")
def stories_dataset() -> pl.DataFrame:
stories = full_dataset().filter(
(pl.col("type") == "story")
& pl.col("time").is_not_null()
& pl.col("text").is_not_null()
& pl.col("url").is_null()
& pl.col("deleted").is_null()
& pl.col("dead").is_null()
)

# There's a weird discontinuity in late 2015, just ignore it
stories = stories.filter(pl.col("time") >= pl.datetime(2016, 1, 1))

# Add a log score, it's a very skewed distribution
stories = stories.with_columns(pl.col("score").log().alias("log_score"))

progress_bar = tqdm.tqdm(total=len(stories), desc="Serializing stories")

def serialize_story(story):
progress_bar.update(1)
escaped_story = html.unescape(story["text"]).replace("<p>", "\n\n")
return f"""
{story["title"]}
{story["by"]}, {story["time"].strftime("%Y-%m-%d")}
{escaped_story}
"""

stories = stories.with_columns(
pl.struct(["title", "by", "time", "text"])
.map_elements(serialize_story, return_dtype=pl.Utf8)
.alias("serialized")
)

progress_bar.close()

stories = stories.sample(fraction=1, shuffle=True, seed=42)

split_assignments = np.random.choice(
["train", "test", "val"], size=len(stories), p=[0.8, 0.1, 0.1]
)

stories = stories.with_columns(pl.Series("split", split_assignments))

return stories.select(
"id",
"title",
"by",
"text",
"score",
"descendants",
"time",
"log_score",
"serialized",
"split",
)


def calculate_metrics_by_split(df: pl.DataFrame) -> pl.DataFrame:
"""
Calculate correlation and RMSE metrics for each split in the dataset.
Expand Down
Loading

0 comments on commit 223dc0e

Please sign in to comment.