Skip to content

Commit

Permalink
prediction code for stories working
Browse files Browse the repository at this point in the history
  • Loading branch information
corbt committed Oct 19, 2024
1 parent e02c7de commit b9d56ee
Show file tree
Hide file tree
Showing 10 changed files with 1,484 additions and 113 deletions.
76 changes: 76 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Union, Optional
from dataclasses import dataclass
import torch
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from liger_kernel.transformers import _apply_liger_kernel_to_instance
from peft.peft_model import PeftModel


@dataclass
class MandT:
model: Union[AutoModelForSequenceClassification, PeftModel]
tokenizer: AutoTokenizer


def load_peft_model(model_path: str, merge: bool = False) -> MandT:
model = AutoModelForSequenceClassification.from_pretrained(
model_path, num_labels=1, device_map="auto", torch_dtype=torch.bfloat16
)

if merge:
model = PeftModel.from_pretrained(model, model_path)
model = model.merge_and_unload()
_apply_liger_kernel_to_instance(model)
else:
_apply_liger_kernel_to_instance(model)
model = PeftModel.from_pretrained(model, model_path)

tokenizer = AutoTokenizer.from_pretrained(model_path)
return MandT(model, tokenizer)


def run_inference_transformers(
prompts: list[str],
model_or_path: Union[MandT, str],
batch_size: int = 4,
) -> list[float]:
if isinstance(model_or_path, str):
mandt = load_peft_model(model_or_path, merge=True)
else:
mandt = model_or_path

model = mandt.model
tokenizer = mandt.tokenizer

# Tokenize all prompts
tokenized_prompts = [
tokenizer.encode(prompt, add_special_tokens=True) for prompt in prompts
]

# Sort prompts by length (number of tokens)
sorted_indices = sorted(
range(len(tokenized_prompts)), key=lambda i: -len(tokenized_prompts[i])
)
sorted_prompts = [prompts[i] for i in sorted_indices]

results = []
for i in tqdm(
range(0, len(sorted_prompts), batch_size),
total=len(sorted_prompts) // batch_size,
):
batch = sorted_prompts[i : i + batch_size]
inputs = tokenizer(
batch, return_tensors="pt", padding=True, truncation=True
).to(model.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.squeeze(-1)
results.extend(logits.cpu().tolist())

# Reorder results to match original prompt order
original_order_results = [0.0] * len(prompts)
for i, result in zip(sorted_indices, results):
original_order_results[i] = result

return original_order_results
347 changes: 347 additions & 0 deletions stories-analysis.ipynb

Large diffs are not rendered by default.

Empty file added stories-inference.ipynb
Empty file.
458 changes: 458 additions & 0 deletions stories-test-inference.ipynb

Large diffs are not rendered by default.

136 changes: 136 additions & 0 deletions stories_train_model_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
from datasets import load_dataset, Dataset
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
from peft.tuners.lora import LoraConfig
from peft.mapping import get_peft_model
import wandb
from dotenv import load_dotenv
import polars as pl
from utils import stories_dataset
from sklearn.metrics import mean_squared_error
from liger_kernel.transformers import _apply_liger_kernel_to_instance

load_dotenv("/workspace/.env")

# Configuration
base_model = "unsloth/Meta-Llama-3.1-8B"
run_name = "stories_model_v2"
output_dir = f"./models/{run_name}"
num_epochs = 1
batch_size = 4
learning_rate = 2e-4
max_length = 4096

# Initialize wandb
wandb.init(project="hn_stories_model_training", name=run_name)


def create_dataset(split, num_rows, tokenizer):
stories = stories_dataset()
stories = stories.filter(pl.col("split") == split).head(num_rows)

stories = stories.with_columns(
[
pl.col("serialized").alias("text"),
pl.col("log_score").alias("label"),
]
)

stories = stories.with_columns(
[
pl.col("text")
.map_elements(
lambda x: tokenizer(x)["input_ids"], return_dtype=pl.List(pl.Int64)
)
.alias("input_ids"),
]
).select(["input_ids", "label"])
return Dataset.from_polars(stories)


print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(
base_model,
truncation=True,
padding=True,
max_length=max_length,
)

model = AutoModelForSequenceClassification.from_pretrained(
base_model,
num_labels=1, # Regression task
device_map="auto",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
_apply_liger_kernel_to_instance(model=model)

model.config.pad_token_id = tokenizer.pad_token_id
tokenizer.padding_side = "right"

print("Configuring LoRA...")
model = get_peft_model(
model,
LoraConfig(
task_type="SEQ_CLS",
r=8,
lora_alpha=16,
lora_dropout=0,
),
)

print("Loading dataset...")
train_stories = create_dataset("train", 1000000, tokenizer)
validation_stories = create_dataset("val", 1000, tokenizer)


# Configure training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
learning_rate=learning_rate,
weight_decay=0,
evaluation_strategy="steps",
eval_steps=0.05,
logging_steps=100,
save_strategy="steps",
save_steps=1000,
report_to="wandb",
no_cuda=False,
bf16=True,
warmup_steps=100,
# use_liger_kernel=True,
)


def compute_metrics(eval_pred):
predictions, labels = eval_pred
rmse = mean_squared_error(labels, predictions, squared=False)
return {"rmse": rmse}


print("Initializing Trainer...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_stories,
eval_dataset=validation_stories,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)

print("Starting model training...")
trainer.train()

print("Saving final model...")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

print("Stories model training complete")
138 changes: 138 additions & 0 deletions stories_train_model_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
from datasets import load_dataset, Dataset
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
from peft.tuners.lora import LoraConfig
from peft.mapping import get_peft_model
import wandb
from dotenv import load_dotenv
import polars as pl
from utils import stories_dataset
from sklearn.metrics import mean_squared_error
from liger_kernel.transformers import _apply_liger_kernel_to_instance

load_dotenv("/workspace/.env")

# Configuration
base_model = "unsloth/Meta-Llama-3.1-8B"
run_name = "stories_model_v2"
output_dir = f"./models/{run_name}"
num_epochs = 1
batch_size = 4
gradient_accumulation_steps = 4
learning_rate = 2e-4
max_length = 4096

# Initialize wandb
wandb.init(project="hn_stories_model_training", name=run_name)


def create_dataset(split, num_rows, tokenizer):
stories = stories_dataset()
stories = stories.filter(pl.col("split") == split).head(num_rows)

stories = stories.with_columns(
[
pl.col("serialized").alias("text"),
pl.col("log_score").alias("label"),
]
)

stories = stories.with_columns(
[
pl.col("text")
.map_elements(
lambda x: tokenizer(x)["input_ids"], return_dtype=pl.List(pl.Int64)
)
.alias("input_ids"),
]
).select(["input_ids", "label"])
return Dataset.from_polars(stories)


print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(
base_model,
truncation=True,
padding=True,
max_length=max_length,
)

model = AutoModelForSequenceClassification.from_pretrained(
base_model,
num_labels=1, # Regression task
device_map="auto",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
_apply_liger_kernel_to_instance(model=model)

model.config.pad_token_id = tokenizer.pad_token_id
tokenizer.padding_side = "right"

print("Configuring LoRA...")
model = get_peft_model(
model,
LoraConfig(
task_type="SEQ_CLS",
r=8,
lora_alpha=16,
lora_dropout=0,
),
)

print("Loading dataset...")
train_stories = create_dataset("train", 1000000, tokenizer)
validation_stories = create_dataset("val", 1000, tokenizer)


# Configure training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
learning_rate=learning_rate,
weight_decay=0,
evaluation_strategy="steps",
eval_steps=0.05,
logging_steps=100,
save_strategy="steps",
save_steps=1000,
report_to="wandb",
no_cuda=False,
bf16=True,
warmup_steps=100,
gradient_accumulation_steps=gradient_accumulation_steps,
# use_liger_kernel=True,
)


def compute_metrics(eval_pred):
predictions, labels = eval_pred
rmse = mean_squared_error(labels, predictions, squared=False)
return {"rmse": rmse}


print("Initializing Trainer...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_stories,
eval_dataset=validation_stories,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)

print("Starting model training...")
trainer.train()

print("Saving final model...")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

print("Stories model training complete")
Loading

0 comments on commit b9d56ee

Please sign in to comment.