Skip to content

Commit

Permalink
feat: ValueCodecs. Closes #42
Browse files Browse the repository at this point in the history
  • Loading branch information
thorwhalen committed Nov 15, 2023
1 parent a85c843 commit fedc9aa
Show file tree
Hide file tree
Showing 7 changed files with 633 additions and 32 deletions.
2 changes: 2 additions & 0 deletions dol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def ihead(store, n=1):
# )


from dol.kv_codecs import ValueCodecs, KeyCodecs

from dol.base import (
Collection, # base class for collections (adds to collections.abc.Collection)
MappingViewMixin,
Expand Down
316 changes: 316 additions & 0 deletions dol/kv_codecs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
# ------------------------------------ Codecs ------------------------------------------

from functools import partial
from dataclasses import dataclass
from typing import TypeVar, Generic, Callable, Iterable, Any, Optional

from dol.trans import wrap_kvs
from dol.util import Pipe, decorate_callables
from dol.signatures import Sig

# For the codecs:
import csv
import io


@Sig
def _string(string: str):
...


@Sig
def _csv_rw_sig(
dialect: str = "excel",
*,
delimiter: str = ",",
quotechar: Optional[str] = '"',
escapechar: Optional[str] = None,
doublequote: bool = True,
skipinitialspace: bool = False,
lineterminator: str = "\r\n",
quoting=0,
strict: bool = False,
):
...


@Sig
def _csv_dict_extra_sig(
fieldnames, restkey=None, restval="", extrasaction="raise", fieldcasts=None
):
...


@(_string + _csv_rw_sig)
def csv_encode(string, *args, **kwargs):
with io.StringIO() as buffer:
writer = csv.writer(buffer, *args, **kwargs)
writer.writerows(string)
return buffer.getvalue()


@(_string + _csv_rw_sig)
def csv_decode(string, *args, **kwargs):
with io.StringIO(string) as buffer:
reader = csv.reader(buffer, *args, **kwargs)
return list(reader)


@(_string + _csv_rw_sig + _csv_dict_extra_sig)
def csv_dict_encode(string, *args, **kwargs):
"""Encode a list of dicts into a csv string.
>>> data = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]
>>> encoded = csv_dict_encode(data, fieldnames=['a', 'b'])
>>> encoded
'a,b\r\n1,2\r\n3,4\r\n'
"""
_ = kwargs.pop("fieldcasts", None) # this one is for decoder only
with io.StringIO() as buffer:
writer = csv.DictWriter(buffer, *args, **kwargs)
writer.writeheader()
writer.writerows(string)
return buffer.getvalue()


@(_string + _csv_rw_sig + _csv_dict_extra_sig)
def csv_dict_decode(string, *args, **kwargs):
r"""Decode a csv string into a list of dicts.
:param string: The csv string to decode
:param fieldcasts: A function that takes a row and returns a row with the same keys
but with values cast to the desired type. If a dict, it should be a mapping
from fieldnames to cast functions. If an iterable, it should be an iterable of
cast functions, in which case each cast function will be applied to each element
of the row, element wise.
>>> data = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]
>>> encoded = csv_dict_encode(data, fieldnames=['a', 'b'])
>>> encoded
'a,b\r\n1,2\r\n3,4\r\n'
>>> csv_dict_decode(encoded)
[{'a': '1', 'b': '2'}, {'a': '3', 'b': '4'}]
See that you don't get back when you started with. The ints aren't ints anymore!
You can resolve this by using the fieldcasts argument
(that's our argument -- not present in builtin csv module).
I should be a function (that transforms a dict to the one you want) or
list or tuple of the same size as the row (that specifies the cast function for
each field)
>>> csv_dict_decode(encoded, fieldnames=['a', 'b'], fieldcasts=[int] * 2)
[{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]
>>> csv_dict_decode(encoded, fieldnames=['a', 'b'], fieldcasts={'b': float})
[{'a': '1', 'b': 2.0}, {'a': '3', 'b': 4.0}]
"""
fieldcasts = kwargs.pop("fieldcasts", lambda row: row)
if isinstance(fieldcasts, Iterable):
if isinstance(fieldcasts, dict):
cast_dict = dict(fieldcasts)
cast = lambda k: cast_dict.get(k, lambda x: x)
fieldcasts = lambda row: {k: cast(k)(v) for k, v in row.items()}
else:
_casts = list(fieldcasts)
# apply each cast function to each element of the row, element wise
fieldcasts = lambda row: {
k: cast(v) for cast, (k, v) in zip(_casts, row.items())
}
with io.StringIO(string) as buffer:
reader = csv.DictReader(buffer, *args, **kwargs)
rows = [row for row in reader]

def remove_first_row_if_only_header(rows):
first_row = next(iter(rows), None)
if first_row is not None and all(k == v for k, v in first_row.items()):
rows.pop(0)

remove_first_row_if_only_header(rows)
return list(map(fieldcasts, rows))


def _xml_tree_encode(element, parser=None):
# Needed to replace original "text" argument with "element" to be consistent with
# ET.tostring
import xml.etree.ElementTree as ET

return ET.fromstring(text=element, parser=parser)


def _xml_tree_decode(
element,
encoding=None,
method=None,
*,
xml_declaration=None,
default_namespace=None,
short_empty_elements=True,
):
import xml.etree.ElementTree as ET

return ET.tostring(
element,
encoding,
method,
xml_declaration=xml_declaration,
default_namespace=default_namespace,
short_empty_elements=short_empty_elements,
)


EncodedType = TypeVar('EncodedType')
DecodedType = TypeVar('DecodedType')


# TODO: Want a way to specify Encoded type and Decoded type
@dataclass
class Codec(Generic[DecodedType, EncodedType]):
encoder: Callable[[DecodedType], EncodedType]
decoder: Callable[[EncodedType], DecodedType]

def __iter__(self):
return iter((self.encoder, self.decoder))

def __add__(self, other):
return Codec(
encoder=Pipe(self.encoder, other.encoder),
decoder=Pipe(other.decoder, self.decoder),
)


class ValueCodec(Codec):
def __call__(self, obj):
return wrap_kvs(obj, data_of_obj=self.encoder, obj_of_data=self.decoder)


class KeyCodec(Codec):
def __call__(self, obj):
return wrap_kvs(obj, id_of_key=self.encoder, key_of_id=self.decoder)


def extract_arguments(func, args, kwargs):
return Sig(func).kwargs_from_args_and_kwargs(
args, kwargs, allow_partial=True, allow_excess=True, ignore_kind=True
)


def _var_kinds_less_signature(func):
sig = Sig(func)
var_kinds = (
sig.names_of_kind[Sig.VAR_POSITIONAL] + sig.names_of_kind[Sig.VAR_KEYWORD]
)
return sig - var_kinds


def _merge_signatures(encoder, decoder, *, exclude=()):
return (_var_kinds_less_signature(encoder) - exclude) + (
_var_kinds_less_signature(decoder) - exclude
)


def _codec_wrap(cls, encoder: Callable, decoder: Callable, **kwargs):
return cls(
encoder=partial(encoder, **extract_arguments(encoder, (), kwargs)),
decoder=partial(decoder, **extract_arguments(decoder, (), kwargs)),
)


def codec_wrap(cls, encoder: Callable, decoder: Callable, *, exclude=()):
_cls_codec_wrap = partial(_codec_wrap, cls)
factory = partial(_cls_codec_wrap, encoder, decoder)
sig = _merge_signatures(encoder, decoder, exclude=exclude)
return sig(factory)


value_wrap = partial(codec_wrap, ValueCodec)
value_wrap.__name__ = 'value_wrap'
key_wrap = partial(codec_wrap, KeyCodec)
key_wrap.__name__ = 'key_wrap'


class ValueCodecs:
"""
A collection of value codecs using standard lib tools.
>>> json_codec = ValueCodecs.json()
>>> encoder, decoder = json_codec
>>> encoder({'b': 2})
'{"b": 2}'
>>> decoder('{"b": 2}')
{'b': 2}
>>> backend = dict()
>>> interface = json_codec(backend)
>>> interface['a'] = {'b': 2} # we write a dict
>>> assert backend == {'a': '{"b": 2}'} # json was written in backend
>>> interface['a'] # but this json is decoded to a dict when read from interface
{'b': 2}
"""

# TODO: Clean up module import polution?
# TODO: Import all these in module instead of class
# TODO: Figure out a way to import these dynamically, only if a particular codec is used
# TODO: Figure out how to give codecs annotations that can actually be inspected!

def __iter__(self):
def is_value_codec(attr_val):
func = getattr(attr_val, 'func', None)
name = getattr(func, '__name__', '')
return name == '_codec_wrap'

for attr in dir(self):
if not attr.startswith('_'):
attr_val = getattr(self, attr, None)
if is_value_codec(attr_val):
yield attr

import pickle, json, gzip, bz2, base64 as b64, lzma, codecs
from dol.zipfiledol import (
zip_compress,
zip_decompress,
tar_compress,
tar_decompress,
)

pickle: Codec[Any, bytes] = value_wrap(pickle.dumps, pickle.loads)
json: Codec[dict, str] = value_wrap(json.dumps, json.loads)
csv: Codec[list, str] = value_wrap(csv_encode, csv_decode)
csv_dict: Codec[list, str] = value_wrap(csv_dict_encode, csv_dict_decode)

base64: Codec[bytes, bytes] = value_wrap(b64.b64encode, b64.b64decode)
urlsafe_b64: Codec[bytes, bytes] = value_wrap(
b64.urlsafe_b64encode, b64.urlsafe_b64decode
)
codecs: Codec[str, bytes] = value_wrap(codecs.encode, codecs.decode)

# Note: Note clear if escaping or unescaping is the encoder or decoder here
# I have never had the need for stores using it, so will omit for now
# html: Codec[str, str] = value_wrap(html.unescape, html.escape)

# Compression
zipfile: Codec[bytes, bytes] = value_wrap(zip_compress, zip_decompress)
gzip: Codec[bytes, bytes] = value_wrap(gzip.compress, gzip.decompress)
bz2: Codec[bytes, bytes] = value_wrap(bz2.compress, bz2.decompress)
tarfile: Codec[bytes, bytes] = value_wrap(tar_compress, tar_decompress)
lzma: Codec[bytes, bytes] = value_wrap(
lzma.compress, lzma.decompress, exclude=('format',)
)

import quopri, plistlib

quopri: Codec[bytes, bytes] = value_wrap(quopri.encodestring, quopri.decodestring)
# plistlib: Codec[bytes, bytes] = value_wrap(plistlib.dumps, plistlib.loads)

xml_etree: Codec['xml.etree.ElementTree', bytes] = value_wrap(
_xml_tree_encode, _xml_tree_decode
)


class KeyCodecs:
"""
A collection of key codecs
"""
Loading

0 comments on commit fedc9aa

Please sign in to comment.