|
| 1 | +import faiss |
| 2 | +import numpy as np |
| 3 | +import os |
| 4 | + |
| 5 | +def inference_knn(train_pred, train_feat, train_label, unlabeled_pred, unlabeled_feat, unlabeled_label, unlabeled_pseudo,k, gamma = 0.1, beta=0.1, prev_val = None): |
| 6 | + train_pred = np.array(train_pred) |
| 7 | + unlabeled_pred = np.array(unlabeled_pred) |
| 8 | + d = train_feat.shape[-1] |
| 9 | + index = faiss.IndexFlatL2(d) |
| 10 | + index.add(train_feat) |
| 11 | + D, I = index.search(unlabeled_feat, k) |
| 12 | + unlabeled_pred = np.expand_dims(unlabeled_pred, axis = 1) |
| 13 | + # [#unlabel, 1] |
| 14 | + # train_pred[I] ---> [#unlabel, k] |
| 15 | + # print(unlabeled_pred.shape) |
| 16 | + score = np.log((1e-10 + train_pred[I])/ (1e-10 + unlabeled_pred)) * train_pred[I] |
| 17 | + # print(score.shape) |
| 18 | + mean_kl = np.mean(np.sum(score, axis = -1), axis = -1) |
| 19 | + |
| 20 | + # mean_mse = np.mean((train_pred[I] - unlabeled_pred)**2, axis = -1) |
| 21 | + # train pred (n_samples, n_class) |
| 22 | + # train pred[I] (n_samples, n_neighbor, n_class) |
| 23 | + var_mse = np.var(train_pred[I], axis = -1) |
| 24 | + |
| 25 | + if prev_val is not None: |
| 26 | + current_val = prev_val * gamma + (1- gamma) * (mean_kl + var_mse * beta) |
| 27 | + else: |
| 28 | + current_val = mean_kl + var_mse * beta |
| 29 | + idx = np.argsort(current_val) |
| 30 | + |
| 31 | + return idx |
| 32 | + |
| 33 | +def inference_conf(train_pred, train_feat, train_label, unlabeled_pred, unlabeled_feat, unlabeled_label, unlabeled_pseudo, gamma = 0.1, prev_val = None): |
| 34 | + train_pred = np.array(train_pred) |
| 35 | + unlabeled_pred = np.array(unlabeled_pred) |
| 36 | + current_val = -np.max(unlabeled_pred, axis = -1) |
| 37 | + if prev_val is not None: |
| 38 | + current_val = prev_val * gamma + (1- gamma) * (current_val) |
| 39 | + else: |
| 40 | + current_val = current_val |
| 41 | + idx = np.argsort(current_val) |
| 42 | + |
| 43 | + return idx |
| 44 | + |
| 45 | +def inference_uncertainty(unlabeled_label, unlabeled_pseudo, mutual_info, gamma = 0.1, prev_val = None): |
| 46 | + if prev_val is not None: |
| 47 | + current_val = prev_val * gamma + (1- gamma) * (mutual_info) |
| 48 | + else: |
| 49 | + current_val = mutual_info |
| 50 | + idx = np.argsort(current_val) |
| 51 | + |
| 52 | + return idx |
| 53 | + |
| 54 | +def save_data(train_pred, train_feat, train_label, unlabeled_pred, unlabeled_feat, unlabeled_label, unlabeled_pseudo, dataset = 'agnews', n_labels = 10, n_iter = 0): |
| 55 | + if n_iter == 0: |
| 56 | + path = f"{dataset}/{n_labels}" |
| 57 | + |
| 58 | + else: |
| 59 | + path = f"{dataset}/{n_labels}_{n_iter}" |
| 60 | + os.makedirs(path, exist_ok = True) |
| 61 | + |
| 62 | + with open(f"{path}/train_pred.npy", 'wb') as f: |
| 63 | + np.save(f, train_pred) |
| 64 | + |
| 65 | + with open(f"{path}/train_feat.npy", 'wb') as f: |
| 66 | + np.save(f, train_feat) |
| 67 | + |
| 68 | + with open(f"{path}/train_label.npy", 'wb') as f: |
| 69 | + np.save(f, train_label) |
| 70 | + |
| 71 | + with open(f"{path}/unlabeled_pred.npy", 'wb') as f: |
| 72 | + np.save(f, unlabeled_pred) |
| 73 | + |
| 74 | + with open(f"{path}/unlabeled_feat.npy", 'wb') as f: |
| 75 | + np.save(f, unlabeled_feat) |
| 76 | + |
| 77 | + with open(f"{path}/unlabeled_label.npy", 'wb') as f: |
| 78 | + np.save(f, unlabeled_label) |
| 79 | + |
| 80 | + with open(f"{path}/unlabeled_pseudo.npy", 'wb') as f: |
| 81 | + np.save(f, unlabeled_pseudo) |
| 82 | + |
| 83 | + |
| 84 | + |
| 85 | + |
| 86 | +def load_pred_data(dataset = 'agnews', n_labels = 10, n_iter = 0): |
| 87 | + # os.makedirs(f"{dataset}/{n_labels}", exist_ok = True) |
| 88 | + # with open(f"{dataset}/{n_labels}/train_pred.npy", 'rb') as f: |
| 89 | + if n_iter == 0: |
| 90 | + path = f"{dataset}/{n_labels}" |
| 91 | + else: |
| 92 | + path = f"{dataset}/{n_labels}_{n_iter}" |
| 93 | + train_pred = np.load(f"{path}/train_pred.npy") |
| 94 | + |
| 95 | + train_feat = np.load(f"{path}/train_feat.npy") |
| 96 | + |
| 97 | + train_label = np.load(f"{path}/train_label.npy") |
| 98 | + |
| 99 | + unlabeled_pred = np.load(f"{path}/unlabeled_pred.npy") |
| 100 | + |
| 101 | + unlabeled_feat = np.load(f"{path}/unlabeled_feat.npy") |
| 102 | + |
| 103 | + unlabeled_label = np.load(f"{path}/unlabeled_label.npy") |
| 104 | + |
| 105 | + unlabeled_pseudo = np.load(f"{path}/unlabeled_pseudo.npy") |
| 106 | + |
| 107 | + return train_pred, train_feat, train_label, unlabeled_pred, unlabeled_feat, unlabeled_label, unlabeled_pseudo |
0 commit comments