-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathood_testing.py
More file actions
161 lines (131 loc) · 6.13 KB
/
ood_testing.py
File metadata and controls
161 lines (131 loc) · 6.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset
import pickle
from bll_pipeline import bll_experiment
from complexnetwork.CDSCNN import TrueCDSCNN
from complexnetwork.complexCNN import ComplexCNN_RLL
from deepensemblepipeline import ensemble_experiment
from duq_pipeline import duq_experiment
from heatmap import count_params
from random_ood_data import RandomOODData
from realnetwork.amc_cnn import AMC_CNN
if __name__ == "__main__":
# Load data
with open('./data/RML2016.10a_dict.pkl', 'rb') as f:
u = pickle._Unpickler(f)
u.encoding = 'latin1'
p = u.load()
snrs = sorted(list(set([key[1] for key in p.keys()])))
# snrs = snrs[4:6]
mods = sorted(list(set([key[0] for key in p.keys()])))
num_classes = len(mods)
print("All Classes:", mods)
# ---------------------------------------------------
# 1. SELECT ONE RANDOM CLASS AS OOD
# ---------------------------------------------------
np.random.seed(2016293)
#np.random.seed(20011008)
# ood_class_idx = np.random.choice(len(mods), size=5, replace=False)
ood_class_idx = [4,5,8,2,3]
ood_class = [mods[i] for i in ood_class_idx]
print(f"\nOOD Class (removed from train/val): {ood_class}")
# Separate in-distribution and OOD classes
id_mods = [mod for mod in mods if mod not in ood_class]
print(f"In-Distribution Classes ({len(id_mods)}): {id_mods}")
# ---------------------------------------------------
# 2. BUILD DATASETS SEPARATELY
# ---------------------------------------------------
# In-distribution data (for train/val/test)
X_id_list = []
y_id_list = []
for mod in id_mods:
for snr in snrs:
samples = p[(mod, snr)] # shape: [N, 2, 128]
X_id_list.append(samples)
# Map to new class indices (0 to len(id_mods)-1)
y_id_list += [id_mods.index(mod)] * samples.shape[0]
X_id = np.vstack(X_id_list)
Y_id = np.array(y_id_list)
N_id = len(Y_id)
# OOD data
X_ood_list = []
y_ood_list = []
for snr in snrs:
for ood_c in ood_class:
samples = p[(ood_c, snr)] # shape: [N, 2, 128]
X_ood_list.append(samples)
y_ood_list += [100] * samples.shape[0] # 100 is the special OOD label
X_ood = np.vstack(X_ood_list)
Y_ood = np.array(y_ood_list)
N_ood = len(Y_ood)
print(f"\nIn-Distribution samples: {N_id}")
print(f"OOD samples: {N_ood}")
# ---------------------------------------------------
# 3. TRAIN/VAL/TEST SPLIT (ID data only)
# ---------------------------------------------------
indices_id = np.arange(N_id)
np.random.shuffle(indices_id)
n_train = int(0.6 * N_id)
n_valid = int(0.2 * N_id)
n_test = N_id - n_train - n_valid
train_idx = indices_id[:n_train]
valid_idx = indices_id[n_train:n_train + n_valid]
test_idx = indices_id[n_train + n_valid:]
X_train = X_id[train_idx]
X_valid = X_id[valid_idx]
X_test = X_id[test_idx]
Y_train = Y_id[train_idx]
Y_valid = Y_id[valid_idx]
Y_test = Y_id[test_idx]
print(f"\nTrain samples: {len(Y_train)}")
print(f"Valid samples: {len(Y_valid)}")
print(f"Test samples: {len(Y_test)}")
# ---------------------------------------------------
# 4. CONVERT TO TORCH TENSORS
# ---------------------------------------------------
X_train_tensor = torch.from_numpy(X_train).float()
X_valid_tensor = torch.from_numpy(X_valid).float()
X_test_tensor = torch.from_numpy(X_test).float()
X_ood_tensor = torch.from_numpy(X_ood).float()
Y_train_tensor = torch.from_numpy(Y_train).long()
Y_valid_tensor = torch.from_numpy(Y_valid).long()
Y_test_tensor = torch.from_numpy(Y_test).long()
Y_ood_tensor = torch.from_numpy(Y_ood).long()
# ---------------------------------------------------
# 5. CREATE DATALOADERS
# ---------------------------------------------------
train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
valid_dataset = TensorDataset(X_valid_tensor, Y_valid_tensor)
test_dataset = TensorDataset(X_test_tensor, Y_test_tensor)
ood_dataset = TensorDataset(X_ood_tensor, Y_ood_tensor)
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=100, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
ood_loader = DataLoader(ood_dataset, batch_size=100, shuffle=False)
print("\n✓ DataLoaders created successfully!")
print(f" - train_loader: {len(train_loader)} batches")
print(f" - valid_loader: {len(valid_loader)} batches")
print(f" - test_loader: {len(test_loader)} batches")
print(f" - ood_loader: {len(ood_loader)} batches (OOD class: {ood_class})")
# ---------------------------------------------------
# 6. VERIFICATION
# ---------------------------------------------------
print("\n--- Class Distribution Verification ---")
print(f"Train classes: {np.unique(Y_train)} (count: {len(np.unique(Y_train))})")
print(f"Valid classes: {np.unique(Y_valid)} (count: {len(np.unique(Y_valid))})")
print(f"Test classes: {np.unique(Y_test)} (count: {len(np.unique(Y_test))})")
print(f"OOD classes: {np.unique(Y_ood)} (count: {len(np.unique(Y_ood))})")
complex_model = ComplexCNN_RLL(num_classes)
real_model = AMC_CNN(num_classes)
print(torch.__version__)
print(f"complex network:{count_params(complex_model)}")
print(f"real network:{count_params(real_model)}")
#
# test_dataset = ConcatDataset([test_dataset, ood_dataset])
# test_loader = DataLoader(test_dataset, batch_size=110, shuffle=False)
# bll_experiment(AMC_CNN, ComplexCNN, train_loader, valid_loader, test_loader, len(id_mods), 10, 0.0001, ood = True, ood_dataloader=ood_loader)
# duq_experiment(AMC_CNN, ComplexCNN, train_loader, valid_loader, test_loader, len(id_mods), 10, 0.0001, ood=True,
# ood_loader=ood_loader)
ensemble_experiment(AMC_CNN, ComplexCNN_RLL, train_loader, valid_loader, test_loader, len(id_mods),
2, epochs=10, ood=True, ood_loader=ood_loader)