diff --git a/landmarkdiff/inference.py b/landmarkdiff/inference.py index e7f827d..a2daeaf 100644 --- a/landmarkdiff/inference.py +++ b/landmarkdiff/inference.py @@ -379,6 +379,74 @@ def generate( "identity_check": identity_check, "restore_used": restore_used, } + + def batch_procedures( + self, + image: np.ndarray, + procedures: list[str], + intensity: float = 50.0, + num_inference_steps: int = 30, + guidance_scale: float = 9.0, + controlnet_conditioning_scale: float = 0.9, + strength: float = 0.5, + seed: Optional[int] = None, + postprocess: bool = True, + use_gfpgan: bool = False, + ) -> dict: + """Run multiple procedures on the same image, sharing landmark detection. + + Args: + image: Input BGR image. + procedures: List of procedure names to apply. + intensity: Procedure intensity (0-100). + num_inference_steps: Diffusion steps. + guidance_scale: CFG scale. + controlnet_conditioning_scale: ControlNet scale. + strength: Img2img strength. + seed: Random seed. + postprocess: Whether to apply postprocessing. + use_gfpgan: Whether to use GFPGAN restoration. + + Returns: + Dict with keys: + - results: list of per-procedure result dicts + - grid: comparison grid image (input + all outputs) + - procedures: list of procedure names + """ + if not self.is_loaded: + raise RuntimeError("Pipeline not loaded. Call .load() first.") + + image_512 = cv2.resize(image, (512, 512)) + + # Extract landmarks once, shared across all procedures + face = extract_landmarks(image_512) + if face is None: + raise ValueError("No face detected in image.") + + results = [] + for procedure in procedures: + result = self.generate( + image, + procedure=procedure, + intensity=intensity, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + controlnet_conditioning_scale=controlnet_conditioning_scale, + strength=strength, + seed=seed, + postprocess=postprocess, + use_gfpgan=use_gfpgan, + ) + results.append(result) + + # Build comparison grid: input | proc1 | proc2 | ... + grid = _build_comparison_grid(image_512, procedures, results) + + return { + "results": results, + "grid": grid, + "procedures": procedures, + } def _generate_controlnet( self, image: np.ndarray, conditioning: np.ndarray, @@ -503,6 +571,43 @@ def run_inference( print(f"Results saved to {out}/") +def _build_comparison_grid( + original: np.ndarray, + procedures: list[str], + results: list[dict], + label_height: int = 24, +) -> np.ndarray: + """Build a labeled comparison grid: original | proc1 | proc2 | ... + + Args: + original: The resized input image (512x512 BGR). + procedures: List of procedure names. + results: List of result dicts from generate(). + label_height: Height in pixels for the label bar above each image. + + Returns: + Horizontally stacked grid image with labels. + """ + panels = [original] + [r["output"] for r in results] + labels = ["Original"] + [p.capitalize() for p in procedures] + + h, w = panels[0].shape[:2] + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.5 + thickness = 1 + + labeled = [] + for img, label in zip(panels, labels): + bar = np.zeros((label_height, w, 3), dtype=np.uint8) + text_size = cv2.getTextSize(label, font, font_scale, thickness)[0] + x = (w - text_size[0]) // 2 + y = (label_height + text_size[1]) // 2 + cv2.putText(bar, label, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) + labeled.append(np.vstack([bar, img])) + + return np.hstack(labeled) + + if __name__ == "__main__": import argparse