diff --git a/prepare_env.sh b/prepare_env.sh
index 9b67eb3..5800d36 100755
--- a/prepare_env.sh
+++ b/prepare_env.sh
@@ -14,6 +14,12 @@ add_to_bashrc "export HF_HOME=/workspace/.cache/huggingface"
add_to_bashrc "export HF_HUB_ENABLE_HF_TRANSFER=1"
add_to_bashrc "source /workspace/.env"
+# Add uv to path
+add_to_bashrc "export PATH=\"/root/.cargo/bin:\$PATH\""
+
+# Enable CUDA debugging
+add_to_bashrc "export CUDA_LAUNCH_BLOCKING=1"
+
source ~/.bashrc
# Install system dependencies
diff --git a/pyproject.toml b/pyproject.toml
index 726bbf0..08604ed 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,6 +18,7 @@ dependencies = [
"peft>=0.13.2",
"polars>=1.9.0",
"python-dotenv>=1.0.1",
+ "schedulefree>=1.2.7",
"scikit-learn>=1.5.2",
"seaborn>=0.13.2",
"sglang[all]>=0.3.3.post1",
diff --git a/stories-analysis.ipynb b/stories-analysis.ipynb
index 6fc81a0..3aa12f5 100644
--- a/stories-analysis.ipynb
+++ b/stories-analysis.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -32,50 +32,91 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Baseline RMSE: 1.33\n",
- "Baseline Correlation: nan\n",
- "Model RMSE: 1.11\n",
- "Model Correlation: 0.55\n"
- ]
+ "data": {
+ "text/html": [
+ "
\n",
+ "
shape: (3, 5)split | baseline_rmse | model_rmse | baseline_correlation | model_correlation |
---|
str | f64 | f64 | f64 | f64 |
"test" | 1.334575 | 1.133601 | NaN | 0.529694 |
"train" | 1.324123 | 1.102933 | NaN | 0.555844 |
"val" | 1.323152 | 1.12773 | NaN | 0.524235 |
"
+ ],
+ "text/plain": [
+ "shape: (3, 5)\n",
+ "┌───────┬───────────────┬────────────┬──────────────────────┬───────────────────┐\n",
+ "│ split ┆ baseline_rmse ┆ model_rmse ┆ baseline_correlation ┆ model_correlation │\n",
+ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
+ "│ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n",
+ "╞═══════╪═══════════════╪════════════╪══════════════════════╪═══════════════════╡\n",
+ "│ test ┆ 1.334575 ┆ 1.133601 ┆ NaN ┆ 0.529694 │\n",
+ "│ train ┆ 1.324123 ┆ 1.102933 ┆ NaN ┆ 0.555844 │\n",
+ "│ val ┆ 1.323152 ┆ 1.12773 ┆ NaN ┆ 0.524235 │\n",
+ "└───────┴───────────────┴────────────┴──────────────────────┴───────────────────┘"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
"import math\n",
"\n",
- "# Calculate the RMSE for baseline prediction (average score)\n",
- "average_score = stories[\"log_score\"].mean()\n",
- "rmse_baseline = math.sqrt(\n",
- " (stories[\"log_score\"] - average_score).pow(2).sum() / len(stories)\n",
- ")\n",
- "print(f\"Baseline RMSE: {rmse_baseline:.2f}\")\n",
"\n",
- "# Calculate and print the correlation for baseline\n",
- "correlation_baseline = stories.select(pl.corr(\"log_score\", pl.lit(average_score)))[\n",
- " \"log_score\"\n",
- "][0]\n",
- "print(f\"Baseline Correlation: {correlation_baseline:.2f}\")\n",
+ "def calculate_metrics_by_split(df: pl.DataFrame) -> pl.DataFrame:\n",
+ " \"\"\"\n",
+ " Calculate correlation and RMSE metrics for each split in the dataset.\n",
"\n",
- "# Calculate the RMSE between the predictions and the log_score, which is the score the model is trained to predict\n",
- "rmse_model = math.sqrt(\n",
- " (stories[\"log_score\"] - stories[\"predictions\"]).pow(2).sum() / len(stories)\n",
- ")\n",
- "print(f\"Model RMSE: {rmse_model:.2f}\")\n",
+ " Args:\n",
+ " df: DataFrame with log_score, predictions and split columns\n",
+ "\n",
+ " Returns:\n",
+ " DataFrame with metrics for each split\n",
+ " \"\"\"\n",
+ " metrics = []\n",
+ "\n",
+ " for split in df[\"split\"].unique():\n",
+ " split_df = df.filter(pl.col(\"split\") == split)\n",
+ "\n",
+ " # Calculate baseline (mean) metrics\n",
+ " average_score = split_df[\"log_score\"].mean()\n",
+ " rmse_baseline = math.sqrt(\n",
+ " (split_df[\"log_score\"] - average_score).pow(2).sum() / len(split_df)\n",
+ " )\n",
+ "\n",
+ " # Calculate model metrics\n",
+ " rmse_model = math.sqrt(\n",
+ " (split_df[\"log_score\"] - split_df[\"predictions\"]).pow(2).sum()\n",
+ " / len(split_df)\n",
+ " )\n",
+ " correlation_model = split_df.select(pl.corr(\"log_score\", \"predictions\"))[\n",
+ " \"log_score\"\n",
+ " ][0]\n",
+ "\n",
+ " metrics.append(\n",
+ " {\n",
+ " \"split\": split,\n",
+ " \"baseline_rmse\": rmse_baseline,\n",
+ " \"model_rmse\": rmse_model,\n",
+ " \"model_correlation\": correlation_model,\n",
+ " }\n",
+ " )\n",
"\n",
- "# Calculate and print the correlation for the model\n",
- "correlation_model = stories.select(pl.corr(\"log_score\", \"predictions\"))[\"log_score\"][0]\n",
- "print(f\"Model Correlation: {correlation_model:.2f}\")\n"
+ " return pl.DataFrame(metrics)\n",
+ "\n",
+ "\n",
+ "calculate_metrics_by_split(stories)"
]
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -127,7 +168,7 @@
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -140,54 +181,52 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (1_000, 3)serialized | predicted_score | score |
---|
str | i64 | i64 |
"\n",
- "Ask HN: Who is hiring? (May 2… | 518 | 553 |
"\n",
- "Launch HN: Curvenote (YC W21)… | 153 | 108 |
"\n",
- "Ask HN: Alternatives to DuckD… | 143 | 42 |
"\n",
- "Launch HN: Penny (YC W17) – a… | 139 | 114 |
"\n",
- "Launch HN: MergeQueue (YC S21… | 115 | 122 |
… | … | … |
"\n",
- "Galaxy S8 Owners\n",
- "lidzen, 2019… | 1 | 2 |
"\n",
- "Is there any extension that c… | 1 | 2 |
"\n",
- "The Designing versus Testing … | 1 | 1 |
"\n",
- "Balloon Bouquets Chandigarh\n",
- "o… | 1 | 1 |
"\n",
- "Early Animal Evolution: A Mor… | 1 | 1 |
"
+ "shape: (142_886, 3)serialized | predicted_score | score |
---|
str | i64 | i64 |
"\n",
+ "Ask HN: Who is hiring? (Octob… | 624 | 678 |
"\n",
+ "Ask HN: Who is hiring? (Septe… | 586 | 573 |
"\n",
+ "Ask HN: Who is hiring? (May 2… | 534 | 553 |
"\n",
+ "Ask HN: Who is hiring? (Novem… | 518 | 539 |
"\n",
+ "Ask HN: Who is hiring? (March… | 502 | 592 |
… | … | … |
"\n",
+ "Meet the Israeli startup that… | 0 | 1 |
"\n",
+ "Irrfan Khan, Bhanu Athaiya Ge… | 0 | 1 |
"\n",
+ "UnlockFame – You Have a Talen… | 0 | 1 |
"\n",
+ "Feedback about My Product\n",
+ "msm… | 0 | 1 |
"\n",
+ "Easy Guidelines to Look Tall … | 0 | 1 |
"
],
"text/plain": [
- "shape: (1_000, 3)\n",
+ "shape: (142_886, 3)\n",
"┌────────────────────────────────┬─────────────────┬───────┐\n",
"│ serialized ┆ predicted_score ┆ score │\n",
"│ --- ┆ --- ┆ --- │\n",
"│ str ┆ i64 ┆ i64 │\n",
"╞════════════════════════════════╪═════════════════╪═══════╡\n",
- "│ ┆ 518 ┆ 553 │\n",
+ "│ ┆ 624 ┆ 678 │\n",
+ "│ Ask HN: Who is hiring? (Octob… ┆ ┆ │\n",
+ "│ ┆ 586 ┆ 573 │\n",
+ "│ Ask HN: Who is hiring? (Septe… ┆ ┆ │\n",
+ "│ ┆ 534 ┆ 553 │\n",
"│ Ask HN: Who is hiring? (May 2… ┆ ┆ │\n",
- "│ ┆ 153 ┆ 108 │\n",
- "│ Launch HN: Curvenote (YC W21)… ┆ ┆ │\n",
- "│ ┆ 143 ┆ 42 │\n",
- "│ Ask HN: Alternatives to DuckD… ┆ ┆ │\n",
- "│ ┆ 139 ┆ 114 │\n",
- "│ Launch HN: Penny (YC W17) – a… ┆ ┆ │\n",
- "│ ┆ 115 ┆ 122 │\n",
- "│ Launch HN: MergeQueue (YC S21… ┆ ┆ │\n",
+ "│ ┆ 518 ┆ 539 │\n",
+ "│ Ask HN: Who is hiring? (Novem… ┆ ┆ │\n",
+ "│ ┆ 502 ┆ 592 │\n",
+ "│ Ask HN: Who is hiring? (March… ┆ ┆ │\n",
"│ … ┆ … ┆ … │\n",
- "│ ┆ 1 ┆ 2 │\n",
- "│ Galaxy S8 Owners ┆ ┆ │\n",
- "│ lidzen, 2019… ┆ ┆ │\n",
- "│ ┆ 1 ┆ 2 │\n",
- "│ Is there any extension that c… ┆ ┆ │\n",
- "│ ┆ 1 ┆ 1 │\n",
- "│ The Designing versus Testing … ┆ ┆ │\n",
- "│ ┆ 1 ┆ 1 │\n",
- "│ Balloon Bouquets Chandigarh ┆ ┆ │\n",
- "│ o… ┆ ┆ │\n",
- "│ ┆ 1 ┆ 1 │\n",
- "│ Early Animal Evolution: A Mor… ┆ ┆ │\n",
+ "│ ┆ 0 ┆ 1 │\n",
+ "│ Meet the Israeli startup that… ┆ ┆ │\n",
+ "│ ┆ 0 ┆ 1 │\n",
+ "│ Irrfan Khan, Bhanu Athaiya Ge… ┆ ┆ │\n",
+ "│ ┆ 0 ┆ 1 │\n",
+ "│ UnlockFame – You Have a Talen… ┆ ┆ │\n",
+ "│ ┆ 0 ┆ 1 │\n",
+ "│ Feedback about My Product ┆ ┆ │\n",
+ "│ msm… ┆ ┆ │\n",
+ "│ ┆ 0 ┆ 1 │\n",
+ "│ Easy Guidelines to Look Tall … ┆ ┆ │\n",
"└────────────────────────────────┴─────────────────┴───────┘"
]
},
- "execution_count": 33,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -200,7 +239,7 @@
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
@@ -267,6 +306,93 @@
" \"All subsets have been saved to CSV files in the ./data/sorted_stories directory.\"\n",
")\n"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "506d7a029e0b405d9af870a4e8f5a458",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at unsloth/Meta-Llama-3.1-8B and are newly initialized: ['score.weight']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
+ " 0%| | 1/35721 [00:01<10:31:47, 1.06s/it]\n"
+ ]
+ },
+ {
+ "ename": "OutOfMemoryError",
+ "evalue": "CUDA out of memory. Tried to allocate 234.00 MiB. GPU 0 has a total capacity of 79.22 GiB of which 202.06 MiB is free. Process 2226597 has 32.55 GiB memory in use. Process 2262461 has 46.46 GiB memory in use. Of the allocated memory 30.92 GiB is allocated by PyTorch, and 994.11 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[12], line 19\u001b[0m\n\u001b[1;32m 15\u001b[0m predictions \u001b[38;5;241m=\u001b[39m run_inference_transformers(stories[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mserialized\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mto_list(), mandt)\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m stories\u001b[38;5;241m.\u001b[39mwith_columns(pl\u001b[38;5;241m.\u001b[39mSeries(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpredictions\u001b[39m\u001b[38;5;124m\"\u001b[39m, values\u001b[38;5;241m=\u001b[39mpredictions))\n\u001b[0;32m---> 19\u001b[0m stories \u001b[38;5;241m=\u001b[39m \u001b[43mstories_with_predictions_sf\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m stories \u001b[38;5;241m=\u001b[39m stories\u001b[38;5;241m.\u001b[39mwith_columns(\n\u001b[1;32m 22\u001b[0m pl\u001b[38;5;241m.\u001b[39mSeries(\n\u001b[1;32m 23\u001b[0m name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpredicted_score\u001b[39m\u001b[38;5;124m\"\u001b[39m, values\u001b[38;5;241m=\u001b[39mstories[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpredictions\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mexp()\u001b[38;5;241m.\u001b[39mcast(pl\u001b[38;5;241m.\u001b[39mInt64)\n\u001b[1;32m 24\u001b[0m )\n\u001b[1;32m 25\u001b[0m )\n",
+ "File \u001b[0;32m/workspace/best-hn/utils.py:23\u001b[0m, in \u001b[0;36mcache_dataframe..decorator..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 21\u001b[0m df \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mread_parquet(path)\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 23\u001b[0m df \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCaching dataframe to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 25\u001b[0m df\u001b[38;5;241m.\u001b[39mwrite_parquet(path)\n",
+ "Cell \u001b[0;32mIn[12], line 15\u001b[0m, in \u001b[0;36mstories_with_predictions_sf\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m stories \u001b[38;5;241m=\u001b[39m stories_dataset()\n\u001b[1;32m 13\u001b[0m mandt \u001b[38;5;241m=\u001b[39m load_peft_model(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./models/stories_model_schedulefree_v1\u001b[39m\u001b[38;5;124m\"\u001b[39m, merge\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m---> 15\u001b[0m predictions \u001b[38;5;241m=\u001b[39m \u001b[43mrun_inference_transformers\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstories\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mserialized\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_list\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmandt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m stories\u001b[38;5;241m.\u001b[39mwith_columns(pl\u001b[38;5;241m.\u001b[39mSeries(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpredictions\u001b[39m\u001b[38;5;124m\"\u001b[39m, values\u001b[38;5;241m=\u001b[39mpredictions))\n",
+ "File \u001b[0;32m/workspace/best-hn/inference.py:67\u001b[0m, in \u001b[0;36mrun_inference_transformers\u001b[0;34m(prompts, model_or_path, batch_size)\u001b[0m\n\u001b[1;32m 63\u001b[0m inputs \u001b[38;5;241m=\u001b[39m tokenizer(\n\u001b[1;32m 64\u001b[0m batch, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 65\u001b[0m )\u001b[38;5;241m.\u001b[39mto(model\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 67\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 68\u001b[0m logits \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlogits\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 69\u001b[0m results\u001b[38;5;241m.\u001b[39mextend(logits\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mtolist())\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py:1290\u001b[0m, in \u001b[0;36mLlamaForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;124;03mlabels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \u001b[38;5;124;03m Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\u001b[39;00m\n\u001b[1;32m 1285\u001b[0m \u001b[38;5;124;03m config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\u001b[39;00m\n\u001b[1;32m 1286\u001b[0m \u001b[38;5;124;03m `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\u001b[39;00m\n\u001b[1;32m 1287\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1288\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[0;32m-> 1290\u001b[0m transformer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1291\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1292\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1293\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1294\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1295\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1296\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1297\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1298\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1299\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1300\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1301\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m transformer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1302\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscore(hidden_states)\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py:944\u001b[0m, in \u001b[0;36mLlamaModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[1;32m 932\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 933\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 934\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 941\u001b[0m position_embeddings,\n\u001b[1;32m 942\u001b[0m )\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 944\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 945\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 946\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 947\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 948\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 949\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 950\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 951\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 952\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 953\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 955\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 957\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py:691\u001b[0m, in \u001b[0;36mLlamaDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 689\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 690\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_attention_layernorm(hidden_states)\n\u001b[0;32m--> 691\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 692\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 694\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (hidden_states,)\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/liger_kernel/transformers/swiglu.py:21\u001b[0m, in \u001b[0;36mLigerSwiGLUMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdown_proj(\n\u001b[0;32m---> 21\u001b[0m \u001b[43mLigerSiLUMulFunction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgate_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mup_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 22\u001b[0m )\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/torch/autograd/function.py:574\u001b[0m, in \u001b[0;36mFunction.apply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_are_functorch_transforms_active():\n\u001b[1;32m 572\u001b[0m \u001b[38;5;66;03m# See NOTE: [functorch vjp and autograd interaction]\u001b[39;00m\n\u001b[1;32m 573\u001b[0m args \u001b[38;5;241m=\u001b[39m _functorch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39munwrap_dead_wrappers(args)\n\u001b[0;32m--> 574\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_setup_ctx_defined:\n\u001b[1;32m 577\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 578\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIn order to use an autograd.Function with functorch transforms \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 579\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m(vmap, grad, jvp, jacrev, ...), it must override the setup_context \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 580\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstaticmethod. For more details, please see \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 581\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://pytorch.org/docs/main/notes/extending.func.html\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 582\u001b[0m )\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/liger_kernel/ops/utils.py:30\u001b[0m, in \u001b[0;36mensure_contiguous..wrapper\u001b[0;34m(ctx, *args, **kwargs)\u001b[0m\n\u001b[1;32m 28\u001b[0m args \u001b[38;5;241m=\u001b[39m [maybe_to_contiguous(arg) \u001b[38;5;28;01mfor\u001b[39;00m arg \u001b[38;5;129;01min\u001b[39;00m args]\n\u001b[1;32m 29\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m {k: maybe_to_contiguous(v) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m---> 30\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/liger_kernel/ops/swiglu.py:111\u001b[0m, in \u001b[0;36mLigerSiLUMulFunction.forward\u001b[0;34m(ctx, a, b)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;129m@ensure_contiguous\u001b[39m\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(ctx, a, b):\n\u001b[0;32m--> 111\u001b[0m a, b, c \u001b[38;5;241m=\u001b[39m \u001b[43mswiglu_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 112\u001b[0m ctx\u001b[38;5;241m.\u001b[39msave_for_backward(a, b)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m c\n",
+ "File \u001b[0;32m/workspace/best-hn/.venv/lib/python3.12/site-packages/liger_kernel/ops/swiglu.py:69\u001b[0m, in \u001b[0;36mswiglu_forward\u001b[0;34m(a, b)\u001b[0m\n\u001b[1;32m 67\u001b[0m a \u001b[38;5;241m=\u001b[39m a\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, n_cols)\n\u001b[1;32m 68\u001b[0m b \u001b[38;5;241m=\u001b[39m b\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, n_cols)\n\u001b[0;32m---> 69\u001b[0m c \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mempty_like\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 70\u001b[0m n_rows \u001b[38;5;241m=\u001b[39m a\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 72\u001b[0m BLOCK_SIZE, num_warps \u001b[38;5;241m=\u001b[39m calculate_settings(n_cols)\n",
+ "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 234.00 MiB. GPU 0 has a total capacity of 79.22 GiB of which 202.06 MiB is free. Process 2226597 has 32.55 GiB memory in use. Process 2262461 has 46.46 GiB memory in use. Of the allocated memory 30.92 GiB is allocated by PyTorch, and 994.11 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)"
+ ]
+ }
+ ],
+ "source": [
+ "from utils import (\n",
+ " cache_dataframe,\n",
+ " stories_dataset,\n",
+ ")\n",
+ "from inference import run_inference_transformers, load_peft_model\n",
+ "import polars as pl\n",
+ "\n",
+ "\n",
+ "@cache_dataframe(\"./data/stories_with_predictions_sf.parquet\")\n",
+ "def stories_with_predictions_sf():\n",
+ " stories = stories_dataset()\n",
+ "\n",
+ " mandt = load_peft_model(\"./models/stories_model_schedulefree_v1\", merge=True)\n",
+ "\n",
+ " predictions = run_inference_transformers(stories[\"serialized\"].to_list(), mandt)\n",
+ " return stories.with_columns(pl.Series(name=\"predictions\", values=predictions))\n",
+ "\n",
+ "\n",
+ "stories = stories_with_predictions_sf()\n",
+ "\n",
+ "stories = stories.with_columns(\n",
+ " pl.Series(\n",
+ " name=\"predicted_score\", values=stories[\"predictions\"].exp().cast(pl.Int64)\n",
+ " )\n",
+ ")"
+ ]
}
],
"metadata": {
diff --git a/stories_train_model_v10.py b/stories_train_model_v10.py
new file mode 100644
index 0000000..725c59f
--- /dev/null
+++ b/stories_train_model_v10.py
@@ -0,0 +1,136 @@
+import torch
+from datasets import load_dataset, Dataset
+from transformers import (
+ AutoModelForSequenceClassification,
+ AutoTokenizer,
+ Trainer,
+ TrainingArguments,
+)
+from peft.tuners.lora import LoraConfig
+from peft.mapping import get_peft_model
+import wandb
+from dotenv import load_dotenv
+import polars as pl
+from utils import stories_dataset
+from liger_kernel.transformers import _apply_liger_kernel_to_instance
+from training_helpers import compute_metrics
+
+load_dotenv("/workspace/.env")
+
+# Configuration
+base_model = "google/gemma-2-9b"
+run_name = __file__.split("/")[-1].replace(".py", "")
+output_dir = f"./models/{run_name}"
+num_epochs = 1
+batch_size = 2
+gradient_accumulation_steps = 8
+learning_rate = 2e-4
+max_length = 4096
+
+# Initialize wandb
+wandb.init(project="hn_stories_model_training", name=run_name)
+
+
+def create_dataset(split, num_rows, tokenizer):
+ stories = stories_dataset()
+ stories = stories.filter(pl.col("split") == split).head(num_rows)
+
+ stories = stories.with_columns(
+ [
+ pl.col("serialized").alias("text"),
+ pl.col("log_score").alias("label"),
+ ]
+ )
+
+ stories = stories.with_columns(
+ [
+ pl.col("text")
+ .map_elements(
+ lambda x: tokenizer(x)["input_ids"], return_dtype=pl.List(pl.Int64)
+ )
+ .alias("input_ids"),
+ ]
+ ).select(["input_ids", "label"])
+ return Dataset.from_polars(stories)
+
+
+print("Loading tokenizer and model...")
+tokenizer = AutoTokenizer.from_pretrained(
+ base_model,
+ truncation=True,
+ padding=True,
+ max_length=max_length,
+)
+
+model = AutoModelForSequenceClassification.from_pretrained(
+ base_model,
+ num_labels=1, # Regression task
+ device_map="auto",
+ attn_implementation="flash_attention_2",
+ torch_dtype=torch.bfloat16,
+)
+_apply_liger_kernel_to_instance(model=model)
+
+model.config.pad_token_id = tokenizer.pad_token_id
+tokenizer.padding_side = "right"
+
+print("Configuring LoRA...")
+model = get_peft_model(
+ model,
+ LoraConfig(
+ task_type="SEQ_CLS",
+ r=8,
+ lora_alpha=16,
+ lora_dropout=0,
+ ),
+)
+
+print("Loading dataset...")
+train_stories = create_dataset("train", 1000000, tokenizer)
+validation_stories = create_dataset("val", 1000, tokenizer)
+
+
+# Configure training arguments
+training_args = TrainingArguments(
+ output_dir=output_dir,
+ num_train_epochs=num_epochs,
+ per_device_train_batch_size=batch_size,
+ per_device_eval_batch_size=batch_size,
+ learning_rate=learning_rate,
+ weight_decay=0,
+ evaluation_strategy="steps",
+ eval_steps=0.05,
+ logging_steps=100,
+ save_strategy="steps",
+ save_steps=1000,
+ report_to="wandb",
+ no_cuda=False,
+ bf16=True,
+ warmup_steps=100,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+)
+
+
+print("Initializing Trainer...")
+trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_stories,
+ eval_dataset=validation_stories,
+ tokenizer=tokenizer,
+ compute_metrics=compute_metrics,
+)
+
+print("Running initial evaluation...")
+results = trainer.evaluate()
+print("Initial evaluation complete")
+print(results)
+
+print("Starting model training...")
+trainer.train()
+
+print("Saving final model...")
+trainer.save_model(output_dir)
+tokenizer.save_pretrained(output_dir)
+
+print("Stories model training complete")
diff --git a/stories_train_model_v2.py b/stories_train_model_v2.py
index aaf74ee..0e929c7 100644
--- a/stories_train_model_v2.py
+++ b/stories_train_model_v2.py
@@ -14,6 +14,7 @@
from utils import stories_dataset
from sklearn.metrics import mean_squared_error
from liger_kernel.transformers import _apply_liger_kernel_to_instance
+from training_helpers import compute_metrics
load_dotenv("/workspace/.env")
@@ -110,12 +111,6 @@ def create_dataset(split, num_rows, tokenizer):
)
-def compute_metrics(eval_pred):
- predictions, labels = eval_pred
- rmse = mean_squared_error(labels, predictions, squared=False)
- return {"rmse": rmse}
-
-
print("Initializing Trainer...")
trainer = Trainer(
model=model,
diff --git a/stories_train_model_v3.py b/stories_train_model_v3.py
index 6ff3fb6..80a03be 100644
--- a/stories_train_model_v3.py
+++ b/stories_train_model_v3.py
@@ -12,14 +12,14 @@
from dotenv import load_dotenv
import polars as pl
from utils import stories_dataset
-from sklearn.metrics import mean_squared_error
from liger_kernel.transformers import _apply_liger_kernel_to_instance
+from training_helpers import compute_metrics
load_dotenv("/workspace/.env")
# Configuration
base_model = "unsloth/Meta-Llama-3.1-8B"
-run_name = "stories_model_v2"
+run_name = __file__.split("/")[-1].replace(".py", "")
output_dir = f"./models/{run_name}"
num_epochs = 1
batch_size = 4
@@ -108,16 +108,9 @@ def create_dataset(split, num_rows, tokenizer):
bf16=True,
warmup_steps=100,
gradient_accumulation_steps=gradient_accumulation_steps,
- # use_liger_kernel=True,
)
-def compute_metrics(eval_pred):
- predictions, labels = eval_pred
- rmse = mean_squared_error(labels, predictions, squared=False)
- return {"rmse": rmse}
-
-
print("Initializing Trainer...")
trainer = Trainer(
model=model,
diff --git a/stories_train_model_v4.py b/stories_train_model_v4.py
index 93f1b2a..794b6a9 100644
--- a/stories_train_model_v4.py
+++ b/stories_train_model_v4.py
@@ -19,9 +19,9 @@
# Configuration
base_model = "unsloth/Meta-Llama-3.1-8B"
-run_name = "stories_model_v4"
+run_name = "stories_model_schedulefree_v1"
output_dir = f"./models/{run_name}"
-num_epochs = 2
+num_epochs = 1
batch_size = 4
gradient_accumulation_steps = 4
learning_rate = 2e-4
@@ -104,6 +104,7 @@ def create_dataset(split, num_rows, tokenizer):
save_strategy="steps",
save_steps=1000,
report_to="wandb",
+ optim="schedule_free_adamw",
no_cuda=False,
bf16=True,
warmup_steps=100,
diff --git a/stories_train_model_v5.py b/stories_train_model_v5.py
new file mode 100644
index 0000000..1597b3c
--- /dev/null
+++ b/stories_train_model_v5.py
@@ -0,0 +1,140 @@
+import torch
+from datasets import load_dataset, Dataset
+from transformers import (
+ AutoModelForSequenceClassification,
+ AutoTokenizer,
+ Trainer,
+ TrainingArguments,
+)
+from peft.tuners.lora import LoraConfig
+from peft.mapping import get_peft_model
+import wandb
+from dotenv import load_dotenv
+import polars as pl
+from utils import stories_dataset
+from liger_kernel.transformers import _apply_liger_kernel_to_instance
+from training_helpers import compute_metrics
+
+load_dotenv("/workspace/.env")
+
+# Configuration
+base_model = "unsloth/Meta-Llama-3.1-8B"
+run_name = "stories_train_model_v5"
+output_dir = f"./models/{run_name}"
+num_epochs = 1
+batch_size = 4
+gradient_accumulation_steps = 4
+learning_rate = 2e-4
+max_length = 4096
+
+# Initialize wandb
+wandb.init(project="hn_stories_model_training", name=run_name)
+
+
+def create_dataset(split, num_rows, tokenizer):
+ stories = stories_dataset()
+ stories = stories.filter(pl.col("split") == split).head(num_rows)
+
+ stories = stories.with_columns(
+ [
+ pl.col("serialized").alias("text"),
+ pl.col("log_score").alias("label"),
+ ]
+ )
+
+ stories = stories.with_columns(
+ [
+ pl.col("text")
+ .map_elements(
+ lambda x: tokenizer(x)["input_ids"], return_dtype=pl.List(pl.Int64)
+ )
+ .alias("input_ids"),
+ ]
+ ).select(["input_ids", "label"])
+ return Dataset.from_polars(stories)
+
+
+print("Loading tokenizer and model...")
+tokenizer = AutoTokenizer.from_pretrained(
+ base_model,
+ truncation=True,
+ padding=True,
+ max_length=max_length,
+)
+
+model = AutoModelForSequenceClassification.from_pretrained(
+ base_model,
+ num_labels=1, # Regression task
+ device_map="auto",
+ attn_implementation="flash_attention_2",
+ torch_dtype=torch.bfloat16,
+)
+_apply_liger_kernel_to_instance(model=model)
+
+model.config.pad_token_id = tokenizer.pad_token_id
+tokenizer.padding_side = "right"
+
+print("Configuring LoRA...")
+model = get_peft_model(
+ model,
+ LoraConfig(
+ task_type="SEQ_CLS",
+ target_modules=[
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "gate_proj",
+ "up_proj",
+ "down_proj",
+ ],
+ r=8,
+ lora_alpha=16,
+ lora_dropout=0,
+ ),
+)
+
+print("Loading dataset...")
+train_stories = create_dataset("train", 1000000, tokenizer)
+validation_stories = create_dataset("val", 1000, tokenizer)
+
+
+# Configure training arguments
+training_args = TrainingArguments(
+ output_dir=output_dir,
+ num_train_epochs=num_epochs,
+ per_device_train_batch_size=batch_size,
+ per_device_eval_batch_size=batch_size,
+ learning_rate=learning_rate,
+ weight_decay=0,
+ evaluation_strategy="steps",
+ eval_steps=0.05,
+ logging_steps=100,
+ save_strategy="steps",
+ save_steps=1000,
+ report_to="wandb",
+ no_cuda=False,
+ bf16=True,
+ warmup_steps=100,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+)
+
+
+print("Initializing Trainer...")
+trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_stories,
+ eval_dataset=validation_stories,
+ tokenizer=tokenizer,
+ compute_metrics=compute_metrics,
+)
+
+print("Starting model training...")
+trainer.train()
+
+print("Saving final model...")
+trainer.save_model(output_dir)
+tokenizer.save_pretrained(output_dir)
+
+print("Stories model training complete")
diff --git a/stories_train_model_v6.py b/stories_train_model_v6.py
new file mode 100644
index 0000000..fc40fe4
--- /dev/null
+++ b/stories_train_model_v6.py
@@ -0,0 +1,145 @@
+import torch
+from datasets import load_dataset, Dataset
+from transformers import (
+ AutoModelForSequenceClassification,
+ AutoTokenizer,
+ Trainer,
+ TrainingArguments,
+)
+from peft.tuners.lora import LoraConfig
+from peft.mapping import get_peft_model
+import wandb
+from dotenv import load_dotenv
+import polars as pl
+from utils import stories_dataset
+from liger_kernel.transformers import _apply_liger_kernel_to_instance
+from training_helpers import compute_metrics
+import os
+
+load_dotenv("/workspace/.env")
+
+# Configuration
+base_model = "unsloth/Meta-Llama-3.1-8B"
+run_name = "stories_train_model_v6"
+output_dir = f"./models/{run_name}"
+num_epochs = 1
+batch_size = 4
+gradient_accumulation_steps = 4
+learning_rate = 2e-4
+max_length = 4096
+
+# Initialize wandb
+wandb.init(project="hn_stories_model_training", name=run_name)
+
+
+def create_dataset(split, num_rows, tokenizer):
+ stories = stories_dataset()
+ stories = stories.filter(pl.col("split") == split).head(num_rows)
+
+ stories = stories.with_columns(
+ [
+ pl.col("serialized").alias("text"),
+ pl.col("log_score").alias("label"),
+ ]
+ )
+
+ stories = stories.with_columns(
+ [
+ pl.col("text")
+ .map_elements(
+ lambda x: tokenizer(x)["input_ids"], return_dtype=pl.List(pl.Int64)
+ )
+ .alias("input_ids"),
+ ]
+ ).select(["input_ids", "label"])
+ return Dataset.from_polars(stories)
+
+
+print("Loading tokenizer and model...")
+tokenizer = AutoTokenizer.from_pretrained(
+ base_model,
+ truncation=True,
+ padding=True,
+ max_length=max_length,
+)
+
+model = AutoModelForSequenceClassification.from_pretrained(
+ base_model,
+ num_labels=1, # Regression task
+ device_map="auto",
+ attn_implementation="flash_attention_2",
+ torch_dtype=torch.bfloat16,
+)
+_apply_liger_kernel_to_instance(model=model)
+
+# Add this block to freeze all parameters except the classification head
+for param in model.parameters():
+ param.requires_grad = False
+for param in model.score.parameters():
+ param.requires_grad = True
+
+model.config.pad_token_id = tokenizer.pad_token_id
+tokenizer.padding_side = "right"
+
+print("Loading dataset...")
+train_stories = create_dataset("train", 1000000, tokenizer)
+validation_stories = create_dataset("val", 1000, tokenizer)
+
+
+# Configure training arguments
+training_args = TrainingArguments(
+ output_dir=output_dir,
+ num_train_epochs=num_epochs,
+ per_device_train_batch_size=batch_size,
+ per_device_eval_batch_size=batch_size,
+ learning_rate=learning_rate,
+ weight_decay=0,
+ evaluation_strategy="steps",
+ eval_steps=0.05,
+ logging_steps=100,
+ save_strategy="steps",
+ save_steps=1000,
+ report_to="wandb",
+ no_cuda=False,
+ bf16=True,
+ warmup_steps=100,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+)
+
+
+class ClassificationHeadTrainer(Trainer):
+ def _save(self, output_dir: str, state_dict=None):
+ # Only save the classification head parameters
+ if state_dict is None:
+ state_dict = self.model.state_dict()
+
+ head_state_dict = {
+ k: v for k, v in state_dict.items() if k.startswith("score.")
+ }
+
+ os.makedirs(output_dir, exist_ok=True)
+ torch.save(head_state_dict, os.path.join(output_dir, "classification_head.bin"))
+
+ def _load_state_dict_in_model(self, state_dict):
+ # Load only classification head parameters
+ self.model.score.load_state_dict(state_dict)
+
+
+print("Initializing Trainer...")
+trainer = ClassificationHeadTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_stories,
+ eval_dataset=validation_stories,
+ tokenizer=tokenizer,
+ compute_metrics=compute_metrics,
+)
+
+print("Starting model training...")
+trainer.train()
+
+print("Saving final model...")
+trainer.save_model(output_dir)
+tokenizer.save_pretrained(output_dir)
+
+print("Stories model training complete")
diff --git a/stories_train_model_v7.py b/stories_train_model_v7.py
new file mode 100644
index 0000000..8a1a4ca
--- /dev/null
+++ b/stories_train_model_v7.py
@@ -0,0 +1,132 @@
+import torch
+from datasets import load_dataset, Dataset
+from transformers import (
+ AutoModelForSequenceClassification,
+ AutoTokenizer,
+ Trainer,
+ TrainingArguments,
+)
+from peft.tuners.lora import LoraConfig
+from peft.mapping import get_peft_model
+import wandb
+from dotenv import load_dotenv
+import polars as pl
+from utils import stories_dataset
+from liger_kernel.transformers import _apply_liger_kernel_to_instance
+from training_helpers import compute_metrics
+
+load_dotenv("/workspace/.env")
+
+# Configuration
+base_model = "unsloth/Meta-Llama-3.1-8B"
+run_name = __file__.split("/")[-1].replace(".py", "")
+output_dir = f"./models/{run_name}"
+num_epochs = 1
+batch_size = 4
+gradient_accumulation_steps = 4
+learning_rate = 2e-4
+max_length = 4096
+
+# Initialize wandb
+wandb.init(project="hn_stories_model_training", name=run_name)
+
+
+def create_dataset(split, num_rows, tokenizer):
+ stories = stories_dataset()
+ stories = stories.filter(pl.col("split") == split).head(num_rows)
+
+ stories = stories.with_columns(
+ [
+ pl.col("serialized").alias("text"),
+ pl.col("log_score").alias("label"),
+ ]
+ )
+
+ stories = stories.with_columns(
+ [
+ pl.col("text")
+ .map_elements(
+ lambda x: tokenizer(x)["input_ids"], return_dtype=pl.List(pl.Int64)
+ )
+ .alias("input_ids"),
+ ]
+ ).select(["input_ids", "label"])
+ return Dataset.from_polars(stories)
+
+
+print("Loading tokenizer and model...")
+tokenizer = AutoTokenizer.from_pretrained(
+ base_model,
+ truncation=True,
+ padding=True,
+ max_length=max_length,
+)
+
+model = AutoModelForSequenceClassification.from_pretrained(
+ base_model,
+ num_labels=1, # Regression task
+ device_map="auto",
+ attn_implementation="flash_attention_2",
+ torch_dtype=torch.bfloat16,
+)
+_apply_liger_kernel_to_instance(model=model)
+
+model.config.pad_token_id = tokenizer.pad_token_id
+tokenizer.padding_side = "right"
+
+print("Configuring LoRA...")
+model = get_peft_model(
+ model,
+ LoraConfig(
+ task_type="SEQ_CLS",
+ r=8,
+ lora_alpha=16,
+ lora_dropout=0,
+ ),
+)
+
+print("Loading dataset...")
+train_stories = create_dataset("train", 1000000, tokenizer)
+validation_stories = create_dataset("val", 1000, tokenizer)
+
+
+# Configure training arguments
+training_args = TrainingArguments(
+ output_dir=output_dir,
+ num_train_epochs=num_epochs,
+ per_device_train_batch_size=batch_size,
+ per_device_eval_batch_size=batch_size,
+ learning_rate=learning_rate,
+ weight_decay=0,
+ evaluation_strategy="steps",
+ eval_steps=0.05,
+ logging_steps=100,
+ save_strategy="steps",
+ save_steps=1000,
+ report_to="wandb",
+ optim="schedule_free_adamw",
+ no_cuda=False,
+ bf16=True,
+ warmup_steps=100,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+)
+
+
+print("Initializing Trainer...")
+trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_stories,
+ eval_dataset=validation_stories,
+ tokenizer=tokenizer,
+ compute_metrics=compute_metrics,
+)
+
+print("Starting model training...")
+trainer.train()
+
+print("Saving final model...")
+trainer.save_model(output_dir)
+tokenizer.save_pretrained(output_dir)
+
+print("Stories model training complete")
diff --git a/stories_train_model_v8.py b/stories_train_model_v8.py
new file mode 100644
index 0000000..f226b28
--- /dev/null
+++ b/stories_train_model_v8.py
@@ -0,0 +1,131 @@
+import torch
+from datasets import load_dataset, Dataset
+from transformers import (
+ AutoModelForSequenceClassification,
+ AutoTokenizer,
+ Trainer,
+ TrainingArguments,
+)
+from peft.tuners.lora import LoraConfig
+from peft.mapping import get_peft_model
+import wandb
+from dotenv import load_dotenv
+import polars as pl
+from utils import stories_dataset
+from liger_kernel.transformers import _apply_liger_kernel_to_instance
+from training_helpers import compute_metrics
+
+load_dotenv("/workspace/.env")
+
+# Configuration
+base_model = "Qwen/Qwen2.5-7B"
+run_name = __file__.split("/")[-1].replace(".py", "")
+output_dir = f"./models/{run_name}"
+num_epochs = 1
+batch_size = 4
+gradient_accumulation_steps = 4
+learning_rate = 2e-4
+max_length = 4096
+
+# Initialize wandb
+wandb.init(project="hn_stories_model_training", name=run_name)
+
+
+def create_dataset(split, num_rows, tokenizer):
+ stories = stories_dataset()
+ stories = stories.filter(pl.col("split") == split).head(num_rows)
+
+ stories = stories.with_columns(
+ [
+ pl.col("serialized").alias("text"),
+ pl.col("log_score").alias("label"),
+ ]
+ )
+
+ stories = stories.with_columns(
+ [
+ pl.col("text")
+ .map_elements(
+ lambda x: tokenizer(x)["input_ids"], return_dtype=pl.List(pl.Int64)
+ )
+ .alias("input_ids"),
+ ]
+ ).select(["input_ids", "label"])
+ return Dataset.from_polars(stories)
+
+
+print("Loading tokenizer and model...")
+tokenizer = AutoTokenizer.from_pretrained(
+ base_model,
+ truncation=True,
+ padding=True,
+ max_length=max_length,
+)
+
+model = AutoModelForSequenceClassification.from_pretrained(
+ base_model,
+ num_labels=1, # Regression task
+ device_map="auto",
+ attn_implementation="flash_attention_2",
+ torch_dtype=torch.bfloat16,
+)
+_apply_liger_kernel_to_instance(model=model)
+
+model.config.pad_token_id = tokenizer.pad_token_id
+tokenizer.padding_side = "right"
+
+print("Configuring LoRA...")
+model = get_peft_model(
+ model,
+ LoraConfig(
+ task_type="SEQ_CLS",
+ r=8,
+ lora_alpha=16,
+ lora_dropout=0,
+ ),
+)
+
+print("Loading dataset...")
+train_stories = create_dataset("train", 1000000, tokenizer)
+validation_stories = create_dataset("val", 1000, tokenizer)
+
+
+# Configure training arguments
+training_args = TrainingArguments(
+ output_dir=output_dir,
+ num_train_epochs=num_epochs,
+ per_device_train_batch_size=batch_size,
+ per_device_eval_batch_size=batch_size,
+ learning_rate=learning_rate,
+ weight_decay=0,
+ evaluation_strategy="steps",
+ eval_steps=0.05,
+ logging_steps=100,
+ save_strategy="steps",
+ save_steps=1000,
+ report_to="wandb",
+ no_cuda=False,
+ bf16=True,
+ warmup_steps=100,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+)
+
+
+print("Initializing Trainer...")
+trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_stories,
+ eval_dataset=validation_stories,
+ tokenizer=tokenizer,
+ compute_metrics=compute_metrics,
+)
+
+print("Starting model training...")
+trainer.train()
+
+print("Saving final model...")
+trainer.save_model(output_dir)
+tokenizer.save_pretrained(output_dir)
+
+print("Stories model training complete")
diff --git a/stories_train_model_v9.py b/stories_train_model_v9.py
new file mode 100644
index 0000000..5e9cde6
--- /dev/null
+++ b/stories_train_model_v9.py
@@ -0,0 +1,131 @@
+import torch
+from datasets import load_dataset, Dataset
+from transformers import (
+ AutoModelForSequenceClassification,
+ AutoTokenizer,
+ Trainer,
+ TrainingArguments,
+)
+from peft.tuners.lora import LoraConfig
+from peft.mapping import get_peft_model
+import wandb
+from dotenv import load_dotenv
+import polars as pl
+from utils import stories_dataset
+from liger_kernel.transformers import _apply_liger_kernel_to_instance
+from training_helpers import compute_metrics
+
+load_dotenv("/workspace/.env")
+
+# Configuration
+base_model = "Qwen/Qwen2.5-14B"
+run_name = __file__.split("/")[-1].replace(".py", "")
+output_dir = f"./models/{run_name}"
+num_epochs = 1
+batch_size = 2
+gradient_accumulation_steps = 8
+learning_rate = 2e-4
+max_length = 4096
+
+# Initialize wandb
+wandb.init(project="hn_stories_model_training", name=run_name)
+
+
+def create_dataset(split, num_rows, tokenizer):
+ stories = stories_dataset()
+ stories = stories.filter(pl.col("split") == split).head(num_rows)
+
+ stories = stories.with_columns(
+ [
+ pl.col("serialized").alias("text"),
+ pl.col("log_score").alias("label"),
+ ]
+ )
+
+ stories = stories.with_columns(
+ [
+ pl.col("text")
+ .map_elements(
+ lambda x: tokenizer(x)["input_ids"], return_dtype=pl.List(pl.Int64)
+ )
+ .alias("input_ids"),
+ ]
+ ).select(["input_ids", "label"])
+ return Dataset.from_polars(stories)
+
+
+print("Loading tokenizer and model...")
+tokenizer = AutoTokenizer.from_pretrained(
+ base_model,
+ truncation=True,
+ padding=True,
+ max_length=max_length,
+)
+
+model = AutoModelForSequenceClassification.from_pretrained(
+ base_model,
+ num_labels=1, # Regression task
+ device_map="auto",
+ attn_implementation="flash_attention_2",
+ torch_dtype=torch.bfloat16,
+)
+_apply_liger_kernel_to_instance(model=model)
+
+model.config.pad_token_id = tokenizer.pad_token_id
+tokenizer.padding_side = "right"
+
+print("Configuring LoRA...")
+model = get_peft_model(
+ model,
+ LoraConfig(
+ task_type="SEQ_CLS",
+ r=8,
+ lora_alpha=16,
+ lora_dropout=0,
+ ),
+)
+
+print("Loading dataset...")
+train_stories = create_dataset("train", 1000000, tokenizer)
+validation_stories = create_dataset("val", 1000, tokenizer)
+
+
+# Configure training arguments
+training_args = TrainingArguments(
+ output_dir=output_dir,
+ num_train_epochs=num_epochs,
+ per_device_train_batch_size=batch_size,
+ per_device_eval_batch_size=batch_size,
+ learning_rate=learning_rate,
+ weight_decay=0,
+ evaluation_strategy="steps",
+ eval_steps=0.05,
+ logging_steps=100,
+ save_strategy="steps",
+ save_steps=1000,
+ report_to="wandb",
+ no_cuda=False,
+ bf16=True,
+ warmup_steps=100,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+)
+
+
+print("Initializing Trainer...")
+trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_stories,
+ eval_dataset=validation_stories,
+ tokenizer=tokenizer,
+ compute_metrics=compute_metrics,
+)
+
+print("Starting model training...")
+trainer.train()
+
+print("Saving final model...")
+trainer.save_model(output_dir)
+tokenizer.save_pretrained(output_dir)
+
+print("Stories model training complete")
diff --git a/training_helpers.py b/training_helpers.py
new file mode 100644
index 0000000..c0916c0
--- /dev/null
+++ b/training_helpers.py
@@ -0,0 +1,27 @@
+import torch
+from scipy.stats import pearsonr
+from sklearn.metrics import root_mean_squared_error
+
+
+def compute_metrics(eval_pred):
+ predictions, labels = eval_pred
+
+ print("before")
+ print(predictions.shape, labels.shape)
+
+ # Convert numpy arrays to torch tensors
+ predictions = torch.tensor(predictions).squeeze()
+ labels = torch.tensor(labels)
+
+ print("after")
+ print(predictions.shape, labels.shape)
+
+ # Filter out NaN values
+ valid_indices = ~torch.isnan(predictions) & ~torch.isnan(labels)
+ valid_predictions = predictions[valid_indices]
+ valid_labels = labels[valid_indices]
+
+ return {
+ "rmse": root_mean_squared_error(valid_labels, valid_predictions),
+ "correlation": pearsonr(valid_labels, valid_predictions)[0],
+ }
diff --git a/uv.lock b/uv.lock
index 72afe7a..bf95973 100644
--- a/uv.lock
+++ b/uv.lock
@@ -289,6 +289,7 @@ dependencies = [
{ name = "peft" },
{ name = "polars" },
{ name = "python-dotenv" },
+ { name = "schedulefree" },
{ name = "scikit-learn" },
{ name = "seaborn" },
{ name = "sglang", extra = ["all"] },
@@ -314,6 +315,7 @@ requires-dist = [
{ name = "peft", specifier = ">=0.13.2" },
{ name = "polars", specifier = ">=1.9.0" },
{ name = "python-dotenv", specifier = ">=1.0.1" },
+ { name = "schedulefree", specifier = ">=1.2.7" },
{ name = "scikit-learn", specifier = ">=1.5.2" },
{ name = "seaborn", specifier = ">=0.13.2" },
{ name = "sglang", extras = ["all"], editable = "deps/sglang/python" },
@@ -3057,6 +3059,12 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/82/cc/9c2cf58611daf1c83ce5d37f9de66353e23fcda36008b13fd3409a760aa3/safetensors-0.4.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:dbd280b07e6054ea68b0cb4b16ad9703e7d63cd6890f577cb98acc5354780142", size = 605580 },
]
+[[package]]
+name = "schedulefree"
+version = "1.2.7"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/3f/b8/5e0bdfcd5555654fcb976df8b0f413c1ec26efe016e7666abc2a7dd5f218/schedulefree-1.2.7.tar.gz", hash = "sha256:e97f8b8db332dafc9999a2c4ed68429606579d5eb250ebb953f233efffdc79c1", size = 19550 }
+
[[package]]
name = "scikit-learn"
version = "1.5.2"