Skip to content

Commit

Permalink
stories trainer
Browse files Browse the repository at this point in the history
seems to work
  • Loading branch information
corbt committed Oct 18, 2024
1 parent 78c63c9 commit e02c7de
Show file tree
Hide file tree
Showing 7 changed files with 1,164 additions and 3 deletions.
2 changes: 1 addition & 1 deletion prepare-dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Building prompts: 2%|▏ | 6448/340702 [00:04<02:47, 1991.79it/s]"
"Building prompts: 100%|██████████| 340702/340702 [02:44<00:00, 2065.44it/s]\n"
]
}
],
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ dependencies = [
"peft>=0.13.2",
"polars>=1.9.0",
"python-dotenv>=1.0.1",
"scikit-learn>=1.5.2",
"seaborn>=0.13.2",
"sglang[all]>=0.3.3.post1",
"torch==2.4.0",
"tqdm>=4.66.5",
Expand Down
4 changes: 2 additions & 2 deletions rate-random-comments.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"outputs": [],
"source": [
"import polars as pl\n",
"from utils import dataset, build_all_prompts, run_inference_sglang\n",
"from utils import augmented_comments, build_all_prompts, run_inference_sglang\n",
"\n",
"df = dataset()"
"df = augmented_comments()"
]
},
{
Expand Down
136 changes: 136 additions & 0 deletions stories_train_model_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
from datasets import load_dataset, Dataset
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
from peft.tuners.lora import LoraConfig
from peft.mapping import get_peft_model
import wandb
from dotenv import load_dotenv
import polars as pl
from utils import stories_dataset
from sklearn.metrics import mean_squared_error
from liger_kernel.transformers import _apply_liger_kernel_to_instance

load_dotenv("/workspace/.env")

# Configuration
base_model = "unsloth/Meta-Llama-3.1-8B"
run_name = "stories_model_v1"
output_dir = f"./models/{run_name}"
num_epochs = 1
batch_size = 8
learning_rate = 2e-4
max_length = 4096

# Initialize wandb
wandb.init(project="hn_stories_model_training", name=run_name)


def create_dataset(split, num_rows, tokenizer):
stories = stories_dataset()
stories = stories.filter(pl.col("split") == split).head(num_rows)

stories = stories.with_columns(
[
pl.col("serialized").alias("text"),
pl.col("log_score").alias("label"),
]
)

stories = stories.with_columns(
[
pl.col("text")
.map_elements(
lambda x: tokenizer(x)["input_ids"], return_dtype=pl.List(pl.Int64)
)
.alias("input_ids"),
]
).select(["input_ids", "label"])
return Dataset.from_polars(stories)


print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(
base_model,
truncation=True,
padding=True,
max_length=max_length,
)

model = AutoModelForSequenceClassification.from_pretrained(
base_model,
num_labels=1, # Regression task
device_map="auto",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
_apply_liger_kernel_to_instance(model=model)

model.config.pad_token_id = tokenizer.pad_token_id
tokenizer.padding_side = "right"

print("Configuring LoRA...")
model = get_peft_model(
model,
LoraConfig(
task_type="SEQ_CLS",
r=8,
lora_alpha=16,
lora_dropout=0,
),
)

print("Loading dataset...")
train_stories = create_dataset("train", 30000, tokenizer)
validation_stories = create_dataset("val", 500, tokenizer)


# Configure training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
learning_rate=learning_rate,
weight_decay=0,
evaluation_strategy="steps",
eval_steps=0.1,
logging_steps=100,
save_strategy="steps",
save_steps=1000,
report_to="wandb",
no_cuda=False,
bf16=True,
warmup_steps=100,
# use_liger_kernel=True,
)


def compute_metrics(eval_pred):
predictions, labels = eval_pred
rmse = mean_squared_error(labels, predictions, squared=False)
return {"rmse": rmse}


print("Initializing Trainer...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_stories,
eval_dataset=validation_stories,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)

print("Starting model training...")
trainer.train()

print("Saving final model...")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

print("Stories model training complete")
67 changes: 67 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import html
import re
import numpy as np


def cache_dataframe(path):
Expand All @@ -24,6 +25,14 @@ def wrapper(*args, **kwargs):
cache[path] = df
return df

def bust_cache():
if path in cache:
del cache[path]
if os.path.exists(path):
os.remove(path)
print(f"Cache busted for {path}")

wrapper.bust_cache = bust_cache
return wrapper

return decorator
Expand Down Expand Up @@ -234,3 +243,61 @@ def with_story_info(comments_df: pl.DataFrame) -> pl.DataFrame:
return comments_df.join(
stories_df, left_on="top_level_parent", right_on="story_id", how="left"
)


@cache_dataframe("./data/stories_dataset.parquet")
def stories_dataset() -> pl.DataFrame:
stories = full_dataset().filter(
(pl.col("type") == "story")
& pl.col("time").is_not_null()
& pl.col("text").is_not_null()
& pl.col("url").is_null()
& pl.col("deleted").is_null()
& pl.col("dead").is_null()
)

# There's a weird discontinuity in late 2015, just ignore it
stories = stories.filter(pl.col("time") >= pl.datetime(2016, 1, 1))

# Add a log score, it's a very skewed distribution
stories = stories.with_columns(pl.col("score").log().alias("log_score"))

progress_bar = tqdm.tqdm(total=len(stories), desc="Serializing stories")

def serialize_story(story):
progress_bar.update(1)
return f"""
{story["title"]}
{story["by"]}, {story["time"].strftime("%Y-%m-%d")}
{html.unescape(story["text"]).replace("<p>", "\n\n")}
"""

stories = stories.with_columns(
pl.struct(["title", "by", "time", "text"])
.map_elements(serialize_story, return_dtype=pl.Utf8)
.alias("serialized")
)

progress_bar.close()

stories = stories.sample(fraction=1, shuffle=True, seed=42)

split_assignments = np.random.choice(
["train", "test", "val"], size=len(stories), p=[0.8, 0.1, 0.1]
)

stories = stories.with_columns(pl.Series("split", split_assignments))

return stories.select(
"id",
"title",
"by",
"text",
"score",
"descendants",
"time",
"log_score",
"serialized",
"split",
)
18 changes: 18 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e02c7de

Please sign in to comment.