Skip to content

Commit aa95e55

Browse files
committed
code release
1 parent dc564ca commit aa95e55

File tree

7 files changed

+1610
-1
lines changed

7 files changed

+1610
-1
lines changed

README.md

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,66 @@
11
# NeST
2-
[AAAI 2023] This is the code for our paper `Neighborhood-Regularized Self-Training for Learning with Few Labels'.
2+
3+
This is the code for the paper `[Neighborhood-regularized Self-training for Learning with Few Labels]()' (In Proceedings of AAAI 2023).
4+
5+
# Requirements
6+
```
7+
python 3.7
8+
transformers==4.2.0
9+
pytorch==1.8.0
10+
tqdm
11+
scikit-learn
12+
faiss-cpu==1.6.4
13+
```
14+
15+
# Datasets
16+
## Datasets
17+
18+
The datasets used in this study can be find at the following link
19+
20+
| Dataset | Task | Number of Classes | Number of Train/Test |
21+
|---------------- | -------------- |-------------- | -------------- |
22+
| [Elec](http://riejohnson.com/cnn_data.html) | Sentiment | 2 | 25K / 25K |
23+
| [AG News](https://huggingface.co/datasets/ag_news) | News Topic | 2 | 120K / 7.6K |
24+
| [NYT](https://github.com/yumeng5/CatE/tree/master/datasets/nyt) | News Topic | 4 | 30K / 3.0K |
25+
| [Chemprot](https://github.com/yueyu1030/COSINE/tree/main/data/chemprot) | Chemical Relation | 10 | 12K / 1.6K |
26+
27+
## Input Format
28+
"_id" stands for the class id, and "text" is the content of the document.
29+
```
30+
{"_id": 0, "text": "Congo Official: Rwanda Troops Attacking (AP) AP - A senior Congolese official said Tuesday his nation had been invaded by neighboring Rwanda, and U.N. officials said they were investigating claims of Rwandan forces clashing with militias in the east."}
31+
{"_id": 1, "text": "Stadler Leads First Tee Open (AP) AP - Craig Stadler moved into position for his second straight victory Saturday, shooting a 9-under 63 to take a one-stroke lead over Jay Haas after the second round of the inaugural First Tee Open."}
32+
{"_id": 2, "text": "Intel Shares Edge Lower After Downgrade NEW YORK (Reuters) - Intel Corp shares slipped on Tuesday after Credit Suisse First Boston downgraded the stock, forecasting that the computer chip maker will have difficulty outperforming the overall semiconductor sector next year."}
33+
{"_id": 3, "text": "Debating the Dinosaur Extinction At least 50 percent of the world's species, including the dinosaurs, went extinct 65 million years ago. While most scientists now blame this catastrophe on a large meteorite impact, others wonder if there is more to the story."}
34+
...
35+
}
36+
```
37+
38+
## Training
39+
Please use the commands in `commands` folder for experiments.
40+
Take AG News dataset as an example, `run_agnews.sh` is used for running the experiment for self-training.
41+
42+
43+
44+
# Hyperparameter Tuning
45+
Some Key Hyperparameters are listed as follows
46+
- `k`: The number of nearest neighbors used in KNN.
47+
- `learning_rate`: The learning rate for initialzation.
48+
- `learning_rate_st`: The learning rate for self-training.
49+
- `self_training_update_period`: The update period of self-training.
50+
- `self_training_weight`: The weight to balance labeled data and unlabeled data during self-training.
51+
- `num_unlabeled`: The number of unlabeled data in the beginning.
52+
- `num_unlabeled_add`: The number of added unlabeled data in each self-training round.
53+
54+
55+
# Citation
56+
57+
Please kindly cite the following paper if you are using our datasets/codebase. Thanks!
58+
59+
```
60+
@inproceedings{xu2023neighborhood,
61+
title = "Neighborhood-regularized Self-training for Learning with Few Labels",
62+
author = "Ran Xu and Yue Yu and Hejie Cui and Xuan Kan and Yanqiao Zhu and Joyce C. Ho and Chao Zhang and Carl Yang",
63+
booktitle = "Proceedings of the Thirty-Seventh AAAI Conference on Artificial Intelligence",
64+
year = "2023",
65+
}
66+
```

commands/run_agnews.sh

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
task=agnews
2+
gpu=0
3+
n_gpu=2
4+
5+
train_seed=42
6+
label_per_class=30
7+
model_type=roberta-base
8+
train_seed=${train_seed}
9+
method=train
10+
max_seq_len=128
11+
self_training_batch_size=32
12+
eval_batch_size=256
13+
dev_labels=100
14+
steps=100
15+
logging_steps=10
16+
st_logging_steps=20
17+
epochs=18
18+
k=5
19+
20+
lr=2e-5
21+
eps=0.9
22+
self_training_weight=1
23+
gce_loss_q=0.6
24+
lr_st=1e-5
25+
batch_size=8
26+
self_training_batch_size=32
27+
self_training_update_period=1000
28+
self_training_max_step=2000
29+
num_unlabeled=2000
30+
num_unlabeled_add=2000
31+
ssl_cmd="--learning_rate_st=${lr_st} --self_training_eps=${eps} --self_training_weight=${self_training_weight} --self_training_update_period=${self_training_update_period} --gce_loss_q=${gce_loss_q} --num_unlabeled=${num_unlabeled} --num_unlabeled_add=${num_unlabeled_add}"
32+
33+
34+
model_type=${model_type} #dmis-lab/biobert-v1.1 #"allenai/scibert_scivocab_uncased"
35+
output_dir=${task}/${label_per_class}/model #../datasets/${task}-${label_per_class}-10/model
36+
mkdir -p ${output_dir}
37+
echo ${method}
38+
mkdir -p ../datasets/${task}-${label_per_class}/cache
39+
# valid_${train_label}.json
40+
train_cmd="CUDA_VISIBLE_DEVICES=${gpu} python3 main.py --do_train --do_eval --task=${task} \
41+
--train_file=train.json --dev_file=valid.json --test_file=test.json \
42+
--unlabel_file=unlabeled.json \
43+
--data_dir=../datasets/${task}-${label_per_class} --train_seed=${train_seed} \
44+
--cache_dir="../datasets/${task}-${label_per_class}/cache" \
45+
--output_dir=${output_dir} \
46+
--logging_steps=${logging_steps} --self_train_logging_steps=${st_logging_steps} --dev_labels=${dev_labels} \
47+
--gpu=${gpu} --n_gpu=${n_gpu} --num_train_epochs=${epochs} --weight_decay=1e-8 \
48+
--learning_rate=${lr} \
49+
--method=${method} --batch_size=${batch_size} --eval_batch_size=${eval_batch_size} \
50+
--self_training_batch_size=${self_training_batch_size} \
51+
--max_seq_len=${max_seq_len} --auto_load=1 \
52+
--max_steps=${steps} --model_type=${model_type} \
53+
--self_training_max_step=${self_training_max_step} \
54+
--sample_labels=${train_label} ${ssl_cmd} --k=${k} --label_per_class=${label_per_class}"
55+
echo $train_cmd
56+
eval $train_cmd

eval.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import faiss
2+
import numpy as np
3+
import os
4+
5+
def inference_knn(train_pred, train_feat, train_label, unlabeled_pred, unlabeled_feat, unlabeled_label, unlabeled_pseudo,k, gamma = 0.1, beta=0.1, prev_val = None):
6+
train_pred = np.array(train_pred)
7+
unlabeled_pred = np.array(unlabeled_pred)
8+
d = train_feat.shape[-1]
9+
index = faiss.IndexFlatL2(d)
10+
index.add(train_feat)
11+
D, I = index.search(unlabeled_feat, k)
12+
unlabeled_pred = np.expand_dims(unlabeled_pred, axis = 1)
13+
# [#unlabel, 1]
14+
# train_pred[I] ---> [#unlabel, k]
15+
# print(unlabeled_pred.shape)
16+
score = np.log((1e-10 + train_pred[I])/ (1e-10 + unlabeled_pred)) * train_pred[I]
17+
# print(score.shape)
18+
mean_kl = np.mean(np.sum(score, axis = -1), axis = -1)
19+
20+
# mean_mse = np.mean((train_pred[I] - unlabeled_pred)**2, axis = -1)
21+
# train pred (n_samples, n_class)
22+
# train pred[I] (n_samples, n_neighbor, n_class)
23+
var_mse = np.var(train_pred[I], axis = -1)
24+
25+
if prev_val is not None:
26+
current_val = prev_val * gamma + (1- gamma) * (mean_kl + var_mse * beta)
27+
else:
28+
current_val = mean_kl + var_mse * beta
29+
idx = np.argsort(current_val)
30+
31+
return idx
32+
33+
def inference_conf(train_pred, train_feat, train_label, unlabeled_pred, unlabeled_feat, unlabeled_label, unlabeled_pseudo, gamma = 0.1, prev_val = None):
34+
train_pred = np.array(train_pred)
35+
unlabeled_pred = np.array(unlabeled_pred)
36+
current_val = -np.max(unlabeled_pred, axis = -1)
37+
if prev_val is not None:
38+
current_val = prev_val * gamma + (1- gamma) * (current_val)
39+
else:
40+
current_val = current_val
41+
idx = np.argsort(current_val)
42+
43+
return idx
44+
45+
def inference_uncertainty(unlabeled_label, unlabeled_pseudo, mutual_info, gamma = 0.1, prev_val = None):
46+
if prev_val is not None:
47+
current_val = prev_val * gamma + (1- gamma) * (mutual_info)
48+
else:
49+
current_val = mutual_info
50+
idx = np.argsort(current_val)
51+
52+
return idx
53+
54+
def save_data(train_pred, train_feat, train_label, unlabeled_pred, unlabeled_feat, unlabeled_label, unlabeled_pseudo, dataset = 'agnews', n_labels = 10, n_iter = 0):
55+
if n_iter == 0:
56+
path = f"{dataset}/{n_labels}"
57+
58+
else:
59+
path = f"{dataset}/{n_labels}_{n_iter}"
60+
os.makedirs(path, exist_ok = True)
61+
62+
with open(f"{path}/train_pred.npy", 'wb') as f:
63+
np.save(f, train_pred)
64+
65+
with open(f"{path}/train_feat.npy", 'wb') as f:
66+
np.save(f, train_feat)
67+
68+
with open(f"{path}/train_label.npy", 'wb') as f:
69+
np.save(f, train_label)
70+
71+
with open(f"{path}/unlabeled_pred.npy", 'wb') as f:
72+
np.save(f, unlabeled_pred)
73+
74+
with open(f"{path}/unlabeled_feat.npy", 'wb') as f:
75+
np.save(f, unlabeled_feat)
76+
77+
with open(f"{path}/unlabeled_label.npy", 'wb') as f:
78+
np.save(f, unlabeled_label)
79+
80+
with open(f"{path}/unlabeled_pseudo.npy", 'wb') as f:
81+
np.save(f, unlabeled_pseudo)
82+
83+
84+
85+
86+
def load_pred_data(dataset = 'agnews', n_labels = 10, n_iter = 0):
87+
# os.makedirs(f"{dataset}/{n_labels}", exist_ok = True)
88+
# with open(f"{dataset}/{n_labels}/train_pred.npy", 'rb') as f:
89+
if n_iter == 0:
90+
path = f"{dataset}/{n_labels}"
91+
else:
92+
path = f"{dataset}/{n_labels}_{n_iter}"
93+
train_pred = np.load(f"{path}/train_pred.npy")
94+
95+
train_feat = np.load(f"{path}/train_feat.npy")
96+
97+
train_label = np.load(f"{path}/train_label.npy")
98+
99+
unlabeled_pred = np.load(f"{path}/unlabeled_pred.npy")
100+
101+
unlabeled_feat = np.load(f"{path}/unlabeled_feat.npy")
102+
103+
unlabeled_label = np.load(f"{path}/unlabeled_label.npy")
104+
105+
unlabeled_pseudo = np.load(f"{path}/unlabeled_pseudo.npy")
106+
107+
return train_pred, train_feat, train_label, unlabeled_pred, unlabeled_feat, unlabeled_label, unlabeled_pseudo

0 commit comments

Comments
 (0)