Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/notebooks/AdvancedGuide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@
"x, y = torch.meshgrid(torch.linspace(0, 10, 100), torch.linspace(0, 10, 100), indexing=\"ij\")\n",
"\n",
"# Batching on all dims using batched tensor input\n",
"params = G.build_params_array()\n",
"params = G.get_values()\n",
"params = params.repeat(5, 1) # 5 copies of the same params\n",
"params[:, 3] = torch.linspace(0.0, 3.14 / 2, 5) # phi\n",
"params[:, 4] = torch.linspace(0.5, 4.0, 5) # sigma\n",
Expand Down
32 changes: 27 additions & 5 deletions docs/source/notebooks/BeginnersGuide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
"secondsim.to_dynamic() # all params owned by secondsim are now dynamic\n",
"secondsim.sigma.to_static() # sigma is now static\n",
"secondsim.I0.to_static() # I0 is now static\n",
"params = secondsim.build_params_array() # automatically build a tensor for the dynamic params\n",
"params = secondsim.get_values() # automatically build a tensor for the dynamic params\n",
"plt.imshow(secondsim.brightness(x, y, params), origin=\"lower\")\n",
"plt.axis(\"off\")\n",
"plt.show()\n",
Expand All @@ -129,21 +129,21 @@
"source": [
"fig, ax = plt.subplots(1, 3, figsize=(15, 5))\n",
"# List of tensors\n",
"params_list = secondsim.build_params_list()\n",
"params_list = secondsim.get_values(\"list\")\n",
"print(\"Params list:\", params_list)\n",
"ax[0].imshow(secondsim.brightness(x, y, params_list), origin=\"lower\")\n",
"ax[0].axis(\"off\")\n",
"ax[0].set_title(\"List of tensors\")\n",
"\n",
"# Single flattened tensor\n",
"params_tensor = secondsim.build_params_array()\n",
"params_tensor = secondsim.get_values(\"array\")\n",
"print(\"Params tensor:\", params_tensor)\n",
"ax[1].imshow(secondsim.brightness(x, y, params_tensor), origin=\"lower\")\n",
"ax[1].axis(\"off\")\n",
"ax[1].set_title(\"Single flattened tensor\")\n",
"\n",
"# Dictionary of tensors, using attribute names of either Param or Module objects\n",
"params_dict = secondsim.build_params_dict()\n",
"params_dict = secondsim.get_values(\"dict\")\n",
"print(\"Params dict:\", params_dict)\n",
"ax[2].imshow(secondsim.brightness(x, y, params_dict), origin=\"lower\")\n",
"ax[2].axis(\"off\")\n",
Expand Down Expand Up @@ -188,6 +188,28 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Update values in simulator\n",
"\n",
"As we run sampling and optimization, we will want to change the stored values in the simulator. While you can update them individually with `param.value = new_value` there are also efficient ways of updating the parameters using the same format as getting the parameters!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"current_values = secondsim.get_values()\n",
"# something like: new_values = scipy.optimize.minimize(secondsim.likelihood, current_values).x\n",
"new_values = current_values\n",
"secondsim.set_values(new_values) # Yay, it's now updated.\n",
"# This will work for array, list, or dict formats"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -271,7 +293,7 @@
"outputs": [],
"source": [
"# same params as before since secondsim is all static or pointers to firstsim\n",
"plt.imshow(combinedsim.brightness(x, y, combinedsim.build_params_array()), origin=\"lower\")\n",
"plt.imshow(combinedsim.brightness(x, y, combinedsim.get_values()), origin=\"lower\")\n",
"plt.axis(\"off\")\n",
"plt.title(\"Combined brightness\")\n",
"plt.show()"
Expand Down
20 changes: 10 additions & 10 deletions docs/source/notebooks/WorkedExample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@
" likelihood.data = data[i] # Update the data for each observation\n",
"\n",
" # Fit the model\n",
" x0 = likelihood.build_params_array()\n",
" x0 = likelihood.get_values(\"array\") # default is \"array\", we just added it for show\n",
" x0 += (\n",
" torch.randn_like(x0) * x0 * 0.05\n",
" ) # Add some noise to the initial guess since we cant start at the true values\n",
Expand Down Expand Up @@ -406,7 +406,7 @@
"outputs": [],
"source": [
"# Fit light curve\n",
"x0 = likelihood2.build_params_array()\n",
"x0 = likelihood2.get_values()\n",
"x0 += (\n",
" torch.randn_like(x0) * x0 * 0.05\n",
") # Add some noise to the initial guess since we cant start at the true values\n",
Expand Down Expand Up @@ -445,7 +445,7 @@
"outputs": [],
"source": [
"# extract light curve\n",
"likelihood2.fill_dynamic_values(torch.tensor(res.x))\n",
"likelihood2.set_values(torch.tensor(res.x))\n",
"likelihood2.to_static(children_only=False)\n",
"light_curve_flux = []\n",
"light_curve_sigma = []\n",
Expand All @@ -454,7 +454,7 @@
" model.models[0].flux.to_dynamic()\n",
"\n",
"# Compute uncertainty using inverse Hessian\n",
"hess = -hessian(likelihood2, likelihood2.build_params_array(), strict=True)\n",
"hess = -hessian(likelihood2, likelihood2.get_values(), strict=True)\n",
"hess_inv = torch.linalg.inv(hess) # Invert the Hessian to get the covariance matrix\n",
"light_curve_sigma = torch.sqrt(torch.diag(hess_inv)).numpy()"
]
Expand Down Expand Up @@ -552,7 +552,7 @@
"outputs": [],
"source": [
"# Fit light curve\n",
"x0 = likelihood2.build_params_array()\n",
"x0 = likelihood2.get_values()\n",
"x0 += (\n",
" torch.randn_like(x0) * x0 * 0.05\n",
") # Add some noise to the initial guess since we cant start at the true values\n",
Expand Down Expand Up @@ -591,7 +591,7 @@
"outputs": [],
"source": [
"# extract light curve\n",
"likelihood2.fill_dynamic_values(torch.tensor(res.x))\n",
"likelihood2.set_values(torch.tensor(res.x))\n",
"light_curve_flux = []\n",
"light_curve_sigma = []\n",
"for model in secondmodel.models:\n",
Expand Down Expand Up @@ -637,7 +637,7 @@
"source": [
"likelihood2.to_static(False)\n",
"lightcurvemodel.to_dynamic()\n",
"fit_vals = likelihood2.build_params_array()\n",
"fit_vals = likelihood2.get_values()\n",
"hess = -hessian(likelihood2, fit_vals, strict=True)\n",
"hess_inv = torch.linalg.inv(hess) # Invert the Hessian to get the covariance matrix\n",
"light_curve_sigma = torch.sqrt(torch.diag(hess_inv).abs()).numpy()\n",
Expand Down Expand Up @@ -732,7 +732,7 @@
" return vsim(torch.as_tensor(x, dtype=torch.float32)).numpy()\n",
"\n",
"\n",
"x0 = likelihood2.build_params_array()\n",
"x0 = likelihood2.get_values()\n",
"nwalkers = 32\n",
"ndim = len(x0)\n",
"\n",
Expand Down Expand Up @@ -761,8 +761,8 @@
"Galaxy.to_dynamic()\n",
"true_values = (\n",
" [SN.x0.value.item(), SN.y0.value.item()]\n",
" + list(SN_lightcurve.build_params_array().numpy())\n",
" + list(Galaxy.build_params_array().numpy())\n",
" + list(SN_lightcurve.get_values().numpy())\n",
" + list(Galaxy.get_values().numpy())\n",
")\n",
"chain_mh = sampler.get_chain(flat=True)\n",
"fig, axarr = plt.subplots(ndim, ndim, figsize=(12, 12))\n",
Expand Down
14 changes: 6 additions & 8 deletions src/caskade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
ParamTypeError,
ActiveStateError,
FillParamsError,
FillDynamicParamsError,
FillDynamicParamsArrayError,
FillDynamicParamsSequenceError,
FillDynamicParamsMappingError,
FillParamsArrayError,
FillParamsSequenceError,
FillParamsMappingError,
)
from .warnings import CaskadeWarning, InvalidValueWarning, SaveStateWarning

Expand Down Expand Up @@ -53,10 +52,9 @@
"ParamTypeError",
"ActiveStateError",
"FillParamsError",
"FillDynamicParamsError",
"FillDynamicParamsArrayError",
"FillDynamicParamsSequenceError",
"FillDynamicParamsMappingError",
"FillParamsArrayError",
"FillParamsSequenceError",
"FillParamsMappingError",
"CaskadeWarning",
"InvalidValueWarning",
"SaveStateWarning",
Expand Down
18 changes: 9 additions & 9 deletions src/caskade/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
self.make_array = self._make_array_torch
self._array_type = self._array_type_torch
self.concatenate = self._concatenate_torch
self.copy = self._copy_torch
self.tolist = self._tolist_torch
self.view = self._view_torch
self.detach = self._detach_torch
self.as_array = self._as_array_torch
self.to = self._to_torch
self.to_numpy = self._to_numpy_torch
Expand All @@ -57,9 +57,9 @@
self.make_array = self._make_array_jax
self._array_type = self._array_type_jax
self.concatenate = self._concatenate_jax
self.copy = self._copy_jax
self.tolist = self._tolist_jax
self.view = self._view_jax
self.detach = self._detach_jax
self.as_array = self._as_array_jax
self.to = self._to_jax
self.to_numpy = self._to_numpy_jax
Expand All @@ -70,9 +70,9 @@
self.make_array = self._make_array_numpy
self._array_type = self._array_type_numpy
self.concatenate = self._concatenate_numpy
self.copy = self._copy_numpy
self.tolist = self._tolist_numpy
self.view = self._view_numpy
self.detach = self._detach_numpy
self.as_array = self._as_array_numpy
self.to = self._to_numpy
self.to_numpy = self._to_numpy_numpy
Expand Down Expand Up @@ -110,14 +110,14 @@
def _concatenate_numpy(self, arrays, axis=0):
return self.module.concatenate(arrays, axis=axis)

def _copy_torch(self, array):
return array.detach().clone()
def _detach_torch(self, array):
return array.detach()

def _copy_jax(self, array):
return self.module.copy(array)
def _detach_jax(self, array):
return array

def _copy_numpy(self, array):
return self.module.copy(array)
def _detach_numpy(self, array):
return array

def _tolist_torch(self, array):
return array.detach().cpu().tolist()
Expand Down Expand Up @@ -171,10 +171,10 @@
return self.module.all(array)

def log(self, array):
return self.module.log(array)

Check warning on line 174 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS windows-latest - Backend numpy

invalid value encountered in log

Check warning on line 174 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS ubuntu-latest - Backend torch

invalid value encountered in log

Check warning on line 174 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS ubuntu-latest - Backend numpy

invalid value encountered in log

Check warning on line 174 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS macOS-latest - Backend numpy

invalid value encountered in log

def exp(self, array):
return self.module.exp(array)

Check warning on line 177 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS windows-latest - Backend numpy

overflow encountered in exp

Check warning on line 177 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS ubuntu-latest - Backend torch

overflow encountered in exp

Check warning on line 177 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS ubuntu-latest - Backend numpy

overflow encountered in exp

Check warning on line 177 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS macOS-latest - Backend numpy

overflow encountered in exp

def sum(self, array, axis=None):
return self.module.sum(array, axis=axis)
Expand All @@ -186,7 +186,7 @@
return self.jax.nn.sigmoid(array)

def _sigmoid_numpy(self, array):
return 1 / (1 + self.module.exp(-array))

Check warning on line 189 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS windows-latest - Backend numpy

overflow encountered in exp

Check warning on line 189 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS ubuntu-latest - Backend torch

overflow encountered in exp

Check warning on line 189 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS ubuntu-latest - Backend numpy

overflow encountered in exp

Check warning on line 189 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.10 - OS macOS-latest - Backend numpy

overflow encountered in exp

def _logit_torch(self, array):
return self.module.logit(array)
Expand Down
65 changes: 61 additions & 4 deletions src/caskade/collection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from typing import Iterable
from typing import Union, Sequence, Mapping

from math import prod

from .base import Node
from .backend import backend, ArrayLike
from .errors import (
ParamConfigurationError,
FillParamsArrayError,
FillParamsSequenceError,
FillParamsMappingError,
)


class NodeCollection(Node):
Expand All @@ -14,9 +23,57 @@ def to_static(self, **kwargs):
if hasattr(node, "to_static"):
node.to_static(**kwargs)

def fill_values(self, values: Iterable):
for node, value in zip(self, values):
node.value = value
def set_values(
self, params: Union[ArrayLike, Sequence, Mapping], node_type="all", attribute="value"
):
if node_type == "all":
node_type = "dynamic/static"
if isinstance(params, backend.array_type):
if params.shape[-1] == 0:
return # No parameters to fill
# check for batch dimension
batch = len(params.shape) > 1
B = tuple(params.shape[:-1]) if batch else ()
pos = 0
for param in self:
if param.node_type not in node_type:
continue
if not isinstance(param.shape, tuple):
raise ParamConfigurationError(
f"Param {param.name} has no shape. dynamic parameters must have a shape to use {backend.array_type.__name__} input."
)
# Handle scalar parameters
size = max(1, prod(param.shape))
try:
val = backend.view(params[..., pos : pos + size], B + param.shape)
setattr(param, attribute, val)
except (RuntimeError, IndexError, ValueError, TypeError):
raise FillParamsArrayError(self.name, params, self)

pos += size
if pos != params.shape[-1]:
raise FillParamsArrayError(self.name, params, self)
elif isinstance(params, Sequence):
if len(params) == 0:
return
elif len(params) == len(self):
param_list = filter(lambda p: p.node_type in node_type, self)
for param, value in zip(param_list, params):
setattr(param, attribute, value)
else:
raise FillParamsSequenceError(self.name, params, self)
elif isinstance(params, Mapping):
params_names = set(params.keys())
for name, param in self.children.items():
if name in params:
params_names.remove(name)
setattr(param, attribute, params[name])
if len(params_names) > 0:
raise FillParamsMappingError(self.name, self.children, next(iter(params_names)))
else:
raise TypeError(
f"Input params type {type(params)} not supported. Should be {backend.array_type.__name__}, Sequence, or Mapping."
)

def copy(self):
raise NotImplementedError
Expand Down
Loading
Loading