From f3921b7ef6f24c9e9acd12ac712b1409cd20952a Mon Sep 17 00:00:00 2001 From: Alix Damman Date: Wed, 11 Dec 2019 11:07:02 +0100 Subject: [PATCH 1/2] updated test_session.h5 file and test_init_session_hdf() function --- larray/tests/data/test_session.h5 | Bin 14152 -> 2144424 bytes larray/tests/test_session.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/larray/tests/data/test_session.h5 b/larray/tests/data/test_session.h5 index d17d199b6526b80bf55f0832dee632173867489b..028dec515689444d8cb9d19b67aea0c41500271a 100644 GIT binary patch literal 2144424 zcmeI*O^jXFT>$VmcG65zr*Ru98K_|R9uRTIX)6pAVYjj45I43QCpgezzp*FrG|9{` zw%Z9Rl?4=1WT}Bn!BC3KmoQZY0aeHgB(h-PrI#$a@Ctd+g%>QcgM02d|99WqnP-m^ z&vc&W_oR8}-TTh{`rUKR{hxd9-1pT3PwapAp&xpvlU}2vo%=f5>yPpFf8}-g%2+)r zhU5Da<-df)u&G1X}v*+e|W2VX%FP8MBlAbPU zl;fxcA1vjyU9Puv#_B0j`A)Uw=u<}zRt328XfbWxD(8dA93k)fCjPOF8M>u6~|=f6iDYD^p1}i$J@6 z=F-{eoG$HlI82rg-AEtLAG$NMGq1fmd+tmbALTsPo-%TU&=oz97GEY0_N z8NU3@Jbq73I-}oOns1^EAO3d9aF>&4=f+!J8>O8a={&MxJN4GK4DnXk;kRq2<{ruU zx8G$yvl|Y9c01K;56bZ?wqknb)YL*X@P+Ox=ccMgDcX}gk>L*0+)g#o^|hbNGl==V zQ-*gumdD3>o%r>o`Cc!>mtV-^$8yqH=q}B7sSLk;GLMhr^5ya&mV2!X-+ZXdXoDxc z{=*OF6}?&3ccZjZ+dBD(Fj7TkDIS~a&YbF=IezxFxg2?>_%I!c!vD{XZx4)fjWO(KaPK)Tdn=s*Q)uY=^p6YuKB5#&b`u|ueS5y zysc$}Q=FP!BO`f4TDSCyH1_Y6QUPw{udDAz>-lWivZbDIw9IGg$dziz={25{^8TSJ ze|n|N)Bkz%bItaqcztjjzW2zHy{T$7`BUz}aromqKb{s-JFumiV}70Z zK>08rDI3?h^n(L3Z1T~$6qU|aWl>HmR6VlhmGY#{So*F#KIN)AYUx9H#c+S#^i_LV z-HGn5Zi!!;Z(}a)-0d~XQ+Ges%>`(M`NtHo?AV6(657MD{_XMY`K0ZOdKcfN^j6ak z%GWktX})RN^cp0Sziqycys_P}zNsBdZB4ue&2MsIa@XYJlOLP>c(bQgTF;nYs*?{^ zf9X(|!hsZ8K3x5!D%#Ae8t@`OfB=E-o50rUyliUAw^xmJj-~e0>(az+Ij-G*aPRYn zj^y|>9h2f%F>~(KxfHiH=TF~XeID(co$t<^nVUT`b*i3S=LdSz6x=rFHOx?%NVa`JW@xrqn(-Vt5bFU>Mb4TV!qGptg}_a57p!5$``kb zpZi_X@v_qytKSF5&(}+cnYOffZPfU=@mrt$8|TH^{p4mn@_z5_$4WEps@RW{g|wYN zI)tg5r2Q26_byD=r$C#2-TUiubt+A8JW5wCU;WFO?ky+v?$Z1otjqHkr(ZdBvO8Z& zb7x?AE|XW&t$aP;S~)Lzsbsv}dA5lk&lz*AiMdNRi$ME4+gRD~g;I{fVRGW}y8qLi z$k(H)>j3qqs1La`NABCFQD58V^Hve<9V=w$f-nP}A$KL(ZIe(lV zxV?-V_mV)newEW@o&Wx^rCYf+=dn|ikzY5;*^8*J+dq-VPy9koW4$iEoQEUd+hzFr zt9g7JC+2>1X}*hP_)2TLa<1WhkMo>2el~Bzjj|~lrJZWx59hI0a~?a+)kPZn_x#T+ z-M@q5@c%7UDr!ugZllKG?eo~QCdnKKnP~;`Z^hS&!xY*#3TL#Odw){r%Yf^D9rL zyY8wT(&SJ5*!wzz`mxoyk97K@%D=nbjj8{d&Zl&%{_ydar_!mK`jz_2>HK`e5ploY zM@vV0P<;JHxvsw2`1aL-4hrmUPs`+B~29ew3|n{Di7iwjSe1A6=V zM!WuOtzv?d_0j4t1q_4h%_mE}Nmuuq*GBD6HF47V#)Iwq4x9B@9uFq>h`YLlw4J{{ z9_)`F-|cmc=X%R&;=*)B;?A#YtbG6Q<#L{CwQ=F;U&<>K?N(Vm$tDpP6c-*Yf$F?{(iHL-8^yOwmR0>Qv!~tA3I9>ccN7QwpRCpr>yBV=~Q;oU!F9dRqfTIt$B3s z{_!naZv0A>w|%UAk`(&s$KsF&?;4HFZ>HXPa&~@d=d0&lIXm5*pYOi5 zcwyV|`Kgn$GiT1upF4SW@y*2x_tzO0FFa8FPh_3t&BdoXUs?S0$mHT5Ego4Me|qHU zEnitYz4gNV^<}~|$(BRAt2|GS+Dy)fv0>$*lk1<& z`CWg>4x8%j^BTwQIk*WVM$KMEknE_XO!>SvKW}&2!&Le%{KU`;nvG zM7sL>k$x(TWc%?Q{=N#E0=b#diL?9Un&X(3~H=v*+-| z`VKm8cxuvYQm`0wT=MEu~q;c%WejQc4g zjrj1y{L=k9`1tmnp*p^`E6BPG_U;sYFEFM_xg#P^xogY|EpcB{+%no{J&aAbS}rw!1TuCNGj@OZ?q2t z2oNAZfWSQ?&^{j*=TYwV{c7#!R=4BI-tRzEvUhs^ssFx+b$@K7uCT|7lK z`gzyR=kIS6*xUX7!`0u<`C~qL$opmyXrFh=Pw(n<7S1mfNY8hjs79r9SCvk_GJEpn ziRyd2)a$XOv*`^*d41Td^LoXUE^Fv`XDQZC|rx12UA+<0U=u$G6TGN(-;wI^RbB_MSUmRqg0h z=c3+CmK_lFVZ7|Grk$LqbBt8KdKu5(-nFXwn9qM*D+_AJV^QBD9ppT?K7Mzo&L1~% z;gFr*Y?YU8tGtYscHmeU-?ZPOx!tc1;AiUJ@^)(C2e($jP z{qI}13$cHj^iF^G;l2;$RgEK19@?nK>Gj^P&GO#kb&poN5XX`DKd#+h&hE#~=d``t z{`)NYfE2^{7Xwu9_Rhu z7Z}t}9V_K11lFpbdTB8)p^3l7>JcN=QykOpr5t_g=)rx}m{bqap|Uy-)`!BTzO2pf z6gJ28!N332Uyt=u^Ka5eeA68N9?r3MFV4v=&B#a|5phkVaU8z#W?c+c@%s=TDaXI6 zfb_jK>hUk%SJCuSPy9*WO0V!Sw(4f7Co5#BM@2t%^e+bG$m64(AN%uL=iBJt_Q$`Y z^~&z4-J>qWnQ=2$6FK>|Z;C{x>i$`m^ zTq%eYl}dFx`{f5&oC#g|cD?ruLd%1gB0%ZHo(%Ai6riGJ$!a(<&z;`?Shb9=O2{nW9) z&INb1{nQ!&>!I5ByzTP7{BP)g$y%PtKm2o_Q(uLh7gY z_luu9y8p3sWNO-(+@4O&&(F?RSG3#bx2lIFPp98&XyTQPaXmH1E^#e4{c1$Dq1Efe z-`2@0`l+v!00{%TPu{KlQDD9F!xEk9Pj*KioRs zM*ntDoYbxl`w#AY{?L(JANuQ0{lxHUBb)!Hdw0(pYnJyO`yJJePJNl^&rFsb5cOf9 z?61RRd{d7%+V#{6;?a5^t#LoKRbINS@^Y-S1ILz?m%H0fjebkCR%02 zd*0Ol_36(*st{-THkgdjw8`e zef=BtN^~xl_`WH>>)TJg`t7_PtL>-0{VzFx#4mZs`(_as)K8r&ALO=Ch>8Ix3agG2yjzu$}kR?^i!pYOZnxHlZn=Znp3Ur&AOpW~aA(<&UT zUkCne9brqTVeZzSHZe-z(+klTG2NJ}Ez$ zBiez9D|vXKC^hHX-djI)P6sMqQ#i90^K?jkHRY5X1NXQVJ96CftC-3s!yR@-FZ+YMQ?%ntO-|zq3>tFZ9 zz14~?r7s9(v;=5=#H#-gq2nc&b%8#6hy`r9$<|wJ&G+R+wl_>G0*wF$U1|o2;G|Xd ziU_0|VN4V6!3;)F!&_UrH-gZL2gV-X03(D&4L*5h7t&Mo7_VGl`%0ANX)!@apY1!x z_LXtd`U*QTMtHc^EG9)**j_`b0-v`pVW+4hsxq|V+KH2;nI$mZR#E+wX$7DU zK)ps_d6s}~eMiP9ym@jJ7I!mDK%gg})>u Date: Tue, 10 Dec 2019 12:17:49 +0100 Subject: [PATCH 2/2] fix #832 : implemented CheckedSession, CheckedParameters and CheckedArray classes (based on preliminary code written by gdementen) --- README.rst | 6 + doc/source/api.rst | 25 +- doc/source/changes/version_0_33.rst.inc | 31 +- environment.yml | 3 +- larray/__init__.py | 3 + larray/core/axis.py | 4 +- larray/core/checked.py | 537 +++++++++++++++++++ larray/core/group.py | 6 +- larray/core/session.py | 17 +- larray/tests/common.py | 14 + larray/tests/data/test_session.h5 | Bin 2144424 -> 2144784 bytes larray/tests/test_checked_session.py | 670 ++++++++++++++++++++++++ larray/tests/test_session.py | 290 +++++----- make_release.py | 4 +- setup.py | 2 - 15 files changed, 1462 insertions(+), 150 deletions(-) create mode 100644 larray/core/checked.py create mode 100644 larray/tests/test_checked_session.py diff --git a/README.rst b/README.rst index f6243dc6c..ff32c8881 100644 --- a/README.rst +++ b/README.rst @@ -129,6 +129,12 @@ For plotting - `matplotlib `__: required for plotting. +Miscellaneous +~~~~~~~~~~~~~ + +- `pydantic `__: + required to use `CheckedSession`. + .. _start-documentation: Documentation diff --git a/doc/source/api.rst b/doc/source/api.rst index 029535af8..47289b1bf 100644 --- a/doc/source/api.rst +++ b/doc/source/api.rst @@ -795,7 +795,6 @@ Modifying Session.add Session.update - Session.get Session.apply Session.transpose @@ -821,6 +820,30 @@ Load/Save Session.to_hdf Session.to_pickle +CheckedArray +============ + +.. autosummary:: + :toctree: _generated/ + + CheckedArray + +CheckedSession +============== + +.. autosummary:: + :toctree: _generated/ + + CheckedSession + +CheckedParameters +================= + +.. autosummary:: + :toctree: _generated/ + + CheckedParameters + .. _api-editor: Editor diff --git a/doc/source/changes/version_0_33.rst.inc b/doc/source/changes/version_0_33.rst.inc index c288ba1f1..03c38d856 100644 --- a/doc/source/changes/version_0_33.rst.inc +++ b/doc/source/changes/version_0_33.rst.inc @@ -20,30 +20,21 @@ New features * added official support for Python 3.9 (0.32.3 already supports it even though it was not mentioned). -* added a feature (see the :ref:`miscellaneous section ` for details). It works on :ref:`api-axis` and - :ref:`api-group` objects. +* added :py:obj:`CheckedSession`, :py:obj:`CheckedParameters` and :py:obj:`CheckedArray` objects. - Here is an example of the new feature: + `CheckedSession` is intended to be inherited by user defined classes in which the variables of a model + are declared. By declaring variables, users will speed up the development of their models using the auto-completion + (the feature in which development tools like PyCharm try to predict the variable or function a user intends + to enter after only a few characters have been typed). All user defined classes inheriting from `CheckedSession` + will have access to the same methods as `Session` objects. - >>> arr = ndtest((2, 3)) - >>> arr - a\b b0 b1 b2 - a0 0 1 2 - a1 3 4 5 + `CheckedParameters` is the same as `CheckedSession` but the declared variables cannot be + modified after initialization. - And it can also be used like this: + The special :py:funct:`CheckedArray` type represents an Array object with fixed axes and/or dtype. + It is intended to be only used along with :py:class:`CheckedSession`. - >>> arr = ndtest("a=a0..a2") - >>> arr - a a0 a1 a2 - 0 1 2 - -* added another feature in the editor (closes :editor_issue:`1`). - - .. note:: - - - It works for foo bar ! - - It does not work for foo baz ! + Closes :issue:`832`. .. _misc: diff --git a/environment.yml b/environment.yml index 8b42a3e4d..40bbf5b8c 100644 --- a/environment.yml +++ b/environment.yml @@ -12,4 +12,5 @@ dependencies: - pytest>=3.5 - flake8 - pip: - - pytest-flake8 \ No newline at end of file + - pytest-flake8 + - pydantic==1.5 \ No newline at end of file diff --git a/larray/__init__.py b/larray/__init__.py index bc6402e3f..8cd0eaf94 100644 --- a/larray/__init__.py +++ b/larray/__init__.py @@ -8,6 +8,7 @@ eye, all, any, sum, prod, cumsum, cumprod, min, max, mean, ptp, var, std, median, percentile, stack, zip_array_values, zip_array_items) from larray.core.session import Session, local_arrays, global_arrays, arrays +from larray.core.checked import CheckedArray, CheckedSession, CheckedParameters from larray.core.constants import nan, inf, pi, e, euler_gamma from larray.core.metadata import Metadata from larray.core.ufuncs import wrap_elementwise_array_func, maximum, minimum, where @@ -55,6 +56,8 @@ 'median', 'percentile', 'stack', 'zip_array_values', 'zip_array_items', # session 'Session', 'local_arrays', 'global_arrays', 'arrays', + # constrained + 'CheckedArray', 'CheckedSession', 'CheckedParameters', # constants 'nan', 'inf', 'pi', 'e', 'euler_gamma', # metadata diff --git a/larray/core/axis.py b/larray/core/axis.py index e498b35d4..5ce1008c3 100644 --- a/larray/core/axis.py +++ b/larray/core/axis.py @@ -839,7 +839,7 @@ def __getitem__(self, key): ----- key is label-based (slice and fancy indexing are supported) """ - # if isinstance(key, basestring): + # if isinstance(key, str): # key = to_keys(key) def isscalar(k): @@ -862,7 +862,7 @@ def isscalar(k): and key.name in self ): return LGroup(key.name, None, self) - # elif isinstance(key, basestring) and key in self: + # elif isinstance(key, str) and key in self: # TODO: this is an awful workaround to avoid the "processing" of string keys which exist as is in the axis # (probably because the string was used in an aggregate function to create the label) # res = LGroup(slice(None), None, self) diff --git a/larray/core/checked.py b/larray/core/checked.py new file mode 100644 index 000000000..71dd6ccb6 --- /dev/null +++ b/larray/core/checked.py @@ -0,0 +1,537 @@ +from abc import ABCMeta +from copy import deepcopy +import warnings + +import numpy as np + +from typing import TYPE_CHECKING, Type, Any, Dict, Set, List, no_type_check + +from larray.core.metadata import Metadata +from larray.core.axis import AxisCollection +from larray.core.group import Group +from larray.core.array import Array, full +from larray.core.session import Session + + +class NotLoaded: + pass + + +try: + import pydantic +except ImportError: + pydantic = None + +# moved the not implemented versions of Checked* classes in the beginning of the module +# otherwise PyCharm do not provide auto-completion for methods of CheckedSession +# (imported from Session) +if not pydantic: + def CheckedArray(axes: AxisCollection, dtype: np.dtype = float) -> Type[Array]: + raise NotImplementedError("CheckedArray cannot be used because pydantic is not installed") + + class CheckedSession: + def __init__(self, *args, **kwargs): + raise NotImplementedError("CheckedSession class cannot be instantiated " + "because pydantic is not installed") + + class CheckedParameters: + def __init__(self, *args, **kwargs): + raise NotImplementedError("CheckedParameters class cannot be instantiated " + "because pydantic is not installed") +else: + from pydantic.fields import ModelField + from pydantic.class_validators import Validator + from pydantic.main import BaseConfig + + # the implementation of the class below is inspired by the 'ConstrainedBytes' class + # from the types.py module of the 'pydantic' library + class CheckedArrayImpl(Array): + expected_axes: AxisCollection + dtype: np.dtype = np.dtype(float) + + # see https://pydantic-docs.helpmanual.io/usage/types/#classes-with-__get_validators__ + @classmethod + def __get_validators__(cls): + # one or more validators may be yielded which will be called in the + # order to validate the input, each validator will receive as an input + # the value returned from the previous validator + yield cls.validate + + @classmethod + def validate(cls, value, field: ModelField): + if not (isinstance(value, Array) or np.isscalar(value)): + raise TypeError(f"Expected object of type '{Array.__name__}' or a scalar for " + f"the variable '{field.name}' but got object of type '{type(value).__name__}'") + + # check axes + if isinstance(value, Array): + error_msg = f"Array '{field.name}' was declared with axes {cls.expected_axes} but got array " \ + f"with axes {value.axes}" + # check for extra axes + extra_axes = value.axes - cls.expected_axes + if extra_axes: + raise ValueError(f"{error_msg} (unexpected {extra_axes} " + f"{'axes' if len(extra_axes) > 1 else 'axis'})") + # check compatible axes + try: + cls.expected_axes.check_compatible(value.axes) + except ValueError as error: + error_msg = str(error).replace("incompatible axes", f"Incompatible axis for array '{field.name}'") + raise ValueError(error_msg) + # broadcast + transpose if needed + value = value.expand(cls.expected_axes) + # check dtype + if value.dtype != cls.dtype: + value = value.astype(cls.dtype) + return value + else: + return full(axes=cls.expected_axes, fill_value=value, dtype=cls.dtype) + + # the implementation of the function below is inspired by the 'conbytes' function + # from the types.py module of the 'pydantic' library + + def CheckedArray(axes: AxisCollection, dtype: np.dtype = float) -> Type[Array]: + # XXX: for a very weird reason I don't know, I have to put the fake import below + # to get autocompletion from PyCharm + from larray.core.checked import CheckedArrayImpl + """ + Represents a constrained array. It is intended to only be used along with :py:class:`CheckedSession`. + + Its axes are assumed to be "frozen", meaning they are constant all along the execution of the program. + A constraint on the dtype of the data can be also specified. + + Parameters + ---------- + axes: AxisCollection + Axes of the checked array. + dtype: data-type, optional + Data-type for the checked array. Defaults to float. + + Returns + ------- + Array + Constrained array. + """ + if axes is not None and not isinstance(axes, AxisCollection): + axes = AxisCollection(axes) + _dtype = np.dtype(dtype) + + class ArrayDefValue(CheckedArrayImpl): + expected_axes = axes + dtype = _dtype + + return ArrayDefValue + + class AbstractCheckedSession: + pass + + # Simplified version of the ModelMetaclass class from pydantic: + # https://github.com/samuelcolvin/pydantic/blob/master/pydantic/main.py#L195 + + class ModelMetaclass(ABCMeta): + @no_type_check # noqa C901 + def __new__(mcs, name, bases, namespace, **kwargs): + from pydantic.fields import Undefined + from pydantic.class_validators import extract_validators, inherit_validators + from pydantic.types import PyObject + from pydantic.typing import is_classvar, resolve_annotations + from pydantic.utils import lenient_issubclass, validate_field_name + from pydantic.main import inherit_config, prepare_config, UNTOUCHED_TYPES + + fields: Dict[str, ModelField] = {} + config = BaseConfig + validators: Dict[str, List[Validator]] = {} + + for base in reversed(bases): + if issubclass(base, AbstractCheckedSession) and base != AbstractCheckedSession: + config = inherit_config(base.__config__, config) + fields.update(deepcopy(base.__fields__)) + validators = inherit_validators(base.__validators__, validators) + + config = inherit_config(namespace.get('Config'), config) + validators = inherit_validators(extract_validators(namespace), validators) + + # update fields inherited from base classes + for field in fields.values(): + field.set_config(config) + extra_validators = validators.get(field.name, []) + if extra_validators: + field.class_validators.update(extra_validators) + # re-run prepare to add extra validators + field.populate_validators() + + prepare_config(config, name) + + # extract and build fields + class_vars = set() + if (namespace.get('__module__'), namespace.get('__qualname__')) != \ + ('larray.core.checked', 'CheckedSession'): + untouched_types = UNTOUCHED_TYPES + config.keep_untouched + + # annotation only fields need to come first in fields + annotations = resolve_annotations(namespace.get('__annotations__', {}), + namespace.get('__module__', None)) + for ann_name, ann_type in annotations.items(): + if is_classvar(ann_type): + class_vars.add(ann_name) + elif not ann_name.startswith('_'): + validate_field_name(bases, ann_name) + value = namespace.get(ann_name, Undefined) + if (isinstance(value, untouched_types) and ann_type != PyObject + and not lenient_issubclass(getattr(ann_type, '__origin__', None), Type)): + continue + fields[ann_name] = ModelField.infer(name=ann_name, value=value, annotation=ann_type, + class_validators=validators.get(ann_name, []), + config=config) + + for var_name, value in namespace.items(): + # 'var_name not in annotations' because namespace.items() contains annotated fields + # with default values + # 'var_name not in class_vars' to avoid to update a field if it was redeclared (by mistake) + if (var_name not in annotations and not var_name.startswith('_') + and not isinstance(value, untouched_types) and var_name not in class_vars): + validate_field_name(bases, var_name) + # the method ModelField.infer() fails to infer the type of Group objects + # (which are interpreted as ndarray objects) + annotation = type(value) if isinstance(value, Group) else annotations.get(var_name) + inferred = ModelField.infer(name=var_name, value=value, annotation=annotation, + class_validators=validators.get(var_name, []), config=config) + if var_name in fields and inferred.type_ != fields[var_name].type_: + raise TypeError(f'The type of {name}.{var_name} differs from the new default value; ' + f'if you wish to change the type of this field, please use a type ' + f'annotation') + fields[var_name] = inferred + + new_namespace = { + '__config__': config, + '__fields__': fields, + '__field_defaults__': {n: f.default for n, f in fields.items() if not f.required}, + '__validators__': validators, + **{n: v for n, v in namespace.items() if n not in fields}, + } + return super().__new__(mcs, name, bases, new_namespace, **kwargs) + + class CheckedSession(Session, AbstractCheckedSession, metaclass=ModelMetaclass): + """ + This class is intended to be inherited by user defined classes in which the variables of a model are declared. + Each declared variable is constrained by a type defined explicitly or deduced from the given default value + (see examples below). + All classes inheriting from `CheckedSession` will have access to all methods of the :py:class:`Session` class. + + The special :py:funct:`ConsArray` type represents an Array object with fixed axes and/or dtype. + This prevents users from modifying the dimensions (and labels) and/or the dtype of an array by mistake + and make sure that the definition of an array remains always valid in the model. + + By declaring variables, users will speed up the development of their models using the auto-completion + (the feature in which development tools like PyCharm try to predict the variable or function a user intends + to enter after only a few characters have been typed). + + As for normal Session objects, it is still possible to add undeclared variables to instances of + classes inheriting from `CheckedSession` but this must be done with caution. + + Parameters + ---------- + *args : str or dict of {str: object} or iterable of tuples (str, object) + Path to the file containing the session to load or + list/tuple/dictionary containing couples (name, object). + **kwargs : dict of {str: object} + + * Objects to add written as name=object + * meta : list of pairs or dict or OrderedDict or Metadata, optional + Metadata (title, description, author, creation_date, ...) associated with the array. + Keys must be strings. Values must be of type string, int, float, date, time or datetime. + + Warnings + -------- + The :py:method:`CheckedSession.filter`, :py:method:`CheckedSession.compact` + and :py:method:`CheckedSession.apply` methods return a simple Session object. + The type of the declared variables (and the value for the declared constants) will + no longer be checked. + + See Also + -------- + Session, CheckedParameters + + Examples + -------- + + Content of file 'parameters.py' + + >>> from larray import * + >>> FIRST_YEAR = 2020 + >>> LAST_YEAR = 2030 + >>> AGE = Axis('age=0..10') + >>> GENDER = Axis('gender=male,female') + >>> TIME = Axis(f'time={FIRST_YEAR}..{LAST_YEAR}') + + Content of file 'model.py' + + >>> class ModelVariables(CheckedSession): + ... # --- declare variables with defined types --- + ... # Their values will be defined at runtime but must match the specified type. + ... birth_rate: Array + ... births: Array + ... # --- declare variables with a default value --- + ... # The default value will be used to set the variable if no value is passed at instantiation (see below). + ... # Their type is deduced from their default value and cannot be changed at runtime. + ... target_age = AGE[:2] >> '0-2' + ... population = zeros((AGE, GENDER, TIME), dtype=int) + ... # --- declare checked arrays --- + ... # The checked arrays have axes assumed to be "frozen", meaning they are + ... # constant all along the execution of the program. + ... mortality_rate: CheckedArray((AGE, GENDER)) + ... # For checked arrays, the default value can be given as a scalar. + ... # Optionally, a dtype can be also specified (defaults to float). + ... deaths: CheckedArray((AGE, GENDER, TIME), dtype=int) = 0 + + >>> variant_name = "baseline" + >>> # Instantiation --> create an instance of the ModelVariables class. + >>> # Warning: All variables declared without a default value must be set. + >>> m = ModelVariables(birth_rate = zeros((AGE, GENDER)), + ... births = zeros((AGE, GENDER, TIME), dtype=int), + ... mortality_rate = 0) + + >>> # ==== model ==== + >>> # In the definition of ModelVariables, the 'birth_rate' variable, has been declared as an Array object. + >>> # This means that the 'birth_rate' variable will always remain of type Array. + >>> # Any attempt to assign a non-Array value to 'birth_rate' will make the program to crash. + >>> m.birth_rate = Array([0.045, 0.055], GENDER) # OK + >>> m.birth_rate = [0.045, 0.055] # Fails + Traceback (most recent call last): + ... + pydantic.errors.ArbitraryTypeError: instance of Array expected + >>> # However, the arrays 'birth_rate', 'births' and 'population' have not been declared as 'CheckedArray'. + >>> # Thus, axes and dtype of these arrays are not protected, leading to potentially unexpected behavior + >>> # of the model. + >>> # example 1: Let's say we want to calculate the new births for the year 2025 and we assume that + >>> # the birth rate only differ by gender. + >>> # In the line below, we add an additional TIME axis to 'birth_rate' while it was initialized + >>> # with the AGE and GENDER axes only + >>> m.birth_rate = full((AGE, GENDER, TIME), fill_value=Array([0.045, 0.055], GENDER)) + >>> # here 'new_births' have the AGE, GENDER and TIME axes instead of the AGE and GENDER axes only + >>> new_births = m.population['female', 2025] * m.birth_rate + >>> print(new_births.info) + 11 x 2 x 11 + age [11]: 0 1 2 ... 8 9 10 + gender [2]: 'male' 'female' + time [11]: 2020 2021 2022 ... 2028 2029 2030 + dtype: float64 + memory used: 1.89 Kb + >>> # and the line below will crash + >>> m.births[2025] = new_births # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Value {time} axis is not present in target subset {age, gender}. + A value can only have the same axes or fewer axes than the subset being targeted + >>> # now let's try to do the same for deaths and making the same mistake as for 'birth_rate'. + >>> # The program will crash now at the first step instead of letting you go further + >>> m.mortality_rate = full((AGE, GENDER, TIME), fill_value=sequence(AGE, inc=0.02)) \ + # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Array 'mortality_rate' was declared with axes {age, gender} but got array with axes + {age, gender, time} (unexpected {time} axis) + + >>> # example 2: let's say we want to calculate the new births for all years. + >>> m.birth_rate = full((AGE, GENDER, TIME), fill_value=Array([0.045, 0.055], GENDER)) + >>> new_births = m.population['female'] * m.birth_rate + >>> # here 'new_births' has the same axes as 'births' but is a float array instead of + >>> # an integer array as 'births'. + >>> # The line below will make the 'births' array become a float array while + >>> # it was initialized as an integer array + >>> m.births = new_births + >>> print(m.births.info) + 11 x 11 x 2 + age [11]: 0 1 2 ... 8 9 10 + time [11]: 2020 2021 2022 ... 2028 2029 2030 + gender [2]: 'male' 'female' + dtype: float64 + memory used: 1.89 Kb + >>> # now let's try to do the same for deaths. + >>> m.mortality_rate = full((AGE, GENDER), fill_value=sequence(AGE, inc=0.02)) + >>> # here the result of the multiplication of the 'population' array by the 'mortality_rate' array + >>> # is automatically converted to an integer array + >>> m.deaths = m.population * m.mortality_rate + >>> print(m.deaths.info) # doctest: +SKIP + 11 x 2 x 11 + age [11]: 0 1 2 ... 8 9 10 + gender [2]: 'male' 'female' + time [11]: 2020 2021 2022 ... 2028 2029 2030 + dtype: int32 + memory used: 968 bytes + + >>> # note that it still possible to add undeclared variables to a checked session + >>> # but this must be done with caution. + >>> m.undeclared_var = 'undeclared_var' + + >>> # ==== output ==== + >>> # save all variables in an HDF5 file + >>> m.save(f'{variant_name}.h5', display=True) + dumping birth_rate ... done + dumping births ... done + dumping mortality_rate ... done + dumping deaths ... done + dumping target_age ... done + dumping population ... done + dumping undeclared_var ... done + """ + if TYPE_CHECKING: + # populated by the metaclass, defined here to help IDEs only + __fields__: Dict[str, ModelField] = {} + __field_defaults__: Dict[str, Any] = {} + __validators__: Dict[str, List[Validator]] = {} + __config__: Type[BaseConfig] = BaseConfig + + class Config: + # whether to allow arbitrary user types for fields (they are validated simply by checking + # if the value is an instance of the type). If False, RuntimeError will be raised on model declaration. + # (default: False) + arbitrary_types_allowed = True + # whether to validate field defaults + validate_all = True + # whether to ignore, allow, or forbid extra attributes during model initialization (and after). + # Accepts the string values of 'ignore', 'allow', or 'forbid', or values of the Extra enum + # (default: Extra.ignore) + extra = 'allow' + # whether to perform validation on assignment to attributes + validate_assignment = True + # whether or not models are faux-immutable, i.e. whether __setattr__ is allowed. + # (default: True) + allow_mutation = True + + # Warning: order of fields is not preserved. + # As of v1.0 of pydantic all fields with annotations (whether annotation-only or with a default value) + # will precede all fields without an annotation. Within their respective groups, fields remain in the + # order they were defined. + # See https://pydantic-docs.helpmanual.io/usage/models/#field-ordering + def __init__(self, *args, **kwargs): + meta = kwargs.pop('meta', Metadata()) + Session.__init__(self, meta=meta) + + # create an intermediate Session object to not call the __setattr__ + # and __setitem__ overridden in the present class and in case a filepath + # is given as only argument + # todo: refactor Session.load() to use a private function which returns the handler directly + # so that we can get the items out of it and avoid this + input_data = dict(Session(*args, **kwargs)) + + # --- declared variables + for name, field in self.__fields__.items(): + value = input_data.pop(field.name, NotLoaded()) + + if isinstance(value, NotLoaded): + if field.default is None: + warnings.warn(f"No value passed for the declared variable '{field.name}'", stacklevel=2) + self.__setattr__(name, value, skip_allow_mutation=True, skip_validation=True) + else: + self.__setattr__(name, field.default, skip_allow_mutation=True) + else: + self.__setattr__(name, value, skip_allow_mutation=True) + + # --- undeclared variables + for name, value in input_data.items(): + self.__setattr__(name, value, skip_allow_mutation=True) + + # code of the method below has been partly borrowed from pydantic.BaseModel.__setattr__() + def _check_key_value(self, name: str, value: Any, skip_allow_mutation: bool, skip_validation: bool) -> Any: + config = self.__config__ + if not config.extra and name not in self.__fields__: + raise ValueError(f"Variable '{name}' is not declared in '{self.__class__.__name__}'. " + f"Adding undeclared variables is forbidden. " + f"List of declared variables is: {list(self.__fields__.keys())}.") + if not skip_allow_mutation and not config.allow_mutation: + raise TypeError(f"Cannot change the value of the variable '{name}' since '{self.__class__.__name__}' " + f"is immutable and does not support item assignment") + known_field = self.__fields__.get(name, None) + if known_field: + if not skip_validation: + value, error_ = known_field.validate(value, self.dict(exclude={name}), loc=name, cls=self.__class__) + if error_: + raise error_.exc + else: + warnings.warn(f"'{name}' is not declared in '{self.__class__.__name__}'", stacklevel=3) + return value + + def __setitem__(self, key, value, skip_allow_mutation=False, skip_validation=False): + if key != 'meta': + value = self._check_key_value(key, value, skip_allow_mutation, skip_validation) + # we need to keep the attribute in sync + object.__setattr__(self, key, value) + self._objects[key] = value + + def __setattr__(self, key, value, skip_allow_mutation=False, skip_validation=False): + if key != 'meta': + value = self._check_key_value(key, value, skip_allow_mutation, skip_validation) + # we need to keep the attribute in sync + object.__setattr__(self, key, value) + Session.__setattr__(self, key, value) + + def __getstate__(self) -> Dict[str, Any]: + return {'__dict__': self.__dict__} + + def __setstate__(self, state: Dict[str, Any]) -> None: + object.__setattr__(self, '__dict__', state['__dict__']) + + def dict(self, exclude: Set[str] = None): + d = dict(self.items()) + for name in exclude: + if name in d: + del d[name] + return d + + class CheckedParameters(CheckedSession): + """ + Same as py:class:`CheckedSession` but declared variables cannot be modified after initialization. + + Parameters + ---------- + *args : str or dict of {str: object} or iterable of tuples (str, object) + Path to the file containing the session to load or + list/tuple/dictionary containing couples (name, object). + **kwargs : dict of {str: object} + + * Objects to add written as name=object + * meta : list of pairs or dict or OrderedDict or Metadata, optional + Metadata (title, description, author, creation_date, ...) associated with the array. + Keys must be strings. Values must be of type string, int, float, date, time or datetime. + + See Also + -------- + CheckedSession + + Examples + -------- + + Content of file 'parameters.py' + + >>> from larray import * + >>> class Parameters(CheckedParameters): + ... # --- declare variables with fixed values --- + ... # The given values can never be changed + ... FIRST_YEAR = 2020 + ... LAST_YEAR = 2030 + ... AGE = Axis('age=0..10') + ... GENDER = Axis('gender=male,female') + ... TIME = Axis(f'time={FIRST_YEAR}..{LAST_YEAR}') + ... # --- declare variables with defined types --- + ... # Their values must be defined at initialized and will be frozen after. + ... variant_name: str + + Content of file 'model.py' + + >>> # instantiation --> create an instance of the ModelVariables class + >>> # all variables declared without value must be set + >>> P = Parameters(variant_name='variant_1') + >>> # once an instance is created, its variables can be accessed but not modified + >>> P.variant_name + 'variant_1' + >>> P.variant_name = 'new_variant' # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + TypeError: Cannot change the value of the variable 'variant_name' since 'Parameters' + is immutable and does not support item assignment + """ + class Config: + # whether or not models are faux-immutable, i.e. whether __setattr__ is allowed. + # (default: True) + allow_mutation = False diff --git a/larray/core/group.py b/larray/core/group.py index 00ddbacfa..f885a2653 100644 --- a/larray/core/group.py +++ b/larray/core/group.py @@ -452,7 +452,7 @@ def _seq_str_to_seq(s, stack_depth=1, parse_single_int=False): Parameters ---------- - s : basestring + s : str string to parse Returns @@ -496,7 +496,7 @@ def _to_key(v, stack_depth=1, parse_single_int=False): Parameters ---------- - v : int or basestring or tuple or list or slice or Array or Group + v : int or str or tuple or list or slice or Array or Group value to convert into a key usable for indexing Returns @@ -598,7 +598,7 @@ def _to_keys(value, stack_depth=1): Parameters ---------- - value : int or basestring or tuple or list or slice or Array or Group + value : int or str or tuple or list or slice or Array or Group (collection of) value(s) to convert into key(s) usable for indexing Returns diff --git a/larray/core/session.py b/larray/core/session.py index a3d34d817..d753469c0 100644 --- a/larray/core/session.py +++ b/larray/core/session.py @@ -85,6 +85,7 @@ def __init__(self, *args, **kwargs): self.meta = meta if len(args) == 1: + assert len(kwargs) == 0 a0 = args[0] if isinstance(a0, str): # assume a0 is a filename @@ -915,7 +916,7 @@ def copy(self): r"""Returns a copy of the session. """ # this actually *does* a copy of the internal mapping (the mapping is not reused-as is) - return Session(self._objects) + return self.__class__(self._objects) def keys(self): r""" @@ -1042,7 +1043,12 @@ def opmethod(self, other): except Exception: res_item = nan res.append((name, res_item)) - return Session(res) + try: + # XXX: print a warning? + ses = self.__class__(res) + except Exception: + ses = Session(res) + return ses opmethod.__name__ = opfullname return opmethod @@ -1072,7 +1078,12 @@ def opmethod(self): except Exception: res_array = nan res.append((k, res_array)) - return Session(res) + try: + # XXX: print a warning? + ses = self.__class__(res) + except Exception: + ses = Session(res) + return ses opmethod.__name__ = opfullname return opmethod diff --git a/larray/tests/common.py b/larray/tests/common.py index b9b11ce43..0b0131b1f 100644 --- a/larray/tests/common.py +++ b/larray/tests/common.py @@ -174,3 +174,17 @@ def must_warn(warn_cls=None, msg=None, match=None, check_file=True, check_num=Tr warning_path = caught_warnings[0].filename assert warning_path == caller_path, \ f"{warning_path} != {caller_path}" + + +@contextmanager +def must_raise(warn_cls=None, msg=None, match=None): + if msg is not None and match is not None: + raise ValueError("bad test: can't use both msg and match arguments") + elif msg is not None: + match = re.escape(msg) + + try: + with pytest.raises(warn_cls, match=match) as error: + yield error + finally: + pass diff --git a/larray/tests/data/test_session.h5 b/larray/tests/data/test_session.h5 index 028dec515689444d8cb9d19b67aea0c41500271a..f27b6bf3a98bdf16b3304f22c1e876ed3394fb5e 100644 GIT binary patch delta 3024 zcmZuzZ){Ul6o2n-!0WYak9N!skUhhn6~(68_)|o2B9Tc3sgo1}%38M8R=f7u7Bfs_ zMfo!jwdEoy3y~rS6J@r0QRYIlL}C{G(hnpq#+b&$WFHKfi7|u^;(7PJ*Vit-=GS-6 zz2}~De&?Lq+mXNC;ZuJD>`0tQKjxC=aXiOyHh}tw@KGPBhQ7E?M1B-jBD;{k=BXJp zb6!;$c9=K;KK=fLug{n8F`8M)07#278`WElLOC{yr%XjY+D0DDoNgX(sN!Vg=7xruqMMCjJ?4T44NH2xL;e*mH zU->;~|v?Otj+XH#k(Az+IKwaIfr(^!ehB zo^WV?+>wa0SEwI?^YzPdyi*^7c00Crt#h@lUFS5OF1%w$LT+7yAsax%K_1jJnPd~o zYc_qj1r^Q8w%SI7iu1_+xsk!(9JZ!sI;F|Ed2C*<=fES?dRa{_G_!d-dvODUv-C8% zaG1?IH9hGAI_NEzxRXtA!4%VQX^NBQ*+36)Mw7)0Y`#O&lRvN4%V_fW^>p4tic1A- zP$P?kE67$yPBuH{PSHL+y~&m@{LYgZY(oQPf**i>k^lEa_!RGbMOzy7Mo6%%( zFPoRxi!1G~))RV}Z1=Obpy`R24y@(OxYz`@+nml{A<`ZbzeJfZ|HdO}5gD$tBgsN; z*GaX1bk*-8xfK~SS)a}?CE0ob35)ajMqVk1s}b_GlaknQ+G<}=5B62k3jh{?Ras$u zdybN95qMH)+HC4nC*f%zKh-r+vzKPp!Ex?&%&1=V0Zq#x?cYjS6l z&Fe{>lB)H%N0}@wV{iJMzWF`LnFW*kaW*-vTzukZ9!=Y6nqRh%Q!6UcoLZ@;xwtBy zuj3(s+jx{PwA8F&!x6c^>MpwHpcxo%RXtQ@AnPF8)(4Pxtup)sl$L9!K2I6eXlX~B zy3gr$)h9`|%#fkBZgkl~er;pHQMX3361?UgByON?iyLRV9e1g2O{HBr<6GwMz-rot zY){h&u35;<4Hc{lPwEkHo74y{vtxr>ZF_2PW@Gt7$myDgGpT?6Y$Z_N_2<>@OUEq?LQc$Wl0@k37uBTWtohV;q)rruuJCWlbcc4Sb}YCo_H4v zq~nx;v*r7T+#fzrZaF})$3Ba zQdO6~D?*}mDU8s%On=&qY8%QfD2sd7XkCg|H>2goQgvO7Yn(E2j#hS>JS^})M_V4`)|b6l3PvBnmpP@cX=aapZSL1 zuSd`oBSL2H&l_FX_)`L`19P)rH@WrQ(edM><2(QXY5=u>Ie@un=m0${*H05f1d xpzg-#c>TW*P;x5>B47c)03`;1R%5Fw%Vg!GA&(W&;2K delta 2617 zcmZuzZ)_7~7{7bhZP(jm*Rd>R!>tE%>u_KTBNW6j7egkWDj#M+lsce}3QJLA8iFv# zpADSS`XH1EI-EdR7VPaN)W%3Z_@N=h)K3yZOh1`rA+aQcEHP1^d*8cWJH6(~d(YeF zp6B=bJw3LfYZf{oV@Z7m6=uz7glqeP0vX zcd~_(jxLJRecuqn zy6>HiDt!>L3Zh93wmGhmL^Z4u*N{++x>KwtM5=;KEB79==pu7mc2#M@1>Lu=X8VRR z{?|+6Ms+{FgY7%m!iC>1jhoQ@^d7b^Xv?!`aeBrTl*#ITUW_W&AlD*T9gXagy!yE1 z4i7^!l4uM^(hhf9og^tihvcyAzmtCxyi$z#mqe8&$hHY&Vws}U^UzUeSG#NL2uv@v z!(d&D*zf9h_pj|=*WamL*>b8}bu?__Nlt&d*C@{YKshHK9Kf%HHu{L1^ zL&at++boc96Fg`Rh~}(N;lnGb@SavC{CydP9}LcQva5|gOY52M+Q3#r?%2DfacSMp z9c25uSX`*NG%lw5Nja+gqNzyc+SU1%-+40BMkRY9LAAR`GUs!KWCKn`S;C_t0uTB2 z9#U;9L{e4E)Tqh|IPv^F5?M_-dR>g8XEltk30N1_F-%(De&b2ZO$)b}>1eOIUoZ*} ztx6ZdHr0oc)cy6jJv=#VgXKwcENd3s zas?GUdnt;A07H@6Lsf_e)v0}`LiYeIdBO%qZWO67x6df)+;8?h5$qw~ntD<37u~_| zb-NpU2hzkhNXyy}v9jDCE$jSQemC zqUh(w5&a-!M|#LEv#y#u@8iLHNf6DK8KvR)s<;>SD~v|wGNsWw$$Alb$4>m;&|Vy< zEgDJWB1{-uVv`1!#C3y9=4x>-!dGc8#)@DSdf~ec^VgE(l9`^eiQIs@@dWwN+-A)% zU$vA)X1bU~`a8xVG|eVTQ8(Ru%G(5m1?AboabW9olZUm18y1Is6da13ByJj{dS?t$ zzFP*V(5yx(G_yb)vk+9Z2Gv4Cr_=gJEz4{^2vyfCHjBHZHOr=d8{yXwJIbKF@n!2H z{i4n0(^DHlR}$OU2btU3iF0WIq1v6qNi@G7VCsh;v2TT@rS->MsS#I-MG4 fV@erQ%9v8dlrpB2F{O+tWlSk!N;y^Yr OK + cs['h'] = zeros_like(h) + + # trying to add an undeclared variable -> prints a warning message + with must_warn(UserWarning, msg=f"'i' is not declared in '{cs.__class__.__name__}'"): + cs['i'] = ndtest((3, 3)) + + # trying to set a variable with an object of different type -> should fail + # a) type given explicitly + # -> Axis + with must_raise(TypeError, msg="instance of Axis expected"): + cs['a'] = 0 + # -> CheckedArray + with must_raise(TypeError, msg="Expected object of type 'Array' or a scalar for the variable 'h' but got " + "object of type 'ndarray'"): + cs['h'] = h.data + # b) type deduced from the given default value + with must_raise(TypeError, msg="instance of Axis expected"): + cs['b'] = ndtest((3, 3)) + + # trying to set a CheckedArray variable using a scalar -> OK + cs['h'] = 5 + + # trying to set a CheckedArray variable using an array with axes in different order -> OK + cs['h'] = h.transpose() + assert cs.h.axes.names == h.axes.names + + # broadcasting (missing axis) is allowed + cs['h'] = ndtest(a3) + assert_array_nan_equal(cs['h']['b0'], cs['h']['b1']) + + # trying to set a CheckedArray variable using an array with wrong axes -> should fail + # a) extra axis + with must_raise(ValueError, msg="Array 'h' was declared with axes {a, b} but got array with axes {a, b, c} " + "(unexpected {c} axis)"): + cs['h'] = ndtest((a3, b2, 'c=c0..c2')) + # b) incompatible axis + msg = """\ +Incompatible axis for array 'h': +Axis(['a0', 'a1', 'a2', 'a3', 'a4'], 'a') +vs +Axis(['a0', 'a1', 'a2', 'a3'], 'a')""" + with must_raise(ValueError, msg=msg): + cs['h'] = h.append('a', 0, 'a4') + + +def test_getattr_cs(checkedsession): + test_getattr(checkedsession) + + +def test_setattr_cs(checkedsession): + cs = checkedsession + + # only change values of an array -> OK + cs.h = zeros_like(h) + + # trying to add an undeclared variable -> prints a warning message + with must_warn(UserWarning, msg=f"'i' is not declared in '{cs.__class__.__name__}'"): + cs.i = ndtest((3, 3)) + + # trying to set a variable with an object of different type -> should fail + # a) type given explicitly + # -> Axis + with must_raise(TypeError, msg="instance of Axis expected"): + cs.a = 0 + # -> CheckedArray + with must_raise(TypeError, msg="Expected object of type 'Array' or a scalar for the variable 'h' but got " + "object of type 'ndarray'"): + cs.h = h.data + # b) type deduced from the given default value + with must_raise(TypeError, msg="instance of Axis expected"): + cs.b = ndtest((3, 3)) + + # trying to set a CheckedArray variable using a scalar -> OK + cs.h = 5 + + # trying to set a CheckedArray variable using an array with axes in different order -> OK + cs.h = h.transpose() + assert cs.h.axes.names == h.axes.names + + # broadcasting (missing axis) is allowed + cs.h = ndtest(a3) + assert_array_nan_equal(cs.h['b0'], cs.h['b1']) + + # trying to set a CheckedArray variable using an array with wrong axes -> should fail + # a) extra axis + with must_raise(ValueError, msg="Array 'h' was declared with axes {a, b} but got array with axes {a, b, c} " + "(unexpected {c} axis)"): + cs.h = ndtest((a3, b2, 'c=c0..c2')) + # b) incompatible axis + msg = """\ +Incompatible axis for array 'h': +Axis(['a0', 'a1', 'a2', 'a3', 'a4'], 'a') +vs +Axis(['a0', 'a1', 'a2', 'a3'], 'a')""" + with must_raise(ValueError, msg=msg): + cs.h = h.append('a', 0, 'a4') + + +def test_add_cs(checkedsession): + cs = checkedsession + test_add(cs) + + u = Axis('u=u0..u2') + with must_warn(UserWarning, msg=f"'u' is not declared in '{cs.__class__.__name__}'", check_file=False): + cs.add(u) + + +def test_iter_cs(checkedsession): + # As of v1.0 of pydantic all fields with annotations (whether annotation-only or with a default value) + # will precede all fields without an annotation. Within their respective groups, fields remain in the + # order they were defined. + # See https://pydantic-docs.helpmanual.io/usage/models/#field-ordering + expected = [a, a2, a01, c, e, g, f, h, b, b024, anonymous, ano01, d] + assert_seq_equal(checkedsession, expected) + + +def test_filter_cs(checkedsession): + # see comment in test_iter_cs() about fields ordering + cs = checkedsession + cs.ax = 'ax' + assert_seq_equal(cs.filter(), [a, a2, a01, c, e, g, f, h, b, b024, anonymous, ano01, d, 'ax']) + assert_seq_equal(cs.filter('a*'), [a, a2, a01, anonymous, ano01, 'ax']) + assert list(cs.filter('a*', dict)) == [] + assert list(cs.filter('a*', str)) == ['ax'] + assert list(cs.filter('a*', Axis)) == [a, a2, anonymous] + assert list(cs.filter(kind=Axis)) == [a, a2, b, anonymous] + assert list(cs.filter('a01', Group)) == [a01] + assert list(cs.filter(kind=Group)) == [a01, b024, ano01] + assert_seq_equal(cs.filter(kind=Array), [e, g, f, h]) + assert list(cs.filter(kind=dict)) == [{}] + assert list(cs.filter(kind=(Axis, Group))) == [a, a2, a01, b, b024, anonymous, ano01] + + +def test_names_cs(checkedsession): + assert checkedsession.names == ['a', 'a01', 'a2', 'ano01', 'anonymous', 'b', 'b024', + 'c', 'd', 'e', 'f', 'g', 'h'] + + +def _test_io_cs(tmpdir, meta, engine, ext): + filename = f"test_{engine}.{ext}" if 'csv' not in engine else f"test_{engine}{ext}" + fpath = tmp_path(tmpdir, filename) + + is_excel_or_csv = 'excel' in engine or 'csv' in engine + + # Save and load + # ------------- + + # a) - all typed variables have a defined value + # - no extra variables are added + csession = TestCheckedSession(a=a, a2=a2, a01=a01, d=d, e=e, g=g, f=f, h=h, meta=meta) + csession.save(fpath, engine=engine) + cs = TestCheckedSession() + cs.load(fpath, engine=engine) + # --- keys --- + assert list(cs.keys()) == list(csession.keys()) + # --- variables with default values --- + assert cs.b.equals(b) + assert cs.b024.equals(b024) + assert cs.anonymous.equals(anonymous) + assert cs.ano01.equals(ano01) + assert cs.d == d + # --- typed variables --- + # Array is support by all formats + assert cs.e.equals(e) + assert cs.g.equals(g) + assert cs.f.equals(f) + assert cs.h.equals(h) + # Axis and Group are not supported by the Excel and CSV formats + if is_excel_or_csv: + assert isinstance(cs.a, NotLoaded) + assert isinstance(cs.a2, NotLoaded) + assert isinstance(cs.a01, NotLoaded) + else: + assert cs.a.equals(a) + assert cs.a2.equals(a2) + assert cs.a01.equals(a01) + # --- dtype of Axis variables --- + if not is_excel_or_csv: + for key in cs.filter(kind=Axis).keys(): + assert cs[key].dtype == csession[key].dtype + # --- metadata --- + if engine != 'pandas_excel': + assert cs.meta == meta + + # b) - not all typed variables have a defined value + # - no extra variables are added + csession = TestCheckedSession(a=a, d=d, e=e, h=h, meta=meta) + if 'csv' in engine: + import shutil + shutil.rmtree(fpath) + csession.save(fpath, engine=engine) + cs = TestCheckedSession() + cs.load(fpath, engine=engine) + # --- keys --- + assert list(cs.keys()) == list(csession.keys()) + # --- variables with default values --- + assert cs.b.equals(b) + assert cs.b024.equals(b024) + assert cs.anonymous.equals(anonymous) + assert cs.ano01.equals(ano01) + assert cs.d == d + # --- typed variables --- + # Array is support by all formats + assert cs.e.equals(e) + assert isinstance(cs.g, NotLoaded) + assert isinstance(cs.f, NotLoaded) + assert cs.h.equals(h) + # Axis and Group are not supported by the Excel and CSV formats + if is_excel_or_csv: + assert isinstance(cs.a, NotLoaded) + assert isinstance(cs.a2, NotLoaded) + assert isinstance(cs.a01, NotLoaded) + else: + assert cs.a.equals(a) + assert isinstance(cs.a2, NotLoaded) + assert isinstance(cs.a01, NotLoaded) + + # c) - all typed variables have a defined value + # - extra variables are added + i = ndtest(6) + j = ndtest((3, 3)) + k = ndtest((2, 2)) + csession = TestCheckedSession(a=a, a2=a2, a01=a01, d=d, e=e, g=g, f=f, h=h, k=k, j=j, i=i, meta=meta) + csession.save(fpath, engine=engine) + cs = TestCheckedSession() + cs.load(fpath, engine=engine) + # --- names --- + # we do not use keys() since order of undeclared variables + # may not be preserved (at least for the HDF format) + assert cs.names == csession.names + # --- extra variable --- + assert cs.i.equals(i) + assert cs.j.equals(j) + assert cs.k.equals(k) + + # Update a Group + an Axis + an array (overwrite=False) + # ----------------------------------------------------- + csession = TestCheckedSession(a=a, a2=a2, a01=a01, d=d, e=e, g=g, f=f, h=h, meta=meta) + csession.save(fpath, engine=engine) + a4 = Axis('a=0..3') + a4_01 = a3['0,1'] >> 'a01' + e2 = ndtest((a4, 'b=b0..b2')) + h2 = full_like(h, fill_value=10) + TestCheckedSession(a=a4, a01=a4_01, e=e2, h=h2).save(fpath, overwrite=False, engine=engine) + cs = TestCheckedSession() + cs.load(fpath, engine=engine) + # --- variables with default values --- + assert cs.b.equals(b) + assert cs.b024.equals(b024) + assert cs.anonymous.equals(anonymous) + assert cs.ano01.equals(ano01) + # --- typed variables --- + # Array is support by all formats + assert cs.e.equals(e2) + assert cs.h.equals(h2) + if engine == 'pandas_excel': + # Session.save() via engine='pandas_excel' always overwrite the output Excel files + # arrays 'g' and 'f' have been dropped + assert isinstance(cs.g, NotLoaded) + assert isinstance(cs.f, NotLoaded) + # Axis and Group are not supported by the Excel and CSV formats + assert isinstance(cs.a, NotLoaded) + assert isinstance(cs.a2, NotLoaded) + assert isinstance(cs.a01, NotLoaded) + elif is_excel_or_csv: + assert cs.g.equals(g) + assert cs.f.equals(f) + # Axis and Group are not supported by the Excel and CSV formats + assert isinstance(cs.a, NotLoaded) + assert isinstance(cs.a2, NotLoaded) + assert isinstance(cs.a01, NotLoaded) + else: + assert list(cs.keys()) == list(csession.keys()) + assert cs.a.equals(a4) + assert cs.a2.equals(a2) + assert cs.a01.equals(a4_01) + assert cs.g.equals(g) + assert cs.f.equals(f) + if engine != 'pandas_excel': + assert cs.meta == meta + + # Load only some objects + # ---------------------- + csession = TestCheckedSession(a=a, a2=a2, a01=a01, d=d, e=e, g=g, f=f, h=h, meta=meta) + csession.save(fpath, engine=engine) + cs = TestCheckedSession() + names_to_load = ['e', 'h'] if is_excel_or_csv else ['a', 'a01', 'a2', 'e', 'h'] + cs.load(fpath, names=names_to_load, engine=engine) + # --- keys --- + assert list(cs.keys()) == list(csession.keys()) + # --- variables with default values --- + assert cs.b.equals(b) + assert cs.b024.equals(b024) + assert cs.anonymous.equals(anonymous) + assert cs.ano01.equals(ano01) + assert cs.d == d + # --- typed variables --- + # Array is support by all formats + assert cs.e.equals(e) + assert isinstance(cs.g, NotLoaded) + assert isinstance(cs.f, NotLoaded) + assert cs.h.equals(h) + # Axis and Group are not supported by the Excel and CSV formats + if is_excel_or_csv: + assert isinstance(cs.a, NotLoaded) + assert isinstance(cs.a2, NotLoaded) + assert isinstance(cs.a01, NotLoaded) + else: + assert cs.a.equals(a) + assert cs.a2.equals(a2) + assert cs.a01.equals(a01) + + return fpath + + +@needs_pytables +def test_h5_io_cs(tmpdir, meta): + _test_io_cs(tmpdir, meta, engine='pandas_hdf', ext='h5') + + +@needs_openpyxl +def test_xlsx_pandas_io_cs(tmpdir, meta): + _test_io_cs(tmpdir, meta, engine='pandas_excel', ext='xlsx') + + +@needs_xlwings +def test_xlsx_xlwings_io_cs(tmpdir, meta): + _test_io_cs(tmpdir, meta, engine='xlwings_excel', ext='xlsx') + + +def test_csv_io_cs(tmpdir, meta): + _test_io_cs(tmpdir, meta, engine='pandas_csv', ext='csv') + + +def test_pickle_io_cs(tmpdir, meta): + _test_io_cs(tmpdir, meta, engine='pickle', ext='pkl') + + +def test_pickle_roundtrip_cs(checkedsession, meta): + cs = checkedsession + cs.meta = meta + s = pickle.dumps(cs) + res = pickle.loads(s) + assert res.equals(cs) + assert res.meta == meta + + +def test_element_equals_cs(checkedsession): + test_element_equals(checkedsession) + + +def test_eq_cs(checkedsession): + test_eq(checkedsession) + + +def test_ne_cs(checkedsession): + test_ne(checkedsession) + + +def test_sub_cs(checkedsession): + cs = checkedsession + session_cls = cs.__class__ + + # session - session + other = session_cls(a=a, a2=a2, a01=a01, e=e - 1, g=zeros_like(g), f=zeros_like(f), h=ones_like(h)) + diff = cs - other + assert isinstance(diff, session_cls) + # --- non-array variables --- + assert diff.b is b + assert diff.b024 is b024 + assert diff.a is a + assert diff.a2 is a2 + assert diff.anonymous is anonymous + assert diff.a01 is a01 + assert diff.ano01 is ano01 + assert diff.c is c + assert diff.d is d + # --- array variables --- + assert_array_nan_equal(diff.e, np.full((2, 3), 1, dtype=np.int32)) + assert_array_nan_equal(diff.g, g) + assert_array_nan_equal(diff.f, f) + assert_array_nan_equal(diff.h, h - ones_like(h)) + + # session - scalar + diff = cs - 2 + assert isinstance(diff, session_cls) + # --- non-array variables --- + assert diff.b is b + assert diff.b024 is b024 + assert diff.a is a + assert diff.a2 is a2 + assert diff.anonymous is anonymous + assert diff.a01 is a01 + assert diff.ano01 is ano01 + assert diff.c is c + assert diff.d is d + # --- non constant arrays --- + assert_array_nan_equal(diff.e, e - 2) + assert_array_nan_equal(diff.g, g - 2) + assert_array_nan_equal(diff.f, f - 2) + assert_array_nan_equal(diff.h, h - 2) + + # session - dict(Array and scalar) + other = {'e': ones_like(e), 'h': 1} + diff = cs - other + assert isinstance(diff, session_cls) + # --- non-array variables --- + assert diff.b is b + assert diff.b024 is b024 + assert diff.a is a + assert diff.a2 is a2 + assert diff.anonymous is anonymous + assert diff.a01 is a01 + assert diff.ano01 is ano01 + assert diff.c is c + assert diff.d is d + # --- non constant arrays --- + assert_array_nan_equal(diff.e, e - ones_like(e)) + assert isnan(diff.g).all() + assert isnan(diff.f).all() + assert_array_nan_equal(diff.h, h - 1) + + # session - array + axes = cs.h.axes + cs.e = ndtest(axes) + cs.g = ones_like(cs.h) + diff = cs - ones(axes) + assert isinstance(diff, session_cls) + # --- non-array variables --- + assert diff.b is b + assert diff.b024 is b024 + assert diff.a is a + assert diff.a2 is a2 + assert diff.anonymous is anonymous + assert diff.a01 is a01 + assert diff.ano01 is ano01 + assert diff.c is c + assert diff.d is d + # --- non constant arrays --- + assert_array_nan_equal(diff.e, cs.e - ones(axes)) + assert_array_nan_equal(diff.g, cs.g - ones(axes)) + assert isnan(diff.f).all() + assert_array_nan_equal(diff.h, cs.h - ones(axes)) + + +def test_rsub_cs(checkedsession): + cs = checkedsession + session_cls = cs.__class__ + + # scalar - session + diff = 2 - cs + assert isinstance(diff, session_cls) + # --- non-array variables --- + assert diff.b is b + assert diff.b024 is b024 + assert diff.a is a + assert diff.a2 is a2 + assert diff.anonymous is anonymous + assert diff.a01 is a01 + assert diff.ano01 is ano01 + assert diff.c is c + assert diff.d is d + # --- non constant arrays --- + assert_array_nan_equal(diff.e, 2 - e) + assert_array_nan_equal(diff.g, 2 - g) + assert_array_nan_equal(diff.f, 2 - f) + assert_array_nan_equal(diff.h, 2 - h) + + # dict(Array and scalar) - session + other = {'e': ones_like(e), 'h': 1} + diff = other - cs + assert isinstance(diff, session_cls) + # --- non-array variables --- + assert diff.b is b + assert diff.b024 is b024 + assert diff.a is a + assert diff.a2 is a2 + assert diff.anonymous is anonymous + assert diff.a01 is a01 + assert diff.ano01 is ano01 + assert diff.c is c + assert diff.d is d + # --- non constant arrays --- + assert_array_nan_equal(diff.e, ones_like(e) - e) + assert isnan(diff.g).all() + assert isnan(diff.f).all() + assert_array_nan_equal(diff.h, 1 - h) + + +def test_neg_cs(checkedsession): + cs = checkedsession + neg_cs = -cs + # --- non-array variables --- + assert isnan(neg_cs.b) + assert isnan(neg_cs.b024) + assert isnan(neg_cs.a) + assert isnan(neg_cs.a2) + assert isnan(neg_cs.anonymous) + assert isnan(neg_cs.a01) + assert isnan(neg_cs.ano01) + assert isnan(neg_cs.c) + assert isnan(neg_cs.d) + # --- non constant arrays --- + assert_array_nan_equal(neg_cs.e, -e) + assert_array_nan_equal(neg_cs.g, -g) + assert_array_nan_equal(neg_cs.f, -f) + assert_array_nan_equal(neg_cs.h, -h) + + +if __name__ == "__main__": + pytest.main() diff --git a/larray/tests/test_session.py b/larray/tests/test_session.py index bc11d93f9..486df9421 100644 --- a/larray/tests/test_session.py +++ b/larray/tests/test_session.py @@ -12,8 +12,8 @@ from larray.tests.common import (assert_array_nan_equal, inputpath, tmp_path, needs_xlwings, needs_pytables, needs_openpyxl, must_warn) from larray.inout.common import _supported_scalars_types -from larray import (Session, Axis, Array, Group, isnan, zeros_like, ndtest, ones_like, ones, full, - local_arrays, global_arrays, arrays) +from larray import (Session, Axis, Array, Group, isnan, zeros_like, ndtest, ones_like, + ones, full, full_like, stack, local_arrays, global_arrays, arrays) # avoid flake8 errors @@ -37,18 +37,25 @@ def assert_seq_equal(got, expected): a = Axis('a=a0..a2') a2 = Axis('a=a0..a4') +a3 = Axis('a=a0..a3') anonymous = Axis(4) a01 = a['a0,a1'] >> 'a01' ano01 = a['a0,a1'] b = Axis('b=0..4') +b2 = Axis('b=b0..b4') b024 = b[[0, 2, 4]] >> 'b024' c = 'c' d = {} e = ndtest([(2, 'a'), (3, 'b')]) _e = ndtest((3, 3)) -f = ndtest((Axis(3), Axis(2))) +f = ndtest((Axis(3), Axis(2)), dtype=float) g = ndtest([(2, 'a'), (4, 'b')]) -h = ndtest(('a=a0..a2', 'b=b0..b4')) +h = ndtest((a3, b2)) +k = ndtest((3, 3)) + +# ########################### # +# SESSION # +# ########################### # @pytest.fixture() @@ -58,12 +65,12 @@ def session(): def test_init_session(meta): - s = Session(b, b024, a, a01, a2=a2, anonymous=anonymous, ano01=ano01, c=c, d=d, e=e, f=f, g=g, h=h) - assert s.names == ['a', 'a01', 'a2', 'ano01', 'anonymous', 'b', 'b024', 'c', 'd', 'e', 'f', 'g', 'h'] + s = Session(b, b024, a, a01, a2=a2, anonymous=anonymous, ano01=ano01, c=c, d=d, e=e, g=g, f=f, h=h) + assert list(s.keys()) == ['b', 'b024', 'a', 'a01', 'a2', 'anonymous', 'ano01', 'c', 'd', 'e', 'g', 'f', 'h'] # TODO: format auto-detection does not work in this case # s = Session('test_session_csv') - # assert s.names == ['e', 'f', 'g'] + # assert list(s.keys()) == ['e', 'f', 'g'] # metadata s = Session(b, b024, a, a01, a2=a2, anonymous=anonymous, ano01=ano01, c=c, d=d, e=e, f=f, g=g, h=h, meta=meta) @@ -73,14 +80,14 @@ def test_init_session(meta): @needs_xlwings def test_init_session_xlsx(): s = Session(inputpath('demography_eurostat.xlsx')) - assert s.names == ['births', 'deaths', 'immigration', 'population', - 'population_5_countries', 'population_benelux'] + assert list(s.keys()) == ['population', 'population_benelux', 'population_5_countries', + 'births', 'deaths', 'immigration'] @needs_pytables def test_init_session_hdf(): s = Session(inputpath('test_session.h5')) - assert s.names == ['a', 'a01', 'a2', 'ano01', 'anonymous', 'b', 'b024', 'e', 'f', 'g', 'h'] + assert list(s.keys()) == ['e', 'f', 'g', 'h', 'a', 'a2', 'anonymous', 'b', 'a01', 'ano01', 'b024'] def test_getitem(session): @@ -175,13 +182,16 @@ def test_names(session): assert session.names == ['a', 'a01', 'a2', 'ano01', 'anonymous', 'b', 'b024', 'c', 'd', 'e', 'f', 'g', 'h'] # add them in the "wrong" order - session.add(i='i') session.add(j='j') + session.add(i='i') assert session.names == ['a', 'a01', 'a2', 'ano01', 'anonymous', 'b', 'b024', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] -def _test_io(fpath, session, meta, engine): +def _test_io(tmpdir, session, meta, engine, ext): + filename = f"test_{engine}.{ext}" if 'csv' not in engine else f"test_{engine}{ext}" + fpath = tmp_path(tmpdir, filename) + is_excel_or_csv = 'excel' in engine or 'csv' in engine kind = Array if is_excel_or_csv else (Axis, Group, Array) + _supported_scalars_types @@ -203,21 +213,22 @@ def _test_io(fpath, session, meta, engine): assert s.meta == meta # update a Group + an Axis + an array (overwrite=False) - a3 = Axis('a=0..3') - a3_01 = a3['0,1'] >> 'a01' - e2 = ndtest((a3, 'b=b0..b2')) - Session(a=a3, a01=a3_01, e=e2).save(fpath, overwrite=False, engine=engine) + a4 = Axis('a=0..3') + a4_01 = a3['0,1'] >> 'a01' + e2 = ndtest((a4, 'b=b0..b2')) + h2 = full_like(h, fill_value=10) + Session(a=a4, a01=a4_01, e=e2, h=h2).save(fpath, overwrite=False, engine=engine) s = Session() s.load(fpath, engine=engine) if engine == 'pandas_excel': # Session.save() via engine='pandas_excel' always overwrite the output Excel files - assert s.names == ['e'] + assert s.names == ['e', 'h'] elif is_excel_or_csv: assert s.names == ['e', 'f', 'g', 'h'] else: assert s.names == session.names - assert s['a'].equals(a3) - assert s['a01'].equals(a3_01) + assert s['a'].equals(a4) + assert s['a01'].equals(a4_01) assert_array_nan_equal(s['e'], e2) if engine != 'pandas_excel': assert s.meta == meta @@ -225,12 +236,14 @@ def _test_io(fpath, session, meta, engine): # load only some objects session.save(fpath, engine=engine) s = Session() - names_to_load = ['e', 'f'] if is_excel_or_csv else ['a', 'a01', 'a2', 'anonymous', 'e', 'f'] + names_to_load = ['e', 'f'] if is_excel_or_csv else ['a', 'a01', 'a2', 'anonymous', 'e', 'f', 's_bool', 's_int'] s.load(fpath, names=names_to_load, engine=engine) assert s.names == names_to_load if engine != 'pandas_excel': assert s.meta == meta + return fpath + def _add_scalars_to_session(s): # 's' for scalar @@ -247,7 +260,6 @@ def _add_scalars_to_session(s): @needs_pytables def test_h5_io(tmpdir, session, meta): session = _add_scalars_to_session(session) - fpath = tmp_path(tmpdir, 'test_session.h5') msg = "\nyour performance may suffer as PyTables will pickle object types" regex = re.compile(msg, flags=re.MULTILINE) @@ -255,27 +267,24 @@ def test_h5_io(tmpdir, session, meta): # for some reason the PerformanceWarning is not detected as such, so this does not work: # with pytest.warns(tables.PerformanceWarning): with pytest.warns(Warning, match=regex): - _test_io(fpath, session, meta, engine='pandas_hdf') + _test_io(tmpdir, session, meta, engine='pandas_hdf', ext='h5') @needs_openpyxl def test_xlsx_pandas_io(tmpdir, session, meta): - fpath = tmp_path(tmpdir, 'test_session.xlsx') - _test_io(fpath, session, meta, engine='pandas_excel') + _test_io(tmpdir, session, meta, engine='pandas_excel', ext='xlsx') @needs_xlwings def test_xlsx_xlwings_io(tmpdir, session, meta): - fpath = tmp_path(tmpdir, 'test_session.xlsx') - _test_io(fpath, session, meta, engine='xlwings_excel') + _test_io(tmpdir, session, meta, engine='xlwings_excel', ext='xlsx') def test_csv_io(tmpdir, session, meta): - fpath = tmp_path(tmpdir, 'test_session_csv') try: - _test_io(fpath, session, meta, engine='pandas_csv') + fpath = _test_io(tmpdir, session, meta, engine='pandas_csv', ext='csv') - names = session.filter(kind=Array).names + names = Session({k: v for k, v in session.items() if isinstance(v, Array)}).names # test loading with a pattern pattern = os.path.join(fpath, '*.csv') @@ -303,8 +312,16 @@ def test_csv_io(tmpdir, session, meta): def test_pickle_io(tmpdir, session, meta): session = _add_scalars_to_session(session) - fpath = tmp_path(tmpdir, 'test_session.pkl') - _test_io(fpath, session, meta, engine='pickle') + _test_io(tmpdir, session, meta, engine='pickle', ext='pkl') + + +def test_pickle_roundtrip(session, meta): + original = session.filter(kind=Array) + original.meta = meta + s = pickle.dumps(original) + res = pickle.loads(s) + assert res.equals(original) + assert res.meta == meta def test_to_globals(session): @@ -337,84 +354,128 @@ def test_to_globals(session): def test_element_equals(session): - sess = session.filter(kind=(Axis, Group, Array)) - expected = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous), - ('a01', a01), ('ano01', ano01), ('e', e), ('g', g), ('f', f), ('h', h)]) - assert all(sess.element_equals(expected)) - - other = Session([('a', a), ('a2', a2), ('anonymous', anonymous), - ('a01', a01), ('ano01', ano01), ('e', e), ('f', f), ('h', h)]) - res = sess.element_equals(other) - assert res.ndim == 1 - assert res.axes.names == ['name'] - assert np.array_equal(res.axes.labels[0], ['b', 'b024', 'a', 'a2', 'anonymous', 'a01', 'ano01', - 'e', 'g', 'f', 'h']) - assert list(res) == [False, False, True, True, True, True, True, True, False, True, True] - - e2 = e.copy() - e2.i[1, 1] = 42 - other = Session([('a', a), ('a2', a2), ('anonymous', anonymous), - ('a01', a01), ('ano01', ano01), ('e', e2), ('f', f), ('h', h)]) - res = sess.element_equals(other) - assert res.axes.names == ['name'] - assert np.array_equal(res.axes.labels[0], ['b', 'b024', 'a', 'a2', 'anonymous', 'a01', 'ano01', - 'e', 'g', 'f', 'h']) - assert list(res) == [False, False, True, True, True, True, True, False, False, True, True] + session_cls = session.__class__ + other_session = session_cls([(key, value) for key, value in session.items()]) + + keys = [key for key, value in session.items() if isinstance(value, (Axis, Group, Array))] + expected_res = full(Axis(keys, 'name'), fill_value=True, dtype=bool) + + # ====== same sessions ====== + res = session.element_equals(other_session) + assert res.axes == expected_res.axes + assert res.equals(expected_res) + + # ====== session with missing/extra items ====== + # delete some items + for deleted_key in ['b', 'b024', 'g']: + del other_session[deleted_key] + expected_res[deleted_key] = False + # add one item + other_session['k'] = k + expected_res = expected_res.append('name', False, label='k') + + res = session.element_equals(other_session) + assert res.axes == expected_res.axes + assert res.equals(expected_res) + + # ====== session with a modified array ====== + h2 = h.copy() + h2['a1', 'b1'] = 42 + other_session['h'] = h2 + expected_res['h'] = False + + res = session.element_equals(other_session) + assert res.axes == expected_res.axes + assert res.equals(expected_res) + + +def to_boolean_array_eq(res): + return stack([(key, item.all() if isinstance(item, Array) else item) + for key, item in res.items()], 'name') def test_eq(session): - sess = session.filter(kind=(Axis, Group, Array)) - expected = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous), - ('a01', a01), ('ano01', ano01), ('e', e), ('g', g), ('f', f), ('h', h)]) - assert all([item.all() if isinstance(item, Array) else item - for item in (sess == expected).values()]) - - other = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous), - ('a01', a01), ('ano01', ano01), ('e', e), ('f', f), ('h', h)]) - res = sess == other - assert list(res.keys()) == ['b', 'b024', 'a', 'a2', 'anonymous', 'a01', 'ano01', - 'e', 'g', 'f', 'h'] - assert [item.all() if isinstance(item, Array) else item - for item in res.values()] == [True, True, True, True, True, True, True, True, False, True, True] - - e2 = e.copy() - e2.i[1, 1] = 42 - other = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous), - ('a01', a01), ('ano01', ano01), ('e', e2), ('f', f), ('h', h)]) - res = sess == other - assert [item.all() if isinstance(item, Array) else item - for item in res.values()] == [True, True, True, True, True, True, True, False, False, True, True] + session_cls = session.__class__ + other_session = session_cls([(key, value) for key, value in session.items()]) + expected_res = full(Axis(list(session.keys()), 'name'), fill_value=True, dtype=bool) + + # ====== same sessions ====== + res = session == other_session + res = to_boolean_array_eq(res) + assert res.axes == expected_res.axes + assert res.equals(expected_res) + + # ====== session with missing/extra items ====== + del other_session['g'] + expected_res['g'] = False + other_session['k'] = k + expected_res = expected_res.append('name', False, label='k') + + res = session == other_session + res = to_boolean_array_eq(res) + assert res.axes == expected_res.axes + assert res.equals(expected_res) + + # ====== session with a modified array ====== + h2 = h.copy() + h2['a1', 'b1'] = 42 + other_session['h'] = h2 + expected_res['h'] = False + + res = session == other_session + assert res['h'].equals(session['h'] == other_session['h']) + res = to_boolean_array_eq(res) + assert res.axes == expected_res.axes + assert res.equals(expected_res) + + +def to_boolean_array_ne(res): + return stack([(key, item.any() if isinstance(item, Array) else item) + for key, item in res.items()], 'name') def test_ne(session): - sess = session.filter(kind=(Axis, Group, Array)) - expected = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous), - ('a01', a01), ('ano01', ano01), ('e', e), ('g', g), ('f', f), ('h', h)]) - assert ([(~item).all() if isinstance(item, Array) else not item - for item in (sess != expected).values()]) - - other = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous), - ('a01', a01), ('ano01', ano01), ('e', e), ('f', f), ('h', h)]) - res = sess != other - assert list(res.keys()) == ['b', 'b024', 'a', 'a2', 'anonymous', 'a01', 'ano01', - 'e', 'g', 'f', 'h'] - assert [(~item).all() if isinstance(item, Array) else not item - for item in res.values()] == [True, True, True, True, True, True, True, True, False, True, True] - - e2 = e.copy() - e2.i[1, 1] = 42 - other = Session([('b', b), ('b024', b024), ('a', a), ('a2', a2), ('anonymous', anonymous), - ('a01', a01), ('ano01', ano01), ('e', e2), ('f', f), ('h', h)]) - res = sess != other - assert [(~item).all() if isinstance(item, Array) else not item - for item in res.values()] == [True, True, True, True, True, True, True, False, False, True, True] + session_cls = session.__class__ + other_session = session_cls([(key, value) for key, value in session.items()]) + expected_res = full(Axis(list(session.keys()), 'name'), fill_value=False, dtype=bool) + + # ====== same sessions ====== + res = session != other_session + res = to_boolean_array_ne(res) + assert res.axes == expected_res.axes + assert res.equals(expected_res) + + # ====== session with missing/extra items ====== + del other_session['g'] + expected_res['g'] = True + other_session['k'] = k + expected_res = expected_res.append('name', True, label='k') + + res = session != other_session + res = to_boolean_array_ne(res) + assert res.axes == expected_res.axes + assert res.equals(expected_res) + + # ====== session with a modified array ====== + h2 = h.copy() + h2['a1', 'b1'] = 42 + other_session['h'] = h2 + expected_res['h'] = True + + res = session != other_session + assert res['h'].equals(session['h'] != other_session['h']) + res = to_boolean_array_ne(res) + assert res.axes == expected_res.axes + assert res.equals(expected_res) def test_sub(session): sess = session # session - session - other = Session({'e': e - 1, 'f': ones_like(f)}) + other = Session({'e': e, 'f': f}) + other['e'] = e - 1 + other['f'] = ones_like(f) diff = sess - other assert_array_nan_equal(diff['e'], np.full((2, 3), 1, dtype=np.int32)) assert_array_nan_equal(diff['f'], f - ones_like(f)) @@ -444,12 +505,12 @@ def test_sub(session): # session - array axes = [a, b] - sess = Session([('a', a), ('a01', a01), ('c', c), ('e', ndtest(axes)), - ('f', full(axes, fill_value=3)), ('g', ndtest('c=c0..c2'))]) - diff = sess - ones(axes) - assert_array_nan_equal(diff['e'], sess['e'] - ones(axes)) - assert_array_nan_equal(diff['f'], sess['f'] - ones(axes)) - assert_array_nan_equal(diff['g'], sess['g'] - ones(axes)) + other = Session([('a', a), ('a01', a01), ('c', c), ('e', ndtest((a, b))), + ('f', full((a, b), fill_value=3)), ('g', ndtest('c=c0..c2'))]) + diff = other - ones(axes) + assert_array_nan_equal(diff['e'], other['e'] - ones(axes)) + assert_array_nan_equal(diff['f'], other['f'] - ones(axes)) + assert_array_nan_equal(diff['g'], other['g'] - ones(axes)) assert diff.a is a assert diff.a01 is a01 assert diff.c is c @@ -480,7 +541,11 @@ def test_rsub(session): def test_div(session): sess = session - other = Session({'e': e - 1, 'f': f + 1}) + session_cls = session.__class__ + + other = session_cls({'e': e, 'f': f}) + other['e'] = e - 1 + other['f'] = f + 1 with must_warn(RuntimeWarning, msg="divide by zero encountered during operation"): res = sess / other @@ -527,15 +592,6 @@ def test_rdiv(session): assert res.c is c -def test_pickle_roundtrip(session, meta): - original = session.filter(kind=Array) - original.meta = meta - s = pickle.dumps(original) - res = pickle.loads(s) - assert res.equals(original) - assert res.meta == meta - - def test_local_arrays(): h = ndtest(2) _h = ndtest(3) @@ -554,12 +610,12 @@ def test_local_arrays(): def test_global_arrays(): # exclude private global arrays s = global_arrays() - s_expected = Session([('e', e), ('f', f), ('g', g), ('h', h)]) + s_expected = Session([('e', e), ('f', f), ('g', g), ('h', h), ('k', k)]) assert s.equals(s_expected) # all global arrays s = global_arrays(include_private=True) - s_expected = Session([('e', e), ('_e', _e), ('f', f), ('g', g), ('h', h)]) + s_expected = Session([('e', e), ('_e', _e), ('f', f), ('g', g), ('h', h), ('k', k)]) assert s.equals(s_expected) @@ -569,12 +625,12 @@ def test_arrays(): # exclude private arrays s = arrays() - s_expected = Session([('e', e), ('f', f), ('g', g), ('h', h), ('i', i)]) + s_expected = Session([('e', e), ('f', f), ('g', g), ('h', h), ('i', i), ('k', k)]) assert s.equals(s_expected) # all arrays s = arrays(include_private=True) - s_expected = Session([('_e', _e), ('_i', _i), ('e', e), ('f', f), ('g', g), ('h', h), ('i', i)]) + s_expected = Session([('_e', _e), ('_i', _i), ('e', e), ('f', f), ('g', g), ('h', h), ('i', i), ('k', k)]) assert s.equals(s_expected) diff --git a/make_release.py b/make_release.py index a01d2c92b..1b6575228 100644 --- a/make_release.py +++ b/make_release.py @@ -39,9 +39,11 @@ def update_metapackage(local_repository, release_name, public_release=True, **ex print(f'Updating larrayenv metapackage to version {version}') # - excluded versions 5.0 and 5.1 of ipykernel because these versions make the console useless after any exception # https://github.com/larray-project/larray-editor/issues/166 + # - pydantic: cannot define numpy ndarray / pandas obj / LArray field with default value + # since version 1.6 check_call(['conda', 'metapackage', 'larrayenv', version, '--dependencies', f'larray =={version}', f'larray-editor =={version}', f'larray_eurostat =={version}', - "qtconsole", "matplotlib", "pyqt", "qtpy", "pytables", + "qtconsole", "matplotlib", "pyqt", "qtpy", "pytables", "pydantic <=1.5", "xlsxwriter", "xlrd", "xlwt", "openpyxl", "xlwings", "ipykernel !=5.0,!=5.1.0", '--user', 'larray-project', '--home', 'http://github.com/larray-project/larray', diff --git a/setup.py b/setup.py index 64911a30f..6f0cbcb77 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,3 @@ -from __future__ import print_function - import os from setuptools import setup, find_packages