From 3006b7aca102c904704c86d7c919add9ceffd15c Mon Sep 17 00:00:00 2001 From: "Zhang Ch. N." Date: Sat, 21 Sep 2024 20:52:39 +0800 Subject: [PATCH] Reuse, reduce, recycle Save pickle file in results folder; Save input latent list in pickle, too. --- scripts/inference.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/scripts/inference.py b/scripts/inference.py index fe2a2345..dec4a6d8 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -39,7 +39,7 @@ def main(args): audio_basename = os.path.basename(audio_path).split('.')[0] output_basename = f"{input_basename}_{audio_basename}" result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs - crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input + crop_coord_save_path = os.path.join(args.result_dir, input_basename+".pkl") # only related to video input os.makedirs(result_img_save_path,exist_ok =True) if args.output_vid_name is None: @@ -72,24 +72,26 @@ def main(args): if os.path.exists(crop_coord_save_path) and args.use_saved_coord: print("using extracted coordinates") with open(crop_coord_save_path,'rb') as f: - coord_list = pickle.load(f) + saved_lists = pickle.load(f) + coord_list = saved_lists['coord'] + input_latent_list = saved_lists['latent'] frame_list = read_imgs(input_img_list) else: print("extracting landmarks...time consuming") coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) + input_latent_list = [] + for bbox, frame in zip(coord_list, frame_list): + if bbox == coord_placeholder: + continue + x1, y1, x2, y2 = bbox + crop_frame = frame[y1:y2, x1:x2] + crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) + latents = vae.get_latents_for_unet(crop_frame) + input_latent_list.append(latents) with open(crop_coord_save_path, 'wb') as f: - pickle.dump(coord_list, f) + pickle.dump({'coord': coord_list, 'latent': input_latent_list}, f) - i = 0 - input_latent_list = [] - for bbox, frame in zip(coord_list, frame_list): - if bbox == coord_placeholder: - continue - x1, y1, x2, y2 = bbox - crop_frame = frame[y1:y2, x1:x2] - crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) - latents = vae.get_latents_for_unet(crop_frame) - input_latent_list.append(latents) + # to smooth the first and the last frame frame_list_cycle = frame_list + frame_list[::-1]