diff --git a/src/caskade/backend.py b/src/caskade/backend.py index 60de9c3..5ff8394 100644 --- a/src/caskade/backend.py +++ b/src/caskade/backend.py @@ -224,8 +224,5 @@ def _logit_jax(self, array): def _logit_numpy(self, array): return np.log(array / (1 - array)) - def sqrt(self, array): - return self.module.sqrt(array) - backend = Backend() diff --git a/src/caskade/base.py b/src/caskade/base.py index 44a0ba3..bc27881 100644 --- a/src/caskade/base.py +++ b/src/caskade/base.py @@ -61,6 +61,7 @@ def __init__( self, name: Optional[str] = None, link: Optional[Union["Node", tuple["Node"]]] = None, + description: str = "", ): if name is None: name = self.__class__.__name__ @@ -73,6 +74,7 @@ def __init__( self._parents = set() self._active = False self._type = "node" + self.description = description self.meta = meta() self.saveattrs = set() if link is not None: diff --git a/src/caskade/module.py b/src/caskade/module.py index 4b61acc..1ed1697 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -67,8 +67,8 @@ def otherfun(self, x, c = None): ) # These tuples will not be converted to NodeTuple objects graphviz_types = {"module": {"style": "solid", "color": "black", "shape": "ellipse"}} - def __init__(self, name: Optional[str] = None): - super().__init__(name=name) + def __init__(self, name: Optional[str] = None, **kwargs): + super().__init__(name=name, **kwargs) self.dynamic_params = () self.all_dynamic_value = True self.pointer_params = () diff --git a/src/caskade/param.py b/src/caskade/param.py index f43ab3d..cf8b29a 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -103,8 +103,9 @@ def __init__( dynamic_value: Optional[Union[ArrayLike, float, int]] = None, dtype: Optional[Any] = None, device: Optional[Any] = None, + **kwargs, ): - super().__init__(name=name) + super().__init__(name=name, **kwargs) if value is not None and dynamic_value is not None: raise ParamConfigurationError("Cannot set both value and dynamic value") if isinstance(value, dynamic):