-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
138 lines (116 loc) · 3.44 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import os
import random
from torch.backends import cudnn
import torch.utils.data.distributed
from torchvision import transforms
from torchvision import utils
from tqdm import tqdm
from cyclegan_pytorch import Generator, ImageDataset
parser = argparse.ArgumentParser(
description="PyTorch implements `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks`"
)
parser.add_argument(
"--dataroot",
type=str,
default="./data",
help="path to datasets. (default:./data)",
)
parser.add_argument(
"--dataset",
type=str,
default="horse2zebra",
help="dataset name. (default:`horse2zebra`)"
"Option: [apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, "
"cezanne2photo, ukiyoe2photo, vangogh2photo, maps, facades, selfie2anime, "
"iphone2dslr_flower, ae_photos, ]",
)
parser.add_argument("--cuda", action="store_true", help="Enables cuda")
parser.add_argument(
"--outf",
default="./results",
help="folder to output images. (default: `./results`).",
)
parser.add_argument(
"--image-size",
type=int,
default=256,
help="size of the data crop (squared assumed). (default:256)",
)
parser.add_argument(
"--manualSeed",
type=int,
help="Seed for initializing training. (default:none)",
)
args = parser.parse_args()
print(args)
try:
os.makedirs(args.outf)
except OSError:
pass
if args.manualSeed is None:
args.manualSeed = random.randint(1, 10000)
print("Random Seed: ", args.manualSeed)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
cudnn.benchmark = True
if torch.cuda.is_available() and not args.cuda:
print(
"WARNING: You have a CUDA device, so you should probably run with --cuda"
)
# Dataset
dataset = ImageDataset(
root=os.path.join(args.dataroot, args.dataset),
transform=transforms.Compose(
[
transforms.Resize(args.image_size),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
]
),
mode="test",
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, pin_memory=True
)
try:
os.makedirs(os.path.join(args.outf, str(args.dataset), "A"))
os.makedirs(os.path.join(args.outf, str(args.dataset), "B"))
except OSError:
pass
device = torch.device("cuda:0" if args.cuda else "cpu")
# create model
netG_A2B = Generator().to(device)
netG_B2A = Generator().to(device)
# Load state dicts
netG_A2B.load_state_dict(
torch.load(os.path.join("weights", str(args.dataset), "netG_A2B.pth"))
)
netG_B2A.load_state_dict(
torch.load(os.path.join("weights", str(args.dataset), "netG_B2A.pth"))
)
# Set model mode
netG_A2B.eval()
netG_B2A.eval()
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
for i, data in progress_bar:
# get batch size data
real_images_A = data["A"].to(device)
real_images_B = data["B"].to(device)
# Generate output
fake_image_A = 0.5 * (netG_B2A(real_images_B).data + 1.0)
fake_image_B = 0.5 * (netG_A2B(real_images_A).data + 1.0)
# Save image files
utils.save_image(
fake_image_A.detach(),
f"{args.outf}/{args.dataset}/A/{i + 1:04d}.png",
normalize=True,
)
utils.save_image(
fake_image_B.detach(),
f"{args.outf}/{args.dataset}/B/{i + 1:04d}.png",
normalize=True,
)
progress_bar.set_description(
f"Process images {i + 1} of {len(dataloader)}"
)