Skip to content

Commit d981d32

Browse files
benniekissabetlen
andauthored
feat: Enable detokenizing special tokens with special=True (abetlen#1596)
* enable detokenizing special tokens * enable skipping_special_tokens in hf_tokenizer detokenize() * process prev_tokens * fix doc strings * Revert changes to LlamaTokenizer prev_tokens and set special to False by default --------- Co-authored-by: Andrei <[email protected]>
1 parent 9cba3b8 commit d981d32

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

llama_cpp/llama.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,8 @@ def tokenize(
578578
579579
Args:
580580
text: The utf-8 encoded string to tokenize.
581+
add_bos: Whether to add a beginning of sequence token.
582+
special: Whether to tokenize special tokens.
581583
582584
Raises:
583585
RuntimeError: If the tokenization failed.
@@ -588,18 +590,19 @@ def tokenize(
588590
return self.tokenizer_.tokenize(text, add_bos, special)
589591

590592
def detokenize(
591-
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
593+
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
592594
) -> bytes:
593595
"""Detokenize a list of tokens.
594596
595597
Args:
596598
tokens: The list of tokens to detokenize.
597-
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided
599+
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided.
600+
special: Whether to detokenize special tokens.
598601
599602
Returns:
600603
The detokenized string.
601604
"""
602-
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens)
605+
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens, special=special)
603606

604607
def set_cache(self, cache: Optional[BaseLlamaCache]):
605608
"""Set the cache.

llama_cpp/llama_tokenizer.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,22 @@ def tokenize(
1919
"""Tokenize the text into tokens.
2020
2121
Args:
22-
text: The text to tokenize.
22+
text: The utf-8 encoded string to tokenize.
2323
add_bos: Whether to add a beginning of sequence token.
24-
special: Whether to tokenize text literally or as special tokens."""
24+
special: Whether to tokenize special tokens.
25+
"""
2526
raise NotImplementedError
2627

2728
@abc.abstractmethod
2829
def detokenize(
29-
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
30+
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
3031
) -> bytes:
3132
"""Detokenize the tokens into text.
3233
3334
Args:
34-
tokens: The tokens to detokenize.
35-
prev_tokens: If tokens is a continuation of a previous sequence, the previous tokens.
35+
tokens: The list of tokens to detokenize.
36+
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided.
37+
special: Whether to detokenize special tokens.
3638
"""
3739
raise NotImplementedError
3840

@@ -47,9 +49,9 @@ def tokenize(
4749
return self._model.tokenize(text, add_bos=add_bos, special=special)
4850

4951
def detokenize(
50-
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
52+
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
5153
) -> bytes:
52-
return self._model.detokenize(tokens)
54+
return self._model.detokenize(tokens, special=special)
5355

5456
def encode(
5557
self, text: str, add_bos: bool = True, special: bool = True
@@ -78,18 +80,19 @@ def tokenize(
7880
)
7981

8082
def detokenize(
81-
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
83+
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
8284
) -> bytes:
85+
skip_special_tokens = not special
8386
if prev_tokens is not None:
84-
text = self.hf_tokenizer.decode(prev_tokens + tokens).encode(
87+
text = self.hf_tokenizer.decode(prev_tokens + tokens, skip_special_tokens=skip_special_tokens).encode(
8588
"utf-8", errors="ignore"
8689
)
87-
prev_text = self.hf_tokenizer.decode(prev_tokens).encode(
90+
prev_text = self.hf_tokenizer.decode(prev_tokens, skip_special_tokens=skip_special_tokens).encode(
8891
"utf-8", errors="ignore"
8992
)
9093
return text[len(prev_text) :]
9194
else:
92-
return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
95+
return self.hf_tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens).encode("utf-8", errors="ignore")
9396

9497
@classmethod
9598
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":

0 commit comments

Comments
 (0)