Skip to content

Commit

Permalink
fix: decorators that made black choke
Browse files Browse the repository at this point in the history
  • Loading branch information
thorwhalen committed Nov 15, 2023
1 parent 8ad27a4 commit 356bab8
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 47 deletions.
5 changes: 1 addition & 4 deletions dol/filesys.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,7 @@ def is_valid_key(self, k):
return bool(self._key_pattern.match(k))

def validate_key(
self,
k,
err_msg_format=_dflt_not_valid_error_msg,
err_type=KeyValidationError,
self, k, err_msg_format=_dflt_not_valid_error_msg, err_type=KeyValidationError,
):
if not self.is_valid_key(k):
raise err_type(err_msg_format.format(k))
Expand Down
25 changes: 15 additions & 10 deletions dol/kv_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def _string(string: str):

@Sig
def _csv_rw_sig(
dialect: str = "excel",
dialect: str = 'excel',
*,
delimiter: str = ",",
delimiter: str = ',',
quotechar: Optional[str] = '"',
escapechar: Optional[str] = None,
doublequote: bool = True,
skipinitialspace: bool = False,
lineterminator: str = "\r\n",
lineterminator: str = '\r\n',
quoting=0,
strict: bool = False,
):
Expand All @@ -36,27 +36,32 @@ def _csv_rw_sig(

@Sig
def _csv_dict_extra_sig(
fieldnames, restkey=None, restval="", extrasaction="raise", fieldcasts=None
fieldnames, restkey=None, restval='', extrasaction='raise', fieldcasts=None
):
...


@(_string + _csv_rw_sig)
__csv_rw_sig = _string + _csv_rw_sig
__csv_dict_sig = _string + _csv_rw_sig + _csv_dict_extra_sig


# Note: @(_string + _csv_rw_sig) made (ax)black choke
@__csv_rw_sig
def csv_encode(string, *args, **kwargs):
with io.StringIO() as buffer:
writer = csv.writer(buffer, *args, **kwargs)
writer.writerows(string)
return buffer.getvalue()


@(_string + _csv_rw_sig)
@__csv_rw_sig
def csv_decode(string, *args, **kwargs):
with io.StringIO(string) as buffer:
reader = csv.reader(buffer, *args, **kwargs)
return list(reader)


@(_string + _csv_rw_sig + _csv_dict_extra_sig)
@__csv_dict_sig
def csv_dict_encode(string, *args, **kwargs):
"""Encode a list of dicts into a csv string.
Expand All @@ -66,15 +71,15 @@ def csv_dict_encode(string, *args, **kwargs):
'a,b\r\n1,2\r\n3,4\r\n'
"""
_ = kwargs.pop("fieldcasts", None) # this one is for decoder only
_ = kwargs.pop('fieldcasts', None) # this one is for decoder only
with io.StringIO() as buffer:
writer = csv.DictWriter(buffer, *args, **kwargs)
writer.writeheader()
writer.writerows(string)
return buffer.getvalue()


@(_string + _csv_rw_sig + _csv_dict_extra_sig)
@__csv_dict_sig
def csv_dict_decode(string, *args, **kwargs):
r"""Decode a csv string into a list of dicts.
Expand Down Expand Up @@ -107,7 +112,7 @@ def csv_dict_decode(string, *args, **kwargs):
[{'a': '1', 'b': 2.0}, {'a': '3', 'b': 4.0}]
"""
fieldcasts = kwargs.pop("fieldcasts", lambda row: row)
fieldcasts = kwargs.pop('fieldcasts', lambda row: row)
if isinstance(fieldcasts, Iterable):
if isinstance(fieldcasts, dict):
cast_dict = dict(fieldcasts)
Expand Down
11 changes: 4 additions & 7 deletions dol/tests/test_kv_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def _test_codec(codec, obj, encoded=None, decoded=None):
decoded = obj
assert (
codec.encoder(obj) == encoded
), f"Expected {codec.encoder(obj)=} to equal {encoded=}"
), f'Expected {codec.encoder(obj)=} to equal {encoded=}'
assert (
codec.decoder(encoded) == decoded
), f"Expected {codec.decoder(encoded)=} to equal {decoded=}"
), f'Expected {codec.decoder(encoded)=} to equal {decoded=}'


def _test_codec_part(codec, obj, encoded, slice_):
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_value_codecs():
)

assert str(inspect.signature(ValueCodecs.pickle)) == (
"(obj, data, protocol=None, fix_imports=True, buffer_callback=None, "
'(obj, data, protocol=None, fix_imports=True, buffer_callback=None, '
"encoding='ASCII', errors='strict', buffers=())"
) # NOTE: May change according to python version. This is 3.8

Expand Down Expand Up @@ -125,10 +125,7 @@ def test_value_codecs():
_test_codec_part(ValueCodecs.tarfile(), b'hello', b'data.bin', slice(0, 8))

_test_codec_part(
ValueCodecs.lzma(),
b'hello',
b'\xfd7zXZ',
slice(0, 4),
ValueCodecs.lzma(), b'hello', b'\xfd7zXZ', slice(0, 4),
)

_test_codec_part(
Expand Down
9 changes: 2 additions & 7 deletions dol/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,7 @@ class PartialClass(cls):
__init__ = partialmethod(cls.__init__, *args, **kwargs)

copy_attrs(
PartialClass,
cls,
attrs=('__name__', '__qualname__', '__module__', '__doc__'),
PartialClass, cls, attrs=('__name__', '__qualname__', '__module__', '__doc__'),
)

return PartialClass
Expand Down Expand Up @@ -904,10 +902,7 @@ def igroupby(
if val is None:
_append_to_group_items = append_to_group_items
else:
_append_to_group_items = lambda group_items, item: (
group_items,
val(item),
)
_append_to_group_items = lambda group_items, item: (group_items, val(item),)

for item in items:
group_key = key(item)
Expand Down
26 changes: 7 additions & 19 deletions dol/zipfiledol.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,7 @@ def zip_compress(


def zip_decompress(
b: bytes,
*,
allowZip64=True,
compresslevel=None,
strict_timestamps=True,
b: bytes, *, allowZip64=True, compresslevel=None, strict_timestamps=True,
) -> bytes:
"""Decompress input bytes of a single file zip, returning the uncompressed bytes
Expand Down Expand Up @@ -492,11 +488,7 @@ def __init__(
self.zip_reader_kwargs = zip_reader_kwargs
if self.zip_reader is ZipReader:
self.zip_reader_kwargs = dict(
dict(
prefix='',
open_kws=None,
file_info_filt=ZipReader.FILES_ONLY,
),
dict(prefix='', open_kws=None, file_info_filt=ZipReader.FILES_ONLY,),
**self.zip_reader_kwargs,
)

Expand Down Expand Up @@ -620,8 +612,7 @@ def mk_flatzips_store(
from dol.util import partialclass

ZipFileStreamsReader = mk_relative_path_store(
partialclass(ZipFilesReader, zip_reader=FileStreamsOfZip),
prefix_attr='rootdir',
partialclass(ZipFilesReader, zip_reader=FileStreamsOfZip), prefix_attr='rootdir',
)
ZipFileStreamsReader.__name__ = 'ZipFileStreamsReader'
ZipFileStreamsReader.__qualname__ = 'ZipFileStreamsReader'
Expand Down Expand Up @@ -783,10 +774,7 @@ def __getitem__(self, k):

def __repr__(self):
args_str = ', '.join(
(
f"'{self.zip_filepath}'",
f"'allow_overwrites={self.allow_overwrites}'",
)
(f"'{self.zip_filepath}'", f"'allow_overwrites={self.allow_overwrites}'",)
)
return f'{self.__class__.__name__}({args_str})'

Expand Down Expand Up @@ -946,12 +934,12 @@ def is_a_mac_junk_path(path):
# ----------------------------- Extras -------------------------------------------------


def tar_compress(data_bytes, file_name="data.bin"):
def tar_compress(data_bytes, file_name='data.bin'):
import tarfile
import io

with io.BytesIO() as tar_buffer:
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
with tarfile.open(fileobj=tar_buffer, mode='w') as tar:
data_file = io.BytesIO(data_bytes)
tarinfo = tarfile.TarInfo(name=file_name)
tarinfo.size = len(data_bytes)
Expand All @@ -964,7 +952,7 @@ def tar_decompress(tar_bytes):
import io

with io.BytesIO(tar_bytes) as tar_buffer:
with tarfile.open(fileobj=tar_buffer, mode="r:") as tar:
with tarfile.open(fileobj=tar_buffer, mode='r:') as tar:
for member in tar.getmembers():
extracted_file = tar.extractfile(member)
if extracted_file:
Expand Down

0 comments on commit 356bab8

Please sign in to comment.