From f9dc5d68318180767136322512e63b6466dd3458 Mon Sep 17 00:00:00 2001 From: Akshay Kalkunte Date: Wed, 9 Oct 2024 23:22:28 +0000 Subject: [PATCH 1/4] stardoc_init --- .gitignore | 2 + examples/train_mistral.sh | 2 +- examples/train_stardoc.sh | 129 ++++ fast_llm/data/config.py | 32 +- fast_llm/data/data.py | 56 +- fast_llm/data/stardoc.py | 128 ++++ fast_llm/data/stardoc_data_utils/constants.py | 9 + .../data/stardoc_data_utils/conversation.py | 303 ++++++++ .../stardoc_data_utils/docowl_processor.py | 218 ++++++ .../docowl_stardoc_processor.py | 110 +++ fast_llm/data/stardoc_data_utils/mm_utils.py | 111 +++ fast_llm/data/stardoc_data_utils/utils.py | 33 + fast_llm/data/tokenizer.py | 66 ++ fast_llm/engine/multi_stage/stage_base.py | 2 + fast_llm/layers/language_model/config.py | 13 + fast_llm/layers/language_model/head.py | 1 + fast_llm/layers/multimodal_model/adapter.py | 58 ++ fast_llm/layers/multimodal_model/config.py | 76 +++ .../layers/multimodal_model/image_encoder.py | 100 +++ .../multimodal_language_embedding.py | 99 +++ fast_llm/models/auto.py | 3 + fast_llm/models/gpt/model.py | 2 +- fast_llm/models/stardoc/__init__.py | 0 fast_llm/models/stardoc/config.py | 102 +++ fast_llm/models/stardoc/conversion.py | 644 ++++++++++++++++++ fast_llm/models/stardoc/model.py | 348 ++++++++++ fast_llm/models/stardoc/trainer.py | 63 ++ fast_llm/tools/train.py | 3 +- run_multimodal.sh | 88 +++ set_env_stardoc.sh | 98 +++ 30 files changed, 2891 insertions(+), 8 deletions(-) create mode 100755 examples/train_stardoc.sh create mode 100644 fast_llm/data/stardoc.py create mode 100644 fast_llm/data/stardoc_data_utils/constants.py create mode 100644 fast_llm/data/stardoc_data_utils/conversation.py create mode 100644 fast_llm/data/stardoc_data_utils/docowl_processor.py create mode 100644 fast_llm/data/stardoc_data_utils/docowl_stardoc_processor.py create mode 100644 fast_llm/data/stardoc_data_utils/mm_utils.py create mode 100644 fast_llm/data/stardoc_data_utils/utils.py create mode 100644 fast_llm/layers/multimodal_model/adapter.py create mode 100644 fast_llm/layers/multimodal_model/config.py create mode 100644 fast_llm/layers/multimodal_model/image_encoder.py create mode 100644 fast_llm/layers/multimodal_model/multimodal_language_embedding.py create mode 100644 fast_llm/models/stardoc/__init__.py create mode 100644 fast_llm/models/stardoc/config.py create mode 100644 fast_llm/models/stardoc/conversion.py create mode 100644 fast_llm/models/stardoc/model.py create mode 100644 fast_llm/models/stardoc/trainer.py create mode 100644 run_multimodal.sh create mode 100755 set_env_stardoc.sh diff --git a/.gitignore b/.gitignore index 41502c68f..ec960a900 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,5 @@ venv.bak/ # Project specifics /.idea/ /.vscode/ +output/ +stardoc_hf_model/ \ No newline at end of file diff --git a/examples/train_mistral.sh b/examples/train_mistral.sh index 5745e38c8..e4b03c475 100644 --- a/examples/train_mistral.sh +++ b/examples/train_mistral.sh @@ -61,7 +61,7 @@ export TRAINING_ARGS="\ " export PERFORMANCE_ARGS="\ ---micro_batch_size=1 \ +--micro_batch_size=4 \ --training_dtype=bf16 \ --zero_stage=2 \ --num_workers=8 \ diff --git a/examples/train_stardoc.sh b/examples/train_stardoc.sh new file mode 100755 index 000000000..ea9865ac0 --- /dev/null +++ b/examples/train_stardoc.sh @@ -0,0 +1,129 @@ +# Required or optional environment variables +export PROJECT_DIR="/mnt/akshay/stardoc-FastLLM/Fast-LLM/output" +export PROJECT_NAME="stardoc-debug" +export PROJECT_VERSION="1.0" +# export DATA_PATH_LIST= +export DATA_PATH="/mnt/stardoc/datasets/save_hf/BigDoc-MultiTurn-v0.3" +# export DATA_PATH_JSON= +# export PRETRAINED_MISTRAL_PATH= +# export PRETRAINED_MIXTRAL_PATH= +export PRETRAINED_STARDOC_PATH="/mnt/akshay/stardoc-FastLLM/Fast-LLM/stardoc_hf_model/stardoc_checkpoint" +export TOKENIZER_PATH="/mnt/core_llm/models/mistral/HF/Mistral-7B-v0.3" + +export PYTHONHASHSEED=12345 + +export CMD_ARGS="fast-llm train stardoc" +# export CMD_ARGS="python /mnt/akshay/stardoc-FastLLM/Fast-LLM/fast_llm/tools/train.py stardoc" + +export MODEL_ARGS_PRETRAINED="\ +--pretrained_checkpoint_type=huggingface \ +--pretrained_checkpoint_path=$PRETRAINED_STARDOC_PATH \ +--use_pretrained_config=1 \ +" + +export MODEL_ARGS_ARCHITECTURE="\ +--num_layers=32 \ +--hidden_size=4096 \ +--vocab_size=32000 \ +--num_attention_heads=32 \ +--head_groups=8 \ +--add_linear_biases=0 \ +--ffn_hidden_size=14336 \ +--kv_channels=128 \ +--use_rotary_embeddings=1 \ +--rotary_embedding_scale=-9.210340371976184 \ +--gated=1 \ +--activation_type=silu \ +--normalization_type=rms_norm \ +--tie_word_embeddings=0 \ +--window_size=8192 \ +" + +export MULTIMODAL_ARGS ="\ +--image_encoder_hidden_size=1024 \ +--num_image_tokens=256 \ +--max_num_images=7 \ +--image_encoder_type=clip \ +" + +export DATA_ARGS="\ +--split=9998,2,0 \ +--dataset_type=stardoc \ +--dataset_source=multimodal \ +--data_path=$DATA_PATH \ +--tokenizer_type=PreTrainedTokenizer \ +--tokenizer_path=$TOKENIZER_PATH \ +" + +export TRAINING_ARGS="\ +--batch_size=32 \ +--sequence_length=8192 \ +--train_iters=500000 \ +--weight_decay=0.1 \ +--adam_beta1=0.9 \ +--adam_beta2=0.95 \ +--clip_grad=1.0 \ +--lr=0.0001 \ +--lr_warmup_iters=1000 \ +--lr_decay_style=cosine \ +--lr_decay_iters=500000 \ +--min_lr=0.000003 \ +" + +export PERFORMANCE_ARGS="\ +--micro_batch_size=4 \ +--training_dtype=bf16 \ +--zero_stage=2 \ +--num_workers=8 \ +" + +export MONITORING_ARGS="\ +--validation_iters=25 \ +--validation_interval=1000 \ +--log_interval=10 \ +--log_offset=0 \ +--checkpoint_interval=500 \ +--max_checkpoints=5 \ +--export_interval=25000 \ +" +# --wandb_status_interval=25000 \ +# --wandb_entity_name=$WANDB_ENTITY_NAME \ +# --wandb_project_name=$PROJECT_NAME \ +# --wandb_group_name=$PROJECT_VERSION \ +# " + +export ALL_ARGS="\ +$CMD_ARGS \ +$MODEL_ARGS_PRETRAINED \ +$MULTIMODAL_ARGS \ +$DATA_ARGS \ +$TRAINING_ARGS \ +$PERFORMANCE_ARGS \ +$MONITORING_ARGS \ +" + +export PROFILE_ARGS="\ +--profile_cuda=1 \ +--profile_skip=10 \ +--profile_wait=95 \ +--profile_warmup=2 \ +--profile_cycles=3 \ +--profile_export=1 \ +" + + +run_local () { # run(name, num_gpus, base_cmd) + echo $1 $2 $3 + export TORCHRUN="torchrun --nproc-per-node=$2 --nnodes=1 --no-python" + $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1 +} + +run_c10d () { # run(name, num_nodes, base_cmd) + echo $1 $2 $3 + export TORCHRUN="torchrun --nproc-per-node=8 --nnodes=$2 --no-python --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR" + $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1 +} + +run_local debug 8 "$ALL_ARGS" +# run_c10d mistral_example 16 "$ALL_ARGS" +# run_c10d mixtral_example 16 "$ALL_ARGS $MIXTRAL_ARGS --train_iters=50" diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 357f0e568..ee755c9cc 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -10,6 +10,7 @@ class DatasetType(str, enum.Enum): """ gpt = "gpt" + stardoc = "stardoc" class DatasetSource(str, enum.Enum): @@ -23,6 +24,7 @@ class DatasetSource(str, enum.Enum): file = "file" sample = "sample" random = "random" + multimodal = "multimodal" class MultiprocessingContext(str, enum.Enum): @@ -110,7 +112,7 @@ def _validate(self): EOD = "<|endoftext|>" TokenizerFromFile = "TokenizerFromFile" - +PreTrainedTokenizer = "PreTrainedTokenzier" @config_class() class TokenizerConfig(Config): @@ -123,13 +125,39 @@ class TokenizerConfig(Config): default="TokenizerFromFile", desc="Unused.", hint=FieldHint.deprecated, - valid=check_field(Assert.eq, TokenizerFromFile), ) tokenizer_file: str | None = Field( default=None, desc="Path to the tokenizer file.", hint=FieldHint.core, ) + tokenizer_path: str | None = Field( + default=None, + desc="Path to pretrained tokenizer", + hint=FieldHint.core, + ) + # max_seq_length: int | None = Field( + # default=8192, + # desc="Max. sequence length", + # hint=FieldHint.core, + # ) + + # def build_tokenizer(self, max_sequence_length=8192): + # self.validate() + # """Initialize tokenizer.""" + # log_main_rank(f"> building {self.tokenizer_type}, {self.tokenizer_type or self.tokenizer_file} tokenizer ...") + + # # Select and instantiate the tokenizer. + # if self.tokenizer_type == "TokenizerFromFile": + # assert self.tokenizer_file is not None + # tokenizer = Tokenizer(self.tokenizer_file, special_tokens=[EOD]) + # elif self.tokenizer_type == "PreTrainedTokenizer": + # assert self.tokenizer_path is not None + # tokenizer = HuggingfacePreTrainedTokenizer(self.tokenizer_path, max_seq_length=max_sequence_length) + # else: + # raise NotImplementedError(f"{self.tokenizer_type} tokenizer is not implemented.") + + # return tokenizer @config_class() diff --git a/fast_llm/data/data.py b/fast_llm/data/data.py index ba11477d9..f270f8044 100644 --- a/fast_llm/data/data.py +++ b/fast_llm/data/data.py @@ -9,11 +9,12 @@ import torch import torch.utils.data -from fast_llm.data.config import DataConfig, DatasetSource, DatasetType +from fast_llm.data.config import DataConfig, DatasetSource, DatasetType, EOD from fast_llm.data.dataset import BlendedDataset, SampledDataset, Sampler from fast_llm.data.gpt import DummyGPTDataset, GPTDataset, GPTSampledDataset +from fast_llm.data.stardoc import StarDocDataset from fast_llm.data.mmap import MMapIndexedDataset -from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.tokenizer import Tokenizer, HuggingfacePreTrainedTokenizer from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.run.run import get_dataset_cache_dir, is_main_rank, log_main_rank @@ -86,6 +87,12 @@ def __init__( assert len(dataset_prefixes) == len(set(dataset_prefixes)) dataset_weights = normalize_probs([float(x) for x in self._config.data_path[::2]]) self._build_and_sample_dataset = self._build_and_sample_gpt_dataset + elif self._config.dataset_source == DatasetSource.multimodal: + # FastLLM Split logic is overriden. Huggingface dataset defines the split + Assert.eq(len(self._config.data_path), 1) + Assert.eq(self._config.dataset_type, DatasetType.stardoc) + dataset_prefixes, dataset_weights = [None], [1.0] + self._build_and_sample_dataset = self._build_and_sample_stardoc_dataset elif self._config.dataset_source == DatasetSource.sample: Assert.eq(len(self._config.data_path), 1) dataset_prefixes, dataset_weights = [self._config.data_path[0].strip()], [1.0] @@ -115,6 +122,23 @@ def __init__( for name, prefix in zip(dataset_names, dataset_prefixes) } self._dataset_weights = {name: weight for name, weight in zip(dataset_names, dataset_weights)} + + def build_tokenizer(self, max_sequence_length): + """Initialize tokenizer.""" + log_main_rank(f"> building {self._config.tokenizer.tokenizer_type}, {self._config.tokenizer.tokenizer_type or self._config.tokenizer.tokenizer_file} tokenizer ...") + + # Select and instantiate the tokenizer. + if self._config.tokenizer.tokenizer_type == "TokenizerFromFile": + assert self._config.tokenizer.tokenizer_file is not None + tokenizer = Tokenizer(self._config.tokenizer) + elif self._config.tokenizer.tokenizer_type == "PreTrainedTokenizer": + assert self._config.tokenizer.tokenizer_path is not None + tokenizer = HuggingfacePreTrainedTokenizer(self._config.tokenizer, max_sequence_length=max_sequence_length) + else: + raise NotImplementedError(f"{self.config.tokenizer.tokenizer_type} tokenizer is not implemented.") + + return tokenizer + def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]): """ @@ -123,7 +147,7 @@ def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int """ Assert.leq(set(samples_per_phase), set(self._phase_split)) log_main_rank(f"Preparing {self._num_datasets} datasets. This may take several minutes.") - self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.fim_rate > 0 else None + self._tokenizer = self.build_tokenizer(self._max_sequence_length) if (self._config.fim.fim_rate > 0 or self._config.dataset_type == DatasetType.stardoc) else None self._distributed = distributed self._cache_dir = get_dataset_cache_dir() self._samples_per_phase = samples_per_phase @@ -228,3 +252,29 @@ def _build_and_sample_dummy_dataset(self, name: str, dataset_samples_per_phase: ) for phase in dataset_samples_per_phase } + + def _build_and_sample_stardoc_dataset(self, name: str, dataset_samples_per_phase: dict[PhaseType, int]): + #TODO: Only training split implemented for now + sampled_dataset = {} + sampled_dataset[PhaseType.training] = StarDocDataset( + im_size=224, + num_samples=-1, + num_im_tokens=256, + transforms=False, + multi_imgs=True, + split="train", + tokenizer=self._tokenizer, + config=self._config, + ) + sampled_dataset[PhaseType.validation] = StarDocDataset( + im_size=224, + num_samples=-1, + num_im_tokens=256, + transforms=False, + multi_imgs=True, + split="val", + tokenizer=self._tokenizer, + config=self._config, + ) + + return sampled_dataset diff --git a/fast_llm/data/stardoc.py b/fast_llm/data/stardoc.py new file mode 100644 index 000000000..415c35c41 --- /dev/null +++ b/fast_llm/data/stardoc.py @@ -0,0 +1,128 @@ +import os +import io +import logging +import time + +import torch +import torch.nn as nn +from PIL import Image +from torch.utils.data import Dataset +from datasets import load_dataset, load_from_disk + +from fast_llm.data.config import DataConfig +from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.stardoc_data_utils.docowl_processor import DocProcessor +from fast_llm.data.stardoc_data_utils.utils import ( + convert_queries_and_annotations_to_messages, + image_loading_function, +) +from fast_llm.data.stardoc_data_utils.docowl_stardoc_processor import docowl_text_preprocess_v1 + +logger = logging.getLogger(__name__) + +class StarDocDataset(Dataset): + def __init__( + self, + im_size: int = 224, + num_samples: int = -1, + num_im_tokens: int = 256, + transforms: bool = False, + multi_imgs: bool = True, + split: str = "train", + tokenizer: Tokenizer | None = None, + config: DataConfig | None = None, + ): + self.im_size = im_size + self.transforms = transforms + self.num_samples = num_samples + self.num_im_tokens = num_im_tokens + self.multi_imgs = multi_imgs + self.tokenizer = tokenizer + self.split=split + + # Use DocOwl processor + self.processor = DocProcessor(image_size=self.im_size, anchors='grid_9', add_global_img=True, add_textual_crop_indicator=True) + + dataset_path = config.data_path[0] + + # TODO: config validation issue + multimodal_load_local = True + if multimodal_load_local: + # Load from a locally cached copy of the dataset + self.data_dict = load_from_disk(dataset_path) + self.data = self.data_dict[split] + else: + # Load the required spit from HF + # TODO: configurable cache_dir + self.data = load_dataset(dataset_path, split=split, cache_dir="/mnt/core_llm/cache/", num_proc=os.cpu_count()-1) + + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + sample = self.data[idx] + images = sample.get('image', []) + + if self.multi_imgs and not isinstance(images, list): + images = [images] + + images = image_loading_function(self.data[idx]["image"]) + + sample_id = sample["sample_id"] + dataset_name = sample["dataset_name"] + queries = sample["queries"] + annotations = sample["annotations"] + task_name = sample["task_name"] + + if images[0].size[0] < 10 or images[0].size[1] < 10: + logger.error("Dummy images with small resolution of < 10x10 seen. Handling these is not implemented") + + sample_tokenized_buffer = [] + labels = [] + + # Add BOS token at the beginning of the sample + sample_tokenized_buffer.append(self.tokenizer.bos_token_id) + + # Dummy image token ID + dummy_image_token_id = self.tokenizer.tokenize("[control_8]", add_special_tokens=False) + assert len(dummy_image_token_id) == 1 + + # tokenized IDs for "USER:" and "ASSISTANT:" + user_ids = self.tokenizer.tokenize("USER: ", add_special_tokens=False) + assistant_ids = self.tokenizer.tokenize(" ASSISTANT: ", add_special_tokens=False) + sample_tokenized_buffer.extend(user_ids) + + # Add dummy tokens for all image tokens + if len(images) > 0: + # Get all the crops and process them + all_images, _, processed_query = self.processor(images=images, query="[control_8]") + crop_splits = processed_query.split("[control_8]")[:-1] + assert len(crop_splits) == len(all_images) + for crop_split_part in crop_splits: + sample_tokenized_buffer.extend(self.tokenizer.tokenize(crop_split_part.strip(), add_special_tokens=False)) + sample_tokenized_buffer.extend(dummy_image_token_id * self.num_im_tokens) + + # Don't learn on any image tokens + [labels.append(-200) for x in range(len(sample_tokenized_buffer))] + + assert(len(queries) == len(annotations)) + for i, (q, a) in enumerate(zip(queries, annotations)): + if i>0: + sample_tokenized_buffer.extend(user_ids) + sample_tokenized_buffer.extend(self.tokenizer.tokenize(q, add_special_tokens=False)) + sample_tokenized_buffer.extend(assistant_ids) + sample_tokenized_buffer.extend(self.tokenizer.tokenize(a, add_special_tokens=False)) + + # Add EOS token at the end of the sample + sample_tokenized_buffer.append(self.tokenizer.eos_token_id) + labels.extend(sample_tokenized_buffer[len(labels):len(sample_tokenized_buffer)]) + assert len(sample_tokenized_buffer) == len(labels) + + return { + 'input_ids': torch.tensor(sample_tokenized_buffer), + 'labels': torch.tensor(labels), + 'images': all_images, + } \ No newline at end of file diff --git a/fast_llm/data/stardoc_data_utils/constants.py b/fast_llm/data/stardoc_data_utils/constants.py new file mode 100644 index 000000000..db134ee6a --- /dev/null +++ b/fast_llm/data/stardoc_data_utils/constants.py @@ -0,0 +1,9 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "./demo_logs" + +# Model Constants +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = -200 +DEFAULT_IMAGE_TOKEN = "<|image|>" \ No newline at end of file diff --git a/fast_llm/data/stardoc_data_utils/conversation.py b/fast_llm/data/stardoc_data_utils/conversation.py new file mode 100644 index 000000000..c5a9ef4b2 --- /dev/null +++ b/fast_llm/data/stardoc_data_utils/conversation.py @@ -0,0 +1,303 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Tuple +from fast_llm.data.stardoc_data_utils.constants import DEFAULT_IMAGE_TOKEN + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + TWO_NO_SYS = auto() + MPT = auto() + PLAIN = auto() + LLAMA_2 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + if len(messages) > 0 and type(messages[0][1]) is tuple: + messages = self.messages.copy() + init_role, init_msg = messages[0].copy() + # init_msg = init_msg[0].replace("", "").strip() + # if 'mmtag' in self.version: + # messages[0] = (init_role, init_msg) + # messages.insert(0, (self.roles[0], "")) + # messages.insert(1, (self.roles[1], "Received.")) + # else: + # messages[0] = (init_role, "\n" + init_msg) + init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip() + messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg) + + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.TWO_NO_SYS: + seps = [self.sep, self.sep2] + ret = "" + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def get_images(self, return_pil=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + from PIL import Image + msg, image, image_process_mode = msg + if image_process_mode == "Pad": + def expand2square(pil_img, background_color=(122, 116, 104)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image) + elif image_process_mode in ["Default", "Crop"]: + pass + elif image_process_mode == "Resize336": + image = image.resize((336, 336)) + elif image_process_mode == "Resize": + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if longest_edge != max(image.size): + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + else: + raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + + if return_pil: + images.append(image) + else: + buffered = BytesIO() + image.save(buffered, format="PNG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + images.append(img_b64_str) + return images + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + msg, image, image_process_mode = msg + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + img_str = f'user upload image' + msg = img_str + msg.replace('<|image|>', '').strip() + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_vicuna_v0 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "What are the key differences between renewable and non-renewable energy sources?"), + ("Assistant", + "Renewable energy sources are those that can be replenished naturally in a relatively " + "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + "Non-renewable energy sources, on the other hand, are finite and will eventually be " + "depleted, such as coal, oil, and natural gas. Here are some key differences between " + "renewable and non-renewable energy sources:\n" + "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + "energy sources are finite and will eventually run out.\n" + "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + "and other negative effects.\n" + "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + "have lower operational costs than non-renewable sources.\n" + "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + "locations than non-renewable sources.\n" + "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_mplug_owl2 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO_NO_SYS, + sep=" ", + sep2="", +) + +# default_conversation = conv_vicuna_v1 +default_conversation = conv_mplug_owl2 +conv_templates = { + "default": conv_vicuna_v0, + "v0": conv_vicuna_v0, + "v1": conv_vicuna_v1, + "vicuna_v1": conv_vicuna_v1, + "mplug_owl2": conv_mplug_owl2, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) \ No newline at end of file diff --git a/fast_llm/data/stardoc_data_utils/docowl_processor.py b/fast_llm/data/stardoc_data_utils/docowl_processor.py new file mode 100644 index 000000000..427f7825e --- /dev/null +++ b/fast_llm/data/stardoc_data_utils/docowl_processor.py @@ -0,0 +1,218 @@ +from einops import rearrange, repeat +import torch +from torchvision import transforms +from PIL import Image, ImageFile +import random +from torchvision.ops.boxes import box_area + +from torchvision.transforms.transforms import InterpolationMode +from torchvision.transforms import functional as F +import numpy as np +#from icecream import ic + +ImageFile.LOAD_TRUNCATED_IMAGES = True +ImageFile.MAX_IMAGE_PIXELS = None +Image.MAX_IMAGE_PIXELS = None + +def box_iou(boxes1, area1, boxes2, eps=1e-5): + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / (union+eps) + return iou, union + +def anchor_rank(anchors, anchors_areas, input_image_size, eps=1e-5): + # anchors x1 y1 x2 y2 + + # image_size: (h, w) + # xyxy + input_image_bbox = torch.tensor([0, 0, input_image_size[1], input_image_size[0]]).unsqueeze(0) + + boxes1 = anchors + boxes2 = input_image_bbox + boxes3 = anchors.clone() + # y2 + boxes3[:,3] = input_image_size[0]/input_image_size[1]*anchors[:,2] # 用于算分辨率无关的iou + + area1 = anchors_areas + + iou, _ = box_iou(boxes1, area1, boxes2) + iou = iou.squeeze(1) + shape_iou, _ = box_iou(boxes1, area1, boxes3) + shape_iou = shape_iou.diag() + # 优先匹配形状接近 再匹配分辨率接近 + index = torch.argmax(shape_iou*100+iou,dim=0) + return index + +class AnchorResize(torch.nn.Module): + + def __init__(self, image_size, anchors, interpolation=InterpolationMode.BILINEAR, antialias=None): + super().__init__() + # xyxy + self.anchors = torch.tensor( + [[0, 0, _[1]*image_size[1], _[0]*image_size[0]] + for _ in anchors], requires_grad=False + ) + + self.anchor_areas = box_area(self.anchors) + + self.interpolation = interpolation + self.antialias = antialias + + def forward(self, img, skip_resize=False): + """ + Args: + img (PIL Image or Tensor): Image to be scaled. + + Returns: + PIL Image or Tensor: Rescaled image. + """ + selected_anchor = anchor_rank(self.anchors, self.anchor_areas, (img.size[1], img.size[0])) + target_size = self.anchors[selected_anchor][2:].tolist() # w,h + if skip_resize: + # for debug + return selected_anchor + return F.resize(img, [target_size[1],target_size[0]], self.interpolation, max_size=None, antialias=self.antialias), selected_anchor + + def __repr__(self) -> str: + detail = f"(size={self.image_size}, anchor={self.anchors}, interpolation={self.interpolation.value}, antialias={self.antialias})" + return f"{self.__class__.__name__}{detail}" + +grid_dict = { + 'grid_1':[ + (1,1)], + 'grid_4':[ + (1,1), + (1,2),(2,1), + (1,3),(3,1), + (2,2),(1,4),(4,1)], + 'grid_9':[ + (1,1), + (1,2),(2,1), + (1,3),(3,1), + (2,2),(1,4),(4,1), + (1,5),(5,1), + (1,6),(6,1),(2,3),(3,2), + (1,7),(7,1), + (4,2),(2,4),(1,8),(8,1), + (3,3),(1,9),(9,1)], + 'grid_3x3':[ + (3,3)], + 'grid_20':[ + (1, 1), + (1, 2), (2, 1), + (1, 3), (3, 1), (1, 4), (2, 2), (4, 1), + (1, 5), (5, 1), + (1, 6), (2, 3), (3, 2), (6, 1), + (1, 7), (7, 1), + (1, 8), (2, 4), (4, 2), (8, 1), + (1, 9), (3, 3), (9, 1), + (1, 10), (2, 5), (5, 2), (10, 1), + (1, 11), (11, 1), + (2, 6), (3, 4), (4, 3), (6, 2), + (2, 7), (7, 2), + (3, 5), (5, 3), + (2, 8), (4, 4), (8, 2), + (2, 9), (3, 6), (6, 3), (9, 2), + (2, 10), (4, 5), (5, 4), (10, 2)] +} + +class DocProcessor(): + def __init__(self, image_size=224, anchors='grid_9', add_global_img=True, add_textual_crop_indicator=False): + self.add_global_img = add_global_img + self.add_textual_crop_indicator = add_textual_crop_indicator + self.media_token= "[control_8]" + # h,w + if isinstance(image_size, int): + image_size = (image_size, image_size) + self.image_size = image_size + # h,w + anchors = grid_dict[anchors] + self.anchors = [tuple(_) for _ in anchors] + self.anchor_max = max([max(_) for _ in self.anchors]) + # xywh -> xyxy + self.resizer = AnchorResize(image_size=image_size, anchors=anchors, interpolation=InterpolationMode.BICUBIC) + self.old_resizer = transforms.Resize(image_size,interpolation=InterpolationMode.BICUBIC) + self.image_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + def _process_image(self, images): + new_images = [] + new_patch_position = [] + num_image_mult = [] + for image in images: + if self.add_global_img: + nocut_image = self.image_transform(self.old_resizer(image)).unsqueeze(0) + + image, selected_anchor = self.resizer(image) + image_input = self.image_transform(image) # h,w,3 -> 3,h,w + # rearrange(x,'B C (n1 h) (n2 w) -> (B n1 n2) C h w', n1=self.down_sample[0], n2=self.down_sample[1]) + image_input = rearrange(image_input, 'C (num_h h) (num_w w) -> (num_h num_w) C h w', h=self.image_size[0], w=self.image_size[1]) + + if self.add_global_img: + image_input = torch.cat([nocut_image, image_input], dim=0) + + anchor = self.anchors[selected_anchor] # w,h + patch_position = torch.cat([ + repeat(torch.arange(anchor[0]), 'num_h -> num_h num_w 1', num_w=anchor[1]), + repeat(torch.arange(anchor[1]), 'num_w -> num_h num_w 1', num_h=anchor[0])],dim=2) + patch_position = rearrange(patch_position, 'num_h num_w p-> (num_h num_w) p', p=2) # num_patch, (ph,pw) + + if self.add_global_img: + patch_position = torch.cat([torch.ones(1,2).long()*self.anchor_max, patch_position], dim=0) + + new_images.append(image_input) + new_patch_position.append(patch_position) + num_image_mult.append(patch_position.shape[0]) + + new_images = torch.cat(new_images,dim=0) + new_patch_position = torch.cat(new_patch_position, dim=0) + return new_images, new_patch_position, num_image_mult + + def __call__(self, images=None, query=None): + assert images is not None + + if not isinstance(images, list): + images = [images] + image_pils = [] + for image in images: + if isinstance(image, str): + image = Image.open(image).convert('RGB') + else: + image = image.convert('RGB') + # ic(image.size) + image_pils.append(image) + + image_data, patch_position, num_image_mult = self._process_image(image_pils) + + assert self.media_token in query + text_list = query.split(self.media_token) + text = text_list[0] + image_token_ptr = 0 + for next_text in text_list[1:]: + if self.add_textual_crop_indicator: + # generate image placeholders with interleaved texutual crop indicator + # e.g. <|image|><|image|><|image|>... + for patch_pos in patch_position.tolist(): + # global non-crop image + if patch_pos[0] == self.anchor_max and patch_pos[1] == self.anchor_max: + text += '[control_8]' + else: + row_col = 'row'+str(patch_pos[0])+'_col'+str(patch_pos[1]) + text += '[control_8]' + else: + # generate successive image placeholders for a image, 1 crop img == 1 <|image|> + text += '[control_8]'*num_image_mult[image_token_ptr] + text += next_text + image_token_ptr += 1 + + return image_data, patch_position, text \ No newline at end of file diff --git a/fast_llm/data/stardoc_data_utils/docowl_stardoc_processor.py b/fast_llm/data/stardoc_data_utils/docowl_stardoc_processor.py new file mode 100644 index 000000000..88c6b1023 --- /dev/null +++ b/fast_llm/data/stardoc_data_utils/docowl_stardoc_processor.py @@ -0,0 +1,110 @@ +import torch +import transformers +from fast_llm.data.stardoc_data_utils import conversation as conversation_lib +from fast_llm.data.stardoc_data_utils.mm_utils import tokenizer_image_token +from fast_llm.data.stardoc_data_utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from typing import Dict + + +def docowl_text_preprocess_v1( + source, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False, + split: str = "train", +) -> Dict: + """ + source: list of {'role':'user'/'assistant', 'content':xxxx} + """ + conv = conversation_lib.conv_mplug_owl2.copy() + # conv.roles: ("USER", "ASSISTANT") + roles = {"user": conv.roles[0], "assistant": conv.roles[1]} + + if split == "train" or split == "val" or split == "test": + + # Apply prompt templates + conversations = [] + + # Skip the first one if it is not from human + if roles[source[0]["role"]] != conv.roles[0]: + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["role"]] + assert role == conv.roles[j % 2] + conv.append_message(role, sentence["content"]) + + # conv.get_prompt(): USER: {content} ASSISTANT: {content}USER: {content} ASSISTANT: {content}... + conversations.append(conv.get_prompt()) + + # Tokenize conversations + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer.tokenize( + conversations, + return_tensors="pt", + padding="longest", + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " # ' ASSISTANT: ' + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) # split by + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) # split each round by ' ASSISTANT: ' + if len(parts) != 2: + break + parts[0] += sep # input query, ignore for loss + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.max_seq_length: # ignore padding + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + else: + text = source[0]["content"] + roles = conv.roles # ("USER", "ASSISTANT") + conv.append_message(conv.roles[0], text) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0) + stop_str = conv.sep2 + keywords = [stop_str] + return dict( + input_ids=input_ids, + labels=input_ids, + stop_str=stop_str, + keywords=keywords, + ) \ No newline at end of file diff --git a/fast_llm/data/stardoc_data_utils/mm_utils.py b/fast_llm/data/stardoc_data_utils/mm_utils.py new file mode 100644 index 000000000..01cb71053 --- /dev/null +++ b/fast_llm/data/stardoc_data_utils/mm_utils.py @@ -0,0 +1,111 @@ +from PIL import Image +from io import BytesIO +import base64 + +import torch +from transformers import StoppingCriteria +from fast_llm.data.stardoc_data_utils.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN + + +def load_image_from_base64(image): + return Image.open(BytesIO(base64.b64decode(image))) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def process_images(images, image_processor, model_cfg=None): + if model_cfg is not None: + image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) + else: + image_aspect_ratio = 'resize' + new_images = [] + if image_aspect_ratio == 'pad': + for image in images: + image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) + image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + new_images.append(image) + elif image_aspect_ratio == 'resize': + for image in images: + max_edge = max(image.size) + image = image.resize((max_edge, max_edge)) + image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + new_images.append(image) + else: + return image_processor(images, return_tensors='pt')['pixel_values'] + if all(x.shape == new_images[0].shape for x in new_images): + new_images = torch.stack(new_images, dim=0) + return new_images + + +def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer.tokenize(chunk, max_length=tokenizer.max_seq_length, truncation=True) if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + + + + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + self.max_keyword_len = 0 + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: + cur_keyword_ids = cur_keyword_ids[1:] + if len(cur_keyword_ids) > self.max_keyword_len: + self.max_keyword_len = len(cur_keyword_ids) + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO + offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) + self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] + for keyword_id in self.keyword_ids: + if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): + return True + outputs = self.tokenizer.detokenize_batch(output_ids[:, -offset:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False \ No newline at end of file diff --git a/fast_llm/data/stardoc_data_utils/utils.py b/fast_llm/data/stardoc_data_utils/utils.py new file mode 100644 index 000000000..15a9a46d2 --- /dev/null +++ b/fast_llm/data/stardoc_data_utils/utils.py @@ -0,0 +1,33 @@ +from PIL import Image +import io + +def convert_queries_and_annotations_to_messages(queries, annotations): + messages = [] + # Add each query and annotation as a user-assistant pair + for i, (q, a) in enumerate(zip(queries, annotations)): + if i == 0: + # Prepend "<|image|>" to the first query + q = f"<|image|>{q}" + messages.append({"role": "user", "content": q}) + messages.append({"role": "assistant", "content": a}) + return messages + +def image_loading_function(images): + """ + Load an image from a file path + """ + assert images is not None + if not isinstance(images, list): + images = [images] + image_pils = [] + for image in images: + if isinstance(image, bytes): + image = Image.open(io.BytesIO(image)) + elif isinstance(image, str): + image = Image.open(image) + elif isinstance(image, Image.Image): + pass + else: + raise ValueError(f"Unsupported image type: {type(image)}") + image_pils.append(image) + return image_pils \ No newline at end of file diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 5f3116de0..825a460c9 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -40,3 +40,69 @@ def detokenize(self, token_ids): @property def eod(self): return self.eod_id + +class HuggingfacePreTrainedTokenizer: + """ + A Huggingface (transformers) tokenizer which uses from_pretrained() to load tokenizer + """ + + def __init__(self, config: TokenizerConfig, max_sequence_length: int): + log_main_rank(f"> loading tokenizer from {config.tokenizer_file} ...") + + self.tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_path) + # self.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) + self._inv_vocab = {v: k for k, v in self.tokenizer.vocab.items()} + self._max_sequence_length = max_sequence_length + + @property + def vocab_size(self): + return len(self.tokenizer) + + @property + def vocab(self): + return self.tokenizer.vocab + + @property + def inv_vocab(self): + return self._inv_vocab + + @property + def max_seq_length(self): + return self._max_sequence_length + + @property + def bos_token_id(self): + if self.tokenizer.bos_token_id: + return self.tokenizer.bos_token_id + else: + raise ValueError("BOS token not set in tokenizer") + + @property + def eos_token_id(self): + if self.tokenizer.eos_token_id: + return self.tokenizer.eos_token_id + else: + raise ValueError("EOS token not set in tokenizer") + + @property + def pad_token_id(self): + if self.tokenizer.pad_token_id: + log_main_rank("PAD token being set to EOS token") + return self.tokenizer.pad_token_id + else: + return self.tokenizer.eos_token_id + + def tokenize(self, text, **kwargs): + return self.tokenizer.encode(text, **kwargs) + + def detokenize(self, token_ids, **kwargs): + return self.tokenizer.decode(token_ids, **kwargs) + + def detokenize_batch(self, token_ids, **kwargs): + return self.tokenizer.batch_decode(token_ids, **kwargs) + + @property + def eod(self): + return self.eod_id + + diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 437aa7b08..2985f766e 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -239,6 +239,8 @@ def setup( def _replace(module: torch.nn.Module): nonlocal i for key in module._parameters: # noqa + if module._parameters[key] is None: + continue meta = typing.cast(ParameterMeta, module._parameters[key]) # noqa module._parameters[key] = self._parameter_buffers[self._parameter_index[meta.tensor_name]] # noqa i += 1 diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 9b6c78d8d..7a0d8b3b5 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -4,6 +4,7 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig +from fast_llm.layers.multimodal_model.config import MultimodalModelArchitectureConfig, MultimodalModelBaseConfig from fast_llm.utils import Assert @@ -26,6 +27,7 @@ class LanguageModelKwargs: # TODO: These are generic labels = "labels" phase = "phase" + tokens = "tokens" @config_class() @@ -35,6 +37,11 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.core, ) + multimodal_model: MultimodalModelArchitectureConfig = Field( + default_factory=MultimodalModelArchitectureConfig, + desc="Configuration for the multimodal components (image encoder and adapter).", + hint=FieldHint.core, + ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -63,6 +70,7 @@ def _validate(self): def setup_tensor_space(self, tensor_space: TensorSpace): self.transformer.setup_tensor_space(tensor_space) + self.multimodal_model.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions @@ -98,6 +106,11 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): transformer: TransformerConfig = Field( default_factory=TransformerConfig, desc="Configuration for the transformer.", hint=FieldHint.core ) + multimodal_model: MultimodalModelBaseConfig = Field( + default_factory=MultimodalModelBaseConfig, + desc="Configuration for the multimodal components (image encoder and adapter).", + hint=FieldHint.core, + ) init_method_std_embed: float = Field( default=None, desc="Initialization scale for the vocabulary embedding and output weights (logits).", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 611a63309..18abdf7d9 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -174,6 +174,7 @@ def _logits_cross_entropy_forward_backward( kwargs: dict, losses: dict | None = None, ): + print(f'Loss head input_ shape: {input_.shape} Labels shape: {labels.shape}') logits, context = output_parallel_linear_forward( input_=input_, weight=weight, diff --git a/fast_llm/layers/multimodal_model/adapter.py b/fast_llm/layers/multimodal_model/adapter.py new file mode 100644 index 000000000..b70d21391 --- /dev/null +++ b/fast_llm/layers/multimodal_model/adapter.py @@ -0,0 +1,58 @@ +import logging +import copy +import torch +from torch import nn + +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.multimodal_model.config import MultimodalModelBaseConfig, MultimodalModelDimNames, MultimodalModelKwargs +from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.tensor import ParameterMeta, TensorMeta, TensorSpace, TensorDim, init_normal_ + +logger = logging.getLogger(__name__) + +class Adapter(torch.nn.Module): + + # Ensure the layer is on its own stage. + layer_count: float = 1000.0 + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + ): + super(Adapter, self).__init__() + self._distributed_config = tensor_space.distributed_config + self._tensor_space = tensor_space + self._residual_dtype = ( + self._distributed_config.optimization_dtype + if config.transformer.full_precision_residual + else self._distributed_config.training_dtype + ).torch + + in_dim = self._tensor_space.get_tensor_dim(MultimodalModelDimNames.image_encoder_hidden_size) + out_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + + self.dropout = nn.Dropout(p=0.1) + self.adapter_fc = Linear( + in_dim, + out_dim, + bias=True, + weight_init_method=init_normal_(std=config.transformer.init_method_std), + ) + + def _forward(self, input_: torch.Tensor, losses: dict | None = None, metrics: dict | None = None): + hidden_states = self.dropout(input_) + out = self.adapter_fc(hidden_states) + + return out.to(dtype=self._residual_dtype) + + def forward(self, input_, kwargs, losses: dict | None = None, metrics: dict | None = None): + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[MultimodalModelKwargs.adapter_hidden_dims], + tensor_name="Adapter output", + dtype=self._residual_dtype, + ) + + return self._forward(input_) \ No newline at end of file diff --git a/fast_llm/layers/multimodal_model/config.py b/fast_llm/layers/multimodal_model/config.py new file mode 100644 index 000000000..be689fab4 --- /dev/null +++ b/fast_llm/layers/multimodal_model/config.py @@ -0,0 +1,76 @@ +import enum +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace + +from fast_llm.utils import Assert + +class MultimodalModelDimNames: + # Image encoder dimensions + max_num_images = "max_num_images" + image_pixel_count = "image_pixel_count" + num_image_tokens = "num_image_tokens" + image_encoder_hidden_size = "image_encoder_hidden_size" + +class MultimodalModelKwargs: + image_encoder_hidden_dims = "image_encoder_hidden_dims" + adapter_hidden_dims = "adapter_hidden_dims" + +class ImageEncoderType(str, enum.Enum): + clip = "clip" + docowl = "docowl" + +@config_class() +class MultimodalModelArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + + image_encoder_hidden_size: int = Field( + default=1024, + desc="Hidden size of image encoder.", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + num_image_tokens: int = Field( + default=256, + desc="Number of image tokens.", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + max_num_images: int = Field( + default=10, + desc="Max. number of images in a sample. We pad to ensure shapes are consistent.", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + image_resolution: int = Field( + default=448, + desc="Resolution of image", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + + def _validate(self): + super()._validate() + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(MultimodalModelDimNames.max_num_images, self.max_num_images)) + tensor_space.add_tensor_dim(TensorDim(MultimodalModelDimNames.num_image_tokens, self.num_image_tokens)) + tensor_space.add_tensor_dim(TensorDim(MultimodalModelDimNames.image_pixel_count, self.image_resolution * self.image_resolution)) + tensor_space.add_tensor_dim(TensorDim(MultimodalModelDimNames.image_encoder_hidden_size, self.image_encoder_hidden_size)) + + +@config_class() +class MultimodalModelBaseConfig(MultimodalModelArchitectureConfig, BaseModelConfig): + """ + A configuration class for defining the model configuration of encoder and adapter components of multi-modal model. + """ + _abstract = False + + image_encoder_type: ImageEncoderType = Field( + default=ImageEncoderType.clip, + desc="Type of image encoder", + hint=FieldHint.feature, + ) + + def _validate(self): + super()._validate() \ No newline at end of file diff --git a/fast_llm/layers/multimodal_model/image_encoder.py b/fast_llm/layers/multimodal_model/image_encoder.py new file mode 100644 index 000000000..d26305d74 --- /dev/null +++ b/fast_llm/layers/multimodal_model/image_encoder.py @@ -0,0 +1,100 @@ +import logging +import copy +import torch + +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs +from fast_llm.layers.multimodal_model.config import MultimodalModelKwargs, MultimodalModelBaseConfig +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.tensor import ParameterMeta, TensorMeta, TensorDim, init_normal_ + +logger = logging.getLogger(__name__) + +class ImageEncoder(torch.nn.Module): + + # Ensure the layer is on its own stage. + layer_count: float = 1000.0 + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + ): + super(ImageEncoder, self).__init__() + self._distributed_config = tensor_space.distributed_config + self._tensor_space = tensor_space + self._residual_dtype = ( + self._distributed_config.optimization_dtype + if config.transformer.full_precision_residual + else self._distributed_config.training_dtype + ).torch + self.image_encoder_type = config.multimodal_model.image_encoder_type + + if self.image_encoder_type.lower() == "clip": + import open_clip + + model, _, _ = open_clip.create_model_and_transforms( + "ViT-L-14", pretrained="laion2b_s32b_b82k" + ) + + self.visual_encoder = model.visual + self.visual_encoder.output_tokens = True + self.ln_vision = copy.deepcopy(self.visual_encoder.ln_post) + else: + logger.error(f'Unknown image encoder specified: {self.image_encoder_type.lower()}') + + # Replace all parameters with Parameter(MetaParameter(...)) + with torch.no_grad(): + for name, param in self.named_parameters(): + module = self + name_parts = name.split('.') + # We have to traverse to the correct parent module and change the parameter there + for part in name_parts[:-1]: + module = getattr(module, part) + + # Replace prameter with FastLLM meta parameter + setattr(module, name_parts[-1], self.get_fastllm_parameter(name, param)) + + def get_fastllm_parameter(self, param_name, param): + param_dims = tuple([TensorDim(name=f'{param_name}_{idx}', global_size=x, parallel_dim=None) for idx, x in enumerate(param.shape)]) + return ParameterMeta(param.to("meta"), tensor_name=param_name, dims=param_dims, init_method=init_normal_(std=0.02)) + + def _forward(self, input_: tuple[torch.Tensor], losses: dict | None = None, metrics: dict | None = None): + if not self.image_encoder_type.lower() == "clip": + raise ValueError(f'clip is the only image encoder type currrently supported') + + # TODO: Remove padding images + # _bsz_im, num_img, ch, im_width, im_height = image_input + # image_input = image_input.view(_bsz_im * num_img, *image_input.shape[2:]) + # num_values_per_image = image_input.shape[1:].numel() + # real_images_inds = (image_input == 0.0).sum(dim=(-1, -2, -3)) != num_values_per_image + # image_input = image_input[real_images_inds].contiguous() + + # (bsz, num_img, ch, im_h, im_w) -> (bsz*num_img, ch, im_h, im_w) + + + # Convert the input images tensor to residual dtype. This is torch.float32 by default + input_ = input_.to(self._residual_dtype) + + _bsz_im, num_img, ch, im_width, im_height = input_.shape + input_ = input_.view(_bsz_im * num_img, *input_.shape[2:]).contiguous() + + out = self.visual_encoder(input_)[1] + out = self.ln_vision(out) + + # (bsz*num_img, im_tokens, h) -> (bsz, num_img, im_tokens, h) + out = out.view(_bsz_im, num_img, *out.shape[1:]).contiguous() + + return out.to(dtype=self._residual_dtype) + + def forward(self, input_, kwargs, losses: dict | None = None, metrics: dict | None = None): + if input_ is None: + raise ValueError(f'You must define a max_num_images > 0 if image_encoder is enabled') + + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[MultimodalModelKwargs.image_encoder_hidden_dims], + tensor_name="Image encoder output", + dtype=self._residual_dtype, + ) + + return self._forward(input_) \ No newline at end of file diff --git a/fast_llm/layers/multimodal_model/multimodal_language_embedding.py b/fast_llm/layers/multimodal_model/multimodal_language_embedding.py new file mode 100644 index 000000000..64301d9c0 --- /dev/null +++ b/fast_llm/layers/multimodal_model/multimodal_language_embedding.py @@ -0,0 +1,99 @@ +import torch + +from fast_llm.core.distributed import set_generator +from fast_llm.core.ops import reduce_forward, split +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +from fast_llm.utils import Assert + +WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" + + +class MultiModalLanguageModelEmbedding(Layer): + """ + An embedding layer that fuses multi-modal features with language embedding. + Consists of multi-modal embeddings (tensor-parallel), + together with optional absolute position embeddings and dropout. + """ + + # Ensure the layer is on its own stage. + layer_count: float = 1000.0 + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + ): + super().__init__() + config.validate() + self._distributed_config = tensor_space.distributed_config + self._tensor_space = tensor_space + self._residual_dtype = ( + self._distributed_config.optimization_dtype + if config.transformer.full_precision_residual + else self._distributed_config.training_dtype + ).torch + self._group_size = self._distributed_config.tensor_parallel + self._sequence_parallel = self._distributed_config.sequence_tensor_parallel + self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._dropout_p = config.transformer.hidden_dropout + self._use_absolute_position_embeddings = config.use_absolute_position_embeddings + + hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + vocab_dim = tensor_space.get_tensor_dim( + LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab + ) + + if self._parallel_embeddings: + self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size + self._vocab_end_index = (self._distributed_config.tensor_rank + 1) * vocab_dim.size + + self.word_embeddings_weight = ParameterMeta.from_dims( + (vocab_dim, hidden_dim), + init_method=init_normal_(std=config.init_method_std_embed), + ) + if self._use_absolute_position_embeddings: + self.position_embeddings_weight = ParameterMeta.from_dims( + (tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim), + init_method=init_normal_(std=config.init_method_std_embed), + allow_sequence_tensor_parallel=not config.parallel_embeddings, + ) + + @torch.compile + def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, tokens: torch.Tensor | None): + Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) + Assert.eq(tokens is not None) + group = self._tensor_space.distributed.tensor_group + + text_embeddings = torch.embedding(self.word_embeddings_weight, tokens) + + bsz, num_imgs, _, hidden_size = input_.shape + + # TODO: Hardcoded image token + image_token_mask = tokens == 10 + + embeddings = text_embeddings.clone() + embeddings[image_token_mask] = input_.view(-1, hidden_size) + + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + + with set_generator( + self._tensor_space.distributed.tp_generator + if self._sequence_parallel + else self._tensor_space.distributed.pp_generator + ): + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + return embeddings.to(dtype=self._residual_dtype) + + def forward(self, input_, kwargs, losses: dict | None = None, metrics: dict | None = None): + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Embedding output", + dtype=self._residual_dtype, + ) + return self._forward(input_, kwargs.get(LanguageModelKwargs.position_ids), kwargs.get(LanguageModelKwargs.tokens)) \ No newline at end of file diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index f1b534035..f94708b74 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -1,10 +1,12 @@ from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig +from fast_llm.models.stardoc.config import StarDocModelConfig, StarDocTrainerConfig from fast_llm.utils import Registry model_registry = Registry( "Model", { "gpt": GPTModelConfig, + "stardoc": StarDocModelConfig, }, ) @@ -12,5 +14,6 @@ "Model", { "gpt": GPTTrainerConfig, + "stardoc": StarDocTrainerConfig, }, ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index ebfb579f2..3b9d1dd6e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -98,7 +98,7 @@ def preprocess_meta(self, input_: BatchConfig | torch.Tensor, phase: PhaseType) if phase != PhaseType.inference: sequence_length -= 1 micro_sequence_length = sequence_length - + batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) diff --git a/fast_llm/models/stardoc/__init__.py b/fast_llm/models/stardoc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/models/stardoc/config.py b/fast_llm/models/stardoc/config.py new file mode 100644 index 000000000..bc8efcae5 --- /dev/null +++ b/fast_llm/models/stardoc/config.py @@ -0,0 +1,102 @@ +import typing + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig +from fast_llm.engine.training.config import TrainerConfig +from fast_llm.layers.language_model.config import LanguageModelArchitectureConfig, LanguageModelBaseConfig + +if typing.TYPE_CHECKING: + from fast_llm.engine.multi_stage.conversion import ModelConverter + + +@config_class() +class StarDocArchitectureConfig(LanguageModelArchitectureConfig): + _abstract = False + + @classmethod + def _from_dict( + cls, + default: dict, + strict: bool = True, + flat: bool = False, + ): + # TODO v0.2: Remove backward compatibility fix + if "transposed_mlp_weight" in default: + assert default.pop("transposed_mlp_weight") + return super()._from_dict(default, strict, flat) + + @classmethod + def get_converter_class(cls, model_type: str | None = None) -> type["ModelConverter"]: + from fast_llm.models.stardoc.conversion import AutoStarDocConverter + + return AutoStarDocConverter if model_type is None else AutoStarDocConverter.converter_map[model_type] + + +@config_class() +class StarDocBaseModelConfig(LanguageModelBaseConfig, StarDocArchitectureConfig): + architecture_cls = StarDocArchitectureConfig + + @classmethod + def _from_dict( + cls, + default: dict, + strict: bool = True, + flat: bool = False, + ): + # TODO v0.2: Remove backward compatibility fix + if "layer_norm_impl" in default: + assert "normalization_implementation" not in default + default["normalization_implementation"] = default.pop("layer_norm_impl") + if "fused_mlp" in default: + del default["fused_mlp"] + return super()._from_dict(default, strict, flat) + + +@config_class() +class StarDocModelConfig(FastLLMModelConfig): + _abstract = False + base_model: StarDocBaseModelConfig = Field(default_factory=StarDocBaseModelConfig) + + @classmethod + def get_model_class(cls): + from fast_llm.models.stardoc.model import StarDocModel + + return StarDocModel + + @classmethod + def get_huggingface_model_class(cls): + from fast_llm.models.stardoc.huggingface import HuggingfaceStarDocModelForCausalLM + + return HuggingfaceStarDocModelForCausalLM + + +@config_class() +class PretrainedStarDocModelConfig(PretrainedFastLLMModelConfig): + _abstract = False + model: StarDocModelConfig = Field(default_factory=StarDocModelConfig) + + +@config_class() +class StarDocTrainerConfig(PretrainedStarDocModelConfig, TrainerConfig): + def _setup(self): + super()._setup() + if self.batch.sequence_length is None: + # TODO: Drop this. + self.batch.sequence_length = self.base_model.max_position_embeddings + + @classmethod + def get_trainer_class(cls): + from fast_llm.models.stardoc.trainer import StarDocTrainer + + return StarDocTrainer + + +class HuggingfaceModelType: + """ + An enum for the huggingface models with conversion support. + """ + + starcoder2 = "starcoder2" + llama = "llama" + mistral = "mistral" + mixtral = "mixtral" \ No newline at end of file diff --git a/fast_llm/models/stardoc/conversion.py b/fast_llm/models/stardoc/conversion.py new file mode 100644 index 000000000..04530e305 --- /dev/null +++ b/fast_llm/models/stardoc/conversion.py @@ -0,0 +1,644 @@ +import abc +import math +import typing + +import torch + +from fast_llm.engine.multi_stage.conversion import ( + AutoModelConverter, + ConstantExportParamConverter, + ConstantImportParamConverter, + HuggingfaceModelConverter, + IgnoreImportParamConverter, + IgnoreWeightConverter, + MappedConfigParamConverter, + ParamConverter, + SplitWeightConverter, + WeightConverter, +) +from fast_llm.functional.config import ActivationType +from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex +from fast_llm.layers.common.config import NormalizationType +from fast_llm.layers.transformer.config import RoutingType +from fast_llm.models.stardoc.config import StarDocArchitectureConfig, StarDocBaseModelConfig, HuggingfaceModelType +from fast_llm.tensor import SafeTensorSlice + +if typing.TYPE_CHECKING: + pass + + +class QueryWeightConverter(WeightConverter): + # Hf uses the real format for rotary embeddings. + _config: StarDocArchitectureConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (query,) = weight + if self._config.transformer.complex_rotary_embeddings: + query = convert_rotary_complex_to_real(query[:], self._config.transformer.kv_channels, 0) + return (query,) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (query,) = weight + if self._config.transformer.complex_rotary_embeddings: + query = convert_rotary_real_to_complex(query[:], self._config.transformer.kv_channels, 0) + return (query,) + + +class KeyValueWeightConverter(WeightConverter): + # Hf uses the real format for rotary embeddings, and keeps the key and value separate. + _config: StarDocArchitectureConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (key_value,) = weight + key, value = key_value[:].chunk(2) + if self._config.transformer.complex_rotary_embeddings: + key = convert_rotary_complex_to_real(key, self._config.transformer.kv_channels, 0) + return key, value + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + key, value = weight + if self._config.transformer.complex_rotary_embeddings: + key = convert_rotary_real_to_complex(key[:], self._config.transformer.kv_channels, 0) + key_value = torch.cat([key[:], value[:]]) + return (key_value,) + + +class MLPLayer2Converter(WeightConverter): + # Similar to SplitWeightConverter, but handles the optional MLP transpose. + # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) + _config: StarDocArchitectureConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (merged_weight,) = weight + return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) + return (merged_weight.t().contiguous(),) + + +class CommonHuggingfaceConverter(HuggingfaceModelConverter): + config: StarDocArchitectureConfig + _base_model_cls = StarDocBaseModelConfig + """ + Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) + """ + + @abc.abstractmethod + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): + pass + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantImportParamConverter(("multimodal_model", "image_encoder_hidden_size",), None, 1024), + ConstantImportParamConverter(("multimodal_model", "num_image_tokens",), None, 256), + ConstantImportParamConverter(("multimodal_model", "max_num_images",), None, 10), + ConstantImportParamConverter(("use_position_embeddings",), None, False), + ConstantImportParamConverter(("transformer", "use_rotary_embeddings"), None, True), + MappedConfigParamConverter( + ("transformer", "rotary_embedding_scale"), "rope_theta", lambda x: -math.log(x), lambda x: math.exp(-x) + ), + MappedConfigParamConverter( + ("transformer", "activation_type"), + "hidden_act", + ActivationType.from_hf_name, + lambda activation_type: activation_type.hf_name, + ), + ParamConverter(("transformer", "num_layers"), "num_hidden_layers"), + ParamConverter(("transformer", "hidden_size"), "hidden_size"), + ParamConverter(("transformer", "num_attention_heads"), "num_attention_heads"), + ParamConverter(("transformer", "head_groups"), "num_key_value_heads"), + ParamConverter(("transformer", "ffn_hidden_size"), "intermediate_size"), + ParamConverter(("vocab_size",), "vocab_size"), + ParamConverter(("tie_word_embeddings",), "tie_word_embeddings"), + ] + + def _create_weight_converters(self) -> list[WeightConverter]: + converters = [] + num_layers = self.config.transformer.num_layers + norm_bias: bool = self.config.transformer.normalization.normalization_type == NormalizationType.layer_norm + linear_bias: bool = self.config.transformer.add_linear_biases + + # Vision encoder + converters.append(WeightConverter("layers.0.visual_encoder.class_embedding", "visual_encoder.class_embedding")) + converters.append(WeightConverter("layers.0.visual_encoder.positional_embedding", "visual_encoder.positional_embedding")) + converters.append(WeightConverter("layers.0.visual_encoder.proj", "visual_encoder.proj")) + converters.append(WeightConverter("layers.0.visual_encoder.conv1.weight", "visual_encoder.conv1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.ln_pre.weight", "visual_encoder.ln_pre.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.ln_pre.bias", "visual_encoder.ln_pre.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.ln_1.weight", "visual_encoder.transformer.resblocks.0.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.ln_1.bias", "visual_encoder.transformer.resblocks.0.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.attn.in_proj_weight", "visual_encoder.transformer.resblocks.0.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.attn.in_proj_bias", "visual_encoder.transformer.resblocks.0.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.attn.out_proj.weight", "visual_encoder.transformer.resblocks.0.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.attn.out_proj.bias", "visual_encoder.transformer.resblocks.0.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.ln_2.weight", "visual_encoder.transformer.resblocks.0.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.ln_2.bias", "visual_encoder.transformer.resblocks.0.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.0.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.0.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.0.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.0.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.0.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.ln_1.weight", "visual_encoder.transformer.resblocks.1.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.ln_1.bias", "visual_encoder.transformer.resblocks.1.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.attn.in_proj_weight", "visual_encoder.transformer.resblocks.1.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.attn.in_proj_bias", "visual_encoder.transformer.resblocks.1.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.attn.out_proj.weight", "visual_encoder.transformer.resblocks.1.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.attn.out_proj.bias", "visual_encoder.transformer.resblocks.1.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.ln_2.weight", "visual_encoder.transformer.resblocks.1.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.ln_2.bias", "visual_encoder.transformer.resblocks.1.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.1.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.1.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.1.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.1.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.1.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.ln_1.weight", "visual_encoder.transformer.resblocks.2.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.ln_1.bias", "visual_encoder.transformer.resblocks.2.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.attn.in_proj_weight", "visual_encoder.transformer.resblocks.2.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.attn.in_proj_bias", "visual_encoder.transformer.resblocks.2.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.attn.out_proj.weight", "visual_encoder.transformer.resblocks.2.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.attn.out_proj.bias", "visual_encoder.transformer.resblocks.2.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.ln_2.weight", "visual_encoder.transformer.resblocks.2.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.ln_2.bias", "visual_encoder.transformer.resblocks.2.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.2.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.2.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.2.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.2.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.2.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.ln_1.weight", "visual_encoder.transformer.resblocks.3.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.ln_1.bias", "visual_encoder.transformer.resblocks.3.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.attn.in_proj_weight", "visual_encoder.transformer.resblocks.3.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.attn.in_proj_bias", "visual_encoder.transformer.resblocks.3.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.attn.out_proj.weight", "visual_encoder.transformer.resblocks.3.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.attn.out_proj.bias", "visual_encoder.transformer.resblocks.3.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.ln_2.weight", "visual_encoder.transformer.resblocks.3.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.ln_2.bias", "visual_encoder.transformer.resblocks.3.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.3.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.3.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.3.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.3.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.3.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.ln_1.weight", "visual_encoder.transformer.resblocks.4.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.ln_1.bias", "visual_encoder.transformer.resblocks.4.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.attn.in_proj_weight", "visual_encoder.transformer.resblocks.4.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.attn.in_proj_bias", "visual_encoder.transformer.resblocks.4.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.attn.out_proj.weight", "visual_encoder.transformer.resblocks.4.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.attn.out_proj.bias", "visual_encoder.transformer.resblocks.4.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.ln_2.weight", "visual_encoder.transformer.resblocks.4.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.ln_2.bias", "visual_encoder.transformer.resblocks.4.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.4.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.4.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.4.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.4.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.4.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.ln_1.weight", "visual_encoder.transformer.resblocks.5.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.ln_1.bias", "visual_encoder.transformer.resblocks.5.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.attn.in_proj_weight", "visual_encoder.transformer.resblocks.5.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.attn.in_proj_bias", "visual_encoder.transformer.resblocks.5.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.attn.out_proj.weight", "visual_encoder.transformer.resblocks.5.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.attn.out_proj.bias", "visual_encoder.transformer.resblocks.5.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.ln_2.weight", "visual_encoder.transformer.resblocks.5.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.ln_2.bias", "visual_encoder.transformer.resblocks.5.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.5.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.5.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.5.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.5.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.5.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.ln_1.weight", "visual_encoder.transformer.resblocks.6.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.ln_1.bias", "visual_encoder.transformer.resblocks.6.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.attn.in_proj_weight", "visual_encoder.transformer.resblocks.6.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.attn.in_proj_bias", "visual_encoder.transformer.resblocks.6.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.attn.out_proj.weight", "visual_encoder.transformer.resblocks.6.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.attn.out_proj.bias", "visual_encoder.transformer.resblocks.6.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.ln_2.weight", "visual_encoder.transformer.resblocks.6.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.ln_2.bias", "visual_encoder.transformer.resblocks.6.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.6.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.6.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.6.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.6.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.6.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.ln_1.weight", "visual_encoder.transformer.resblocks.7.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.ln_1.bias", "visual_encoder.transformer.resblocks.7.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.attn.in_proj_weight", "visual_encoder.transformer.resblocks.7.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.attn.in_proj_bias", "visual_encoder.transformer.resblocks.7.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.attn.out_proj.weight", "visual_encoder.transformer.resblocks.7.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.attn.out_proj.bias", "visual_encoder.transformer.resblocks.7.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.ln_2.weight", "visual_encoder.transformer.resblocks.7.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.ln_2.bias", "visual_encoder.transformer.resblocks.7.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.7.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.7.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.7.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.7.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.7.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.ln_1.weight", "visual_encoder.transformer.resblocks.8.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.ln_1.bias", "visual_encoder.transformer.resblocks.8.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.attn.in_proj_weight", "visual_encoder.transformer.resblocks.8.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.attn.in_proj_bias", "visual_encoder.transformer.resblocks.8.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.attn.out_proj.weight", "visual_encoder.transformer.resblocks.8.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.attn.out_proj.bias", "visual_encoder.transformer.resblocks.8.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.ln_2.weight", "visual_encoder.transformer.resblocks.8.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.ln_2.bias", "visual_encoder.transformer.resblocks.8.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.8.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.8.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.8.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.8.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.8.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.ln_1.weight", "visual_encoder.transformer.resblocks.9.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.ln_1.bias", "visual_encoder.transformer.resblocks.9.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.attn.in_proj_weight", "visual_encoder.transformer.resblocks.9.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.attn.in_proj_bias", "visual_encoder.transformer.resblocks.9.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.attn.out_proj.weight", "visual_encoder.transformer.resblocks.9.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.attn.out_proj.bias", "visual_encoder.transformer.resblocks.9.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.ln_2.weight", "visual_encoder.transformer.resblocks.9.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.ln_2.bias", "visual_encoder.transformer.resblocks.9.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.9.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.9.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.9.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.9.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.9.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.ln_1.weight", "visual_encoder.transformer.resblocks.10.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.ln_1.bias", "visual_encoder.transformer.resblocks.10.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.attn.in_proj_weight", "visual_encoder.transformer.resblocks.10.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.attn.in_proj_bias", "visual_encoder.transformer.resblocks.10.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.attn.out_proj.weight", "visual_encoder.transformer.resblocks.10.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.attn.out_proj.bias", "visual_encoder.transformer.resblocks.10.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.ln_2.weight", "visual_encoder.transformer.resblocks.10.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.ln_2.bias", "visual_encoder.transformer.resblocks.10.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.10.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.10.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.10.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.10.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.10.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.ln_1.weight", "visual_encoder.transformer.resblocks.11.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.ln_1.bias", "visual_encoder.transformer.resblocks.11.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.attn.in_proj_weight", "visual_encoder.transformer.resblocks.11.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.attn.in_proj_bias", "visual_encoder.transformer.resblocks.11.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.attn.out_proj.weight", "visual_encoder.transformer.resblocks.11.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.attn.out_proj.bias", "visual_encoder.transformer.resblocks.11.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.ln_2.weight", "visual_encoder.transformer.resblocks.11.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.ln_2.bias", "visual_encoder.transformer.resblocks.11.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.11.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.11.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.11.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.11.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.11.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.ln_1.weight", "visual_encoder.transformer.resblocks.12.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.ln_1.bias", "visual_encoder.transformer.resblocks.12.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.attn.in_proj_weight", "visual_encoder.transformer.resblocks.12.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.attn.in_proj_bias", "visual_encoder.transformer.resblocks.12.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.attn.out_proj.weight", "visual_encoder.transformer.resblocks.12.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.attn.out_proj.bias", "visual_encoder.transformer.resblocks.12.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.ln_2.weight", "visual_encoder.transformer.resblocks.12.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.ln_2.bias", "visual_encoder.transformer.resblocks.12.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.12.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.12.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.12.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.12.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.12.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.ln_1.weight", "visual_encoder.transformer.resblocks.13.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.ln_1.bias", "visual_encoder.transformer.resblocks.13.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.attn.in_proj_weight", "visual_encoder.transformer.resblocks.13.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.attn.in_proj_bias", "visual_encoder.transformer.resblocks.13.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.attn.out_proj.weight", "visual_encoder.transformer.resblocks.13.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.attn.out_proj.bias", "visual_encoder.transformer.resblocks.13.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.ln_2.weight", "visual_encoder.transformer.resblocks.13.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.ln_2.bias", "visual_encoder.transformer.resblocks.13.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.13.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.13.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.13.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.13.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.13.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.ln_1.weight", "visual_encoder.transformer.resblocks.14.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.ln_1.bias", "visual_encoder.transformer.resblocks.14.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.attn.in_proj_weight", "visual_encoder.transformer.resblocks.14.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.attn.in_proj_bias", "visual_encoder.transformer.resblocks.14.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.attn.out_proj.weight", "visual_encoder.transformer.resblocks.14.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.attn.out_proj.bias", "visual_encoder.transformer.resblocks.14.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.ln_2.weight", "visual_encoder.transformer.resblocks.14.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.ln_2.bias", "visual_encoder.transformer.resblocks.14.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.14.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.14.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.14.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.14.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.14.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.ln_1.weight", "visual_encoder.transformer.resblocks.15.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.ln_1.bias", "visual_encoder.transformer.resblocks.15.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.attn.in_proj_weight", "visual_encoder.transformer.resblocks.15.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.attn.in_proj_bias", "visual_encoder.transformer.resblocks.15.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.attn.out_proj.weight", "visual_encoder.transformer.resblocks.15.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.attn.out_proj.bias", "visual_encoder.transformer.resblocks.15.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.ln_2.weight", "visual_encoder.transformer.resblocks.15.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.ln_2.bias", "visual_encoder.transformer.resblocks.15.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.15.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.15.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.15.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.15.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.15.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.ln_1.weight", "visual_encoder.transformer.resblocks.16.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.ln_1.bias", "visual_encoder.transformer.resblocks.16.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.attn.in_proj_weight", "visual_encoder.transformer.resblocks.16.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.attn.in_proj_bias", "visual_encoder.transformer.resblocks.16.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.attn.out_proj.weight", "visual_encoder.transformer.resblocks.16.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.attn.out_proj.bias", "visual_encoder.transformer.resblocks.16.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.ln_2.weight", "visual_encoder.transformer.resblocks.16.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.ln_2.bias", "visual_encoder.transformer.resblocks.16.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.16.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.16.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.16.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.16.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.16.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.ln_1.weight", "visual_encoder.transformer.resblocks.17.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.ln_1.bias", "visual_encoder.transformer.resblocks.17.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.attn.in_proj_weight", "visual_encoder.transformer.resblocks.17.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.attn.in_proj_bias", "visual_encoder.transformer.resblocks.17.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.attn.out_proj.weight", "visual_encoder.transformer.resblocks.17.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.attn.out_proj.bias", "visual_encoder.transformer.resblocks.17.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.ln_2.weight", "visual_encoder.transformer.resblocks.17.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.ln_2.bias", "visual_encoder.transformer.resblocks.17.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.17.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.17.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.17.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.17.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.17.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.ln_1.weight", "visual_encoder.transformer.resblocks.18.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.ln_1.bias", "visual_encoder.transformer.resblocks.18.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.attn.in_proj_weight", "visual_encoder.transformer.resblocks.18.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.attn.in_proj_bias", "visual_encoder.transformer.resblocks.18.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.attn.out_proj.weight", "visual_encoder.transformer.resblocks.18.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.attn.out_proj.bias", "visual_encoder.transformer.resblocks.18.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.ln_2.weight", "visual_encoder.transformer.resblocks.18.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.ln_2.bias", "visual_encoder.transformer.resblocks.18.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.18.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.18.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.18.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.18.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.18.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.ln_1.weight", "visual_encoder.transformer.resblocks.19.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.ln_1.bias", "visual_encoder.transformer.resblocks.19.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.attn.in_proj_weight", "visual_encoder.transformer.resblocks.19.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.attn.in_proj_bias", "visual_encoder.transformer.resblocks.19.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.attn.out_proj.weight", "visual_encoder.transformer.resblocks.19.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.attn.out_proj.bias", "visual_encoder.transformer.resblocks.19.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.ln_2.weight", "visual_encoder.transformer.resblocks.19.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.ln_2.bias", "visual_encoder.transformer.resblocks.19.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.19.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.19.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.19.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.19.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.19.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.ln_1.weight", "visual_encoder.transformer.resblocks.20.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.ln_1.bias", "visual_encoder.transformer.resblocks.20.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.attn.in_proj_weight", "visual_encoder.transformer.resblocks.20.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.attn.in_proj_bias", "visual_encoder.transformer.resblocks.20.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.attn.out_proj.weight", "visual_encoder.transformer.resblocks.20.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.attn.out_proj.bias", "visual_encoder.transformer.resblocks.20.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.ln_2.weight", "visual_encoder.transformer.resblocks.20.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.ln_2.bias", "visual_encoder.transformer.resblocks.20.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.20.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.20.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.20.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.20.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.20.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.ln_1.weight", "visual_encoder.transformer.resblocks.21.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.ln_1.bias", "visual_encoder.transformer.resblocks.21.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.attn.in_proj_weight", "visual_encoder.transformer.resblocks.21.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.attn.in_proj_bias", "visual_encoder.transformer.resblocks.21.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.attn.out_proj.weight", "visual_encoder.transformer.resblocks.21.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.attn.out_proj.bias", "visual_encoder.transformer.resblocks.21.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.ln_2.weight", "visual_encoder.transformer.resblocks.21.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.ln_2.bias", "visual_encoder.transformer.resblocks.21.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.21.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.21.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.21.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.21.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.21.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.ln_1.weight", "visual_encoder.transformer.resblocks.22.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.ln_1.bias", "visual_encoder.transformer.resblocks.22.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.attn.in_proj_weight", "visual_encoder.transformer.resblocks.22.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.attn.in_proj_bias", "visual_encoder.transformer.resblocks.22.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.attn.out_proj.weight", "visual_encoder.transformer.resblocks.22.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.attn.out_proj.bias", "visual_encoder.transformer.resblocks.22.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.ln_2.weight", "visual_encoder.transformer.resblocks.22.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.ln_2.bias", "visual_encoder.transformer.resblocks.22.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.22.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.22.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.22.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.22.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.22.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.ln_1.weight", "visual_encoder.transformer.resblocks.23.ln_1.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.ln_1.bias", "visual_encoder.transformer.resblocks.23.ln_1.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.attn.in_proj_weight", "visual_encoder.transformer.resblocks.23.attn.in_proj_weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.attn.in_proj_bias", "visual_encoder.transformer.resblocks.23.attn.in_proj_bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.attn.out_proj.weight", "visual_encoder.transformer.resblocks.23.attn.out_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.attn.out_proj.bias", "visual_encoder.transformer.resblocks.23.attn.out_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.ln_2.weight", "visual_encoder.transformer.resblocks.23.ln_2.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.ln_2.bias", "visual_encoder.transformer.resblocks.23.ln_2.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.mlp.c_fc.weight", "visual_encoder.transformer.resblocks.23.mlp.c_fc.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.mlp.c_fc.bias", "visual_encoder.transformer.resblocks.23.mlp.c_fc.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.mlp.c_proj.weight", "visual_encoder.transformer.resblocks.23.mlp.c_proj.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.transformer.resblocks.23.mlp.c_proj.bias", "visual_encoder.transformer.resblocks.23.mlp.c_proj.bias")) + converters.append(WeightConverter("layers.0.visual_encoder.ln_post.weight", "visual_encoder.ln_post.weight")) + converters.append(WeightConverter("layers.0.visual_encoder.ln_post.bias", "visual_encoder.ln_post.bias")) + converters.append(WeightConverter("layers.0.ln_vision.weight", "ln_vision.weight")) + converters.append(WeightConverter("layers.0.ln_vision.bias", "ln_vision.bias")) + + # Adapter + converters.append(WeightConverter("layers.1.adapter_fc.weight", "c_fc.weight")) + converters.append(WeightConverter("layers.1.adapter_fc.bias", "c_fc.bias")) + + # Embedding and output + if self.config.tie_word_embeddings: + converters.append(WeightConverter("layers.2.word_embeddings_weight", "model.embed_tokens.weight")) + converters.append(IgnoreWeightConverter((), "lm_head.weight")) + else: + converters.append(WeightConverter("layers.2.word_embeddings_weight", "model.embed_tokens.weight")) + converters.append(WeightConverter(f"layers.{num_layers + 3}.output_weights", "lm_head.weight")) + + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + 3}.final_norm", "model.norm", norm_bias + ) + + for i in range(num_layers): + # Self-attn + converters += self._get_weight_and_bias_converters( + f"layers.{i+3}.self_attn.query", + f"model.layers.{i}.self_attn.q_proj", + linear_bias, + QueryWeightConverter, + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+3}.self_attn.key_value", + (f"model.layers.{i}.self_attn.k_proj", f"model.layers.{i}.self_attn.v_proj"), + linear_bias, + KeyValueWeightConverter, + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+3}.self_attn.dense", f"model.layers.{i}.self_attn.o_proj", linear_bias + ) + + # Norm + converters += self._get_weight_and_bias_converters( + f"layers.{i+3}.norm_1", f"model.layers.{i}.input_layernorm", norm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+3}.norm_2", f"model.layers.{i}.post_attention_layernorm", norm_bias + ) + + # MLP + converters += self._get_mlp_converters(f"layers.{i+3}", f"model.layers.{i}") + + return converters + + def _get_weight_and_bias_converters( + self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + ): + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self.config, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self.config, + ) + ) + return converters + + +class Starcoder2HuggingfaceConverter(CommonHuggingfaceConverter): + model_type = HuggingfaceModelType.starcoder2 + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter(None, "architectures", ["Starcoder2ForCausalLM"]), + ConstantImportParamConverter( + ("transformer", "normalization", "normalization_type"), None, NormalizationType.layer_norm + ), + ParamConverter(("transformer", "normalization", "layer_norm_eps"), "norm_epsilon"), + ConstantImportParamConverter(("transformer", "gated"), None, False), + ConstantImportParamConverter(("transformer", "add_linear_biases"), None, True), + ] + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): + linear_bias: bool = self.config.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", linear_bias + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.c_proj", linear_bias, MLPLayer2Converter + ), + ] + + +class CommonLlamaHuggingfaceConverter(CommonHuggingfaceConverter, abc.ABC): + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantImportParamConverter( + ("transformer", "normalization", "normalization_type"), None, NormalizationType.rms_norm + ), + ParamConverter(("transformer", "normalization", "layer_norm_eps"), "rms_norm_eps"), + ConstantImportParamConverter(("transformer", "gated"), None, True), + ConstantImportParamConverter(("transformer", "add_linear_biases"), None, False), + ] + + +class LlamaHuggingfaceConverter(CommonLlamaHuggingfaceConverter): + model_type = HuggingfaceModelType.llama + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter(None, "architectures", ["LlamaForCausalLM"]), + # TODO: Llama supports biases + ConstantExportParamConverter(None, "attention_bias", False), + ConstantExportParamConverter(None, "mlp_bias", False), + ConstantExportParamConverter(None, "rope_scaling", False), + ] + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): + linear_bias: bool = self.config.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), + ] + + +class MistralHuggingfaceConverter(CommonLlamaHuggingfaceConverter): + model_type = HuggingfaceModelType.mistral + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter(None, "architectures", ["MistralForCausalLM"]), + IgnoreImportParamConverter(None, "sliding_window", None), + ] + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.mlp.gate_proj.weight", f"{hf_prefix}.mlp.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", f"{hf_prefix}.mlp.down_proj.weight", self.config + ), + ] + + +class MixtralHuggingfaceConverter(CommonLlamaHuggingfaceConverter): + model_type = HuggingfaceModelType.mixtral + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter(None, "architectures", ["MixtralForCausalLM"]), + ConstantImportParamConverter(("transformer", "expert_routing_type"), None, RoutingType.topk), + ParamConverter(("transformer", "num_experts"), "num_local_experts"), + ParamConverter(("transformer", "num_experts_per_token"), "num_experts_per_tok"), + IgnoreImportParamConverter(None, "sliding_window", None), + ] + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): + num_experts = self.config.transformer.num_experts + return [ + WeightConverter(f"{fast_llm_prefix}.mlp.router.weight", f"{hf_prefix}.block_sparse_moe.gate.weight"), + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + tuple( + f"{hf_prefix}.block_sparse_moe.experts.{i}.{w}.weight" + for i in range(num_experts) + for w in ("w1", "w3") + ), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + tuple(f"{hf_prefix}.block_sparse_moe.experts.{i}.w2.weight" for i in range(num_experts)), + self.config, + ), + ] + + +class AutoStarDocConverter(AutoModelConverter, HuggingfaceModelConverter, abc.ABC): + converter_map = { + HuggingfaceModelType.starcoder2: Starcoder2HuggingfaceConverter, + HuggingfaceModelType.llama: LlamaHuggingfaceConverter, + HuggingfaceModelType.mistral: MistralHuggingfaceConverter, + HuggingfaceModelType.mixtral: MixtralHuggingfaceConverter, + } \ No newline at end of file diff --git a/fast_llm/models/stardoc/model.py b/fast_llm/models/stardoc/model.py new file mode 100644 index 000000000..f6aa5cfac --- /dev/null +++ b/fast_llm/models/stardoc/model.py @@ -0,0 +1,348 @@ +import logging + +import torch + +from fast_llm.engine.base_model.base_model import BaseModel, LossDef +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames +from fast_llm.layers.multimodal_model.config import MultimodalModelDimNames, MultimodalModelKwargs +from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding +from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor +from fast_llm.layers.multimodal_model.multimodal_language_embedding import MultiModalLanguageModelEmbedding +from fast_llm.layers.multimodal_model.image_encoder import ImageEncoder +from fast_llm.layers.multimodal_model.adapter import Adapter + +from fast_llm.layers.transformer.config import ( + RoutingType, + TransformerDimNames, + TransformerKwargs, + TransformerLossNames, +) +from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, RotaryEmbeddingPreprocessor +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.models.stardoc.config import StarDocBaseModelConfig, StarDocModelConfig +from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.utils import Assert, div + +logger = logging.getLogger(__name__) + + +class StarDocBaseModel(BaseModel): + """ + A transformer-based language model generalizing the StarDoc model architecture. + """ + + _is_setup: bool = False + _config: StarDocBaseModelConfig + _rotary_embedding_frequencies: torch.Tensor + _position_ids: torch.Tensor + _mask: torch.Tensor + _mask_value: torch.Tensor + _tensor_cache_max_sequence_length: int = -1 + config_cls = StarDocBaseModelConfig + + def __init__( + self, + config: BaseModelConfig, + distributed_config: DistributedConfig, + ): + super().__init__(config, distributed_config) + self._use_flash_attention = self._config.transformer.do_use_flash_attention(distributed_config) + if self._config.use_absolute_position_embeddings: + self._position_embedding_preprocessor = PositionEmbeddingPreprocessor(self._config, self._tensor_space) + if self._config.transformer.use_rotary_position_embeddings: + self._rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( + self._config.transformer, self._tensor_space + ) + if not self._use_flash_attention: + self._backup_attention_preprocessor = BackupAttentionPreprocessor( + self._config.transformer, self._tensor_space + ) + + def get_layers(self): + return [ + ImageEncoder(self._config, self._tensor_space), + Adapter(self._config, self._tensor_space), + MultiModalLanguageModelEmbedding(self._config, self._tensor_space), + *[ + TransformerLayer( + self._config.transformer, + self._tensor_space, + layer_index=i + 1, + ) + for i in range(self._config.transformer.num_layers) + ], + LanguageModelHead(self._config, self._tensor_space), + ] + + def setup(self, distributed: Distributed): + assert not self._is_setup + assert distributed.config is self._tensor_space.distributed_config + self._tensor_space.setup(distributed) + self._is_setup = True + + def preprocess_meta(self, input_: BatchConfig | torch.Tensor, phase: PhaseType) -> list[tuple[TensorMeta, dict]]: + # TODO: How much of this is generalizable? + # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence + + if isinstance(input_, BatchConfig): + micro_batch_size = input_.micro_batch_size + sequence_length = input_.sequence_length + micro_sequence_length = input_.micro_sequence_length + else: + micro_batch_size, sequence_length = input_.shape + if phase != PhaseType.inference: + sequence_length -= 1 + micro_sequence_length = sequence_length + + print(f'Sequence length for meta {sequence_length}') + + batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) + batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) + + if isinstance(input_, BatchConfig): + micro_sequence_length = input_.micro_sequence_length + + if micro_sequence_length is None: + micro_sequence_length = sequence_length + else: + Assert.multiple(sequence_length, micro_sequence_length) + + local_micro_sequence_length = div( + micro_sequence_length, self._tensor_space.distributed_config.sequence_data_parallel + ) + + need_sequence_first = ( + self._tensor_space.distributed_config.sequence_tensor_parallel + or sequence_length > local_micro_sequence_length + ) + if self._config.sequence_first is None: + sequence_first = need_sequence_first + else: + sequence_first = self._config.sequence_first + assert not (need_sequence_first and not sequence_first) + + sequence_q_dim = TensorDim(TransformerDimNames.sequence_q, local_micro_sequence_length) + + # TODO: Calculate hidden dims elsewhere? + hidden_sequence_q_dim = ( + TensorDim( + TransformerDimNames.sequence_q_tp, + micro_sequence_length, + self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor), + ) + if self._tensor_space.distributed_config.sequence_tensor_parallel + else sequence_q_dim + ) + hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dims = ( + (hidden_sequence_q_dim, batch_dim, hidden_dim) + if sequence_first + else (batch_dim, hidden_sequence_q_dim, hidden_dim) + ) + + max_num_images = self._tensor_space.get_tensor_dim(MultimodalModelDimNames.max_num_images) + image_pixel_count = self._tensor_space.get_tensor_dim(MultimodalModelDimNames.image_pixel_count) + num_image_tokens = self._tensor_space.get_tensor_dim(MultimodalModelDimNames.num_image_tokens) + image_encoder_hidden_size = self._tensor_space.get_tensor_dim(MultimodalModelDimNames.image_encoder_hidden_size) + + image_encoder_hidden_dims = ( + (batch_dim, max_num_images, num_image_tokens, image_encoder_hidden_size) + ) + adapter_hidden_dims = ( + (batch_dim, max_num_images, num_image_tokens, hidden_dim) + ) + + common_kwargs = { + LanguageModelKwargs.phase: phase, + TransformerKwargs.sequence_first: sequence_first, + TransformerKwargs.hidden_dims: hidden_dims, + TransformerKwargs.sequence_length: sequence_length, + TransformerKwargs.sequence_q_dim: sequence_q_dim, + MultimodalModelKwargs.image_encoder_hidden_dims: image_encoder_hidden_dims, + MultimodalModelKwargs.adapter_hidden_dims: adapter_hidden_dims, + } + + # For stardoc, since image tokens and text tokens need to be merged, sequence parallel is complicated + Assert.eq(micro_sequence_length, sequence_length) + Assert.eq(local_micro_sequence_length, sequence_length) + + preprocessed_meta = [] + for sequence_k_past in range( + local_micro_sequence_length * self._tensor_space.distributed_config.sequence_data_rank, + sequence_length, + micro_sequence_length, + ): + sequence_k = sequence_k_past + local_micro_sequence_length + sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) + + tokens = TensorMeta.from_dims( + hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 + ) + + image_data = TensorMeta.from_dims( + ( + batch_dim, + max_num_images, + image_pixel_count, + ), + tensor_name="image_data", + dtype=torch.float32, + ) + + kwargs = { + **common_kwargs, + LanguageModelKwargs.tokens: tokens, + TransformerKwargs.sequence_k_dim: sequence_k_dim, + } + if phase != PhaseType.inference: + kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( + hidden_dims[:2], tensor_name="labels", dtype=torch.int64 + ) + if self._config.use_absolute_position_embeddings: + self._position_embedding_preprocessor.preprocess_meta(kwargs) + if self._config.transformer.use_rotary_position_embeddings: + self._rotary_embedding_preprocessor.preprocess_meta(kwargs) + if not self._use_flash_attention: + self._backup_attention_preprocessor.preprocess_meta(kwargs) + preprocessed_meta.append((image_data, kwargs)) + + return preprocessed_meta + + def preprocess( + self, + batch: dict, + preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, + *, + phase: PhaseType, + iteration: int, + metrics: dict | None = None, + ) -> list[tuple[torch.Tensor, dict]]: + # TODO: How much of this is generalizable? + assert self._is_setup + + if preprocessed_meta is None: + preprocessed_meta = self.preprocess_meta(batch, phase) + + _, common_kwargs = preprocessed_meta[0] + sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size + sequence_first = common_kwargs[TransformerKwargs.sequence_first] + sequence_length = common_kwargs[TransformerKwargs.sequence_length] + + tokens = batch["input_ids"] + labels = batch["labels"] + image_data = batch["images"] + + # Move input_ids, labels and images to device + tokens = tokens.to( + device=self._tensor_space.distributed.device, + dtype=torch.int64, + non_blocking=True, + ).contiguous() + labels = labels.to( + device=self._tensor_space.distributed.device, + dtype=torch.int64, + non_blocking=True, + ).contiguous() + image_data = image_data.to( + device=self._tensor_space.distributed.device, + dtype=torch.float32, + non_blocking=True, + ).contiguous() + + if self._config.use_absolute_position_embeddings: + self._position_embedding_preprocessor.create_tensors(sequence_length) + if self._config.transformer.use_rotary_position_embeddings: + self._rotary_embedding_preprocessor.create_tensors(sequence_length) + if not self._use_flash_attention: + self._backup_attention_preprocessor.create_tensors(sequence_length) + + # TODO: Pasts and presents for inference? + preprocessed = [] + presents = None + for tokens_meta, kwargs_meta in preprocessed_meta: + sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size + tokens = tokens[:, sequence_k - sequence_q : sequence_k].contiguous() + print(f'Tokens sequence_k: {sequence_k} sequence_q: {sequence_q} shape: {tokens.shape}') + + pasts = presents + presents = None if sequence_k == sequence_length else [] + kwargs = { + **kwargs_meta, + LanguageModelKwargs.tokens: tokens, + TransformerKwargs.past_key_values: pasts, + TransformerKwargs.presents: presents, + + } + if phase != PhaseType.inference: + labels = labels[:, sequence_k - sequence_q + 1 : sequence_k + 1].contiguous() + print(f'Labels sequence_k: {sequence_k} sequence_q: {sequence_q} shape: {labels.shape}') + kwargs[LanguageModelKwargs.labels] = labels + + if self._config.use_absolute_position_embeddings: + self._position_embedding_preprocessor.preprocess(kwargs) + if self._config.transformer.use_rotary_position_embeddings: + self._rotary_embedding_preprocessor.preprocess(kwargs) + if not self._use_flash_attention: + self._backup_attention_preprocessor.preprocess(kwargs) + preprocessed.append((image_data, kwargs)) + + return preprocessed + + @property + def embedding(self) -> LanguageModelEmbedding: + return self.layers[0] + + @property + def transformer_layers(self) -> list[TransformerLayer]: + return self.layers[1:-1] + + @property + def model_head(self) -> LanguageModelHead: + return self.layers[-1] + + def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: + return ( + {WORD_EMBEDDINGS_WEIGHT: (self.embedding.word_embeddings_weight, (0, len(self) - 1))} + if self._config.tie_word_embeddings + else {} + ) + + @property + def loss_defs(self) -> list[LossDef]: + loss_defs = [ + LossDef(name=LanguageModelLossNames.language_model_loss, formatted_name="language model loss", count=1) + ] + if ( + self._config.transformer.num_experts > 1 + and self._config.transformer.expert_routing_type == RoutingType.topk + ): + loss_defs.append( + LossDef( + name=TransformerLossNames.load_balancing_loss, + formatted_name="load balancing loss", + count=self._config.transformer.num_layers, + ) + ) + if self._config.transformer.expert_z_loss_coefficient: + loss_defs.append( + LossDef( + name=TransformerLossNames.router_z_loss, + formatted_name="router z loss", + count=self._config.transformer.num_layers, + ) + ) + if self._config.logit_z_loss: + LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) + return loss_defs + + +class StarDocModel(FastLLMModel): + config_class = StarDocModelConfig + base_model_class = StarDocBaseModel diff --git a/fast_llm/models/stardoc/trainer.py b/fast_llm/models/stardoc/trainer.py new file mode 100644 index 000000000..885714de3 --- /dev/null +++ b/fast_llm/models/stardoc/trainer.py @@ -0,0 +1,63 @@ +import logging + +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.training.trainer import Trainer +from fast_llm.models.stardoc.config import StarDocTrainerConfig +from fast_llm.models.stardoc.model import StarDocModel + +logger = logging.getLogger(__name__) + + +class StarDocTrainer(Trainer): + _abstract = False + config_class = StarDocTrainerConfig + model_class = StarDocModel + + def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: + # TODO: Do in model, automate/generalize, get other stats + """Get tflop/s/GPU from global-batch-size and elapsed-time""" + checkpoint_activations_factor = 3 if phase == PhaseType.training else 1 + transformer_config = self._config.base_model.transformer + sequence_length = self._config.batch.sequence_length + + tokens = self._config.batch.batch_size * sequence_length + transformer_flops_base = 2 * checkpoint_activations_factor * tokens * transformer_config.num_layers + dense_flops_base = transformer_flops_base * transformer_config.hidden_size + # Query, key, value, dense. + flops_per_iteration = ( + 2 + * (transformer_config.num_attention_heads + transformer_config.head_groups) + * transformer_config.kv_channels + * dense_flops_base + ) + # MLP + flops_per_iteration += ( + (2 + transformer_config.gated) + * transformer_config.ffn_hidden_size + * dense_flops_base + * transformer_config.num_experts_per_token + ) + + # LM-head + flops_per_iteration += 6 * tokens * transformer_config.hidden_size * self._config.base_model.vocab_size + + # Attention-matrix computation + attn_flops_base = transformer_flops_base * transformer_config.projection_size + if transformer_config.window_size is None: + # Ignore masked values (s**2/2) + attn_flops = attn_flops_base * sequence_length + model_tflops = flops_per_iteration + attn_flops + else: + # s*w - w**2/2 + attn_flops = ( + 2 + * attn_flops_base + * transformer_config.window_size + * (1 - transformer_config.window_size / 2 / sequence_length) + ) + model_tflops = flops_per_iteration + attn_flops + + # Partial recomputation (normal is 2 ops * ckpt_factor = 6, adding 1 for recomputing Q x K) + hardware_flops = flops_per_iteration + 7 / 6 * attn_flops + ratio = elapsed_time_per_iteration * self._config.distributed.world_size * 1e12 + return model_tflops / ratio, hardware_flops / ratio diff --git a/fast_llm/tools/train.py b/fast_llm/tools/train.py index 956075b7c..cd871659f 100644 --- a/fast_llm/tools/train.py +++ b/fast_llm/tools/train.py @@ -4,6 +4,7 @@ import requests import yaml +import sys from fast_llm.config import NoAutoValidate from fast_llm.engine.training.config import TrainerConfig @@ -116,4 +117,4 @@ def train(args=None): if __name__ == "__main__": - train() + train(sys.argv[1:]) diff --git a/run_multimodal.sh b/run_multimodal.sh new file mode 100644 index 000000000..ed72d7a25 --- /dev/null +++ b/run_multimodal.sh @@ -0,0 +1,88 @@ +BASE_JOB_NAME="mistral-7b-FastLLM-stardoc-debug-local" + +export PYTHONHASHSEED=12345 + +export MODEL_ARGS="\ +--pretrained_checkpoint_type=huggingface \ +--pretrained_checkpoint_path=/data/git/Fast-LLM/stardoc_hf/stardoc_checkpoint \ +--use_pretrained_config=1 \ +--attention_dropout=0.0 \ +--hidden_dropout=0.0 \ +--max_num_images=5 \ +--image_resolution=224 \ +--num_image_tokens=256 \ +--image_encoder_hidden_size=1024 \ +--image_encoder_type=clip \ +" + +export STAGE_ARGS="\ +--zero_stage=3 \ +" + +export OPTIMIZER_ARGS="\ +--lr=0.000001 \ +--lr_decay_style=cosine \ +--lr_decay_iters=250 \ +--lr_warmup_iters=100 \ +--min_lr=0.0 \ +--weight_decay=0.1 \ +--adam_beta1=.9 \ +--adam_beta2=.95 \ +--clip_grad=1.0 \ +" + +export DATA_ARGS="\ +--split=9998,2,0 \ +--dataset_source=multimodal \ +--data_path=/data/datasets/stardoc/BigDoc-MultiTurn-v0.3 \ +--tokenizer_type=PreTrainedTokenizer \ +--tokenizer_path=/data/models/mistral/HF/Mistral-7B-v0.3 \ +" + +export SCHEDULE_ARGS="\ +--batch_size=32 \ +--micro_batch_size=1 \ +--sequence_length=8192 \ +" + +export DISTRIBUTED_ARGS="\ +--training_dtype=bf16 \ +--distributed_timeout=600 \ +--seed=984059 \ +--sequence_data_parallel=1 \ +" + +export RUN_ARGS="\ +--log_interval=10 \ +--log_offset=0 \ +--checkpoint_interval=500 \ +--max_checkpoints=5 \ +--export_interval=25000 \ +" + +export TRAINING_ARGS="\ +--train_iters=40000 \ +--validation_iters=25000000 \ +--validation_interval=1000000 \ +--test_iters=0 \ +--num_workers=1 \ +" + +export ALL_ARGS="\ +$MODEL_ARGS \ +$STAGE_ARGS \ +$DATA_ARGS \ +$SCHEDULE_ARGS \ +$OPTIMIZER_ARGS \ +$DISTRIBUTED_ARGS \ +$TRAINING_ARGS \ +$RUN_ARGS \ +" + +torchrun --nproc-per-node=8 \ + --log-dir=output/$BASE_JOB_NAME/logs \ + --redirects=3 \ + pretrain_fast_llm.py $ALL_ARGS --experiment_dir="output/$BASE_JOB_NAME/" + +# torchrun --nproc-per-node=8 \ +# pretrain_fast_llm.py $ALL_ARGS --experiment_dir="output/$BASE_JOB_NAME/" \ No newline at end of file diff --git a/set_env_stardoc.sh b/set_env_stardoc.sh new file mode 100755 index 000000000..cf86c498a --- /dev/null +++ b/set_env_stardoc.sh @@ -0,0 +1,98 @@ +#!/bin/bash + +export PROJECT_DIR="/mnt/akshay/stardoc-FastLLM/Fast-LLM/output" +export PROJECT_NAME="stardoc_debug" +export PROJECT_VERSION="1.0" +export RUN_NAME="debug" + +export DATA_PATH="/mnt/stardoc/datasets/save_hf/BigDoc-MultiTurn-v0.13" +export PRETRAINED_STARDOC_PATH="/mnt/akshay/stardoc-FastLLM/Fast-LLM/stardoc_hf_model/stardoc_checkpoint" +export TOKENIZER_PATH="/mnt/core_llm/models/mistral/HF/Mistral-7B-v0.3" + +export CMD_ARGS="fast-llm train gpt" + +export MODEL_ARGS_PRETRAINED="\ +--pretrained_checkpoint_type=huggingface \ +--pretrained_checkpoint_path=$PRETRAINED_STARDOC_PATH \ +--use_pretrained_config=1 \ +" + +export MODEL_ARGS_ARCHITECTURE="\ +--num_layers=32 \ +--hidden_size=4096 \ +--vocab_size=32000 \ +--num_attention_heads=32 \ +--head_groups=8 \ +--add_linear_biases=0 \ +--ffn_hidden_size=14336 \ +--kv_channels=128 \ +--use_rotary_embeddings=1 \ +--rotary_embedding_scale=-9.210340371976184 \ +--gated=1 \ +--activation_type=silu \ +--normalization_type=rms_norm \ +--tie_word_embeddings=0 \ +--window_size=8192 \ +" + +export DATA_ARGS="\ +--split=9998,2,0 \ +--dataset_source=multimodal \ +--data_path=$DATA_PATH \ +--tokenizer_type=PreTrainedTokenizer \ +--tokenizer_path=$TOKENIZER_PATH \ +" + +export TRAINING_ARGS="\ +--batch_size=8 \ +--sequence_length=8192 \ +--train_iters=500000 \ +--weight_decay=0.1 \ +--adam_beta1=0.9 \ +--adam_beta2=0.95 \ +--clip_grad=1.0 \ +--lr=0.0001 \ +--lr_warmup_iters=1000 \ +--lr_decay_style=cosine \ +--lr_decay_iters=500000 \ +--min_lr=0.000003 \ +" + +export PERFORMANCE_ARGS="\ +--micro_batch_size=1 \ +--training_dtype=bf16 \ +--zero_stage=2 \ +--num_workers=8 \ +" + +export MONITORING_ARGS="\ +--validation_iters=25 \ +--validation_interval=1000 \ +--log_interval=10 \ +--log_offset=0 \ +--checkpoint_interval=500 \ +--max_checkpoints=5 \ +--export_interval=25000 \ +--wandb_status_interval=25000 \ +--wandb_entity_name=$WANDB_ENTITY_NAME \ +--wandb_project_name=$PROJECT_NAME \ +--wandb_group_name=$PROJECT_VERSION \ +" + +export ALL_ARGS="\ +$CMD_ARGS \ +$MODEL_ARGS_PRETRAINED \ +$DATA_ARGS \ +$TRAINING_ARGS \ +$PERFORMANCE_ARGS \ +$MONITORING_ARGS \ +" + +export PROFILE_ARGS="\ +--profile_cuda=1 \ +--profile_skip=10 \ +--profile_wait=95 \ +--profile_warmup=2 \ +--profile_cycles=3 \ +--profile_export=1 \ +" \ No newline at end of file From 162438f0d01879a5485a0255d988013313604059 Mon Sep 17 00:00:00 2001 From: Akshay Kalkunte Date: Thu, 17 Oct 2024 00:33:09 +0000 Subject: [PATCH 2/4] Code cleanup --- examples/train_mistral.sh | 2 +- examples/train_stardoc.sh | 47 +++--- examples/train_stardoc_akshay.sh | 156 ++++++++++++++++++ fast_llm/data/stardoc.py | 10 +- fast_llm/layers/language_model/head.py | 1 - .../layers/multimodal_model/image_encoder.py | 2 +- .../multimodal_language_embedding.py | 1 - fast_llm/models/stardoc/model.py | 3 - run_multimodal.sh | 88 ---------- set_env_stardoc.sh | 98 ----------- 10 files changed, 188 insertions(+), 220 deletions(-) create mode 100755 examples/train_stardoc_akshay.sh delete mode 100644 run_multimodal.sh delete mode 100755 set_env_stardoc.sh diff --git a/examples/train_mistral.sh b/examples/train_mistral.sh index e4b03c475..5745e38c8 100644 --- a/examples/train_mistral.sh +++ b/examples/train_mistral.sh @@ -61,7 +61,7 @@ export TRAINING_ARGS="\ " export PERFORMANCE_ARGS="\ ---micro_batch_size=4 \ +--micro_batch_size=1 \ --training_dtype=bf16 \ --zero_stage=2 \ --num_workers=8 \ diff --git a/examples/train_stardoc.sh b/examples/train_stardoc.sh index ea9865ac0..5607f1537 100755 --- a/examples/train_stardoc.sh +++ b/examples/train_stardoc.sh @@ -1,19 +1,15 @@ # Required or optional environment variables -export PROJECT_DIR="/mnt/akshay/stardoc-FastLLM/Fast-LLM/output" -export PROJECT_NAME="stardoc-debug" -export PROJECT_VERSION="1.0" -# export DATA_PATH_LIST= -export DATA_PATH="/mnt/stardoc/datasets/save_hf/BigDoc-MultiTurn-v0.3" -# export DATA_PATH_JSON= -# export PRETRAINED_MISTRAL_PATH= -# export PRETRAINED_MIXTRAL_PATH= -export PRETRAINED_STARDOC_PATH="/mnt/akshay/stardoc-FastLLM/Fast-LLM/stardoc_hf_model/stardoc_checkpoint" -export TOKENIZER_PATH="/mnt/core_llm/models/mistral/HF/Mistral-7B-v0.3" +# export PROJECT_DIR= +# export PROJECT_NAME= +# export PROJECT_VERSION= +# export DATA_PATH= +# export PRETRAINED_STARDOC_PATH= +# export TOKENIZER_PATH= -export PYTHONHASHSEED=12345 +# export HF_HOME= +# export HF_TOKEN= export CMD_ARGS="fast-llm train stardoc" -# export CMD_ARGS="python /mnt/akshay/stardoc-FastLLM/Fast-LLM/fast_llm/tools/train.py stardoc" export MODEL_ARGS_PRETRAINED="\ --pretrained_checkpoint_type=huggingface \ @@ -39,10 +35,10 @@ export MODEL_ARGS_ARCHITECTURE="\ --window_size=8192 \ " -export MULTIMODAL_ARGS ="\ +export MULTIMODAL_ARGS="\ --image_encoder_hidden_size=1024 \ --num_image_tokens=256 \ ---max_num_images=7 \ +--max_num_images=10 \ --image_encoder_type=clip \ " @@ -56,7 +52,7 @@ export DATA_ARGS="\ " export TRAINING_ARGS="\ ---batch_size=32 \ +--batch_size=8 \ --sequence_length=8192 \ --train_iters=500000 \ --weight_decay=0.1 \ @@ -71,9 +67,9 @@ export TRAINING_ARGS="\ " export PERFORMANCE_ARGS="\ ---micro_batch_size=4 \ +--micro_batch_size=1 \ --training_dtype=bf16 \ ---zero_stage=2 \ +--zero_stage=3 \ --num_workers=8 \ " @@ -85,16 +81,16 @@ export MONITORING_ARGS="\ --checkpoint_interval=500 \ --max_checkpoints=5 \ --export_interval=25000 \ +--wandb_status_interval=25000 \ +--wandb_entity_name=$WANDB_ENTITY_NAME \ +--wandb_project_name=$PROJECT_NAME \ +--wandb_group_name=$PROJECT_VERSION \ " -# --wandb_status_interval=25000 \ -# --wandb_entity_name=$WANDB_ENTITY_NAME \ -# --wandb_project_name=$PROJECT_NAME \ -# --wandb_group_name=$PROJECT_VERSION \ -# " export ALL_ARGS="\ $CMD_ARGS \ $MODEL_ARGS_PRETRAINED \ +$MODEL_ARGS_ARCHITECTURE \ $MULTIMODAL_ARGS \ $DATA_ARGS \ $TRAINING_ARGS \ @@ -111,7 +107,6 @@ export PROFILE_ARGS="\ --profile_export=1 \ " - run_local () { # run(name, num_gpus, base_cmd) echo $1 $2 $3 export TORCHRUN="torchrun --nproc-per-node=$2 --nnodes=1 --no-python" @@ -124,6 +119,6 @@ run_c10d () { # run(name, num_nodes, base_cmd) $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1 } -run_local debug 8 "$ALL_ARGS" -# run_c10d mistral_example 16 "$ALL_ARGS" -# run_c10d mixtral_example 16 "$ALL_ARGS $MIXTRAL_ARGS --train_iters=50" +run_local stardoc_example 8 "$ALL_ARGS" +# run_c10d stardoc_example 16 "$ALL_ARGS" +# run_c10d stardoc_example 16 "$ALL_ARGS $MIXTRAL_ARGS --train_iters=50" diff --git a/examples/train_stardoc_akshay.sh b/examples/train_stardoc_akshay.sh new file mode 100755 index 000000000..97e6fd90e --- /dev/null +++ b/examples/train_stardoc_akshay.sh @@ -0,0 +1,156 @@ +# Required or optional environment variables +export PROJECT_DIR="/mnt/akshay/stardoc-FastLLM/Fast-LLM/output" +export PROJECT_NAME="stardoc-debug" +export PROJECT_VERSION="1.0" +# export DATA_PATH_LIST= +export DATA_PATH="/mnt/stardoc/datasets/save_hf/BigDoc-MultiTurn-v0.3" +# export DATA_PATH_JSON= +# export PRETRAINED_MISTRAL_PATH= +# export PRETRAINED_MIXTRAL_PATH= +export PRETRAINED_STARDOC_PATH="/mnt/akshay/stardoc-FastLLM/Fast-LLM/stardoc_hf_model/stardoc_checkpoint" +export TOKENIZER_PATH="/mnt/core_llm/models/mistral/HF/Mistral-7B-v0.3" + +export HF_HOME=/mnt/stardoc/hf +export HF_TOKEN=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 + +export PYTHONHASHSEED=12345 + +export CMD_ARGS="fast-llm train stardoc" +# export CMD_ARGS="python /mnt/akshay/stardoc-FastLLM/Fast-LLM/fast_llm/tools/train.py stardoc" + +export MODEL_ARGS_PRETRAINED="\ +--pretrained_checkpoint_type=huggingface \ +--pretrained_checkpoint_path=$PRETRAINED_STARDOC_PATH \ +--use_pretrained_config=1 \ +" + +export MODEL_ARGS_ARCHITECTURE="\ +--num_layers=32 \ +--hidden_size=4096 \ +--vocab_size=32000 \ +--num_attention_heads=32 \ +--head_groups=8 \ +--add_linear_biases=0 \ +--ffn_hidden_size=14336 \ +--kv_channels=128 \ +--use_rotary_embeddings=1 \ +--rotary_embedding_scale=-9.210340371976184 \ +--gated=1 \ +--activation_type=silu \ +--normalization_type=rms_norm \ +--tie_word_embeddings=0 \ +--window_size=8192 \ +" + +export MULTIMODAL_ARGS="\ +--image_encoder_hidden_size=1024 \ +--num_image_tokens=256 \ +--max_num_images=10 \ +--image_encoder_type=clip \ +" + +export DATA_ARGS="\ +--split=9998,2,0 \ +--dataset_type=stardoc \ +--dataset_source=multimodal \ +--data_path=$DATA_PATH \ +--tokenizer_type=PreTrainedTokenizer \ +--tokenizer_path=$TOKENIZER_PATH \ +" + +export TRAINING_ARGS="\ +--batch_size=8 \ +--sequence_length=8192 \ +--train_iters=500000 \ +--weight_decay=0.1 \ +--adam_beta1=0.9 \ +--adam_beta2=0.95 \ +--clip_grad=1.0 \ +--lr=0.0001 \ +--lr_warmup_iters=1000 \ +--lr_decay_style=cosine \ +--lr_decay_iters=500000 \ +--min_lr=0.000003 \ +" + +export PERFORMANCE_ARGS="\ +--micro_batch_size=1 \ +--training_dtype=bf16 \ +--zero_stage=3 \ +--num_workers=8 \ +" + +export MONITORING_ARGS="\ +--validation_iters=25 \ +--validation_interval=1000 \ +--log_interval=10 \ +--log_offset=0 \ +--checkpoint_interval=500 \ +--max_checkpoints=5 \ +--export_interval=25000 \ +" +# --wandb_status_interval=25000 \ +# --wandb_entity_name=$WANDB_ENTITY_NAME \ +# --wandb_project_name=$PROJECT_NAME \ +# --wandb_group_name=$PROJECT_VERSION \ +# " + +export ALL_ARGS="\ +$CMD_ARGS \ +$MODEL_ARGS_PRETRAINED \ +$MODEL_ARGS_ARCHITECTURE \ +$MULTIMODAL_ARGS \ +$DATA_ARGS \ +$TRAINING_ARGS \ +$PERFORMANCE_ARGS \ +$MONITORING_ARGS \ +" + +export PROFILE_ARGS="\ +--profile_cuda=1 \ +--profile_skip=10 \ +--profile_wait=95 \ +--profile_warmup=2 \ +--profile_cycles=3 \ +--profile_export=1 \ +" + +# cd /mnt/akshay/stardoc-FastLLM/Fast-LLM + +# PIP_NO_INPUT=1 pip3 install --user --no-cache-dir --no-dependencies -e . +# export PATH="$PATH:$HOME/.local/bin/" +# make -C ./fast_llm/csrc/ + +run_local () { # run(name, num_gpus, base_cmd) + echo $1 $2 $3 + export TORCHRUN="torchrun --nproc-per-node=$2 --nnodes=1 --no-python" + $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1 +} + +run_c10d () { # run(name, num_nodes, base_cmd) + echo $1 $2 $3 + export TORCHRUN="torchrun --nproc-per-node=8 --nnodes=$2 --no-python --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR" + $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1 +} + +run_distributed () { + echo $1 $2 $3 + # Master address and rank + MASTER_ADDR="dns-$EAI_PROCESS_AGENT-0" + MASTER_PORT=8001 + NODE_RANK="$EAI_PROCESS_AGENT_INDEX" + LOG_DIR="$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1" + + echo "MASTER ADDR: $MASTER_ADDR" + echo "PORT: $MASTER_PORT" + echo "NODE_RANK: $NODE_RANK" + echo "LOG_DIR: $LOG_DIR" + + export TORCHRUN="torchrun --nproc-per-node=8 --nnodes=$2 --no-python --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --node_rank=$NODE_RANK --log-dir=$LOG_DIR --redirects=3" + $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1 +} + +run_local debug 8 "$ALL_ARGS" +# run_distributed debug 2 "$ALL_ARGS" +# run_c10d mistral_example 16 "$ALL_ARGS" +# run_c10d mixtral_example 16 "$ALL_ARGS $MIXTRAL_ARGS --train_iters=50" diff --git a/fast_llm/data/stardoc.py b/fast_llm/data/stardoc.py index 415c35c41..ad80fd32b 100644 --- a/fast_llm/data/stardoc.py +++ b/fast_llm/data/stardoc.py @@ -17,6 +17,7 @@ image_loading_function, ) from fast_llm.data.stardoc_data_utils.docowl_stardoc_processor import docowl_text_preprocess_v1 +from fast_llm.data.stardoc_data_utils.constants import IGNORE_INDEX logger = logging.getLogger(__name__) @@ -106,7 +107,7 @@ def __getitem__(self, idx): sample_tokenized_buffer.extend(dummy_image_token_id * self.num_im_tokens) # Don't learn on any image tokens - [labels.append(-200) for x in range(len(sample_tokenized_buffer))] + [labels.append(IGNORE_INDEX) for x in range(len(sample_tokenized_buffer))] assert(len(queries) == len(annotations)) for i, (q, a) in enumerate(zip(queries, annotations)): @@ -120,6 +121,13 @@ def __getitem__(self, idx): sample_tokenized_buffer.append(self.tokenizer.eos_token_id) labels.extend(sample_tokenized_buffer[len(labels):len(sample_tokenized_buffer)]) assert len(sample_tokenized_buffer) == len(labels) + + # Right pad to max. sequence length + n_pad_tokens = self.tokenizer.max_seq_length - len(sample_tokenized_buffer) + sample_tokenized_buffer = sample_tokenized_buffer + n_pad_tokens*[self.tokenizer.pad_token_id] + + # Add an extra pad token to the labels at the end to support shifting left + labels = labels + (n_pad_tokens + 1) *[IGNORE_INDEX] return { 'input_ids': torch.tensor(sample_tokenized_buffer), diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 18abdf7d9..611a63309 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -174,7 +174,6 @@ def _logits_cross_entropy_forward_backward( kwargs: dict, losses: dict | None = None, ): - print(f'Loss head input_ shape: {input_.shape} Labels shape: {labels.shape}') logits, context = output_parallel_linear_forward( input_=input_, weight=weight, diff --git a/fast_llm/layers/multimodal_model/image_encoder.py b/fast_llm/layers/multimodal_model/image_encoder.py index d26305d74..d078bd591 100644 --- a/fast_llm/layers/multimodal_model/image_encoder.py +++ b/fast_llm/layers/multimodal_model/image_encoder.py @@ -56,7 +56,7 @@ def __init__( def get_fastllm_parameter(self, param_name, param): param_dims = tuple([TensorDim(name=f'{param_name}_{idx}', global_size=x, parallel_dim=None) for idx, x in enumerate(param.shape)]) - return ParameterMeta(param.to("meta"), tensor_name=param_name, dims=param_dims, init_method=init_normal_(std=0.02)) + return ParameterMeta(param.to("meta"), tensor_name=param_name, dims=param_dims, init_method=init_normal_(std=0.02), requires_grad=True, allow_no_grad=True) def _forward(self, input_: tuple[torch.Tensor], losses: dict | None = None, metrics: dict | None = None): if not self.image_encoder_type.lower() == "clip": diff --git a/fast_llm/layers/multimodal_model/multimodal_language_embedding.py b/fast_llm/layers/multimodal_model/multimodal_language_embedding.py index 64301d9c0..47aa85595 100644 --- a/fast_llm/layers/multimodal_model/multimodal_language_embedding.py +++ b/fast_llm/layers/multimodal_model/multimodal_language_embedding.py @@ -66,7 +66,6 @@ def __init__( def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, tokens: torch.Tensor | None): Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) Assert.eq(tokens is not None) - group = self._tensor_space.distributed.tensor_group text_embeddings = torch.embedding(self.word_embeddings_weight, tokens) diff --git a/fast_llm/models/stardoc/model.py b/fast_llm/models/stardoc/model.py index f6aa5cfac..b20101f30 100644 --- a/fast_llm/models/stardoc/model.py +++ b/fast_llm/models/stardoc/model.py @@ -269,7 +269,6 @@ def preprocess( for tokens_meta, kwargs_meta in preprocessed_meta: sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size tokens = tokens[:, sequence_k - sequence_q : sequence_k].contiguous() - print(f'Tokens sequence_k: {sequence_k} sequence_q: {sequence_q} shape: {tokens.shape}') pasts = presents presents = None if sequence_k == sequence_length else [] @@ -278,11 +277,9 @@ def preprocess( LanguageModelKwargs.tokens: tokens, TransformerKwargs.past_key_values: pasts, TransformerKwargs.presents: presents, - } if phase != PhaseType.inference: labels = labels[:, sequence_k - sequence_q + 1 : sequence_k + 1].contiguous() - print(f'Labels sequence_k: {sequence_k} sequence_q: {sequence_q} shape: {labels.shape}') kwargs[LanguageModelKwargs.labels] = labels if self._config.use_absolute_position_embeddings: diff --git a/run_multimodal.sh b/run_multimodal.sh deleted file mode 100644 index ed72d7a25..000000000 --- a/run_multimodal.sh +++ /dev/null @@ -1,88 +0,0 @@ -BASE_JOB_NAME="mistral-7b-FastLLM-stardoc-debug-local" - -export PYTHONHASHSEED=12345 - -export MODEL_ARGS="\ ---pretrained_checkpoint_type=huggingface \ ---pretrained_checkpoint_path=/data/git/Fast-LLM/stardoc_hf/stardoc_checkpoint \ ---use_pretrained_config=1 \ ---attention_dropout=0.0 \ ---hidden_dropout=0.0 \ ---max_num_images=5 \ ---image_resolution=224 \ ---num_image_tokens=256 \ ---image_encoder_hidden_size=1024 \ ---image_encoder_type=clip \ -" - -export STAGE_ARGS="\ ---zero_stage=3 \ -" - -export OPTIMIZER_ARGS="\ ---lr=0.000001 \ ---lr_decay_style=cosine \ ---lr_decay_iters=250 \ ---lr_warmup_iters=100 \ ---min_lr=0.0 \ ---weight_decay=0.1 \ ---adam_beta1=.9 \ ---adam_beta2=.95 \ ---clip_grad=1.0 \ -" - -export DATA_ARGS="\ ---split=9998,2,0 \ ---dataset_source=multimodal \ ---data_path=/data/datasets/stardoc/BigDoc-MultiTurn-v0.3 \ ---tokenizer_type=PreTrainedTokenizer \ ---tokenizer_path=/data/models/mistral/HF/Mistral-7B-v0.3 \ -" - -export SCHEDULE_ARGS="\ ---batch_size=32 \ ---micro_batch_size=1 \ ---sequence_length=8192 \ -" - -export DISTRIBUTED_ARGS="\ ---training_dtype=bf16 \ ---distributed_timeout=600 \ ---seed=984059 \ ---sequence_data_parallel=1 \ -" - -export RUN_ARGS="\ ---log_interval=10 \ ---log_offset=0 \ ---checkpoint_interval=500 \ ---max_checkpoints=5 \ ---export_interval=25000 \ -" - -export TRAINING_ARGS="\ ---train_iters=40000 \ ---validation_iters=25000000 \ ---validation_interval=1000000 \ ---test_iters=0 \ ---num_workers=1 \ -" - -export ALL_ARGS="\ -$MODEL_ARGS \ -$STAGE_ARGS \ -$DATA_ARGS \ -$SCHEDULE_ARGS \ -$OPTIMIZER_ARGS \ -$DISTRIBUTED_ARGS \ -$TRAINING_ARGS \ -$RUN_ARGS \ -" - -torchrun --nproc-per-node=8 \ - --log-dir=output/$BASE_JOB_NAME/logs \ - --redirects=3 \ - pretrain_fast_llm.py $ALL_ARGS --experiment_dir="output/$BASE_JOB_NAME/" - -# torchrun --nproc-per-node=8 \ -# pretrain_fast_llm.py $ALL_ARGS --experiment_dir="output/$BASE_JOB_NAME/" \ No newline at end of file diff --git a/set_env_stardoc.sh b/set_env_stardoc.sh deleted file mode 100755 index cf86c498a..000000000 --- a/set_env_stardoc.sh +++ /dev/null @@ -1,98 +0,0 @@ -#!/bin/bash - -export PROJECT_DIR="/mnt/akshay/stardoc-FastLLM/Fast-LLM/output" -export PROJECT_NAME="stardoc_debug" -export PROJECT_VERSION="1.0" -export RUN_NAME="debug" - -export DATA_PATH="/mnt/stardoc/datasets/save_hf/BigDoc-MultiTurn-v0.13" -export PRETRAINED_STARDOC_PATH="/mnt/akshay/stardoc-FastLLM/Fast-LLM/stardoc_hf_model/stardoc_checkpoint" -export TOKENIZER_PATH="/mnt/core_llm/models/mistral/HF/Mistral-7B-v0.3" - -export CMD_ARGS="fast-llm train gpt" - -export MODEL_ARGS_PRETRAINED="\ ---pretrained_checkpoint_type=huggingface \ ---pretrained_checkpoint_path=$PRETRAINED_STARDOC_PATH \ ---use_pretrained_config=1 \ -" - -export MODEL_ARGS_ARCHITECTURE="\ ---num_layers=32 \ ---hidden_size=4096 \ ---vocab_size=32000 \ ---num_attention_heads=32 \ ---head_groups=8 \ ---add_linear_biases=0 \ ---ffn_hidden_size=14336 \ ---kv_channels=128 \ ---use_rotary_embeddings=1 \ ---rotary_embedding_scale=-9.210340371976184 \ ---gated=1 \ ---activation_type=silu \ ---normalization_type=rms_norm \ ---tie_word_embeddings=0 \ ---window_size=8192 \ -" - -export DATA_ARGS="\ ---split=9998,2,0 \ ---dataset_source=multimodal \ ---data_path=$DATA_PATH \ ---tokenizer_type=PreTrainedTokenizer \ ---tokenizer_path=$TOKENIZER_PATH \ -" - -export TRAINING_ARGS="\ ---batch_size=8 \ ---sequence_length=8192 \ ---train_iters=500000 \ ---weight_decay=0.1 \ ---adam_beta1=0.9 \ ---adam_beta2=0.95 \ ---clip_grad=1.0 \ ---lr=0.0001 \ ---lr_warmup_iters=1000 \ ---lr_decay_style=cosine \ ---lr_decay_iters=500000 \ ---min_lr=0.000003 \ -" - -export PERFORMANCE_ARGS="\ ---micro_batch_size=1 \ ---training_dtype=bf16 \ ---zero_stage=2 \ ---num_workers=8 \ -" - -export MONITORING_ARGS="\ ---validation_iters=25 \ ---validation_interval=1000 \ ---log_interval=10 \ ---log_offset=0 \ ---checkpoint_interval=500 \ ---max_checkpoints=5 \ ---export_interval=25000 \ ---wandb_status_interval=25000 \ ---wandb_entity_name=$WANDB_ENTITY_NAME \ ---wandb_project_name=$PROJECT_NAME \ ---wandb_group_name=$PROJECT_VERSION \ -" - -export ALL_ARGS="\ -$CMD_ARGS \ -$MODEL_ARGS_PRETRAINED \ -$DATA_ARGS \ -$TRAINING_ARGS \ -$PERFORMANCE_ARGS \ -$MONITORING_ARGS \ -" - -export PROFILE_ARGS="\ ---profile_cuda=1 \ ---profile_skip=10 \ ---profile_wait=95 \ ---profile_warmup=2 \ ---profile_cycles=3 \ ---profile_export=1 \ -" \ No newline at end of file From b2d0c6eab82c58cb80fe0ba58b5e809c82620670 Mon Sep 17 00:00:00 2001 From: Akshay Kalkunte Date: Thu, 17 Oct 2024 01:27:02 +0000 Subject: [PATCH 3/4] cleanup 2 --- .gitignore | 4 +- examples/train_stardoc_akshay.sh | 156 ------------------------------- fast_llm/data/config.py | 1 + fast_llm/models/gpt/model.py | 2 +- 4 files changed, 3 insertions(+), 160 deletions(-) delete mode 100755 examples/train_stardoc_akshay.sh diff --git a/.gitignore b/.gitignore index ec960a900..83d992634 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,4 @@ venv.bak/ # Project specifics /.idea/ -/.vscode/ -output/ -stardoc_hf_model/ \ No newline at end of file +/.vscode/ \ No newline at end of file diff --git a/examples/train_stardoc_akshay.sh b/examples/train_stardoc_akshay.sh deleted file mode 100755 index 97e6fd90e..000000000 --- a/examples/train_stardoc_akshay.sh +++ /dev/null @@ -1,156 +0,0 @@ -# Required or optional environment variables -export PROJECT_DIR="/mnt/akshay/stardoc-FastLLM/Fast-LLM/output" -export PROJECT_NAME="stardoc-debug" -export PROJECT_VERSION="1.0" -# export DATA_PATH_LIST= -export DATA_PATH="/mnt/stardoc/datasets/save_hf/BigDoc-MultiTurn-v0.3" -# export DATA_PATH_JSON= -# export PRETRAINED_MISTRAL_PATH= -# export PRETRAINED_MIXTRAL_PATH= -export PRETRAINED_STARDOC_PATH="/mnt/akshay/stardoc-FastLLM/Fast-LLM/stardoc_hf_model/stardoc_checkpoint" -export TOKENIZER_PATH="/mnt/core_llm/models/mistral/HF/Mistral-7B-v0.3" - -export HF_HOME=/mnt/stardoc/hf -export HF_TOKEN=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 - -export PYTHONHASHSEED=12345 - -export CMD_ARGS="fast-llm train stardoc" -# export CMD_ARGS="python /mnt/akshay/stardoc-FastLLM/Fast-LLM/fast_llm/tools/train.py stardoc" - -export MODEL_ARGS_PRETRAINED="\ ---pretrained_checkpoint_type=huggingface \ ---pretrained_checkpoint_path=$PRETRAINED_STARDOC_PATH \ ---use_pretrained_config=1 \ -" - -export MODEL_ARGS_ARCHITECTURE="\ ---num_layers=32 \ ---hidden_size=4096 \ ---vocab_size=32000 \ ---num_attention_heads=32 \ ---head_groups=8 \ ---add_linear_biases=0 \ ---ffn_hidden_size=14336 \ ---kv_channels=128 \ ---use_rotary_embeddings=1 \ ---rotary_embedding_scale=-9.210340371976184 \ ---gated=1 \ ---activation_type=silu \ ---normalization_type=rms_norm \ ---tie_word_embeddings=0 \ ---window_size=8192 \ -" - -export MULTIMODAL_ARGS="\ ---image_encoder_hidden_size=1024 \ ---num_image_tokens=256 \ ---max_num_images=10 \ ---image_encoder_type=clip \ -" - -export DATA_ARGS="\ ---split=9998,2,0 \ ---dataset_type=stardoc \ ---dataset_source=multimodal \ ---data_path=$DATA_PATH \ ---tokenizer_type=PreTrainedTokenizer \ ---tokenizer_path=$TOKENIZER_PATH \ -" - -export TRAINING_ARGS="\ ---batch_size=8 \ ---sequence_length=8192 \ ---train_iters=500000 \ ---weight_decay=0.1 \ ---adam_beta1=0.9 \ ---adam_beta2=0.95 \ ---clip_grad=1.0 \ ---lr=0.0001 \ ---lr_warmup_iters=1000 \ ---lr_decay_style=cosine \ ---lr_decay_iters=500000 \ ---min_lr=0.000003 \ -" - -export PERFORMANCE_ARGS="\ ---micro_batch_size=1 \ ---training_dtype=bf16 \ ---zero_stage=3 \ ---num_workers=8 \ -" - -export MONITORING_ARGS="\ ---validation_iters=25 \ ---validation_interval=1000 \ ---log_interval=10 \ ---log_offset=0 \ ---checkpoint_interval=500 \ ---max_checkpoints=5 \ ---export_interval=25000 \ -" -# --wandb_status_interval=25000 \ -# --wandb_entity_name=$WANDB_ENTITY_NAME \ -# --wandb_project_name=$PROJECT_NAME \ -# --wandb_group_name=$PROJECT_VERSION \ -# " - -export ALL_ARGS="\ -$CMD_ARGS \ -$MODEL_ARGS_PRETRAINED \ -$MODEL_ARGS_ARCHITECTURE \ -$MULTIMODAL_ARGS \ -$DATA_ARGS \ -$TRAINING_ARGS \ -$PERFORMANCE_ARGS \ -$MONITORING_ARGS \ -" - -export PROFILE_ARGS="\ ---profile_cuda=1 \ ---profile_skip=10 \ ---profile_wait=95 \ ---profile_warmup=2 \ ---profile_cycles=3 \ ---profile_export=1 \ -" - -# cd /mnt/akshay/stardoc-FastLLM/Fast-LLM - -# PIP_NO_INPUT=1 pip3 install --user --no-cache-dir --no-dependencies -e . -# export PATH="$PATH:$HOME/.local/bin/" -# make -C ./fast_llm/csrc/ - -run_local () { # run(name, num_gpus, base_cmd) - echo $1 $2 $3 - export TORCHRUN="torchrun --nproc-per-node=$2 --nnodes=1 --no-python" - $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1 -} - -run_c10d () { # run(name, num_nodes, base_cmd) - echo $1 $2 $3 - export TORCHRUN="torchrun --nproc-per-node=8 --nnodes=$2 --no-python --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR" - $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1 -} - -run_distributed () { - echo $1 $2 $3 - # Master address and rank - MASTER_ADDR="dns-$EAI_PROCESS_AGENT-0" - MASTER_PORT=8001 - NODE_RANK="$EAI_PROCESS_AGENT_INDEX" - LOG_DIR="$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1" - - echo "MASTER ADDR: $MASTER_ADDR" - echo "PORT: $MASTER_PORT" - echo "NODE_RANK: $NODE_RANK" - echo "LOG_DIR: $LOG_DIR" - - export TORCHRUN="torchrun --nproc-per-node=8 --nnodes=$2 --no-python --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --node_rank=$NODE_RANK --log-dir=$LOG_DIR --redirects=3" - $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1 -} - -run_local debug 8 "$ALL_ARGS" -# run_distributed debug 2 "$ALL_ARGS" -# run_c10d mistral_example 16 "$ALL_ARGS" -# run_c10d mixtral_example 16 "$ALL_ARGS $MIXTRAL_ARGS --train_iters=50" diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index f77a16e70..81b724d65 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -114,6 +114,7 @@ class TokenizerConfig(Config): default="TokenizerFromFile", desc="Unused.", hint=FieldHint.deprecated, + valid=check_field(Assert.eq, TokenizerFromFile), ) path: str | None = Field( default=None, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index e93b8ad86..55f96831b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -97,7 +97,7 @@ def preprocess_meta(self, input_: BatchConfig | torch.Tensor, phase: PhaseType) if phase != PhaseType.inference: sequence_length -= 1 micro_sequence_length = sequence_length - + batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) From 00be01f00ca6c1df5e0bf38d6626a675ea1e4e19 Mon Sep 17 00:00:00 2001 From: Akshay Kalkunte Date: Thu, 14 Nov 2024 00:24:49 +0000 Subject: [PATCH 4/4] Re-factor to new changes in public repo --- examples/stardoc_config.yaml | 78 ++++++++ fast_llm/data/config.py | 49 ++++- fast_llm/data/data.py | 26 +-- .../stardoc_data_utils/docowl_processor.py | 10 +- fast_llm/data/tokenizer.py | 88 +++------ fast_llm/layers/language_model/config.py | 12 -- .../layers/multimodal_model/image_encoder.py | 2 +- .../multimodal_language_embedding.py | 2 +- fast_llm/models/stardoc/config.py | 78 ++++---- fast_llm/models/stardoc/conversion.py | 10 +- fast_llm/models/stardoc/data.py | 177 ++++++++++++++++++ fast_llm/models/stardoc/head.py | 6 + fast_llm/models/stardoc/huggingface.py | 18 ++ fast_llm/models/stardoc/model.py | 2 +- .../stardoc/stardoc_dataset.py} | 41 ++-- fast_llm/models/stardoc/trainer.py | 64 +------ 16 files changed, 434 insertions(+), 229 deletions(-) create mode 100644 examples/stardoc_config.yaml create mode 100644 fast_llm/models/stardoc/data.py create mode 100644 fast_llm/models/stardoc/head.py create mode 100644 fast_llm/models/stardoc/huggingface.py rename fast_llm/{data/stardoc.py => models/stardoc/stardoc_dataset.py} (79%) diff --git a/examples/stardoc_config.yaml b/examples/stardoc_config.yaml new file mode 100644 index 000000000..0c170dea5 --- /dev/null +++ b/examples/stardoc_config.yaml @@ -0,0 +1,78 @@ +training: + train_iters: 1000 + num_workers: 2 + logs: + interval: 10 + checkpoint: + interval: 1000 + keep: 10 + export: + interval: 1000 + validation: + iterations: null + test_iters: 0 +pretrained: + path: ".../stardoc_checkpoint" + format: huggingface +batch: + sequence_length: 8192 + micro_batch_size: 1 + batch_size: 8 +data: + split: [0.9, 0.1, 0] + path: ".../stardoc_data_config.json" + tokenizer: + format: TokenzierFromFile + path: ".../Mistral-7B-v0.3/tokenizer.json" + special_tokens: + eos_token: "" + bos_token: "" + pad_token: "[control_8]" + image_placeholder_token: "[control_9]" +optimizer: + learning_rate: + base: 1.0e-05 + decay_style: constant + warmup_iterations: 0 + weight_decay: 0.1 + beta_1: 0.9 + beta_2: 0.95 +model: + base_model: + transformer: + normalization: + type: rms_norm + epsilon: 1.0e-05 + num_layers: 32 + hidden_size: 4096 + ffn_hidden_size: 14336 + num_attention_heads: 32 + head_groups: 8 + add_linear_biases: false + use_rotary_embeddings: true + gated: true + activation_type: silu + triton_rotary: true + kv_channels: 128 + rotary_embedding_scale: -9.210340371976184 + window_size: 4096 + init_method_std: 0.009021 + attention_dropout: 0.0 + hidden_dropout: 0.0 + multimodal_model: + image_encoder_hidden_size: 1024 + num_image_tokens: 256 + max_num_images: 10 + image_resolution: 448 + image_encoder_type: clip + vocab_size: 32000 + tie_word_embeddings: false + multi_stage: + zero_stage: 3 + distributed: + training_dtype: bf16 + distributed_timeout: 3600 + seed: 984059 + +run: + experiment_dir: stardoc \ No newline at end of file diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 81b724d65..06ff476c8 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -99,9 +99,47 @@ def _validate(self): Assert.in_range_incl(self.rate, 0, 1) -EOD = "<|endoftext|>" TokenizerFromFile = "TokenizerFromFile" -PreTrainedTokenizer = "PreTrainedTokenzier" + + +@config_class() +class SpecialTokensConfig(Config): + """ + Define special tokens like EOS, BOS, PAD and image_placeholder tokens + """ + + bos_token: str | None = Field( + default=None, + desc="Beginning of sequence token", + hint=FieldHint.core, + ) + eos_token: str | None = Field( + default="<|endoftext|>", + desc="End of sequence token", + hint=FieldHint.core, + ) + pad_token: str | None = Field( + default=None, + desc="Pad token", + hint=FieldHint.core, + ) + image_placeholder_token: str | None = Field( + default=None, + desc="Placeholder token for images. Used only in multi-modal models", + hint=FieldHint.core, + ) + + def get_special_tokens(self): + special_tokens = [ + self.bos_token, + self.eos_token, + self.pad_token, + self.image_placeholder_token, + ] + + # Only return special tokens that are set + return [token for token in special_tokens if token is not None] + @config_class() class TokenizerConfig(Config): @@ -114,16 +152,15 @@ class TokenizerConfig(Config): default="TokenizerFromFile", desc="Unused.", hint=FieldHint.deprecated, - valid=check_field(Assert.eq, TokenizerFromFile), ) path: str | None = Field( default=None, desc="Path to the tokenizer file.", hint=FieldHint.core, ) - tokenizer_path: str | None = Field( - default=None, - desc="Path to pretrained tokenizer", + special_tokens: SpecialTokensConfig = Field( + default_factory=SpecialTokensConfig, + desc="Define special tokens.", hint=FieldHint.core, ) diff --git a/fast_llm/data/data.py b/fast_llm/data/data.py index 8b55cbdc0..ea1a79066 100644 --- a/fast_llm/data/data.py +++ b/fast_llm/data/data.py @@ -12,7 +12,7 @@ from fast_llm.data.config import AbstractData, DataConfig, DatasetSource from fast_llm.data.dataset import BlendedDataset, SampledDataset, Sampler from fast_llm.data.gpt import DummyGPTDataset, GPTDataset, GPTSampledDataset -from fast_llm.data.stardoc import StarDocDataset +from fast_llm.models.stardoc.stardoc_dataset import StarDocDataset from fast_llm.data.mmap import MMapIndexedDataset from fast_llm.data.tokenizer import Tokenizer, HuggingfacePreTrainedTokenizer from fast_llm.engine.config_utils.run import get_run, log_main_rank @@ -86,11 +86,6 @@ def __init__( assert len(dataset_prefixes) == len(set(dataset_prefixes)) dataset_weights = normalize_probs([float(x) for x in self._config.path[::2]]) self._build_and_sample_dataset = self._build_and_sample_gpt_dataset - elif self._config.dataset_source == DatasetSource.multimodal: - # FastLLM Split logic is overriden. Huggingface dataset defines the split - Assert.eq(len(self._config.data_path), 1) - dataset_prefixes, dataset_weights = [None], [1.0] - self._build_and_sample_dataset = self._build_and_sample_stardoc_dataset elif self._config.format == DatasetSource.sample: Assert.eq(len(self._config.path), 1) dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0] @@ -121,23 +116,6 @@ def __init__( } self._dataset_weights = {name: weight for name, weight in zip(dataset_names, dataset_weights)} - def build_tokenizer(self, max_sequence_length): - """Initialize tokenizer.""" - log_main_rank(f"> building {self._config.tokenizer.tokenizer_type}, {self._config.tokenizer.tokenizer_type or self._config.tokenizer.tokenizer_file} tokenizer ...") - - # Select and instantiate the tokenizer. - if self._config.tokenizer.tokenizer_type == "TokenizerFromFile": - assert self._config.tokenizer.tokenizer_file is not None - tokenizer = Tokenizer(self._config.tokenizer) - elif self._config.tokenizer.tokenizer_type == "PreTrainedTokenizer": - assert self._config.tokenizer.tokenizer_path is not None - tokenizer = HuggingfacePreTrainedTokenizer(self._config.tokenizer, max_sequence_length=max_sequence_length) - else: - raise NotImplementedError(f"{self.config.tokenizer.tokenizer_type} tokenizer is not implemented.") - - return tokenizer - - def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]): """ Load the datasets, and prepare or load the samplings. @@ -146,7 +124,7 @@ def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int run = get_run() Assert.leq(set(samples_per_phase), set(self._phase_split)) log_main_rank(f"Preparing {self._num_datasets} datasets. This may take several minutes.") - self._tokenizer = self.build_tokenizer(self._max_sequence_length) if (self._config.fim.fim_rate > 0 or self._config.dataset_source == DatasetSource.multimodal) else None + self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.rate > 0 else None self._distributed = distributed self._cache_dir = run.dataset_cache_dir self._samples_per_phase = samples_per_phase diff --git a/fast_llm/data/stardoc_data_utils/docowl_processor.py b/fast_llm/data/stardoc_data_utils/docowl_processor.py index 427f7825e..dfc71b99e 100644 --- a/fast_llm/data/stardoc_data_utils/docowl_processor.py +++ b/fast_llm/data/stardoc_data_utils/docowl_processor.py @@ -125,10 +125,10 @@ def __repr__(self) -> str: } class DocProcessor(): - def __init__(self, image_size=224, anchors='grid_9', add_global_img=True, add_textual_crop_indicator=False): + def __init__(self, image_size=224, anchors='grid_9', add_global_img=True, add_textual_crop_indicator=False, media_token=""): self.add_global_img = add_global_img self.add_textual_crop_indicator = add_textual_crop_indicator - self.media_token= "[control_8]" + self.media_token= media_token # h,w if isinstance(image_size, int): image_size = (image_size, image_size) @@ -205,13 +205,13 @@ def __call__(self, images=None, query=None): for patch_pos in patch_position.tolist(): # global non-crop image if patch_pos[0] == self.anchor_max and patch_pos[1] == self.anchor_max: - text += '[control_8]' + text += '' + self.media_token else: row_col = 'row'+str(patch_pos[0])+'_col'+str(patch_pos[1]) - text += '[control_8]' + text += '' + self.media_token else: # generate successive image placeholders for a image, 1 crop img == 1 <|image|> - text += '[control_8]'*num_image_mult[image_token_ptr] + text += self.media_token*num_image_mult[image_token_ptr] text += next_text image_token_ptr += 1 diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 5a368c4e5..933ecb738 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -1,6 +1,6 @@ from transformers import PreTrainedTokenizerFast -from fast_llm.data.config import EOD, TokenizerConfig +from fast_llm.data.config import TokenizerConfig from fast_llm.engine.config_utils.run import log_main_rank @@ -9,51 +9,18 @@ class Tokenizer: A Huggingface (transformers) tokenizer. """ - def __init__(self, config: TokenizerConfig): + def __init__(self, config: TokenizerConfig, max_sequence_length=None): log_main_rank(f"> loading tokenizer from {config.path} ...") - special_tokens = [EOD] - self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=config.path, errors="replace", max_len=None) + self._config = config + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=config.path, errors="replace", max_len=max_sequence_length) + special_tokens = config.special_tokens.get_special_tokens() self.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) - self.eod_id = self.tokenizer.vocab[EOD] + # Token->id mapping for additional special-tokens self.special_tokens = {tok: self.tokenizer.vocab[tok] for tok in special_tokens} self._inv_vocab = {v: k for k, v in self.tokenizer.vocab.items()} - - @property - def vocab_size(self): - return len(self.tokenizer) - - @property - def vocab(self): - return self.tokenizer.vocab - - @property - def inv_vocab(self): - return self._inv_vocab - - def tokenize(self, text): - return self.tokenizer.encode(text) - - def detokenize(self, token_ids): - return self.tokenizer.decode(token_ids) - - @property - def eod(self): - return self.eod_id - -class HuggingfacePreTrainedTokenizer: - """ - A Huggingface (transformers) tokenizer which uses from_pretrained() to load tokenizer - """ - - def __init__(self, config: TokenizerConfig, max_sequence_length: int): - log_main_rank(f"> loading tokenizer from {config.tokenizer_file} ...") - - self.tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_path) - # self.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) - self._inv_vocab = {v: k for k, v in self.tokenizer.vocab.items()} self._max_sequence_length = max_sequence_length - + @property def vocab_size(self): return len(self.tokenizer) @@ -67,42 +34,47 @@ def inv_vocab(self): return self._inv_vocab @property - def max_seq_length(self): + def max_sequence_length(self): return self._max_sequence_length @property def bos_token_id(self): - if self.tokenizer.bos_token_id: - return self.tokenizer.bos_token_id + bos_token = self._config.special_tokens.bos_token + if bos_token is not None: + return self.special_tokens[bos_token] else: raise ValueError("BOS token not set in tokenizer") @property def eos_token_id(self): - if self.tokenizer.eos_token_id: - return self.tokenizer.eos_token_id + eos_token = self._config.special_tokens.eos_token + if eos_token is not None: + return self.special_tokens[eos_token] else: raise ValueError("EOS token not set in tokenizer") @property def pad_token_id(self): - if self.tokenizer.pad_token_id: - log_main_rank("PAD token being set to EOS token") - return self.tokenizer.pad_token_id + pad_token = self._config.special_tokens.pad_token + if pad_token is not None: + return self.special_tokens[pad_token] else: - return self.tokenizer.eos_token_id + raise ValueError("PAD token not set in tokenizer") - def tokenize(self, text, **kwargs): - return self.tokenizer.encode(text, **kwargs) + @property + def image_placeholder_token_id(self): + image_placeholder_token = self._config.special_tokens.image_placeholder_token + if image_placeholder_token is not None: + return self.special_tokens[image_placeholder_token] + else: + raise ValueError("Image placeholder token not set in tokenizer") - def detokenize(self, token_ids, **kwargs): - return self.tokenizer.decode(token_ids, **kwargs) + def tokenize(self, text): + return self.tokenizer.encode(text) - def detokenize_batch(self, token_ids, **kwargs): - return self.tokenizer.batch_decode(token_ids, **kwargs) + def detokenize(self, token_ids): + return self.tokenizer.decode(token_ids) @property def eod(self): - return self.eod_id - - + return self.eos_token_id \ No newline at end of file diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 070726498..c6f8380bf 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -4,7 +4,6 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig -from fast_llm.layers.multimodal_model.config import MultimodalModelArchitectureConfig, MultimodalModelBaseConfig from fast_llm.utils import Assert @@ -37,11 +36,6 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.core, ) - multimodal_model: MultimodalModelArchitectureConfig = Field( - default_factory=MultimodalModelArchitectureConfig, - desc="Configuration for the multimodal components (image encoder and adapter).", - hint=FieldHint.core, - ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -70,7 +64,6 @@ def _validate(self): def setup_tensor_space(self, tensor_space: TensorSpace): self.transformer.setup_tensor_space(tensor_space) - self.multimodal_model.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions @@ -120,11 +113,6 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): transformer: TransformerConfig = Field( default_factory=TransformerConfig, desc="Configuration for the transformer.", hint=FieldHint.core ) - multimodal_model: MultimodalModelBaseConfig = Field( - default_factory=MultimodalModelBaseConfig, - desc="Configuration for the multimodal components (image encoder and adapter).", - hint=FieldHint.core, - ) init_method_std_embed: float = Field( default=None, desc="Initialization scale for the vocabulary embedding and output weights (logits).", diff --git a/fast_llm/layers/multimodal_model/image_encoder.py b/fast_llm/layers/multimodal_model/image_encoder.py index d078bd591..6e5b1191d 100644 --- a/fast_llm/layers/multimodal_model/image_encoder.py +++ b/fast_llm/layers/multimodal_model/image_encoder.py @@ -56,7 +56,7 @@ def __init__( def get_fastllm_parameter(self, param_name, param): param_dims = tuple([TensorDim(name=f'{param_name}_{idx}', global_size=x, parallel_dim=None) for idx, x in enumerate(param.shape)]) - return ParameterMeta(param.to("meta"), tensor_name=param_name, dims=param_dims, init_method=init_normal_(std=0.02), requires_grad=True, allow_no_grad=True) + return ParameterMeta(param.to("meta"), tensor_name=param_name, dims=param_dims, init_method=init_normal_(std=0.02), allow_no_grad=True) def _forward(self, input_: tuple[torch.Tensor], losses: dict | None = None, metrics: dict | None = None): if not self.image_encoder_type.lower() == "clip": diff --git a/fast_llm/layers/multimodal_model/multimodal_language_embedding.py b/fast_llm/layers/multimodal_model/multimodal_language_embedding.py index 47aa85595..9d5275715 100644 --- a/fast_llm/layers/multimodal_model/multimodal_language_embedding.py +++ b/fast_llm/layers/multimodal_model/multimodal_language_embedding.py @@ -72,7 +72,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, toke bsz, num_imgs, _, hidden_size = input_.shape # TODO: Hardcoded image token - image_token_mask = tokens == 10 + image_token_mask = tokens == 11 embeddings = text_embeddings.clone() embeddings[image_token_mask] = input_.view(-1, hidden_size) diff --git a/fast_llm/models/stardoc/config.py b/fast_llm/models/stardoc/config.py index bc8efcae5..390069eac 100644 --- a/fast_llm/models/stardoc/config.py +++ b/fast_llm/models/stardoc/config.py @@ -1,55 +1,57 @@ import typing from fast_llm.config import Field, FieldHint, config_class -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig +from fast_llm.data.config import DataConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelArchitectureConfig, LanguageModelBaseConfig + +from fast_llm.models.gpt.config import ( + GPTArchitectureConfig, + GPTBaseModelConfig, + GPTTrainerConfig, +) + +from fast_llm.layers.multimodal_model.config import MultimodalModelArchitectureConfig, MultimodalModelBaseConfig +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace if typing.TYPE_CHECKING: from fast_llm.engine.multi_stage.conversion import ModelConverter @config_class() -class StarDocArchitectureConfig(LanguageModelArchitectureConfig): - _abstract = False +class StarDocDataConfig(DataConfig): + # TODO: If needed, inherit from AbstractDataConfig instead and re-implement everything. + pass - @classmethod - def _from_dict( - cls, - default: dict, - strict: bool = True, - flat: bool = False, - ): - # TODO v0.2: Remove backward compatibility fix - if "transposed_mlp_weight" in default: - assert default.pop("transposed_mlp_weight") - return super()._from_dict(default, strict, flat) +@config_class() +class StarDocArchitectureConfig(GPTArchitectureConfig): + multimodal_model: MultimodalModelArchitectureConfig = Field( + default_factory=MultimodalModelArchitectureConfig, + desc="Configuration for the multimodal components (image encoder and adapter).", + hint=FieldHint.core, + ) + + def setup_tensor_space(self, tensor_space: TensorSpace): + super().setup_tensor_space(tensor_space) + self.multimodal_model.setup_tensor_space(tensor_space) + @classmethod def get_converter_class(cls, model_type: str | None = None) -> type["ModelConverter"]: from fast_llm.models.stardoc.conversion import AutoStarDocConverter return AutoStarDocConverter if model_type is None else AutoStarDocConverter.converter_map[model_type] - @config_class() -class StarDocBaseModelConfig(LanguageModelBaseConfig, StarDocArchitectureConfig): +class StarDocBaseModelConfig(GPTBaseModelConfig, StarDocArchitectureConfig): architecture_cls = StarDocArchitectureConfig - @classmethod - def _from_dict( - cls, - default: dict, - strict: bool = True, - flat: bool = False, - ): - # TODO v0.2: Remove backward compatibility fix - if "layer_norm_impl" in default: - assert "normalization_implementation" not in default - default["normalization_implementation"] = default.pop("layer_norm_impl") - if "fused_mlp" in default: - del default["fused_mlp"] - return super()._from_dict(default, strict, flat) + multimodal_model: MultimodalModelBaseConfig = Field( + default_factory=MultimodalModelBaseConfig, + desc="Configuration for the multimodal components (image encoder and adapter).", + hint=FieldHint.core, + ) + @config_class() @@ -63,12 +65,6 @@ def get_model_class(cls): return StarDocModel - @classmethod - def get_huggingface_model_class(cls): - from fast_llm.models.stardoc.huggingface import HuggingfaceStarDocModelForCausalLM - - return HuggingfaceStarDocModelForCausalLM - @config_class() class PretrainedStarDocModelConfig(PretrainedFastLLMModelConfig): @@ -77,13 +73,7 @@ class PretrainedStarDocModelConfig(PretrainedFastLLMModelConfig): @config_class() -class StarDocTrainerConfig(PretrainedStarDocModelConfig, TrainerConfig): - def _setup(self): - super()._setup() - if self.batch.sequence_length is None: - # TODO: Drop this. - self.batch.sequence_length = self.base_model.max_position_embeddings - +class StarDocTrainerConfig(PretrainedStarDocModelConfig, GPTTrainerConfig): @classmethod def get_trainer_class(cls): from fast_llm.models.stardoc.trainer import StarDocTrainer diff --git a/fast_llm/models/stardoc/conversion.py b/fast_llm/models/stardoc/conversion.py index 04530e305..e5186a34e 100644 --- a/fast_llm/models/stardoc/conversion.py +++ b/fast_llm/models/stardoc/conversion.py @@ -129,7 +129,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = [] num_layers = self.config.transformer.num_layers - norm_bias: bool = self.config.transformer.normalization.normalization_type == NormalizationType.layer_norm + norm_bias: bool = self.config.transformer.normalization.type == NormalizationType.layer_norm linear_bias: bool = self.config.transformer.add_linear_biases # Vision encoder @@ -517,9 +517,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(None, "architectures", ["Starcoder2ForCausalLM"]), ConstantImportParamConverter( - ("transformer", "normalization", "normalization_type"), None, NormalizationType.layer_norm + ("transformer", "normalization", "type"), None, NormalizationType.layer_norm ), - ParamConverter(("transformer", "normalization", "layer_norm_eps"), "norm_epsilon"), + ParamConverter(("transformer", "normalization", "epsilon"), "norm_epsilon"), ConstantImportParamConverter(("transformer", "gated"), None, False), ConstantImportParamConverter(("transformer", "add_linear_biases"), None, True), ] @@ -541,9 +541,9 @@ class CommonLlamaHuggingfaceConverter(CommonHuggingfaceConverter, abc.ABC): def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantImportParamConverter( - ("transformer", "normalization", "normalization_type"), None, NormalizationType.rms_norm + ("transformer", "normalization", "type"), None, NormalizationType.rms_norm ), - ParamConverter(("transformer", "normalization", "layer_norm_eps"), "rms_norm_eps"), + ParamConverter(("transformer", "normalization", "epsilon"), "rms_norm_eps"), ConstantImportParamConverter(("transformer", "gated"), None, True), ConstantImportParamConverter(("transformer", "add_linear_biases"), None, False), ] diff --git a/fast_llm/models/stardoc/data.py b/fast_llm/models/stardoc/data.py new file mode 100644 index 000000000..3d39b0da2 --- /dev/null +++ b/fast_llm/models/stardoc/data.py @@ -0,0 +1,177 @@ +import json +import logging +import math +import pathlib +import typing +import warnings +import numpy as np + +from fast_llm.models.stardoc.config import StarDocDataConfig +from fast_llm.models.stardoc.stardoc_dataset import StarDocDataset +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.config_utils.run import get_run, log_main_rank +from fast_llm.data.data import Data +from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.dataset import BlendedDataset, SampledDataset, Sampler +from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +def normalize_probs(p: list[float]) -> list[float]: + p = np.array(p) + Assert.custom(lambda x: np.all(x >= 0), p) + p_sum = p.sum() + Assert.gt(p_sum, 0) + return (p / p_sum).tolist() + + +class StarDocData(Data): + """ + A class for all dataset needs for StarDoc. + """ + _sampled_datasets: dict[PhaseType, dict[str, SampledDataset]] + _blended_datasets: dict[PhaseType, SampledDataset] + _tokenizer: Tokenizer | None + _distributed: Distributed + _cache_dir: pathlib.Path | None + _samples_per_phase: dict[PhaseType, int] + _phases: typing.ClassVar[tuple[PhaseType, ...]] = (PhaseType.training, PhaseType.validation, PhaseType.test) + + def __init__( + self, + config: StarDocDataConfig, + distributed_config: DistributedConfig, + vocab_size: int, + max_sequence_length: int, + ): + """ + Create the data and gather some basic information on the dataset(s). + Should be `setup` before use. + """ + self._config = config.validate() + self._distributed_config = distributed_config.validate() + self._vocab_size = vocab_size + self._max_sequence_length = max_sequence_length + Assert.eq(len(self._config.split), len(self._phases)) + self._phase_split = { + phase: ratio for phase, ratio in zip(self._phases, normalize_probs(self._config.split)) if ratio > 0 + } + data_base_path = None + Assert.eq(len(self._config.path), 1) + data_path = pathlib.Path(self._config.path[0]) + dataset_defs = json.load(data_path.open("r")) + data_base_path = data_path.parent + dataset_prefixes = [dataset_def["prefix"] for dataset_def in dataset_defs["datasets"]] + dataset_weights = normalize_probs([dataset_def["weight"] for dataset_def in dataset_defs["datasets"]]) + self._build_and_sample_dataset = self._build_and_sample_stardoc_dataset + + dataset_names = [ + f"dataset_{i}_{'dummy' if prefix is None else prefix.replace('/','__')}" + for i, prefix in enumerate(dataset_prefixes) + ] + self._num_datasets = len(dataset_names) + self._dataset_prefixes = { + name: ( + None + if prefix is None + else ( + pathlib.Path(prefix).resolve() + if data_base_path is None + else (pathlib.Path(data_base_path) / prefix).resolve() + ) + ) + for name, prefix in zip(dataset_names, dataset_prefixes) + } + self._dataset_weights = {name: weight for name, weight in zip(dataset_names, dataset_weights)} + + def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]): + """ + Load the datasets. This may take a while and a significant amount of cpu memory. + """ + run = get_run() + Assert.leq(set(samples_per_phase), set(self._phase_split)) + log_main_rank(f"Preparing {self._num_datasets} datasets. This may take several minutes.") + self._tokenizer = Tokenizer(self._config.tokenizer, max_sequence_length=self._max_sequence_length) + self._distributed = distributed + self._cache_dir = run.dataset_cache_dir + self._samples_per_phase = samples_per_phase + if self._cache_dir is None: + warnings.warn(f"Using the dataset directory for the index cache.") + + # Build and split datasets. + self._sampled_datasets = {phase: {} for phase in self._samples_per_phase} + for i, (name, weight) in enumerate(self._dataset_weights.items()): + if i % 100 == 0 and i > 0: + log_main_rank(f"Prepared {i} of {self._num_datasets} datasets.") + dataset_samples_per_phase = {} + for phase, samples_per_phase in self._samples_per_phase.items(): + expected_samples = self._dataset_weights[name] * samples_per_phase + # Add 5 times the standard deviation (of a binomial distribution) + # so the probability of sampling more than this amount during blending is negligible. + dataset_samples_per_phase[phase] = math.ceil( + expected_samples + + 5 * math.sqrt(expected_samples * self._dataset_weights[name] * (1 - self._dataset_weights[name])) + ) + sampled_datasets = self._build_and_sample_dataset(name, dataset_samples_per_phase) + for phase, dataset in sampled_datasets.items(): + self._sampled_datasets[phase][name] = dataset + + self._blended_datasets = { + phase: ( + list(datasets.values())[0] + if len(datasets) == 1 + else BlendedDataset( + list(datasets.values()), + weights=[self._dataset_weights[name] for name in datasets], + name=phase.value, + num_samples=self._samples_per_phase[phase], + cache_dir=self._cache_dir, + group=self._distributed.world_group, + verbose=run.is_main_rank, + data_sample_warn_time_ms=self._config.data_sample_warn_time_ms, + ) + ) + for phase, datasets in self._sampled_datasets.items() + } + + def get_iterator( + self, + batch_config: BatchConfig, + phase: PhaseType, + *, + consumed_samples: int, + num_workers: int, + prefetch_factor: int | None = None, + ): + # TODO: Adjust or reimplement. + return super().get_iterator( + batch_config, + phase, + consumed_samples=consumed_samples, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + ) + + def _build_and_sample_stardoc_dataset(self, name: str, dataset_samples_per_phase: dict[PhaseType, int]): + sampled_datasets = {} + for phase, num_samples in dataset_samples_per_phase.items(): + if num_samples == 0: + continue + + # TODO: Get image handling parameters from config + sampled_datasets[phase] = StarDocDataset( + self._dataset_prefixes[name], + im_size=224, + num_samples=num_samples, + num_im_tokens=256, + transforms=False, + multi_imgs=True, + split=phase, + tokenizer=self._tokenizer, + config=self._config, + ) + + return sampled_datasets \ No newline at end of file diff --git a/fast_llm/models/stardoc/head.py b/fast_llm/models/stardoc/head.py new file mode 100644 index 000000000..786e36929 --- /dev/null +++ b/fast_llm/models/stardoc/head.py @@ -0,0 +1,6 @@ +from fast_llm.layers.language_model.head import LanguageModelHead + + +class CustomHead(LanguageModelHead): + # TODO: Implement custom parts + pass diff --git a/fast_llm/models/stardoc/huggingface.py b/fast_llm/models/stardoc/huggingface.py new file mode 100644 index 000000000..7db4e73f8 --- /dev/null +++ b/fast_llm/models/stardoc/huggingface.py @@ -0,0 +1,18 @@ +from fast_llm.models.custom.config import CustomModelConfig +from fast_llm.models.custom.model import CustomModel +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelConfig, HuggingfaceGPTModelForCausalLM + + +class HuggingfaceCustomModelConfig(HuggingfaceGPTModelConfig): + model_type = "fast_llm_gpt_custom" + model_config_class = CustomModelConfig + fast_llm_config: CustomModelConfig + + +class HuggingfaceCustomModelForCausalLM(HuggingfaceGPTModelForCausalLM): + # TODO: Implement changes in huggingface interface, if any. + # Ex.: Return predictions instead of logits. + config_class = HuggingfaceCustomModelConfig + config: HuggingfaceCustomModelConfig + model_class = CustomModel + _fast_llm_model: CustomModel diff --git a/fast_llm/models/stardoc/model.py b/fast_llm/models/stardoc/model.py index b20101f30..f5f264481 100644 --- a/fast_llm/models/stardoc/model.py +++ b/fast_llm/models/stardoc/model.py @@ -342,4 +342,4 @@ def loss_defs(self) -> list[LossDef]: class StarDocModel(FastLLMModel): config_class = StarDocModelConfig - base_model_class = StarDocBaseModel + base_model_class = StarDocBaseModel \ No newline at end of file diff --git a/fast_llm/data/stardoc.py b/fast_llm/models/stardoc/stardoc_dataset.py similarity index 79% rename from fast_llm/data/stardoc.py rename to fast_llm/models/stardoc/stardoc_dataset.py index ad80fd32b..1fcee9367 100644 --- a/fast_llm/data/stardoc.py +++ b/fast_llm/models/stardoc/stardoc_dataset.py @@ -18,12 +18,16 @@ ) from fast_llm.data.stardoc_data_utils.docowl_stardoc_processor import docowl_text_preprocess_v1 from fast_llm.data.stardoc_data_utils.constants import IGNORE_INDEX +from fast_llm.engine.distributed.config import PhaseType + logger = logging.getLogger(__name__) + class StarDocDataset(Dataset): def __init__( self, + dataset_path: str | None = None, im_size: int = 224, num_samples: int = -1, num_im_tokens: int = 256, @@ -39,23 +43,28 @@ def __init__( self.num_im_tokens = num_im_tokens self.multi_imgs = multi_imgs self.tokenizer = tokenizer - self.split=split + phase_map = { + PhaseType.training: "train", + PhaseType.validation: "val", + PhaseType.test: "test", + } + self.split=phase_map[split] # Use DocOwl processor - self.processor = DocProcessor(image_size=self.im_size, anchors='grid_9', add_global_img=True, add_textual_crop_indicator=True) + self.processor = DocProcessor(image_size=self.im_size, anchors='grid_9', add_global_img=True, add_textual_crop_indicator=True, media_token=self.tokenizer._config.special_tokens.image_placeholder_token) - dataset_path = config.data_path[0] + assert dataset_path is not None # TODO: config validation issue multimodal_load_local = True if multimodal_load_local: # Load from a locally cached copy of the dataset self.data_dict = load_from_disk(dataset_path) - self.data = self.data_dict[split] + self.data = self.data_dict[self.split] else: # Load the required spit from HF # TODO: configurable cache_dir - self.data = load_dataset(dataset_path, split=split, cache_dir="/mnt/core_llm/cache/", num_proc=os.cpu_count()-1) + self.data = load_dataset(dataset_path, split=self.split, cache_dir="/mnt/core_llm/cache/", num_proc=os.cpu_count()-1) if self.num_samples != -1: self.data = self.data.select(range(self.num_samples)) @@ -87,24 +96,20 @@ def __getitem__(self, idx): # Add BOS token at the beginning of the sample sample_tokenized_buffer.append(self.tokenizer.bos_token_id) - # Dummy image token ID - dummy_image_token_id = self.tokenizer.tokenize("[control_8]", add_special_tokens=False) - assert len(dummy_image_token_id) == 1 - # tokenized IDs for "USER:" and "ASSISTANT:" - user_ids = self.tokenizer.tokenize("USER: ", add_special_tokens=False) - assistant_ids = self.tokenizer.tokenize(" ASSISTANT: ", add_special_tokens=False) + user_ids = self.tokenizer.tokenize("USER: ") + assistant_ids = self.tokenizer.tokenize(" ASSISTANT: ") sample_tokenized_buffer.extend(user_ids) # Add dummy tokens for all image tokens if len(images) > 0: # Get all the crops and process them - all_images, _, processed_query = self.processor(images=images, query="[control_8]") - crop_splits = processed_query.split("[control_8]")[:-1] + all_images, _, processed_query = self.processor(images=images, query=self.tokenizer._config.special_tokens.image_placeholder_token) + crop_splits = processed_query.split(self.tokenizer._config.special_tokens.image_placeholder_token)[:-1] assert len(crop_splits) == len(all_images) for crop_split_part in crop_splits: - sample_tokenized_buffer.extend(self.tokenizer.tokenize(crop_split_part.strip(), add_special_tokens=False)) - sample_tokenized_buffer.extend(dummy_image_token_id * self.num_im_tokens) + sample_tokenized_buffer.extend(self.tokenizer.tokenize(crop_split_part.strip())) + sample_tokenized_buffer.extend([self.tokenizer.image_placeholder_token_id] * self.num_im_tokens) # Don't learn on any image tokens [labels.append(IGNORE_INDEX) for x in range(len(sample_tokenized_buffer))] @@ -113,9 +118,9 @@ def __getitem__(self, idx): for i, (q, a) in enumerate(zip(queries, annotations)): if i>0: sample_tokenized_buffer.extend(user_ids) - sample_tokenized_buffer.extend(self.tokenizer.tokenize(q, add_special_tokens=False)) + sample_tokenized_buffer.extend(self.tokenizer.tokenize(q)) sample_tokenized_buffer.extend(assistant_ids) - sample_tokenized_buffer.extend(self.tokenizer.tokenize(a, add_special_tokens=False)) + sample_tokenized_buffer.extend(self.tokenizer.tokenize(a)) # Add EOS token at the end of the sample sample_tokenized_buffer.append(self.tokenizer.eos_token_id) @@ -123,7 +128,7 @@ def __getitem__(self, idx): assert len(sample_tokenized_buffer) == len(labels) # Right pad to max. sequence length - n_pad_tokens = self.tokenizer.max_seq_length - len(sample_tokenized_buffer) + n_pad_tokens = self.tokenizer.max_sequence_length - len(sample_tokenized_buffer) sample_tokenized_buffer = sample_tokenized_buffer + n_pad_tokens*[self.tokenizer.pad_token_id] # Add an extra pad token to the labels at the end to support shifting left diff --git a/fast_llm/models/stardoc/trainer.py b/fast_llm/models/stardoc/trainer.py index 885714de3..9f77eb790 100644 --- a/fast_llm/models/stardoc/trainer.py +++ b/fast_llm/models/stardoc/trainer.py @@ -1,63 +1,19 @@ -import logging - -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.engine.training.trainer import Trainer from fast_llm.models.stardoc.config import StarDocTrainerConfig +from fast_llm.models.stardoc.data import StarDocData from fast_llm.models.stardoc.model import StarDocModel - -logger = logging.getLogger(__name__) +from fast_llm.models.gpt.trainer import GPTTrainer -class StarDocTrainer(Trainer): +class StarDocTrainer(GPTTrainer): _abstract = False + _config: StarDocTrainerConfig config_class = StarDocTrainerConfig model_class = StarDocModel - def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: - # TODO: Do in model, automate/generalize, get other stats - """Get tflop/s/GPU from global-batch-size and elapsed-time""" - checkpoint_activations_factor = 3 if phase == PhaseType.training else 1 - transformer_config = self._config.base_model.transformer - sequence_length = self._config.batch.sequence_length - - tokens = self._config.batch.batch_size * sequence_length - transformer_flops_base = 2 * checkpoint_activations_factor * tokens * transformer_config.num_layers - dense_flops_base = transformer_flops_base * transformer_config.hidden_size - # Query, key, value, dense. - flops_per_iteration = ( - 2 - * (transformer_config.num_attention_heads + transformer_config.head_groups) - * transformer_config.kv_channels - * dense_flops_base - ) - # MLP - flops_per_iteration += ( - (2 + transformer_config.gated) - * transformer_config.ffn_hidden_size - * dense_flops_base - * transformer_config.num_experts_per_token + def _get_data(self): + return StarDocData( + config=self._config.data, + distributed_config=self._config.distributed, + vocab_size=self._config.base_model.vocab_size, + max_sequence_length=self._config.batch.sequence_length, ) - - # LM-head - flops_per_iteration += 6 * tokens * transformer_config.hidden_size * self._config.base_model.vocab_size - - # Attention-matrix computation - attn_flops_base = transformer_flops_base * transformer_config.projection_size - if transformer_config.window_size is None: - # Ignore masked values (s**2/2) - attn_flops = attn_flops_base * sequence_length - model_tflops = flops_per_iteration + attn_flops - else: - # s*w - w**2/2 - attn_flops = ( - 2 - * attn_flops_base - * transformer_config.window_size - * (1 - transformer_config.window_size / 2 / sequence_length) - ) - model_tflops = flops_per_iteration + attn_flops - - # Partial recomputation (normal is 2 ops * ckpt_factor = 6, adding 1 for recomputing Q x K) - hardware_flops = flops_per_iteration + 7 / 6 * attn_flops - ratio = elapsed_time_per_iteration * self._config.distributed.world_size * 1e12 - return model_tflops / ratio, hardware_flops / ratio