Skip to content

Commit 2c90912

Browse files
[mypyc] Add bytearray support (#10891)
bytearray is treated as a subtype of bytes by mypy, even though they behave differently in some cases. We keep this design and accept bytearrays when the static type of a value is bytes.
1 parent 84504b0 commit 2c90912

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

mypyc/codegen/emit.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,8 @@ def emit_cast(self,
451451

452452
# TODO: Verify refcount handling.
453453
if (is_list_rprimitive(typ) or is_dict_rprimitive(typ) or is_set_rprimitive(typ)
454-
or is_str_rprimitive(typ) or is_bytes_rprimitive(typ) or is_range_rprimitive(typ)
455-
or is_float_rprimitive(typ) or is_int_rprimitive(typ) or is_bool_rprimitive(typ)):
454+
or is_str_rprimitive(typ) or is_range_rprimitive(typ) or is_float_rprimitive(typ)
455+
or is_int_rprimitive(typ) or is_bool_rprimitive(typ) or is_bit_rprimitive(typ)):
456456
if declare_dest:
457457
self.emit_line('PyObject *{};'.format(dest))
458458
if is_list_rprimitive(typ):
@@ -463,8 +463,6 @@ def emit_cast(self,
463463
prefix = 'PySet'
464464
elif is_str_rprimitive(typ):
465465
prefix = 'PyUnicode'
466-
elif is_bytes_rprimitive(typ):
467-
prefix = 'PyBytes'
468466
elif is_range_rprimitive(typ):
469467
prefix = 'PyRange'
470468
elif is_float_rprimitive(typ):
@@ -484,6 +482,18 @@ def emit_cast(self,
484482
'else {',
485483
err,
486484
'}')
485+
elif is_bytes_rprimitive(typ):
486+
if declare_dest:
487+
self.emit_line('PyObject *{};'.format(dest))
488+
check = '(PyBytes_Check({}) || PyByteArray_Check({}))'
489+
if likely:
490+
check = '(likely{})'.format(check)
491+
self.emit_arg_check(src, dest, typ, check.format(src, src), optional)
492+
self.emit_lines(
493+
' {} = {};'.format(dest, src),
494+
'else {',
495+
err,
496+
'}')
487497
elif is_tuple_rprimitive(typ):
488498
if declare_dest:
489499
self.emit_line('{} {};'.format(self.ctype(typ), dest))

mypyc/test-data/fixtures/ir.py

+8
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ def __ne__(self, x: object) -> bool: pass
106106
def __getitem__(self, i: int) -> int: pass
107107
def join(self, x: Iterable[object]) -> bytes: pass
108108

109+
class bytearray:
110+
@overload
111+
def __init__(self) -> None: pass
112+
@overload
113+
def __init__(self, x: object) -> None: pass
114+
@overload
115+
def __init__(self, string: str, encoding: str, err: str = ...) -> None: pass
116+
109117
class bool(int):
110118
def __init__(self, o: object = ...) -> None: ...
111119
@overload

mypyc/test-data/run-bytes.test

+27-1
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,30 @@ def test_len() -> None:
5252
# Use bytes() to avoid constant folding
5353
b = b'foo' + bytes()
5454
assert len(b) == 3
55-
assert len(bytes()) == 0
55+
assert len(bytes()) == 0
56+
57+
[case testBytearrayBasics]
58+
from typing import Any
59+
60+
def test_init() -> None:
61+
brr1: bytes = bytearray(3)
62+
assert brr1 == bytearray(b'\x00\x00\x00')
63+
assert brr1 == b'\x00\x00\x00'
64+
l = [10, 20, 30, 40]
65+
brr2: bytes = bytearray(l)
66+
assert brr2 == bytearray(b'\n\x14\x1e(')
67+
assert brr2 == b'\n\x14\x1e('
68+
brr3: bytes = bytearray(range(5))
69+
assert brr3 == bytearray(b'\x00\x01\x02\x03\x04')
70+
assert brr3 == b'\x00\x01\x02\x03\x04'
71+
brr4: bytes = bytearray('string', 'utf-8')
72+
assert brr4 == bytearray(b'string')
73+
assert brr4 == b'string'
74+
75+
def f(b: bytes) -> bool:
76+
return True
77+
78+
def test_bytearray_passed_into_bytes() -> None:
79+
assert f(bytearray(3))
80+
brr1: Any = bytearray()
81+
assert f(brr1)

0 commit comments

Comments
 (0)