Skip to content

Commit

Permalink
Fix folds in metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
jilljenn committed Apr 5, 2024
1 parent 0b4b84d commit 4788ba9
Show file tree
Hide file tree
Showing 4 changed files with 525 additions and 16 deletions.
25 changes: 10 additions & 15 deletions eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def all_metrics(results, test):
print('This is not the right fold', len(y), len(test))
sys.exit(0)

for user, pred, true in zip(test['user_id'], y_pred, y):
for user, pred, true in zip(test['user'], y_pred, y):
predictions_per_user[user]['pred'].append(pred)
predictions_per_user[user]['y'].append(true)

Expand Down Expand Up @@ -135,6 +135,7 @@ def all_metrics(results, test):
val = 0
for subgroup, auc, nb in zip(attr_ids, metrics_per_sensitive_attr['auc'], nb_samples):
candidates[subgroup] = (-auc, -nb)
print(len(candidates), 'groups and ', test[SENSITIVE_ATTR].nunique(), 'schools in test')

x = []
nb = []
Expand All @@ -145,6 +146,9 @@ def all_metrics(results, test):
nb.append(-yi)
val += 1
plt.stem(x, nb, use_line_collection=True)
plt.xlabel('AUC value')
plt.ylabel('Number of samples in group')
plt.title('For each group, number of samples per AUC value')
plt.show()

# Display ids of the subgroups (sensitive attribute) that have the lowest/highest AUC
Expand Down Expand Up @@ -173,29 +177,20 @@ def all_metrics(results, test):
plt.show()

if __name__ == '__main__':
os.chdir('data/assistments09')
os.chdir('data/assistments2009full')
# os.chdir('data/fr_en')

# indices = np.load('folds/weak278607fold0.npy')
# indices = np.load('folds/278607fold0.npy')
# indices = np.load('folds/50weak278607fold0.npy')
# indices = np.load('folds/weak926646fold0.npy')
# indices = np.load('folds/1199731fold0.npy')
# indices = np.load('folds/50weak341791fold0.npy')
indices = np.load('folds/50weak341791fold0.npy')
print(len(indices))

df = pd.read_csv('needed.csv')
test = df.iloc[indices]
df = pd.read_csv('data.csv')

# r = re.compile(r'results-(.*).json')

# ndcg_ = defaultdict(list)
for filename in sorted(glob.glob('results*2020*'))[::-1][:1]:
for filename in sorted(glob.glob('results*2024*'))[::-1][:1]:

print(filename)

with open(filename) as f:
results = json.load(f)

all_metrics(results, test)
i_test = results['predictions'][0]['i_test']
all_metrics(results, df.iloc[i_test])
1 change: 1 addition & 0 deletions lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
if dataset == 'Test':
predictions.append({
'fold': i,
'i_test': i_test.tolist(),
'pred': y_pred.tolist(),
'y': y.tolist()
})
Expand Down
Loading

0 comments on commit 4788ba9

Please sign in to comment.