Skip to content

Commit

Permalink
script for inference at checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
JanProvaznik committed Apr 28, 2024
1 parent 4048c05 commit ee672c8
Showing 1 changed file with 41 additions and 31 deletions.
72 changes: 41 additions & 31 deletions evaluation_batchedgpuevaluate_other_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# evaluate performance along various axes of sentence complexity"
"# inference "
]
},
{
Expand Down Expand Up @@ -79,6 +79,7 @@
"from src.ByT5Dataset import ByT5ConstEnigmaDataset, ByT5CaesarRandomDataset, ByT5NoisyVignere2Dataset, ByT5NoisyConstEnigmaDataset, ByT5NoisyVignere3Dataset\n",
"from src.evaluation import Model\n",
"from src.ByT5Dataset import ByT5Dataset\n",
"import argparse\n",
"\n",
"models = {\n",
" 'caesar': Model(ByT5CaesarRandomDataset, 'caesar', 'en', 16677),\n",
Expand Down Expand Up @@ -125,44 +126,53 @@
" 'cs_noisevignere3_3000': Model(ByT5NoisyVignere3Dataset, 'cs_noisevignere3_3000', 'cs', 22989 , True, 3000, .15), # 23177\n",
" 'cs_noisevignere3_3500': Model(ByT5NoisyVignere3Dataset, 'cs_noisevignere3_3500', 'cs', 22989 , True, 3500, .15), # 23178\n",
" 'cs_noisevignere3_4000': Model(ByT5NoisyVignere3Dataset, 'cs_noisevignere3_4000', 'cs', 22989 , True, 4000, .15), # 23179\n",
" # enigmas\n",
"\n",
" # de enigma 23190\n",
" 'de_noiseconstenigma_500': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_500', 'de', 23190 , True, 500, .15), # \n",
" 'de_noiseconstenigma_1000': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_1000', 'de', 23190 , True, 1000, .15), #\n",
" 'de_noiseconstenigma_1500': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_1500', 'de', 23190 , True, 1500, .15), #\n",
" 'de_noiseconstenigma_2000': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_2000', 'de', 23190 , True, 2000, .15), #\n",
" 'de_noiseconstenigma_2500': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_2500', 'de', 23190 , True, 2500, .15), #\n",
" 'de_noiseconstenigma_3000': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_3000', 'de', 23190 , True, 3000, .15), #\n",
" 'de_noiseconstenigma_3500': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_3500', 'de', 23190 , True, 3500, .15), #\n",
" 'de_noiseconstenigma_4000': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_4000', 'de', 23190 , True, 4000, .15), #\n",
"\n",
"\n",
" 'de_noiseconstenigma_500' : Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_500' , 'de', 23190 , True, 500 , .15), # 23639\n",
" 'de_noiseconstenigma_1000': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_1000', 'de', 23190 , True, 1000, .15), # 23640\n",
" 'de_noiseconstenigma_1500': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_1500', 'de', 23190 , True, 1500, .15), # 23641\n",
" 'de_noiseconstenigma_2000': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_2000', 'de', 23190 , True, 2000, .15), # 23642\n",
" 'de_noiseconstenigma_2500': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_2500', 'de', 23190 , True, 2500, .15), # 23643\n",
" 'de_noiseconstenigma_3000': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_3000', 'de', 23190 , True, 3000, .15), # 23644\n",
" 'de_noiseconstenigma_3500': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_3500', 'de', 23190 , True, 3500, .15), # 23645\n",
" 'de_noiseconstenigma_4000': Model(ByT5NoisyConstEnigmaDataset, 'de_noiseconstenigma_4000', 'de', 23190 , True, 4000, .15), # 23646\n",
"\n",
" # cs enigma 23167 \n",
" 'cs_noiseconstenigma_500': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_500', 'cs', 23167 , True, 500, .15), #\n",
" 'cs_noiseconstenigma_1000': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_1000', 'cs', 23167 , True, 1000, .15), #\n",
" 'cs_noiseconstenigma_1500': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_1500', 'cs', 23167 , True, 1500, .15), #\n",
" 'cs_noiseconstenigma_2000': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_2000', 'cs', 23167 , True, 2000, .15), #\n",
" 'cs_noiseconstenigma_2500': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_2500', 'cs', 23167 , True, 2500, .15), #\n",
" 'cs_noiseconstenigma_3000': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_3000', 'cs', 23167 , True, 3000, .15), #\n",
" 'cs_noiseconstenigma_3500': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_3500', 'cs', 23167 , True, 3500, .15), #\n",
" 'cs_noiseconstenigma_4000': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_4000', 'cs', 23167 , True, 4000, .15), #\n",
" 'cs_noiseconstenigma_500': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_500', 'cs', 23167 , True, 500, .15), # 23647\n",
" 'cs_noiseconstenigma_1000': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_1000', 'cs', 23167 , True, 1000, .15), # 23648\n",
" 'cs_noiseconstenigma_1500': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_1500', 'cs', 23167 , True, 1500, .15), # 23649\n",
" 'cs_noiseconstenigma_2000': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_2000', 'cs', 23167 , True, 2000, .15), # 23650\n",
" 'cs_noiseconstenigma_2500': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_2500', 'cs', 23167 , True, 2500, .15), # 23651\n",
" 'cs_noiseconstenigma_3000': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_3000', 'cs', 23167 , True, 3000, .15), # 23652\n",
" 'cs_noiseconstenigma_3500': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_3500', 'cs', 23167 , True, 3500, .15), # 23653\n",
" 'cs_noiseconstenigma_4000': Model(ByT5NoisyConstEnigmaDataset, 'cs_noiseconstenigma_4000', 'cs', 23167 , True, 4000, .15), # 23654\n",
"\n",
" # en enigma 23609\n",
" 'en_noiseconstenigma_500': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_500', 'en', 23609 , True, 500, .15), #\n",
" 'en_noiseconstenigma_1000': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_1000', 'en', 23609 , True, 1000, .15), #\n",
" 'en_noiseconstenigma_1500': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_1500', 'en', 23609 , True, 1500, .15), #\n",
" 'en_noiseconstenigma_2000': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_2000', 'en', 23609 , True, 2000, .15), #\n",
" 'en_noiseconstenigma_2500': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_2500', 'en', 23609 , True, 2500, .15), #\n",
" 'en_noiseconstenigma_3000': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_3000', 'en', 23609 , True, 3000, .15), #\n",
" 'en_noiseconstenigma_3500': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_3500', 'en', 23609 , True, 3500, .15), #\n",
" 'en_noiseconstenigma_4000': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_4000', 'en', 23609 , True, 4000, .15), #\n",
" \n",
" 'en_noiseconstenigma_500': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_500', 'en', 23609 , True, 500, .15), # 24303\n",
" 'en_noiseconstenigma_1000': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_1000', 'en', 23609 , True, 1000, .15), # 24304\n",
" 'en_noiseconstenigma_1500': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_1500', 'en', 23609 , True, 1500, .15), # 24306\n",
" 'en_noiseconstenigma_2000': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_2000', 'en', 23609 , True, 2000, .15), # 24307\n",
" 'en_noiseconstenigma_2500': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_2500', 'en', 23609 , True, 2500, .15), # 24309\n",
" 'en_noiseconstenigma_3000': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_3000', 'en', 23609 , True, 3000, .15), # 24310\n",
" 'en_noiseconstenigma_3500': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_3500', 'en', 23609 , True, 3500, .15), # 24311\n",
" 'en_noiseconstenigma_4000': Model(ByT5NoisyConstEnigmaDataset, 'en_noiseconstenigma_4000', 'en', 23609 , True, 4000, .15), # 24312\n",
"\n",
"\n",
"}\n",
"\n",
"# evaluated_name = 'en_noisevignere_checkpoint-5000'\n",
"evaluated_name = 'cs_noisevignere3_4000'\n",
"# evaluated_name = 'cs_noisevignere3_4000'\n",
"# Create an argument parser\n",
"parser = argparse.ArgumentParser()\n",
"parser.add_argument('--eval_model', type=str, help='Name of the evaluated model')\n",
"\n",
"args, _ = parser.parse_known_args()\n",
"\n",
"# Get the evaluated name from script arguments\n",
"evaluated_name = args.eval_model\n",
"\n",
"\n",
"model_metadata = models[evaluated_name]\n",
"\n",
"data_path = f'news.2013.{model_metadata.language}.trainlen.200.evaluation.100000.csv'\n",
Expand Down Expand Up @@ -277,7 +287,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "enigmavenv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -295,5 +305,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}

0 comments on commit ee672c8

Please sign in to comment.