Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/caskade/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@
return self.module.all(array)

def log(self, array):
return self.module.log(array)

Check warning on line 201 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS ubuntu-latest - Backend torch

invalid value encountered in log

Check warning on line 201 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS windows-latest - Backend numpy

invalid value encountered in log

Check warning on line 201 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS ubuntu-latest - Backend numpy

invalid value encountered in log

Check warning on line 201 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS macOS-latest - Backend numpy

invalid value encountered in log

def exp(self, array):
return self.module.exp(array)

Check warning on line 204 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS ubuntu-latest - Backend torch

overflow encountered in exp

Check warning on line 204 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS windows-latest - Backend numpy

overflow encountered in exp

Check warning on line 204 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS ubuntu-latest - Backend numpy

overflow encountered in exp

Check warning on line 204 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS macOS-latest - Backend numpy

overflow encountered in exp

def sum(self, array, axis=None):
return self.module.sum(array, axis=axis)
Expand All @@ -213,7 +213,7 @@
return self.jax.nn.sigmoid(array)

def _sigmoid_numpy(self, array):
return 1 / (1 + self.module.exp(-array))

Check warning on line 216 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS ubuntu-latest - Backend torch

overflow encountered in exp

Check warning on line 216 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS windows-latest - Backend numpy

overflow encountered in exp

Check warning on line 216 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS ubuntu-latest - Backend numpy

overflow encountered in exp

Check warning on line 216 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS macOS-latest - Backend numpy

overflow encountered in exp

def _logit_torch(self, array):
return self.module.logit(array)
Expand All @@ -224,8 +224,5 @@
def _logit_numpy(self, array):
return np.log(array / (1 - array))

def sqrt(self, array):
return self.module.sqrt(array)


backend = Backend()
2 changes: 2 additions & 0 deletions src/caskade/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
self,
name: Optional[str] = None,
link: Optional[Union["Node", tuple["Node"]]] = None,
description: str = "",
):
if name is None:
name = self.__class__.__name__
Expand All @@ -73,6 +74,7 @@ def __init__(
self._parents = set()
self._active = False
self._type = "node"
self.description = description
self.meta = meta()
self.saveattrs = set()
if link is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/caskade/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def otherfun(self, x, c = None):
) # These tuples will not be converted to NodeTuple objects
graphviz_types = {"module": {"style": "solid", "color": "black", "shape": "ellipse"}}

def __init__(self, name: Optional[str] = None):
super().__init__(name=name)
def __init__(self, name: Optional[str] = None, **kwargs):
super().__init__(name=name, **kwargs)
self.dynamic_params = ()
self.all_dynamic_value = True
self.pointer_params = ()
Expand Down
3 changes: 2 additions & 1 deletion src/caskade/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ def __init__(
dynamic_value: Optional[Union[ArrayLike, float, int]] = None,
dtype: Optional[Any] = None,
device: Optional[Any] = None,
**kwargs,
):
super().__init__(name=name)
super().__init__(name=name, **kwargs)
if value is not None and dynamic_value is not None:
raise ParamConfigurationError("Cannot set both value and dynamic value")
if isinstance(value, dynamic):
Expand Down
Loading