@@ -236,7 +236,7 @@ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, to
236
236
`np.ndarray` or `torch.Tensor`:
237
237
The denormalized image array.
238
238
"""
239
- return (images / 2 + 0.5 ).clamp (0 , 1 )
239
+ return (images * 0.5 + 0.5 ).clamp (0 , 1 )
240
240
241
241
@staticmethod
242
242
def convert_to_rgb (image : PIL .Image .Image ) -> PIL .Image .Image :
@@ -537,6 +537,26 @@ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
537
537
538
538
return image
539
539
540
+ def _denormalize_conditionally (
541
+ self , images : torch .Tensor , do_denormalize : Optional [List [bool ]] = None
542
+ ) -> torch .Tensor :
543
+ r"""
544
+ Denormalize a batch of images based on a condition list.
545
+
546
+ Args:
547
+ images (`torch.Tensor`):
548
+ The input image tensor.
549
+ do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
550
+ A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
551
+ value of `do_normalize` in the `VaeImageProcessor` config.
552
+ """
553
+ if do_denormalize is None :
554
+ return self .denormalize (images ) if self .config .do_normalize else images
555
+
556
+ return torch .stack (
557
+ [self .denormalize (images [i ]) if do_denormalize [i ] else images [i ] for i in range (images .shape [0 ])]
558
+ )
559
+
540
560
def get_default_height_width (
541
561
self ,
542
562
image : Union [PIL .Image .Image , np .ndarray , torch .Tensor ],
@@ -752,12 +772,7 @@ def postprocess(
752
772
if output_type == "latent" :
753
773
return image
754
774
755
- if do_denormalize is None :
756
- do_denormalize = [self .config .do_normalize ] * image .shape [0 ]
757
-
758
- image = torch .stack (
759
- [self .denormalize (image [i ]) if do_denormalize [i ] else image [i ] for i in range (image .shape [0 ])]
760
- )
775
+ image = self ._denormalize_conditionally (image , do_denormalize )
761
776
762
777
if output_type == "pt" :
763
778
return image
@@ -966,12 +981,7 @@ def postprocess(
966
981
deprecate ("Unsupported output_type" , "1.0.0" , deprecation_message , standard_warn = False )
967
982
output_type = "np"
968
983
969
- if do_denormalize is None :
970
- do_denormalize = [self .config .do_normalize ] * image .shape [0 ]
971
-
972
- image = torch .stack (
973
- [self .denormalize (image [i ]) if do_denormalize [i ] else image [i ] for i in range (image .shape [0 ])]
974
- )
984
+ image = self ._denormalize_conditionally (image , do_denormalize )
975
985
976
986
image = self .pt_to_numpy (image )
977
987
0 commit comments