diff --git a/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py b/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py index e45f431d0b9d..51764e8de6df 100644 --- a/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py +++ b/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py @@ -299,20 +299,23 @@ def preprocess( # Follows diffusers.VaeImageProcessor.postprocess def postprocess(self, sample: torch.Tensor, output_type: str = "pil"): - if output_type not in ["pt", "np", "pil"]: + if output_type not in {"pt", "np", "pil"}: raise ValueError( f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" ) # Equivalent to diffusers.VaeImageProcessor.denormalize - sample = (sample / 2 + 0.5).clamp(0, 1) + sample = (sample / 2 + 0.5).clamp_(0, 1) if output_type == "pt": return sample + # Only move to CPU and numpy if necessary + if sample.device.type != "cpu": + sample = sample.cpu() # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy - sample = sample.cpu().permute(0, 2, 3, 1).numpy() + sample = sample.permute(0, 2, 3, 1).contiguous().numpy() if output_type == "np": return sample + # Output_type must be 'pil' - sample = numpy_to_pil(sample) - return sample + return numpy_to_pil(sample) diff --git a/src/diffusers/utils/pil_utils.py b/src/diffusers/utils/pil_utils.py index 76678070b697..7a9a90803cc6 100644 --- a/src/diffusers/utils/pil_utils.py +++ b/src/diffusers/utils/pil_utils.py @@ -38,16 +38,15 @@ def numpy_to_pil(images): """ Convert a numpy image or a batch of images to a PIL image. """ + # If single HWC image, expand dims to NHWC if images.ndim == 3: images = images[None, ...] - images = (images * 255).round().astype("uint8") + images = (images * 255).round().astype("uint8", copy=False) if images.shape[-1] == 1: # special case for grayscale (single channel) images - pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + return [Image.fromarray(image[..., 0], mode="L") for image in images] else: - pil_images = [Image.fromarray(image) for image in images] - - return pil_images + return [Image.fromarray(image) for image in images] def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image: