diff --git a/hidden_context/train_llm_vae_preference_model.py b/hidden_context/train_llm_vae_preference_model.py index 06df6e9..d115658 100644 --- a/hidden_context/train_llm_vae_preference_model.py +++ b/hidden_context/train_llm_vae_preference_model.py @@ -15,7 +15,7 @@ AutoModelForCausalLM, ) from transformers.utils import PaddingStrategy -from .vae_utils import VAETrainer, VAEModel +from .vae_utils import VAETrainer, VAEModel, VQVAETrainer, VQVAE_Encoder from .train_llm_preference_model import ( get_step_decay_lr_lambda, @@ -254,6 +254,7 @@ def __call__(self, examples): trainer_classes: Dict[RewardModelType, Type[VAETrainer]] = { "vae": VAETrainer, + "vqvae": VQVAETrainer } @@ -607,8 +608,12 @@ def up_sample_controversial(dataset, seed): embed_dim = script_args.embed_dim if not script_args.use_causal_lm: + if script_args.reward_model_type == "vqvae": + num_labels = 1 + else: + num_labels = embed_dim model = AutoModelForSequenceClassification.from_pretrained( - script_args.model_name, num_labels=embed_dim, torch_dtype=torch.bfloat16 + script_args.model_name, num_labels=num_labels, torch_dtype=torch.bfloat16 ) # We multiply the final linear layer's weights by 0.01 because this seems to # significantly stabilize training and lead to better optimization of the loss. @@ -653,10 +658,32 @@ def up_sample_controversial(dataset, seed): # Train the model. latent_dim = script_args.latent_dim hidden_dim = script_args.hidden_dim - vae_model = VAEModel(embed_dim, hidden_dim, latent_dim, model, - fixed_contexts=script_args.fixed_contexts, - fixed_llm_embeddings=script_args.fixed_llm_embeddings, - use_causal_lm=script_args.use_causal_lm,) + if script_args.reward_model_type == "vae": + vae_model = VAEModel(embed_dim, hidden_dim, latent_dim, model, + fixed_contexts=script_args.fixed_contexts, + fixed_llm_embeddings=script_args.fixed_llm_embeddings, + use_causal_lm=script_args.use_causal_lm,) + elif script_args.reward_model_type == "vqvae": + if script_args.model_name == 'gpt2': + embed_dim = 768 + if script_args.model_name == 'meta-llama/Llama-2-7b-hf': + embed_dim = 4096 + + if script_args.use_causal_lm: + context_dim = embed_dim + else: + context_dim = script_args.embed_dim + + vae_model = VQVAE_Encoder( + script_args.n_embeddings, + embed_dim, + hidden_dim, + model, + context_dim=context_dim, + fixed_contexts=script_args.fixed_contexts, + fixed_llm_embeddings=script_args.fixed_llm_embeddings, + use_causal_lm=script_args.use_causal_lm, + ) trainer = trainer_class( model=vae_model, diff --git a/hidden_context/vae_utils.py b/hidden_context/vae_utils.py index 35aeb63..1ba6311 100644 --- a/hidden_context/vae_utils.py +++ b/hidden_context/vae_utils.py @@ -9,7 +9,7 @@ import torch.nn as nn from transformers import Trainer, EvalPrediction import wandb - +import matplotlib.pyplot as plt class PairEncoder(nn.Module): """ @@ -458,3 +458,269 @@ def cyclical_setter(self, value): else: self.cyclical = value return + + +class VQVAE_Encoder(nn.Module): + def __init__(self, n_embeddings, embed_dim, hidden_dim, llm_encoder, context_dim=None, commitment_cost=0.25, decay=0.999, epsilon=1e-5, fixed_contexts=False, fixed_llm_embeddings=False, use_causal_lm=False): + super(VQVAE_Encoder, self).__init__() + self.commitment_cost = commitment_cost + self.decay = decay + self.epsilon = epsilon + + self.llm_encoder = llm_encoder + self.pair_encoder = PairEncoder(context_dim, hidden_dim, embed_dim) + self.sequence_encoder = SequenceEncoder(embed_dim, embed_dim) + + #TODO: initialise using llms + mean_wte = llm_encoder.transformer.wte.weight.mean(0) + weights = torch.randn(size=(n_embeddings, embed_dim), dtype=mean_wte.dtype) + mean_wte + self.embedding = nn.Parameter(weights, requires_grad=True) + # self.register_buffer("embedding", embedding) + # self.register_buffer("ema_count", torch.zeros(n_embeddings)) + # self.register_buffer("ema_weight", self.embedding.clone()) + + self.fixed_contexts = fixed_contexts + self.fixed_llm_embeddings = fixed_llm_embeddings + self.use_causal_lm = use_causal_lm + + def encode_pair(self, e_c, e_r): + return self.pair_encoder(e_c, e_r) + + def encode_sequence(self, sequences, seq_start_end): + e_z, _ = self.sequence_encoder(sequences, seq_start_end) + return e_z + + # def discretize(self, x): + # M, D = self.embedding.size() + # x_flat = x.detach().reshape(-1, D) + + # distances = (-torch.cdist(x_flat, self.embedding, p=2)) ** 2 + + # indices = torch.argmin(distances.float(), dim=-1) + # quantized = F.embedding(indices, self.embedding) + # quantized = quantized.view_as(x) + # return quantized, indices + + def retrieve_random_codebook(self, random_indices): + quantized = F.embedding(random_indices, self.embedding) + quantized = quantized.transpose(1, 3) + + return quantized + + def gt_forward( + self, + user_type, + seq_start_end, + ): + quantized = self.embedding[user_type.long()] + commitment_loss = torch.Tensor([0.0]) + codebook_loss = torch.Tensor([0.0]) + # import pdb; pdb.set_trace() + return quantized, commitment_loss, codebook_loss, user_type #, perplexity + def forward( + self, + context_chosen, + context_rejected, + seq_start_end, + user_type, + ground_truth_user_vector=True + ): + # import pdb; pdb.set_trace() + if ground_truth_user_vector: + return self.gt_forward(user_type, seq_start_end) + pair_embed = self.encode_pair(context_chosen, context_rejected) + x = self.encode_sequence(pair_embed, seq_start_end) + M, D = self.embedding.size() + x_flat = x.detach().reshape(-1, D) + + distances = (-torch.cdist(x_flat, self.embedding, p=2)) ** 2 + + indices = torch.argmin(distances.float(), dim=-1) + encodings = F.one_hot(indices, M).float() + quantized = F.embedding(indices, self.embedding) + quantized = quantized.view_as(x) + + #TODO: fix EMA loss + # if self.training: + # self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0) + # n = torch.sum(self.ema_count) + # self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n + + # dw = torch.matmul(encodings.t(), x_flat) + # self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw + # self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1) + + #TODO: look at how losses flow? do we need to pass in gradients to the embeddings or the codebook loss works? + codebook_loss = F.mse_loss(x_flat.detach(), quantized) * 0.1 + e_latent_loss = F.mse_loss(x_flat, quantized.detach()) + commitment_loss = self.commitment_cost * e_latent_loss * 0.1 + + quantized = x + (quantized - x).detach() + + # avg_probs = torch.mean(encodings, dim=0) + # perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + # import pdb; pdb.set_trace() + return quantized, commitment_loss, codebook_loss, indices #, perplexity + + def save_model(self, path): + torch.save(self, path) + +class VQVAETrainer(VAETrainer): + def __init__( + self, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.pad_token = torch.tensor([50256]) + + def compute_loss(self, wrapped_model, inputs, return_outputs=False): + model = wrapped_model # .module + device = model.llm_encoder.device + batch_size = inputs["seq_start_end"].shape[0] + self.pad_token = self.pad_token.to(device) + + embeddings_chosen = model.llm_encoder.transformer.wte(inputs["input_ids_chosen"]) + embeddings_rejected = model.llm_encoder.transformer.wte(inputs["input_ids_rejected"]) + attention_mask_padding = torch.ones_like(inputs["attention_mask_chosen"][:, None, 0]) + attention_mask_chosen = torch.cat((attention_mask_padding, inputs["attention_mask_chosen"]), dim=-1) + attention_mask_rejected = torch.cat((attention_mask_padding, inputs["attention_mask_rejected"]), dim=-1) + + seq_len_chosen = (inputs["input_ids_chosen"] != self.pad_token).sum(dim=1) + seq_len_rejected = (inputs["input_ids_rejected"] != self.pad_token).sum(dim=1) + seq_len = torch.cat([seq_len_chosen, seq_len_rejected])+1 + + if model.fixed_contexts: + contexts_embeddings_chosen = torch.tensor(inputs["contexts_embeddings_chosen"]).to(device).bfloat16() + contexts_embeddings_rejected = torch.tensor(inputs["contexts_embeddings_rejected"]).to(device).bfloat16() + else: + if model.use_causal_lm: + last_hidden_state_chosen = model.llm_encoder( + input_ids=inputs["contexts_input_ids_chosen"], + attention_mask=inputs["contexts_attention_mask_chosen"], + output_hidden_states=True + ).hidden_states[-1] + masked_last_hidden_state_chosen = last_hidden_state_chosen * inputs[ + "contexts_attention_mask_chosen"].unsqueeze(-1) + token_length_chosen = torch.sum(inputs["contexts_attention_mask_chosen"], dim=1) + contexts_embeddings_chosen = torch.sum(masked_last_hidden_state_chosen, + dim=1) / token_length_chosen.unsqueeze(-1) + + last_hidden_state_rejected = model.llm_encoder( + input_ids=inputs["contexts_input_ids_rejected"], + attention_mask=inputs["contexts_attention_mask_rejected"], + output_hidden_states=True + ).hidden_states[-1] + masked_last_hidden_state_rejected = last_hidden_state_rejected * inputs[ + "contexts_attention_mask_rejected"].unsqueeze(-1) + token_length_rejected = torch.sum(inputs["contexts_attention_mask_rejected"], dim=1) + contexts_embeddings_rejected = torch.sum(masked_last_hidden_state_rejected, + dim=1) / token_length_rejected.unsqueeze(-1) + else: + contexts_embeddings_chosen = model.llm_encoder( + inputs["contexts_input_ids_chosen"], + inputs["contexts_attention_mask_chosen"] + )[0] + contexts_embeddings_rejected = model.llm_encoder( + inputs["contexts_input_ids_rejected"], + inputs["contexts_attention_mask_rejected"] + )[0] + seq_start_end = inputs["seq_start_end"] + user_type = torch.tensor(inputs["user_type"]).to(device).bfloat16() + + quantized, commitment_loss, codebook_loss, indices = model( + contexts_embeddings_chosen, + contexts_embeddings_rejected, + seq_start_end, + user_type, + ground_truth_user_vector=False # todo: set to True for debug usage + ) + quantized = quantized.to(device).bfloat16() + + embeddings_chosen = torch.cat((quantized[:, None], embeddings_chosen), dim=1) + embeddings_rejected = torch.cat((quantized[:, None], embeddings_rejected), dim=1) + + output_dict = model.llm_encoder( + inputs_embeds=torch.concatenate( + [ + embeddings_chosen, + embeddings_rejected, + ], + dim=0, + ), + attention_mask=torch.concatenate( + [ + attention_mask_chosen, + attention_mask_rejected, + ], + dim=0, + ), + return_dict=True, + output_hidden_states=True + ) + + batch_indices = torch.arange(len(seq_len)).to(device) + hidden_states = output_dict["hidden_states"][-1][batch_indices, seq_len] + rewards = model.llm_encoder.score(hidden_states) + + # rewards = rewards[0] + rewards_chosen = rewards[:batch_size] + rewards_rejected = rewards[batch_size:] + + reproduction_loss = self.loss(rewards_chosen, rewards_rejected) + loss = reproduction_loss # + commitment_loss + codebook_loss + + if return_outputs: + return loss, { + "rewards_chosen": rewards_chosen, + "rewards_rejected": rewards_rejected, + "commitment_loss": commitment_loss, + "codebook_loss": codebook_loss, + "z": quantized, + "user_type": user_type, + "indices": indices, + "embeddings": model.embedding + } + else: + accuracy = torch.mean((rewards_chosen > rewards_rejected).float()) + self.log( + { + "rewards_chosen": rewards_chosen.mean().item(), + "rewards_rejected": rewards_rejected.mean().item(), + "train_commitment_loss": commitment_loss.item(), + "train_codebook_loss": codebook_loss.item(), + "train_loss": loss.item(), + "train_reproduction_loss": reproduction_loss.item(), + "train_accuracy": accuracy + } + ) + return loss + + @classmethod + def compute_metrics(cls, eval_prediction: EvalPrediction): + rewards_chosen, rewards_rejected, commitment_loss, codebook_loss, z, user_type, indices, embeddings = ( + eval_prediction.predictions + ) + rewards_chosen = torch.from_numpy(rewards_chosen) + rewards_rejected = torch.from_numpy(rewards_rejected) + + loss = cls.per_sample_loss(rewards_chosen, rewards_rejected) + accuracy = torch.mean((rewards_chosen > rewards_rejected).float())#torch.mean((loss < np.log(2)).float()) + + # import pdb; pdb.set_trace() + embeddings_table = wandb.Table(columns=list(range(z.shape[1])), data=embeddings) + + unique_users = np.unique(user_type) + fig, axs = plt.subplots(1, len(unique_users), figsize=(20,5)) + for i, uid in enumerate(unique_users): + user_indices = indices[np.argwhere(user_type == uid)] + axs[i].hist(user_indices) + axs[i].set_title(f"User {i}") + im = wandb.Image(fig) + + return { + "reproduction_loss": loss.mean().item(), + "accuracy": accuracy.item(), + "commitment_loss": commitment_loss.mean().item(), + "codebook_loss": codebook_loss.mean().item(), + "embeddings_table": embeddings_table, + "latents": im + } \ No newline at end of file