diff --git a/moondream/finetune/README.md b/moondream/finetune/README.md deleted file mode 100644 index 80ca2e8a..00000000 --- a/moondream/finetune/README.md +++ /dev/null @@ -1,112 +0,0 @@ -# Finetuning Moondream 2B - -This readme will walk you through the process of finetuning the text and region encoders of the Moondream 2B model. - -> Make sure to run all commands from the root directory of the project. - -## Initial Setup - -### Clone and Setup Environment -```bash -git clone https://github.com/vikhyat/moondream -cd moondream -python -m venv .venv -source .venv/bin/activate -``` - -### Install Dependencies -```bash -# Install base requirements -pip install -r requirements.txt - -# Install finetuning specific dependencies -pip install safetensors datasets bitsandbytes tqdm wandb einops -``` - -## Downloading the Base Model - -Download `model.safetensors` from the [Hugging Face repository](https://huggingface.co/vikhyatk/moondream2/tree/main) and place it in the `models` directory as `moondream_base.safetensors`. - -```bash -# Create models directory -mkdir -p models - -# Download it using curl (run from root moondream directory) -wget https://huggingface.co/vikhyatk/moondream2/resolve/main/model.safetensors -``` - -## Weights & Biases - -We use Weights & Biases (wandb) to track finetuning progress. - -To set it up to track your runs, use `wandb login`. - -This will take you through creating an account if you don't have one setup already. Enter your API key and you're ready to go. - -## Finetuning the Text Encoder - -For this example, we will be teaching Moondream to describe images. - -Given the prompt: -`\n\nQuestion: Describe this image.\n\nAnswer:` - -We return a more detailed caption of the image then you would get from the base model. - -1. Double check that you've updated MODEL_PATH to point to the base moondream model in `moondream/finetune/finetune_text.py` -2. Double check that the save path ends in `.safetensors`, otherwise the run will fail. -> Navigate to line 150 in `moondream/finetune/finetune_text.py`, -``` # Add save path - save_file( - model.state_dict(), - "moondream_finetune.safetensors", // update this line ex: "models/moondream_text_finetuned.safetensors" - ) -``` - -### Start Text Finetuning -```bash -python -m moondream.finetune.finetune_text -``` - -The process will output a finetuned version of Moondream into your save path. Example output: `models/moondream_text_finetuned.safetensors`. - -### Test the Finetuned Text Encoder - -You can test the finetuned models performance with the following command (run from root moondream directory). - -This will return the caption of the image. - -```bash -# Remember to update the paths -python -m moondream.torch.sample --model [FINETUNED_MODEL_PATH] --image "[DATASET_DIRECTORY]/test/[IMAGE_NAME]" --prompt "\n\nQuestion: Describe this image.\n\nAnswer:" -``` - -## Finetuning the Region Encoder - -For this example, we will be teaching Moondream to detect railroad cracks in images of a railway. - -Our dataset trains our model such that, - -Given the prompt: -`\n\nDetect: \n\n` - -We are returned the coordinates of a detected crack in the following format: -```{'objects': [{'x_min': [X_MIN], 'y_min': [Y_MIN], 'x_max': [X_MAX], 'y_max': [Y_MAX]}]}``` - -### Setup Dataset Dependencies - -1. Update MODEL_PATH to point to the base moondream model. -5. Double check that the save path ends in `.safetensors`, otherwise the run will fail. -> Navigate to line 244 in `moondream/finetune/finetune_region.py`. -``` # Add save path - save_file( - model.state_dict(), - "moondream_finetune.safetensors", // update this line - ) -``` - -### Start Region Finetuning -```bash -python -m moondream.finetune.finetune_region -``` - -The process will output a finetuned version of Moondream into your save path. Example output: `models/moondream_region_finetuned.safetensors`. \ No newline at end of file diff --git a/moondream/finetune/__init__.py b/moondream/finetune/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/moondream/finetune/finetune_region.py b/moondream/finetune/finetune_region.py deleted file mode 100644 index c010c0fe..00000000 --- a/moondream/finetune/finetune_region.py +++ /dev/null @@ -1,264 +0,0 @@ -import torch -from torch.utils.data import Dataset -import torch.nn.functional as F -import math -from safetensors.torch import save_file -import datasets - -from tqdm import tqdm -from bitsandbytes.optim import AdamW -import wandb - -from ..torch.weights import load_weights_into_model -from ..torch.moondream import MoondreamModel, MoondreamConfig, text_encoder -from ..torch.text import _produce_hidden -from ..torch.region import ( - decode_coordinate, - decode_size, - encode_coordinate, - encode_size, -) - - -# This is a intended to be a basic starting point. Your optimal hyperparams and data may be different. -MODEL_PATH = "" -LR = 1e-5 -EPOCHS = 1 -GRAD_ACCUM_STEPS = 128 - - -def lr_schedule(step, max_steps): - x = step / max_steps - if x < 0.1: - return 0.1 * LR + 0.9 * LR * x / 0.1 - else: - return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2 - - -def region_loss( - hidden_states: torch.Tensor, - w, - labels: torch.Tensor, - c_idx: torch.Tensor, - s_idx: torch.Tensor, -): - l_idx = torch.arange(len(labels)) - - c_idx = c_idx - 1 - c_hidden = hidden_states[:, c_idx, :] - c_logits = decode_coordinate(c_hidden, w) - c_labels = labels[(l_idx % 4) < 2] - - c_loss = F.cross_entropy( - c_logits.view(-1, c_logits.size(-1)), - c_labels, - ) - - s_idx = s_idx - 1 - s_hidden = hidden_states[:, s_idx, :] - s_logits = decode_size(s_hidden, w).view(-1, 1024) - s_labels = labels[(l_idx % 4) >= 2] - - s_loss = F.cross_entropy(s_logits, s_labels) - - return c_loss + s_loss - - -class WasteDetection(Dataset): - def __init__(self, split: str = "train"): - self.dataset: datasets.Dataset = datasets.load_dataset( - "moondream/waste_detection", split=split - ) - self.dataset = self.dataset.shuffle(seed=111) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - row = self.dataset[idx] - image = row["image"] - boxes = row["boxes"] - labels = row["labels"] - - objects = {} - for box, label in zip(boxes, labels): - objects.setdefault(label, []).append(box) - - flat_boxes = [] - class_names = [] - for label, box_list in objects.items(): - for b in box_list: - flat_boxes.append(b) - class_names.append(label) - - flat_boxes = torch.as_tensor(flat_boxes, dtype=torch.float16) - image_id = torch.tensor([idx], dtype=torch.int64) - - return { - "image": image, - "boxes": flat_boxes, - "class_names": class_names, - "image_id": image_id, - } - - -def main(): - if torch.cuda.is_available(): - torch.set_default_device("cuda") - elif torch.backends.mps.is_available(): - torch.set_default_device("mps") - - wandb.init( - project="moondream-ft", - config={ - "EPOCHS": EPOCHS, - "GRAD_ACCUM_STEPS": GRAD_ACCUM_STEPS, - "LR": LR, - }, - ) - - config = MoondreamConfig() - model = MoondreamModel(config) - load_weights_into_model(MODEL_PATH, model) - - # If you are struggling with GPU memory, try AdamW8Bit - optimizer = AdamW( - [{"params": model.region.parameters()}], - lr=LR, - betas=(0.9, 0.95), - eps=1e-6, - ) - - dataset = WasteDetection() - - total_steps = EPOCHS * len(dataset) // GRAD_ACCUM_STEPS - pbar = tqdm(total=total_steps) - - i = 0 - for epoch in range(EPOCHS): - for sample in dataset: - i += 1 - - with torch.no_grad(): - img_emb = model._run_vision_encoder(sample["image"]) - bos_emb = text_encoder( - torch.tensor( - [[model.config.tokenizer.bos_id]], device=model.device - ), - model.text, - ) - eos_emb = text_encoder( - torch.tensor( - [[model.config.tokenizer.eos_id]], device=model.device - ), - model.text, - ) - - boxes_by_class = {} - for box, cls in zip(sample["boxes"], sample["class_names"]): - boxes_by_class.setdefault(cls, []).append(box) - - total_loss = 0.0 - for class_name, boxes_list in boxes_by_class.items(): - with torch.no_grad(): - instruction = f"\n\nDetect: {class_name}\n\n" - instruction_tokens = model.tokenizer.encode(instruction).ids - instruction_emb = text_encoder( - torch.tensor([[instruction_tokens]], device=model.device), - model.text, - ).squeeze(0) - - cs_emb = [] - cs_labels = [] - c_idx = [] - s_idx = [] - for bb in boxes_list: - l_cs = len(cs_emb) - cs_emb.extend( - [ - encode_coordinate(bb[0].unsqueeze(0), model.region), - encode_coordinate(bb[1].unsqueeze(0), model.region), - encode_size(bb[2:4], model.region), - ] - ) - c_idx.extend([l_cs, l_cs + 1]) - s_idx.append(l_cs + 2) - - # Create coordinate bin labels - coord_labels = [ - min(max(torch.round(p * 1023), 0), 1023).item() for p in bb[:2] - ] - - # Create size bin labels using log-scale mapping - s_log2_bins = [] - for s_val in bb[2:4]: - s_val = float(s_val) - s_clamped = max(s_val, 1 / 1024) - s_log2 = math.log2(s_clamped) - mapped = (s_log2 + 10.0) / 10.0 * 1023.0 - s_bin = int(round(mapped)) - s_bin = max(min(s_bin, 1023), 0) - s_log2_bins.append(s_bin) - - # Combine coordinate and size bin labels - cs_labels.extend(coord_labels + s_log2_bins) - - if len(cs_emb) == 0: - continue - cs_emb = torch.stack(cs_emb) - - inputs_embeds = torch.cat( - [bos_emb, img_emb[None], instruction_emb, cs_emb[None], eos_emb], - dim=1, - ) - prefix = inputs_embeds.size(1) - cs_emb.size(0) - c_idx = torch.tensor(c_idx) + prefix - s_idx = torch.tensor(s_idx) + prefix - - hidden = _produce_hidden( - inputs_embeds=inputs_embeds, w=model.text, config=config.text - ) - - loss = region_loss( - hidden_states=hidden, - w=model.region, - labels=torch.tensor(cs_labels, dtype=torch.int64), - c_idx=c_idx, - s_idx=s_idx, - ) - total_loss += loss - - total_loss.backward() - - if i % GRAD_ACCUM_STEPS == 0: - optimizer.step() - optimizer.zero_grad() - - lr_val = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps) - for param_group in optimizer.param_groups: - param_group["lr"] = lr_val - pbar.set_postfix( - {"step": i // GRAD_ACCUM_STEPS, "loss": total_loss.item()} - ) - pbar.update(1) - wandb.log( - { - "loss/train": total_loss.item(), - "lr": optimizer.param_groups[0]["lr"], - } - ) - wandb.finish() - - # Replace with your desired output location. - save_file( - model.state_dict(), - "moondream_finetune.safetensors", - ) - - -if __name__ == "__main__": - """ - Replace paths with your appropriate paths. - To run: python -m moondream.finetune.finetune_region - """ - main() diff --git a/moondream/finetune/finetune_text.py b/moondream/finetune/finetune_text.py deleted file mode 100644 index 429b19f9..00000000 --- a/moondream/finetune/finetune_text.py +++ /dev/null @@ -1,163 +0,0 @@ -import torch -import torch.nn as nn -from torch.utils.data import Dataset -import math -from safetensors.torch import save_file - -from tqdm import tqdm -from datasets import load_dataset -from bitsandbytes.optim import AdamW8bit -import wandb - -from ..torch.weights import load_weights_into_model -from ..torch.moondream import MoondreamModel, MoondreamConfig, text_encoder -from ..torch.text import _produce_hidden, _lm_head, TextConfig - -# This is a intended to be a basic starting point for fine-tuning the text encoder. -# Your optimal hyperparams and data may be different. -MODEL_PATH = "" -# Your data should end with the eos token. Here is the textual representation. -ANSWER_EOS = "<|endoftext|>" -LR = 3e-6 -EPOCHS = 3 -GRAD_ACCUM_STEPS = 128 - - -def lr_schedule(step, max_steps): - x = step / max_steps - if x < 0.1: - return 0.1 * LR + 0.9 * LR * x / 0.1 - else: - return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2 - - -def text_loss( - inputs_embeds: torch.Tensor, w: nn.Module, labels: torch.Tensor, config: TextConfig -): - _, q_len, _ = inputs_embeds.shape - hidden_BTC = _produce_hidden(inputs_embeds, w, config) - lm_logits = _lm_head(hidden_BTC, w) - - loss = None - if labels is not None: - _, _, l_len = labels.shape - shift_index = (q_len - l_len) - 1 - shifted_logits = lm_logits[..., shift_index:-1, :].contiguous() - shifted_labels = labels.contiguous() - loss = nn.CrossEntropyLoss()( - shifted_logits.view(-1, shifted_logits.size(-1)), - shifted_labels.view(-1), - ) - return loss - - -class DocciDataset(Dataset): - def __init__(self, split="train"): - self.data = load_dataset("google/docci", trust_remote_code=True)[split] - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - sample = self.data[idx] - description = sample["description"] - return { - "image": sample["image"], - "qa": { - "question": "\n\nQuestion: Describe this image.\n\nAnswer:", - "answer": f"{description}{ANSWER_EOS}", - }, - } - - -def main(): - if torch.cuda.is_available(): - torch.set_default_device("cuda") - elif torch.backends.mps.is_available(): - torch.set_default_device("mps") - - wandb.init( - project="moondream-ft", - config={ - "EPOCHS": EPOCHS, - "GRAD_ACCUM_STEPS": GRAD_ACCUM_STEPS, - "LR": LR, - }, - ) - - config = MoondreamConfig() - model = MoondreamModel(config) - load_weights_into_model(MODEL_PATH, model) - - optimizer = AdamW8bit( - [ - {"params": model.text.parameters()}, - ], - lr=LR, - betas=(0.9, 0.95), - eps=1e-6, - ) - - dataset = DocciDataset("train") - - total_steps = EPOCHS * len(dataset) // GRAD_ACCUM_STEPS - pbar = tqdm(total=total_steps) - - i = 0 - for epoch in range(EPOCHS): - for sample in dataset: - i += 1 - with torch.no_grad(): - img_emb = model._run_vision_encoder(sample["image"]) - bos_emb = text_encoder( - torch.tensor([[model.config.tokenizer.bos_id]], device=model.device), - model.text, - ) - question_tokens = model.tokenizer.encode(sample["qa"]["question"]).ids - question_emb = text_encoder( - torch.tensor([[question_tokens]], device=model.device), - model.text, - ).squeeze(0) - answer_tokens = model.tokenizer.encode(sample["qa"]["answer"]).ids - answer_emb = text_encoder( - torch.tensor([[answer_tokens]], device=model.device), - model.text, - ).squeeze(0) - inputs_embeds = torch.cat( - [bos_emb, img_emb[None], question_emb, answer_emb], dim=1 - ) - loss = text_loss( - inputs_embeds=inputs_embeds, - w=model.text, - labels=torch.tensor([[answer_tokens]], device=model.device), - config=config.text, - ) - - loss.backward() - - if i % GRAD_ACCUM_STEPS == 0: - optimizer.step() - optimizer.zero_grad() - - lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps) - for param_group in optimizer.param_groups: - param_group["lr"] = lr - pbar.set_postfix({"step": i // GRAD_ACCUM_STEPS, "loss": loss.item()}) - pbar.update(1) - wandb.log( - {"loss/train": loss.item(), "lr": optimizer.param_groups[0]["lr"]} - ) - wandb.finish() - # Add save path: ex. home/model.safetensors - save_file( - model.state_dict(), - "moondream_finetune.safetensors", - ) - - -if __name__ == "__main__": - """ - Replace paths with your appropriate paths. - To run: python -m moondream.finetune.finetune_text - """ - main() diff --git a/moondream/torch/config.py b/moondream/torch/config.py index 9856d6b1..544b7888 100644 --- a/moondream/torch/config.py +++ b/moondream/torch/config.py @@ -2,17 +2,26 @@ from typing import Dict, List, Optional +@dataclass(frozen=True) +class TextMoeConfig: + num_experts: int = 64 + start_layer: int = 4 + experts_per_token: int = 8 + expert_inner_dim: int = 1024 + + @dataclass(frozen=True) class TextConfig: dim: int = 2048 ff_dim: int = 8192 n_layers: int = 24 vocab_size: int = 51200 - max_context: int = 2048 + max_context: int = 4096 n_heads: int = 32 n_kv_heads: int = 32 prefix_attn: int = 730 group_size: Optional[int] = None + moe: Optional[TextMoeConfig] = TextMoeConfig() @dataclass(frozen=True) @@ -37,7 +46,6 @@ class RegionConfig: coord_out_dim: int = 1024 size_feat_dim: int = 512 size_out_dim: int = 2048 - inner_dim: int = 8192 group_size: Optional[int] = None diff --git a/moondream/torch/hf_moondream.py b/moondream/torch/hf_moondream.py index 526b6550..3432d869 100644 --- a/moondream/torch/hf_moondream.py +++ b/moondream/torch/hf_moondream.py @@ -27,11 +27,11 @@ def extract_question(text): class HfConfig(PretrainedConfig): _auto_class = "AutoConfig" - model_type = "moondream1" + model_type = "moondream3" def __init__(self, **kwargs): super().__init__(**kwargs) - self.config = {} + self.config = {"skills": ["query", "caption", "detect", "point"]} class HfMoondream(PreTrainedModel): diff --git a/moondream/torch/layers.py b/moondream/torch/layers.py index 6fee6e8f..8d57c521 100644 --- a/moondream/torch/layers.py +++ b/moondream/torch/layers.py @@ -146,6 +146,74 @@ def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Te return x +def moe_mlp( + x: torch.Tensor, mlp_module: nn.Module, experts_per_token: int +) -> torch.Tensor: + B, T, C = x.shape + x = x.reshape(-1, C) + + # Router computation + router_logits = mlp_module.router(x) + topk_logits, topk_idxs = torch.topk(router_logits, experts_per_token, dim=-1) + topk_weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(x.dtype) + num_tokens, top_k = topk_idxs.shape + + if T == 1: + w1_weight = mlp_module.fc1.weight + w2_weight = mlp_module.fc2.weight + + # Flatten to process all token-expert pairs at once + flat_idxs = topk_idxs.view(-1) # [T*A] + flat_weights = topk_weights.view(-1) # [T*A] + + # Select expert weights + w1_selected = w1_weight[flat_idxs] # [T*A, H, D] + w2_selected = w2_weight[flat_idxs] # [T*A, D, H] + + # Expand input for all token-expert pairs + x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C) # [T*A, D] + + # First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H] + x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze( + -1 + ) # [T*A, H] + x1, g = x1_full.chunk(2, dim=-1) + x1 = F.gelu(x1) * (g + 1) + + # Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D] + expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1) # [T*A, D] + + # Apply weights and reshape + weighted_outs = expert_outs * flat_weights.unsqueeze(-1) # [T*A, D] + weighted_outs = weighted_outs.view(num_tokens, top_k, C) # [T, A, D] + + # Sum over experts + mlp_out = weighted_outs.sum(dim=1) # [T, D] + mlp_out = mlp_out.view(B, T, C) + + return mlp_out + else: + out = x.new_zeros(x.size()) + + for expert_id in range(mlp_module.fc1.weight.shape[0]): + token_pos, which_k = (topk_idxs == expert_id).nonzero(as_tuple=True) + if token_pos.numel() == 0: + continue + + x_tok = x.index_select(0, token_pos) + gate_tok = topk_weights[token_pos, which_k] + + h_full = F.linear(x_tok, mlp_module.fc1.weight[expert_id]) + h, g = h_full.chunk(2, dim=-1) + h = F.gelu(h) * (g + 1) + y = F.linear(h, mlp_module.fc2.weight[expert_id]) + + y.mul_(gate_tok.unsqueeze(-1)) + out.index_add_(0, token_pos, y) + + return out.view(B, T, C) + + @dataclass class AttentionWeights: qkv: LinearWeights diff --git a/moondream/torch/lora.py b/moondream/torch/lora.py index e193c98b..b6fc4b91 100644 --- a/moondream/torch/lora.py +++ b/moondream/torch/lora.py @@ -21,9 +21,12 @@ def variant_cache_dir(): def cached_variant_path(variant_id: str): - cache_dir = variant_cache_dir() / variant_id + variant, *rest = variant_id.split("/", 1) + step = rest[0] if rest else "final" + + cache_dir = variant_cache_dir() / variant os.makedirs(cache_dir, exist_ok=True) - dest = cache_dir / "final.pt" + dest = cache_dir / f"{step}.pt" if dest.exists(): return dest diff --git a/moondream/torch/moondream.py b/moondream/torch/moondream.py index 7c02bd99..9cfcc0dc 100644 --- a/moondream/torch/moondream.py +++ b/moondream/torch/moondream.py @@ -6,6 +6,7 @@ from PIL import Image from dataclasses import dataclass from tokenizers import Tokenizer +from torch.nn.attention.flex_attention import create_block_mask from .config import MoondreamConfig from .image_crops import reconstruct_from_crops @@ -49,8 +50,8 @@ DEFAULT_MAX_TOKENS = 768 DEFAULT_TEMPERATURE = 0.5 -DEFAULT_TOP_P = 0.3 -DEFAULT_MAX_OBJECTS = 50 +DEFAULT_TOP_P = 0.9 +DEFAULT_MAX_OBJECTS = 150 @dataclass(frozen=True) @@ -78,6 +79,17 @@ def update(self, pos_ids, k, v): return kout, vout +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def get_mask_mod(mask_mod, offset): + def _mask_mod(b, h, q, kv): + return mask_mod(b, h, q + offset, kv) + + return _mask_mod + + class MoondreamModel(nn.Module): def __init__( @@ -99,32 +111,14 @@ def __init__( "coord_encoder": linear_cls( config.region.coord_feat_dim, config.region.dim, dtype=dtype ), - "coord_decoder": nn.ModuleDict( - { - "fc1": linear_cls( - config.region.dim, config.region.inner_dim, dtype=dtype - ), - "fc2": linear_cls( - config.region.inner_dim, - config.region.coord_out_dim, - dtype=dtype, - ), - } + "coord_decoder": linear_cls( + config.region.dim, config.region.coord_out_dim, dtype=dtype ), "size_encoder": linear_cls( config.region.size_feat_dim, config.region.dim, dtype=dtype ), - "size_decoder": nn.ModuleDict( - { - "fc1": linear_cls( - config.region.dim, config.region.inner_dim, dtype=dtype - ), - "fc2": linear_cls( - config.region.inner_dim, - config.region.size_out_dim, - dtype=dtype, - ), - } + "size_decoder": linear_cls( + config.region.dim, config.region.size_out_dim, dtype=dtype ), } ) @@ -145,10 +139,36 @@ def __init__( attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1 self.register_buffer("attn_mask", attn_mask, persistent=False) + self.use_flex_decoding = True + self._causal_block_mask = None + self._point_gen_indices = None + # Initialize KV caches. if setup_caches: self._setup_caches() + @property + def causal_block_mask(self): + # The things we do to deal with ZeroGPU... + if self._causal_block_mask is None: + self._causal_block_mask = create_block_mask( + causal_mask, + B=None, + H=None, + Q_LEN=self.config.text.max_context, + KV_LEN=self.config.text.max_context, + ) + return self._causal_block_mask + + @property + def point_gen_indices(self): + if self._point_gen_indices is None: + self._point_gen_indices = torch.tensor( + [self.config.tokenizer.coord_id, self.config.tokenizer.eos_id], + device=self.device, + ) + return self._point_gen_indices + def _setup_caches(self): c = self.config.text for b in self.text.blocks: @@ -186,9 +206,27 @@ def _decode_one_tok( attn_mask: torch.Tensor, pos_ids: torch.Tensor, lora: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, ): - hidden = text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora) - logits = lm_head(hidden, self.text) + if self.use_flex_decoding: + torch._assert(pos_ids.shape[-1] == 1, "Invalid position ID shape") + block_index = pos_ids // self.causal_block_mask.BLOCK_SIZE[0] + mask = self.causal_block_mask[:, :, block_index] + mask.seq_lengths = (1, mask.seq_lengths[1]) + mask.mask_mod = get_mask_mod(self.causal_block_mask.mask_mod, pos_ids[0]) + else: + mask = None + + hidden = text_decoder( + x, + self.text, + attn_mask, + pos_ids, + self.config.text, + lora=lora, + flex_block_mask_slice=mask, + ) + logits = lm_head(hidden, self.text, indices=lm_head_indices) return logits, hidden def compile(self): @@ -196,13 +234,41 @@ def compile(self): if isinstance(module, QuantizedLinear): module.unpack() - # TODO: vision_projection is not being compiled + # Initialize lazy properties to avoid first-call overhead + self.causal_block_mask + self.point_gen_indices + + # TODO: vision_projection and _prefill is not being compiled self._vis_enc = torch.compile(self._vis_enc, fullgraph=True) - self._prefill = torch.compile(self._prefill, fullgraph=True) self._decode_one_tok = torch.compile( self._decode_one_tok, fullgraph=True, mode="reduce-overhead" ) + # Warm up compiled methods with dummy forward passes + device = self.device + dtype = self.vision.pos_emb.dtype + with torch.no_grad(): + # Warmup vision encoder + dummy_crops = torch.randn(1, 3, 378, 378, device=device, dtype=dtype) + self._vis_enc(dummy_crops) + + # Warmup _decode_one_tok (both normal and point generation modes) + dummy_emb = torch.randn( + 1, 1, self.config.text.dim, device=device, dtype=dtype + ) + dummy_mask = torch.ones( + 1, 1, self.config.text.max_context, device=device, dtype=torch.bool + ) + dummy_pos_ids = torch.tensor([100], device=device, dtype=torch.long) + self._decode_one_tok(dummy_emb, dummy_mask, dummy_pos_ids, None) + self._decode_one_tok( + dummy_emb, + dummy_mask, + dummy_pos_ids, + None, + lm_head_indices=self.point_gen_indices, + ) + def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor: all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device) @@ -239,7 +305,7 @@ def encode_image( lora = ( variant_state_dict(settings["variant"], device=self.device) - if settings is not None and settings["variant"] is not None + if settings is not None and "variant" in settings else None ) @@ -253,7 +319,9 @@ def encode_image( ) inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1) mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :] - pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long) + pos_ids = torch.arange( + inputs_embeds.size(1), dtype=torch.long, device=self.device + ) self._prefill(inputs_embeds, mask, pos_ids, lora) return EncodedImage( @@ -306,7 +374,9 @@ def _prefill_prompt( attn_mask = self.attn_mask mask = attn_mask[:, :, pos : pos + prompt_emb.size(1), :] - pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long) + pos_ids = torch.arange( + pos, pos + prompt_emb.size(1), dtype=torch.long, device=self.device + ) hidden_BC = self._prefill(prompt_emb, mask, pos_ids, lora) logits_BV = lm_head(hidden_BC, self.text) @@ -360,7 +430,9 @@ def _generate_reasoning( text_token_chunks = [[]] grounding_chunks = [[]] - mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool) + mask = torch.zeros( + 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool + ) mask[:, :, :pos] = 1 pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long) generated_tokens = 0 @@ -469,7 +541,9 @@ def _generate_answer( ) def generator(next_token, pos): - mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool) + mask = torch.zeros( + 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool + ) mask[:, :, :pos] = 1 pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long) generated_tokens = 0 @@ -542,7 +616,7 @@ def query( self, image: Optional[Union[Image.Image, EncodedImage]] = None, question: str = None, - reasoning: bool = False, + reasoning: bool = True, spatial_refs: Optional[SpatialRefs] = None, stream: bool = False, settings: Optional[TextSamplingSettings] = None, @@ -584,10 +658,7 @@ def query( spatial_toks.extend([coord_id, coord_id, size_id]) prompt_tokens = [ - prompt_toks - + spatial_toks - + self.tokenizer.encode(question).ids - + self.config.tokenizer.templates["query"]["suffix"] + prompt_toks + spatial_toks + self.tokenizer.encode(question).ids ] if reasoning: @@ -660,7 +731,9 @@ def _generate_points( lora: Optional[dict] = None, ): out = [] - mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool) + mask = torch.zeros( + 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool + ) mask[:, :, :pos] = 1 pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long) @@ -726,9 +799,17 @@ def _generate_points( # Decode next token (x-coordinate, or eos) mask[:, :, pos], pos_ids[0] = 1, pos - logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora) + logits, hidden = self._decode_one_tok( + next_emb, + mask, + pos_ids, + lora, + lm_head_indices=self.point_gen_indices, + ) pos += 1 - next_token = torch.argmax(logits, dim=-1) + # Map back: index 0 -> coord_id, index 1 -> eos_id + next_token_idx = torch.argmax(logits, dim=-1) + next_token = self.point_gen_indices[next_token_idx] return out @@ -862,7 +943,10 @@ def _detect_gaze( mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :] pos_ids = torch.arange( - image.pos, image.pos + prompt_emb.size(1), dtype=torch.long + image.pos, + image.pos + prompt_emb.size(1), + dtype=torch.long, + device=self.device, ) hidden = self._prefill(prompt_emb, mask, pos_ids, lora=None) logits = lm_head(hidden, self.text) diff --git a/moondream/torch/region.py b/moondream/torch/region.py index 9224e2c2..2d7686bf 100644 --- a/moondream/torch/region.py +++ b/moondream/torch/region.py @@ -4,8 +4,6 @@ from typing import List, Tuple, Union -from .layers import mlp - SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]] @@ -54,7 +52,7 @@ def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor: Returns: A single logit representing the predicted coordinate value (x or y) """ - return mlp(hidden_state, w.coord_decoder) + return w.coord_decoder(hidden_state) def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor: @@ -90,7 +88,7 @@ def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor: A tensor containing logits for 1024 bins for width and height. Shape is (2, 1024) where the first dimension corresponds to width and height. """ - return mlp(hidden_state, w.size_decoder).view(2, -1) + return w.size_decoder(hidden_state).view(2, -1) def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor: diff --git a/moondream/torch/rope.py b/moondream/torch/rope.py index 737fa080..93a3bfbb 100644 --- a/moondream/torch/rope.py +++ b/moondream/torch/rope.py @@ -6,8 +6,7 @@ def precompute_freqs_cis( dim: int, end: int, - theta: float = 10000.0, - use_scaled: bool = False, + theta: float = 1500000.0, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim)) diff --git a/moondream/torch/sample.py b/moondream/torch/sample.py index 0a666579..2c35268f 100644 --- a/moondream/torch/sample.py +++ b/moondream/torch/sample.py @@ -35,7 +35,8 @@ config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) - model.to(device) + model.to(device, dtype=torch.bfloat16) + model.compile() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -50,6 +51,16 @@ if not args.benchmark: encoded_image = model.encode_image(image) + # Text query + text_query = "What is the capital of Washington, USA? Answer in JSON format." + print("Query:", text_query) + text_response = model.query(None, text_query, reasoning=True, stream=True) + print("Reasoning:", text_response["reasoning"]) + for t in text_response["answer"]: + print(t, end="", flush=True) + print() + print() + # Short caption print("Caption: short") for t in model.caption(encoded_image, "short", stream=True)["caption"]: @@ -64,6 +75,13 @@ print() print() + # Long caption + print("Caption: long") + for t in model.caption(encoded_image, "long", stream=True)["caption"]: + print(t, end="", flush=True) + print() + print() + # Query print("Query:", args.prompt) for t in model.query( @@ -104,21 +122,22 @@ image.save("detect.jpg") # Spatial query - print("Spatial query: What is this?") - for t in model.query( - encoded_image, - "What is this?", - spatial_refs=[ - [ - (obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]) - for obj in objs - ][0] - ], - stream=True, - )["answer"]: - print(t, end="", flush=True) - print() - print() + if len(objs) > 0: + print("Spatial query: What is this?") + for t in model.query( + encoded_image, + "What is this?", + spatial_refs=[ + [ + (obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]) + for obj in objs + ][0] + ], + stream=True, + )["answer"]: + print(t, end="", flush=True) + print() + print() # Point obj = "ear" @@ -134,32 +153,35 @@ print() # Spatial query - for o in ["hand", "ear", "face"]: - for k in [(objs, "hand"), (points, "ear")]: - print(f"Spatial query: Is this a {o}? ({k[1]})") - for t in model.query( - encoded_image, - f"Is this a {o}?", - spatial_refs=[ - [ - ( - (obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]) - if "x_min" in obj - else (obj["x"], obj["y"]) - ) - for obj in k[0] - ][0] - ], - )["answer"]: - print(t, end="", flush=True) - print() + if len(objs) > 0: + for o in ["hand", "ear", "face"]: + for k in [(objs, "hand"), (points, "ear")]: + print(f"Spatial query: Is this a {o}? ({k[1]})") + for t in model.query( + encoded_image, + f"Is this a {o}?", + spatial_refs=[ + [ + ( + ( + obj["x_min"], + obj["y_min"], + obj["x_max"], + obj["y_max"], + ) + if "x_min" in obj + else (obj["x"], obj["y"]) + ) + for obj in k[0] + ][0] + ], + )["answer"]: + print(t, end="", flush=True) + print() # Detect gaze model.detect_gaze(encoded_image, (0.5, 0.5)) elif model.device.type != "mps": - torch._dynamo.reset() - model.compile() - # Warmup runs for _ in tqdm(range(5), desc="Warmup"): encoded_image = model.encode_image(image) diff --git a/moondream/torch/text.py b/moondream/torch/text.py index 07cbafd4..9997dd2d 100644 --- a/moondream/torch/text.py +++ b/moondream/torch/text.py @@ -2,9 +2,10 @@ import torch.nn as nn from torch.nn import functional as F +from torch.nn.attention.flex_attention import flex_attention from typing import Optional -from .layers import layer_norm, mlp, QuantizedLinear +from .layers import layer_norm, mlp, QuantizedLinear, moe_mlp from .rope import apply_rotary_emb, precompute_freqs_cis from .config import TextConfig @@ -22,7 +23,8 @@ def attn( n_heads: int, n_kv_heads: int, position_ids: torch.Tensor, - lora: Optional[dict], + lora: Optional[dict] = None, + flex_block_mask_slice=None, ): bsz, q_len, d_model = x.shape head_dim = d_model // n_heads @@ -33,21 +35,38 @@ def attn( q_dim = n_heads * head_dim kv_dim = n_kv_heads * head_dim q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1) - del qkv_out q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + if hasattr(w, "tau") and w.tau is not None: + tok_feat = F.gelu(qkv_out) + tok_q = torch.tanh(torch.matmul(tok_feat, w.tau["wq"].t())).permute(0, 2, 1) + tok_v = torch.tanh(torch.matmul(tok_feat, w.tau["wv"].t())).permute(0, 2, 1) + pos = position_ids.to(q.dtype) + 1 + tau_pos = 1 + ( + torch.sigmoid(w.tau["alpha"][:, None] * pos.log()) - 0.5 + ) # (H,S) + tau_q = (tok_q + tau_pos[None]).unsqueeze(-1) # (B,H,S,1) + tau_v = (tok_v + tau_pos[None]).unsqueeze(-1) + q = q * tau_q + v = v * tau_v + q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads) k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads) if kv_cache is not None: k, v = kv_cache.update(position_ids, k, v) - out = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads - ) + if flex_block_mask_slice is not None: + torch._assert(n_heads == n_kv_heads, "gqa not supported yet") + out = flex_attention(q, k, v, block_mask=flex_block_mask_slice) + else: + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads + ) + out = out.transpose(1, 2).reshape(bsz, q_len, d_model) out0 = w.proj(out) @@ -60,78 +79,14 @@ def attn( return out -def _attn( - x: torch.Tensor, - w: torch.Tensor, - freqs_cis: torch.Tensor, - attn_mask: torch.Tensor, - n_heads: int, - n_kv_heads: int, -): - bsz, q_len, d_model = x.shape - head_dim = d_model // n_heads - pos = 0 - - qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim) - q_dim = n_heads * head_dim - kv_dim = n_kv_heads * head_dim - - q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2) - k = ( - qkv_out[..., q_dim : q_dim + kv_dim] - .view(bsz, q_len, n_kv_heads, head_dim) - .transpose(1, 2) - ) - v = ( - qkv_out[..., q_dim + kv_dim :] - .view(bsz, q_len, n_kv_heads, head_dim) - .transpose(1, 2) - ) - - position_ids = torch.arange(pos, pos + q_len, dtype=torch.long) - q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads) - k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads) - out = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads - ) - out = out.transpose(1, 2).reshape(bsz, q_len, d_model) - out = w.proj(out) - return out - - -def _produce_hidden(inputs_embeds: torch.Tensor, w: nn.Module, config: TextConfig): - hidden_BTC = inputs_embeds - - bsz, q_len, d_model = inputs_embeds.shape - attn_mask = torch.zeros(q_len, q_len) - attn_mask[:730, :730] = 1 - for i in range(730, q_len): - attn_mask[i, : i + 1] = 1 - attn_mask = attn_mask.to(dtype=torch.bool) - - for i, block in enumerate(w.blocks): - l_in = layer_norm(hidden_BTC, block.ln) - l_attn = _attn( - x=l_in, - w=block.attn, - freqs_cis=w.freqs_cis, - attn_mask=attn_mask, - n_heads=config.n_heads, - n_kv_heads=config.n_kv_heads, - ) - l_mlp = mlp(l_in, block.mlp) - hidden_BTC = hidden_BTC + l_attn + l_mlp - - return hidden_BTC - - def text_decoder( x: torch.Tensor, w: nn.Module, attn_mask: torch.Tensor, position_ids: torch.Tensor, config: TextConfig, - lora: Optional[dict], + lora: Optional[dict] = None, + flex_block_mask_slice=None, ): for i, block in enumerate(w.blocks): if lora is not None: @@ -153,24 +108,62 @@ def text_decoder( n_kv_heads=config.n_kv_heads, position_ids=position_ids, lora=attn_lora, + flex_block_mask_slice=flex_block_mask_slice, ) - l_mlp = mlp(l_in, block.mlp, lora=mlp_lora) + + if config.moe is not None and i >= config.moe.start_layer: + l_mlp = moe_mlp(l_in, block.mlp, config.moe.experts_per_token) + else: + l_mlp = mlp(l_in, block.mlp, lora=mlp_lora) + x = x + l_attn + l_mlp return x -def lm_head(hidden_BTC: torch.Tensor, w: nn.Module): +def lm_head( + hidden_BTC: torch.Tensor, w: nn.Module, indices: Optional[torch.Tensor] = None +): hidden_BC = hidden_BTC[:, -1, :] hidden_BC = layer_norm(hidden_BC, w.post_ln) - logits = w.lm_head(hidden_BC) + if indices is not None: + # Only compute logits for specified token indices + logits = hidden_BC @ w.lm_head.weight[indices].T + w.lm_head.bias[indices] + else: + logits = w.lm_head(hidden_BC) return logits -def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module): - hidden_BTC = layer_norm(hidden_BTC, w.post_ln) - logits = w.lm_head(hidden_BTC) - return logits +def build_dense_mlp(d_model, d_ffn, dtype, linear_cls): + return nn.ModuleDict( + { + "fc1": linear_cls(d_model, d_ffn, dtype=dtype), + "fc2": linear_cls(d_ffn, d_model, dtype=dtype), + } + ) + + +def build_moe_mlp(d_model, d_ffn, n_experts, dtype): + # For GeGLU, fc1 needs to output 2 * d_ffn (for gating) + return nn.ModuleDict( + { + "router": nn.Linear(d_model, n_experts, dtype=dtype), + "fc1": nn.ParameterDict( + { + "weight": nn.Parameter( + torch.empty(n_experts, 2 * d_ffn, d_model, dtype=dtype) + ) + } + ), + "fc2": nn.ParameterDict( + { + "weight": nn.Parameter( + torch.empty(n_experts, d_model, d_ffn, dtype=dtype) + ) + } + ), + } + ) def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: @@ -190,21 +183,41 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: "proj": linear_cls( config.dim, config.dim, dtype=dtype ), - } - ), - "mlp": nn.ModuleDict( - { - "fc1": linear_cls( - config.dim, config.ff_dim, dtype=dtype - ), - "fc2": linear_cls( - config.ff_dim, config.dim, dtype=dtype + "tau": nn.ParameterDict( + { + "wq": nn.Parameter( + torch.empty( + config.n_heads, qkv_dim, dtype=dtype + ) + ), + "wv": nn.Parameter( + torch.empty( + config.n_heads, qkv_dim, dtype=dtype + ) + ), + "alpha": nn.Parameter( + torch.empty(config.n_heads, dtype=dtype) + ), + } ), } ), + "mlp": ( + build_moe_mlp( + config.dim, + config.moe.expert_inner_dim, + config.moe.num_experts, + dtype, + ) + if config.moe is not None + and layer_idx >= config.moe.start_layer + else build_dense_mlp( + config.dim, config.ff_dim, dtype, linear_cls + ) + ), } ) - for _ in range(config.n_layers) + for layer_idx in range(config.n_layers) ] ), "post_ln": nn.LayerNorm(config.dim, dtype=dtype), diff --git a/moondream/torch/weights.py b/moondream/torch/weights.py index ab17b7a4..8bc8796e 100644 --- a/moondream/torch/weights.py +++ b/moondream/torch/weights.py @@ -54,20 +54,12 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) - "text_model.lm_head.linear.bias": model.text["lm_head"].bias, "region_model.coordinate_encoder.weight": region["coord_encoder"].weight, "region_model.coordinate_encoder.bias": region["coord_encoder"].bias, - "region_model.coordinate_decoder.fc1.weight": region["coord_decoder"][ - "fc1" - ].weight, - "region_model.coordinate_decoder.fc1.bias": region["coord_decoder"]["fc1"].bias, - "region_model.coordinate_decoder.fc2.weight": region["coord_decoder"][ - "fc2" - ].weight, - "region_model.coordinate_decoder.fc2.bias": region["coord_decoder"]["fc2"].bias, + "region_model.coordinate_head.weight": region["coord_decoder"].weight, + "region_model.coordinate_head.bias": region["coord_decoder"].bias, "region_model.size_encoder.weight": region["size_encoder"].weight, "region_model.size_encoder.bias": region["size_encoder"].bias, - "region_model.size_decoder.fc1.weight": region["size_decoder"]["fc1"].weight, - "region_model.size_decoder.fc1.bias": region["size_decoder"]["fc1"].bias, - "region_model.size_decoder.fc2.weight": region["size_decoder"]["fc2"].weight, - "region_model.size_decoder.fc2.bias": region["size_decoder"]["fc2"].bias, + "region_model.size_head.weight": region["size_decoder"].weight, + "region_model.size_head.bias": region["size_decoder"].bias, } for i in range(len(model.vision["blocks"])): @@ -93,6 +85,7 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) - for i in range(len(model.text["blocks"])): prefix = f"text_model.transformer.h.{i}" blk = model.text["blocks"][i] + is_moe = hasattr(blk.mlp, "router") weight_map.update( { f"{prefix}.ln.weight": blk["ln"].weight, @@ -101,12 +94,29 @@ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) - f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias, f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight, f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias, - f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight, - f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias, - f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight, - f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias, + f"{prefix}.tau_wq": blk["attn"]["tau"]["wq"], + f"{prefix}.tau_wv": blk["attn"]["tau"]["wv"], + f"{prefix}.tau_alpha": blk["attn"]["tau"]["alpha"], } ) + if is_moe: + weight_map.update( + { + f"{prefix}.gate.weight": blk["mlp"]["router"].weight, + f"{prefix}.gate.bias": blk["mlp"]["router"].bias, + f"{prefix}.mlp.experts.weight": blk["mlp"]["fc1"].weight, + f"{prefix}.mlp.output_experts.weight": blk["mlp"]["fc2"].weight, + } + ) + else: + weight_map.update( + { + f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight, + f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias, + f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight, + f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias, + } + ) for key, tensor in weight_map.items(): tensor.data.copy_(get_tensor(key)) diff --git a/requirements.txt b/requirements.txt index c3cea0b0..533070d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,9 @@ -accelerate==0.32.1 -huggingface-hub==0.24.0 -Pillow==10.4.0 +torch==2.8.0 +Pillow-SIMD==9.5.0.post2 +transformers==4.56.1 pyvips-binary==8.16.0 pyvips==2.2.3 -torch==2.5.1 -transformers==4.44.0 +accelerate==1.10.1 gradio==4.38.1 # Needed for running evals