Skip to content

Commit f51a3df

Browse files
committed
Partial sync of codebase
1 parent 00ff187 commit f51a3df

File tree

4 files changed

+13
-11
lines changed

4 files changed

+13
-11
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ python = [
1414
]
1515

1616
[dependencies]
17-
pyo3 = { version = "0.26", default-features = false, features = [
17+
pyo3 = { version = "0.26.0", default-features = false, features = [
1818
"extension-module",
1919
"macros",
2020
], optional = true }
2121

2222
# tiktoken dependencies
23-
fancy-regex = "0.16"
23+
fancy-regex = "0.13.0"
2424
regex = "1.10.3"
2525
rustc-hash = "2"
2626
bstr = "1.5.0"

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ skip = [
3131
"*-manylinux_i686",
3232
"*-musllinux_i686",
3333
"*-win32",
34-
"*-musllinux_aarch64",
3534
]
3635
macos.archs = ["x86_64", "arm64"]
3736
# When cross-compiling on Intel, it is not possible to test arm64 wheels.

src/py.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl CoreBPE {
2828

2929
#[pyo3(name = "encode_ordinary")]
3030
fn py_encode_ordinary(&self, py: Python, text: &str) -> Vec<Rank> {
31-
py.allow_threads(|| self.encode_ordinary(text))
31+
py.detach(|| self.encode_ordinary(text))
3232
}
3333

3434
#[pyo3(name = "encode")]
@@ -38,7 +38,7 @@ impl CoreBPE {
3838
text: &str,
3939
allowed_special: HashSet<PyBackedStr>,
4040
) -> PyResult<Vec<Rank>> {
41-
py.allow_threads(|| {
41+
py.detach(|| {
4242
let allowed_special: HashSet<&str> =
4343
allowed_special.iter().map(|s| s.as_ref()).collect();
4444
match self.encode(text, &allowed_special) {
@@ -54,7 +54,7 @@ impl CoreBPE {
5454
text: &str,
5555
allowed_special: HashSet<PyBackedStr>,
5656
) -> PyResult<Py<PyAny>> {
57-
let tokens_res = py.allow_threads(|| {
57+
let tokens_res = py.detach(|| {
5858
let allowed_special: HashSet<&str> =
5959
allowed_special.iter().map(|s| s.as_ref()).collect();
6060
self.encode(text, &allowed_special)
@@ -70,7 +70,7 @@ impl CoreBPE {
7070
}
7171

7272
fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> {
73-
py.allow_threads(|| {
73+
py.detach(|| {
7474
match std::str::from_utf8(bytes) {
7575
// Straightforward case
7676
Ok(text) => self.encode_ordinary(text),
@@ -121,7 +121,7 @@ impl CoreBPE {
121121
text: &str,
122122
allowed_special: HashSet<PyBackedStr>,
123123
) -> PyResult<(Vec<Rank>, Py<PyList>)> {
124-
let (tokens, completions): (Vec<Rank>, HashSet<Vec<Rank>>) = py.allow_threads(|| {
124+
let (tokens, completions): (Vec<Rank>, HashSet<Vec<Rank>>) = py.detach(|| {
125125
let allowed_special: HashSet<&str> =
126126
allowed_special.iter().map(|s| s.as_ref()).collect();
127127
self._encode_unstable_native(text, &allowed_special)
@@ -155,7 +155,7 @@ impl CoreBPE {
155155

156156
#[pyo3(name = "decode_bytes")]
157157
fn py_decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Result<Py<PyBytes>, PyErr> {
158-
match py.allow_threads(|| self.decode_bytes(&tokens)) {
158+
match py.detach(|| self.decode_bytes(&tokens)) {
159159
Ok(bytes) => Ok(PyBytes::new(py, &bytes).into()),
160160
Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))),
161161
}

tiktoken/core.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from concurrent.futures import ThreadPoolExecutor
55
from typing import TYPE_CHECKING, AbstractSet, Collection, Literal, NoReturn, Sequence
66

7-
import regex
8-
97
from tiktoken import _tiktoken
108

119
if TYPE_CHECKING:
10+
import re
11+
1212
import numpy as np
1313
import numpy.typing as npt
1414

@@ -391,6 +391,9 @@ def _encode_single_piece(self, text_or_bytes: str | bytes) -> list[int]:
391391

392392
def _encode_only_native_bpe(self, text: str) -> list[int]:
393393
"""Encodes a string into tokens, but do regex splitting in Python."""
394+
# We need specifically `regex` in order to compile pat_str due to e.g. \p
395+
import regex
396+
394397
_unused_pat = regex.compile(self._pat_str)
395398
ret = []
396399
for piece in regex.findall(_unused_pat, text):

0 commit comments

Comments
 (0)