Skip to content

idk if this works, but here's a small clip implementation #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 52 additions & 39 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,27 +167,31 @@ def __getitem__(self, index):
else:
text = descriptions
# max length from the paper
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True,
)
if self.tokenizer is not None:
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True,
)

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
else:
input_ids = [0]
attn_mask = [0]

if self.using_taming:
if self.embeds:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text
else:
return self.transform(image), input_ids[0], attn_mask[0], []
return self.transform(image), input_ids[0], attn_mask[0], [], text
else:
if self.embeds:
return self.transform(image), input_ids[0], attn_mask[0], embed
return self.transform(image), input_ids[0], attn_mask[0], embed, text
else:
return self.transform(image), input_ids[0], attn_mask[0], []
return self.transform(image), input_ids[0], attn_mask[0], [], text


class URLTextDataset(ImageDataset):
Expand Down Expand Up @@ -242,27 +246,31 @@ def __getitem__(self, index):
else:
text = descriptions
# max length from the paper
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True,
)
if self.tokenizer is not None:
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True,
)

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
else:
input_ids = [0]
attn_mask = [0]

if self.using_taming:
if self.embeds:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text
else:
return self.transform(image), input_ids[0], attn_mask[0], []
return self.transform(image), input_ids[0], attn_mask[0], [], text
else:
if self.embeds:
return self.transform(image), input_ids[0], attn_mask[0], embed
return self.transform(image), input_ids[0], attn_mask[0], embed, text
else:
return self.transform(image), input_ids[0], attn_mask[0], []
return self.transform(image), input_ids[0], attn_mask[0], [], text


class LocalTextImageDataset(Dataset):
Expand Down Expand Up @@ -338,26 +346,31 @@ def __getitem__(self, index):
embed = self.embeds[index]

# max length from the paper
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True,
)
if self.tokenizer is not None:
encoded = self.tokenizer.batch_encode_plus(
[str(text)],
return_tensors="pt",
padding="max_length",
max_length=MAX_LENGTH,
truncation=True,
)

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
else:
input_ids = [0]
attn_mask = [0]

input_ids = encoded.input_ids
attn_mask = encoded.attention_mask
if self.using_taming:
if self.embeds:
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed
return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text
else:
return self.transform(image), input_ids[0], attn_mask[0], []
return self.transform(image), input_ids[0], attn_mask[0], [], text
else:
if self.embeds:
return self.transform(image), input_ids[0], attn_mask[0], embed
return self.transform(image), input_ids[0], attn_mask[0], embed, text
else:
return self.transform(image), input_ids[0], attn_mask[0], []
return self.transform(image), input_ids[0], attn_mask[0], [], text


def get_directory_size(path):
Expand Down
116 changes: 84 additions & 32 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from einops import rearrange, repeat
from torch import einsum, isnan, nn
from tqdm.auto import tqdm
from transformers import T5EncoderModel, T5Tokenizer
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer

from .attn import ein_attn, sdp_attn
from .t5 import DEFAULT_T5_NAME, get_encoded_dim, get_model_and_tokenizer, t5_encode_text
Expand Down Expand Up @@ -169,6 +169,7 @@ def __init__(
self_cond: bool = False,
add_mask_id: bool = False,
cache_path: PathLike = None,
use_clip=False,
**kwargs,
):
super().__init__()
Expand All @@ -183,17 +184,24 @@ def __init__(
self.transformer_blocks = TransformerBlocks(dim=dim, **kwargs)
self.norm = LayerNorm(dim)

self.use_clip = use_clip
self.tokenizer = None

self.dim_out = default(dim_out, num_tokens)
self.to_logits = nn.Linear(dim, self.dim_out, bias=False)

# text conditioning
t5, tokenizer = get_model_and_tokenizer(t5_name, cache_path)
self.t5: T5EncoderModel = t5
self.tokenizer: T5Tokenizer = tokenizer
if not use_clip:
t5, tokenizer = get_model_and_tokenizer(t5_name, cache_path)
self.t5: T5EncoderModel = t5
self.tokenizer: T5Tokenizer = tokenizer

self.t5.eval()

self.t5.eval()
text_embed_dim = get_encoded_dim(t5_name)

text_embed_dim = get_encoded_dim(t5_name)
else:
text_embed_dim = 512

self.text_embed_proj = (
nn.Linear(text_embed_dim, dim, bias=False) if text_embed_dim != dim else nn.Identity()
Expand All @@ -204,8 +212,11 @@ def __init__(
self.self_cond_to_init_embed = FeedForward(dim)

def encode_text(self, *args, **kwargs):
kwargs.update(tokenizer=self.tokenizer, t5=self.t5)
return t5_encode_text(*args, **kwargs)
if not self.use_clip:
kwargs.update(tokenizer=self.tokenizer, t5=self.t5)
return t5_encode_text(*args, **kwargs)
else:
print("Using clip instead, this function shouldn't be accessed")

def forward_with_cond_scale(self, *args, cond_scale=3.0, return_embed=False, **kwargs):
if cond_scale == 1:
Expand Down Expand Up @@ -406,6 +417,8 @@ def __init__(
self,
image_size,
transformer: MaskGitTransformer,
clip: CLIPTextModel = None,
clip_tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast] = None,
accelerator: Optional[Accelerator] = None,
noise_schedule: Callable = cosine_schedule,
token_critic: Optional[TokenCritic] = None,
Expand Down Expand Up @@ -435,6 +448,9 @@ def __init__(
self.resize_image_for_cond_image = exists(cond_image_size)
self.cond_drop_prob = cond_drop_prob

self.clip = clip
self.clip_tokenizer = clip_tokenizer

self.transformer = transformer
self.self_cond = transformer.self_cond
if not self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens:
Expand Down Expand Up @@ -513,34 +529,66 @@ def generate(

cond_ids = None

text_embeds = self.transformer.encode_text(texts)

demask_fn = self.transformer.forward_with_cond_scale

# whether to use token critic for scores

use_token_critic = exists(self.token_critic) and not force_not_use_token_critic

# whether to use token critic for scores
if use_token_critic:
token_critic_fn = self.token_critic.forward_with_cond_scale

# negative prompting, as in paper
if self.clip is not None and self.clip_tokenizer is not None:
clip_model = self.clip
clip_tokenizer = self.clip_tokenizer
print(texts)
inputs = [token[1:-1] for token in clip_tokenizer(texts, truncation=True).input_ids]

inputs = torch.tensor(inputs, device=self.accelerator.device)
max_embeddings_multiples = (inputs.shape[1] - 2) // (75 - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = inputs[:, i * (75 - 2) : (i + 1) * (75 - 2) + 2].clone()

# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = inputs[0, 0]
text_input_chunk[:, -1] = inputs[0, -1]
text_embedding = clip_model(text_input_chunk)[0]

if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]

text_embeddings.append(text_embedding)
text_embeds = torch.concat(text_embeddings, axis=1).to(self.accelerator.device)
else:
text_embeds = clip_model(inputs)[0].to(self.accelerator.device)
else:
text_embeds = self.transformer.encode_text(texts)

demask_fn = self.transformer.forward_with_cond_scale

neg_text_embeds = None
if exists(negative_texts):
assert len(texts) == len(negative_texts)
# negative prompting, as in paper

neg_text_embeds = None
if exists(negative_texts):
assert len(texts) == len(negative_texts)

neg_text_embeds = self.transformer.encode_text(negative_texts)
demask_fn = partial(
self.transformer.forward_with_neg_prompt,
neg_text_embeds=neg_text_embeds,
)

if use_token_critic:
token_critic_fn = partial(
self.token_critic.forward_with_neg_prompt,
neg_text_embeds=neg_text_embeds,
)
if use_token_critic:
token_critic_fn = partial(
self.token_critic.forward_with_neg_prompt,
neg_text_embeds=neg_text_embeds,
)

if self.resize_image_for_cond_image:
if cond_images is None:
Expand All @@ -565,14 +613,18 @@ def generate(

ids = ids.scatter(1, masked_indices, self.mask_id)

logits, embed = demask_fn(
ids,
text_embeds=text_embeds,
self_cond_embed=self_cond_embed,
conditioning_token_ids=cond_ids,
cond_scale=cond_scale,
return_embed=True,
)
if self.clip is None:
logits, embed = demask_fn(
ids,
text_embeds=text_embeds,
self_cond_embed=self_cond_embed,
conditioning_token_ids=cond_ids,
cond_scale=cond_scale,
return_embed=True,
)
else:
embed = text_embeds
logits = text_embeds

self_cond_embed = embed if self.self_cond else None

Expand Down
Loading