Skip to content

Commit

Permalink
Update evaluation codes,
Browse files Browse the repository at this point in the history
make project installable via pip
refactor
add some code for german evaluation
training scripts for german and czech
  • Loading branch information
JanProvaznik committed Dec 2, 2023
1 parent 8554011 commit 3cfe6fa
Show file tree
Hide file tree
Showing 15 changed files with 2,158 additions and 5 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
This project explores the possibility of using a pretrained language model to decrypt ciphers. The aim is also to discover what linguistic features of a text the model learns to use by varying the test set and measuring accuracy.



## Docs
### How to run
- get dependencies and install local code
```
pip install -r requirements.txt
pip install -e .
```
#### Slurm cluster
- basic setting: `sbatch -p gpu -c1 --gpus=1 --mem=16G <bash_script_path>`
- use `run_notebook.sh <notebook_path>` to run a Jupyter notebook on a slurm cluster
Expand Down Expand Up @@ -102,4 +108,6 @@ This project explores the possibility of using a pretrained language model to de

#### learning rate
- has to be quite high because we're not fine-tuning for a language task but for a quite strange translaton
- usually use the huggingface default LR schedule for `Seq2SeqTrainer` (linear decay); and set relative warmup (e.g. 0.2 of total steps)
- usually use the huggingface default LR schedule for `Seq2SeqTrainer` (linear decay); and set relative warmup (e.g. 0.2 of total steps)


418 changes: 418 additions & 0 deletions analysis.ipynb

Large diffs are not rendered by default.

276 changes: 276 additions & 0 deletions evaluation_gpuevaluate.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# evaluate performance along various axes of sentence complexity"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## data"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"start\n"
]
}
],
"source": [
"print(\"start\")\n",
"import sys\n",
"import os\n",
"\n",
"sys.path.append(\"./enigma-transformed/src\")\n",
"sys.path.append(\"./src\")\n",
"sys.path.append(\"../src\")\n",
"sys.path.append(\"../../src\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"eval_column: unigram_js_divergence\n"
]
}
],
"source": [
"available_eval_columns = [\n",
" \"unigram_js_divergence\", # 17602 old, 18204\n",
" \"gpt2_tokens_per_char\", # 17659 old, 18206\n",
" \"gpt2_perplexity\", # 17604 old, 18207\n",
" \"bigram_js_divergence\", # 17657, 18209\n",
" \"depth_of_parse_tree\", #17687\n",
" \"named_entities\", #17688\n",
" \"pos_js_divergence\", #17689\n",
" \"pos_bigram_js_divergence\" #17690\n",
"]\n",
"eval_column = available_eval_columns[3]\n",
"print(\"eval_column:\", eval_column)\n",
"dataset_max_len = 200"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"len: 4888014\n",
"len: 4740817\n",
"len: 4726966\n",
"data loaded\n",
"data sorted\n",
"top 100 in original_text column:\n",
"3657034 The latter would remain safely in the realm of...\n",
"66317 Ostensibly - particularly when it came to the ...\n",
"681907 The report, commissioned by the anti-mining gr...\n",
"4534661 IRAN held two sets of talks on Wednesday aimed...\n",
"1218458 It was overwhelming - even after repeated eati...\n",
" ... \n",
"3595778 The Missouri Supreme Court temporarily suspend...\n",
"3488527 Mr Shaw's decision to quit the Liberal Party o...\n",
"2245300 At the end of the week, the Moncler Gamme Roug...\n",
"632568 The transport firm revealed it has set aside m...\n",
"3916704 Dr Ali pointed to US President Barack Obama's ...\n",
"Name: original_text, Length: 100, dtype: object\n",
"bottom 100 in original_text column:\n",
"265842 Montenegro: Mladen Bozovic; Savo Pavicevic, St...\n",
"4130722 To join the Viking Cruises communities online,...\n",
"588808 Shakhter Karagandy (4-5-1): Mokin; Simcevic, M...\n",
"2209035 India (from): M Dhoni (capt & wkt) R Ashwin, S...\n",
"126085 Estonia (4-2-3-1): Pareiko; Jaager, Morozov, K...\n",
" ... \n",
"3118803 Tonga: Lilo - Vainikolo, Piutau, Piukala, Helu...\n",
"3176970 It identified them as Qa'ed al-Dahab, Ali Jall...\n",
"4138023 China got £37 billion ($60 billion,) the Phili...\n",
"2865645 Barcelona 4-0 Ajax, Milan 2-0 Celtic; Ajax 1-1...\n",
"1593711 Chelsea 1-2 Basel, Schalke 3-0 Steaua Buchares...\n",
"Name: original_text, Length: 100, dtype: object\n"
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
"source": [
"import pandas as pd\n",
"\n",
"data_path = f\"news.2013.en.trainlen.{dataset_max_len}.merged.csv\"\n",
"data = pd.read_csv(data_path)\n",
"# print len\n",
"print(\"len:\", len(data))\n",
"# filter out rows which 'lang' != True\n",
"data = data[data['lang'] == True]\n",
"print(\"len:\", len(data))\n",
"# weird is false\n",
"data = data[data['weird'] == False]\n",
"# print len\n",
"print(\"len:\", len(data))\n",
"print(\"data loaded\")\n",
"# sort by eval column\n",
"data.sort_values(eval_column, inplace=True)\n",
"print(\"data sorted\")\n",
"\n",
"# # print top 100 in original_text column\n",
"# print(\"top 100 in original_text column:\")\n",
"# print(data['original_text'][:100])\n",
"# # bottom 100 in original_text column\n",
"# print(\"bottom 100 in original_text column:\")\n",
"# print(data['original_text'][-100:])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# print the first 10 rows\n",
"print(f\"{data.head(10)=}\")\n",
"# last 10 rows\n",
"print(f\"{data.tail(10)=}\")\n",
"\n",
"rows = {\n",
" 0.1: data.iloc[len(data) // 1000 * 999 - 1000 : len(data) // 1000 * 999],\n",
" 1: data.iloc[len(data) // 100 * 99 - 1000 : len(data) // 100 * 99],\n",
" 5: data.iloc[len(data) // 100 * 95 - 1000 : len(data) // 100 * 95],\n",
" 10: data.iloc[len(data) // 100 * 90 - 1000 : len(data) // 100 * 90],\n",
" 15: data.iloc[len(data) // 100 * 85 - 1000 : len(data) // 100 * 85],\n",
" 20: data.iloc[len(data) // 100 * 80 - 1000 : len(data) // 100 * 80],\n",
" 25: data.iloc[len(data) // 100 * 75 - 1000 : len(data) // 100 * 75],\n",
" 30: data.iloc[len(data) // 100 * 70 - 1000 : len(data) // 100 * 70],\n",
" 35: data.iloc[len(data) // 100 * 65 - 1000 : len(data) // 100 * 65],\n",
" 40: data.iloc[len(data) // 100 * 60 - 1000 : len(data) // 100 * 60],\n",
" 45: data.iloc[len(data) // 100 * 55 - 1000 : len(data) // 100 * 55],\n",
" 50: data.iloc[len(data) // 100 * 50 - 1000 : len(data) // 100 * 50],\n",
" 55: data.iloc[len(data) // 100 * 45 - 1000 : len(data) // 100 * 45],\n",
" 60: data.iloc[len(data) // 100 * 40 - 1000 : len(data) // 100 * 40],\n",
" 65: data.iloc[len(data) // 100 * 35 - 1000 : len(data) // 100 * 35],\n",
" 70: data.iloc[len(data) // 100 * 30 - 1000 : len(data) // 100 * 30],\n",
" 75: data.iloc[len(data) // 100 * 25 - 1000 : len(data) // 100 * 25],\n",
" 80: data.iloc[len(data) // 100 * 20 - 1000 : len(data) // 100 * 20],\n",
" 85: data.iloc[len(data) // 100 * 15 - 1000 : len(data) // 100 * 15],\n",
" 90: data.iloc[len(data) // 100 * 10 - 1000 : len(data) // 100 * 10],\n",
" 95: data.iloc[len(data) // 100 * 5 - 1000 : len(data) // 100 * 5],\n",
" 99: data.iloc[len(data) // 100 * 1 - 1000 : len(data) // 100 * 1],\n",
" 99.9: data.iloc[len(data) // 1000 * 1 - 1000 : len(data) // 1000 * 1],\n",
"}\n",
"print(\"percentile rows split\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## load model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import ByT5Tokenizer, T5ForConditionalGeneration\n",
"from src.utils import levensthein_distance, print_avg_median_mode_error\n",
"from transformers import pipeline, logging\n",
"import torch\n",
"\n",
"logging.set_verbosity(logging.ERROR)\n",
"\n",
"\n",
"tokenizer = ByT5Tokenizer()\n",
"\n",
"from src.ByT5Dataset import ByT5CaesarRandomDataset, ByT5ConstEnigmaDataset\n",
"\n",
"model = T5ForConditionalGeneration.from_pretrained(\"./logs/slurm_17510/model\")\n",
"dataset_class = ByT5ConstEnigmaDataset # for 17510 model\n",
"# dataset_class = ByT5CaesarRandomDataset # for 16677 model\n",
"rows_datasets = {i: dataset_class(list(rows[i].text), dataset_max_len) for i in rows}\n",
"\n",
"print(\"rows_datasets created\")\n",
"\n",
"averages, medians, modes = {}, {}, {}\n",
"raw_data = {}\n",
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"for i, test in rows_datasets.items():\n",
" print(f\"#############################################\")\n",
" print(f\"Testing {i}th percentile\")\n",
" error_counts = []\n",
" translate = pipeline(\"translation\", model=model, tokenizer=tokenizer, device=device)\n",
" for index in range(len(test)):\n",
" generated = translate(\n",
" test[index][\"input_text\"], max_length=(dataset_max_len + 1) * 2\n",
" )[0][\"translation_text\"]\n",
" error_counts.append(levensthein_distance(generated, test[index][\"output_text\"]))\n",
" if error_counts[-1] > 0:\n",
" print(f\"Example {index}, error count {error_counts[-1]}\")\n",
" print(\"In :\", test[index][\"input_text\"])\n",
" print(\"Gen:\", generated)\n",
" expected = test[index][\"output_text\"]\n",
" print(\"Exp:\", expected)\n",
" else:\n",
" print(f\"Example {index} OK\")\n",
" print(\"-----------------------\")\n",
"\n",
" avg, med, mode = print_avg_median_mode_error(error_counts)\n",
" averages[i] = avg\n",
" medians[i] = med\n",
" modes[i] = mode\n",
" raw_data[i] = error_counts\n",
"\n",
"print(\"Averages:\", averages)\n",
"print(\"Medians:\", medians)\n",
"print(\"Modes:\", modes)\n",
"print(\"Raw data:\", raw_data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "enigmavenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 3cfe6fa

Please sign in to comment.