-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
make project installable via pip refactor add some code for german evaluation training scripts for german and czech
- Loading branch information
1 parent
8554011
commit 3cfe6fa
Showing
15 changed files
with
2,158 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# evaluate performance along various axes of sentence complexity" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"start\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(\"start\")\n", | ||
"import sys\n", | ||
"import os\n", | ||
"\n", | ||
"sys.path.append(\"./enigma-transformed/src\")\n", | ||
"sys.path.append(\"./src\")\n", | ||
"sys.path.append(\"../src\")\n", | ||
"sys.path.append(\"../../src\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"eval_column: unigram_js_divergence\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"available_eval_columns = [\n", | ||
" \"unigram_js_divergence\", # 17602 old, 18204\n", | ||
" \"gpt2_tokens_per_char\", # 17659 old, 18206\n", | ||
" \"gpt2_perplexity\", # 17604 old, 18207\n", | ||
" \"bigram_js_divergence\", # 17657, 18209\n", | ||
" \"depth_of_parse_tree\", #17687\n", | ||
" \"named_entities\", #17688\n", | ||
" \"pos_js_divergence\", #17689\n", | ||
" \"pos_bigram_js_divergence\" #17690\n", | ||
"]\n", | ||
"eval_column = available_eval_columns[3]\n", | ||
"print(\"eval_column:\", eval_column)\n", | ||
"dataset_max_len = 200" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"len: 4888014\n", | ||
"len: 4740817\n", | ||
"len: 4726966\n", | ||
"data loaded\n", | ||
"data sorted\n", | ||
"top 100 in original_text column:\n", | ||
"3657034 The latter would remain safely in the realm of...\n", | ||
"66317 Ostensibly - particularly when it came to the ...\n", | ||
"681907 The report, commissioned by the anti-mining gr...\n", | ||
"4534661 IRAN held two sets of talks on Wednesday aimed...\n", | ||
"1218458 It was overwhelming - even after repeated eati...\n", | ||
" ... \n", | ||
"3595778 The Missouri Supreme Court temporarily suspend...\n", | ||
"3488527 Mr Shaw's decision to quit the Liberal Party o...\n", | ||
"2245300 At the end of the week, the Moncler Gamme Roug...\n", | ||
"632568 The transport firm revealed it has set aside m...\n", | ||
"3916704 Dr Ali pointed to US President Barack Obama's ...\n", | ||
"Name: original_text, Length: 100, dtype: object\n", | ||
"bottom 100 in original_text column:\n", | ||
"265842 Montenegro: Mladen Bozovic; Savo Pavicevic, St...\n", | ||
"4130722 To join the Viking Cruises communities online,...\n", | ||
"588808 Shakhter Karagandy (4-5-1): Mokin; Simcevic, M...\n", | ||
"2209035 India (from): M Dhoni (capt & wkt) R Ashwin, S...\n", | ||
"126085 Estonia (4-2-3-1): Pareiko; Jaager, Morozov, K...\n", | ||
" ... \n", | ||
"3118803 Tonga: Lilo - Vainikolo, Piutau, Piukala, Helu...\n", | ||
"3176970 It identified them as Qa'ed al-Dahab, Ali Jall...\n", | ||
"4138023 China got £37 billion ($60 billion,) the Phili...\n", | ||
"2865645 Barcelona 4-0 Ajax, Milan 2-0 Celtic; Ajax 1-1...\n", | ||
"1593711 Chelsea 1-2 Basel, Schalke 3-0 Steaua Buchares...\n", | ||
"Name: original_text, Length: 100, dtype: object\n" | ||
] | ||
}, | ||
{ | ||
"ename": "", | ||
"evalue": "", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details." | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"\n", | ||
"data_path = f\"news.2013.en.trainlen.{dataset_max_len}.merged.csv\"\n", | ||
"data = pd.read_csv(data_path)\n", | ||
"# print len\n", | ||
"print(\"len:\", len(data))\n", | ||
"# filter out rows which 'lang' != True\n", | ||
"data = data[data['lang'] == True]\n", | ||
"print(\"len:\", len(data))\n", | ||
"# weird is false\n", | ||
"data = data[data['weird'] == False]\n", | ||
"# print len\n", | ||
"print(\"len:\", len(data))\n", | ||
"print(\"data loaded\")\n", | ||
"# sort by eval column\n", | ||
"data.sort_values(eval_column, inplace=True)\n", | ||
"print(\"data sorted\")\n", | ||
"\n", | ||
"# # print top 100 in original_text column\n", | ||
"# print(\"top 100 in original_text column:\")\n", | ||
"# print(data['original_text'][:100])\n", | ||
"# # bottom 100 in original_text column\n", | ||
"# print(\"bottom 100 in original_text column:\")\n", | ||
"# print(data['original_text'][-100:])\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# print the first 10 rows\n", | ||
"print(f\"{data.head(10)=}\")\n", | ||
"# last 10 rows\n", | ||
"print(f\"{data.tail(10)=}\")\n", | ||
"\n", | ||
"rows = {\n", | ||
" 0.1: data.iloc[len(data) // 1000 * 999 - 1000 : len(data) // 1000 * 999],\n", | ||
" 1: data.iloc[len(data) // 100 * 99 - 1000 : len(data) // 100 * 99],\n", | ||
" 5: data.iloc[len(data) // 100 * 95 - 1000 : len(data) // 100 * 95],\n", | ||
" 10: data.iloc[len(data) // 100 * 90 - 1000 : len(data) // 100 * 90],\n", | ||
" 15: data.iloc[len(data) // 100 * 85 - 1000 : len(data) // 100 * 85],\n", | ||
" 20: data.iloc[len(data) // 100 * 80 - 1000 : len(data) // 100 * 80],\n", | ||
" 25: data.iloc[len(data) // 100 * 75 - 1000 : len(data) // 100 * 75],\n", | ||
" 30: data.iloc[len(data) // 100 * 70 - 1000 : len(data) // 100 * 70],\n", | ||
" 35: data.iloc[len(data) // 100 * 65 - 1000 : len(data) // 100 * 65],\n", | ||
" 40: data.iloc[len(data) // 100 * 60 - 1000 : len(data) // 100 * 60],\n", | ||
" 45: data.iloc[len(data) // 100 * 55 - 1000 : len(data) // 100 * 55],\n", | ||
" 50: data.iloc[len(data) // 100 * 50 - 1000 : len(data) // 100 * 50],\n", | ||
" 55: data.iloc[len(data) // 100 * 45 - 1000 : len(data) // 100 * 45],\n", | ||
" 60: data.iloc[len(data) // 100 * 40 - 1000 : len(data) // 100 * 40],\n", | ||
" 65: data.iloc[len(data) // 100 * 35 - 1000 : len(data) // 100 * 35],\n", | ||
" 70: data.iloc[len(data) // 100 * 30 - 1000 : len(data) // 100 * 30],\n", | ||
" 75: data.iloc[len(data) // 100 * 25 - 1000 : len(data) // 100 * 25],\n", | ||
" 80: data.iloc[len(data) // 100 * 20 - 1000 : len(data) // 100 * 20],\n", | ||
" 85: data.iloc[len(data) // 100 * 15 - 1000 : len(data) // 100 * 15],\n", | ||
" 90: data.iloc[len(data) // 100 * 10 - 1000 : len(data) // 100 * 10],\n", | ||
" 95: data.iloc[len(data) // 100 * 5 - 1000 : len(data) // 100 * 5],\n", | ||
" 99: data.iloc[len(data) // 100 * 1 - 1000 : len(data) // 100 * 1],\n", | ||
" 99.9: data.iloc[len(data) // 1000 * 1 - 1000 : len(data) // 1000 * 1],\n", | ||
"}\n", | ||
"print(\"percentile rows split\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## load model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import ByT5Tokenizer, T5ForConditionalGeneration\n", | ||
"from src.utils import levensthein_distance, print_avg_median_mode_error\n", | ||
"from transformers import pipeline, logging\n", | ||
"import torch\n", | ||
"\n", | ||
"logging.set_verbosity(logging.ERROR)\n", | ||
"\n", | ||
"\n", | ||
"tokenizer = ByT5Tokenizer()\n", | ||
"\n", | ||
"from src.ByT5Dataset import ByT5CaesarRandomDataset, ByT5ConstEnigmaDataset\n", | ||
"\n", | ||
"model = T5ForConditionalGeneration.from_pretrained(\"./logs/slurm_17510/model\")\n", | ||
"dataset_class = ByT5ConstEnigmaDataset # for 17510 model\n", | ||
"# dataset_class = ByT5CaesarRandomDataset # for 16677 model\n", | ||
"rows_datasets = {i: dataset_class(list(rows[i].text), dataset_max_len) for i in rows}\n", | ||
"\n", | ||
"print(\"rows_datasets created\")\n", | ||
"\n", | ||
"averages, medians, modes = {}, {}, {}\n", | ||
"raw_data = {}\n", | ||
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", | ||
"for i, test in rows_datasets.items():\n", | ||
" print(f\"#############################################\")\n", | ||
" print(f\"Testing {i}th percentile\")\n", | ||
" error_counts = []\n", | ||
" translate = pipeline(\"translation\", model=model, tokenizer=tokenizer, device=device)\n", | ||
" for index in range(len(test)):\n", | ||
" generated = translate(\n", | ||
" test[index][\"input_text\"], max_length=(dataset_max_len + 1) * 2\n", | ||
" )[0][\"translation_text\"]\n", | ||
" error_counts.append(levensthein_distance(generated, test[index][\"output_text\"]))\n", | ||
" if error_counts[-1] > 0:\n", | ||
" print(f\"Example {index}, error count {error_counts[-1]}\")\n", | ||
" print(\"In :\", test[index][\"input_text\"])\n", | ||
" print(\"Gen:\", generated)\n", | ||
" expected = test[index][\"output_text\"]\n", | ||
" print(\"Exp:\", expected)\n", | ||
" else:\n", | ||
" print(f\"Example {index} OK\")\n", | ||
" print(\"-----------------------\")\n", | ||
"\n", | ||
" avg, med, mode = print_avg_median_mode_error(error_counts)\n", | ||
" averages[i] = avg\n", | ||
" medians[i] = med\n", | ||
" modes[i] = mode\n", | ||
" raw_data[i] = error_counts\n", | ||
"\n", | ||
"print(\"Averages:\", averages)\n", | ||
"print(\"Medians:\", medians)\n", | ||
"print(\"Modes:\", modes)\n", | ||
"print(\"Raw data:\", raw_data)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "enigmavenv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.