From 5900fa974656929b876bb216cee008f4f3f35677 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Louis=20Brul=C3=A9=20Naudet?= Date: Mon, 17 Feb 2025 20:18:22 +0100 Subject: [PATCH] fix: resolve conflicts --- tiktoken/_educational.py | 120 +++++++++++++-- tiktoken/core.py | 323 ++++++++++++++++++++++++++++----------- tiktoken/load.py | 83 +++++++++- tiktoken/model.py | 34 ++++- tiktoken/registry.py | 68 +++++++++ 5 files changed, 524 insertions(+), 104 deletions(-) diff --git a/tiktoken/_educational.py b/tiktoken/_educational.py index 317e775..1cc1744 100644 --- a/tiktoken/_educational.py +++ b/tiktoken/_educational.py @@ -11,7 +11,13 @@ class SimpleBytePairEncoding: def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None: - """Creates an Encoding object.""" + """Creates an Encoding object. + + Args: + pat_str (str): A regex pattern string that is used to split the input text. + mergeable_ranks (dict[bytes, int]): A dictionary mapping token bytes to their ranks. + The ranks correspond to merge priority. + """ # A regex pattern string that is used to split the input text self.pat_str = pat_str # A dictionary mapping token bytes to their ranks. The ranks correspond to merge priority @@ -23,8 +29,17 @@ def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None: def encode(self, text: str, visualise: str | None = "colour") -> list[int]: """Encodes a string into tokens. - >>> enc.encode("hello world") - [388, 372] + Args: + text (str): The text to encode. + visualise (str | None, optional): Visualization mode. Can be 'colour', 'color', + 'simple', or None. Defaults to 'colour'. + + Returns: + list[int]: The encoded tokens. + + Examples: + >>> enc.encode("hello world") + [388, 372] """ # Use the regex to split the text into (approximately) words words = self._pat.findall(text) @@ -39,35 +54,70 @@ def encode(self, text: str, visualise: str | None = "colour") -> list[int]: def decode_bytes(self, tokens: list[int]) -> bytes: """Decodes a list of tokens into bytes. - >>> enc.decode_bytes([388, 372]) - b'hello world' + Args: + tokens (list[int]): The list of tokens to decode. + + Returns: + bytes: The decoded bytes. + + Examples: + >>> enc.decode_bytes([388, 372]) + b'hello world' """ return b"".join(self._decoder[token] for token in tokens) def decode(self, tokens: list[int]) -> str: """Decodes a list of tokens into a string. - Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace - the invalid bytes with the replacement character "�". + Args: + tokens (list[int]): The list of tokens to decode. - >>> enc.decode([388, 372]) - 'hello world' + Returns: + str: The decoded string. + + Note: + Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace + the invalid bytes with the replacement character "�". + + Examples: + >>> enc.decode([388, 372]) + 'hello world' """ return self.decode_bytes(tokens).decode("utf-8", errors="replace") def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]: """Decodes a list of tokens into a list of bytes. - Useful for visualising how a string is tokenised. + Args: + tokens (list[int]): The list of tokens to decode. + + Returns: + list[bytes]: A list of decoded bytes. - >>> enc.decode_tokens_bytes([388, 372]) - [b'hello', b' world'] + Note: + Useful for visualising how a string is tokenised. + + Examples: + >>> enc.decode_tokens_bytes([388, 372]) + [b'hello', b' world'] """ return [self._decoder[token] for token in tokens] @staticmethod def train(training_data: str, vocab_size: int, pat_str: str): - """Train a BPE tokeniser on some data!""" + """Train a BPE tokeniser on some data. + + Args: + training_data (str): The text data to train on. + vocab_size (int): The desired size of the vocabulary. + pat_str (str): The regex pattern string used for tokenization. + + Returns: + SimpleBytePairEncoding: A new tokenizer trained on the data. + + Note: + This is an educational implementation of BPE training. + """ mergeable_ranks = bpe_train(data=training_data, vocab_size=vocab_size, pat_str=pat_str) return SimpleBytePairEncoding(pat_str=pat_str, mergeable_ranks=mergeable_ranks) @@ -81,8 +131,21 @@ def from_tiktoken(encoding): def bpe_encode( - mergeable_ranks: dict[bytes, int], input: bytes, visualise: str | None = "colour" + mergeable_ranks: dict[bytes, int], + input: bytes, + visualise: str | None = "colour" ) -> list[int]: + """Encodes input bytes using byte pair encoding. + + Args: + mergeable_ranks (dict[bytes, int]): Dictionary mapping token bytes to their ranks. + input (bytes): The input bytes to encode. + visualise (str | None, optional): Visualization mode. Can be 'colour', 'color', + 'simple', or None. Defaults to 'colour'. + + Returns: + list[int]: The encoded tokens. + """ parts = [bytes([b]) for b in input] while True: # See the intermediate merges play out! @@ -117,8 +180,26 @@ def bpe_encode( def bpe_train( - data: str, vocab_size: int, pat_str: str, visualise: str | None = "colour" + data: str, + vocab_size: int, + pat_str: str, + visualise: str | None = "colour" ) -> dict[bytes, int]: + """Trains a byte pair encoding model on the given data. + + Args: + data (str): The text data to train on. + vocab_size (int): The desired size of the vocabulary. + pat_str (str): The regex pattern string used for tokenization. + visualise (str | None, optional): Visualization mode. Can be 'colour', 'color', + 'simple', or None. Defaults to 'colour'. + + Returns: + dict[bytes, int]: A dictionary mapping token bytes to their ranks. + + Raises: + ValueError: If vocab_size is less than 256. + """ # First, add tokens for each individual byte value if vocab_size < 2**8: raise ValueError("vocab_size must be at least 256, so we can encode all bytes") @@ -186,6 +267,15 @@ def bpe_train( def visualise_tokens(token_values: list[bytes]) -> None: + """Visualizes tokens by printing them with different background colors. + + Args: + token_values (list[bytes]): List of token bytes to visualize. + + Note: + If token boundaries do not occur at unicode character boundaries, it uses the + unicode replacement character to represent some fraction of a character. + """ background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]] # If token boundaries do not occur at unicode character boundaries, it's unclear how best to # visualise the token. Here, we'll just use the unicode replacement character to represent some diff --git a/tiktoken/core.py b/tiktoken/core.py index 6bc9736..55f6afa 100644 --- a/tiktoken/core.py +++ b/tiktoken/core.py @@ -28,15 +28,17 @@ def __init__( See openai_public.py for examples of how to construct an Encoding object. Args: - name: The name of the encoding. It should be clear from the name of the encoding - what behaviour to expect, in particular, encodings with different special tokens - should have different names. - pat_str: A regex pattern string that is used to split the input text. - mergeable_ranks: A dictionary mapping mergeable token bytes to their ranks. The ranks - must correspond to merge priority. - special_tokens: A dictionary mapping special token strings to their token values. - explicit_n_vocab: The number of tokens in the vocabulary. If provided, it is checked - that the number of mergeable tokens and special tokens is equal to this number. + name (str): The name of the encoding. It should be clear from the name of the + encoding what behaviour to expect, in particular, encodings with different + special tokens should have different names. + pat_str (str): A regex pattern string that is used to split the input text. + mergeable_ranks (dict[bytes, int]): A dictionary mapping mergeable token bytes + to their ranks. The ranks must correspond to merge priority. + special_tokens (dict[str, int]): A dictionary mapping special token strings to + their token values. + explicit_n_vocab (int | None, optional): The number of tokens in the vocabulary. + If provided, it is checked that the number of mergeable tokens and special + tokens is equal to this number. """ self.name = name @@ -65,9 +67,15 @@ def encode_ordinary(self, text: str) -> list[int]: This is equivalent to `encode(text, disallowed_special=())` (but slightly faster). - ``` - >>> enc.encode_ordinary("hello world") - [31373, 995] + Examples: + >>> enc.encode_ordinary("hello world") + [31373, 995] + + Args: + text (str): The text to encode. + + Returns: + list[int]: The encoded tokens. """ try: return self._core_bpe.encode_ordinary(text) @@ -91,24 +99,37 @@ def encode( Hence, by default, encode will raise an error if it encounters text that corresponds to a special token. This can be controlled on a per-token level using the `allowed_special` - and `disallowed_special` parameters. In particular: - - Setting `disallowed_special` to () will prevent this function from raising errors and - cause all text corresponding to special tokens to be encoded as natural text. - - Setting `allowed_special` to "all" will cause this function to treat all text - corresponding to special tokens to be encoded as special tokens. + and `disallowed_special` parameters. - ``` - >>> enc.encode("hello world") - [31373, 995] - >>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}) - [50256] - >>> enc.encode("<|endoftext|>", allowed_special="all") - [50256] - >>> enc.encode("<|endoftext|>") - # Raises ValueError - >>> enc.encode("<|endoftext|>", disallowed_special=()) - [27, 91, 437, 1659, 5239, 91, 29] - ``` + Args: + text (str): The text to encode. + allowed_special (Union[Literal["all"], AbstractSet[str]], optional): Special tokens + that are allowed to be encoded. If "all", all special tokens are allowed. + Defaults to set(). + disallowed_special (Union[Literal["all"], Collection[str]], optional): Special + tokens that are not allowed to be encoded. If "all", no special tokens are + allowed except those in allowed_special. Defaults to "all". + + Returns: + list[int]: The encoded tokens. + + Note: + - Setting `disallowed_special` to () will prevent this function from raising errors + and cause all text corresponding to special tokens to be encoded as natural text. + - Setting `allowed_special` to "all" will cause this function to treat all text + corresponding to special tokens to be encoded as special tokens. + + Examples: + >>> enc.encode("hello world") + [31373, 995] + >>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}) + [50256] + >>> enc.encode("<|endoftext|>", allowed_special="all") + [50256] + >>> enc.encode("<|endoftext|>") + # Raises ValueError + >>> enc.encode("<|endoftext|>", disallowed_special=()) + [27, 91, 437, 1659, 5239, 91, 29] """ if allowed_special == "all": allowed_special = self.special_tokens_set @@ -142,6 +163,18 @@ def encode_to_numpy( """Encodes a string into tokens, returning a numpy array. Avoids the overhead of copying the token buffer into a Python list. + + Args: + text (str): The text to encode. + allowed_special (Union[Literal["all"], AbstractSet[str]], optional): Special tokens + that are allowed to be encoded. If "all", all special tokens are allowed. + Defaults to set(). + disallowed_special (Union[Literal["all"], Collection[str]], optional): Special + tokens that are not allowed to be encoded. If "all", no special tokens are + allowed except those in allowed_special. Defaults to "all". + + Returns: + npt.NDArray[np.uint32]: The encoded tokens as a numpy array. """ if allowed_special == "all": allowed_special = self.special_tokens_set @@ -163,10 +196,17 @@ def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> lis This is equivalent to `encode_batch(text, disallowed_special=())` (but slightly faster). - ``` - >>> enc.encode_ordinary_batch(["hello world", "goodbye world"]) - [[31373, 995], [11274, 16390, 995]] - ``` + Args: + text (list[str]): List of strings to encode. + num_threads (int, optional): Number of threads to use for parallel processing. + Defaults to 8. + + Returns: + list[list[int]]: List of encoded token sequences, one for each input string. + + Examples: + >>> enc.encode_ordinary_batch(["hello world", "goodbye world"]) + [[31373, 995], [11274, 16390, 995]] """ encoder = functools.partial(self.encode_ordinary) with ThreadPoolExecutor(num_threads) as e: @@ -184,10 +224,23 @@ def encode_batch( See `encode` for more details on `allowed_special` and `disallowed_special`. - ``` - >>> enc.encode_batch(["hello world", "goodbye world"]) - [[31373, 995], [11274, 16390, 995]] - ``` + Args: + text (list[str]): List of strings to encode. + num_threads (int, optional): Number of threads to use for parallel processing. + Defaults to 8. + allowed_special (Union[Literal["all"], AbstractSet[str]], optional): Special tokens + that are allowed to be encoded. If "all", all special tokens are allowed. + Defaults to set(). + disallowed_special (Union[Literal["all"], Collection[str]], optional): Special + tokens that are not allowed to be encoded. If "all", no special tokens are + allowed except those in allowed_special. Defaults to "all". + + Returns: + list[list[int]]: List of encoded token sequences, one for each input string. + + Examples: + >>> enc.encode_batch(["hello world", "goodbye world"]) + [[31373, 995], [11274, 16390, 995]] """ if allowed_special == "all": allowed_special = self.special_tokens_set @@ -212,20 +265,34 @@ def encode_with_unstable( """Encodes a string into stable tokens and possible completion sequences. Note that the stable tokens will only represent a substring of `text`. - - See `encode` for more details on `allowed_special` and `disallowed_special`. - This API should itself be considered unstable. - ``` - >>> enc.encode_with_unstable("hello fanta") - ([31373], [(277, 4910), (5113, 265), ..., (8842,)]) - - >>> text = "..." - >>> stable_tokens, completions = enc.encode_with_unstable(text) - >>> assert text.encode().startswith(enc.decode_bytes(stable_tokens)) - >>> assert all(enc.decode_bytes(stable_tokens + seq).startswith(text.encode()) for seq in completions) - ``` + Args: + text (str): The text to encode. + allowed_special (Union[Literal["all"], AbstractSet[str]], optional): Special tokens + that are allowed to be encoded. If "all", all special tokens are allowed. + Defaults to set(). + disallowed_special (Union[Literal["all"], Collection[str]], optional): Special + tokens that are not allowed to be encoded. If "all", no special tokens are + allowed except those in allowed_special. Defaults to "all". + + Returns: + tuple[list[int], list[list[int]]]: A tuple containing: + - list[int]: The stable tokens. + - list[list[int]]: Possible completion sequences. + + Examples: + >>> enc.encode_with_unstable("hello fanta") + ([31373], [(277, 4910), (5113, 265), ..., (8842,)]) + + >>> text = "..." + >>> stable_tokens, completions = enc.encode_with_unstable(text) + >>> assert text.encode().startswith(enc.decode_bytes(stable_tokens)) + >>> assert all(enc.decode_bytes(stable_tokens + seq).startswith(text.encode()) + ... for seq in completions) + + See Also: + See `encode` for more details on `allowed_special` and `disallowed_special`. """ if allowed_special == "all": allowed_special = self.special_tokens_set @@ -262,62 +329,102 @@ def encode_single_token(self, text_or_bytes: str | bytes) -> int: def decode_bytes(self, tokens: Sequence[int]) -> bytes: """Decodes a list of tokens into bytes. - ``` - >>> enc.decode_bytes([31373, 995]) - b'hello world' - ``` + Args: + tokens (Sequence[int]): The sequence of tokens to decode. + + Returns: + bytes: The decoded bytes. + + Examples: + >>> enc.decode_bytes([31373, 995]) + b'hello world' """ return self._core_bpe.decode_bytes(tokens) def decode(self, tokens: Sequence[int], errors: str = "replace") -> str: """Decodes a list of tokens into a string. - WARNING: the default behaviour of this function is lossy, since decoded bytes are not - guaranteed to be valid UTF-8. You can control this behaviour using the `errors` parameter, - for instance, setting `errors=strict`. + Args: + tokens (Sequence[int]): The sequence of tokens to decode. + errors (str, optional): How to handle unicode decode errors. Options include + 'strict', 'replace', 'ignore'. Defaults to 'replace'. + + Returns: + str: The decoded string. - ``` - >>> enc.decode([31373, 995]) - 'hello world' - ``` + Warning: + The default behaviour of this function is lossy, since decoded bytes are not + guaranteed to be valid UTF-8. You can control this behaviour using the `errors` + parameter, for instance, setting `errors=strict`. + + Examples: + >>> enc.decode([31373, 995]) + 'hello world' """ return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors) def decode_single_token_bytes(self, token: int) -> bytes: """Decodes a token into bytes. - NOTE: this will decode all special tokens. + Args: + token (int): The token to decode. - Raises `KeyError` if the token is not in the vocabulary. + Returns: + bytes: The decoded bytes. - ``` - >>> enc.decode_single_token_bytes(31373) - b'hello' - ``` + Raises: + KeyError: If the token is not in the vocabulary. + + Note: + This will decode all special tokens. + + Examples: + >>> enc.decode_single_token_bytes(31373) + b'hello' """ return self._core_bpe.decode_single_token_bytes(token) def decode_tokens_bytes(self, tokens: Sequence[int]) -> list[bytes]: """Decodes a list of tokens into a list of bytes. - Useful for visualising tokenisation. - >>> enc.decode_tokens_bytes([31373, 995]) - [b'hello', b' world'] + Args: + tokens (Sequence[int]): The sequence of tokens to decode. + + Returns: + list[bytes]: List of decoded bytes for each token. + + Note: + Useful for visualising tokenisation. + + Examples: + >>> enc.decode_tokens_bytes([31373, 995]) + [b'hello', b' world'] """ return [self.decode_single_token_bytes(token) for token in tokens] def decode_with_offsets(self, tokens: Sequence[int]) -> tuple[str, list[int]]: """Decodes a list of tokens into a string and a list of offsets. - Each offset is the index into text corresponding to the start of each token. - If UTF-8 character boundaries do not line up with token boundaries, the offset is the index - of the first character that contains bytes from the token. + Args: + tokens (Sequence[int]): The sequence of tokens to decode. + + Returns: + tuple[str, list[int]]: A tuple containing: + - str: The decoded string. + - list[int]: The offsets into the string for each token. + + Note: + Each offset is the index into text corresponding to the start of each token. + If UTF-8 character boundaries do not line up with token boundaries, the offset + is the index of the first character that contains bytes from the token. - This will currently raise if given tokens that decode to invalid UTF-8; this behaviour may - change in the future to be more permissive. + Warning: + This will currently raise if given tokens that decode to invalid UTF-8; + this behaviour may change in the future to be more permissive. - >>> enc.decode_with_offsets([31373, 995]) - ('hello world', [0, 5]) + Examples: + >>> enc.decode_with_offsets([31373, 995]) + ('hello world', [0, 5]) """ token_bytes = self.decode_tokens_bytes(tokens) @@ -334,7 +441,18 @@ def decode_with_offsets(self, tokens: Sequence[int]) -> tuple[str, list[int]]: def decode_batch( self, batch: Sequence[Sequence[int]], *, errors: str = "replace", num_threads: int = 8 ) -> list[str]: - """Decodes a batch (list of lists of tokens) into a list of strings.""" + """Decodes a batch (list of lists of tokens) into a list of strings. + + Args: + batch (Sequence[Sequence[int]]): List of token sequences to decode. + errors (str, optional): How to handle unicode decode errors. Options include + 'strict', 'replace', 'ignore'. Defaults to 'replace'. + num_threads (int, optional): Number of threads to use for parallel processing. + Defaults to 8. + + Returns: + list[str]: List of decoded strings, one for each token sequence. + """ decoder = functools.partial(self.decode, errors=errors) with ThreadPoolExecutor(num_threads) as e: return list(e.map(decoder, batch)) @@ -342,7 +460,16 @@ def decode_batch( def decode_bytes_batch( self, batch: Sequence[Sequence[int]], *, num_threads: int = 8 ) -> list[bytes]: - """Decodes a batch (list of lists of tokens) into a list of bytes.""" + """Decodes a batch (list of lists of tokens) into a list of bytes. + + Args: + batch (Sequence[Sequence[int]]): List of token sequences to decode. + num_threads (int, optional): Number of threads to use for parallel processing. + Defaults to 8. + + Returns: + list[bytes]: List of decoded bytes, one for each token sequence. + """ with ThreadPoolExecutor(num_threads) as e: return list(e.map(self.decode_bytes, batch)) @@ -351,7 +478,11 @@ def decode_bytes_batch( # ==================== def token_byte_values(self) -> list[bytes]: - """Returns the list of all token byte values.""" + """Returns the list of all token byte values. + + Returns: + list[bytes]: List of all token byte values in the vocabulary. + """ return self._core_bpe.token_byte_values() @property @@ -368,7 +499,14 @@ def is_special_token(self, token: int) -> bool: @property def n_vocab(self) -> int: - """For backwards compatibility. Prefer to use `enc.max_token_value + 1`.""" + """Get the vocabulary size. + + Returns: + int: The number of tokens in the vocabulary. + + Note: + For backwards compatibility. Prefer to use `enc.max_token_value + 1`. + """ return self.max_token_value + 1 # ==================== @@ -378,19 +516,32 @@ def n_vocab(self) -> int: def _encode_single_piece(self, text_or_bytes: str | bytes) -> list[int]: """Encodes text corresponding to bytes without a regex split. - NOTE: this will not encode any special tokens. + Args: + text_or_bytes (Union[str, bytes]): Text to encode. - ``` - >>> enc.encode_single_piece("helloqqqq") - [31373, 38227, 38227] - ``` + Returns: + list[int]: The encoded tokens. + + Note: + This will not encode any special tokens. + + Examples: + >>> enc.encode_single_piece("helloqqqq") + [31373, 38227, 38227] """ if isinstance(text_or_bytes, str): text_or_bytes = text_or_bytes.encode("utf-8") return self._core_bpe.encode_single_piece(text_or_bytes) def _encode_only_native_bpe(self, text: str) -> list[int]: - """Encodes a string into tokens, but do regex splitting in Python.""" + """Encodes a string into tokens, but do regex splitting in Python. + + Args: + text (str): The text to encode. + + Returns: + list[int]: The encoded tokens. + """ _unused_pat = regex.compile(self._pat_str) ret = [] for piece in regex.findall(_unused_pat, text): diff --git a/tiktoken/load.py b/tiktoken/load.py index 295deb9..1a3499c 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -6,6 +6,18 @@ def read_file(blobpath: str) -> bytes: + """Reads the contents of a file specified by the given blobpath. + + Args: + blobpath (str): The path or URL to the file to be read. + + Returns: + bytes: The binary content of the file. + + Raises: + ImportError: If blobfile is not installed for local file access. + requests.exceptions.RequestException: If the HTTP request fails. + """ if not blobpath.startswith("http://") and not blobpath.startswith("https://"): try: import blobfile @@ -25,11 +37,37 @@ def read_file(blobpath: str) -> bytes: def check_hash(data: bytes, expected_hash: str) -> bool: + """Checks if the hash of the given data matches the expected hash. + + Args: + data (bytes): The binary data to be hashed. + expected_hash (str): The expected hash value. + + Returns: + bool: True if the actual hash matches the expected hash, False otherwise. + """ actual_hash = hashlib.sha256(data).hexdigest() return actual_hash == expected_hash def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes: + """Reads and caches the contents of a file specified by the given blobpath. + + If the file exists in cache and the hash matches (if provided), returns the cached content. + Otherwise, fetches it from the source, caches it, and returns the content. + + Args: + blobpath (str): The path or URL to the file to be read. + expected_hash (str | None, optional): The expected hash value of the file content. + Defaults to None. + + Returns: + bytes: The binary content of the file. + + Raises: + ValueError: If the downloaded content's hash doesn't match the expected hash. + OSError: If there are issues writing to a user-specified cache directory. + """ user_specified_cache = True if "TIKTOKEN_CACHE_DIR" in os.environ: cache_dir = os.environ["TIKTOKEN_CACHE_DIR"] @@ -89,6 +127,22 @@ def data_gym_to_mergeable_bpe_ranks( vocab_bpe_hash: str | None = None, encoder_json_hash: str | None = None, ) -> dict[bytes, int]: + """Converts a vocab BPE file and an encoder JSON file into mergeable BPE ranks. + + Args: + vocab_bpe_file (str): The path to the vocabulary BPE file. + encoder_json_file (str): The path to the encoder JSON file. + vocab_bpe_hash (str | None, optional): The expected hash value of the vocabulary BPE file. + Defaults to None. + encoder_json_hash (str | None, optional): The expected hash value of the encoder JSON file. + Defaults to None. + + Returns: + dict[bytes, int]: A dictionary mapping mergeable BPE tokens to their ranks. + + Note: + This function does not implement caching internally. + """ # NB: do not add caching to this function rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] @@ -132,6 +186,15 @@ def decode_data_gym(value: str) -> bytes: def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> None: + """Dumps the mergeable BPE ranks to a TikToken BPE file. + + Args: + bpe_ranks (dict[bytes, int]): A dictionary mapping mergeable BPE tokens to their ranks. + tiktoken_bpe_file (str): The path to the TikToken BPE file. + + Raises: + ImportError: If blobfile is not installed. + """ try: import blobfile except ImportError as e: @@ -143,7 +206,25 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n") -def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) -> dict[bytes, int]: +def load_tiktoken_bpe( + tiktoken_bpe_file: str, expected_hash: str | None = None +) -> dict[bytes, int]: + """Loads mergeable BPE ranks from a TikToken BPE file. + + Args: + tiktoken_bpe_file (str): The path to the TikToken BPE file. + expected_hash (str | None, optional): The expected hash value of the file content. + Defaults to None. + + Returns: + dict[bytes, int]: A dictionary mapping mergeable BPE tokens to their ranks. + + Raises: + ValueError: If there are errors parsing the BPE file. + + Note: + This function does not implement caching internally. + """ # NB: do not add caching to this function contents = read_file_cached(tiktoken_bpe_file, expected_hash) ret = {} diff --git a/tiktoken/model.py b/tiktoken/model.py index 4298ae7..39fffc8 100644 --- a/tiktoken/model.py +++ b/tiktoken/model.py @@ -80,7 +80,30 @@ def encoding_name_for_model(model_name: str) -> str: """Returns the name of the encoding used by a model. - Raises a KeyError if the model name is not recognised. + Args: + model_name (str): The name of the model. + + Returns: + str: The name of the encoding used by the model. + + Raises: + KeyError: If the model name is not recognized or cannot be mapped to an encoding. + + Note: + This function checks if the provided model name is directly mapped to an encoding + in MODEL_TO_ENCODING. If not, it attempts to match the model name with known + prefixes in MODEL_PREFIX_TO_ENCODING. If a match is found, it returns the + corresponding encoding name. + + Examples: + >>> encoding_name_for_model("gpt2") + 'gpt2' + >>> encoding_name_for_model("roberta-large") + 'roberta' + >>> encoding_name_for_model("nonexistent-model") + Traceback (most recent call last): + ... + KeyError: "Could not automatically map nonexistent-model to a tokeniser..." """ encoding_name = None if model_name in MODEL_TO_ENCODING: @@ -105,6 +128,13 @@ def encoding_name_for_model(model_name: str) -> str: def encoding_for_model(model_name: str) -> Encoding: """Returns the encoding used by a model. - Raises a KeyError if the model name is not recognised. + Args: + model_name (str): The name of the model. + + Returns: + Encoding: The encoding used by the model. + + Raises: + KeyError: If the model name is not recognized or cannot be mapped to an encoding. """ return get_encoding(encoding_name_for_model(model_name)) diff --git a/tiktoken/registry.py b/tiktoken/registry.py index 17c4574..ff9090b 100644 --- a/tiktoken/registry.py +++ b/tiktoken/registry.py @@ -18,6 +18,22 @@ @functools.lru_cache def _available_plugin_modules() -> Sequence[str]: + """Returns a sequence of available plugin modules. + + Args: + None + + Returns: + Sequence[str]: List of available plugin module names. + + Note: + This function inspects tiktoken_ext namespace package for available plugin modules. + Submodules inside tiktoken_ext will be checked for ENCODING_CONSTRUCTORS attributes. + Uses namespace package pattern for faster pkgutil.iter_modules operation. + + tiktoken_ext is implemented as a separate top-level package because namespace + subpackages of non-namespace packages don't work as expected with editable installs. + """ # tiktoken_ext is a namespace package # submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes # - we use namespace package pattern so `pkgutil.iter_modules` is fast @@ -31,6 +47,25 @@ def _available_plugin_modules() -> Sequence[str]: def _find_constructors() -> None: + """Searches for and registers encoding constructors from available plugin modules. + + This function populates the global ENCODING_CONSTRUCTORS dictionary by searching through + available plugin modules to find encoding constructors. It ensures there are no duplicate + encoding names across different plugins. + + Args: + None + + Raises: + ValueError: If either: + - A plugin module does not define ENCODING_CONSTRUCTORS + - There are duplicate encoding names across plugins + + Note: + - Uses a lock to ensure thread safety when populating the global dictionary + - If ENCODING_CONSTRUCTORS is already populated, returns early + - In case of any exception, ENCODING_CONSTRUCTORS is reset to None before re-raising + """ global ENCODING_CONSTRUCTORS with _lock: if ENCODING_CONSTRUCTORS is not None: @@ -61,6 +96,27 @@ def _find_constructors() -> None: def get_encoding(encoding_name: str) -> Encoding: + """Returns an Encoding object for the given encoding name. + + If the encoding has been previously loaded, returns the cached version. + Otherwise, constructs a new Encoding object using the appropriate constructor. + + Args: + encoding_name (str): The name of the encoding to retrieve. + + Returns: + Encoding: An Encoding object for the specified encoding. + + Raises: + ValueError: If either: + - encoding_name is not a string + - the specified encoding name is unknown + + Examples: + >>> encoding = get_encoding("gpt2") + >>> encoding + + """ if not isinstance(encoding_name, str): raise ValueError(f"Expected a string in get_encoding, got {type(encoding_name)}") @@ -89,6 +145,18 @@ def get_encoding(encoding_name: str) -> Encoding: def list_encoding_names() -> list[str]: + """Lists all available encoding names that can be used with tiktoken. + + Args: + None + + Returns: + list[str]: List of available encoding names. + + Examples: + >>> list_encoding_names() + ['gpt2', 'r50k_base', 'p50k_base', ...] + """ with _lock: if ENCODING_CONSTRUCTORS is None: _find_constructors()