Skip to content

Commit ce0826b

Browse files
authored
Merge pull request #7 from shimo-lab/peform_pred
Add code for model performance prediction
2 parents b4190ca + c0799f9 commit ce0826b

15 files changed

Lines changed: 1676 additions & 0 deletions

File tree

resampling-texts/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@ Using only texts selected through LS sampling allows new models to be efficientl
3232
<img src="figures/fig3.png" alt="fig3" style="width:90%">
3333
</p>
3434

35+
### Prediction of Model's Performance (Figure4)
3536

37+
Using model coordinates from unique texts, we predict the average performance across six downstream tasks with ridge regression.
38+
See [`code_for_prediction/`](./code_for_prediction/) for details.
39+
40+
<p align="center">
41+
<img src="figures/fig4.png" alt="fig4" style="width:50%">
42+
</p>
3643

3744
## 🦉 Misc.
3845

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
### Generated by gibo (https://github.com/simonwhitaker/gibo)
2+
### https://raw.github.com/github/gitignore/d0b80a469983a7beece8fa1f5c48a8242318b531/Global/Vim.gitignore
3+
4+
# Swap
5+
[._]*.s[a-v][a-z]
6+
!*.svg # comment out if you don't need vector files
7+
[._]*.sw[a-p]
8+
[._]s[a-rt-v][a-z]
9+
[._]ss[a-gi-z]
10+
[._]sw[a-p]
11+
12+
# Session
13+
Session.vim
14+
Sessionx.vim
15+
16+
# Temporary
17+
.netrwhist
18+
*~
19+
# Auto-generated tag files
20+
tags
21+
# Persistent undo
22+
[._]*.un~
23+
24+
25+
### https://raw.github.com/github/gitignore/d0b80a469983a7beece8fa1f5c48a8242318b531/Python.gitignore
26+
27+
# Byte-compiled / optimized / DLL files
28+
__pycache__/
29+
*.py[cod]
30+
*$py.class
31+
32+
# C extensions
33+
*.so
34+
35+
# Distribution / packaging
36+
.Python
37+
build/
38+
develop-eggs/
39+
dist/
40+
downloads/
41+
eggs/
42+
.eggs/
43+
lib/
44+
lib64/
45+
parts/
46+
sdist/
47+
var/
48+
wheels/
49+
share/python-wheels/
50+
*.egg-info/
51+
.installed.cfg
52+
*.egg
53+
MANIFEST
54+
55+
# PyInstaller
56+
# Usually these files are written by a python script from a template
57+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
58+
*.manifest
59+
*.spec
60+
61+
# Installer logs
62+
pip-log.txt
63+
pip-delete-this-directory.txt
64+
65+
# Unit test / coverage reports
66+
htmlcov/
67+
.tox/
68+
.nox/
69+
.coverage
70+
.coverage.*
71+
.cache
72+
nosetests.xml
73+
coverage.xml
74+
*.cover
75+
*.py,cover
76+
.hypothesis/
77+
.pytest_cache/
78+
cover/
79+
80+
# Translations
81+
*.mo
82+
*.pot
83+
84+
# Django stuff:
85+
*.log
86+
local_settings.py
87+
db.sqlite3
88+
db.sqlite3-journal
89+
90+
# Flask stuff:
91+
instance/
92+
.webassets-cache
93+
94+
# Scrapy stuff:
95+
.scrapy
96+
97+
# Sphinx documentation
98+
docs/_build/
99+
100+
# PyBuilder
101+
.pybuilder/
102+
target/
103+
104+
# Jupyter Notebook
105+
.ipynb_checkpoints
106+
107+
# IPython
108+
profile_default/
109+
ipython_config.py
110+
111+
# pyenv
112+
# For a library or package, you might want to ignore these files since the code is
113+
# intended to run in multiple environments; otherwise, check them in:
114+
# .python-version
115+
116+
# pipenv
117+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
118+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
119+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
120+
# install all needed dependencies.
121+
#Pipfile.lock
122+
123+
# poetry
124+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
125+
# This is especially recommended for binary packages to ensure reproducibility, and is more
126+
# commonly ignored for libraries.
127+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
128+
#poetry.lock
129+
130+
# pdm
131+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
132+
#pdm.lock
133+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
134+
# in version control.
135+
# https://pdm.fming.dev/#use-with-ide
136+
.pdm.toml
137+
138+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
139+
__pypackages__/
140+
141+
# Celery stuff
142+
celerybeat-schedule
143+
celerybeat.pid
144+
145+
# SageMath parsed files
146+
*.sage.py
147+
148+
# Environments
149+
.env
150+
.venv
151+
env/
152+
venv/
153+
ENV/
154+
env.bak/
155+
venv.bak/
156+
157+
# Spyder project settings
158+
.spyderproject
159+
.spyproject
160+
161+
# Rope project settings
162+
.ropeproject
163+
164+
# mkdocs documentation
165+
/site
166+
167+
# mypy
168+
.mypy_cache/
169+
.dmypy.json
170+
dmypy.json
171+
172+
# Pyre type checker
173+
.pyre/
174+
175+
# pytype static type analyzer
176+
.pytype/
177+
178+
# Cython debug symbols
179+
cython_debug/
180+
181+
# PyCharm
182+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
183+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
184+
# and can be added to the global gitignore or merged into this file. For a more nuclear
185+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
186+
#.idea/
187+
188+
189+
### https://raw.github.com/github/gitignore/d0b80a469983a7beece8fa1f5c48a8242318b531/C++.gitignore
190+
191+
# Prerequisites
192+
*.d
193+
194+
# Compiled Object files
195+
*.slo
196+
*.lo
197+
*.o
198+
*.obj
199+
200+
# Precompiled Headers
201+
*.gch
202+
*.pch
203+
204+
# Compiled Dynamic libraries
205+
*.so
206+
*.dylib
207+
*.dll
208+
209+
# Fortran module files
210+
*.mod
211+
*.smod
212+
213+
# Compiled Static libraries
214+
*.lai
215+
*.la
216+
*.a
217+
*.lib
218+
219+
# Executables
220+
*.exe
221+
*.out
222+
*.app
223+
224+
# Added manually
225+
.DS_Store
226+
output/
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.13
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
## Model Performance Prediction (Section 4.3 and Appendix F)
2+
3+
We use [`uv`](https://docs.astral.sh/uv/) for the experiment environment in this directory. See the official site for installation.
4+
5+
### Setup
6+
7+
With `uv` available, install the required packages:
8+
9+
```bash
10+
$ uv sync
11+
```
12+
13+
The experiments use [`../data/modeldata_1018.pkl`](../data/modeldata_1018.pkl) and [`../data/uniq-idx-weight/`](../data/uniq-idx-weight/). See [`../README.md`](../README.md) for details of these data.
14+
15+
### Data Preparation
16+
17+
Prepare training and prediction splits for ridge regression with `GroupKFold`:
18+
19+
```bash
20+
$ uv run src/split_data.py
21+
```
22+
23+
Five-fold splits with five seeds are saved to `output/split_data/groupkfold/`.
24+
25+
### Train and Predict with Ridge Regression
26+
27+
Train the ridge regression models and generate predictions (**This step takes about half a day !**) :
28+
29+
```bash
30+
$ uv run src/train_and_pred.py
31+
```
32+
33+
Predictions for each method (Uniform, KL, LS) are saved to `output/train_and_pred/groupkfold/`.
34+
35+
### Plot Figures
36+
37+
Draw Figure 4 from the predictions:
38+
39+
```bash
40+
$ uv run src/figure4.py
41+
```
42+
43+
Figure 4 is saved to `output/images/`.
44+
45+
<p align="center">
46+
<img src="figures/fig4.png" alt="fig4" style="width:50%">
47+
</p>
48+
49+
Draw Figure 6 from the predictions:
50+
51+
```bash
52+
$ uv run src/figure6_and_table2.py
53+
```
54+
55+
<p align="center">
56+
<img src="figures/fig6.png" alt="fig6" style="width:90%">
57+
</p>
58+
59+
Figure 6 is saved to `output/images/`. This script also saves the results for Table 2 to `output/summary/`.
107 KB
Loading
301 KB
Loading
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[project]
2+
name = "code-for-prediction"
3+
version = "0.1.0"
4+
readme = "README.md"
5+
requires-python = ">=3.13"
6+
dependencies = [
7+
"numpy",
8+
"pandas",
9+
"scikit-learn",
10+
"fire",
11+
"matplotlib",
12+
"tqdm",
13+
"jinja2>=3.1.6",
14+
]

0 commit comments

Comments
 (0)