-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_adv_exmp.py
138 lines (118 loc) · 5.08 KB
/
generate_adv_exmp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import os
from torchvision.transforms import ToPILImage
from PIL import Image
import torch
import torchattacks
def choose_data(dataset):
if 'MNIST' in dataset:
from data_scripts import MNIST
train_data_loader, test_data_loader = MNIST.main(args)
elif 'CIFAR' in dataset:
from data_scripts import CIFAR
train_data_loader, test_data_loader = CIFAR.encapsulate_loader(args)
return train_data_loader, test_data_loader
def cuda(model):
if torch.cuda.is_available():
model = model.cuda()
device_num = torch.cuda.device_count()
print('you have %d available GPU' % (device_num))
if device_num > 1:
device_ids = [x for x in range(device_num)]
model = torch.nn.DataParallel(model, device_ids=device_ids)
args.batch_size *= device_num
return model
def define_model(net_arch, dataset='ImageNet'):
if net_arch == 'resnet18':
from model import ResNet
if dataset == 'ImageNet':
model = ResNet.resnet18_ImageNet
if dataset == 'CIFAR':
model = ResNet.resnet18_CIFAR
elif net_arch == 'MNIST_Net':
from model import MNIST_Net
model = MNIST_Net.MNIST_net
elif net_arch == 'CIFAR_Net':
from model import CIFAR_Net
model = CIFAR_Net.CIFAR_Net
elif net_arch == 'wideresnet':
from model import wideresnet
model = wideresnet.WideResNet()
return model
def load(model):
if args.load:
model.load_state_dict(torch.load(args.load)['state_dict'])
print('Model loaded from {}'.format(args.load))
def attack_method(method,model):
if method == 'fgsm':
attack = torchattacks.FGSM(model,eps=args.eps)
return attack
def adv_img_root(data_root, method, eps):
adv_train_save_path = os.path.join(data_root,method,str(eps),'train')
adv_test_save_path = os.path.join(data_root,method,str(eps),'test')
os.makedirs(adv_train_save_path,exist_ok=True)
os.makedirs(adv_test_save_path,exist_ok=True)
if 'CIFAR' or 'MNIST' in data_root:
for i in range(10):
adv_cla_train_dir = os.path.join(adv_train_save_path,str(i))
os.makedirs(adv_cla_train_dir,exist_ok=True)
adv_cla_test_dir = os.path.join(adv_test_save_path,str(i))
os.makedirs(adv_cla_test_dir,exist_ok=True)
elif 'ImageNet' in data_root:
pass
return adv_train_save_path, adv_test_save_path
def save_image(images,labels,save_path,iter):
for i in range(images.size(0)):
adv_image = images[i:i + 1, :, :, :]
adv_image = adv_image.squeeze()
cla = labels[i:i + 1].item()
image_save_path = os.path.join(save_path, str(cla), str('%05d' %(iter * args.batch_size + i)) + '.jpg')
print(image_save_path)
adv_image = ToPILImage()(adv_image.cpu())
adv_image.save(image_save_path)
def main(args):
#adv_train_save_dir, adv_test_save_dir = adv_img_root(args.data_root,args.attack_method,args.eps)
model = define_model(args.net_arch)
model = cuda(model)
load(model)
attack = attack_method(args.attack_method,model)
train_data_loader, test_data_loader = choose_data(args.dataset)
#for i, data in enumerate(train_data_loader):
# images, labels = data
# images = images.cuda()
# labels = labels.cuda()
# adv_images = attack(images, labels)
# adv_images = adv_images.squeeze()
# save_image(adv_images,labels,adv_train_save_dir,i)
correct = 0
total = 0
#with torch.no_grad():
for i, data in enumerate(test_data_loader):
model.eval()
images, label = data
images = images.cuda()
label = label.cuda()
adv_images = attack(images, label)
outputs = model(adv_images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += label.size(0)
correct += (predicted == label).sum().item()
print('%d/%d' % (i, len(test_data_loader)))
acc = correct / total
print('Accuracy of the network on the 10000 test images: %f %%' % (
100 * correct / total))
save_image(adv_images,labels,adv_test_save_dir,i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='attck template')
parser.add_argument('--data_root', type=str, default='/home/panmeng/data/')
parser.add_argument('--dataset', type=str, default='CIFAR',choices=['CIFAR','MNIST'])
parser.add_argument('--net_arch', type=str, default='wideresnet', choices=['resnet18', 'mnist_net', 'CIFAR_Net','wideresnet'])
parser.add_argument('--load', type=str, default='/home/panmeng/adv_frame/adv_frame/experiments/2021_08_30_15_52_59/ckp/23checkpoint.pth.tar')
parser.add_argument('--attack_method', type=str, default='fgsm', choices=['fgsm','deepfool'])
parser.add_argument('--eps',type=float, default=0.03137)
parser.add_argument('--batch_size',type=int,default=64)
parser.add_argument('--num_worker', type = int, default=4)
args = parser.parse_args()
#os.environ['CUDA_VISIBLE_DEVICES'] ='1,2,3'
main(args)