Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions landmarkdiff/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,74 @@ def generate(
"identity_check": identity_check,
"restore_used": restore_used,
}

def batch_procedures(
self,
image: np.ndarray,
procedures: list[str],
intensity: float = 50.0,
num_inference_steps: int = 30,
guidance_scale: float = 9.0,
controlnet_conditioning_scale: float = 0.9,
strength: float = 0.5,
seed: Optional[int] = None,
postprocess: bool = True,
use_gfpgan: bool = False,
) -> dict:
"""Run multiple procedures on the same image, sharing landmark detection.

Args:
image: Input BGR image.
procedures: List of procedure names to apply.
intensity: Procedure intensity (0-100).
num_inference_steps: Diffusion steps.
guidance_scale: CFG scale.
controlnet_conditioning_scale: ControlNet scale.
strength: Img2img strength.
seed: Random seed.
postprocess: Whether to apply postprocessing.
use_gfpgan: Whether to use GFPGAN restoration.

Returns:
Dict with keys:
- results: list of per-procedure result dicts
- grid: comparison grid image (input + all outputs)
- procedures: list of procedure names
"""
if not self.is_loaded:
raise RuntimeError("Pipeline not loaded. Call .load() first.")

image_512 = cv2.resize(image, (512, 512))

# Extract landmarks once, shared across all procedures
face = extract_landmarks(image_512)
if face is None:
raise ValueError("No face detected in image.")

results = []
for procedure in procedures:
result = self.generate(
image,
procedure=procedure,
intensity=intensity,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
strength=strength,
seed=seed,
postprocess=postprocess,
use_gfpgan=use_gfpgan,
)
results.append(result)

# Build comparison grid: input | proc1 | proc2 | ...
grid = _build_comparison_grid(image_512, procedures, results)

return {
"results": results,
"grid": grid,
"procedures": procedures,
}

def _generate_controlnet(
self, image: np.ndarray, conditioning: np.ndarray,
Expand Down Expand Up @@ -503,6 +571,43 @@ def run_inference(
print(f"Results saved to {out}/")


def _build_comparison_grid(
original: np.ndarray,
procedures: list[str],
results: list[dict],
label_height: int = 24,
) -> np.ndarray:
"""Build a labeled comparison grid: original | proc1 | proc2 | ...

Args:
original: The resized input image (512x512 BGR).
procedures: List of procedure names.
results: List of result dicts from generate().
label_height: Height in pixels for the label bar above each image.

Returns:
Horizontally stacked grid image with labels.
"""
panels = [original] + [r["output"] for r in results]
labels = ["Original"] + [p.capitalize() for p in procedures]

h, w = panels[0].shape[:2]
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
thickness = 1

labeled = []
for img, label in zip(panels, labels):
bar = np.zeros((label_height, w, 3), dtype=np.uint8)
text_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
x = (w - text_size[0]) // 2
y = (label_height + text_size[1]) // 2
cv2.putText(bar, label, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
labeled.append(np.vstack([bar, img]))

return np.hstack(labeled)


if __name__ == "__main__":
import argparse

Expand Down
Loading