diff --git a/fast_llm/csrc/data.cpp b/fast_llm/csrc/data.cpp index 59598eaf..a1a24c7c 100644 --- a/fast_llm/csrc/data.cpp +++ b/fast_llm/csrc/data.cpp @@ -27,7 +27,7 @@ /* Helper methods for fast index mapping builds. - Changes for Fast-LLM: Use int16 for dataset index, add verbose argument to build_sample_idx. + Changes for Fast-LLM: Use int16 for dataset index, add verbose argument to build_sample_idx, add build_sample_idx_padded */ #include @@ -129,6 +129,65 @@ py::array build_sample_idx(const py::array_t& sizes_, } +py::array build_padded_token_cumsum(const py::array_t& sizes_, + const int32_t seq_length, + const int32_t token_cumsum_rate, + const int64_t offset + ) { + /* + Build token cumsums at regular intervals from document sizes with padding in mind. + We inject 0 or more padding tokens at the end of every sequence to fill the sequence length. + */ + int32_t seq_size = 0; + int64_t sizes_idx = 0; + int32_t samples = 0; + auto sizes = sizes_.unchecked<1>(); + std::vector token_cumsum; + + int64_t cumsum = offset; + + while (sizes_idx < sizes.size()) { + int32_t size = sizes[sizes_idx]; + if (size > seq_length) { + // Skip sequences that are too long, to avoid truncations + if (samples % token_cumsum_rate==0) token_cumsum.push_back(cumsum); + sizes_idx += 1; + samples += 1; + } else if (seq_size + size > seq_length) { + // add padded tokens if a document does not fit in current sequence and start a new sequence + cumsum += seq_length - seq_size; + seq_size = 0; + } else { + // Increment here to account for padding. This ensures that the stored values match the beginning of the next document. + if (samples % token_cumsum_rate==0) token_cumsum.push_back(cumsum); + seq_size += size; + cumsum += size; + sizes_idx += 1; + samples += 1; + } + } + + // Add a final (padded) entry so we know how many tokens there are in total. + cumsum += seq_length - seq_size; + token_cumsum.push_back(cumsum); + + + int64_t* token_cumsum_result = new int64_t[token_cumsum.size()]; + memcpy(token_cumsum_result, token_cumsum.data(), token_cumsum.size() * sizeof(int64_t)); + + py::capsule free_when_done(token_cumsum_result, [](void *mem_) { + int64_t *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + const auto byte_size = sizeof(int64_t); + return py::array(std::vector{token_cumsum.size()}, + {byte_size}, + token_cumsum_result, + free_when_done); +} + PYBIND11_MODULE(data, m) { m.def("build_sample_idx", &build_sample_idx); + m.def("build_padded_token_cumsum", &build_padded_token_cumsum); } diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 18b1eaac..c98a781e 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -57,6 +57,15 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Multiprocessing context. Do not touch.", hint=FieldHint.expert, ) + truncate_documents: bool = Field( + default=True, + desc=( + "If enabled, documents may be truncated while being packed to fit the sequence length." + "Otherwise, sequences will be padded such that every document lies entirely within a sample" + " (and documents exceeding the sequence length will be skipped altogether)." + ), + hint=FieldHint.feature, + ) def _validate(self) -> None: if not self.datasets: diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 8fc33376..a0940e7c 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -121,6 +121,7 @@ def setup( sequence_length=self._max_sequence_length, vocab_size=self._vocab_size, tokenizer=self._tokenizer, + truncate_documents=self._config.truncate_documents, cross_document_attention=self._cross_document_attention, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index f91c537e..0f04884b 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -71,6 +71,7 @@ class GPTSamplingData(SamplingData): sequence_length: int vocab_size: int tokenizer: "Tokenizer" + truncate_documents: bool = True cross_document_attention: bool = True diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f5d23031..dd449358 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -12,12 +12,12 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.engine.config_utils.data_type import get_unsigned_integer_type +from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert try: - from fast_llm.csrc.data import build_sample_idx # noqa + from fast_llm.csrc.data import build_padded_token_cumsum, build_sample_idx # noqa _extension_available = True except ImportError: @@ -89,6 +89,7 @@ def __init__( self._sequence_length = sampling.sequence_length self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config + self._truncate_documents = sampling.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") if sampling.cache_directory is None: @@ -124,15 +125,35 @@ def _sample(self) -> None: """ # Get the document sizes, the main information needed for sampling. document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) - - # Calculate basic stats. documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + + # Calculate basic stats. + if not self._truncate_documents: + assert _extension_available, ( + "The C++ extension for dataset sampling is missing." + " Please make sure Fast-LLM is installed correctly." + ) + long_docs_filter = document_sizes > self._sequence_length + 1 + ignored_documents = sum(long_docs_filter) + if ignored_documents: + log_main_rank( + f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._sequence_length+1} tokens and will be ignored.", + log_fn=logger.warning, + ) + tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() + if tokens_per_epoch == 0: + raise RuntimeError( + f" > No documents shorter than {self._sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." + ) # TODO MTP: Produce more labels to provide labels for the multi-token prediction heads? # We produce sequences of length `self._sequence_length + 1` so the last token has a label, - # but we also include that last label in the following sample, + # but in case of truncations we also include that last label in the following sample, # so we need `sequence_length * num_samples + 1` tokens in total. - num_epochs = math.ceil((self._sequence_length * self._num_samples + 1) / tokens_per_epoch) + num_epochs = math.ceil( + ((self._sequence_length + 1 - self._truncate_documents) * self._num_samples + 1 * self._truncate_documents) + / tokens_per_epoch + ) # Prepare for shuffling. generator = torch.Generator(device=self._device) @@ -154,6 +175,7 @@ def _sample(self) -> None: "num_samples": self._num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._sequence_length, + "truncate_documents": self._truncate_documents, "config": self._config.to_serialized(), } self._load_yaml_data(yaml_data) @@ -161,6 +183,9 @@ def _sample(self) -> None: if self._yaml_path is not None: if self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) + unshuffled_tokens = loaded_yaml_data.pop("unshuffled_tokens", None) + if unshuffled_tokens is not None: + self._unshuffled_tokens = unshuffled_tokens if loaded_yaml_data != yaml_data: raise RuntimeError( f"Invalid dataset cache for dataset {self.name}." @@ -172,9 +197,6 @@ def _sample(self) -> None: # Dataset is already sampled, skip. logger.info(f"Using existing sampling for dataset {self.name}") return - else: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) if shuffled_documents > 1e8: warnings.warn( @@ -232,51 +254,78 @@ def _sample(self) -> None: # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` + if unshuffled_epochs > 0: + token_cumsum_unshuffled, num_tokens_unshuffled = self._get_token_cumsum( + document_sizes, + offset=0, + # TODO: Allowing for max 100% extra tokens for padding, is that enough? + dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), + ) + if self._truncate_documents: + num_tokens_unshuffled = tokens_per_epoch * unshuffled_epochs + self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) + else: + num_tokens_unshuffled = 0 + self._unshuffled_tokens = num_tokens_unshuffled + + if self._yaml_path is not None: + yaml_data["unshuffled_tokens"] = num_tokens_unshuffled + self._yaml_path.parent.mkdir(parents=True, exist_ok=True) + yaml.safe_dump(yaml_data, self._yaml_path.open("w")) + if shuffled_epochs > 0: - token_cumsum_shuffled = self._get_token_cumsum( + token_cumsum_shuffled, num_tokens_shuffled = self._get_token_cumsum( document_sizes[ # Torch indexing only works with int32 or int64 document_shuffling.to( dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 ) ], - offset=unshuffled_epochs * tokens_per_epoch, - dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch, + offset=num_tokens_unshuffled, + # TODO: Allowing for max 100% extra tokens for padding, is that enough? + dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), ) - self._token_cumsum_shuffled.save(token_cumsum_shuffled.numpy(force=self._config.gpu)) + self._token_cumsum_shuffled.save(token_cumsum_shuffled) self._document_shuffling.save( - document_shuffling[: (token_cumsum_shuffled.numel() + 1) * TOKEN_CUMSUM_RATE].numpy( + document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy( force=self._config.gpu ) ) # Free memory - del token_cumsum_shuffled del document_shuffling - if unshuffled_epochs > 0: - token_cumsum_unshuffled = self._get_token_cumsum( - document_sizes, offset=0, dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch + def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]: + if self._truncate_documents: + # Create the output tensor. + out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype.torch) + # Get partial sums for regular intervals, excluding the last incomplete interval. + torch.sum( + sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE), + dim=1, + out=out[1:], ) - self._token_cumsum_unshuffled.save(token_cumsum_unshuffled.numpy(force=self._config.gpu)) - - def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: torch.dtype) -> torch.Tensor: - # Create the output tensor. - out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype) - # Get partial sums for regular intervals, excluding the last incomplete interval. - torch.sum( - sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE), dim=1, out=out[1:] - ) - # Pad with the begin offset - out[0] = offset - # Calculate the cumsum. - out.cumsum_(0) - # Crop unnecessary entries. - return out[ - : torch.clamp_min_( - torch.searchsorted(out, self._num_samples * self._sequence_length, side="right"), - 0, + # Pad with the begin offset + out[0] = offset + # Calculate the cumsum. + out.cumsum_(0) + # Crop unnecessary entries. + out = out[ + : torch.clamp_min_( + torch.searchsorted(out, self._num_samples * self._sequence_length, side="right"), + 0, + ) + ] + return out.numpy(force=self._config.gpu), None + else: + # TODO: dynamically handle int64 or int32 in CPP + out = build_padded_token_cumsum( + sizes.cpu().numpy(), (self._sequence_length + 1), TOKEN_CUMSUM_RATE, offset ) - ] + num_tokens = out[-1] + out = out[:-1][ + : np.clip(np.searchsorted(out, self._num_samples * (self._sequence_length + 1), side="right"), 0, None) + ] + return out, num_tokens def __len__(self) -> int: return self._num_samples @@ -288,7 +337,9 @@ def __getitem__(self, index: int) -> typing.Any: The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). """ self._lazy_load() - token_start = index * self._sequence_length + # tokens at the boundary are included in only one sample when we pack without truncations + # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample + token_start = index * (self._sequence_length + 1 - self._truncate_documents) token_end = token_start + self._sequence_length + 1 if token_start < self._unshuffled_tokens: @@ -302,6 +353,7 @@ def __getitem__(self, index: int) -> typing.Any: token_start_cumsum_index = np.searchsorted(token_start_array, token_start, side="right").item() - 1 document_sampling_index = token_start_cumsum_index * TOKEN_CUMSUM_RATE + token_start_array_document_offset + token_count = token_start_array[token_start_cumsum_index] token_ids = [] @@ -314,6 +366,25 @@ def __getitem__(self, index: int) -> typing.Any: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() document_size = self._indexed_dataset.get_document_size(document_index) + + if not self._truncate_documents: + if document_size > self._sequence_length + 1: + # Document too long, ignore + document_sampling_index += 1 + continue + tokens_in_sample = token_count % (self._sequence_length + 1) + if document_size + tokens_in_sample > self._sequence_length + 1: + # Document belongs to the next sample, need to account for padding. + padding_size = self._sequence_length + 1 - tokens_in_sample + if token_count > token_start: + # Add padding tokens to current sample + token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + Assert.eq(token_count + padding_size, token_end) + break + else: + # Move on to the next sample. + token_count += padding_size + # Determine if the document belongs to the requested sample. if token_count + document_size >= token_start: # Determine which part of the document belong to the sample, and add it to the list. @@ -343,7 +414,9 @@ def __getitem__(self, index: int) -> typing.Any: ) token_ids = np.concatenate(token_ids, dtype=np.int64) loss_masking_spans = ( - np.stack(loss_masking_spans, dtype=np.int32) if self._config.use_loss_masking_spans else None + (np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) + if self._config.use_loss_masking_spans + else None ) Assert.eq(len(token_ids), self._sequence_length + 1) @@ -357,9 +430,12 @@ def _lazy_load(self): if not hasattr(self, "_documents_per_epoch"): self._load_yaml_data(yaml.safe_load(self._yaml_path.open("r"))) - def _load_yaml_data(self, data: dict[str, typing.Any]): + def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"] + if unshuffled_tokens := data.get("unshuffled_tokens") is not None: + self._unshuffled_tokens = unshuffled_tokens + else: + self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch @@ -380,9 +456,12 @@ def __init__( self._indexed_dataset = indexed_dataset self._num_samples = sampling.num_samples self._sequence_length = sampling.sequence_length + if not sampling.truncate_documents: + raise NotImplementedError( + "Legacy sampling only supports document truncation. Please use the latest dataset format." + ) self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config - self._tokenizer = sampling.tokenizer if sampling.cache_directory is None: log_main_rank( @@ -498,7 +577,7 @@ def __getitem__(self, idx: int) -> typing.Any: for span in sample.loss_masking_spans: spans.append(span + offset) offset += len(sample.token_ids) - spans = np.stack(spans, dtype=np.int32) + spans = np.stack(spans, dtype=np.int32) if spans else np.array([]) else: spans = None sequence_lengths = ( diff --git a/tests/data/common.py b/tests/data/common.py index 5177b1f1..b3f41a1a 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -36,6 +36,7 @@ def get_sampling_data( tokenizer: Tokenizer | None = None, gpu: bool = False, shuffle: ShufflingType = ShufflingType.epoch, + truncate_documents=True, ) -> GPTSamplingData: # Config with convenient defaults. return GPTSamplingData( @@ -51,6 +52,7 @@ def get_sampling_data( sequence_length=sequence_length, vocab_size=vocab_size, tokenizer=tokenizer, + truncate_documents=truncate_documents, ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index e622d118..38679582 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -15,6 +15,14 @@ validate_indexed_dataset_sampling, ) +try: + from fast_llm.csrc.data import build_padded_token_cumsum # noqa + + _extension_available = True +except ImportError: + _extension_available = False + + GPT_MEMMAP_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], [1790, 80, 6506, 1735, 542, 88], @@ -125,3 +133,70 @@ def test_gpt_sample(seed, shuffle): # Check that the sequence is independent of `num_sample`. Assert.all_equal(samples, previous_samples[: len(samples)]) previous_samples = samples + + +@pytest.mark.skipif(not _extension_available, reason="CPP Extension not available") +def test_build_padded_token_cumsum(): + sizes = np.array([100, 256, 580, 600, 550, 89, 339, 430, 400, 795, 680, 50], dtype=np.int32) + sequence_length = 768 + token_cumsum_rate = 4 + offset = 0 + # sequences with padding: + # [100, 256, 413 padded, 580, 189 padded, 600, 169 padded, 550, 89, 130 padded, 339, 430, 400, 369 padded, 680, 50, 39 padded] + # cumsums: + # [100, 356, 1349, 2307, 2857, 2946, 3415, 3845, 4245, 5294, 5344, 5383] + expected_cumsums = [0, 2307, 3845, 5383] + token_cumsum = build_padded_token_cumsum(sizes, sequence_length + 1, token_cumsum_rate, offset) + Assert.all_equal(token_cumsum, expected_cumsums) + + +def get_test_seeds(num_seeds): + np.random.seed(42) + seeds = np.random.randint(0, num_seeds * 100, num_seeds) + return seeds.tolist() + + +@pytest.mark.skipif(not _extension_available, reason="CPP Extension not available") +def test_gpt_sample_padding(): + for seed in get_test_seeds(100): + vocab_size = 30 + np.random.seed(seed) + num_sequences = np.random.randint(1, 20) + sequence_length = np.random.randint(1, 20) + doc_sizes = np.random.randint(1, 2 * sequence_length, num_sequences) + samples = [np.random.randint(0, vocab_size, size) for size in doc_sizes] + expected_samples = [] + seq_size = 0 + token_ids = [] + total_tokens = 0 + for idx, sample in enumerate(samples): + doc_size = len(sample) + if doc_size > sequence_length + 1: + continue + elif doc_size + seq_size > sequence_length + 1: + padding_tokens = sequence_length + 1 - seq_size + token_ids.append([-100] * padding_tokens) + expected_samples.append(list(np.concatenate(token_ids))) + token_ids = [sample] + seq_size = doc_size + total_tokens += doc_size + else: + token_ids.append(sample) + seq_size += doc_size + total_tokens += doc_size + dataset = SimpleGPTIndexedDataset(samples) + sampling = get_sampling_data( + num_samples=len(expected_samples), + sequence_length=sequence_length, + vocab_size=vocab_size, + seed=seed, + shuffle=ShufflingType.disabled, + truncate_documents=False, + ) + if total_tokens == 0: + with pytest.raises(RuntimeError): + dataset.sample(sampling) + else: + sampled = dataset.sample(sampling) + for idx in range(len(expected_samples)): + Assert.all_equal(sampled[idx].token_ids, np.array(expected_samples[idx]))