Skip to content

Commit 87f157c

Browse files
messenseh-vetinari
authored andcommitted
Refactor and add SAFETY comments to PyArrayUnicode
Replace deprecated `PyUnicode_FromUnicode` with `PyUnicode_FromKindAndData`
1 parent bc9e5fd commit 87f157c

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

bindings/python/src/tokenizer.rs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -259,48 +259,50 @@ impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {
259259
struct PyArrayUnicode(Vec<String>);
260260
impl FromPyObject<'_> for PyArrayUnicode {
261261
fn extract(ob: &PyAny) -> PyResult<Self> {
262+
// SAFETY Making sure the pointer is a valid numpy array requires calling numpy C code
262263
if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 {
263264
return Err(exceptions::PyTypeError::new_err("Expected an np.array"));
264265
}
265266
let arr = ob.as_ptr() as *mut npyffi::PyArrayObject;
266-
if unsafe { (*arr).nd } != 1 {
267-
return Err(exceptions::PyTypeError::new_err(
268-
"Expected a 1 dimensional np.array",
269-
));
270-
}
271-
if unsafe { (*arr).flags }
272-
& (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
273-
== 0
274-
{
275-
return Err(exceptions::PyTypeError::new_err(
276-
"Expected a contiguous np.array",
277-
));
278-
}
279-
let n_elem = unsafe { *(*arr).dimensions } as usize;
280-
let (type_num, elsize, alignment, data) = unsafe {
267+
// SAFETY Getting all the metadata about the numpy array to check its sanity
268+
let (type_num, elsize, alignment, data, nd, flags) = unsafe {
281269
let desc = (*arr).descr;
282270
(
283271
(*desc).type_num,
284272
(*desc).elsize as usize,
285273
(*desc).alignment as usize,
286274
(*arr).data,
275+
(*arr).nd,
276+
(*arr).flags,
287277
)
288278
};
289279

280+
if nd != 1 {
281+
return Err(exceptions::PyTypeError::new_err(
282+
"Expected a 1 dimensional np.array",
283+
));
284+
}
285+
if flags & (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS) == 0 {
286+
return Err(exceptions::PyTypeError::new_err(
287+
"Expected a contiguous np.array",
288+
));
289+
}
290290
if type_num != npyffi::types::NPY_TYPES::NPY_UNICODE as i32 {
291291
return Err(exceptions::PyTypeError::new_err(
292292
"Expected a np.array[dtype='U']",
293293
));
294294
}
295295

296+
// SAFETY Looking at the raw numpy data to create new owned Rust strings via copies (so it's safe afterwards).
296297
unsafe {
298+
let n_elem = *(*arr).dimensions as usize;
297299
let all_bytes = std::slice::from_raw_parts(data as *const u8, elsize * n_elem);
298300

299301
let seq = (0..n_elem)
300302
.map(|i| {
301303
let bytes = &all_bytes[i * elsize..(i + 1) * elsize];
302-
#[allow(deprecated)]
303-
let unicode = pyo3::ffi::PyUnicode_FromUnicode(
304+
let unicode = pyo3::ffi::PyUnicode_FromKindAndData(
305+
pyo3::ffi::PyUnicode_4BYTE_KIND as _,
304306
bytes.as_ptr() as *const _,
305307
elsize as isize / alignment as isize,
306308
);

0 commit comments

Comments
 (0)