-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
80 lines (70 loc) · 3.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import glob
import traceback
import io
import os
import requests
import gradio as gr
from PIL import Image
from tqdm import tqdm
from logger import logger
class UI_Methods(object):
def __init__(self, logger):
self.logger = logger
def load_images(self, image_dir):
try:
path_list = glob.glob(image_dir + '/*.png')
image_list = []
for path in tqdm(path_list):
image = Image.open(path)
image_list.append(image)
return gr.update(value=image_list), gr.update(value='0'), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
except Exception as exc:
self.logger.error(f'load_images error: {exc}\n {traceback.format_exc()}\n')
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
def caption_select_image(self, output_dir: str, image_idx: str, images: dict):
try:
image_idx = int(image_idx)
if image_idx >= 0 and image_idx >= len(images):
image_idx -= len(images)
elif image_idx <= 0 and abs(image_idx) >= len(images):
image_idx += len(images)
image_info = images[image_idx]
image_url = image_info['data']
resp = requests.get(image_url).content
image = Image.open(io.BytesIO(resp))
if os.path.exists(output_dir + f'/{image_idx}.txt'):
with open(output_dir + f'/{image_idx}.txt', 'r') as f:
prompt = f.read()
return gr.update(value=image), gr.update(value=prompt), gr.update(value=str(image_idx))
else:
return gr.update(value=image), gr.update(value=''), gr.update(value=str(image_idx))
except Exception as exc:
self.logger.error(f'caption_select_image error: {exc}\n {traceback.format_exc()}\n')
return gr.update()
def save_prompt(self, output_dir: str, image_idx: str, image, prompt: str):
try:
save_path = '/'.join([output_dir, f'{image_idx}.png'])
image.save(save_path)
with open(save_path.replace('.png', '.txt'), 'w') as f:
f.write(prompt)
return gr.update(value=str(int(image_idx) + 1))
except Exception as exc:
self.logger.error(f'save_label error: {exc}\n {traceback.format_exc()}\n')
return gr.update()
def format_output_dir(self, output_dir):
try:
path_list = glob.glob(output_dir + '/*.png')
for i, path in enumerate(tqdm(path_list)):
basedir = path[:path.rfind('/')]
image = Image.open(path)
with open(path.replace('.png', '.txt'), 'r') as f:
content = f.read()
new_name = str(i).zfill(6)
image.save(basedir + '/' + new_name + '.png')
with open(basedir + '/' + new_name + '.txt', 'w') as f:
f.write(content)
os.unlink(path)
os.unlink(path.replace('.png', '.txt'))
except Exception as exc:
self.logger.error(f'format_output_dir: {exc}\n {traceback.format_exc()}\n')
ui_methods = UI_Methods(logger)