8
8
from ...loaders import PeftAdapterMixin
9
9
from ...models .modeling_outputs import Transformer2DModelOutput
10
10
from ...models .modeling_utils import ModelMixin
11
- from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
11
+ from ...utils import USE_PEFT_BACKEND , deprecate , logging , scale_lora_layers , unscale_lora_layers
12
12
from ...utils .torch_utils import maybe_allow_in_graph
13
13
from ..attention import Attention
14
14
from ..embeddings import TimestepEmbedding , Timesteps
@@ -686,46 +686,108 @@ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_train
686
686
x = torch .cat (x_arr , dim = 0 )
687
687
return x
688
688
689
- def patchify (self , x , max_seq , img_sizes = None ):
690
- pz2 = self .config .patch_size * self .config .patch_size
691
- if isinstance (x , torch .Tensor ):
692
- B , C = x .shape [0 ], x .shape [1 ]
693
- device = x .device
694
- dtype = x .dtype
689
+ def patchify (self , hidden_states ):
690
+ batch_size , channels , height , width = hidden_states .shape
691
+ patch_size = self .config .patch_size
692
+ patch_height , patch_width = height // patch_size , width // patch_size
693
+ device = hidden_states .device
694
+ dtype = hidden_states .dtype
695
+
696
+ # create img_sizes
697
+ img_sizes = torch .tensor ([patch_height , patch_width ], dtype = torch .int64 , device = device ).reshape (- 1 )
698
+ img_sizes = img_sizes .unsqueeze (0 ).repeat (batch_size , 1 )
699
+
700
+ # create hidden_states_masks
701
+ if hidden_states .shape [- 2 ] != hidden_states .shape [- 1 ]:
702
+ hidden_states_masks = torch .zeros ((batch_size , self .max_seq ), dtype = dtype , device = device )
703
+ hidden_states_masks [:, : patch_height * patch_width ] = 1.0
695
704
else :
696
- B , C = len (x ), x [0 ].shape [0 ]
697
- device = x [0 ].device
698
- dtype = x [0 ].dtype
699
- x_masks = torch .zeros ((B , max_seq ), dtype = dtype , device = device )
705
+ hidden_states_masks = None
706
+
707
+ # create img_ids
708
+ img_ids = torch .zeros (patch_height , patch_width , 3 , device = device )
709
+ row_indices = torch .arange (patch_height , device = device )[:, None ]
710
+ col_indices = torch .arange (patch_width , device = device )[None , :]
711
+ img_ids [..., 1 ] = img_ids [..., 1 ] + row_indices
712
+ img_ids [..., 2 ] = img_ids [..., 2 ] + col_indices
713
+ img_ids = img_ids .reshape (patch_height * patch_width , - 1 )
714
+
715
+ if hidden_states .shape [- 2 ] != hidden_states .shape [- 1 ]:
716
+ # Handle non-square latents
717
+ img_ids_pad = torch .zeros (self .max_seq , 3 , device = device )
718
+ img_ids_pad [: patch_height * patch_width , :] = img_ids
719
+ img_ids = img_ids_pad .unsqueeze (0 ).repeat (batch_size , 1 , 1 )
720
+ else :
721
+ img_ids = img_ids .unsqueeze (0 ).repeat (batch_size , 1 , 1 )
722
+
723
+ # patchify hidden_states
724
+ if hidden_states .shape [- 2 ] != hidden_states .shape [- 1 ]:
725
+ # Handle non-square latents
726
+ out = torch .zeros (
727
+ (batch_size , channels , self .max_seq , patch_size * patch_size ),
728
+ dtype = dtype ,
729
+ device = device ,
730
+ )
731
+ hidden_states = hidden_states .reshape (
732
+ batch_size , channels , patch_height , patch_size , patch_width , patch_size
733
+ )
734
+ hidden_states = hidden_states .permute (0 , 1 , 2 , 4 , 3 , 5 )
735
+ hidden_states = hidden_states .reshape (
736
+ batch_size , channels , patch_height * patch_width , patch_size * patch_size
737
+ )
738
+ out [:, :, 0 : patch_height * patch_width ] = hidden_states
739
+ hidden_states = out
740
+ hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (
741
+ batch_size , self .max_seq , patch_size * patch_size * channels
742
+ )
700
743
701
- if img_sizes is not None :
702
- for i , img_size in enumerate (img_sizes ):
703
- x_masks [i , 0 : img_size [0 ] * img_size [1 ]] = 1
704
- B , C , S , _ = x .shape
705
- x = x .permute (0 , 2 , 3 , 1 ).reshape (B , S , pz2 * C )
706
- elif isinstance (x , torch .Tensor ):
707
- B , C , Hp1 , Wp2 = x .shape
708
- pH , pW = Hp1 // self .config .patch_size , Wp2 // self .config .patch_size
709
- x = x .reshape (B , C , pH , self .config .patch_size , pW , self .config .patch_size )
710
- x = x .permute (0 , 2 , 4 , 3 , 5 , 1 )
711
- x = x .reshape (B , pH * pW , self .config .patch_size * self .config .patch_size * C )
712
- img_sizes = [[pH , pW ]] * B
713
- x_masks = None
714
744
else :
715
- raise NotImplementedError
716
- return x , x_masks , img_sizes
745
+ # Handle square latents
746
+ hidden_states = hidden_states .reshape (
747
+ batch_size , channels , patch_height , patch_size , patch_width , patch_size
748
+ )
749
+ hidden_states = hidden_states .permute (0 , 2 , 4 , 3 , 5 , 1 )
750
+ hidden_states = hidden_states .reshape (
751
+ batch_size , patch_height * patch_width , patch_size * patch_size * channels
752
+ )
753
+
754
+ return hidden_states , hidden_states_masks , img_sizes , img_ids
717
755
718
756
def forward (
719
757
self ,
720
758
hidden_states : torch .Tensor ,
721
759
timesteps : torch .LongTensor = None ,
722
- encoder_hidden_states : torch .Tensor = None ,
760
+ encoder_hidden_states_t5 : torch .Tensor = None ,
761
+ encoder_hidden_states_llama3 : torch .Tensor = None ,
723
762
pooled_embeds : torch .Tensor = None ,
724
- img_sizes : Optional [List [Tuple [int , int ]]] = None ,
725
763
img_ids : Optional [torch .Tensor ] = None ,
764
+ img_sizes : Optional [List [Tuple [int , int ]]] = None ,
765
+ hidden_states_masks : Optional [torch .Tensor ] = None ,
726
766
attention_kwargs : Optional [Dict [str , Any ]] = None ,
727
767
return_dict : bool = True ,
768
+ ** kwargs ,
728
769
):
770
+ encoder_hidden_states = kwargs .get ("encoder_hidden_states" , None )
771
+
772
+ if encoder_hidden_states is not None :
773
+ deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
774
+ deprecate ("encoder_hidden_states" , "0.34.0" , deprecation_message )
775
+ encoder_hidden_states_t5 = encoder_hidden_states [0 ]
776
+ encoder_hidden_states_llama3 = encoder_hidden_states [1 ]
777
+
778
+ if img_ids is not None and img_sizes is not None and hidden_states_masks is None :
779
+ deprecation_message = (
780
+ "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
781
+ )
782
+ deprecate ("img_ids" , "0.34.0" , deprecation_message )
783
+
784
+ if hidden_states_masks is not None and (img_ids is None or img_sizes is None ):
785
+ raise ValueError ("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed." )
786
+ elif hidden_states_masks is not None and hidden_states .ndim != 3 :
787
+ raise ValueError (
788
+ "if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
789
+ )
790
+
729
791
if attention_kwargs is not None :
730
792
attention_kwargs = attention_kwargs .copy ()
731
793
lora_scale = attention_kwargs .pop ("scale" , 1.0 )
@@ -745,42 +807,19 @@ def forward(
745
807
batch_size = hidden_states .shape [0 ]
746
808
hidden_states_type = hidden_states .dtype
747
809
748
- if hidden_states .shape [- 2 ] != hidden_states .shape [- 1 ]:
749
- B , C , H , W = hidden_states .shape
750
- patch_size = self .config .patch_size
751
- pH , pW = H // patch_size , W // patch_size
752
- out = torch .zeros (
753
- (B , C , self .max_seq , patch_size * patch_size ),
754
- dtype = hidden_states .dtype ,
755
- device = hidden_states .device ,
756
- )
757
- hidden_states = hidden_states .reshape (B , C , pH , patch_size , pW , patch_size )
758
- hidden_states = hidden_states .permute (0 , 1 , 2 , 4 , 3 , 5 )
759
- hidden_states = hidden_states .reshape (B , C , pH * pW , patch_size * patch_size )
760
- out [:, :, 0 : pH * pW ] = hidden_states
761
- hidden_states = out
810
+ # Patchify the input
811
+ if hidden_states_masks is None :
812
+ hidden_states , hidden_states_masks , img_sizes , img_ids = self .patchify (hidden_states )
813
+
814
+ # Embed the hidden states
815
+ hidden_states = self .x_embedder (hidden_states )
762
816
763
817
# 0. time
764
818
timesteps = self .t_embedder (timesteps , hidden_states_type )
765
819
p_embedder = self .p_embedder (pooled_embeds )
766
820
temb = timesteps + p_embedder
767
821
768
- hidden_states , hidden_states_masks , img_sizes = self .patchify (hidden_states , self .max_seq , img_sizes )
769
- if hidden_states_masks is None :
770
- pH , pW = img_sizes [0 ]
771
- img_ids = torch .zeros (pH , pW , 3 , device = hidden_states .device )
772
- img_ids [..., 1 ] = img_ids [..., 1 ] + torch .arange (pH , device = hidden_states .device )[:, None ]
773
- img_ids [..., 2 ] = img_ids [..., 2 ] + torch .arange (pW , device = hidden_states .device )[None , :]
774
- img_ids = (
775
- img_ids .reshape (img_ids .shape [0 ] * img_ids .shape [1 ], img_ids .shape [2 ])
776
- .unsqueeze (0 )
777
- .repeat (batch_size , 1 , 1 )
778
- )
779
- hidden_states = self .x_embedder (hidden_states )
780
-
781
- T5_encoder_hidden_states = encoder_hidden_states [0 ]
782
- encoder_hidden_states = encoder_hidden_states [- 1 ]
783
- encoder_hidden_states = [encoder_hidden_states [k ] for k in self .config .llama_layers ]
822
+ encoder_hidden_states = [encoder_hidden_states_llama3 [k ] for k in self .config .llama_layers ]
784
823
785
824
if self .caption_projection is not None :
786
825
new_encoder_hidden_states = []
@@ -789,9 +828,9 @@ def forward(
789
828
enc_hidden_state = enc_hidden_state .view (batch_size , - 1 , hidden_states .shape [- 1 ])
790
829
new_encoder_hidden_states .append (enc_hidden_state )
791
830
encoder_hidden_states = new_encoder_hidden_states
792
- T5_encoder_hidden_states = self .caption_projection [- 1 ](T5_encoder_hidden_states )
793
- T5_encoder_hidden_states = T5_encoder_hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
794
- encoder_hidden_states .append (T5_encoder_hidden_states )
831
+ encoder_hidden_states_t5 = self .caption_projection [- 1 ](encoder_hidden_states_t5 )
832
+ encoder_hidden_states_t5 = encoder_hidden_states_t5 .view (batch_size , - 1 , hidden_states .shape [- 1 ])
833
+ encoder_hidden_states .append (encoder_hidden_states_t5 )
795
834
796
835
txt_ids = torch .zeros (
797
836
batch_size ,
0 commit comments