-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
412 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div><style>\n", | ||
".dataframe > thead > tr,\n", | ||
".dataframe > tbody > tr {\n", | ||
" text-align: right;\n", | ||
" white-space: pre-wrap;\n", | ||
"}\n", | ||
"</style>\n", | ||
"<small>shape: (200, 2)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>id</th><th>predictions</th></tr><tr><td>i64</td><td>f64</td></tr></thead><tbody><tr><td>15410850</td><td>-1.03125</td></tr><tr><td>19662594</td><td>-3.671875</td></tr><tr><td>19390782</td><td>-2.984375</td></tr><tr><td>35481487</td><td>-3.28125</td></tr><tr><td>27015197</td><td>-3.5</td></tr><tr><td>…</td><td>…</td></tr><tr><td>11173806</td><td>-1.351562</td></tr><tr><td>40002472</td><td>-3.140625</td></tr><tr><td>12095865</td><td>-2.765625</td></tr><tr><td>39592207</td><td>-3.40625</td></tr><tr><td>19140970</td><td>1.6640625</td></tr></tbody></table></div>" | ||
], | ||
"text/plain": [ | ||
"shape: (200, 2)\n", | ||
"┌──────────┬─────────────┐\n", | ||
"│ id ┆ predictions │\n", | ||
"│ --- ┆ --- │\n", | ||
"│ i64 ┆ f64 │\n", | ||
"╞══════════╪═════════════╡\n", | ||
"│ 15410850 ┆ -1.03125 │\n", | ||
"│ 19662594 ┆ -3.671875 │\n", | ||
"│ 19390782 ┆ -2.984375 │\n", | ||
"│ 35481487 ┆ -3.28125 │\n", | ||
"│ 27015197 ┆ -3.5 │\n", | ||
"│ … ┆ … │\n", | ||
"│ 11173806 ┆ -1.351562 │\n", | ||
"│ 40002472 ┆ -3.140625 │\n", | ||
"│ 12095865 ┆ -2.765625 │\n", | ||
"│ 39592207 ┆ -3.40625 │\n", | ||
"│ 19140970 ┆ 1.6640625 │\n", | ||
"└──────────┴─────────────┘" | ||
] | ||
}, | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import polars as pl\n", | ||
"\n", | ||
"df = pl.read_parquet(\"s3://best-hn-data/rm/models/model1/dataset_predictions.parquet\")\n", | ||
"# df.group_by(\"label\").agg(pl.col(\"prediction\").mean()).sort(\"prediction\", descending=True)\n", | ||
"df\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.1" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import io | ||
import subprocess | ||
import tempfile | ||
import os | ||
import hashlib | ||
import asyncio | ||
from panza import limit_concurrency | ||
|
||
|
||
# limit_concurrency will ensure that we aren't building the same image twice at the same time |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
) | ||
}, | ||
) | ||
async def main(): | ||
def main(): | ||
import os | ||
import torch | ||
from datasets import Dataset | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# uv run modal run --detach scraped_stories.rm.model1 | ||
|
||
from .modal_app import app | ||
import modal | ||
import logging | ||
|
||
s3_bucket_name = "[placeholder]" | ||
|
||
if modal.is_local(): | ||
import os | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
s3_bucket_name = os.getenv("REMOTE_BUCKET", "[placeholder]") | ||
logging.info(f"Using S3 bucket: {s3_bucket_name}") | ||
|
||
|
||
@app.function( | ||
secrets=[modal.Secret.from_dotenv(".env")], | ||
gpu="H100", | ||
memory=32768 * 2, | ||
timeout=3600 * 24, | ||
volumes={ | ||
"/remote": modal.CloudBucketMount( | ||
bucket_name=s3_bucket_name, | ||
secret=modal.Secret.from_dotenv(".env"), | ||
read_only=False, | ||
) | ||
}, | ||
) | ||
def main(): | ||
import os | ||
import torch | ||
from datasets import Dataset | ||
from transformers import ( | ||
AutoModelForSequenceClassification, | ||
AutoTokenizer, | ||
Trainer, | ||
TrainingArguments, | ||
) | ||
from peft.tuners.lora import LoraConfig | ||
from peft.mapping import get_peft_model | ||
import wandb | ||
import polars as pl | ||
from liger_kernel.transformers import _apply_liger_kernel_to_instance | ||
from .training_helpers import ( | ||
compute_metrics, | ||
run_final_inference_and_report_metrics, | ||
MandT, | ||
create_dataset, | ||
) | ||
|
||
# Configuration | ||
base_model = "unsloth/Llama-3.2-1B" | ||
run_name = __file__.split("/")[-1].replace(".py", "") | ||
output_dir = f"/remote/rm/models/{run_name}" | ||
num_epochs = 1 | ||
batch_size = 4 | ||
gradient_accumulation_steps = 4 | ||
learning_rate = 2e-4 | ||
max_length = 4096 | ||
|
||
# Initialize wandb | ||
wandb.init(project="hn_scraped_stories", name=run_name) | ||
|
||
logging.info("Loading dataset...") | ||
df = pl.read_parquet( | ||
f"s3://{os.getenv('REMOTE_BUCKET')}/scraped-stories-with-datetime.parquet" | ||
).sample(n=50000, seed=42) | ||
logging.info(f"Loaded {df.height} rows") | ||
|
||
logging.info("Loading tokenizer and model...") | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
base_model, | ||
truncation=True, | ||
padding=True, | ||
max_length=max_length, | ||
) | ||
|
||
model = AutoModelForSequenceClassification.from_pretrained( | ||
base_model, | ||
num_labels=1, # Regression task | ||
device_map="auto", | ||
attn_implementation="flash_attention_2", | ||
torch_dtype=torch.bfloat16, | ||
) | ||
_apply_liger_kernel_to_instance(model=model) | ||
|
||
model.gradient_checkpointing_enable( | ||
gradient_checkpointing_kwargs={"use_reentrant": True} | ||
) | ||
|
||
model.config.pad_token_id = tokenizer.pad_token_id | ||
tokenizer.padding_side = "right" | ||
|
||
logging.info("Configuring LoRA...") | ||
model = get_peft_model( | ||
model, | ||
LoraConfig( | ||
task_type="SEQ_CLS", | ||
r=8, | ||
lora_alpha=16, | ||
lora_dropout=0, | ||
), | ||
) | ||
|
||
logging.info("Transforming datasets...") | ||
train_stories = create_dataset(df, "train", 50000, tokenizer, max_length) | ||
print(f"Train stories: {len(train_stories)}") | ||
validation_stories = create_dataset(df, "val", 500, tokenizer, max_length) | ||
print(f"Validation stories: {len(validation_stories)}") | ||
|
||
# Configure training arguments | ||
training_args = TrainingArguments( | ||
output_dir=output_dir, | ||
num_train_epochs=num_epochs, | ||
per_device_train_batch_size=batch_size, | ||
per_device_eval_batch_size=batch_size, | ||
learning_rate=learning_rate, | ||
weight_decay=0, | ||
evaluation_strategy="steps", | ||
eval_steps=0.05, | ||
logging_steps=100, | ||
save_strategy="steps", | ||
save_steps=1000, | ||
report_to="wandb", | ||
no_cuda=False, | ||
bf16=True, | ||
warmup_ratio=0.1, | ||
gradient_accumulation_steps=gradient_accumulation_steps, | ||
) | ||
|
||
logging.info("Initializing Trainer...") | ||
trainer = Trainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=train_stories, | ||
eval_dataset=validation_stories, | ||
tokenizer=tokenizer, | ||
compute_metrics=compute_metrics, | ||
) | ||
|
||
logging.info("Starting model training...") | ||
trainer.train() | ||
|
||
logging.info("Saving final model...") | ||
trainer.save_model(output_dir) | ||
tokenizer.save_pretrained(output_dir) | ||
|
||
logging.info("Running final inference and reporting metrics...") | ||
metrics = run_final_inference_and_report_metrics( | ||
MandT(model, tokenizer), df, output_dir | ||
) | ||
|
||
logging.info("Model training complete") | ||
|
||
|
||
@app.local_entrypoint() | ||
def main_local(): | ||
print("Running main locally") | ||
main.remote() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# uv run modal run --detach scraped_stories.rm.model1 | ||
|
||
from .modal_app import app | ||
import modal | ||
import logging | ||
|
||
s3_bucket_name = "[placeholder]" | ||
|
||
if modal.is_local(): | ||
import os | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
s3_bucket_name = os.getenv("REMOTE_BUCKET", "[placeholder]") | ||
logging.info(f"Using S3 bucket: {s3_bucket_name}") | ||
|
||
|
||
@app.function( | ||
secrets=[modal.Secret.from_dotenv(".env")], | ||
gpu="H100", | ||
memory=32768 * 2, | ||
timeout=3600 * 24, | ||
volumes={ | ||
"/remote": modal.CloudBucketMount( | ||
bucket_name=s3_bucket_name, | ||
secret=modal.Secret.from_dotenv(".env"), | ||
read_only=False, | ||
) | ||
}, | ||
) | ||
def main(): | ||
import os | ||
import polars as pl | ||
from .training_helpers import ( | ||
run_final_inference_and_report_metrics, | ||
load_model, | ||
) | ||
|
||
# Configuration | ||
model_dir = f"/remote/rm/models/model1" | ||
|
||
logging.info(f"Loading model from {model_dir}") | ||
model = load_model(model_dir) | ||
|
||
logging.info("Loading dataset...") | ||
df = pl.read_parquet( | ||
f"s3://{os.getenv('REMOTE_BUCKET')}/scraped-stories-with-datetime.parquet" | ||
) | ||
|
||
logging.info("Running final inference and reporting metrics...") | ||
run_final_inference_and_report_metrics(model, df, model_dir) | ||
|
||
logging.info("Inference complete") | ||
|
||
|
||
@app.local_entrypoint() | ||
def main_local(): | ||
print("Running main locally") | ||
main.remote() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,5 @@ | ||
import modal | ||
print("Hello, world!") | ||
|
||
app = modal.App() | ||
import polars as pl | ||
|
||
|
||
@app.function() | ||
async def main(): | ||
# TESTING CODE, REMOVE | ||
import os | ||
|
||
print("Ok, it ran") | ||
|
||
print("contents of /remote:") | ||
print(os.listdir("/remote")) | ||
|
||
# END TESTING CODE | ||
|
||
|
||
@app.local_entrypoint() | ||
def main_local(): | ||
print("Running main locally") | ||
main.remote() | ||
print("imported polars") |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.