-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
3,505 additions
and
269 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,4 +166,5 @@ last_run_prepared/ | |
|
||
data/ | ||
wandb/ | ||
reward_model_output/ | ||
reward_model_output/ | ||
models/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "deps/sglang"] | ||
path = deps/sglang | ||
url = [email protected]:OpenPipe/sglang.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div><style>\n", | ||
".dataframe > thead > tr,\n", | ||
".dataframe > tbody > tr {\n", | ||
" text-align: right;\n", | ||
" white-space: pre-wrap;\n", | ||
"}\n", | ||
"</style>\n", | ||
"<small>shape: (340_688, 18)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>id</th><th>type</th><th>by</th><th>time</th><th>title</th><th>text</th><th>url</th><th>score</th><th>parent</th><th>top_level_parent</th><th>descendants</th><th>kids</th><th>deleted</th><th>dead</th><th>siblings_count</th><th>sibling_rank</th><th>prompt</th><th>reward</th></tr><tr><td>i64</td><td>str</td><td>str</td><td>datetime[μs]</td><td>str</td><td>str</td><td>str</td><td>i64</td><td>i64</td><td>i64</td><td>i64</td><td>list[i64]</td><td>bool</td><td>bool</td><td>u32</td><td>i64</td><td>str</td><td>f64</td></tr></thead><tbody><tr><td>29389287</td><td>"comment"</td><td>"jsc1986"</td><td>2021-11-30 05:32:42</td><td>null</td><td>"Perhaps it was just our batch,…</td><td>null</td><td>null</td><td>29389229</td><td>29387264</td><td>null</td><td>[29389323, 29389355]</td><td>null</td><td>null</td><td>6</td><td>1</td><td>"<instructions>Your goal is to …</td><td>32.25</td></tr><tr><td>2920304</td><td>"comment"</td><td>"wheels"</td><td>2011-08-24 11:35:53</td><td>null</td><td>"For the record, I don't live i…</td><td>null</td><td>null</td><td>2920148</td><td>2919708</td><td>null</td><td>null</td><td>null</td><td>null</td><td>7</td><td>1</td><td>"<instructions>Your goal is to …</td><td>30.125</td></tr><tr><td>29390682</td><td>"comment"</td><td>"ZephyrBlu"</td><td>2021-11-30 10:22:12</td><td>null</td><td>"I find it unlikely, but not ex…</td><td>null</td><td>null</td><td>29390566</td><td>29387264</td><td>null</td><td>[29390962]</td><td>null</td><td>null</td><td>5</td><td>1</td><td>"<instructions>Your goal is to …</td><td>29.25</td></tr><tr><td>29389031</td><td>"comment"</td><td>"temp7536"</td><td>2021-11-30 04:39:01</td><td>null</td><td>"I&#x27;m sorry but no. Patrick…</td><td>null</td><td>null</td><td>29388863</td><td>29387264</td><td>null</td><td>[29389537, 29389411, … 29389222]</td><td>null</td><td>null</td><td>7</td><td>1</td><td>"<instructions>Your goal is to …</td><td>28.625</td></tr><tr><td>6370703</td><td>"comment"</td><td>"enraged_camel"</td><td>2013-09-11 22:53:13</td><td>null</td><td>"I disagree completely.<p>The f…</td><td>null</td><td>null</td><td>6370519</td><td>6369530</td><td>null</td><td>[6370714, 6371017, … 6370873]</td><td>null</td><td>null</td><td>8</td><td>1</td><td>"<instructions>Your goal is to …</td><td>28.5</td></tr><tr><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td><td>…</td></tr><tr><td>13831616</td><td>"comment"</td><td>"numbsafari"</td><td>2017-03-09 18:27:59</td><td>null</td><td>"Any public timeline on HA feat…</td><td>null</td><td>null</td><td>13831415</td><td>13831277</td><td>null</td><td>[13831972, 13831764]</td><td>null</td><td>null</td><td>10</td><td>1</td><td>"<instructions>Your goal is to …</td><td>-23.25</td></tr><tr><td>29273269</td><td>"comment"</td><td>"elwell"</td><td>2021-11-19 02:55:45</td><td>null</td><td>"Would be interested in seeing …</td><td>null</td><td>null</td><td>29262357</td><td>29262357</td><td>null</td><td>[29273533]</td><td>null</td><td>null</td><td>5</td><td>1</td><td>"<instructions>Your goal is to …</td><td>-23.25</td></tr><tr><td>32693407</td><td>"comment"</td><td>"jaquilio"</td><td>2022-09-02 17:28:12</td><td>null</td><td>"How does this compare to Elect…</td><td>null</td><td>null</td><td>32524577</td><td>32524577</td><td>null</td><td>null</td><td>null</td><td>null</td><td>12</td><td>1</td><td>"<instructions>Your goal is to …</td><td>-23.625</td></tr><tr><td>1436880</td><td>"comment"</td><td>"acgourley"</td><td>2010-06-16 20:15:21</td><td>null</td><td>"Wow that's an impressive hack."</td><td>null</td><td>null</td><td>1436658</td><td>1436658</td><td>null</td><td>null</td><td>null</td><td>null</td><td>9</td><td>1</td><td>"<instructions>Your goal is to …</td><td>-23.875</td></tr><tr><td>2160465</td><td>"comment"</td><td>"pan69"</td><td>2011-01-31 05:52:40</td><td>null</td><td>"Hires image:\n", | ||
"<a href="http://w…</td><td>null</td><td>null</td><td>2160446</td><td>2160446</td><td>null</td><td>null</td><td>null</td><td>null</td><td>9</td><td>1</td><td>"<instructions>Your goal is to …</td><td>-23.875</td></tr></tbody></table></div>" | ||
], | ||
"text/plain": [ | ||
"shape: (340_688, 18)\n", | ||
"┌──────────┬─────────┬────────────┬────────────┬───┬────────────┬────────────┬───────────┬─────────┐\n", | ||
"│ id ┆ type ┆ by ┆ time ┆ … ┆ siblings_c ┆ sibling_ra ┆ prompt ┆ reward │\n", | ||
"│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ ount ┆ nk ┆ --- ┆ --- │\n", | ||
"│ i64 ┆ str ┆ str ┆ datetime[μ ┆ ┆ --- ┆ --- ┆ str ┆ f64 │\n", | ||
"│ ┆ ┆ ┆ s] ┆ ┆ u32 ┆ i64 ┆ ┆ │\n", | ||
"╞══════════╪═════════╪════════════╪════════════╪═══╪════════════╪════════════╪═══════════╪═════════╡\n", | ||
"│ 29389287 ┆ comment ┆ jsc1986 ┆ 2021-11-30 ┆ … ┆ 6 ┆ 1 ┆ <instruct ┆ 32.25 │\n", | ||
"│ ┆ ┆ ┆ 05:32:42 ┆ ┆ ┆ ┆ ions>Your ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ goal is ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ to … ┆ │\n", | ||
"│ 2920304 ┆ comment ┆ wheels ┆ 2011-08-24 ┆ … ┆ 7 ┆ 1 ┆ <instruct ┆ 30.125 │\n", | ||
"│ ┆ ┆ ┆ 11:35:53 ┆ ┆ ┆ ┆ ions>Your ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ goal is ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ to … ┆ │\n", | ||
"│ 29390682 ┆ comment ┆ ZephyrBlu ┆ 2021-11-30 ┆ … ┆ 5 ┆ 1 ┆ <instruct ┆ 29.25 │\n", | ||
"│ ┆ ┆ ┆ 10:22:12 ┆ ┆ ┆ ┆ ions>Your ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ goal is ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ to … ┆ │\n", | ||
"│ 29389031 ┆ comment ┆ temp7536 ┆ 2021-11-30 ┆ … ┆ 7 ┆ 1 ┆ <instruct ┆ 28.625 │\n", | ||
"│ ┆ ┆ ┆ 04:39:01 ┆ ┆ ┆ ┆ ions>Your ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ goal is ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ to … ┆ │\n", | ||
"│ 6370703 ┆ comment ┆ enraged_ca ┆ 2013-09-11 ┆ … ┆ 8 ┆ 1 ┆ <instruct ┆ 28.5 │\n", | ||
"│ ┆ ┆ mel ┆ 22:53:13 ┆ ┆ ┆ ┆ ions>Your ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ goal is ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ to … ┆ │\n", | ||
"│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n", | ||
"│ 13831616 ┆ comment ┆ numbsafari ┆ 2017-03-09 ┆ … ┆ 10 ┆ 1 ┆ <instruct ┆ -23.25 │\n", | ||
"│ ┆ ┆ ┆ 18:27:59 ┆ ┆ ┆ ┆ ions>Your ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ goal is ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ to … ┆ │\n", | ||
"│ 29273269 ┆ comment ┆ elwell ┆ 2021-11-19 ┆ … ┆ 5 ┆ 1 ┆ <instruct ┆ -23.25 │\n", | ||
"│ ┆ ┆ ┆ 02:55:45 ┆ ┆ ┆ ┆ ions>Your ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ goal is ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ to … ┆ │\n", | ||
"│ 32693407 ┆ comment ┆ jaquilio ┆ 2022-09-02 ┆ … ┆ 12 ┆ 1 ┆ <instruct ┆ -23.625 │\n", | ||
"│ ┆ ┆ ┆ 17:28:12 ┆ ┆ ┆ ┆ ions>Your ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ goal is ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ to … ┆ │\n", | ||
"│ 1436880 ┆ comment ┆ acgourley ┆ 2010-06-16 ┆ … ┆ 9 ┆ 1 ┆ <instruct ┆ -23.875 │\n", | ||
"│ ┆ ┆ ┆ 20:15:21 ┆ ┆ ┆ ┆ ions>Your ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ goal is ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ to … ┆ │\n", | ||
"│ 2160465 ┆ comment ┆ pan69 ┆ 2011-01-31 ┆ … ┆ 9 ┆ 1 ┆ <instruct ┆ -23.875 │\n", | ||
"│ ┆ ┆ ┆ 05:52:40 ┆ ┆ ┆ ┆ ions>Your ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ goal is ┆ │\n", | ||
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ to … ┆ │\n", | ||
"└──────────┴─────────┴────────────┴────────────┴───┴────────────┴────────────┴───────────┴─────────┘" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import polars as pl\n", | ||
"\n", | ||
"df = pl.read_parquet(\"./data/top_comments_with_reward.parquet\").sort(\n", | ||
" by=\"reward\", descending=True\n", | ||
")\n", | ||
"\n", | ||
"df" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"stories_df = pl.read_parquet(\"./data/stories.parquet\")\n", | ||
"\n", | ||
"stories_df = stories_df.select(pl.col(\"id\", \"title\", \"url\")).rename(\n", | ||
" {\n", | ||
" \"id\": \"story_id\",\n", | ||
" \"title\": \"story_title\",\n", | ||
" \"url\": \"story_url\",\n", | ||
" }\n", | ||
")\n", | ||
"\n", | ||
"df = df.join(stories_df, left_on=\"top_level_parent\", right_on=\"story_id\", how=\"left\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"sys:1: MapWithoutReturnDtypeWarning: Calling `map_elements` without specifying `return_dtype` can lead to unpredictable results. Specify `return_dtype` to silence this warning.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import html\n", | ||
"\n", | ||
"\n", | ||
"def unescape_html(text):\n", | ||
" return html.unescape(text)\n", | ||
"\n", | ||
"\n", | ||
"df = df.with_columns(\n", | ||
" pl.concat_str(\n", | ||
" pl.lit(\"https://news.ycombinator.com/item?id=\"),\n", | ||
" pl.col(\"id\"),\n", | ||
" ).alias(\"link\"),\n", | ||
" pl.col(\"time\").dt.strftime(\"%B %d, %Y\").alias(\"date\"),\n", | ||
" pl.col(\"text\").str.replace_all(\"<p>\", \"\\n\\n\").alias(\"text\"),\n", | ||
" pl.col(\"text\")\n", | ||
" .map_elements(unescape_html, return_dtype=pl.String)\n", | ||
" .alias(\"text_unescaped\"),\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df.select(pl.col(\"date\", \"by\", \"link\", \"story_title\", \"text\", \"reward\")).head(\n", | ||
" 100\n", | ||
").write_csv(\"./data/top_comments_with_links.csv\")\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loading tokenizer and model...\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "01b762ff603b4cf9b0e5efa1f18ffc57", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at unsloth/Meta-Llama-3.1-8B and are newly initialized: ['score.weight']\n", | ||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loading PEFT model...\n", | ||
"Merging PEFT model with base model...\n", | ||
"Saving merged model...\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"('./models/llama_32_8b_merged/tokenizer_config.json',\n", | ||
" './models/llama_32_8b_merged/special_tokens_map.json',\n", | ||
" './models/llama_32_8b_merged/tokenizer.json')" | ||
] | ||
}, | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import dotenv\n", | ||
"import torch\n", | ||
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", | ||
"from peft import PeftModel\n", | ||
"\n", | ||
"dotenv.load_dotenv()\n", | ||
"\n", | ||
"print(\"Loading tokenizer and model...\")\n", | ||
"base_model = AutoModelForSequenceClassification.from_pretrained(\n", | ||
" \"unsloth/Meta-Llama-3.1-8B\",\n", | ||
" device_map=\"auto\",\n", | ||
" num_labels=1,\n", | ||
" torch_dtype=torch.bfloat16,\n", | ||
")\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(\"unsloth/Meta-Llama-3.1-8B\")\n", | ||
"\n", | ||
"print(\"Loading PEFT model...\")\n", | ||
"peft_model = PeftModel.from_pretrained(\n", | ||
" base_model, \"./reward_model_output/checkpoint-30000/\", device_map=\"auto\"\n", | ||
")\n", | ||
"\n", | ||
"print(\"Merging PEFT model with base model...\")\n", | ||
"merged_model = peft_model.merge_and_unload()\n", | ||
"\n", | ||
"print(\"Saving merged model...\")\n", | ||
"merged_model.save_pretrained(\"./models/llama_32_8b_merged\")\n", | ||
"tokenizer.save_pretrained(\"./models/llama_32_8b_merged\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.