diff --git a/.github/workflows/cd.yaml b/.github/workflows/cd.yaml index b408cd4e..d70914a0 100644 --- a/.github/workflows/cd.yaml +++ b/.github/workflows/cd.yaml @@ -23,14 +23,14 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Build sdist and wheel run: pipx run build - - uses: actions/upload-artifact@v4 + - uses: actions/upload-artifact@v5 with: path: dist @@ -45,11 +45,11 @@ jobs: permissions: id-token: write steps: - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 name: Install Python with: python-version: "3.10" - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v6 with: name: artifact path: dist @@ -58,7 +58,7 @@ jobs: ls -ltrh ls -ltrh dist - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@v1.12.4 + uses: pypa/gh-action-pypi-publish@v1.13.0 with: repository-url: https://test.pypi.org/legacy/ verbose: true @@ -91,10 +91,10 @@ jobs: if: github.event_name == 'release' && github.event.action == 'published' steps: - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v6 with: name: artifact path: dist - - uses: pypa/gh-action-pypi-publish@v1.12.4 + - uses: pypa/gh-action-pypi-publish@v1.13.0 if: startsWith(github.ref, 'refs/tags') diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 22e0d18d..b14db8cf 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -20,7 +20,7 @@ jobs: steps: - uses: actions/checkout@master - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" - name: Record State @@ -44,12 +44,23 @@ jobs: python -m pip install pytest-github-actions-annotate-failures - name: Install AstroPhot run: | - pip install -e . + pip install -e ".[dev]" pip show ${{ env.PROJECT_NAME }} shell: bash - - name: Test with pytest + - name: Test with pytest [torch] run: | - pytest -vvv --cov=${{ env.PROJECT_NAME }} --cov-report=xml --cov-report=term tests/ + coverage run --source=${{ env.PROJECT_NAME }} -m pytest tests/ + shell: bash + env: + CASKADE_BACKEND: torch + - name: Extra coverage report for jax checks + run: | + echo "Running extra coverage report for jax checks" + coverage run --append --source=${{ env.PROJECT_NAME }} -m pytest tests/ + shell: bash + env: + JAX_ENABLE_X64: True + CASKADE_BACKEND: jax - name: Upload coverage reports to Codecov with GitHub Action uses: codecov/codecov-action@v5 diff --git a/.github/workflows/testing.yaml b/.github/workflows/testing.yaml index 5a542e87..eeffce2e 100644 --- a/.github/workflows/testing.yaml +++ b/.github/workflows/testing.yaml @@ -15,12 +15,12 @@ jobs: runs-on: ${{matrix.os}} strategy: matrix: - python-version: ["3.9", "3.10"] + python-version: ["3.10", "3.11", "3.12"] os: [ubuntu-latest, windows-latest, macOS-latest] steps: - uses: actions/checkout@master - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Record State @@ -35,12 +35,13 @@ jobs: python -m pip install --upgrade pip pip install pytest pip install wheel + pip install jax if [ -f requirements.txt ]; then pip install -r requirements.txt; fi shell: bash - name: Install AstroPhot run: | cd $GITHUB_WORKSPACE/ - pip install . + pip install .[dev] pip show astrophot shell: bash - name: Test with pytest diff --git a/.gitignore b/.gitignore index 7844bc7d..763be6f1 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ tests/*.yaml docs/source/tutorials/*.fits docs/source/tutorials/*.yaml docs/source/tutorials/*.jpg -docs/autophot.*rst +docs/source/astrophotdocs/*.ipynb docs/modules.rst pip_cheatsheet.txt .gitpod.yml diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 6ef33248..3989c638 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -20,11 +20,14 @@ sphinx: build: os: "ubuntu-20.04" tools: - python: "3.9" + python: "3.12" apt_packages: - pandoc # Specify pandoc to be installed via apt-get + - graphviz jobs: pre_build: + # Build docstring jupyter notebooks + - "python make_docs.py" # Generate the Sphinx configuration for this Jupyter Book so it builds. - "jupyter-book config sphinx docs/source/" # Create font cache ahead of jupyter book diff --git a/astrophot/__init__.py b/astrophot/__init__.py index 91b884a8..f863ad11 100644 --- a/astrophot/__init__.py +++ b/astrophot/__init__.py @@ -1,8 +1,28 @@ import argparse import requests import torch -from .parse_config import galfit_config, basic_config -from . import models, image, plots, utils, fit, param, AP_config +from . import config, models, plots, utils, fit, image, errors +from .param import forward, Param, Module + +from .image import ( + Image, + ImageList, + TargetImage, + TargetImageList, + SIPModelImage, + SIPTargetImage, + CMOSModelImage, + CMOSTargetImage, + JacobianImage, + JacobianImageList, + PSFImage, + ModelImage, + ModelImageList, + Window, + WindowList, +) +from .models import Model +from .backend_obj import backend, ArrayLike try: from ._version import version as VERSION # noqa @@ -21,32 +41,10 @@ def run_from_terminal() -> None: """ - Execute AstroPhot from the command line with various options. - - This function uses the `argparse` module to parse command line arguments and execute the appropriate functionality. - It accepts the following arguments: - - - `filename`: the path to the configuration file. Or just 'tutorial' to download tutorials. - - `--config`: the type of configuration file being provided. One of: astrophot, galfit. - - `-v`, `--version`: print the current AstroPhot version to screen. - - `--log`: set the log file name for AstroPhot. Use 'none' to suppress the log file. - - `-q`: quiet flag to stop command line output, only print to log file. - - `--dtype`: set the float point precision. Must be one of: float64, float32. - - `--device`: set the device for AstroPhot to use for computations. Must be one of: cpu, gpu. - - If the `filename` argument is not provided, it raises a `RuntimeError`. - If the `filename` argument is `tutorial` or `tutorials`, - it downloads tutorials from various URLs and saves them locally. - - This function logs messages using the `AP_config` module, - which sets the logging output based on the `--log` and `-q` arguments. - The `dtype` and `device` of AstroPhot can also be set using the `--dtype` and `--device` arguments, respectively. - - Returns: - None + Running from terminal no longer supported. This is only used for convenience to download the tutorials. """ - AP_config.ap_logger.debug("running from the terminal, not sure if it will catch me.") + config.logger.debug("running from the terminal, not sure if it will catch me.") parser = argparse.ArgumentParser( prog="astrophot", description="Fast and flexible astronomical image photometry package. For the documentation go to: https://astrophot.readthedocs.io", @@ -58,14 +56,14 @@ def run_from_terminal() -> None: metavar="configfile", help="the path to the configuration file. Or just 'tutorial' to download tutorials.", ) - parser.add_argument( - "--config", - type=str, - default="astrophot", - choices=["astrophot", "galfit"], - metavar="format", - help="The type of configuration file being being provided. One of: astrophot, galfit.", - ) + # parser.add_argument( + # "--config", + # type=str, + # default="astrophot", + # choices=["astrophot", "galfit"], + # metavar="format", + # help="The type of configuration file being being provided. One of: astrophot, galfit.", + # ) parser.add_argument( "-v", "--version", @@ -73,45 +71,45 @@ def run_from_terminal() -> None: version=f"%(prog)s {__version__}", help="print the current AstroPhot version to screen", ) - parser.add_argument( - "--log", - type=str, - metavar="logfile.log", - help="set the log file name for AstroPhot. use 'none' to suppress the log file.", - ) - parser.add_argument( - "-q", - action="store_true", - help="quiet flag to stop command line output, only print to log file", - ) - parser.add_argument( - "--dtype", - type=str, - choices=["float64", "float32"], - metavar="datatype", - help="set the float point precision. Must be one of: float64, float32", - ) - parser.add_argument( - "--device", - type=str, - choices=["cpu", "gpu"], - metavar="device", - help="set the device for AstroPhot to use for computations. Must be one of: cpu, gpu", - ) + # parser.add_argument( + # "--log", + # type=str, + # metavar="logfile.log", + # help="set the log file name for AstroPhot. use 'none' to suppress the log file.", + # ) + # parser.add_argument( + # "-q", + # action="store_true", + # help="quiet flag to stop command line output, only print to log file", + # ) + # parser.add_argument( + # "--dtype", + # type=str, + # choices=["float64", "float32"], + # metavar="datatype", + # help="set the float point precision. Must be one of: float64, float32", + # ) + # parser.add_argument( + # "--device", + # type=str, + # choices=["cpu", "gpu"], + # metavar="device", + # help="set the device for AstroPhot to use for computations. Must be one of: cpu, gpu", + # ) args = parser.parse_args() if args.log is not None: - AP_config.set_logging_output( + config.set_logging_output( stdout=not args.q, filename=None if args.log == "none" else args.log ) elif args.q: - AP_config.set_logging_output(stdout=not args.q, filename="AstroPhot.log") + config.set_logging_output(stdout=not args.q, filename="AstroPhot.log") if args.dtype is not None: - AP_config.dtype = torch.float64 if args.dtype == "float64" else torch.float32 + config.DTYPE = torch.float64 if args.dtype == "float64" else torch.float32 if args.device is not None: - AP_config.device = "cpu" if args.device == "cpu" else "cuda:0" + config.DEVICE = "cpu" if args.device == "cpu" else "cuda:0" if args.filename is None: raise RuntimeError( @@ -128,7 +126,6 @@ def run_from_terminal() -> None: "https://raw.github.com/Autostronomy/AstroPhot/main/docs/source/tutorials/BasicPSFModels.ipynb", "https://raw.github.com/Autostronomy/AstroPhot/main/docs/source/tutorials/AdvancedPSFModels.ipynb", "https://raw.github.com/Autostronomy/AstroPhot/main/docs/source/tutorials/ConstrainedModels.ipynb", - "https://raw.github.com/Autostronomy/AstroPhot-tutorials/main/docs/tutorials/simple_config.py", ] for url in tutorials: try: @@ -140,12 +137,41 @@ def run_from_terminal() -> None: f"WARNING: couldn't find tutorial: {url[url.rfind('/')+1:]} check internet connection" ) - AP_config.ap_logger.info("collected the tutorials") - elif args.config == "astrophot": - basic_config(args.filename) - elif args.config == "galfit": - galfit_config(args.filename) + config.logger.info("collected the tutorials") else: - raise ValueError( - f"Unrecognized configuration file format {args.config}. Should be one of: astrophot, galfit" - ) + raise ValueError(f"Unrecognized request") + + +__all__ = ( + "models", + "image", + "Model", + "Image", + "ImageList", + "TargetImage", + "TargetImageList", + "SIPModelImage", + "SIPTargetImage", + "CMOSModelImage", + "CMOSTargetImage", + "JacobianImage", + "JacobianImageList", + "PSFImage", + "ModelImage", + "ModelImageList", + "Window", + "WindowList", + "plots", + "utils", + "fit", + "forward", + "Param", + "errors", + "Module", + "config", + "backend", + "run_from_terminal", + "__version__", + "__author__", + "__email__", +) diff --git a/astrophot/backend_obj.py b/astrophot/backend_obj.py new file mode 100644 index 00000000..dc12fec7 --- /dev/null +++ b/astrophot/backend_obj.py @@ -0,0 +1,533 @@ +import os +import importlib +from typing import Annotated + +from torch import Tensor, dtype, device +import torch +import numpy as np +import caskade as ck + +from . import config + +ArrayLike = Annotated[ + Tensor, + "One of: torch.Tensor or jax.numpy.ndarray depending on the chosen backend.", +] +dtypeLike = Annotated[ + dtype, + "One of: torch.dtype or jax.numpy.dtype depending on the chosen backend.", +] +deviceLike = Annotated[ + device, + "One of: torch.device or jax.DeviceArray depending on the chosen backend.", +] + + +class Backend: + def __init__(self, backend=None): + self.backend = backend + + @property + def backend(self): + return self._backend + + @backend.setter + def backend(self, backend): + if backend is None: + backend = os.getenv("CASKADE_BACKEND", "torch") + ck.backend.backend = backend + self._load_backend(backend) + self._backend = backend + + def _load_backend(self, backend): + if backend == "torch": + self.module = importlib.import_module("torch") + self.setup_torch() + elif backend == "jax": + self.module = importlib.import_module("jax.numpy") + self.setup_jax() + else: + raise ValueError(f"Unsupported backend: {backend}") + + def setup_torch(self): + config.DTYPE = torch.float64 + config.DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" + self.make_array = self._make_array_torch + self._array_type = self._array_type_torch + self.concatenate = self._concatenate_torch + self.copy = self._copy_torch + self.tolist = self._tolist_torch + self.view = self._view_torch + self.as_array = self._as_array_torch + self.to = self._to_torch + self.to_numpy = self._to_numpy_torch + self.gammaln = self._gammaln_torch + self.logit = self._logit_torch + self.sigmoid = self._sigmoid_torch + self.repeat = self._repeat_torch + self.stack = self._stack_torch + self.transpose = self._transpose_torch + self.upsample2d = self._upsample2d_torch + self.pad = self._pad_torch + self.LinAlgErr = self.module._C._LinAlgError + self.roll = self._roll_torch + self.clamp = self._clamp_torch + self.flatten = self._flatten_torch + self.conv2d = self._conv2d_torch + self.mean = self._mean_torch + self.sum = self._sum_torch + self.max = self._max_torch + self.topk = self._topk_torch + self.bessel_j1 = self._bessel_j1_torch + self.bessel_k1 = self._bessel_k1_torch + self.lgamma = self._lgamma_torch + self.hessian = self._hessian_torch + self.jacobian = self._jacobian_torch + self.jacfwd = self._jacfwd_torch + self.grad = self._grad_torch + self.vmap = self._vmap_torch + self.long = self._long_torch + self.detach = lambda x: x.detach() + self.fill_at_indices = self._fill_at_indices_torch + self.add_at_indices = self._add_at_indices_torch + self.and_at_indices = self._and_at_indices_torch + + def setup_jax(self): + self.jax = importlib.import_module("jax") + self.jax.config.update("jax_enable_x64", True) + config.DTYPE = None + config.DEVICE = None + self.make_array = self._make_array_jax + self._array_type = self._array_type_jax + self.concatenate = self._concatenate_jax + self.copy = self._copy_jax + self.tolist = self._tolist_jax + self.view = self._view_jax + self.as_array = self._as_array_jax + self.to = self._to_jax + self.to_numpy = self._to_numpy_jax + self.gammaln = self._gammaln_jax + self.logit = self._logit_jax + self.sigmoid = self._sigmoid_jax + self.repeat = self._repeat_jax + self.stack = self._stack_jax + self.transpose = self._transpose_jax + self.upsample2d = self._upsample2d_jax + self.pad = self._pad_jax + self.LinAlgErr = Exception + self.roll = self._roll_jax + self.clamp = self._clamp_jax + self.flatten = self._flatten_jax + self.conv2d = self._conv2d_jax + self.mean = self._mean_jax + self.sum = self._sum_jax + self.max = self._max_jax + self.topk = self._topk_jax + self.bessel_j1 = self._bessel_j1_jax + self.bessel_k1 = self._bessel_k1_jax + self.lgamma = self._lgamma_jax + self.hessian = self._hessian_jax + self.jacobian = self._jacobian_jax + self.jacfwd = self._jacfwd_jax + self.grad = self._grad_jax + self.vmap = self._vmap_jax + self.long = self._long_jax + self.detach = lambda x: x + self.fill_at_indices = self._fill_at_indices_jax + self.add_at_indices = self._add_at_indices_jax + self.and_at_indices = self._and_at_indices_jax + + @property + def array_type(self): + return self._array_type() + + def _make_array_torch(self, array, dtype=None, device=None): + return self.module.tensor(array, dtype=dtype, device=device) + + def _make_array_jax(self, array, dtype=None, **kwargs): + return self.module.array(array, dtype=dtype) + + def _array_type_torch(self): + return self.module.Tensor + + def _array_type_jax(self): + return self.module.ndarray + + def _concatenate_torch(self, arrays, dim=0): + return self.module.cat(arrays, dim=dim) + + def _concatenate_jax(self, arrays, dim=0): + return self.module.concatenate(arrays, axis=dim) + + def _copy_torch(self, array): + return array.detach().clone() + + def _copy_jax(self, array): + return self.module.copy(array) + + def _tolist_torch(self, array): + return array.detach().cpu().tolist() + + def _tolist_jax(self, array): + return array.block_until_ready().tolist() + + def _view_torch(self, array, shape): + return array.reshape(shape) + + def _view_jax(self, array, shape): + return array.reshape(shape) + + def _as_array_torch(self, array, dtype=None, device=None): + return self.module.as_tensor(array, dtype=dtype, device=device) + + def _as_array_jax(self, array, dtype=None, **kwargs): + return self.module.asarray(array, dtype=dtype) + + def _to_torch(self, array, dtype=None, device=None): + return array.to(dtype=dtype, device=device) + + def _to_jax(self, array, dtype=None, device=None): + return self.jax.device_put(array.astype(dtype), device=device) + + def _to_numpy_torch(self, array): + return array.detach().cpu().numpy() + + def _to_numpy_jax(self, array): + return np.array(array.block_until_ready()) + + def _repeat_torch(self, a, repeats, axis=None): + return self.module.repeat_interleave(a, repeats, dim=axis) + + def _repeat_jax(self, a, repeats, axis=None): + return self.module.repeat(a, repeats, axis=axis) + + def _stack_torch(self, arrays, dim=0): + return self.module.stack(arrays, dim=dim) + + def _stack_jax(self, arrays, dim=0): + return self.module.stack(arrays, axis=dim) + + def _transpose_torch(self, array, *args): + return self.module.transpose(array, *args) + + def _transpose_jax(self, array, *args): + permutation = np.arange(array.ndim) + permutation[np.sort(args)] = args + return self.module.transpose(array, permutation) + + def _gammaln_torch(self, array): + return self.module.special.gammaln(array) + + def _gammaln_jax(self, array): + return self.jax.scipy.special.gammaln(array) + + def _sigmoid_torch(self, array): + return self.module.sigmoid(array) + + def _sigmoid_jax(self, array): + return self.jax.nn.sigmoid(array) + + def _logit_torch(self, array): + return self.module.logit(array) + + def _logit_jax(self, array): + return self.jax.scipy.special.logit(array) + + def _upsample2d_torch(self, array, scale_factor, method): + U = self.module.nn.Upsample(scale_factor=scale_factor, mode=method) + array = U(array) / scale_factor**2 + return array + + def _upsample2d_jax(self, array, scale_factor, method): + if method == "nearest": + method = "bilinear" # no nearest neighbor interpolation in jax + new_shape = list(array.shape) + new_shape[-2] = array.shape[-2] * scale_factor + new_shape[-1] = array.shape[-1] * scale_factor + return self.jax.image.resize(array, new_shape, method=method) + + def _pad_torch(self, array, padding, mode): + return self.module.nn.functional.pad(array, padding[-4:], mode=mode) + + def _pad_jax(self, array, padding, mode): + if mode == "replicate": + mode = "edge" + padding = np.array(padding).reshape(-1, 2) + return self.module.pad(array, padding, mode=mode) + + def _roll_torch(self, array, shifts, dims): + return self.module.roll(array, shifts, dims=dims) + + def _roll_jax(self, array, shifts, dims): + return self.module.roll(array, shifts, axis=dims) + + def _clamp_torch(self, array, min, max): + return self.module.clamp(array, min, max) + + def _clamp_jax(self, array, min, max): + return self.module.clip(array, min, max) + + def _long_torch(self, array): + return array.long() + + def _long_jax(self, array): + return self.module.astype(array, self.module.int64) + + def _conv2d_torch(self, input, kernel, padding, stride=1): + return self.module.nn.functional.conv2d( + input, + kernel, + padding=padding, + stride=stride, + ) + + def _conv2d_jax(self, input, kernel, padding, stride=1): + return self.jax.lax.conv_general_dilated( + input, kernel, window_strides=(stride, stride), padding=padding + ) + + def _mean_torch(self, array, dim=None): + return self.module.mean(array, dim=dim) + + def _mean_jax(self, array, dim=None): + return self.module.mean(array, axis=dim) + + def _sum_torch(self, array, dim=None): + return self.module.sum(array, dim=dim) + + def _sum_jax(self, array, dim=None): + return self.module.sum(array, axis=dim) + + def _max_torch(self, array, dim=None): + return array.amax(dim=dim) + + def _max_jax(self, array, dim=None): + return self.module.max(array, axis=dim) + + def _topk_torch(self, array, k): + return self.module.topk(array, k=k) + + def _topk_jax(self, array, k): + return self.jax.lax.top_k(array, k=k) + + def _bessel_j1_torch(self, array): + return self.module.special.bessel_j1(array) + + def _bessel_j1_jax(self, array): + return self.jax.scipy.special.bessel_jn(array, v=1)[-1] + + def _bessel_k1_torch(self, array): + return self.module.special.modified_bessel_k1(array) + + def _bessel_k1_jax(self, array): + return self.jax.scipy.special.kn(1, array) + + def _lgamma_torch(self, array): + return self.module.lgamma(array) + + def _lgamma_jax(self, array): + return self.jax.lax.lgamma(array) + + def _grad_torch(self, func): + return self.module.func.grad(func) + + def _grad_jax(self, func): + return self.jax.grad(func) + + def _jacobian_torch(self, func, x, strategy="forward-mode", vectorize=True, create_graph=False): + return self.module.autograd.functional.jacobian( + func, x, strategy=strategy, vectorize=vectorize, create_graph=create_graph + ) + + def _jacobian_jax(self, func, x, strategy="forward-mode", vectorize=True, create_graph=False): + if "forward" in strategy: + # n = x.size + # eye = self.module.eye(n) + # Jt = self.jax.vmap(lambda s: self.jax.jvp(func, (x,), (s,))[1])(eye) + # return self.module.moveaxis(Jt, 0, -1) + return self.jax.jacfwd(func)(x) + return self.jax.jacrev(func)(x) + + def _jacfwd_torch(self, func, argnums=0): + return self.module.func.jacfwd(func, argnums=argnums) + + def _jacfwd_jax(self, func, argnums=0): + return self.jax.jacfwd(func, argnums=argnums) + + def _hessian_torch(self, func): + return self.module.func.hessian(func) + + def _hessian_jax(self, func): + return self.jax.hessian(func) + + def _vmap_torch(self, *args, **kwargs): + return self.module.vmap(*args, **kwargs) + + def _vmap_jax(self, *args, **kwargs): + return self.jax.vmap(*args, **kwargs) + + def _fill_at_indices_torch(self, array, indices, values): + array[indices] = values + return array + + def _fill_at_indices_jax(self, array, indices, values): + array = array.at[indices].set(values) + return array + + def _add_at_indices_torch(self, array, indices, values): + array[indices] += values + return array + + def _add_at_indices_jax(self, array, indices, values): + array = array.at[indices].add(values) + return array + + def _and_at_indices_torch(self, array, indices, values): + array[indices] &= values + return array + + def _and_at_indices_jax(self, array, indices, values): + array = array.at[indices].set(array[indices] & values) + return array + + def _flatten_torch(self, array, start_dim=0, end_dim=-1): + return array.flatten(start_dim, end_dim) + + def _flatten_jax(self, array, start_dim=0, end_dim=-1): + shape = tuple(array.shape) + end_dim = (end_dim % len(shape)) + 1 + new_shape = shape[:start_dim] + (-1,) + shape[end_dim:] + return self.module.reshape(array, new_shape) + + def arange(self, *args, dtype=None, device=None): + return self.module.arange(*args, dtype=dtype, device=device) + + def linspace(self, start, end, steps, dtype=None, device=None): + return self.module.linspace(start, end, steps, dtype=dtype, device=device) + + def meshgrid(self, *arrays, indexing="ij"): + return self.module.meshgrid(*arrays, indexing=indexing) + + def searchsorted(self, array, value): + return self.module.searchsorted(array, value) + + def any(self, array): + return self.module.any(array) + + def all(self, array): + return self.module.all(array) + + def log(self, array): + return self.module.log(array) + + def log10(self, array): + return self.module.log10(array) + + def exp(self, array): + return self.module.exp(array) + + def sin(self, array): + return self.module.sin(array) + + def cos(self, array): + return self.module.cos(array) + + def cosh(self, array): + return self.module.cosh(array) + + def sqrt(self, array): + return self.module.sqrt(array) + + def abs(self, array): + return self.module.abs(array) + + def floor(self, array): + return self.module.floor(array) + + def tanh(self, array): + return self.module.tanh(array) + + def arctan(self, array): + return self.module.arctan(array) + + def arctan2(self, y, x): + return self.module.arctan2(y, x) + + def arcsin(self, array): + return self.module.arcsin(array) + + def round(self, array): + return self.module.round(array) + + def zeros(self, shape, dtype=None, device=None): + return self.module.zeros(shape, dtype=dtype, device=device) + + def zeros_like(self, array, dtype=None): + return self.module.zeros_like(array, dtype=dtype) + + def ones(self, shape, dtype=None, device=None): + return self.module.ones(shape, dtype=dtype, device=device) + + def ones_like(self, array, dtype=None): + return self.module.ones_like(array, dtype=dtype) + + def empty(self, shape, dtype=None, device=None): + return self.module.empty(shape, dtype=dtype, device=device) + + def eye(self, n, dtype=None, device=None): + return self.module.eye(n, dtype=dtype, device=device) + + def diag(self, array): + return self.module.diag(array) + + def outer(self, a, b): + return self.module.outer(a, b) + + def minimum(self, a, b): + return self.module.minimum(a, b) + + def maximum(self, a, b): + return self.module.maximum(a, b) + + def isnan(self, array): + return self.module.isnan(array) + + def isfinite(self, array): + return self.module.isfinite(array) + + def where(self, condition, x, y): + return self.module.where(condition, x, y) + + def allclose(self, a, b, rtol=1e-5, atol=1e-8): + return self.module.allclose(a, b, rtol=rtol, atol=atol) + + @property + def linalg(self): + return self.module.linalg + + @property + def fft(self): + return self.module.fft + + @property + def inf(self): + return self.module.inf + + @property + def bool(self): + return self.module.bool + + @property + def int32(self): + return self.module.int32 + + @property + def float32(self): + return self.module.float32 + + @property + def float64(self): + return self.module.float64 + + +backend = Backend() diff --git a/astrophot/AP_config.py b/astrophot/config.py similarity index 76% rename from astrophot/AP_config.py rename to astrophot/config.py index 722ccc2c..3f11da8c 100644 --- a/astrophot/AP_config.py +++ b/astrophot/config.py @@ -2,29 +2,28 @@ import logging import torch -__all__ = ["ap_dtype", "ap_device", "ap_logger", "set_logging_output"] +__all__ = ["DTYPE", "DEVICE", "logger", "set_logging_output"] -ap_dtype = torch.float64 -ap_device = "cuda:0" if torch.cuda.is_available() else "cpu" -ap_verbose = 0 +DTYPE = torch.float64 +DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" logging.basicConfig( filename="AstroPhot.log", level=logging.INFO, format="%(asctime)s:%(levelname)s: %(message)s", ) -ap_logger = logging.getLogger() +logger = logging.getLogger() out_handler = logging.StreamHandler(sys.stdout) out_handler.setLevel(logging.INFO) out_handler.setFormatter(logging.Formatter("%(message)s")) -ap_logger.addHandler(out_handler) +logger.addHandler(out_handler) def set_logging_output(stdout=True, filename=None, **kwargs): """ Change the logging system for AstroPhot. Here you can set whether output prints to screen or to a logging file. - This function will remove all handlers from the current logger in ap_logger, + This function will remove all handlers from the current logger in logger, then add new handlers based on the input to the function. Parameters: @@ -39,11 +38,11 @@ def set_logging_output(stdout=True, filename=None, **kwargs): """ hi = 0 - while hi < len(ap_logger.handlers): - if isinstance(ap_logger.handlers[hi], logging.StreamHandler): - ap_logger.removeHandler(ap_logger.handlers[hi]) - elif isinstance(ap_logger.handlers[hi], logging.FileHandler): - ap_logger.removeHandler(ap_logger.handlers[hi]) + while hi < len(logger.handlers): + if isinstance(logger.handlers[hi], logging.StreamHandler): + logger.removeHandler(logger.handlers[hi]) + elif isinstance(logger.handlers[hi], logging.FileHandler): + logger.removeHandler(logger.handlers[hi]) else: hi += 1 @@ -51,8 +50,8 @@ def set_logging_output(stdout=True, filename=None, **kwargs): out_handler = logging.StreamHandler(sys.stdout) out_handler.setLevel(kwargs.get("stdout_level", logging.INFO)) out_handler.setFormatter(kwargs.get("stdout_formatter", logging.Formatter("%(message)s"))) - ap_logger.addHandler(out_handler) - ap_logger.debug("logging now going to stdout") + logger.addHandler(out_handler) + logger.debug("logging now going to stdout") if filename is not None: out_handler = logging.FileHandler(filename) out_handler.setLevel(kwargs.get("filename_level", logging.INFO)) @@ -62,5 +61,5 @@ def set_logging_output(stdout=True, filename=None, **kwargs): logging.Formatter("%(asctime)s:%(levelname)s: %(message)s"), ) ) - ap_logger.addHandler(out_handler) - ap_logger.debug("logging now going to %s" % filename) + logger.addHandler(out_handler) + logger.debug("logging now going to %s" % filename) diff --git a/astrophot/errors/__init__.py b/astrophot/errors/__init__.py index 88392248..924f120c 100644 --- a/astrophot/errors/__init__.py +++ b/astrophot/errors/__init__.py @@ -1,5 +1,16 @@ -from .base import * -from .fit import * -from .image import * -from .models import * -from .param import * +from .base import AstroPhotError, SpecificationConflict +from .fit import OptimizeStopFail, OptimizeStopSuccess +from .image import InvalidWindow, InvalidData, InvalidImage +from .models import InvalidTarget, UnrecognizedModel + +__all__ = ( + "AstroPhotError", + "SpecificationConflict", + "OptimizeStopFail", + "OptimizeStopSuccess", + "InvalidWindow", + "InvalidData", + "InvalidImage", + "InvalidTarget", + "UnrecognizedModel", +) diff --git a/astrophot/errors/base.py b/astrophot/errors/base.py index 0f6a2433..b64b0b4b 100644 --- a/astrophot/errors/base.py +++ b/astrophot/errors/base.py @@ -1,4 +1,4 @@ -__all__ = ("AstroPhotError", "NameNotAllowed", "SpecificationConflict") +__all__ = ("AstroPhotError", "SpecificationConflict") class AstroPhotError(Exception): @@ -6,20 +6,8 @@ class AstroPhotError(Exception): Base exception for all AstroPhot processes. """ - ... - - -class NameNotAllowed(AstroPhotError): - """ - Used for invalid names of AstroPhot objects - """ - - ... - class SpecificationConflict(AstroPhotError): """ Raised when the inputs to an object are conflicting and/or ambiguous """ - - ... diff --git a/astrophot/errors/fit.py b/astrophot/errors/fit.py index 19d9dede..0aa61620 100644 --- a/astrophot/errors/fit.py +++ b/astrophot/errors/fit.py @@ -1,11 +1,15 @@ from .base import AstroPhotError -__all__ = ("OptimizeStop",) +__all__ = ("OptimizeStopFail", "OptimizeStopSuccess") -class OptimizeStop(AstroPhotError): +class OptimizeStopFail(AstroPhotError): """ - Raised at any point to stop optimization process. + Raised at any point to stop optimization process due to failure. """ - pass + +class OptimizeStopSuccess(AstroPhotError): + """ + Raised at any point to stop optimization process due to success condition. + """ diff --git a/astrophot/errors/image.py b/astrophot/errors/image.py index ef77642a..cdf73fc4 100644 --- a/astrophot/errors/image.py +++ b/astrophot/errors/image.py @@ -1,12 +1,6 @@ from .base import AstroPhotError -__all__ = ( - "InvalidWindow", - "ConflicingWCS", - "InvalidData", - "InvalidImage", - "InvalidWCS", -) +__all__ = ("InvalidWindow", "InvalidData", "InvalidImage") class InvalidWindow(AstroPhotError): @@ -14,36 +8,14 @@ class InvalidWindow(AstroPhotError): Raised whenever a window is misspecified """ - ... - - -class ConflicingWCS(InvalidWindow): - """ - Raised when windows are compared and have WCS prescriptions which do not agree - """ - - ... - class InvalidData(AstroPhotError): """ - Raised when an image object can't determine the data it is holding. + Raised when the data provided to an image is invalid or cannot be processed. """ - ... - class InvalidImage(AstroPhotError): """ Raised when an image object cannot be used as given. """ - - ... - - -class InvalidWCS(AstroPhotError): - """ - Raised when the WCS is not appropriate as given. - """ - - ... diff --git a/astrophot/errors/models.py b/astrophot/errors/models.py index 9de693f4..78cfdc4c 100644 --- a/astrophot/errors/models.py +++ b/astrophot/errors/models.py @@ -1,14 +1,6 @@ from .base import AstroPhotError -__all__ = ("InvalidModel", "InvalidTarget", "UnrecognizedModel") - - -class InvalidModel(AstroPhotError): - """ - Catches when a model object is inappropriate for this instance. - """ - - ... +__all__ = ("InvalidTarget", "UnrecognizedModel") class InvalidTarget(AstroPhotError): @@ -16,12 +8,8 @@ class InvalidTarget(AstroPhotError): Catches when a target object is assigned incorrectly. """ - ... - class UnrecognizedModel(AstroPhotError): """ Raised when the user tries to invoke a model that does not exist. """ - - ... diff --git a/astrophot/errors/param.py b/astrophot/errors/param.py deleted file mode 100644 index afa068a3..00000000 --- a/astrophot/errors/param.py +++ /dev/null @@ -1,11 +0,0 @@ -from .base import AstroPhotError - -__all__ = ("InvalidParameter",) - - -class InvalidParameter(AstroPhotError): - """ - Catches when a parameter object is assigned incorrectly. - """ - - ... diff --git a/astrophot/fit/__init__.py b/astrophot/fit/__init__.py index 13976d48..987035bc 100644 --- a/astrophot/fit/__init__.py +++ b/astrophot/fit/__init__.py @@ -1,33 +1,24 @@ -from .base import * -from .lm import * -from .oldlm import * -from .gradient import * -from .iterative import * -from .minifit import * +from .lm import LM, LMfast +from .gradient import Grad, Slalom +from .iterative import Iter, IterParam +from .scipy_fit import ScipyFit +from .minifit import MiniFit +from .hmc import HMC +from .mala import MALA +from .mhmcmc import MHMCMC +from . import func -try: - from .hmc import * - from .nuts import * -except AssertionError as e: - print("Could not load HMC or NUTS due to:", str(e)) -from .mhmcmc import * - -""" -base: This module defines the base class BaseOptimizer, - which is used as the parent class for all optimization algorithms in AstroPhot. - This module contains helper functions used across multiple optimization algorithms, - such as computing gradients and making copies of models. - -LM: This module defines the class LM, - which uses the Levenberg-Marquardt algorithm to perform optimization. - This algorithm adjusts the learning rate at each step to find the optimal value. - -Grad: This module defines the class Gradient-Optimizer, - which uses a simple gradient descent algorithm to perform optimization. - This algorithm adjusts the learning rate at each step to find the optimal value. - -Iterative: This module defines the class Iter, - which uses an iterative algorithm to perform Optimization. - This algorithm repeatedly fits each model individually until they all converge. - -""" +__all__ = [ + "LM", + "LMfast", + "Grad", + "Iter", + "MALA", + "IterParam", + "ScipyFit", + "MiniFit", + "HMC", + "MHMCMC", + "Slalom", + "func", +] diff --git a/astrophot/fit/base.py b/astrophot/fit/base.py index be9a2e56..b9152f9f 100644 --- a/astrophot/fit/base.py +++ b/astrophot/fit/base.py @@ -1,84 +1,65 @@ from typing import Sequence, Optional import numpy as np -import torch from scipy.optimize import minimize from scipy.special import gammainc -from .. import AP_config +from .. import config +from ..backend_obj import backend, ArrayLike +from ..models import Model +from ..image import Window -__all__ = ["BaseOptimizer"] +__all__ = ("BaseOptimizer",) -class BaseOptimizer(object): +class BaseOptimizer: """ Base optimizer object that other optimizers inherit from. Ensures consistent signature for the classes. - Parameters: - model: an AstroPhot_Model object that will have its (unlocked) parameters optimized [AstroPhot_Model] - initial_state: optional initialization for the parameters as a 1D tensor [tensor] - max_iter: maximum allowed number of iterations [int] - relative_tolerance: tolerance for counting success steps as: 0 < (Chi2^2 - Chi1^2)/Chi1^2 < tol [float] + **Args:** + - `model`: an AstroPhot_Model object that will have its (unlocked) parameters optimized [AstroPhot_Model] + - `initial_state`: optional initialization for the parameters as a 1D tensor [tensor] + - `relative_tolerance`: tolerance for counting success steps as: $0 < (\\chi_2^2 - \\chi_1^2)/\\chi_1^2 < \\text{tol}$ [float] + - `fit_window`: optional window to fit the model on [Window] + - `verbose`: verbosity level for the optimizer [int] + - `max_iter`: maximum allowed number of iterations [int] + - `save_steps`: optional string for path to save the model at each step (fitter dependent), e.g. "model_step_{step}.hdf5" [str] + - `fit_valid`: whether to fit while forcing parameters into valid range, or allow any value for each parameter. Default True [bool] """ def __init__( self, - model: "AstroPhot_Model", + model: Model, initial_state: Sequence = None, relative_tolerance: float = 1e-3, - fit_window: Optional["Window"] = None, - **kwargs, + fit_window: Optional[Window] = None, + verbose: int = 1, + max_iter: int = None, + save_steps: Optional[str] = None, + fit_valid: bool = True, ) -> None: - """ - Initializes a new instance of the class. - - Args: - model (object): An object representing the model. - initial_state (Optional[Sequence]): The initial state of the model could be any tensor. - If `None`, the model's default initial state will be used. - relative_tolerance (float): The relative tolerance for the optimization. - fit_parameters_identity (Optional[tuple]): a tuple of parameter identity strings which tell the LM optimizer which parameters of the model to fit. - **kwargs (dict): Additional keyword arguments. - - Attributes: - model (object): An object representing the model. - verbose (int): The verbosity level. - current_state (Tensor): The current state of the model. - max_iter (int): The maximum number of iterations. - iteration (int): The current iteration number. - save_steps (Optional[str]): Save intermediate results to this path. - relative_tolerance (float): The relative tolerance for the optimization. - lambda_history (List[ndarray]): A list of the optimization steps. - loss_history (List[float]): A list of the optimization losses. - message (str): An informational message. - """ self.model = model - self.verbose = kwargs.get("verbose", 0) + self.verbose = verbose + + if initial_state is None: + self.current_state = model.get_values() + else: + self.current_state = backend.as_array( + initial_state, dtype=config.DTYPE, device=config.DEVICE + ) if fit_window is None: self.fit_window = self.model.window else: self.fit_window = fit_window & self.model.window - if initial_state is None: - self.model.initialize() - initial_state = self.model.parameters.vector_representation() - else: - initial_state = torch.as_tensor( - initial_state, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - - self.current_state = torch.as_tensor( - initial_state, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - if self.verbose > 1: - AP_config.ap_logger.info(f"initial state: {self.current_state}") - self.max_iter = kwargs.get("max_iter", 100 * len(initial_state)) + self.max_iter = max_iter if max_iter is not None else 100 * len(self.current_state) self.iteration = 0 - self.save_steps = kwargs.get("save_steps", None) + self.save_steps = save_steps + self.fit_valid = fit_valid self.relative_tolerance = relative_tolerance self.lambda_history = [] @@ -86,71 +67,28 @@ def __init__( self.message = "" def fit(self) -> "BaseOptimizer": - """ - Raises: - NotImplementedError: Error is raised if this method is not implemented in a subclass of BaseOptimizer. - """ raise NotImplementedError("Please use a subclass of BaseOptimizer for optimization") - def step(self, current_state: torch.Tensor = None) -> None: - """Args: - current_state (torch.Tensor, optional): Current state of the model parameters. Defaults to None. - - Raises: - NotImplementedError: Error is raised if this method is not implemented in a subclass of BaseOptimizer. - """ + def step(self, current_state: ArrayLike = None) -> None: raise NotImplementedError("Please use a subclass of BaseOptimizer for optimization") def chi2min(self) -> float: """ Returns the minimum value of chi^2 loss in the loss history. - - Returns: - float: Minimum value of chi^2 loss. """ return np.nanmin(self.loss_history) def res(self) -> np.ndarray: - """Returns the value of lambda (regularization strength) at which minimum chi^2 loss was achieved. - - Returns: ndarray which is the Value of lambda at which minimum chi^2 loss was achieved. - """ + """Returns the value of lambda (state parameters) at which minimum loss was achieved.""" N = np.isfinite(self.loss_history) if np.sum(N) == 0: - AP_config.ap_logger.warning( + config.logger.warning( "Getting optimizer res with no real loss history, using current state" ) - return self.current_state.detach().cpu().numpy() + return backend.to_numpy(self.current_state) return np.array(self.lambda_history)[N][np.argmin(np.array(self.loss_history)[N])] def res_loss(self): + """returns the minimum value from the loss history.""" N = np.isfinite(self.loss_history) return np.min(np.array(self.loss_history)[N]) - - @staticmethod - def chi2contour(n_params: int, confidence: float = 0.682689492137) -> float: - """ - Calculates the chi^2 contour for the given number of parameters. - - Args: - n_params (int): The number of parameters. - confidence (float, optional): The confidence interval (default is 0.682689492137). - - Returns: - float: The calculated chi^2 contour value. - - Raises: - RuntimeError: If unable to compute the Chi^2 contour for the given number of parameters. - - """ - - def _f(x: float, nu: int) -> float: - """Helper function for calculating chi^2 contour.""" - return (gammainc(nu / 2, x / 2) - confidence) ** 2 - - for method in ["L-BFGS-B", "Powell", "Nelder-Mead"]: - res = minimize(_f, x0=n_params, args=(n_params,), method=method, tol=1e-8) - - if res.success: - return res.x[0] - raise RuntimeError(f"Unable to compute Chi^2 contour for ndf: {ndf}") diff --git a/astrophot/fit/func/__init__.py b/astrophot/fit/func/__init__.py new file mode 100644 index 00000000..58da703e --- /dev/null +++ b/astrophot/fit/func/__init__.py @@ -0,0 +1,13 @@ +from .lm import lm_step, hessian, gradient, hessian_poisson, gradient_poisson +from .slalom import slalom_step +from .mala import mala + +__all__ = [ + "lm_step", + "hessian", + "gradient", + "slalom_step", + "hessian_poisson", + "gradient_poisson", + "mala", +] diff --git a/astrophot/fit/func/lm.py b/astrophot/fit/func/lm.py new file mode 100644 index 00000000..2e88c816 --- /dev/null +++ b/astrophot/fit/func/lm.py @@ -0,0 +1,158 @@ +import numpy as np + +from ...errors import OptimizeStopFail, OptimizeStopSuccess +from ...backend_obj import backend +from ... import config + + +def nll(D, M, W): + """ + Negative log-likelihood for Gaussian noise. + D: data + M: model prediction + W: weights + """ + return 0.5 * backend.sum(W * (D - M) ** 2) + + +def nll_poisson(D, M): + """ + Negative log-likelihood for Poisson noise. + D: data + M: model prediction + """ + return backend.sum(M - D * backend.log(M + 1e-10)) # Adding small value to avoid log(0) + + +def gradient(J, W, D, M): + return J.T @ (W * (D - M))[:, None] + + +def gradient_poisson(J, D, M): + return J.T @ (D / M - 1)[:, None] + + +def hessian(J, W): + return J.T @ (W[:, None] * J) + + +def hessian_poisson(J, D, M): + return J.T @ ((D / (M**2 + 1e-10))[:, None] * J) + + +def damp_hessian(hess, L): + I = backend.eye(len(hess), dtype=config.DTYPE, device=config.DEVICE) + D = backend.ones_like(hess) - I + return hess * (I + D / (1 + L)) + L * I * backend.diag(hess) + + +def solve(hess, grad, L): + hessD = damp_hessian(hess, L) # (N, N) + while True: + try: + h = backend.linalg.solve(hessD, grad) + break + except backend.LinAlgErr: + hessD = hessD + L * backend.eye(len(hessD), dtype=config.DTYPE, device=config.DEVICE) + L = L * 2 + return hessD, h + + +def lm_step( + x, + data, + model, + weight, + jacobian, + L=1.0, + Lup=9.0, + Ldn=11.0, + tolerance=1e-4, + likelihood="gaussian", +): + L0 = L + M0 = backend.detach(model(x)) # (M,) + J = backend.detach(jacobian(x)) # (M, N) + + if likelihood == "gaussian": + nll0 = nll(data, M0, weight).item() + grad = gradient(J, weight, data, M0) # (N, 1) + hess = hessian(J, weight) # (N, N) + elif likelihood == "poisson": + nll0 = nll_poisson(data, M0).item() + grad = gradient_poisson(J, data, M0) # (N, 1) + hess = hessian_poisson(J, data, M0) # (N, N) + else: + raise ValueError(f"Unsupported likelihood: {likelihood}") + + del J + + if backend.allclose(grad, backend.zeros_like(grad)): + raise OptimizeStopSuccess("Gradient is zero, optimization converged.") + + best = {"x": backend.zeros_like(x), "nll": nll0, "L": L} + scary = {"x": None, "nll": np.inf, "L": None, "rho": np.inf} + nostep = True + improving = None + for i in range(10): + hessD, h = solve(hess, grad, L) # (N, N), (N, 1) + M1 = model(x + h.squeeze(1)) # (M,) + if likelihood == "gaussian": + nll1 = nll(data, M1, weight).item() + elif likelihood == "poisson": + nll1 = nll_poisson(data, M1).item() + + # Handle nan chi2 + if not np.isfinite(nll1): + L *= Lup + if improving is True: + break + improving = False + continue + + if backend.allclose(h, backend.zeros_like(h)) and L < 0.1: + if i == 0: + raise OptimizeStopSuccess("Step with zero length means optimization complete.") + break + + # actual nll improvement vs expected from linearization + rho = (nll0 - nll1) / backend.abs(h.T @ hessD @ h - 2 * grad.T @ h).item() + + if (nll1 < (nll0 + tolerance) and abs(rho - 1) < abs(scary["rho"] - 1)) or ( + nll1 < scary["nll"] and rho > -10 + ): + scary = {"x": x + h.squeeze(1), "nll": nll1, "L": L0, "rho": rho} + + # Avoid highly non-linear regions + if rho < 0.1 or rho > 2: + L *= Lup + if improving is True: + break + improving = False + continue + + if nll1 < best["nll"]: # new best + best = {"x": x + h.squeeze(1), "nll": nll1, "L": L} + nostep = False + L /= Ldn + if L < 1e-8 or improving is False: + break + improving = True + elif improving is True: # were improving, now not improving + break + else: # not improving and bad chi2, damp more + L *= Lup + if L >= 1e9: + break + improving = False + + # If we are improving chi2 by more than 10% then we can stop + if (best["nll"] - nll0) / nll0 < -0.1: + break + + if nostep: + if scary["x"] is not None and (scary["nll"] - nll0) / nll0 < tolerance: + return scary + raise OptimizeStopFail("Could not find step to improve chi^2") + + return best diff --git a/astrophot/fit/func/mala.py b/astrophot/fit/func/mala.py new file mode 100644 index 00000000..e6ae0b30 --- /dev/null +++ b/astrophot/fit/func/mala.py @@ -0,0 +1,73 @@ +import numpy as np +from tqdm import tqdm + + +def mala( + initial_state, # (num_chains, D) + log_prob, # (num_chains, D) -> (num_chains,) + log_prob_grad, # (num_chains, D) -> (num_chains, D) + num_samples, + epsilon, + mass_matrix, # covariance + progress=True, + desc="MALA", +): + x = np.array(initial_state, copy=True) + C, D = x.shape + + # mass, inv_mass, L + mass = np.array(mass_matrix, copy=False) # (D, D) + inv_mass = np.linalg.inv(mass) # (D, D) + L = np.linalg.cholesky(mass) # (D, D) + + samples = np.zeros((num_samples, C, D), dtype=x.dtype) # (N, C, D) + acceptance_rate = np.zeros([0]) # (0,) + logp = np.zeros((num_samples, C), dtype=x.dtype) # (N, C) + + # Cache current state + logp_cur = log_prob(x) # (C,) + grad_cur = log_prob_grad(x) # (C, D) + + # Random number generator + rng = np.random.default_rng(np.random.randint(1e9)) + + it = range(num_samples) + if progress: + it = tqdm(it, desc=desc, position=0, leave=True) + + for t in it: + # proposal using current grad + mu_x = 0.5 * (epsilon**2) * (grad_cur @ mass) # (C, D) + noise = rng.standard_normal((C, D)) @ L.T # (C, D) + x_prop = x + mu_x + epsilon * noise # (C, D) + + # Evaluate proposal + logp_prop = log_prob(x_prop) # (C,) + grad_prop = log_prob_grad(x_prop) # (C, D) + + mu_xprop = 0.5 * (epsilon**2) * (grad_prop @ mass) # (C, D) + + # q(x|x') \propto \exp(-0.5|x - x' - mu(x')|^2 / \epsilon^2) + d1 = x - x_prop - mu_xprop # for q(x | x') + d2 = x_prop - x - mu_x # for q(x'| x) + + logq1 = -0.5 * np.einsum("bi,ij,bj->b", d1, inv_mass, d1) / epsilon**2 # (C,) + logq2 = -0.5 * np.einsum("bi,ij,bj->b", d2, inv_mass, d2) / epsilon**2 # (C,) + + log_alpha = (logp_prop - logp_cur) + (logq1 - logq2) # (C,) + + accept = np.log(rng.random(C)) < log_alpha # (C,) + acceptance_rate = np.concatenate([acceptance_rate, accept]) + + # Update all three pieces in-place where accepted + x[accept] = x_prop[accept] # (C, D) + logp_cur[accept] = logp_prop[accept] # (C,) + grad_cur[accept] = grad_prop[accept] # (C, D) + + samples[t] = x.copy() + logp[t] = logp_cur.copy() + + if progress: + it.set_postfix(acc_rate=f"{acceptance_rate.mean():0.2f}") + + return samples, logp diff --git a/astrophot/fit/func/slalom.py b/astrophot/fit/func/slalom.py new file mode 100644 index 00000000..1bb76d68 --- /dev/null +++ b/astrophot/fit/func/slalom.py @@ -0,0 +1,49 @@ +import numpy as np + +from ...errors import OptimizeStopFail, OptimizeStopSuccess +from ...backend_obj import backend + + +def slalom_step(f, g, x0, m, S, N=10, up=1.3, down=0.5): + l = [f(x0).item()] + d = [0.0] + grad = g(x0) + if backend.allclose(grad, backend.zeros_like(grad)): + raise OptimizeStopSuccess("success: Gradient is zero, optimization converged.") + + D = grad + m + D = D / backend.linalg.norm(D) + seeking = False + for _ in range(N): + l.append(f(x0 - S * D).item()) + d.append(S) + + # Check if the last value is finite + if not np.isfinite(l[-1]): + l.pop() + d.pop() + S *= down + continue + + if seeking and np.argmin(l) == len(l) - 1: + # If we are seeking a minimum and the last value is the minimum, we can stop + break + + if len(l) < 3: + # Seek better step size based on loss improvement + if l[-1] < l[-2]: + S *= up + else: + S *= down + else: + O = np.polyfit(d[-3:], l[-3:], 2) + if O[0] > 0: + S = -O[1] / (2 * O[0]) + seeking = True + else: + S *= down + seeking = False + + if np.argmin(l) == 0: + raise OptimizeStopFail("fail: cannot find step to improve.") + return d[np.argmin(l)], l[np.argmin(l)], grad diff --git a/astrophot/fit/gp.py b/astrophot/fit/gp.py deleted file mode 100644 index f01212f3..00000000 --- a/astrophot/fit/gp.py +++ /dev/null @@ -1 +0,0 @@ -# Gaussian Process Regression diff --git a/astrophot/fit/gradient.py b/astrophot/fit/gradient.py index 4152be4c..11ae29a3 100644 --- a/astrophot/fit/gradient.py +++ b/astrophot/fit/gradient.py @@ -1,65 +1,61 @@ # Traditional gradient descent with Adam from time import time from typing import Sequence +from caskade import ValidContext import torch import numpy as np from .base import BaseOptimizer -from .. import AP_config +from .. import config +from ..backend_obj import backend, ArrayLike +from ..models import Model +from ..errors import OptimizeStopFail, OptimizeStopSuccess +from . import func +from ..utils.decorators import combine_docstrings __all__ = ["Grad"] +@combine_docstrings class Grad(BaseOptimizer): - """A gradient descent optimization wrapper for AstroPhot_Model objects. + """A gradient descent optimization wrapper for AstroPhot Model objects. The default method is "NAdam", a variant of the Adam optimization algorithm. This optimizer uses a combination of gradient descent and Nesterov momentum for faster convergence. The optimizer is instantiated with a set of initial parameters and optimization options provided by the user. The `fit` method performs the optimization, taking a series of gradient steps until a stopping criteria is met. - Parameters: - model (AstroPhot_Model): an AstroPhot_Model object with which to perform optimization. - initial_state (torch.Tensor, optional): an optional initial state for optimization. - method (str, optional): the optimization method to use for the update step. Defaults to "NAdam". - patience (int or None, optional): the number of iterations without improvement before the optimizer will exit early. Defaults to None. - optim_kwargs (dict, optional): a dictionary of keyword arguments to pass to the pytorch optimizer. - - Attributes: - model (AstroPhot_Model): the AstroPhot_Model object to optimize. - current_state (torch.Tensor): the current state of the parameters being optimized. - iteration (int): the number of iterations performed during the optimization. - loss_history (list): the history of loss values at each iteration of the optimization. - lambda_history (list): the history of parameter values at each iteration of the optimization. - optimizer (torch.optimizer): the PyTorch optimizer object being used. - patience (int or None): the number of iterations without improvement before the optimizer will exit early. - method (str): the optimization method being used. - optim_kwargs (dict): the dictionary of keyword arguments passed to the PyTorch optimizer. - - + **Args:** + - `likelihood` (str, optional): The likelihood function to use for the optimization. Defaults to "gaussian". + - `method` (str, optional): the optimization method to use for the update step. Defaults to "NAdam". + - `optim_kwargs` (dict, optional): a dictionary of keyword arguments to pass to the pytorch optimizer. + - `patience` (int, optional): number of steps with no improvement before stopping the optimization. Defaults to 10. + - `report_freq` (int, optional): frequency of reporting the optimization progress. Defaults to 10 steps. """ - def __init__(self, model: "AstroPhot_Model", initial_state: Sequence = None, **kwargs) -> None: - """Initialize the gradient descent optimizer. - - Args: - - model: instance of the model to be optimized. - - initial_state: Initial state of the model. - - patience: (optional) If a positive integer, then stop the optimization if there has been no improvement in the loss for this number of iterations. - - method: (optional) The name of the optimization method to use. Default is NAdam. - - optim_kwargs: (optional) Keyword arguments to be passed to the optimizer. - """ + def __init__( + self, + model: Model, + initial_state: Sequence = None, + likelihood="gaussian", + method="NAdam", + optim_kwargs={}, + patience: int = 10, + report_freq=10, + **kwargs, + ) -> None: super().__init__(model, initial_state, **kwargs) - self.model.parameters.flat_detach() + + self.likelihood = likelihood # set parameters from the user - self.patience = kwargs.get("patience", None) - self.method = kwargs.get("method", "NAdam").strip() - self.optim_kwargs = kwargs.get("optim_kwargs", {}) - self.report_freq = kwargs.get("report_freq", 10) + self.patience = patience + self.method = method + self.optim_kwargs = optim_kwargs + self.report_freq = report_freq - # Default learning rate if none given. Equalt to 1 / sqrt(parames) + # Default learning rate if none given. Equal to 1 / sqrt(parames) if "lr" not in self.optim_kwargs: self.optim_kwargs["lr"] = 0.1 / (len(self.current_state) ** (0.5)) @@ -69,26 +65,22 @@ def __init__(self, model: "AstroPhot_Model", initial_state: Sequence = None, **k (self.current_state,), **self.optim_kwargs ) - def compute_loss(self) -> torch.Tensor: - Ym = self.model(parameters=self.current_state, as_representation=True).flatten("data") - Yt = self.model.target[self.model.window].flatten("data") - W = ( - self.model.target[self.model.window].flatten("variance") - if self.model.target.has_variance - else 1.0 - ) - ndf = len(Yt) - len(self.current_state) - if self.model.target.has_mask: - mask = self.model.target[self.model.window].flatten("mask") - ndf -= torch.sum(mask) - mask = torch.logical_not(mask) - loss = torch.sum((Ym[mask] - Yt[mask]) ** 2 / W[mask]) / ndf + def density(self, state: torch.Tensor) -> torch.Tensor: + """ + Returns the density of the model at the given state vector. This is used + to calculate the likelihood of the model at the given state. Based on + ``self.likelihood``, will be either the Gaussian or Poisson negative log + likelihood. + """ + if self.likelihood == "gaussian": + return -self.model.gaussian_log_likelihood(state) + elif self.likelihood == "poisson": + return -self.model.poisson_log_likelihood(state) else: - loss = torch.sum((Ym - Yt) ** 2 / W) / ndf - return loss + raise ValueError(f"Unknown likelihood type: {self.likelihood}") def step(self) -> None: - """Take a single gradient step. Take a single gradient step. + """Take a single gradient step. Computes the loss function of the model, computes the gradient of the parameters using automatic differentiation, @@ -98,24 +90,23 @@ def step(self) -> None: self.iteration += 1 self.optimizer.zero_grad() - self.model.parameters.flat_detach() - - loss = self.compute_loss() + self.current_state.requires_grad = True + loss = self.density(self.current_state) loss.backward() - self.loss_history.append(loss.detach().cpu().item()) - self.lambda_history.append(np.copy(self.current_state.detach().cpu().numpy())) + self.loss_history.append(backend.to_numpy(loss)) + self.lambda_history.append(np.copy(backend.to_numpy(self.current_state))) if ( self.iteration % int(self.max_iter / self.report_freq) == 0 ) or self.iteration == self.max_iter: if self.verbose > 0: - AP_config.ap_logger.info(f"iter: {self.iteration}, loss: {loss.item()}") + config.logger.info(f"iter: {self.iteration}, posterior density: {loss.item():.6e}") if self.verbose > 1: - AP_config.ap_logger.info(f"gradient: {self.current_state.grad}") + config.logger.info(f"gradient: {self.current_state.grad}") self.optimizer.step() - def fit(self) -> "BaseOptimizer": + def fit(self) -> BaseOptimizer: """ Perform an iterative fit of the model parameters using the specified optimizer. @@ -138,17 +129,140 @@ def fit(self) -> "BaseOptimizer": self.message = self.message + " fail no improvement" break L = np.sort(self.loss_history) - if len(L) >= 3 and 0 < L[1] - L[0] < 1e-6 and 0 < L[2] - L[1] < 1e-6: + if len(L) >= 5 and 0 < (L[4] - L[0]) / L[0] < self.relative_tolerance: self.message = self.message + " success" break except KeyboardInterrupt: self.message = self.message + " fail interrupted" # Set the model parameters to the best values from the fit and clear any previous model sampling - self.model.parameters.vector_set_representation(self.res()) + self.model.set_values(torch.tensor(self.res(), dtype=config.DTYPE, device=config.DEVICE)) if self.verbose > 1: - AP_config.ap_logger.info( + config.logger.info( f"Grad Fitting complete in {time() - start_fit} sec with message: {self.message}" ) - self.model.parameters.flat_detach() + return self + + +class Slalom(BaseOptimizer): + """Slalom optimizer for Model objects. + + Slalom is a gradient descent optimization algorithm that uses a few + evaluations along the direction of the gradient to find the optimal step + size. This is done by assuming that the posterior density is a parabola and + then finding the minimum. + + The optimizer quickly finds the minimum of the posterior density along the + gradient direction, then updates the gradient at the new position and + repeats. This continues until it reaches a set of 5 steps which collectively + improve the posterior density by an amount smaller than the + `relative_tolerance` threshold, indicating that convergence has been + achieved. Note that this convergence criteria is not a guarantee, simply a + heuristic. The default tolerance was such that the optimizer will + substantially improve from the starting point, and do so quickly, but may + not reach all the way to the minimum of the posterior density. Like other + gradient descent algorithms, Slalom slows down considerably when trying to + achieve very high precision. + + **Args:** + - `S` (float, optional): The initial step size for the Slalom optimizer. Defaults to 1e-4. + - `likelihood` (str, optional): The likelihood function to use for the optimization. Defaults to "gaussian". + - `report_freq` (int, optional): Frequency of reporting the optimization progress. Defaults to 10 steps. + - `relative_tolerance` (float, optional): The relative tolerance for convergence. Defaults to 1e-4. + - `momentum` (float, optional): The momentum factor for the Slalom optimizer. Defaults to 0.5. + - `max_iter` (int, optional): The maximum number of iterations for the optimizer. Defaults to 1000. + """ + + def __init__( + self, + model: Model, + initial_state: Sequence = None, + S=1e-4, + likelihood: str = "gaussian", + report_freq: int = 10, + relative_tolerance: float = 1e-4, + momentum: float = 0.5, + max_iter: int = 1000, + **kwargs, + ) -> None: + """Initialize the Slalom optimizer.""" + super().__init__( + model, initial_state, relative_tolerance=relative_tolerance, max_iter=max_iter, **kwargs + ) + self.likelihood = likelihood + self.S = S + self.report_freq = report_freq + self.momentum = momentum + + def density(self, state: ArrayLike) -> ArrayLike: + """Calculate the density of the model at the given state. Based on + ``self.likelihood``, will be either the Gaussian or Poisson negative log + likelihood.""" + if self.likelihood == "gaussian": + return -self.model.gaussian_log_likelihood(state) + elif self.likelihood == "poisson": + return -self.model.poisson_log_likelihood(state) + else: + raise ValueError(f"Unknown likelihood type: {self.likelihood}") + + def fit(self) -> BaseOptimizer: + """Perform the Slalom optimization.""" + + grad_func = backend.grad(self.density) + momentum = backend.zeros_like(self.current_state) + self.S_history = [self.S] + self.loss_history = [self.density(self.current_state).item()] + self.lambda_history = [backend.to_numpy(self.current_state)] + self.start_fit = time() + + for i in range(self.max_iter): + + try: + # Perform the Slalom step + vstate = self.model.to_valid(self.current_state) + with ValidContext(self.model): + self.S, loss, grad = func.slalom_step( + self.density, grad_func, vstate, m=momentum, S=self.S + ) + self.current_state = self.model.from_valid( + vstate - self.S * (grad + momentum) / backend.linalg.norm(grad + momentum) + ) + momentum = self.momentum * (momentum + grad) + except OptimizeStopSuccess as e: + self.message = self.message + str(e) + break + except OptimizeStopFail as e: + if backend.allclose(momentum, backend.zeros_like(momentum)): + self.message = self.message + str(e) + break + momentum = backend.zeros_like(self.current_state) + continue + # Log the loss + self.S_history.append(self.S) + self.loss_history.append(loss) + self.lambda_history.append(backend.to_numpy(self.current_state)) + + if self.verbose > 0 and (i % int(self.report_freq) == 0 or i == self.max_iter - 1): + config.logger.info( + f"iter: {i}, step size: {self.S:.6e}, posterior density: {loss:.6e}" + ) + + if len(self.loss_history) >= 5: + relative_loss = (self.loss_history[-5] - self.loss_history[-1]) / self.loss_history[ + -1 + ] + if relative_loss < self.relative_tolerance: + self.message = self.message + " success" + break + else: + self.message = self.message + " fail. max iteration reached" + + # Set the model parameters to the best values from the fit + self.model.set_values( + backend.as_array(self.res(), dtype=config.DTYPE, device=config.DEVICE) + ) + if self.verbose > 0: + config.logger.info( + f"Slalom Fitting complete in {time() - self.start_fit} sec with message: {self.message}" + ) return self diff --git a/astrophot/fit/hmc.py b/astrophot/fit/hmc.py index f5e7d466..f726f8ee 100644 --- a/astrophot/fit/hmc.py +++ b/astrophot/fit/hmc.py @@ -2,17 +2,22 @@ from typing import Optional, Sequence import torch -import pyro -import pyro.distributions as dist -from pyro.infer import MCMC as pyro_MCMC -from pyro.infer import HMC as pyro_HMC -from pyro.infer.mcmc.adaptation import BlockMassMatrix -from pyro.ops.welford import WelfordCovariance + +try: + import pyro + import pyro.distributions as dist + from pyro.infer import MCMC as pyro_MCMC + from pyro.infer import HMC as pyro_HMC + from pyro.infer.mcmc.adaptation import BlockMassMatrix + from pyro.ops.welford import WelfordCovariance +except ImportError: + pyro = None from .base import BaseOptimizer -from ..models import AstroPhot_Model +from ..models import Model +from .. import config -__all__ = ["HMC"] +__all__ = ("HMC",) ########################################### @@ -24,10 +29,11 @@ def new_configure(self, mass_matrix_shape, adapt_mass_matrix=True, options={}): """ Sets up an initial mass matrix. - :param dict mass_matrix_shape: a dict that maps tuples of site names to the shape of + **Args:** + - `mass_matrix_shape`: a dict that maps tuples of site names to the shape of the corresponding mass matrix. Each tuple of site names corresponds to a block. - :param bool adapt_mass_matrix: a flag to decide whether an adaptation scheme will be used. - :param dict options: tensor options to construct the initial mass matrix. + - `adapt_mass_matrix`: a flag to decide whether an adaptation scheme will be used. + - `options`: tensor options to construct the initial mass matrix. """ inverse_mass_matrix = {} for site_names, shape in mass_matrix_shape.items(): @@ -53,48 +59,56 @@ def new_configure(self, mass_matrix_shape, adapt_mass_matrix=True, options={}): class HMC(BaseOptimizer): """Hamiltonian Monte-Carlo sampler wrapper for the Pyro package. - This MCMC algorithm uses gradients of the Chi^2 to more - efficiently explore the probability distribution. Consider using - the NUTS sampler instead of HMC, as it is generally better in most - aspects. + This MCMC algorithm uses gradients of the $\\chi^2$ to more + efficiently explore the probability distribution. More information on HMC can be found at: https://en.wikipedia.org/wiki/Hamiltonian_Monte_Carlo, https://arxiv.org/abs/1701.02434, and http://www.mcmchandbook.net/HandbookChapter5.pdf - Args: - model (AstroPhot_Model): The model which will be sampled. - initial_state (Optional[Sequence], optional): A 1D array with the values for each parameter in the model. These values should be in the form of "as_representation" in the model. Defaults to None. - max_iter (int, optional): The number of sampling steps to perform. Defaults to 1000. - epsilon (float, optional): The length of the integration step to perform for each leapfrog iteration. The momentum update will be of order epsilon * score. Defaults to 1e-5. - leapfrog_steps (int, optional): Number of steps to perform with leapfrog integrator per sample of the HMC. Defaults to 20. - inv_mass (float or array, optional): Inverse Mass matrix (covariance matrix) which can tune the behavior in each dimension to ensure better mixing when sampling. Defaults to the identity. - progress_bar (bool, optional): Whether to display a progress bar during sampling. Defaults to True. - prior (distribution, optional): Prior distribution for the parameters. Defaults to None. - warmup (int, optional): Number of warmup steps before actual sampling begins. Defaults to 100. - hmc_kwargs (dict, optional): Additional keyword arguments for the HMC sampler. Defaults to {}. - mcmc_kwargs (dict, optional): Additional keyword arguments for the MCMC process. Defaults to {}. + **Args:** + - `max_iter` (int, optional): The number of sampling steps to perform. Defaults to 1000. + - `epsilon` (float, optional): The length of the integration step to perform for each leapfrog iteration. The momentum update will be of order epsilon * score. Defaults to 1e-5. + - `leapfrog_steps` (int, optional): Number of steps to perform with leapfrog integrator per sample of the HMC. Defaults to 10. + - `inv_mass` (float or array, optional): Inverse Mass matrix (covariance matrix) which can tune the behavior in each dimension to ensure better mixing when sampling. Defaults to the identity. + - `progress_bar` (bool, optional): Whether to display a progress bar during sampling. Defaults to True. + - `prior` (distribution, optional): Prior distribution for the parameters. Defaults to None. + - `warmup` (int, optional): Number of warmup steps before actual sampling begins. Defaults to 100. + - `hmc_kwargs` (dict, optional): Additional keyword arguments for the HMC sampler. Defaults to {}. + - `mcmc_kwargs` (dict, optional): Additional keyword arguments for the MCMC process. Defaults to {}. """ def __init__( self, - model: AstroPhot_Model, + model: Model, initial_state: Optional[Sequence] = None, max_iter: int = 1000, + inv_mass: Optional[torch.Tensor] = None, + epsilon: float = 1e-4, + leapfrog_steps: int = 10, + progress_bar: bool = True, + prior: Optional[dist.Distribution] = None, + warmup: int = 100, + hmc_kwargs: dict = {}, + mcmc_kwargs: dict = {}, + likelihood: str = "gaussian", **kwargs, ): + if pyro is None: + raise ImportError("Pyro must be installed to use HMC.") super().__init__(model, initial_state, max_iter=max_iter, **kwargs) - self.inv_mass = kwargs.get("inv_mass", None) - self.epsilon = kwargs.get("epsilon", 1e-3) - self.leapfrog_steps = kwargs.get("leapfrog_steps", 20) - self.progress_bar = kwargs.get("progress_bar", True) - self.prior = kwargs.get("prior", None) - self.warmup = kwargs.get("warmup", 100) - self.hmc_kwargs = kwargs.get("hmc_kwargs", {}) - self.mcmc_kwargs = kwargs.get("mcmc_kwargs", {}) + self.inv_mass = inv_mass + self.epsilon = epsilon + self.leapfrog_steps = leapfrog_steps + self.progress_bar = progress_bar + self.prior = prior + self.warmup = warmup + self.hmc_kwargs = hmc_kwargs + self.mcmc_kwargs = mcmc_kwargs + self.likelihood = likelihood self.acceptance = None def fit( @@ -105,21 +119,20 @@ def fit( Records the chain for later examination. - Args: + **Args:** state (torch.Tensor, optional): Model parameters as a 1D tensor. - Returns: - HMC: An instance of the HMC class with updated chain. - """ def step(model, prior): x = pyro.sample("x", prior) # Log-likelihood function - model.parameters.flat_detach() - log_likelihood_value = -model.negative_log_likelihood( - parameters=x, as_representation=True - ) + if self.likelihood == "gaussian": + log_likelihood_value = model.gaussian_log_likelihood(params=x) + elif self.likelihood == "poisson": + log_likelihood_value = model.poisson_log_likelihood(params=x) + else: + raise ValueError(f"Unsupported likelihood type: {self.likelihood}") # Observe the log-likelihood pyro.factor("obs", log_likelihood_value) @@ -145,7 +158,7 @@ def step(model, prior): hmc_kernel.mass_matrix_adapter.inverse_mass_matrix = {("x",): self.inv_mass} # Provide an initial guess for the parameters - init_params = {"x": self.model.parameters.vector_representation()} + init_params = {"x": self.model.get_values()} # Run MCMC with the HMC sampler and the initial guess mcmc_kwargs = { @@ -163,9 +176,8 @@ def step(model, prior): # Extract posterior samples chain = mcmc.get_samples()["x"] - with torch.no_grad(): - for i in range(len(chain)): - chain[i] = self.model.parameters.vector_transform_rep_to_val(chain[i]) self.chain = chain - + self.model.set_values( + torch.as_tensor(self.chain[-1], dtype=config.DTYPE, device=config.DEVICE) + ) return self diff --git a/astrophot/fit/iterative.py b/astrophot/fit/iterative.py index ff04b934..2e9330ca 100644 --- a/astrophot/fit/iterative.py +++ b/astrophot/fit/iterative.py @@ -1,126 +1,109 @@ # Apply a different optimizer iteratively -from typing import Dict, Any, Sequence, Union -import os +from typing import Dict, Any, Union, Sequence, Literal from time import time -import random +from functools import partial +from caskade import ValidContext import numpy as np import torch from .base import BaseOptimizer -from ..models import AstroPhot_Model +from ..models import Model from .lm import LM -from ..param import Param_Mask -from .. import AP_config +from .. import config +from ..backend_obj import backend +from ..errors import OptimizeStopSuccess, OptimizeStopFail +from . import func -__all__ = ["Iter", "Iter_LM"] +__all__ = [ + "Iter", + # "Iter_LM" +] class Iter(BaseOptimizer): """Optimizer wrapper that performs optimization iteratively. - This optimizer applies a different optimizer to a group model iteratively. - It can be used for complex fits or when the number of models to fit is too large to fit in memory. - - Args: - model: An `AstroPhot_Model` object to perform optimization on. - method: The optimizer class to apply at each iteration step. - initial_state: Optional initial state for optimization, defaults to None. - max_iter: Maximum number of iterations, defaults to 100. - method_kwargs: Keyword arguments to pass to `method`. - **kwargs: Additional keyword arguments. - - Attributes: - ndf: Degrees of freedom of the data. - method: The optimizer class to apply at each iteration step. Default: Levenberg-Marquardt - method_kwargs: Keyword arguments to pass to `method`. - iteration: The number of iterations performed. - lambda_history: A list of the states at each iteration step. - loss_history: A list of the losses at each iteration step + This optimizer applies the LM optimizer to a group model iteratively one + model at a time. It can be used for complex fits or when the number of + models to fit is too large to fit in memory. Note that it will iterate + through the group model, but if models within the group are themselves group + models, then they will be optimized as a whole. This gives some flexibility + to structure the models in a useful way. + + If not given, the `lm_kwargs` will be set to a relative tolerance of 1e-3 + and a maximum of 15 iterations. This is to allow for faster convergence, it + is not worthwhile for a single model to spend lots of time optimizing when + its neighbors havent converged. + + **Args:** + - `max_iter`: Maximum number of iterations, defaults to 100. + - `lm_kwargs`: Keyword arguments to pass to `LM` optimizer. """ def __init__( self, - model: AstroPhot_Model, - method: BaseOptimizer = LM, + model: Model, initial_state: np.ndarray = None, max_iter: int = 100, - method_kwargs: Dict[str, Any] = {}, + lm_kwargs: Dict[str, Any] = {"verbose": 0}, **kwargs: Dict[str, Any], - ) -> None: + ): super().__init__(model, initial_state, max_iter=max_iter, **kwargs) - self.method = method - self.method_kwargs = method_kwargs - if "relative_tolerance" not in method_kwargs and isinstance(method, LM): + self.current_state = model.get_values() + self.lm_kwargs = lm_kwargs + if "relative_tolerance" not in lm_kwargs: # Lower tolerance since it's not worth fine tuning a model when its neighbors will be shifting soon anyway - self.method_kwargs["relative_tolerance"] = 1e-3 - self.method_kwargs["max_iter"] = 15 + self.lm_kwargs["relative_tolerance"] = 1e-3 + self.lm_kwargs["max_iter"] = 15 # # pixels # parameters - self.ndf = self.model.target[self.model.window].flatten("data").size(0) - len( + self.ndf = self.model.target[self.model.window].flatten("data").shape[0] - len( self.current_state ) - if self.model.target.has_mask: - # subtract masked pixels from degrees of freedom - self.ndf -= torch.sum(self.model.target[self.model.window].flatten("mask")).item() + # subtract masked pixels from degrees of freedom + self.ndf -= backend.sum(self.model.target[self.model.window].flatten("mask")).item() - def sub_step(self, model: "AstroPhot_Model") -> None: + def sub_step(self, model: Model, update_uncertainty=False): """ Perform optimization for a single model. - - Args: - model: The model to perform optimization on. """ self.Y -= model() - initial_target = model.target - model.target = model.target[model.window] - self.Y[model.window] - res = self.method(model, **self.method_kwargs).fit() - model.parameters.flat_detach() + initial_values = model.target.copy() + model.target = model.target - self.Y + res = LM(model, **self.lm_kwargs).fit(update_uncertainty=update_uncertainty) self.Y += model() if self.verbose > 1: - AP_config.ap_logger.info(res.message) - model.target = initial_target + config.logger.info(res.message) + model.target = initial_values - def step(self) -> None: + def step(self): """ Perform a single iteration of optimization. """ if self.verbose > 0: - AP_config.ap_logger.info("--------iter-------") + config.logger.info("--------iter-------") # Fit each model individually - for model in self.model.models.values(): + for model in self.model.models: if self.verbose > 0: - AP_config.ap_logger.info(model.name) + config.logger.info(model.name) self.sub_step(model) # Update the current state - self.current_state = self.model.parameters.vector_representation() + self.current_state = self.model.get_values() # Update the loss value with torch.no_grad(): if self.verbose > 0: - AP_config.ap_logger.info("Update Chi^2 with new parameters") - self.Y = self.model( - parameters=self.current_state, - as_representation=True, - ) + config.logger.info("Update Chi^2 with new parameters") + self.Y = self.model(params=self.current_state) D = self.model.target[self.model.window].flatten("data") - V = ( - self.model.target[self.model.window].flatten("variance") - if self.model.target.has_variance - else 1.0 - ) - if self.model.target.has_mask: - M = self.model.target[self.model.window].flatten("mask") - loss = ( - torch.sum((((D - self.Y.flatten("data")) ** 2) / V)[torch.logical_not(M)]) - / self.ndf - ) - else: - loss = torch.sum(((D - self.Y.flatten("data")) ** 2 / V)) / self.ndf + V = self.model.target[self.model.window].flatten("variance") + M = self.model.target[self.model.window].flatten("mask") + loss = backend.sum((((D - self.Y.flatten("data")) ** 2) / V)[~M]) / self.ndf if self.verbose > 0: - AP_config.ap_logger.info(f"Loss: {loss.item()}") - self.lambda_history.append(np.copy((self.current_state).detach().cpu().numpy())) + config.logger.info(f"Loss: {loss.item()}") + self.lambda_history.append(np.copy(backend.to_numpy(self.current_state))) self.loss_history.append(loss.item()) # Test for convergence @@ -135,26 +118,16 @@ def step(self) -> None: self.iteration += 1 - def fit(self) -> "BaseOptimizer": + def fit(self) -> BaseOptimizer: """ - Fit the models to the target. - - + Perform the iterative fitting process until convergence or maximum iterations reached. """ - self.iteration = 0 - self.Y = self.model(parameters=self.current_state, as_representation=True) + self.Y = self.model(params=self.current_state) start_fit = time() try: while True: self.step() - if self.save_steps is not None: - self.model.save( - os.path.join( - self.save_steps, - f"{self.model.name}_Iteration_{self.iteration:03d}.yaml", - ) - ) if self.iteration > 2 and self._count_finish >= 2: self.message = self.message + "success" break @@ -165,174 +138,361 @@ def fit(self) -> "BaseOptimizer": except KeyboardInterrupt: self.message = self.message + "fail interrupted" - self.model.parameters.vector_set_representation(self.res()) + self.model.set_values( + backend.as_array(self.res(), dtype=config.DTYPE, device=config.DEVICE) + ) if self.verbose > 1: - AP_config.ap_logger.info( + config.logger.info( f"Iter Fitting complete in {time() - start_fit} sec with message: {self.message}" ) return self -class Iter_LM(BaseOptimizer): +class IterParam(BaseOptimizer): """Optimization wrapper that call LM optimizer on subsets of variables. - Iter_LM takes the full set of parameters for a model and breaks - them down into chunks as specified by the user. It then calls - Levenberg-Marquardt optimization on the subset of parameters, and - iterates through all subsets until every parameter has been - optimized. It cycles through these chunks until convergence. This - method is very powerful in situations where the full optimization - problem cannot fit in memory, or where the optimization problem is - too complex to tackle as a single large problem. In full LM - optimization a single problematic parameter can ripple into issues - with every other parameter, so breaking the problem down can - sometimes make an otherwise intractable problem easier. For small - problems with only a few models, it is likely better to optimize - the full problem with LM as, when it works, LM is faster than the - Iter_LM method. + IterParam takes the full set of parameters for a model and breaks them down + into chunks as specified by the user. It then calls Levenberg-Marquardt + optimization on the subset of parameters, and iterates through all subsets + until every parameter has been optimized. It cycles through these chunks + until convergence. This method is very powerful in situations where the full + optimization problem cannot fit in memory, or where the optimization problem + is too complex to tackle as a single large problem. In full LM optimization + a single problematic parameter can ripple into issues with every other + parameter, so breaking the problem down can sometimes make an otherwise + intractable problem easier. For small problems with only a few models, it is + likely better to optimize the full problem with LM as, when it works, LM is + faster than the IterParam method. Args: - chunks (Union[int, tuple]): Specify how to break down the model parameters. If an integer, at each iteration the algorithm will break the parameters into groups of that size. If a tuple, should be a tuple of tuples of strings which give an explicit pairing of parameters to optimize, note that it is allowed to have variable size chunks this way. Default: 50 - method (str): How to iterate through the chunks. Should be one of: random, sequential. Default: random + chunks (Union[int, tuple]): Specify how to break down the model + parameters. If an integer, at each iteration the algorithm will break the + parameters into groups of that size. If a tuple, should be a tuple of + arrays of length num_dimensions which act as selectors for the parameters + to fit (1 to include, 0 to exclude). Default: 50 + chunk_order (str): How to iterate through the chunks. Should be one of: random, + sequential. Default: sequential """ def __init__( self, - model: "AstroPhot_Model", + model: Model, initial_state: Sequence = None, chunks: Union[int, tuple] = 50, + chunk_order: Literal["random", "sequential"] = "sequential", max_iter: int = 100, - method: str = "random", - LM_kwargs: dict = {}, - **kwargs: Dict[str, Any], - ) -> None: - super().__init__(model, initial_state, max_iter=max_iter, **kwargs) + relative_tolerance: float = 1e-5, + Lup=11.0, + Ldn=9.0, + L0=1.0, + max_step_iter: int = 10, + ndf=None, + W=None, + likelihood="gaussian", + **kwargs, + ): + + super().__init__( + model, + initial_state, + max_iter=max_iter, + relative_tolerance=relative_tolerance, + **kwargs, + ) + # Maximum number of iterations of the algorithm + self.max_iter = max_iter + # Maximum number of steps while searching for chi^2 improvement on a single jacobian evaluation + self.max_step_iter = max_step_iter + self.Lup = Lup + self.Ldn = Ldn + self.L = L0 + self.likelihood = likelihood + if self.likelihood not in ["gaussian", "poisson"]: + raise ValueError(f"Unsupported likelihood: {self.likelihood}") + self.chunks = self.make_chunks(chunks) + self.chunk_order = chunk_order + + # mask + fit_mask = self.model.fit_mask() + if isinstance(fit_mask, tuple): + fit_mask = backend.concatenate(tuple(FM.flatten() for FM in fit_mask)) + else: + fit_mask = fit_mask.flatten() - self.chunks = chunks - self.method = method - self.LM_kwargs = LM_kwargs + mask = self.model.target[self.fit_window].flatten("mask") + mask = mask | fit_mask + self.mask = ~mask - # # pixels # parameters - self.ndf = self.model.target[self.model.window].flatten("data").numel() - len( - self.current_state + if backend.sum(self.mask).item() == 0: + raise OptimizeStopSuccess("No data to fit. All pixels are masked") + + # Initialize optimizer attributes + self.Y = self.model.target[self.fit_window].flatten("data")[self.mask] + + # 1 / (sigma^2) + if W is not None: + self.W = backend.as_array(W, dtype=config.DTYPE, device=config.DEVICE).flatten()[ + self.mask + ] + else: + self.W = self.model.target[self.fit_window].flatten("weight")[self.mask] + + # The forward model which computes the output image given input parameters + self.full_forward = lambda x: model(window=self.fit_window, params=x).flatten("data")[ + self.mask + ] + self.forward = [] + # Compute the jacobian + self.jacobian = [] + + f = lambda c, state, x: model( + window=self.fit_window, + params=backend.fill_at_indices(backend.copy(state), self.chunks[c], x), + ).flatten("data")[self.mask] + j = backend.jacfwd( + lambda c, state, x: self.model( + window=self.fit_window, + params=backend.fill_at_indices(backend.copy(state), self.chunks[c], x), + ).flatten("data")[self.mask], + argnums=2, ) - if self.model.target.has_mask: - # subtract masked pixels from degrees of freedom - self.ndf -= torch.sum(self.model.target[self.model.window].flatten("mask")).item() + for c in range(len(self.chunks)): + self.forward.append(partial(f, c)) + self.jacobian.append(partial(j, c)) - def step(self): - # These store the chunking information depending on which chunk mode is selected - param_ids = list(self.model.parameters.vector_identities()) - init_param_ids = list(self.model.parameters.vector_identities()) - _chunk_index = 0 - _chunk_choices = None - res = None + # variable to store covariance matrix if it is ever computed + self._covariance_matrix = None + + # Degrees of freedom + if ndf is None: + self.ndf = max(1.0, len(self.Y) - len(self.current_state)) + else: + self.ndf = ndf + + def make_chunks(self, chunks): + if isinstance(chunks, int): + new_chunks = [] + for i in range(0, len(self.current_state), chunks): + chunk = np.zeros(len(self.current_state), dtype=bool) + chunk[i : i + chunks] = True + new_chunks.append(chunk) + chunks = new_chunks + return chunks + + def iter_chunks(self): + if self.chunk_order == "random": + chunk_ids = list(range(len(self.chunks))) + np.random.shuffle(chunk_ids) + elif self.chunk_order == "sequential": + chunk_ids = list(range(len(self.chunks))) + else: + raise ValueError( + f"Unrecognized chunk_order: {self.chunk_order}. Should be one of: random, sequential" + ) + return chunk_ids + + def chi2_ndf(self): + return ( + backend.sum(self.W * (self.Y - self.full_forward(self.current_state)) ** 2) / self.ndf + ) + + def poisson_2nll_ndf(self): + M = self.full_forward(self.current_state) + return 2 * backend.sum(M - self.Y * backend.log(M + 1e-10)) / self.ndf + + @torch.no_grad() + def fit(self, update_uncertainty=True) -> BaseOptimizer: + """This performs the fitting operation. It iterates the LM step + function until convergence is reached. Includes a message + after fitting to indicate how the fitting exited. Typically if + the message returns a "success" then the algorithm found a + minimum. This may be the desired solution, or a pathological + local minimum, this often depends on the initial conditions. + + """ + if len(self.current_state) == 0: + if self.verbose > 0: + config.logger.warning("No parameters to optimize. Exiting fit") + self.message = "No parameters to optimize. Exiting fit" + return self + + if self.likelihood == "gaussian": + quantity = "Chi^2/DoF" + self.loss_history = [self.chi2_ndf().item()] + elif self.likelihood == "poisson": + quantity = "2NLL/DoF" + self.loss_history = [self.poisson_2nll_ndf().item()] + self._covariance_matrix = None + self.L_history = [self.L] + self.lambda_history = [backend.to_numpy(backend.copy(self.current_state))] if self.verbose > 0: - AP_config.ap_logger.info("--------iter-------") + config.logger.info( + f"==Starting LM fit for '{self.model.name}' with {len(self.current_state)} dynamic parameters and {len(self.Y)} pixels==" + ) - # Loop through all the chunks - while True: - chunk = torch.zeros(len(init_param_ids), dtype=torch.bool, device=AP_config.ap_device) - if isinstance(self.chunks, int): - if len(param_ids) == 0: - break - if self.method == "random": - # Draw a random chunk of ids - for pid in random.sample(param_ids, min(len(param_ids), self.chunks)): - chunk[init_param_ids.index(pid)] = True - else: - # Draw the next chunk of ids - for pid in param_ids[: self.chunks]: - chunk[init_param_ids.index(pid)] = True - # Remove the selected ids from the list - for p in np.array(init_param_ids)[chunk.detach().cpu().numpy()]: - param_ids.pop(param_ids.index(p)) - elif isinstance(self.chunks, (tuple, list)): - if _chunk_choices is None: - # Make a list of the chunks as given explicitly - _chunk_choices = list(range(len(self.chunks))) - if self.method == "random": - if len(_chunk_choices) == 0: - break - # Select a random chunk from the given groups - sub_index = random.choice(_chunk_choices) - _chunk_choices.pop(_chunk_choices.index(sub_index)) - for pid in self.chunks[sub_index]: - chunk[param_ids.index(pid)] = True - else: - if _chunk_index >= len(self.chunks): - break - # Select the next chunk in order - for pid in self.chunks[_chunk_index]: - chunk[param_ids.index(pid)] = True - _chunk_index += 1 - else: - raise ValueError( - "Unrecognized chunks value, should be one of int, tuple. not: {type(self.chunks)}" - ) - if self.verbose > 1: - AP_config.ap_logger.info(str(chunk)) - del res - with Param_Mask(self.model.parameters, chunk): - res = LM( - self.model, - ndf=self.ndf, - **self.LM_kwargs, - ).fit() + for _ in range(self.max_iter): + # Report status if self.verbose > 0: - AP_config.ap_logger.info(f"chunk loss: {res.res_loss()}") - if self.verbose > 1: - AP_config.ap_logger.info(f"chunk message: {res.message}") + config.logger.info(f"{quantity}: {self.loss_history[-1]:.6g}, L: {self.L:.3g}") + + # Perform fitting + chunk_L = [] + for c in self.iter_chunks(): + try: + if self.fit_valid: + with ValidContext(self.model): + valid_state = self.model.to_valid(self.current_state) + res = func.lm_step( + x=valid_state[self.chunks[c]], + data=self.Y, + model=partial(self.forward[c], valid_state), + weight=self.W, + jacobian=partial(self.jacobian[c], valid_state), + L=self.L, + Lup=self.Lup, + Ldn=self.Ldn, + likelihood=self.likelihood, + ) + self.current_state = self.model.from_valid( + backend.fill_at_indices( + valid_state, self.chunks[c], backend.copy(res["x"]) + ) + ) + else: + res = func.lm_step( + x=self.current_state[self.chunks[c]], + data=self.Y, + model=partial(self.forward[c], self.current_state), + weight=self.W, + jacobian=partial(self.jacobian[c], self.current_state), + L=self.L, + Lup=self.Lup, + Ldn=self.Ldn, + likelihood=self.likelihood, + ) + self.current_state = backend.fill_at_indices( + self.current_state, self.chunks[c], backend.copy(res["x"]) + ) + except OptimizeStopFail: + if self.verbose > 0: + config.logger.warning( + f"Could not find step to improve Chi^2 on chunk {c}, moving to next chunk" + ) + continue + except OptimizeStopSuccess as e: + continue # success on individual chunk is not enough to stop overall fit + chunk_L.append(res["L"]) + + # Record progress + self.L = np.clip(np.max(chunk_L), 1e-9, 1e9) + self.L_history.append(self.L) + self.loss_history.append(2 * res["nll"] / self.ndf) + self.lambda_history.append(backend.to_numpy(backend.copy(self.current_state))) + if self.check_convergence(): + break + + else: + self.message = self.message + "fail. Maximum iterations" - self.loss_history.append(res.res_loss()) - self.lambda_history.append( - self.model.parameters.vector_representation().detach().cpu().numpy() - ) if self.verbose > 0: - AP_config.ap_logger.info(f"Loss: {self.loss_history[-1]}") + config.logger.info( + f"Final {quantity}: {np.nanmin(self.loss_history):.6g}, L: {self.L_history[np.nanargmin(self.loss_history)]:.3g}. Converged: {self.message}" + ) - # test for convergence - if self.iteration >= 2 and ( - (-self.relative_tolerance * 1e-3) - < ((self.loss_history[-2] - self.loss_history[-1]) / self.loss_history[-1]) - < (self.relative_tolerance / 10) - ): - self._count_finish += 1 - else: - self._count_finish = 0 + self.model.set_values( + backend.as_array(self.res(), dtype=config.DTYPE, device=config.DEVICE) + ) + if update_uncertainty: + self.update_uncertainty() - self.iteration += 1 + return self - def fit(self): - self.iteration = 0 + def check_convergence(self) -> bool: + """Check if the optimization has converged based on the last + iteration's chi^2 and the relative tolerance. + """ + if len(self.loss_history) < 3: + return False + good_history = [self.loss_history[0]] + for l in self.loss_history[1:]: + if good_history[-1] > l: + good_history.append(l) + if len(self.loss_history) - len(good_history) >= 10: + self.message = self.message + "success by immobility. Convergence not guaranteed" + return True + if len(good_history) < 3: + return False + if (good_history[-2] - good_history[-1]) / good_history[ + -1 + ] < self.relative_tolerance and self.L < 0.1: + self.message = self.message + "success" + return True + if len(good_history) < 10: + return False + if (good_history[-10] - good_history[-1]) / good_history[-1] < self.relative_tolerance: + self.message = self.message + "success by immobility. Convergence not guaranteed" + return True + return False + + @property + @torch.no_grad() + def covariance_matrix(self): + """The covariance matrix for the model at the current + parameters. This can be used to construct a full Gaussian PDF for the + parameters using: $\\mathcal{N}(\\mu,\\Sigma)$ where $\\mu$ is the + optimized parameters and $\\Sigma$ is the covariance matrix. - start_fit = time() - try: - while True: - self.step() - if self.save_steps is not None: - self.model.save( - os.path.join( - self.save_steps, - f"{self.model.name}_Iteration_{self.iteration:03d}.yaml", - ) - ) - if self.iteration > 2 and self._count_finish >= 2: - self.message = self.message + "success" - break - elif self.iteration >= self.max_iter: - self.message = self.message + f"fail max iterations reached: {self.iteration}" - break + """ - except KeyboardInterrupt: - self.message = self.message + "fail interrupted" + if self._covariance_matrix is not None: + return self._covariance_matrix + + N = len(self.current_state) + self._covariance_matrix = backend.zeros((N, N), dtype=config.DTYPE, device=config.DEVICE) + for c in self.iter_chunks(): + J = self.jacobian[c](self.current_state, self.current_state[self.chunks[c]]) + if self.likelihood == "gaussian": + hess = func.hessian(J, self.W) + elif self.likelihood == "poisson": + hess = func.hessian_poisson(J, self.Y, self.full_forward(self.current_state)) + try: + sub_covariance_matrix = backend.linalg.inv(hess) + except: + config.logger.warning( + "WARNING: Hessian is singular, likely at least one parameter is non-physical. Will use pseudo-inverse of Hessian to continue but results should be inspected." + ) + sub_covariance_matrix = backend.linalg.pinv(hess) - self.model.parameters.vector_set_representation(self.res()) - if self.verbose > 1: - AP_config.ap_logger.info( - f"Iter Fitting complete in {time() - start_fit} sec with message: {self.message}" + ids = backend.meshgrid( + backend.as_array(np.arange(N)[self.chunks[c]], dtype=int, device=config.DEVICE), + backend.as_array(np.arange(N)[self.chunks[c]], dtype=int, device=config.DEVICE), + indexing="ij", ) + self._covariance_matrix = backend.fill_at_indices( + self._covariance_matrix, (ids[0], ids[1]), sub_covariance_matrix + ) + return self._covariance_matrix - return self + @torch.no_grad() + def update_uncertainty(self) -> None: + """Call this function after optimization to set the uncertainties for + the parameters. This will use the diagonal of the covariance + matrix to update the uncertainties. See the covariance_matrix + function for the full representation of the uncertainties. + + """ + # set the uncertainty for each parameter + cov = self.covariance_matrix + if backend.all(backend.isfinite(cov)): + try: + self.model.set_values( + backend.sqrt(backend.abs(backend.diag(cov))), attribute="uncertainty" + ) + except RuntimeError as e: + config.logger.warning(f"Unable to update uncertainty due to: {e}") + else: + config.logger.warning( + "Unable to update uncertainty due to non finite covariance matrix" + ) diff --git a/astrophot/fit/lm.py b/astrophot/fit/lm.py index 44d40460..a6aeb6ec 100644 --- a/astrophot/fit/lm.py +++ b/astrophot/fit/lm.py @@ -1,15 +1,17 @@ # Levenberg-Marquardt algorithm from typing import Sequence -from functools import partial import torch import numpy as np from .base import BaseOptimizer -from .. import AP_config -from ..errors import OptimizeStop +from .. import config +from ..backend_obj import backend, ArrayLike +from . import func +from ..errors import OptimizeStopFail, OptimizeStopSuccess +from ..param import ValidContext -__all__ = ("LM",) +__all__ = ("LM", "LMfast") class LM(BaseOptimizer): @@ -28,65 +30,67 @@ class LM(BaseOptimizer): The cost function that the LM algorithm tries to minimize is of the form: - .. math:: - f(\\boldsymbol{\\beta}) = \\frac{1}{2}\\sum_{i=1}^{N} r_i(\\boldsymbol{\\beta})^2 + $$f(\\boldsymbol{\\beta}) = \\frac{1}{2}\\sum_{i=1}^{N} r_i(\\boldsymbol{\\beta})^2$$ - where :math:`\\boldsymbol{\\beta}` is the vector of parameters, - :math:`r_i` are the residuals, and :math:`N` is the number of + where $\\boldsymbol{\\beta}$ is the vector of parameters, + $r_i$ are the residuals, and $N$ is the number of observations. The LM algorithm iteratively performs the following update to the parameters: - .. math:: - \\boldsymbol{\\beta}_{n+1} = \\boldsymbol{\\beta}_{n} - (J^T J + \\lambda diag(J^T J))^{-1} J^T \\boldsymbol{r} + $$\\boldsymbol{\\beta}_{n+1} = \\boldsymbol{\\beta}_{n} - (J^T J + \\lambda diag(J^T J))^{-1} J^T \\boldsymbol{r}$$ where: - - :math:`J` is the Jacobian matrix whose elements are :math:`J_{ij} = \\frac{\\partial r_i}{\\partial \\beta_j}`, - - :math:`\\boldsymbol{r}` is the vector of residuals :math:`r_i(\\boldsymbol{\\beta})`, - - :math:`\\lambda` is a damping factor which is adjusted at each iteration. - - When :math:`\\lambda = 0` this can be seen as the Gauss-Newton - method. In the limit that :math:`\\lambda` is large, the - :math:`J^T J` matrix (an approximation of the Hessian) becomes - subdominant and the update essentially points along :math:`J^T - \\boldsymbol{r}` which is the gradient. In this scenario the - gradient descent direction is also modified by the :math:`\\lambda - diag(J^T J)` scaling which in some sense makes each gradient + - $J$ is the Jacobian matrix whose elements are $J_{ij} = \\frac{\\partial r_i}{\\partial \\beta_j}$, + - $\\boldsymbol{r}$ is the vector of residuals $r_i(\\boldsymbol{\\beta})$, + - $\\lambda$ is a damping factor which is adjusted at each iteration. + + When $\\lambda = 0$ this can be seen as the Gauss-Newton + method. In the limit that $\\lambda$ is large, the + $J^T J$ matrix (an approximation of the Hessian) becomes + subdominant and the update essentially points along $J^T + \\boldsymbol{r}$ which is the gradient. In this scenario the + gradient descent direction is also modified by the $\\lambda + diag(J^T J)$ scaling which in some sense makes each gradient unitless and further improves the step. Note as well that as - :math:`\\lambda` gets larger the step taken will be smaller, which + $\\lambda$ gets larger the step taken will be smaller, which helps to ensure convergence when the initial guess of the parameters are far from the optimal solution. - Note that the residuals :math:`r_i` are typically also scaled by + Note that the residuals $r_i$ are typically also scaled by the variance of the pixels, but this does not change the equations above. For a detailed explanation of the LM method see the article by Henri Gavin on which much of the AstroPhot LM implementation is based:: - @article{Gavin2019, - title={The Levenberg-Marquardt algorithm for nonlinear least squares curve-fitting problems}, - author={Gavin, Henri P}, - journal={Department of Civil and Environmental Engineering, Duke University}, - volume={19}, - year={2019} - } + ```{latex} + @article{Gavin2019, + title={The Levenberg-Marquardt algorithm for nonlinear least squares curve-fitting problems}, + author={Gavin, Henri P}, + journal={Department of Civil and Environmental Engineering, Duke University}, + volume={19}, + year={2019} + } + ``` as well as the paper on LM geodesic acceleration by Mark Transtrum:: - @article{Tanstrum2012, - author = {{Transtrum}, Mark K. and {Sethna}, James P.}, - title = "{Improvements to the Levenberg-Marquardt algorithm for nonlinear least-squares minimization}", - year = 2012, - doi = {10.48550/arXiv.1201.5885}, - adsurl = {https://ui.adsabs.harvard.edu/abs/2012arXiv1201.5885T}, - } - - The damping factor :math:`\\lambda` is adjusted at each iteration: + ```{latex} + @article{Tanstrum2012, + author = {{Transtrum}, Mark K. and {Sethna}, James P.}, + title = "{Improvements to the Levenberg-Marquardt algorithm for nonlinear least-squares minimization}", + year = 2012, + doi = {10.48550/arXiv.1201.5885}, + adsurl = {https://ui.adsabs.harvard.edu/abs/2012arXiv1201.5885T}, + } + ``` + + The damping factor $\\lambda$ is adjusted at each iteration: it is effectively increased when we are far from the solution, and decreased when we are close to it. In practice, the algorithm - attempts to pick the smallest :math:`\\lambda` that is can while - making sure that the :math:`\\chi^2` decreases at each step. + attempts to pick the smallest $\\lambda$ that is can while + making sure that the $\\chi^2$ decreases at each step. The main advantage of the LM algorithm is its adaptability. When the current estimate is far from the optimum, the algorithm @@ -98,7 +102,7 @@ class LM(BaseOptimizer): enhancements to improve its performance. For example, the Jacobian may be approximated with finite differences, geodesic acceleration can be used to speed up convergence, and more sophisticated - strategies can be used to adjust the damping factor :math:`\\lambda`. + strategies can be used to adjust the damping factor $\\lambda$. The exact performance of the LM algorithm will depend on the nature of the problem, including the complexity of the function @@ -109,47 +113,6 @@ class LM(BaseOptimizer): state, and various other optional parameters as inputs and seeks to find the parameters that minimize the cost function. - Args: - model: The model to be optimized. - initial_state (Sequence): Initial values for the parameters to be optimized. - max_iter (int): Maximum number of iterations for the algorithm. - relative_tolerance (float): Tolerance level for relative change in cost function value to trigger termination of the algorithm. - fit_parameters_identity: Used to select a subset of parameters. This is mostly used internally. - verbose: Controls the verbosity of the output during optimization. A higher value results in more detailed output. If not provided, defaults to 0 (no output). - max_step_iter (optional): The maximum number of steps while searching for chi^2 improvement on a single Jacobian evaluation. Default is 10. - curvature_limit (optional): Controls how cautious the optimizer is for changing curvature. It should be a number greater than 0, where smaller is more cautious. Default is 1. - Lup and Ldn (optional): These adjust the step sizes for the damping parameter. Default is 5 and 3 respectively. - L0 (optional): This is the starting damping parameter. For easy problems with good initialization, this can be set lower. Default is 1. - acceleration (optional): Controls the use of geodesic acceleration, which can be helpful in some scenarios. Set 1 for full acceleration, 0 for no acceleration. Default is 0. - - Here is some basic usage of the LM optimizer: - - .. code-block:: python - - import astrophot as ap - - # build model - # ... - - # Initialize model parameters - model.initialize() - - # Fit the parameters - result = ap.fit.lm(model, verbose=1) - - # Check that a minimum was found - print(result.message) - - # See the minimum chi^2 value - print(f"min chi2: {result.res_loss()}") - - # Update parameter uncertainties - result.update_uncertainty() - - # Extract multivariate Gaussian of uncertainties - mu = result.res() - cov = result.covariance_matrix - """ def __init__( @@ -158,7 +121,12 @@ def __init__( initial_state: Sequence = None, max_iter: int = 100, relative_tolerance: float = 1e-5, + Lup=11.0, + Ldn=9.0, + L0=1.0, + max_step_iter: int = 10, ndf=None, + likelihood="gaussian", **kwargs, ): @@ -169,263 +137,69 @@ def __init__( relative_tolerance=relative_tolerance, **kwargs, ) - # The forward model which computes the output image given input parameters - self.forward = partial(model, as_representation=True) - # Compute the jacobian in representation units (defined for -inf, inf) - self.jacobian = partial(model.jacobian, as_representation=True) - self.jacobian_natural = partial(model.jacobian, as_representation=False) # Maximum number of iterations of the algorithm self.max_iter = max_iter # Maximum number of steps while searching for chi^2 improvement on a single jacobian evaluation - self.max_step_iter = kwargs.get("max_step_iter", 10) - # sets how cautious the optimizer is for changing curvature, should be number greater than 0, where smaller is more cautious - self.curvature_limit = kwargs.get("curvature_limit", 1.0) - # These are the adjustment step sized for the damping parameter - self._Lup = kwargs.get("Lup", 11.0) - self._Ldn = kwargs.get("Ldn", 9.0) - # This is the starting damping parameter, for easy problems with good initialization, this can be set lower - self.L = kwargs.get("L0", 1.0) - # Geodesic acceleration is helpful in some scenarios. By default it is turned off. Set 1 for full acceleration, 0 for no acceleration. - self.acceleration = kwargs.get("acceleration", 0.0) - # Initialize optimizer attributes - self.Y = self.model.target[self.fit_window].flatten("data") - - # 1 / (sigma^2) - kW = kwargs.get("W", None) - if kW is not None: - self.W = torch.as_tensor( - kW, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ).flatten() - elif model.target.has_variance: - self.W = self.model.target[self.fit_window].flatten("weight") - else: - self.W = torch.ones_like(self.Y) + self.max_step_iter = max_step_iter + self.Lup = Lup + self.Ldn = Ldn + self.L = L0 + self.likelihood = likelihood + if self.likelihood not in ["gaussian", "poisson"]: + raise ValueError(f"Unsupported likelihood: {self.likelihood}") # mask fit_mask = self.model.fit_mask() if isinstance(fit_mask, tuple): - fit_mask = torch.cat(tuple(FM.flatten() for FM in fit_mask)) + fit_mask = backend.concatenate(tuple(FM.flatten() for FM in fit_mask)) else: fit_mask = fit_mask.flatten() - if torch.sum(fit_mask).item() == 0: + if backend.sum(fit_mask).item() == 0: fit_mask = None - if model.target.has_mask: - mask = self.model.target[self.fit_window].flatten("mask") - if fit_mask is not None: - mask = mask | fit_mask - self.mask = torch.logical_not(mask) - elif fit_mask is not None: - self.mask = torch.logical_not(fit_mask) - else: - self.mask = None - if self.mask is not None and torch.sum(self.mask).item() == 0: - raise OptimizeStop("No data to fit. All pixels are masked") + + mask = self.model.target[self.fit_window].flatten("mask") + if fit_mask is not None: + mask = mask | fit_mask + self.mask = ~mask + if backend.sum(self.mask).item() == 0: + raise OptimizeStopSuccess("No data to fit. All pixels are masked") + + # Initialize optimizer attributes + self.Y = self.model.target[self.fit_window].flatten("data")[self.mask] + + # 1 / (sigma^2) + kW = kwargs.get("W", None) + if kW is not None: + self.W = backend.as_array(kW, dtype=config.DTYPE, device=config.DEVICE).flatten()[ + self.mask + ] + self.W = self.model.target[self.fit_window].flatten("weight")[self.mask] + + # The forward model which computes the output image given input parameters + self.forward = lambda x: model(window=self.fit_window, params=x).flatten("data")[self.mask] + # Compute the jacobian + self.jacobian = lambda x: model.jacobian(window=self.fit_window, params=x).flatten("data")[ + self.mask + ] # variable to store covariance matrix if it is ever computed self._covariance_matrix = None # Degrees of freedom if ndf is None: - if self.mask is None: - self.ndf = max(1.0, len(self.Y) - len(self.current_state)) - else: - self.ndf = max(1.0, torch.sum(self.mask).item() - len(self.current_state)) + self.ndf = max(1.0, len(self.Y) - len(self.current_state)) else: self.ndf = ndf - def Lup(self): - """ - Increases the damping parameter for more gradient-like steps. Used internally. - """ - self.L = min(1e9, self.L * self._Lup) - - def Ldn(self): - """ - Decreases the damping parameter for more Gauss-Newton like steps. Used internally. - """ - self.L = max(1e-9, self.L / self._Ldn) - - @torch.no_grad() - def step(self, chi2) -> torch.Tensor: - """Performs one step of the LM algorithm. Computes Jacobian, infers - hessian and gradient, solves for step vector and iterates on - damping parameter magnitude until a step with some improvement - in chi2 is found. Used internally. - - """ - Y0 = self.forward(parameters=self.current_state).flatten("data") - J = self.jacobian(parameters=self.current_state).flatten("data") - r = self._r(Y0, self.Y, self.W) - self.hess = self._hess(J, self.W) - self.grad = self._grad(J, self.W, Y0, self.Y) - init_chi2 = chi2 - nostep = True - best = (torch.zeros_like(self.current_state), init_chi2, self.L) - scarry_best = (None, init_chi2, self.L) - direction = "none" - iteration = 0 - d = 0.1 - for iteration in range(self.max_step_iter): - # In a scenario where LM is having a hard time proposing a good step, but the damping is really low, just jump up to normal damping levels - if iteration > self.max_step_iter / 2 and self.L < 1e-3: - self.L = 1.0 - - # compute LM update step - h = self._h(self.L, self.grad, self.hess) - - # Compute goedesic acceleration - Y1 = self.forward(parameters=self.current_state + d * h).flatten("data") - - rh = self._r(Y1, self.Y, self.W) - - rpp = self._rpp(J, d, rh - r, self.W, h) - - if self.L > 1e-4: - a = -self._h(self.L, rpp, self.hess) / 2 - else: - a = torch.zeros_like(h) - - # Evaluate new step - ha = h + a * self.acceleration - Y1 = self.forward(parameters=self.current_state + ha).flatten("data") - - # Compute and report chi^2 - chi2 = self._chi2(Y1.detach()).item() - if self.verbose > 1: - AP_config.ap_logger.info(f"sub step L: {self.L}, Chi^2/DoF: {chi2}") - - # Skip if chi^2 is nan - if not np.isfinite(chi2): - if self.verbose > 1: - AP_config.ap_logger.info("Skip due to non-finite values") - self.Lup() - if direction == "better": - break - direction = "worse" - continue - - # Keep track of chi^2 improvement even if it fails curvature test - if chi2 <= scarry_best[1]: - scarry_best = (ha, chi2, self.L) - - # Check for high curvature, in which case linear approximation is not valid. avoid this step - rho = torch.linalg.norm(a) / torch.linalg.norm(h) - if rho > self.curvature_limit: - if self.verbose > 1: - AP_config.ap_logger.info("Skip due to large curvature") - self.Lup() - if direction == "better": - break - direction = "worse" - continue - - # Check for Chi^2 improvement - if chi2 < best[1]: - if self.verbose > 1: - AP_config.ap_logger.info("new best chi^2") - best = (ha, chi2, self.L) - nostep = False - self.Ldn() - if self.L <= 1e-8 or direction == "worse": - break - direction = "better" - elif chi2 > best[1] and direction in ["none", "worse"]: - if self.verbose > 1: - AP_config.ap_logger.info("chi^2 is worse") - self.Lup() - if self.L == 1e9: - break - direction = "worse" - else: - break - - # If a step substantially improves the chi^2, stop searching for better step, simply exit the loop and accept the good step - if (best[1] - init_chi2) / init_chi2 < -0.1: - if self.verbose > 1: - AP_config.ap_logger.info("Large step taken, ending search for good step") - break - - if nostep: - if scarry_best[0] is not None: - if self.verbose > 1: - AP_config.ap_logger.warning( - "no low curvature step found, taking high curvature step" - ) - return scarry_best - raise OptimizeStop("Could not find step to improve chi^2") - - return best - - @staticmethod - @torch.no_grad() - def _h(L, grad, hess) -> torch.Tensor: - I = torch.eye(len(grad), dtype=grad.dtype, device=grad.device) - D = torch.ones_like(hess) - I - # Alternate damping scheme - # (hess + 1e-2 * L**2 * I) * (1 + L**2 * I) ** 2 / (1 + L**2), - h = torch.linalg.solve( - hess * (I + D / (1 + L)) + L * I * (1 + torch.diag(hess)), - grad, - ) - - return h - - @torch.no_grad() - def _chi2(self, Ypred) -> torch.Tensor: - if self.mask is None: - return torch.sum(self.W * (self.Y - Ypred) ** 2) / self.ndf - else: - return torch.sum((self.W * (self.Y - Ypred) ** 2)[self.mask]) / self.ndf - - @torch.no_grad() - def _r(self, Y, Ypred, W) -> torch.Tensor: - if self.mask is None: - return W * (Y - Ypred) - else: - return W[self.mask] * (Y[self.mask] - Ypred[self.mask]) - - @torch.no_grad() - def _hess(self, J, W) -> torch.Tensor: - if self.mask is None: - return J.T @ (W.view(len(W), -1) * J) - else: - return J[self.mask].T @ (W[self.mask].view(len(W[self.mask]), -1) * J[self.mask]) - - @torch.no_grad() - def _grad(self, J, W, Y, Ypred) -> torch.Tensor: - if self.mask is None: - return -J.T @ self._r(Y, Ypred, W) - else: - return -J[self.mask].T @ self._r(Y, Ypred, W) - - @torch.no_grad() - def _rpp(self, J, d, dr, W, h): - if self.mask is None: - return J.T @ ((2 / d) * ((dr / d - W * (J @ h)))) - else: - return J[self.mask].T @ ((2 / d) * ((dr / d - W[self.mask] * (J[self.mask] @ h)))) - - @torch.no_grad() - def update_hess_grad(self, natural=False) -> None: - """Updates the stored hessian matrix and gradient vector. This can be - used to compute the quantities in their natural parameter - representation. During normal optimization the hessian and - gradient are computed in a re-mapped parameter space where - parameters are defined form -inf to inf. + def chi2_ndf(self): + return backend.sum(self.W * (self.Y - self.forward(self.current_state)) ** 2) / self.ndf - """ - if natural: - J = self.jacobian_natural( - parameters=self.model.parameters.vector_transform_rep_to_val(self.current_state) - ).flatten("data") - else: - J = self.jacobian(parameters=self.current_state).flatten("data") - Ypred = self.forward(parameters=self.current_state).flatten("data") - self.hess = self._hess(J, self.W) - self.grad = self._grad(J, self.W, self.Y, Ypred) + def poisson_2nll_ndf(self): + M = self.forward(self.current_state) + return 2 * backend.sum(M - self.Y * backend.log(M + 1e-10)) / self.ndf @torch.no_grad() - def fit(self) -> BaseOptimizer: + def fit(self, update_uncertainty=True) -> BaseOptimizer: """This performs the fitting operation. It iterates the LM step function until convergence is reached. Includes a message after fitting to indicate how the fitting exited. Typically if @@ -437,85 +211,144 @@ def fit(self) -> BaseOptimizer: if len(self.current_state) == 0: if self.verbose > 0: - AP_config.ap_logger.warning("No parameters to optimize. Exiting fit") + config.logger.warning("No parameters to optimize. Exiting fit") + self.message = "No parameters to optimize. Exiting fit" return self + if self.likelihood == "gaussian": + quantity = "Chi^2/DoF" + self.loss_history = [self.chi2_ndf().item()] + elif self.likelihood == "poisson": + quantity = "2NLL/DoF" + self.loss_history = [self.poisson_2nll_ndf().item()] self._covariance_matrix = None - self.loss_history = [ - self._chi2(self.forward(parameters=self.current_state).flatten("data")).item() - ] self.L_history = [self.L] - self.lambda_history = [self.current_state.detach().clone().cpu().numpy()] + self.lambda_history = [backend.to_numpy(backend.copy(self.current_state))] + if self.verbose > 0: + config.logger.info( + f"==Starting LM fit for '{self.model.name}' with {len(self.current_state)} dynamic parameters and {len(self.Y)} pixels==" + ) - for iteration in range(self.max_iter): + for _ in range(self.max_iter): if self.verbose > 0: - AP_config.ap_logger.info(f"Chi^2/DoF: {self.loss_history[-1]}, L: {self.L}") + config.logger.info(f"{quantity}: {self.loss_history[-1]:.6g}, L: {self.L:.3g}") try: - res = self.step(chi2=self.loss_history[-1]) - except OptimizeStop: + if self.fit_valid: + with ValidContext(self.model): + res = func.lm_step( + x=self.model.to_valid(self.current_state), + data=self.Y, + model=self.forward, + weight=self.W, + jacobian=self.jacobian, + L=self.L, + Lup=self.Lup, + Ldn=self.Ldn, + likelihood=self.likelihood, + ) + self.current_state = self.model.from_valid(backend.copy(res["x"])) + else: + res = func.lm_step( + x=self.current_state, + data=self.Y, + model=self.forward, + weight=self.W, + jacobian=self.jacobian, + L=self.L, + Lup=self.Lup, + Ldn=self.Ldn, + likelihood=self.likelihood, + ) + self.current_state = backend.copy(res["x"]) + except OptimizeStopFail: + if self.verbose > 0: + config.logger.warning("Could not find step to improve Chi^2, stopping") + self.message = ( + self.message + + "success by immobility. Could not find step to improve Chi^2. Convergence not guaranteed" + ) + break + except OptimizeStopSuccess as e: if self.verbose > 0: - AP_config.ap_logger.warning("Could not find step to improve Chi^2, stopping") - self.message = self.message + "fail. Could not find step to improve Chi^2" + config.logger.info(f"Optimization converged successfully: {e}") + self.message = self.message + "success" break - self.L = res[2] - self.current_state = (self.current_state + res[0]).detach() - self.L_history.append(self.L) - self.loss_history.append(res[1]) - self.lambda_history.append(self.current_state.detach().clone().cpu().numpy()) - - self.Ldn() - - if len(self.loss_history) >= 3: - if (self.loss_history[-3] - self.loss_history[-1]) / self.loss_history[ - -1 - ] < self.relative_tolerance and self.L < 0.1: - self.message = self.message + "success" - break - if len(self.loss_history) > 10: - if (self.loss_history[-10] - self.loss_history[-1]) / self.loss_history[ - -1 - ] < self.relative_tolerance: - self.message = ( - self.message + "success by immobility. Convergence not guaranteed" - ) - break + self.L = np.clip(res["L"], 1e-9, 1e9) + self.L_history.append(res["L"]) + self.loss_history.append(2 * res["nll"] / self.ndf) + self.lambda_history.append(backend.to_numpy(backend.copy(self.current_state))) + + if self.check_convergence(): + break else: self.message = self.message + "fail. Maximum iterations" if self.verbose > 0: - AP_config.ap_logger.info( - f"Final Chi^2/DoF: {self.loss_history[-1]}, L: {self.L_history[-1]}. Converged: {self.message}" + config.logger.info( + f"Final {quantity}: {np.nanmin(self.loss_history):.6g}, L: {self.L_history[np.nanargmin(self.loss_history)]:.3g}. Converged: {self.message}" ) - self.model.parameters.vector_set_representation(self.res()) + + self.model.set_values( + backend.as_array(self.res(), dtype=config.DTYPE, device=config.DEVICE) + ) + if update_uncertainty: + self.update_uncertainty() return self + def check_convergence(self) -> bool: + """Check if the optimization has converged based on the last + iteration's chi^2 and the relative tolerance. + """ + if len(self.loss_history) < 3: + return False + good_history = [self.loss_history[0]] + for l in self.loss_history[1:]: + if good_history[-1] > l: + good_history.append(l) + if len(self.loss_history) - len(good_history) >= 10: + self.message = self.message + "success by immobility. Convergence not guaranteed" + return True + if len(good_history) < 3: + return False + if (good_history[-2] - good_history[-1]) / good_history[ + -1 + ] < self.relative_tolerance and self.L < 0.1: + self.message = self.message + "success" + return True + if len(good_history) < 10: + return False + if (good_history[-10] - good_history[-1]) / good_history[-1] < self.relative_tolerance: + self.message = self.message + "success by immobility. Convergence not guaranteed" + return True + return False + @property @torch.no_grad() - def covariance_matrix(self) -> torch.Tensor: + def covariance_matrix(self) -> ArrayLike: """The covariance matrix for the model at the current - parameters. This can be used to construct a full Gaussian PDF - for the parameters using: :math:`\\mathcal{N}(\\mu,\\Sigma)` - where :math:`\\mu` is the optimized parameters and - :math:`\\Sigma` is the covariance matrix. + parameters. This can be used to construct a full Gaussian PDF for the + parameters using: $\\mathcal{N}(\\mu,\\Sigma)$ where $\\mu$ is the + optimized parameters and $\\Sigma$ is the covariance matrix. """ if self._covariance_matrix is not None: return self._covariance_matrix - self.update_hess_grad(natural=True) + J = self.jacobian(self.current_state) + if self.likelihood == "gaussian": + hess = func.hessian(J, self.W) + elif self.likelihood == "poisson": + hess = func.hessian_poisson(J, self.Y, self.forward(self.current_state)) try: - self._covariance_matrix = torch.linalg.inv(self.hess) + self._covariance_matrix = backend.linalg.inv(hess) except: - AP_config.ap_logger.warning( - "WARNING: Hessian is singular, likely at least one model is non-physical. Will massage Hessian to continue but results should be inspected." + config.logger.warning( + "WARNING: Hessian is singular, likely at least one parameter is non-physical. Will use pseudo-inverse of Hessian to continue but results should be inspected." ) - self.hess += torch.eye( - len(self.grad), dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) * (torch.diag(self.hess) == 0) - self._covariance_matrix = torch.linalg.inv(self.hess) + self._covariance_matrix = backend.linalg.pinv(hess) return self._covariance_matrix @torch.no_grad() @@ -528,12 +361,22 @@ def update_uncertainty(self) -> None: """ # set the uncertainty for each parameter cov = self.covariance_matrix - if torch.all(torch.isfinite(cov)): + if backend.all(backend.isfinite(cov)): try: - self.model.parameters.vector_set_uncertainty(torch.sqrt(torch.abs(torch.diag(cov)))) + self.model.set_values( + backend.sqrt(backend.abs(backend.diag(cov))), attribute="uncertainty" + ) except RuntimeError as e: - AP_config.ap_logger.warning(f"Unable to update uncertainty due to: {e}") + config.logger.warning(f"Unable to update uncertainty due to: {e}") else: - AP_config.ap_logger.warning( + config.logger.warning( "Unable to update uncertainty due to non finite covariance matrix" ) + + +class LMfast(LM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.jacobian = backend.jacfwd( + lambda x: self.model(window=self.fit_window, params=x).flatten("data")[self.mask] + ) diff --git a/astrophot/fit/mala.py b/astrophot/fit/mala.py new file mode 100644 index 00000000..fe2b7cce --- /dev/null +++ b/astrophot/fit/mala.py @@ -0,0 +1,123 @@ +# Metropolis-Adjusted Langevin Algorithm sampler +from typing import Optional, Sequence + +import numpy as np + +from .base import BaseOptimizer +from ..models import Model +from .. import config +from ..backend_obj import backend +from . import func + +__all__ = ("MALA",) + + +class MALA(BaseOptimizer): + """Metropolis-Adjusted Langevin Algorithm (MALA) sampler, based on: + https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm . This + is a gradient-based MCMC sampler that uses the gradient of the + log-likelihood to propose new samples. These gradient based proposals can + lead to more efficient sampling of the parameter space. This is especially + true when the mass_matrix is set well. A good guess for the mass matrix is + the covariance matrix of the likelihood at the maximum likelihood point. + Which can be found fairly easily with the LM optimizer (see the fitting + methods tutorial). + + **Args:** + - `chains`: The number of MCMC chains to run in parallel. Default is 4. + - `epsilon`: The step size for the MALA sampler. Default is 1e-2. + - `mass_matrix`: The mass matrix for the MALA sampler. If None, the identity matrix is used. + - `progress_bar`: Whether to show a progress bar during sampling. Default is True. + - `likelihood`: The likelihood function to use for the MCMC sampling. Can be "gaussian" or "poisson". Default is "gaussian". + """ + + def __init__( + self, + model: Model, + initial_state: Optional[Sequence] = None, + chains=4, + epsilon: float = 1e-2, + mass_matrix: Optional[np.ndarray] = None, + max_iter: int = 1000, + progress_bar: bool = True, + likelihood="gaussian", + **kwargs, + ): + super().__init__(model, initial_state, max_iter=max_iter, **kwargs) + self.chain = [] + if len(self.current_state.shape) == 2: + self.chains = self.current_state.shape[0] + else: + self.chains = chains + self.likelihood = likelihood + self.epsilon = epsilon + self.mass_matrix = mass_matrix + self.progress_bar = progress_bar + + def density_func(self): + """ + Returns the density of the model at the given state vector. + This is used to calculate the likelihood of the model at the given state. + """ + if self.likelihood == "gaussian": + vll = backend.vmap(self.model.gaussian_log_likelihood) + elif self.likelihood == "poisson": + vll = backend.vmap(self.model.poisson_log_likelihood) + else: + raise ValueError(f"Unknown likelihood type: {self.likelihood}") + + def dens(state: np.ndarray) -> np.ndarray: + state = backend.as_array(state, dtype=config.DTYPE, device=config.DEVICE) + return backend.to_numpy(vll(state)) + + return dens + + def density_grad_func(self): + """ + Returns the gradient of the density of the model at the given state vector. + This is used to calculate the gradient of the likelihood of the model at the given state. + """ + if self.likelihood == "gaussian": + vll_grad = backend.vmap(backend.grad(self.model.gaussian_log_likelihood)) + elif self.likelihood == "poisson": + vll_grad = backend.vmap(backend.grad(self.model.poisson_log_likelihood)) + else: + raise ValueError(f"Unknown likelihood type: {self.likelihood}") + + def grad(state: np.ndarray) -> np.ndarray: + state = backend.as_array(state, dtype=config.DTYPE, device=config.DEVICE) + return backend.to_numpy(vll_grad(state)) + + return grad + + def fit(self): + + Px = self.density_func() + dPdx = self.density_grad_func() + + initial_state = backend.to_numpy(self.current_state) + if len(initial_state.shape) == 1: + initial_state = np.repeat(initial_state[None, :], self.chains, axis=0) + + if self.mass_matrix is None: + D = initial_state.shape[1] + self.mass_matrix = np.eye(D, dtype=initial_state.dtype) + + self.chain, self.logp = func.mala( + initial_state, + Px, + dPdx, + self.max_iter, + self.epsilon, + self.mass_matrix, + progress=self.progress_bar, + desc="MALA", + ) + # Fill model with max logp sample + max_logp_index = np.argmax(self.logp) + max_logp_index = np.unravel_index(max_logp_index, self.logp.shape) + self.model.set_values( + backend.as_array(self.chain[max_logp_index], dtype=config.DTYPE, device=config.DEVICE) + ) + + return self diff --git a/astrophot/fit/mhmcmc.py b/astrophot/fit/mhmcmc.py index ffb437eb..0ef9506b 100644 --- a/astrophot/fit/mhmcmc.py +++ b/astrophot/fit/mhmcmc.py @@ -1,51 +1,78 @@ # Metropolis-Hasting Markov-Chain Monte-Carlo from typing import Optional, Sequence -import torch -from tqdm import tqdm + import numpy as np + +try: + import emcee +except ImportError: + emcee = None + from .base import BaseOptimizer -from .. import AP_config +from ..models import Model +from .. import config +from ..backend_obj import backend -__all__ = ["MHMCMC"] +__all__ = ("MHMCMC",) class MHMCMC(BaseOptimizer): """Metropolis-Hastings Markov-Chain Monte-Carlo sampler, based on: - https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm . This - is a naive implementation of a standard MCMC, it is far from - optimal and should not be used for anything but the most basic - scenarios. + https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm . This is simply + a thin wrapper for the Emcee package, which is a well-known MCMC sampler. - Args: - model (AstroPhot_Model): The model which will be sampled. - initial_state (Optional[Sequence]): A 1D array with the values for each parameter in the model. Note that these values should be in the form of "as_representation" in the model. - max_iter (int): The number of sampling steps to perform. Default 1000 - epsilon (float or array): The random step length to take at each iteration. This is the standard deviation for the normal distribution sampling. Default 1e-2 + Note that the Emcee sampler requires multiple walkers to sample the + parameter space efficiently. The number of walkers is set to twice the + number of parameters by default, but can be made higher (not lower) if desired. + This is done by passing a 2D array of shape (nwalkers, ndim) to the `fit` method. + **Args:** + - `likelihood`: The likelihood function to use for the MCMC sampling. Can be "gaussian" or "poisson". Default is "gaussian". """ def __init__( self, - model: "AstroPhot_Model", + model: Model, initial_state: Optional[Sequence] = None, max_iter: int = 1000, + likelihood="gaussian", **kwargs, ): super().__init__(model, initial_state, max_iter=max_iter, **kwargs) - self.epsilon = kwargs.get("epsilon", 1e-2) - self.progress_bar = kwargs.get("progress_bar", True) - self.report_after = kwargs.get("report_after", int(self.max_iter / 10)) + if emcee is None: + raise ImportError( + "The emcee package is required for MHMCMC sampling. Please install it with `pip install emcee` or the like." + ) + self.likelihood = likelihood self.chain = [] - self._accepted = 0 - self._sampled = 0 + + def density(self): + """ + Returns the density of the model at the given state vector. + This is used to calculate the likelihood of the model at the given state. + """ + if self.likelihood == "gaussian": + vll = backend.vmap(self.model.gaussian_log_likelihood) + elif self.likelihood == "poisson": + vll = backend.vmap(self.model.poisson_log_likelihood) + else: + raise ValueError(f"Unknown likelihood type: {self.likelihood}") + + def dens(state: np.ndarray) -> np.ndarray: + state = backend.as_array(state, dtype=config.DTYPE, device=config.DEVICE) + return backend.to_numpy(vll(state)) + + return dens def fit( self, - state: Optional[torch.Tensor] = None, + state: Optional[np.ndarray] = None, nsamples: Optional[int] = None, restart_chain: bool = True, + skip_initial_state_check: bool = True, + flat_chain: bool = True, ): """ Performs the MCMC sampling using a Metropolis Hastings acceptance step and records the chain for later examination. @@ -56,66 +83,20 @@ def fit( if state is None: state = self.current_state - chi2 = self.sample(state) + if len(state.shape) == 1: + nwalkers = state.shape[0] * 2 + state = state * np.random.normal(loc=1, scale=0.01, size=(nwalkers, state.shape[0])) + else: + nwalkers = state.shape[0] + ndim = state.shape[1] + sampler = emcee.EnsembleSampler(nwalkers, ndim, self.density(), vectorize=True) + state = sampler.run_mcmc(state, nsamples, skip_initial_state_check=skip_initial_state_check) if restart_chain: - self.chain = [] + self.chain = sampler.get_chain(flat=flat_chain) else: - self.chain = list(self.chain) - - iterator = tqdm(range(nsamples)) if self.progress_bar else range(nsamples) - for i in iterator: - state, chi2 = self.step(state, chi2) - self.append_chain(state) - if i % self.report_after == 0 and i > 0 and self.verbose > 0: - AP_config.ap_logger.info(f"Acceptance: {self.acceptance}") - if self.verbose > 0: - AP_config.ap_logger.info(f"Acceptance: {self.acceptance}") - self.current_state = state - self.chain = np.stack(self.chain) - return self - - def append_chain(self, state: torch.Tensor): - """ - Add a state vector to the MCMC chain - """ - - self.chain.append( - self.model.parameters.vector_transform_rep_to_val(state).detach().cpu().clone().numpy() + self.chain = np.append(self.chain, sampler.get_chain(flat=flat_chain), axis=0) + self.model.set_values( + backend.as_array(self.chain[-1], dtype=config.DTYPE, device=config.DEVICE) ) - - @staticmethod - def accept(log_alpha): - """ - Evaluates randomly if a given proposal is accepted. This is done in log space which is more natural for the evaluation in the step. - """ - return torch.log(torch.rand(log_alpha.shape)) < log_alpha - - @torch.no_grad() - def sample(self, state: torch.Tensor): - """ - Samples the model at the proposed state vector values - """ - return self.model.negative_log_likelihood(parameters=state, as_representation=True) - - @torch.no_grad() - def step(self, state: torch.Tensor, chi2: torch.Tensor) -> torch.Tensor: - """ - Takes one step of the HMC sampler by integrating along a path initiated with a random momentum. - """ - - proposal_state = torch.normal(mean=state, std=self.epsilon) - proposal_chi2 = self.sample(proposal_state) - log_alpha = chi2 - proposal_chi2 - accept = self.accept(log_alpha) - self._accepted += accept - self._sampled += 1 - return proposal_state if accept else state, proposal_chi2 if accept else chi2 - - @property - def acceptance(self): - """ - Returns the ratio of accepted states to total states sampled. - """ - - return self._accepted / self._sampled + return self diff --git a/astrophot/fit/minifit.py b/astrophot/fit/minifit.py index a08b00d5..fe46921e 100644 --- a/astrophot/fit/minifit.py +++ b/astrophot/fit/minifit.py @@ -4,18 +4,34 @@ import numpy as np from .base import BaseOptimizer -from ..models import AstroPhot_Model +from ..models import Model from .lm import LM -from .. import AP_config +from .. import config __all__ = ["MiniFit"] class MiniFit(BaseOptimizer): + """MiniFit optimizer that applies a fitting method to a downsampled version + of the model's target image. + + This is useful for quickly optimizing parameters on a smaller scale before + applying them to the full resolution image. With fewer pixels, the optimization + can be faster and more efficient, especially for large images. + + This Optimizer can wrap any optimizer that follows the BaseOptimizer interface. + + **Args:** + - `downsample_factor`: Factor by which to downsample the target image. Default is 2. + - `max_pixels`: Maximum number of pixels in the downsampled image. Default is 10000. + - `method`: The optimizer method to use, e.g., `LM` for Levenberg-Marquardt. Default is `LM`. + - `method_kwargs`: Additional keyword arguments to pass to the optimizer method. + """ + def __init__( self, - model: AstroPhot_Model, - downsample_factor: int = 1, + model: Model, + downsample_factor: int = 2, max_pixels: int = 10000, method: BaseOptimizer = LM, initial_state: np.ndarray = None, @@ -37,12 +53,12 @@ def fit(self) -> BaseOptimizer: target_area = self.model.target[self.model.window] while True: small_target = target_area.reduce(self.downsample_factor) - if small_target.size < self.max_pixels: + if np.prod(small_target._data.shape) < self.max_pixels: break self.downsample_factor += 1 if self.verbose > 0: - AP_config.ap_logger.info(f"Downsampling target by {self.downsample_factor}x") + config.logger.info(f"Downsampling target by {self.downsample_factor}x") self.small_target = small_target self.model.target = small_target diff --git a/astrophot/fit/nuts.py b/astrophot/fit/nuts.py deleted file mode 100644 index 3fcee171..00000000 --- a/astrophot/fit/nuts.py +++ /dev/null @@ -1,171 +0,0 @@ -# No U-Turn Sampler variant of Hamiltonian Monte-Carlo -from typing import Optional, Sequence - -import torch -import pyro -import pyro.distributions as dist -from pyro.infer import MCMC as pyro_MCMC -from pyro.infer import NUTS as pyro_NUTS -from pyro.infer.mcmc.adaptation import BlockMassMatrix -from pyro.ops.welford import WelfordCovariance - -from .base import BaseOptimizer -from ..models import AstroPhot_Model - -__all__ = ["NUTS"] - - -########################################### -# !Overwrite pyro configuration behavior! -# currently this is the only way to provide -# mass matrix manually -########################################### -def new_configure(self, mass_matrix_shape, adapt_mass_matrix=True, options={}): - """ - Sets up an initial mass matrix. - - :param dict mass_matrix_shape: a dict that maps tuples of site names to the shape of - the corresponding mass matrix. Each tuple of site names corresponds to a block. - :param bool adapt_mass_matrix: a flag to decide whether an adaptation scheme will be used. - :param dict options: tensor options to construct the initial mass matrix. - """ - inverse_mass_matrix = {} - for site_names, shape in mass_matrix_shape.items(): - self._mass_matrix_size[site_names] = shape[0] - diagonal = len(shape) == 1 - inverse_mass_matrix[site_names] = ( - torch.full(shape, self._init_scale, **options) - if diagonal - else torch.eye(*shape, **options) * self._init_scale - ) - if adapt_mass_matrix: - adapt_scheme = WelfordCovariance(diagonal=diagonal) - self._adapt_scheme[site_names] = adapt_scheme - - if len(self.inverse_mass_matrix.keys()) == 0: - self.inverse_mass_matrix = inverse_mass_matrix - - -BlockMassMatrix.configure = new_configure -############################################ - - -class NUTS(BaseOptimizer): - """No U-Turn Sampler (NUTS) implementation for Hamiltonian Monte Carlo - (HMC) based MCMC sampling. - - This is a wrapper for the Pyro package: https://docs.pyro.ai/en/stable/index.html - - The NUTS class provides an implementation of the No-U-Turn Sampler - (NUTS) algorithm, which is a variation of the Hamiltonian Monte - Carlo (HMC) method for Markov Chain Monte Carlo (MCMC) - sampling. This implementation uses the Pyro library to perform the - sampling. The NUTS algorithm utilizes gradients of the target - distribution to more efficiently explore the probability - distribution of the model. - - More information on HMC and NUTS can be found at: - https://en.wikipedia.org/wiki/Hamiltonian_Monte_Carlo, - https://arxiv.org/abs/1701.02434, and - http://www.mcmchandbook.net/HandbookChapter5.pdf - - Args: - model (AstroPhot_Model): The model which will be sampled. - initial_state (Optional[Sequence], optional): A 1D array with the values for each parameter in the model. These values should be in the form of "as_representation" in the model. Defaults to None. - max_iter (int, optional): The number of sampling steps to perform. Defaults to 1000. - epsilon (float, optional): The step size for the NUTS sampler. Defaults to 1e-3. - inv_mass (Optional[Tensor], optional): Inverse Mass matrix (covariance matrix) for the Hamiltonian system. Defaults to None. - progress_bar (bool, optional): If True, display a progress bar during sampling. Defaults to True. - prior (Optional[Distribution], optional): Prior distribution for the model parameters. Defaults to None. - warmup (int, optional): Number of warmup (or burn-in) steps to perform before sampling. Defaults to 100. - nuts_kwargs (Dict[str, Any], optional): A dictionary of additional keyword arguments to pass to the NUTS sampler. Defaults to {}. - mcmc_kwargs (Dict[str, Any], optional): A dictionary of additional keyword arguments to pass to the MCMC function. Defaults to {}. - - Methods: - fit(state: Optional[torch.Tensor] = None, nsamples: Optional[int] = None, restart_chain: bool = True) -> 'NUTS': - Performs the MCMC sampling using a NUTS HMC and records the chain for later examination. - - """ - - def __init__( - self, - model: AstroPhot_Model, - initial_state: Optional[Sequence] = None, - max_iter: int = 1000, - **kwargs, - ): - super().__init__(model, initial_state, max_iter=max_iter, **kwargs) - - self.inv_mass = kwargs.get("inv_mass", None) - self.epsilon = kwargs.get("epsilon", 1e-4) - self.progress_bar = kwargs.get("progress_bar", True) - self.prior = kwargs.get("prior", None) - self.warmup = kwargs.get("warmup", 100) - self.nuts_kwargs = kwargs.get("nuts_kwargs", {}) - self.mcmc_kwargs = kwargs.get("mcmc_kwargs", {}) - - def fit( - self, - state: Optional[torch.Tensor] = None, - nsamples: Optional[int] = None, - restart_chain: bool = True, - ): - """ - Performs the MCMC sampling using a NUTS HMC and records the chain for later examination. - """ - - def step(model, prior): - x = pyro.sample("x", prior) - # Log-likelihood function - model.parameters.flat_detach() - log_likelihood_value = -model.negative_log_likelihood( - parameters=x, as_representation=True - ) - # Observe the log-likelihood - pyro.factor("obs", log_likelihood_value) - - if self.prior is None: - self.prior = dist.Normal( - self.current_state, - torch.ones_like(self.current_state) * 1e2 + torch.abs(self.current_state) * 1e2, - ) - - # Set up the NUTS sampler - nuts_kwargs = { - "jit_compile": False, - "ignore_jit_warnings": True, - "step_size": self.epsilon, - "full_mass": True, - "adapt_step_size": True, - "adapt_mass_matrix": self.inv_mass is None, - } - nuts_kwargs.update(self.nuts_kwargs) - nuts_kernel = pyro_NUTS(step, **nuts_kwargs) - if self.inv_mass is not None: - nuts_kernel.mass_matrix_adapter.inverse_mass_matrix = {("x",): self.inv_mass} - - # Provide an initial guess for the parameters - init_params = {"x": self.model.parameters.vector_representation()} - - # Run MCMC with the NUTS sampler and the initial guess - mcmc_kwargs = { - "num_samples": self.max_iter, - "warmup_steps": self.warmup, - "initial_params": init_params, - "disable_progbar": not self.progress_bar, - } - mcmc_kwargs.update(self.mcmc_kwargs) - mcmc = pyro_MCMC(nuts_kernel, **mcmc_kwargs) - - mcmc.run(self.model, self.prior) - self.iteration += self.max_iter - - # Extract posterior samples - chain = mcmc.get_samples()["x"] - - with torch.no_grad(): - for i in range(len(chain)): - chain[i] = self.model.parameters.vector_transform_rep_to_val(chain[i]) - self.chain = chain - - return self diff --git a/astrophot/fit/oldlm.py b/astrophot/fit/oldlm.py deleted file mode 100644 index 8df1e884..00000000 --- a/astrophot/fit/oldlm.py +++ /dev/null @@ -1,712 +0,0 @@ -# Levenberg-Marquardt algorithm -import os -from time import time -from typing import List, Callable, Optional, Sequence, Any - -import torch -from torch.autograd.functional import jacobian -import numpy as np - -from .base import BaseOptimizer -from .. import AP_config - -__all__ = ["oldLM", "LM_Constraint"] - - -@torch.no_grad() -@torch.jit.script -def Broyden_step(J, h, Yp, Yph): - delta = torch.matmul(J, h) - # avoid constructing a second giant jacobian matrix, instead go one row at a time - for j in range(J.shape[1]): - J[:, j] += (Yph - Yp - delta) * h[j] / torch.linalg.norm(h) - return J - - -class oldLM(BaseOptimizer): - """based heavily on: - @article{gavin2019levenberg, - title={The Levenberg-Marquardt algorithm for nonlinear least squares curve-fitting problems}, - author={Gavin, Henri P}, - journal={Department of Civil and Environmental Engineering, Duke University}, - volume={19}, - year={2019} - } - - The Levenberg-Marquardt algorithm bridges the gap between a - gradient descent optimizer and a Newton's Method optimizer. The - Hessian for the Newton's Method update is too complex to evaluate - with automatic differentiation (memory scales roughly as - parameters^2 * pixels^2) and so an approximation is made using the - Jacobian of the image pixels wrt to the parameters of the - model. Automatic differentiation provides an exact Jacobian as - opposed to a finite differences approximation. - - Once a Hessian H and gradient G have been determined, the update - step is defined as h which is the solution to the linear equation: - - (H + L*I)h = G - - where L is the Levenberg-Marquardt damping parameter and I is the - identity matrix. For small L this is just the Newton's method, for - large L this is just a small gradient descent step (approximately - h = grad/L). The method implemented is modified from Gavin 2019. - - Args: - model (AstroPhot_Model): object with which to perform optimization - initial_state (Optional[Sequence]): an initial state for optimization - epsilon4 (Optional[float]): approximation accuracy requirement, for any rho < epsilon4 the step will be rejected. Default 0.1 - epsilon5 (Optional[float]): numerical stability factor, added to the diagonal of the Hessian. Default 1e-8 - constraints (Optional[Union[LM_Constraint,tuple[LM_Constraint]]]): Constraint objects which control the fitting process. - L0 (Optional[float]): initial value for L factor in (H +L*I)h = G. Default 1. - Lup (Optional[float]): amount to increase L when rejecting an update step. Default 11. - Ldn (Optional[float]): amount to decrease L when accetping an update step. Default 9. - - """ - - def __init__( - self, - model: "AstroPhot_Model", - initial_state: Sequence = None, - max_iter: int = 100, - fit_parameters_identity: Optional[tuple] = None, - **kwargs, - ): - super().__init__( - model, - initial_state, - max_iter=max_iter, - fit_parameters_identity=fit_parameters_identity, - **kwargs, - ) - - # Set optimizer parameters - self.epsilon4 = kwargs.get("epsilon4", 0.1) - self.epsilon5 = kwargs.get("epsilon5", 1e-8) - self.Lup = kwargs.get("Lup", 11.0) - self.Ldn = kwargs.get("Ldn", 9.0) - self.L = kwargs.get("L0", 1e-3) - self.use_broyden = kwargs.get("use_broyden", False) - - # Initialize optimizer attributes - self.Y = self.model.target[self.fit_window].flatten("data") - # 1 / sigma^2 - self.W = ( - 1.0 / self.model.target[self.fit_window].flatten("variance") - if model.target.has_variance - else 1.0 - ) - # # pixels # parameters - self.ndf = len(self.Y) - len(self.current_state) - self.J = None - self.full_jac = False - self.current_Y = None - self.prev_Y = [None, None] - if self.model.target.has_mask: - self.mask = self.model.target[self.fit_window].flatten("mask") - # subtract masked pixels from degrees of freedom - self.ndf -= torch.sum(self.mask) - self.L_history = [] - self.decision_history = [] - self.rho_history = [] - self._count_converged = 0 - self.ndf = kwargs.get("ndf", self.ndf) - self._covariance_matrix = None - - # update attributes with constraints - self.constraints = kwargs.get("constraints", None) - if self.constraints is not None and isinstance(self.constraints, LM_Constraint): - self.constraints = (self.constraints,) - - if self.constraints is not None: - for con in self.constraints: - self.Y = torch.cat((self.Y, con.reference_value)) - self.W = torch.cat((self.W, 1 / con.weight)) - self.ndf -= con.reduce_ndf - if self.model.target.has_mask: - self.mask = torch.cat( - ( - self.mask, - torch.zeros_like(con.reference_value, dtype=torch.bool), - ) - ) - - def L_up(self, Lup=None): - if Lup is None: - Lup = self.Lup - self.L = min(1e9, self.L * Lup) - - def L_dn(self, Ldn=None): - if Ldn is None: - Ldn = self.Ldn - self.L = max(1e-9, self.L / Ldn) - - def step(self, current_state=None) -> None: - """ - Levenberg-Marquardt update step - """ - if current_state is not None: - self.current_state = current_state - - if self.iteration > 0: - if self.verbose > 0: - AP_config.ap_logger.info("---------iter---------") - else: - if self.verbose > 0: - AP_config.ap_logger.info("---------init---------") - - h = self.update_h() - if self.verbose > 1: - AP_config.ap_logger.info(f"h: {h.detach().cpu().numpy()}") - - self.update_Yp(h) - loss = self.update_chi2() - if self.verbose > 0: - AP_config.ap_logger.info(f"LM loss: {loss.item()}") - - if self.iteration == 0: - self.prev_Y[1] = self.current_Y - self.loss_history.append(loss.detach().cpu().item()) - self.L_history.append(self.L) - self.lambda_history.append(np.copy((self.current_state + h).detach().cpu().numpy())) - - if self.iteration > 0 and not torch.isfinite(loss): - if self.verbose > 0: - AP_config.ap_logger.warning("nan loss") - self.decision_history.append("nan") - self.rho_history.append(None) - self._count_reject += 1 - self.iteration += 1 - self.L_up() - return - elif self.iteration > 0: - lossmin = np.nanmin(self.loss_history[:-1]) - rho = self.rho(lossmin, loss, h) - if self.verbose > 1: - AP_config.ap_logger.debug( - f"LM loss: {loss.item()}, best loss: {np.nanmin(self.loss_history[:-1])}, loss diff: {np.nanmin(self.loss_history[:-1]) - loss.item()}, L: {self.L}" - ) - self.rho_history.append(rho) - if self.verbose > 1: - AP_config.ap_logger.debug(f"rho: {rho.item()}") - - if rho > self.epsilon4: - if self.verbose > 0: - AP_config.ap_logger.info("accept") - self.decision_history.append("accept") - self.prev_Y[0] = self.prev_Y[1] - self.prev_Y[1] = torch.clone(self.current_Y) - self.current_state += h - self.L_dn() - self._count_reject = 0 - if 0 < ((lossmin - loss) / loss) < self.relative_tolerance: - self._count_finish += 1 - else: - self._count_finish = 0 - else: - if self.verbose > 0: - AP_config.ap_logger.info("reject") - self.decision_history.append("reject") - self.L_up() - self._count_reject += 1 - return - else: - self.decision_history.append("init") - self.rho_history.append(None) - - if ( - (not self.use_broyden) - or self.J is None - or self.iteration < 2 - or "reset" in self.decision_history[-2:] - or rho < self.epsilon4 - or self._count_reject > 0 - or self.iteration >= (2 * len(self.current_state)) - or self.decision_history[-1] == "nan" - ): - if self.verbose > 1: - AP_config.ap_logger.debug("full jac") - self.update_J_AD() - else: - if self.verbose > 1: - AP_config.ap_logger.debug("Broyden jac") - self.update_J_Broyden(h, self.prev_Y[0], self.current_Y) - - self.update_hess() - self.update_grad(self.prev_Y[1]) - self.iteration += 1 - - def fit(self): - self.iteration = 0 - self._count_reject = 0 - self._count_finish = 0 - self.grad_only = False - - start_fit = time() - try: - while True: - if self.verbose > 0: - AP_config.ap_logger.info(f"L: {self.L}") - - # take LM step - self.step() - - # Save the state of the model - if self.save_steps is not None and self.decision_history[-1] == "accept": - self.model.save( - os.path.join( - self.save_steps, - f"{self.model.name}_Iteration_{self.iteration:03d}.yaml", - ) - ) - - lam, L, loss = self.progress_history() - - # Check for convergence - if ( - self.decision_history.count("accept") > 2 - and self.decision_history[-1] == "accept" - and L[-1] < 0.1 - and ((loss[-2] - loss[-1]) / loss[-1]) < (self.relative_tolerance / 10) - ): - self._count_converged += 1 - elif self.iteration >= self.max_iter: - self.message = self.message + f"fail max iterations reached: {self.iteration}" - break - elif not torch.all(torch.isfinite(self.current_state)): - self.message = self.message + "fail non-finite step taken" - break - elif ( - self.L >= (1e9 - 1) and self._count_reject >= 8 and not self.take_low_rho_step() - ): - self.message = ( - self.message - + "fail by immobility, unable to find improvement or even small bad step" - ) - break - if self._count_converged >= 3: - self.message = self.message + "success" - break - lam, L, loss = self.accept_history() - if len(loss) >= 10: - loss10 = np.array(loss[-10:]) - if ( - np.all( - np.abs((loss10[0] - loss10[-1]) / loss10[-1]) < self.relative_tolerance - ) - and L[-1] < 0.1 - ): - self.message = self.message + "success" - break - if ( - np.all( - np.abs((loss10[0] - loss10[-1]) / loss10[-1]) < self.relative_tolerance - ) - and L[-1] >= 0.1 - ): - self.message = ( - self.message - + "fail by immobility, possible bad area of parameter space." - ) - break - except KeyboardInterrupt: - self.message = self.message + "fail interrupted" - - if self.message.startswith("fail") and self._count_finish > 0: - self.message = ( - self.message - + ". possibly converged to numerical precision and could not make a better step." - ) - self.model.parameters.set_values( - self.res(), - as_representation=True, - parameters_identity=self.fit_parameters_identity, - ) - if self.verbose > 1: - AP_config.ap_logger.info( - f"LM Fitting complete in {time() - start_fit} sec with message: {self.message}" - ) - - return self - - def update_uncertainty(self): - # set the uncertainty for each parameter - cov = self.covariance_matrix - if torch.all(torch.isfinite(cov)): - try: - self.model.parameters.set_uncertainty( - torch.sqrt(torch.abs(torch.diag(cov))), - as_representation=False, - parameters_identity=self.fit_parameters_identity, - ) - except RuntimeError as e: - AP_config.ap_logger.warning(f"Unable to update uncertainty due to: {e}") - - @torch.no_grad() - def undo_step(self) -> None: - AP_config.ap_logger.info("undoing step, trying to recover") - assert ( - self.decision_history.count("accept") >= 2 - ), "cannot undo with not enough accepted steps, retry with new parameters" - assert len(self.decision_history) == len(self.lambda_history) - assert len(self.decision_history) == len(self.L_history) - found_accept = False - for i in reversed(range(len(self.decision_history))): - if not found_accept and self.decision_history[i] == "accept": - found_accept = True - continue - if self.decision_history[i] != "accept": - continue - self.current_state = torch.tensor( - self.lambda_history[i], - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - self.L = self.L_history[i] * self.Lup - - def take_low_rho_step(self) -> bool: - for i in reversed(range(len(self.decision_history))): - if "accept" in self.decision_history[i]: - return False - if self.rho_history[i] is not None and self.rho_history[i] > 0: - if self.verbose > 0: - AP_config.ap_logger.info( - f"taking a low rho step for some progress: {self.rho_history[i]}" - ) - self.current_state = torch.tensor( - self.lambda_history[i], - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - self.L = self.L_history[i] - - self.loss_history.append(self.loss_history[i]) - self.L_history.append(self.L) - self.lambda_history.append(np.copy((self.current_state).detach().cpu().numpy())) - self.decision_history.append("low rho accept") - self.rho_history.append(self.rho_history[i]) - - with torch.no_grad(): - self.update_Yp(torch.zeros_like(self.current_state)) - self.prev_Y[0] = self.prev_Y[1] - self.prev_Y[1] = self.current_Y - self.update_J_AD() - self.update_hess() - self.update_grad(self.prev_Y[1]) - self.iteration += 1 - self.count_reject = 0 - return True - - @torch.no_grad() - def update_h(self) -> torch.Tensor: - """Solves the LM update linear equation (H + L*I)h = G to determine - the proposal for how to adjust the parameters to decrease the - chi2. - - """ - h = torch.zeros_like(self.current_state) - if self.iteration == 0: - return h - - h = torch.linalg.solve( - ( - self.hess - + self.L**2 - * torch.eye(len(self.grad), dtype=AP_config.ap_dtype, device=AP_config.ap_device) - ) - * ( - 1 - + self.L**2 - * torch.eye(len(self.grad), dtype=AP_config.ap_dtype, device=AP_config.ap_device) - ) - ** 2 - / (1 + self.L**2), - self.grad, - ) - return h - - @torch.no_grad() - def update_Yp(self, h): - """ - Updates the current model values for each pixel - """ - # Sample model at proposed state - self.current_Y = self.model( - parameters=self.current_state + h, - as_representation=True, - parameters_identity=self.fit_parameters_identity, - window=self.fit_window, - ).flatten("data") - - # Add constraint evaluations - if self.constraints is not None: - for con in self.constraints: - self.current_Y = torch.cat((self.current_Y, con(self.model))) - - @torch.no_grad() - def update_chi2(self): - """ - Updates the chi squared / ndf value - """ - # Apply mask if needed - if self.model.target.has_mask: - loss = ( - torch.sum(((self.Y - self.current_Y) ** 2 * self.W)[torch.logical_not(self.mask)]) - / self.ndf - ) - else: - loss = torch.sum((self.Y - self.current_Y) ** 2 * self.W) / self.ndf - - return loss - - def update_J_AD(self) -> None: - """ - Update the jacobian using automatic differentiation, produces an accurate jacobian at the current state. - """ - # Free up memory - del self.J - if "cpu" not in AP_config.ap_device: - torch.cuda.empty_cache() - - # Compute jacobian on image - self.J = self.model.jacobian( - torch.clone(self.current_state).detach(), - as_representation=True, - parameters_identity=self.fit_parameters_identity, - window=self.fit_window, - ).flatten("data") - - # compute the constraint jacobian if needed - if self.constraints is not None: - for con in self.constraints: - self.J = torch.cat((self.J, con.jacobian(self.model))) - - # Apply mask if needed - if self.model.target.has_mask: - self.J[self.mask] = 0.0 - - # Note that the most recent jacobian was a full autograd jacobian - self.full_jac = True - - def update_J_natural(self) -> None: - """ - Update the jacobian using automatic differentiation, produces an accurate jacobian at the current state. Use this method to get the jacobian in the parameter space instead of representation space. - """ - # Free up memory - del self.J - if "cpu" not in AP_config.ap_device: - torch.cuda.empty_cache() - - # Compute jacobian on image - self.J = self.model.jacobian( - torch.clone( - self.model.parameters.transform( - self.current_state, - to_representation=False, - parameters_identity=self.fit_parameters_identity, - ) - ).detach(), - as_representation=False, - parameters_identity=self.fit_parameters_identity, - window=self.fit_window, - ).flatten("data") - - # compute the constraint jacobian if needed - if self.constraints is not None: - for con in self.constraints: - self.J = torch.cat((self.J, con.jacobian(self.model))) - - # Apply mask if needed - if self.model.target.has_mask: - self.J[self.mask] = 0.0 - - # Note that the most recent jacobian was a full autograd jacobian - self.full_jac = False - - @torch.no_grad() - def update_J_Broyden(self, h, Yp, Yph) -> None: - """ - Use the Broyden update to approximate the new Jacobian tensor at the current state. Less accurate, but far faster. - """ - - # Update the Jacobian - self.J = Broyden_step(self.J, h, Yp, Yph) - - # Apply mask if needed - if self.model.target.has_mask: - self.J[self.mask] = 0.0 - - # compute the constraint jacobian if needed - if self.constraints is not None: - for con in self.constraints: - self.J = torch.cat((self.J, con.jacobian(self.model))) - - # Note that the most recent jacobian update was with Broyden step - self.full_jac = False - - @torch.no_grad() - def update_hess(self) -> None: - """ - Update the Hessian using the jacobian most recently computed on the image. - """ - - if isinstance(self.W, float): - self.hess = torch.matmul(self.J.T, self.J) - else: - self.hess = torch.matmul(self.J.T, self.W.view(len(self.W), -1) * self.J) - self.hess += self.epsilon5 * torch.eye( - len(self.current_state), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - - @property - @torch.no_grad() - def covariance_matrix(self) -> torch.Tensor: - if self._covariance_matrix is not None: - return self._covariance_matrix - self.update_J_natural() - self.update_hess() - try: - self._covariance_matrix = 2 * torch.linalg.inv(self.hess) - except: - AP_config.ap_logger.warning( - "WARNING: Hessian is singular, likely at least one model is non-physical. Will massage Hessian to continue but results should be inspected." - ) - self.hess += torch.eye( - len(self.grad), dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) * (torch.diag(self.hess) == 0) - self._covariance_matrix = 2 * torch.linalg.inv(self.hess) - return self._covariance_matrix - - @torch.no_grad() - def update_grad(self, Yph) -> None: - """ - Update the gradient using the model evaluation on all pixels - """ - self.grad = torch.matmul(self.J.T, self.W * (self.Y - Yph)) - - @torch.no_grad() - def rho(self, Xp, Xph, h) -> torch.Tensor: - return ( - self.ndf - * (Xp - Xph) - / abs( - torch.dot( - h, - self.L**2 * (torch.abs(torch.diag(self.hess) - self.epsilon5) * h) + self.grad, - ) - ) - ) - - def accept_history(self) -> (List[np.ndarray], List[np.ndarray], List[float]): - lambdas = [] - Ls = [] - losses = [] - - for l in range(len(self.decision_history)): - if "accept" in self.decision_history[l] and np.isfinite(self.loss_history[l]): - lambdas.append(self.lambda_history[l]) - Ls.append(self.L_history[l]) - losses.append(self.loss_history[l]) - return lambdas, Ls, losses - - def progress_history(self) -> (List[np.ndarray], List[np.ndarray], List[float]): - lambdas = [] - Ls = [] - losses = [] - - for l in range(len(self.decision_history)): - if self.decision_history[l] == "accept": - lambdas.append(self.lambda_history[l]) - Ls.append(self.L_history[l]) - losses.append(self.loss_history[l]) - return lambdas, Ls, losses - - -class LM_Constraint: - """Add an arbitrary constraint to the LM optimization algorithm. - - Expresses a constraint between parameters in the LM optimization - routine. Constraints may be used to bias parameters to have - certain behaviour, for example you may require the radius of one - model to be larger than that of another, or may require two models - to have the same position on the sky. The constraints defined in - this object are fuzzy constraints and so can be broken to some - degree, the amount of constraint breaking is determined my how - informative the data is and how strong the constraint weight is - set. To create a constraint, first construct a function which - takes as argument a 1D tensor of the model parameters and gives as - output a real number (or 1D tensor of real numbers) which is zero - when the constraint is satisfied and non-zero increasing based on - how much the constraint is violated. For example: - - def example_constraint(P): - return (P[1] - P[0]) * (P[1] > P[0]).int() - - which enforces that parameter 1 is less than parameter 0. Note - that we do not use any control flow "if" statements and instead - incorporate the condition through multiplication, this is - important as it allows pytorch to compute derivatives through the - expression and performs far faster on GPU since no communication - is needed back and forth to handle the if-statement. Keep this in - mind while constructing your constraint function. Also, make sure - that any math operations are performed by pytorch so it can - construct a computational graph. Bayond the requirement that the - constraint be differentiable, there is no limitation on what - constraints can be built with this system. - - Args: - constraint_func (Callable[torch.Tensor, torch.Tensor]): python function which takes in a 1D tensor of parameters and generates real values in a tensor. - constraint_args (Optional[tuple]): An optional tuple of arguments for the constraint function that will be unpacked when calling the function. - weight (torch.Tensor): The weight of this constraint in the range (0,inf). Smaller values mean a stronger constraint, larger values mean a weaker constraint. Default 1. - representation_parameters (bool): if the constraint_func expects the parameters in the form of their representation or their standard value. Default False - out_len (int): the length of the output tensor by constraint_func. Default 1 - reference_value (torch.Tensor): The value at which the constraint is satisfied. Default 0. - reduce_ndf (float): Amount by which to reduce the degrees of freedom. Default 0. - - """ - - def __init__( - self, - constraint_func: Callable[[torch.Tensor, Any], torch.Tensor], - constraint_args: tuple = (), - representation_parameters: bool = False, - out_len: int = 1, - reduce_ndf: float = 0.0, - weight: Optional[torch.Tensor] = None, - reference_value: Optional[torch.Tensor] = None, - **kwargs, - ): - self.constraint_func = constraint_func - self.constraint_args = constraint_args - self.representation_parameters = representation_parameters - self.out_len = out_len - self.reduce_ndf = reduce_ndf - self.reference_value = torch.as_tensor( - reference_value if reference_value is not None else torch.zeros(out_len), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - self.weight = torch.as_tensor( - weight if weight is not None else torch.ones(out_len), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - - def jacobian(self, model: "AstroPhot_Model"): - jac = jacobian( - lambda P: self.constraint_func(P, *self.constraint_args), - model.parameters.get_vector(as_representation=self.representation_parameters), - strategy="forward-mode", - vectorize=True, - create_graph=False, - ) - - return jac.reshape(-1, np.sum(model.parameters.vector_len())) - - def __call__(self, model: "AstroPhot_Model"): - return self.constraint_func( - model.parameters.get_vector(as_representation=self.representation_parameters), - *self.constraint_args, - ) diff --git a/astrophot/fit/scipy_fit.py b/astrophot/fit/scipy_fit.py new file mode 100644 index 00000000..41031631 --- /dev/null +++ b/astrophot/fit/scipy_fit.py @@ -0,0 +1,109 @@ +from typing import Sequence, Literal + +from scipy.optimize import minimize +import numpy as np + +from .base import BaseOptimizer +from .. import config +from ..backend_obj import backend + +__all__ = ("ScipyFit",) + + +class ScipyFit(BaseOptimizer): + """Scipy-based optimizer for fitting models to data using various + optimization methods. + + The optimizer uses the `scipy.optimize.minimize` function to perform the + fitting. The Scipy package is widely used and well tested for optimization + tasks. It supports a variety of methods, however only a subset allow users to + define boundaries for the parameters. This wrapper is only for those methods. + + **Args:** + - `model`: The model to fit, which should be an instance of `Model`. + - `initial_state`: Initial guess for the model parameters as a 1D tensor. + - `method`: The optimization method to use. Default is "Nelder-Mead", but can be set to any of: "Nelder-Mead", "L-BFGS-B", "TNC", "SLSQP", "Powell", or "trust-constr". + - `ndf`: Optional number of degrees of freedom for the fit. If not provided, it is calculated as the number of data points minus the number of parameters. + """ + + def __init__( + self, + model, + initial_state: Sequence = None, + method: Literal[ + "Nelder-Mead", "L-BFGS-B", "TNC", "SLSQP", "Powell", "trust-constr" + ] = "Nelder-Mead", + likelihood: Literal["gaussian", "poisson"] = "gaussian", + ndf=None, + **kwargs, + ): + + super().__init__(model, initial_state, **kwargs) + self.method = method + self.likelihood = likelihood + + # Degrees of freedom + if ndf is None: + sub_target = self.model.target[self.model.window] + ndf = ( + np.prod(sub_target.flatten("data").shape) + - backend.sum(sub_target.flatten("mask")).item() + ) + self.ndf = max(1.0, ndf - len(self.current_state)) + else: + self.ndf = ndf + + def numpy_bounds(self): + """Convert the model's parameter bounds to a format suitable for scipy.optimize.""" + bounds = [] + for param in self.model.dynamic_params: + if param.shape == (): + bound = [None, None] + if param.valid[0] is not None: + bound[0] = backend.to_numpy(param.valid[0]) + if param.valid[1] is not None: + bound[1] = backend.to_numpy(param.valid[1]) + bounds.append(tuple(bound)) + else: + for i in range(np.prod(param.value.shape)): + bound = [None, None] + if param.valid[0] is not None: + bound[0] = backend.to_numpy(param.valid[0].flatten()[i]) + if param.valid[1] is not None: + bound[1] = backend.to_numpy(param.valid[1].flatten()[i]) + bounds.append(tuple(bound)) + return bounds + + def density(self, state: Sequence) -> float: + if self.likelihood == "gaussian": + return -self.model.gaussian_log_likelihood( + backend.as_array(state, dtype=config.DTYPE, device=config.DEVICE) + ).item() + elif self.likelihood == "poisson": + return -self.model.poisson_log_likelihood( + backend.as_array(state, dtype=config.DTYPE, device=config.DEVICE) + ).item() + else: + raise ValueError(f"Unknown likelihood type: {self.likelihood}") + + def fit(self): + + res = minimize( + lambda x: self.density(x), + self.current_state, + method=self.method, + bounds=self.numpy_bounds(), + options={ + "maxiter": self.max_iter, + }, + ) + self.scipy_res = res + self.message = self.message + f"success: {res.success}, message: {res.message}" + self.current_state = backend.as_array(res.x, dtype=config.DTYPE, device=config.DEVICE) + if self.verbose > 0: + config.logger.info( + f"Final 2NLL/DoF: {2*self.density(res.x)/self.ndf:.6g}. Converged: {self.message}" + ) + self.model.set_values(self.current_state) + + return self diff --git a/astrophot/image/__init__.py b/astrophot/image/__init__.py index 68ac134c..cc3615f8 100644 --- a/astrophot/image/__init__.py +++ b/astrophot/image/__init__.py @@ -1,8 +1,28 @@ -from .image_object import * -from .image_header import * -from .target_image import * -from .jacobian_image import * -from .psf_image import * -from .model_image import * -from .window_object import * -from .wcs import * +from .image_object import Image, ImageList +from .target_image import TargetImage, TargetImageList +from .sip_image import SIPModelImage, SIPTargetImage +from .cmos_image import CMOSModelImage, CMOSTargetImage +from .jacobian_image import JacobianImage, JacobianImageList +from .psf_image import PSFImage +from .model_image import ModelImage, ModelImageList +from .window import Window, WindowList +from . import func + +__all__ = ( + "Image", + "ImageList", + "TargetImage", + "TargetImageList", + "SIPModelImage", + "SIPTargetImage", + "CMOSModelImage", + "CMOSTargetImage", + "JacobianImage", + "JacobianImageList", + "PSFImage", + "ModelImage", + "ModelImageList", + "Window", + "WindowList", + "func", +) diff --git a/astrophot/image/cmos_image.py b/astrophot/image/cmos_image.py new file mode 100644 index 00000000..cc9e6766 --- /dev/null +++ b/astrophot/image/cmos_image.py @@ -0,0 +1,41 @@ +from .target_image import TargetImage +from .mixins import CMOSMixin +from .model_image import ModelImage +from ..backend_obj import backend +from .. import config + + +class CMOSModelImage(CMOSMixin, ModelImage): + """A ModelImage with CMOS-specific functionality.""" + + def fluxdensity_to_flux(self): + # CMOS pixels only sensitive in sub area, so scale the flux density + self._data = self._data * self.pixel_area * self.subpixel_scale**2 + + +class CMOSTargetImage(CMOSMixin, TargetImage): + """ + A TargetImage with CMOS-specific functionality. + This class is used to represent a target image with CMOS-specific features. + It inherits from TargetImage and CMOSMixin. + """ + + def model_image(self, upsample: int = 1, pad: int = 0, **kwargs) -> CMOSModelImage: + """Model the image with CMOS-specific features.""" + if upsample > 1 or pad > 0: + raise NotImplementedError("Upsampling and padding are not implemented for CMOS images.") + + kwargs = { + "subpixel_loc": self.subpixel_loc, + "subpixel_scale": self.subpixel_scale, + "_data": backend.zeros(self._data.shape[:2], dtype=config.DTYPE, device=config.DEVICE), + "CD": self.CD.value, + "crpix": self.crpix, + "crtan": self.crtan.value, + "crval": self.crval.value, + "zeropoint": self.zeropoint, + "identity": self.identity, + "name": self.name + "_model", + **kwargs, + } + return CMOSModelImage(**kwargs) diff --git a/astrophot/image/func/__init__.py b/astrophot/image/func/__init__.py new file mode 100644 index 00000000..f0723080 --- /dev/null +++ b/astrophot/image/func/__init__.py @@ -0,0 +1,38 @@ +from .image import ( + pixel_center_meshgrid, + cmos_pixel_center_meshgrid, + pixel_corner_meshgrid, + pixel_simpsons_meshgrid, + pixel_quad_meshgrid, + rotate, +) +from .wcs import ( + world_to_plane_gnomonic, + plane_to_world_gnomonic, + pixel_to_plane_linear, + plane_to_pixel_linear, + sip_delta, + sip_coefs, + sip_backward_transform, + sip_matrix, +) +from .window import window_or, window_and + +__all__ = ( + "pixel_center_meshgrid", + "cmos_pixel_center_meshgrid", + "pixel_corner_meshgrid", + "pixel_simpsons_meshgrid", + "pixel_quad_meshgrid", + "rotate", + "world_to_plane_gnomonic", + "plane_to_world_gnomonic", + "pixel_to_plane_linear", + "plane_to_pixel_linear", + "sip_delta", + "sip_coefs", + "sip_backward_transform", + "sip_matrix", + "window_or", + "window_and", +) diff --git a/astrophot/image/func/image.py b/astrophot/image/func/image.py new file mode 100644 index 00000000..74737a1f --- /dev/null +++ b/astrophot/image/func/image.py @@ -0,0 +1,45 @@ +from ...utils.integration import quad_table +from ...backend_obj import backend, ArrayLike + + +def pixel_center_meshgrid(shape: tuple[int, int], dtype, device) -> tuple: + i = backend.arange(shape[0], dtype=dtype, device=device) + j = backend.arange(shape[1], dtype=dtype, device=device) + return backend.meshgrid(i, j, indexing="ij") + + +def cmos_pixel_center_meshgrid( + shape: tuple[int, int], loc: tuple[float, float], dtype, device +) -> tuple: + i = backend.arange(shape[0], dtype=dtype, device=device) + loc[0] + j = backend.arange(shape[1], dtype=dtype, device=device) + loc[1] + return backend.meshgrid(i, j, indexing="ij") + + +def pixel_corner_meshgrid(shape: tuple[int, int], dtype, device) -> tuple: + i = backend.arange(shape[0] + 1, dtype=dtype, device=device) - 0.5 + j = backend.arange(shape[1] + 1, dtype=dtype, device=device) - 0.5 + return backend.meshgrid(i, j, indexing="ij") + + +def pixel_simpsons_meshgrid(shape: tuple[int, int], dtype, device) -> tuple: + i = 0.5 * backend.arange(2 * shape[0] + 1, dtype=dtype, device=device) - 0.5 + j = 0.5 * backend.arange(2 * shape[1] + 1, dtype=dtype, device=device) - 0.5 + return backend.meshgrid(i, j, indexing="ij") + + +def pixel_quad_meshgrid(shape: tuple[int, int], dtype, device, order=3) -> tuple: + i, j = pixel_center_meshgrid(shape, dtype, device) + di, dj, w = quad_table(order, dtype, device) + i = backend.repeat(i[..., None], order**2, -1) + di.flatten() + j = backend.repeat(j[..., None], order**2, -1) + dj.flatten() + return i, j, w.flatten() + + +def rotate(theta: ArrayLike, x: ArrayLike, y: ArrayLike) -> tuple: + """ + Applies a rotation matrix to the X,Y coordinates + """ + s = theta.sin() + c = theta.cos() + return c * x - s * y, s * x + c * y diff --git a/astrophot/image/func/wcs.py b/astrophot/image/func/wcs.py new file mode 100644 index 00000000..8d811256 --- /dev/null +++ b/astrophot/image/func/wcs.py @@ -0,0 +1,178 @@ +import numpy as np +from ...backend_obj import backend +from ... import config + +deg_to_rad = np.pi / 180 +rad_to_deg = 180 / np.pi +rad_to_arcsec = rad_to_deg * 3600 +arcsec_to_rad = deg_to_rad / 3600 + + +def world_to_plane_gnomonic(ra, dec, ra0, dec0, x0=0.0, y0=0.0): + """ + Convert world coordinates (RA, Dec) to plane coordinates (x, y) using the gnomonic projection. + + **Args:** + - `ra`: (torch.Tensor) Right Ascension in degrees. + - `dec`: (torch.Tensor) Declination in degrees. + - `ra0`: (torch.Tensor) Reference Right Ascension in degrees. + - `dec0`: (torch.Tensor) Reference Declination in degrees. + + **Returns:** + - `x`: (torch.Tensor) x coordinate in arcseconds. + - `y`: (torch.Tensor) y coordinate in arcseconds. + """ + ra = ra * deg_to_rad + dec = dec * deg_to_rad + ra0 = ra0 * deg_to_rad + dec0 = dec0 * deg_to_rad + + cosc = backend.sin(dec0) * backend.sin(dec) + backend.cos(dec0) * backend.cos( + dec + ) * backend.cos(ra - ra0) + + x = backend.cos(dec) * backend.sin(ra - ra0) + + y = backend.cos(dec0) * backend.sin(dec) - backend.sin(dec0) * backend.cos(dec) * backend.cos( + ra - ra0 + ) + + return x * rad_to_arcsec / cosc + x0, y * rad_to_arcsec / cosc + y0 + + +def plane_to_world_gnomonic(x, y, ra0, dec0, x0=0.0, y0=0.0, s=1e-10): + """ + Convert plane coordinates (x, y) to world coordinates (RA, Dec) using the gnomonic projection. + + **Args:** + - `x`: (Tensor) x coordinate in arcseconds. + - `y`: (Tensor) y coordinate in arcseconds. + - `ra0`: (Tensor) Reference Right Ascension in degrees. + - `dec0`: (Tensor) Reference Declination in degrees. + - `s`: (float) Small constant to avoid division by zero. + + **Returns:** + - `ra`: (Tensor) Right Ascension in degrees. + - `dec`: (Tensor) Declination in degrees. + """ + x = (x - x0) * arcsec_to_rad + y = (y - y0) * arcsec_to_rad + ra0 = ra0 * deg_to_rad + dec0 = dec0 * deg_to_rad + + rho = backend.sqrt(x**2 + y**2) + s + c = backend.arctan(rho) + + ra = ra0 + backend.arctan2( + x * backend.sin(c), + rho * backend.cos(dec0) * backend.cos(c) - y * backend.sin(dec0) * backend.sin(c), + ) + + dec = backend.arcsin( + backend.cos(c) * backend.sin(dec0) + y * backend.sin(c) * backend.cos(dec0) / rho + ) + + return ra * rad_to_deg, dec * rad_to_deg + + +def pixel_to_plane_linear(i, j, i0, j0, CD, x0=0.0, y0=0.0): + """ + Convert pixel coordinates to a tangent plane using the WCS information. This + matches the FITS convention for linear transformations. + + **Args:** + - `i` (Tensor): The first coordinate of the pixel in pixel units. + - `j` (Tensor): The second coordinate of the pixel in pixel units. + - `i0` (Tensor): The i reference pixel coordinate in pixel units. + - `j0` (Tensor): The j reference pixel coordinate in pixel units. + - `CD` (Tensor): The CD matrix in arcsec per pixel. This 2x2 matrix is used to convert + from pixel to arcsec units and also handles rotation/skew. + - `x0` (float): The x reference coordinate in arcseconds. + - `y0` (float): The y reference coordinate in arcseconds. + + **Returns:** + - Tuple[Tensor, Tensor]: Tuple containing the x and y coordinates in arcseconds + """ + uv = backend.stack((i.flatten() - i0, j.flatten() - j0), dim=0) + xy = CD @ uv + + return xy[0].reshape(i.shape) + x0, xy[1].reshape(i.shape) + y0 + + +def sip_coefs(order): + coefs = [] + for p in range(order + 1): + for q in range(order + 1 - p): + coefs.append((p, q)) + return tuple(coefs) + + +def sip_matrix(u, v, order): + M = backend.zeros( + (len(u), (order + 1) * (order + 2) // 2), dtype=config.DTYPE, device=config.DEVICE + ) + for i, (p, q) in enumerate(sip_coefs(order)): + M = backend.fill_at_indices(M, (slice(None), i), u**p * v**q) + return M + + +def sip_backward_transform(u, v, U, V, A_ORDER, B_ORDER): + """ + Credit: Shu Liu and Lei Hi, see here: + https://github.com/Roman-Supernova-PIT/sfft/blob/master/sfft/utils/CupyWCSTransform.py + + Compute the backward transformation from (U, V) to (u, v) + """ + + FP_UV = sip_matrix(U, V, A_ORDER) + GP_UV = sip_matrix(U, V, B_ORDER) + + AP = backend.linalg.lstsq(FP_UV, (u.flatten() - U).reshape(-1, 1))[0].squeeze(1) + BP = backend.linalg.lstsq(GP_UV, (v.flatten() - V).reshape(-1, 1))[0].squeeze(1) + return AP, BP + + +def sip_delta(u, v, sipA=(), sipB=()): + """ + u = j - j0 + v = i - i0 + sipA = dict(tuple(int,int), float) + The SIP coefficients, where the keys are tuples of powers (i, j) and the values are the coefficients. + For example, {(1, 2): 0.1} means delta_u = 0.1 * (u * v^2). + """ + delta_u = backend.zeros_like(u) + delta_v = backend.zeros_like(v) + # Get all used coefficient powers + all_a = set(s[0] for s in sipA) | set(s[0] for s in sipB) + all_b = set(s[1] for s in sipA) | set(s[1] for s in sipB) + # Pre-compute all powers of u and v + u_a = dict((a, u**a) for a in all_a) + v_b = dict((b, v**b) for b in all_b) + for a, b in sipA: + delta_u = delta_u + sipA[(a, b)] * (u_a[a] * v_b[b]) + for a, b in sipB: + delta_v = delta_v + sipB[(a, b)] * (u_a[a] * v_b[b]) + return delta_u, delta_v + + +def plane_to_pixel_linear(x, y, i0, j0, CD, x0=0.0, y0=0.0): + """ + Convert tangent plane coordinates to pixel coordinates using the WCS + information. This matches the FITS convention for linear transformations. + + **Args:** + - `x`: (Tensor) The first coordinate of the pixel in arcsec. + - `y`: (Tensor) The second coordinate of the pixel in arcsec. + - `i0`: (Tensor) The i reference pixel coordinate in pixel units. + - `j0`: (Tensor) The j reference pixel coordinate in pixel units. + - `CD`: (Tensor) The CD matrix in arcsec per pixel. + - `x0`: (float) The x reference coordinate in arcsec. + - `y0`: (float) The y reference coordinate in arcsec. + + **Returns:** + - Tuple[Tensor, Tensor]: Tuple containing the i and j pixel coordinates in pixel units. + """ + xy = backend.stack((x.flatten() - x0, y.flatten() - y0), dim=0) + uv = backend.linalg.inv(CD) @ xy + + return uv[0].reshape(x.shape) + i0, uv[1].reshape(y.shape) + j0 diff --git a/astrophot/image/func/window.py b/astrophot/image/func/window.py new file mode 100644 index 00000000..46be8061 --- /dev/null +++ b/astrophot/image/func/window.py @@ -0,0 +1,16 @@ +from ...backend_obj import backend + + +def window_or(other_origin, self_end, other_end): + + new_origin = backend.minimum(-0.5 * backend.ones_like(other_origin), other_origin) + new_end = backend.maximum(self_end, other_end) + + return new_origin, new_end + + +def window_and(other_origin, self_end, other_end): + new_origin = backend.maximum(-0.5 * backend.ones_like(other_origin), other_origin) + new_end = backend.minimum(self_end, other_end) + + return new_origin, new_end diff --git a/astrophot/image/image_header.py b/astrophot/image/image_header.py deleted file mode 100644 index ea74e127..00000000 --- a/astrophot/image/image_header.py +++ /dev/null @@ -1,334 +0,0 @@ -from typing import Optional, Union, Any - -import torch -import numpy as np -from astropy.io import fits -from astropy.wcs import WCS as AstropyWCS - -from .window_object import Window -from .. import AP_config - -__all__ = ["Image_Header"] - - -class Image_Header: - """Store meta-information for images to be used in AstroPhot. - - The Image_Header object stores all meta information which tells - AstroPhot what is contained in an image array of pixels. This - includes coordinate systems and how to transform between them (see - :doc:`coordinates`). The image header will also know the image - zeropoint if that data is available. - - Args: - window : Window or None, optional - A Window object defining the area of the image in the coordinate - systems. Default is None. - filename : str or None, optional - The name of a file containing the image data. Default is None. - zeropoint : float or None, optional - The image's zeropoint, used for flux calibration. Default is None. - metadata : dict or None, optional - Any information the user wishes to associate with this image, stored in a python dictionary. Default is None. - - """ - - north = np.pi / 2.0 - - def __init__( - self, - *, - data_shape: Optional[torch.Tensor] = None, - wcs: Optional[AstropyWCS] = None, - window: Optional[Window] = None, - filename: Optional[str] = None, - zeropoint: Optional[Union[float, torch.Tensor]] = None, - metadata: Optional[dict] = None, - identity: str = None, - state: Optional[dict] = None, - fits_state: Optional[dict] = None, - **kwargs: Any, - ) -> None: - # Record identity - if identity is None: - self.identity = str(id(self)) - else: - self.identity = identity - - # set Zeropoint - self.zeropoint = zeropoint - - # set metadata for the image - self.metadata = metadata - - if filename is not None: - self.load(filename) - return - elif state is not None: - self.set_state(state) - return - elif fits_state is not None: - self.set_fits_state(fits_state) - return - - # Set Window - if window is None: - data_shape = torch.as_tensor(data_shape, dtype=torch.int32, device=AP_config.ap_device) - # If window is not provided, create one based on provided information - self.window = Window( - pixel_shape=torch.flip(data_shape, (0,)), - wcs=wcs, - **kwargs, - ) - else: - # When the Window object is provided - self.window = window - - @property - def zeropoint(self): - """The photometric zeropoint of the image, used as a flux reference - point. - - """ - return self._zeropoint - - @zeropoint.setter - def zeropoint(self, zp): - if zp is None: - self._zeropoint = None - return - - self._zeropoint = ( - torch.as_tensor(zp, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - .clone() - .detach() - ) - - @property - def origin(self) -> torch.Tensor: - """ - Returns the location of the origin (pixel coordinate -0.5, -0.5) of the image window in the tangent plane (arcsec). - - Returns: - torch.Tensor: A 1D tensor of shape (2,) containing the (x, y) coordinates of the origin. - """ - return self.window.origin - - @property - def shape(self) -> torch.Tensor: - """ - Returns the shape (size) of the image window (arcsec, arcsec). - - Returns: - torch.Tensor: A 1D tensor of shape (2,) containing the (width, height) of the window in arcsec. - """ - return self.window.shape - - @property - def center(self) -> torch.Tensor: - """ - Returns the center of the image window (arcsec). - - Returns: - torch.Tensor: A 1D tensor of shape (2,) containing the (x, y) coordinates of the center. - """ - return self.window.center - - def world_to_plane(self, *args, **kwargs): - return self.window.world_to_plane(*args, **kwargs) - - def plane_to_world(self, *args, **kwargs): - return self.window.plane_to_world(*args, **kwargs) - - def plane_to_pixel(self, *args, **kwargs): - return self.window.plane_to_pixel(*args, **kwargs) - - def pixel_to_plane(self, *args, **kwargs): - return self.window.pixel_to_plane(*args, **kwargs) - - def plane_to_pixel_delta(self, *args, **kwargs): - return self.window.plane_to_pixel_delta(*args, **kwargs) - - def pixel_to_plane_delta(self, *args, **kwargs): - return self.window.pixel_to_plane_delta(*args, **kwargs) - - def world_to_pixel(self, *args, **kwargs): - return self.window.world_to_pixel(*args, **kwargs) - - def pixel_to_world(self, *args, **kwargs): - return self.window.pixel_to_world(*args, **kwargs) - - def get_coordinate_meshgrid(self): - return self.window.get_coordinate_meshgrid() - - def get_coordinate_corner_meshgrid(self): - return self.window.get_coordinate_corner_meshgrid() - - def get_coordinate_simps_meshgrid(self): - return self.window.get_coordinate_simps_meshgrid() - - @property - def pixelscale(self): - return self.window.pixelscale - - @property - def pixel_length(self): - return self.window.pixel_length - - @property - def pixel_area(self): - return self.window.pixel_area - - def shift(self, shift): - """Adjust the position of the image described by the header. This will - not adjust the data represented by the header, only the - coordinate system that maps pixel coordinates to the plane - coordinates. - - """ - self.window.shift(shift) - - def pixel_shift(self, shift): - self.window.pixel_shift(shift) - - def copy(self, **kwargs): - """Produce a copy of this image with all of the same properties. This - can be used when one wishes to make temporary modifications to - an image and then will want the original again. - - """ - copy_kwargs = { - "zeropoint": self.zeropoint, - "metadata": self.metadata, - "window": self.window.copy(), - "identity": self.identity, - } - copy_kwargs.update(kwargs) - return self.__class__(**copy_kwargs) - - def get_window(self, window, **kwargs): - """Get a sub-region of the image as defined by a window on the sky.""" - copy_kwargs = { - "window": self.window & window, - } - copy_kwargs.update(kwargs) - return self.copy(**copy_kwargs) - - def to(self, dtype=None, device=None): - if dtype is None: - dtype = AP_config.ap_dtype - if device is None: - device = AP_config.ap_device - self.window.to(dtype=dtype, device=device) - if self.zeropoint is not None: - self.zeropoint.to(dtype=dtype, device=device) - return self - - def crop(self, pixels): # fixme data_shape? - """Reduce the size of an image by cropping some number of pixels off - the borders. If pixels is a single value, that many pixels are - cropped off all sides. If pixels is two values then a different - crop is done in x vs y. If pixels is four values then crop on - all sides are specified explicitly. - - formatted as: - [crop all sides] or - [crop x, crop y] or - [crop x low, crop y low, crop x high, crop y high] - - """ - self.window.crop_pixel(pixels) - return self - - def rescale_pixel(self, scale: int, **kwargs): - if scale == 1: - return self - - return self.copy( - window=self.window.rescale_pixel(scale), - **kwargs, - ) - - def get_state(self): - """Returns a dictionary with necessary information to recreate the - Image_Header object. - - """ - state = {} - if self.zeropoint is not None: - state["zeropoint"] = self.zeropoint.item() - state["window"] = self.window.get_state() - if self.metadata is not None: - state["metadata"] = self.metadata - return state - - def set_state(self, state): - self.zeropoint = state.get("zeropoint", self.zeropoint) - self.window = Window(state=state["window"]) - self.metadata = state.get("metadata", self.metadata) - - def get_fits_state(self): - state = {} - state.update(self.window.get_fits_state()) - if self.zeropoint is not None: - state["ZEROPNT"] = str(self.zeropoint.detach().cpu().item()) - if self.metadata is not None: - state["METADATA"] = str(self.metadata) - return state - - def set_fits_state(self, state): - """ - Updates the state of the Image_Header using information saved in a FITS header (more generally, a properly formatted dictionary will also work but not yet). - """ - self.zeropoint = eval(state.get("ZEROPNT", "None")) - self.metadata = state.get("METADATA", None) - self.window = Window(fits_state=state) - - def _save_image_list(self): - """ - Constructs a FITS header object which has the necessary information to recreate the Image_Header object. - """ - img_header = fits.Header() - img_header["IMAGE"] = "PRIMARY" - img_header["WINDOW"] = str(self.window.get_state()) - if self.zeropoint is not None: - img_header["ZEROPNT"] = str(self.zeropoint.detach().cpu().item()) - if self.metadata is not None: - img_header["METADATA"] = str(self.metadata) - return img_header - - def save(self, filename=None, overwrite=True): - """ - Save header to a FITS file. - """ - image_list = self._save_image_list() - hdul = fits.HDUList(image_list) - if filename is not None: - hdul.writeto(filename, overwrite=overwrite) - return hdul - - def load(self, filename): - """ - load header from a FITS file. - """ - hdul = fits.open(filename) - for hdu in hdul: - if "IMAGE" in hdu.header and hdu.header["IMAGE"] == "PRIMARY": - self.set_fits_state(hdu.header) - break - return hdul - - def __str__(self): - state = self.get_state() - state.update(self.window.get_state()) - keys = ["pixel_shape", "pixelscale", "reference_imageij", "reference_imagexy"] - if "zeropoint" in state: - keys.append("zeropoint") - if "metadata" in state: - keys.append("metadata") - return "\n".join(f"{key}: {state[key]}" for key in keys) - - def __repr__(self): - state = self.get_state() - state.update(self.window.get_state()) - return "\n".join(f"{key}: {state[key]}" for key in sorted(state.keys())) diff --git a/astrophot/image/image_object.py b/astrophot/image/image_object.py index e94b4caf..7d989342 100644 --- a/astrophot/image/image_object.py +++ b/astrophot/image/image_object.py @@ -1,267 +1,306 @@ -from typing import Optional, Union, Any, Sequence, Tuple +from typing import Optional, Tuple, Union import torch -from torch.nn.functional import pad import numpy as np -from astropy.io import fits from astropy.wcs import WCS as AstropyWCS +from astropy.io import fits + +from ..param import Module, Param, forward +from .. import config +from ..backend_obj import backend, ArrayLike +from ..utils.conversions.units import deg_to_arcsec, arcsec_to_deg +from .window import Window, WindowList +from ..errors import InvalidImage, SpecificationConflict -from .window_object import Window, Window_List -from .image_header import Image_Header -from .. import AP_config -from ..errors import SpecificationConflict, ConflicingWCS, InvalidData, InvalidWindow +# from .base import BaseImage +from . import func -__all__ = ["Image", "Image_List"] +__all__ = ["Image", "ImageList"] -class Image(object): +class Image(Module): """Core class to represent images with pixel values, pixel scale, - and a window defining the spatial coordinates on the sky. - It supports arithmetic operations with other image objects while preserving logical image boundaries. - It also provides methods for determining the coordinate locations of pixels - - Parameters: - data: the matrix of pixel values for the image - pixelscale: the length of one side of a pixel in arcsec/pixel - window: an AstroPhot Window object which defines the spatial coordinates on the sky - filename: a filename from which to load the image. - zeropoint: photometric zero point for converting from pixel flux to magnitude - metadata: Any information the user wishes to associate with this image, stored in a python dictionary - origin: The origin of the image in the coordinate system. + and a window defining the spatial coordinates on the sky. It supports + arithmetic operations with other image objects while preserving logical + image boundaries. It also provides methods for determining the coordinate + locations of pixels + + **Args:** + - `data`: The image data as a tensor of pixel values. If not provided, a tensor of zeros will be created. + - `zeropoint`: The zeropoint of the image, which is used to convert from pixel flux to magnitude. + - `crpix`: The reference pixel coordinates in the image, which is used to convert from pixel coordinates to tangent plane coordinates. + - `pixelscale`: The side length of a pixel, used to create a simple diagonal CD matrix. + - `wcs`: An optional Astropy WCS object to initialize the image. + - `filename`: The filename to load the image from. If provided, the image will be loaded from the file. + - `hduext`: The HDU extension to load from the FITS file specified in `filename`. + - `identity`: An optional identity string for the image. + + these parameters are added to the optimization model: + + **Parameters:** + - `crval`: The reference coordinate of the image in degrees [RA, DEC]. + - `crtan`: The tangent plane coordinate of the image in arcseconds [x, y]. + - `CD`: The coordinate transformation matrix in arcseconds/pixel. """ + default_CD = ((1.0, 0.0), (0.0, 1.0)) + expect_ctype = (("RA---TAN",), ("DEC--TAN",)) + base_scale = 1.0 + def __init__( self, *, - data: Optional[torch.Tensor] = None, - header: Optional[Image_Header] = None, + data: Optional[ArrayLike] = None, + CD: Optional[Union[float, ArrayLike]] = None, + zeropoint: Optional[Union[float, ArrayLike]] = None, + crpix: Union[ArrayLike, tuple] = (0.0, 0.0), + crtan: Union[ArrayLike, tuple] = (0.0, 0.0), + crval: Union[ArrayLike, tuple] = (0.0, 0.0), + pixelscale: Optional[Union[ArrayLike, float]] = None, wcs: Optional[AstropyWCS] = None, - pixelscale: Optional[Union[float, torch.Tensor]] = None, - window: Optional[Window] = None, filename: Optional[str] = None, - zeropoint: Optional[Union[float, torch.Tensor]] = None, - metadata: Optional[dict] = None, - origin: Optional[Sequence] = None, - center: Optional[Sequence] = None, + hduext: int = 0, identity: str = None, - state: Optional[dict] = None, - fits_state: Optional[dict] = None, - **kwargs: Any, - ) -> None: - """Initialize an instance of the APImage class. - - Parameters: - ----------- - data : numpy.ndarray or None, optional - The image data. Default is None. - wcs : astropy.wcs.wcs.WCS or None, optional - A WCS object which defines a coordinate system for the image. Note that AstroPhot only handles basic WCS conventions. It will use the WCS object to get `wcs.pixel_to_world(-0.5, -0.5)` to determine the position of the origin in world coordinates. It will also extract the `pixel_scale_matrix` to index pixels going forward. - pixelscale : float or None, optional - The physical scale of the pixels in the image, in units of arcseconds. Default is None. - window : Window or None, optional - A Window object defining the area of the image to use. Default is None. - filename : str or None, optional - The name of a file containing the image data. Default is None. - zeropoint : float or None, optional - The image's zeropoint, used for flux calibration. Default is None. - metadata : dict or None, optional - Any information the user wishes to associate with this image, stored in a python dictionary. Default is None. - origin : numpy.ndarray or None, optional - The origin of the image in the coordinate system, as a 1D array of length 2. Default is None. - center : numpy.ndarray or None, optional - The center of the image in the coordinate system, as a 1D array of length 2. Default is None. - - Returns: - -------- - None - """ - self._data = None - - if state is not None: - self.header = Image_Header(state=state["header"]) - elif fits_state is not None: - self.set_fits_state(fits_state) - return - elif header is None: - if data is None and window is None and filename is None: - raise InvalidData("Image must have either data or a window to construct itself.") - self.header = Image_Header( - data_shape=None if data is None else data.shape, - pixelscale=pixelscale, - wcs=wcs, - window=window, - filename=filename, - zeropoint=zeropoint, - metadata=metadata, - origin=origin, - center=center, - identity=identity, - **kwargs, - ) + name: Optional[str] = None, + _data: Optional[ArrayLike] = None, + ): + super().__init__(name=name) + if _data is None: + self.data = data # units: flux else: - self.header = header + self._data = _data + self.crtan = Param( + "crtan", + crtan, + shape=(2,), + units="arcsec", + dtype=config.DTYPE, + device=config.DEVICE, + ) + self.zeropoint = zeropoint - if filename is not None: - self.load(filename) - elif state is not None: - self.set_state(state) - elif fits_state is not None: - self.data = fits_state[0]["DATA"] + if identity is None: + self.identity = id(self) else: - # set the data - if data is None: - self.data = torch.zeros( - torch.flip(self.window.pixel_shape, (0,)).detach().cpu().tolist(), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, + self.identity = identity + + if wcs is not None: + if wcs.wcs.ctype[0] not in self.expect_ctype[0]: + config.logger.warning( + "Astropy WCS not tangent plane coordinate system! May not be compatible with AstroPhot." + ) + if wcs.wcs.ctype[1] not in self.expect_ctype[1]: + config.logger.warning( + "Astropy WCS not tangent plane coordinate system! May not be compatible with AstroPhot." ) - else: - self.data = data - self.to() + crval = wcs.wcs.crval + crpix = np.array(wcs.wcs.crpix)[::-1] - 1 # handle FITS 1-indexing - # # Check that image data and header are in agreement (this requires talk back from GPU to CPU so is only used for testing) - # assert np.all(np.flip(np.array(self.data.shape)[:2]) == self.window.pixel_shape.numpy()), f"data shape {np.flip(np.array(self.data.shape)[:2])}, window shape {self.window.pixel_shape.numpy()}" + if CD is not None: + config.logger.warning("WCS CD set with supplied WCS, ignoring user supplied CD!") + CD = deg_to_arcsec * wcs.pixel_scale_matrix - @property - def north(self): - return self.header.north + # set the data + self.crval = Param( + "crval", crval, shape=(2,), units="deg", dtype=config.DTYPE, device=config.DEVICE + ) + self.crpix = crpix + + if isinstance(CD, (float, int)): + CD = np.array([[CD, 0.0], [0.0, CD]], dtype=np.float64) + elif CD is None and pixelscale is not None: + CD = np.array([[pixelscale, 0.0], [0.0, pixelscale]], dtype=np.float64) + elif CD is None: + CD = self.default_CD + + self.CD = Param( + "CD", + CD, + shape=(2, 2), + units="arcsec/pixel", + dtype=config.DTYPE, + device=config.DEVICE, + ) - @property - def pixel_area(self): - return self.header.pixel_area + if filename is not None: + self.load(filename, hduext=hduext) + return @property - def pixel_length(self): - return self.header.pixel_length - - def world_to_plane(self, *args, **kwargs): - return self.window.world_to_plane(*args, **kwargs) - - def plane_to_world(self, *args, **kwargs): - return self.window.plane_to_world(*args, **kwargs) - - def plane_to_pixel(self, *args, **kwargs): - return self.window.plane_to_pixel(*args, **kwargs) - - def pixel_to_plane(self, *args, **kwargs): - return self.window.pixel_to_plane(*args, **kwargs) - - def plane_to_pixel_delta(self, *args, **kwargs): - return self.window.plane_to_pixel_delta(*args, **kwargs) - - def pixel_to_plane_delta(self, *args, **kwargs): - return self.window.pixel_to_plane_delta(*args, **kwargs) - - def world_to_pixel(self, *args, **kwargs): - return self.window.world_to_pixel(*args, **kwargs) - - def pixel_to_world(self, *args, **kwargs): - return self.window.pixel_to_world(*args, **kwargs) - - def get_coordinate_meshgrid(self): - return self.window.get_coordinate_meshgrid() - - def get_coordinate_corner_meshgrid(self): - return self.window.get_coordinate_corner_meshgrid() + def data(self): + """The image data, which is a tensor of pixel values.""" + return backend.transpose(self._data, 1, 0) - def get_coordinate_simps_meshgrid(self): - return self.window.get_coordinate_simps_meshgrid() + @data.setter + def data(self, value: Optional[ArrayLike]): + """Set the image data. If value is None, the data is initialized to an empty tensor.""" + if value is None: + self._data = backend.empty((0, 0), dtype=config.DTYPE, device=config.DEVICE) + else: + # Transpose since pytorch uses (j, i) indexing when (i, j) is more natural for coordinates + self._data = backend.transpose( + backend.as_array(value, dtype=config.DTYPE, device=config.DEVICE), 1, 0 + ) @property - def origin(self) -> torch.Tensor: - """ - Returns the origin (bottom-left corner) of the image window. + def crpix(self) -> np.ndarray: + """The reference pixel coordinates in the image, which is used to convert from pixel coordinates to tangent plane coordinates.""" + return self._crpix - Returns: - torch.Tensor: A 1D tensor of shape (2,) containing the (x, y) coordinates of the origin. - """ - return self.header.window.origin + @crpix.setter + def crpix(self, value: Union[ArrayLike, tuple]): + self._crpix = np.asarray(value, dtype=np.float64) @property - def shape(self) -> torch.Tensor: - """ - Returns the shape (size) of the image window. - - Returns: - torch.Tensor: A 1D tensor of shape (2,) containing the (width, height) of the window in pixels. - """ - return self.header.window.shape + def zeropoint(self) -> ArrayLike: + """The zeropoint of the image, which is used to convert from pixel flux to magnitude.""" + return self._zeropoint + + @zeropoint.setter + def zeropoint(self, value): + """Set the zeropoint of the image.""" + if value is None: + self._zeropoint = None + else: + self._zeropoint = backend.as_array(value, dtype=config.DTYPE, device=config.DEVICE) @property - def center(self) -> torch.Tensor: - """ - Returns the center of the image window. - - Returns: - torch.Tensor: A 1D tensor of shape (2,) containing the (x, y) coordinates of the center. - """ - return self.header.window.center + def window(self) -> Window: + return Window(window=((0, 0), self._data.shape[:2]), image=self) @property - def size(self) -> torch.Tensor: - """ - Returns the size of the image window, the number of pixels in the image. + def center(self): + shape = backend.as_array(self._data.shape[:2], dtype=config.DTYPE, device=config.DEVICE) + return backend.stack(self.pixel_to_plane(*((shape - 1) / 2))) - Returns: - torch.Tensor: A 0D tensor containing the number of pixels. - """ - return self.header.window.size + # @property + # def shape(self): + # """The shape of the image data.""" + # return self.data.shape @property - def window(self): - return self.header.window + @forward + def pixel_area(self, CD): + """The area inside a pixel in arcsec^2""" + return backend.abs(backend.linalg.det(CD)) @property + @forward def pixelscale(self): - return self.header.pixelscale - - @property - def zeropoint(self): - return self.header.zeropoint - - @property - def metadata(self): - return self.header.metadata + """The approximate side length of a pixel, which is just + sqrt(pixel_area). For square pixels this is the actual pixel + length, for rectangular pixels it is a kind of average. - @property - def identity(self): - return self.header.identity + The pixelscale is not used for exact calculations + and instead sets a size scale within an image. - @property - def data(self) -> torch.Tensor: """ - Returns the image data. - """ - return self._data + return backend.sqrt(self.pixel_area) - @data.setter - def data(self, data) -> None: - """Set the image data.""" - self.set_data(data) + @forward + def pixel_to_plane( + self, + i: ArrayLike, + j: ArrayLike, + crtan: ArrayLike, + CD: ArrayLike, + ) -> Tuple[ArrayLike, ArrayLike]: + return func.pixel_to_plane_linear(i, j, *self.crpix, CD, *crtan) + + @forward + def plane_to_pixel( + self, + x: ArrayLike, + y: ArrayLike, + crtan: ArrayLike, + CD: ArrayLike, + ) -> Tuple[ArrayLike, ArrayLike]: + return func.plane_to_pixel_linear(x, y, *self.crpix, CD, *crtan) + + @forward + def plane_to_world( + self, x: ArrayLike, y: ArrayLike, crval: ArrayLike + ) -> Tuple[ArrayLike, ArrayLike]: + return func.plane_to_world_gnomonic(x, y, *crval) + + @forward + def world_to_plane( + self, ra: ArrayLike, dec: ArrayLike, crval: ArrayLike + ) -> Tuple[ArrayLike, ArrayLike]: + return func.world_to_plane_gnomonic(ra, dec, *crval) + + @forward + def world_to_pixel(self, ra: ArrayLike, dec: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: + """A wrapper which applies :meth:`world_to_plane` then + :meth:`plane_to_pixel`, see those methods for further + information. - def set_data(self, data: Union[torch.Tensor, np.ndarray], require_shape: bool = True): """ - Set the image data. + return self.plane_to_pixel(*self.world_to_plane(ra, dec)) - Args: - data (torch.Tensor or numpy.ndarray): The image data. - require_shape (bool): Whether to check that the shape of the data is the same as the current data. + @forward + def pixel_to_world(self, i: ArrayLike, j: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: + """A wrapper which applies :meth:`pixel_to_plane` then + :meth:`plane_to_world`, see those methods for further + information. - Raises: - SpecificationConflict: If `require_shape` is `True` and the shape of the data is different from the current data. """ - if self._data is not None and require_shape and data.shape != self._data.shape: - raise SpecificationConflict( - f"Attempting to change image data with tensor that has a different shape! ({data.shape} vs {self._data.shape}) Use 'require_shape = False' if this is desired behaviour." - ) - - if data is None: - self.data = torch.tensor((), dtype=AP_config.ap_dtype, device=AP_config.ap_device) - elif isinstance(data, torch.Tensor): - self._data = data.to(dtype=AP_config.ap_dtype, device=AP_config.ap_device) - else: - self._data = torch.as_tensor(data, dtype=AP_config.ap_dtype, device=AP_config.ap_device) + return self.plane_to_world(*self.pixel_to_plane(i, j)) + + def pixel_center_meshgrid(self) -> Tuple[ArrayLike, ArrayLike]: + """Get a meshgrid of pixel coordinates in the image, centered on the pixel grid.""" + return func.pixel_center_meshgrid(self._data.shape, config.DTYPE, config.DEVICE) + + def pixel_corner_meshgrid(self) -> Tuple[ArrayLike, ArrayLike]: + """Get a meshgrid of pixel coordinates in the image, with corners at the pixel grid.""" + return func.pixel_corner_meshgrid(self._data.shape, config.DTYPE, config.DEVICE) + + def pixel_simpsons_meshgrid(self) -> Tuple[ArrayLike, ArrayLike]: + """Get a meshgrid of pixel coordinates in the image, with Simpson's rule sampling.""" + return func.pixel_simpsons_meshgrid(self._data.shape, config.DTYPE, config.DEVICE) + + def pixel_quad_meshgrid(self, order=3) -> Tuple[ArrayLike, ArrayLike]: + """Get a meshgrid of pixel coordinates in the image, with quadrature sampling.""" + return func.pixel_quad_meshgrid(self._data.shape, config.DTYPE, config.DEVICE, order=order) + + @forward + def coordinate_center_meshgrid(self) -> Tuple[ArrayLike, ArrayLike]: + """Get a meshgrid of coordinate locations in the image, centered on the pixel grid.""" + i, j = self.pixel_center_meshgrid() + return self.pixel_to_plane(i, j) + + @forward + def coordinate_corner_meshgrid(self) -> Tuple[ArrayLike, ArrayLike]: + """Get a meshgrid of coordinate locations in the image, with corners at the pixel grid.""" + i, j = self.pixel_corner_meshgrid() + return self.pixel_to_plane(i, j) + + @forward + def coordinate_simpsons_meshgrid(self) -> Tuple[ArrayLike, ArrayLike]: + """Get a meshgrid of coordinate locations in the image, with Simpson's rule sampling.""" + i, j = self.pixel_simpsons_meshgrid() + return self.pixel_to_plane(i, j) + + @forward + def coordinate_quad_meshgrid(self, order=3) -> Tuple[ArrayLike, ArrayLike]: + """Get a meshgrid of coordinate locations in the image, with quadrature sampling.""" + i, j, _ = self.pixel_quad_meshgrid(order=order) + return self.pixel_to_plane(i, j) + + def copy_kwargs(self, **kwargs) -> dict: + kwargs = { + "_data": backend.copy(self._data), + "CD": self.CD.value, + "crpix": self.crpix, + "crval": self.crval.value, + "crtan": self.crtan.value, + "zeropoint": self.zeropoint, + "identity": self.identity, + "name": self.name, + **kwargs, + } + return kwargs def copy(self, **kwargs): """Produce a copy of this image with all of the same properties. This @@ -269,81 +308,59 @@ def copy(self, **kwargs): an image and then will want the original again. """ - return self.__class__( - data=torch.clone(self.data), - header=self.header.copy(**kwargs), - **kwargs, - ) + return self.__class__(**self.copy_kwargs(**kwargs)) def blank_copy(self, **kwargs): """Produces a blank copy of the image which has the same properties except that its data is now filled with zeros. """ - return self.__class__( - data=torch.zeros_like(self.data), - header=self.header.copy(**kwargs), + kwargs = { + "_data": backend.zeros_like(self._data), **kwargs, - ) + } + return self.copy(**kwargs) - def get_window(self, window, **kwargs): - """Get a sub-region of the image as defined by a window on the sky.""" - return self.__class__( - data=self.data[self.window.get_self_indices(window)], - header=self.header.get_window(window, **kwargs), - **kwargs, - ) + def crop(self, pixels: Union[int, Tuple[int, int], Tuple[int, int, int, int]], **kwargs): + """Crop the image by the number of pixels given. This will crop + the image in all four directions by the number of pixels given. - def to(self, dtype=None, device=None): - if dtype is None: - dtype = AP_config.ap_dtype - if device is None: - device = AP_config.ap_device - if self._data is not None: - self._data = self._data.to(dtype=dtype, device=device) - self.header.to(dtype=dtype, device=device) - return self + given data shape (N, M) the new shape will be: - def crop(self, pixels): - # does this show up? - if len(pixels) == 1: # same crop in all dimension - self.set_data( - self.data[ - pixels[0].int() : (self.data.shape[0] - pixels[0]).int(), - pixels[0].int() : (self.data.shape[1] - pixels[0]).int(), - ], - require_shape=False, - ) + crop - int: crop the same number of pixels on all sides. new shape (N - 2*crop, M - 2*crop) + crop - (int, int): crop each dimension by the number of pixels given. new shape (N - 2*crop[1], M - 2*crop[0]) + crop - (int, int, int, int): crop each side by the number of pixels given assuming (x low, x high, y low, y high). new shape (N - crop[2] - crop[3], M - crop[0] - crop[1]) + """ + if isinstance(pixels, int): + data = self._data[ + pixels : self._data.shape[0] - pixels, + pixels : self._data.shape[1] - pixels, + ] + crpix = self.crpix - pixels + elif len(pixels) == 1: # same crop in all dimension + crop = pixels if isinstance(pixels, int) else pixels[0] + data = self._data[ + crop : self._data.shape[0] - crop, + crop : self._data.shape[1] - crop, + ] + crpix = self.crpix - crop elif len(pixels) == 2: # different crop in each dimension - self.set_data( - self.data[ - pixels[1].int() : (self.data.shape[0] - pixels[1]).int(), - pixels[0].int() : (self.data.shape[1] - pixels[0]).int(), - ], - require_shape=False, - ) + data = self._data[ + pixels[0] : self._data.shape[0] - pixels[0], + pixels[1] : self._data.shape[1] - pixels[1], + ] + crpix = self.crpix - pixels elif len(pixels) == 4: # different crop on all sides - self.set_data( - self.data[ - pixels[2].int() : (self.data.shape[0] - pixels[3]).int(), - pixels[0].int() : (self.data.shape[1] - pixels[1]).int(), - ], - require_shape=False, + data = self._data[ + pixels[0] : self._data.shape[0] - pixels[1], + pixels[2] : self._data.shape[1] - pixels[3], + ] + crpix = self.crpix - pixels[0::2] + else: + raise ValueError( + f"Invalid crop shape {pixels}, must be (int,), (int, int), or (int, int, int, int)!" ) - self.header = self.header.crop(pixels) - return self - - def flatten(self, attribute: str = "data") -> np.ndarray: - return getattr(self, attribute).reshape(-1) - - def get_coordinate_meshgrid(self): - return self.header.get_coordinate_meshgrid() - - def get_coordinate_corner_meshgrid(self): - return self.header.get_coordinate_corner_meshgrid() - - def get_coordinate_simps_meshgrid(self): - return self.header.get_coordinate_simps_meshgrid() + return self.copy(_data=data, crpix=crpix, **kwargs) def reduce(self, scale: int, **kwargs): """This operation will downsample an image by the factor given. If @@ -354,319 +371,364 @@ def reduce(self, scale: int, **kwargs): pixels are condensed, but the pixel size is increased correspondingly. - Parameters: - scale: factor by which to condense the image pixels. Each scale X scale region will be summed [int] - + **Args:** + - `scale` (int): The scale factor by which to reduce the image. """ if not isinstance(scale, int) and not ( - isinstance(scale, torch.Tensor) and scale.dtype is torch.int32 + isinstance(scale, ArrayLike) and scale.dtype is backend.int32 ): raise SpecificationConflict(f"Reduce scale must be an integer! not {type(scale)}") if scale == 1: return self - MS = self.data.shape[0] // scale - NS = self.data.shape[1] // scale - return self.__class__( - data=self.data[: MS * scale, : NS * scale] - .reshape(MS, scale, NS, scale) - .sum(axis=(1, 3)), - header=self.header.rescale_pixel(scale, **kwargs), + MS = self._data.shape[0] // scale + NS = self._data.shape[1] // scale + + data = self._data[: MS * scale, : NS * scale].reshape(MS, scale, NS, scale).sum(axis=(1, 3)) + CD = self.CD.value * scale + crpix = (self.crpix + 0.5) / scale - 0.5 + return self.copy( + _data=data, + CD=CD, + crpix=crpix, **kwargs, ) - def expand(self, padding: Tuple[float]) -> None: - """ - Args: - padding tuple[float]: length 4 tuple with amounts to pad each dimension in physical units + def to(self, dtype=None, device=None): + if dtype is None: + dtype = config.DTYPE + if device is None: + device = config.DEVICE + super().to(dtype=dtype, device=device) + self._data = backend.to(self._data, dtype=dtype, device=device) + if self.zeropoint is not None: + self.zeropoint = backend.to(self.zeropoint, dtype=dtype, device=device) + return self + + def flatten(self, attribute: str = "data") -> ArrayLike: + return backend.flatten(getattr(self, attribute), end_dim=1) + + def fits_info(self) -> dict: + return { + "CTYPE1": "RA---TAN", + "CTYPE2": "DEC--TAN", + "CRVAL1": self.crval.value[0].item(), + "CRVAL2": self.crval.value[1].item(), + "CRPIX1": self.crpix[0] + 1, + "CRPIX2": self.crpix[1] + 1, + "CRTAN1": self.crtan.value[0].item(), + "CRTAN2": self.crtan.value[1].item(), + "CD1_1": self.CD.value[0][0].item() * arcsec_to_deg, + "CD1_2": self.CD.value[0][1].item() * arcsec_to_deg, + "CD2_1": self.CD.value[1][0].item() * arcsec_to_deg, + "CD2_2": self.CD.value[1][1].item() * arcsec_to_deg, + "MAGZP": self.zeropoint.item() if self.zeropoint is not None else -999, + "IDNTY": self.identity, + } + + def fits_images(self): + return [ + fits.PrimaryHDU( + backend.to_numpy(backend.transpose(self._data, 1, 0)), + header=fits.Header(self.fits_info()), + ) + ] + + def get_astropywcs(self, **kwargs): + kwargs = { + "NAXIS": 2, + "NAXIS1": self.shape[0].item(), + "NAXIS2": self.shape[1].item(), + **self.fits_info(), + **kwargs, + } + return AstropyWCS(kwargs) + + def save(self, filename: str): + hdulist = fits.HDUList(self.fits_images()) + hdulist.writeto(filename, overwrite=True) + + def load(self, filename: Union[str, fits.HDUList], hduext: int = 0): + """Load an image from a FITS file. This will load the primary HDU + and set the data, CD, crpix, crval, and crtan attributes + accordingly. If the WCS is not tangent plane, it will warn the user. + """ - padding = np.array(padding) - if np.any(padding < 0): - raise SpecificationConflict("negative padding not allowed in expand method") - pad_boundaries = tuple(np.int64(np.round(np.array(padding) / self.pixelscale))) - self.data = pad(self.data, pad=pad_boundaries, mode="constant", value=0) - self.header.expand(padding) - - def get_state(self): - state = {} - state["type"] = self.__class__.__name__ - state["data"] = self.data.detach().cpu().tolist() - state["header"] = self.header.get_state() - return state - - def set_state(self, state): - self.set_data(state["data"], require_shape=False) - self.header.set_state(state["header"]) - - def get_fits_state(self): - states = [{}] - states[0]["DATA"] = self.data.detach().cpu().numpy() - states[0]["HEADER"] = self.header.get_fits_state() - states[0]["HEADER"]["IMAGE"] = "PRIMARY" - return states - - def set_fits_state(self, states): - for state in states: - if state["HEADER"]["IMAGE"] == "PRIMARY": - self.set_data(np.array(state["DATA"], dtype=np.float64), require_shape=False) - self.header.set_fits_state(state["HEADER"]) - break - - def save(self, filename=None, overwrite=True): - states = self.get_fits_state() - img_list = [fits.PrimaryHDU(states[0]["DATA"], header=fits.Header(states[0]["HEADER"]))] - for state in states[1:]: - img_list.append(fits.ImageHDU(state["DATA"], header=fits.Header(state["HEADER"]))) - hdul = fits.HDUList(img_list) - if filename is not None: - hdul.writeto(filename, overwrite=overwrite) - return hdul + if isinstance(filename, str): + hdulist = fits.open(filename) + else: + hdulist = filename + self.data = np.array(hdulist[hduext].data, dtype=np.float64) + + self.CD = ( + np.array( + ( + (hdulist[hduext].header["CD1_1"], hdulist[hduext].header["CD1_2"]), + (hdulist[hduext].header["CD2_1"], hdulist[hduext].header["CD2_2"]), + ), + dtype=np.float64, + ) + * deg_to_arcsec + ) + self.crpix = (hdulist[hduext].header["CRPIX1"] - 1, hdulist[hduext].header["CRPIX2"] - 1) + self.crval = (hdulist[hduext].header["CRVAL1"], hdulist[hduext].header["CRVAL2"]) + if "CRTAN1" in hdulist[hduext].header and "CRTAN2" in hdulist[hduext].header: + self.crtan = (hdulist[hduext].header["CRTAN1"], hdulist[hduext].header["CRTAN2"]) + if "MAGZP" in hdulist[hduext].header and hdulist[hduext].header["MAGZP"] > -998: + self.zeropoint = hdulist[hduext].header["MAGZP"] + self.identity = hdulist[hduext].header.get("IDNTY", str(id(self))) + return hdulist + + def corners( + self, + ) -> Tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike]: + pixel_lowleft = backend.make_array((-0.5, -0.5), dtype=config.DTYPE, device=config.DEVICE) + pixel_lowright = backend.make_array( + (self._data.shape[0] - 0.5, -0.5), dtype=config.DTYPE, device=config.DEVICE + ) + pixel_upleft = backend.make_array( + (-0.5, self._data.shape[1] - 0.5), dtype=config.DTYPE, device=config.DEVICE + ) + pixel_upright = backend.make_array( + (self._data.shape[0] - 0.5, self._data.shape[1] - 0.5), + dtype=config.DTYPE, + device=config.DEVICE, + ) + lowleft = self.pixel_to_plane(*pixel_lowleft) + lowright = self.pixel_to_plane(*pixel_lowright) + upleft = self.pixel_to_plane(*pixel_upleft) + upright = self.pixel_to_plane(*pixel_upright) + return (lowleft, lowright, upright, upleft) + + @torch.no_grad() + def get_indices(self, other: Window): + if other.image is self: + return slice(max(0, other.i_low), min(self._data.shape[0], other.i_high)), slice( + max(0, other.j_low), min(self._data.shape[1], other.j_high) + ) + shift = np.round(self.crpix - other.crpix).astype(int) + return slice( + min(max(0, other.i_low + shift[0]), self._data.shape[0]), + max(0, min(other.i_high + shift[0], self._data.shape[0])), + ), slice( + min(max(0, other.j_low + shift[1]), self._data.shape[1]), + max(0, min(other.j_high + shift[1], self._data.shape[1])), + ) - def load(self, filename): - hdul = fits.open(filename) - states = list({"DATA": hdu.data, "HEADER": hdu.header} for hdu in hdul) - self.set_fits_state(states) + @torch.no_grad() + def get_other_indices(self, other: Window): + if other.image == self: + shape = other.shape + return slice( + max(0, -other.i_low), min(self._data.shape[0] - other.i_low, shape[0]) + ), slice(max(0, -other.j_low), min(self._data.shape[1] - other.j_low, shape[1])) + raise ValueError() + + def get_window(self, other: Union[Window, "Image"], indices=None, **kwargs): + """Get a new image object which is a window of this image + corresponding to the other image's window. This will return a + new image object with the same properties as this one, but with + the data cropped to the other image's window. + + """ + if indices is None: + indices = self.get_indices(other if isinstance(other, Window) else other.window) + new_img = self.copy( + _data=self._data[indices], + crpix=self.crpix - np.array((indices[0].start, indices[1].start)), + **kwargs, + ) + return new_img def __sub__(self, other): if isinstance(other, Image): - new_img = self[other.window].copy() - new_img.data -= other.data[self.window.get_other_indices(other)] + new_img = self[other] + new_img._data = new_img._data - other[self]._data return new_img else: new_img = self.copy() - new_img.data -= other + new_img._data = new_img._data - other return new_img def __add__(self, other): if isinstance(other, Image): - new_img = self[other.window].copy() - new_img.data += other.data[self.window.get_other_indices(other)] + new_img = self[other] + new_img._data = new_img._data + other[self]._data return new_img else: new_img = self.copy() - new_img.data += other + new_img._data = new_img._data + other return new_img def __iadd__(self, other): if isinstance(other, Image): - self.data[other.window.get_other_indices(self)] += other.data[ - self.window.get_other_indices(other) - ] + self._data = backend.add_at_indices( + self._data, + self.get_indices(other.window), + other._data[other.get_indices(self.window)], + ) else: - self.data += other + self._data = self._data + other return self def __isub__(self, other): if isinstance(other, Image): - self.data[other.window.get_other_indices(self)] -= other.data[ - self.window.get_other_indices(other) - ] + self._data = backend.add_at_indices( + self._data, + self.get_indices(other.window), + -other._data[other.get_indices(self.window)], + ) else: - self.data -= other + self._data = self._data - other return self def __getitem__(self, *args): - if len(args) == 1 and isinstance(args[0], Window): + if len(args) == 1 and isinstance(args[0], (Image, Window)): return self.get_window(args[0]) - if len(args) == 1 and isinstance(args[0], Image): - return self.get_window(args[0].window) - raise ValueError("Unrecognized Image getitem request!") - - def __str__(self): - return f"image pixelscale: {self.pixelscale.detach().cpu().numpy()} origin: {self.origin.detach().cpu().numpy()} shape: {self.shape.detach().cpu().numpy()}" - - def __repr__(self): - return f"image pixelscale: {self.pixelscale.detach().cpu().numpy()} origin: {self.origin.detach().cpu().numpy()} shape: {self.shape.detach().cpu().numpy()} center: {self.center.detach().cpu().numpy()}\ndata: {self.data.detach().cpu().numpy()}" + return super().__getitem__(*args) -class Image_List(Image): - def __init__(self, image_list, window=None): - self.image_list = list(image_list) - self.check_wcs() - self.window = window - - def check_wcs(self): - """Ensure the WCS systems being used by all the windows in this list - are consistent with each other. They should all project world - coordinates onto the same tangent plane. - - """ - ref = torch.stack(tuple(I.window.reference_radec for I in self.image_list)) - if not torch.allclose(ref, ref[0]): - raise ConflicingWCS( - "Reference (world) coordinate mismatch! All images in Image_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means." - ) - ref = torch.stack(tuple(I.window.reference_planexy for I in self.image_list)) - if not torch.allclose(ref, ref[0]): - raise ConflicingWCS( - "Reference (tangent plane) coordinate mismatch! All images in Image_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means." - ) - - if len(set(I.window.projection for I in self.image_list)) > 1: - raise ConflicingWCS( - "Projection mismatch! All images in Image_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means." +class ImageList(Module): + def __init__(self, images, name=None): + super().__init__(name=name) + self.images = list(images) + if not all(isinstance(image, Image) for image in self.images): + raise InvalidImage( + f"Image_List can only hold Image objects, not {tuple(type(image) for image in self.images)}" ) - @property - def window(self): - return Window_List(list(image.window for image in self.image_list)) - - @window.setter - def window(self, window): - if window is None: - return - - if not isinstance(window, Window_List): - raise InvalidWindow("Target_List must take a Window_List object as its window") - - for i in range(len(self.image_list)): - self.image_list[i] = self.image_list[i][window.window_list[i]] - - @property - def pixelscale(self): - return tuple(image.pixelscale for image in self.image_list) - - @property - def zeropoint(self): - return tuple(image.zeropoint for image in self.image_list) - @property def data(self): - return tuple(image.data for image in self.image_list) + return tuple(image.data for image in self.images) - @data.setter - def data(self, data): - for image, dat in zip(self.image_list, data): - image.data = dat + @property + def _data(self): + return tuple(image._data for image in self.images) def copy(self): return self.__class__( - tuple(image.copy() for image in self.image_list), + tuple(image.copy() for image in self.images), ) def blank_copy(self): return self.__class__( - tuple(image.blank_copy() for image in self.image_list), + tuple(image.blank_copy() for image in self.images), ) - def get_window(self, window): + def get_window(self, other: "ImageList"): return self.__class__( - tuple(image[win] for image, win in zip(self.image_list, window)), + tuple(image[win] for image, win in zip(self.images, other.images)), ) - def index(self, other): - if isinstance(other, Image) and hasattr(other, "identity"): - for i, self_image in enumerate(self.image_list): - if other.identity == self_image.identity: - return i - else: - raise ValueError("Could not find identity match between image list and input image") - raise NotImplementedError(f"Image_List cannot get index for {type(other)}") + def index(self, other: Image): + for i, image in enumerate(self.images): + if other.identity == image.identity: + return i + else: + raise IndexError( + f"Could not find identity match between image list {self.name} and input image {other.name}" + ) + + def match_indices(self, other: "ImageList"): + """Match the indices of the images in this list with those in another Image_List.""" + indices = [] + for other_image in other.images: + try: + i = self.index(other_image) + except IndexError: + continue + indices.append(i) + return indices def to(self, dtype=None, device=None): if dtype is not None: - dtype = AP_config.ap_dtype + dtype = config.DTYPE if device is not None: - device = AP_config.ap_device - for image in self.image_list: - image.to(dtype=dtype, device=device) + device = config.DEVICE + super().to(dtype=dtype, device=device) return self - def crop(self, *pixels): - raise NotImplementedError("Crop function not available for Image_List object") - - def get_coordinate_meshgrid(self): - return tuple(image.get_coordinate_meshgrid() for image in self.image_list) - - def get_coordinate_corner_meshgrid(self): - return tuple(image.get_coordinate_corner_meshgrid() for image in self.image_list) - - def get_coordinate_simps_meshgrid(self): - return tuple(image.get_coordinate_simps_meshgrid() for image in self.image_list) - - def flatten(self, attribute="data"): - return torch.cat(tuple(image.flatten(attribute) for image in self.image_list)) - - def reduce(self, scale): - if scale == 1: - return self - - return self.__class__( - tuple(image.reduce(scale) for image in self.image_list), - ) + def flatten(self, attribute: str = "data") -> ArrayLike: + return backend.concatenate(tuple(image.flatten(attribute) for image in self.images)) def __sub__(self, other): - if isinstance(other, Image_List): + if isinstance(other, ImageList): new_list = [] - for self_image, other_image in zip(self.image_list, other.image_list): + for other_image in other.images: + i = self.index(other_image) + self_image = self.images[i] new_list.append(self_image - other_image) return self.__class__(new_list) else: - new_list = [] - for self_image, other_image in zip(self.image_list, other): - new_list.append(self_image - other_image) - return self.__class__(new_list) + raise ValueError("Subtraction of Image_List only works with another Image_List object!") def __add__(self, other): - if isinstance(other, Image_List): + if isinstance(other, ImageList): new_list = [] - for self_image, other_image in zip(self.image_list, other.image_list): + for other_image in other.images: + try: + i = self.index(other_image) + except IndexError: + continue + self_image = self.images[i] new_list.append(self_image + other_image) return self.__class__(new_list) else: - new_list = [] - for self_image, other_image in zip(self.image_list, other): - new_list.append(self_image + other_image) - return self.__class__(new_list) + raise ValueError("Addition of Image_List only works with another Image_List object!") def __isub__(self, other): - if isinstance(other, Image_List): - for self_image, other_image in zip(self.image_list, other.image_list): - self_image -= other_image + if isinstance(other, ImageList): + for other_image in other.images: + try: + i = self.index(other_image) + except IndexError: + continue + self.images[i] -= other_image + elif isinstance(other, Image): + i = self.index(other) + self.images[i] -= other else: - for self_image, other_image in zip(self.image_list, other): - self_image -= other_image + raise ValueError("Subtraction of Image_List only works with another Image_List object!") return self def __iadd__(self, other): - if isinstance(other, Image_List): - for self_image, other_image in zip(self.image_list, other.image_list): - self_image += other_image + if isinstance(other, ImageList): + for other_image in other.images: + try: + i = self.index(other_image) + except IndexError: + continue + self.images[i] += other_image + elif isinstance(other, Image): + i = self.index(other) + self.images[i] += other else: - for self_image, other_image in zip(self.image_list, other): - self_image += other_image + raise ValueError("Addition of Image_List only works with another Image_List object!") return self - def save(self, filename=None, overwrite=True): - raise NotImplementedError("Save/load not yet available for image lists") - - def load(self, filename): - raise NotImplementedError("Save/load not yet available for image lists") - def __getitem__(self, *args): - if len(args) == 1 and isinstance(args[0], Window): - return self.get_window(args[0]) - if len(args) == 1 and isinstance(args[0], Image): - return self.get_window(args[0].window) - if all(isinstance(arg, (int, slice)) for arg in args): - return self.image_list.__getitem__(*args) - raise ValueError("Unrecognized Image_List getitem request!") - - def __str__(self): - return "image list of:\n" + "\n".join(image.__str__() for image in self.image_list) - - def __repr__(self): - return "image list of:\n" + "\n".join(image.__repr__() for image in self.image_list) + if len(args) == 1: + if isinstance(args[0], ImageList): + new_list = [] + for other_image in args[0].images: + i = self.index(other_image) + new_list.append(self.images[i].get_window(other_image)) + return self.__class__(new_list) + elif isinstance(args[0], WindowList): + new_list = [] + for other_window in args[0].windows: + i = self.index(other_window.image) + new_list.append(self.images[i].get_window(other_window)) + return self.__class__(new_list) + elif isinstance(args[0], Image): + i = self.index(args[0]) + return self.images[i].get_window(args[0]) + elif isinstance(args[0], Window): + i = self.index(args[0].image) + return self.images[i].get_window(args[0]) + elif isinstance(args[0], int): + return self.images[args[0]] + super().__getitem__(*args) def __iter__(self): - return (img for img in self.image_list) - - # self._index = 0 - # return self - - # def __next__(self): - # if self._index >= len(self.image_list): - # raise StopIteration - # img = self.image_list[self._index] - # self._index += 1 - # return img + return (img for img in self.images) diff --git a/astrophot/image/jacobian_image.py b/astrophot/image/jacobian_image.py index cf8e42ba..caaef243 100644 --- a/astrophot/image/jacobian_image.py +++ b/astrophot/image/jacobian_image.py @@ -1,16 +1,14 @@ -from typing import List +from typing import List, Union -import torch - -from .image_object import Image, Image_List -from .. import AP_config +from .image_object import Image, ImageList from ..errors import SpecificationConflict, InvalidImage +from ..backend_obj import backend -__all__ = ["Jacobian_Image", "Jacobian_Image_List"] +__all__ = ("JacobianImage", "JacobianImageList") ###################################################################### -class Jacobian_Image(Image): +class JacobianImage(Image): """Jacobian of a model evaluated in an image. Image object which represents the evaluation of a jacobian on an @@ -23,103 +21,54 @@ class Jacobian_Image(Image): def __init__( self, parameters: List[str], - target_identity: str, **kwargs, ): super().__init__(**kwargs) - self.target_identity = target_identity self.parameters = list(parameters) if len(self.parameters) != len(set(self.parameters)): raise SpecificationConflict("Every parameter should be unique upon jacobian creation") - def flatten(self, attribute: str = "data"): - return getattr(self, attribute).reshape((-1, len(self.parameters))) - def copy(self, **kwargs): - return super().copy( - parameters=self.parameters, target_identity=self.target_identity, **kwargs - ) - - def get_state(self): - state = super().get_state() - state["target_identity"] = self.target_identity - state["parameters"] = self.parameters - return state - - def set_state(self, state): - super().set_state(state) - self.target_identity = state["target_identity"] - self.parameters = state["parameters"] - - def get_fits_state(self): - states = super().get_fits_state() - for state in states: - if state["HEADER"]["IMAGE"] == "PRIMARY": - state["HEADER"]["TRGTID"] = self.target_identity - state["HEADER"]["PARAMS"] = str(self.parameters) - return states - - def set_fits_state(self, states): - super().set_fits_state(states) - for state in states: - if state["HEADER"]["IMAGE"] == "PRIMARY": - self.target_identity = state["HEADER"]["TRGTID"] - self.parameters = eval(state["HEADER"]["params"]) - - def __add__(self, other): - raise NotImplementedError("Jacobian images cannot add like this, use +=") - - def __sub__(self, other): - raise NotImplementedError("Jacobian images cannot subtract") - - def __isub__(self, other): - raise NotImplementedError("Jacobian images cannot subtract") - - def __iadd__(self, other): - if not isinstance(other, Jacobian_Image): + return super().copy(parameters=self.parameters, **kwargs) + + def match_parameters(self, other: Union["JacobianImage", "JacobianImageList", List]): + self_i = [] + other_i = [] + other_parameters = other if isinstance(other, list) else other.parameters + for i, other_param in enumerate(other_parameters): + if other_param in self.parameters: + self_i.append(self.parameters.index(other_param)) + other_i.append(i) + return self_i, other_i + + def __iadd__(self, other: "JacobianImage"): + if not isinstance(other, JacobianImage): raise InvalidImage("Jacobian images can only add with each other, not: type(other)") - # exclude null jacobian images - if other.data is None: - return self - if self.data is None: - return other - - full_window = self.window | other.window - - self_indices = other.window.get_other_indices(self) - other_indices = self.window.get_other_indices(other) - for i, other_identity in enumerate(other.parameters): - if other_identity in self.parameters: - other_loc = self.parameters.index(other_identity) - else: - self.set_data( - torch.cat( - ( - self.data, - torch.zeros( - self.data.shape[0], - self.data.shape[1], - 1, - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ), - ), - dim=2, - ), - require_shape=False, - ) - self.parameters.append(other_identity) - other_loc = -1 - self.data[self_indices[0], self_indices[1], other_loc] += other.data[ - other_indices[0], other_indices[1], i - ] + self_indices = self.get_indices(other.window) + other_indices = other.get_indices(self.window) + for self_i, other_i in zip(*self.match_parameters(other)): + self._data = backend.add_at_indices( + self._data, + self_indices + (self_i,), + other._data[other_indices[0], other_indices[1], other_i], + ) return self + def plane_to_world(self, x, y): + raise NotImplementedError( + "JacobianImage does not support plane_to_world conversion. There is no meaningful world position of a PSF image." + ) + + def world_to_plane(self, ra, dec): + raise NotImplementedError( + "JacobianImage does not support world_to_plane conversion. There is no meaningful world position of a PSF image." + ) + ###################################################################### -class Jacobian_Image_List(Image_List, Jacobian_Image): +class JacobianImageList(ImageList): """For joint modelling, represents Jacobians evaluated on a list of images. @@ -132,44 +81,35 @@ class Jacobian_Image_List(Image_List, Jacobian_Image): """ - def __init__(self, image_list): - super().__init__(image_list) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not all(isinstance(image, (JacobianImage, JacobianImageList)) for image in self.images): + raise InvalidImage( + f"JacobianImageList can only hold JacobianImage objects, not {tuple(type(image) for image in self.images)}" + ) - def flatten(self, attribute="data"): - if len(self.image_list) > 1: - for image in self.image_list[1:]: - if self.image_list[0].parameters != image.parameters: + @property + def parameters(self) -> List[str]: + """List of parameters for the jacobian images in this list.""" + if not self.images: + return [] + return self.images[0].parameters + + def flatten(self, attribute: str = "data"): + if len(self.images) > 1: + for image in self.images[1:]: + if self.images[0].parameters != image.parameters: raise SpecificationConflict( "Jacobian image list sub-images track different parameters. Please initialize with all parameters that will be used." ) - return torch.cat(tuple(image.flatten(attribute) for image in self.image_list)) - - def __add__(self, other): - raise NotImplementedError("Jacobian images cannot add like this, use +=") - - def __sub__(self, other): - raise NotImplementedError("Jacobian images cannot subtract") - - def __isub__(self, other): - raise NotImplementedError("Jacobian images cannot subtract") - - def __iadd__(self, other): - if isinstance(other, Jacobian_Image_List): - for other_image in other.image_list: - for self_image in self.image_list: - if other_image.target_identity == self_image.target_identity: - self_image += other_image - break - else: - self.image_list.append(other_image) - elif isinstance(other, Jacobian_Image): - for self_image in self.image_list: - if other.target_identity == self_image.target_identity: - self_image += other - break - else: - self.image_list.append(other_image) - else: - for self_image, other_image in zip(self.image_list, other): - self_image += other_image - return self + return backend.concatenate(tuple(image.flatten(attribute) for image in self.images), dim=0) + + def match_parameters(self, other: Union[JacobianImage, "JacobianImageList", List[str]]): + self_i = [] + other_i = [] + other_parameters = other if isinstance(other, list) else other.parameters + for i, other_param in enumerate(other_parameters): + if other_param in self.parameters: + self_i.append(self.parameters.index(other_param)) + other_i.append(i) + return self_i, other_i diff --git a/astrophot/image/mixins/__init__.py b/astrophot/image/mixins/__init__.py new file mode 100644 index 00000000..00c57f96 --- /dev/null +++ b/astrophot/image/mixins/__init__.py @@ -0,0 +1,5 @@ +from .data_mixin import DataMixin +from .sip_mixin import SIPMixin +from .cmos_mixin import CMOSMixin + +__all__ = ("DataMixin", "SIPMixin", "CMOSMixin") diff --git a/astrophot/image/mixins/cmos_mixin.py b/astrophot/image/mixins/cmos_mixin.py new file mode 100644 index 00000000..f3ac2c05 --- /dev/null +++ b/astrophot/image/mixins/cmos_mixin.py @@ -0,0 +1,59 @@ +from typing import Optional, Tuple + +from .. import func +from ... import config + + +class CMOSMixin: + """ + A mixin class for CMOS image processing. This class can be used to add + CMOS-specific functionality to image processing classes. + """ + + def __init__( + self, + *args, + subpixel_loc: Tuple[float, float] = (0, 0), + subpixel_scale: float = 1.0, + filename: Optional[str] = None, + **kwargs, + ): + super().__init__(*args, filename=filename, **kwargs) + if filename is not None: + return + self.subpixel_loc = subpixel_loc + self.subpixel_scale = subpixel_scale + + @property + def base_scale(self): + """Get the base scale of the image, which is the subpixel scale.""" + return self.subpixel_scale + + def pixel_center_meshgrid(self): + """Get a meshgrid of pixel coordinates in the image, centered on the pixel grid.""" + return func.cmos_pixel_center_meshgrid( + self._data.shape, self.subpixel_loc, config.DTYPE, config.DEVICE + ) + + def copy(self, **kwargs): + return super().copy( + subpixel_loc=self.subpixel_loc, subpixel_scale=self.subpixel_scale, **kwargs + ) + + def fits_info(self): + info = super().fits_info() + info["SPIXLOC1"] = self.subpixel_loc[0] + info["SPIXLOC2"] = self.subpixel_loc[1] + info["SPIXSCL"] = self.subpixel_scale + return info + + def load(self, filename: str, hduext: int = 0): + hdulist = super().load(filename, hduext=hduext) + if "SPIXLOC1" in hdulist[hduext].header: + self.subpixel_loc = ( + hdulist[0].header.get("SPIXLOC1", 0), + hdulist[0].header.get("SPIXLOC2", 0), + ) + if "SPIXSCL" in hdulist[hduext].header: + self.subpixel_scale = hdulist[0].header.get("SPIXSCL", 1.0) + return hdulist diff --git a/astrophot/image/mixins/data_mixin.py b/astrophot/image/mixins/data_mixin.py new file mode 100644 index 00000000..88bfde35 --- /dev/null +++ b/astrophot/image/mixins/data_mixin.py @@ -0,0 +1,301 @@ +from typing import Union, Optional + +import numpy as np +from astropy.io import fits + +from ...utils.initialize import auto_variance +from ... import config +from ...backend_obj import backend, ArrayLike +from ...errors import SpecificationConflict +from ..image_object import Image +from ..window import Window + + +class DataMixin: + """Mixin for data handling in image objects. + + This mixin provides functionality for handling variance and mask, + as well as other ancillary data. + + **Args:** + - `mask`: A boolean mask indicating which pixels to ignore. + - `std`: Standard deviation of the image pixels. + - `variance`: Variance of the image pixels. + - `weight`: Weights for the image pixels. + + Note that only one of `std`, `variance`, or `weight` should be + provided at a time. If multiple are provided, an error will be raised. + """ + + def __init__( + self, + *args, + mask: Optional[ArrayLike] = None, + std: Optional[ArrayLike] = None, + variance: Optional[ArrayLike] = None, + weight: Optional[ArrayLike] = None, + _mask: Optional[ArrayLike] = None, + _weight: Optional[ArrayLike] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + + if _mask is None: + self.mask = mask + else: + self._mask = _mask + if (std is not None) + (variance is not None) + (weight is not None) > 1: + raise SpecificationConflict( + "Can only define one of: std, variance, or weight for a given image." + ) + + if _weight is not None: + self._weight = _weight + elif std is not None: + self.std = std + elif variance is not None: + self.variance = variance + else: + self.weight = weight + + # Set nan pixels to be masked automatically + if backend.any(backend.isnan(self._data)).item(): + self._mask = self._mask | backend.isnan(self._data) + + @property + def std(self): + """Stores the standard deviation of the image pixels. This represents + the uncertainty in each pixel value. It should always have the + same shape as the image data. In the case where the standard + deviation is not known, a tensor of ones will be created to + stand in as the standard deviation values. + + The standard deviation is not stored directly, instead it is + computed as $\\sqrt{1/W}$ where $W$ is the weights. + + """ + return backend.sqrt(self.variance) + + @std.setter + def std(self, std): + if std is None: + self._weight = None + return + if isinstance(std, str) and std == "auto": + self.weight = "auto" + return + self.weight = 1 / std**2 + + @property + def variance(self): + """Stores the variance of the image pixels. This represents the + uncertainty in each pixel value. It should always have the + same shape as the image data. In the case where the variance + is not known, a tensor of ones will be created to stand in as + the variance values. + + The variance is not stored directly, instead it is + computed as $\\frac{1}{W}$ where $W$ is the + weights. + + """ + return backend.where(self.weight == 0, backend.inf, 1 / self.weight) + + @property + def _variance(self): + return backend.where(self._weight == 0, backend.inf, 1 / self._weight) + + @variance.setter + def variance(self, variance): + if variance is None: + self._weight = None + return + if isinstance(variance, str) and variance == "auto": + self.weight = "auto" + return + self.weight = 1 / variance + + @property + def weight(self): + """Stores the weight of the image pixels. This represents the + uncertainty in each pixel value. It should always have the + same shape as the image data. In the case where the weight + is not known, a tensor of ones will be created to stand in as + the weight values. + + The weights are used to proprtionately scale residuals in the + likelihood. Most commonly this shows up as a :math:`\\chi^2` + like: + + $$\\chi^2 = (\\vec{y} - \\vec{f(\\theta)})^TW(\\vec{y} - \\vec{f(\\theta)})$$ + + which can be optimized to find parameter values. Using the + Jacobian, which in this case is the derivative of every pixel + wrt every parameter, the weight matrix also appears in the + gradient: + + $$\\vec{g} = J^TW(\\vec{y} - \\vec{f(\\theta)})$$ + + and the hessian approximation used in Levenberg-Marquardt: + + $$H \\approx J^TWJ$$ + + """ + return backend.transpose(self._weight, 1, 0) + + @weight.setter + def weight(self, weight): + if weight is None: + self._weight = backend.ones_like(self._data) + return + if isinstance(weight, str) and weight == "auto": + weight = 1 / auto_variance(self.data, self.mask) + self._weight = backend.transpose( + backend.as_array(weight, dtype=config.DTYPE, device=config.DEVICE), 1, 0 + ) + if self._weight.shape != self._data.shape: + raise SpecificationConflict( + f"weight/variance must have same shape as data ({weight.shape} vs {self.data.shape})" + ) + + @property + def _weight(self): + return self.__weight + + @_weight.setter + def _weight(self, value): + if value is None: + value = backend.ones_like(self._data) + self.__weight = value + + @property + def mask(self): + """The mask stores a tensor of boolean values which indicate any + pixels to be ignored. These pixels will be skipped in + likelihood evaluations and in parameter optimization. It is + common practice to mask pixels with pathological values such + as due to cosmic rays or satellites passing through the image. + + In a mask, a True value indicates that the pixel is masked and + should be ignored. False indicates a normal pixel which will + inter into most calculations. + + If no mask is provided, all pixels are assumed valid. + + """ + return backend.transpose(self._mask, 1, 0) + + @mask.setter + def mask(self, mask): + if mask is None: + self._mask = backend.zeros_like(self._data, dtype=backend.bool) + return + self._mask = backend.transpose( + backend.as_array(mask, dtype=backend.bool, device=config.DEVICE), 1, 0 + ) + if self._mask.shape != self._data.shape: + raise SpecificationConflict( + f"mask must have same shape as data ({mask.shape} vs {self.data.shape})" + ) + + @property + def _mask(self): + return self.__mask + + @_mask.setter + def _mask(self, value): + if value is None: + value = backend.zeros_like(self._data, dtype=backend.bool) + self.__mask = value + + def to(self, dtype=None, device=None): + """Converts the stored `Target_Image` data, variance, psf, etc to a + given data type and device. + + """ + if dtype is not None: + dtype = config.DTYPE + if device is not None: + device = config.DEVICE + super().to(dtype=dtype, device=device) + + self._weight = backend.to(self._weight, dtype=dtype, device=device) + self._mask = backend.to(self._mask, dtype=backend.bool, device=device) + return self + + def copy_kwargs(self, **kwargs): + """Produce a copy of this image with all of the same properties. This + can be used when one wishes to make temporary modifications to + an image and then will want the original again. + + """ + kwargs = {"_mask": self._mask, "_weight": self._weight, **kwargs} + return super().copy_kwargs(**kwargs) + + def get_window(self, other: Union[Image, Window], indices=None, **kwargs): + """Get a sub-region of the image as defined by an other image on the sky.""" + if indices is None: + indices = self.get_indices(other if isinstance(other, Window) else other.window) + return super().get_window( + other, + _weight=self._weight[indices], + _mask=self._mask[indices], + indices=indices, + **kwargs, + ) + + def fits_images(self): + images = super().fits_images() + images.append(fits.ImageHDU(backend.to_numpy(self.weight), name="WEIGHT")) + images.append( + fits.ImageHDU( + backend.to_numpy(self.mask).astype(int), + name="MASK", + ) + ) + return images + + def load(self, filename: str, hduext: int = 0): + """Load the image from a FITS file. This will load the data, WCS, and + any ancillary data such as variance, mask, and PSF. + + """ + hdulist = super().load(filename, hduext=hduext) + if "WEIGHT" in hdulist: + self.weight = np.array(hdulist["WEIGHT"].data, dtype=np.float64) + if "MASK" in hdulist: + self.mask = np.array(hdulist["MASK"].data, dtype=bool) + elif "DQ" in hdulist: + self.mask = np.array(hdulist["DQ"].data, dtype=bool) + return hdulist + + def reduce(self, scale: int, **kwargs) -> Image: + """Returns a new `TargetImage` object with a reduced resolution + compared to the current image. `scale` should be an integer + indicating how much to reduce the resolution. If the + `TargetImage` was originally (48,48) pixels across with a + pixelscale of 1 and `reduce(2)` is called then the image will + be (24,24) pixels and the pixelscale will be 2. If `reduce(3)` + is called then the returned image will be (16,16) pixels + across and the pixelscale will be 3. + + """ + MS = self._data.shape[0] // scale + NS = self._data.shape[1] // scale + + return super().reduce( + scale=scale, + _weight=( + 1 + / backend.sum( + self._variance[: MS * scale, : NS * scale].reshape(MS, scale, NS, scale), + dim=(1, 3), + ) + ), + _mask=( + backend.max( + self._mask[: MS * scale, : NS * scale].reshape(MS, scale, NS, scale), dim=(1, 3) + ) + ), + **kwargs, + ) diff --git a/astrophot/image/mixins/sip_mixin.py b/astrophot/image/mixins/sip_mixin.py new file mode 100644 index 00000000..0acc0457 --- /dev/null +++ b/astrophot/image/mixins/sip_mixin.py @@ -0,0 +1,255 @@ +from typing import Union, Optional, Tuple + +from ..image_object import Image +from ..window import Window +from .. import func +from ... import config +from ...backend_obj import backend, ArrayLike +from ...utils.interpolate import interp2d +from ...param import forward + + +class SIPMixin: + """A mixin class for SIP (Simple Image Polynomial) distortion model.""" + + expect_ctype = (("RA---TAN-SIP",), ("DEC--TAN-SIP",)) + + def __init__( + self, + *args, + sipA: dict[Tuple[int, int], float] = {}, + sipB: dict[Tuple[int, int], float] = {}, + sipAP: dict[Tuple[int, int], float] = {}, + sipBP: dict[Tuple[int, int], float] = {}, + pixel_area_map: Optional[ArrayLike] = None, + distortion_ij: Optional[ArrayLike] = None, + distortion_IJ: Optional[ArrayLike] = None, + filename: Optional[str] = None, + **kwargs, + ): + super().__init__(*args, filename=filename, **kwargs) + if filename is not None: + return + self.sipA = sipA + self.sipB = sipB + self.sipAP = sipAP + self.sipBP = sipBP + + if len(self.sipAP) == 0 and len(self.sipA) > 0: + self.compute_backward_sip_coefs() + + self.update_distortion_model( + distortion_ij=distortion_ij, distortion_IJ=distortion_IJ, pixel_area_map=pixel_area_map + ) + + @forward + def pixel_to_plane( + self, + i: ArrayLike, + j: ArrayLike, + crtan: ArrayLike, + CD: ArrayLike, + ) -> Tuple[ArrayLike, ArrayLike]: + di = interp2d(self.distortion_ij[0], i, j, padding_mode="border") + dj = interp2d(self.distortion_ij[1], i, j, padding_mode="border") + return func.pixel_to_plane_linear(i + di, j + dj, *self.crpix, CD, *crtan) + + @forward + def plane_to_pixel( + self, + x: ArrayLike, + y: ArrayLike, + crtan: ArrayLike, + CD: ArrayLike, + ) -> Tuple[ArrayLike, ArrayLike]: + I, J = func.plane_to_pixel_linear(x, y, *self.crpix, CD, *crtan) + dI = interp2d(self.distortion_IJ[0], I, J, padding_mode="border") + dJ = interp2d(self.distortion_IJ[1], I, J, padding_mode="border") + return I + dI, J + dJ + + @property + def pixel_area_map(self): + return self._pixel_area_map + + @property + def A_ORDER(self) -> int: + if self.sipA: + return max(a + b for a, b in self.sipA) + return 0 + + @property + def B_ORDER(self) -> int: + if self.sipB: + return max(a + b for a, b in self.sipB) + return 0 + + def compute_backward_sip_coefs(self): + """ + Credit: Shu Liu and Lei Hi, see here: + https://github.com/Roman-Supernova-PIT/sfft/blob/master/sfft/utils/CupyWCSTransform.py + + Compute the backward transformation from (U, V) to (u, v) + """ + i, j = self.pixel_center_meshgrid() + u, v = i - self.crpix[0], j - self.crpix[1] + du, dv = func.sip_delta(u, v, self.sipA, self.sipB) + U = (u + du).flatten() + V = (v + dv).flatten() + AP, BP = func.sip_backward_transform( + u.flatten(), v.flatten(), U, V, self.A_ORDER, self.B_ORDER + ) + self.sipAP = dict( + ((p, q), ap.item()) for (p, q), ap in zip(func.sip_coefs(self.A_ORDER), AP) + ) + self.sipBP = dict( + ((p, q), bp.item()) for (p, q), bp in zip(func.sip_coefs(self.B_ORDER), BP) + ) + + def update_distortion_model( + self, + distortion_ij: Optional[ArrayLike] = None, + distortion_IJ: Optional[ArrayLike] = None, + pixel_area_map: Optional[ArrayLike] = None, + ): + """ + Update the pixel area map based on the current SIP coefficients. + """ + + # Pixelized distortion model + ############################################################# + if distortion_ij is None or distortion_IJ is None: + i, j = self.pixel_center_meshgrid() + u, v = i - self.crpix[0], j - self.crpix[1] + if distortion_ij is None: + distortion_ij = backend.stack(func.sip_delta(u, v, self.sipA, self.sipB), dim=0) + if distortion_IJ is None: + # fixme maybe + distortion_IJ = backend.stack(func.sip_delta(u, v, self.sipAP, self.sipBP), dim=0) + self.distortion_ij = distortion_ij + self.distortion_IJ = distortion_IJ + + # Pixel area map + ############################################################# + if pixel_area_map is not None: + self._pixel_area_map = pixel_area_map + return + i, j = self.pixel_corner_meshgrid() + x, y = self.pixel_to_plane(i, j) + + # Shoelace formula for pixel area + # 1: [:-1, :-1] + # 2: [:-1, 1:] + # 3: [1:, 1:] + # 4: [1:, :-1] + A = 0.5 * ( + x[:-1, :-1] * y[:-1, 1:] + + x[:-1, 1:] * y[1:, 1:] + + x[1:, 1:] * y[1:, :-1] + + x[1:, :-1] * y[:-1, :-1] + - ( + x[:-1, 1:] * y[:-1, :-1] + + x[1:, 1:] * y[:-1, 1:] + + x[1:, :-1] * y[1:, 1:] + + x[:-1, :-1] * y[1:, :-1] + ) + ) + self._pixel_area_map = backend.abs(A) + + def to(self, dtype=None, device=None): + if dtype is None: + dtype = config.DTYPE + if device is None: + device = config.DEVICE + super().to(dtype=dtype, device=device) + self._pixel_area_map = backend.to(self._pixel_area_map, dtype=dtype, device=device) + self.distortion_ij = backend.to(self.distortion_ij, dtype=dtype, device=device) + self.distortion_IJ = backend.to(self.distortion_IJ, dtype=dtype, device=device) + + def copy_kwargs(self, **kwargs): + kwargs = { + "sipA": self.sipA, + "sipB": self.sipB, + "sipAP": self.sipAP, + "sipBP": self.sipBP, + "pixel_area_map": self.pixel_area_map, + "distortion_ij": self.distortion_ij, + "distortion_IJ": self.distortion_IJ, + **kwargs, + } + return super().copy_kwargs(**kwargs) + + def get_window(self, other: Union[Image, Window], indices=None, **kwargs): + """Get a sub-region of the image as defined by an other image on the sky.""" + if indices is None: + indices = self.get_indices(other if isinstance(other, Window) else other.window) + return super().get_window( + other, + pixel_area_map=self.pixel_area_map[indices], + distortion_ij=self.distortion_ij[:, indices[0], indices[1]], + distortion_IJ=self.distortion_IJ[:, indices[0], indices[1]], + indices=indices, + **kwargs, + ) + + def fits_info(self): + info = super().fits_info() + info["CTYPE1"] = "RA---TAN-SIP" + info["CTYPE2"] = "DEC--TAN-SIP" + a_order = 0 + for a, b in self.sipA: + info[f"A_{a}_{b}"] = self.sipA[(a, b)] + a_order = max(a_order, a + b) + info["A_ORDER"] = a_order + b_order = 0 + for a, b in self.sipB: + info[f"B_{a}_{b}"] = self.sipB[(a, b)] + b_order = max(b_order, a + b) + info["B_ORDER"] = b_order + ap_order = 0 + for a, b in self.sipAP: + info[f"AP_{a}_{b}"] = self.sipAP[(a, b)] + ap_order = max(ap_order, a + b) + info["AP_ORDER"] = ap_order + bp_order = 0 + for a, b in self.sipBP: + info[f"BP_{a}_{b}"] = self.sipBP[(a, b)] + bp_order = max(bp_order, a + b) + info["BP_ORDER"] = bp_order + return info + + def load(self, filename: str, hduext: int = 0): + hdulist = super().load(filename, hduext=hduext) + self.sipA = {} + if "A_ORDER" in hdulist[hduext].header: + a_order = hdulist[hduext].header["A_ORDER"] + for i in range(a_order + 1): + for j in range(a_order + 1 - i): + key = (i, j) + if f"A_{i}_{j}" in hdulist[hduext].header: + self.sipA[key] = hdulist[hduext].header[f"A_{i}_{j}"] + self.sipB = {} + if "B_ORDER" in hdulist[hduext].header: + b_order = hdulist[hduext].header["B_ORDER"] + for i in range(b_order + 1): + for j in range(b_order + 1 - i): + key = (i, j) + if f"B_{i}_{j}" in hdulist[hduext].header: + self.sipB[key] = hdulist[hduext].header[f"B_{i}_{j}"] + self.sipAP = {} + if "AP_ORDER" in hdulist[hduext].header: + ap_order = hdulist[hduext].header["AP_ORDER"] + for i in range(ap_order + 1): + for j in range(ap_order + 1 - i): + key = (i, j) + if f"AP_{i}_{j}" in hdulist[hduext].header: + self.sipAP[key] = hdulist[hduext].header[f"AP_{i}_{j}"] + self.sipBP = {} + if "BP_ORDER" in hdulist[hduext].header: + bp_order = hdulist[hduext].header["BP_ORDER"] + for i in range(bp_order + 1): + for j in range(bp_order + 1 - i): + key = (i, j) + if f"BP_{i}_{j}" in hdulist[hduext].header: + self.sipBP[key] = hdulist[hduext].header[f"BP_{i}_{j}"] + self.update_distortion_model() + return hdulist diff --git a/astrophot/image/model_image.py b/astrophot/image/model_image.py index 9215d00e..3d969338 100644 --- a/astrophot/image/model_image.py +++ b/astrophot/image/model_image.py @@ -1,16 +1,11 @@ -import torch - -from .. import AP_config -from .image_object import Image, Image_List -from .window_object import Window -from ..utils.interpolate import shift_Lanczos_torch +from .image_object import Image, ImageList from ..errors import InvalidImage -__all__ = ["Model_Image", "Model_Image_List"] +__all__ = ["ModelImage", "ModelImageList"] ###################################################################### -class Model_Image(Image): +class ModelImage(Image): """Image object which represents the sampling of a model at the given coordinates of the image. Extra arithmetic operations are available which can update model values in the image. The whole @@ -19,160 +14,17 @@ class Model_Image(Image): """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.target_identity = kwargs.get("target_identity", None) - self.to() - - def clear_image(self): - self.data = torch.zeros_like(self.data) - - def shift_origin(self, shift, is_prepadded=True): - self.window.shift(shift) - pix_shift = self.plane_to_pixel_delta(shift) - if torch.any(torch.abs(pix_shift) > 1): - raise NotImplementedError("Shifts larger than 1 pixel are currently not handled") - self.data = shift_Lanczos_torch( - self.data, - pix_shift[0], - pix_shift[1], - min(min(self.data.shape), 10), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - img_prepadded=is_prepadded, - ) - - def get_window(self, window: Window, **kwargs): - return super().get_window(window, target_identity=self.target_identity, **kwargs) - - def reduce(self, scale, **kwargs): - return super().reduce(scale, target_identity=self.target_identity, **kwargs) - - def replace(self, other, data=None): - if isinstance(other, Image): - if self.window.overlap_frac(other.window) == 0.0: # fixme control flow - return - other_indices = self.window.get_other_indices(other) - self_indices = other.window.get_other_indices(self) - if self.data[self_indices].nelement() == 0 or other.data[other_indices].nelement() == 0: - return - self.data[self_indices] = other.data[other_indices] - elif isinstance(other, Window): - self.data[self.window.get_self_indices(other)] = data - else: - self.data = other - - def copy(self, **kwargs): - return super().copy(target_identity=self.target_identity, **kwargs) - - def blank_copy(self, **kwargs): - return super().blank_copy(target_identity=self.target_identity, **kwargs) - - def get_state(self): - state = super().get_state() - state["target_identity"] = self.target_identity - return state - - def set_state(self, state): - super().set_state(state) - self.target_identity = state["target_identity"] - - def get_fits_state(self): - states = super().get_fits_state() - for state in states: - if state["HEADER"]["IMAGE"] == "PRIMARY": - state["HEADER"]["TRGTID"] = self.target_identity - return states - - def set_fits_state(self, states): - super().set_fits_state(states) - for state in states: - if state["HEADER"]["IMAGE"] == "PRIMARY": - self.target_identity = state["HEADER"]["TRGTID"] + def fluxdensity_to_flux(self): + self._data = self._data * self.pixel_area ###################################################################### -class Model_Image_List(Image_List, Model_Image): +class ModelImageList(ImageList): + """A list of ModelImage objects.""" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if not all(isinstance(image, Model_Image) for image in self.image_list): + if not all(isinstance(image, (ModelImage, ModelImageList)) for image in self.images): raise InvalidImage( - f"Model_Image_List can only hold Model_Image objects, not {tuple(type(image) for image in self.image_list)}" + f"Model_Image_List can only hold Model_Image objects, not {tuple(type(image) for image in self.images)}" ) - - def clear_image(self): - for image in self.image_list: - image.clear_image() - - def shift_origin(self, shift): - raise NotImplementedError() - - def replace(self, other, data=None): - if data is None: - for image, oth in zip(self.image_list, other): - image.replace(oth) - else: - for image, oth, dat in zip(self.image_list, other, data): - image.replace(oth, dat) - - @property - def target_identity(self): - targets = tuple(image.target_identity for image in self.image_list) - if any(tar_id is None for tar_id in targets): - return None - return targets - - def __isub__(self, other): - if isinstance(other, Model_Image_List): - for other_image, zip_self_image in zip(other.image_list, self.image_list): - if other_image.target_identity is None or self.target_identity is None: - zip_self_image -= other_image - continue - for self_image in self.image_list: - if other_image.target_identity == self_image.target_identity: - self_image -= other_image - break - else: - self.image_list.append(other_image) - elif isinstance(other, Model_Image): - if other.target_identity is None or zip_self_image.target_identity is None: - zip_self_image -= other_image - else: - for self_image in self.image_list: - if other.target_identity == self_image.target_identity: - self_image -= other - break - else: - self.image_list.append(other) - else: - for self_image, other_image in zip(self.image_list, other): - self_image -= other_image - return self - - def __iadd__(self, other): - if isinstance(other, Model_Image_List): - for other_image, zip_self_image in zip(other.image_list, self.image_list): - if other_image.target_identity is None or self.target_identity is None: - zip_self_image += other_image - continue - for self_image in self.image_list: - if other_image.target_identity == self_image.target_identity: - self_image += other_image - break - else: - self.image_list.append(other_image) - elif isinstance(other, Model_Image): - if other.target_identity is None or self.target_identity is None: - for self_image in self.image_list: - self_image += other - else: - for self_image in self.image_list: - if other.target_identity == self_image.target_identity: - self_image += other - break - else: - self.image_list.append(other) - else: - for self_image, other_image in zip(self.image_list, other): - self_image += other_image - return self diff --git a/astrophot/image/psf_image.py b/astrophot/image/psf_image.py index ff267270..d0cc05b5 100644 --- a/astrophot/image/psf_image.py +++ b/astrophot/image/psf_image.py @@ -1,158 +1,99 @@ -from typing import List, Optional, Union +from typing import List, Optional -import torch import numpy as np from .image_object import Image -from .image_header import Image_Header -from .model_image import Model_Image -from .jacobian_image import Jacobian_Image -from astropy.io import fits -from .. import AP_config -from ..errors import SpecificationConflict +from .jacobian_image import JacobianImage +from .. import config +from ..backend_obj import backend, ArrayLike +from .mixins import DataMixin -__all__ = ["PSF_Image"] +__all__ = ["PSFImage"] -class PSF_Image(Image): +class PSFImage(DataMixin, Image): """Image object which represents a model of PSF (Point Spread Function). - PSF_Image inherits from the base Image class and represents the model of a point spread function. + PSFImage inherits from the base Image class and represents the model of a point spread function. The point spread function characterizes the response of an imaging system to a point source or point object. - The shape of the PSF data must be odd. - - Attributes: - data (torch.Tensor): The image data of the PSF. - identity (str): The identity of the image. Default is None. - - Methods: - psf_border_int: Calculates and returns the convolution border size of the PSF image in integer format. - psf_border: Calculates and returns the convolution border size of the PSF image in the units of pixelscale. - _save_image_list: Saves the image list to the PSF HDU header. - reduce: Reduces the size of the image using a given scale factor. + The shape of the PSF data should be odd (for your sanity) but this is not enforced. """ - has_mask = False - has_variance = False - def __init__(self, *args, **kwargs): - """ - Initializes the PSF_Image class. - - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - band (str, optional): The band of the image. Default is None. - """ + kwargs.update({"crval": (0, 0), "crpix": (0, 0), "crtan": (0, 0)}) super().__init__(*args, **kwargs) - - self.window.reference_radec = (0, 0) - self.window.reference_planexy = (0, 0) - self.window.reference_imageij = np.flip(np.array(self.data.shape, dtype=float) - 1.0) / 2 - self.window.reference_imagexy = (0, 0) - - def set_data(self, data: Union[torch.Tensor, np.ndarray], require_shape: bool = True): - super().set_data(data=data, require_shape=require_shape) - - if torch.any((torch.tensor(self.data.shape) % 2) != 1): - raise SpecificationConflict(f"psf must have odd shape, not {self.data.shape}") - if torch.any(self.data < 0): - AP_config.ap_logger.warning("psf data should be non-negative") + self.crpix = (np.array(self._data.shape[:2], dtype=np.float64) - 1.0) / 2 def normalize(self): """Normalizes the PSF image to have a sum of 1.""" - self.data /= torch.sum(self.data) + norm = backend.sum(self._data) + self._data = self._data / norm + self._weight = self._weight * norm**2 @property - def mask(self): - return torch.zeros_like(self.data, dtype=bool) - - @property - def psf_border_int(self): - """Calculates and returns the border size of the PSF image in integer - format. This is the border used for padding before convolution. - - Returns: - torch.Tensor: The border size of the PSF image in integer format. - - """ - return torch.ceil( - ( - 1 - + torch.flip( - torch.tensor( - self.data.shape, - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ), - (0,), - ) - ) - / 2 - ).int() - - def _save_image_list(self, image_list): - """Saves the image list to the PSF HDU header. - - Args: - image_list (list): The list of images to be saved. - psf_header (astropy.io.fits.Header): The header of the PSF HDU. - """ - img_header = self.header._save_image_list() - img_header["IMAGE"] = "PSF" - image_list.append(fits.ImageHDU(self.data.detach().cpu().numpy(), header=img_header)) + def psf_pad(self) -> int: + return max(self._data.shape[:2]) // 2 def jacobian_image( self, parameters: Optional[List[str]] = None, - data: Optional[torch.Tensor] = None, + data: Optional[ArrayLike] = None, **kwargs, - ): + ) -> JacobianImage: """ - Construct a blank `Jacobian_Image` object formatted like this current `PSF_Image` object. Mostly used internally. + Construct a blank `JacobianImage` object formatted like this current `PSFImage` object. Mostly used internally. """ if parameters is None: data = None parameters = [] elif data is None: - data = torch.zeros( - (*self.data.shape, len(parameters)), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, + data = backend.zeros( + (*self._data.shape, len(parameters)), + dtype=config.DTYPE, + device=config.DEVICE, ) - return Jacobian_Image( - parameters=parameters, - target_identity=self.identity, - data=data, - header=self.header, + kwargs = { + "CD": self.CD.value, + "crpix": self.crpix, + "crtan": self.crtan.value, + "crval": self.crval.value, + "zeropoint": self.zeropoint, + "identity": self.identity, **kwargs, - ) + } + return JacobianImage(parameters=parameters, _data=data, **kwargs) - def model_image(self, data: Optional[torch.Tensor] = None, **kwargs): + def model_image(self, **kwargs) -> "PSFImage": """ - Construct a blank `Model_Image` object formatted like this current `Target_Image` object. Mostly used internally. + Construct a blank `ModelImage` object formatted like this current `TargetImage` object. Mostly used internally. """ - return Model_Image( - data=torch.zeros_like(self.data) if data is None else data, - header=self.header, - target_identity=self.identity, + kwargs = { + "_data": backend.zeros_like(self._data), + "CD": self.CD.value, + "crpix": self.crpix, + "crtan": self.crtan.value, + "crval": self.crval.value, + "identity": self.identity, **kwargs, + } + return PSFImage(**kwargs) + + @property + def zeropoint(self): + return None + + @zeropoint.setter + def zeropoint(self, value): + """PSFImage does not support zeropoint.""" + pass + + def plane_to_world(self, x, y): + raise NotImplementedError( + "PSFImage does not support plane_to_world conversion. There is no meaningful world position of a PSF image." ) - def expand(self, padding): - raise NotImplementedError("expand not available for PSF_Image") - - def get_fits_state(self): - states = [{}] - states[0]["DATA"] = self.data.detach().cpu().numpy() - states[0]["HEADER"] = self.header.get_fits_state() - states[0]["HEADER"]["IMAGE"] = "PSF" - return states - - def set_fits_state(self, states): - for state in states: - if state["HEADER"]["IMAGE"] == "PSF": - self.set_data(np.array(state["DATA"], dtype=np.float64), require_shape=False) - self.header = Image_Header(fits_state=state["HEADER"]) - break + def world_to_plane(self, ra, dec): + raise NotImplementedError( + "PSFImage does not support world_to_plane conversion. There is no meaningful world position of a PSF image." + ) diff --git a/astrophot/image/sip_image.py b/astrophot/image/sip_image.py new file mode 100644 index 00000000..8e921be7 --- /dev/null +++ b/astrophot/image/sip_image.py @@ -0,0 +1,166 @@ +from typing import Tuple, Union + +from .target_image import TargetImage +from .model_image import ModelImage +from .mixins import SIPMixin +from ..backend_obj import backend, ArrayLike +from .. import config + + +class SIPModelImage(SIPMixin, ModelImage): + """ + A ModelImage with SIP distortion coefficients.""" + + def crop(self, pixels: Union[int, Tuple[int, int], Tuple[int, int, int, int]], **kwargs): + """ + Crop the image by the number of pixels given. This will crop + the image in all four directions by the number of pixels given. + """ + if isinstance(pixels, int): # same crop in all dimension + crop = (slice(pixels, -pixels), slice(pixels, -pixels)) + elif len(pixels) == 1: # same crop in all dimension + crop = (slice(pixels[0], -pixels[0]), slice(pixels[0], -pixels[0])) + elif len(pixels) == 2: # different crop in each dimension + crop = ( + slice(pixels[1], -pixels[1]), + slice(pixels[0], -pixels[0]), + ) + elif len(pixels) == 4: # different crop on all sides + crop = ( + slice(pixels[0], -pixels[1]), + slice(pixels[2], -pixels[3]), + ) + else: + raise ValueError( + f"Invalid crop shape {pixels}, must be int, (int,), (int, int), or (int, int, int, int)!" + ) + kwargs = { + "pixel_area_map": self.pixel_area_map[crop], + "distortion_ij": self.distortion_ij[:, crop[0], crop[1]], + "distortion_IJ": self.distortion_IJ[:, crop[0], crop[1]], + **kwargs, + } + return super().crop(pixels, **kwargs) + + def reduce(self, scale: int, **kwargs): + """This operation will downsample an image by the factor given. If + scale = 2 then 2x2 blocks of pixels will be summed together to + form individual larger pixels. A new image object will be + returned with the appropriate pixelscale and data tensor. Note + that the window does not change in this operation since the + pixels are condensed, but the pixel size is increased + correspondingly. + + **Args:** + - `scale`: factor by which to condense the image pixels. Each scale X scale region will be summed [int] + + """ + if not isinstance(scale, int) and not ( + isinstance(scale, ArrayLike) and scale.dtype is backend.int32 + ): + raise SpecificationConflict(f"Reduce scale must be an integer! not {type(scale)}") + if scale == 1: + return self + + MS = self._data.shape[0] // scale + NS = self._data.shape[1] // scale + + kwargs = { + "pixel_area_map": ( + backend.sum( + self.pixel_area_map[: MS * scale, : NS * scale].reshape(MS, scale, NS, scale), + dim=(1, 3), + ) + ), + "distortion_ij": ( + backend.mean( + self.distortion_ij[:, : MS * scale, : NS * scale].reshape( + 2, MS, scale, NS, scale + ), + dim=(2, 4), + ) + ), + "distortion_IJ": ( + backend.mean( + self.distortion_IJ[:, : MS * scale, : NS * scale].reshape( + 2, MS, scale, NS, scale + ), + dim=(2, 4), + ) + ), + **kwargs, + } + return super().reduce( + scale=scale, + **kwargs, + ) + + def fluxdensity_to_flux(self): + self._data = self._data * self.pixel_area_map + + +class SIPTargetImage(SIPMixin, TargetImage): + """ + A TargetImage with SIP distortion coefficients. + This class is used to represent a target image with SIP distortion coefficients. + It inherits from TargetImage and SIPMixin. + """ + + def model_image(self, upsample: int = 1, pad: int = 0, **kwargs) -> SIPModelImage: + new_area_map = self.pixel_area_map + new_distortion_ij = self.distortion_ij + new_distortion_IJ = self.distortion_IJ + if upsample > 1: + new_area_map = ( + backend.upsample2d(new_area_map[None, None], upsample, "nearest") + .squeeze(0) + .squeeze(0) + ) + new_distortion_ij = backend.upsample2d( + new_distortion_ij[:, None], upsample, "bilinear" + ).squeeze(1) + new_distortion_IJ = backend.upsample2d( + new_distortion_IJ[:, None], upsample, "bilinear" + ).squeeze(1) + if pad > 0: + new_area_map = ( + backend.pad( + new_area_map[None, None], + (0, 0, 0, 0, pad, pad, pad, pad), + mode="replicate", + ) + .squeeze(0) + .squeeze(0) + ) + new_distortion_ij = backend.pad( + new_distortion_ij[:, None], (0, 0, 0, 0, pad, pad, pad, pad), mode="replicate" + ).squeeze(1) + new_distortion_IJ = backend.pad( + new_distortion_IJ[:, None], (0, 0, 0, 0, pad, pad, pad, pad), mode="replicate" + ).squeeze(1) + kwargs = { + "pixel_area_map": new_area_map, + "sipA": self.sipA, + "sipB": self.sipB, + "sipAP": self.sipAP, + "sipBP": self.sipBP, + "distortion_ij": new_distortion_ij, + "distortion_IJ": new_distortion_IJ, + "_data": backend.zeros( + ( + self._data.shape[0] * upsample + 2 * pad, + self._data.shape[1] * upsample + 2 * pad, + ), + dtype=config.DTYPE, + device=config.DEVICE, + ), + "CD": self.CD.value / upsample, + "crpix": (self.crpix + 0.5) * upsample + pad - 0.5, + "crtan": self.crtan.value, + "crval": self.crval.value, + "zeropoint": self.zeropoint, + "identity": self.identity, + "name": self.name + "_model", + **kwargs, + } + return SIPModelImage(**kwargs) diff --git a/astrophot/image/target_image.py b/astrophot/image/target_image.py index 94408723..cb047d37 100644 --- a/astrophot/image/target_image.py +++ b/astrophot/image/target_image.py @@ -1,20 +1,23 @@ from typing import List, Optional -import torch import numpy as np +from astropy.io import fits -from .image_object import Image, Image_List -from .jacobian_image import Jacobian_Image, Jacobian_Image_List -from .model_image import Model_Image, Model_Image_List -from .psf_image import PSF_Image -from .. import AP_config -from ..utils.initialize import auto_variance -from ..errors import SpecificationConflict, InvalidImage +from .image_object import Image, ImageList +from .jacobian_image import JacobianImage, JacobianImageList +from .model_image import ModelImage, ModelImageList +from .psf_image import PSFImage +from .. import config +from ..backend_obj import backend, ArrayLike +from ..errors import InvalidImage +from .mixins import DataMixin +from ..utils.decorators import combine_docstrings -__all__ = ["Target_Image", "Target_Image_List"] +__all__ = ["TargetImage", "TargetImageList"] -class Target_Image(Image): +@combine_docstrings +class TargetImage(DataMixin, Image): """Image object which represents the data to be fit by a model. It can include a variance image, mask, and PSF as anciliary data which describes the target image. @@ -27,32 +30,32 @@ class Target_Image(Image): Basic usage: - .. code-block:: python + ```{python} + import astrophot as ap - import astrophot as ap + # Create target image + image = ap.image.Target_Image( + data="pixel data", + wcs="astropy WCS object", + variance="pixel uncertainties", + psf="point spread function as PSF_Image object", + mask="True for pixels to ignore", + ) - # Create target image - image = ap.image.Target_Image( - data="pixel data", - wcs="astropy WCS object", - variance="pixel uncertainties", - psf="point spread function as PSF_Image object", - mask=" True for pixels to ignore", - ) + # Display the data + fig, ax = plt.subplots() + ap.plots.target_image(fig, ax, image) + plt.show() - # Display the data - fig, ax = plt.subplots() - ap.plots.target_image(fig, ax, image) - plt.show() + # Save the image + image.save("mytarget.fits") - # Save the image - image.save("mytarget.fits") + # Load the image + image2 = ap.image.Target_Image(filename="mytarget.fits") - # Load the image - image2 = ap.image.Target_Image(filename="mytarget.fits") - - # Make low resolution version - lowrez = image.reduce(2) + # Make low resolution version + lowrez = image.reduce(2) + ``` Some important information to keep in mind. First, providing an `astropy WCS` object is the best way to keep track of coordinates @@ -79,628 +82,247 @@ class Target_Image(Image): """ - image_count = 0 - - def __init__(self, *args, **kwargs): + def __init__(self, *args, psf=None, **kwargs): super().__init__(*args, **kwargs) - if not self.has_mask: - self.set_mask(kwargs.get("mask", None)) - if not self.has_weight and "weight" in kwargs: - self.set_weight(kwargs.get("weight", None)) - elif not self.has_variance and "variance" in kwargs: - self.set_variance(kwargs.get("variance", None)) if not self.has_psf: - self.set_psf(kwargs.get("psf", None), kwargs.get("psf_upscale", 1)) - - # Set nan pixels to be masked automatically - if torch.any(torch.isnan(self.data)).item(): - self.set_mask(torch.logical_or(self.mask, torch.isnan(self.data))) - - @property - def standard_deviation(self): - """Stores the standard deviation of the image pixels. This represents - the uncertainty in each pixel value. It should always have the - same shape as the image data. In the case where the standard - deviation is not known, a tensor of ones will be created to - stand in as the standard deviation values. - - The standard deviation is not stored directly, instead it is - computed as :math:`\\sqrt{1/W}` where :math:`W` is the - weights. - - """ - if self.has_variance: - return torch.sqrt(self.variance) - return torch.ones_like(self.data) + self.psf = psf @property - def variance(self): - """Stores the variance of the image pixels. This represents the - uncertainty in each pixel value. It should always have the - same shape as the image data. In the case where the variance - is not known, a tensor of ones will be created to stand in as - the variance values. - - The variance is not stored directly, instead it is - computed as :math:`\\frac{1}{W}` where :math:`W` is the - weights. - - """ - if self.has_variance: - return torch.where(self._weight == 0, torch.inf, 1 / self._weight) - return torch.ones_like(self.data) - - @variance.setter - def variance(self, variance): - self.set_variance(variance) - - @property - def has_variance(self): - """Returns True when the image object has stored variance values. If - this is False and the variance property is called then a - tensor of ones will be returned. - - """ + def has_psf(self) -> bool: + """Returns True when the target image object has a PSF model.""" try: - return self._weight is not None - except AttributeError: - return False - - @property - def weight(self): - """Stores the weight of the image pixels. This represents the - uncertainty in each pixel value. It should always have the - same shape as the image data. In the case where the weight - is not known, a tensor of ones will be created to stand in as - the weight values. - - The weights are used to proprtionately scale residuals in the - likelihood. Most commonly this shows up as a :math:`\\chi^2` - like: - - .. math:: - - \\chi^2 = (\\vec{y} - \\vec{f(\\theta)})^TW(\\vec{y} - \\vec{f(\\theta)}) - - which can be optimized to find parameter values. Using the - Jacobian, which in this case is the derivative of every pixel - wrt every parameter, the weight matrix also appears in the - gradient: - - .. math:: - - \\vec{g} = J^TW(\\vec{y} - \\vec{f(\\theta)}) - - and the hessian approximation used in Levenberg-Marquardt: - - .. math:: - - H \\approx J^TWJ - - """ - if self.has_weight: - return self._weight - return torch.ones_like(self.data) - - @weight.setter - def weight(self, weight): - self.set_weight(weight) - - @property - def has_weight(self): - """Returns True when the image object has stored weight values. If - this is False and the weight property is called then a - tensor of ones will be returned. - - """ - try: - return self._weight is not None - except AttributeError: - self._weight = None - return False - - @property - def mask(self): - """The mask stores a tensor of boolean values which indicate any - pixels to be ignored. These pixels will be skipped in - likelihood evaluations and in parameter optimization. It is - common practice to mask pixels with pathological values such - as due to cosmic rays or satellites passing through the image. - - In a mask, a True value indicates that the pixel is masked and - should be ignored. False indicates a normal pixel which will - inter into most calculaitons. - - If no mask is provided, all pixels are assumed valid. - - """ - if self.has_mask: - return self._mask - return torch.zeros_like(self.data, dtype=torch.bool) - - @mask.setter - def mask(self, mask): - self.set_mask(mask) - - @property - def has_mask(self): - """ - Single boolean to indicate if a mask has been provided by the user. - """ - try: - return self._mask is not None + return self._psf is not None except AttributeError: return False @property def psf(self): - """Stores the point-spread-function for this target. This should be a - `PSF_Image` object which represents the scattering of a point - source of light. It can also be an `AstroPhot_Model` object - which will contribute its own parameters to an optimization - problem. + """The PSF for the `TargetImage`. This is used to convolve the + model with the PSF before evaluating the likelihood. The PSF + should be a `PSFImage` object or an `AstroPhot` PSFModel. - The PSF stored for a `Target_Image` object is passed to all - models applied to that target which have a `psf_mode` that is - not `none`. This means they will all use the same PSF - model. If one wishes to define a variable PSF across an image, - then they should pass the PSF objects to the `AstroPhot_Model`'s - directly instead of to a `Target_Image`. - - Raises: - - AttributeError: if this is called without a PSF defined + If no PSF is provided, then the image will not be convolved + with a PSF and the model will be evaluated directly on the + image pixels. """ - if self.has_psf: + try: return self._psf - raise AttributeError("This image does not have a PSF") + except AttributeError: + return None @psf.setter def psf(self, psf): - self.set_psf(psf) - - @property - def has_psf(self): - try: - return self._psf is not None - except AttributeError: - return False - - def set_variance(self, variance): - """ - Provide a variance tensor for the image. Variance is equal to :math:`\\sigma^2`. This should have the same shape as the data. - """ - if variance is None: - self._weight = None - return - if isinstance(variance, str) and variance == "auto": - self.set_weight("auto") - return - self.set_weight(1 / variance) - - def set_weight(self, weight): - """Provide a weight tensor for the image. Weight is equal to :math:`\\frac{1}{\\sigma^2}`. This should have the same - shape as the data. - - """ - if weight is None: - self._weight = None - return - if isinstance(weight, str) and weight == "auto": - weight = 1 / auto_variance(self.data, self.mask) - if weight.shape != self.data.shape: - raise SpecificationConflict( - f"weight/variance must have same shape as data ({weight.shape} vs {self.data.shape})" - ) - self._weight = ( - weight.to(dtype=AP_config.ap_dtype, device=AP_config.ap_device) - if isinstance(weight, torch.Tensor) - else torch.as_tensor(weight, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - ) - - def set_psf(self, psf, psf_upscale=1): - """Provide a psf for the `Target_Image`. This is stored and passed to + """Provide a psf for the `TargetImage`. This is stored and passed to models which need to be convolved. The PSF doesn't need to have the same pixelscale as the image. It should be some multiple of the resolution of the - `Target_Image` though. So if the image has a pixelscale of 1, + `TargetImage` though. So if the image has a pixelscale of 1, the psf may have a pixelscale of 1, 1/2, 1/3, 1/4 and so on. """ + if hasattr(self, "_psf"): + del self._psf # remove old psf if it exists + from ..models import Model + if psf is None: self._psf = None - return - if isinstance(psf, PSF_Image): + elif isinstance(psf, PSFImage): self._psf = psf - return - - self._psf = PSF_Image( - data=psf, - psf_upscale=psf_upscale, - pixelscale=self.pixelscale / psf_upscale, - identity=self.identity, - ) - - def set_mask(self, mask): - """ - Set the boolean mask which will indicate which pixels to ignore. A mask value of True means the pixel will be ignored. - """ - if mask is None: - self._mask = None - return - if mask.shape != self.data.shape: - raise SpecificationConflict( - f"mask must have same shape as data ({mask.shape} vs {self.data.shape})" + elif isinstance(psf, Model): + self._psf = psf + else: + self._psf = PSFImage( + data=psf, + CD=self.CD, + name=self.name + "_psf", ) - self._mask = ( - mask.to(dtype=torch.bool, device=AP_config.ap_device) - if isinstance(mask, torch.Tensor) - else torch.as_tensor(mask, dtype=torch.bool, device=AP_config.ap_device) - ) - def to(self, dtype=None, device=None): - """Converts the stored `Target_Image` data, variance, psf, etc to a - given data type and device. + def copy_kwargs(self, **kwargs): + kwargs = {"psf": self.psf, **kwargs} + return super().copy_kwargs(**kwargs) - """ - super().to(dtype=dtype, device=device) - if dtype is not None: - dtype = AP_config.ap_dtype - if device is not None: - device = AP_config.ap_device - - if self.has_weight: - self._weight = self._weight.to(dtype=dtype, device=device) + def fits_images(self): + images = super().fits_images() if self.has_psf: - self._psf = self._psf.to(dtype=dtype, device=device) - if self.has_mask: - self._mask = self.mask.to(dtype=torch.bool, device=device) - return self - - def or_mask(self, mask): - """ - Combines the currently stored mask with a provided new mask using the boolean `or` operator. - """ - self._mask = torch.logical_or(self.mask, mask) - - def and_mask(self, mask): - """ - Combines the currently stored mask with a provided new mask using the boolean `and` operator. - """ - self._mask = torch.logical_and(self.mask, mask) - - def copy(self, **kwargs): - """Produce a copy of this image with all of the same properties. This - can be used when one wishes to make temporary modifications to - an image and then will want the original again. - - """ - return super().copy( - mask=self._mask, - psf=self._psf, - weight=self._weight, - **kwargs, - ) + if isinstance(self.psf, PSFImage): + images.append( + fits.ImageHDU( + backend.to_numpy(self.psf.data), + name="PSF", + header=fits.Header(self.psf.fits_info()), + ) + ) + else: + config.logger.warning("Unable to save PSF to FITS, not a PSF_Image.") + return images - def blank_copy(self, **kwargs): - """Produces a blank copy of the image which has the same properties - except that its data is not filled with zeros. + def load(self, filename: str, hduext: int = 0): + """Load the image from a FITS file. This will load the data, WCS, and + any ancillary data such as variance, mask, and PSF. """ - return super().blank_copy(mask=self._mask, psf=self._psf, **kwargs) - - def get_window(self, window, **kwargs): - """Get a sub-region of the image as defined by a window on the sky.""" - indices = self.window.get_self_indices(window) - return super().get_window( - window=window, - weight=self._weight[indices] if self.has_weight else None, - mask=self._mask[indices] if self.has_mask else None, - psf=self._psf, - **kwargs, - ) + hdulist = super().load(filename, hduext=hduext) + if "PSF" in hdulist: + self.psf = PSFImage( + data=np.array(hdulist["PSF"].data, dtype=np.float64), + CD=( + (hdulist["PSF"].header["CD1_1"], hdulist["PSF"].header["CD1_2"]), + (hdulist["PSF"].header["CD2_1"], hdulist["PSF"].header["CD2_2"]), + ), + ) + return hdulist def jacobian_image( self, - parameters: Optional[List[str]] = None, - data: Optional[torch.Tensor] = None, + parameters: List[str], + data: Optional[ArrayLike] = None, **kwargs, - ): + ) -> JacobianImage: """ - Construct a blank `Jacobian_Image` object formatted like this current `Target_Image` object. Mostly used internally. + Construct a blank `JacobianImage` object formatted like this current `TargetImage` object. Mostly used internally. """ - if parameters is None: - data = None - parameters = [] - elif data is None: - data = torch.zeros( - (*self.data.shape, len(parameters)), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, + if data is None: + data = backend.zeros( + (*self._data.shape, len(parameters)), + dtype=config.DTYPE, + device=config.DEVICE, ) - return Jacobian_Image( - parameters=parameters, - target_identity=self.identity, - data=data, - header=self.header, + kwargs = { + "CD": self.CD.value, + "crpix": self.crpix, + "crtan": self.crtan.value, + "crval": self.crval.value, + "zeropoint": self.zeropoint, + "identity": self.identity, + "name": self.name + "_jacobian", **kwargs, - ) + } + return JacobianImage(parameters=parameters, _data=data, **kwargs) - def model_image(self, data: Optional[torch.Tensor] = None, **kwargs): + def model_image(self, upsample: int = 1, pad: int = 0, **kwargs) -> ModelImage: """ - Construct a blank `Model_Image` object formatted like this current `Target_Image` object. Mostly used internally. + Construct a blank `ModelImage` object formatted like this current `TargetImage` object. Mostly used internally. """ - return Model_Image( - data=torch.zeros_like(self.data) if data is None else data, - header=self.header, - target_identity=self.identity, + kwargs = { + "_data": backend.zeros( + ( + self._data.shape[0] * upsample + 2 * pad, + self._data.shape[1] * upsample + 2 * pad, + ), + dtype=config.DTYPE, + device=config.DEVICE, + ), + "CD": self.CD.value / upsample, + "crpix": (self.crpix + 0.5) * upsample + pad - 0.5, + "crtan": self.crtan.value, + "crval": self.crval.value, + "zeropoint": self.zeropoint, + "identity": self.identity, + "name": self.name + "_model", **kwargs, - ) + } + return ModelImage(**kwargs) + + def psf_image(self, data: ArrayLike, upscale: int = 1, **kwargs) -> PSFImage: + kwargs = { + "data": data, + "CD": self.CD.value / upscale, + "identity": self.identity, + "name": self.name + "_psf", + **kwargs, + } + return PSFImage(**kwargs) - def reduce(self, scale, **kwargs): - """Returns a new `Target_Image` object with a reduced resolution + def reduce(self, scale: int, **kwargs) -> "TargetImage": + """Returns a new `TargetImage` object with a reduced resolution compared to the current image. `scale` should be an integer indicating how much to reduce the resolution. If the - `Target_Image` was originally (48,48) pixels across with a + `TargetImage` was originally (48,48) pixels across with a pixelscale of 1 and `reduce(2)` is called then the image will be (24,24) pixels and the pixelscale will be 2. If `reduce(3)` is called then the returned image will be (16,16) pixels across and the pixelscale will be 3. """ - MS = self.data.shape[0] // scale - NS = self.data.shape[1] // scale return super().reduce( scale=scale, - variance=( - self.variance[: MS * scale, : NS * scale] - .reshape(MS, scale, NS, scale) - .sum(axis=(1, 3)) - if self.has_variance - else None - ), - mask=( - self.mask[: MS * scale, : NS * scale] - .reshape(MS, scale, NS, scale) - .amax(axis=(1, 3)) - if self.has_mask - else None - ), - psf=self.psf.reduce(scale) if self.has_psf else None, + psf=(self.psf.reduce(scale) if isinstance(self.psf, PSFImage) else None), **kwargs, ) - def expand(self, padding): - """ - `Target_Image` doesn't have expand yet. - """ - raise NotImplementedError("expand not available for Target_Image yet") - - def get_state(self): - state = super().get_state() - - if self.has_weight: - state["weight"] = self.weight.detach().cpu().tolist() - if self.has_mask: - state["mask"] = self.mask.detach().cpu().tolist() - if self.has_psf: - state["psf"] = self.psf.get_state() - - return state - - def set_state(self, state): - super().set_state(state) - - self.weight = state.get("weight", None) - self.mask = state.get("mask", None) - if "psf" in state: - self.psf = PSF_Image(state=state["psf"]) - - def get_fits_state(self): - states = super().get_fits_state() - if self.has_weight: - states.append( - { - "DATA": self.weight.detach().cpu().numpy(), - "HEADER": {"IMAGE": "WEIGHT"}, - } - ) - if self.has_mask: - states.append( - { - "DATA": self.mask.detach().cpu().numpy().astype(int), - "HEADER": {"IMAGE": "MASK"}, - } - ) - if self.has_psf: - states += self.psf.get_fits_state() - - return states - - def set_fits_state(self, states): - super().set_fits_state(states) - for state in states: - if state["HEADER"]["IMAGE"] == "WEIGHT": - self.weight = np.array(state["DATA"], dtype=np.float64) - if state["HEADER"]["IMAGE"] == "mask": - self.mask = np.array(state["DATA"], dtype=bool) - if state["HEADER"]["IMAGE"] == "PSF": - self.psf = PSF_Image(fits_state=states) - -class Target_Image_List(Image_List, Target_Image): +class TargetImageList(ImageList): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if not all(isinstance(image, Target_Image) for image in self.image_list): + if not all(isinstance(image, (TargetImage, TargetImageList)) for image in self.images): raise InvalidImage( - f"Target_Image_List can only hold Target_Image objects, not {tuple(type(image) for image in self.image_list)}" + f"Target_Image_List can only hold Target_Image objects, not {tuple(type(image) for image in self.images)}" ) @property def variance(self): - return tuple(image.variance for image in self.image_list) + return tuple(image.variance for image in self.images) + + @property + def _variance(self): + return tuple(image._variance for image in self.images) @variance.setter def variance(self, variance): - for image, var in zip(self.image_list, variance): - image.set_variance(var) + for image, var in zip(self.images, variance): + image.variance = var @property - def has_variance(self): - return any(image.has_variance for image in self.image_list) + def weight(self): + return tuple(image.weight for image in self.images) @property - def weight(self): - return tuple(image.weight for image in self.image_list) + def _weight(self): + return tuple(image._weight for image in self.images) @weight.setter def weight(self, weight): - for image, wgt in zip(self.image_list, weight): - image.set_weight(wgt) + for image, wgt in zip(self.images, weight): + image.weight = wgt - @property - def has_weight(self): - return any(image.has_weight for image in self.image_list) - - def jacobian_image(self, parameters: List[str], data: Optional[List[torch.Tensor]] = None): - if data is None: - data = [None] * len(self.image_list) - return Jacobian_Image_List( - list(image.jacobian_image(parameters, dat) for image, dat in zip(self.image_list, data)) - ) - - def model_image(self, data: Optional[List[torch.Tensor]] = None): + def jacobian_image( + self, parameters: List[str], data: Optional[List[ArrayLike]] = None + ) -> JacobianImageList: if data is None: - data = [None] * len(self.image_list) - return Model_Image_List( - list(image.model_image(data=dat) for image, dat in zip(self.image_list, data)) + data = tuple(None for _ in range(len(self.images))) + return JacobianImageList( + list(image.jacobian_image(parameters, dat) for image, dat in zip(self.images, data)) ) - def match_indices(self, other): - indices = [] - if isinstance(other, Target_Image_List): - for other_image in other.image_list: - for isi, self_image in enumerate(self.image_list): - if other_image.identity == self_image.identity: - indices.append(isi) - break - else: - indices.append(None) - elif isinstance(other, Target_Image): - for isi, self_image in enumerate(self.image_list): - if other.identity == self_image.identity: - indices = isi - break - else: - indices = None - return indices - - def __isub__(self, other): - if isinstance(other, Target_Image_List): - for other_image in other.image_list: - for self_image in self.image_list: - if other_image.identity == self_image.identity: - self_image -= other_image - break - else: - self.image_list.append(other_image) - elif isinstance(other, Target_Image): - for self_image in self.image_list: - if other.identity == self_image.identity: - self_image -= other - break - elif isinstance(other, Model_Image_List): - for other_image in other.image_list: - for self_image in self.image_list: - if other_image.target_identity == self_image.identity: - self_image -= other_image - break - elif isinstance(other, Model_Image): - for self_image in self.image_list: - if other.target_identity == self_image.identity: - self_image -= other - else: - for self_image, other_image in zip(self.image_list, other): - self_image -= other_image - return self - - def __iadd__(self, other): - if isinstance(other, Target_Image_List): - for other_image in other.image_list: - for self_image in self.image_list: - if other_image.identity == self_image.identity: - self_image += other_image - break - else: - self.image_list.append(other_image) - elif isinstance(other, Target_Image): - for self_image in self.image_list: - if other.identity == self_image.identity: - self_image += other - elif isinstance(other, Model_Image_List): - for other_image in other.image_list: - for self_image in self.image_list: - if other_image.target_identity == self_image.identity: - self_image += other_image - break - elif isinstance(other, Model_Image): - for self_image in self.image_list: - if other.target_identity == self_image.identity: - self_image += other - else: - for self_image, other_image in zip(self.image_list, other): - self_image += other_image - return self + def model_image(self) -> ModelImageList: + return ModelImageList(list(image.model_image() for image in self.images)) @property def mask(self): - return tuple(image.mask for image in self.image_list) + return tuple(image.mask for image in self.images) + + @property + def _mask(self): + return tuple(image._mask for image in self.images) @mask.setter def mask(self, mask): - for image, M in zip(self.image_list, mask): - image.set_mask(M) - - @property - def has_mask(self): - return any(image.has_mask for image in self.image_list) + for image, M in zip(self.images, mask): + image.mask = M @property def psf(self): - return tuple(image.psf for image in self.image_list) + return tuple(image.psf for image in self.images) @psf.setter def psf(self, psf): - for image, P in zip(self.image_list, psf): - image.set_psf(P) + for image, P in zip(self.images, psf): + image.psf = P @property - def has_psf(self): - return any(image.has_psf for image in self.image_list) - - @property - def psf_border(self): - return tuple(image.psf_border for image in self.image_list) - - @property - def psf_border_int(self): - return tuple(image.psf_border_int for image in self.image_list) - - def set_variance(self, variance, img): - self.image_list[img].set_variance(variance) - - def set_psf(self, psf, img): - self.image_list[img].set_psf(psf) - - def set_mask(self, mask, img): - self.image_list[img].set_mask(mask) - - def or_mask(self, mask): - raise NotImplementedError() - - def and_mask(self, mask): - raise NotImplementedError() + def has_psf(self) -> bool: + return any(image.has_psf for image in self.images) diff --git a/astrophot/image/wcs.py b/astrophot/image/wcs.py deleted file mode 100644 index 6f0f71a6..00000000 --- a/astrophot/image/wcs.py +++ /dev/null @@ -1,808 +0,0 @@ -import torch -import numpy as np - -from .. import AP_config -from ..utils.conversions.units import deg_to_arcsec -from ..errors import InvalidWCS - -__all__ = ("WPCS", "PPCS", "WCS") - -deg_to_rad = np.pi / 180 -rad_to_deg = 180 / np.pi -rad_to_arcsec = rad_to_deg * 3600 -arcsec_to_rad = deg_to_rad / 3600 - - -class WPCS: - """World to Plane Coordinate System in AstroPhot. - - AstroPhot performs its operations on a tangent plane to the - celestial sphere, this class handles projections between the sphere and the - tangent plane. It holds variables for the reference (RA,DEC) where - the tangent plane contacts the sphere, and the type of projection - being performed. Note that (RA,DEC) coordinates should always be - in degrees while the tangent plane is in arcsecs. - - Attributes: - reference_radec: The reference (RA,DEC) coordinates in degrees where the tangent plane contacts the sphere. - reference_planexy: The reference tangent plane coordinates in arcsec where the tangent plane contacts the sphere. - projection: The projection system used to convert from (RA,DEC) onto the tangent plane. Should be one of: gnomonic (default), orthographic, steriographic - - """ - - # Softening length used for numerical stability and/or integration stability to avoid discontinuities (near R=0). This is in units of arcsec. - softening = 1e-3 - - default_reference_radec = (0, 0) - default_reference_planexy = (0, 0) - default_projection = "gnomonic" - - def __init__(self, **kwargs): - self.projection = kwargs.get("projection", self.default_projection) - self.reference_radec = kwargs.get("reference_radec", self.default_reference_radec) - self.reference_planexy = kwargs.get("reference_planexy", self.default_reference_planexy) - - def world_to_plane(self, world_RA, world_DEC=None): - """Take a coordinate on the world coordinate system, also called the - celesial sphere, (RA, DEC in degrees) and transform it to the - corresponding tangent plane coordinate - (arcsec). Transformation is done based on the chosen - projection (default gnomonic) and reference positions. See the - :doc:`coordinates` documentation for more details on how the - transformation is performed. - - """ - - if world_DEC is None: - return torch.stack(self.world_to_plane(*world_RA)) - - world_RA = torch.as_tensor(world_RA, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - world_DEC = torch.as_tensor(world_DEC, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - - if self.projection == "gnomonic": - coords = self._world_to_plane_gnomonic( - world_RA, - world_DEC, - ) - elif self.projection == "orthographic": - coords = self._world_to_plane_orthographic( - world_RA, - world_DEC, - ) - elif self.projection == "steriographic": - coords = self._world_to_plane_steriographic( - world_RA, - world_DEC, - ) - return ( - coords[0] + self.reference_planexy[0], - coords[1] + self.reference_planexy[1], - ) - - def plane_to_world(self, plane_x, plane_y=None): - """Take a coordinate on the tangent plane (arcsec), and transform it - to the corresponding world coordinate (RA, DEC in - degrees). Transformation is done based on the chosen - projection (default gnomonic) and reference positions. See the - :doc:`coordinates` documentation for more details on how the - transformation is performed. - - """ - - if plane_y is None: - return torch.stack(self.plane_to_world(*plane_x)) - plane_x = torch.as_tensor( - plane_x - self.reference_planexy[0], - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - plane_y = torch.as_tensor( - plane_y - self.reference_planexy[1], - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - if self.projection == "gnomonic": - return self._plane_to_world_gnomonic( - plane_x, - plane_y, - ) - if self.projection == "orthographic": - return self._plane_to_world_orthographic( - plane_x, - plane_y, - ) - if self.projection == "steriographic": - return self._plane_to_world_steriographic( - plane_x, - plane_y, - ) - - @property - def projection(self): - """ - The mathematical projection formula which described how world coordinates are mapped to the tangent plane. - """ - return self._projection - - @projection.setter - def projection(self, proj): - if proj not in ( - "gnomonic", - "orthographic", - "steriographic", - ): - raise InvalidWCS( - f"Unrecognized projection: {proj}. Should be one of: gnomonic, orthographic, steriographic" - ) - self._projection = proj - - @property - def reference_radec(self): - """ - RA DEC (world) coordinates where the tangent plane meets the celestial sphere. These should be in degrees. - """ - return self._reference_radec - - @reference_radec.setter - def reference_radec(self, radec): - self._reference_radec = torch.as_tensor( - radec, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - - @property - def reference_planexy(self): - """ - x y tangent plane coordinates where the tangent plane meets the celestial sphere. These should be in arcsec. - """ - return self._reference_planexy - - @reference_planexy.setter - def reference_planexy(self, planexy): - self._reference_planexy = torch.as_tensor( - planexy, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - - def _project_world_to_plane(self, world_RA, world_DEC): - """ - Recurring core calculation in all the projections from world to plane. - - Args: - world_RA: Right ascension in degrees - world_DEC: Declination in degrees - """ - return ( - torch.cos(world_DEC * deg_to_rad) - * torch.sin((world_RA - self.reference_radec[0]) * deg_to_rad) - * rad_to_arcsec, - ( - torch.cos(self.reference_radec[1] * deg_to_rad) * torch.sin(world_DEC * deg_to_rad) - - torch.sin(self.reference_radec[1] * deg_to_rad) - * torch.cos(world_DEC * deg_to_rad) - * torch.cos((world_RA - self.reference_radec[0]) * deg_to_rad) - ) - * rad_to_arcsec, - ) - - def _project_plane_to_world(self, plane_x, plane_y, rho, c): - """ - Recurring core calculation in all the projections from plane to world. - - Args: - plane_x: tangent plane x coordinate in arcseconds. - plane_y: tangent plane y coordinate in arcseconds. - rho: polar radius on tangent plane. - c: coordinate term dependent on the projection. - """ - return ( - ( - self._reference_radec[0] * deg_to_rad - + torch.arctan2( - plane_x * arcsec_to_rad * torch.sin(c), - rho * torch.cos(self.reference_radec[1] * deg_to_rad) * torch.cos(c) - - plane_y - * arcsec_to_rad - * torch.sin(self.reference_radec[1] * deg_to_rad) - * torch.sin(c), - ) - ) - * rad_to_deg, - torch.arcsin( - torch.cos(c) * torch.sin(self.reference_radec[1] * deg_to_rad) - + plane_y - * arcsec_to_rad - * torch.sin(c) - * torch.cos(self.reference_radec[1] * deg_to_rad) - / rho - ) - * rad_to_deg, - ) - - def _world_to_plane_gnomonic(self, world_RA, world_DEC): - """Gnomonic projection: (RA,DEC) to tangent plane. - - Performs Gnomonic projection of (RA,DEC) coordinates onto a - tangent plane. The tangent plane makes contact at the location - of the `reference_radec` variable. In a gnomonic projection, - great circles are mapped to straight lines. The gnomonic - projection represents the image formed by a spherical lens, - and is sometimes known as the rectilinear projection. - - Args: - world_RA: Right ascension in degrees - world_DEC: Declination in degrees - - See: https://mathworld.wolfram.com/GnomonicProjection.html - - """ - C = torch.sin(self.reference_radec[1] * deg_to_rad) * torch.sin( - world_DEC * deg_to_rad - ) + torch.cos(self.reference_radec[1] * deg_to_rad) * torch.cos( - world_DEC * deg_to_rad - ) * torch.cos( - (world_RA - self.reference_radec[0]) * deg_to_rad - ) - x, y = self._project_world_to_plane(world_RA, world_DEC) - return x / C, y / C - - def _plane_to_world_gnomonic(self, plane_x, plane_y): - """Inverse Gnomonic projection: tangent plane to (RA,DEC). - - Performs the inverse Gnomonic projection of tangent plane - coordinates into (RA,DEC) coordinates. The tangent plane makes - contact at the location of the `reference_radec` variable. In - a gnomonic projection, great circles are mapped to straight - lines. The gnomonic projection represents the image formed by - a spherical lens, and is sometimes known as the rectilinear - projection. - - Args: - plane_x: tangent plane x coordinate in arcseconds. - plane_y: tangent plane y coordinate in arcseconds. - - See: https://mathworld.wolfram.com/GnomonicProjection.html - - """ - rho = (torch.sqrt(plane_x**2 + plane_y**2) + self.softening) * arcsec_to_rad - c = torch.arctan(rho) - - ra, dec = self._project_plane_to_world(plane_x, plane_y, rho, c) - return ra, dec - - def _world_to_plane_steriographic(self, world_RA, world_DEC): - """Steriographic projection: (RA,DEC) to tangent plane - - Performs Steriographic projection of (RA,DEC) coordinates onto - a tangent plane. The tangent plane makes contact at the - location of the `reference_radec` variable. The steriographic - projection preserves circles and angle measures. - - Args: - world_RA: Right ascension in degrees - world_DEC: Declination in degrees - - See: https://mathworld.wolfram.com/StereographicProjection.html - - """ - C = ( - 1 - + torch.sin(world_DEC * deg_to_rad) * torch.sin(self._reference_radec[1] * deg_to_rad) - + torch.cos(world_DEC * deg_to_rad) - * torch.cos(self._reference_radec[1] * deg_to_rad) - * torch.cos((world_RA - self._reference_radec[0]) * deg_to_rad) - ) / 2 - x, y = self._project_world_to_plane(world_RA, world_DEC) - return x / C, y / C - - def _plane_to_world_steriographic(self, plane_x, plane_y): - """Inverse Steriographic projection: tangent plane to (RA,DEC). - - Performs the inverse Steriographic projection of tangent plane - coordinates into (RA,DEC) coordinates. The tangent plane makes - contact at the location of the `reference_radec` variable. The - steriographic projection preserves circles and angle measures. - - Args: - plane_x: tangent plane x coordinate in arcseconds. The origin of the tangent plane is the contact point with the sphere, represented by `reference_radec`. - plane_y: tangent plane y coordinate in arcseconds. The origin of the tangent plane is the contact point with the sphere, represented by `reference_radec`. - - See: https://mathworld.wolfram.com/StereographicProjection.html - - """ - rho = (torch.sqrt(plane_x**2 + plane_y**2) + self.softening) * arcsec_to_rad - c = 2 * torch.arctan(rho / 2) - ra, dec = self._project_plane_to_world(plane_x, plane_y, rho, c) - return ra, dec - - def _world_to_plane_orthographic(self, world_RA, world_DEC): - """Orthographic projection: (RA,DEC) to tangent plane - - Performs Orthographic projection of (RA,DEC) coordinates onto - a tangent plane. The tangent plane makes contact at the - location of the `reference_radec` variable. The point of - perspective for the orthographic projection is at infinite - distance. This projection is perhaps better suited to - represent the view of an exoplanet, however it is included - here for completeness. - - Args: - world_RA: Right ascension in degrees - world_DEC: Declination in degrees - - See: https://mathworld.wolfram.com/OrthographicProjection.html - - """ - x, y = self._project_world_to_plane(world_RA, world_DEC) - return x, y - - def _plane_to_world_orthographic(self, plane_x, plane_y): - """Inverse Orthographic projection: tangent plane to (RA,DEC). - - Performs the inverse Orthographic projection of tangent plane - coordinates into (RA,DEC) coordinates. The tangent plane makes - contact at the location of the `reference_radec` variable. The - point of perspective for the orthographic projection is at - infinite distance. This projection is perhaps better suited to - represent the view of an exoplanet, however it is included - here for completeness. - - Args: - plane_x: tangent plane x coordinate in arcseconds. The origin of the tangent plane is the contact point with the sphere, represented by `reference_radec`. - plane_y: tangent plane y coordinate in arcseconds. The origin of the tangent plane is the contact point with the sphere, represented by `reference_radec`. - - See: https://mathworld.wolfram.com/OrthographicProjection.html - - """ - rho = (torch.sqrt(plane_x**2 + plane_y**2) + self.softening) * arcsec_to_rad - c = torch.arcsin(rho) - - ra, dec = self._project_plane_to_world(plane_x, plane_y, rho, c) - return ra, dec - - def get_state(self): - """Returns a dictionary with the information needed to recreate the - WPCS object. - - """ - return { - "projection": self.projection, - "reference_radec": self.reference_radec.detach().cpu().tolist(), - "reference_planexy": self.reference_planexy.detach().cpu().tolist(), - } - - def set_state(self, state): - """Takes a state dictionary and re-creates the state of the WPCS - object. - - """ - self.projection = state.get("projection", self.default_projection) - self.reference_radec = state.get("reference_radec", self.default_reference_radec) - self.reference_planexy = state.get("reference_planexy", self.default_reference_planexy) - - def get_fits_state(self): - """ - Similar to get_state, except specifically tailored to be stored in a FITS format. - """ - return { - "PROJ": self.projection, - "REFRADEC": str(self.reference_radec.detach().cpu().tolist()), - "REFPLNXY": str(self.reference_planexy.detach().cpu().tolist()), - } - - def set_fits_state(self, state): - """ - Reads and applies the state from the get_fits_state function. - """ - self.projection = state["PROJ"] - self.reference_radec = eval(state["REFRADEC"]) - self.reference_planexy = eval(state["REFPLNXY"]) - - def copy(self, **kwargs): - """Create a copy of the WPCS object with the same projection - parameters. - - """ - copy_kwargs = { - "projection": self.projection, - "reference_radec": self.reference_radec, - "reference_planexy": self.reference_planexy, - } - copy_kwargs.update(kwargs) - return self.__class__( - **copy_kwargs, - ) - - def to(self, dtype=None, device=None): - """ - Convert all stored tensors to a new device and data type - """ - if dtype is None: - dtype = AP_config.ap_dtype - if device is None: - device = AP_config.ap_device - self._reference_radec = self._reference_radec.to(dtype=dtype, device=device) - self._reference_planexy = self._reference_planexy.to(dtype=dtype, device=device) - - def __str__(self): - return f"WPCS reference_radec: {self.reference_radec.detach().cpu().tolist()}, reference_planexy: {self.reference_planexy.detach().cpu().tolist()}" - - def __repr__(self): - return f"WPCS reference_radec: {self.reference_radec.detach().cpu().tolist()}, reference_planexy: {self.reference_planexy.detach().cpu().tolist()}, projection: {self.projection}" - - -class PPCS: - """ - plane to pixel coordinate system - - - Args: - pixelscale : float or None, optional - The physical scale of the pixels in the image, this is - represented as a matrix which projects pixel units into sky - units: ``pixelscale @ pixel_vec = sky_vec``. The pixel - scale matrix can be thought of in four components: - :math:`\\vec{s} @ F @ R @ S` where :math:`\\vec{s}` is the side - length of the pixels, :math:`F` is a diagonal matrix of {1,-1} - which flips the axes orientation, :math:`R` is a rotation - matrix, and :math:`S` is a shear matrix which turns - rectangular pixels into parallelograms. Default is None. - reference_imageij : Sequence or None, optional - The pixel coordinate at which the image is fixed to the - tangent plane. By default this is (-0.5, -0.5) or the bottom - corner of the [0,0] indexed pixel. - reference_imagexy : Sequence or None, optional - The tangent plane coordinate at which the image is fixed, - corresponding to the reference_imageij coordinate. These two - reference points ar pinned together, any rotations would occur - about this point. By default this is (0., 0.). - - """ - - default_reference_imageij = (-0.5, -0.5) - default_reference_imagexy = (0, 0) - default_pixelscale = 1 - - def __init__(self, *, wcs=None, pixelscale=None, **kwargs): - - self.reference_imageij = kwargs.get("reference_imageij", self.default_reference_imageij) - self.reference_imagexy = kwargs.get("reference_imagexy", self.default_reference_imagexy) - - # Collect the pixelscale of the pixel grid - if wcs is not None and pixelscale is None: - self.pixelscale = deg_to_arcsec * wcs.pixel_scale_matrix - elif pixelscale is not None: - if wcs is not None and isinstance(pixelscale, float): - AP_config.ap_logger.warning( - "Overriding WCS pixelscale with manual input! To remove this message, either let WCS define pixelscale, or input full pixelscale matrix" - ) - self.pixelscale = pixelscale - else: - AP_config.ap_logger.warning( - "Assuming pixelscale of 1! To remove this message please provide the pixelscale explicitly" - ) - self.pixelscale = self.default_pixelscale - - @property - def pixelscale(self): - """Matrix defining the shape of pixels in the tangent plane, these - can be any parallelogram defined by the matrix. - - """ - return self._pixelscale - - @pixelscale.setter - def pixelscale(self, pix): - if pix is None: - self._pixelscale = None - return - - self._pixelscale = ( - torch.as_tensor(pix, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - .clone() - .detach() - ) - if self._pixelscale.numel() == 1: - self._pixelscale = torch.tensor( - [[self._pixelscale.item(), 0.0], [0.0, self._pixelscale.item()]], - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - self._pixel_area = torch.linalg.det(self.pixelscale).abs() - self._pixel_length = self._pixel_area.sqrt() - self._pixelscale_inv = torch.linalg.inv(self.pixelscale) - - @property - def pixel_area(self): - """The area inside a pixel in arcsec^2""" - return self._pixel_area - - @property - def pixel_length(self): - """The approximate length of a pixel, which is just - sqrt(pixel_area). For square pixels this is the actual pixel - length, for rectangular pixels it is a kind of average. - - The pixel_length is typically not used for exact calculations - and instead sets a size scale within an image. - - """ - return self._pixel_length - - @property - def reference_imageij(self): - """pixel coordinates where the pixel grid is fixed to the tangent - plane. These should be in pixel units where (0,0) is the - center of the [0,0] indexed pixel. However, it is still in xy - format, meaning that the first index gives translations in the - x-axis (horizontal-axis) of the image. - - """ - return self._reference_imageij - - @reference_imageij.setter - def reference_imageij(self, imageij): - self._reference_imageij = torch.as_tensor( - imageij, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - - @property - def reference_imagexy(self): - """plane coordinates where the image grid is fixed to the tangent - plane. These should be in arcsec. - - """ - return self._reference_imagexy - - @reference_imagexy.setter - def reference_imagexy(self, imagexy): - self._reference_imagexy = torch.as_tensor( - imagexy, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - - def pixel_to_plane(self, pixel_i, pixel_j=None): - """Take in a coordinate on the regular pixel grid, where 0,0 is the - center of the [0,0] indexed pixel. This coordinate is - transformed into the tangent plane coordinate system (arcsec) - based on the pixel scale and reference positions. If the pixel - scale matrix is :math:`P`, the reference pixel is - :math:`\\vec{r}_{pix}`, the reference tangent plane point is - :math:`\\vec{r}_{tan}`, and the coordinate to transform is - :math:`\\vec{c}_{pix}` then the coordinate in the tangent plane - is: - - .. math:: - - \\vec{c}_{tan} = [P(\\vec{c}_{pix} - \\vec{r}_{pix})] + \\vec{r}_{tan} - - """ - if pixel_j is None: - return torch.stack(self.pixel_to_plane(*pixel_i)) - coords = torch.mm( - self.pixelscale, - torch.stack((pixel_i.reshape(-1), pixel_j.reshape(-1))) - - self.reference_imageij.view(2, 1), - ) + self.reference_imagexy.view(2, 1) - return coords[0].reshape(pixel_i.shape), coords[1].reshape(pixel_j.shape) - - def plane_to_pixel(self, plane_x, plane_y=None): - """Take a coordinate on the tangent plane (arcsec) and transform it to - the corresponding pixel grid coordinate (pixel units where - (0,0) is the [0,0] indexed pixel). Transformation is done - based on the pixel scale and reference positions. If the pixel - scale matrix is :math:`P`, the reference pixel is - :math:`\\vec{r}_{pix}`, the reference tangent plane point is - :math:`\\vec{r}_{tan}`, and the coordinate to transform is - :math:`\\vec{c}_{tan}` then the coordinate in the pixel grid - is: - - .. math:: - - \\vec{c}_{pix} = [P^{-1}(\\vec{c}_{tan} - \\vec{r}_{tan})] + \\vec{r}_{pix} - - """ - if plane_y is None: - return torch.stack(self.plane_to_pixel(*plane_x)) - coords = torch.mm( - self._pixelscale_inv, - torch.stack((plane_x.reshape(-1), plane_y.reshape(-1))) - - self.reference_imagexy.view(2, 1), - ) + self.reference_imageij.view(2, 1) - return coords[0].reshape(plane_x.shape), coords[1].reshape(plane_y.shape) - - def pixel_to_plane_delta(self, pixel_delta_i, pixel_delta_j=None): - """Take a translation in pixel space and determine the corresponding - translation in the tangent plane (arcsec). Essentially this performs - the pixel scale matrix multiplication without any reference - coordinates applied. - - """ - if pixel_delta_j is None: - return torch.stack(self.pixel_to_plane_delta(*pixel_delta_i)) - coords = torch.mm( - self.pixelscale, - torch.stack((pixel_delta_i.reshape(-1), pixel_delta_j.reshape(-1))), - ) - return coords[0].reshape(pixel_delta_i.shape), coords[1].reshape(pixel_delta_j.shape) - - def plane_to_pixel_delta(self, plane_delta_x, plane_delta_y=None): - """Take a translation in tangent plane space (arcsec) and determine - the corresponding translation in pixel space. Essentially this - performs the pixel scale matrix multiplication without any - reference coordinates applied. - - """ - if plane_delta_y is None: - return torch.stack(self.plane_to_pixel_delta(*plane_delta_x)) - coords = torch.mm( - self._pixelscale_inv, - torch.stack((plane_delta_x.reshape(-1), plane_delta_y.reshape(-1))), - ) - return coords[0].reshape(plane_delta_x.shape), coords[1].reshape(plane_delta_y.shape) - - def copy(self, **kwargs): - """Create a copy of the PPCS object with the same projection - parameters. - - """ - copy_kwargs = { - "pixelscale": self.pixelscale, - "reference_imageij": self.reference_imageij, - "reference_imagexy": self.reference_imagexy, - } - copy_kwargs.update(kwargs) - return self.__class__( - **copy_kwargs, - ) - - def get_state(self): - return { - "pixelscale": self.pixelscale.detach().cpu().tolist(), - "reference_imageij": self.reference_imageij.detach().cpu().tolist(), - "reference_imagexy": self.reference_imagexy.detach().cpu().tolist(), - } - - def set_state(self, state): - self.pixelscale = state.get("pixelscale", self.default_pixelscale) - self.reference_imageij = state.get("reference_imageij", self.default_reference_imageij) - self.reference_imagexy = state.get("reference_imagexy", self.default_reference_imagexy) - - def get_fits_state(self): - """ - Similar to get_state, except specifically tailored to be stored in a FITS format. - """ - return { - "PXLSCALE": str(self.pixelscale.detach().cpu().tolist()), - "REFIMGIJ": str(self.reference_imageij.detach().cpu().tolist()), - "REFIMGXY": str(self.reference_imagexy.detach().cpu().tolist()), - } - - def set_fits_state(self, state): - """ - Reads and applies the state from the get_fits_state function. - """ - self.pixelscale = eval(state["PXLSCALE"]) - self.reference_imageij = eval(state["REFIMGIJ"]) - self.reference_imagexy = eval(state["REFIMGXY"]) - - def to(self, dtype=None, device=None): - """ - Convert all stored tensors to a new device and data type - """ - if dtype is None: - dtype = AP_config.ap_dtype - if device is None: - device = AP_config.ap_device - self._pixelscale = self._pixelscale.to(dtype=dtype, device=device) - self._reference_imageij = self._reference_imageij.to(dtype=dtype, device=device) - self._reference_imagexy = self._reference_imagexy.to(dtype=dtype, device=device) - - def __str__(self): - return f"PPCS reference_imageij: {self.reference_imageij.detach().cpu().tolist()}, reference_imagexy: {self.reference_imagexy.detach().cpu().tolist()}" - - def __repr__(self): - return f"PPCS reference_imageij: {self.reference_imageij.detach().cpu().tolist()}, reference_imagexy: {self.reference_imagexy.detach().cpu().tolist()}, pixelscale: {self.pixelscale.detach().cpu().tolist()}" - - -class WCS(WPCS, PPCS): - """ - Full world coordinate system defines mappings from world to tangent plane to pixel grid and all other variations. - """ - - def __init__(self, *args, wcs=None, **kwargs): - if kwargs.get("state", None) is not None: - self.set_state(kwargs["state"]) - return - - if wcs is not None: - if wcs.wcs.ctype[0] != "RA---TAN": - AP_config.ap_logger.warning( - "Astropy WCS not tangent plane coordinate system! May not be compatible with AstroPhot." - ) - if wcs.wcs.ctype[1] != "DEC--TAN": - AP_config.ap_logger.warning( - "Astropy WCS not tangent plane coordinate system! May not be compatible with AstroPhot." - ) - - if wcs is not None: - kwargs["reference_radec"] = kwargs.get("reference_radec", wcs.wcs.crval) - kwargs["reference_imageij"] = wcs.wcs.crpix - WPCS.__init__(self, *args, wcs=wcs, **kwargs) - sky_coord = wcs.pixel_to_world(*wcs.wcs.crpix) - kwargs["reference_imagexy"] = self.world_to_plane( - torch.tensor( - (sky_coord.ra.deg, sky_coord.dec.deg), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - ) - else: - WPCS.__init__(self, *args, **kwargs) - - PPCS.__init__(self, *args, wcs=wcs, **kwargs) - - def world_to_pixel(self, world_RA, world_DEC=None): - """A wrapper which applies :meth:`world_to_plane` then - :meth:`plane_to_pixel`, see those methods for further - information. - - """ - if world_DEC is None: - return torch.stack(self.world_to_pixel(*world_RA)) - return self.plane_to_pixel(*self.world_to_plane(world_RA, world_DEC)) - - def pixel_to_world(self, pixel_i, pixel_j=None): - """A wrapper which applies :meth:`pixel_to_plane` then - :meth:`plane_to_world`, see those methods for further - information. - - """ - if pixel_j is None: - return torch.stack(self.pixel_to_world(*pixel_i)) - return self.plane_to_world(*self.pixel_to_plane(pixel_i, pixel_j)) - - def copy(self, **kwargs): - copy_kwargs = { - "pixelscale": self.pixelscale, - "reference_imageij": self.reference_imageij, - "reference_imagexy": self.reference_imagexy, - "projection": self.projection, - "reference_radec": self.reference_radec, - "reference_planexy": self.reference_planexy, - } - copy_kwargs.update(kwargs) - return self.__class__( - **copy_kwargs, - ) - - def to(self, dtype=None, device=None): - WPCS.to(self, dtype, device) - PPCS.to(self, dtype, device) - - def get_state(self): - state = WPCS.get_state(self) - state.update(PPCS.get_state(self)) - return state - - def set_state(self, state): - WPCS.set_state(self, state) - PPCS.set_state(self, state) - - def get_fits_state(self): - """ - Similar to get_state, except specifically tailored to be stored in a FITS format. - """ - state = WPCS.get_fits_state(self) - state.update(PPCS.get_fits_state(self)) - return state - - def set_fits_state(self, state): - """ - Reads and applies the state from the get_fits_state function. - """ - WPCS.set_fits_state(self, state) - PPCS.set_fits_state(self, state) - - def __str__(self): - return f"WCS:\n{WPCS.__str__(self)}\n{PPCS.__str__(self)}" - - def __repr__(self): - return f"WCS:\n{WPCS.__repr__(self)}\n{PPCS.__repr__(self)}" diff --git a/astrophot/image/window.py b/astrophot/image/window.py new file mode 100644 index 00000000..397e3cde --- /dev/null +++ b/astrophot/image/window.py @@ -0,0 +1,166 @@ +from typing import Union, Tuple, List + +import numpy as np + +from ..errors import InvalidWindow + +__all__ = ("Window",) + + +class Window: + def __init__( + self, + window: Union[Tuple[int, int, int, int], Tuple[Tuple[int, int], Tuple[int, int]]], + image: "Image", + ): + self.extent = window + self.image = image + + @property + def identity(self): + return self.image.identity + + @property + def crpix(self): + return self.image.crpix + + @property + def shape(self): + return (self.i_high - self.i_low, self.j_high - self.j_low) + + @property + def extent(self): + return (self.i_low, self.i_high, self.j_low, self.j_high) + + @extent.setter + def extent( + self, value: Union[Tuple[int, int, int, int], Tuple[Tuple[int, int], Tuple[int, int]]] + ): + if len(value) == 4: + self.i_low, self.i_high, self.j_low, self.j_high = value + elif len(value) == 2: + self.i_low, self.j_low = value[0] + self.i_high, self.j_high = value[1] + else: + raise ValueError( + "Extent must be formatted as (i_low, i_high, j_low, j_high) or ((i_low, j_low), (i_high, j_high))" + ) + + def chunk(self, chunk_size: int) -> List["Window"]: + # number of pixels on each axis + px = self.i_high - self.i_low + py = self.j_high - self.j_low + # total number of chunks desired + chunk_tot = int(np.ceil((px * py) / chunk_size)) + # number of chunks on each axis + cx = int(np.ceil(np.sqrt(chunk_tot * px / py))) + cy = int(np.ceil(chunk_tot / cx)) + # number of pixels on each axis per chunk + stepx = int(np.ceil(px / cx)) + stepy = int(np.ceil(py / cy)) + # create the windows + windows = [] + for i in range(self.i_low, self.i_high, stepx): + for j in range(self.j_low, self.j_high, stepy): + i_high = min(i + stepx, self.i_high) + j_high = min(j + stepy, self.j_high) + windows.append(Window((i, i_high, j, j_high), self.image)) + return windows + + def pad(self, pad: int): + self.i_low -= pad + self.i_high += pad + self.j_low -= pad + self.j_high += pad + + def copy(self): + return Window((self.i_low, self.i_high, self.j_low, self.j_high), self.image) + + def __or__(self, other: "Window"): + if not isinstance(other, Window): + raise TypeError(f"Cannot combine Window with {type(other)}") + if self.image != other.image: + raise InvalidWindow( + f"Cannot combine Windows from different images: {self.image.identity} and {other.image.identity}" + ) + new_i_low = min(self.i_low, other.i_low) + new_i_high = max(self.i_high, other.i_high) + new_j_low = min(self.j_low, other.j_low) + new_j_high = max(self.j_high, other.j_high) + return Window((new_i_low, new_i_high, new_j_low, new_j_high), self.image) + + def __ior__(self, other: "Window"): + if not isinstance(other, Window): + raise TypeError(f"Cannot combine Window with {type(other)}") + if self.image != other.image: + raise InvalidWindow( + f"Cannot combine Windows from different images: {self.image.identity} and {other.image.identity}" + ) + self.i_low = min(self.i_low, other.i_low) + self.i_high = max(self.i_high, other.i_high) + self.j_low = min(self.j_low, other.j_low) + self.j_high = max(self.j_high, other.j_high) + return self + + def __and__(self, other: "Window"): + if not isinstance(other, Window): + raise TypeError(f"Cannot intersect Window with {type(other)}") + if self.image.identity != other.image.identity: + raise InvalidWindow( + f"Cannot combine Windows from different images: {self.image.identity} and {other.image.identity}" + ) + if ( + self.i_high <= other.i_low + or self.i_low >= other.i_high + or self.j_high <= other.j_low + or self.j_low >= other.j_high + ): + return Window((0, 0, 0, 0), self.image) + # fixme handle crpix + new_i_low = max(self.i_low, other.i_low) + new_i_high = min(self.i_high, other.i_high) + new_j_low = max(self.j_low, other.j_low) + new_j_high = min(self.j_high, other.j_high) + return Window((new_i_low, new_i_high, new_j_low, new_j_high), self.image) + + def __str__(self): + return f"Window({self.i_low}, {self.i_high}, {self.j_low}, {self.j_high})" + + +class WindowList: + def __init__(self, windows: list[Window]): + if not all(isinstance(window, Window) for window in windows): + raise InvalidWindow( + f"Window_List can only hold Window objects, not {tuple(type(window) for window in windows)}" + ) + self.windows = windows + + def index(self, other: Window) -> int: + for i, window in enumerate(self.windows): + if other.identity == window.identity: + return i + else: + raise IndexError("Could not find identity match between window list and input window") + + def __and__(self, other: "WindowList"): + if not isinstance(other, WindowList): + raise TypeError(f"Cannot intersect WindowList with {type(other)}") + if len(self.windows) == 0 or len(other.windows) == 0: + return WindowList([]) + new_windows = [] + for other_window in other.windows: + try: + i = self.index(other_window) + except IndexError: + continue # skip if the window is not in self.windows + new_windows.append(self.windows[i] & other_window) + return WindowList(new_windows) + + def __getitem__(self, index): + return self.windows[index] + + def __len__(self): + return len(self.windows) + + def __iter__(self): + return iter(self.windows) diff --git a/astrophot/image/window_object.py b/astrophot/image/window_object.py deleted file mode 100644 index d237d016..00000000 --- a/astrophot/image/window_object.py +++ /dev/null @@ -1,668 +0,0 @@ -import torch -from astropy.wcs import WCS as AstropyWCS - -from .. import AP_config -from .wcs import WCS -from ..errors import ConflicingWCS, SpecificationConflict - -__all__ = ["Window", "Window_List"] - - -class Window(WCS): - """class to define a window on the sky in coordinate space. These - windows can undergo arithmetic and preserve logical behavior. Image - objects can also be indexed using windows and will return an - appropriate subsection of their data. - - There are several ways to tell a Window object where to - place itself. The simplest method is to pass an - Astropy WCS object such as:: - - H = ap.image.Window(wcs = wcs) - - this will automatically place your image at the correct RA, DEC - and assign the correct pixel scale, etc. WARNING, it will default to - setting the reference RA DEC at the reference RA DEC of the wcs - object; if you have multiple images you should force them all to - have the same reference world coordinate by passing - ``reference_radec = (ra, dec)``. See the :doc:`coordinates` - documentation for more details. There are several other ways to - initialize a window. If you provide ``origin_radec`` then - it will place the image origin at the requested RA DEC - coordinates. If you provide ``center_radec`` then it will place - the image center at the requested RA DEC coordinates. Note that in - these cases the fixed point between the pixel grid and image plane - is different (pixel origin and center respectively); so if you - have rotated pixels in your pixel scale matrix then everything - will be rotated about different points (pixel origin and center - respectively). If you provide ``origin`` or ``center`` then those - are coordinates in the tangent plane (arcsec) and they will - correspondingly become fixed points. For arbitrary control over - the pixel positioning, use ``reference_imageij`` and - ``reference_imagexy`` to fix the pixel and tangent plane - coordinates respectively to each other, any rotation or shear will - happen about that fixed point. - - Args: - origin : Sequence or None, optional - The origin of the image in the tangent plane coordinate system - (arcsec), as a 1D array of length 2. Default is None. - origin_radec : Sequence or None, optional - The origin of the image in the world coordinate system (RA, - DEC in degrees), as a 1D array of length 2. Default is None. - center : Sequence or None, optional - The center of the image in the tangent plane coordinate system - (arcsec), as a 1D array of length 2. Default is None. - center_radec : Sequence or None, optional - The center of the image in the world coordinate system (RA, - DEC in degrees), as a 1D array of length 2. Default is None. - wcs: An astropy.wcs.WCS object which gives information about the - origin and orientation of the window. - reference_radec: world coordinates on the celestial sphere (RA, - DEC in degrees) where the tangent plane makes contact. This should - be the same for every image in multi-image analysis. - reference_planexy: tangent plane coordinates (arcsec) where it - makes contact with the celesial sphere. This should typically be - (0,0) though that is not stricktly enforced (it is assumed if not - given). This reference coordinate should be the same for all - images in multi-image analysis. - reference_imageij: pixel coordinates about which the image is - defined. For example in an Astropy WCS object the wcs.wcs.crpix - array gives the pixel coordinate reference point for which the - world coordinate mapping (wcs.wcs.crval) is defined. One may think - of the referenced pixel location as being "pinned" to the tangent - plane. This may be different for each image in multi-image - analysis.. - reference_imagexy: tangent plane coordinates (arcsec) about - which the image is defined. This is the pivot point about which the - pixelscale matrix operates, therefore if the pixelscale matrix - defines a rotation then this is the coordinate about which the - rotation will be performed. This may be different for each image in - multi-image analysis. - - """ - - def __init__( - self, - *, - pixel_shape=None, - origin=None, - origin_radec=None, - center=None, - center_radec=None, - state=None, - fits_state=None, - wcs=None, - **kwargs, - ): - # If loading from a previous state, simply update values and end init - if state is not None: - self.set_state(state) - return - if fits_state is not None: - self.set_fits_state(fits_state) - return - - # Collect the shape of the window - if pixel_shape is not None: - self.pixel_shape = pixel_shape - else: - self.pixel_shape = wcs.pixel_shape - - # Determine relative positioning of tangent plane and pixel grid. Also world coordinates and tangent plane - if not sum(C is not None for C in [wcs, origin_radec, center_radec, origin, center]) <= 1: - raise SpecificationConflict( - "Please provide only one reference position for the window, otherwise the placement is ambiguous" - ) - - # Image coordinates provided by WCS - if wcs is not None: - super().__init__(wcs=wcs, **kwargs) - # Image reference position from RA and DEC of image origin - elif origin_radec is not None: - # Origin given, it is reference point - origin_radec = torch.as_tensor( - origin_radec, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - kwargs["reference_radec"] = kwargs.get("reference_radec", origin_radec) - super().__init__(**kwargs) - self.reference_imageij = (-0.5, -0.5) - self.reference_imagexy = self.world_to_plane(origin_radec) - # Image reference position from RA and DEC of image center - elif center_radec is not None: - pix_center = self.pixel_shape.to(dtype=AP_config.ap_dtype) / 2 - 0.5 - center_radec = torch.as_tensor( - center_radec, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - kwargs["reference_radec"] = kwargs.get("reference_radec", center_radec) - super().__init__(**kwargs) - center = self.world_to_plane(center_radec) - self.reference_imageij = pix_center - self.reference_imagexy = center - # Image reference position from tangent plane position of image origin - elif origin is not None: - kwargs.update( - { - "reference_imageij": (-0.5, -0.5), - "reference_imagexy": origin, - } - ) - super().__init__(**kwargs) - # Image reference position from tangent plane position of image center - elif center is not None: - pix_center = self.pixel_shape.to(dtype=AP_config.ap_dtype) / 2 - 0.5 - kwargs.update( - { - "reference_imageij": pix_center, - "reference_imagexy": center, - } - ) - super().__init__(**kwargs) - # Image origin assumed to be at tangent plane origin - else: - super().__init__(**kwargs) - - @property - def shape(self): - dtype, device = self.pixelscale.dtype, self.pixelscale.device - S1 = self.pixel_shape.to(dtype=dtype, device=device) - S1[1] = 0.0 - S2 = self.pixel_shape.to(dtype=dtype, device=device) - S2[0] = 0.0 - return torch.stack( - ( - torch.linalg.norm(self.pixelscale @ S1), - torch.linalg.norm(self.pixelscale @ S2), - ) - ) - - @shape.setter - def shape(self, shape): - if shape is None: - self._pixel_shape = None - return - shape = torch.as_tensor(shape, dtype=self.pixelscale.dtype, device=self.pixelscale.device) - self.pixel_shape = shape / torch.sqrt(torch.sum(self.pixelscale**2, dim=0)) - - @property - def pixel_shape(self): - return self._pixel_shape - - @pixel_shape.setter - def pixel_shape(self, shape): - if shape is None: - self._pixel_shape = None - return - self._pixel_shape = torch.as_tensor(shape, device=AP_config.ap_device) - self._pixel_shape = torch.round(self.pixel_shape).to( - dtype=torch.int32, device=AP_config.ap_device - ) - - @property - def size(self): - """The number of pixels in the window""" - return torch.prod(self.pixel_shape) - - @property - def end(self): - return self.pixel_to_plane_delta( - self.pixel_shape.to(dtype=self.pixelscale.dtype, device=self.pixelscale.device) - ) - - @property - def origin(self): - return self.pixel_to_plane(-0.5 * torch.ones_like(self.reference_imageij)) - - @property - def center(self): - return self.origin + self.end / 2 - - def copy(self, **kwargs): - copy_kwargs = {"pixel_shape": torch.clone(self.pixel_shape)} - copy_kwargs.update(kwargs) - return super().copy(**copy_kwargs) - - def to(self, dtype=None, device=None): - if dtype is None: - dtype = AP_config.ap_dtype - if device is None: - device = AP_config.ap_device - super().to(dtype=dtype, device=device) - self.pixel_shape = self.pixel_shape.to(dtype=dtype, device=device) - - def rescale_pixel(self, scale, **kwargs): - return self.copy( - pixelscale=self.pixelscale * scale, - pixel_shape=self.pixel_shape // scale, - reference_imageij=(self.reference_imageij + 0.5) / scale - 0.5, - **kwargs, - ) - - @staticmethod - @torch.no_grad() - def _get_indices(ref_window, obj_window): - other_origin_pix = torch.round(ref_window.plane_to_pixel(obj_window.origin) + 0.5).int() - new_origin_pix = torch.maximum(torch.zeros_like(other_origin_pix), other_origin_pix) - - other_pixel_end = torch.round( - ref_window.plane_to_pixel(obj_window.origin + obj_window.end) + 0.5 - ).int() - new_pixel_end = torch.minimum(ref_window.pixel_shape, other_pixel_end) - return slice(new_origin_pix[1], new_pixel_end[1]), slice( - new_origin_pix[0], new_pixel_end[0] - ) - - def get_self_indices(self, obj): - """ - Return an index slicing tuple for obj corresponding to this window - """ - if isinstance(obj, Window): - return self._get_indices(self, obj) - return self._get_indices(self, obj.window) - - def get_other_indices(self, obj): - """ - Return an index slicing tuple for obj corresponding to this window - """ - if isinstance(obj, Window): - return self._get_indices(obj, self) - return self._get_indices(obj.window, self) - - def overlap_frac(self, other): - overlap = self & other - overlap_area = torch.prod(overlap.shape) - full_area = torch.prod(self.shape) + torch.prod(other.shape) - overlap_area - return overlap_area / full_area - - def shift(self, shift): - """ - Shift the location of the window by a specified amount in tangent plane coordinates - """ - self.reference_imagexy = self.reference_imagexy + shift - return self - - def pixel_shift(self, shift): - """ - Shift the location of the window by a specified amount in pixel grid coordinates - """ - - self.reference_imageij = self.reference_imageij - shift - return self - - def get_astropywcs(self, **kwargs): - wargs = { - "NAXIS": 2, - "NAXIS1": self.pixel_shape[0].item(), - "NAXIS2": self.pixel_shape[1].item(), - "CTYPE1": "RA---TAN", - "CTYPE2": "DEC--TAN", - "CRVAL1": self.pixel_to_world(self.reference_imageij)[0].item(), - "CRVAL2": self.pixel_to_world(self.reference_imageij)[1].item(), - "CRPIX1": self.reference_imageij[0].item(), - "CRPIX2": self.reference_imageij[1].item(), - "CD1_1": self.pixelscale[0][0].item(), - "CD1_2": self.pixelscale[0][1].item(), - "CD2_1": self.pixelscale[1][0].item(), - "CD2_2": self.pixelscale[1][1].item(), - } - wargs.update(kwargs) - return AstropyWCS(wargs) - - def get_state(self): - state = super().get_state() - state["pixel_shape"] = self.pixel_shape.detach().cpu().tolist() - return state - - def set_state(self, state): - super().set_state(state) - self.pixel_shape = torch.tensor( - state["pixel_shape"], dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - - def get_fits_state(self): - state = super().get_fits_state() - state["PXL_SHPE"] = str(self.pixel_shape.detach().cpu().tolist()) - return state - - def set_fits_state(self, state): - super().set_fits_state(state) - self.pixel_shape = torch.tensor( - eval(state["PXL_SHPE"]), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - - def crop_pixel(self, pixels): - """ - [crop all sides] or - [crop x, crop y] or - [crop x low, crop y low, crop x high, crop y high] - """ - if len(pixels) == 1: - self.pixel_shape = self.pixel_shape - 2 * pixels[0] - self.reference_imageij = self.reference_imageij - pixels[0] - elif len(pixels) == 2: - pix_shift = torch.as_tensor( - pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - self.pixel_shape = self.pixel_shape - 2 * pix_shift - self.reference_imageij = self.reference_imageij - pix_shift - elif len(pixels) == 4: # different crop on all sides - pixels = torch.as_tensor(pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - self.pixel_shape = self.pixel_shape - pixels[:2] - pixels[2:] - self.reference_imageij = self.reference_imageij - pixels[:2] - else: - raise ValueError(f"Unrecognized pixel crop format: {pixels}") - return self - - def crop_to_pixel(self, pixels): - """ - format: [[xmin, xmax],[ymin,ymax]] - """ - pixels = torch.tensor(pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - self.reference_imageij = self.reference_imageij - pixels[:, 0] - self.pixel_shape = pixels[:, 1] - pixels[:, 0] - return self - - def pad_pixel(self, pixels): - """ - [pad all sides] or - [pad x, pad y] or - [pad x low, pad y low, pad x high, pad y high] - """ - if len(pixels) == 1: - self.pixel_shape = self.pixel_shape + 2 * pixels[0] - self.reference_imageij = self.reference_imageij + pixels[0] - elif len(pixels) == 2: - pix_shift = torch.as_tensor( - pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - self.pixel_shape = self.pixel_shape + 2 * pix_shift - self.reference_imageij = self.reference_imageij + pix_shift - elif len(pixels) == 4: # different crop on all sides - pixels = torch.as_tensor(pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - self.pixel_shape = self.pixel_shape + pixels[:2] + pixels[2:] - self.reference_imageij = self.reference_imageij + pixels[:2] - else: - raise ValueError(f"Unrecognized pixel crop format: {pixels}") - return self - - @torch.no_grad() - def get_coordinate_meshgrid(self): - """Returns a meshgrid with tangent plane coordinates for the center - of every pixel. - - """ - pix = self.pixel_shape.to(dtype=AP_config.ap_dtype) - xsteps = torch.arange(pix[0], dtype=AP_config.ap_dtype, device=AP_config.ap_device) - ysteps = torch.arange(pix[1], dtype=AP_config.ap_dtype, device=AP_config.ap_device) - meshx, meshy = torch.meshgrid( - xsteps, - ysteps, - indexing="xy", - ) - Coords = self.pixel_to_plane(meshx, meshy) - return torch.stack(Coords) - - @torch.no_grad() - def get_coordinate_corner_meshgrid(self): - """Returns a meshgrid with tangent plane coordinates for the corners - of every pixel. - - """ - pix = self.pixel_shape.to(dtype=AP_config.ap_dtype) - xsteps = ( - torch.arange(pix[0] + 1, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - 0.5 - ) - ysteps = ( - torch.arange(pix[1] + 1, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - 0.5 - ) - meshx, meshy = torch.meshgrid( - xsteps, - ysteps, - indexing="xy", - ) - Coords = self.pixel_to_plane(meshx, meshy) - return torch.stack(Coords) - - @torch.no_grad() - def get_coordinate_simps_meshgrid(self): - """Returns a meshgrid with tangent plane coordinates for performing - simpsons method pixel integration (all corners, centers, and - middle of each edge). This is approximately 4 times more - points than the standard :meth:`get_coordinate_meshgrid`. - - """ - pix = self.pixel_shape.to(dtype=AP_config.ap_dtype) - xsteps = ( - 0.5 - * torch.arange( - 2 * (pix[0]) + 1, - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - - 0.5 - ) - ysteps = ( - 0.5 - * torch.arange( - 2 * (pix[1]) + 1, - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - - 0.5 - ) - meshx, meshy = torch.meshgrid( - xsteps, - ysteps, - indexing="xy", - ) - Coords = self.pixel_to_plane(meshx, meshy) - return torch.stack(Coords) - - # Window Comparison operators - @torch.no_grad() - def __eq__(self, other): - return ( - torch.all(self.pixel_shape == other.pixel_shape) - and torch.all(self.pixelscale == other.pixelscale) - and (self.projection == other.projection) - and ( - torch.all( - self.pixel_to_plane(torch.zeros_like(self.reference_imageij)) - == other.pixel_to_plane(torch.zeros_like(other.reference_imageij)) - ) - ) - ) # fixme more checks? - - @torch.no_grad() - def __ne__(self, other): - return not self == other - - # Window interaction operators - @torch.no_grad() - def __or__(self, other): - other_origin_pix = self.plane_to_pixel(other.origin) - new_origin_pix = torch.minimum(-0.5 * torch.ones_like(other_origin_pix), other_origin_pix) - - other_pixel_end = self.plane_to_pixel(other.origin + other.end) - new_pixel_end = torch.maximum( - self.pixel_shape.to(dtype=AP_config.ap_dtype), other_pixel_end - ) - return self.copy( - origin=self.pixel_to_plane(new_origin_pix), - pixel_shape=new_pixel_end - new_origin_pix, - ) - - @torch.no_grad() - def __ior__(self, other): - other_origin_pix = self.plane_to_pixel(other.origin) - new_origin_pix = torch.minimum(-0.5 * torch.ones_like(other_origin_pix), other_origin_pix) - - other_pixel_end = self.plane_to_pixel(other.origin + other.end) - new_pixel_end = torch.maximum( - self.pixel_shape.to(dtype=AP_config.ap_dtype), other_pixel_end - ) - - self.reference_imageij = self.reference_imageij - (new_origin_pix + 0.5) - self.pixel_shape = new_pixel_end - new_origin_pix - return self - - @torch.no_grad() - def __and__(self, other): - other_origin_pix = self.plane_to_pixel(other.origin) - new_origin_pix = torch.maximum(-0.5 * torch.ones_like(other_origin_pix), other_origin_pix) - - other_pixel_end = self.plane_to_pixel(other.origin + other.end) - new_pixel_end = torch.minimum( - self.pixel_shape.to(dtype=AP_config.ap_dtype) - 0.5, other_pixel_end - ) - return self.copy( - origin=self.pixel_to_plane(new_origin_pix), - pixel_shape=new_pixel_end - new_origin_pix, - ) - - @torch.no_grad() - def __iand__(self, other): - other_origin_pix = self.plane_to_pixel(other.origin) - new_origin_pix = torch.maximum(-0.5 * torch.ones_like(other_origin_pix), other_origin_pix) - - other_pixel_end = self.plane_to_pixel(other.origin + other.end) - new_pixel_end = torch.minimum( - self.pixel_shape.to(dtype=AP_config.ap_dtype), other_pixel_end - ) - - self.reference_imageij = self.reference_imageij - (new_origin_pix + 0.5) - self.pixel_shape = new_pixel_end - new_origin_pix - return self - - def __str__(self): - return f"window origin: {self.origin.detach().cpu().tolist()}, shape: {self.shape.detach().cpu().tolist()}, center: {self.center.detach().cpu().tolist()}, pixelscale: {self.pixelscale.detach().cpu().tolist()}" - - def __repr__(self): - return ( - f"window pixel_shape: {self.pixel_shape.detach().cpu().tolist()}, shape: {self.shape.detach().cpu().tolist()}\n" - + super().__repr__() - ) - - -class Window_List(Window): - def __init__(self, window_list=None, state=None): - if state is not None: - self.set_state(state) - else: - if window_list is None: - window_list = [] - self.window_list = list(window_list) - - self.check_wcs() - - def check_wcs(self): - """Ensure the WCS systems being used by all the windows in this list - are consistent with each other. They should all project world - coordinates onto the same tangent plane. - - """ - windows = tuple( - W.reference_radec for W in filter(lambda w: w is not None, self.window_list) - ) - if len(windows) == 0: - return - ref = torch.stack(windows) - if not torch.allclose(ref, ref[0]): - raise ConflicingWCS( - "Reference (world) coordinate mismatch! All windows in Window_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means." - ) - - ref = torch.stack( - tuple(W.reference_planexy for W in filter(lambda w: w is not None, self.window_list)) - ) - if not torch.allclose(ref, ref[0]): - raise ConflicingWCS( - "Reference (tangent plane) coordinate mismatch! All windows in Window_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means." - ) - - if len(set(W.projection for W in filter(lambda w: w is not None, self.window_list))) > 1: - raise ConflicingWCS( - "Projection mismatch! All windows in Window_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means." - ) - - @property - @torch.no_grad() - def origin(self): - return tuple(w.origin for w in self) - - @property - @torch.no_grad() - def shape(self): - return tuple(w.shape for w in self) - - @property - @torch.no_grad() - def center(self): - return tuple(w.center for w in self) - - def shift_origin(self, shift): - raise NotImplementedError("shift origin not implemented for window list") - - def copy(self): - return self.__class__(list(w.copy() for w in self)) - - def to(self, dtype=None, device=None): - if dtype is None: - dtype = AP_config.ap_dtype - if device is None: - device = AP_config.ap_device - for window in self: - window.to(dtype, device) - - def get_state(self): - return list(window.get_state() for window in self) - - def set_state(self, state): - self.window_list = list(Window(state=st) for st in state) - - # Window interaction operators - @torch.no_grad() - def __or__(self, other): - new_windows = list((sw | ow) for sw, ow in zip(self, other)) - return self.__class__(window_list=new_windows) - - @torch.no_grad() - def __ior__(self, other): - for sw, ow in zip(self, other): - sw |= ow - return self - - @torch.no_grad() - def __and__(self, other): - new_windows = list((sw & ow) for sw, ow in zip(self, other)) - return self.__class__(window_list=new_windows) - - @torch.no_grad() - def __iand__(self, other): - for sw, ow in zip(self, other): - sw &= ow - return self - - # Window Comparison operators - @torch.no_grad() - def __eq__(self, other): - results = list((sw == ow).view(-1) for sw, ow in zip(self, other)) - return torch.all(torch.cat(results)) - - @torch.no_grad() - def __ne__(self, other): - return not self == other - - def __len__(self): - return len(self.window_list) - - def __iter__(self): - return (win for win in self.window_list) - - def __str__(self): - return "Window List: \n" + ("\n".join(list(str(window) for window in self)) + "\n") - - def __repr__(self): - return "Window List: \n" + ("\n".join(list(repr(window) for window in self)) + "\n") diff --git a/astrophot/models/__init__.py b/astrophot/models/__init__.py index 81edb2c8..6858ddca 100644 --- a/astrophot/models/__init__.py +++ b/astrophot/models/__init__.py @@ -1,28 +1,238 @@ -from .core_model import * -from .model_object import * -from .galaxy_model_object import * -from .ray_model import * -from .sersic_model import * -from .group_model_object import * -from .sky_model_object import * -from .flatsky_model import * -from .planesky_model import * -from .gaussian_model import * -from .multi_gaussian_expansion_model import * -from .spline_model import * -from .relspline_model import * -from .psf_model_object import * -from .pixelated_psf_model import * -from .eigen_psf import * -from .superellipse_model import * -from .edgeon_model import * -from .exponential_model import * -from .foureirellipse_model import * -from .wedge_model import * -from .warp_model import * -from .moffat_model import * -from .nuker_model import * -from .zernike_model import * -from .airy_psf import * -from .point_source import * -from .group_psf_model import * +# Base model object +from .base import Model + +# Primary model types +from .model_object import ComponentModel +from .psf_model_object import PSFModel +from .group_model_object import GroupModel +from .group_psf_model import PSFGroupModel + +# Component model main types +from .galaxy_model_object import GalaxyModel +from .sky_model_object import SkyModel +from .point_source import PointSource + +# subtypes of PSFModel +from .basis import PixelBasisPSF +from .airy import AiryPSF +from .pixelated_psf import PixelatedPSF + +# Subtypes of SkyModel +from .flatsky import FlatSky +from .planesky import PlaneSky +from .bilinear_sky import BilinearSky + +# Special galaxy types +from .edgeon import EdgeonModel, EdgeonSech, EdgeonIsothermal +from .multi_gaussian_expansion import MultiGaussianExpansion +from .gaussian_ellipsoid import GaussianEllipsoid + +# Standard models based on a core radial profile +from .sersic import ( + SersicGalaxy, + SersicPSF, + SersicFourierEllipse, + SersicSuperEllipse, + SersicWarp, + SersicRay, + SersicWedge, +) +from .exponential import ( + ExponentialGalaxy, + ExponentialPSF, + ExponentialSuperEllipse, + ExponentialFourierEllipse, + ExponentialWarp, + ExponentialRay, + ExponentialWedge, +) +from .gaussian import ( + GaussianGalaxy, + GaussianPSF, + GaussianSuperEllipse, + GaussianFourierEllipse, + GaussianWarp, + GaussianRay, + GaussianWedge, +) +from .moffat import ( + MoffatGalaxy, + MoffatPSF, + Moffat2DPSF, + MoffatFourierEllipse, + MoffatRay, + MoffatWedge, + MoffatWarp, + MoffatSuperEllipse, +) +from .ferrer import ( + FerrerGalaxy, + FerrerPSF, + FerrerSuperEllipse, + FerrerFourierEllipse, + FerrerWarp, + FerrerRay, + FerrerWedge, +) +from .king import ( + KingGalaxy, + KingPSF, + KingSuperEllipse, + KingFourierEllipse, + KingWarp, + KingRay, + KingWedge, +) +from .nuker import ( + NukerGalaxy, + NukerPSF, + NukerFourierEllipse, + NukerSuperEllipse, + NukerWarp, + NukerRay, + NukerWedge, +) +from .spline import ( + SplineGalaxy, + SplinePSF, + SplineFourierEllipse, + SplineSuperEllipse, + SplineWarp, + SplineRay, + SplineWedge, +) + +from .mixins import ( + RadialMixin, + WedgeMixin, + RayMixin, + ExponentialMixin, + iExponentialMixin, + FerrerMixin, + iFerrerMixin, + GaussianMixin, + iGaussianMixin, + KingMixin, + iKingMixin, + MoffatMixin, + iMoffatMixin, + NukerMixin, + iNukerMixin, + SersicMixin, + iSersicMixin, + SplineMixin, + iSplineMixin, + SampleMixin, + InclinedMixin, + SuperEllipseMixin, + FourierEllipseMixin, + WarpMixin, + TruncationMixin, +) +from . import func + + +__all__ = ( + "Model", + "ComponentModel", + "PSFModel", + "GroupModel", + "PSFGroupModel", + "GalaxyModel", + "SkyModel", + "PointSource", + "PixelBasisPSF", + "AiryPSF", + "PixelatedPSF", + "FlatSky", + "PlaneSky", + "BilinearSky", + "EdgeonModel", + "EdgeonSech", + "EdgeonIsothermal", + "MultiGaussianExpansion", + "GaussianEllipsoid", + "SersicGalaxy", + "SersicPSF", + "SersicFourierEllipse", + "SersicSuperEllipse", + "SersicWarp", + "SersicRay", + "SersicWedge", + "ExponentialGalaxy", + "ExponentialPSF", + "ExponentialSuperEllipse", + "ExponentialFourierEllipse", + "ExponentialWarp", + "ExponentialRay", + "ExponentialWedge", + "GaussianGalaxy", + "GaussianPSF", + "GaussianSuperEllipse", + "GaussianFourierEllipse", + "GaussianWarp", + "GaussianRay", + "GaussianWedge", + "MoffatGalaxy", + "MoffatPSF", + "Moffat2DPSF", + "MoffatFourierEllipse", + "MoffatRay", + "MoffatWedge", + "MoffatWarp", + "MoffatSuperEllipse", + "FerrerGalaxy", + "FerrerPSF", + "FerrerSuperEllipse", + "FerrerFourierEllipse", + "FerrerWarp", + "FerrerRay", + "FerrerWedge", + "KingGalaxy", + "KingPSF", + "KingSuperEllipse", + "KingFourierEllipse", + "KingWarp", + "KingRay", + "KingWedge", + "NukerGalaxy", + "NukerPSF", + "NukerFourierEllipse", + "NukerSuperEllipse", + "NukerWarp", + "NukerRay", + "NukerWedge", + "SplineGalaxy", + "SplinePSF", + "SplineFourierEllipse", + "SplineWarp", + "SplineSuperEllipse", + "SplineRay", + "SplineWedge", + "RadialMixin", + "WedgeMixin", + "RayMixin", + "ExponentialMixin", + "iExponentialMixin", + "FerrerMixin", + "iFerrerMixin", + "GaussianMixin", + "iGaussianMixin", + "KingMixin", + "iKingMixin", + "MoffatMixin", + "iMoffatMixin", + "NukerMixin", + "iNukerMixin", + "SersicMixin", + "iSersicMixin", + "SplineMixin", + "iSplineMixin", + "SampleMixin", + "InclinedMixin", + "SuperEllipseMixin", + "FourierEllipseMixin", + "WarpMixin", + "TruncationMixin", + "func", +) diff --git a/astrophot/models/_model_methods.py b/astrophot/models/_model_methods.py deleted file mode 100644 index d934c490..00000000 --- a/astrophot/models/_model_methods.py +++ /dev/null @@ -1,482 +0,0 @@ -from typing import Optional, Union -import io -from copy import deepcopy - -import numpy as np -import torch -from torch.autograd.functional import jacobian as torchjac - -from ..param import Parameter_Node, Param_Mask -from ..utils.decorators import default_internal -from ..utils.interpolate import ( - _shift_Lanczos_kernel_torch, - simpsons_kernel, - curvature_kernel, - interp2d, -) -from ..image import ( - Window, - Jacobian_Image, - Window_List, - PSF_Image, -) -from ..utils.operations import ( - fft_convolve_torch, - grid_integrate, - single_quad_integrate, -) -from ..errors import SpecificationConflict -from .core_model import AstroPhot_Model -from .. import AP_config - - -@default_internal -def angular_metric(self, X, Y, image=None, parameters=None): - return torch.atan2(Y, X) - - -@default_internal -def radius_metric(self, X, Y, image=None, parameters=None): - return torch.sqrt(X**2 + Y**2 + self.softening**2) - - -@classmethod -def build_parameter_specs(cls, user_specs=None): - parameter_specs = {} - for base in cls.__bases__: - try: - parameter_specs.update(base.build_parameter_specs()) - except AttributeError: - pass - parameter_specs.update(cls.parameter_specs) - parameter_specs = deepcopy(parameter_specs) - if isinstance(user_specs, dict): - for p in user_specs: - # If the user supplied a parameter object subclass, simply use that as is - if isinstance(user_specs[p], Parameter_Node): - parameter_specs[p] = user_specs[p] - elif isinstance( - user_specs[p], dict - ): # if the user supplied parameter specifications, update the defaults - parameter_specs[p].update(user_specs[p]) - else: - parameter_specs[p]["value"] = user_specs[p] - - return parameter_specs - - -def build_parameters(self): - for p in self.__class__._parameter_order: - # skip if the parameter already exists - if p in self.parameters: - continue - # If a parameter object is provided, simply use as-is - if isinstance(self.parameter_specs[p], Parameter_Node): - self.parameters.link(self.parameter_specs[p].to()) - elif isinstance(self.parameter_specs[p], dict): - self.parameters.link(Parameter_Node(p, **self.parameter_specs[p])) - else: - raise ValueError(f"unrecognized parameter specification for {p}") - - -def _sample_init(self, image, parameters, center): - if self.sampling_mode == "midpoint": - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - center[..., None, None] - mid = self.evaluate_model(X=X, Y=Y, image=image, parameters=parameters) - kernel = curvature_kernel(AP_config.ap_dtype, AP_config.ap_device) - # convolve curvature kernel to numericall compute second derivative - curvature = torch.nn.functional.pad( - torch.nn.functional.conv2d( - mid.view(1, 1, *mid.shape), - kernel.view(1, 1, *kernel.shape), - padding="valid", - ), - (1, 1, 1, 1), - mode="replicate", - ).squeeze() - return mid + curvature, mid - elif self.sampling_mode == "simpsons": - Coords = image.get_coordinate_simps_meshgrid() - X, Y = Coords - center[..., None, None] - dens = self.evaluate_model(X=X, Y=Y, image=image, parameters=parameters) - kernel = simpsons_kernel(dtype=AP_config.ap_dtype, device=AP_config.ap_device) - # midpoint is just every other sample in the simpsons grid - mid = dens[1::2, 1::2] - simps = torch.nn.functional.conv2d( - dens.view(1, 1, *dens.shape), kernel, stride=2, padding="valid" - ) - return mid.squeeze(), simps.squeeze() - elif "quad" in self.sampling_mode: - quad_level = int(self.sampling_mode[self.sampling_mode.find(":") + 1 :]) - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - center[..., None, None] - res, ref = single_quad_integrate( - X=X, - Y=Y, - image_header=image.header, - eval_brightness=self.evaluate_model, - eval_parameters=parameters, - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - quad_level=quad_level, - ) - return ref, res - elif self.sampling_mode == "trapezoid": - Coords = image.get_coordinate_corner_meshgrid() - X, Y = Coords - center[..., None, None] - dens = self.evaluate_model(X=X, Y=Y, image=image, parameters=parameters) - kernel = ( - torch.ones((1, 1, 2, 2), dtype=AP_config.ap_dtype, device=AP_config.ap_device) / 4.0 - ) - trapz = torch.nn.functional.conv2d(dens.view(1, 1, *dens.shape), kernel, padding="valid") - trapz = trapz.squeeze() - kernel = curvature_kernel(AP_config.ap_dtype, AP_config.ap_device) - curvature = torch.nn.functional.pad( - torch.nn.functional.conv2d( - trapz.view(1, 1, *trapz.shape), - kernel.view(1, 1, *kernel.shape), - padding="valid", - ), - (1, 1, 1, 1), - mode="replicate", - ).squeeze() - return trapz + curvature, trapz - - raise SpecificationConflict( - f"{self.name} has unknown sampling mode: {self.sampling_mode}. Should be one of: midpoint, simpsons, quad:level, trapezoid" - ) - - -def _integrate_reference(self, image_data, image_header, parameters): - return torch.sum(image_data) / image_data.numel() - - -def _sample_integrate(self, deep, reference, image, parameters, center): - if self.integrate_mode == "none": - pass - elif self.integrate_mode == "threshold": - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - center[..., None, None] - ref = self._integrate_reference( - deep, image.header, parameters - ) # fixme, error can be over 100% on initial sampling reference is invalid - error = torch.abs((deep - reference)) - select = error > (self.sampling_tolerance * ref) - intdeep = grid_integrate( - X=X[select], - Y=Y[select], - image_header=image.header, - eval_brightness=self.evaluate_model, - eval_parameters=parameters, - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - quad_level=self.integrate_quad_level, - gridding=self.integrate_gridding, - max_depth=self.integrate_max_depth, - reference=self.sampling_tolerance * ref, - ) - deep[select] = intdeep - else: - raise SpecificationConflict( - f"{self.name} has unknown integration mode: {self.integrate_mode}. Should be one of: none, threshold" - ) - return deep - - -def _shift_psf(self, psf, shift, shift_method="bilinear", keep_pad=True): - if shift_method == "bilinear": - psf_data = torch.nn.functional.pad(psf.data, (1, 1, 1, 1)) - X, Y = torch.meshgrid( - torch.arange( - psf_data.shape[1], - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - - shift[0], - torch.arange( - psf_data.shape[0], - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - - shift[1], - indexing="xy", - ) - shift_psf = interp2d(psf_data, X.clone(), Y.clone()) - if not keep_pad: - shift_psf = shift_psf[1:-1, 1:-1] - - elif "lanczos" in shift_method: - lanczos_order = int(shift_method[shift_method.find(":") + 1 :]) - psf_data = torch.nn.functional.pad( - psf.data, (lanczos_order, lanczos_order, lanczos_order, lanczos_order) - ) - LL = _shift_Lanczos_kernel_torch( - -shift[0], - -shift[1], - lanczos_order, - AP_config.ap_dtype, - AP_config.ap_device, - ) - shift_psf = torch.nn.functional.conv2d( - psf_data.view(1, 1, *psf_data.shape), - LL.view(1, 1, *LL.shape), - padding="same", - ).squeeze() - if not keep_pad: - shift_psf = shift_psf[lanczos_order:-lanczos_order, lanczos_order:-lanczos_order] - else: - raise SpecificationConflict(f"unrecognized subpixel shift method: {shift_method}") - return shift_psf - - -def _sample_convolve(self, image, shift, psf, shift_method="bilinear"): - """ - image: Image object with image.data pixel matrix - shift: the amount of shifting to do in pixel units - psf: a PSF_Image object - """ - if shift is not None: - shift_psf = self._shift_psf(psf, shift, shift_method) - else: - shift_psf = psf.data - shift_psf = shift_psf / torch.sum(shift_psf) - - if self.psf_convolve_mode == "fft": - image.data = fft_convolve_torch(image.data, shift_psf, img_prepadded=True) - elif self.psf_convolve_mode == "direct": - image.data = torch.nn.functional.conv2d( - image.data.view(1, 1, *image.data.shape), - torch.flip( - shift_psf.view(1, 1, *shift_psf.shape), - dims=(2, 3), - ), - padding="same", - ).squeeze() - else: - raise ValueError(f"unrecognized psf_convolve_mode: {self.psf_convolve_mode}") - - -@torch.no_grad() -def jacobian( - self, - parameters: Optional[torch.Tensor] = None, - as_representation: bool = False, - window: Optional[Window] = None, - pass_jacobian: Optional[Jacobian_Image] = None, - **kwargs, -): - """Compute the Jacobian matrix for this model. - - The Jacobian matrix represents the partial derivatives of the - model's output with respect to its input parameters. It is useful - in optimization and model fitting processes. This method - simplifies the process of computing the Jacobian matrix for - astronomical image models and is primarily used by the - Levenberg-Marquardt algorithm for model fitting tasks. - - Args: - parameters (Optional[torch.Tensor]): A 1D parameter tensor to override the - current model's parameters. - as_representation (bool): Indicates if the parameters argument is - provided as real values or representations - in the (-inf, inf) range. Default is False. - parameters_identity (Optional[tuple]): Specifies which parameters are to be - considered in the computation. - window (Optional[Window]): A window object specifying the region of interest - in the image. - **kwargs: Additional keyword arguments. - - Returns: - Jacobian_Image: A Jacobian_Image object containing the computed Jacobian matrix. - - """ - if window is None: - window = self.window - else: - if isinstance(window, Window_List): - window = window.window_list[pass_jacobian.index(self.target)] - window = self.window & window - - # skip jacobian calculation if no parameters match criteria - if torch.sum(self.parameters.vector_mask()) == 0 or window.overlap_frac(self.window) <= 0: - return self.target[window].jacobian_image() - - # Set the parameters if provided and check the size of the parameter list - if parameters is not None: - if as_representation: - self.parameters.vector_set_representation(parameters) - else: - self.parameters.vector_set_values(parameters) - if torch.sum(self.parameters.vector_mask()) > self.jacobian_chunksize: - return self._chunk_jacobian( - as_representation=as_representation, - window=window, - **kwargs, - ) - if torch.max(window.pixel_shape) > self.image_chunksize: - return self._chunk_image_jacobian( - as_representation=as_representation, - window=window, - **kwargs, - ) - - # Compute the jacobian - full_jac = torchjac( - lambda P: self( - image=None, - parameters=P, - as_representation=as_representation, - window=window, - ).data, - ( - self.parameters.vector_representation().detach() - if as_representation - else self.parameters.vector_values().detach() - ), - strategy="forward-mode", - vectorize=True, - create_graph=False, - ) - - # Store the jacobian as a Jacobian_Image object - jac_img = self.target[window].jacobian_image( - parameters=self.parameters.vector_identities(), - data=full_jac, - ) - return jac_img - - -@torch.no_grad() -def _chunk_image_jacobian( - self, - as_representation: bool = False, - parameters_identity: Optional[tuple] = None, - window: Optional[Window] = None, - **kwargs, -): - """Evaluates the Jacobian in smaller chunks to reduce memory usage. - - For models acting on large windows it can be prohibitive to build - the full Jacobian in a single pass. Instead this function breaks - the image into chunks as determined by `self.image_chunksize` - evaluates the Jacobian only for the sub-images, it then builds up - the full Jacobian as a separate tensor. - - This is for internal use and should be called by the - `self.jacobian` function when appropriate. - - """ - - pids = self.parameters.vector_identities() - jac_img = self.target[window].jacobian_image( - parameters=pids, - ) - - pixel_shape = window.pixel_shape.detach().cpu().numpy() - Ncells = np.int64(np.round(np.ceil(pixel_shape / self.image_chunksize))) - cellsize = np.int64(np.round(window.pixel_shape / Ncells)) - - for nx in range(Ncells[0]): - for ny in range(Ncells[1]): - subwindow = window.copy() - subwindow.crop_to_pixel( - ( - (cellsize[0] * nx, min(pixel_shape[0], cellsize[0] * (nx + 1))), - (cellsize[1] * ny, min(pixel_shape[1], cellsize[1] * (ny + 1))), - ) - ) - jac_img += self.jacobian( - parameters=None, - as_representation=as_representation, - window=subwindow, - **kwargs, - ) - - return jac_img - - -@torch.no_grad() -def _chunk_jacobian( - self, - as_representation: bool = False, - parameters_identity: Optional[tuple] = None, - window: Optional[Window] = None, - **kwargs, -): - """Evaluates the Jacobian in small chunks to reduce memory usage. - - For models with many parameters it can be prohibitive to build the - full Jacobian in a single pass. Instead this function breaks the - list of parameters into chunks as determined by - `self.jacobian_chunksize` evaluates the Jacobian only for those, - it then builds up the full Jacobian as a separate tensor. This is - for internal use and should be called by the `self.jacobian` - function when appropriate. - - """ - pids = self.parameters.vector_identities() - jac_img = self.target[window].jacobian_image( - parameters=pids, - ) - - for ichunk in range(0, len(pids), self.jacobian_chunksize): - mask = torch.zeros(len(pids), dtype=torch.bool, device=AP_config.ap_device) - mask[ichunk : ichunk + self.jacobian_chunksize] = True - with Param_Mask(self.parameters, mask): - jac_img += self.jacobian( - parameters=None, - as_representation=as_representation, - window=window, - **kwargs, - ) - - return jac_img - - -def load(self, filename: Union[str, dict, io.TextIOBase] = "AstroPhot.yaml", new_name=None): - """Used to load the model from a saved state. - - Sets the model window to the saved value and updates all - parameters with the saved information. This overrides the - current parameter settings. - - Args: - filename: The source from which to load the model parameters. Can be a string (the name of the file on disc), a dictionary (formatted as if from self.get_state), or an io.TextIOBase (a file stream to load the file from). - - """ - state = AstroPhot_Model.load(filename) - if new_name is None: - new_name = state["name"] - self.name = new_name - # Use window saved state to initialize model window - self.window = Window(**state["window"]) - # reassign target in case a target list was given - self._target_identity = state["target_identity"] - self.target = self.target - # Set any attributes which were not default - for key in self.track_attrs: - if key in state: - setattr(self, key, state[key]) - # Load the parameter group, this is handled by the parameter group object - if isinstance(state["parameters"], Parameter_Node): - self.parameters = state["parameters"] - else: - self.parameters = Parameter_Node(self.name, state=state["parameters"]) - # Move parameters to the appropriate device and dtype - self.parameters.to(dtype=AP_config.ap_dtype, device=AP_config.ap_device) - # Re-create the aux PSF model if there was one - if "psf" in state: - if state["psf"].get("type", "AstroPhot_Model") == "PSF_Image": - self.psf = PSF_Image(state=state["psf"]) - else: - print(state["psf"]) - state["psf"]["parameters"] = self.parameters[state["psf"]["name"]] - self.set_aux_psf( - AstroPhot_Model( - name=state["psf"]["name"], - filename=state["psf"], - target=self.target, - ) - ) - return state diff --git a/astrophot/models/_shared_methods.py b/astrophot/models/_shared_methods.py index be31ef0d..5a81c017 100644 --- a/astrophot/models/_shared_methods.py +++ b/astrophot/models/_shared_methods.py @@ -1,94 +1,48 @@ -import functools - from scipy.stats import binned_statistic, iqr import numpy as np import torch from scipy.optimize import minimize -from ..utils.initialize import isophotes -from ..utils.parametric_profiles import ( - sersic_torch, - gaussian_torch, - exponential_torch, - spline_torch, - moffat_torch, - nuker_torch, -) -from ..utils.conversions.coordinates import ( - Rotate_Cartesian, -) -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..image import ( - Image_List, - Model_Image_List, - Target_Image_List, - Window_List, -) -from ..param import Param_Unlock, Param_SoftLimits -from .. import AP_config - - -# Target Selector Decorator -###################################################################### -def select_target(func): - @functools.wraps(func) - def targeted(self, target=None, **kwargs): - if target is None: - send_target = self.target - elif isinstance(target, Target_Image_List) and not isinstance(self.target, Image_List): - for sub_target in target: - if sub_target.identity == self.target.identity: - send_target = sub_target - break - else: - raise RuntimeError("{self.name} could not find matching target to initialize with") - else: - send_target = target - return func(self, target=send_target, **kwargs) - - return targeted - +from ..utils.decorators import ignore_numpy_warnings +from .. import config +from ..backend_obj import backend -def select_sample(func): - @functools.wraps(func) - def targeted(self, image=None, **kwargs): - if isinstance(image, Model_Image_List) and not isinstance(self.target, Image_List): - for i, sub_image in enumerate(image): - if sub_image.target_identity == self.target.identity: - send_image = sub_image - if "window" in kwargs and isinstance(kwargs["window"], Window_List): - kwargs["window"] = kwargs["window"].window_list[i] - break - else: - raise RuntimeError(f"{self.name} could not find matching image to sample with") - else: - send_image = image - return func(self, image=send_image, **kwargs) - return targeted - - -def _sample_image(image, transform, metric, parameters, rad_bins=None): - dat = image.data.detach().cpu().clone().numpy() +def _sample_image( + image, + transform, + radius, + angle=None, + rad_bins=None, + angle_range=None, + cycle=2 * np.pi, +): + dat = backend.to_numpy(image._data).copy() # Fill masked pixels - if image.has_mask: - mask = image.mask.detach().cpu().numpy() - dat[mask] = np.median(dat[np.logical_not(mask)]) + mask = backend.to_numpy(image._mask) + dat[mask] = np.median(dat[~mask]) # Subtract median of edge pixels to avoid effect of nearby sources edge = np.concatenate((dat[:, 0], dat[:, -1], dat[0, :], dat[-1, :])) dat -= np.median(edge) # Get the radius of each pixel relative to object center - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - X, Y = transform(X, Y, image, parameters) - R = metric(X, Y, image, parameters).detach().cpu().numpy().flatten() + x, y = transform(*image.coordinate_center_meshgrid(), params=()) + R = backend.to_numpy(radius(x, y, params=())).flatten() + + if angle_range is not None: + T = backend.to_numpy(angle(x, y, params=())).flatten() + T = (T - angle_range[0]) % cycle + CHOOSE = T < (angle_range[1] - angle_range[0]) + R = R[CHOOSE] + dat = dat.flatten()[CHOOSE] + raveldat = dat.ravel() # Bin fluxes by radius if rad_bins is None: - rad_bins = np.logspace(np.log10(R.min() * 0.9), np.log10(R.max() * 1.1), 11) + rad_bins = np.logspace( + np.log10(R.min() * 0.9 + image.pixelscale / 2), np.log10(R.max() * 1.1), 11 + ) else: rad_bins = np.array(rad_bins) - raveldat = dat.ravel() I = ( binned_statistic(R, raveldat, statistic="median", bins=rad_bins)[0] ) / image.pixel_area.item() @@ -97,21 +51,29 @@ def _sample_image(image, transform, metric, parameters, rad_bins=None): R = (rad_bins[:-1] + rad_bins[1:]) / 2 # Ensure enough values are positive - I[I <= 0] = np.min(I[np.logical_and(np.isfinite(I), I > 0)]) + N = np.isfinite(I) + I[~N] = np.interp(R[~N], R[N], I[N]) + if np.sum(I > 0) <= 3: + I = np.abs(I) + N = I > 0 + if not np.all(N): + I[~N] = np.interp(R[~N], R[N], I[N]) # Ensure decreasing brightness with radius in outer regions for i in range(5, len(I)): - if I[i] >= I[i - 1] and np.isfinite(I[i - 1]): - I[i] = I[i - 1] - np.abs(I[i - 1] * 0.1) + if I[i] >= I[i - 1]: + I[i] = I[i - 1] * 0.9 # Convert to log scale S = S / (I * np.log(10)) I = np.log10(I) # Ensure finite N = np.isfinite(I) if not np.all(N): - I[np.logical_not(N)] = np.interp(R[np.logical_not(N)], R[N], I[N]) + I[~N] = np.interp(R[~N], R[N], I[N]) N = np.isfinite(S) if not np.all(N): - S[np.logical_not(N)] = np.abs(np.interp(R[np.logical_not(N)], R[N], S[N])) + S[~N] = np.abs(np.interp(R[~N], R[N], S[N])) + Sm = np.median(S) + S[S < Sm] = Sm # remove very small uncertainties return R, I, S @@ -120,534 +82,76 @@ def _sample_image(image, transform, metric, parameters, rad_bins=None): ###################################################################### @torch.no_grad() @ignore_numpy_warnings -def parametric_initialize( - model, parameters, target, prof_func, params, x0_func, force_uncertainty=None -): - if all(list(parameters[param].value is not None for param in params)): +def parametric_initialize(model, target, prof_func, params, x0_func): + if all(list(model[param].initialized for param in params)): return - # Get the sub-image area corresponding to the model image - target_area = target[model.window] - R, I, S = _sample_image( - target_area, model.transform_coordinates, model.radius_metric, parameters - ) + R, I, S = _sample_image(target, model.transform_coordinates, model.radius_metric) x0 = list(x0_func(model, R, I)) for i, param in enumerate(params): - x0[i] = x0[i] if parameters[param].value is None else parameters[param].value.item() + x0[i] = x0[i] if not model[param].initialized else model[param].npvalue - def optim(x, r, f): - residual = (f - np.log10(prof_func(r, *x))) ** 2 + def optim(x, r, f, u): + residual = ((f - np.nan_to_num(np.log10(prof_func(r, *x)), nan=np.min(f))) / u) ** 2 N = np.argsort(residual) return np.mean(residual[N][:-2]) - res = minimize(optim, x0=x0, args=(R, I), method="Nelder-Mead") - if not res.success and AP_config.ap_verbose >= 2: - AP_config.ap_logger.warning( - f"initialization fit not successful for {model.name}, falling back to defaults" - ) + res = minimize(optim, x0=x0, args=(R, I, S), method="Nelder-Mead") + + if res.success: + x0 = res.x - if force_uncertainty is None: - reses = [] - for i in range(10): - N = np.random.randint(0, len(R), len(R)) - reses.append(minimize(optim, x0=x0, args=(R[N], I[N]), method="Nelder-Mead")) - for param, resx, x0x in zip(params, res.x, x0): - with Param_Unlock(parameters[param]), Param_SoftLimits(parameters[param]): - if parameters[param].value is None: - parameters[param].value = resx if res.success else x0x - if force_uncertainty is None and parameters[param].uncertainty is None: - parameters[param].uncertainty = np.std( - list(subres.x[params.index(param)] for subres in reses) - ) - elif force_uncertainty is not None: - parameters[param].uncertainty = force_uncertainty[params.index(param)] + for param, x0x in zip(params, x0): + if not model[param].initialized: + x0x = backend.as_array(x0x, dtype=config.DTYPE, device=config.DEVICE) + if not model[param].is_valid(x0x): + x0x = model[param].soft_valid(x0x) + model[param].value = x0x @torch.no_grad() @ignore_numpy_warnings def parametric_segment_initialize( model=None, - parameters=None, target=None, prof_func=None, params=None, x0_func=None, segments=None, - force_uncertainty=None, ): - if all(list(model[param].value is not None for param in params)): + if all(list(model[param].initialized for param in params)): return - # Get the sub-image area corresponding to the model image - target_area = target[model.window] - target_dat = target_area.data.detach().cpu().numpy().copy() - if target_area.has_mask: - mask = target_area.mask.detach().cpu().numpy() - target_dat[mask] = np.median(target_dat[np.logical_not(mask)]) - edge = np.concatenate( - ( - target_dat[:, 0], - target_dat[:, -1], - target_dat[0, :], - target_dat[-1, :], - ) - ) - edge_average = np.median(edge) - edge_scatter = iqr(edge, rng=(16, 84)) / 2 - # Convert center coordinates to target area array indices - icenter = target_area.plane_to_pixel(model["center"].value) - - iso_info = isophotes( - target_dat - edge_average, - (icenter[1].item(), icenter[0].item()), - threshold=3 * edge_scatter, - pa=(model["PA"].value - target.north).item() if "PA" in model else 0.0, - q=model["q"].value.item() if "q" in model else 1.0, - n_isophotes=15, - more=True, - ) - R = np.array(list(iso["R"] for iso in iso_info)) * target.pixel_length.item() - was_none = list(False for i in range(len(params))) - val = {} - unc = {} - for i, p in enumerate(params): - if model[p].value is None: - was_none[i] = True - val[p] = np.zeros(segments) - unc[p] = np.zeros(segments) - for r in range(segments): - flux = [] - for iso in iso_info: - modangles = ( - iso["angles"] - - ((model["PA"].value - target.north).detach().cpu().item() + r * np.pi / segments) - ) % np.pi - flux.append( - np.median( - iso["isovals"][ - np.logical_or( - modangles < (0.5 * np.pi / segments), - modangles >= (np.pi * (1 - 0.5 / segments)), - ) - ] - ) - ) - flux = np.array(flux) / target.pixel_area.item() - if np.sum(flux < 0) >= 1: - flux -= np.min(flux) - np.abs(np.min(flux) * 0.1) - flux = np.log10(flux) - - x0 = list(x0_func(model, R, flux)) - for i, param in enumerate(params): - x0[i] = x0[i] if was_none[i] else model[param].value.detach().cpu().numpy()[r] - res = minimize( - lambda x: np.mean((flux - np.log10(prof_func(R, *x))) ** 2), - x0=x0, - method="Nelder-Mead", - ) - if force_uncertainty is None: - reses = [] - for i in range(10): - N = np.random.randint(0, len(R), len(R)) - reses.append( - minimize( - lambda x: np.mean((flux - np.log10(prof_func(R, *x))) ** 2), - x0=x0, - method="Nelder-Mead", - ) - ) - for i, param in enumerate(params): - if was_none[i]: - val[param][r] = res.x[i] if res.success else x0[i] - if force_uncertainty is None and model[param].uncertainty is None: - unc[r] = np.std(list(subres.x[params.index(param)] for subres in reses)) - elif force_uncertainty is not None: - unc[r] = force_uncertainty[params.index(param)][r] - - with Param_Unlock(model[param]), Param_SoftLimits(model[param]): - model[param].value = val[param] - model[param].uncertainty = unc[param] - - -# Evaluate_Model -###################################################################### -@default_internal -def radial_evaluate_model(self, X=None, Y=None, image=None, parameters=None): - if X is None: - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - return self.radial_model( - self.radius_metric(X, Y, image=image, parameters=parameters), - image=image, - parameters=parameters, - ) - - -@default_internal -def transformed_evaluate_model(self, X=None, Y=None, image=None, parameters=None, **kwargs): - if X is None or Y is None: - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - X, Y = self.transform_coordinates(X, Y, image, parameters) - return self.radial_model( - self.radius_metric(X, Y, image=image, parameters=parameters), - image=image, - parameters=parameters, - ) - - -# Transform Coordinates -###################################################################### -@default_internal -def inclined_transform_coordinates(self, X, Y, image=None, parameters=None): - X, Y = Rotate_Cartesian(-(parameters["PA"].value - image.north), X, Y) - return ( - X, - Y / parameters["q"].value, - ) - - -# Exponential -###################################################################### -@default_internal -def exponential_radial_model(self, R, image=None, parameters=None): - return exponential_torch( - R, - parameters["Re"].value, - image.pixel_area * 10 ** parameters["Ie"].value, - ) - - -@default_internal -def exponential_iradial_model(self, i, R, image=None, parameters=None): - return exponential_torch( - R, - parameters["Re"].value[i], - image.pixel_area * 10 ** parameters["Ie"].value[i], - ) - - -# Sersic -###################################################################### -@default_internal -def sersic_radial_model(self, R, image=None, parameters=None): - return sersic_torch( - R, - parameters["n"].value, - parameters["Re"].value, - image.pixel_area * 10 ** parameters["Ie"].value, - ) - - -@default_internal -def sersic_iradial_model(self, i, R, image=None, parameters=None): - return sersic_torch( - R, - parameters["n"].value[i], - parameters["Re"].value[i], - image.pixel_area * 10 ** parameters["Ie"].value[i], - ) - - -# Moffat -###################################################################### -@default_internal -def moffat_radial_model(self, R, image=None, parameters=None): - return moffat_torch( - R, - parameters["n"].value, - parameters["Rd"].value, - image.pixel_area * 10 ** parameters["I0"].value, - ) - - -@default_internal -def moffat_iradial_model(self, i, R, image=None, parameters=None): - return moffat_torch( - R, - parameters["n"].value[i], - parameters["Rd"].value[i], - image.pixel_area * 10 ** parameters["I0"].value[i], - ) - -# Nuker Profile -###################################################################### -@default_internal -def nuker_radial_model(self, R, image=None, parameters=None): - return nuker_torch( - R, - parameters["Rb"].value, - image.pixel_area * 10 ** parameters["Ib"].value, - parameters["alpha"].value, - parameters["beta"].value, - parameters["gamma"].value, - ) - - -@default_internal -def nuker_iradial_model(self, i, R, image=None, parameters=None): - return nuker_torch( - R, - parameters["Rb"].value[i], - image.pixel_area * 10 ** parameters["Ib"].value[i], - parameters["alpha"].value[i], - parameters["beta"].value[i], - parameters["gamma"].value[i], - ) - - -# Gaussian -###################################################################### -@default_internal -def gaussian_radial_model(self, R, image=None, parameters=None): - return gaussian_torch( - R, - parameters["sigma"].value, - image.pixel_area * 10 ** parameters["flux"].value, - ) - - -@default_internal -def gaussian_iradial_model(self, i, R, image=None, parameters=None): - return gaussian_torch( - R, - parameters["sigma"].value[i], - image.pixel_area * 10 ** parameters["flux"].value[i], - ) - - -# Spline -###################################################################### -@torch.no_grad() -@ignore_numpy_warnings -@select_target -@default_internal -def spline_initialize(self, target=None, parameters=None, **kwargs): - super(self.__class__, self).initialize(target=target, parameters=parameters) - - if parameters["I(R)"].value is not None and parameters["I(R)"].prof is not None: - return - - # Create the I(R) profile radii if needed - if parameters["I(R)"].prof is None: - new_prof = [0, 2 * target.pixel_length] - while new_prof[-1] < torch.max(self.window.shape / 2): - new_prof.append(new_prof[-1] + torch.max(2 * target.pixel_length, new_prof[-1] * 0.2)) - new_prof.pop() - new_prof.pop() - new_prof.append(torch.sqrt(torch.sum((self.window.shape / 2) ** 2))) - parameters["I(R)"].prof = new_prof - - profR = parameters["I(R)"].prof.detach().cpu().numpy() - target_area = target[self.window] - R, I, S = _sample_image( - target_area, - self.transform_coordinates, - self.radius_metric, - parameters, - rad_bins=[profR[0]] + list((profR[:-1] + profR[1:]) / 2) + [profR[-1] * 100], - ) - with Param_Unlock(parameters["I(R)"]), Param_SoftLimits(parameters["I(R)"]): - parameters["I(R)"].value = I - parameters["I(R)"].uncertainty = S - - -@torch.no_grad() -@ignore_numpy_warnings -@select_target -@default_internal -def spline_segment_initialize( - self, target=None, parameters=None, segments=1, symmetric=True, **kwargs -): - super(self.__class__, self).initialize(target=target, parameters=parameters) - - if parameters["I(R)"].value is not None and parameters["I(R)"].prof is not None: - return - - # Create the I(R) profile radii if needed - if parameters["I(R)"].prof is None: - new_prof = [0, 2 * target.pixel_length] - while new_prof[-1] < torch.max(self.window.shape / 2): - new_prof.append(new_prof[-1] + torch.max(2 * target.pixel_length, new_prof[-1] * 0.2)) - new_prof.pop() - new_prof.pop() - new_prof.append(torch.sqrt(torch.sum((self.window.shape / 2) ** 2))) - parameters["I(R)"].prof = new_prof - - profR = parameters["I(R)"].prof.detach().cpu().numpy() - target_area = target[self.window] - target_dat = target_area.data.detach().cpu().numpy().copy() - if target_area.has_mask: - mask = target_area.mask.detach().cpu().numpy() - target_dat[mask] = np.median(target_dat[np.logical_not(mask)]) - Coords = target_area.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - X, Y = self.transform_coordinates(X, Y, target, parameters) - R = self.radius_metric(X, Y, target, parameters).detach().cpu().numpy() - T = self.angular_metric(X, Y, target, parameters).detach().cpu().numpy() - rad_bins = [profR[0]] + list((profR[:-1] + profR[1:]) / 2) + [profR[-1] * 100] - raveldat = target_dat.ravel() - val = np.zeros((segments, len(parameters["I(R)"].prof))) - unc = np.zeros((segments, len(parameters["I(R)"].prof))) + cycle = np.pi if model.symmetric else 2 * np.pi + w = cycle / segments + v = w * np.arange(segments) + values = [] for s in range(segments): - if segments % 2 == 0 and symmetric: - angles = (T - (s * np.pi / segments)) % np.pi - TCHOOSE = np.logical_or( - angles < (np.pi / segments), angles >= (np.pi * (1 - 1 / segments)) - ) - elif segments % 2 == 1 and symmetric: - angles = (T - (s * np.pi / segments)) % (2 * np.pi) - TCHOOSE = np.logical_or( - angles < (np.pi / segments), angles >= (np.pi * (2 - 1 / segments)) - ) - angles = (T - (np.pi + s * np.pi / segments)) % (2 * np.pi) - TCHOOSE = np.logical_or( - TCHOOSE, - np.logical_or(angles < (np.pi / segments), angles >= (np.pi * (2 - 1 / segments))), - ) - elif segments % 2 == 0 and not symmetric: - angles = (T - (s * 2 * np.pi / segments)) % (2 * np.pi) - TCHOOSE = torch.logical_or( - angles < (2 * np.pi / segments), - angles >= (2 * np.pi * (1 - 1 / segments)), - ) - else: - angles = (T - (s * 2 * np.pi / segments)) % (2 * np.pi) - TCHOOSE = torch.logical_or( - angles < (2 * np.pi / segments), angles >= (np.pi * (2 - 1 / segments)) - ) - TCHOOSE = TCHOOSE.ravel() - I = ( - binned_statistic( - R.ravel()[TCHOOSE], raveldat[TCHOOSE], statistic="median", bins=rad_bins - )[0] - ) / target.pixel_area.item() - N = np.isfinite(I) - if not np.all(N): - I[np.logical_not(N)] = np.interp(profR[np.logical_not(N)], profR[N], I[N]) - S = binned_statistic( - R.ravel(), - raveldat, - statistic=lambda d: iqr(d, rng=[16, 84]) / 2, - bins=rad_bins, - )[0] - N = np.isfinite(S) - if not np.all(N): - S[np.logical_not(N)] = np.interp(profR[np.logical_not(N)], profR[N], S[N]) - val[s] = np.log10(np.abs(I)) - unc[s] = S / (np.abs(I) * np.log(10)) - with Param_Unlock(parameters["I(R)"]), Param_SoftLimits(parameters["I(R)"]): - parameters["I(R)"].value = val - parameters["I(R)"].uncertainty = unc - - -@default_internal -def spline_radial_model(self, R, image=None, parameters=None): - return ( - spline_torch( - R, - parameters["I(R)"].prof, - parameters["I(R)"].value, - extend=self.extend_profile, - ) - * image.pixel_area - ) - - -@default_internal -def spline_iradial_model(self, i, R, image=None, parameters=None): - return ( - spline_torch( - R, - parameters["I(R)"].prof, - parameters["I(R)"].value[i], - extend=self.extend_profile, + angle_range = (v[s] - w / 2, v[s] + w / 2) + # Get the sub-image area corresponding to the model image + R, I, S = _sample_image( + target, + model.transform_coordinates, + model.radius_metric, + angle=model.angular_metric, + angle_range=angle_range, + cycle=cycle, ) - * image.pixel_area - ) - -# RelSpline -###################################################################### -@torch.no_grad() -@ignore_numpy_warnings -@select_target -@default_internal -def relspline_initialize(self, target=None, parameters=None, **kwargs): - super(self.__class__, self).initialize(target=target, parameters=parameters) + x0 = list(x0_func(model, R, I)) - target_area = target[self.window] - target_dat = target_area.data.detach().cpu().numpy().copy() - if target_area.has_mask: - mask = target_area.mask.detach().cpu().numpy() - target_dat[mask] = np.median(target_dat[np.logical_not(mask)]) - if parameters["I0"].value is None: - center = target_area.plane_to_pixel(parameters["center"].value) - flux = target_dat[center[1].int().item(), center[0].int().item()] - with Param_Unlock(parameters["I0"]), Param_SoftLimits(parameters["I0"]): - parameters["I0"].value = np.log10(np.abs(flux) / target_area.pixel_area.item()) - parameters["I0"].uncertainty = 0.01 + def optim(x, r, f, u): + residual = ((f - np.log10(prof_func(r, *x))) / u) ** 2 + N = np.argsort(residual) + return np.mean(residual[N][:-2]) - if parameters["dI(R)"].value is not None and parameters["dI(R)"].prof is not None: - return + res = minimize(optim, x0=x0, args=(R, I, S), method="Nelder-Mead") + if res.success: + x0 = res.x - # Create the I(R) profile radii if needed - if parameters["dI(R)"].prof is None: - new_prof = [2 * target.pixel_length] - while new_prof[-1] < torch.max(self.window.shape / 2): - new_prof.append(new_prof[-1] + torch.max(2 * target.pixel_length, new_prof[-1] * 0.2)) - new_prof.pop() - new_prof.pop() - new_prof.append(torch.sqrt(torch.sum((self.window.shape / 2) ** 2))) - parameters["dI(R)"].prof = new_prof - - profR = parameters["dI(R)"].prof.detach().cpu().numpy() - - Coords = target_area.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - X, Y = self.transform_coordinates(X, Y, target, parameters) - R = self.radius_metric(X, Y, target, parameters).detach().cpu().numpy() - rad_bins = [profR[0]] + list((profR[:-1] + profR[1:]) / 2) + [profR[-1] * 100] - raveldat = target_dat.ravel() - - I = ( - binned_statistic(R.ravel(), raveldat, statistic="median", bins=rad_bins)[0] - ) / target.pixel_area.item() - N = np.isfinite(I) - if not np.all(N): - I[np.logical_not(N)] = np.interp(profR[np.logical_not(N)], profR[N], I[N]) - if I[-1] >= I[-2]: - I[-1] = I[-2] / 2 - S = binned_statistic( - R.ravel(), raveldat, statistic=lambda d: iqr(d, rng=[16, 84]) / 2, bins=rad_bins - )[0] - N = np.isfinite(S) - if not np.all(N): - S[np.logical_not(N)] = np.interp(profR[np.logical_not(N)], profR[N], S[N]) - with Param_Unlock(parameters["dI(R)"]), Param_SoftLimits(parameters["dI(R)"]): - parameters["dI(R)"].value = np.log10(np.abs(I)) - parameters["I0"].value.item() - parameters["dI(R)"].uncertainty = S / (np.abs(I) * np.log(10)) - - -@default_internal -def relspline_radial_model(self, R, image=None, parameters=None): - return ( - spline_torch( - R, - torch.cat( - ( - torch.zeros_like(parameters["I0"].value).unsqueeze(-1), - parameters["dI(R)"].prof, - ) - ), - torch.cat( - ( - parameters["I0"].value.unsqueeze(-1), - parameters["I0"].value + parameters["dI(R)"].value, - ) - ), - extend=self.extend_profile, - ) - * image.pixel_area - ) + values.append(x0) + values = np.stack(values).T + for param, v in zip(params, values): + if not model[param].initialized: + model[param].value = v diff --git a/astrophot/models/airy.py b/astrophot/models/airy.py new file mode 100644 index 00000000..ffc8d70e --- /dev/null +++ b/astrophot/models/airy.py @@ -0,0 +1,74 @@ +import torch +import numpy as np + +from ..utils.decorators import ignore_numpy_warnings, combine_docstrings +from .psf_model_object import PSFModel +from .mixins import RadialMixin +from ..param import forward +from ..backend_obj import backend, ArrayLike + +__all__ = ("AiryPSF",) + + +@combine_docstrings +class AiryPSF(RadialMixin, PSFModel): + """The Airy disk is an analytic description of the diffraction pattern + for a circular aperture. + + The diffraction pattern is described exactly by the configuration + of the lens system under the assumption that all elements are + perfect. This expression goes as: + + $$I(\\theta) = I_0\\left[\\frac{2J_1(x)}{x}\\right]^2$$ + $$x = ka\\sin(\\theta) = \\frac{2\\pi a r}{\\lambda R}$$ + + where $I(\\theta)$ is the intensity as a function of the + angular position within the diffraction system along its main + axis, $I_0$ is the central intensity of the airy disk, + $J_1$ is the Bessel function of the first kind of order one, + $k = \\frac{2\\pi}{\\lambda}$ is the wavenumber of the + light, $a$ is the aperture radius, $r$ is the radial + position from the center of the pattern, $R$ is the distance + from the circular aperture to the observation plane. + + In the `Airy_PSF` class we combine the parameters + $a,R,\\lambda$ into a single ratio to be optimized (or fixed + by the optical configuration). + + **Parameters:** + - `I0`: The central intensity of the airy disk in flux/arcsec^2. + - `aRL`: The ratio of the aperture radius to the + product of the wavelength and the distance from the aperture to the + observation plane, $\\frac{a}{R \\lambda}$. + + """ + + _model_type = "airy" + _parameter_specs = { + "I0": {"units": "flux/arcsec^2", "value": 1.0, "shape": (), "dynamic": False}, + "aRL": {"units": "a/(R lambda)", "shape": (), "dynamic": True}, + } + usable = True + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + if self.I0.initialized and self.aRL.initialized: + return + icenter = self.target.plane_to_pixel(*self.center.value) + + if not self.I0.initialized: + mid_chunk = self.target._data[ + int(icenter[0]) - 2 : int(icenter[0]) + 2, + int(icenter[1]) - 2 : int(icenter[1]) + 2, + ] + self.I0.value = backend.mean(mid_chunk) / self.target.pixel_area + if not self.aRL.initialized: + self.aRL.value = (5.0 / 8.0) * 2 * self.target.pixelscale + + @forward + def radial_model(self, R: ArrayLike, I0: ArrayLike, aRL: ArrayLike) -> ArrayLike: + x = 2 * np.pi * aRL * R + return I0 * (2 * backend.bessel_j1(x) / x) ** 2 diff --git a/astrophot/models/airy_psf.py b/astrophot/models/airy_psf.py deleted file mode 100644 index 81bed4ed..00000000 --- a/astrophot/models/airy_psf.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch - -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ._shared_methods import select_target -from .psf_model_object import PSF_Model -from ..param import Param_Unlock, Param_SoftLimits - -__all__ = ("Airy_PSF",) - - -class Airy_PSF(PSF_Model): - """The Airy disk is an analytic description of the diffraction pattern - for a circular aperture. - - The diffraction pattern is described exactly by the configuration - of the lens system under the assumption that all elements are - perfect. This expression goes as: - - .. math:: - - I(\\theta) = I_0\\left[\\frac{2J_1(x)}{x}\\right]^2 - - x = ka\\sin(\\theta) = \\frac{2\\pi a r}{\\lambda R} - - where :math:`I(\\theta)` is the intensity as a function of the - angular position within the diffraction system along its main - axis, :math:`I_0` is the central intensity of the airy disk, - :math:`J_1` is the Bessel function of the first kind of order one, - :math:`k = \\frac{2\\pi}{\\lambda}` is the wavenumber of the - light, :math:`a` is the aperture radius, :math:`r` is the radial - position from the center of the pattern, :math:`R` is the distance - from the circular aperture to the observation plane. - - In the `Airy_PSF` class we combine the parameters - :math:`a,R,\\lambda` into a single ratio to be optimized (or fixed - by the optical configuration). - - """ - - model_type = f"airy {PSF_Model.model_type}" - parameter_specs = { - "I0": {"units": "log10(flux/arcsec^2)", "value": 0.0, "locked": True}, - "aRL": {"units": "a/(R lambda)"}, - } - _parameter_order = PSF_Model._parameter_order + ("I0", "aRL") - usable = True - model_integrated = False - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - if (parameters["I0"].value is not None) and (parameters["aRL"].value is not None): - return - target_area = target[self.window] - icenter = target_area.plane_to_pixel(parameters["center"].value) - - if parameters["I0"].value is None: - with Param_Unlock(parameters["I0"]), Param_SoftLimits(parameters["I0"]): - parameters["I0"].value = torch.log10( - torch.mean( - target_area.data[ - int(icenter[0]) - 2 : int(icenter[0]) + 2, - int(icenter[1]) - 2 : int(icenter[1]) + 2, - ] - ) - / target.pixel_area.item() - ) - parameters["I0"].uncertainty = torch.std( - target_area.data[ - int(icenter[0]) - 2 : int(icenter[0]) + 2, - int(icenter[1]) - 2 : int(icenter[1]) + 2, - ] - ) / (torch.abs(parameters["I0"].value) * target.pixel_area) - if parameters["aRL"].value is None: - with Param_Unlock(parameters["aRL"]), Param_SoftLimits(parameters["aRL"]): - parameters["aRL"].value = (5.0 / 8.0) * 2 * target.pixel_length - parameters["aRL"].uncertainty = parameters["aRL"].value * self.default_uncertainty - - @default_internal - def radial_model(self, R, image=None, parameters=None): - x = 2 * torch.pi * parameters["aRL"].value * R - - return (image.pixel_area * 10 ** parameters["I0"].value) * ( - 2 * torch.special.bessel_j1(x) / x - ) ** 2 - - from ._shared_methods import radial_evaluate_model as evaluate_model diff --git a/astrophot/models/base.py b/astrophot/models/base.py new file mode 100644 index 00000000..04a3b99e --- /dev/null +++ b/astrophot/models/base.py @@ -0,0 +1,271 @@ +from typing import Optional, Union +from copy import deepcopy + +import numpy as np + +from caskade import Param as CParam +from ..param import Module, forward, Param +from ..utils.decorators import classproperty +from ..image import Window, ImageList, ModelImage, ModelImageList +from ..errors import UnrecognizedModel, InvalidWindow +from .. import config +from ..backend_obj import backend, ArrayLike +from . import func + +__all__ = ("Model",) + + +###################################################################### +class Model(Module): + """Base class for all AstroPhot models.""" + + _model_type = "model" + _parameter_specs = {} + # Softening length used for numerical stability and/or integration stability to avoid discontinuities (near R=0) + softening = 1e-3 # arcsec + _options = ("softening",) + usable = False + + def __new__(cls, *, filename=None, model_type=None, **kwargs): + if filename is not None: + state = Model.load(filename) + MODELS = Model.List_Models() + for M in MODELS: + if M.model_type == state["model_type"]: + return super(Model, cls).__new__(M) + else: + raise UnrecognizedModel(f"Unknown AstroPhot model type: {state['model_type']}") + elif model_type is not None: + MODELS = Model.List_Models() # all_subclasses(Model) + for M in MODELS: + if M.model_type == model_type: + return super(Model, cls).__new__(M) + else: + raise UnrecognizedModel(f"Unknown AstroPhot model type: {model_type}") + + return super().__new__(cls) + + def __init__(self, *, name=None, target=None, window=None, mask=None, filename=None, **kwargs): + super().__init__(name=name) + self.target = target + self.window = window + self.mask = mask + + # Set any user defined options for the model + for kwarg in list(kwargs.keys()): + if kwarg in self.options: + setattr(self, kwarg, kwargs.pop(kwarg)) + + # Create Param objects for this Module + parameter_specs = self.build_parameter_specs(kwargs, self.parameter_specs) + for key in parameter_specs: + param = Param(key, **parameter_specs[key], dtype=config.DTYPE, device=config.DEVICE) + setattr(self, key, param) + + self.saveattrs.update(self.options) + self.saveattrs.add("window.extent") + + kwargs.pop("model_type", None) # model_type is set by __new__ + if len(kwargs) > 0: + raise TypeError( + f"Unrecognized keyword arguments for {self.__class__.__name__}: {', '.join(kwargs.keys())}" + ) + + @classproperty + def model_type(cls) -> str: + collected = [] + for subcls in cls.mro(): + if subcls is object: + continue + mt = subcls.__dict__.get("_model_type", None) + if mt: + collected.append(mt) + return " ".join(collected) + + @classproperty + def options(cls) -> set: + options = set() + for subcls in cls.mro(): + if subcls is object: + continue + options.update(subcls.__dict__.get("_options", [])) + return options + + @classproperty + def parameter_specs(cls) -> dict: + """Collects all parameter specifications from the class hierarchy.""" + specs = {} + for subcls in reversed(cls.mro()): + if subcls is object: + continue + specs.update(getattr(subcls, "_parameter_specs", {})) + return specs + + def build_parameter_specs(self, kwargs, parameter_specs) -> dict: + parameter_specs = deepcopy(parameter_specs) + + for p in list(kwargs.keys()): + if p not in parameter_specs: + continue + if isinstance(kwargs[p], dict): + parameter_specs[p].update(kwargs.pop(p)) + else: + parameter_specs[p]["value"] = kwargs.pop(p) + if isinstance(parameter_specs[p].get("value", None), CParam) or callable( + parameter_specs[p].get("value", None) + ): + parameter_specs[p]["dynamic"] = False + + return parameter_specs + + @forward + def gaussian_log_likelihood( + self, + window: Optional[Window] = None, + ) -> ArrayLike: + """ + Compute the negative log likelihood of the model wrt the target image in the appropriate window. + """ + + if window is None: + window = self.window + model = self(window=window).data + data = self.target[window] + weight = data.weight + mask = data.mask + data = data.data + if isinstance(data, tuple): + nll = 0.5 * sum( + backend.sum(((da - mo) ** 2 * wgt)[~ma]) + for mo, da, wgt, ma in zip(model, data, weight, mask) + ) + else: + nll = 0.5 * backend.sum(((data - model) ** 2 * weight)[~mask]) + + return -nll + + @forward + def poisson_log_likelihood( + self, + window: Optional[Window] = None, + ) -> ArrayLike: + """ + Compute the negative log likelihood of the model wrt the target image in the appropriate window. + """ + if window is None: + window = self.window + model = self(window=window).data + data = self.target[window] + mask = data.mask + data = data.data + + if isinstance(data, tuple): + nll = sum( + backend.sum((mo - da * backend.log(mo + 1e-10) + backend.lgamma(da + 1))[~ma]) + for mo, da, ma in zip(model, data, mask) + ) + else: + nll = backend.sum( + (model - data * backend.log(model + 1e-10) + backend.lgamma(data + 1))[~mask] + ) + + return -nll + + def hessian(self, likelihood="gaussian"): + if likelihood == "gaussian": + return backend.hessian(self.gaussian_log_likelihood)(self.get_values()) + elif likelihood == "poisson": + return backend.hessian(self.poisson_log_likelihood)(self.get_values()) + else: + raise ValueError(f"Unknown likelihood type: {likelihood}") + + def total_flux(self, window=None) -> ArrayLike: + F = self(window=window) + return backend.sum(F.data) + + def total_flux_uncertainty(self, window=None) -> ArrayLike: + jac = self.jacobian(window=window).flatten("data") + dF = backend.sum(jac, dim=0) # VJP for sum(total_flux) + current_uncertainty = self.build_params_array_uncertainty() + return backend.sqrt(backend.sum((dF * current_uncertainty) ** 2)) + + def total_magnitude(self, window=None) -> ArrayLike: + """Compute the total magnitude of the model in the given window.""" + F = self.total_flux(window=window) + return -2.5 * backend.log10(F) + self.target.zeropoint + + def total_magnitude_uncertainty(self, window=None) -> ArrayLike: + """Compute the uncertainty in the total magnitude of the model in the given window.""" + F = self.total_flux(window=window) + dF = self.total_flux_uncertainty(window=window) + return 2.5 * (dF / F) / np.log(10) + + @property + def window(self) -> Optional[Window]: + """The window defines a region on the sky in which this model will be + optimized and typically evaluated. Two models with + non-overlapping windows are in effect independent of each + other. If there is another model with a window that spans both + of them, then they are tenuously connected. + + If not provided, the model will assume a window equal to the + target it is fitting. Note that in this case the window is not + explicitly set to the target window, so if the model is moved + to another target then the fitting window will also change. + + """ + if self._window is None: + if self.target is None: + raise ValueError( + "This model has no target or window, these must be provided by the user" + ) + return self.target.window + return self._window + + @window.setter + def window(self, window): + if window is None: + self._window = None + elif isinstance(window, Window): + self._window = window + elif len(window) in [2, 4]: + self._window = Window(window, image=self.target) + else: + raise InvalidWindow(f"Unrecognized window format: {str(window)}") + + @classmethod + def List_Models(cls, usable: Optional[bool] = None, types: bool = False) -> set: + MODELS = func.all_subclasses(cls) + result = set() + for model in MODELS: + if not (model.__dict__.get("usable", False) is usable or usable is None): + continue + if types: + result.add(model.model_type) + else: + result.add(model) + return result + + @forward + def radius_metric(self, x, y): + return backend.sqrt(x**2 + y**2 + self.softening**2) + + @forward + def angular_metric(self, x, y): + return backend.arctan2(y, x) + + def to(self, dtype=None, device=None): + if dtype is None: + dtype = config.DTYPE + if device is None: + device = config.DEVICE + super().to(dtype=dtype, device=device) + + @forward + def __call__( + self, + window: Optional[Window] = None, + **kwargs, + ) -> Union[ModelImage, ModelImageList]: + + return self.sample(window=window, **kwargs) diff --git a/astrophot/models/basis.py b/astrophot/models/basis.py new file mode 100644 index 00000000..1064943a --- /dev/null +++ b/astrophot/models/basis.py @@ -0,0 +1,114 @@ +from typing import Union, Tuple +import torch +import numpy as np + +from .psf_model_object import PSFModel +from ..utils.decorators import ignore_numpy_warnings, combine_docstrings +from ..utils.interpolate import interp2d +from .. import config +from ..backend_obj import backend, ArrayLike +from ..errors import SpecificationConflict +from ..param import forward +from . import func +from ..utils.initialize import polar_decomposition + +__all__ = ["BasisPSF"] + + +@combine_docstrings +class PixelBasisPSF(PSFModel): + """point source model which uses multiple images as a basis for the + PSF as its representation for point sources. Using bilinear interpolation it + will shift the PSF within a pixel to accurately represent the center + location of a point source. There is no functional form for this object type + as any image can be supplied. Bilinear interpolation is very fast and + accurate for smooth models, so it is possible to do the expensive + interpolation before optimization and save time. + + **Parameters:** + - `weights`: The weights of the basis set of images in units of flux. + - `PA`: The position angle of the PSF in radians. + - `scale`: The scale of the PSF in arcseconds per grid unit. + """ + + _model_type = "basis" + _parameter_specs = { + "weights": {"units": "flux", "dynamic": True}, + "PA": {"units": "radians", "shape": (), "dynamic": True}, + "scale": {"units": "arcsec/grid-unit", "shape": (), "dynamic": True}, + } + usable = True + + def __init__(self, *args, basis: Union[str, ArrayLike] = "zernike:3", **kwargs): + """Initialize the PixelBasisPSF model with a basis set of images.""" + super().__init__(*args, **kwargs) + self.basis = basis + + @property + def basis(self): + """The basis set of images used to form the eigen point source.""" + return self._basis + + @basis.setter + def basis(self, value: Union[str, ArrayLike]): + """Set the basis set of images. If value is None, the basis is initialized to an empty tensor.""" + if value is None: + raise SpecificationConflict( + "PixelBasisPSF requires a basis set of images to be provided." + ) + elif isinstance(value, str) and value.startswith("zernike:"): + self._basis = value + else: + # Transpose since pytorch uses (j, i) indexing when (i, j) is more natural for coordinates + self._basis = backend.transpose( + backend.as_array(value, dtype=config.DTYPE, device=config.DEVICE), 2, 1 + ) + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + target_area = self.target[self.window] + if not self.PA.initialized: + R, _ = polar_decomposition(self.target.CD.npvalue) + self.PA.value = np.arccos(np.abs(R[0, 0])) + if not self.scale.initialized: + self.scale.value = self.target.pixelscale.item() + if isinstance(self.basis, str) and self.basis.startswith("zernike:"): + order = int(self.basis.split(":")[1]) + nm = func.zernike_n_m_list(order) + N = int( + target_area._data.shape[0] * self.target.pixelscale.item() / self.scale.value.item() + ) + X, Y = np.meshgrid( + np.linspace(-1, 1, N) * (N - 1) / N, + np.linspace(-1, 1, N) * (N - 1) / N, + indexing="ij", + ) + R = np.sqrt(X**2 + Y**2) + Phi = np.arctan2(Y, X) + basis = [] + for n, m in nm: + basis.append(func.zernike_n_m_modes(R, Phi, n, m)) + self.basis = np.stack(basis, axis=0) + + if not self.weights.initialized: + w = np.zeros(self.basis.shape[0]) + w[0] = 1.0 + self.weights.value = w + + @forward + def transform_coordinates( + self, x: ArrayLike, y: ArrayLike, PA: ArrayLike, scale: ArrayLike + ) -> Tuple[ArrayLike, ArrayLike]: + x, y = super().transform_coordinates(x, y) + i, j = func.rotate(-PA, x, y) + pixel_center = (self.basis.shape[1] - 1) / 2, (self.basis.shape[2] - 1) / 2 + return i / scale + pixel_center[0], j / scale + pixel_center[1] + + @forward + def brightness(self, x: ArrayLike, y: ArrayLike, weights: ArrayLike) -> ArrayLike: + x, y = self.transform_coordinates(x, y) + return backend.sum( + backend.vmap(lambda w, b: w * interp2d(b, x, y))(weights, self.basis), dim=0 + ) diff --git a/astrophot/models/bilinear_sky.py b/astrophot/models/bilinear_sky.py new file mode 100644 index 00000000..0d4873a3 --- /dev/null +++ b/astrophot/models/bilinear_sky.py @@ -0,0 +1,89 @@ +from typing import Tuple +import numpy as np +import torch + +from .sky_model_object import SkyModel +from ..utils.decorators import ignore_numpy_warnings, combine_docstrings +from ..utils.interpolate import interp2d +from ..param import forward +from ..backend_obj import backend, ArrayLike +from . import func +from ..utils.initialize import polar_decomposition + +__all__ = ["BilinearSky"] + + +@combine_docstrings +class BilinearSky(SkyModel): + """Sky background model using a coarse bilinear grid for the sky flux. + + **Parameters:** + - `I`: sky brightness grid + - `PA`: position angle of the sky grid in radians. + - `scale`: scale of the sky grid in arcseconds per grid unit. + + """ + + _model_type = "bilinear" + _parameter_specs = { + "I": {"units": "flux/arcsec^2", "dynamic": True}, + "PA": {"units": "radians", "shape": (), "dynamic": True}, + "scale": {"units": "arcsec/grid-unit", "shape": (), "dynamic": True}, + } + sampling_mode = "midpoint" + usable = True + + def __init__(self, *args, nodes: Tuple[int, int] = (3, 3), **kwargs): + """Initialize the BilinearSky model with a grid of nodes.""" + super().__init__(*args, **kwargs) + self.nodes = nodes + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + if self.I.initialized: + self.nodes = tuple(self.I.value.shape) + + if not self.PA.initialized: + R, _ = polar_decomposition(self.target.CD.npvalue) + self.PA.value = np.arccos(np.abs(R[0, 0])) + if not self.scale.initialized: + self.scale.value = ( + self.target.pixelscale.item() * self.target._data.shape[0] / self.nodes[0] + ) + + if self.I.initialized: + return + + target_dat = self.target[self.window] + dat = backend.to_numpy(target_dat._data).copy() + mask = backend.to_numpy(target_dat._mask).copy() + dat[mask] = np.nanmedian(dat) + iS = dat.shape[0] // self.nodes[0] + jS = dat.shape[1] // self.nodes[1] + + self.I.value = ( + np.median( + dat[: iS * self.nodes[0], : jS * self.nodes[1]].reshape( + iS, self.nodes[0], jS, self.nodes[1] + ), + axis=(0, 2), + ) + / self.target.pixel_area.item() + ) + + @forward + def transform_coordinates( + self, x: ArrayLike, y: ArrayLike, I: ArrayLike, PA: ArrayLike, scale: ArrayLike + ) -> Tuple[ArrayLike, ArrayLike]: + x, y = super().transform_coordinates(x, y) + i, j = func.rotate(-PA, x, y) + pixel_center = (I.shape[0] - 1) / 2, (I.shape[1] - 1) / 2 + return i / scale + pixel_center[0], j / scale + pixel_center[1] + + @forward + def brightness(self, x: ArrayLike, y: ArrayLike, I: ArrayLike) -> ArrayLike: + x, y = self.transform_coordinates(x, y) + return interp2d(I, x, y) diff --git a/astrophot/models/core_model.py b/astrophot/models/core_model.py deleted file mode 100644 index 00ce26a6..00000000 --- a/astrophot/models/core_model.py +++ /dev/null @@ -1,502 +0,0 @@ -import io -from typing import Optional - -import torch -import yaml -import numpy as np - -from ..utils.conversions.dict_to_hdf5 import dict_to_hdf5, hdf5_to_dict -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..image import Window, Target_Image, Target_Image_List -from ..param import Parameter_Node -from ._shared_methods import select_target, select_sample -from .. import AP_config -from ..errors import NameNotAllowed, InvalidTarget, UnrecognizedModel, InvalidWindow - -__all__ = ("AstroPhot_Model",) - - -def all_subclasses(cls): - return set(cls.__subclasses__()).union( - [s for c in cls.__subclasses__() for s in all_subclasses(c)] - ) - - -###################################################################### -class AstroPhot_Model(object): - """Core class for all AstroPhot models and model like objects. This - class defines the signatures to interact with AstroPhot models - both for users and internal functions. - - Basic usage: - - .. code-block:: python - - import astrophot as ap - - # Create a model object - model = ap.models.AstroPhot_Model( - name="unique name", - model_type="choose a model type", - target="Target_Image object", - window="[[xmin, xmax],[ymin,ymax]]", # , - parameters="dict of parameter specifications if desired", - ) - - # Initialize parameters that weren't set on creation - model.initialize() - - # Fit model to target - result = ap.fit.lm(model, verbose=1).fit() - - # Plot the model - fig, ax = plt.subplots() - ap.plots.model_image(fig, ax, model) - plt.show() - - # Sample the model - img = model() - pixels = img.data - - AstroPhot models are one of the main ways that one interacts with - the code, either by setting model parameters or passing models to - other objects, one can perform a huge variety of fitting - tasks. The subclass `Component_Model` should be thought of as the - basic unit when constructing a model of an image while a - `Group_Model` is a composite structure that may represent a - complex object, a region of an image, or even a model spanning - many images. Constructing the `Component_Model`s is where most - work goes, these store the actual parameters that will be - optimized. It is important to remember that a `Component_Model` - only ever applies to a single image and a single component (star, - galaxy, or even sub-component of one of those) in that image. - - A complex representation is made by stacking many - `Component_Model`s together, in total this may result in a very - large number of parameters. Trying to find starting values for all - of these parameters can be tedious and error prone, so instead all - built-in AstroPhot models can self initialize and find reasonable - starting parameters for most situations. Even still one may find - that for extremely complex fits, it is more stable to first run an - iterative fitter before global optimization to start the models in - better initial positions. - - Args: - name (Optional[str]): every AstroPhot model should have a unique name - model_type (str): a model type string can determine which kind of AstroPhot model is instantiated. - target (Optional[Target_Image]): A Target_Image object which stores information about the image which the model is trying to fit. - filename (Optional[str]): name of a file to load AstroPhot parameters, window, and name. The model will still need to be told its target, device, and other information - - """ - - model_type = "model" - default_uncertainty = 1e-2 # During initialization, uncertainty will be assumed 1% of initial value if no uncertainty is given - usable = False - model_names = [] - special_kwargs = ["parameters", "filename", "model_type", "usable"] - - def __new__(cls, *, filename=None, model_type=None, **kwargs): - if filename is not None: - state = AstroPhot_Model.load(filename) - MODELS = AstroPhot_Model.List_Models() - for M in MODELS: - if M.model_type == state["model_type"]: - return super(AstroPhot_Model, cls).__new__(M) - else: - raise UnrecognizedModel(f"Unknown AstroPhot model type: {state['model_type']}") - elif model_type is not None: - MODELS = AstroPhot_Model.List_Models() # all_subclasses(AstroPhot_Model) - for M in MODELS: - if M.model_type == model_type: - return super(AstroPhot_Model, cls).__new__(M) - else: - raise UnrecognizedModel(f"Unknown AstroPhot model type: {model_type}") - - return super().__new__(cls) - - def __init__(self, *, name=None, target=None, window=None, locked=False, **kwargs): - if not hasattr(self, "_window"): - self._window = None - if not hasattr(self, "_target"): - self._target = None - self.name = name - AP_config.ap_logger.debug("Creating model named: {self.name}") - self.parameters = Parameter_Node(self.name) - self.target = target - self.window = window - self._locked = locked - self.mask = kwargs.get("mask", None) - - # Set any user defined attributes for the model - for kwarg in kwargs: - # Skip parameters with special behaviour - if kwarg in self.special_kwargs: - continue - # Set the model parameter - setattr(self, kwarg, kwargs[kwarg]) - - @property - def name(self): - """The name for this model as a string. The name should be unique - though this is not enforced here. The name should not contain - the `|` or `:` characters as these are reserved for internal - use. If one tries to set the name of a model as `None` (for - example by not providing a name for the model) then a new - unique name will be generated. The unique name is just the - model type for this model with an extra unique id appended to - the end in the format of `[#]` where `#` is a number that - increases until a unique name is found. - - """ - return self._name - - @name.setter - def name(self, name): - try: - if name == self.name: - return - except AttributeError: - pass - if name is None: - i = 0 - while True: - proposed_name = f"{self.model_type} [{i}]" - if proposed_name in AstroPhot_Model.model_names: - i += 1 - else: - name = proposed_name - break - if ":" in name or "|" in name: - raise NameNotAllowed( - "characters '|' and ':' are reserved for internal model operations please do not include these in a model name" - ) - self._name = name - AstroPhot_Model.model_names.append(name) - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - """When this function finishes, all parameters should have numerical - values (non None) that are reasonable estimates of the final - values. - - """ - pass - - @property - def is_initialized(self): - """Returns True if all parameters have been initialized.""" - return all((not P.leaf) or (P.value is not None) for P in self.parameters) - - def make_model_image(self, window: Optional[Window] = None): - """This is called to create a blank `Model_Image` object of the - correct format for this model. This is typically used - internally to construct the model image before filling the - pixel values with the model. - - """ - if window is None: - window = self.window - else: - window = self.window & window - return self.target[window].model_image() - - def sample(self, image=None, window=None, parameters=None, *args, **kwargs): - """Calling this function should fill the given image with values - sampled from the given model. - - """ - pass - - def fit_mask(self): - """ - Return a mask to be used for fitting this model. This will block out - pixels that are not relevant to the model. - """ - return torch.zeros_like(self.target[self.window].mask) - - def negative_log_likelihood( - self, - parameters=None, - as_representation=False, - ): - """ - Compute the negative log likelihood of the model wrt the target image in the appropriate window. - """ - if parameters is not None: - if as_representation: - self.parameters.vector_set_representation(parameters) - else: - self.parameters.vector_set_values(parameters) - - model = self.sample() - data = self.target[self.window] - weight = data.weight - if self.target.has_mask: - if isinstance(data, Target_Image_List): - mask = tuple(torch.logical_not(submask) for submask in data.mask) - chi2 = sum( - torch.sum(((mo - da).data ** 2 * wgt)[ma]) / 2.0 - for mo, da, wgt, ma in zip(model, data, weight, mask) - ) - else: - mask = torch.logical_not(data.mask) - chi2 = torch.sum(((model - data).data ** 2 * weight)[mask]) / 2.0 - else: - if isinstance(data, Target_Image_List): - chi2 = sum( - torch.sum(((mo - da).data ** 2 * wgt)) / 2.0 - for mo, da, wgt in zip(model, data, weight) - ) - else: - chi2 = torch.sum(((model - data).data ** 2 * weight)) / 2.0 - - return chi2 - - def jacobian( - self, - parameters=None, - **kwargs, - ): - raise NotImplementedError("please use a subclass of AstroPhot_Model") - - @default_internal - def total_flux(self, parameters=None, window=None): - F = self(parameters=parameters, window=window, image=None) - return torch.sum(F.data) - - @default_internal - def total_flux_uncertainty(self, parameters=None, window=None): - current_state = parameters.vector_values() - jac = self.jacobian(parameters=current_state, window=window).flatten("data") - dF = torch.sum(jac, dim=0) # VJP for sum(total_flux) - current_uncertainty = self.parameters.vector_uncertainty() - return torch.sqrt(torch.sum((dF * current_uncertainty) ** 2)) - - @default_internal - def total_magnitude(self, parameters=None, window=None): - """Returns the total magnitude of the model in the given window.""" - F = self.total_flux(parameters=parameters, window=window) - return -2.5 * torch.log10(F) + self.target.header.zeropoint - - @default_internal - def total_magnitude_uncertainty(self, parameters=None, window=None): - """Returns the uncertainty in the total magnitude of the model in the given window.""" - F = self.total_flux(parameters=parameters, window=window) - dF = self.total_flux_uncertainty(parameters=parameters, window=window) - return torch.abs(2.5 * dF / (F * np.log(10))) - - @property - def window(self): - """The window defines a region on the sky in which this model will be - optimized and typically evaluated. Two models with - non-overlapping windows are in effect independent of each - other. If there is another model with a window that spans both - of them, then they are tenuously connected. - - If not provided, the model will assume a window equal to the - target it is fitting. Note that in this case the window is not - explicitly set to the target window, so if the model is moved - to another target then the fitting window will also change. - - """ - if self._window is None: - if self.target is None: - raise ValueError( - "This model has no target or window, these must be provided by the user" - ) - return self.target.window.copy() - return self._window - - def set_window(self, window): - if window is None: - # If no window given, set to none - self._window = None - elif isinstance(window, Window): - # If window object given, use that - self._window = window - elif len(window) == 2: - # If window given in pixels, use relative to target - self._window = self.target.window.copy().crop_to_pixel(window) - else: - raise InvalidWindow(f"Unrecognized window format: {str(window)}") - - @window.setter - def window(self, window): - self.set_window(window) - - @property - def target(self): - return self._target - - @target.setter - def target(self, tar): - if not (tar is None or isinstance(tar, Target_Image)): - raise InvalidTarget("AstroPhot_Model target must be a Target_Image instance.") - self._target = tar - - @property - def locked(self): - """Set when the model should remain fixed going forward. This model - will be bypassed when fitting parameters, however it will - still be sampled for generating the model image. - - Warning: - - This feature is not yet fully functional and should be avoided for now. It is included here for the sake of testing. - - """ - return self._locked - - @locked.setter - def locked(self, val): - self._locked = val - - @property - def parameter_order(self): - """Returns the model parameters in the order they are kept for - flattening, such as when evaluating the model with a tensor of - parameter values. - - """ - return tuple(P.name for P in self.parameters) - - def __str__(self): - """String representation for the model.""" - return self.parameters.__str__() - - def __repr__(self): - """Detailed string representation for the model.""" - return yaml.dump(self.get_state(), indent=2) - - def get_state(self, *args, **kwargs): - """Returns a dictionary of the state of the model with its name, - type, parameters, and other important information. This - dictionary is what gets saved when a model saves to disk. - - """ - state = { - "name": self.name, - "model_type": self.model_type, - } - return state - - def save(self, filename="AstroPhot.yaml"): - """Saves a model object to disk. By default the file type should be - yaml, this is the only file type which gets tested, though - other file types such as json and hdf5 should work. - - """ - if filename.endswith(".yaml"): - state = self.get_state() - with open(filename, "w") as f: - yaml.dump(state, f, indent=2) - elif filename.endswith(".json"): - import json - - state = self.get_state() - with open(filename, "w") as f: - json.dump(state, f, indent=2) - elif filename.endswith(".hdf5"): - import h5py - - state = self.get_state() - with h5py.File(filename, "w") as F: - dict_to_hdf5(F, state) - else: - if isinstance(filename, str) and "." in filename: - raise ValueError( - f"Unrecognized filename format: {filename[filename.find('.'):]}, must be one of: .json, .yaml, .hdf5" - ) - else: - raise ValueError( - f"Unrecognized filename format: {str(filename)}, must be one of: .json, .yaml, .hdf5" - ) - - @classmethod - def load(cls, filename="AstroPhot.yaml"): - """ - Loads a saved model object. - """ - if isinstance(filename, dict): - state = filename - elif isinstance(filename, io.TextIOBase): - state = yaml.load(filename, Loader=yaml.FullLoader) - elif filename.endswith(".yaml"): - with open(filename, "r") as f: - state = yaml.load(f, Loader=yaml.FullLoader) - elif filename.endswith(".json"): - import json - - with open(filename, "r") as f: - state = json.load(f) - elif filename.endswith(".hdf5"): - import h5py - - with h5py.File(filename, "r") as F: - state = hdf5_to_dict(F) - else: - if isinstance(filename, str) and "." in filename: - raise ValueError( - f"Unrecognized filename format: {filename[filename.find('.'):]}, must be one of: .json, .yaml, .hdf5" - ) - else: - raise ValueError( - f"Unrecognized filename format: {str(filename)}, must be one of: .json, .yaml, .hdf5 or python dictionary." - ) - return state - - @classmethod - def List_Models(cls, usable=None): - MODELS = all_subclasses(cls) - if usable is not None: - for model in list(MODELS): - if model.usable is not usable: - MODELS.remove(model) - return MODELS - - @classmethod - def List_Model_Names(cls, usable=None): - MODELS = cls.List_Models(usable=usable) - names = [] - for model in MODELS: - names.append(model.model_type) - return list(sorted(names, key=lambda n: n[::-1])) - - def __eq__(self, other): - return self is other - - def __getitem__(self, key): - return self.parameters[key] - - def __contains__(self, key): - return self.parameters.__contains__(key) - - def __del__(self): - try: - i = AstroPhot_Model.model_names.index(self.name) - AstroPhot_Model.model_names.pop(i) - except: - pass - - @select_sample - def __call__( - self, - image=None, - parameters=None, - window=None, - as_representation=False, - **kwargs, - ): - - if parameters is None: - parameters = self.parameters - elif isinstance(parameters, torch.Tensor): - if as_representation: - self.parameters.vector_set_representation(parameters) - else: - self.parameters.vector_set_values(parameters) - parameters = self.parameters - return self.sample(image=image, window=window, parameters=parameters, **kwargs) diff --git a/astrophot/models/edgeon.py b/astrophot/models/edgeon.py new file mode 100644 index 00000000..115e2334 --- /dev/null +++ b/astrophot/models/edgeon.py @@ -0,0 +1,136 @@ +from typing import Tuple +import torch +import numpy as np + +from .model_object import ComponentModel +from ..utils.decorators import ignore_numpy_warnings, combine_docstrings +from . import func +from ..backend_obj import backend, ArrayLike +from ..param import forward + +__all__ = ["EdgeonModel", "EdgeonSech", "EdgeonIsothermal"] + + +class EdgeonModel(ComponentModel): + """General Edge-On galaxy model to be subclassed for any specific + representation such as radial light profile or the structure of + the galaxy on the sky. Defines an edgeon galaxy as an object with + a position angle, no inclination information is included. + + **Parameters:** + - `PA`: Position angle of the edgeon disk in radians. + + """ + + _model_type = "edgeon" + _parameter_specs = { + "PA": { + "units": "radians", + "valid": (0, np.pi), + "cyclic": True, + "shape": (), + "dynamic": True, + }, + } + usable = False + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + if self.PA.initialized: + return + target_area = self.target[self.window] + dat = backend.to_numpy(target_area._data).copy() + mask = backend.to_numpy(target_area._mask) + dat[mask] = np.median(dat[~mask]) + + edge = np.concatenate((dat[:, 0], dat[:, -1], dat[0, :], dat[-1, :])) + edge_average = np.median(edge) + dat = dat - edge_average + + x, y = target_area.coordinate_center_meshgrid() + x = backend.to_numpy(x - self.center.value[0]) + y = backend.to_numpy(y - self.center.value[1]) + mu20 = np.median(dat * np.abs(x)) + mu02 = np.median(dat * np.abs(y)) + mu11 = np.median(dat * x * y / np.sqrt(np.abs(x * y))) + M = np.array([[mu20, mu11], [mu11, mu02]]) + if np.any(np.iscomplex(M)) or np.any(~np.isfinite(M)): + self.PA.value = np.pi / 2 + else: + self.PA.value = (0.5 * np.arctan2(2 * mu11, mu20 - mu02)) % np.pi + + @forward + def transform_coordinates( + self, x: ArrayLike, y: ArrayLike, PA: ArrayLike + ) -> Tuple[ArrayLike, ArrayLike]: + x, y = super().transform_coordinates(x, y) + return func.rotate(-(PA + np.pi / 2), x, y) + + +class EdgeonSech(EdgeonModel): + """An edgeon profile where the vertical distribution is a sech^2 + profile, subclasses define the radial profile. + + **Parameters:** + - `I0`: The central intensity of the sech^2 profile in flux/arcsec^2. + - `hs`: The scale height of the sech^2 profile in arcseconds. + """ + + _model_type = "sech2" + _parameter_specs = { + "I0": {"units": "flux/arcsec^2", "shape": (), "dynamic": True}, + "hs": {"units": "arcsec", "valid": (0, None), "shape": (), "dynamic": True}, + } + usable = False + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + if self.I0.initialized and self.hs.initialized: + return + target_area = self.target[self.window] + icenter = target_area.plane_to_pixel(*self.center.value) + + if not self.I0.initialized: + chunk = target_area.data[ + int(icenter[0]) - 2 : int(icenter[0]) + 2, + int(icenter[1]) - 2 : int(icenter[1]) + 2, + ] + self.I0.value = backend.mean(chunk) / self.target.pixel_area + if not self.hs.initialized: + self.hs.value = max(self.window.shape) * target_area.pixelscale * 0.1 + + @forward + def brightness(self, x: ArrayLike, y: ArrayLike, I0: ArrayLike, hs: ArrayLike) -> ArrayLike: + x, y = self.transform_coordinates(x, y) + return I0 * self.radial_model(x) / (backend.cosh((y + self.softening) / hs) ** 2) + + +@combine_docstrings +class EdgeonIsothermal(EdgeonSech): + """A self-gravitating locally-isothermal edgeon disk. This comes from + van der Kruit & Searle 1981. + + **Parameters:** + - `rs`: Scale radius of the isothermal disk in arcseconds. + """ + + _model_type = "isothermal" + _parameter_specs = {"rs": {"units": "arcsec", "valid": (0, None), "shape": (), "dynamic": True}} + usable = True + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + if self.rs.initialized: + return + self.rs.value = max(self.window.shape) * self.target.pixelscale * 0.4 + + @forward + def radial_model(self, R: ArrayLike, rs: ArrayLike) -> ArrayLike: + Rscaled = backend.abs(R / rs) + return Rscaled * backend.exp(-Rscaled) * backend.bessel_k1(Rscaled + self.softening / rs) diff --git a/astrophot/models/edgeon_model.py b/astrophot/models/edgeon_model.py deleted file mode 100644 index 83f6be84..00000000 --- a/astrophot/models/edgeon_model.py +++ /dev/null @@ -1,191 +0,0 @@ -from typing import Optional - -from scipy.stats import iqr -import torch -import numpy as np - -from .model_object import Component_Model -from ._shared_methods import select_target -from ..utils.initialize import isophotes -from ..utils.angle_operations import Angle_Average -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..param import Param_Unlock, Param_SoftLimits, Parameter_Node -from ..image import Image -from ..utils.conversions.coordinates import ( - Rotate_Cartesian, -) - -__all__ = ["Edgeon_Model"] - - -class Edgeon_Model(Component_Model): - """General Edge-On galaxy model to be subclassed for any specific - representation such as radial light profile or the structure of - the galaxy on the sky. Defines an edgeon galaxy as an object with - a position angle, no inclination information is included. - - """ - - model_type = f"edgeon {Component_Model.model_type}" - parameter_specs = { - "PA": { - "units": "rad", - "limits": (0, np.pi), - "cyclic": True, - "uncertainty": 0.06, - }, - } - _parameter_order = Component_Model._parameter_order + ("PA",) - usable = False - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters: Optional[Parameter_Node] = None, **kwargs): - super().initialize(target=target, parameters=parameters) - if parameters["PA"].value is not None: - return - target_area = target[self.window] - edge = np.concatenate( - ( - target_area.data[:, 0].detach().cpu().numpy(), - target_area.data[:, -1].detach().cpu().numpy(), - target_area.data[0, :].detach().cpu().numpy(), - target_area.data[-1, :].detach().cpu().numpy(), - ) - ) - edge_average = np.median(edge) - edge_scatter = iqr(edge, rng=(16, 84)) / 2 - icenter = target_area.plane_to_pixel(parameters["center"].value) - - iso_info = isophotes( - target_area.data.detach().cpu().numpy() - edge_average, - (icenter[1].detach().cpu().item(), icenter[0].detach().cpu().item()), - threshold=3 * edge_scatter, - pa=0.0, - q=1.0, - n_isophotes=15, - ) - with Param_Unlock(parameters["PA"]), Param_SoftLimits(parameters["PA"]): - parameters["PA"].value = ( - -( - ( - Angle_Average( - list(iso["phase2"] for iso in iso_info[-int(len(iso_info) / 3) :]) - ) - / 2 - ) - + target.north - ) - ) % np.pi - parameters["PA"].uncertainty = parameters["PA"].value * self.default_uncertainty - - @default_internal - def transform_coordinates(self, X, Y, image=None, parameters=None): - return Rotate_Cartesian(-(parameters["PA"].value - image.north), X, Y) - - @default_internal - def evaluate_model( - self, - X=None, - Y=None, - image: Image = None, - parameters: Parameter_Node = None, - **kwargs, - ): - if X is None: - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - XX, YY = self.transform_coordinates(X, Y, image=image, parameters=parameters) - - return self.brightness_model( - torch.abs(XX), torch.abs(YY), image=image, parameters=parameters - ) - - -class Edgeon_Sech(Edgeon_Model): - """An edgeon profile where the vertical distribution is a sech^2 - profile, subclasses define the radial profile. - - """ - - model_type = f"sech2 {Edgeon_Model.model_type}" - parameter_specs = { - "I0": {"units": "log10(flux/arcsec^2)"}, - "hs": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = Edgeon_Model._parameter_order + ("I0", "hs") - usable = False - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters: Optional[Parameter_Node] = None, **kwargs): - super().initialize(target=target, parameters=parameters) - if (parameters["I0"].value is not None) and (parameters["hs"].value is not None): - return - target_area = target[self.window] - icenter = target_area.plane_to_pixel(parameters["center"].value) - - if parameters["I0"].value is None: - with Param_Unlock(parameters["I0"]), Param_SoftLimits(parameters["I0"]): - parameters["I0"].value = torch.log10( - torch.mean( - target_area.data[ - int(icenter[0]) - 2 : int(icenter[0]) + 2, - int(icenter[1]) - 2 : int(icenter[1]) + 2, - ] - ) - / target.pixel_area.item() - ) - parameters["I0"].uncertainty = torch.std( - target_area.data[ - int(icenter[0]) - 2 : int(icenter[0]) + 2, - int(icenter[1]) - 2 : int(icenter[1]) + 2, - ] - ) / (torch.abs(parameters["I0"].value) * target.pixel_area) - if parameters["hs"].value is None: - with Param_Unlock(parameters["hs"]), Param_SoftLimits(parameters["hs"]): - parameters["hs"].value = torch.max(self.window.shape) * 0.1 - parameters["hs"].uncertainty = parameters["hs"].value / 2 - - @default_internal - def brightness_model(self, X, Y, image=None, parameters=None): - return ( - (image.pixel_area * 10 ** parameters["I0"].value) - * self.radial_model(X, image=image, parameters=parameters) - / (torch.cosh((Y + self.softening) / parameters["hs"].value) ** 2) - ) - - -class Edgeon_Isothermal(Edgeon_Sech): - """A self-gravitating locally-isothermal edgeon disk. This comes from - van der Kruit & Searle 1981. - - """ - - model_type = f"isothermal {Edgeon_Sech.model_type}" - parameter_specs = { - "rs": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = Edgeon_Sech._parameter_order + ("rs",) - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters: Optional[Parameter_Node] = None, **kwargs): - super().initialize(target=target, parameters=parameters) - if parameters["rs"].value is not None: - return - with Param_Unlock(parameters["rs"]), Param_SoftLimits(parameters["rs"]): - parameters["rs"].value = torch.max(self.window.shape) * 0.4 - parameters["rs"].uncertainty = parameters["rs"].value / 2 - - @default_internal - def radial_model(self, R, image=None, parameters=None): - Rscaled = torch.abs((R + self.softening) / parameters["rs"].value) - return Rscaled * torch.exp(-Rscaled) * torch.special.scaled_modified_bessel_k1(Rscaled) diff --git a/astrophot/models/eigen_psf.py b/astrophot/models/eigen_psf.py deleted file mode 100644 index 64d09ca0..00000000 --- a/astrophot/models/eigen_psf.py +++ /dev/null @@ -1,145 +0,0 @@ -import torch -import numpy as np - -from .psf_model_object import PSF_Model -from ..image import PSF_Image -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..utils.interpolate import interp2d -from ._shared_methods import select_target -from ..param import Param_Unlock, Param_SoftLimits -from .. import AP_config - -__all__ = ["Eigen_PSF"] - - -class Eigen_PSF(PSF_Model): - """point source model which uses multiple images as a basis for the - PSF as its representation for point sources. Using bilinear - interpolation it will shift the PSF within a pixel to accurately - represent the center location of a point source. There is no - functional form for this object type as any image can be - supplied. Note that as an argument to the model at construction - one can provide "psf" as an AstroPhot PSF_Image object. Since only - bilinear interpolation is performed, it is recommended to provide - the PSF at a higher resolution than the image if it is near the - nyquist sampling limit. Bilinear interpolation is very fast and - accurate for smooth models, so this way it is possible to do the - expensive interpolation before optimization and save time. Note - that if you do this you must provide the PSF as a PSF_Image object - with the correct pixelscale (essentially just divide the - pixelscale by the upsampling factor you used). - - Args: - eigen_basis (tensor): This is the basis set of images used to form the eigen point source, it should be a tensor with shape (N x W x H) where N is the number of eigen images, and W/H are the dimensions of the image. - eigen_pixelscale (float): This is the pixelscale associated with the eigen basis images. - - Parameters: - flux: the total flux of the point source model, represented as the log of the total flux. - weights: the relative amplitude of the Eigen basis modes. - - """ - - model_type = f"eigen {PSF_Model.model_type}" - parameter_specs = { - "flux": {"units": "log10(flux/arcsec^2)", "value": 0.0, "locked": True}, - "weights": {"units": "unitless"}, - } - _parameter_order = PSF_Model._parameter_order + ("flux", "weights") - usable = True - model_integrated = True - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if "eigen_basis" not in kwargs: - AP_config.ap_logger.warning( - "Eigen basis not supplied! Assuming psf as single basis element. Please provide Eigen basis or just use an empirical PSF image." - ) - self.eigen_basis = torch.clone(self.target.data).unsqueeze(0) - self.parameters["weights"].locked = True - else: - self.eigen_basis = torch.as_tensor( - kwargs["eigen_basis"], - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - if kwargs.get("normalize_eigen_basis", True): - self.eigen_basis = self.eigen_basis / torch.sum( - self.eigen_basis, axis=(1, 2) - ).unsqueeze(1).unsqueeze(2) - self.eigen_pixelscale = torch.as_tensor( - kwargs.get( - "eigen_pixelscale", - 1.0 if self.target is None else self.target.pixelscale, - ), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - target_area = target[self.window] - with Param_Unlock(parameters["flux"]), Param_SoftLimits(parameters["flux"]): - if parameters["flux"].value is None: - parameters["flux"].value = torch.log10( - torch.abs(torch.sum(target_area.data)) / target.pixel_area - ) - if parameters["flux"].uncertainty is None: - parameters["flux"].uncertainty = ( - torch.abs(parameters["flux"].value) * self.default_uncertainty - ) - with ( - Param_Unlock(parameters["weights"]), - Param_SoftLimits(parameters["weights"]), - ): - if parameters["weights"].value is None: - W = np.zeros(len(self.eigen_basis)) - W[0] = 1.0 - parameters["weights"].value = W - if parameters["weights"].uncertainty is None: - parameters["weights"].uncertainty = ( - torch.ones_like(parameters["weights"].value) * self.default_uncertainty - ) - - @default_internal - def evaluate_model(self, X=None, Y=None, image=None, parameters=None, **kwargs): - if X is None: - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - - psf_model = PSF_Image( - data=torch.clamp( - torch.sum( - self.eigen_basis.detach() - * (parameters["weights"].value / torch.linalg.norm(parameters["weights"].value)) - .unsqueeze(1) - .unsqueeze(2), - axis=0, - ), - min=0.0, - ), - pixelscale=self.eigen_pixelscale.detach(), - ) - - # Convert coordinates into pixel locations in the psf image - pX, pY = psf_model.plane_to_pixel(X, Y) - - # Select only the pixels where the PSF image is defined - select = torch.logical_and( - torch.logical_and(pX > -0.5, pX < psf_model.data.shape[1] - 0.5), - torch.logical_and(pY > -0.5, pY < psf_model.data.shape[0] - 0.5), - ) - - # Zero everywhere outside the psf - result = torch.zeros_like(X) - - # Use bilinear interpolation of the PSF at the requested coordinates - result[select] = interp2d(psf_model.data, pX[select], pY[select]) - - # Ensure positive values - result = torch.clamp(result, min=0.0) - - return result * (image.pixel_area * 10 ** parameters["flux"].value) diff --git a/astrophot/models/exponential.py b/astrophot/models/exponential.py new file mode 100644 index 00000000..84cb82ef --- /dev/null +++ b/astrophot/models/exponential.py @@ -0,0 +1,59 @@ +from .galaxy_model_object import GalaxyModel +from ..utils.decorators import combine_docstrings +from .psf_model_object import PSFModel +from .mixins import ( + ExponentialMixin, + iExponentialMixin, + RadialMixin, + WedgeMixin, + RayMixin, + SuperEllipseMixin, + FourierEllipseMixin, + WarpMixin, +) + +__all__ = [ + "ExponentialGalaxy", + "ExponentialPSF", + "ExponentialSuperEllipse", + "ExponentialFourierEllipse", + "ExponentialWarp", + "ExponentialRay", + "ExponentialWedge", +] + + +@combine_docstrings +class ExponentialGalaxy(ExponentialMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class ExponentialPSF(ExponentialMixin, RadialMixin, PSFModel): + _parameter_specs = {"Ie": {"units": "flux/arcsec^2", "value": 1.0, "dynamic": False}} + usable = True + + +@combine_docstrings +class ExponentialSuperEllipse(ExponentialMixin, RadialMixin, SuperEllipseMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class ExponentialFourierEllipse(ExponentialMixin, RadialMixin, FourierEllipseMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class ExponentialWarp(ExponentialMixin, RadialMixin, WarpMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class ExponentialRay(iExponentialMixin, RayMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class ExponentialWedge(iExponentialMixin, WedgeMixin, GalaxyModel): + usable = True diff --git a/astrophot/models/exponential_model.py b/astrophot/models/exponential_model.py deleted file mode 100644 index 1f78bea7..00000000 --- a/astrophot/models/exponential_model.py +++ /dev/null @@ -1,388 +0,0 @@ -from typing import Optional - -import torch - -from .galaxy_model_object import Galaxy_Model -from .warp_model import Warp_Galaxy -from .ray_model import Ray_Galaxy -from .psf_model_object import PSF_Model -from .superellipse_model import SuperEllipse_Galaxy, SuperEllipse_Warp -from .foureirellipse_model import FourierEllipse_Galaxy, FourierEllipse_Warp -from .wedge_model import Wedge_Galaxy -from ._shared_methods import ( - parametric_initialize, - parametric_segment_initialize, - select_target, -) -from ..param import Parameter_Node -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..utils.parametric_profiles import exponential_np - -__all__ = [ - "Exponential_Galaxy", - "Exponential_PSF", - "Exponential_SuperEllipse", - "Exponential_SuperEllipse_Warp", - "Exponential_Warp", - "Exponential_Ray", - "Exponential_Wedge", -] - - -def _x0_func(model_params, R, F): - return R[4], F[4] - - -def _wrap_exp(R, re, ie): - return exponential_np(R, re, 10**ie) - - -class Exponential_Galaxy(Galaxy_Model): - """basic galaxy model with a exponential profile for the radial light - profile. The light profile is defined as: - - I(R) = Ie * exp(-b1(R/Re - 1)) - - where I(R) is the brightness as a function of semi-major axis, Ie - is the brightness at the half light radius, b1 is a constant not - involved in the fit, R is the semi-major axis, and Re is the - effective radius. - - Parameters: - Ie: Brightness at half light radius, represented as the log of the brightness divided by pixelscale squared. This is proportional to a surface brightness - Re: half light radius, represented in arcsec. This parameter cannot go below zero. - - """ - - model_type = f"exponential {Galaxy_Model.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = Galaxy_Model._parameter_order + ("Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters: Optional[Parameter_Node] = None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_exp, ("Re", "Ie"), _x0_func) - - from ._shared_methods import exponential_radial_model as radial_model - - -class Exponential_PSF(PSF_Model): - """basic point source model with a exponential profile for the radial light - profile. - - I(R) = Ie * exp(-b1(R/Re - 1)) - - where I(R) is the brightness as a function of semi-major axis, Ie - is the brightness at the half light radius, b1 is a constant not - involved in the fit, R is the semi-major axis, and Re is the - effective radius. - - Parameters: - Ie: Brightness at half light radius, represented as the log of the brightness divided by pixelscale squared. This is proportional to a surface brightness - Re: half light radius, represented in arcsec. This parameter cannot go below zero. - - """ - - model_type = f"exponential {PSF_Model.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)", "value": 0.0, "locked": True}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = PSF_Model._parameter_order + ("Re", "Ie") - usable = True - model_integrated = False - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_exp, ("Re", "Ie"), _x0_func) - - from ._shared_methods import exponential_radial_model as radial_model - from ._shared_methods import radial_evaluate_model as evaluate_model - - -class Exponential_SuperEllipse(SuperEllipse_Galaxy): - """super ellipse galaxy model with a exponential profile for the radial - light profile. - - I(R) = Ie * exp(-b1(R/Re - 1)) - - where I(R) is the brightness as a function of semi-major axis, Ie - is the brightness at the half light radius, b1 is a constant not - involved in the fit, R is the semi-major axis, and Re is the - effective radius. - - Parameters: - Ie: Brightness at half light radius, represented as the log of the brightness divided by pixelscale squared. This is proportional to a surface brightness - Re: half light radius, represented in arcsec. This parameter cannot go below zero. - - """ - - model_type = f"exponential {SuperEllipse_Galaxy.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = SuperEllipse_Galaxy._parameter_order + ("Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_exp, ("Re", "Ie"), _x0_func) - - from ._shared_methods import exponential_radial_model as radial_model - - -class Exponential_SuperEllipse_Warp(SuperEllipse_Warp): - """super ellipse warp galaxy model with a exponential profile for the - radial light profile. - - I(R) = Ie * exp(-b1(R/Re - 1)) - - where I(R) is the brightness as a function of semi-major axis, Ie - is the brightness at the half light radius, b1 is a constant not - involved in the fit, R is the semi-major axis, and Re is the - effective radius. - - Parameters: - Ie: Brightness at half light radius, represented as the log of the brightness divided by pixelscale squared. This is proportional to a surface brightness - Re: half light radius, represented in arcsec. This parameter cannot go below zero. - - """ - - model_type = f"exponential {SuperEllipse_Warp.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = SuperEllipse_Warp._parameter_order + ("Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_exp, ("Re", "Ie"), _x0_func) - - from ._shared_methods import exponential_radial_model as radial_model - - -class Exponential_FourierEllipse(FourierEllipse_Galaxy): - """fourier mode perturbations to ellipse galaxy model with an - exponential profile for the radial light profile. - - I(R) = Ie * exp(-b1(R/Re - 1)) - - where I(R) is the brightness as a function of semi-major axis, Ie - is the brightness at the half light radius, b1 is a constant not - involved in the fit, R is the semi-major axis, and Re is the - effective radius. - - Parameters: - Ie: Brightness at half light radius, represented as the log of the brightness divided by pixelscale squared. This is proportional to a surface brightness - Re: half light radius, represented in arcsec. This parameter cannot go below zero. - - """ - - model_type = f"exponential {FourierEllipse_Galaxy.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = FourierEllipse_Galaxy._parameter_order + ("Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_exp, ("Re", "Ie"), _x0_func) - - from ._shared_methods import exponential_radial_model as radial_model - - -class Exponential_FourierEllipse_Warp(FourierEllipse_Warp): - """fourier mode perturbations to ellipse galaxy model with a exponential - profile for the radial light profile. - - I(R) = Ie * exp(-b1(R/Re - 1)) - - where I(R) is the brightness as a function of semi-major axis, Ie - is the brightness at the half light radius, b1 is a constant not - involved in the fit, R is the semi-major axis, and Re is the - effective radius. - - Parameters: - Ie: Brightness at half light radius, represented as the log of the brightness divided by pixelscale squared. This is proportional to a surface brightness - Re: half light radius, represented in arcsec. This parameter cannot go below zero. - - """ - - model_type = f"exponential {FourierEllipse_Warp.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = FourierEllipse_Warp._parameter_order + ("Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_exp, ("Re", "Ie"), _x0_func) - - from ._shared_methods import exponential_radial_model as radial_model - - -class Exponential_Warp(Warp_Galaxy): - """warped coordinate galaxy model with a exponential profile for the - radial light model. - - I(R) = Ie * exp(-b1(R/Re - 1)) - - where I(R) is the brightness as a function of semi-major axis, Ie - is the brightness at the half light radius, b1 is a constant not - involved in the fit, R is the semi-major axis, and Re is the - effective radius. - - Parameters: - Ie: Brightness at half light radius, represented as the log of the brightness divided by pixelscale squared. This is proportional to a surface brightness - Re: half light radius, represented in arcsec. This parameter cannot go below zero. - - """ - - model_type = f"exponential {Warp_Galaxy.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = Warp_Galaxy._parameter_order + ("Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_exp, ("Re", "Ie"), _x0_func) - - from ._shared_methods import exponential_radial_model as radial_model - - -class Exponential_Ray(Ray_Galaxy): - """ray galaxy model with a sersic profile for the radial light - model. The functional form of the Sersic profile is defined as: - - I(R) = Ie * exp(- bn((R/Re) - 1)) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ie is the brightness as the - half light radius, bn is a function of n and is not involved in - the fit, Re is the half light radius. - - Parameters: - Ie: brightness at the half light radius, represented as the log of the brightness divided by pixel scale squared. - Re: half light radius - - """ - - model_type = f"exponential {Ray_Galaxy.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = Ray_Galaxy._parameter_order + ("Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_segment_initialize( - model=self, - parameters=parameters, - target=target, - prof_func=_wrap_exp, - params=("Re", "Ie"), - x0_func=_x0_func, - segments=self.rays, - ) - - from ._shared_methods import exponential_iradial_model as iradial_model - - -class Exponential_Wedge(Wedge_Galaxy): - """wedge galaxy model with a exponential profile for the radial light - model. The functional form of the Sersic profile is defined as: - - I(R) = Ie * exp(- bn((R/Re) - 1)) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ie is the brightness as the - half light radius, bn is a function of n and is not involved in - the fit, Re is the half light radius. - - Parameters: - Ie: brightness at the half light radius, represented as the log of the brightness divided by pixel scale squared. - Re: half light radius - - """ - - model_type = f"exponential {Wedge_Galaxy.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = Wedge_Galaxy._parameter_order + ("Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_segment_initialize( - model=self, - parameters=parameters, - target=target, - prof_func=_wrap_exp, - params=("Re", "Ie"), - x0_func=_x0_func, - segments=self.wedges, - ) - - from ._shared_methods import exponential_iradial_model as iradial_model diff --git a/astrophot/models/ferrer.py b/astrophot/models/ferrer.py new file mode 100644 index 00000000..39c87d70 --- /dev/null +++ b/astrophot/models/ferrer.py @@ -0,0 +1,60 @@ +from .galaxy_model_object import GalaxyModel +from .psf_model_object import PSFModel +from .mixins import ( + FerrerMixin, + RadialMixin, + WedgeMixin, + RayMixin, + SuperEllipseMixin, + FourierEllipseMixin, + WarpMixin, + iFerrerMixin, +) +from ..utils.decorators import combine_docstrings + + +__all__ = ( + "FerrerGalaxy", + "FerrerPSF", + "FerrerSuperEllipse", + "FerrerFourierEllipse", + "FerrerWarp", + "FerrerRay", + "FerrerWedge", +) + + +@combine_docstrings +class FerrerGalaxy(FerrerMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class FerrerPSF(FerrerMixin, RadialMixin, PSFModel): + _parameter_specs = {"I0": {"units": "flux/arcsec^2", "value": 1.0, "dynamic": False}} + usable = True + + +@combine_docstrings +class FerrerSuperEllipse(FerrerMixin, SuperEllipseMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class FerrerFourierEllipse(FerrerMixin, FourierEllipseMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class FerrerWarp(FerrerMixin, WarpMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class FerrerRay(iFerrerMixin, RayMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class FerrerWedge(iFerrerMixin, WedgeMixin, GalaxyModel): + usable = True diff --git a/astrophot/models/flatsky.py b/astrophot/models/flatsky.py new file mode 100644 index 00000000..75170e81 --- /dev/null +++ b/astrophot/models/flatsky.py @@ -0,0 +1,43 @@ +import numpy as np +import torch + +from ..utils.decorators import ignore_numpy_warnings, combine_docstrings +from .sky_model_object import SkyModel +from ..backend_obj import backend, ArrayLike +from ..param import forward + +__all__ = ["FlatSky"] + + +@combine_docstrings +class FlatSky(SkyModel): + """Model for the sky background in which all values across the image + are the same. + + **Parameters:** + - `I`: brightness for the sky, represented as the log of the brightness over pixel scale squared, this is proportional to a surface brightness + + """ + + _model_type = "flat" + _parameter_specs = {"I": {"units": "flux/arcsec^2", "dynamic": True}} + usable = True + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + if self.I.initialized: + return + + target_area = self.target[self.window] + dat = backend.to_numpy(target_area._data).copy() + mask = backend.to_numpy(target_area._mask) + dat[mask] = np.median(dat[~mask]) + + self.I.value = np.median(dat) / self.target.pixel_area.item() + + @forward + def brightness(self, x: ArrayLike, y: ArrayLike, I: ArrayLike) -> ArrayLike: + return backend.ones_like(x) * I diff --git a/astrophot/models/flatsky_model.py b/astrophot/models/flatsky_model.py deleted file mode 100644 index e9ee06bc..00000000 --- a/astrophot/models/flatsky_model.py +++ /dev/null @@ -1,55 +0,0 @@ -import numpy as np -from scipy.stats import iqr -import torch - -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..param import Param_Unlock, Param_SoftLimits -from .sky_model_object import Sky_Model -from ._shared_methods import select_target - -__all__ = ["Flat_Sky"] - - -class Flat_Sky(Sky_Model): - """Model for the sky background in which all values across the image - are the same. - - Parameters: - sky: brightness for the sky, represented as the log of the brightness over pixel scale squared, this is proportional to a surface brightness - - """ - - model_type = f"flat {Sky_Model.model_type}" - parameter_specs = { - "F": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = Sky_Model._parameter_order + ("F",) - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - with Param_Unlock(parameters["F"]), Param_SoftLimits(parameters["F"]): - if parameters["F"].value is None: - parameters["F"].value = torch.log10( - torch.abs(torch.median(target[self.window].data)) / target.pixel_area - ) - if parameters["F"].uncertainty is None: - parameters["F"].uncertainty = ( - ( - iqr( - target[self.window].data.detach().cpu().numpy(), - rng=(31.731 / 2, 100 - 31.731 / 2), - ) - / (2.0 * target.pixel_area.item()) - ) - / np.sqrt(np.prod(self.window.shape.detach().cpu().numpy())) - ) / (10 ** parameters["F"].value.item() * np.log(10)) - - def evaluate_model(self, X=None, Y=None, image=None, parameters=None, **kwargs): - ref = image.data if X is None else X - return torch.ones_like(ref) * (image.pixel_area * 10 ** parameters["F"].value) diff --git a/astrophot/models/foureirellipse_model.py b/astrophot/models/foureirellipse_model.py deleted file mode 100644 index 3cdcf417..00000000 --- a/astrophot/models/foureirellipse_model.py +++ /dev/null @@ -1,242 +0,0 @@ -import torch -import numpy as np - -from ..utils.decorators import ignore_numpy_warnings, default_internal -from .galaxy_model_object import Galaxy_Model -from .warp_model import Warp_Galaxy -from ._shared_methods import select_target -from ..param import Param_Unlock, Param_SoftLimits -from .. import AP_config - -__all__ = ["FourierEllipse_Galaxy", "FourierEllipse_Warp"] - - -class FourierEllipse_Galaxy(Galaxy_Model): - """Expanded galaxy model which includes a Fourier transformation in - its radius metric. This allows for the expression of arbitrarily - complex isophotes instead of pure ellipses. This is a common - extension of the standard elliptical representation. The form of - the Fourier perturbations is: - - R' = R * exp(sum_m(a_m * cos(m * theta + phi_m))) - - where R' is the new radius value, R is the original ellipse - radius, a_m is the amplitude of the m'th Fourier mode, m is the - index of the Fourier mode, theta is the angle around the ellipse, - and phi_m is the phase of the m'th fourier mode. This - representation is somewhat different from other Fourier mode - implementations where instead of an expoenntial it is just 1 + - sum_m(...), we opt for this formulation as it is more numerically - stable. It cannot ever produce negative radii, but to first order - the two representation are the same as can be seen by a Taylor - expansion of exp(x) = 1 + x + O(x^2). - - One can create extremely complex shapes using different Fourier - modes, however usually it is only low order modes that are of - interest. For intuition, the first Fourier mode is roughly - equivalent to a lopsided galaxy, one side will be compressed and - the opposite side will be expanded. The second mode is almost - never used as it is nearly degenerate with ellipticity. The third - mode is an alternate kind of lopsidedness for a galaxy which makes - it somewhat triangular, meaning that it is wider on one side than - the other. The fourth mode is similar to a boxyness/diskyness - parameter which tends to make more pronounced peanut shapes since - it is more rounded than a superellipse representation. Modes - higher than 4 are only useful in very specialized situations. In - general one should consider carefully why the Fourier modes are - being used for the science case at hand. - - Parameters: - am: Tensor of amplitudes for the Fourier modes, indicates the strength of each mode. - phi_m: Tensor of phases for the Fourier modes, adjusts the orientation of the mode perturbation relative to the major axis. It is cyclically defined in the range [0,2pi) - - """ - - model_type = f"fourier {Galaxy_Model.model_type}" - parameter_specs = { - "am": {"units": "none"}, - "phim": {"units": "radians", "limits": (0, 2 * np.pi), "cyclic": True}, - } - _parameter_order = Galaxy_Model._parameter_order + ("am", "phim") - modes = (1, 3, 4) - track_attrs = Galaxy_Model.track_attrs + ["modes"] - usable = False - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.modes = torch.tensor( - kwargs.get("modes", FourierEllipse_Galaxy.modes), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - - @default_internal - def angular_metric(self, X, Y, image=None, parameters=None): - return torch.atan2(Y, X) - - @default_internal - def radius_metric(self, X, Y, image=None, parameters=None): - R = super().radius_metric(X, Y, image, parameters) - theta = self.angular_metric(X, Y, image, parameters) - return R * torch.exp( - torch.sum( - parameters["am"].value.view(len(self.modes), -1) - * torch.cos( - self.modes.view(len(self.modes), -1) * theta.view(-1) - + parameters["phim"].value.view(len(self.modes), -1) - ), - 0, - ).view(theta.shape) - ) - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - with Param_Unlock(parameters["am"]), Param_SoftLimits(parameters["am"]): - if parameters["am"].value is None: - parameters["am"].value = torch.zeros( - len(self.modes), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - if parameters["am"].uncertainty is None: - parameters["am"].uncertainty = torch.tensor( - self.default_uncertainty * np.ones(len(self.modes)), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - with Param_Unlock(parameters["phim"]), Param_SoftLimits(parameters["phim"]): - if parameters["phim"].value is None: - parameters["phim"].value = torch.zeros( - len(self.modes), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - if parameters["phim"].uncertainty is None: - parameters["phim"].uncertainty = ( - torch.tensor( # Uncertainty assumed to be 5 degrees if not provided - (5 * np.pi / 180) * np.ones(len(self.modes)), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - ) - - -class FourierEllipse_Warp(Warp_Galaxy): - """Expanded warp galaxy model which includes a Fourier transformation - in its radius metric. This allows for the expression of - arbitrarily complex isophotes instead of pure ellipses. This is a - common extension of the standard elliptical representation. The - form of the Fourier perturbations is: - - R' = R * exp(sum_m(a_m * cos(m * theta + phi_m))) - - where R' is the new radius value, R is the original ellipse - radius, a_m is the amplitude of the m'th Fourier mode, m is the - index of the Fourier mode, theta is the angle around the ellipse, - and phi_m is the phase of the m'th fourier mode. This - representation is somewhat different from other Fourier mode - implementations where instead of an expoenntial it is just 1 + - sum_m(...), we opt for this formulation as it is more numerically - stable. It cannot ever produce negative radii, but to first order - the two representation are the same as can be seen by a Taylor - expansion of exp(x) = 1 + x + O(x^2). - - One can create extremely complex shapes using different Fourier - modes, however usually it is only low order modes that are of - interest. For intuition, the first Fourier mode is roughly - equivalent to a lopsided galaxy, one side will be compressed and - the opposite side will be expanded. The second mode is almost - never used as it is nearly degenerate with ellipticity. The third - mode is an alternate kind of lopsidedness for a galaxy which makes - it somewhat triangular, meaning that it is wider on one side than - the other. The fourth mode is similar to a boxyness/diskyness - parameter which tends to make more pronounced peanut shapes since - it is more rounded than a superellipse representation. Modes - higher than 4 are only useful in very specialized situations. In - general one should consider carefully why the Fourier modes are - being used for the science case at hand. - - Parameters: - am: Tensor of amplitudes for the Fourier modes, indicates the strength of each mode. - phi_m: Tensor of phases for the Fourier modes, adjusts the orientation of the mode perturbation relative to the major axis. It is cyclically defined in the range [0,2pi) - - """ - - model_type = f"fourier {Warp_Galaxy.model_type}" - parameter_specs = { - "am": {"units": "none"}, - "phim": {"units": "radians", "limits": (0, 2 * np.pi), "cyclic": True}, - } - _parameter_order = Warp_Galaxy._parameter_order + ("am", "phim") - modes = (1, 3, 4) - track_attrs = Galaxy_Model.track_attrs + ["modes"] - usable = False - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.modes = torch.tensor( - kwargs.get("modes", FourierEllipse_Warp.modes), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - - @default_internal - def angular_metric(self, X, Y, image=None, parameters=None): - return torch.atan2(Y, X) - - @default_internal - def radius_metric(self, X, Y, image=None, parameters=None): - R = super().radius_metric(X, Y, image, parameters) - theta = self.angular_metric(X, Y, image, parameters) - return R * torch.exp( - torch.sum( - parameters["am"].value.view(len(self.modes), -1) - * torch.cos( - self.modes.view(len(self.modes), -1) * theta.view(-1) - + parameters["phim"].value.view(len(self.modes), -1) - ), - 0, - ).view(theta.shape) - ) - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - with Param_Unlock(parameters["am"]), Param_SoftLimits(parameters["am"]): - if parameters["am"].value is None: - parameters["am"].value = torch.zeros( - len(self.modes), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - if parameters["am"].uncertainty is None: - parameters["am"].uncertainty = torch.tensor( - self.default_uncertainty * np.ones(len(self.modes)), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - with Param_Unlock(parameters["phim"]), Param_SoftLimits(parameters["phim"]): - if parameters["phim"].value is None: - parameters["phim"].value = torch.zeros( - len(self.modes), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - if parameters["phim"].uncertainty is None: - parameters["phim"].uncertainty = torch.tensor( - (5 * np.pi / 180) - * np.ones( - len(self.modes) - ), # Uncertainty assumed to be 5 degrees if not provided - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) diff --git a/astrophot/models/func/__init__.py b/astrophot/models/func/__init__.py new file mode 100644 index 00000000..79e7e8e6 --- /dev/null +++ b/astrophot/models/func/__init__.py @@ -0,0 +1,53 @@ +from .base import all_subclasses +from .integration import ( + quad_table, + pixel_center_integrator, + pixel_simpsons_integrator, + pixel_quad_integrator, + single_quad_integrate, + recursive_quad_integrate, + upsample, + bright_integrate, +) +from .convolution import ( + convolve, + curvature_kernel, +) +from .sersic import sersic, sersic_n_to_b +from .moffat import moffat +from .ferrer import ferrer +from .king import king +from .gaussian import gaussian +from .gaussian_ellipsoid import euler_rotation_matrix +from .exponential import exponential +from .nuker import nuker +from .spline import spline +from .transform import rotate +from .zernike import zernike_n_m_list, zernike_n_m_modes + +__all__ = ( + "all_subclasses", + "quad_table", + "pixel_center_integrator", + "pixel_simpsons_integrator", + "pixel_quad_integrator", + "convolve", + "curvature_kernel", + "sersic", + "sersic_n_to_b", + "moffat", + "ferrer", + "king", + "gaussian", + "euler_rotation_matrix", + "exponential", + "nuker", + "spline", + "single_quad_integrate", + "recursive_quad_integrate", + "upsample", + "bright_integrate", + "rotate", + "zernike_n_m_list", + "zernike_n_m_modes", +) diff --git a/astrophot/models/func/base.py b/astrophot/models/func/base.py new file mode 100644 index 00000000..de9906ca --- /dev/null +++ b/astrophot/models/func/base.py @@ -0,0 +1,4 @@ +def all_subclasses(cls): + return set(cls.__subclasses__()).union( + [s for c in cls.__subclasses__() for s in all_subclasses(c)] + ) diff --git a/astrophot/models/func/convolution.py b/astrophot/models/func/convolution.py new file mode 100644 index 00000000..aea0ecbc --- /dev/null +++ b/astrophot/models/func/convolution.py @@ -0,0 +1,31 @@ +from functools import lru_cache + +from ...backend_obj import backend, ArrayLike + + +def convolve(image: ArrayLike, psf: ArrayLike) -> ArrayLike: + + image_fft = backend.fft.rfft2(image, s=image.shape) + psf_fft = backend.fft.rfft2(psf, s=image.shape) + + convolved_fft = image_fft * psf_fft + convolved = backend.fft.irfft2(convolved_fft, s=image.shape) + return backend.roll( + convolved, + shifts=(-(psf.shape[0] // 2), -(psf.shape[1] // 2)), + dims=(0, 1), + ) + + +@lru_cache(maxsize=32) +def curvature_kernel(dtype, device): + kernel = backend.as_array( + [ + [0.0, 1.0, 0.0], + [1.0, -4.0, 1.0], + [0.0, 1.0, 0.0], + ], # [[1., -2.0, 1.], [-2.0, 4, -2.0], [1.0, -2.0, 1.0]], + device=device, + dtype=dtype, + ) + return kernel diff --git a/astrophot/models/func/exponential.py b/astrophot/models/func/exponential.py new file mode 100644 index 00000000..91fe4250 --- /dev/null +++ b/astrophot/models/func/exponential.py @@ -0,0 +1,16 @@ +from ...backend_obj import backend, ArrayLike +from .sersic import sersic_n_to_b + +b = sersic_n_to_b(1.0) + + +def exponential(R: ArrayLike, Re: ArrayLike, Ie: ArrayLike) -> ArrayLike: + """Exponential 1d profile function, specifically designed for pytorch + operations. + + **Args:** + - `R`: Radius tensor at which to evaluate the exponential function + - `Re`: Effective radius in the same units as R + - `Ie`: Effective surface density + """ + return Ie * backend.exp(-b * ((R / Re) - 1.0)) diff --git a/astrophot/models/func/ferrer.py b/astrophot/models/func/ferrer.py new file mode 100644 index 00000000..b34c82db --- /dev/null +++ b/astrophot/models/func/ferrer.py @@ -0,0 +1,22 @@ +import torch +from ...backend_obj import backend, ArrayLike + + +def ferrer( + R: ArrayLike, rout: ArrayLike, alpha: ArrayLike, beta: ArrayLike, I0: ArrayLike +) -> ArrayLike: + """ + Modified Ferrer profile. + + **Args:** + - `R`: Radius tensor at which to evaluate the modified Ferrer function + - `rout`: Outer radius of the profile + - `alpha`: Power-law index + - `beta`: Exponent for the modified Ferrer function + - `I0`: Central intensity + """ + return backend.where( + R < rout, + I0 * ((1 - (backend.clamp(R, 0, rout) / rout) ** (2 - beta)) ** alpha), + backend.zeros_like(R), + ) diff --git a/astrophot/models/func/gaussian.py b/astrophot/models/func/gaussian.py new file mode 100644 index 00000000..7a4085e1 --- /dev/null +++ b/astrophot/models/func/gaussian.py @@ -0,0 +1,17 @@ +import torch +from ...backend_obj import backend, ArrayLike +import numpy as np + +sq_2pi = np.sqrt(2 * np.pi) + + +def gaussian(R: ArrayLike, sigma: ArrayLike, flux: ArrayLike) -> ArrayLike: + """Gaussian 1d profile function, specifically designed for pytorch + operations. + + **Args:** + - `R`: Radii tensor at which to evaluate the gaussian function + - `sigma`: Standard deviation of the gaussian in the same units as R + - `flux`: Central surface density + """ + return (flux / (sq_2pi * sigma)) * backend.exp(-0.5 * (R / sigma) ** 2) diff --git a/astrophot/models/func/gaussian_ellipsoid.py b/astrophot/models/func/gaussian_ellipsoid.py new file mode 100644 index 00000000..4b07e9cf --- /dev/null +++ b/astrophot/models/func/gaussian_ellipsoid.py @@ -0,0 +1,24 @@ +from ...backend_obj import backend, ArrayLike + + +def euler_rotation_matrix(alpha: ArrayLike, beta: ArrayLike, gamma: ArrayLike) -> ArrayLike: + """Compute the rotation matrix from Euler angles. + + See the Z_alpha X_beta Z_gamma convention for the order of rotations here: + https://en.wikipedia.org/wiki/Euler_angles + """ + ca = backend.cos(alpha) + sa = backend.sin(alpha) + cb = backend.cos(beta) + sb = backend.sin(beta) + cg = backend.cos(gamma) + sg = backend.sin(gamma) + R = backend.stack( + ( + backend.stack((ca * cg - cb * sa * sg, -ca * sg - cb * cg * sa, sb * sa)), + backend.stack((cg * sa + ca * cb * sg, ca * cb * cg - sa * sg, -ca * sb)), + backend.stack((sb * sg, sb * cg, cb)), + ), + dim=-1, + ) + return R diff --git a/astrophot/models/func/integration.py b/astrophot/models/func/integration.py new file mode 100644 index 00000000..b5009ba8 --- /dev/null +++ b/astrophot/models/func/integration.py @@ -0,0 +1,140 @@ +from typing import Tuple +import numpy as np + +from ...utils.integration import quad_table +from ...backend_obj import backend, ArrayLike +from ... import config + + +def pixel_center_integrator(Z: ArrayLike) -> ArrayLike: + return Z + + +def pixel_simpsons_integrator(Z: ArrayLike) -> ArrayLike: + kernel = ( + backend.as_array( + [[[[1, 4, 1], [4, 16, 4], [1, 4, 1]]]], dtype=config.DTYPE, device=config.DEVICE + ) + / 36.0 + ) + Z = backend.conv2d(Z.reshape(1, 1, *Z.shape), kernel, padding="valid", stride=2) + return Z.squeeze(0).squeeze(0) + + +def pixel_quad_integrator(Z: ArrayLike, w: ArrayLike = None, order: int = 3) -> ArrayLike: + """ + Integrate the pixel values using quadrature weights. + + **Args:** + - `Z`: The tensor containing pixel values. + - `w`: The quadrature weights. + - `order`: The order of the quadrature. + """ + if w is None: + _, _, w = quad_table(order, config.DTYPE, config.DEVICE) + Z = Z * w + return backend.sum(Z, dim=-1) + + +def upsample(i: ArrayLike, j: ArrayLike, order: int, scale: float) -> Tuple[ArrayLike, ArrayLike]: + dp = ( + backend.linspace(-1, 1, order, dtype=config.DTYPE, device=config.DEVICE) + * (order - 1) + / (2.0 * order) + ) + di, dj = backend.meshgrid(dp, dp, indexing="xy") + + si = backend.repeat(i[..., None], order**2, -1) + scale * di.flatten() + sj = backend.repeat(j[..., None], order**2, -1) + scale * dj.flatten() + return si, sj + + +def single_quad_integrate( + i: ArrayLike, j: ArrayLike, brightness_ij, scale: float, quad_order: int = 3 +) -> Tuple[ArrayLike, ArrayLike]: + di, dj, w = quad_table(quad_order, config.DTYPE, config.DEVICE) + qi = backend.repeat(i[..., None], quad_order**2, -1) + scale * di.flatten() + qj = backend.repeat(j[..., None], quad_order**2, -1) + scale * dj.flatten() + z = brightness_ij(qi, qj) + z0 = backend.mean(z, dim=-1) + z = backend.sum(z * w.flatten(), dim=-1) + return z, z0 + + +def recursive_quad_integrate( + i: ArrayLike, + j: ArrayLike, + brightness_ij: callable, + curve_frac: float, + scale: float = 1.0, + quad_order: int = 3, + gridding: int = 5, + _current_depth: int = 0, + max_depth: int = 1, +) -> ArrayLike: + z, z0 = single_quad_integrate(i, j, brightness_ij, scale, quad_order) + + if _current_depth >= max_depth: + return z + + N = max(1, int(np.prod(z.shape) * curve_frac)) + select = backend.topk(backend.abs(z - z0).flatten(), N)[1] + + integral_flat = z.flatten() + + si, sj = upsample(i.flatten()[select], j.flatten()[select], gridding, scale) + + integral_flat = backend.fill_at_indices( + integral_flat, + select, + backend.mean( + recursive_quad_integrate( + si, + sj, + brightness_ij, + curve_frac=curve_frac, + scale=scale / gridding, + quad_order=quad_order, + gridding=gridding, + _current_depth=_current_depth + 1, + max_depth=max_depth, + ), + dim=-1, + ), + ) + + return integral_flat.reshape(z.shape) + + +def bright_integrate( + z: ArrayLike, + i: ArrayLike, + j: ArrayLike, + brightness_ij: callable, + bright_frac: float, + scale: float = 1.0, + quad_order: int = 3, + gridding: int = 5, + max_depth: int = 2, +): + trace = [] + for d in range(max_depth): + N = max(1, int(np.prod(z.shape) * bright_frac)) + z_flat = z.flatten() + select = backend.topk(z_flat, N)[1] + trace.append([z_flat, select, z.shape]) + if d > 0: + i, j = upsample(i.flatten()[select], j.flatten()[select], gridding, scale) + scale = scale / gridding + else: + i, j = i.flatten()[select].reshape(-1, 1), j.flatten()[select].reshape(-1, 1) + z, _ = single_quad_integrate(i, j, brightness_ij, scale, quad_order) + trace.append([z, None, z.shape]) + + for _ in reversed(range(1, max_depth + 1)): + T = trace.pop(-1) + trace[-1][0] = backend.fill_at_indices( + trace[-1][0], trace[-1][1], backend.mean(T[0].reshape(T[2]), dim=-1) + ) + + return trace[0][0].reshape(trace[0][2]) diff --git a/astrophot/models/func/king.py b/astrophot/models/func/king.py new file mode 100644 index 00000000..7246160b --- /dev/null +++ b/astrophot/models/func/king.py @@ -0,0 +1,21 @@ +from ...backend_obj import backend, ArrayLike + + +def king(R: ArrayLike, Rc: ArrayLike, Rt: ArrayLike, alpha: ArrayLike, I0: ArrayLike) -> ArrayLike: + """ + Empirical King profile. + + **Args:** + - `R`: Radial distance from the center of the profile. + - `Rc`: Core radius of the profile. + - `Rt`: Truncation radius of the profile. + - `alpha`: Power-law index of the profile. + - `I0`: Central intensity of the profile. + """ + beta = 1 / (1 + (Rt / Rc) ** 2) ** (1 / alpha) + gamma = 1 / (1 + (R / Rc) ** 2) ** (1 / alpha) + return backend.where( + R < Rt, + I0 * ((backend.clamp(gamma, 0, 1) - beta) / (1 - beta)) ** alpha, + backend.zeros_like(R), + ) diff --git a/astrophot/models/func/moffat.py b/astrophot/models/func/moffat.py new file mode 100644 index 00000000..d50a0c3a --- /dev/null +++ b/astrophot/models/func/moffat.py @@ -0,0 +1,14 @@ +from ...backend_obj import ArrayLike + + +def moffat(R: ArrayLike, n: ArrayLike, Rd: ArrayLike, I0: ArrayLike) -> ArrayLike: + """Moffat 1d profile function + + **Args:** + - `R`: Radii tensor at which to evaluate the moffat function + - `n`: concentration index + - `Rd`: scale length in the same units as R + - `I0`: central surface density + + """ + return I0 / (1 + (R / Rd) ** 2) ** n diff --git a/astrophot/models/func/nuker.py b/astrophot/models/func/nuker.py new file mode 100644 index 00000000..a5f34b25 --- /dev/null +++ b/astrophot/models/func/nuker.py @@ -0,0 +1,28 @@ +from ...backend_obj import ArrayLike + + +def nuker( + R: ArrayLike, + Rb: ArrayLike, + Ib: ArrayLike, + alpha: ArrayLike, + beta: ArrayLike, + gamma: ArrayLike, +) -> ArrayLike: + """Nuker 1d profile function + + **Args:** + - `R`: Radii tensor at which to evaluate the nuker function + - `Ib`: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. + - `Rb`: scale length radius + - `alpha`: sharpness of transition between power law slopes + - `beta`: outer power law slope + - `gamma`: inner power law slope + + """ + return ( + Ib + * (2 ** ((beta - gamma) / alpha)) + * ((R / Rb) ** (-gamma)) + * ((1 + (R / Rb) ** alpha) ** ((gamma - beta) / alpha)) + ) diff --git a/astrophot/models/func/sersic.py b/astrophot/models/func/sersic.py new file mode 100644 index 00000000..79165fd7 --- /dev/null +++ b/astrophot/models/func/sersic.py @@ -0,0 +1,32 @@ +from ...backend_obj import backend, ArrayLike + + +C1 = 4 / 405 +C2 = 46 / 25515 +C3 = 131 / 1148175 +C4 = -2194697 / 30690717750 + + +def sersic_n_to_b(n: float) -> float: + """Compute the `b(n)` for a sersic model. This factor ensures that + the $R_e$ and $I_e$ parameters do in fact correspond + to the half light values and not some other scale + radius/intensity. + + """ + x = 1 / n + return 2 * n - 1 / 3 + x * (C1 + x * (C2 + x * (C3 + C4 * x))) + + +def sersic(R: ArrayLike, n: ArrayLike, Re: ArrayLike, Ie: ArrayLike) -> ArrayLike: + """Seric 1d profile function, specifically designed for pytorch + operations + + **Args:** + - `R`: Radii tensor at which to evaluate the sersic function + - `n`: sersic index restricted to n > 0.36 + - `Re`: Effective radius in the same units as R + - `Ie`: Effective surface density + """ + bn = sersic_n_to_b(n) + return Ie * backend.exp(-bn * ((R / Re) ** (1 / n) - 1)) diff --git a/astrophot/models/func/spline.py b/astrophot/models/func/spline.py new file mode 100644 index 00000000..0fdb344b --- /dev/null +++ b/astrophot/models/func/spline.py @@ -0,0 +1,67 @@ +from ...backend_obj import backend, ArrayLike +from ... import config + + +def _h_poly(t: ArrayLike) -> ArrayLike: + """Helper function to compute the 'h' polynomial matrix used in the + cubic spline. + + Args: + t (Tensor): A 1D tensor representing the normalized x values. + + Returns: + Tensor: A 2D tensor of size (4, len(t)) representing the 'h' polynomial matrix. + + """ + + tt = t[None, :] ** (backend.arange(4, device=config.DEVICE)[:, None]) + A = backend.as_array( + [ + [1.0, 0.0, -3.0, 2.0], + [0.0, 1.0, -2.0, 1.0], + [0.0, 0.0, 3.0, -2.0], + [0.0, 0.0, -1.0, 1.0], + ], + dtype=config.DTYPE, + device=config.DEVICE, + ) + return A @ tt + + +def cubic_spline_torch(x: ArrayLike, y: ArrayLike, xs: ArrayLike) -> ArrayLike: + """Compute the 1D cubic spline interpolation for the given data points + using PyTorch. + + **Args:** + - `x` (Tensor): A 1D tensor representing the x-coordinates of the known data points. + - `y` (Tensor): A 1D tensor representing the y-coordinates of the known data points. + - `xs` (Tensor): A 1D tensor representing the x-coordinates of the positions where + the cubic spline function should be evaluated. + """ + m = (y[1:] - y[:-1]) / (x[1:] - x[:-1]) + m = backend.concatenate([m[0].flatten(), (m[1:] + m[:-1]) / 2, m[-1].flatten()]) + idxs = backend.searchsorted(x[:-1], xs) - 1 + dx = x[idxs + 1] - x[idxs] + hh = _h_poly((xs - x[idxs]) / dx) + ret = hh[0] * y[idxs] + hh[1] * m[idxs] * dx + hh[2] * y[idxs + 1] + hh[3] * m[idxs + 1] * dx + return ret + + +def spline(R: ArrayLike, profR: ArrayLike, profI: ArrayLike, extend: str = "zeros") -> ArrayLike: + """Spline 1d profile function, cubic spline between points up + to second last point beyond which is linear + + **Args:** + - `R`: Radii tensor at which to evaluate the spline function + - `profR`: radius values for the surface density profile in the same units as `R` + - `profI`: surface density values for the surface density profile + - `extend`: How to extend the spline beyond the last point. Options are 'zeros' or 'const'. + """ + I = cubic_spline_torch(profR, profI, R.flatten()).reshape(*R.shape) + if extend == "zeros": + backend.fill_at_indices(I, R > profR[-1], 0) + elif extend == "const": + backend.fill_at_indices(I, R > profR[-1], profI[-1]) + else: + raise ValueError(f"Unknown extend option: {extend}. Use 'zeros' or 'const'.") + return I diff --git a/astrophot/models/func/transform.py b/astrophot/models/func/transform.py new file mode 100644 index 00000000..b9252589 --- /dev/null +++ b/astrophot/models/func/transform.py @@ -0,0 +1,11 @@ +from typing import Tuple +from ...backend_obj import backend, ArrayLike + + +def rotate(theta: ArrayLike, x: ArrayLike, y: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: + """ + Applies a rotation matrix to the X,Y coordinates + """ + s = backend.sin(theta) + c = backend.cos(theta) + return c * x - s * y, s * x + c * y diff --git a/astrophot/models/func/zernike.py b/astrophot/models/func/zernike.py new file mode 100644 index 00000000..34efa822 --- /dev/null +++ b/astrophot/models/func/zernike.py @@ -0,0 +1,38 @@ +from functools import lru_cache +from scipy.special import binom +import numpy as np + + +@lru_cache(maxsize=1024) +def coefficients(n: int, m: int) -> list[tuple[int, float]]: + C = [] + for k in range(int((n - abs(m)) / 2) + 1): + C.append( + ( + k, + (-1) ** k * binom(n - k, k) * binom(n - 2 * k, (n - abs(m)) / 2 - k), + ) + ) + return C + + +def zernike_n_m_list(n: int) -> list[tuple[int, int]]: + nm = [] + for n_i in range(n + 1): + for m_i in range(-n_i, n_i + 1, 2): + nm.append((n_i, m_i)) + return nm + + +def zernike_n_m_modes(rho: np.ndarray, phi: np.ndarray, n: int, m: int) -> np.ndarray: + Z = np.zeros_like(rho) + for k, c in coefficients(n, m): + R = rho ** (n - 2 * k) + T = 1.0 + if m < 0: + T = np.sin(abs(m) * phi) + elif m > 0: + T = np.cos(m * phi) + + Z = Z + c * R * T + return Z * (rho <= 1).astype(np.float64) diff --git a/astrophot/models/galaxy_model_object.py b/astrophot/models/galaxy_model_object.py index bbe2ec57..53beb529 100644 --- a/astrophot/models/galaxy_model_object.py +++ b/astrophot/models/galaxy_model_object.py @@ -1,116 +1,14 @@ -from typing import Optional +from .model_object import ComponentModel +from .mixins import InclinedMixin +from ..utils.decorators import combine_docstrings -import torch -import numpy as np -from scipy.stats import iqr -from ..utils.initialize import isophotes -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..utils.angle_operations import Angle_COM_PA -from ..utils.conversions.coordinates import ( - Rotate_Cartesian, -) -from ..param import Param_Unlock, Param_SoftLimits, Parameter_Node -from .model_object import Component_Model -from ._shared_methods import select_target +__all__ = ["GalaxyModel"] -__all__ = ["Galaxy_Model"] +@combine_docstrings +class GalaxyModel(InclinedMixin, ComponentModel): + """Intended to represent a galaxy or extended component in an image.""" - -class Galaxy_Model(Component_Model): - """General galaxy model to be subclassed for any specific - representation. Defines a galaxy as an object with a position - angle and axis ratio, or effectively a tilted disk. Most - subclassing models should simply define a radial model or update - to the coordinate transform. The definition of the position angle and axis ratio used here is simply a scaling along the minor axis. The transformation can be written as: - - X, Y = meshgrid(image) - X', Y' = Rot(theta, X, Y) - Y'' = Y' / q - - where X Y are the coordinates of an image, X' Y' are the rotated - coordinates, Rot is a rotation matrix by angle theta applied to the - initial X Y coordinates, Y'' is the scaled semi-minor axis, and q - is the axis ratio. - - Parameters: - q: axis ratio to scale minor axis from the ratio of the minor/major axis b/a, this parameter is unitless, it is restricted to the range (0,1) - PA: position angle of the smei-major axis relative to the image positive x-axis in radians, it is a cyclic parameter in the range [0,pi) - - """ - - model_type = f"galaxy {Component_Model.model_type}" - parameter_specs = { - "q": {"units": "b/a", "limits": (0, 1), "uncertainty": 0.03}, - "PA": { - "units": "radians", - "limits": (0, np.pi), - "cyclic": True, - "uncertainty": 0.06, - }, - } - _parameter_order = Component_Model._parameter_order + ("q", "PA") + _model_type = "galaxy" usable = False - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters: Optional[Parameter_Node] = None, **kwargs): - super().initialize(target=target, parameters=parameters) - - if not (parameters["PA"].value is None or parameters["q"].value is None): - return - target_area = target[self.window] - target_dat = target_area.data.detach().cpu().numpy().copy() - if target_area.has_mask: - mask = target_area.mask.detach().cpu().numpy() - target_dat[mask] = np.median(target_dat[np.logical_not(mask)]) - edge = np.concatenate( - ( - target_dat[:, 0], - target_dat[:, -1], - target_dat[0, :], - target_dat[-1, :], - ) - ) - edge_average = np.nanmedian(edge) - edge_scatter = iqr(edge[np.isfinite(edge)], rng=(16, 84)) / 2 - icenter = target_area.plane_to_pixel(parameters["center"].value) - - if parameters["PA"].value is None: - weights = target_dat - edge_average - Coords = target_area.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - X, Y = X.detach().cpu().numpy(), Y.detach().cpu().numpy() - if target_area.has_mask: - seg = np.logical_not(target_area.mask.detach().cpu().numpy()) - PA = Angle_COM_PA(weights[seg], X[seg], Y[seg]) - else: - PA = Angle_COM_PA(weights, X, Y) - - with Param_Unlock(parameters["PA"]), Param_SoftLimits(parameters["PA"]): - parameters["PA"].value = (PA + target_area.north) % np.pi - if parameters["PA"].uncertainty is None: - parameters["PA"].uncertainty = (5 * np.pi / 180) * torch.ones_like( - parameters["PA"].value - ) # default uncertainty of 5 degrees is assumed - if parameters["q"].value is None: - q_samples = np.linspace(0.2, 0.9, 15) - iso_info = isophotes( - target_area.data.detach().cpu().numpy() - edge_average, - (icenter[1].detach().cpu().item(), icenter[0].detach().cpu().item()), - threshold=3 * edge_scatter, - pa=(parameters["PA"].value - target.north).detach().cpu().item(), - q=q_samples, - ) - with Param_Unlock(parameters["q"]), Param_SoftLimits(parameters["q"]): - parameters["q"].value = q_samples[ - np.argmin(list(iso["amplitude2"] for iso in iso_info)) - ] - if parameters["q"].uncertainty is None: - parameters["q"].uncertainty = parameters["q"].value * self.default_uncertainty - - from ._shared_methods import inclined_transform_coordinates as transform_coordinates - from ._shared_methods import transformed_evaluate_model as evaluate_model diff --git a/astrophot/models/gaussian.py b/astrophot/models/gaussian.py new file mode 100644 index 00000000..900c8241 --- /dev/null +++ b/astrophot/models/gaussian.py @@ -0,0 +1,61 @@ +from .galaxy_model_object import GalaxyModel + +from .psf_model_object import PSFModel +from .mixins import ( + GaussianMixin, + RadialMixin, + WedgeMixin, + RayMixin, + SuperEllipseMixin, + FourierEllipseMixin, + WarpMixin, + iGaussianMixin, +) +from ..utils.decorators import combine_docstrings + + +__all__ = [ + "GaussianGalaxy", + "GaussianPSF", + "GaussianSuperEllipse", + "GaussianFourierEllipse", + "GaussianWarp", + "GaussianRay", + "GaussianWedge", +] + + +@combine_docstrings +class GaussianGalaxy(GaussianMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class GaussianPSF(GaussianMixin, RadialMixin, PSFModel): + _parameter_specs = {"flux": {"units": "flux", "value": 1.0, "dynamic": False}} + usable = True + + +@combine_docstrings +class GaussianSuperEllipse(GaussianMixin, SuperEllipseMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class GaussianFourierEllipse(GaussianMixin, FourierEllipseMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class GaussianWarp(GaussianMixin, WarpMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class GaussianRay(iGaussianMixin, RayMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class GaussianWedge(iGaussianMixin, WedgeMixin, GalaxyModel): + usable = True diff --git a/astrophot/models/gaussian_ellipsoid.py b/astrophot/models/gaussian_ellipsoid.py new file mode 100644 index 00000000..23fab669 --- /dev/null +++ b/astrophot/models/gaussian_ellipsoid.py @@ -0,0 +1,153 @@ +import torch +import numpy as np + +from .model_object import ComponentModel +from ..utils.decorators import ignore_numpy_warnings, combine_docstrings +from . import func +from ..param import forward +from ..backend_obj import backend, ArrayLike + +__all__ = ["GaussianEllipsoid"] + + +@combine_docstrings +class GaussianEllipsoid(ComponentModel): + """Model that represents a galaxy as a 3D Gaussian ellipsoid. + + The model is triaxial, meaning it has three different standard deviations + along the three axes. The orientation of the ellipsoid is defined by Euler + angles. + + If all three Euler angles are set to zero, the ellipsoid is aligned with the + image axes meaning sigma_a gives the std along the x axis of the tangent + plane, sigma_b gives the std along the y axis of the tangent plane, and + sigma_z gives the std into the tangent plane. We use the ZXZ convention for + the Euler angles. This means that for a disk galaxy, one can naturally + consider sigma_c as the disk thickness and sigma_a=sigma_b as the disk + radius; setting the Euler angles to zero would leave the disk face-on in the + x-y tangent plane. + + Note: + the model is highly degenerate, meaning that it is not possible to + uniquely determine the parameters from the data. The model is useful if + one already has a 3D model of the galaxy in mind and wants to produce + mock data. Alternately, if one applies some constraints on the + parameters, such as sigma_a = sigma_b and alpha=0, then the model will + be better determined. In that case, beta is related to the inclination + of the disk and gamma is related to the position angle of the disk. The + initialization for this model assumes exactly this interpretation with a + disk thickness of sigma_c = 0.2 *sigma_a. + + **Parameters:** + - `sigma_a`: Standard deviation of the Gaussian along the alpha axis in arcseconds. + - `sigma_b`: Standard deviation of the Gaussian along the beta axis in arcseconds. + - `sigma_c`: Standard deviation of the Gaussian along the gamma axis in arcseconds. + - `alpha`: Euler angle representing the rotation around the alpha axis in radians. + - `beta`: Euler angle representing the rotation around the beta axis in radians. + - `gamma`: Euler angle representing the rotation around the gamma axis in radians. + - `flux`: Total flux of the galaxy in arbitrary units. + + """ + + _model_type = "gaussianellipsoid" + _parameter_specs = { + "sigma_a": {"units": "arcsec", "valid": (0, None), "shape": (), "dynamic": True}, + "sigma_b": {"units": "arcsec", "valid": (0, None), "shape": (), "dynamic": True}, + "sigma_c": {"units": "arcsec", "valid": (0, None), "shape": (), "dynamic": True}, + "alpha": { + "units": "radians", + "valid": (0, 2 * np.pi), + "cyclic": True, + "shape": (), + "dynamic": True, + }, + "beta": { + "units": "radians", + "valid": (0, 2 * np.pi), + "cyclic": True, + "shape": (), + "dynamic": True, + }, + "gamma": { + "units": "radians", + "valid": (0, 2 * np.pi), + "cyclic": True, + "shape": (), + "dynamic": True, + }, + "flux": {"units": "flux", "shape": (), "dynamic": True}, + } + usable = True + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + if any(self[key].initialized for key in GaussianEllipsoid._parameter_specs): + return + + self.sigma_b = self.sigma_a + self.sigma_c = lambda p: 0.2 * p.sigma_a.value + self.sigma_c.link(self.sigma_a) + self.alpha = 0.0 + + target_area = self.target[self.window] + dat = backend.to_numpy(target_area._data).copy() + mask = backend.to_numpy(target_area._mask).copy() + dat[mask] = np.median(dat[~mask]) + + edge = np.concatenate((dat[:, 0], dat[:, -1], dat[0, :], dat[-1, :])) + edge_average = np.nanmedian(edge) + dat -= edge_average + + x, y = target_area.coordinate_center_meshgrid() + center = self.center.value + x = x - center[0] + y = y - center[1] + r = backend.to_numpy(self.radius_metric(x, y, params=())) + self.sigma_a.value = np.sqrt(np.sum((r * dat) ** 2) / np.sum(r**2)) + + x = backend.to_numpy(x) + y = backend.to_numpy(y) + + mu20 = np.median(dat * np.abs(x)) + mu02 = np.median(dat * np.abs(y)) + mu11 = np.median(dat * x * y / np.sqrt(np.abs(x * y) + self.softening**2)) + M = np.array([[mu20, mu11], [mu11, mu02]]) + if np.any(np.iscomplex(M)) or np.any(~np.isfinite(M)): + PA = np.pi / 2 + l = (0.7, 1.0) + else: + PA = (0.5 * np.arctan2(2 * mu11, mu20 - mu02) - np.pi / 2) % np.pi + l = np.sort(np.linalg.eigvals(M)) + q = np.clip(np.sqrt(l[0] / l[1]), 0.1, 0.9) + self.beta.value = np.arccos(q) + self.gamma.value = PA + self.flux.value = np.sum(dat) + + @forward + def brightness( + self, + x: ArrayLike, + y: ArrayLike, + sigma_a: ArrayLike, + sigma_b: ArrayLike, + sigma_c: ArrayLike, + alpha: ArrayLike, + beta: ArrayLike, + gamma: ArrayLike, + flux: ArrayLike, + ) -> ArrayLike: + """Brightness of the Gaussian ellipsoid.""" + D = backend.diag(backend.stack((sigma_a, sigma_b, sigma_c)) ** 2) + R = func.euler_rotation_matrix(alpha, beta, gamma) + Sigma = R @ D @ R.T + Sigma2D = Sigma[:2, :2] + inv_Sigma = backend.linalg.inv(Sigma2D) + v = backend.stack(self.transform_coordinates(x, y), dim=0).reshape(2, -1) + return ( + flux + * backend.exp(-0.5 * backend.sum(v * (inv_Sigma @ v), dim=0)) + / (2 * np.pi * backend.sqrt(backend.linalg.det(Sigma2D))) + ).reshape(x.shape) diff --git a/astrophot/models/gaussian_model.py b/astrophot/models/gaussian_model.py deleted file mode 100644 index 8213dc8a..00000000 --- a/astrophot/models/gaussian_model.py +++ /dev/null @@ -1,378 +0,0 @@ -import torch - -from .galaxy_model_object import Galaxy_Model -from .warp_model import Warp_Galaxy -from .superellipse_model import SuperEllipse_Galaxy, SuperEllipse_Warp -from .foureirellipse_model import FourierEllipse_Galaxy, FourierEllipse_Warp -from .ray_model import Ray_Galaxy -from .wedge_model import Wedge_Galaxy -from .psf_model_object import PSF_Model -from ._shared_methods import ( - parametric_initialize, - parametric_segment_initialize, - select_target, -) -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..utils.parametric_profiles import gaussian_np - -__all__ = [ - "Gaussian_Galaxy", - "Gaussian_SuperEllipse", - "Gaussian_SuperEllipse_Warp", - "Gaussian_FourierEllipse", - "Gaussian_FourierEllipse_Warp", - "Gaussian_Warp", - "Gaussian_PSF", -] - - -def _x0_func(model_params, R, F): - return R[4], F[0] - - -def _wrap_gauss(R, sig, flu): - return gaussian_np(R, sig, 10**flu) - - -class Gaussian_Galaxy(Galaxy_Model): - """Basic galaxy model with Gaussian as the radial light profile. The - gaussian radial profile is defined as: - - I(R) = F * exp(-0.5 R^2/S^2) / sqrt(2pi*S^2) - - where I(R) is the prightness as a function of semi-major axis - length, F is the total flux in the model, R is the semi-major - axis, and S is the standard deviation. - - Parameters: - sigma: standard deviation of the gaussian profile, must be a positive value - flux: the total flux in the gaussian model, represented as the log of the total - - """ - - model_type = f"gaussian {Galaxy_Model.model_type}" - parameter_specs = { - "sigma": {"units": "arcsec", "limits": (0, None)}, - "flux": {"units": "log10(flux)"}, - } - _parameter_order = Galaxy_Model._parameter_order + ("sigma", "flux") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_gauss, ("sigma", "flux"), _x0_func) - - from ._shared_methods import gaussian_radial_model as radial_model - - -class Gaussian_SuperEllipse(SuperEllipse_Galaxy): - """Super ellipse galaxy model with Gaussian as the radial light - profile.The gaussian radial profile is defined as: - - I(R) = F * exp(-0.5 R^2/S^2) / sqrt(2pi*S^2) - - where I(R) is the prightness as a function of semi-major axis - length, F is the total flux in the model, R is the semi-major - axis, and S is the standard deviation. - - Parameters: - sigma: standard deviation of the gaussian profile, must be a positive value - flux: the total flux in the gaussian model, represented as the log of the total - - """ - - model_type = f"gaussian {SuperEllipse_Galaxy.model_type}" - parameter_specs = { - "sigma": {"units": "arcsec", "limits": (0, None)}, - "flux": {"units": "log10(flux)"}, - } - _parameter_order = SuperEllipse_Galaxy._parameter_order + ("sigma", "flux") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_gauss, ("sigma", "flux"), _x0_func) - - from ._shared_methods import gaussian_radial_model as radial_model - - -class Gaussian_SuperEllipse_Warp(SuperEllipse_Warp): - """super ellipse warp galaxy model with a gaussian profile for the - radial light profile. The gaussian radial profile is defined as: - - I(R) = F * exp(-0.5 R^2/S^2) / sqrt(2pi*S^2) - - where I(R) is the prightness as a function of semi-major axis - length, F is the total flux in the model, R is the semi-major - axis, and S is the standard deviation. - - Parameters: - sigma: standard deviation of the gaussian profile, must be a positive value - flux: the total flux in the gaussian model, represented as the log of the total - - """ - - model_type = f"gaussian {SuperEllipse_Warp.model_type}" - parameter_specs = { - "sigma": {"units": "arcsec", "limits": (0, None)}, - "flux": {"units": "log10(flux)"}, - } - _parameter_order = SuperEllipse_Warp._parameter_order + ("sigma", "flux") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_gauss, ("sigma", "flux"), _x0_func) - - from ._shared_methods import gaussian_radial_model as radial_model - - -class Gaussian_FourierEllipse(FourierEllipse_Galaxy): - """fourier mode perturbations to ellipse galaxy model with a gaussian - profile for the radial light profile. The gaussian radial profile - is defined as: - - I(R) = F * exp(-0.5 R^2/S^2) / sqrt(2pi*S^2) - - where I(R) is the prightness as a function of semi-major axis - length, F is the total flux in the model, R is the semi-major - axis, and S is the standard deviation. - - Parameters: - sigma: standard deviation of the gaussian profile, must be a positive value - flux: the total flux in the gaussian model, represented as the log of the total - - """ - - model_type = f"gaussian {FourierEllipse_Galaxy.model_type}" - parameter_specs = { - "sigma": {"units": "arcsec", "limits": (0, None)}, - "flux": {"units": "log10(flux)"}, - } - _parameter_order = FourierEllipse_Galaxy._parameter_order + ("sigma", "flux") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_gauss, ("sigma", "flux"), _x0_func) - - from ._shared_methods import gaussian_radial_model as radial_model - - -class Gaussian_FourierEllipse_Warp(FourierEllipse_Warp): - """fourier mode perturbations to ellipse galaxy model with a gaussian - profile for the radial light profile. The gaussian radial profile - is defined as: - - I(R) = F * exp(-0.5 R^2/S^2) / sqrt(2pi*S^2) - - where I(R) is the prightness as a function of semi-major axis - length, F is the total flux in the model, R is the semi-major - axis, and S is the standard deviation. - - Parameters: - sigma: standard deviation of the gaussian profile, must be a positive value - flux: the total flux in the gaussian model, represented as the log of the total - - """ - - model_type = f"gaussian {FourierEllipse_Warp.model_type}" - parameter_specs = { - "sigma": {"units": "arcsec", "limits": (0, None)}, - "flux": {"units": "log10(flux)"}, - } - _parameter_order = FourierEllipse_Warp._parameter_order + ("sigma", "flux") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_gauss, ("sigma", "flux"), _x0_func) - - from ._shared_methods import gaussian_radial_model as radial_model - - -class Gaussian_Warp(Warp_Galaxy): - """Coordinate warped galaxy model with Gaussian as the radial light - profile. The gaussian radial profile is defined as: - - I(R) = F * exp(-0.5 R^2/S^2) / sqrt(2pi*S^2) - - where I(R) is the prightness as a function of semi-major axis - length, F is the total flux in the model, R is the semi-major - axis, and S is the standard deviation. - - Parameters: - sigma: standard deviation of the gaussian profile, must be a positive value - flux: the total flux in the gaussian model, represented as the log of the total - - """ - - model_type = f"gaussian {Warp_Galaxy.model_type}" - parameter_specs = { - "sigma": {"units": "arcsec", "limits": (0, None)}, - "flux": {"units": "log10(flux)"}, - } - _parameter_order = Warp_Galaxy._parameter_order + ("sigma", "flux") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_gauss, ("sigma", "flux"), _x0_func) - - from ._shared_methods import gaussian_radial_model as radial_model - - -class Gaussian_PSF(PSF_Model): - """Basic point source model with a Gaussian as the radial light profile. The - gaussian radial profile is defined as: - - I(R) = F * exp(-0.5 R^2/S^2) / sqrt(2pi*S^2) - - where I(R) is the prightness as a function of semi-major axis - length, F is the total flux in the model, R is the semi-major - axis, and S is the standard deviation. - - Parameters: - sigma: standard deviation of the gaussian profile, must be a positive value - flux: the total flux in the gaussian model, represented as the log of the total - - """ - - model_type = f"gaussian {PSF_Model.model_type}" - parameter_specs = { - "sigma": {"units": "arcsec", "limits": (0, None)}, - "flux": {"units": "log10(flux)", "value": 0.0, "locked": True}, - } - _parameter_order = PSF_Model._parameter_order + ("sigma", "flux") - usable = True - model_integrated = False - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_gauss, ("sigma", "flux"), _x0_func) - - from ._shared_methods import gaussian_radial_model as radial_model - from ._shared_methods import radial_evaluate_model as evaluate_model - - -class Gaussian_Ray(Ray_Galaxy): - """ray galaxy model with a gaussian profile for the radial light - model. The gaussian radial profile is defined as: - - I(R) = F * exp(-0.5 R^2/S^2) / sqrt(2pi*S^2) - - where I(R) is the prightness as a function of semi-major axis - length, F is the total flux in the model, R is the semi-major - axis, and S is the standard deviation. - - Parameters: - sigma: standard deviation of the gaussian profile, must be a positive value - flux: the total flux in the gaussian model, represented as the log of the total - - """ - - model_type = f"gaussian {Ray_Galaxy.model_type}" - parameter_specs = { - "sigma": {"units": "arcsec", "limits": (0, None)}, - "flux": {"units": "log10(flux)"}, - } - _parameter_order = Ray_Galaxy._parameter_order + ("sigma", "flux") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_segment_initialize( - model=self, - parameters=parameters, - target=target, - prof_func=_wrap_gauss, - params=("sigma", "flux"), - x0_func=_x0_func, - segments=self.rays, - ) - - from ._shared_methods import gaussian_iradial_model as iradial_model - - -class Gaussian_Wedge(Wedge_Galaxy): - """wedge galaxy model with a gaussian profile for the radial light - model. The gaussian radial profile is defined as: - - I(R) = F * exp(-0.5 R^2/S^2) / sqrt(2pi*S^2) - - where I(R) is the prightness as a function of semi-major axis - length, F is the total flux in the model, R is the semi-major - axis, and S is the standard deviation. - - Parameters: - sigma: standard deviation of the gaussian profile, must be a positive value - flux: the total flux in the gaussian model, represented as the log of the total - - """ - - model_type = f"gaussian {Wedge_Galaxy.model_type}" - parameter_specs = { - "sigma": {"units": "arcsec", "limits": (0, None)}, - "flux": {"units": "log10(flux)"}, - } - _parameter_order = Wedge_Galaxy._parameter_order + ("sigma", "flux") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_segment_initialize( - self, - parameters, - target, - _wrap_gauss, - ("sigma", "flux"), - _x0_func, - self.wedges, - ) - - from ._shared_methods import gaussian_iradial_model as iradial_model diff --git a/astrophot/models/group_model_object.py b/astrophot/models/group_model_object.py index 01bf77c4..9a85ed38 100644 --- a/astrophot/models/group_model_object.py +++ b/astrophot/models/group_model_object.py @@ -1,30 +1,31 @@ -from typing import Optional, Sequence -from collections import OrderedDict +from typing import Optional, Sequence, Union import torch +import numpy as np +from caskade import forward -from .core_model import AstroPhot_Model -from .. import AP_config +from .base import Model from ..image import ( Image, - Target_Image, - Target_Image_List, - Image_List, + TargetImage, + TargetImageList, + ModelImage, + ModelImageList, + ImageList, Window, - Window_List, - Model_Image, - Model_Image_List, - Jacobian_Image, + WindowList, + JacobianImage, + JacobianImageList, ) -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ._shared_methods import select_target -from ..param import Parameter_Node -from ..errors import InvalidTarget +from .. import config +from ..backend_obj import backend, ArrayLike +from ..utils.decorators import ignore_numpy_warnings +from ..errors import InvalidTarget, InvalidWindow -__all__ = ["Group_Model"] +__all__ = ["GroupModel"] -class Group_Model(AstroPhot_Model): +class GroupModel(Model): """Model object which represents a list of other models. For each general AstroPhot model method, this calls all the appropriate models from its list and combines their output into a single @@ -40,88 +41,60 @@ class Group_Model(AstroPhot_Model): """ - model_type = f"group {AstroPhot_Model.model_type}" + _model_type = "group" usable = True def __init__( self, *, name: Optional[str] = None, - models: Optional[Sequence[AstroPhot_Model]] = None, + models: Optional[Sequence[Model]] = None, **kwargs, ): - if "model" in kwargs: - AP_config.ap_logger.warning( - "kwarg `model` is not used in Group_Model, did you mean `models` instead?" - ) - self._psf_mode = "none" - self._param_tuple = None - self.models = OrderedDict() super().__init__(name=name, **kwargs) - if models is not None: - self.add_model(models) - self.update_window() - if "filename" in kwargs: - self.load(kwargs["filename"], new_name=name) - - def add_model(self, model): - """Adds a new model to the group model list. Ensures that the same - model isn't added a second time. - - Parameters: - model: a model object to add to the model list. + for model in models: + if not isinstance(model, Model): + raise TypeError(f"Expected a Model instance in 'models', got {type(model)}") + self.models = models + self._update_window() - """ - if isinstance(model, (tuple, list)): - for mod in model: - self.add_model(mod) - return - if model.name in self.models: - if model is self.models[model.name]: - return - raise KeyError( - f"{self.name} already has model with name {model.name}, every model must have a unique name." - ) - - self.models[model.name] = model - self.parameters.link(model.parameters) - self.psf_mode = self.psf_mode - self.target = self.target - self.update_window() - - def update_window(self, include_locked: bool = False): + def _update_window(self): """Makes a new window object which encloses all the windows of the sub models in this group model object. """ - if isinstance(self.target, Image_List): # Window_List if target is a Target_Image_List - new_window = [None] * len(self.target.image_list) - for model in self.models.values(): - if model.locked and not include_locked: - continue - if isinstance(model.target, Image_List): + if isinstance(self.target, ImageList): # WindowList if target is a TargetImageList + new_window = list(target.window.copy() for target in self.target) + n_windows = [0] * len(self.target.images) + for model in self.models: + if isinstance(model.target, ImageList): for target, window in zip(model.target, model.window): index = self.target.index(target) - if new_window[index] is None: - new_window[index] = window.copy() + if n_windows[index] == 0: + new_window[index] &= window else: new_window[index] |= window - elif isinstance(model.target, Target_Image): + n_windows[index] += 1 + elif isinstance(model.target, TargetImage): index = self.target.index(model.target) - if new_window[index] is None: - new_window[index] = model.window.copy() + if n_windows[index] == 0: + new_window[index] &= model.window else: new_window[index] |= model.window + n_windows[index] += 1 else: raise NotImplementedError( f"Group_Model cannot construct a window for itself using {type(model.target)} object. Must be a Target_Image" ) - new_window = Window_List(new_window) + new_window = WindowList(new_window) + for i, n in enumerate(n_windows): + if n == 0: + config.logger.warning( + f"Model {self.name} has no sub models in target '{self.target.images[i].name}', this may cause issues with fitting." + ) else: new_window = None - for model in self.models.values(): - if model.locked and not include_locked: - continue + for model in self.models: if new_window is None: new_window = model.window.copy() else: @@ -130,26 +103,15 @@ def update_window(self, include_locked: bool = False): @torch.no_grad() @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target: Optional[Image] = None, parameters=None, **kwargs): + def initialize(self): """ Initialize each model in this group. Does this by iteratively initializing a model then subtracting it from a copy of the target. - - Args: - target (Optional["Target_Image"]): A Target_Image instance to use as the source for initializing the model parameters on this image. """ - self._param_tuple = None - super().initialize(target=target, parameters=parameters) - - target_copy = target.copy() - for model in self.models.values(): - if not model.is_initialized: - print("Initializing: ", model.name) - model.initialize(target=target_copy, parameters=parameters[model.name]) - target_copy -= model(parameters=parameters[model.name]) + for model in self.models: + config.logger.info(f"Initializing model {model.name}") + model.initialize() - def fit_mask(self) -> torch.Tensor: + def _fit_mask(self) -> torch.Tensor: """Returns a mask for the target image which is the combination of all the fit masks of the sub models. This mask is used when the multiple models in the group model do not completely overlap with each other, thus @@ -157,209 +119,279 @@ def fit_mask(self) -> torch.Tensor: reason to be fit. """ - if isinstance(self.target, Image_List): - mask = tuple(torch.ones_like(submask) for submask in self.target[self.window].mask) - for model in self.models.values(): - model_flat_mask = model.fit_mask() - if isinstance(model.target, Image_List): - for target, window, submask in zip(model.target, model.window, model_flat_mask): - index = self.target.index(target) - group_indices = self.window.window_list[index].get_self_indices(window) - model_indices = window.get_self_indices(self.window.window_list[index]) - mask[index][group_indices] &= submask[model_indices] + subtarget = self.target[self.window] + if isinstance(subtarget, ImageList): + mask = list(backend.ones_like(submask) for submask in subtarget._mask) + for model in self.models: + model_subtarget = model.target[model.window] + model_fit_mask = model._fit_mask() + if isinstance(model_subtarget, ImageList): + for target, submask in zip(model_subtarget, model_fit_mask): + index = subtarget.index(target) + group_indices = subtarget.images[index].get_indices(target.window) + model_indices = target.get_indices(subtarget.images[index].window) + mask[index] = backend.and_at_indices( + mask[index], group_indices, submask[model_indices] + ) else: - index = self.target.index(model.target) - group_indices = self.window.window_list[index].get_self_indices(model.window) - model_indices = model.window.get_self_indices(self.window.window_list[index]) - mask[index][group_indices] &= model_flat_mask[model_indices] + index = subtarget.index(model_subtarget) + group_indices = subtarget.images[index].get_indices(model_subtarget.window) + model_indices = model_subtarget.get_indices(subtarget.images[index].window) + mask[index] = backend.and_at_indices( + mask[index], group_indices, model_fit_mask[model_indices] + ) + mask = tuple(mask) else: - mask = torch.ones_like(self.target[self.window].mask) - for model in self.models.values(): - group_indices = self.window.get_self_indices(model.window) - model_indices = model.window.get_self_indices(self.window) - mask[group_indices] &= model.fit_mask()[model_indices] + mask = backend.ones_like(subtarget._mask) + for model in self.models: + model_subtarget = model.target[model.window] + group_indices = subtarget.get_indices(model.window) + model_indices = model_subtarget.get_indices(subtarget.window) + mask = backend.and_at_indices(mask, group_indices, model._fit_mask()[model_indices]) return mask + def fit_mask(self) -> torch.Tensor: + mask = self._fit_mask() + if isinstance(mask, tuple): + return tuple(backend.transpose(m, 1, 0) for m in mask) + return backend.transpose(mask, 1, 0) + + def match_window(self, image: Union[Image, ImageList], window: Window, model: Model) -> Window: + if isinstance(image, ImageList) and isinstance(model.target, ImageList): + indices = image.match_indices(model.target) + if len(indices) == 0: + raise IndexError + use_window = WindowList(windows=list(image.images[i].window for i in indices)) + elif isinstance(image, ImageList) and isinstance(model.target, Image): + try: + image.index(model.target) + except ValueError: + raise IndexError + use_window = model.window + elif isinstance(image, Image) and isinstance(model.target, ImageList): + try: + i = model.target.index(image) + except ValueError: + raise IndexError + use_window = model.window[i] + elif isinstance(image, Image) and isinstance(model.target, Image): + if image.identity != model.target.identity: + raise IndexError + use_window = window + else: + raise NotImplementedError( + f"Group_Model cannot sample with {type(image)} and {type(model.target)}" + ) + return use_window + + def _ensure_vmap_compatible( + self, image: Union[Image, ImageList], other: Union[Image, ImageList] + ): + if isinstance(image, ImageList): + for img in image.images: + self._ensure_vmap_compatible(img, other) + return + if isinstance(other, ImageList): + for img in other.images: + self._ensure_vmap_compatible(image, img) + return + if image.identity == other.identity: + image += backend.zeros_like(other._data[0, 0]) + + @forward def sample( self, - image: Optional[Image] = None, window: Optional[Window] = None, - parameters: Optional["Parameter_Node"] = None, - ): + ) -> Union[ModelImage, ModelImageList]: """Sample the group model on an image. Produces the flux values for each pixel associated with the models in this group. Each model is called individually and the results are added together in one larger image. - Args: - image (Optional["Model_Image"]): Image to sample on, overrides the windows for each sub model, they will all be evaluated over this entire image. If left as none then each sub model will be evaluated in its window. + **Args:** + - `image` (Optional[ModelImage]): Image to sample on, overrides the windows for each sub model, they will all be evaluated over this entire image. If left as none then each sub model will be evaluated in its window. """ - self._param_tuple = None - if image is None: - sample_window = True - image = self.make_model_image(window=window) - else: - sample_window = False if window is None: - window = image.window - if parameters is None: - parameters = self.parameters - - working_image = image[window].blank_copy() - - for model in self.models.values(): - if window is not None and isinstance(window, Window_List): - indices = self.target.match_indices(model.target) - if isinstance(indices, (tuple, list)): - use_window = Window_List( - window_list=list(window.window_list[ind] for ind in indices) - ) - else: - use_window = window.window_list[indices] - else: - use_window = window - if sample_window: - # Will sample the model fit window then add to the image - working_image += model(window=use_window, parameters=parameters[model.name]) - else: - # Will sample the entire image - model(working_image, window=use_window, parameters=parameters[model.name]) + image = self.target[self.window].model_image() + else: + image = self.target[window].model_image() - image += working_image + for model in self.models: + if window is None: + use_window = model.window + else: + try: + use_window = self.match_window(image, window, model) + except IndexError: + # If the model target is not in the image, skip it + continue + model_image = model(window=model.window & use_window) + self._ensure_vmap_compatible(image, model_image) + image += model_image return image @torch.no_grad() def jacobian( self, - parameters: Optional[torch.Tensor] = None, - as_representation: bool = False, - pass_jacobian: Optional[Jacobian_Image] = None, - window: Optional[Window] = None, - **kwargs, - ): + pass_jacobian: Optional[Union[JacobianImage, JacobianImageList]] = None, + window: Optional[Union[Window, WindowList]] = None, + params=None, + ) -> JacobianImage: """Compute the jacobian for this model. Done by first constructing a full jacobian (Npixels * Nparameters) of zeros then call the jacobian method of each sub model and add it in to the total. - Args: - parameters (Optional[torch.Tensor]): 1D parameter vector to overwrite current values - as_representation (bool): Indicates if the "parameters" argument is in the form of the real values, or as representations in the (-inf,inf) range. Default False - pass_jacobian (Optional["Jacobian_Image"]): A Jacobian image pre-constructed to be passed along instead of constructing new Jacobians + **Args:** + - `pass_jacobian` (Optional[JacobianImage]): A Jacobian image pre-constructed to be passed along instead of constructing new Jacobians + - `window` (Optional[Window]): A window within which to evaluate the jacobian. If not provided, the model's window will be used. + - `params` (Optional[Sequence[Param]]): Parameters to use for the jacobian. If not provided, the model's parameters will be used. """ if window is None: window = self.window - self._param_tuple = None - if parameters is not None: - if as_representation: - self.parameters.vector_set_representation(parameters) - else: - self.parameters.vector_set_values(parameters) + if params is not None: + self.set_values(params) if pass_jacobian is None: jac_img = self.target[window].jacobian_image( - parameters=self.parameters.vector_identities() + parameters=self.build_params_array_identities() ) else: jac_img = pass_jacobian - for model in self.models.values(): - if isinstance(model, Group_Model): - model.jacobian( - as_representation=as_representation, - pass_jacobian=jac_img, - window=window, - ) - else: # fixme, maybe make pass_jacobian be filled internally to each model - jac_img += model.jacobian( - as_representation=as_representation, - pass_jacobian=jac_img, - window=window, - ) + for model in reversed(self.models): + try: + use_window = self.match_window(jac_img, window, model) + except IndexError: + # If the model target is not in the image, skip it + continue + jac_img = model.jacobian( + pass_jacobian=jac_img, + window=use_window & model.window, + ) return jac_img def __iter__(self): - return (mod for mod in self.models.values()) + return (mod for mod in self.models) @property - def psf_mode(self): - return self._psf_mode - - @psf_mode.setter - def psf_mode(self, value): - self._psf_mode = value - if hasattr(self, "models"): - for model in self.models.values(): - model.psf_mode = value - - @property - def target(self): + def target(self) -> Optional[Union[TargetImage, TargetImageList]]: try: return self._target except AttributeError: return None @target.setter - def target(self, tar): - if not (tar is None or isinstance(tar, Target_Image)): + def target(self, tar: Optional[Union[TargetImage, TargetImageList]]): + if not (tar is None or isinstance(tar, (TargetImage, TargetImageList))): raise InvalidTarget("Group_Model target must be a Target_Image instance.") - self._target = tar + try: + del self._target # Remove old target if it exists + except AttributeError: + pass - if hasattr(self, "models"): - if not isinstance(tar, Image_List): - for model in self.models.values(): - if model.target is None: - model.target = tar - elif ( - isinstance(model.target, Image_List) - or model.target.identity != tar.identity - ): - AP_config.ap_logger.warning( - f"Group_Model target does not match model {model.name} target. This may cause issues. Use the same Target_Image object for all relevant models." - ) + self._target = tar - def get_state(self, save_params=True): - """Returns a dictionary with information about the state of the model - and its parameters. + @property + def window(self) -> Optional[Union[Window, WindowList]]: + """The window defines a region on the sky in which this model will be + optimized and typically evaluated. Two models with + non-overlapping windows are in effect independent of each + other. If there is another model with a window that spans both + of them, then they are tenuously connected. + + If not provided, the model will assume a window equal to the + target it is fitting. Note that in this case the window is not + explicitly set to the target window, so if the model is moved + to another target then the fitting window will also change. """ - state = super().get_state(save_params=save_params) - if save_params: - state["parameters"] = self.parameters.get_state() - if "models" not in state: - state["models"] = {} - for model in self.models.values(): - state["models"][model.name] = model.get_state(save_params=False) - return state - - def load(self, filename="AstroPhot.yaml", new_name=None): - """Loads an AstroPhot state file and updates this model with the - loaded parameters. + if self._window is None: + if self.target is None: + raise ValueError( + "This model has no target or window, these must be provided by the user" + ) + return self.target.window + return self._window + + @window.setter + def window(self, window): + if window is None: + self._window = None + elif isinstance(window, (Window, WindowList)): + self._window = window + elif len(window) in [2, 4]: + self._window = Window(window, image=self.target) + else: + raise InvalidWindow(f"Unrecognized window format: {str(window)}") + + def segmentation_map(self) -> ArrayLike: + """Generate a segmentation map for this group model. Each pixel in the + segmentation map is assigned an integer value corresponding to the index + of the sub-model that corresponds to that pixel. The pixels are assigned + based on "relative importance", meaning that for each pixel, the + sub-model which contributes the largest fraction of its own total flux to that + pixel is assigned to it. + + Returns: + ArrayLike: Segmentation map with the same shape as the target image as windowed by the group model window. """ - state = AstroPhot_Model.load(filename) + subtarget = self.target[self.window] + if isinstance(subtarget, ImageList): + raise NotImplementedError( + "Segmentation maps are not currently supported for ImageList targets. Please apply one target at a time." + ) + else: + seg_map = backend.zeros_like(subtarget._data, dtype=backend.int32) - 1 + max_flux_frac = ( + 0.0 * backend.ones_like(subtarget._data) / np.prod(subtarget._data.shape) + ) + for idx, model in enumerate(self.models): + model_image = model() + model_flux_frac = backend.abs(model_image._data) / backend.sum( + backend.abs(model_image._data) + ) + indices = subtarget.get_indices(model.window) + model_flux_frac_full = backend.zeros_like(subtarget._data) + model_flux_frac_full = backend.fill_at_indices( + model_flux_frac_full, indices, model_flux_frac + ) + update_mask = model_flux_frac_full >= max_flux_frac + seg_map = backend.where(update_mask, idx, seg_map) + max_flux_frac = backend.where(update_mask, model_flux_frac_full, max_flux_frac) + return seg_map.T - if new_name is None: - new_name = state["name"] - self.name = new_name + def deblend(self) -> Sequence[TargetImage]: + """Generate deblended images for each sub-model in this group model. + Each deblended image contains for each pixel, the fraction of the total + flux at that pixel which is contributed by that sub-model. - if isinstance(state["parameters"], Parameter_Node): - self.parameters = state["parameters"] + Returns: + Sequence[TargetImage]: List of deblended TargetImage objects for each sub-model. + + """ + deblended_images = [] + subtarget = self.target[self.window] + full_model = self() + if isinstance(subtarget, ImageList): + raise NotImplementedError( + "Deblending is not currently supported for ImageList targets. Please apply one target at a time." + ) else: - self.parameters = Parameter_Node(self.name, state=state["parameters"]) - - for model in state["models"]: - state["models"][model]["parameters"] = self.parameters[model] - for own_model in self.models.values(): - if model == own_model.name: - own_model.load(state["models"][model]) - break - else: - self.add_model( - AstroPhot_Model(name=model, filename=state["models"][model], target=self.target) + for model in self.models: + model_image = model() + subfull_model = full_model[model.window] + subsubtarget = subtarget[model.window].copy( + name=f"deblend_{model.name}_{subtarget.name}" ) - self.update_window() + deblend_data = subsubtarget.data * model_image.data / subfull_model.data + deblend_variance = subsubtarget.variance * model_image.data / subfull_model.data + subsubtarget.data = deblend_data + subsubtarget.variance = deblend_variance + deblended_images.append(subsubtarget) + return deblended_images diff --git a/astrophot/models/group_psf_model.py b/astrophot/models/group_psf_model.py index 54fc93b6..2d861200 100644 --- a/astrophot/models/group_psf_model.py +++ b/astrophot/models/group_psf_model.py @@ -1,27 +1,21 @@ -from typing import Optional - -from .group_model_object import Group_Model -from ..image import PSF_Image -from ..image import PSF_Image, Image, Window, Model_Image, Model_Image_List, Window_List +from .group_model_object import GroupModel +from ..image import PSFImage from ..errors import InvalidTarget -from ..param import Parameter_Node +from ..param import forward -__all__ = ["PSF_Group_Model"] +__all__ = ["PSFGroupModel"] -class PSF_Group_Model(Group_Model): +class PSFGroupModel(GroupModel): + """ + A group of PSF models. Behaves similarly to a `GroupModel`, but specifically designed for PSF models. + """ - model_type = f"psf {Group_Model.model_type}" + _model_type = "psf" usable = True normalize_psf = True - @property - def psf_mode(self): - return "none" - - @psf_mode.setter - def psf_mode(self, value): - pass + _options = ("normalize_psf",) @property def target(self): @@ -31,55 +25,20 @@ def target(self): return None @target.setter - def target(self, tar): - if not (tar is None or isinstance(tar, PSF_Image)): + def target(self, target): + if not (target is None or isinstance(target, PSFImage)): raise InvalidTarget("Group_Model target must be a PSF_Image instance.") - self._target = tar - - if hasattr(self, "models"): - for model in self.models.values(): - model.target = tar - - def sample( - self, - image: Optional[Image] = None, - window: Optional[Window] = None, - parameters: Optional[Parameter_Node] = None, - ): - # Note: same as group model except working_image is normalized at the end - self._param_tuple = None - if image is None: - sample_window = True - image = self.make_model_image(window=window) - else: - sample_window = False - if window is None: - window = image.window - - working_image = image[window].blank_copy() - - if parameters is None: - parameters = self.parameters + try: + del self._target # Remove old target if it exists + except AttributeError: + pass - for model in self.models.values(): - if window is not None and isinstance(window, Window_List): - indices = self.target.match_indices(model.target) - if isinstance(indices, (tuple, list)): - use_window = Window_List( - window_list=list(window.window_list[ind] for ind in indices) - ) - else: - use_window = window.window_list[indices] - else: - use_window = window - if sample_window: - # Will sample the model fit window then add to the image - working_image += model(window=use_window, parameters=parameters[model.name]) - else: - # Will sample the entire image - model(working_image, window=use_window, parameters=parameters[model.name]) + self._target = target + @forward + def sample(self, *args, **kwargs): + """Sample the PSF group model on the target image.""" + psf_img = super().sample(*args, **kwargs) if self.normalize_psf: - working_image.data /= working_image.data.sum() - image += working_image - return image + psf_img.normalize() + return psf_img diff --git a/astrophot/models/king.py b/astrophot/models/king.py new file mode 100644 index 00000000..a565d406 --- /dev/null +++ b/astrophot/models/king.py @@ -0,0 +1,60 @@ +from .galaxy_model_object import GalaxyModel +from .psf_model_object import PSFModel +from .mixins import ( + KingMixin, + RadialMixin, + WedgeMixin, + RayMixin, + SuperEllipseMixin, + FourierEllipseMixin, + WarpMixin, + iKingMixin, +) +from ..utils.decorators import combine_docstrings + + +__all__ = ( + "KingGalaxy", + "KingPSF", + "KingSuperEllipse", + "KingFourierEllipse", + "KingWarp", + "KingRay", + "KingWedge", +) + + +@combine_docstrings +class KingGalaxy(KingMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class KingPSF(KingMixin, RadialMixin, PSFModel): + _parameter_specs = {"I0": {"units": "flux/arcsec^2", "value": 1.0, "dynamic": False}} + usable = True + + +@combine_docstrings +class KingSuperEllipse(KingMixin, SuperEllipseMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class KingFourierEllipse(KingMixin, FourierEllipseMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class KingWarp(KingMixin, WarpMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class KingRay(iKingMixin, RayMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class KingWedge(iKingMixin, WedgeMixin, GalaxyModel): + usable = True diff --git a/astrophot/models/mixins/__init__.py b/astrophot/models/mixins/__init__.py new file mode 100644 index 00000000..884033d5 --- /dev/null +++ b/astrophot/models/mixins/__init__.py @@ -0,0 +1,45 @@ +from .brightness import RadialMixin, WedgeMixin, RayMixin +from .transform import ( + InclinedMixin, + SuperEllipseMixin, + FourierEllipseMixin, + WarpMixin, + TruncationMixin, +) +from .sersic import SersicMixin, iSersicMixin +from .exponential import ExponentialMixin, iExponentialMixin +from .moffat import MoffatMixin, iMoffatMixin +from .ferrer import FerrerMixin, iFerrerMixin +from .king import KingMixin, iKingMixin +from .gaussian import GaussianMixin, iGaussianMixin +from .nuker import NukerMixin, iNukerMixin +from .spline import SplineMixin, iSplineMixin +from .sample import SampleMixin + +__all__ = ( + "RadialMixin", + "WedgeMixin", + "RayMixin", + "SuperEllipseMixin", + "FourierEllipseMixin", + "WarpMixin", + "TruncationMixin", + "InclinedMixin", + "SersicMixin", + "iSersicMixin", + "ExponentialMixin", + "iExponentialMixin", + "MoffatMixin", + "iMoffatMixin", + "FerrerMixin", + "iFerrerMixin", + "KingMixin", + "iKingMixin", + "GaussianMixin", + "iGaussianMixin", + "NukerMixin", + "iNukerMixin", + "SplineMixin", + "iSplineMixin", + "SampleMixin", +) diff --git a/astrophot/models/mixins/brightness.py b/astrophot/models/mixins/brightness.py new file mode 100644 index 00000000..168ab77c --- /dev/null +++ b/astrophot/models/mixins/brightness.py @@ -0,0 +1,120 @@ +from torch import Tensor +from ...backend_obj import backend, ArrayLike +import numpy as np + +from ...param import forward + + +class RadialMixin: + """This model defines its `brightness(x,y)` function using a radial model. + Thus the brightness is instead defined as`radial_model(R)` + + More specifically the function is: + + $$x, y = {\\rm transform\\_coordinates}(x, y)$$ + $$R = {\\rm radius\\_metric}(x, y)$$ + $$I(x, y) = {\\rm radial\\_model}(R)$$ + + The `transform_coordinates` function depends on the model. In its simplest + form it simply subtracts the center of the model to re-center the coordinates. + + The `radius_metric` function is also model dependent, in its simplest form + this is just $R = \\sqrt{x^2 + y^2}$. + """ + + @forward + def brightness(self, x: ArrayLike, y: ArrayLike) -> ArrayLike: + """ + Calculate the brightness at a given point (x, y) based on radial distance from the center. + """ + x, y = self.transform_coordinates(x, y) + return self.radial_model(self.radius_metric(x, y)) + + +class WedgeMixin: + """Defines a model with multiple profiles that form wedges projected from the center. + + model which defines multiple radial models separately along some number of + wedges projected from the center. These wedges have sharp transitions along boundary angles theta. + + **Options:** + - `symmetric`: If True, the model will have symmetry for rotations of pi radians + and each ray will appear twice on the sky on opposite sides of the model. + If False, each ray is independent. + - `segments`: The number of segments to divide the model into. This controls + how many rays are used in the model. The default is 2 + """ + + _model_type = "wedge" + _options = ("segments", "symmetric") + + def __init__(self, *args, symmetric: bool = True, segments: int = 2, **kwargs): + super().__init__(*args, **kwargs) + self.symmetric = symmetric + self.segments = segments + + def polar_model(self, R: ArrayLike, T: ArrayLike) -> ArrayLike: + model = backend.zeros_like(R) + cycle = np.pi if self.symmetric else 2 * np.pi + w = cycle / self.segments + angles = (T + w / 2) % cycle + v = w * np.arange(self.segments) + for s in range(self.segments): + indices = (angles >= v[s]) & (angles < (v[s] + w)) + model = backend.add_at_indices(model, indices, self.iradial_model(s, R[indices])) + return model + + def brightness(self, x: Tensor, y: Tensor) -> Tensor: + x, y = self.transform_coordinates(x, y) + return self.polar_model(self.radius_metric(x, y), self.angular_metric(x, y)) + + +class RayMixin: + """Defines a model with multiple profiles along rays projected from the center. + + model which defines multiple radial models separately along some number of + rays projected from the center. These rays smoothly transition from one to + another along angles theta. The ray transition uses a cosine smoothing + function which depends on the number of rays, for example with two rays the + brightness would be: + + $$I(R,\\theta) = I_1(R)*\\cos(\\theta \\% \\pi) + I_2(R)*\\cos((\\theta + \\pi/2) \\% \\pi)$$ + + For $\\theta = 0$ the brightness comes entirely from `I_1` while for $\\theta = \\pi/2$ + the brightness comes entirely from `I_2`. + + **Options:** + - `symmetric`: If True, the model will have symmetry for rotations of pi radians + and each ray will appear twice on the sky on opposite sides of the model. + If False, each ray is independent. + - `segments`: The number of segments to divide the model into. This controls + how many rays are used in the model. The default is 2 + """ + + _model_type = "ray" + _options = ("symmetric", "segments") + + def __init__(self, *args, symmetric: bool = True, segments: int = 2, **kwargs): + super().__init__(*args, **kwargs) + self.symmetric = symmetric + self.segments = segments + + def polar_model(self, R: ArrayLike, T: ArrayLike) -> ArrayLike: + model = backend.zeros_like(R) + weight = backend.zeros_like(R) + cycle = np.pi if self.symmetric else 2 * np.pi + w = cycle / self.segments + v = w * np.arange(self.segments) + for s in range(self.segments): + angles = (T + cycle / 2 - v[s]) % cycle - cycle / 2 + indices = (angles >= -w) & (angles < w) + weights = (backend.cos(angles[indices] * self.segments) + 1) / 2 + model = backend.add_at_indices( + model, indices, weights * self.iradial_model(s, R[indices]) + ) + weight = backend.add_at_indices(weight, indices, weights) + return model / weight + + def brightness(self, x: ArrayLike, y: ArrayLike) -> ArrayLike: + x, y = self.transform_coordinates(x, y) + return self.polar_model(self.radius_metric(x, y), self.angular_metric(x, y)) diff --git a/astrophot/models/mixins/exponential.py b/astrophot/models/mixins/exponential.py new file mode 100644 index 00000000..3e578d0e --- /dev/null +++ b/astrophot/models/mixins/exponential.py @@ -0,0 +1,96 @@ +import torch + +from ...param import forward +from ...backend_obj import ArrayLike +from ...utils.decorators import ignore_numpy_warnings +from .._shared_methods import parametric_initialize, parametric_segment_initialize +from ...utils.parametric_profiles import exponential_np +from .. import func + + +def _x0_func(model_params, R, F): + return R[4], 10 ** F[4] + + +class ExponentialMixin: + """Exponential radial light profile. + + An exponential is a classical radial model used in many contexts. The + functional form of the exponential profile is defined as: + + $$I(R) = I_e * \\exp\\left(- b_1\\left(\\frac{R}{R_e} - 1\\right)\\right)$$ + + Ie is the brightness at the effective radius, and Re is the effective + radius. $b_1$ is a constant that ensures $I_e$ is the brightness at $R_e$. + + **Parameters:** + - `Re`: effective radius in arcseconds + - `Ie`: effective surface density in flux/arcsec^2 + """ + + _model_type = "exponential" + _parameter_specs = { + "Re": {"units": "arcsec", "valid": (0, None), "shape": (), "dynamic": True}, + "Ie": {"units": "flux/arcsec^2", "valid": (0, None), "shape": (), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_initialize( + self, + self.target[self.window], + exponential_np, + ("Re", "Ie"), + _x0_func, + ) + + @forward + def radial_model(self, R: ArrayLike, Re: ArrayLike, Ie: ArrayLike) -> ArrayLike: + return func.exponential(R, Re, Ie) + + +class iExponentialMixin: + """Exponential radial light profile. + + An exponential is a classical radial model used in many contexts. The + functional form of the exponential profile is defined as: + + $$I(R) = I_e * \\exp\\left(- b_1\\left(\\frac{R}{R_e} - 1\\right)\\right)$$ + + $I_e$ is the brightness at the effective radius, and $R_e$ is the effective + radius. $b_1$ is a constant that ensures $I_e$ is the brightness at $R_e$. + + `Re` and `Ie` are batched by their first dimension, allowing for multiple + exponential profiles to be defined at once. + + **Parameters:** + - `Re`: effective radius in arcseconds + - `Ie`: effective surface density in flux/arcsec^2 + """ + + _model_type = "exponential" + _parameter_specs = { + "Re": {"units": "arcsec", "valid": (0, None), "dynamic": True}, + "Ie": {"units": "flux/arcsec^2", "valid": (0, None), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_segment_initialize( + model=self, + target=self.target[self.window], + prof_func=exponential_np, + params=("Re", "Ie"), + x0_func=_x0_func, + segments=self.segments, + ) + + @forward + def iradial_model(self, i: int, R: ArrayLike, Re: ArrayLike, Ie: ArrayLike) -> ArrayLike: + return func.exponential(R, Re[i], Ie[i]) diff --git a/astrophot/models/mixins/ferrer.py b/astrophot/models/mixins/ferrer.py new file mode 100644 index 00000000..47cb87f3 --- /dev/null +++ b/astrophot/models/mixins/ferrer.py @@ -0,0 +1,118 @@ +import torch + +from ...param import forward +from ...backend_obj import ArrayLike +from ...utils.decorators import ignore_numpy_warnings +from ...utils.parametric_profiles import ferrer_np +from .._shared_methods import parametric_initialize, parametric_segment_initialize +from .. import func + + +def x0_func(model_params, R, F): + return R[5], 1, 1, 10 ** F[0] + + +class FerrerMixin: + """Modified Ferrer radial light profile (Binney & Tremaine 1987). + + This model has a relatively flat brightness core and then a truncation. It + is used in specialized circumstances such as fitting the bar of a galaxy. + The functional form of the Modified Ferrer profile is defined as: + + $$I(R) = I_0 \\left(1 - \\left(\\frac{R}{r_{\\rm out}}\\right)^{2-\\beta}\\right)^{\\alpha}$$ + + where `rout` is the outer truncation radius, `alpha` controls the steepness + of the truncation, `beta` controls the shape, and `I0` is the intensity at + the center of the profile. + + **Parameters:** + - `rout`: Outer truncation radius in arcseconds. + - `alpha`: Inner slope parameter. + - `beta`: Outer slope parameter. + - `I0`: Intensity at the center of the profile in flux/arcsec^2 + """ + + _model_type = "ferrer" + _parameter_specs = { + "rout": {"units": "arcsec", "valid": (0.0, None), "shape": (), "dynamic": True}, + "alpha": {"units": "unitless", "valid": (0, 10), "shape": (), "dynamic": True}, + "beta": {"units": "unitless", "valid": (0, 2), "shape": (), "dynamic": True}, + "I0": {"units": "flux/arcsec^2", "valid": (0, None), "shape": (), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_initialize( + self, + self.target[self.window], + ferrer_np, + ("rout", "alpha", "beta", "I0"), + x0_func, + ) + + @forward + def radial_model( + self, R: ArrayLike, rout: ArrayLike, alpha: ArrayLike, beta: ArrayLike, I0: ArrayLike + ) -> ArrayLike: + return func.ferrer(R, rout, alpha, beta, I0) + + +class iFerrerMixin: + """Modified Ferrer radial light profile (Binney & Tremaine 1987). + + This model has a relatively flat brightness core and then a truncation. It + is used in specialized circumstances such as fitting the bar of a galaxy. + The functional form of the Modified Ferrer profile is defined as: + + $$I(R) = I_0 \\left(1 - \\left(\\frac{R}{r_{\\rm out}}\\right)^{2-\\beta}\\right)^{\\alpha}$$ + + where `rout` is the outer truncation radius, `alpha` controls the steepness + of the truncation, `beta` controls the shape, and `I0` is the intensity at + the center of the profile. + + `rout`, `alpha`, `beta`, and `I0` are batched by their first dimension, + allowing for multiple Ferrer profiles to be defined at once. + + **Parameters:** + - `rout`: Outer truncation radius in arcseconds. + - `alpha`: Inner slope parameter. + - `beta`: Outer slope parameter. + - `I0`: Intensity at the center of the profile in flux/arcsec^2 + """ + + _model_type = "ferrer" + _parameter_specs = { + "rout": {"units": "arcsec", "valid": (0.0, None), "shape": (), "dynamic": True}, + "alpha": {"units": "unitless", "valid": (0, 10), "shape": (), "dynamic": True}, + "beta": {"units": "unitless", "valid": (0, 2), "shape": (), "dynamic": True}, + "I0": {"units": "flux/arcsec^2", "valid": (0.0, None), "shape": (), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_segment_initialize( + model=self, + target=self.target[self.window], + prof_func=ferrer_np, + params=("rout", "alpha", "beta", "I0"), + x0_func=x0_func, + segments=self.segments, + ) + + @forward + def iradial_model( + self, + i: int, + R: ArrayLike, + rout: ArrayLike, + alpha: ArrayLike, + beta: ArrayLike, + I0: ArrayLike, + ) -> ArrayLike: + return func.ferrer(R, rout[i], alpha[i], beta[i], I0[i]) diff --git a/astrophot/models/mixins/gaussian.py b/astrophot/models/mixins/gaussian.py new file mode 100644 index 00000000..f6b57921 --- /dev/null +++ b/astrophot/models/mixins/gaussian.py @@ -0,0 +1,97 @@ +import torch + +from ...param import forward +from ...backend_obj import ArrayLike +from ...utils.decorators import ignore_numpy_warnings +from .._shared_methods import parametric_initialize, parametric_segment_initialize +from ...utils.parametric_profiles import gaussian_np +from .. import func + + +def _x0_func(model_params, R, F): + return R[4], 10 ** F[0] + + +class GaussianMixin: + """Gaussian radial light profile. + + The Gaussian profile is a simple and widely used model for extended objects. + The functional form of the Gaussian profile is defined as: + + $$I(R) = \\frac{{\\rm flux}}{\\sqrt{2\\pi}\\sigma} \\exp(-R^2 / (2 \\sigma^2))$$ + + where `I_0` is the intensity at the center of the profile and `sigma` is the + standard deviation which controls the width of the profile. + + **Parameters:** + - `sigma`: Standard deviation of the Gaussian profile in arcseconds. + - `flux`: Total flux of the Gaussian profile. + """ + + _model_type = "gaussian" + _parameter_specs = { + "sigma": {"units": "arcsec", "valid": (0, None), "shape": (), "dynamic": True}, + "flux": {"units": "flux", "valid": (0, None), "shape": (), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_initialize( + self, + self.target[self.window], + gaussian_np, + ("sigma", "flux"), + _x0_func, + ) + + @forward + def radial_model(self, R: ArrayLike, sigma: ArrayLike, flux: ArrayLike) -> ArrayLike: + return func.gaussian(R, sigma, flux) + + +class iGaussianMixin: + """Gaussian radial light profile. + + The Gaussian profile is a simple and widely used model for extended objects. + The functional form of the Gaussian profile is defined as: + + $$I(R) = \\frac{{\\rm flux}}{\\sqrt{2\\pi}\\sigma} \\exp(-R^2 / (2 \\sigma^2))$$ + + where `sigma` is the standard deviation which controls the width of the + profile and `flux` gives the total flux of the profile (assuming no + perturbations). + + `sigma` and `flux` are batched by their first dimension, allowing for + multiple Gaussian profiles to be defined at once. + + **Parameters:** + - `sigma`: Standard deviation of the Gaussian profile in arcseconds. + - `flux`: Total flux of the Gaussian profile. + """ + + _model_type = "gaussian" + _parameter_specs = { + "sigma": {"units": "arcsec", "valid": (0, None), "dynamic": True}, + "flux": {"units": "flux", "valid": (0, None), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_segment_initialize( + model=self, + target=self.target[self.window], + prof_func=gaussian_np, + params=("sigma", "flux"), + x0_func=_x0_func, + segments=self.segments, + ) + + @forward + def iradial_model(self, i: int, R: ArrayLike, sigma: ArrayLike, flux: ArrayLike) -> ArrayLike: + return func.gaussian(R, sigma[i], flux[i]) diff --git a/astrophot/models/mixins/king.py b/astrophot/models/mixins/king.py new file mode 100644 index 00000000..eea81306 --- /dev/null +++ b/astrophot/models/mixins/king.py @@ -0,0 +1,122 @@ +import torch +import numpy as np + +from ...param import forward +from ...backend_obj import ArrayLike +from ...utils.decorators import ignore_numpy_warnings +from ...utils.parametric_profiles import king_np +from .._shared_methods import parametric_initialize, parametric_segment_initialize +from .. import func + + +def x0_func(model_params, R, F): + return R[2], R[5], 2, 10 ** F[0] + + +class KingMixin: + """Empirical King radial light profile (Elson 1999). + + Often used for star clusters. By default the profile has `alpha = 2` but we + allow the parameter to vary freely for fitting. The functional form of the + Empirical King profile is defined as: + + $$I(R) = I_0\\left[\\frac{1}{(1 + (R/R_c)^2)^{1/\\alpha}} - \\frac{1}{(1 + (R_t/R_c)^2)^{1/\\alpha}}\\right]^{\\alpha}\\left[1 - \\frac{1}{(1 + (R_t/R_c)^2)^{1/\\alpha}}\\right]^{-\\alpha}$$ + + where `R_c` is the core radius, `R_t` is the truncation radius, and `I_0` is + the intensity at the center of the profile. `alpha` is the concentration + index which controls the shape of the profile. + + **Parameters:** + - `Rc`: core radius + - `Rt`: truncation radius + - `alpha`: concentration index which controls the shape of the brightness profile + - `I0`: intensity at the center of the profile + """ + + _model_type = "king" + _parameter_specs = { + "Rc": {"units": "arcsec", "valid": (0.0, None), "shape": (), "dynamic": True}, + "Rt": {"units": "arcsec", "valid": (0.0, None), "shape": (), "dynamic": True}, + "alpha": { + "units": "unitless", + "valid": (0, 10), + "shape": (), + "value": 2.0, + "dynamic": False, + }, + "I0": {"units": "flux/arcsec^2", "valid": (0, None), "shape": (), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_initialize( + self, + self.target[self.window], + lambda r, *x: king_np(r, x[0], x[1], 2.0, x[2]), + ("Rc", "Rt", "I0"), + x0_func, + ) + + @forward + def radial_model( + self, R: ArrayLike, Rc: ArrayLike, Rt: ArrayLike, alpha: ArrayLike, I0: ArrayLike + ) -> ArrayLike: + return func.king(R, Rc, Rt, alpha, I0) + + +class iKingMixin: + """Empirical King radial light profile (Elson 1999). + + Often used for star clusters. By default the profile has `alpha = 2` but we + allow the parameter to vary freely for fitting. The functional form of the + Empirical King profile is defined as: + + $$I(R) = I_0\\left[\\frac{1}{(1 + (R/R_c)^2)^{1/\\alpha}} - \\frac{1}{(1 + (R_t/R_c)^2)^{1/\\alpha}}\\right]^{\\alpha}\\left[1 - \\frac{1}{(1 + (R_t/R_c)^2)^{1/\\alpha}}\\right]^{-\\alpha}$$ + + where `R_c` is the core radius, `R_t` is the truncation radius, and `I_0` is + the intensity at the center of the profile. `alpha` is the concentration + index which controls the shape of the profile. + + `Rc`, `Rt`, `alpha`, and `I0` are batched by their first dimension, allowing + for multiple King profiles to be defined at once. + + **Parameters:** + - `Rc`: core radius + - `Rt`: truncation radius + - `alpha`: concentration index which controls the shape of the brightness profile + - `I0`: intensity at the center of the profile + """ + + _model_type = "king" + _parameter_specs = { + "Rc": {"units": "arcsec", "valid": (0.0, None), "dynamic": True}, + "Rt": {"units": "arcsec", "valid": (0.0, None), "dynamic": True}, + "alpha": {"units": "unitless", "valid": (0, 10), "dynamic": False}, + "I0": {"units": "flux/arcsec^2", "valid": (0, None), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + if not self.alpha.initialized: + self.alpha.value = 2.0 * np.ones(self.segments) + + parametric_segment_initialize( + model=self, + target=self.target[self.window], + prof_func=lambda r, *x: king_np(r, x[0], x[1], 2.0, x[2]), + params=("Rc", "Rt", "I0"), + x0_func=x0_func, + segments=self.segments, + ) + + @forward + def iradial_model( + self, i: int, R: ArrayLike, Rc: ArrayLike, Rt: ArrayLike, alpha: ArrayLike, I0: ArrayLike + ) -> ArrayLike: + return func.king(R, Rc[i], Rt[i], alpha[i], I0[i]) diff --git a/astrophot/models/mixins/moffat.py b/astrophot/models/mixins/moffat.py new file mode 100644 index 00000000..eef7f2b6 --- /dev/null +++ b/astrophot/models/mixins/moffat.py @@ -0,0 +1,102 @@ +import torch + +from ...param import forward +from ...backend_obj import ArrayLike +from ...utils.decorators import ignore_numpy_warnings +from .._shared_methods import parametric_initialize, parametric_segment_initialize +from ...utils.parametric_profiles import moffat_np +from .. import func + + +def _x0_func(model_params, R, F): + return 2.0, R[4], 10 ** F[0] + + +class MoffatMixin: + """Moffat radial light profile (Moffat 1969). + + The moffat profile gives a good representation of the general structure of + PSF functions for ground based data. It can also be used to fit extended + objects. The functional form of the Moffat profile is defined as: + + $$I(R) = \\frac{I_0}{(1 + (R/R_d)^2)^n}$$ + + `n` is the concentration index which controls the shape of the profile. + + **Parameters:** + - `n`: Concentration index which controls the shape of the brightness profile + - `Rd`: Scale length radius + - `I0`: Intensity at the center of the profile + """ + + _model_type = "moffat" + _parameter_specs = { + "n": {"units": "none", "valid": (0.1, 10), "shape": (), "dynamic": True}, + "Rd": {"units": "arcsec", "valid": (0, None), "shape": (), "dynamic": True}, + "I0": {"units": "flux/arcsec^2", "valid": (0, None), "shape": (), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_initialize( + self, + self.target[self.window], + moffat_np, + ("n", "Rd", "I0"), + _x0_func, + ) + + @forward + def radial_model(self, R: ArrayLike, n: ArrayLike, Rd: ArrayLike, I0: ArrayLike) -> ArrayLike: + return func.moffat(R, n, Rd, I0) + + +class iMoffatMixin: + """Moffat radial light profile (Moffat 1969). + + The moffat profile gives a good representation of the general structure of + PSF functions for ground based data. It can also be used to fit extended + objects. The functional form of the Moffat profile is defined as: + + $$I(R) = \\frac{I_0}{(1 + (R/R_d)^2)^n}$$ + + `n` is the concentration index which controls the shape of the profile. + + `n`, `Rd`, and `I0` are batched by their first dimension, allowing for + multiple Moffat profiles to be defined at once. + + **Parameters:** + - `n`: Concentration index which controls the shape of the brightness profile + - `Rd`: Scale length radius + - `I0`: Intensity at the center of the profile + """ + + _model_type = "moffat" + _parameter_specs = { + "n": {"units": "none", "valid": (0.1, 10), "dynamic": True}, + "Rd": {"units": "arcsec", "valid": (0, None), "dynamic": True}, + "I0": {"units": "flux/arcsec^2", "valid": (0, None), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_segment_initialize( + model=self, + target=self.target[self.window], + prof_func=moffat_np, + params=("n", "Rd", "I0"), + x0_func=_x0_func, + segments=self.segments, + ) + + @forward + def iradial_model( + self, i: int, R: ArrayLike, n: ArrayLike, Rd: ArrayLike, I0: ArrayLike + ) -> ArrayLike: + return func.moffat(R, n[i], Rd[i], I0[i]) diff --git a/astrophot/models/mixins/nuker.py b/astrophot/models/mixins/nuker.py new file mode 100644 index 00000000..36d26994 --- /dev/null +++ b/astrophot/models/mixins/nuker.py @@ -0,0 +1,127 @@ +import torch + +from ...param import forward +from ...backend_obj import ArrayLike +from ...utils.decorators import ignore_numpy_warnings +from .._shared_methods import parametric_initialize, parametric_segment_initialize +from ...utils.parametric_profiles import nuker_np +from .. import func + + +def _x0_func(model_params, R, F): + return R[4], 10 ** F[4], 1.0, 2.0, 0.5 + + +class NukerMixin: + """Nuker radial light profile (Lauer et al. 1995). + + This is a classic profile used widely in galaxy modelling. The functional + form of the Nuker profile is defined as: + + $$I(R) = I_b2^{\\frac{\\beta - \\gamma}{\\alpha}}\\left(\\frac{R}{R_b}\\right)^{-\\gamma}\\left[1 + \\left(\\frac{R}{R_b}\\right)^{\\alpha}\\right]^{\\frac{\\gamma-\\beta}{\\alpha}}$$ + + It is effectively a double power law profile. $\\gamma$ gives the inner + slope, $\\beta$ gives the outer slope, $\\alpha$ is somewhat degenerate with + the other slopes. + + **Parameters:** + - `Rb`: scale length radius + - `Ib`: intensity at the scale length + - `alpha`: sharpness of transition between power law slopes + - `beta`: outer power law slope + - `gamma`: inner power law slope + """ + + _model_type = "nuker" + _parameter_specs = { + "Rb": {"units": "arcsec", "valid": (0, None), "shape": (), "dynamic": True}, + "Ib": {"units": "flux/arcsec^2", "valid": (0, None), "shape": (), "dynamic": True}, + "alpha": {"units": "none", "valid": (0, None), "shape": (), "dynamic": True}, + "beta": {"units": "none", "valid": (0, None), "shape": (), "dynamic": True}, + "gamma": {"units": "none", "shape": (), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_initialize( + self, + self.target[self.window], + nuker_np, + ("Rb", "Ib", "alpha", "beta", "gamma"), + _x0_func, + ) + + @forward + def radial_model( + self, + R: ArrayLike, + Rb: ArrayLike, + Ib: ArrayLike, + alpha: ArrayLike, + beta: ArrayLike, + gamma: ArrayLike, + ) -> ArrayLike: + return func.nuker(R, Rb, Ib, alpha, beta, gamma) + + +class iNukerMixin: + """Nuker radial light profile (Lauer et al. 1995). + + This is a classic profile used widely in galaxy modelling. The functional + form of the Nuker profile is defined as: + + $$I(R) = I_b2^{\\frac{\\beta - \\gamma}{\\alpha}}\\left(\\frac{R}{R_b}\\right)^{-\\gamma}\\left[1 + \\left(\\frac{R}{R_b}\\right)^{\\alpha}\\right]^{\\frac{\\gamma-\\beta}{\\alpha}}$$ + + It is effectively a double power law profile. $\\gamma$ gives the inner + slope, $\\beta$ gives the outer slope, $\\alpha$ is somewhat degenerate with + the other slopes. + + `Rb`, `Ib`, `alpha`, `beta`, and `gamma` are batched by their first + dimension, allowing for multiple Nuker profiles to be defined at once. + + **Parameters:** + - `Rb`: scale length radius + - `Ib`: intensity at the scale length + - `alpha`: sharpness of transition between power law slopes + - `beta`: outer power law slope + - `gamma`: inner power law slope + """ + + _model_type = "nuker" + _parameter_specs = { + "Rb": {"units": "arcsec", "valid": (0, None), "dynamic": True}, + "Ib": {"units": "flux/arcsec^2", "valid": (0, None), "dynamic": True}, + "alpha": {"units": "none", "valid": (0, None), "dynamic": True}, + "beta": {"units": "none", "valid": (0, None), "dynamic": True}, + "gamma": {"units": "none", "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_segment_initialize( + model=self, + target=self.target[self.window], + prof_func=nuker_np, + params=("Rb", "Ib", "alpha", "beta", "gamma"), + x0_func=_x0_func, + segments=self.segments, + ) + + @forward + def iradial_model( + self, + i: int, + R: ArrayLike, + Rb: ArrayLike, + Ib: ArrayLike, + alpha: ArrayLike, + beta: ArrayLike, + gamma: ArrayLike, + ) -> ArrayLike: + return func.nuker(R, Rb[i], Ib[i], alpha[i], beta[i], gamma[i]) diff --git a/astrophot/models/mixins/sample.py b/astrophot/models/mixins/sample.py new file mode 100644 index 00000000..c33e9dcf --- /dev/null +++ b/astrophot/models/mixins/sample.py @@ -0,0 +1,243 @@ +from typing import Optional, Literal + +import numpy as np + +from ...param import forward +from ...backend_obj import backend, ArrayLike +from ... import config +from ...image import Image, Window, JacobianImage +from .. import func +from ...errors import SpecificationConflict + + +class SampleMixin: + """ + **Options:** + - `sampling_mode`: The method used to sample the model in image pixels. Options are: + - `auto`: Automatically choose the sampling method based on the image size. + - `midpoint`: Use midpoint sampling, evaluate the brightness at the center of each pixel. + - `simpsons`: Use Simpson's rule for sampling integrating each pixel. + - `quad:x`: Use quadrature sampling with order x, where x is a positive integer to integrate each pixel. + - `jacobian_maxparams`: The maximum number of parameters before the Jacobian will be broken into smaller chunks. This is helpful for limiting the memory requirements to build a model. + - `jacobian_maxpixels`: The maximum number of pixels before the Jacobian will be broken into smaller chunks. This is helpful for limiting the memory requirements to build a model. + - `integrate_mode`: The method used to select pixels to integrate further where the model varies significantly. Options are: + - `none`: No extra integration is performed (beyond the sampling_mode). + - `bright`: Select the brightest pixels for further integration. + - `threshold`: Select pixels which show signs of significant higher order derivatives. + - `integrate_tolerance`: The tolerance for selecting a pixel in the integration method. This is the total flux fraction that is integrated over the image. + - `integrate_fraction`: The fraction of the pixels to super sample during integration. + - `integrate_max_depth`: The maximum depth of the integration method. + - `integrate_gridding`: The gridding used for the integration method to super-sample a pixel at each iteration. + - `integrate_quad_order`: The order of the quadrature used for the integration method on the super sampled pixels. + """ + + # Method for initial sampling of model + sampling_mode = "auto" # auto (choose based on image size), midpoint, simpsons, quad:x (where x is a positive integer) + + # Maximum size of parameter list before jacobian will be broken into smaller chunks, this is helpful for limiting the memory requirements to build a model, lower jacobian_chunksize is slower but uses less memory + jacobian_maxparams = 10 + jacobian_maxpixels = 1000**2 + integrate_mode = "bright" # none, bright, curvature + integrate_fraction = 0.05 # fraction of the pixels to super sample + integrate_max_depth = 2 + integrate_gridding = 5 + integrate_quad_order = 3 + + _options = ( + "sampling_mode", + "jacobian_maxparams", + "jacobian_maxpixels", + "integrate_mode", + "integrate_fraction", + "integrate_max_depth", + "integrate_gridding", + "integrate_quad_order", + ) + + @forward + def _bright_integrate(self, sample: ArrayLike, image: Image) -> ArrayLike: + i, j = image.pixel_center_meshgrid() + sample = func.bright_integrate( + sample, + i, + j, + lambda i, j: self.brightness(*image.pixel_to_plane(i, j)), + scale=image.base_scale, + bright_frac=self.integrate_fraction, + quad_order=self.integrate_quad_order, + gridding=self.integrate_gridding, + max_depth=self.integrate_max_depth, + ) + return sample + + @forward + def _curvature_integrate(self, sample: ArrayLike, image: Image) -> ArrayLike: + i, j = image.pixel_center_meshgrid() + kernel = func.curvature_kernel(config.DTYPE, config.DEVICE) + curvature = ( + backend.abs( + backend.pad( + backend.conv2d( + sample.reshape(1, 1, *sample.shape), + kernel.reshape(1, 1, *kernel.shape), + padding="valid", + ), + (0, 0, 0, 0, 1, 1, 1, 1), + mode="replicate", + ) + ) + .squeeze(0) + .squeeze(0) + ) + N = max(1, int(np.prod(image.data.shape) * self.integrate_fraction)) + select = backend.topk(curvature.flatten(), N)[1] + + sample_flat = sample.flatten() + sample_flat = backend.fill_at_indices( + sample_flat, + select, + func.recursive_quad_integrate( + i.flatten()[select], + j.flatten()[select], + lambda i, j: self.brightness(*image.pixel_to_plane(i, j)), + scale=image.base_scale, + curve_frac=self.integrate_fraction, + quad_order=self.integrate_quad_order, + gridding=self.integrate_gridding, + max_depth=self.integrate_max_depth, + ), + ) + return sample_flat.reshape(sample.shape) + + @forward + def sample_image(self, image: Image) -> ArrayLike: + if self.sampling_mode == "auto": + N = np.prod(image._data.shape[:2]) + if N <= 100: + sampling_mode = "quad:5" + elif N <= 10000: + sampling_mode = "simpsons" + else: + sampling_mode = "midpoint" + else: + sampling_mode = self.sampling_mode + if sampling_mode == "midpoint": + x, y = image.coordinate_center_meshgrid() + res = self.brightness(x, y) + sample = func.pixel_center_integrator(res) + elif sampling_mode == "simpsons": + x, y = image.coordinate_simpsons_meshgrid() + res = self.brightness(x, y) + sample = func.pixel_simpsons_integrator(res) + elif sampling_mode.startswith("quad:"): + order = int(self.sampling_mode.split(":")[1]) + i, j, w = image.pixel_quad_meshgrid(order=order) + x, y = image.pixel_to_plane(i, j) + res = self.brightness(x, y) + sample = func.pixel_quad_integrator(res, w) + else: + raise SpecificationConflict( + f"Unknown sampling mode {self.sampling_mode} for model {self.name}" + ) + if self.integrate_mode == "curvature": + sample = self._curvature_integrate(sample, image) + elif self.integrate_mode == "bright": + sample = self._bright_integrate(sample, image) + elif self.integrate_mode != "none": + raise SpecificationConflict( + f"Unknown integrate mode {self.integrate_mode} for model {self.name}" + ) + return sample + + def _jacobian( + self, window: Window, params_pre: ArrayLike, params: ArrayLike, params_post: ArrayLike + ) -> ArrayLike: + # return jacfwd( # this should be more efficient, but the trace overhead is too high + # lambda x: self.sample( + # window=window, params=torch.cat((params_pre, x, params_post), dim=-1) + # ).data + # )(params) + return backend.jacobian( + lambda x: self.sample( + window=window, params=backend.concatenate((params_pre, x, params_post), dim=-1) + )._data, + params, + ) + + def jacobian( + self, + window: Optional[Window] = None, + pass_jacobian: Optional[JacobianImage] = None, + params: Optional[ArrayLike] = None, + ) -> JacobianImage: + if window is None: + window = self.window + + if pass_jacobian is None: + jac_img = self.target[window].jacobian_image( + parameters=self.build_params_array_identities() + ) + else: + jac_img = pass_jacobian + + # No dynamic params + if params is None: + params = self.get_values() + if params.shape[-1] == 0: + return jac_img + + # handle large images + n_pixels = np.prod(window.shape) + if n_pixels > self.jacobian_maxpixels: + for chunk in window.chunk(self.jacobian_maxpixels): + jac_img = self.jacobian(window=chunk, pass_jacobian=jac_img, params=params) + return jac_img + + identities = self.build_params_array_identities() + if len(jac_img.match_parameters(identities)[0]) == 0: + return jac_img + + target = self.target[window] + if len(params) > self.jacobian_maxparams: # handle large number of parameters + chunksize = len(params) // self.jacobian_maxparams + 1 + for i in range(0, len(params), chunksize): + params_pre = params[:i] + params_chunk = params[i : i + chunksize] + params_post = params[i + chunksize :] + jac_chunk = self._jacobian(window, params_pre, params_chunk, params_post) + jac_img += target.jacobian_image( + parameters=identities[i : i + chunksize], + data=jac_chunk, + ) + else: + jac = self._jacobian(window, params[:0], params, params[0:0]) + jac_img += target.jacobian_image(parameters=identities, data=jac) + + return jac_img + + def gradient( + self, + window: Optional[Window] = None, + params: Optional[ArrayLike] = None, + likelihood: Literal["gaussian", "poisson"] = "gaussian", + ) -> ArrayLike: + """Compute the gradient of the model with respect to its parameters.""" + if window is None: + window = self.window + + jacobian_image = self.jacobian(window=window, params=params) + + data = self.target[window]._data + model = self.sample(window=window)._data + if likelihood == "gaussian": + weight = self.target[window]._weight + gradient = backend.sum( + jacobian_image._data * ((data - model) * weight)[..., None], dim=(0, 1) + ) + elif likelihood == "poisson": + gradient = backend.sum( + jacobian_image._data * (1 - data / model)[..., None], + dim=(0, 1), + ) + + return gradient diff --git a/astrophot/models/mixins/sersic.py b/astrophot/models/mixins/sersic.py new file mode 100644 index 00000000..11730f1e --- /dev/null +++ b/astrophot/models/mixins/sersic.py @@ -0,0 +1,104 @@ +import torch + +from ...param import forward +from ...backend_obj import ArrayLike +from ...utils.decorators import ignore_numpy_warnings +from .._shared_methods import parametric_initialize, parametric_segment_initialize +from ...utils.parametric_profiles import sersic_np +from .. import func + + +def _x0_func(model, R, F): + return 2.0, R[4], 10 ** F[4] + + +class SersicMixin: + """Sersic radial light profile (Sersic 1963). + + This is a classic profile used widely in galaxy modelling. It can be a good + starting point for many extended objects. The functional form of the Sersic + profile is defined as: + + $$I(R) = I_e * \\exp(- b_n((R/R_e)^{1/n} - 1))$$ + + It is a generalization of a gaussian, exponential, and de-Vaucouleurs + profile. The Sersic index `n` controls the shape of the profile, with `n=1` + being an exponential profile, `n=4` being a de-Vaucouleurs profile, and + `n=0.5` being a Gaussian profile. + + **Parameters:** + - `n`: Sersic index which controls the shape of the brightness profile + - `Re`: half light radius [arcsec] + - `Ie`: intensity at the half light radius [flux/arcsec^2] + """ + + _model_type = "sersic" + _parameter_specs = { + "n": {"units": "none", "valid": (0.36, 8), "shape": (), "dynamic": True}, + "Re": {"units": "arcsec", "valid": (0, None), "shape": (), "dynamic": True}, + "Ie": {"units": "flux/arcsec^2", "valid": (0, None), "shape": (), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_initialize( + self, self.target[self.window], sersic_np, ("n", "Re", "Ie"), _x0_func + ) + + @forward + def radial_model(self, R: ArrayLike, n: ArrayLike, Re: ArrayLike, Ie: ArrayLike) -> ArrayLike: + return func.sersic(R, n, Re, Ie) + + +class iSersicMixin: + """Sersic radial light profile (Sersic 1963). + + This is a classic profile used widely in galaxy modelling. It can be a good + starting point for many extended objects. The functional form of the Sersic + profile is defined as: + + $$I(R) = I_e * \\exp(- b_n((R/R_e)^{1/n} - 1))$$ + + It is a generalization of a gaussian, exponential, and de-Vaucouleurs + profile. The Sersic index `n` controls the shape of the profile, with `n=1` + being an exponential profile, `n=4` being a de-Vaucouleurs profile, and + `n=0.5` being a Gaussian profile. + + `n`, `Re`, and `Ie` are batched by their first dimension, allowing for + multiple Sersic profiles to be defined at once. + + **Parameters:** + - `n`: Sersic index which controls the shape of the brightness profile + - `Re`: half light radius [arcsec] + - `Ie`: intensity at the half light radius [flux/arcsec^2] + """ + + _model_type = "sersic" + _parameter_specs = { + "n": {"units": "none", "valid": (0.36, 8), "dynamic": True}, + "Re": {"units": "arcsec", "valid": (0, None), "dynamic": True}, + "Ie": {"units": "flux/arcsec^2", "valid": (0, None), "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + parametric_segment_initialize( + model=self, + target=self.target[self.window], + prof_func=sersic_np, + params=("n", "Re", "Ie"), + x0_func=_x0_func, + segments=self.segments, + ) + + @forward + def iradial_model( + self, i: int, R: ArrayLike, n: ArrayLike, Re: ArrayLike, Ie: ArrayLike + ) -> ArrayLike: + return func.sersic(R, n[i], Re[i], Ie[i]) diff --git a/astrophot/models/mixins/spline.py b/astrophot/models/mixins/spline.py new file mode 100644 index 00000000..4b95dffb --- /dev/null +++ b/astrophot/models/mixins/spline.py @@ -0,0 +1,116 @@ +import torch +import numpy as np + +from ...param import forward +from ...backend_obj import ArrayLike +from ...utils.decorators import ignore_numpy_warnings +from .._shared_methods import _sample_image +from ...utils.interpolate import default_prof +from .. import func + + +class SplineMixin: + """Spline radial model for brightness. + + The `radial_model` function for this model is defined as a spline + interpolation from the parameter `I_R`. The `I_R` parameter is a tensor + that contains the radial profile of the brightness in units of + flux/arcsec^2. The radius of each node is determined from `I_R.prof`. + + **Parameters:** + - `I_R`: Tensor of radial brightness values in units of flux/arcsec^2. + """ + + _model_type = "spline" + _parameter_specs = {"I_R": {"units": "flux/arcsec^2", "valid": (0, None), "dynamic": True}} + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + if self.I_R.initialized: + return + + target_area = self.target[self.window] + # Create the I_R profile radii if needed + if self.I_R.prof is None: + prof = default_prof(self.window.shape, target_area.pixelscale, 2, 0.2) + prof = np.append(prof, prof[-1] * 10) + self.I_R.prof = prof + else: + prof = self.I_R.prof + + R, I, S = _sample_image( + target_area, + self.transform_coordinates, + self.radius_metric, + rad_bins=[0] + list((prof[:-1] + prof[1:]) / 2) + [prof[-1] * 100], + ) + self.I_R.value = 10**I + + @forward + def radial_model(self, R: ArrayLike, I_R: ArrayLike) -> ArrayLike: + ret = func.spline(R, self.I_R.prof, I_R) + return ret + + +class iSplineMixin: + """Batched spline radial model for brightness. + + The `radial_model` function for this model is defined as a spline + interpolation from the parameter `I_R`. The `I_R` parameter is a tensor that + contains the radial profile of the brightness in units of flux/arcsec^2. The + radius of each node is determined from `I_R.prof`. + + Both `I_R` and `I_R.prof` are batched by their first dimension, allowing for + multiple spline profiles to be defined at once. Each individual spline model + is then `I_R[i]` and `I_R.prof[i]` where `i` indexes the profiles. + + **Parameters:** + - `I_R`: Tensor of radial brightness values in units of flux/arcsec^2. + """ + + _model_type = "spline" + _parameter_specs = {"I_R": {"units": "flux/arcsec^2", "valid": (0, None), "dynamic": True}} + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + if self.I_R.initialized: + return + + target_area = self.target[self.window] + # Create the I_R profile radii if needed + if self.I_R.prof is None: + prof = default_prof(self.window.shape, target_area.pixelscale, 2, 0.2) + prof = np.append(prof, prof[-1] * 10) + prof = np.stack([prof] * self.segments) + self.I_R.prof = prof + else: + prof = self.I_R.prof + + value = np.zeros(prof.shape) + cycle = np.pi if self.symmetric else 2 * np.pi + w = cycle / self.segments + v = w * np.arange(self.segments) + for s in range(self.segments): + angle_range = (v[s] - w / 2, v[s] + w / 2) + R, I, S = _sample_image( + target_area, + self.transform_coordinates, + self.radius_metric, + angle=self.angular_metric, + rad_bins=[0] + list((prof[s][:-1] + prof[s][1:]) / 2) + [prof[s][-1] * 100], + angle_range=angle_range, + cycle=cycle, + ) + value[s] = I + + self.I_R.value = 10**value + + @forward + def iradial_model(self, i: int, R: ArrayLike, I_R: ArrayLike) -> ArrayLike: + return func.spline(R, self.I_R.prof[i], I_R[i]) diff --git a/astrophot/models/mixins/transform.py b/astrophot/models/mixins/transform.py new file mode 100644 index 00000000..7d335cc5 --- /dev/null +++ b/astrophot/models/mixins/transform.py @@ -0,0 +1,310 @@ +from typing import Tuple +import numpy as np +import torch + +from ...utils.decorators import ignore_numpy_warnings +from ...utils.interpolate import default_prof +from ...backend_obj import backend, ArrayLike +from ...param import forward +from .. import func +from ... import config + + +class InclinedMixin: + """A model which defines a position angle and axis ratio. + + PA and q operate on the coordinates to transform the model. Given some x,y + the updated values are: + + $$x', y' = {\\rm rotate}(-PA + \\pi/2, x, y)$$ + $$y'' = y' / q$$ + + where x' and y'' are the final transformed coordinates. The $\\pi/2$ is included + such that the position angle is defined with 0 at north. The -PA is such + that the position angle increases to the East. Thus, the position angle is a + standard East of North definition assuming the WCS of the image is correct. + + Note that this means radii are defined with $R = \\sqrt{x^2 + + \\left(\\frac{y}{q}\\right)^2}$ rather than the common alternative which is $R = + \\sqrt{qx^2 + \\frac{y^2}{q}}$ + + **Parameters:** + - `q`: Axis ratio of the model, defined as the ratio of the + semi-minor axis to the semi-major axis. A value of 1.0 is + circular. + - `PA`: Position angle of the model, defined as the angle + between the semi-major axis and North, measured East of North. + A value of 0.0 is North, a value of pi/2 is East. + """ + + _parameter_specs = { + "q": {"units": "b/a", "valid": (0.01, 1), "shape": (), "dynamic": True}, + "PA": { + "units": "radians", + "valid": (0, np.pi), + "cyclic": True, + "shape": (), + "dynamic": True, + }, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + if self.PA.initialized and self.q.initialized: + return + target_area = self.target[self.window] + dat = backend.to_numpy(backend.copy(target_area._data)) + mask = backend.to_numpy(backend.copy(target_area._mask)) + dat[mask] = np.median(dat[~mask]) + + edge = np.concatenate((dat[:, 0], dat[:, -1], dat[0, :], dat[-1, :])) + edge_average = np.nanmedian(edge) + dat -= edge_average + + x, y = target_area.coordinate_center_meshgrid() + x = backend.to_numpy(x - self.center.value[0]) + y = backend.to_numpy(y - self.center.value[1]) + mu20 = np.mean(dat * np.abs(x)) + mu02 = np.mean(dat * np.abs(y)) + mu11 = np.mean(dat * x * y / np.sqrt(np.abs(x * y) + self.softening**2)) + M = np.array([[mu20, mu11], [mu11, mu02]]) + if not self.PA.initialized: + if np.any(np.iscomplex(M)) or np.any(~np.isfinite(M)): + self.PA.value = np.pi / 2 + else: + self.PA.value = (0.5 * np.arctan2(2 * mu11, mu20 - mu02) - np.pi / 2) % np.pi + if not self.q.initialized: + if np.any(np.iscomplex(M)) or np.any(~np.isfinite(M)): + l = (0.7, 1.0) + else: + l = np.sort(np.linalg.eigvals(M)) + self.q.value = np.clip(np.sqrt(np.abs(l[0] / l[1])), 0.1, 0.9) + + @forward + def transform_coordinates( + self, x: ArrayLike, y: ArrayLike, PA: ArrayLike, q: ArrayLike + ) -> Tuple[ArrayLike, ArrayLike]: + x, y = super().transform_coordinates(x, y) + x, y = func.rotate(-PA + np.pi / 2, x, y) + return x, y / q + + +class SuperEllipseMixin: + """Generalizes the definition of radius and so modifies the evaluation of radial models. + + A superellipse transformation allows for the expression of "boxy" and + "disky" modifications to traditional elliptical isophotes. This is a common + extension of the standard elliptical representation, especially for + early-type galaxies. The functional form for this is: + + $$R = (|x|^C + |y|^C)^{1/C}$$ + + where $R$ is the new distance metric, $X$ and $Y$ are the coordinates, and $C$ is the + coefficient for the superellipse. $C$ can take on any value greater than zero + where $C = 2$ is the standard distance metric, $0 < C < 2$ creates disky or + pointed perturbations to an ellipse, and $C > 2$ transforms an ellipse to be + more boxy. + + **Parameters:** + - `C`: Superellipse distance metric parameter, controls the shape of the isophotes. + A value of 2.0 is a standard elliptical distance metric, values + less than 2.0 create disky or pointed perturbations to an ellipse, + and values greater than 2.0 create boxy perturbations to an ellipse. + + """ + + _model_type = "superellipse" + _parameter_specs = { + "C": {"units": "none", "value": 2.0, "valid": (0, 10), "dynamic": True}, + } + + @forward + def radius_metric(self, x: ArrayLike, y: ArrayLike, C: ArrayLike) -> ArrayLike: + return (backend.abs(x) ** C + backend.abs(y) ** C + self.softening**C) ** (1.0 / C) + + +class FourierEllipseMixin: + """Sine wave perturbation of the elliptical radius metric. + + This allows for the expression of arbitrarily complex isophotes instead of + pure ellipses. This is a common extension of the standard elliptical + representation. The form of the Fourier perturbations is: + + $$R' = R * \\exp\\left(\\sum_m(a_m * \\cos(m * \\theta + \\phi_m))\\right)$$ + + where R' is the new radius value, R is the original radius (typically + computed as $\\sqrt{x^2+y^2}$), m is the index of the Fourier mode, a_m is + the amplitude of the m'th Fourier mode, theta is the angle around the + ellipse (typically $\\arctan(y/x)$), and phi_m is the phase of the m'th + fourier mode. + + One can create extremely complex shapes using different Fourier modes, + however usually it is only low order modes that are of interest. For + intuition, the first Fourier mode is roughly equivalent to a lopsided + galaxy, one side will be compressed and the opposite side will be expanded. + The second mode is almost never used as it is nearly degenerate with + ellipticity. The third mode is an alternate kind of lopsidedness for a + galaxy which makes it somewhat triangular, meaning that it is wider on one + side than the other. The fourth mode is similar to a boxyness/diskyness + parameter of a superelllipse which tends to make more pronounced peanut + shapes since it is more rounded than a superellipse representation. Modes + higher than 4 are only useful in very specialized situations. In general one + should consider carefully why the Fourier modes are being used for the + science case at hand. + + **Parameters:** + - `am`: Tensor of amplitudes for the Fourier modes, indicates the strength + of each mode. + - `phim`: Tensor of phases for the Fourier modes, adjusts the + orientation of the mode perturbation relative to the major axis. It + is cyclically defined in the range [0,2pi) + + **Options:** + - `modes`: Tuple of integers indicating which Fourier modes to use. + """ + + _model_type = "fourier" + _parameter_specs = { + "am": {"units": "none", "dynamic": True}, + "phim": {"units": "radians", "valid": (0, 2 * np.pi), "cyclic": True, "dynamic": False}, + } + _options = ("modes",) + + def __init__(self, *args, modes: Tuple[int] = (3, 4), **kwargs): + super().__init__(*args, **kwargs) + self.modes = backend.as_array(modes, dtype=config.DTYPE, device=config.DEVICE) + + @forward + def radius_metric( + self, x: ArrayLike, y: ArrayLike, am: ArrayLike, phim: ArrayLike + ) -> ArrayLike: + R = super().radius_metric(x, y) + theta = self.angular_metric(x, y) + return R * backend.exp( + backend.sum( + am[..., None] + * backend.cos(self.modes[..., None] * theta.flatten() + phim[..., None]), + 0, + ).reshape(x.shape) + ) + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + if not self.am.initialized: + self.am.value = np.zeros(len(self.modes)) + if not self.phim.initialized: + self.phim.value = np.zeros(len(self.modes)) + + +class WarpMixin: + """Warped model with varying PA and q as a function of radius. + + This works by warping the coordinates using the same transform for a global + PA, q except applied to each pixel individually based on its unwarped radius + value. In the limit that PA and q are a constant, this recovers a basic + model with global PA, q. However, a linear PA profile will give a spiral + appearance, variations of PA, q profiles can create complex galaxy models. + The form of the coordinate transformation for each pixel looks like: + + $$R = \\sqrt{x^2 + y^2}$$ + $$x', y' = \\rm{rotate}(-PA(R) + \\pi/2, x, y)$$ + $$y'' = y' / q(R)$$ + + Note that now PA and q are functions of radius R, which is computed from the + original coordinates X, Y. This is achieved by making PA and q a spline + profile. + + **Parameters:** + - `q_R`: Tensor of axis ratio values for axis ratio spline + - `PA_R`: Tensor of position angle values as input to the spline + + """ + + _model_type = "warp" + _parameter_specs = { + "q_R": {"units": "b/a", "valid": (0, 1), "dynamic": True}, + "PA_R": {"units": "radians", "valid": (0, np.pi), "cyclic": True, "dynamic": True}, + } + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + if not self.PA_R.initialized: + if self.PA_R.prof is None: + self.PA_R.prof = default_prof(self.window.shape, self.target.pixelscale, 2, 0.2) + self.PA_R.value = np.zeros(len(self.PA_R.prof)) + np.pi / 2 + if not self.q_R.initialized: + if self.q_R.prof is None: + self.q_R.prof = default_prof(self.window.shape, self.target.pixelscale, 2, 0.2) + self.q_R.value = np.ones(len(self.q_R.prof)) * 0.8 + + @forward + def transform_coordinates( + self, x: ArrayLike, y: ArrayLike, q_R: ArrayLike, PA_R: ArrayLike + ) -> Tuple[ArrayLike, ArrayLike]: + x, y = super().transform_coordinates(x, y) + R = self.radius_metric(x, y) + PA = func.spline(R, self.PA_R.prof, PA_R, extend="const") + q = func.spline(R, self.q_R.prof, q_R, extend="const") + x, y = func.rotate(-PA + np.pi / 2, x, y) + return x, y / q + + +class TruncationMixin: + """Truncated model with radial brightness profile. + + This model will smoothly truncate the radial brightness profile at Rt. The + truncation is centered on Rt and thus two identical models with the same Rt + (and St) where one is inner truncated and the other is outer truncated will + reproduce nearly the same as a single un-truncated model. + + By default the St parameter is set fixed to 1.0, giving a relatively smooth + truncation. This can be set to a smaller value for sharper truncations or a + larger value for even more gradual truncation. It can be set dynamic to be + optimized in a model, though it is possible for this parameter to be + unstable if there isn't a clear truncation signal in the data. + + **Parameters:** + - `Rt`: The truncation radius in arcseconds. + - `St`: The steepness of the truncation profile, controlling how quickly + the brightness drops to zero at the truncation radius. + + **Options:** + - `outer_truncation`: If True, the model will truncate the brightness beyond + the truncation radius. If False, the model will truncate the + brightness within the truncation radius. + """ + + _model_type = "truncated" + _parameter_specs = { + "Rt": {"units": "arcsec", "valid": (0, None), "shape": (), "dynamic": True}, + "St": {"units": "none", "valid": (0, None), "shape": (), "value": 1.0, "dynamic": False}, + } + _options = ("outer_truncation",) + + def __init__(self, *args, outer_truncation: bool = True, **kwargs): + super().__init__(*args, **kwargs) + self.outer_truncation = outer_truncation + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + if not self.Rt.initialized: + prof = default_prof(self.window.shape, self.target.pixelscale, 2, 0.2) + self.Rt.value = prof[len(prof) // 2] + + @forward + def radial_model(self, R: ArrayLike, Rt: ArrayLike, St: ArrayLike) -> ArrayLike: + I = super().radial_model(R) + if self.outer_truncation: + return I * (1 - backend.tanh(St * (R - Rt))) / 2 + return I * (backend.tanh(St * (R - Rt)) + 1) / 2 diff --git a/astrophot/models/model_object.py b/astrophot/models/model_object.py index 919481d5..8c17a94b 100644 --- a/astrophot/models/model_object.py +++ b/astrophot/models/model_object.py @@ -3,265 +3,157 @@ import numpy as np import torch -from .core_model import AstroPhot_Model +from ..param import forward +from .base import Model +from . import func from ..image import ( - Model_Image, + TargetImage, Window, - PSF_Image, - Target_Image, - Target_Image_List, - Image, + PSFImage, ) -from ..param import Parameter_Node, Param_Unlock, Param_SoftLimits -from ..utils.initialize import center_of_mass -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ._shared_methods import select_target -from .. import AP_config +from ..utils.initialize import recursive_center_of_mass +from ..utils.decorators import ignore_numpy_warnings, combine_docstrings +from .. import config +from ..backend_obj import backend from ..errors import InvalidTarget +from .mixins import SampleMixin -__all__ = ["Component_Model"] - - -class Component_Model(AstroPhot_Model): - """Component_Model(name, target, window, locked, **kwargs) - - Component_Model is a base class for models that represent single - objects or parametric forms. It provides the basis for subclassing - models and requires the definition of parameters, initialization, - and model evaluation functions. This class also handles - integration, PSF convolution, and computing the Jacobian matrix. - - Attributes: - parameter_specs (dict): Specifications for the model parameters. - _parameter_order (tuple): Fixed order of parameters. - psf_mode (str): Technique and scope for PSF convolution. - sampling_mode (str): Method for initial sampling of model. Can be one of midpoint, trapezoid, simpson. Default: midpoint - sampling_tolerance (float): accuracy to which each pixel should be evaluated. Default: 1e-2 - integrate_mode (str): Integration scope for the model. One of none, threshold, full where threshold will select which pixels to integrate while full (in development) will integrate all pixels. Default: threshold - integrate_max_depth (int): Maximum recursion depth when performing sub pixel integration. - integrate_gridding (int): Amount by which to subdivide pixels when doing recursive pixel integration. - integrate_quad_level (int): The initial quadrature level for sub pixel integration. Please always choose an odd number 3 or higher. - softening (float): Softening length used for numerical stability and integration stability to avoid discontinuities (near R=0). Effectively has units of arcsec. Default: 1e-5 - jacobian_chunksize (int): Maximum size of parameter list before jacobian will be broken into smaller chunks. - special_kwargs (list): Parameters which are treated specially by the model object and should not be updated directly. - usable (bool): Indicates if the model is usable. - - Methods: - initialize: Determine initial values for the center coordinates. - sample: Evaluate the model on the space covered by an image object. - jacobian: Compute the Jacobian matrix for this model. +__all__ = ("ComponentModel",) - """ - - # Specifications for the model parameters including units, value, uncertainty, limits, locked, and cyclic - parameter_specs = { - "center": {"units": "arcsec", "uncertainty": [0.1, 0.1]}, - } - # Fixed order of parameters for all methods that interact with the list of parameters - _parameter_order = ("center",) - - # Scope for PSF convolution - psf_mode = "none" # none, full - # Technique for PSF convolution - psf_convolve_mode = "fft" # fft, direct - # Method to use when performing subpixel shifts. bilinear set by default for stability around pixel edges, though lanczos:3 is also fairly stable, and all are stable when away from pixel edges - psf_subpixel_shift = "bilinear" # bilinear, lanczos:2, lanczos:3, lanczos:5, none - - # Method for initial sampling of model - sampling_mode = ( - "midpoint" # midpoint, trapezoid, simpsons, quad:x (where x is a positive integer) - ) - - # Level to which each pixel should be evaluated - sampling_tolerance = 1e-2 - - # Integration scope for model - integrate_mode = "threshold" # none, threshold - - # Maximum recursion depth when performing sub pixel integration - integrate_max_depth = 3 - - # Amount by which to subdivide pixels when doing recursive pixel integration - integrate_gridding = 5 - - # The initial quadrature level for sub pixel integration. Please always choose an odd number 3 or higher - integrate_quad_level = 3 - - # Maximum size of parameter list before jacobian will be broken into smaller chunks, this is helpful for limiting the memory requirements to build a model, lower jacobian_chunksize is slower but uses less memory - jacobian_chunksize = 10 - image_chunksize = 1000 - - # Softening length used for numerical stability and/or integration stability to avoid discontinuities (near R=0) - softening = 1e-3 - - # Parameters which are treated specially by the model object and should not be updated directly when initializing - track_attrs = [ - "psf_mode", - "psf_convolve_mode", - "psf_subpixel_shift", - "sampling_mode", - "sampling_tolerance", - "integrate_mode", - "integrate_max_depth", - "integrate_gridding", - "integrate_quad_level", - "jacobian_chunksize", - "image_chunksize", - "softening", - ] - usable = False - def __init__(self, *, name=None, **kwargs): - self._target_identity = None +@combine_docstrings +class ComponentModel(SampleMixin, Model): + """Component of a model for an object in an image. - self.psf = None - self.psf_aux_image = None + This is a single component of an image model. It has a position on the sky + determined by `center` and may or may not be convolved with a PSF to represent some data. - super().__init__(name=name, **kwargs) - - # If loading from a file, get model configuration then exit __init__ - if "filename" in kwargs: - self.load(kwargs["filename"], new_name=name) - return + **Parameters:** + - `center`: The center of the component in arcseconds [x, y] defined on the tangent plane. - self.parameter_specs = self.build_parameter_specs(kwargs.get("parameters", None)) - with torch.no_grad(): - self.build_parameters() - if isinstance(kwargs.get("parameters", None), torch.Tensor): - self.parameters.value = kwargs["parameters"] + **Options:** + - `psf_convolve`: Whether to convolve the model with a PSF. (bool) - def set_aux_psf(self, aux_psf, add_parameters=True): - """Set the PSF for this model as an auxiliary psf model. This psf - model will be resampled as part of the model sampling step to - track changes made during fitting. + """ - Args: - aux_psf: The auxiliary psf model - add_parameters: if true, the parameters of the auxiliary psf model will become model parameters for this model as well. + _parameter_specs = {"center": {"units": "arcsec", "shape": (2,), "dynamic": True}} - """ + _options = ("psf_convolve",) - self._psf = aux_psf + usable = False - if add_parameters: - self.parameters.link(aux_psf.parameters) + def __init__(self, *args, psf=None, psf_convolve: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.psf = psf + self.psf_convolve = psf_convolve @property def psf(self): if self._psf is None: - try: - return self.target.psf - except AttributeError: - return None + return self.target.psf return self._psf @psf.setter def psf(self, val): + try: + del self._psf # Remove old PSF if it exists + except AttributeError: + pass if val is None: self._psf = None - elif isinstance(val, PSF_Image): + elif isinstance(val, PSFImage): self._psf = val - elif isinstance(val, AstroPhot_Model): - self.set_aux_psf(val) + self.psf_convolve = True + elif isinstance(val, Model): + self._psf = val + self.psf_convolve = True + else: + self._psf = self.target.psf_image(data=val) + self.psf_convolve = True + self._update_psf_upscale() + + def _update_psf_upscale(self): + """Update the PSF upscale factor based on the current target pixel length.""" + if self.psf is None: + self.psf_upscale = 1 + elif isinstance(self.psf, PSFImage): + self.psf_upscale = int(np.round((self.target.pixelscale / self.psf.pixelscale).item())) + elif isinstance(self.psf, Model): + self.psf_upscale = int( + np.round((self.target.pixelscale / self.psf.target.pixelscale).item()) + ) else: - self._psf = PSF_Image(data=val, pixelscale=self.target.pixelscale) - AP_config.ap_logger.warning( - "Setting PSF with pixel matrix, assuming target pixelscale is the same as " - "PSF pixelscale. To remove this warning, set PSFs as an ap.image.PSF_Image " - "or ap.models.AstroPhot_Model object instead." + raise TypeError( + f"PSF must be a PSFImage or Model instance, got {type(self.psf)} instead." ) + @property + def target(self): + return self._target + + @target.setter + def target(self, tar): + if tar is None: + self._target = None + return + elif not isinstance(tar, TargetImage): + raise InvalidTarget("AstroPhot Model target must be a TargetImage instance.") + try: + del self._target # Remove old target if it exists + except AttributeError: + pass + self._target = tar + try: + self._update_psf_upscale() + except AttributeError: + pass + # Initialization functions ###################################################################### @torch.no_grad() @ignore_numpy_warnings - @select_target - @default_internal - def initialize( - self, - target: Optional["Target_Image"] = None, - parameters: Optional[Parameter_Node] = None, - **kwargs, - ): + def initialize(self): """Determine initial values for the center coordinates. This is done with a local center of mass search which iterates by finding the center of light in a window, then iteratively updates until the iterations move by less than a pixel. - - Args: - target (Optional[Target_Image]): A target image object to use as a reference when setting parameter values - """ - super().initialize(target=target, parameters=parameters) - # Get the sub-image area corresponding to the model image - target_area = target[self.window] + if self.psf is not None and isinstance(self.psf, Model): + self.psf.initialize() # Use center of window if a center hasn't been set yet - if parameters["center"].value is None: - with ( - Param_Unlock(parameters["center"]), - Param_SoftLimits(parameters["center"]), - ): - parameters["center"].value = self.window.center - else: + if self.center.initialized: return - if parameters["center"].locked: - return - - # Convert center coordinates to target area array indices - init_icenter = target_area.plane_to_pixel(parameters["center"].value) + target_area = self.target[self.window] + dat = np.copy(backend.to_numpy(target_area._data)) + mask = backend.to_numpy(target_area._mask) + dat[mask] = np.nanmedian(dat[~mask]) - # Compute center of mass in window - COM = center_of_mass( - ( - init_icenter[1].detach().cpu().item(), - init_icenter[0].detach().cpu().item(), - ), - target_area.data.detach().cpu().numpy(), - ) - if np.any(np.array(COM) < 0) or np.any(np.array(COM) >= np.array(target_area.data.shape)): - AP_config.ap_logger.warning("center of mass failed, using center of window") + COM = recursive_center_of_mass(dat) + if not np.all(np.isfinite(COM)): return - COM = (COM[1], COM[0]) - # Convert center of mass indices to coordinates COM_center = target_area.pixel_to_plane( - torch.tensor(COM, dtype=AP_config.ap_dtype, device=AP_config.ap_device) + *backend.as_array(COM, dtype=config.DTYPE, device=config.DEVICE) ) + self.center.value = COM_center - # Set the new coordinates as the model center - parameters["center"].value = COM_center - - # Fit loop functions - ###################################################################### - def evaluate_model( - self, - X: Optional[torch.Tensor] = None, - Y: Optional[torch.Tensor] = None, - image: Optional[Image] = None, - parameters: Parameter_Node = None, - **kwargs, - ): - """Evaluate the model on every pixel in the given image. The - basemodel object simply returns zeros, this function should be - overloaded by subclasses. + def fit_mask(self): + return backend.zeros_like(self.target[self.window].mask, dtype=backend.bool) - Args: - image (Image): The image defining the set of pixels on which to evaluate the model + def _fit_mask(self): + return backend.zeros_like(self.target[self.window]._mask, dtype=backend.bool) - """ - if X is None or Y is None: - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - return torch.zeros_like(X) # do nothing in base model + @forward + def transform_coordinates(self, x, y, center): + return x - center[0], y - center[1] + @forward def sample( self, - image: Optional[Image] = None, window: Optional[Window] = None, - parameters: Optional[Parameter_Node] = None, ): - """Evaluate the model on the space covered by an image object. This + """Evaluate the model on the pixels defined in an image. This function properly calls integration methods and PSF convolution. This should not be overloaded except in special cases. @@ -273,172 +165,33 @@ def sample( with the original pixel grid. The final model is then added to the requested image. - Args: - image (Optional[Image]): An AstroPhot Image object (likely a Model_Image) - on which to evaluate the model values. If not - provided, a new Model_Image object will be created. - window (Optional[Window]): A window within which to evaluate the model. - Should only be used if a subset of the full image - is needed. If not provided, the entire image will - be used. + **Args:** + - `window` (Optional[Window]): A window within which to evaluate the model. + By default this is the model's window. - Returns: - Image: The image with the computed model values. + **Returns:** + - `Image` (ModelImage): The image with the computed model values. """ - # Image on which to evaluate model - if image is None: - image = self.make_model_image(window=window) - # Window within which to evaluate model if window is None: - working_window = image.window.copy() - else: - working_window = window.copy() - - # Parameters with which to evaluate the model - if parameters is None: - parameters = self.parameters - - if "window" in self.psf_mode: - raise NotImplementedError("PSF convolution in sub-window not available yet") - - if "full" in self.psf_mode: - if isinstance(self.psf, AstroPhot_Model): - psf = self.psf( - parameters=parameters[self.psf.name], - ) - else: - psf = self.psf - psf_upscale = torch.round(working_window.pixel_length / psf.pixel_length).int() - working_window = working_window.rescale_pixel(1 / psf_upscale) - # Add border for psf convolution edge effects, will be cropped out later - working_window.pad_pixel(psf.psf_border_int) - # Make the image object to which the samples will be tracked - working_image = Model_Image(window=working_window) - # Sub pixel shift to align the model with the center of a pixel - if self.psf_subpixel_shift != "none": - pixel_center = working_image.plane_to_pixel(parameters["center"].value) - center_shift = pixel_center - torch.round(pixel_center) - working_image.header.pixel_shift(center_shift) - else: - center_shift = None - - # Evaluate the model at the current resolution - reference, deep = self._sample_init( - image=working_image, - parameters=parameters, - center=parameters["center"].value, - ) - # If needed, super-resolve the image in areas of high curvature so pixels are properly sampled - deep = self._sample_integrate( - deep, reference, working_image, parameters, parameters["center"].value - ) + window = self.window - # update the image with the integrated pixels - working_image.data += deep + if self.psf_convolve: + psf = self.psf() if isinstance(self.psf, Model) else self.psf - # Convolve the PSF - self._sample_convolve(working_image, center_shift, psf, self.psf_subpixel_shift) - - # Shift image back to align with original pixel grid - if self.psf_subpixel_shift != "none": - working_image.header.pixel_shift(-center_shift) - # Add the sampled/integrated/convolved pixels to the requested image - working_upscale = torch.round(image.pixel_length / working_window.pixel_length).int() - working_image = working_image.crop(psf.psf_border_int).reduce(working_upscale) - else: - # Create an image to store pixel samples - working_image = Model_Image(pixelscale=image.pixelscale, window=working_window) - # Evaluate the model on the image - reference, deep = self._sample_init( - image=working_image, - parameters=parameters, - center=parameters["center"].value, + working_image = self.target[window].model_image( + upsample=self.psf_upscale, pad=psf.psf_pad ) - # Super-resolve and integrate where needed - deep = self._sample_integrate( - deep, - reference, - working_image, - parameters, - center=parameters["center"].value, - ) - # Add the sampled/integrated pixels to the requested image - working_image.data += deep - - if self.mask is not None: - working_image.data = working_image.data * torch.logical_not(self.mask) + sample = self.sample_image(working_image) + working_image._data = func.convolve(sample, psf._data) + working_image = working_image.crop(psf.psf_pad).reduce(self.psf_upscale) - image += working_image - - return image - - @property - def target(self): - return self._target - - @target.setter - def target(self, tar): - if not (tar is None or isinstance(tar, Target_Image)): - raise InvalidTarget("AstroPhot_Model target must be a Target_Image instance.") - - # If a target image list is assigned, pick out the target appropriate for this model - if isinstance(tar, Target_Image_List) and self._target_identity is not None: - for subtar in tar: - if subtar.identity == self._target_identity: - usetar = subtar - break - else: - raise InvalidTarget( - f"Could not find target in Target_Image_List with matching identity " - f"to {self.name}: {self._target_identity}" - ) else: - usetar = tar + working_image = self.target[window].model_image() + working_image._data = self.sample_image(working_image) - self._target = usetar + # Units from flux/arcsec^2 to flux, multiply by pixel area + working_image.fluxdensity_to_flux() - # Remember the target identity to use - try: - self._target_identity = self._target.identity - except AttributeError: - pass - - def get_state(self, save_params=True): - """Returns a dictionary with a record of the current state of the - model. - - Specifically, the current parameter settings and the window for - this model. From this information it is possible for the model to - re-build itself lated when loading from disk. Note that the target - image is not saved, this must be reset when loading the model. - - """ - state = super().get_state() - state["window"] = self.window.get_state() - if save_params: - state["parameters"] = self.parameters.get_state() - state["target_identity"] = self._target_identity - if isinstance(self._psf, PSF_Image) or isinstance(self._psf, AstroPhot_Model): - state["psf"] = self._psf.get_state() - for key in self.track_attrs: - if getattr(self, key) != getattr(self.__class__, key): - state[key] = getattr(self, key) - return state - - # Extra background methods for the basemodel - ###################################################################### - from ._model_methods import radius_metric - from ._model_methods import angular_metric - from ._model_methods import _sample_init - from ._model_methods import _sample_integrate - from ._model_methods import _sample_convolve - from ._model_methods import _integrate_reference - from ._model_methods import _shift_psf - from ._model_methods import build_parameter_specs - from ._model_methods import build_parameters - from ._model_methods import jacobian - from ._model_methods import _chunk_jacobian - from ._model_methods import _chunk_image_jacobian - from ._model_methods import load + return working_image diff --git a/astrophot/models/moffat.py b/astrophot/models/moffat.py new file mode 100644 index 00000000..65be477c --- /dev/null +++ b/astrophot/models/moffat.py @@ -0,0 +1,69 @@ +from .galaxy_model_object import GalaxyModel +from .psf_model_object import PSFModel +from .mixins import ( + MoffatMixin, + InclinedMixin, + RadialMixin, + WedgeMixin, + RayMixin, + SuperEllipseMixin, + FourierEllipseMixin, + WarpMixin, + iMoffatMixin, +) +from ..utils.decorators import combine_docstrings + + +__all__ = ( + "MoffatGalaxy", + "MoffatPSF", + "Moffat2DPSF", + "MoffatSuperEllipse", + "MoffatFourierEllipse", + "MoffatWarp", + "MoffatRay", + "MoffatWedge", +) + + +@combine_docstrings +class MoffatGalaxy(MoffatMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class MoffatPSF(MoffatMixin, RadialMixin, PSFModel): + _parameter_specs = {"I0": {"units": "flux/arcsec^2", "value": 1.0, "dynamic": False}} + usable = True + + +@combine_docstrings +class Moffat2DPSF(MoffatMixin, InclinedMixin, RadialMixin, PSFModel): + _model_type = "2d" + _parameter_specs = {"I0": {"units": "flux/arcsec^2", "value": 1.0, "dynamic": False}} + usable = True + + +@combine_docstrings +class MoffatSuperEllipse(MoffatMixin, SuperEllipseMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class MoffatFourierEllipse(MoffatMixin, FourierEllipseMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class MoffatWarp(MoffatMixin, WarpMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class MoffatRay(iMoffatMixin, RayMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class MoffatWedge(iMoffatMixin, WedgeMixin, GalaxyModel): + usable = True diff --git a/astrophot/models/moffat_model.py b/astrophot/models/moffat_model.py deleted file mode 100644 index 51122628..00000000 --- a/astrophot/models/moffat_model.py +++ /dev/null @@ -1,151 +0,0 @@ -import torch -import numpy as np - -from .galaxy_model_object import Galaxy_Model -from .psf_model_object import PSF_Model -from ._shared_methods import parametric_initialize, select_target -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..utils.parametric_profiles import moffat_np -from ..utils.conversions.functions import moffat_I0_to_flux -from ..param import Param_Unlock, Param_SoftLimits - -__all__ = ["Moffat_Galaxy", "Moffat_PSF"] - - -def _x0_func(model_params, R, F): - return 2.0, R[4], F[0] - - -def _wrap_moffat(R, n, rd, i0): - return moffat_np(R, n, rd, 10 ** (i0)) - - -class Moffat_Galaxy(Galaxy_Model): - """basic galaxy model with a Moffat profile for the radial light - profile. The functional form of the Moffat profile is defined as: - - I(R) = I0 / (1 + (R/Rd)^2)^n - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, I0 is the central flux - density, Rd is the scale length for the profile, and n is the - concentration index which controls the shape of the profile. - - Parameters: - n: Concentration index which controls the shape of the brightness profile - I0: brightness at the center of the profile, represented as the log of the brightness divided by pixel scale squared. - Rd: scale length radius - - """ - - model_type = f"moffat {Galaxy_Model.model_type}" - parameter_specs = { - "n": {"units": "none", "limits": (0.1, 10), "uncertainty": 0.05}, - "Rd": {"units": "arcsec", "limits": (0, None)}, - "I0": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = Galaxy_Model._parameter_order + ("n", "Rd", "I0") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_moffat, ("n", "Rd", "I0"), _x0_func) - - @default_internal - def total_flux(self, parameters=None, window=None): - return moffat_I0_to_flux( - 10 ** parameters["I0"].value, - parameters["n"].value, - parameters["Rd"].value, - parameters["q"].value, - ) - - from ._shared_methods import moffat_radial_model as radial_model - - -class Moffat_PSF(PSF_Model): - """basic point source model with a Moffat profile for the radial light - profile. The functional form of the Moffat profile is defined as: - - I(R) = I0 / (1 + (R/Rd)^2)^n - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, I0 is the central flux - density, Rd is the scale length for the profile, and n is the - concentration index which controls the shape of the profile. - - Parameters: - n: Concentration index which controls the shape of the brightness profile - I0: brightness at the center of the profile, represented as the log of the brightness divided by pixel scale squared. - Rd: scale length radius - - """ - - model_type = f"moffat {PSF_Model.model_type}" - parameter_specs = { - "n": {"units": "none", "limits": (0.1, 10), "uncertainty": 0.05}, - "Rd": {"units": "arcsec", "limits": (0, None)}, - "I0": {"units": "log10(flux/arcsec^2)", "value": 0.0, "locked": True}, - } - _parameter_order = PSF_Model._parameter_order + ("n", "Rd", "I0") - usable = True - model_integrated = False - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_moffat, ("n", "Rd", "I0"), _x0_func) - - from ._shared_methods import moffat_radial_model as radial_model - - @default_internal - def total_flux(self, parameters=None, window=None): - return moffat_I0_to_flux( - 10 ** parameters["I0"].value, - parameters["n"].value, - parameters["Rd"].value, - torch.ones_like(parameters["n"].value), - ) - - from ._shared_methods import radial_evaluate_model as evaluate_model - - -class Moffat2D_PSF(Moffat_PSF): - - model_type = f"moffat2d {PSF_Model.model_type}" - parameter_specs = { - "q": {"units": "b/a", "limits": (0, 1), "uncertainty": 0.03}, - "PA": { - "units": "radians", - "limits": (0, np.pi), - "cyclic": True, - "uncertainty": 0.06, - }, - } - _parameter_order = Moffat_PSF._parameter_order + ("q", "PA") - usable = True - model_integrated = False - - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - with Param_Unlock(parameters["q"]), Param_SoftLimits(parameters["q"]): - if parameters["q"].value is None: - parameters["q"].value = 0.9 - - with Param_Unlock(parameters["PA"]), Param_SoftLimits(parameters["PA"]): - if parameters["PA"].value is None: - parameters["PA"].value = 0.1 - super().initialize(target=target, parameters=parameters) - - from ._shared_methods import inclined_transform_coordinates as transform_coordinates - from ._shared_methods import transformed_evaluate_model as evaluate_model diff --git a/astrophot/models/multi_gaussian_expansion.py b/astrophot/models/multi_gaussian_expansion.py new file mode 100644 index 00000000..5f50980c --- /dev/null +++ b/astrophot/models/multi_gaussian_expansion.py @@ -0,0 +1,130 @@ +from typing import Optional, Tuple +import torch +import numpy as np + +from .model_object import ComponentModel +from ..utils.decorators import ignore_numpy_warnings, combine_docstrings +from . import func +from .. import config +from ..backend_obj import backend, ArrayLike +from ..param import forward + +__all__ = ["MultiGaussianExpansion"] + + +@combine_docstrings +class MultiGaussianExpansion(ComponentModel): + """Model that represents a galaxy as a sum of multiple Gaussian + profiles. The model is defined as: + + $$I(R) = \\sum_i {\\rm flux}_i * \\exp(-0.5*(R_i / \\sigma_i)^2) / (2 * \\pi * q_i * \\sigma_i^2)$$ + + where $R_i$ is a radius computed using $q_i$ and $PA_i$ for that component. All components share the same center. + + **Parameters:** + - `q`: axis ratio to scale minor axis from the ratio of the minor/major axis b/a, this parameter is unitless, it is restricted to the range (0,1) + - `PA`: position angle of the semi-major axis relative to the image positive x-axis in radians, it is a cyclic parameter in the range [0,pi) + - `sigma`: standard deviation of each Gaussian + - `flux`: amplitude of each Gaussian + """ + + _model_type = "mge" + _parameter_specs = { + "q": {"units": "b/a", "valid": (0, 1), "dynamic": True}, + "PA": {"units": "radians", "valid": (0, np.pi), "cyclic": True, "dynamic": True}, + "sigma": {"units": "arcsec", "valid": (0, None), "dynamic": True}, + "flux": {"units": "flux", "dynamic": True}, + } + usable = True + + def __init__(self, *args, n_components: Optional[int] = None, **kwargs): + super().__init__(*args, **kwargs) + if n_components is None: + for key in ("q", "sigma", "flux"): + if self[key].value is not None: + self.n_components = self[key].value.shape[0] + break + else: + self.n_components = 1 + else: + self.n_components = int(n_components) + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + target_area = self.target[self.window] + dat = backend.to_numpy(target_area._data).copy() + mask = backend.to_numpy(target_area._mask) + dat[mask] = np.median(dat[~mask]) + + edge = np.concatenate((dat[:, 0], dat[:, -1], dat[0, :], dat[-1, :])) + edge_average = np.nanmedian(edge) + dat -= edge_average + + if not self.sigma.initialized: + self.sigma.value = np.logspace( + np.log10(target_area.pixelscale.item() * 3), + max(target_area.data.shape) * target_area.pixelscale.item() * 0.7, + self.n_components, + ) + if not self.flux.initialized: + self.flux.value = (np.sum(dat) / self.n_components) * np.ones(self.n_components) + + if self.PA.initialized or self.q.initialized: + return + + x, y = target_area.coordinate_center_meshgrid() + x = backend.to_numpy(x - self.center.value[0]) + y = backend.to_numpy(y - self.center.value[1]) + mu20 = np.median(dat * np.abs(x)) + mu02 = np.median(dat * np.abs(y)) + mu11 = np.median(dat * x * y / np.sqrt(np.abs(x * y) + self.softening**2)) + # mu20 = np.median(dat * x**2) + # mu02 = np.median(dat * y**2) + # mu11 = np.median(dat * x * y) + M = np.array([[mu20, mu11], [mu11, mu02]]) + ones = np.ones(self.n_components) + if not self.PA.initialized: + if np.any(np.iscomplex(M)) or np.any(~np.isfinite(M)): + self.PA.value = ones * np.pi / 2 + else: + self.PA.value = ones * (0.5 * np.arctan2(2 * mu11, mu20 - mu02) - np.pi / 2) % np.pi + if not self.q.initialized: + l = np.sort(np.linalg.eigvals(M)) + if np.any(np.iscomplex(l)) or np.any(~np.isfinite(l)): + l = (0.7, 1.0) + self.q.value = ones * np.clip(np.sqrt(l[0] / l[1]), 0.1, 0.9) + + @forward + def transform_coordinates( + self, x: ArrayLike, y: ArrayLike, q: ArrayLike, PA: ArrayLike + ) -> Tuple[ArrayLike, ArrayLike]: + x, y = super().transform_coordinates(x, y) + if np.prod(PA.shape) == 1: + x, y = func.rotate(-(PA + np.pi / 2), x, y) + x = x * backend.ones( + (q.shape[0], *[1] * x.ndim), dtype=config.DTYPE, device=config.DEVICE + ) + y = y * backend.ones( + (q.shape[0], *[1] * y.ndim), dtype=config.DTYPE, device=config.DEVICE + ) + else: + x, y = backend.vmap(lambda pa: func.rotate(-(pa + np.pi / 2), x, y))(PA) + y = backend.vmap(lambda q, y: y / q)(q, y) + return x, y + + @forward + def brightness( + self, x: ArrayLike, y: ArrayLike, flux: ArrayLike, sigma: ArrayLike, q: ArrayLike + ) -> ArrayLike: + x, y = self.transform_coordinates(x, y) + R = self.radius_metric(x, y) + return backend.sum( + backend.vmap( + lambda A, r, sig, _q: (A / backend.sqrt(2 * np.pi * _q * sig**2)) + * backend.exp(-0.5 * (r / sig) ** 2) + )(flux, R, sigma, q), + dim=0, + ) diff --git a/astrophot/models/multi_gaussian_expansion_model.py b/astrophot/models/multi_gaussian_expansion_model.py deleted file mode 100644 index d6f42da8..00000000 --- a/astrophot/models/multi_gaussian_expansion_model.py +++ /dev/null @@ -1,171 +0,0 @@ -import torch -import numpy as np -from scipy.stats import iqr - -from .psf_model_object import PSF_Model -from .model_object import Component_Model -from ._shared_methods import ( - select_target, -) -from ..utils.initialize import isophotes -from ..utils.angle_operations import Angle_COM_PA -from ..utils.conversions.coordinates import ( - Rotate_Cartesian, -) -from ..param import Param_Unlock, Param_SoftLimits, Parameter_Node -from ..utils.decorators import ignore_numpy_warnings, default_internal - -__all__ = ["Multi_Gaussian_Expansion"] - - -class Multi_Gaussian_Expansion(Component_Model): - """Model that represents a galaxy as a sum of multiple Gaussian - profiles. The model is defined as: - - I(R) = sum_i flux_i * exp(-0.5*(R_i / sigma_i)^2) / (2 * pi * q_i * sigma_i^2) - - where $R_i$ is a radius computed using $q_i$ and $PA_i$ for that component. All components share the same center. - - Parameters: - q: axis ratio to scale minor axis from the ratio of the minor/major axis b/a, this parameter is unitless, it is restricted to the range (0,1) - PA: position angle of the semi-major axis relative to the image positive x-axis in radians, it is a cyclic parameter in the range [0,pi) - sigma: standard deviation of each Gaussian - flux: amplitude of each Gaussian - """ - - model_type = f"mge {Component_Model.model_type}" - parameter_specs = { - "q": {"units": "b/a", "limits": (0, 1)}, - "PA": {"units": "radians", "limits": (0, np.pi), "cyclic": True}, - "sigma": {"units": "arcsec", "limits": (0, None)}, - "flux": {"units": "log10(flux)"}, - } - _parameter_order = Component_Model._parameter_order + ("q", "PA", "sigma", "flux") - usable = True - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # determine the number of components - for key in ("q", "sigma", "flux"): - if self[key].value is not None: - self.n_components = self[key].value.shape[0] - break - else: - self.n_components = kwargs.get("n_components", 3) - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - target_area = target[self.window] - target_dat = target_area.data.detach().cpu().numpy().copy() - if target_area.has_mask: - mask = target_area.mask.detach().cpu().numpy() - target_dat[mask] = np.median(target_dat[np.logical_not(mask)]) - if parameters["sigma"].value is None: - with Param_Unlock(parameters["sigma"]), Param_SoftLimits(parameters["sigma"]): - parameters["sigma"].value = np.logspace( - np.log10(target_area.pixel_length.item() * 3), - max(target_area.shape.detach().cpu().numpy()) * 0.7, - self.n_components, - ) - parameters["sigma"].uncertainty = ( - self.default_uncertainty * parameters["sigma"].value - ) - if parameters["flux"].value is None: - with Param_Unlock(parameters["flux"]), Param_SoftLimits(parameters["flux"]): - parameters["flux"].value = np.log10( - np.sum(target_dat[~mask]) / self.n_components - ) * np.ones(self.n_components) - parameters["flux"].uncertainty = 0.1 * parameters["flux"].value - - if not (parameters["PA"].value is None or parameters["q"].value is None): - return - edge = np.concatenate( - ( - target_dat[:, 0], - target_dat[:, -1], - target_dat[0, :], - target_dat[-1, :], - ) - ) - edge_average = np.nanmedian(edge) - edge_scatter = iqr(edge[np.isfinite(edge)], rng=(16, 84)) / 2 - icenter = target_area.plane_to_pixel(parameters["center"].value) - - if parameters["PA"].value is None: - weights = target_dat - edge_average - Coords = target_area.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - X, Y = X.detach().cpu().numpy(), Y.detach().cpu().numpy() - if target_area.has_mask: - seg = np.logical_not(target_area.mask.detach().cpu().numpy()) - PA = Angle_COM_PA(weights[seg], X[seg], Y[seg]) - else: - PA = Angle_COM_PA(weights, X, Y) - - with Param_Unlock(parameters["PA"]), Param_SoftLimits(parameters["PA"]): - parameters["PA"].value = ((PA + target_area.north) % np.pi) * np.ones( - self.n_components - ) - if parameters["PA"].uncertainty is None: - parameters["PA"].uncertainty = (5 * np.pi / 180) * torch.ones_like( - parameters["PA"].value - ) # default uncertainty of 5 degrees is assumed - if parameters["q"].value is None: - q_samples = np.linspace(0.2, 0.9, 15) - try: - pa = parameters["PA"].value.item() - except: - pa = parameters["PA"].value[0].item() - iso_info = isophotes( - target_area.data.detach().cpu().numpy() - edge_average, - (icenter[1].detach().cpu().item(), icenter[0].detach().cpu().item()), - threshold=3 * edge_scatter, - pa=(pa - target.north), - q=q_samples, - ) - with Param_Unlock(parameters["q"]), Param_SoftLimits(parameters["q"]): - parameters["q"].value = q_samples[ - np.argmin(list(iso["amplitude2"] for iso in iso_info)) - ] * torch.ones(self.n_components) - if parameters["q"].uncertainty is None: - parameters["q"].uncertainty = parameters["q"].value * self.default_uncertainty - - @default_internal - def total_flux(self, parameters=None): - return torch.sum(10 ** parameters["flux"].value) - - @default_internal - def evaluate_model(self, X=None, Y=None, image=None, parameters=None, **kwargs): - if X is None or Y is None: - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - - if parameters["PA"].value.numel() == 1: - X, Y = Rotate_Cartesian(-(parameters["PA"].value - image.north), X, Y) - X = X.repeat(parameters["q"].value.shape[0], *[1] * X.ndim) - Y = torch.vmap(lambda q: Y / q)(parameters["q"].value) - else: - X, Y = torch.vmap(lambda pa: Rotate_Cartesian(-(pa - image.north), X, Y))( - parameters["PA"].value - ) - Y = torch.vmap(lambda q, y: y / q)(parameters["q"].value, Y) - - R = self.radius_metric(X, Y, image, parameters) - return torch.sum( - torch.vmap( - lambda A, R, sigma, q: (A / (2 * np.pi * q * sigma**2)) - * torch.exp(-0.5 * (R / sigma) ** 2) - )( - image.pixel_area * 10 ** parameters["flux"].value, - R, - parameters["sigma"].value, - parameters["q"].value, - ), - dim=0, - ) diff --git a/astrophot/models/nuker.py b/astrophot/models/nuker.py new file mode 100644 index 00000000..6e9f55f6 --- /dev/null +++ b/astrophot/models/nuker.py @@ -0,0 +1,60 @@ +from .galaxy_model_object import GalaxyModel +from .psf_model_object import PSFModel +from .mixins import ( + NukerMixin, + RadialMixin, + iNukerMixin, + RayMixin, + WedgeMixin, + SuperEllipseMixin, + FourierEllipseMixin, + WarpMixin, +) +from ..utils.decorators import combine_docstrings + + +__all__ = [ + "NukerGalaxy", + "NukerPSF", + "NukerSuperEllipse", + "NukerFourierEllipse", + "NukerWarp", + "NukerWedge", + "NukerRay", +] + + +@combine_docstrings +class NukerGalaxy(NukerMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class NukerPSF(NukerMixin, RadialMixin, PSFModel): + _parameter_specs = {"Ib": {"units": "flux/arcsec^2", "value": 1.0, "dynamic": False}} + usable = True + + +@combine_docstrings +class NukerSuperEllipse(NukerMixin, SuperEllipseMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class NukerFourierEllipse(NukerMixin, FourierEllipseMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class NukerWarp(NukerMixin, WarpMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class NukerRay(iNukerMixin, RayMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class NukerWedge(iNukerMixin, WedgeMixin, GalaxyModel): + usable = True diff --git a/astrophot/models/nuker_model.py b/astrophot/models/nuker_model.py deleted file mode 100644 index 8911ca99..00000000 --- a/astrophot/models/nuker_model.py +++ /dev/null @@ -1,557 +0,0 @@ -import torch - -from .galaxy_model_object import Galaxy_Model -from .psf_model_object import PSF_Model -from .warp_model import Warp_Galaxy -from .ray_model import Ray_Galaxy -from .wedge_model import Wedge_Galaxy -from .superellipse_model import SuperEllipse_Galaxy, SuperEllipse_Warp -from .foureirellipse_model import FourierEllipse_Galaxy, FourierEllipse_Warp -from ._shared_methods import ( - parametric_initialize, - parametric_segment_initialize, - select_target, -) -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..utils.parametric_profiles import nuker_np - -__all__ = [ - "Nuker_Galaxy", - "Nuker_PSF", - "Nuker_SuperEllipse", - "Nuker_SuperEllipse_Warp", - "Nuker_FourierEllipse", - "Nuker_FourierEllipse_Warp", - "Nuker_Warp", - "Nuker_Ray", -] - - -def _x0_func(model_params, R, F): - return R[4], F[4], 1.0, 2.0, 0.5 - - -def _wrap_nuker(R, rb, ib, a, b, g): - return nuker_np(R, rb, 10 ** (ib), a, b, g) - - -class Nuker_Galaxy(Galaxy_Model): - """basic galaxy model with a Nuker profile for the radial light - profile. The functional form of the Nuker profile is defined as: - - I(R) = Ib * 2^((beta-gamma)/alpha) * (R / Rb)^(-gamma) * (1 + (R/Rb)^alpha)^((gamma - beta)/alpha) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ib is the flux density at - the scale radius Rb, Rb is the scale length for the profile, beta - is the outer power law slope, gamma is the iner power law slope, - and alpha is the sharpness of the transition. - - Parameters: - Ib: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. - Rb: scale length radius - alpha: sharpness of transition between power law slopes - beta: outer power law slope - gamma: inner power law slope - - """ - - model_type = f"nuker {Galaxy_Model.model_type}" - parameter_specs = { - "Rb": {"units": "arcsec", "limits": (0, None)}, - "Ib": {"units": "log10(flux/arcsec^2)"}, - "alpha": {"units": "none", "limits": (0, None)}, - "beta": {"units": "none", "limits": (0, None)}, - "gamma": {"units": "none"}, - } - _parameter_order = Galaxy_Model._parameter_order + ( - "Rb", - "Ib", - "alpha", - "beta", - "gamma", - ) - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize( - self, - parameters, - target, - _wrap_nuker, - ("Rb", "Ib", "alpha", "beta", "gamma"), - _x0_func, - ) - - from ._shared_methods import nuker_radial_model as radial_model - - -class Nuker_PSF(PSF_Model): - """basic point source model with a Nuker profile for the radial light - profile. The functional form of the Nuker profile is defined as: - - I(R) = Ib * 2^((beta-gamma)/alpha) * (R / Rb)^(-gamma) * (1 + (R/Rb)^alpha)^((gamma - beta)/alpha) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ib is the flux density at - the scale radius Rb, Rb is the scale length for the profile, beta - is the outer power law slope, gamma is the iner power law slope, - and alpha is the sharpness of the transition. - - Parameters: - Ib: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. - Rb: scale length radius - alpha: sharpness of transition between power law slopes - beta: outer power law slope - gamma: inner power law slope - - """ - - model_type = f"nuker {PSF_Model.model_type}" - parameter_specs = { - "Rb": {"units": "arcsec", "limits": (0, None)}, - "Ib": {"units": "log10(flux/arcsec^2)", "value": 0.0, "locked": True}, - "alpha": {"units": "none", "limits": (0, None)}, - "beta": {"units": "none", "limits": (0, None)}, - "gamma": {"units": "none"}, - } - _parameter_order = PSF_Model._parameter_order + ( - "Rb", - "Ib", - "alpha", - "beta", - "gamma", - ) - usable = True - model_integrated = False - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize( - self, - parameters, - target, - _wrap_nuker, - ("Rb", "Ib", "alpha", "beta", "gamma"), - _x0_func, - ) - - from ._shared_methods import nuker_radial_model as radial_model - from ._shared_methods import radial_evaluate_model as evaluate_model - - -class Nuker_SuperEllipse(SuperEllipse_Galaxy): - """super ellipse galaxy model with a Nuker profile for the radial - light profile. The functional form of the Nuker profile is defined as: - - I(R) = Ib * 2^((beta-gamma)/alpha) * (R / Rb)^(-gamma) * (1 + (R/Rb)^alpha)^((gamma - beta)/alpha) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ib is the flux density at - the scale radius Rb, Rb is the scale length for the profile, beta - is the outer power law slope, gamma is the iner power law slope, - and alpha is the sharpness of the transition. - - Parameters: - Ib: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. - Rb: scale length radius - alpha: sharpness of transition between power law slopes - beta: outer power law slope - gamma: inner power law slope - - """ - - model_type = f"nuker {SuperEllipse_Galaxy.model_type}" - parameter_specs = { - "Rb": {"units": "arcsec", "limits": (0, None)}, - "Ib": {"units": "log10(flux/arcsec^2)"}, - "alpha": {"units": "none", "limits": (0, None)}, - "beta": {"units": "none", "limits": (0, None)}, - "gamma": {"units": "none"}, - } - _parameter_order = SuperEllipse_Galaxy._parameter_order + ( - "Rb", - "Ib", - "alpha", - "beta", - "gamma", - ) - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize( - self, - parameters, - target, - _wrap_nuker, - ("Rb", "Ib", "alpha", "beta", "gamma"), - _x0_func, - ) - - from ._shared_methods import nuker_radial_model as radial_model - - -class Nuker_SuperEllipse_Warp(SuperEllipse_Warp): - """super ellipse warp galaxy model with a Nuker profile for the - radial light profile. The functional form of the Nuker profile is - defined as: - - I(R) = Ib * 2^((beta-gamma)/alpha) * (R / Rb)^(-gamma) * (1 + (R/Rb)^alpha)^((gamma - beta)/alpha) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ib is the flux density at - the scale radius Rb, Rb is the scale length for the profile, beta - is the outer power law slope, gamma is the iner power law slope, - and alpha is the sharpness of the transition. - - Parameters: - Ib: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. - Rb: scale length radius - alpha: sharpness of transition between power law slopes - beta: outer power law slope - gamma: inner power law slope - - - """ - - model_type = f"nuker {SuperEllipse_Warp.model_type}" - parameter_specs = { - "Rb": {"units": "arcsec", "limits": (0, None)}, - "Ib": {"units": "log10(flux/arcsec^2)"}, - "alpha": {"units": "none", "limits": (0, None)}, - "beta": {"units": "none", "limits": (0, None)}, - "gamma": {"units": "none"}, - } - _parameter_order = SuperEllipse_Warp._parameter_order + ( - "Rb", - "Ib", - "alpha", - "beta", - "gamma", - ) - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize( - self, - parameters, - target, - _wrap_nuker, - ("Rb", "Ib", "alpha", "beta", "gamma"), - _x0_func, - ) - - from ._shared_methods import nuker_radial_model as radial_model - - -class Nuker_FourierEllipse(FourierEllipse_Galaxy): - """fourier mode perturbations to ellipse galaxy model with a Nuker - profile for the radial light profile. The functional form of the - Nuker profile is defined as: - - I(R) = Ib * 2^((beta-gamma)/alpha) * (R / Rb)^(-gamma) * (1 + (R/Rb)^alpha)^((gamma - beta)/alpha) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ib is the flux density at - the scale radius Rb, Rb is the scale length for the profile, beta - is the outer power law slope, gamma is the iner power law slope, - and alpha is the sharpness of the transition. - - Parameters: - Ib: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. - Rb: scale length radius - alpha: sharpness of transition between power law slopes - beta: outer power law slope - gamma: inner power law slope - - """ - - model_type = f"nuker {FourierEllipse_Galaxy.model_type}" - parameter_specs = { - "Rb": {"units": "arcsec", "limits": (0, None)}, - "Ib": {"units": "log10(flux/arcsec^2)"}, - "alpha": {"units": "none", "limits": (0, None)}, - "beta": {"units": "none", "limits": (0, None)}, - "gamma": {"units": "none"}, - } - _parameter_order = FourierEllipse_Galaxy._parameter_order + ( - "Rb", - "Ib", - "alpha", - "beta", - "gamma", - ) - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize( - self, - parameters, - target, - _wrap_nuker, - ("Rb", "Ib", "alpha", "beta", "gamma"), - _x0_func, - ) - - from ._shared_methods import nuker_radial_model as radial_model - - -class Nuker_FourierEllipse_Warp(FourierEllipse_Warp): - """fourier mode perturbations to ellipse galaxy model with a Nuker - profile for the radial light profile. The functional form of the - Nuker profile is defined as: - - I(R) = Ib * 2^((beta-gamma)/alpha) * (R / Rb)^(-gamma) * (1 + (R/Rb)^alpha)^((gamma - beta)/alpha) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ib is the flux density at - the scale radius Rb, Rb is the scale length for the profile, beta - is the outer power law slope, gamma is the iner power law slope, - and alpha is the sharpness of the transition. - - Parameters: - Ib: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. - Rb: scale length radius - alpha: sharpness of transition between power law slopes - beta: outer power law slope - gamma: inner power law slope - - """ - - model_type = f"nuker {FourierEllipse_Warp.model_type}" - parameter_specs = { - "Rb": {"units": "arcsec", "limits": (0, None)}, - "Ib": {"units": "log10(flux/arcsec^2)"}, - "alpha": {"units": "none", "limits": (0, None)}, - "beta": {"units": "none", "limits": (0, None)}, - "gamma": {"units": "none"}, - } - _parameter_order = FourierEllipse_Warp._parameter_order + ( - "Rb", - "Ib", - "alpha", - "beta", - "gamma", - ) - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize( - self, - parameters, - target, - _wrap_nuker, - ("Rb", "Ib", "alpha", "beta", "gamma"), - _x0_func, - ) - - from ._shared_methods import nuker_radial_model as radial_model - - -class Nuker_Warp(Warp_Galaxy): - """warped coordinate galaxy model with a Nuker profile for the radial - light model. The functional form of the Nuker profile is defined - as: - - I(R) = Ib * 2^((beta-gamma)/alpha) * (R / Rb)^(-gamma) * (1 + (R/Rb)^alpha)^((gamma - beta)/alpha) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ib is the flux density at - the scale radius Rb, Rb is the scale length for the profile, beta - is the outer power law slope, gamma is the iner power law slope, - and alpha is the sharpness of the transition. - - Parameters: - Ib: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. - Rb: scale length radius - alpha: sharpness of transition between power law slopes - beta: outer power law slope - gamma: inner power law slope - - """ - - model_type = f"nuker {Warp_Galaxy.model_type}" - parameter_specs = { - "Rb": {"units": "arcsec", "limits": (0, None)}, - "Ib": {"units": "log10(flux/arcsec^2)"}, - "alpha": {"units": "none", "limits": (0, None)}, - "beta": {"units": "none", "limits": (0, None)}, - "gamma": {"units": "none"}, - } - _parameter_order = Warp_Galaxy._parameter_order + ( - "Rb", - "Ib", - "alpha", - "beta", - "gamma", - ) - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize( - self, - parameters, - target, - _wrap_nuker, - ("Rb", "Ib", "alpha", "beta", "gamma"), - _x0_func, - ) - - from ._shared_methods import nuker_radial_model as radial_model - - -class Nuker_Ray(Ray_Galaxy): - """ray galaxy model with a nuker profile for the radial light - model. The functional form of the Sersic profile is defined as: - - I(R) = Ib * 2^((beta-gamma)/alpha) * (R / Rb)^(-gamma) * (1 + (R/Rb)^alpha)^((gamma - beta)/alpha) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ib is the flux density at - the scale radius Rb, Rb is the scale length for the profile, beta - is the outer power law slope, gamma is the iner power law slope, - and alpha is the sharpness of the transition. - - Parameters: - Ib: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. - Rb: scale length radius - alpha: sharpness of transition between power law slopes - beta: outer power law slope - gamma: inner power law slope - - """ - - model_type = f"nuker {Ray_Galaxy.model_type}" - parameter_specs = { - "Rb": {"units": "arcsec", "limits": (0, None)}, - "Ib": {"units": "log10(flux/arcsec^2)"}, - "alpha": {"units": "none", "limits": (0, None)}, - "beta": {"units": "none", "limits": (0, None)}, - "gamma": {"units": "none"}, - } - _parameter_order = Ray_Galaxy._parameter_order + ( - "Rb", - "Ib", - "alpha", - "beta", - "gamma", - ) - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_segment_initialize( - model=self, - parameters=parameters, - target=target, - prof_func=_wrap_nuker, - params=("Rb", "Ib", "alpha", "beta", "gamma"), - x0_func=_x0_func, - segments=self.rays, - ) - - from ._shared_methods import nuker_iradial_model as iradial_model - - -class Nuker_Wedge(Wedge_Galaxy): - """wedge galaxy model with a nuker profile for the radial light - model. The functional form of the Sersic profile is defined as: - - I(R) = Ib * 2^((beta-gamma)/alpha) * (R / Rb)^(-gamma) * (1 + (R/Rb)^alpha)^((gamma - beta)/alpha) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ib is the flux density at - the scale radius Rb, Rb is the scale length for the profile, beta - is the outer power law slope, gamma is the iner power law slope, - and alpha is the sharpness of the transition. - - Parameters: - Ib: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. - Rb: scale length radius - alpha: sharpness of transition between power law slopes - beta: outer power law slope - gamma: inner power law slope - - """ - - model_type = f"nuker {Wedge_Galaxy.model_type}" - parameter_specs = { - "Rb": {"units": "arcsec", "limits": (0, None)}, - "Ib": {"units": "log10(flux/arcsec^2)"}, - "alpha": {"units": "none", "limits": (0, None)}, - "beta": {"units": "none", "limits": (0, None)}, - "gamma": {"units": "none"}, - } - _parameter_order = Wedge_Galaxy._parameter_order + ( - "Rb", - "Ib", - "alpha", - "beta", - "gamma", - ) - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_segment_initialize( - model=self, - parameters=parameters, - target=target, - prof_func=_wrap_nuker, - params=("Rb", "Ib", "alpha", "beta", "gamma"), - x0_func=_x0_func, - segments=self.wedges, - ) - - from ._shared_methods import nuker_iradial_model as iradial_model diff --git a/astrophot/models/pixelated_psf.py b/astrophot/models/pixelated_psf.py new file mode 100644 index 00000000..e0821aed --- /dev/null +++ b/astrophot/models/pixelated_psf.py @@ -0,0 +1,64 @@ +import torch + +from .psf_model_object import PSFModel +from ..utils.decorators import ignore_numpy_warnings, combine_docstrings +from ..utils.interpolate import interp2d +from caskade import OverrideParam +from ..param import forward +from ..backend_obj import backend, ArrayLike + +__all__ = ["PixelatedPSF"] + + +@combine_docstrings +class PixelatedPSF(PSFModel): + """point source model which uses an image of the PSF as its + representation for point sources. Using bilinear interpolation it + will shift the PSF within a pixel to accurately represent the + center location of a point source. There is no functional form for + this object type as any image can be supplied. The image pixels + will be optimized as individual parameters. This can very quickly + result in a large number of parameters and a near impossible + fitting task, ideally this should be restricted to a very small + area likely at the center of the PSF. + + To initialize the PSF image will by default be set to the target + PSF_Image values, thus one can use an empirical PSF as a starting + point. Since only bilinear interpolation is performed, it is + recommended to provide the PSF at a higher resolution than the + image if it is near the nyquist sampling limit. Bilinear + interpolation is very fast and accurate for smooth models, so this + way it is possible to do the expensive interpolation before + optimization and save time. Note that if you do this you must + provide the PSF as a PSF_Image object with the correct pixelscale + (essentially just divide the pixelscale by the upsampling factor + you used). + + **Parameters:** + - `pixels`: the total flux within each pixel, represented as the log of the flux. + + """ + + _model_type = "pixelated" + _parameter_specs = {"pixels": {"units": "flux/arcsec^2", "dynamic": True}} + usable = True + sampling_mode = "midpoint" + integrate_mode = "none" + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + if self.pixels.initialized: + return + target_area = self.target[self.window] + self.pixels.value = backend.copy(target_area._data) / target_area.pixel_area + + @forward + def brightness( + self, x: ArrayLike, y: ArrayLike, pixels: ArrayLike, center: ArrayLike + ) -> ArrayLike: + with OverrideParam(self.target.crtan, center): + i, j = self.target.plane_to_pixel(x, y) + result = interp2d(pixels, i, j) + return result diff --git a/astrophot/models/pixelated_psf_model.py b/astrophot/models/pixelated_psf_model.py deleted file mode 100644 index 5169a3b0..00000000 --- a/astrophot/models/pixelated_psf_model.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch - -from .psf_model_object import PSF_Model -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..utils.interpolate import interp2d -from ._shared_methods import select_target -from ..param import Param_Unlock, Param_SoftLimits - -__all__ = ["Pixelated_PSF"] - - -class Pixelated_PSF(PSF_Model): - """point source model which uses an image of the PSF as its - representation for point sources. Using bilinear interpolation it - will shift the PSF within a pixel to accurately represent the - center location of a point source. There is no functional form for - this object type as any image can be supplied. The image pixels - will be optimized as individual parameters. This can very quickly - result in a large number of parameters and a near impossible - fitting task, ideally this should be restricted to a very small - area likely at the center of the PSF. - - To initialize the PSF image will by default be set to the target - PSF_Image values, thus one can use an empirical PSF as a starting - point. Since only bilinear interpolation is performed, it is - recommended to provide the PSF at a higher resolution than the - image if it is near the nyquist sampling limit. Bilinear - interpolation is very fast and accurate for smooth models, so this - way it is possible to do the expensive interpolation before - optimization and save time. Note that if you do this you must - provide the PSF as a PSF_Image object with the correct pixelscale - (essentially just divide the pixelscale by the upsampling factor - you used). - - Parameters: - pixels: the total flux within each pixel, represented as the log of the flux. - - """ - - model_type = f"pixelated {PSF_Model.model_type}" - parameter_specs = { - "pixels": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = PSF_Model._parameter_order + ("pixels",) - usable = True - model_integrated = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - target_area = target[self.window] - with Param_Unlock(parameters["pixels"]), Param_SoftLimits(parameters["pixels"]): - if parameters["pixels"].value is None: - dat = torch.abs(target_area.data) - dat[dat == 0] = torch.median(dat) * 1e-7 - parameters["pixels"].value = torch.log10(dat / target.pixel_area) - if parameters["pixels"].uncertainty is None: - parameters["pixels"].uncertainty = ( - torch.abs(parameters["pixels"].value) * self.default_uncertainty - ) - - @default_internal - def evaluate_model(self, X=None, Y=None, image=None, parameters=None, **kwargs): - if X is None: - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - - # Convert coordinates into pixel locations in the psf image - pX, pY = self.target.plane_to_pixel(X, Y) - - # Select only the pixels where the PSF image is defined - select = torch.logical_and( - torch.logical_and(pX > -0.5, pX < parameters["pixels"].shape[1] - 0.5), - torch.logical_and(pY > -0.5, pY < parameters["pixels"].shape[0] - 0.5), - ) - - # Zero everywhere outside the psf - result = torch.zeros_like(X) - - # Use bilinear interpolation of the PSF at the requested coordinates - result[select] = interp2d(parameters["pixels"].value, pX[select], pY[select]) - - return image.pixel_area * 10**result diff --git a/astrophot/models/planesky.py b/astrophot/models/planesky.py new file mode 100644 index 00000000..b8d4f251 --- /dev/null +++ b/astrophot/models/planesky.py @@ -0,0 +1,50 @@ +import numpy as np +import torch + +from .sky_model_object import SkyModel +from ..utils.decorators import ignore_numpy_warnings, combine_docstrings +from ..param import forward +from ..backend_obj import backend, ArrayLike + +__all__ = ["PlaneSky"] + + +@combine_docstrings +class PlaneSky(SkyModel): + """Sky background model using a tilted plane for the sky flux. The brightness for each pixel is defined as: + + $$I(X, Y) = I_0 + X*\\delta_x + Y*\\delta_y$$ + + where $I(X,Y)$ is the brightness as a function of image position $X, Y$, + $I_0$ is the central sky brightness value, and $\\delta_x, \\delta_y$ are the slopes of + the sky brightness plane. + + **Parameters:** + - `I0`: central sky brightness value + - `delta`: Tensor for slope of the sky brightness in each image dimension + + """ + + _model_type = "plane" + _parameter_specs = { + "I0": {"units": "flux/arcsec^2", "dynamic": True}, + "delta": {"units": "flux/arcsec", "dynamic": True}, + } + usable = True + + @torch.no_grad() + @ignore_numpy_warnings + def initialize(self): + super().initialize() + + if not self.I0.initialized: + dat = backend.to_numpy(self.target[self.window]._data).copy() + mask = backend.to_numpy(self.target[self.window]._mask) + dat[mask] = np.median(dat[~mask]) + self.I0.value = np.median(dat) / self.target.pixel_area.item() + if not self.delta.initialized: + self.delta.value = [0.0, 0.0] + + @forward + def brightness(self, x: ArrayLike, y: ArrayLike, I0: ArrayLike, delta: ArrayLike) -> ArrayLike: + return I0 + x * delta[0] + y * delta[1] diff --git a/astrophot/models/planesky_model.py b/astrophot/models/planesky_model.py deleted file mode 100644 index 31b0ace7..00000000 --- a/astrophot/models/planesky_model.py +++ /dev/null @@ -1,74 +0,0 @@ -import numpy as np -from scipy.stats import iqr -import torch - -from .sky_model_object import Sky_Model -from ._shared_methods import select_target -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..param import Param_Unlock, Param_SoftLimits - -__all__ = ["Plane_Sky"] - - -class Plane_Sky(Sky_Model): - """Sky background model using a tilted plane for the sky flux. The brightness for each pixel is defined as: - - I(X, Y) = S + X*dx + Y*dy - - where I(X,Y) is the brightness as a function of image position X Y, - S is the central sky brightness value, and dx dy are the slopes of - the sky brightness plane. - - Parameters: - sky: central sky brightness value - delta: Tensor for slope of the sky brightness in each image dimension - - """ - - model_type = f"plane {Sky_Model.model_type}" - parameter_specs = { - "F": {"units": "flux/arcsec^2"}, - "delta": {"units": "flux/arcsec"}, - } - _parameter_order = Sky_Model._parameter_order + ("F", "delta") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - with Param_Unlock(parameters["F"]), Param_SoftLimits(parameters["F"]): - if parameters["F"].value is None: - parameters["F"].value = ( - np.median(target[self.window].data.detach().cpu().numpy()) - / target.pixel_area.item() - ) - if parameters["F"].uncertainty is None: - parameters["F"].uncertainty = ( - iqr( - target[self.window].data.detach().cpu().numpy(), - rng=(31.731 / 2, 100 - 31.731 / 2), - ) - / (2.0) - ) / np.sqrt(np.prod(self.window.shape.detach().cpu().numpy())) - with Param_Unlock(parameters["delta"]), Param_SoftLimits(parameters["delta"]): - if parameters["delta"].value is None: - parameters["delta"].value = [0.0, 0.0] - parameters["delta"].uncertainty = [ - self.default_uncertainty, - self.default_uncertainty, - ] - - @default_internal - def evaluate_model(self, X=None, Y=None, image=None, parameters=None, **kwargs): - if X is None: - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - return ( - image.pixel_area * parameters["F"].value - + X * parameters["delta"].value[0] - + Y * parameters["delta"].value[1] - ) diff --git a/astrophot/models/point_source.py b/astrophot/models/point_source.py index 1153f506..90faec52 100644 --- a/astrophot/models/point_source.py +++ b/astrophot/models/point_source.py @@ -3,31 +3,37 @@ import torch import numpy as np -from ..param import Param_Unlock, Param_SoftLimits, Parameter_Node -from .model_object import Component_Model -from .core_model import AstroPhot_Model -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..image import PSF_Image, Window, Model_Image, Image -from ._shared_methods import select_target +from .base import Model +from .model_object import ComponentModel +from ..image import ModelImage +from ..utils.decorators import ignore_numpy_warnings, combine_docstrings +from ..utils.interpolate import interp2d +from ..image import Window, PSFImage from ..errors import SpecificationConflict +from ..param import forward +from ..backend_obj import backend, ArrayLike +from . import func -__all__ = ("Point_Source",) +__all__ = ("PointSource",) -class Point_Source(Component_Model): +@combine_docstrings +class PointSource(ComponentModel): """Describes a point source in the image, this is a delta function at some position in the sky. This is typically used to describe stars, supernovae, very small galaxies, quasars, asteroids or any other object which can essentially be entirely described by a position and total flux (no structure). + **Parameters:** + - `flux`: The total flux of the point source + """ - model_type = f"point {Component_Model.model_type}" - parameter_specs = { - "flux": {"units": "log10(flux)"}, + _model_type = "point" + _parameter_specs = { + "flux": {"units": "flux", "valid": (0, None), "shape": (), "dynamic": True}, } - _parameter_order = Component_Model._parameter_order + ("flux",) usable = True def __init__(self, *args, **kwargs): @@ -35,50 +41,48 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.psf is None: - raise ValueError("Point_Source needs psf information") + raise SpecificationConflict("Point_Source needs a psf!") @torch.no_grad() @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) + def initialize(self): + super().initialize() - if parameters["flux"].value is not None: + if self.flux.initialized: return - target_area = target[self.window] - target_dat = target_area.data.detach().cpu().numpy().copy() - with Param_Unlock(parameters["flux"]), Param_SoftLimits(parameters["flux"]): - icenter = target_area.plane_to_pixel(parameters["center"].value) - edge = np.concatenate( - ( - target_dat[:, 0], - target_dat[:, -1], - target_dat[0, :], - target_dat[-1, :], - ) - ) - edge_average = np.median(edge) - parameters["flux"].value = np.log10(np.abs(np.sum(target_dat - edge_average))) - parameters["flux"].uncertainty = torch.std(target_area.data) / ( - np.log(10) * 10 ** parameters["flux"].value - ) + target_area = self.target[self.window] + dat = backend.to_numpy(target_area._data).copy() + mask = backend.to_numpy(target_area._mask) + dat[mask] = np.median(dat[~mask]) + + edge = np.concatenate((dat[:, 0], dat[:, -1], dat[0, :], dat[-1, :])) + edge_average = np.median(edge) + self.flux.value = np.abs(np.sum(dat - edge_average)) # Psf convolution should be on by default since this is a delta function @property - def psf_mode(self): - return "full" + def psf_convolve(self): + return True + + @psf_convolve.setter + def psf_convolve(self, value): + pass + + @property + def integrate_mode(self): + return "none" - @psf_mode.setter - def psf_mode(self, value): + @integrate_mode.setter + def integrate_mode(self, value): pass + @forward def sample( self, - image: Optional[Image] = None, window: Optional[Window] = None, - parameters: Optional[Parameter_Node] = None, - ): + center: ArrayLike = None, + flux: ArrayLike = None, + ) -> ModelImage: """Evaluate the model on the space covered by an image object. This function properly calls integration methods and PSF convolution. This should not be overloaded except in special @@ -104,86 +108,28 @@ def sample( Image: The image with the computed model values. """ - # Image on which to evaluate model - if image is None: - image = self.make_model_image(window=window) - # Window within which to evaluate model if window is None: - working_window = image.window.copy() - else: - working_window = window.copy() - - # Parameters with which to evaluate the model - if parameters is None: - parameters = self.parameters - - # Sample the PSF pixels - if isinstance(self.psf, AstroPhot_Model): - # Adjust for supersampled PSF - psf_upscale = torch.round( - working_window.pixel_length / self.psf.target.pixel_length - ).int() - working_window = working_window.rescale_pixel(1 / psf_upscale) - working_window.shift(-parameters["center"].value) - - # Make the image object to which the samples will be tracked - working_image = Model_Image(window=working_window) - - # Fill the image using the PSF model - psf = self.psf( - image=working_image, - parameters=parameters[self.psf.name], - ) - - # Scale for point source flux - working_image.data *= 10 ** parameters["flux"].value - - # Return to original coordinates - working_image.header.shift(parameters["center"].value) - - elif isinstance(self.psf, PSF_Image): - psf = self.psf.copy() - - # Adjust for supersampled PSF - psf_upscale = torch.round(working_window.pixel_length / psf.pixel_length).int() - working_window = working_window.rescale_pixel(1 / psf_upscale) + window = self.window - # Make the image object to which the samples will be tracked - working_image = Model_Image(window=working_window) - - # Compute the center offset - pixel_center = working_image.plane_to_pixel(parameters["center"].value) - center_shift = pixel_center - torch.round(pixel_center) - # working_image.header.pixel_shift(center_shift) - psf.window.shift(working_image.pixel_to_plane(torch.round(pixel_center))) - psf.data = self._shift_psf( - psf=psf.data, - shift=center_shift, - shift_method=self.psf_subpixel_shift, - keep_pad=False, + if isinstance(self.psf, PSFImage): + psf = self.psf._data + elif isinstance(self.psf, Model): + psf = self.psf()._data + else: + raise TypeError( + f"PSF must be a PSFImage or Model instance, got {type(self.psf)} instead." ) - psf.data /= torch.sum(psf.data) - - # Scale for psf flux - psf.data *= 10 ** parameters["flux"].value - # Fill pixels with the PSF image - working_image += psf + # Make the image object to which the samples will be tracked + working_image = self.target[window].model_image(upsample=self.psf_upscale) - # Shift image back to align with original pixel grid - # working_image.header.pixel_shift(-center_shift) + i, j, w = working_image.pixel_quad_meshgrid() + i0, j0 = working_image.plane_to_pixel(*center) + z = interp2d(psf, i - i0 + (psf.shape[0] // 2), j - j0 + (psf.shape[1] // 2)) - else: - raise SpecificationConflict( - f"Point_Source must have a psf that is either an AstroPhot_Model or a PSF_Image. not {type(self.psf)}" - ) + working_image._data = flux * func.pixel_quad_integrator(z, w) - # Return to image pixelscale - working_image = working_image.reduce(psf_upscale) - if self.mask is not None: - working_image.data = working_image.data * torch.logical_not(self.mask) + working_image = working_image.reduce(self.psf_upscale) - # Add the sampled/integrated/convolved pixels to the requested image - image += working_image - return image + return working_image diff --git a/astrophot/models/psf_model_object.py b/astrophot/models/psf_model_object.py index 18614e68..3ba42dfc 100644 --- a/astrophot/models/psf_model_object.py +++ b/astrophot/models/psf_model_object.py @@ -1,25 +1,16 @@ -from typing import Optional +from typing import Optional, Tuple +from caskade import forward -import torch +from .base import Model +from ..image import ModelImage, PSFImage, Window +from ..errors import InvalidTarget +from .mixins import SampleMixin +from ..backend_obj import backend, ArrayLike -from .core_model import AstroPhot_Model -from ..image import ( - Image, - Model_Image, - Window, - PSF_Image, - Image_List, -) -from ._shared_methods import select_target -from ..utils.decorators import default_internal, ignore_numpy_warnings -from ..param import Parameter_Node -from ..errors import SpecificationConflict +__all__ = ["PSFModel"] -__all__ = ["PSF_Model"] - - -class PSF_Model(AstroPhot_Model): +class PSFModel(SampleMixin, Model): """Prototype point source (typically a star) model, to be subclassed by other point source models which define specific behavior. @@ -32,157 +23,31 @@ class PSF_Model(AstroPhot_Model): """ - # Specifications for the model parameters including units, value, uncertainty, limits, locked, and cyclic - parameter_specs = { - "center": { - "units": "arcsec", - "value": (0.0, 0.0), - "uncertainty": (0.1, 0.1), - "locked": True, - }, + _parameter_specs = { + "center": {"units": "arcsec", "value": (0.0, 0.0), "shape": (2,), "dynamic": False}, } - # Fixed order of parameters for all methods that interact with the list of parameters - _parameter_order = ("center",) - model_type = f"psf {AstroPhot_Model.model_type}" + _model_type = "psf" usable = False - model_integrated = None # The sampled PSF will be normalized to a total flux of 1 within the window normalize_psf = True - # Method for initial sampling of model - sampling_mode = "simpsons" # midpoint, trapezoid, simpson - - # Level to which each pixel should be evaluated - sampling_tolerance = 1e-3 - - # Integration scope for model - integrate_mode = "threshold" # none, threshold, full* - - # Maximum recursion depth when performing sub pixel integration - integrate_max_depth = 3 - - # Amount by which to subdivide pixels when doing recursive pixel integration - integrate_gridding = 5 - - # The initial quadrature level for sub pixel integration. Please always choose an odd number 3 or higher - integrate_quad_level = 3 - - # Maximum size of parameter list before jacobian will be broken into smaller chunks, this is helpful for limiting the memory requirements to build a model, lower jacobian_chunksize is slower but uses less memory - jacobian_chunksize = 10 - image_chunksize = 1000 - - # Softening length used for numerical stability and/or integration stability to avoid discontinuities (near R=0) - softening = 1e-3 - # Parameters which are treated specially by the model object and should not be updated directly when initializing - special_kwargs = ["parameters", "filename", "model_type"] - track_attrs = [ - "sampling_mode", - "sampling_tolerance", - "integrate_mode", - "integrate_max_depth", - "integrate_gridding", - "integrate_quad_level", - "jacobian_chunksize", - "softening", - ] - - def __init__(self, *, name=None, **kwargs): - self._target_identity = None - super().__init__(name=name, **kwargs) - - # Set any user defined attributes for the model - for kwarg in kwargs: # fixme move to core model? - # Skip parameters with special behaviour - if kwarg in self.special_kwargs: - continue - # Set the model parameter - setattr(self, kwarg, kwargs[kwarg]) - - # If loading from a file, get model configuration then exit __init__ - if "filename" in kwargs: - self.load(kwargs["filename"], new_name=name) - return - - self.parameter_specs = self.build_parameter_specs(kwargs.get("parameters", None)) - with torch.no_grad(): - self.build_parameters() - if isinstance(kwargs.get("parameters", None), torch.Tensor): - self.parameters.value = kwargs["parameters"] - assert torch.allclose( - self.window.center, torch.zeros_like(self.window.center) - ), "PSF models must always be centered at (0,0)" + _options = ("normalize_psf",) - # Initialization functions - ###################################################################### - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize( - self, - target: Optional["PSF_Image"] = None, - parameters: Optional[Parameter_Node] = None, - **kwargs, - ): - """Determine initial values for the center coordinates. This is done - with a local center of mass search which iterates by finding - the center of light in a window, then iteratively updates - until the iterations move by less than a pixel. - - Args: - target (Optional[Target_Image]): A target image object to use as a reference when setting parameter values - - """ - super().initialize(target=target, parameters=parameters) + def initialize(self): + pass - @default_internal - def transform_coordinates(self, X, Y, image=None, parameters=None): - return X, Y + @forward + def transform_coordinates( + self, x: ArrayLike, y: ArrayLike, center: ArrayLike + ) -> Tuple[ArrayLike, ArrayLike]: + return x - center[0], y - center[1] # Fit loop functions ###################################################################### - def evaluate_model( - self, - X: Optional[torch.Tensor] = None, - Y: Optional[torch.Tensor] = None, - image: Optional[Image] = None, - parameters: "Parameter_Node" = None, - **kwargs, - ): - """Evaluate the model on every pixel in the given image. The - basemodel object simply returns zeros, this function should be - overloaded by subclasses. - - Args: - image (Image): The image defining the set of pixels on which to evaluate the model - - """ - if X is None or Y is None: - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - return torch.zeros_like(X) # do nothing in base model - - def make_model_image(self, window: Optional[Window] = None): - """This is called to create a blank `Model_Image` object of the - correct format for this model. This is typically used - internally to construct the model image before filling the - pixel values with the model. - - """ - if window is None: - window = self.window - else: - window = self.window & window - return self.target[window].blank_copy() - - def sample( - self, - image: Optional[Image] = None, - window: Optional[Window] = None, - parameters: Optional[Parameter_Node] = None, - ): + @forward + def sample(self, window: Optional[Window] = None) -> PSFImage: """Evaluate the model on the space covered by an image object. This function properly calls integration methods. This should not be overloaded except in special cases. @@ -194,74 +59,28 @@ def sample( pixel grid. The final model is then added to the requested image. - Args: - image (Optional[Image]): An AstroPhot Image object (likely a Model_Image) - on which to evaluate the model values. If not - provided, a new Model_Image object will be created. - window (Optional[Window]): A window within which to evaluate the model. + **Args:** + - `window` (Optional[Window]): A window within which to evaluate the model. Should only be used if a subset of the full image is needed. If not provided, the entire image will be used. - Returns: - Image: The image with the computed model values. + **Returns:** + - `PSFImage`: The image with the computed model values. """ - # Image on which to evaluate model - if image is None: - image = self.make_model_image(window=window) - - # Window within which to evaluate model - if window is None: - working_window = image.window.copy() - else: - working_window = window.copy() - - # Parameters with which to evaluate the model - if parameters is None: - parameters = self.parameters - # Create an image to store pixel samples - working_image = Model_Image(window=working_window) - if self.model_integrated is True: - # Evaluate the model on the image - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - working_image.data = self.evaluate_model( - X=X, Y=Y, image=working_image, parameters=parameters - ) - elif self.model_integrated is False: - # Evaluate the model on the image - reference, deep = self._sample_init( - image=working_image, - parameters=parameters, - center=parameters["center"].value, - ) - # Super-resolve and integrate where needed - deep = self._sample_integrate( - deep, - reference, - working_image, - parameters, - center=torch.zeros_like(working_image.center), - ) - # Add the sampled/integrated pixels to the requested image - working_image.data += deep - else: - raise SpecificationConflict( - "PSF model 'model_integrated' should be either True or False" - ) + working_image = self.target[self.window].model_image() + working_image._data = self.sample_image(working_image) # normalize to total flux 1 if self.normalize_psf: - working_image.data /= torch.sum(working_image.data) - - if self.mask is not None: - working_image.data = working_image.data * torch.logical_not(self.mask) + working_image.normalize() - image += working_image + return working_image - return image + def fit_mask(self) -> ArrayLike: + return backend.zeros_like(self.target[self.window].mask, dtype=backend.bool) @property def target(self): @@ -271,60 +90,18 @@ def target(self): return None @target.setter - def target(self, tar): - assert tar is None or isinstance(tar, PSF_Image) - - # If a target image list is assigned, pick out the target appropriate for this model - if isinstance(tar, Image_List) and self._target_identity is not None: - for subtar in tar: - if subtar.identity == self._target_identity: - usetar = subtar - break - else: - raise KeyError( - f"Could not find target in Target_Image_List with matching identity to {self.name}: {self._target_identity}" - ) - else: - usetar = tar - - self._target = usetar - - # Remember the target identity to use + def target(self, target): + if target is None: + self._target = None + elif not isinstance(target, PSFImage): + raise InvalidTarget(f"Target for PSF_Model must be a PSF_Image, not {type(target)}") try: - self._target_identity = self._target.identity + del self._target # Remove old target if it exists except AttributeError: pass - def get_state(self, save_params=True): - """Returns a dictionary with a record of the current state of the - model. + self._target = target - Specifically, the current parameter settings and the window for - this model. From this information it is possible for the model to - re-build itself lated when loading from disk. Note that the target - image is not saved, this must be reset when loading the model. - - """ - state = super().get_state() - state["window"] = self.window.get_state() - if save_params: - state["parameters"] = self.parameters.get_state() - state["target_identity"] = self._target_identity - for key in self.track_attrs: - if getattr(self, key) != getattr(self.__class__, key): - state[key] = getattr(self, key) - return state - - # Extra background methods for the basemodel - ###################################################################### - from ._model_methods import radius_metric - from ._model_methods import angular_metric - from ._model_methods import _sample_init - from ._model_methods import _sample_integrate - from ._model_methods import _integrate_reference - from ._model_methods import build_parameter_specs - from ._model_methods import build_parameters - from ._model_methods import jacobian - from ._model_methods import _chunk_jacobian - from ._model_methods import _chunk_image_jacobian - from ._model_methods import load + @forward + def __call__(self, window: Optional[Window] = None) -> ModelImage: + return self.sample(window=window) diff --git a/astrophot/models/ray_model.py b/astrophot/models/ray_model.py deleted file mode 100644 index 965e9ae9..00000000 --- a/astrophot/models/ray_model.py +++ /dev/null @@ -1,105 +0,0 @@ -import numpy as np -import torch - -from .galaxy_model_object import Galaxy_Model -from ..utils.decorators import default_internal - -__all__ = ["Ray_Galaxy"] - - -class Ray_Galaxy(Galaxy_Model): - """Variant of a galaxy model which defines multiple radial models - seprarately along some number of rays projected from the galaxy - center. These rays smoothly transition from one to another along - angles theta. The ray transition uses a cosine smoothing function - which depends on the number of rays, for example with two rays the - brightness would be: - - I(R,theta) = I1(R)*cos(theta % pi) + I2(R)*cos((theta + pi/2) % pi) - - Where I(R,theta) is the brightness function in polar coordinates, - R is the semi-major axis, theta is the polar angle (defined after - galaxy axis ratio is applied), I1(R) is the first brightness - profile, % is the modulo operator, and I2 is the second brightness - profile. The ray model defines no extra parameters, though now - every model parameter related to the brightness profile gains an - extra dimension for the ray number. Also a new input can be given - when instantiating the ray model: "rays" which is an integer for - the number of rays. - - """ - - model_type = f"ray {Galaxy_Model.model_type}" - special_kwargs = Galaxy_Model.special_kwargs + ["rays"] - rays = 2 - track_attrs = Galaxy_Model.track_attrs + ["rays"] - usable = False - - def __init__(self, *args, **kwargs): - self.symmetric_rays = True - super().__init__(*args, **kwargs) - self.rays = kwargs.get("rays", Ray_Galaxy.rays) - - @default_internal - def polar_model(self, R, T, image=None, parameters=None): - model = torch.zeros_like(R) - if self.rays % 2 == 0 and self.symmetric_rays: - for r in range(self.rays): - angles = (T - (r * np.pi / self.rays)) % np.pi - indices = torch.logical_or( - angles < (np.pi / self.rays), - angles >= (np.pi * (1 - 1 / self.rays)), - ) - weight = (torch.cos(angles[indices] * self.rays) + 1) / 2 - model[indices] += weight * self.iradial_model(r, R[indices], image) - elif self.rays % 2 == 1 and self.symmetric_rays: - for r in range(self.rays): - angles = (T - (r * np.pi / self.rays)) % (2 * np.pi) - indices = torch.logical_or( - angles < (np.pi / self.rays), - angles >= (np.pi * (2 - 1 / self.rays)), - ) - weight = (torch.cos(angles[indices] * self.rays) + 1) / 2 - model[indices] += weight * self.iradial_model(r, R[indices], image) - angles = (T - (np.pi + r * np.pi / self.rays)) % (2 * np.pi) - indices = torch.logical_or( - angles < (np.pi / self.rays), - angles >= (np.pi * (2 - 1 / self.rays)), - ) - weight = (torch.cos(angles[indices] * self.rays) + 1) / 2 - model[indices] += weight * self.iradial_model(r, R[indices], image) - elif self.rays % 2 == 0 and not self.symmetric_rays: - for r in range(self.rays): - angles = (T - (r * 2 * np.pi / self.rays)) % (2 * np.pi) - indices = torch.logical_or( - angles < (2 * np.pi / self.rays), - angles >= (2 * np.pi * (1 - 1 / self.rays)), - ) - weight = (torch.cos(angles[indices] * self.rays) + 1) / 2 - model[indices] += weight * self.iradial_model(r, R[indices], image) - else: - for r in range(self.rays): - angles = (T - (r * 2 * np.pi / self.rays)) % (2 * np.pi) - indices = torch.logical_or( - angles < (2 * np.pi / self.rays), - angles >= (np.pi * (2 - 1 / self.rays)), - ) - weight = (torch.cos(angles[indices] * self.rays) + 1) / 2 - model[indices] += weight * self.iradial_model(r, R[indices], image) - return model - - def evaluate_model(self, X=None, Y=None, image=None, parameters=None, **kwargs): - if X is None: - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - XX, YY = self.transform_coordinates(X, Y, image, parameters) - - return self.polar_model( - self.radius_metric(XX, YY, image=image, parameters=parameters), - self.angular_metric(XX, YY, image=image, parameters=parameters), - image=image, - parameters=parameters, - ) - - -# class SingleRay_Galaxy(Galaxy_Model): diff --git a/astrophot/models/relspline_model.py b/astrophot/models/relspline_model.py deleted file mode 100644 index a7eb5f05..00000000 --- a/astrophot/models/relspline_model.py +++ /dev/null @@ -1,78 +0,0 @@ -from .galaxy_model_object import Galaxy_Model -from .psf_model_object import PSF_Model -from ..utils.decorators import default_internal - -__all__ = [ - "RelSpline_Galaxy", - "RelSpline_PSF", -] - - -# First Order -###################################################################### -class RelSpline_Galaxy(Galaxy_Model): - """Basic galaxy model with a spline radial light profile. The - light profile is defined as a cubic spline interpolation of the - stored brightness values: - - I(R) = interp(R, profR, I) - - where I(R) is the brightness along the semi-major axis, interp is - a cubic spline function, R is the semi-major axis length, profR is - a list of radii for the spline, I is a corresponding list of - brightnesses at each profR value. - - Parameters: - I0: Central brightness - dI(R): Tensor of brighntess values relative to central brightness, represented as the log of the brightness divided by pixelscale squared - - """ - - model_type = f"relspline {Galaxy_Model.model_type}" - parameter_specs = { - "I0": {"units": "log10(flux/arcsec^2)"}, - "dI(R)": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = Galaxy_Model._parameter_order + ("I0", "dI(R)") - usable = True - extend_profile = True - - from ._shared_methods import relspline_initialize as initialize - from ._shared_methods import relspline_radial_model as radial_model - - -class RelSpline_PSF(PSF_Model): - """point source model with a spline radial light profile. The light - profile is defined as a cubic spline interpolation of the stored - brightness values: - - I(R) = interp(R, profR, I) - - where I(R) is the brightness along the semi-major axis, interp is - a cubic spline function, R is the semi-major axis length, profR is - a list of radii for the spline, I is a corresponding list of - brightnesses at each profR value. - - Parameters: - I0: Central brightness - dI(R): Tensor of brighntess values relative to central brightness, represented as the log of the brightness divided by pixelscale squared - - """ - - model_type = f"relspline {PSF_Model.model_type}" - parameter_specs = { - "I0": {"units": "log10(flux/arcsec^2)", "value": 0.0, "locked": True}, - "dI(R)": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = PSF_Model._parameter_order + ("I0", "dI(R)") - usable = True - extend_profile = True - model_integrated = False - - @default_internal - def transform_coordinates(self, X=None, Y=None, image=None, parameters=None): - return X, Y - - from ._shared_methods import relspline_initialize as initialize - from ._shared_methods import relspline_radial_model as radial_model - from ._shared_methods import radial_evaluate_model as evaluate_model diff --git a/astrophot/models/sersic.py b/astrophot/models/sersic.py new file mode 100644 index 00000000..6d68f1a8 --- /dev/null +++ b/astrophot/models/sersic.py @@ -0,0 +1,76 @@ +from ..param import forward +from .galaxy_model_object import GalaxyModel +from .psf_model_object import PSFModel +from ..utils.conversions.functions import sersic_Ie_to_flux_torch +from ..utils.decorators import combine_docstrings +from .mixins import ( + SersicMixin, + RadialMixin, + WedgeMixin, + iSersicMixin, + RayMixin, + SuperEllipseMixin, + FourierEllipseMixin, + WarpMixin, + TruncationMixin, +) + +__all__ = [ + "SersicGalaxy", + "TSersicGalaxy", + "SersicPSF", + "Sersic_Warp", + "Sersic_SuperEllipse", + "Sersic_FourierEllipse", + "Sersic_Ray", + "Sersic_Wedge", +] + + +@combine_docstrings +class SersicGalaxy(SersicMixin, RadialMixin, GalaxyModel): + usable = True + + @forward + def total_flux(self, Ie, n, Re, q, window=None): + return sersic_Ie_to_flux_torch(Ie, n, Re, q) + + +@combine_docstrings +class TSersicGalaxy(TruncationMixin, SersicMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class SersicPSF(SersicMixin, RadialMixin, PSFModel): + _parameter_specs = {"Ie": {"units": "flux/arcsec^2", "value": 1.0, "dynamic": False}} + usable = True + + @forward + def total_flux(self, Ie, n, Re): + return sersic_Ie_to_flux_torch(Ie, n, Re, 1.0) + + +@combine_docstrings +class SersicSuperEllipse(SersicMixin, RadialMixin, SuperEllipseMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class SersicFourierEllipse(SersicMixin, RadialMixin, FourierEllipseMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class SersicWarp(SersicMixin, RadialMixin, WarpMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class SersicRay(iSersicMixin, RayMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class SersicWedge(iSersicMixin, WedgeMixin, GalaxyModel): + usable = True diff --git a/astrophot/models/sersic_model.py b/astrophot/models/sersic_model.py deleted file mode 100644 index 3bd1ae90..00000000 --- a/astrophot/models/sersic_model.py +++ /dev/null @@ -1,438 +0,0 @@ -import torch - -from .galaxy_model_object import Galaxy_Model -from .warp_model import Warp_Galaxy -from .ray_model import Ray_Galaxy -from .wedge_model import Wedge_Galaxy -from .psf_model_object import PSF_Model -from .superellipse_model import SuperEllipse_Galaxy, SuperEllipse_Warp -from .foureirellipse_model import FourierEllipse_Galaxy, FourierEllipse_Warp -from ._shared_methods import ( - parametric_initialize, - parametric_segment_initialize, - select_target, -) -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..utils.parametric_profiles import sersic_np -from ..utils.conversions.functions import sersic_Ie_to_flux_torch - - -__all__ = [ - "Sersic_Galaxy", - "Sersic_PSF", - "Sersic_Warp", - "Sersic_SuperEllipse", - "Sersic_FourierEllipse", - "Sersic_Ray", - "Sersic_Wedge", - "Sersic_SuperEllipse_Warp", - "Sersic_FourierEllipse_Warp", -] - - -def _x0_func(model, R, F): - return 2.0, R[4], F[4] - - -def _wrap_sersic(R, n, r, i): - return sersic_np(R, n, r, 10 ** (i)) - - -class Sersic_Galaxy(Galaxy_Model): - """basic galaxy model with a sersic profile for the radial light - profile. The functional form of the Sersic profile is defined as: - - I(R) = Ie * exp(- bn((R/Re)^(1/n) - 1)) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ie is the brightness as the - half light radius, bn is a function of n and is not involved in - the fit, Re is the half light radius, and n is the sersic index - which controls the shape of the profile. - - Parameters: - n: Sersic index which controls the shape of the brightness profile - Ie: brightness at the half light radius, represented as the log of the brightness divided by pixel scale squared. - Re: half light radius - - """ - - model_type = f"sersic {Galaxy_Model.model_type}" - parameter_specs = { - "n": {"units": "none", "limits": (0.36, 8), "uncertainty": 0.05}, - "Re": {"units": "arcsec", "limits": (0, None)}, - "Ie": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = Galaxy_Model._parameter_order + ("n", "Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_sersic, ("n", "Re", "Ie"), _x0_func) - - @default_internal - def total_flux(self, parameters=None, window=None): - return sersic_Ie_to_flux_torch( - 10 ** parameters["Ie"].value, - parameters["n"].value, - parameters["Re"].value, - parameters["q"].value, - ) - - def _integrate_reference(self, image_data, image_header, parameters): - tot = self.total_flux(parameters) - return tot / image_data.numel() - - from ._shared_methods import sersic_radial_model as radial_model - - -class Sersic_PSF(PSF_Model): - """basic point source model with a sersic profile for the radial light - profile. The functional form of the Sersic profile is defined as: - - I(R) = Ie * exp(- bn((R/Re)^(1/n) - 1)) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ie is the brightness as the - half light radius, bn is a function of n and is not involved in - the fit, Re is the half light radius, and n is the sersic index - which controls the shape of the profile. - - Parameters: - n: Sersic index which controls the shape of the brightness profile - Ie: brightness at the half light radius, represented as the log of the brightness divided by pixel scale squared. - Re: half light radius - - """ - - model_type = f"sersic {PSF_Model.model_type}" - parameter_specs = { - "n": {"units": "none", "limits": (0.36, 8), "uncertainty": 0.05}, - "Re": {"units": "arcsec", "limits": (0, None)}, - "Ie": { - "units": "log10(flux/arcsec^2)", - "value": 0.0, - "uncertainty": 0.0, - "locked": True, - }, - } - _parameter_order = PSF_Model._parameter_order + ("n", "Re", "Ie") - usable = True - model_integrated = False - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_sersic, ("n", "Re", "Ie"), _x0_func) - - from ._shared_methods import sersic_radial_model as radial_model - from ._shared_methods import radial_evaluate_model as evaluate_model - - -class Sersic_SuperEllipse(SuperEllipse_Galaxy): - """super ellipse galaxy model with a sersic profile for the radial - light profile. The functional form of the Sersic profile is defined as: - - I(R) = Ie * exp(- bn((R/Re)^(1/n) - 1)) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ie is the brightness as the - half light radius, bn is a function of n and is not involved in - the fit, Re is the half light radius, and n is the sersic index - which controls the shape of the profile. - - Parameters: - n: Sersic index which controls the shape of the brightness profile - Ie: brightness at the half light radius, represented as the log of the brightness divided by pixel scale squared. - Re: half light radius - - """ - - model_type = f"sersic {SuperEllipse_Galaxy.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "n": {"units": "none", "limits": (0.36, 8), "uncertainty": 0.05}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = SuperEllipse_Galaxy._parameter_order + ("n", "Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_sersic, ("n", "Re", "Ie"), _x0_func) - - from ._shared_methods import sersic_radial_model as radial_model - - -class Sersic_SuperEllipse_Warp(SuperEllipse_Warp): - """super ellipse warp galaxy model with a sersic profile for the - radial light profile. The functional form of the Sersic profile is - defined as: - - I(R) = Ie * exp(- bn((R/Re)^(1/n) - 1)) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ie is the brightness as the - half light radius, bn is a function of n and is not involved in - the fit, Re is the half light radius, and n is the sersic index - which controls the shape of the profile. - - Parameters: - n: Sersic index which controls the shape of the brightness profile - Ie: brightness at the half light radius, represented as the log of the brightness divided by pixel scale squared. - Re: half light radius - - """ - - model_type = f"sersic {SuperEllipse_Warp.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "n": {"units": "none", "limits": (0.36, 8), "uncertainty": 0.05}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = SuperEllipse_Warp._parameter_order + ("n", "Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_sersic, ("n", "Re", "Ie"), _x0_func) - - from ._shared_methods import sersic_radial_model as radial_model - - -class Sersic_FourierEllipse(FourierEllipse_Galaxy): - """fourier mode perturbations to ellipse galaxy model with a sersic - profile for the radial light profile. The functional form of the - Sersic profile is defined as: - - I(R) = Ie * exp(- bn((R/Re)^(1/n) - 1)) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ie is the brightness as the - half light radius, bn is a function of n and is not involved in - the fit, Re is the half light radius, and n is the sersic index - which controls the shape of the profile. - - Parameters: - n: Sersic index which controls the shape of the brightness profile - Ie: brightness at the half light radius, represented as the log of the brightness divided by pixel scale squared. - Re: half light radius - - """ - - model_type = f"sersic {FourierEllipse_Galaxy.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "n": {"units": "none", "limits": (0.36, 8), "uncertainty": 0.05}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = FourierEllipse_Galaxy._parameter_order + ("n", "Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_sersic, ("n", "Re", "Ie"), _x0_func) - - from ._shared_methods import sersic_radial_model as radial_model - - -class Sersic_FourierEllipse_Warp(FourierEllipse_Warp): - """fourier mode perturbations to ellipse galaxy model with a sersic - profile for the radial light profile. The functional form of the - Sersic profile is defined as: - - I(R) = Ie * exp(- bn((R/Re)^(1/n) - 1)) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ie is the brightness as the - half light radius, bn is a function of n and is not involved in - the fit, Re is the half light radius, and n is the sersic index - which controls the shape of the profile. - - Parameters: - n: Sersic index which controls the shape of the brightness profile - Ie: brightness at the half light radius, represented as the log of the brightness divided by pixel scale squared. - Re: half light radius - - """ - - model_type = f"sersic {FourierEllipse_Warp.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "n": {"units": "none", "limits": (0.36, 8), "uncertainty": 0.05}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = FourierEllipse_Warp._parameter_order + ("n", "Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_sersic, ("n", "Re", "Ie"), _x0_func) - - from ._shared_methods import sersic_radial_model as radial_model - - -class Sersic_Warp(Warp_Galaxy): - """warped coordinate galaxy model with a sersic profile for the radial - light model. The functional form of the Sersic profile is defined - as: - - I(R) = Ie * exp(- bn((R/Re)^(1/n) - 1)) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ie is the brightness as the - half light radius, bn is a function of n and is not involved in - the fit, Re is the half light radius, and n is the sersic index - which controls the shape of the profile. - - Parameters: - n: Sersic index which controls the shape of the brightness profile - Ie: brightness at the half light radius, represented as the log of the brightness divided by pixel scale squared. - Re: half light radius - - """ - - model_type = f"sersic {Warp_Galaxy.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "n": {"units": "none", "limits": (0.36, 8), "uncertainty": 0.05}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = Warp_Galaxy._parameter_order + ("n", "Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_initialize(self, parameters, target, _wrap_sersic, ("n", "Re", "Ie"), _x0_func) - - from ._shared_methods import sersic_radial_model as radial_model - - -class Sersic_Ray(Ray_Galaxy): - """ray galaxy model with a sersic profile for the radial light - model. The functional form of the Sersic profile is defined as: - - I(R) = Ie * exp(- bn((R/Re)^(1/n) - 1)) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ie is the brightness as the - half light radius, bn is a function of n and is not involved in - the fit, Re is the half light radius, and n is the sersic index - which controls the shape of the profile. - - Parameters: - n: Sersic index which controls the shape of the brightness profile - Ie: brightness at the half light radius, represented as the log of the brightness divided by pixel scale squared. - Re: half light radius - - """ - - model_type = f"sersic {Ray_Galaxy.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "n": {"units": "none", "limits": (0.36, 8), "uncertainty": 0.05}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = Ray_Galaxy._parameter_order + ("n", "Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_segment_initialize( - model=self, - target=target, - parameters=parameters, - prof_func=_wrap_sersic, - params=("n", "Re", "Ie"), - x0_func=_x0_func, - segments=self.rays, - ) - - from ._shared_methods import sersic_iradial_model as iradial_model - - -class Sersic_Wedge(Wedge_Galaxy): - """wedge galaxy model with a sersic profile for the radial light - model. The functional form of the Sersic profile is defined as: - - I(R) = Ie * exp(- bn((R/Re)^(1/n) - 1)) - - where I(R) is the brightness profile as a function of semi-major - axis, R is the semi-major axis length, Ie is the brightness as the - half light radius, bn is a function of n and is not involved in - the fit, Re is the half light radius, and n is the sersic index - which controls the shape of the profile. - - Parameters: - n: Sersic index which controls the shape of the brightness profile - Ie: brightness at the half light radius, represented as the log of the brightness divided by pixel scale squared. - Re: half light radius - - """ - - model_type = f"sersic {Wedge_Galaxy.model_type}" - parameter_specs = { - "Ie": {"units": "log10(flux/arcsec^2)"}, - "n": {"units": "none", "limits": (0.36, 8), "uncertainty": 0.05}, - "Re": {"units": "arcsec", "limits": (0, None)}, - } - _parameter_order = Wedge_Galaxy._parameter_order + ("n", "Re", "Ie") - usable = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - parametric_segment_initialize( - model=self, - parameters=parameters, - target=target, - prof_func=_wrap_sersic, - params=("n", "Re", "Ie"), - x0_func=_x0_func, - segments=self.wedges, - ) - - from ._shared_methods import sersic_iradial_model as iradial_model diff --git a/astrophot/models/sky_model_object.py b/astrophot/models/sky_model_object.py index a0c345c3..6d9f5cca 100644 --- a/astrophot/models/sky_model_object.py +++ b/astrophot/models/sky_model_object.py @@ -1,9 +1,11 @@ -from .model_object import Component_Model +from .model_object import ComponentModel +from ..utils.decorators import combine_docstrings -__all__ = ["Sky_Model"] +__all__ = ["SkyModel"] -class Sky_Model(Component_Model): +@combine_docstrings +class SkyModel(ComponentModel): """prototype class for any sky background model. This simply imposes that the center is a locked parameter, not involved in the fit. Also, a sky model object has no psf mode or integration mode @@ -12,24 +14,32 @@ class Sky_Model(Component_Model): """ - model_type = f"sky {Component_Model.model_type}" - parameter_specs = { - "center": {"units": "arcsec", "locked": True, "uncertainty": 0.0}, - } + _model_type = "sky" usable = False + def initialize(self): + """Initialize the sky model, this is called after the model is + created and before it is used. This is where we can set the + center to be a locked parameter. + """ + if not self.center.initialized: + target_area = self.target[self.window] + self.center.to_static(target_area.center) + super().initialize() + self.center.to_static() + @property - def psf_mode(self): - return "none" + def psf_convolve(self) -> bool: + return False - @psf_mode.setter - def psf_mode(self, val): + @psf_convolve.setter + def psf_convolve(self, val: bool): pass @property - def integrate_mode(self): + def integrate_mode(self) -> str: return "none" @integrate_mode.setter - def integrate_mode(self, val): + def integrate_mode(self, val: str): pass diff --git a/astrophot/models/spline.py b/astrophot/models/spline.py new file mode 100644 index 00000000..0a011e7d --- /dev/null +++ b/astrophot/models/spline.py @@ -0,0 +1,59 @@ +from .galaxy_model_object import GalaxyModel +from .psf_model_object import PSFModel +from .mixins import ( + SplineMixin, + RadialMixin, + iSplineMixin, + RayMixin, + WedgeMixin, + SuperEllipseMixin, + FourierEllipseMixin, + WarpMixin, +) +from ..utils.decorators import combine_docstrings + + +__all__ = [ + "SplineGalaxy", + "SplinePSF", + "SplineWarp", + "SplineSuperEllipse", + "SplineFourierEllipse", + "SplineRay", + "SplineWedge", +] + + +@combine_docstrings +class SplineGalaxy(SplineMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class SplinePSF(SplineMixin, RadialMixin, PSFModel): + usable = True + + +@combine_docstrings +class SplineSuperEllipse(SplineMixin, SuperEllipseMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class SplineFourierEllipse(SplineMixin, FourierEllipseMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class SplineWarp(SplineMixin, WarpMixin, RadialMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class SplineRay(iSplineMixin, RayMixin, GalaxyModel): + usable = True + + +@combine_docstrings +class SplineWedge(iSplineMixin, WedgeMixin, GalaxyModel): + usable = True diff --git a/astrophot/models/spline_model.py b/astrophot/models/spline_model.py deleted file mode 100644 index 76711326..00000000 --- a/astrophot/models/spline_model.py +++ /dev/null @@ -1,319 +0,0 @@ -import torch - -from .galaxy_model_object import Galaxy_Model -from .warp_model import Warp_Galaxy -from .superellipse_model import SuperEllipse_Galaxy, SuperEllipse_Warp -from .foureirellipse_model import FourierEllipse_Galaxy, FourierEllipse_Warp -from .psf_model_object import PSF_Model -from .ray_model import Ray_Galaxy -from .wedge_model import Wedge_Galaxy -from ._shared_methods import spline_segment_initialize, select_target -from ..utils.decorators import ignore_numpy_warnings, default_internal - -__all__ = [ - "Spline_Galaxy", - "Spline_PSF", - "Spline_Warp", - "Spline_SuperEllipse", - "Spline_FourierEllipse", - "Spline_Ray", - "Spline_SuperEllipse_Warp", - "Spline_FourierEllipse_Warp", -] - - -# First Order -###################################################################### -class Spline_Galaxy(Galaxy_Model): - """Basic galaxy model with a spline radial light profile. The - light profile is defined as a cubic spline interpolation of the - stored brightness values: - - I(R) = interp(R, profR, I) - - where I(R) is the brightness along the semi-major axis, interp is - a cubic spline function, R is the semi-major axis length, profR is - a list of radii for the spline, I is a corresponding list of - brightnesses at each profR value. - - Parameters: - I(R): Tensor of brighntess values, represented as the log of the brightness divided by pixelscale squared - - """ - - model_type = f"spline {Galaxy_Model.model_type}" - parameter_specs = { - "I(R)": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = Galaxy_Model._parameter_order + ("I(R)",) - usable = True - extend_profile = True - - from ._shared_methods import spline_initialize as initialize - from ._shared_methods import spline_radial_model as radial_model - - -class Spline_PSF(PSF_Model): - """star model with a spline radial light profile. The light - profile is defined as a cubic spline interpolation of the stored - brightness values: - - I(R) = interp(R, profR, I) - - where I(R) is the brightness along the semi-major axis, interp is - a cubic spline function, R is the semi-major axis length, profR is - a list of radii for the spline, I is a corresponding list of - brightnesses at each profR value. - - Parameters: - I(R): Tensor of brighntess values, represented as the log of the brightness divided by pixelscale squared - - """ - - model_type = f"spline {PSF_Model.model_type}" - parameter_specs = { - "I(R)": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = PSF_Model._parameter_order + ("I(R)",) - usable = True - extend_profile = True - model_integrated = False - - @default_internal - def transform_coordinates(self, X=None, Y=None, image=None, parameters=None): - return X, Y - - from ._shared_methods import spline_initialize as initialize - from ._shared_methods import spline_radial_model as radial_model - from ._shared_methods import radial_evaluate_model as evaluate_model - - -class Spline_Warp(Warp_Galaxy): - """warped coordinate galaxy model with a spline light - profile. The light profile is defined as a cubic spline - interpolation of the stored brightness values: - - I(R) = interp(R, profR, I) - - where I(R) is the brightness along the semi-major axis, interp is - a cubic spline function, R is the semi-major axis length, profR is - a list of radii for the spline, I is a corresponding list of - brightnesses at each profR value. - - Parameters: - I(R): Tensor of brighntess values, represented as the log of the brightness divided by pixelscale squared - - """ - - model_type = f"spline {Warp_Galaxy.model_type}" - parameter_specs = { - "I(R)": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = Warp_Galaxy._parameter_order + ("I(R)",) - usable = True - extend_profile = True - - from ._shared_methods import spline_initialize as initialize - from ._shared_methods import spline_radial_model as radial_model - - -# Second Order -###################################################################### -class Spline_SuperEllipse(SuperEllipse_Galaxy): - """The light profile is defined as a cubic spline interpolation of - the stored brightness values: - - I(R) = interp(R, profR, I) - - where I(R) is the brightness along the semi-major axis, interp is - a cubic spline function, R is the semi-major axis length, profR is - a list of radii for the spline, I is a corresponding list of - brightnesses at each profR value. - - Parameters: - I(R): Tensor of brighntess values, represented as the log of the brightness divided by pixelscale squared - - """ - - model_type = f"spline {SuperEllipse_Galaxy.model_type}" - parameter_specs = { - "I(R)": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = SuperEllipse_Galaxy._parameter_order + ("I(R)",) - usable = True - extend_profile = True - - from ._shared_methods import spline_initialize as initialize - from ._shared_methods import spline_radial_model as radial_model - - -class Spline_FourierEllipse(FourierEllipse_Galaxy): - """The light profile is defined as a cubic spline interpolation of the - stored brightness values: - - I(R) = interp(R, profR, I) - - where I(R) is the brightness along the semi-major axis, interp is - a cubic spline function, R is the semi-major axis length, profR is - a list of radii for the spline, I is a corresponding list of - brightnesses at each profR value. - - Parameters: - I(R): Tensor of brighntess values, represented as the log of the brightness divided by pixelscale squared - - """ - - model_type = f"spline {FourierEllipse_Galaxy.model_type}" - parameter_specs = { - "I(R)": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = FourierEllipse_Galaxy._parameter_order + ("I(R)",) - usable = True - extend_profile = True - - from ._shared_methods import spline_initialize as initialize - from ._shared_methods import spline_radial_model as radial_model - - -class Spline_Ray(Ray_Galaxy): - """ray galaxy model with a spline light profile. The light - profile is defined as a cubic spline interpolation of the stored - brightness values: - - I(R) = interp(R, profR, I) - - where I(R) is the brightness along the semi-major axis, interp is - a cubic spline function, R is the semi-major axis length, profR is - a list of radii for the spline, I is a corresponding list of - brightnesses at each profR value. - - Parameters: - I(R): 2D Tensor of brighntess values for each ray, represented as the log of the brightness divided by pixelscale squared - - """ - - model_type = f"spline {Ray_Galaxy.model_type}" - parameter_specs = { - "I(R)": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = Ray_Galaxy._parameter_order + ("I(R)",) - usable = True - extend_profile = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - spline_segment_initialize( - self, - target=target, - parameters=parameters, - segments=self.rays, - symmetric=self.symmetric_rays, - ) - - from ._shared_methods import spline_iradial_model as iradial_model - - -class Spline_Wedge(Wedge_Galaxy): - """wedge galaxy model with a spline light profile. The light - profile is defined as a cubic spline interpolation of the stored - brightness values: - - I(R) = interp(R, profR, I) - - where I(R) is the brightness along the semi-major axis, interp is - a cubic spline function, R is the semi-major axis length, profR is - a list of radii for the spline, I is a corresponding list of - brightnesses at each profR value. - - Parameters: - I(R): 2D Tensor of brighntess values for each wedge, represented as the log of the brightness divided by pixelscale squared - - """ - - model_type = f"spline {Wedge_Galaxy.model_type}" - parameter_specs = { - "I(R)": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = Wedge_Galaxy._parameter_order + ("I(R)",) - usable = True - extend_profile = True - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - spline_segment_initialize( - self, - target=target, - parameters=parameters, - segments=self.wedges, - symmetric=self.symmetric_wedges, - ) - - from ._shared_methods import spline_iradial_model as iradial_model - - -# Third Order -###################################################################### -class Spline_SuperEllipse_Warp(SuperEllipse_Warp): - """The light profile is defined as a cubic spline interpolation of the - stored brightness values: - - I(R) = interp(R, profR, I) - - where I(R) is the brightness along the semi-major axis, interp is - a cubic spline function, R is the semi-major axis length, profR is - a list of radii for the spline, I is a corresponding list of - brightnesses at each profR value. - - Parameters: - I(R): Tensor of brighntess values, represented as the log of the brightness divided by pixelscale squared - - """ - - model_type = f"spline {SuperEllipse_Warp.model_type}" - parameter_specs = { - "I(R)": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = SuperEllipse_Warp._parameter_order + ("I(R)",) - usable = True - extend_profile = True - - from ._shared_methods import spline_initialize as initialize - from ._shared_methods import spline_radial_model as radial_model - - -class Spline_FourierEllipse_Warp(FourierEllipse_Warp): - """The light profile is defined as a cubic spline interpolation of the - stored brightness values: - - I(R) = interp(R, profR, I) - - where I(R) is the brightness along the semi-major axis, interp is - a cubic spline function, R is the semi-major axis length, profR is - a list of radii for the spline, I is a corresponding list of - brightnesses at each profR value. - - Parameters: - I(R): Tensor of brighntess values, represented as the log of the brightness divided by pixelscale squared - - """ - - model_type = f"spline {FourierEllipse_Warp.model_type}" - parameter_specs = { - "I(R)": {"units": "log10(flux/arcsec^2)"}, - } - _parameter_order = FourierEllipse_Warp._parameter_order + ("I(R)",) - usable = True - extend_profile = True - - from ._shared_methods import spline_initialize as initialize - from ._shared_methods import spline_radial_model as radial_model diff --git a/astrophot/models/superellipse_model.py b/astrophot/models/superellipse_model.py deleted file mode 100644 index 64e3b6d3..00000000 --- a/astrophot/models/superellipse_model.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch - -from .galaxy_model_object import Galaxy_Model -from .warp_model import Warp_Galaxy -from ..utils.decorators import default_internal - -__all__ = ["SuperEllipse_Galaxy", "SuperEllipse_Warp"] - - -class SuperEllipse_Galaxy(Galaxy_Model): - """Expanded galaxy model which includes a superellipse transformation - in its radius metric. This allows for the expression of "boxy" and - "disky" isophotes instead of pure ellipses. This is a common - extension of the standard elliptical representation, especially - for early-type galaxies. The functional form for this is: - - R = (|X|^C + |Y|^C)^(1/C) - - where R is the new distance metric, X Y are the coordinates, and C - is the coefficient for the superellipse. C can take on any value - greater than zero where C = 2 is the standard distance metric, 0 < - C < 2 creates disky or pointed perturbations to an ellipse, and C - > 2 transforms an ellipse to be more boxy. - - Parameters: - C0: superellipse distance metric parameter where C0 = C-2 so that a value of zero is now a standard ellipse. - - """ - - model_type = f"superellipse {Galaxy_Model.model_type}" - parameter_specs = { - "C0": {"units": "C-2", "value": 0.0, "uncertainty": 1e-2, "limits": (-2, None)}, - } - _parameter_order = Galaxy_Model._parameter_order + ("C0",) - usable = False - - @default_internal - def radius_metric(self, X, Y, image=None, parameters=None): - return torch.pow( - torch.pow(torch.abs(X), parameters["C0"].value + 2.0) - + torch.pow(torch.abs(Y), parameters["C0"].value + 2.0), - 1.0 / (parameters["C0"].value + 2.0), - ) - - -class SuperEllipse_Warp(Warp_Galaxy): - """Expanded warp model which includes a superellipse transformation - in its radius metric. This allows for the expression of "boxy" and - "disky" isophotes instead of pure ellipses. This is a common - extension of the standard elliptical representation, especially - for early-type galaxies. The functional form for this is: - - R = (|X|^C + |Y|^C)^(1/C) - - where R is the new distance metric, X Y are the coordinates, and C - is the coefficient for the superellipse. C can take on any value - greater than zero where C = 2 is the standard distance metric, 0 < - C < 2 creates disky or pointed perturbations to an ellipse, and C - > 2 transforms an ellipse to be more boxy. - - Parameters: - C0: superellipse distance metric parameter where C0 = C-2 so that a value of zero is now a standard ellipse. - - - """ - - model_type = f"superellipse {Warp_Galaxy.model_type}" - parameter_specs = { - "C0": {"units": "C-2", "value": 0.0, "uncertainty": 1e-2, "limits": (-2, None)}, - } - _parameter_order = Warp_Galaxy._parameter_order + ("C0",) - usable = False - - @default_internal - def radius_metric(self, X, Y, image=None, parameters=None): - return torch.pow( - torch.pow(torch.abs(X), parameters["C0"].value + 2.0) - + torch.pow(torch.abs(Y), parameters["C0"].value + 2.0), - 1.0 / (parameters["C0"].value + 2.0), - ) # epsilon added for numerical stability of gradient diff --git a/astrophot/models/warp_model.py b/astrophot/models/warp_model.py deleted file mode 100644 index 664dc561..00000000 --- a/astrophot/models/warp_model.py +++ /dev/null @@ -1,120 +0,0 @@ -import numpy as np -import torch - -from .galaxy_model_object import Galaxy_Model -from ..utils.interpolate import cubic_spline_torch -from ..utils.conversions.coordinates import Rotate_Cartesian -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ..param import Param_Unlock, Param_SoftLimits -from ._shared_methods import select_target - -__all__ = ["Warp_Galaxy"] - - -class Warp_Galaxy(Galaxy_Model): - """Galaxy model which includes radially varrying PA and q - profiles. This works by warping the coordinates using the same - transform for a global PA/q except applied to each pixel - individually. In the limit that PA and q are a constant, this - recovers a basic galaxy model with global PA/q. However, a linear - PA profile will give a spiral appearance, variations of PA/q - profiles can create complex galaxy models. The form of the - coordinate transformation looks like: - - X, Y = meshgrid(image) - R = sqrt(X^2 + Y^2) - X', Y' = Rot(theta(R), X, Y) - Y'' = Y' / q(R) - - where the definitions are the same as for a regular galaxy model, - except now the theta is a function of radius R (before - transformation) and the axis ratio q is also a function of radius - (before the transformation). - - Parameters: - q(R): Tensor of axis ratio values for axis ratio spline - PA(R): Tensor of position angle values as input to the spline - - """ - - model_type = f"warp {Galaxy_Model.model_type}" - parameter_specs = { - "q(R)": {"units": "b/a", "limits": (0.05, 1), "uncertainty": 0.04}, - "PA(R)": { - "units": "rad", - "limits": (0, np.pi), - "cyclic": True, - "uncertainty": 0.08, - }, - } - _parameter_order = Galaxy_Model._parameter_order + ("q(R)", "PA(R)") - usable = False - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - # create the PA(R) and q(R) profile radii if needed - for prof_param in ["PA(R)", "q(R)"]: - if parameters[prof_param].prof is None: - if parameters[prof_param].value is None: # from scratch - new_prof = [0, 2 * target.pixel_length] - while new_prof[-1] < torch.min(self.window.shape / 2): - new_prof.append( - new_prof[-1] + torch.max(2 * target.pixel_length, new_prof[-1] * 0.2) - ) - new_prof.pop() - new_prof.pop() - new_prof.append(torch.sqrt(torch.sum((self.window.shape / 2) ** 2))) - parameters[prof_param].prof = new_prof - else: # matching length of a provided profile - # create logarithmically spaced profile radii - new_prof = [0] + list( - np.logspace( - np.log10(2 * target.pixel_length), - np.log10(torch.max(self.window.shape / 2).item()), - len(parameters[prof_param].value) - 1, - ) - ) - # ensure no step is smaller than a pixelscale - for i in range(1, len(new_prof)): - if new_prof[i] - new_prof[i - 1] < target.pixel_length.item(): - new_prof[i] = new_prof[i - 1] + target.pixel_length.item() - parameters[prof_param].prof = new_prof - - if not (parameters["PA(R)"].value is None or parameters["q(R)"].value is None): - return - - with Param_Unlock(parameters["PA(R)"]), Param_SoftLimits(parameters["PA(R)"]): - if parameters["PA(R)"].value is None: - parameters["PA(R)"].value = np.zeros(len(parameters["PA(R)"].prof)) + target.north - if parameters["PA(R)"].uncertainty is None: - parameters["PA(R)"].uncertainty = (5 * np.pi / 180) * torch.ones_like( - parameters["PA(R)"].value - ) - if parameters["q(R)"].value is None: - # If no initial value is provided for q(R) a heuristic initial value is assumed. - # The most neutral initial position would be 1, but the boundaries of q are (0,1) non-inclusive - # so that is not allowed. A value like 0.999 may get stuck since it is near the very edge of - # the (0,1) range. So 0.9 is chosen to be mostly passive, but still some signal for the optimizer. - parameters["q(R)"].value = np.ones(len(parameters["q(R)"].prof)) * 0.9 - if parameters["q(R)"].uncertainty is None: - parameters["q(R)"].uncertainty = self.default_uncertainty * parameters["q(R)"].value - - @default_internal - def transform_coordinates(self, X, Y, image=None, parameters=None): - X, Y = super().transform_coordinates(X, Y, image, parameters) - R = self.radius_metric(X, Y, image, parameters) - PA = cubic_spline_torch( - parameters["PA(R)"].prof, - -(parameters["PA(R)"].value - image.north), - R.view(-1), - ).view(*R.shape) - q = cubic_spline_torch(parameters["q(R)"].prof, parameters["q(R)"].value, R.view(-1)).view( - *R.shape - ) - X, Y = Rotate_Cartesian(PA, X, Y) - return X, Y / q diff --git a/astrophot/models/wedge_model.py b/astrophot/models/wedge_model.py deleted file mode 100644 index 31ee5b74..00000000 --- a/astrophot/models/wedge_model.py +++ /dev/null @@ -1,84 +0,0 @@ -import numpy as np -import torch - -from .galaxy_model_object import Galaxy_Model -from ..utils.decorators import default_internal - -__all__ = ["Wedge_Galaxy"] - - -class Wedge_Galaxy(Galaxy_Model): - """Variant of the ray model where no smooth transition is performed - between regions as a function of theta, instead there is a sharp - trnasition boundary. This may be desirable as it cleanly - separates where the pixel information is going. Due to the sharp - transition though, it may cause unusual behaviour when fitting. If - problems occur, try fitting a ray model first then fix the center, - PA, and q and then fit the wedge model. Essentially this breaks - down the structure fitting and the light profile fitting into two - steps. The wedge model, like the ray model, defines no extra - parameters, however a new option can be supplied on instantiation - of the wedge model which is "wedges" or the number of wedges in - the model. - - """ - - model_type = f"wedge {Galaxy_Model.model_type}" - special_kwargs = Galaxy_Model.special_kwargs + ["wedges"] - wedges = 2 - track_attrs = Galaxy_Model.track_attrs + ["wedges"] - usable = False - - def __init__(self, *args, **kwargs): - self.symmetric_wedges = True - super().__init__(*args, **kwargs) - self.wedges = kwargs.get("wedges", 2) - - @default_internal - def polar_model(self, R, T, image=None, parameters=None): - model = torch.zeros_like(R) - if self.wedges % 2 == 0 and self.symmetric_wedges: - for w in range(self.wedges): - angles = (T - (w * np.pi / self.wedges)) % np.pi - indices = torch.logical_or( - angles < (np.pi / (2 * self.wedges)), - angles >= (np.pi * (1 - 1 / (2 * self.wedges))), - ) - model[indices] += self.iradial_model(w, R[indices], image, parameters) - elif self.wedges % 2 == 1 and self.symmetric_wedges: - for w in range(self.wedges): - angles = (T - (w * np.pi / self.wedges)) % (2 * np.pi) - indices = torch.logical_or( - angles < (np.pi / (2 * self.wedges)), - angles >= (np.pi * (2 - 1 / (2 * self.wedges))), - ) - model[indices] += self.iradial_model(w, R[indices], image, parameters) - angles = (T - (np.pi + w * np.pi / self.wedges)) % (2 * np.pi) - indices = torch.logical_or( - angles < (np.pi / (2 * self.wedges)), - angles >= (np.pi * (2 - 1 / (2 * self.wedges))), - ) - model[indices] += self.iradial_model(w, R[indices], image, parameters) - else: - for w in range(self.wedges): - angles = (T - (w * 2 * np.pi / self.wedges)) % (2 * np.pi) - indices = torch.logical_or( - angles < (np.pi / self.wedges), - angles >= (np.pi * (2 - 1 / self.wedges)), - ) - model[indices] += self.iradial_model(w, R[indices], image, parameters) - return model - - @default_internal - def evaluate_model(self, X=None, Y=None, image=None, parameters=None, **kwargs): - if X is None: - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - XX, YY = self.transform_coordinates(X, Y, image, parameters) - - return self.polar_model( - self.radius_metric(XX, YY, image=image, parameters=parameters), - self.angular_metric(XX, YY, image=image, parameters=parameters), - image=image, - parameters=parameters, - ) diff --git a/astrophot/models/zernike_model.py b/astrophot/models/zernike_model.py deleted file mode 100644 index 73b2ebb8..00000000 --- a/astrophot/models/zernike_model.py +++ /dev/null @@ -1,138 +0,0 @@ -from functools import lru_cache - -import torch -from scipy.special import binom - -from ..utils.decorators import ignore_numpy_warnings, default_internal -from ._shared_methods import select_target -from .psf_model_object import PSF_Model -from ..param import Param_Unlock, Param_SoftLimits -from ..errors import SpecificationConflict - -__all__ = ("Zernike_PSF",) - - -class Zernike_PSF(PSF_Model): - - model_type = f"zernike {PSF_Model.model_type}" - parameter_specs = { - "Anm": {"units": "flux/arcsec^2"}, - } - _parameter_order = PSF_Model._parameter_order + ("Anm",) - usable = True - model_integrated = False - - def __init__(self, *, name=None, order_n=2, r_scale=None, **kwargs): - super().__init__(name=name, **kwargs) - - self.order_n = int(order_n) - self.r_scale = r_scale - self.nm_list = self.iter_nm(self.order_n) - - @torch.no_grad() - @ignore_numpy_warnings - @select_target - @default_internal - def initialize(self, target=None, parameters=None, **kwargs): - super().initialize(target=target, parameters=parameters) - - # List the coefficients to use - self.nm_list = self.iter_nm(self.order_n) - # Set the scale radius for the Zernike area - if self.r_scale is None: - self.r_scale = torch.max(self.window.shape) / 2 - - # Check if user has already set the coefficients - if parameters["Anm"].value is not None: - if len(self.nm_list) != len(parameters["Anm"].value): - raise SpecificationConflict( - f"nm_list length ({len(self.nm_list)}) must match coefficients ({len(parameters['Anm'].value)})" - ) - return - - # Set the default coefficients to zeros - with Param_Unlock(parameters["Anm"]), Param_SoftLimits(parameters["Anm"]): - parameters["Anm"].value = torch.zeros(len(self.nm_list)) - if parameters["Anm"].uncertainty is None: - parameters["Anm"].uncertainty = self.default_uncertainty * torch.ones_like( - parameters["Anm"].value - ) - # Set the zero order zernike polynomial to the average in the image - if self.nm_list[0] == (0, 0): - parameters["Anm"].value[0] = ( - torch.median(target[self.window].data) / target.pixel_area - ) - - def iter_nm(self, n): - nm = [] - for n_i in range(n + 1): - for m_i in range(-n_i, n_i + 1, 2): - nm.append((n_i, m_i)) - return nm - - @staticmethod - @lru_cache(maxsize=1024) - def coefficients(n, m): - C = [] - for k in range(int((n - abs(m)) / 2) + 1): - C.append( - ( - k, - (-1) ** k * binom(n - k, k) * binom(n - 2 * k, (n - abs(m)) / 2 - k), - ) - ) - return C - - def Z_n_m(self, rho, phi, n, m, efficient=True): - Z = torch.zeros_like(rho) - if efficient: - T_cache = {0: None} - R_cache = {} - for k, c in self.coefficients(n, m): - if efficient: - if (n - 2 * k) not in R_cache: - R_cache[n - 2 * k] = rho ** (n - 2 * k) - R = R_cache[n - 2 * k] - if m not in T_cache: - if m < 0: - T_cache[m] = torch.sin(abs(m) * phi) - elif m > 0: - T_cache[m] = torch.cos(m * phi) - T = T_cache[m] - else: - R = rho ** (n - 2 * k) - if m < 0: - T = torch.sin(abs(m) * phi) - elif m > 0: - T = torch.cos(m * phi) - - if m == 0: - Z += c * R - elif m < 0: - Z += c * R * T - else: - Z += c * R * T - return Z - - @default_internal - def evaluate_model(self, X=None, Y=None, image=None, parameters=None): - if X is None: - Coords = image.get_coordinate_meshgrid() - X, Y = Coords - parameters["center"].value[..., None, None] - - phi = self.angular_metric(X, Y, image, parameters) - - r = self.radius_metric(X, Y, image, parameters) - r = r / self.r_scale - - G = torch.zeros_like(X) - - i = 0 - A = image.pixel_area * parameters["Anm"].value - for n, m in self.nm_list: - G += A[i] * self.Z_n_m(r, phi, n, m) - i += 1 - - G[r > 1] = 0.0 - - return G diff --git a/astrophot/param/__init__.py b/astrophot/param/__init__.py index fd67d0b9..1b780893 100644 --- a/astrophot/param/__init__.py +++ b/astrophot/param/__init__.py @@ -1,3 +1,5 @@ -from .parameter import * -from .param_context import * -from .base import * +from caskade import forward, ValidContext, OverrideParam +from .module import Module +from .param import Param + +__all__ = ["Module", "Param", "forward", "ValidContext", "OverrideParam"] diff --git a/astrophot/param/base.py b/astrophot/param/base.py deleted file mode 100644 index 3bea49a7..00000000 --- a/astrophot/param/base.py +++ /dev/null @@ -1,201 +0,0 @@ -from collections import OrderedDict -from abc import ABC, abstractmethod -from ..errors import InvalidParameter - -__all__ = ["Node"] - - -class Node(ABC): - """Base node object in the Directed Acyclic Graph (DAG). - - The base Node object handles storing the DAG nodes and links - between them. An important part of the DAG system is to be able to - find all the leaf nodes, which is done using the `flat` function. - - Args: - name (str): The name of the node, this should identify it uniquely in the local context it will be used in. - locked (bool): Records if the node is locked, this is relevant for some other operations which only act on unlocked nodes. - link (tuple[Node]): A tuple of node objects which this node will be linked to on initialization. - - """ - - global_unlock = False - - def __init__(self, name, **kwargs): - if ":" in name: - raise ValueError(f"Node names must not have ':' character. Cannot use name: {name}") - self.name = name - self.nodes = OrderedDict() - if "state" in kwargs: - self.set_state(kwargs["state"]) - return - if "link" in kwargs: - self.link(*kwargs["link"]) - self.locked = kwargs.get("locked", False) - - def link(self, *nodes): - """Creates a directed link from the current node to the provided - node(s) in the input. This function will also check that the - linked node does not exist higher up in the DAG to the current - node, if that is the case then a cycle has formed which breaks - the DAG structure and could cause problems. An error will be - thrown in this case. - - The linked node is added to a ``nodes`` dictionary that each - node stores. This makes it easy to check which nodes are - linked to each other. - - """ - for node in nodes: - for subnode_id in node.flat(include_locked=True, include_links=True).keys(): - if self.identity == subnode_id: - raise InvalidParameter( - "Parameter structure must be Directed Acyclic Graph! Adding this node would create a cycle" - ) - self.nodes[node.name] = node - - def unlink(self, *nodes): - """Undoes the linking of two nodes. Note that this could sever the - connection of many nodes to each other if the current node was - the only link between two branches. - - """ - for node in nodes: - del self.nodes[node.name] - - def dump(self): - """Simply unlinks all nodes that the current node is linked with.""" - self.unlink(*self.nodes.values()) - - @property - def leaf(self): - """Returns True when the current node is a leaf node.""" - return len(self.nodes) == 0 - - @property - def branch(self): - """Returns True when the current node is a branch node (not a leaf node, is linked to more nodes).""" - return len(self.nodes) > 0 - - def __getitem__(self, key): - """Used to get a node from the DAG relative to the current node. It - is possible to collect nodes from deeper in the DAG by - separating the names of the nodes along the path with a colon - (:). For example:: - - first_node["second_node:third_node"] - - returns a node that is actually linked to ``second_node`` - without needing to first get ``second_node`` then call - ``second_node['third_node']``. - - """ - if key == self.name: - return self - if key in self.nodes: - return self.nodes[key] - if isinstance(key, str) and ":" in key: - base, stem = key.split(":", 1) - return self.nodes[base][stem] - if isinstance(key, int): - for node in self.nodes.values(): - if key == node.identity: - return node - raise KeyError(f"Unrecognized key for '{self.name}': {key}") - - def __contains__(self, key): - """Check if a node has a link directly to another node. A check like - ``"second_node" in first_node`` would return true only if - ``first_node`` was linked to ``second_node``. - - """ - return key in self.nodes - - def __eq__(self, other): - """Equality check for nodes only returns true if they are in fact the - same node. - - """ - return self is other - - @property - def identity(self): - """A read only property of the node which does not change over it's - lifetime that uniquely identifies it relative to other - nodes. By default this just uses the ``id(self)`` though for - the purpose of saving/loading it may not always be this way. - - """ - try: - return self._identity - except AttributeError: - return id(self) - - def get_state(self): - """Returns a dictionary with state information about this node. From - that dictionary the node can reconstruct itself, or form - another node which is a copy of this one. - - """ - state = { - "name": self.name, - "identity": self.identity, - } - if self.locked: - state["locked"] = self.locked - if len(self.nodes) > 0: - state["nodes"] = list(node.get_state() for node in self.nodes.values()) - return state - - def set_state(self, state): - """Used to set the state of the node for the purpose of - loading/copying. This uses the dictionary produced by - ``get_state`` to re-create itself. - - """ - self.name = state["name"] - self._identity = state["identity"] - if "nodes" in state: - for node in state["nodes"]: - self.link(self.__class__(name=node["name"], state=node)) - self.locked = state.get("locked", False) - - def __iter__(self): - return filter(lambda n: not n.locked, self.nodes.values()) - - @property - @abstractmethod - def value(self): ... - - def flat(self, include_locked=True, include_links=False): - """Searches the DAG from this node and collects other nodes in the - graph. By default it will include all leaf nodes only, however - it can be directed to only collect leaf nodes that are not - locked, it can also be directed to collect all nodes instead - of just leaf nodes. - - """ - flat = OrderedDict() - if self.leaf and self.value is not None: - if (not self.locked) or include_locked or Node.global_unlock: - flat[self.identity] = self - for node in self.nodes.values(): - if node.locked and not (include_locked or Node.global_unlock): - continue - if node.leaf and node.value is not None: - flat[node.identity] = node - else: - if include_links and ((not node.locked) or include_locked or Node.global_unlock): - flat[node.identity] = node - flat.update(node.flat(include_locked)) - return flat - - def __str__(self): - return f"Node: {self.name}" - - def __repr__(self): - return ( - f"Node: {self.name} " - + ("locked" if self.locked else "unlocked") - + ("" if self.leaf else " {" + ";".join(repr(node) for node in self.nodes) + "}") - ) diff --git a/astrophot/param/module.py b/astrophot/param/module.py new file mode 100644 index 00000000..a29e0337 --- /dev/null +++ b/astrophot/param/module.py @@ -0,0 +1,88 @@ +import numpy as np +from math import prod +from caskade import ( + Module as CModule, + ActiveStateError, + ParamConfigurationError, + FillParamsArrayError, +) +from ..backend_obj import backend + + +class Module(CModule): + + def build_params_array_identities(self): + identities = [] + for param in self.dynamic_params: + numel = max(1, np.prod(param.shape)) + for i in range(numel): + identities.append(f"{id(param)}_{i}") + return identities + + def build_params_array_uncertainty(self): + uncertainties = [] + for param in self.dynamic_params: + if param.uncertainty is None: + uncertainties.append(backend.zeros_like(param.value.flatten())) + else: + uncertainties.append(param.uncertainty.flatten()) + return backend.concatenate(tuple(uncertainties), dim=-1) + + def build_params_array_names(self): + names = [] + for param in self.dynamic_params: + numel = max(1, np.prod(param.shape)) + if numel == 1: + names.append(param.name) + else: + for i in range(numel): + names.append(f"{param.name}_{i}") + return names + + def build_params_array_units(self): + units = [] + for param in self.dynamic_params: + numel = max(1, np.prod(param.shape)) + for _ in range(numel): + units.append(param.units) + return units + + def fill_dynamic_value_uncertainties(self, uncertainty): + if self.active: + raise ActiveStateError(f"Cannot fill dynamic values when Module {self.name} is active") + + dynamic_params = self.dynamic_params + + if uncertainty.shape[-1] == 0: + return # No parameters to fill + # check for batch dimension + pos = 0 + for param in dynamic_params: + if not isinstance(param.shape, tuple): + raise ParamConfigurationError( + f"Param {param.name} has no shape. dynamic parameters must have a shape to use Tensor input." + ) + # Handle scalar parameters + size = max(1, prod(param.shape)) + try: + val = uncertainty[..., pos : pos + size].reshape(param.shape) + param.uncertainty = val + except (RuntimeError, IndexError, ValueError, TypeError): + raise FillParamsArrayError(self.name, uncertainty, dynamic_params) + + pos += size + if pos != uncertainty.shape[-1]: + raise FillParamsArrayError(self.name, uncertainty, dynamic_params) + + def dynamic_params_array_index(self, param): + i = 0 + for p in self.dynamic_params: + if p is param: + return list(range(i, i + max(1, prod(p.shape)))) + i += max(1, prod(p.shape)) + try: + raise ValueError( + f"Param {param.name} not found in dynamic_params of Module {self.name}" + ) + except: + raise ValueError(f"Param {param} not found in dynamic_params of Module {self.name}") diff --git a/astrophot/param/param.py b/astrophot/param/param.py new file mode 100644 index 00000000..4df33cdf --- /dev/null +++ b/astrophot/param/param.py @@ -0,0 +1,73 @@ +from math import prod +import numpy as np + +from caskade import Param as CParam +from ..backend_obj import backend + + +class Param(CParam): + """ + A class that extends the Caskade Param class to include additional functionality. + This class is used to define parameters for models in the AstroPhot package. + """ + + def __init__(self, *args, uncertainty=None, prof=None, **kwargs): + super().__init__(*args, **kwargs) + self.uncertainty = uncertainty + self.saveattrs.add("uncertainty") + self.prof = prof + self.saveattrs.add("prof") + + @property + def uncertainty(self): + return self._uncertainty + + @uncertainty.setter + def uncertainty(self, uncertainty): + if uncertainty is None: + self._uncertainty = None + else: + self._uncertainty = backend.as_array(uncertainty) + + @property + def prof(self): + return self._prof + + @prof.setter + def prof(self, prof): + if prof is None: + self._prof = None + else: + self._prof = backend.as_array(prof) + + @property + def name_array(self): + numel = max(1, prod(self.shape)) + if numel == 1: + return np.array(self.name) + names = [f"{self.name}_{i}" for i in range(numel)] + return np.array(names).reshape(self.shape) + + @property + def initialized(self): + """Check if the parameter is initialized.""" + if self.pointer: + return True + if self.value is not None: + return True + return False + + def soft_valid(self, value): + if self.valid[0] is None and self.valid[1] is None: + return value + if self.valid[0] is not None and self.valid[1] is not None: + vrange = 0.1 * (self.valid[1] - self.valid[0]) + smin = self.valid[0] + 0.1 * vrange + smax = self.valid[1] - 0.1 * vrange + elif self.valid[0] is not None: + smin = self.valid[0] + 0.1 + smax = None + elif self.valid[1] is not None: + smin = None + smax = self.valid[1] - 0.1 + return backend.clamp(value, min=smin, max=smax) diff --git a/astrophot/param/param_context.py b/astrophot/param/param_context.py deleted file mode 100644 index 6a217486..00000000 --- a/astrophot/param/param_context.py +++ /dev/null @@ -1,102 +0,0 @@ -from .base import Node - -__all__ = ("Param_Unlock", "Param_SoftLimits", "Param_Mask") - - -class Param_Unlock: - """Temporarily unlock a parameter. - - Context manager to unlock a parameter temporarily. Inside the - context, the parameter will behave as unlocked regardless of its - initial condition. Upon exiting the context, the parameter will - return to its previous locked state regardless of any changes - made by the user to the lock state. - - """ - - def __init__(self, param=None): - self.param = param - - def __enter__(self): - if self.param is None: - Node.global_unlock = True - else: - self.original_locked = self.param.locked - self.param.locked = False - - def __exit__(self, *args, **kwargs): - if self.param is None: - Node.global_unlock = False - else: - self.param.locked = self.original_locked - - -class Param_SoftLimits: - """Temporarily allow writing parameter values outside limits. - - Values outside the limits will be quietly (no error/warning - raised) shifted until they are within the boundaries of the - parameter limits. Since the limits are non-inclusive, the soft - limits will actually move a parameter by 0.001 into the parameter - range. For example the axis ratio ``q`` has limits from (0,1) so - if one were to write: ``q.value = 2`` then the actual value that - gets written would be ``0.999``. - - Cyclic parameters are not affected by this, any value outside the - range is always (Param_SoftLimits context or not) wrapped back - into the range using modulo arithmetic. - - """ - - def __init__(self, param): - self.param = param - - def __enter__(self, *args, **kwargs): - self.original_setter = self.param._set_val_self - self.param._set_val_self = self.param._soft_set_val_self - - def __exit__(self, *args, **kwargs): - self.param._set_val_self = self.original_setter - - -class Param_Mask: - """Temporarily mask parameters. - - Select a subset of parameters to be used through the "vector" - interface of the DAG. The context is initialized with a - Parameter_Node object (``P``) and a torch tensor (``M``) where the - size of the mask should be equal to the current vector - representation of the parameter (``M.numel() == - P.vector_values().numel()``). The mask tensor should be of - ``torch.bool`` dtype where ``True`` indicates to keep using that - parameter and ``False`` indicates to hide that parameter value. - - Note that ``Param_Mask`` contexts can be nested and will behave - accordingly (the mask tensor will need to match the vector size - within the previous context). As an example, imagine there is a - parameter node ``P`` which has five sub-nodes each with a single - value, one could nest contexts like:: - - M1 = torch.tensor((1,1,0,1,0), dtype = torch.bool) - with Param_Mask(P, M1): - # Now P behaves as if it only has 3 elements - M2 = torch.tensor([0,1,1], dtype = torch.bool) - with Param_Mask(P, M2): - # Now P behaves as if it only has 2 elements - P.vector_values() # returns tensor with 2 elements - - """ - - def __init__(self, param, new_mask): - self.param = param - self.new_mask = new_mask - - def __enter__(self): - - self.old_mask = self.param.vector_mask() - self.mask = self.param.vector_mask() - self.mask[self.mask.clone()] = self.new_mask - self.param.vector_set_mask(self.mask) - - def __exit__(self, *args, **kwargs): - self.param.vector_set_mask(self.old_mask) diff --git a/astrophot/param/parameter.py b/astrophot/param/parameter.py deleted file mode 100644 index 7c772ab0..00000000 --- a/astrophot/param/parameter.py +++ /dev/null @@ -1,742 +0,0 @@ -from types import FunctionType - -import torch -import numpy as np - -from ..utils.conversions.optimization import ( - boundaries, - inv_boundaries, - cyclic_boundaries, -) -from .. import AP_config -from .base import Node -from ..errors import InvalidParameter - -__all__ = ["Parameter_Node"] - - -class Parameter_Node(Node): - """A node representing parameters and their relative structure. - - The Parameter_Node object stores all information relevant for the - parameters of a model. At a high level the Parameter_Node - accomplishes two tasks. The first task is to store the actual - parameter values, these are represented as pytorch tensors which - can have any shape; these are leaf nodes. The second task is to - store the relationship between parameters in a graph structure; - these are branch nodes. The two tasks are handled by the same type - of object since there is some overlap between them where a branch - node acts like a leaf node in certain contexts. - - There are various quantities that a Parameter_Node tracks which - can be provided as arguments or updated later. - - Args: - value: The value of a node represents the tensor which will be used by models to compute their projection into the pixels of an image. These can be quite complex, see further down for more details. - cyclic (bool): Records if the value of a node is cyclic, meaning that if it is updated outside it's limits it should be wrapped back into the limits. - limits (Tuple[Tensor or None, Tensor or None]): Tracks if a parameter has constraints on the range of values it can take. The first element is the lower limit, the second element is the upper limit. The two elements should either be None (no limit) or tensors with the same shape as the value. - units (str): The units of the parameter value. - uncertainty (Tensor or None): represents the uncertainty of the parameter value. This should be None (no uncertainty) or a Tensor with the same shape as the value. - prof (Tensor or None): This is a profile of values which has no explicit meaning, but can be used to store information which should be kept alongside the value. For example in a spline model the position of the spline points may be a ``prof`` while the flux at each node is the value to be optimized. - shape (Tuple or None): Can be used to set the shape of the value (number of elements/dimensions). If not provided then the shape will be set by the first time a value is given. Once a shape has been set, if a value is given which cannot be coerced into that shape, then an error will be thrown. - - The ``value`` of a Parameter_Node is somewhat complicated, there - are a number of states it can take on. The most straightforward is - just a Tensor, if a Tensor (or just an iterable like a list or - numpy.ndarray) is provided then the node is required to be a leaf - node and it will store the value to be accessed later by other - parts of AstroPhot. Another option is to set the value as another - node (they will automatically be linked), in this case the node's - ``value`` is just a wrapper to call for the ``value`` of the - linked node. Finally, the value may be a function which allows for - arbitrarily complex values to be computed from other node's - values. The function must take as an argument the current - Parameter_Node instance and return a Tensor. Here are some - examples of the various ways of interacting with the ``value`` for a hypothetical parameter ``P``:: - - P.value = 1. # Will create a tensor with value 1. - P.value = P2 # calling P.value will actually call P2.value - def compute_value(param): - return param["P2"].value**2 - P.value = compute_value # calling P.value will call the function as: compute_value(P) which will return P2.value**2 - - """ - - def __init__(self, name, **kwargs): - - super().__init__(name, **kwargs) - if "state" in kwargs: - return - temp_locked = self.locked - self.locked = False - self._value = None - self.prof = kwargs.get("prof", None) - self.limits = kwargs.get("limits", [None, None]) - self.cyclic = kwargs.get("cyclic", False) - self.shape = kwargs.get("shape", None) - self.value = kwargs.get("value", None) - self.units = kwargs.get("units", "none") - self.uncertainty = kwargs.get("uncertainty", None) - self.to() - self.locked = temp_locked - - @property - def value(self): - """The ``value`` of a Parameter_Node is somewhat complicated, there - are a number of states it can take on. The most - straightforward is just a Tensor, if a Tensor (or just an - iterable like a list or numpy.ndarray) is provided then the - node is required to be a leaf node and it will store the value - to be accessed later by other parts of AstroPhot. Another - option is to set the value as another node (they will - automatically be linked), in this case the node's ``value`` is - just a wrapper to call for the ``value`` of the linked - node. Finally, the value may be a function which allows for - arbitrarily complex values to be computed from other node's - values. The function must take as an argument the current - Parameter_Node instance and return a Tensor. Here are some - examples of the various ways of interacting with the ``value`` - for a hypothetical parameter ``P``:: - - P.value = 1. # Will create a tensor with value 1. - P.value = P2 # calling P.value will actually call P2.value - def compute_value(param): - return param["P2"].value**2 - P.value = compute_value # calling P.value will call the function as: compute_value(P) which will return P2.value**2 - - """ - if isinstance(self._value, Parameter_Node): - return self._value.value - if isinstance(self._value, FunctionType): - return self._value(self) - - return self._value - - @property - def mask(self): - """The mask tensor is stored internally and it cuts out some values - from the parameter. This is used by the ``vector`` methods in - the class to give the parameter DAG a dynamic shape. - - """ - if not self.leaf: - return self.vector_mask() - try: - return self._mask - except AttributeError: - return torch.ones(self.shape, dtype=torch.bool, device=AP_config.ap_device) - - @property - def identities(self): - """This creates a numpy array of strings which uniquely identify - every element in the parameter vector. For example a - ``center`` parameter with two components [x,y] would have - identities be ``np.array(["123456:0", "123456:1"])`` where the - first part is the unique id for the Parameter_Node object and - the second number indexes where in the value tensor it refers - to. - - """ - if self.leaf: - idstr = str(self.identity) - return np.array(tuple(f"{idstr}:{i}" for i in range(self.size))) - flat = self.flat(include_locked=False, include_links=False) - vec = tuple(node.identities for node in flat.values()) - if len(vec) > 0: - return np.concatenate(vec) - return np.array(()) - - @property - def names(self): - """Returns a numpy array of names for all the elements of the - ``vector`` representation where the name is determined by the - name of the parameters. Note that this does not create a - unique name for each element and this should only be used for - graphical purposes on small parameter DAGs. - - """ - if self.leaf: - S = self.size - if S == 1: - return np.array((self.name,)) - return np.array(tuple(f"{self.name}:{i}" for i in range(self.size))) - flat = self.flat(include_locked=False, include_links=False) - vec = tuple(node.names for node in flat.values()) - if len(vec) > 0: - return np.concatenate(vec) - return np.array(()) - - def vector_values(self): - """The vector representation is for values which correspond to - fundamental inputs to the parameter DAG. Since the DAG may - have linked nodes, or functions which produce values derived - from other node values, the collection of all "values" is not - necessarily of use for some methods such as fitting - algorithms. The vector representation is useful for optimizers - as it gives a fundamental representation of the parameter - DAG. The vector_values function returns a vector of the - ``value`` for each leaf node. - - """ - - if self.leaf: - return self.value[self.mask].flatten() - - flat = self.flat(include_locked=False, include_links=False) - vec = tuple(node.vector_values() for node in flat.values()) - if len(vec) > 0: - return torch.cat(vec) - return torch.tensor((), dtype=AP_config.ap_dtype, device=AP_config.ap_device) - - def vector_uncertainty(self): - """This returns a vector (see vector_values) with the uncertainty for - each leaf node. - - """ - if self.leaf: - if self._uncertainty is None: - self.uncertainty = torch.ones_like(self.value) - return self.uncertainty[self.mask].flatten() - - flat = self.flat(include_locked=False, include_links=False) - vec = tuple(node.vector_uncertainty() for node in flat.values()) - if len(vec) > 0: - return torch.cat(vec) - return torch.tensor((), dtype=AP_config.ap_dtype, device=AP_config.ap_device) - - def vector_representation(self): - """This returns a vector (see vector_values) with the representation - for each leaf node. The representation is an alternative view - of each value which is mapped into the (-inf, inf) range where - optimization is more stable. - - """ - return self.vector_transform_val_to_rep(self.vector_values()) - - def vector_mask(self): - """This returns a vector (see vector_values) with the mask for each - leaf node. Note however that the mask is not itself masked, - this vector is always the full size of the unmasked parameter - DAG. - - """ - if self.leaf: - return self.mask.flatten() - - flat = self.flat(include_locked=False, include_links=False) - vec = tuple(node.vector_mask() for node in flat.values()) - if len(vec) > 0: - return torch.cat(vec) - return torch.tensor((), dtype=AP_config.ap_dtype, device=AP_config.ap_device) - - def vector_identities(self): - """This returns a vector (see vector_values) with the identities for - each leaf node. - - """ - if self.leaf: - return self.identities[self.vector_mask().detach().cpu().numpy()].flatten() - flat = self.flat(include_locked=False, include_links=False) - vec = tuple(node.vector_identities() for node in flat.values()) - if len(vec) > 0: - return np.concatenate(vec) - return np.array(()) - - def vector_names(self): - """This returns a vector (see vector_values) with the names for each - leaf node. - - """ - if self.leaf: - return self.names[self.vector_mask().detach().cpu().numpy()].flatten() - flat = self.flat(include_locked=False, include_links=False) - vec = tuple(node.vector_names() for node in flat.values()) - if len(vec) > 0: - return np.concatenate(vec) - return np.array(()) - - def vector_set_values(self, values): - """This function allows one to update the full vector of values in a - single call by providing a tensor of the appropriate size. The - input will be separated so that the correct elements are - passed to the correct leaf nodes. - - """ - values = torch.as_tensor( - values, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ).flatten() - if self.leaf: - self._value[self.mask] = values - return - - mask = self.vector_mask() - flat = self.flat(include_locked=False, include_links=False) - - loc = 0 - for node in flat.values(): - node.vector_set_values( - values[mask[:loc].sum().int() : mask[: loc + node.size].sum().int()] - ) - loc += node.size - - def vector_set_uncertainty(self, uncertainty): - """Update the uncertainty vector for this parameter DAG (see - vector_set_values). - - """ - uncertainty = torch.as_tensor( - uncertainty, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - if self.leaf: - if self._uncertainty is None: - self._uncertainty = torch.ones_like(self.value) - self._uncertainty[self.mask] = uncertainty - return - - mask = self.vector_mask() - flat = self.flat(include_locked=False, include_links=False) - - loc = 0 - for node in flat.values(): - node.vector_set_uncertainty( - uncertainty[mask[:loc].sum().int() : mask[: loc + node.size].sum().int()] - ) - loc += node.size - - def vector_set_mask(self, mask): - """Update the mask vector for this parameter DAG (see - vector_set_values). Note again that the mask vector is always - the full size of the DAG. - - """ - mask = torch.as_tensor(mask, dtype=torch.bool, device=AP_config.ap_device) - if self.leaf: - self._mask = mask.reshape(self.shape) - return - flat = self.flat(include_locked=False, include_links=False) - - loc = 0 - for node in flat.values(): - node.vector_set_mask(mask[loc : loc + node.size]) - loc += node.size - - def vector_set_representation(self, rep): - """Update the representation vector for this parameter DAG (see - vector_set_values). - - """ - self.vector_set_values(self.vector_transform_rep_to_val(rep)) - - def vector_transform_rep_to_val(self, rep): - """Used to transform between the ``vector_values`` and - ``vector_representation`` views of the elements in the DAG - leafs. This transforms from representation to value. - - The transformation is done based on the limits of each - parameter leaf. If no limits are provided then the - representation and value are equivalent. If both are given - then a ``tan`` and ``arctan`` are used to convert between the - finite range and the infinite range. If the limits are - one-sided then the transformation: ``newvalue = value - 1 / - (value - limit)`` is used. - - """ - rep = torch.as_tensor(rep, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - if self.leaf: - if self.cyclic: - val = cyclic_boundaries(rep, (self.limits[0], self.limits[1])) - elif self.limits[0] is None and self.limits[1] is None: - val = rep - else: - val = inv_boundaries( - rep, - ( - None if self.limits[0] is None else self.limits[0], - None if self.limits[1] is None else self.limits[1], - ), - ) - return val - - mask = self.vector_mask() - flat = self.flat(include_locked=False, include_links=False) - - loc = 0 - vals = [] - for node in flat.values(): - vals.append( - node.vector_transform_rep_to_val( - rep[mask[:loc].sum().int() : mask[: loc + node.size].sum().int()] - ) - ) - loc += node.size - if len(vals) > 0: - return torch.cat(vals) - return torch.tensor((), dtype=AP_config.ap_dtype, device=AP_config.ap_device) - - def vector_transform_val_to_rep(self, val): - """Used to transform between the ``vector_values`` and - ``vector_representation`` views of the elements in the DAG - leafs. This transforms from value to representation. - - The transformation is done based on the limits of each - parameter leaf. If no limits are provided then the - representation and value are equivalent. If both are given - then a ``tan`` and ``arctan`` are used to convert between the - finite range and the infinite range. If the limits are - one-sided then the transformation: ``newvalue = value - 1 / - (value - limit)`` is used. - - """ - val = torch.as_tensor(val, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - if self.leaf: - if self.cyclic: - rep = cyclic_boundaries(val, (self.limits[0], self.limits[1])) - elif self.limits[0] is None and self.limits[1] is None: - rep = val - else: - rep = boundaries( - val, - ( - None if self.limits[0] is None else self.limits[0], - None if self.limits[1] is None else self.limits[1], - ), - ) - return rep - - mask = self.vector_mask() - flat = self.flat(include_locked=False, include_links=False) - - loc = 0 - reps = [] - for node in flat.values(): - reps.append( - node.vector_transform_val_to_rep( - val[mask[:loc].sum().int() : mask[: loc + node.size].sum().int()] - ) - ) - loc += node.size - if len(reps) > 0: - return torch.cat(reps) - return torch.tensor((), dtype=AP_config.ap_dtype, device=AP_config.ap_device) - - def _set_val_self(self, val): - """Handles the setting of the value for a leaf node. Ensures the - value is a Tensor and that it has the right shape. Will also - check the limits of the value which has different behaviour - depending on if it is cyclic, one sided, or two sided. - - """ - val = torch.as_tensor(val, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - if self.shape is not None: - self._value = val.reshape(self.shape) - else: - self._value = val - self.shape = self._value.shape - - if self.cyclic: - self._value = self.limits[0] + ( - (self._value - self.limits[0]) % (self.limits[1] - self.limits[0]) - ) - return - if self.limits[0] is not None: - if not torch.all(self._value > self.limits[0]): - raise InvalidParameter( - f"{self.name} has lower limit {self.limits[0].detach().cpu().tolist()}" - ) - if self.limits[1] is not None: - if not torch.all(self._value < self.limits[1]): - raise InvalidParameter( - f"{self.name} has upper limit {self.limits[1].detach().cpu().tolist()}" - ) - - def _soft_set_val_self(self, val): - """The same as ``_set_val_self`` except that it doesn't raise an - error when the values are set outside their range, instead it - will push the values into the range defined by the limits. - - """ - val = torch.as_tensor(val, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - if self.shape is not None: - self._value = val.reshape(self.shape) - else: - self._value = val - self.shape = self._value.shape - - if self.cyclic: - self._value = self.limits[0] + ( - (self._value - self.limits[0]) % (self.limits[1] - self.limits[0]) - ) - return - if self.limits[0] is not None: - self._value = torch.maximum( - self._value, self.limits[0] + torch.ones_like(self._value) * 1e-3 - ) - if self.limits[1] is not None: - self._value = torch.minimum( - self._value, self.limits[1] - torch.ones_like(self._value) * 1e-3 - ) - - @value.setter - def value(self, val): - if self.locked and not Node.global_unlock: - return - if val is None: - self._value = None - self.shape = None - return - if isinstance(val, str): - self._value = val - return - if isinstance(val, Parameter_Node): - self._value = val - self.shape = None - # Link only to the pointed node - self.dump() - self.link(val) - return - if isinstance(val, FunctionType): - self._value = val - self.shape = None - return - if len(self.nodes) > 0: - self.vector_set_values(val) - self.shape = None - return - self._set_val_self(val) - self.dump() - - @property - def shape(self): - try: - if isinstance(self._value, Parameter_Node): - return self._value.shape - if isinstance(self._value, FunctionType): - return self.value.shape - if self.leaf: - return self._shape - except AttributeError: - pass - return None - - @shape.setter - def shape(self, shape): - self._shape = shape - - @property - def prof(self): - return self._prof - - @prof.setter - def prof(self, prof): - if self.locked and not Node.global_unlock: - return - if prof is None: - self._prof = None - return - self._prof = torch.as_tensor(prof, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - - @property - def uncertainty(self): - return self._uncertainty - - @uncertainty.setter - def uncertainty(self, unc): - if self.locked and not Node.global_unlock: - return - if unc is None: - self._uncertainty = None - return - - self._uncertainty = torch.as_tensor( - unc, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - # Ensure that the uncertainty tensor has the same shape as the data - if self.shape is not None: - if self._uncertainty.shape != self.shape: - self._uncertainty = self._uncertainty * torch.ones( - self.shape, dtype=AP_config.ap_dtype, device=AP_config.ap_device - ) - - @property - def limits(self): - return self._limits - - @limits.setter - def limits(self, limits): - if self.locked and not Node.global_unlock: - return - if limits[0] is None: - low = None - else: - low = torch.as_tensor(limits[0], dtype=AP_config.ap_dtype, device=AP_config.ap_device) - if limits[1] is None: - high = None - else: - high = torch.as_tensor(limits[1], dtype=AP_config.ap_dtype, device=AP_config.ap_device) - self._limits = (low, high) - - def to(self, dtype=None, device=None): - """ - updates the datatype or device of this parameter - """ - if dtype is not None: - dtype = AP_config.ap_dtype - if device is not None: - device = AP_config.ap_device - - if isinstance(self._value, torch.Tensor): - self._value = self._value.to(dtype=dtype, device=device) - elif len(self.nodes) > 0: - for node in self.nodes.values(): - node.to(dtype, device) - if isinstance(self._uncertainty, torch.Tensor): - self._uncertainty = self._uncertainty.to(dtype=dtype, device=device) - if isinstance(self.prof, torch.Tensor): - self.prof = self.prof.to(dtype=dtype, device=device) - return self - - def get_state(self): - """Return the values representing the current state of the parameter, - this can be used to re-load the state later from memory. - - """ - state = super().get_state() - if self.value is not None: - if isinstance(self._value, Node): - state["value"] = "NODE:" + str(self._value.identity) - elif isinstance(self._value, FunctionType): - state["value"] = "FUNCTION:" + self._value.__name__ - else: - state["value"] = self.value.detach().cpu().numpy().tolist() - if self.shape is not None: - state["shape"] = list(self.shape) - if self.units is not None: - state["units"] = self.units - if self.uncertainty is not None: - state["uncertainty"] = self.uncertainty.detach().cpu().numpy().tolist() - if not (self.limits[0] is None and self.limits[1] is None): - save_lim = [] - for i in [0, 1]: - if self.limits[i] is None: - save_lim.append(None) - else: - save_lim.append(self.limits[i].detach().cpu().tolist()) - state["limits"] = save_lim - if self.cyclic: - state["cyclic"] = self.cyclic - if self.prof is not None: - state["prof"] = self.prof.detach().cpu().tolist() - - return state - - def set_state(self, state): - """Update the state of the parameter given a state variable which - holds all information about a variable. - - """ - - super().set_state(state) - save_locked = self.locked - self.locked = False - self.units = state.get("units", None) - self.limits = state.get("limits", (None, None)) - self.cyclic = state.get("cyclic", False) - self.value = state.get("value", None) - self.uncertainty = state.get("uncertainty", None) - self.prof = state.get("prof", None) - self.locked = save_locked - - def flat_detach(self): - """Due to the system used to track and update values in the DAG, some - parts of the computational graph used to determine gradients - may linger after calling .backward on a model using the - parameters. This function essentially resets all the leaf - values so that the full computational graph is freed. - - """ - for P in self.flat().values(): - P.value = P.value.detach() - if P.uncertainty is not None: - P.uncertainty = P.uncertainty.detach() - if P.prof is not None: - P.prof = P.prof.detach() - - @property - def size(self): - if self.leaf: - return self.value.numel() - return self.vector_values().numel() - - def __len__(self): - """The number of elements required to fully describe the DAG. This is - the number of elements in the vector_values tensor. - - """ - return self.size - - def print_params(self, include_locked=True, include_prof=True, include_id=True): - if self.leaf: - return ( - f"{self.name}" - + (f" (id-{self.identity})" if include_id else "") - + f": {self.value.detach().cpu().tolist()}" - + ( - "" - if self.uncertainty is None - else f" +- {self.uncertainty.detach().cpu().tolist()}" - ) - + f" [{self.units}]" - + ( - "" - if self.limits[0] is None and self.limits[1] is None - else f", limits: ({None if self.limits[0] is None else self.limits[0].detach().cpu().tolist()}, {None if self.limits[1] is None else self.limits[1].detach().cpu().tolist()})" - ) - + (", cyclic" if self.cyclic else "") - + (", locked" if self.locked else "") - + ( - f", prof: {self.prof.detach().cpu().tolist()}" - if include_prof and self.prof is not None - else "" - ) - ) - elif isinstance(self._value, Parameter_Node): - return ( - self.name - + (f" (id-{self.identity})" if include_id else "") - + " points to: " - + self._value.print_params( - include_locked=include_locked, - include_prof=include_prof, - include_id=include_id, - ) - ) - return ( - self.name - + ( - f" (id-{self.identity}, {('function node, '+self._value.__name__) if isinstance(self._value, FunctionType) else 'branch node'})" - if include_id - else "" - ) - + ":\n" - ) - - def __str__(self): - reply = self.print_params(include_locked=True, include_prof=False, include_id=False) - if self.leaf or isinstance(self._value, Parameter_Node): - return reply - reply += "\n".join( - node.print_params(include_locked=True, include_prof=False, include_id=False) - for node in self.flat(include_locked=True, include_links=False).values() - ) - return reply - - def __repr__(self, level=0, indent=" "): - reply = indent * level + self.print_params( - include_locked=True, include_prof=False, include_id=True - ) - if self.leaf or isinstance(self._value, Parameter_Node): - return reply - reply += "\n".join( - node.__repr__(level=level + 1, indent=indent) for node in self.nodes.values() - ) - return reply diff --git a/astrophot/parse_config/__init__.py b/astrophot/parse_config/__init__.py deleted file mode 100644 index 1a1aaec3..00000000 --- a/astrophot/parse_config/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .basic_config import * -from .galfit_config import * diff --git a/astrophot/parse_config/basic_config.py b/astrophot/parse_config/basic_config.py deleted file mode 100644 index 72da3256..00000000 --- a/astrophot/parse_config/basic_config.py +++ /dev/null @@ -1,121 +0,0 @@ -import sys -import os -import importlib -import numpy as np -from astropy.io import fits -from ..image import Target_Image -from ..models import AstroPhot_Model -from ..fit import LM -from .. import AP_config - -__all__ = ["basic_config"] - - -def GetOptions(c): - newoptions = {} - for var in dir(c): - if var.startswith("ap_"): - val = getattr(c, var) - if val is not None: - newoptions[var] = val - return newoptions - - -def import_configfile(config_file): - if "/" in config_file: - startat = config_file.rfind("/") + 1 - else: - startat = 0 - if "." in config_file: - use_config = config_file[startat : config_file.rfind(".")] - else: - use_config = config_file[startat:] - if startat > 0: - sys.path.append(os.path.abspath(config_file[: config_file.rfind("/")])) - else: - sys.path.append(os.getcwd()) - c = importlib.import_module(use_config) - return c - - -def basic_config(config_file): - c = import_configfile(config_file) # importlib.import_module(config_file) - config = GetOptions(c) - - # Parse Target - ###################################################################### - AP_config.ap_logger.info("Collecting target information") - target = config.get("ap_target", None) - if target is None: - target_file = config.get("ap_target_file", None) - target_hdu = config.get("ap_target_hdu", 0) - variance_file = config.get("ap_variance_file", None) - variance_hdu = config.get("ap_variance_hdu", 0) - target_pixelscale = config.get("ap_target_pixelscale", None) - target_zeropoint = config.get("ap_target.zeropoint", None) - target_origin = config.get("ap_target_origin", None) - - if variance_file is not None: - var_data = np.array(fits.open(target_file)[target_hdu].data, dtype=np.float64) - else: - var_data = None - if target_file is not None: - data = np.array(fits.open(target_file)[target_hdu].data, dtype=np.float64) - target = Target_Image( - data=data, - pixelscale=target_pixelscale, - zeropoint=target_zeropoint, - variance=var_data, - origin=target_origin, - ) - - # Parse Models - ###################################################################### - AP_config.ap_logger.info("Constructing models") - model_info_list = config.get("ap_models", []) - name_order = config.get( - "ap_model_name_order", - list(n[9:] for n in filter(lambda k: k.startswith("ap_model_"), config.keys())), - ) - for name in name_order: - key_name = "ap_model_" + name - model_info_list.append(config[key_name]) - if "name" not in model_info_list[-1]: - model_info_list[-1]["name"] = name - model_list = [] - for model in model_info_list: - model_list.append(AstroPhot_Model(target=target, **model)) - - MODEL = AstroPhot_Model( - name="AstroPhot", - model_type="group model", - models=model_list, - target=target, - ) - - # Parse Optimize - ###################################################################### - AP_config.ap_logger.info("Running optimization") - MODEL.initialize() - - optim_type = config.get("ap_optimizer", "LM") - optim_kwargs = config.get("ap_optimizer_kwargs", {}) - if optim_type is None: - # perform no optimization, simply write the astrophot model and the requested images - pass - elif optim_type == "LM": - result = LM(MODEL, **optim_kwargs).fit() - - # Parse Save - ###################################################################### - AP_config.ap_logger.info("Saving model") - model_save = config.get("ap_saveto_model", "AstroPhot.yaml") - MODEL.save(model_save) - - model_image_save = config.get("ap_saveto_model_image", None) - if model_image_save is not None: - MODEL().save(model_image_save) - - model_residual_save = config.get("ap_saveto_model_residual", None) - if model_residual_save is not None: - (target - MODEL()).save(model_residual_save) diff --git a/astrophot/parse_config/galfit_config.py b/astrophot/parse_config/galfit_config.py deleted file mode 100644 index 2248043c..00000000 --- a/astrophot/parse_config/galfit_config.py +++ /dev/null @@ -1,127 +0,0 @@ -__all__ = ["galfit_config"] - -galfit_object_type_map = { - "sersic": "sersic galaxy model", - "sky": "flat sky model", -} - -galfit_parameter_map = { - "sersic galaxy model": { - "1": ["centerpix", 2], - "3": ["totalmag", 1], - "4": ["Repix", 1], - "5": ["n", 1], - "9": ["q", 1], - "10": ["PAdeg", 1], - } -} - - -def space_split(l): - items = list(ls.strip() for ls in l.split(" ")) - index = 0 - while index < len(items): - if items[index] == "": - items.pop(index) - else: - index += 1 - return items - - -def galfit_config(config_file): - if True: - raise NotImplementedError("galfit configuration file interface under construction") - with open(config_file, "r") as f: - config_lines = f.readlines() - # Header info - headerinfo = {} - for line in config_lines: - # remove comment from line and strip whitespace - comment = line.find("#") - if comment >= 0: - line = line[:comment].strip() - if line == "": - continue - if line.startswith("A)"): - headerinfo["target_file"] = line[2:].strip() - if line.startswith("B)"): - headerinfo["saveto_model"] = line[2:].strip() - if line.startswith("C)"): - headerinfo["varaince_file"] = line[2:].strip() - if line.startswith("D)"): - headerinfo["psf_file"] = line[2:].strip() - if line.startswith("E)"): - headerinfo["psf_upample"] = line[2:].strip() - if line.startswith("F)"): - headerinfo["mask_file"] = line[2:].strip() - if line.startswith("G)"): - headerinfo["constraints_file"] = line[2:].strip() - if line.startswith("H)"): - headerinfo["fit_window"] = line[2:].strip() - if line.startswith("I)"): - headerinfo["convolution_window"] = line[2:].strip() - if line.startswith("J)"): - headerinfo["target_zeropoint"] = line[2:].strip() - if line.startswith("K)"): - headerinfo["target_pixelscale"] = line[2:].strip() - - # Object info - objects = [] - in_object = False - for line in config_lines: - # remove comment from line and strip whitespace - comment = line.find("#") - if comment >= 0: - linem = line[:comment].strip() - if linem == "": - continue - - # New model added to the fit - if linem.startswith("0)"): - objects.append({"model_type": galfit_object_type_map[linem[2:].strip()]}) - in_object = True - # Model finished adding - if linem.startswith("Z)"): - in_object = False - - # Collect the parameters - if in_object: - param = linem[: linem.find(")")] - objects[-1][galfit_parameter_map[objects[-1]["model_type"]][param][0]] = space_split( - linem[linem.find(")") + 1 :] - ) - if len(objects[-1][galfit_parameter_map[objects[-1]["model_type"]][param][0]]) != ( - 2 * galfit_parameter_map[objects[-1]["model_type"]][param][1] - ): - raise ValueError(f"Incorrectly formatted line in GALFIT config file:\n{line}") - - # Format parameters - for i in range(len(objects)): - astrophot_object = { - "model_type": objects[i]["model_type"], - } - - # common params - if "centerpix" in objects[i]: - astrophot_object["center"] = { - "value": [ - float(objects[i]["centerpix"][0]) * headerinfo["target_pixelscale"], - float(objects[i]["centerpix"][1]) * headerinfo["target_pixelscale"], - ], - "locked": bool(objects[i]["centerpix"][2]), - } - if "Repix" in objects[i]: - astrophot_object["Re"] = { - "value": float(objects[i]["Repix"][0]) * headerinfo["target_pixelscale"], - "locked": bool(objects[i]["Repix"][1]), - } - if "q" in objects[i]: - astrophot_object["q"] = { - "value": float(objects[i]["q"][0]), - "locked": bool(objects[i]["q"][1]), - } - if "PAdeg" in objects[i]: - astrophot_object["PA"] = { - "value": float(objects[i]["PAdeg"][0]) * np.pi / 180, - "locked": bool(objects[i]["PAdeg"][1]), - } diff --git a/astrophot/parse_config/shared_methods.py b/astrophot/parse_config/shared_methods.py deleted file mode 100644 index e69de29b..00000000 diff --git a/astrophot/plots/__init__.py b/astrophot/plots/__init__.py index e5799f23..2981a510 100644 --- a/astrophot/plots/__init__.py +++ b/astrophot/plots/__init__.py @@ -1,4 +1,25 @@ -from .profile import * -from .image import * -from .visuals import * -from .diagnostic import * +from .profile import ( + radial_light_profile, + radial_median_profile, + ray_light_profile, + warp_phase_profile, +) +from .image import target_image, model_image, residual_image, model_window, psf_image +from .visuals import main_pallet, cmap_div, cmap_grad +from .diagnostic import covariance_matrix + +__all__ = ( + "radial_light_profile", + "radial_median_profile", + "ray_light_profile", + "warp_phase_profile", + "target_image", + "model_image", + "residual_image", + "model_window", + "psf_image", + "main_pallet", + "cmap_div", + "cmap_grad", + "covariance_matrix", +) diff --git a/astrophot/plots/diagnostic.py b/astrophot/plots/diagnostic.py index a9c161ba..1e0df730 100644 --- a/astrophot/plots/diagnostic.py +++ b/astrophot/plots/diagnostic.py @@ -3,6 +3,7 @@ from matplotlib.patches import Ellipse from matplotlib import pyplot as plt from scipy.stats import norm +from .visuals import main_pallet __all__ = ("covariance_matrix",) @@ -13,10 +14,24 @@ def covariance_matrix( labels=None, figsize=(10, 10), reference_values=None, - ellipse_colors="g", + ellipse_colors=main_pallet["primary1"], showticks=True, **kwargs, ): + """ + Create a covariance matrix plot. Creates a corner plot with ellipses representing the covariance between parameters. + + **Args:** + - `covariance_matrix` (np.ndarray): Covariance matrix of shape (n_params, n_params). + - `mean` (np.ndarray): Mean values of the parameters, shape (n_params,). + - `labels` (list, optional): Labels for the parameters. + - `figsize` (tuple, optional): Size of the figure. Default is (10, 10). + - `reference_values` (np.ndarray, optional): Reference values for the parameters, used to draw vertical and horizontal lines. Typically these are the true values of the parameters. + - `ellipse_colors` (str or list, optional): Color for the ellipses. Default is `main_pallet["primary1"]`. + - `showticks` (bool, optional): Whether to show ticks on the axes. Default is True. + + returns the fig and ax objects created to allow further customization by the user. + """ num_params = covariance_matrix.shape[0] fig, axes = plt.subplots(num_params, num_params, figsize=figsize) plt.subplots_adjust(wspace=0.0, hspace=0.0) @@ -32,13 +47,13 @@ def covariance_matrix( 100, ) y = norm.pdf(x, mean[i], np.sqrt(covariance_matrix[i, i])) - ax.plot(x, y, color="g") + ax.plot(x, y, color=ellipse_colors, lw=1.5) ax.set_xlim( mean[i] - 3 * np.sqrt(covariance_matrix[i, i]), mean[i] + 3 * np.sqrt(covariance_matrix[i, i]), ) if reference_values is not None: - ax.axvline(reference_values[i], color="red", linestyle="-", lw=1) + ax.axvline(reference_values[i], color=main_pallet["pop"], linestyle="-", lw=1) elif j < i: cov = covariance_matrix[np.ix_([j, i], [j, i])] lambda_, v = np.linalg.eig(cov) @@ -52,6 +67,7 @@ def covariance_matrix( angle=angle, edgecolor=ellipse_colors, facecolor="none", + lw=1.5, ) ax.add_artist(ellipse) @@ -67,8 +83,8 @@ def covariance_matrix( ) if reference_values is not None: - ax.axvline(reference_values[j], color="red", linestyle="-", lw=1) - ax.axhline(reference_values[i], color="red", linestyle="-", lw=1) + ax.axvline(reference_values[j], color=main_pallet["pop"], linestyle="-", lw=1) + ax.axhline(reference_values[i], color=main_pallet["pop"], linestyle="-", lw=1) if j > i: ax.axis("off") diff --git a/astrophot/plots/image.py b/astrophot/plots/image.py index 9a3cb89d..fc0aba8a 100644 --- a/astrophot/plots/image.py +++ b/astrophot/plots/image.py @@ -1,3 +1,4 @@ +from typing import Literal, Optional, Union import numpy as np import torch @@ -6,31 +7,33 @@ import matplotlib from scipy.stats import iqr -from ..models import Group_Model, PSF_Model, AstroPhot_Model -from ..image import Image_List, Window_List -from .. import AP_config +from ..models import GroupModel, PSFModel, PSFGroupModel +from ..image import ImageList, WindowList, PSFImage +from .. import config +from ..backend_obj import backend from ..utils.conversions.units import flux_to_sb +from ..utils.decorators import ignore_numpy_warnings from .visuals import * -__all__ = ["target_image", "psf_image", "model_image", "residual_image", "model_window"] +__all__ = ("target_image", "psf_image", "model_image", "residual_image", "model_window") +@ignore_numpy_warnings def target_image(fig, ax, target, window=None, **kwargs): """ - This function is used to display a target image using the provided figure and axes. - - Args: - fig (matplotlib.figure.Figure): The figure object in which the target image will be displayed. - ax (matplotlib.axes.Axes): The axes object on which the target image will be plotted. - target (Image or Image_List): The image or list of images to be displayed. - window (Window, optional): The window through which the image is viewed. If `None`, the window of the - provided `target` is used. Defaults to `None`. - **kwargs: Arbitrary keyword arguments. - - Returns: - fig (matplotlib.figure.Figure): The figure object containing the displayed target image. - ax (matplotlib.axes.Axes): The axes object containing the displayed target image. + This function is used to display a target image using the provided figure + and axes. The target is plotted using histogram equalization for better + visibility of the image data for the faint areas of the image, while it uses + log scale normalization for the bright areas. + + **Args:** + - `fig` (matplotlib.figure.Figure): The figure object in which the target image will be displayed. + - `ax` (matplotlib.axes.Axes): The axes object on which the target image will be plotted. + - `target` (Image or Image_List): The image or list of images to be displayed. + - `window` (Window, optional): The window through which the image is viewed. If `None`, the window of the + provided `target` is used. Defaults to `None`. + - **kwargs: Arbitrary keyword arguments. Note: If the `target` is an `Image_List`, this function will recursively call itself for each image in the list. @@ -38,27 +41,23 @@ def target_image(fig, ax, target, window=None, **kwargs): """ # recursive call for target image list - if isinstance(target, Image_List): - for i in range(len(target.image_list)): - target_image(fig, ax[i], target.image_list[i], window=window, **kwargs) + if isinstance(target, ImageList): + for i in range(len(target.images)): + target_image(fig, ax[i], target.images[i], window=window, **kwargs) return fig, ax if window is None: window = target.window - if kwargs.get("flipx", False): - ax.invert_xaxis() target_area = target[window] - dat = np.copy(target_area.data.detach().cpu().numpy()) - if target_area.has_mask: - dat[target_area.mask.detach().cpu().numpy()] = np.nan - X, Y = target_area.get_coordinate_corner_meshgrid() - X = X.detach().cpu().numpy() - Y = Y.detach().cpu().numpy() + + dat = np.copy(backend.to_numpy(target_area._data)) + dat[backend.to_numpy(target_area._mask)] = np.nan + X, Y = target_area.coordinate_corner_meshgrid() + X = backend.to_numpy(X) + Y = backend.to_numpy(Y) sky = np.nanmedian(dat) noise = iqr(dat[np.isfinite(dat)], rng=(16, 84)) / 2 if noise == 0: noise = np.nanstd(dat) - vmin = sky - 5 * noise - vmax = sky + 5 * noise if kwargs.get("linear", False): im = ax.pcolormesh( @@ -72,7 +71,7 @@ def target_image(fig, ax, target, window=None, **kwargs): X, Y, dat, - cmap="Greys", + cmap="gray_r", norm=ImageNormalize( stretch=HistEqStretch( dat[np.logical_and(dat <= (sky + 3 * noise), np.isfinite(dat))] @@ -93,6 +92,8 @@ def target_image(fig, ax, target, window=None, **kwargs): clim=[sky + 3 * noise, None], ) + if np.linalg.det(target.CD.npvalue) < 0: + ax.invert_xaxis() ax.axis("equal") ax.set_xlabel("Tangent Plane X [arcsec]") ax.set_ylabel("Tangent Plane Y [arcsec]") @@ -101,54 +102,57 @@ def target_image(fig, ax, target, window=None, **kwargs): @torch.no_grad() +@ignore_numpy_warnings def psf_image( fig, ax, - psf, - window=None, - cmap_levels=None, - flipx=False, + psf: Union[PSFImage, PSFModel, PSFGroupModel], + cmap_levels: Optional[int] = None, + vmin: Optional[float] = None, + vmax: Optional[float] = None, **kwargs, ): - if isinstance(psf, AstroPhot_Model): + """For plotting PSF images, or the output of a PSF model. + + **Args:** + - `fig` (matplotlib.figure.Figure): The figure object in which the PSF image will be displayed. + - `ax` (matplotlib.axes.Axes): The axes object on which the PSF image will be plotted. + - `psf` (PSFImage or PSFModel or PSFGroupModel): The PSF model or group model to be displayed. + - `cmap_levels` (int, optional): The number of discrete levels to convert the continuous color map to. If not `None`, the color map is converted to a ListedColormap with the specified number of levels. Defaults to `None`. + - `vmin` (float, optional): The minimum value for the color scale. Defaults to `None`. + - `vmax` (float, optional): The maximum value for the color scale. Defaults to `None`. + """ + if isinstance(psf, (PSFModel, PSFGroupModel)): psf = psf() # recursive call for target image list - if isinstance(psf, Image_List): - for i in range(len(psf.image_list)): - psf_image(fig, ax[i], psf.image_list[i], window=window, **kwargs) + if isinstance(psf, ImageList): + for i in range(len(psf.images)): + psf_image(fig, ax[i], psf.images[i], **kwargs) return fig, ax - if window is None: - window = psf.window - if flipx: - ax.invert_xaxis() - - # cut out the requested window - psf = psf[window] - # Evaluate the model image - X, Y = psf.get_coordinate_corner_meshgrid() - X = X.detach().cpu().numpy() - Y = Y.detach().cpu().numpy() - psf = psf.data.detach().cpu().numpy() + x, y = psf.coordinate_corner_meshgrid() + x = backend.to_numpy(x) + y = backend.to_numpy(y) + psf = backend.to_numpy(psf._data) # Default kwargs for image - imshow_kwargs = { + kwargs = { "cmap": cmap_grad, - "norm": matplotlib.colors.LogNorm(), # "norm": ImageNormalize(stretch=LogStretch(), clip=False), + "norm": matplotlib.colors.LogNorm( + vmin=vmin, vmax=vmax + ), # "norm": ImageNormalize(stretch=LogStretch(), clip=False), + **kwargs, } - # Update with user provided kwargs - imshow_kwargs.update(kwargs) - # if requested, convert the continuous colourmap into discrete levels if cmap_levels is not None: - imshow_kwargs["cmap"] = matplotlib.colors.ListedColormap( - list(imshow_kwargs["cmap"](c) for c in np.linspace(0.0, 1.0, cmap_levels)) + kwargs["cmap"] = matplotlib.colors.ListedColormap( + list(kwargs["cmap"](c) for c in np.linspace(0.0, 1.0, cmap_levels)) ) # Plot the image - im = ax.pcolormesh(X, Y, psf, **imshow_kwargs) + ax.pcolormesh(x, y, psf, **kwargs) # Enforce equal spacing on x y ax.axis("equal") @@ -159,6 +163,7 @@ def psf_image( @torch.no_grad() +@ignore_numpy_warnings def model_image( fig, ax, @@ -166,39 +171,31 @@ def model_image( sample_image=None, window=None, target=None, - showcbar=True, - target_mask=False, - cmap_levels=None, - flipx=False, - magunits=True, - sample_full_image=False, + showcbar: bool = True, + target_mask: bool = False, + cmap_levels: Optional[int] = None, + magunits: bool = True, + vmin: Optional[float] = None, + vmax: Optional[float] = None, **kwargs, ): """ This function is used to generate a model image and display it using the provided figure and axes. - Args: - fig (matplotlib.figure.Figure): The figure object in which the image will be displayed. - ax (matplotlib.axes.Axes): The axes object on which the image will be plotted. - model (Model): The model object used to generate a model image if `sample_image` is not provided. - sample_image (Image or Image_List, optional): The image or list of images to be displayed. - If `None`, a model image is generated using the provided `model`. Defaults to `None`. - window (Window, optional): The window through which the image is viewed. If `None`, the window of the - provided `model` is used. Defaults to `None`. - target (Target, optional): The target or list of targets for the image or image list. - If `None`, the target of the `model` is used. Defaults to `None`. - showcbar (bool, optional): Whether to show the color bar. Defaults to `True`. - target_mask (bool, optional): Whether to apply the mask of the target. If `True` and if the target has a mask, - the mask is applied to the image. Defaults to `False`. - cmap_levels (int, optional): The number of discrete levels to convert the continuous color map to. - If not `None`, the color map is converted to a ListedColormap with the specified number of levels. - Defaults to `None`. - sample_full_image: If True, every model will be sampled on the full image window. If False (default) each model will only be sampled in its fitting window. - **kwargs: Arbitrary keyword arguments. These are used to override the default imshow_kwargs. - - Returns: - fig (matplotlib.figure.Figure): The figure object containing the displayed image. - ax (matplotlib.axes.Axes): The axes object containing the displayed image. + **Args:** + - `fig` (matplotlib.figure.Figure): The figure object in which the image will be displayed. + - `ax` (matplotlib.axes.Axes): The axes object on which the image will be plotted. + - `model` (Model): The model object used to generate a model image if `sample_image` is not provided. + - `sample_image` (Image or Image_List, optional): The image or list of images to be displayed. If `None`, a model image is generated using the provided `model`. Defaults to `None`. + - `window` (Window, optional): The window through which the image is viewed. If `None`, the window of the provided `model` is used. Defaults to `None`. + - `target` (Target, optional): The target or list of targets for the image or image list. If `None`, the target of the `model` is used. Defaults to `None`. + - `showcbar` (bool, optional): Whether to show the color bar. Defaults to `True`. + - `target_mask` (bool, optional): Whether to apply the mask of the target. If `True` and if the target has a mask, the mask is applied to the image. Defaults to `False`. + - `cmap_levels` (int, optional): The number of discrete levels to convert the continuous color map to. If not `None`, the color map is converted to a ListedColormap with the specified number of levels. Defaults to `None`. + - `magunits` (bool, optional): Whether to convert the image to surface brightness units. If `True`, the zeropoint of the target is used to convert the image to surface brightness units. Defaults to `True`. + - `vmin` (float, optional): The minimum value for the color scale. Defaults to `None`. + - `vmax` (float, optional): The maximum value for the color scale. Defaults to `None`. + - **kwargs: Arbitrary keyword arguments. These are used to override the default imshow_kwargs. Note: If the `sample_image` is an `Image_List`, this function will recursively call itself for each image in the list, @@ -206,11 +203,7 @@ def model_image( """ if sample_image is None: - if sample_full_image: - sample_image = model.make_model_image() - sample_image = model(sample_image) - else: - sample_image = model() + sample_image = model() # Use model target if not given if target is None: @@ -221,63 +214,68 @@ def model_image( window = model.window # Handle image lists - if isinstance(sample_image, Image_List): - for i, images in enumerate(zip(sample_image, target, window)): + if isinstance(sample_image, ImageList): + for i, (images, targets, windows) in enumerate(zip(sample_image, target, window)): model_image( fig, ax[i], model, - sample_image=images[0], - window=images[2], - target=images[1], + sample_image=images, + window=windows, + target=targets, showcbar=showcbar, target_mask=target_mask, cmap_levels=cmap_levels, - flipx=flipx, magunits=magunits, + vmin=vmin, + vmax=vmax, **kwargs, ) return fig, ax - if flipx: - ax.invert_xaxis() - # cut out the requested window sample_image = sample_image[window] # Evaluate the model image - X, Y = sample_image.get_coordinate_corner_meshgrid() - X = X.detach().cpu().numpy() - Y = Y.detach().cpu().numpy() - sample_image = sample_image.data.detach().cpu().numpy() + X, Y = sample_image.coordinate_corner_meshgrid() + X = backend.to_numpy(X) + Y = backend.to_numpy(Y) + sample_image = backend.to_numpy(sample_image._data) # Default kwargs for image - imshow_kwargs = { + kwargs = { "cmap": cmap_grad, - "norm": matplotlib.colors.LogNorm(), # "norm": ImageNormalize(stretch=LogStretch(), clip=False), + **kwargs, } - # Update with user provided kwargs - imshow_kwargs.update(kwargs) - # if requested, convert the continuous colourmap into discrete levels if cmap_levels is not None: - imshow_kwargs["cmap"] = matplotlib.colors.ListedColormap( - list(imshow_kwargs["cmap"](c) for c in np.linspace(0.0, 1.0, cmap_levels)) + kwargs["cmap"] = matplotlib.colors.ListedColormap( + list(kwargs["cmap"](c) for c in np.linspace(0.0, 1.0, cmap_levels)) ) # If zeropoint is available, convert to surface brightness units if target.zeropoint is not None and magunits: sample_image = flux_to_sb(sample_image, target.pixel_area.item(), target.zeropoint.item()) - del imshow_kwargs["norm"] - imshow_kwargs["cmap"] = imshow_kwargs["cmap"].reversed() + kwargs["cmap"] = kwargs["cmap"].reversed() + kwargs["vmin"] = vmin + kwargs["vmax"] = vmax + else: + kwargs = { + "norm": matplotlib.colors.LogNorm( + vmin=vmin, vmax=vmax + ), # "norm": ImageNormalize(stretch=LogStretch(), clip=False), + **kwargs, + } # Apply the mask if available - if target_mask and target.has_mask: - sample_image[target.mask.detach().cpu().numpy()] = np.nan + sample_image[backend.to_numpy(target[window]._mask)] = np.nan # Plot the image - im = ax.pcolormesh(X, Y, sample_image, **imshow_kwargs) + im = ax.pcolormesh(X, Y, sample_image, **kwargs) + + if np.linalg.det(target.CD.npvalue) < 0: + ax.invert_xaxis() # Enforce equal spacing on x y ax.axis("equal") @@ -296,6 +294,7 @@ def model_image( @torch.no_grad() +@ignore_numpy_warnings def residual_image( fig, ax, @@ -304,40 +303,27 @@ def residual_image( sample_image=None, showcbar=True, window=None, - center_residuals=False, clb_label=None, normalize_residuals=False, - flipx=False, - sample_full_image=False, + scaling: Literal["arctan", "clip", "none"] = "arctan", **kwargs, ): """ This function is used to calculate and display the residuals of a model image with respect to a target image. - The residuals are calculated as the difference between the target image and the sample image. - - Args: - fig (matplotlib.figure.Figure): The figure object in which the residuals will be displayed. - ax (matplotlib.axes.Axes): The axes object on which the residuals will be plotted. - model (Model): The model object used to generate a model image if `sample_image` is not provided. - target (Target or Image_List, optional): The target or list of targets for the image or image list. - If `None`, the target of the `model` is used. Defaults to `None`. - sample_image (Image or Image_List, optional): The image or list of images from which residuals will be calculated. - If `None`, a model image is generated using the provided `model`. Defaults to `None`. - showcbar (bool, optional): Whether to show the color bar. Defaults to `True`. - window (Window or Window_List, optional): The window through which the image is viewed. If `None`, the window of the - provided `model` is used. Defaults to `None`. - center_residuals (bool, optional): Whether to subtract the median of the residuals. If `True`, the median is subtracted - from the residuals. Defaults to `False`. - clb_label (str, optional): The label for the colorbar. If `None`, a default label is used based on the normalization of the - residuals. Defaults to `None`. - normalize_residuals (bool, optional): Whether to normalize the residuals. If `True`, residuals are divided by the square root - of the variance of the target. Defaults to `False`. - sample_full_image: If True, every model will be sampled on the full image window. If False (default) each model will only be sampled in its fitting window. - **kwargs: Arbitrary keyword arguments. These are used to override the default imshow_kwargs. - - Returns: - fig (matplotlib.figure.Figure): The figure object containing the displayed residuals. - ax (matplotlib.axes.Axes): The axes object containing the displayed residuals. + The residuals are calculated as the difference between the target image and the sample image and may be normalized by the standard deviation. + + **Args:** + - `fig` (matplotlib.figure.Figure): The figure object in which the residuals will be displayed. + - `ax` (matplotlib.axes.Axes): The axes object on which the residuals will be plotted. + - `model` (Model): The model object used to generate a model image if `sample_image` is not provided. + - `target` (Target or Image_List, optional): The target or list of targets for the image or image list. If `None`, the target of the `model` is used. Defaults to `None`. + - `sample_image` (Image or Image_List, optional): The image or list of images from which residuals will be calculated. If `None`, a model image is generated using the provided `model`. Defaults to `None`. + - `showcbar` (bool, optional): Whether to show the color bar. Defaults to `True`. + - `window` (Window or Window_List, optional): The window through which the image is viewed. If `None`, the window of the provided `model` is used. Defaults to `None`. + - `clb_label` (str, optional): The label for the colorbar. If `None`, a default label is used based on the normalization of the residuals. Defaults to `None`. + - `normalize_residuals` (bool, optional): Whether to normalize the residuals. If `True`, residuals are divided by the square root of the variance of the target. Defaults to `False`. + - `scaling` (str, optional): The scaling method for the residuals. Options are "arctan", "clip", or "none". arctan will show all residuals, though squish high values to make the fainter residuals more visible, clip will show the residuals in linear space but remove any values above/below 5 sigma, none does no scaling and simply shows the residuals in linear space. Defaults to "arctan". + - `**kwargs`: Arbitrary keyword arguments. These are used to override the default imshow_kwargs. Note: If the `window`, `target`, or `sample_image` are lists, this function will recursively call itself for each element in the list, @@ -350,12 +336,8 @@ def residual_image( if target is None: target = model.target if sample_image is None: - if sample_full_image: - sample_image = model.make_model_image() - sample_image = model(sample_image) - else: - sample_image = model() - if isinstance(window, Window_List) or isinstance(target, Image_List): + sample_image = model() + if isinstance(window, WindowList) or isinstance(target, ImageList): for i_ax, win, tar, sam in zip(ax, window, target, sample_image): residual_image( fig, @@ -365,89 +347,117 @@ def residual_image( sample_image=sam, window=win, showcbar=showcbar, - center_residuals=center_residuals, clb_label=clb_label, normalize_residuals=normalize_residuals, - flipx=flipx, + scaling=scaling, **kwargs, ) return fig, ax - if flipx: - ax.invert_xaxis() - X, Y = sample_image[window].get_coordinate_corner_meshgrid() - X = X.detach().cpu().numpy() - Y = Y.detach().cpu().numpy() - residuals = (target[window] - sample_image[window]).data - if isinstance(normalize_residuals, bool) and normalize_residuals: - residuals = residuals / torch.sqrt(target[window].variance) - elif isinstance(normalize_residuals, torch.Tensor): - residuals = residuals / torch.sqrt(normalize_residuals) + sample_image = sample_image[window] + target = target[window] + X, Y = sample_image.coordinate_corner_meshgrid() + X = backend.to_numpy(X) + Y = backend.to_numpy(Y) + residuals = (target - sample_image)._data + + if normalize_residuals is True: + residuals = residuals / backend.sqrt(target._variance) + elif isinstance(normalize_residuals, backend.array_type): + residuals = residuals / backend.sqrt(normalize_residuals) normalize_residuals = True - residuals = residuals.detach().cpu().numpy() - - if target.has_mask: - residuals[target[window].mask.detach().cpu().numpy()] = np.nan - if center_residuals: - residuals -= np.nanmedian(residuals) - residuals = np.arctan(residuals / (iqr(residuals[np.isfinite(residuals)], rng=[10, 90]) * 2)) - extreme = np.max(np.abs(residuals[np.isfinite(residuals)])) + residuals = backend.to_numpy(residuals) + residuals[backend.to_numpy(target._mask)] = np.nan + + if scaling == "clip": + if normalize_residuals is not True: + config.logger.warning( + "Using clipping scaling without normalizing residuals. This may lead to confusing results." + ) + residuals = np.clip(residuals, -5, 5) + vmax = 5 + default_label = ( + f"(Target - {model.name}) / $\\sigma$" + if normalize_residuals + else f"(Target - {model.name})" + ) + elif scaling == "arctan": + residuals = np.arctan( + residuals / (iqr(residuals[np.isfinite(residuals)], rng=[10, 90]) * 2) + ) + vmax = np.pi / 2 + if normalize_residuals: + default_label = f"tan$^{{-1}}$((Target - {model.name}) / $\\sigma$)" + else: + default_label = f"tan$^{{-1}}$(Target - {model.name})" + elif scaling == "none": + vmax = np.max(np.abs(residuals[np.isfinite(residuals)])) + default_label = ( + f"(Target - {model.name}) / $\\sigma$" + if normalize_residuals + else f"(Target - {model.name})" + ) + else: + raise ValueError(f"Unknown scaling type {scaling}. Use 'clip', 'arctan', or 'none'.") imshow_kwargs = { "cmap": cmap_div, - "vmin": -extreme, - "vmax": extreme, + "vmin": -vmax, + "vmax": vmax, } imshow_kwargs.update(kwargs) im = ax.pcolormesh(X, Y, residuals, **imshow_kwargs) + if np.linalg.det(target.CD.npvalue) < 0: + ax.invert_xaxis() ax.axis("equal") ax.set_xlabel("Tangent Plane X [arcsec]") ax.set_ylabel("Tangent Plane Y [arcsec]") if showcbar: - if normalize_residuals: - default_label = f"tan$^{{-1}}$((Target - {model.name}) / $\\sigma$)" - else: - default_label = f"tan$^{{-1}}$(Target - {model.name})" clb = fig.colorbar(im, ax=ax, label=default_label if clb_label is None else clb_label) clb.ax.set_yticks([]) clb.ax.set_yticklabels([]) return fig, ax +@ignore_numpy_warnings def model_window(fig, ax, model, target=None, rectangle_linewidth=2, **kwargs): + """Used for plotting the window(s) of a model on a target image. These + windows bound the region that a model will be evaluated/fit to. + + **Args:** + - `fig` (matplotlib.figure.Figure): The figure object in which the model window will be displayed. + - `ax` (matplotlib.axes.Axes): The axes object on which the model window will be plotted. + - `model` (Model): The model object whose window will be displayed. + - `target` (Target or Image_List, optional): The target or list of targets for the image or image list. If `None`, the target of the `model` is used. Defaults to `None`. + - `rectangle_linewidth` (int, optional): The linewidth of the rectangle drawn around the model window. Defaults to 2. + - **kwargs: Arbitrary keyword arguments. These are used to override the default rectangle properties. + """ + if target is None: + target = model.target if isinstance(ax, np.ndarray): for i, axitem in enumerate(ax): - model_window(fig, axitem, model, target=model.target.image_list[i], **kwargs) + model_window(fig, axitem, model, target=target.images[i], **kwargs) return fig, ax - if isinstance(model, Group_Model): - for m in model.models.values(): - if isinstance(m.window, Window_List): - use_window = m.window.window_list[m.target.index(target)] + if isinstance(model, GroupModel): + for m in model.models: + if isinstance(m.window, WindowList): + use_window = m.window.windows[m.target.index(target)] else: use_window = m.window - lowright = use_window.pixel_shape.clone().to(dtype=AP_config.ap_dtype) - lowright[1] = 0.0 - lowright = use_window.origin + use_window.pixel_to_plane_delta(lowright) - lowright = lowright.detach().cpu().numpy() - upleft = use_window.pixel_shape.clone().to(dtype=AP_config.ap_dtype) - upleft[0] = 0.0 - upleft = use_window.origin + use_window.pixel_to_plane_delta(upleft) - upleft = upleft.detach().cpu().numpy() - end = use_window.origin + use_window.end - end = end.detach().cpu().numpy() + corners = target[use_window].corners() x = [ - use_window.origin[0].detach().cpu().numpy(), - lowright[0], - end[0], - upleft[0], + corners[0][0].item(), + corners[1][0].item(), + corners[2][0].item(), + corners[3][0].item(), ] y = [ - use_window.origin[1].detach().cpu().numpy(), - lowright[1], - end[1], - upleft[1], + corners[0][1].item(), + corners[1][1].item(), + corners[2][1].item(), + corners[3][1].item(), ] ax.add_patch( Polygon( @@ -455,34 +465,23 @@ def model_window(fig, ax, model, target=None, rectangle_linewidth=2, **kwargs): fill=False, linewidth=rectangle_linewidth, edgecolor=main_pallet["secondary1"], + **kwargs, ) ) else: - if isinstance(model.window, Window_List): - use_window = model.window.window_list[model.target.index(target)] - else: - use_window = model.window - lowright = use_window.pixel_shape.clone().to(dtype=AP_config.ap_dtype) - lowright[1] = 0.0 - lowright = use_window.origin + use_window.pixel_to_plane_delta(lowright) - lowright = lowright.detach().cpu().numpy() - upleft = use_window.pixel_shape.clone().to(dtype=AP_config.ap_dtype) - upleft[0] = 0.0 - upleft = use_window.origin + use_window.pixel_to_plane_delta(upleft) - upleft = upleft.detach().cpu().numpy() - end = use_window.origin + use_window.end - end = end.detach().cpu().numpy() + use_window = model.window + corners = target[use_window].corners() x = [ - use_window.origin[0].detach().cpu().numpy(), - lowright[0], - end[0], - upleft[0], + corners[0][0].item(), + corners[1][0].item(), + corners[2][0].item(), + corners[3][0].item(), ] y = [ - use_window.origin[1].detach().cpu().numpy(), - lowright[1], - end[1], - upleft[1], + corners[0][1].item(), + corners[1][1].item(), + corners[2][1].item(), + corners[3][1].item(), ] ax.add_patch( Polygon( @@ -490,6 +489,7 @@ def model_window(fig, ax, model, target=None, rectangle_linewidth=2, **kwargs): fill=False, linewidth=rectangle_linewidth, edgecolor=main_pallet["secondary1"], + **kwargs, ) ) diff --git a/astrophot/plots/profile.py b/astrophot/plots/profile.py index dceb9ef4..adcdd83b 100644 --- a/astrophot/plots/profile.py +++ b/astrophot/plots/profile.py @@ -5,11 +5,12 @@ import torch from scipy.stats import binned_statistic, iqr -from .. import AP_config -from ..models import Warp_Galaxy +from .. import config +from ..backend_obj import backend +from ..models import Model + from ..utils.conversions.units import flux_to_sb from .visuals import * -from ..errors import InvalidModel __all__ = [ "radial_light_profile", @@ -23,24 +24,36 @@ def radial_light_profile( fig, ax, - model, + model: Model, rad_unit="arcsec", extend_profile=1.0, R0=0.0, resolution=1000, - doassert=True, plot_kwargs={}, ): - xx = torch.linspace( + """ + Used to plot the brightness profile as a function of radius for models which define a `radial_model`. + + **Args:** + - `fig`: matplotlib figure object + - `ax`: matplotlib axis object + - `model` (Model): Model object from which to plot the radial profile. + - `rad_unit` (str): The name of the radius units to plot. If you select "pixel" then the plot will work in pixel units (physical radii divided by pixelscale) if you choose any other string then it will remain in the physical units of the image and the axis label will be whatever you set the value to. Default: "arcsec". Options: "arcsec", "pixel" + - `extend_profile` (float): The factor by which to extend the profile beyond the maximum radius of the model's window. Default: 1.0 + - `R0` (float): The starting radius for the profile. Default: 0.0 + - `resolution` (int): The number of points to use in the profile. Default: 1000 + - `plot_kwargs` (dict): Additional keyword arguments to pass to the plot function, such as `linewidth`, `color`, etc. + """ + xx = backend.linspace( R0, - torch.max(model.window.shape / 2) * extend_profile, + max(model.window.shape) * backend.to_numpy(model.target.pixelscale) * extend_profile / 2, int(resolution), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, + dtype=config.DTYPE, + device=config.DEVICE, ) - flux = model.radial_model(xx).detach().cpu().numpy() + flux = backend.to_numpy(model.radial_model(xx, params=())) if model.target.zeropoint is not None: - yy = flux_to_sb(flux, model.target.pixel_area.item(), model.target.zeropoint.item()) + yy = flux_to_sb(flux, 1.0, model.target.zeropoint.item()) else: yy = np.log10(flux) @@ -50,12 +63,11 @@ def radial_light_profile( "label": f"{model.name} profile", } kwargs.update(plot_kwargs) - with torch.no_grad(): - ax.plot( - xx.detach().cpu().numpy(), - yy, - **kwargs, - ) + ax.plot( + backend.to_numpy(xx), + yy, + **kwargs, + ) if model.target.zeropoint is not None: ax.set_ylabel("Surface Brightness [mag/arcsec$^2$]") @@ -71,16 +83,14 @@ def radial_light_profile( def radial_median_profile( fig, ax, - model: "AstroPhot_Model", + model: Model, count_limit: int = 10, return_profile: bool = False, - rad_unit: Literal["arcsec", "pixel"] = "arcsec", - bin_scale: float = 0.1, - min_bin_width: float = 2, - doassert: bool = True, + rad_unit: str = "arcsec", plot_kwargs: dict = {}, ): - """Plot an SB profile by taking flux median at each radius. + """ + Plot an SB profile by taking flux median at each radius. Using the coordinate transforms defined by the model object, assigns a radius to each pixel then bins the pixel-radii and @@ -88,45 +98,48 @@ def radial_median_profile( representation of the image data if one were to simply average the pixels along isophotes. - Args: - fig: matplotlib figure object - ax: matplotlib axis object - model (AstroPhot_Model): Model object from which to determine the radial binning. Also provides the target image to extract the data - count_limit (int): The limit of pixels in a bin, below which uncertainties are not computed. Default: 10 - return_profile (bool): Instead of just returning the fig and ax object, will return the extracted profile formatted as: Rbins (the radial bin edges), medians (the median in each bin), scatter (the 16-84 quartile range / 2), count (the number of pixels in each bin). Default: False - rad_unit (str): The name of the radius units to plot. If you select "pixel" then the plot will work in pixel units (physical radii divided by pixelscale) if you choose any other string then it will remain in the physical units of the image and the axis label will be whatever you set the value to. Default: "arcsec". Options: "arcsec", "pixel" - bin_scale (float): The geometric scaling factor for the binning, each bin will be this much larger than the previous. Default: 0.1 - min_bin_width (float): The minimum width of a bin in pixel units, default is 2 so that each bin will have some data to compute the median with. Default: 2 - doassert (bool): If any requirements are imposed on which kind of profile can be plotted, this activates them. Default: True + **Args:** + - `fig`: matplotlib figure object + - `ax`: matplotlib axis object + - `model` (AstroPhot_Model): Model object from which to determine the radial binning. Also provides the target image to extract the data + - `count_limit` (int): The limit of pixels in a bin, below which uncertainties are not computed. Default: 10 + - `return_profile` (bool): Instead of just returning the fig and ax object, will return the extracted profile formatted as: Rbins (the radial bin edges), medians (the median in each bin), scatter (the 16-84 quartile range / 2), count (the number of pixels in each bin). Default: False + - `rad_unit` (str): The name of the radius units to plot. If you select "pixel" then the plot will work in pixel units (physical radii divided by pixelscale) if you choose any other string then it will remain in the physical units of the image and the axis label will be whatever you set the value to. Default: "arcsec". Options: "arcsec", "pixel" + - `plot_kwargs` (dict): Additional keyword arguments to pass to the plot function, such as `linewidth`, `color`, etc. """ - Rlast_phys = torch.max(model.window.shape / 2).item() - Rlast_pix = Rlast_phys / model.target.pixel_length.item() + Rlast_pix = max(model.window.shape) / 2 + Rlast_phys = Rlast_pix * model.target.pixelscale.item() Rbins = [0.0] - while Rbins[-1] < Rlast_pix: - Rbins.append(Rbins[-1] + max(min_bin_width, Rbins[-1] * bin_scale)) + while Rbins[-1] < Rlast_phys: + Rbins.append(Rbins[-1] + max(2 * model.target.pixelscale.item(), Rbins[-1] * 0.1)) Rbins = np.array(Rbins) - Rbins = Rbins * model.target.pixel_length.item() # back to physical units with torch.no_grad(): image = model.target[model.window] - X, Y = image.get_coordinate_meshgrid() - model["center"].value[..., None, None] - X, Y = model.transform_coordinates(X, Y) - R = model.radius_metric(X, Y) - R = R.detach().cpu().numpy() + x, y = image.coordinate_center_meshgrid() + x, y = model.transform_coordinates(x, y, params=()) + R = backend.sqrt(x**2 + y**2) + R = backend.to_numpy(R) + + dat = backend.to_numpy(image._data) + # remove masked pixels + mask = backend.to_numpy(image._mask) + dat = dat[~mask] + R = R[~mask] count, bins, binnum = binned_statistic( R.ravel(), - image.data.detach().cpu().numpy().ravel(), + dat.ravel(), statistic="count", bins=Rbins, ) stat, bins, binnum = binned_statistic( R.ravel(), - image.data.detach().cpu().numpy().ravel(), + dat.ravel(), statistic="median", bins=Rbins, ) @@ -134,7 +147,7 @@ def radial_median_profile( scat, bins, binnum = binned_statistic( R.ravel(), - image.data.detach().cpu().numpy().ravel(), + dat.ravel(), statistic=partial(iqr, rng=(16, 84)), bins=Rbins, ) @@ -155,10 +168,8 @@ def radial_median_profile( "elinewidth": 1, "color": main_pallet["primary2"], "label": "data profile", + **plot_kwargs, } - kwargs.update(plot_kwargs) - if rad_unit == "pixel": - Rbins = Rbins / model.target.pixel_length.item() ax.errorbar( (Rbins[:-1] + Rbins[1:]) / 2, stat, @@ -176,63 +187,38 @@ def radial_median_profile( def ray_light_profile( fig, ax, - model, + model: Model, rad_unit="arcsec", extend_profile=1.0, resolution=1000, - doassert=True, ): - xx = torch.linspace( - 0, - torch.max(model.window.shape / 2) * extend_profile, - int(resolution), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, - ) - for r in range(model.rays): - if model.rays <= 5: - col = main_pallet[f"primary{r+1}"] - else: - col = cmap_grad(r / model.rays) - with torch.no_grad(): - ax.plot( - xx.detach().cpu().numpy(), - np.log10(model.iradial_model(r, xx).detach().cpu().numpy()), - linewidth=2, - color=col, - label=f"{model.name} profile {r}", - ) - ax.set_ylabel("log$_{10}$(flux)") - ax.set_xlabel(f"Radius [{rad_unit}]") - - return fig, ax - - -def wedge_light_profile( - fig, - ax, - model, - rad_unit="arcsec", - extend_profile=1.0, - resolution=1000, - doassert=True, -): - xx = torch.linspace( + """ + Used for plotting ray (wedge) type models which define a `iradial_model` method. These have multiple radial profiles. + + **Args:** + - `fig`: matplotlib figure object + - `ax`: matplotlib axis object + - `model` (Model): Model object from which to plot the radial profile. + - `rad_unit` (str): The name of the radius units to plot. + - `extend_profile` (float): The factor by which to extend the profile beyond the maximum radius of the model's window. Default: 1.0 + - `resolution` (int): The number of points to use in the profile. Default: 1000 + """ + xx = backend.linspace( 0, - torch.max(model.window.shape / 2) * extend_profile, + max(model.window.shape) * model.target.pixelscale * extend_profile / 2, int(resolution), - dtype=AP_config.ap_dtype, - device=AP_config.ap_device, + dtype=config.DTYPE, + device=config.DEVICE, ) - for r in range(model.wedges): - if model.wedges <= 5: + for r in range(model.segments): + if model.segments <= 3: col = main_pallet[f"primary{r+1}"] else: - col = cmap_grad(r / model.wedges) + col = cmap_grad(r / model.segments) with torch.no_grad(): ax.plot( - xx.detach().cpu().numpy(), - np.log10(model.iradial_model(r, xx).detach().cpu().numpy()), + backend.to_numpy(xx), + np.log10(backend.to_numpy(model.iradial_model(r, xx, params=()))), linewidth=2, color=col, label=f"{model.name} profile {r}", @@ -243,26 +229,21 @@ def wedge_light_profile( return fig, ax -def warp_phase_profile(fig, ax, model, rad_unit="arcsec", doassert=True): - if doassert: - if not isinstance(model, Warp_Galaxy): - raise InvalidModel( - f"warp_phase_profile must be given a 'Warp_Galaxy' object. Not {type(model)}" - ) - +def warp_phase_profile(fig, ax, model: Model, rad_unit="arcsec"): + """Used to plot the phase profile of a warp model. This gives the axis ratio and position angle as a function of radius.""" ax.plot( - model.profR, - model["q(R)"].value.detach().cpu().numpy(), + backend.to_numpy(model.q_R.prof), + model.q_R.npvalue, linewidth=2, color=main_pallet["primary1"], label=f"{model.name} axis ratio", ) ax.plot( - model.profR, - model["PA(R)"].detach().cpu().numpy() / np.pi, + backend.to_numpy(model.PA_R.prof), + model.PA_R.npvalue / np.pi, linewidth=2, - color=main_pallet["secondary1"], - label=f"{model.name} position angle", + color=main_pallet["primary2"], + label=f"{model.name} position angle/$\\pi$", ) ax.set_ylim([0, 1]) ax.set_ylabel("q [b/a], PA [rad/$\\pi$]") diff --git a/astrophot/plots/shared_elements.py b/astrophot/plots/shared_elements.py deleted file mode 100644 index 9751f757..00000000 --- a/astrophot/plots/shared_elements.py +++ /dev/null @@ -1,111 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt -from astropy.visualization.mpl_normalize import ImageNormalize -from astropy.visualization import LogStretch, HistEqStretch - - -def LSBImage(dat, noise): - plt.figure(figsize=(6, 6)) - plt.imshow( - dat, - origin="lower", - cmap="Greys", - norm=ImageNormalize( - stretch=HistEqStretch(dat[dat <= 3 * noise]), - clip=False, - vmax=3 * noise, - vmin=np.min(dat), - ), - ) - my_cmap = copy(cm.Greys_r) - my_cmap.set_under("k", alpha=0) - - plt.imshow( - np.ma.masked_where(dat < 3 * noise, dat), - origin="lower", - cmap=my_cmap, - norm=ImageNormalize(stretch=LogStretch(), clip=False), - clim=[3 * noise, None], - interpolation="none", - ) - plt.xticks([]) - plt.yticks([]) - plt.subplots_adjust(left=0.03, right=0.97, top=0.97, bottom=0.05) - plt.xlim([0, dat.shape[1]]) - plt.ylim([0, dat.shape[0]]) - - -def _display_time(seconds): - intervals = ( - ("hours", 3600), # 60 * 60 - ("arcminutes", 60), - ("arcseconds", 1), - ) - result = [] - - for name, count in intervals: - value = seconds // count - if value: - seconds -= value * count - if value == 1: - name = name.rstrip("s") - result.append("{} {}".format(value, name)) - return ", ".join(result) - - -def AddScale(ax, img_width, loc="lower right"): - """ - ax: figure axis object - img_width: image width in arcseconds - loc: location to put the scale bar - """ - scale_width = int(img_width / 6) - - if scale_width > 60 and scale_width % 60 <= 15: - scale_width -= scale_width % 60 - if scale_width > 45 and scale_width % 60 >= 45: - scale_width += 60 - (scale_width % 60) - if 15 < scale_width % 60 < 45: - scale_width += 30 - (scale_width % 60) - - label = _display_time(scale_width) - - xloc = 0.05 if "left" in loc else 0.95 - yloc = 0.95 if "upper" in loc else 0.05 - - ax.text( - xloc - 0.5 * scale_width / img_width, - yloc + 0.005, - label, - horizontalalignment="center", - verticalalignment="bottom", - transform=ax.transAxes, - fontsize="x-small" if len(label) < 20 else "xx-small", - weight="bold", - color=autocolours["red1"], - ) - ax.plot( - [xloc - scale_width / img_width, xloc], - [yloc, yloc], - transform=ax.transAxes, - color=autocolours["red1"], - ) - - -def AddLogo(fig, loc=[0.8, 0.01, 0.844 / 5, 0.185 / 5], white=False): - im = plt.imread( - get_sample_data( - os.path.join( - os.environ["AUTOPROF"], - "_static/", - ("AP_logo_white.png" if white else "AP_logo.png"), - ) - ) - ) - newax = fig.add_axes(loc, zorder=1000) - if white: - newax.imshow(np.zeros(im.shape) + np.array([0, 0, 0, 1])) - else: - newax.imshow(np.ones(im.shape)) - newax.imshow(im) - newax.axis("off") diff --git a/astrophot/plots/visuals.py b/astrophot/plots/visuals.py index 8ebc913f..37af1d89 100644 --- a/astrophot/plots/visuals.py +++ b/astrophot/plots/visuals.py @@ -1,373 +1,21 @@ -import numpy as np -from matplotlib.colors import LinearSegmentedColormap +from matplotlib.pyplot import get_cmap + +# from matplotlib.colors import ListedColormap +# import numpy as np __all__ = ["main_pallet", "cmap_grad", "cmap_div"] main_pallet = { - "primary1": "#5FAD41", - "primary2": "#46A057", - "primary3": "#2D936C", - "secondary1": "#595122", - "secondary2": "#BFAE48", - "pop": "#391463", + "primary1": "tab:blue", + "primary2": "tab:orange", + "primary3": "tab:red", + "secondary1": "tab:green", + "secondary2": "tab:purple", + "pop": "tab:pink", } -# grad_list = [ -# "#000000", -# "#1A1F16", -# "#1E3F20", -# "#335E31", # "#294C28", -# "#477641", # "#345830", -# "#5D986D", # "#4A7856", -# "#88BF9E", # "#6FB28A", -# "#94ECBE", -# "#FFFFFF", -# ] - -# grad_list = np.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), "rgb_colours.npy")) -# not proud of this but it works -grad_list = [ - [0.02352941176470601, 0.05490196078431372, 0.03137254901960787], - [0.025423221664412132, 0.057920953380312966, 0.033516620406216086], - [0.027376785284830882, 0.06093685701603565, 0.035720426678078974], - [0.02938891743876659, 0.0639504745699993, 0.03798295006291389], - [0.03145841869575705, 0.06696255038243151, 0.04030317185736432], - [0.033584075048444885, 0.06997377604547542, 0.04260833195845542], - [0.03576465765226276, 0.07298479546414544, 0.044888919486002675], - [0.03799892263356237, 0.0759962092973116, 0.0471477168479796], - [0.04028561096212802, 0.07900857886908796, 0.049385189300118405], - [0.04255435348091496, 0.08202242962578685, 0.05160177112762838], - [0.04479522750797262, 0.0850382542012795, 0.053797868796876404], - [0.047011290960205, 0.08805651514356005, 0.05597386374386699], - [0.04920301153517722, 0.09107764734708199, 0.05813011485045571], - [0.05137081101697757, 0.09410206022865564, 0.06026696065107289], - [0.05351507020340438, 0.09713013967907827, 0.062384721306060466], - [0.05563613318418619, 0.1001622498180042, 0.06448370037222773], - [0.057734311075703586, 0.10319873457564618, 0.06656418639667772], - [0.0598098852979892, 0.10623991912163352, 0.0686264543561699], - [0.06186311046428467, 0.10928611115857814, 0.07067076696111699], - [0.06389421694104419, 0.11233760209556845, 0.07269737584065203], - [0.06590341312637987, 0.11539466811482368, 0.07470652262295904], - [0.06789088748697142, 0.11845757114304345, 0.07669843992315992], - [0.06985681038697011, 0.12152655973754478, 0.07867335224943414], - [0.07180133573715267, 0.12460186989603428, 0.08063147683667268], - [0.07372460248826435, 0.12768372579779083, 0.08257302441578668], - [0.07562673598889558, 0.13077234048311664, 0.08449819992578214], - [0.07750784922530482, 0.1338679164771084, 0.0864072031748393], - [0.0793680439581338, 0.13697064636311304, 0.08830022945588598], - [0.08120741176889373, 0.14008071331062474, 0.0901774701215026], - [0.08302603502740363, 0.1431982915618533, 0.09203911312243368], - [0.08482398778987404, 0.14632354688073884, 0.0938853435134896], - [0.08660133663613406, 0.14945663696777384, 0.09571634393019346], - [0.08835814145343088, 0.15259771184365092, 0.0975322950391592], - [0.09009445617335096, 0.15574691420443426, 0.09933337596485287], - [0.09181032946766504, 0.1589043797506794, 0.10111976469511005], - [0.09350580540821188, 0.16207023749268656, 0.1028916384675223], - [0.0951809240954313, 0.16524461003384922, 0.1046491741385914], - [0.09683572225960263, 0.1684276138338816, 0.10639254853734492], - [0.09847023383851092, 0.17161935945352033, 0.10812193880494256], - [0.10008449053481877, 0.17481995178216395, 0.10983752272163866], - [0.10167852235616967, 0.1780294902497653, 0.11153947902233724], - [0.10325235814074307, 0.18124806902417712, 0.11322798770185102], - [0.10480602607076109, 0.18447577719504085, 0.1149032303108651], - [0.10633955417621349, 0.18771269894521686, 0.11656539024350993], - [0.10785297083092218, 0.19095891371065837, 0.11821465301736292], - [0.10934630524287259, 0.19421449632956456, 0.11985120654661363], - [0.11081958794061691, 0.19747951718156898, 0.12147524140906224], - [0.11227285125743197, 0.20075404231766003, 0.12308695110755152], - [0.113706129814793, 0.2040381335814737, 0.12468653232637883], - [0.11511946100664691, 0.20733184872254534, 0.12627418518317554], - [0.1165128854858628, 0.21063524150205992, 0.12785011347669853], - [0.11788644765418299, 0.21394836179160054, 0.12941452493092973], - [0.11924019615691589, 0.21727125566535316, 0.13096763143583748], - [0.12057418438355699, 0.2206039654861931, 0.13250964928512182], - [0.12188847097547184, 0.2239465299860438, 0.13404079941121932], - [0.12318312034172044, 0.2272989843408748, 0.13556130761782226], - [0.12445820318406359, 0.23066136024067158, 0.13707140481012495], - [0.12571379703214872, 0.2340336859546926, 0.13857132722298834], - [0.12694998678982505, 0.23741598639230338, 0.14006131664718174], - [0.1281668652935146, 0.2408082831596569, 0.14154162065383566], - [0.12936453388351488, 0.24421059461247346, 0.14301249281721415], - [0.13054310298908373, 0.24762293590515277, 0.14447419293589053], - [0.13170269272811447, 0.25104531903643956, 0.1459269872523864], - [0.1328434335221802, 0.25447775289184116, 0.14737114867131162], - [0.1339654667276644, 0.2579202432829991, 0.14880695697601956], - [0.13506894528368824, 0.26137279298418187, 0.15023469904377038], - [0.1361540343774794, 0.26483540176607334, 0.15165466905937425], - [0.13722091212776352, 0.2683080664270149, 0.15306716872726295], - [0.1382697702867517, 0.271790780821844, 0.15447250748191957], - [0.13930081496119553, 0.2752835358884738, 0.1558710026965733], - [0.14031426735293703, 0.2787863196723416, 0.15726297989004664], - [0.1413103645193169, 0.2822991173488483, 0.15864877293161908], - [0.142289360153694, 0.28582191124391093, 0.16002872424375406], - [0.1432515253862969, 0.2893546808527299, 0.16140318500251255], - [0.1441971496054517, 0.29289740285688415, 0.16277251533545473], - [0.14512654129921948, 0.29645005113984313, 0.16413708451681472], - [0.1460400289172567, 0.3000125968009987, 0.1654972711597065], - [0.14693796175266738, 0.30358500816829653, 0.16685346340510454], - [0.14782071084339368, 0.3071672508095593, 0.16820605910731407], - [0.14868866989260401, 0.31075928754257476, 0.16955546601563484], - [0.14954225620729, 0.3143610784440312, 0.1709021019518895], - [0.1503819116541449, 0.31797258085736924, 0.1722463949834796], - [0.15120810363154275, 0.32159374939962293, 0.17358878359160151], - [0.1520213260562527, 0.32522453596731304, 0.17492971683424066], - [0.15282210036320543, 0.32886488974146083, 0.1762696545035385], - [0.15361097651642713, 0.33251475719178275, 0.17760906727710982], - [0.15438853402891373, 0.3361740820801234, 0.17894843686287198], - [0.15515538298892081, 0.3398428054631865, 0.18028825613691796], - [0.15591216508980174, 0.34352086569461787, 0.18162902927396304], - [0.15665955466015927, 0.3472081984264961, 0.18297127186986703], - [0.15739825969072788, 0.3509047366102775, 0.1843155110557227], - [0.1581290228539065, 0.35461041049725495, 0.18566228560298584], - [0.15885262251157475, 0.35832514763856854, 0.18701214601910962], - [0.1595698737062211, 0.36204887288482673, 0.1883656546331343], - [0.16028162913006866, 0.365781508385379, 0.18972338567067223], - [0.1609887800663473, 0.3695229735872851, 0.19108592531772042], - [0.1616922572963582, 0.3732731852340314, 0.19245387177272855], - [0.1623930319654953, 0.3770320573640352, 0.19382783528634243], - [0.1630921164008896, 0.38079950130898077, 0.1952084381882439], - [0.16379056487278745, 0.38457542569203246, 0.19659631490050877], - [0.16448947429132857, 0.38835973642596816, 0.19799211193690508], - [0.16518998482987113, 0.39215233671127353, 0.19939648788756642], - [0.1658932804655635, 0.3959531270342421, 0.20081011338847304], - [0.166600589427405, 0.39976200516512433, 0.20223367107520046], - [0.16731318454171665, 0.4035788661563645, 0.20366785552039823], - [0.1680323834645103, 0.4074036023409737, 0.20511337315448636], - [0.16875954879008526, 0.4112361033310768, 0.20657094216908256], - [0.16949608802490335, 0.4150762560166792, 0.20804129240268987], - [0.17024345341575386, 0.4189239445646968, 0.20952516520821696], - [0.17100314162122876, 0.42277905041829333, 0.21102331330192792], - [0.17177669321562006, 0.4266414522965662, 0.21253650059345497], - [0.1725656920146934, 0.43051102619463094, 0.21406550199656194], - [0.17337176421315095, 0.43438764538414765, 0.2156111032203738], - [0.174196577324217, 0.4382711804143329, 0.21717410054084768], - [0.1750418389125363, 0.4421614991135102, 0.21875530055231346], - [0.1759092951125111, 0.44605846659124265, 0.22035551989895852], - [0.17680072892538157, 0.4499619452410963, 0.22197558498620257], - [0.17771795828963766, 0.4538717947440886, 0.2236163316719602], - [0.178662833920936, 0.45778787207287114, 0.22527860493786028], - [0.1796372369194738, 0.46171003149669404, 0.22696325854055394], - [0.18064307614457878, 0.46563812458721304, 0.22867115464331578], - [0.18168228535855538, 0.4695720002251926, 0.23040316342821293], - [0.1827568201439772, 0.47351150460815455, 0.23216016268918455], - [0.1838686546011176, 0.477456481259039, 0.233943037406457], - [0.18501977783468404, 0.48140677103593116, 0.23575267930278093], - [0.18621219024170063, 0.4853622121429166, 0.23758998638206003], - [0.18744789961499414, 0.4893226401421262, 0.23945586245100942], - [0.1887289170794073, 0.49328788796703654, 0.24135121662454895], - [0.19005725288048025, 0.4972577859370923, 0.24327696281571337], - [0.19143491204781923, 0.5012321617737128, 0.24523401921092142], - [0.19286388995773657, 0.505210840617762, 0.2472233077315104], - [0.1943461678218979, 0.509193645048545, 0.24924575348250605], - [0.1958837081305836, 0.5131803951044065, 0.2513022841896479], - [0.1974784500806806, 0.5171709083050143, 0.25339382962574203], - [0.1991323050198725, 0.5211649996753929, 0.2555213210274589], - [0.20084715193916947, 0.5251624817718009, 0.2576856905037352], - [0.20262483304633028, 0.5291631647095255, 0.2598878704369637], - [0.204467149452764, 0.5331668561926818, 0.26212879287818813], - [0.206375857005755, 0.5371733615461064, 0.2644093889375338], - [0.20835266229704946, 0.5411824837494301, 0.26673058817112216], - [0.2103992188771397, 0.5451940234734288, 0.26909331796570707], - [0.21251712370282538, 0.5492077791187413, 0.2714985029222856], - [0.21470791384316384, 0.5532235468570559, 0.2739470642399031], - [0.21697306346622774, 0.5572411206748591, 0.27643991910086557], - [0.21931398112597572, 0.5612602924198622, 0.2789779800585389], - [0.22173200736531357, 0.5652808518501962, 0.28156215442887567], - [0.22422841264772203, 0.5693025866864954, 0.28419334368676863], - [0.2268043956263523, 0.5733252826669735, 0.2868724428682829], - [0.22946108175552465, 0.5773487236056106, 0.28960033997974594], - [0.23219952224604845, 0.5813726914535655, 0.2923779154146281], - [0.23502069336196157, 0.5853969663639336, 0.29520604137906364], - [0.23792549605282798, 0.5894213267599773, 0.2980855813267848], - [0.24091475591248918, 0.5934455494069475, 0.30101738940417244], - [0.2439892234519782, 0.5974694094876285, 0.3040023099060248], - [0.2471495746716238, 0.6014926806817401, 0.3070411767425731], - [0.2503964119149868, 0.6055151352493269, 0.31013481291816875], - [0.25373026498509416, 0.6095365441182762, 0.3132840300219835], - [0.25715159250188796, 0.6135566769761028, 0.3164896277309624], - [0.26066078347841287, 0.6175753023661407, 0.31975239332517896], - [0.26425815909238537, 0.6215921877882963, 0.3230731012156454], - [0.26794397462927927, 0.6256070998045035, 0.32645251248454027], - [0.27171842157278525, 0.6296198041490337, 0.32989137443772115], - [0.2755816298187335, 0.6336300658438219, 0.3333904201693042], - [0.27953366998896156, 0.6376376493189533, 0.3369503681380034], - [0.2835745558223273, 0.6416423185384807, 0.34057192175484835], - [0.2877042466209889, 0.6456438371317278, 0.3442557689818123], - [0.2919226497312409, 0.6496419685302498, 0.34800258194081274], - [0.2962296230395252, 0.6536364761106029, 0.3518130165324921], - [0.30062497746549316, 0.6576271233431101, 0.3556877120641009], - [0.3051084794357496, 0.6616136739467778, 0.3596272908857725], - [0.3096798533232689, 0.6655958920505426, 0.3636323580344192], - [0.3143387838391119, 0.6695735423610237, 0.3677035008844355], - [0.3190849183647877, 0.6735463903369483, 0.37184128880436657], - [0.3239178692149385, 0.677514202370436, 0.37604627281865705], - [0.3288372158217989, 0.6814767459753105, 0.38031898527358976], - [0.33384250683409117, 0.6854337899826277, 0.3846599395064855], - [0.33893326212463143, 0.68938510474359, 0.38906962951724855], - [0.344108974702066, 0.6933304623400306, 0.39354852964131404], - [0.34936911252340047, 0.697269636802651, 0.398097094223074], - [0.354713120205116, 0.7012024043371861, 0.40271575728885467], - [0.3601404206316701, 0.7051285435586824, 0.407404932218529], - [0.365650416461051, 0.7090478357340677, 0.4121650114148717], - [0.3712424915278934, 0.7129600650331871, 0.4169963659697734], - [0.3769160121453231, 0.7168650187884974, 0.4218993453264664], - [0.3826703283074528, 0.7207624877635708, 0.426874276936929], - [0.38850477479467865, 0.724652266430624, 0.4319214659136745], - [0.39441867218475135, 0.7285341532572063, 0.4370411946751511], - [0.4004113277726398, 0.7324079510022498, 0.4422337225840276], - [0.40648203640264285, 0.7362734670216392, 0.4474992855776477], - [0.41263008121642264, 0.7401305135834749, 0.4528380957899934], - [0.41885473432075093, 0.7439789081931996, 0.4582503411645124], - [0.42515525737890464, 0.74781847392875, 0.4637361850571962], - [0.4315309021296915, 0.7516490397859019, 0.4692957658293329], - [0.4379809108381271, 0.755470441033972, 0.4749291964293609], - [0.44450451668174257, 0.7592825195820302, 0.4806365639632952], - [0.4511009440764859, 0.7630851243557921, 0.48641792925318267], - [0.4577694089460426, 0.766878111685348, 0.49227332638307436], - [0.46450911893842234, 0.7706613457038759, 0.49820276223198207], - [0.47131927359337594, 0.7744346987575222, 0.5042062159932884], - [0.47819906446422616, 0.7781980518265826, 0.5102836386800519], - [0.4851476751974689, 0.7819512949581628, 0.5164349526156181], - [0.4921642815733072, 0.7856943277104902, 0.5226600509088999], - [0.49924805151023793, 0.7894270596090324, 0.528958796913626], - [0.506398145036514, 0.7931494106146145, 0.5353310236707783], - [0.5136137142311699, 0.7968613116037284, 0.5417765333333355], - [0.5208939031371336, 0.8005627048612254, 0.5482950965723069], - [0.528237847648764, 0.8042535445856207, 0.5548864519629064], - [0.5356446753758648, 0.8079337974072546, 0.5615503053495099], - [0.5431135054862372, 0.8116034429195591, 0.5682863291878316], - [0.5506434485283842, 0.8152624742237489, 0.5750941618624987], - [0.558233606235997, 0.8189108984872628, 0.5819734069778754], - [0.5658830713155681, 0.8225487375163245, 0.588923632619649], - [0.5735909272182081, 0.8261760283430954, 0.5959443705842361], - [0.5813562478966682, 0.8297928238278868, 0.6030351155725882], - [0.5891780975482949, 0.8333991932770318, 0.6101953243443609], - [0.5970555303443447, 0.8369952230771135, 0.6174244148277456], - [0.6049875901460008, 0.8405810173463196, 0.6247217651794131], - [0.6129733102071075, 0.8441566986038703, 0.6320867127880987], - [0.6210117128633796, 0.8477224084586106, 0.6395185532141892], - [0.629101809207598, 0.8512783083180583, 0.6470165390563735], - [0.6372425987500262, 0.8548245801194263, 0.6545798787348152], - [0.6454330690629193, 0.8583614270844092, 0.6622077351784434], - [0.6536721954076695, 0.861889074499868, 0.6698992244017301], - [0.6619589403427637, 0.8654077705269353, 0.6776534139536473], - [0.670292253310307, 0.868917787041518, 0.6854693212183214], - [0.6786710701982759, 0.8724194205097873, 0.6933459115430326], - [0.6870943128752733, 0.8759129929029018, 0.7012820961645896], - [0.6955608886937588, 0.8793988526560546, 0.7092767298994272], - [0.7040696899570745, 0.8828773756779759, 0.7173286085559195], - [0.7126195933446154, 0.8863489664182553, 0.7254364660189113], - [0.7212094592884835, 0.8898140590014009, 0.7335989709460631], - [0.7298381312936156, 0.8932731184384776, 0.7418147230026438], - [0.738504435191736, 0.8967266419295534, 0.7500822485452499], - [0.7472071783175026, 0.900175160273203, 0.7583999956446169], - [0.755945148592551, 0.9036192394031619, 0.7667663283120103], - [0.7647171134997374, 0.9070594820771404, 0.7751795197609306], - [0.7735218189254033, 0.9104965297491636, 0.783637744493875], - [0.7823579878412776, 0.9139310646651907, 0.7921390689495337], - [0.79122431878925, 0.9173638122327802, 0.8006814403748957], - [0.8001194841201199, 0.9207955437304958, 0.8092626734934449], - [0.8090421279203749, 0.9242270794429251, 0.817880434416763], - [0.8179908635354565, 0.9276592923352225, 0.8265322210808196], - [0.8269642705600306, 0.9310931124203533, 0.8352153392634705], - [0.8359608911072839, 0.934529532028408, 0.8439268729323351], - [0.844979225077973, 0.9379696122690698, 0.852663647247446], - [0.8540177240040997, 0.9414144910996233, 0.8614221819497985], - [0.863074782804233, 0.9448653935946189, 0.8701986320294973], - [0.8721487283913724, 0.9483236452977989, 0.87898871137326], - [0.8812378033992858, 0.9517906899876957, 0.8877875933734316], - [0.890340142117086, 0.955268113920236, 0.8965897799935405], - [0.8994537336219597, 0.9587576798305327, 0.9053889271758421], - [0.908576363256908, 0.9622613760595616, 0.9141776092709066], - [0.9177055163843063, 0.9657814898283584, 0.9229469978386258], - [0.9268382144426769, 0.9693207202671751, 0.9316864204851563], - [0.9359707258804478, 0.9728823589346071, 0.9403827547526942], - [0.9450980392156855, 0.9764705882352952, 0.9490196078431373], -] - -cmap_grad = LinearSegmentedColormap.from_list("cmap_grad", grad_list) - -# # grad_list = ["#000000", "#1A1F16", "#1E3F20", "#294C28", "#345830", "#4A7856", "#6FB28A", "#94ECBE", "#FFFFFF"] -# grad_cdict = {"red": [], "green": [], "blue": []} -# cpoints = np.linspace(0, 1, len(grad_list)) -# for i in range(len(grad_list)): -# grad_cdict["red"].append( -# [cpoints[i], int(grad_list[i][1:3], 16) / 256, int(grad_list[i][1:3], 16) / 256] -# ) -# grad_cdict["green"].append( -# [cpoints[i], int(grad_list[i][3:5], 16) / 256, int(grad_list[i][3:5], 16) / 256] -# ) -# grad_cdict["blue"].append( -# [cpoints[i], int(grad_list[i][5:7], 16) / 256, int(grad_list[i][5:7], 16) / 256] -# ) -# cmap_grad = LinearSegmentedColormap("cmap_grad", grad_cdict) - -div_list = [ - "#332A1F", - "#514129", - "#7C6527", - "#A2862A", - "#DAB944", - "#FFFFFF", - "#7EC87E", - "#3EA343", - "#267D2F", - "#0D5D09", - "#073805", -] -# div_list = ["#083D77", "#7E886B", "#B9AE65", "#FFFFFF", "#F1B555", "#EE964B", "#F95738"] -div_cdict = {"red": [], "green": [], "blue": []} -cpoints = np.linspace(0, 1, len(div_list)) -for i in range(len(div_list)): - div_cdict["red"].append( - [cpoints[i], int(div_list[i][1:3], 16) / 256, int(div_list[i][1:3], 16) / 256] - ) - div_cdict["green"].append( - [cpoints[i], int(div_list[i][3:5], 16) / 256, int(div_list[i][3:5], 16) / 256] - ) - div_cdict["blue"].append( - [cpoints[i], int(div_list[i][5:7], 16) / 256, int(div_list[i][5:7], 16) / 256] - ) -cmap_div = LinearSegmentedColormap("cmap_div", div_cdict) - -# P = plt.cm.plasma_r -# C = plt.cm.cividis -# N = 3 -# cmap_div = ListedColormap(["#083D77", "#7E886B", "#B9AE65", "#FFFFFF", "#F1B555", "#EE964B", "#F95738"]) - -# main_pallet = { -# "primary1": "g", -# "primary2": "r", -# "primary3": "b", -# "primary4": "ornnge", -# "primary5": "cyan", -# "secondary1": "purple", -# "secondary2": "salmon", -# "secondary3": "k", -# "pop": "yellow", -# } - -# cmap_grad = plt.cm.magma -# cmap_div = plt.cm.seismic - -# from matplotlib.colors import LinearSegmentedColormap -# cmaplist = ["#000000", "#720026", "#A0213F", "#ce4257", "#E76154", "#ff9b54", "#ffd1b1"] -# cdict = {"red": [], "green": [], "blue": []} -# cpoints = np.linspace(0, 1, len(cmaplist)) -# for i in range(len(cmaplist)): -# cdict["red"].append( -# [cpoints[i], int(cmaplist[i][1:3], 16) / 256, int(cmaplist[i][1:3], 16) / 256] -# ) -# cdict["green"].append( -# [cpoints[i], int(cmaplist[i][3:5], 16) / 256, int(cmaplist[i][3:5], 16) / 256] -# ) -# cdict["blue"].append( -# [cpoints[i], int(cmaplist[i][5:7], 16) / 256, int(cmaplist[i][5:7], 16) / 256] -# ) -# autocmap = LinearSegmentedColormap("autocmap", cdict) -# autocolours = { -# "red1": "#c33248", -# "blue1": "#84DCCF", -# "blue2": "#6F8AB7", -# "redrange": ["#720026", "#A0213F", "#ce4257", "#E76154", "#ff9b54", "#ffd1b1"], -# } # '#D95D39' +cmap_grad = get_cmap("inferno") +cmap_div = get_cmap("seismic") # twilight RdBu_r +# print(__file__) +# colors = np.load(f"{__file__[:-10]}/managua_cmap.npy") +# cmap_div = ListedColormap(list(reversed(colors)), name="mangua") diff --git a/astrophot/utils/__init__.py b/astrophot/utils/__init__.py index dec9f641..33925367 100644 --- a/astrophot/utils/__init__.py +++ b/astrophot/utils/__init__.py @@ -1,23 +1,19 @@ from . import ( - optimization, - angle_operations, + conversions, + initialize, decorators, + integration, interpolate, - operations, parametric_profiles, - isophote, - initialize, - conversions, ) +from .fitsopen import ls_open __all__ = [ - "optimization", - "angle_operations", "decorators", "interpolate", - "operations", + "integration", "parametric_profiles", - "isophote", "initialize", "conversions", + "ls_open", ] diff --git a/astrophot/utils/angle_operations.py b/astrophot/utils/angle_operations.py deleted file mode 100644 index e4119e64..00000000 --- a/astrophot/utils/angle_operations.py +++ /dev/null @@ -1,56 +0,0 @@ -import numpy as np -from scipy.stats import iqr - - -def Angle_Average(a): - """ - Compute the average for a list of angles, which may wrap around a cyclic boundary. - - a: list of angles in the range [0,2pi] - """ - i = np.cos(a) + 1j * np.sin(a) - return np.angle(np.mean(i)) - - -def Angle_Median(a): - """ - Compute the median for a list of angles, which may wrap around a cyclic boundary. - - a: list of angles in the range [0,2pi] - """ - i = np.median(np.cos(a)) + 1j * np.median(np.sin(a)) - return np.angle(i) - - -def Angle_Scatter(a): - """ - Compute the scatter for a list of angles, which may wrap around a cyclic boundary. - - a: list of angles in the range [0,2pi] - """ - i = np.cos(a) + 1j * np.sin(a) - return iqr(np.angle(1j * i / np.mean(i)), rng=[16, 84]) - - -def Angle_COM_PA(flux, X=None, Y=None): - """Performs a center of angular mass calculation by using the flux as - weights to compute a position angle which accounts for the general - "direction" of the light. This PA is computed mod pi since these - are 180 degree rotation symmetric. - - Args: - flux: the weight values for each element (by assumption, pixel fluxes) in a 2D array - X: x coordinate of the flux points. Assumed centered pixel indices if not given - Y: y coordinate of the flux points. Assumed centered pixel indices if not given - - """ - if X is None: - S = flux.shape - X, Y = np.meshgrid(np.arange(S[1]) - S[1] / 2, np.arange(S[0]) - S[0] / 2, indexing="xy") - - theta = np.arctan2(Y, X) - - ang_com_cos = np.sum(flux * np.cos(2 * theta)) / np.sum(flux) - ang_com_sin = np.sum(flux * np.sin(2 * theta)) / np.sum(flux) - - return np.arctan2(ang_com_sin, ang_com_cos) / 2 % np.pi diff --git a/astrophot/utils/conversions/__init__.py b/astrophot/utils/conversions/__init__.py index e69de29b..e3b9a8f4 100644 --- a/astrophot/utils/conversions/__init__.py +++ b/astrophot/utils/conversions/__init__.py @@ -0,0 +1,47 @@ +from .functions import ( + sersic_n_to_b, + sersic_I0_to_flux_np, + sersic_flux_to_I0_np, + sersic_Ie_to_flux_np, + sersic_flux_to_Ie_np, + sersic_I0_to_flux_torch, + sersic_flux_to_I0_torch, + sersic_Ie_to_flux_torch, + sersic_flux_to_Ie_torch, + sersic_inv_np, + sersic_inv_torch, + moffat_I0_to_flux, +) +from .units import ( + deg_to_arcsec, + arcsec_to_deg, + flux_to_sb, + flux_to_mag, + sb_to_flux, + mag_to_flux, + magperarcsec2_to_mag, + mag_to_magperarcsec2, +) + +__all__ = ( + "sersic_n_to_b", + "sersic_I0_to_flux_np", + "sersic_flux_to_I0_np", + "sersic_Ie_to_flux_np", + "sersic_flux_to_Ie_np", + "sersic_I0_to_flux_torch", + "sersic_flux_to_I0_torch", + "sersic_Ie_to_flux_torch", + "sersic_flux_to_Ie_torch", + "sersic_inv_np", + "sersic_inv_torch", + "moffat_I0_to_flux", + "deg_to_arcsec", + "arcsec_to_deg", + "flux_to_sb", + "flux_to_mag", + "sb_to_flux", + "mag_to_flux", + "magperarcsec2_to_mag", + "mag_to_magperarcsec2", +) diff --git a/astrophot/utils/conversions/coordinates.py b/astrophot/utils/conversions/coordinates.py deleted file mode 100644 index 30deb64d..00000000 --- a/astrophot/utils/conversions/coordinates.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch -import numpy as np - - -def Rotate_Cartesian(theta, X, Y=None): - """ - Applies a rotation matrix to the X,Y coordinates - """ - s = torch.sin(theta) - c = torch.cos(theta) - if Y is None: - return c * X[0] - s * X[1], s * X[0] + c * X[1] - return c * X - s * Y, s * X + c * Y - - -def Rotate_Cartesian_np(theta, X, Y): - """ - Applies a rotation matrix to the X,Y coordinates - """ - s = np.sin(theta) - c = np.cos(theta) - return c * X - s * Y, c * Y + s * X - - -def Axis_Ratio_Cartesian(q, X, Y, theta=0.0, inv_scale=False): - """ - Applies the transformation: R(theta) Q R(-theta) - where R is the rotation matrix and Q is the matrix which scales the y component by 1/q. - This effectively counter-rotates the coordinates so that the angle theta is along the x-axis - then applies the y-axis scaling, then re-rotates everything back to where it was. - """ - if inv_scale: - scale = (1 / q) - 1 - else: - scale = q - 1 - ss = 1 + scale * torch.pow(torch.sin(theta), 2) - cc = 1 + scale * torch.pow(torch.cos(theta), 2) - s2 = scale * torch.sin(2 * theta) - return ss * X - s2 * Y / 2, -s2 * X / 2 + cc * Y - - -def Axis_Ratio_Cartesian_np(q, X, Y, theta=0.0, inv_scale=False): - """ - Applies the transformation: R(theta) Q R(-theta) - where R is the rotation matrix and Q is the matrix which scales the y component by 1/q. - This effectively counter-rotates the coordinates so that the angle theta is along the x-axis - then applies the y-axis scaling, then re-rotates everything back to where it was. - """ - if inv_scale: - scale = (1 / q) - 1 - else: - scale = q - 1 - ss = 1 + scale * np.sin(theta) ** 2 - cc = 1 + scale * np.cos(theta) ** 2 - s2 = scale * np.sin(2 * theta) - return ss * X - s2 * Y / 2, -s2 * X / 2 + cc * Y diff --git a/astrophot/utils/conversions/dict_to_hdf5.py b/astrophot/utils/conversions/dict_to_hdf5.py deleted file mode 100644 index d1b02354..00000000 --- a/astrophot/utils/conversions/dict_to_hdf5.py +++ /dev/null @@ -1,38 +0,0 @@ -def to_hdf5_has_None(l): - for i in range(len(l)): - if hasattr(l[i], "__iter__") and not isinstance(l[i], str): - l[i] = to_hdf5_has_None(l[i]) - elif l[i] is None: - return True - return False - - -def dict_to_hdf5(h, D): - for key in D: - if isinstance(D[key], dict): - n = h.create_group(key) - dict_to_hdf5(n, D[key]) - else: - if hasattr(D[key], "__iter__") and not isinstance(D[key], str): - if to_hdf5_has_None(D[key]): - h[key] = str(D[key]) - else: - h.create_dataset(key, data=D[key]) - elif D[key] is not None: - h[key] = D[key] - else: - h[key] = "None" - - -def hdf5_to_dict(h): - import h5py - - D = {} - for key in h.keys(): - if isinstance(h[key], h5py.Group): - D[key] = hdf5_to_dict(h[key]) - elif isinstance(h[key], str) and "None" in h[key]: - D[key] = eval(h[key]) - else: - D[key] = h[key] - return D diff --git a/astrophot/utils/conversions/functions.py b/astrophot/utils/conversions/functions.py index 98540df4..1a2f1c60 100644 --- a/astrophot/utils/conversions/functions.py +++ b/astrophot/utils/conversions/functions.py @@ -1,12 +1,29 @@ +from typing import Union import numpy as np -import torch from scipy.special import gamma -from torch.special import gammaln - - -def sersic_n_to_b(n): +from ...backend_obj import backend, ArrayLike + +__all__ = ( + "sersic_n_to_b", + "sersic_I0_to_flux_np", + "sersic_flux_to_I0_np", + "sersic_Ie_to_flux_np", + "sersic_flux_to_Ie_np", + "sersic_I0_to_flux_torch", + "sersic_flux_to_I0_torch", + "sersic_Ie_to_flux_torch", + "sersic_flux_to_Ie_torch", + "sersic_inv_np", + "sersic_inv_torch", + "moffat_I0_to_flux", +) + + +def sersic_n_to_b( + n: Union[float, np.ndarray, ArrayLike], +) -> Union[float, np.ndarray, ArrayLike]: """Compute the `b(n)` for a sersic model. This factor ensures that - the :math:`R_e` and :math:`I_e` parameters do in fact correspond + the $R_e$ and $I_e$ parameters do in fact correspond to the half light values and not some other scale radius/intensity. @@ -22,95 +39,90 @@ def sersic_n_to_b(n): ) -def sersic_I0_to_flux_np(I0, n, R, q): +def sersic_I0_to_flux_np(I0: np.ndarray, n: np.ndarray, R: np.ndarray, q: np.ndarray) -> np.ndarray: """Compute the total flux integrated to infinity for a 2D elliptical - sersic given the :math:`I_0,n,R_s,q` parameters which uniquely - define the profile (:math:`I_0` is the central intensity in - flux/arcsec^2). Note that :math:`R_s` is not the effective radius, + sersic given the $I_0,n,R_s,q$ parameters which uniquely + define the profile ($I_0$ is the central intensity in + flux/arcsec^2). Note that $R_s$ is not the effective radius, but in fact the scale radius in the more straightforward sersic representation: - .. math:: - - I(R) = I_0e^{-(R/R_s)^{1/n}} + $$I(R) = I_0e^{-(R/R_s)^{1/n}}$$ - Args: - I0: central intensity (flux/arcsec^2) - n: sersic index - R: Scale radius - q: axis ratio (b/a) + **Args:** + - `I0`: central intensity (flux/arcsec^2) + - `n`: sersic index + - `R`: Scale radius + - `q`: axis ratio (b/a) """ return 2 * np.pi * I0 * q * n * R**2 * gamma(2 * n) -def sersic_flux_to_I0_np(flux, n, R, q): +def sersic_flux_to_I0_np( + flux: np.ndarray, n: np.ndarray, R: np.ndarray, q: np.ndarray +) -> np.ndarray: """Compute the central intensity (flux/arcsec^2) for a 2D elliptical - sersic given the :math:`F,n,R_s,q` parameters which uniquely - define the profile (:math:`F` is the total flux integrated to - infinity). Note that :math:`R_s` is not the effective radius, but + sersic given the $F,n,R_s,q$ parameters which uniquely + define the profile ($F$ is the total flux integrated to + infinity). Note that $R_s$ is not the effective radius, but in fact the scale radius in the more straightforward sersic representation: - .. math:: + $$I(R) = I_0e^{-(R/R_s)^{1/n}}$$ - I(R) = I_0e^{-(R/R_s)^{1/n}} - - Args: - flux: total flux integrated to infinity (flux) - n: sersic index - R: Scale radius - q: axis ratio (b/a) + **Args:** + - `flux`: total flux integrated to infinity (flux) + - `n`: sersic index + - `R`: Scale radius + - `q`: axis ratio (b/a) """ return flux / (2 * np.pi * q * n * R**2 * gamma(2 * n)) -def sersic_Ie_to_flux_np(Ie, n, R, q): +def sersic_Ie_to_flux_np(Ie: np.ndarray, n: np.ndarray, R: np.ndarray, q: np.ndarray) -> np.ndarray: """Compute the total flux integrated to infinity for a 2D elliptical - sersic given the :math:`I_e,n,R_e,q` parameters which uniquely - define the profile (:math:`I_e` is the intensity at :math:`R_e` in - flux/arcsec^2). Note that :math:`R_e` is the effective radius in + sersic given the $I_e,n,R_e,q$ parameters which uniquely + define the profile ($I_e$ is the intensity at $R_e$ in + flux/arcsec^2). Note that $R_e$ is the effective radius in the sersic representation: - .. math:: - - I(R) = I_ee^{-b_n[(R/R_e)^{1/n}-1]} - - Args: - Ie: intensity at the effective radius (flux/arcsec^2) - n: sersic index - R: Scale radius - q: axis ratio (b/a) + $$I(R) = I_ee^{-b_n[(R/R_e)^{1/n}-1]}$$ + **Args:** + - `Ie`: intensity at the effective radius (flux/arcsec^2) + - `n`: sersic index + - `R`: Scale radius + - `q`: axis ratio (b/a) """ bn = sersic_n_to_b(n) return 2 * np.pi * Ie * R**2 * q * n * (np.exp(bn) * bn ** (-2 * n)) * gamma(2 * n) -def sersic_flux_to_Ie_np(flux, n, R, q): - """Compute the intensity at :math:`R_e` (flux/arcsec^2) for a 2D - elliptical sersic given the :math:`F,n,R_e,q` parameters which - uniquely define the profile (:math:`F` is the total flux - integrated to infinity). Note that :math:`R_e` is the effective +def sersic_flux_to_Ie_np( + flux: np.ndarray, n: np.ndarray, R: np.ndarray, q: np.ndarray +) -> np.ndarray: + """Compute the intensity at $R_e$ (flux/arcsec^2) for a 2D + elliptical sersic given the $F,n,R_e,q$ parameters which + uniquely define the profile ($F$ is the total flux + integrated to infinity). Note that $R_e$ is the effective radius in the sersic representation: - .. math:: + $$I(R) = I_ee^{-b_n[(R/R_e)^{1/n}-1]}$$ - I(R) = I_ee^{-b_n[(R/R_e)^{1/n}-1]} - - Args: - flux: flux integrated to infinity (flux) - n: sersic index - R: Scale radius - q: axis ratio (b/a) + **Args:** + - `flux`: flux integrated to infinity (flux) + - `n`: sersic index + - `R`: Scale radius + - `q`: axis ratio (b/a) """ bn = sersic_n_to_b(n) return flux / (2 * np.pi * R**2 * q * n * (np.exp(bn) * bn ** (-2 * n)) * gamma(2 * n)) -def sersic_inv_np(I, n, Re, Ie): +def sersic_inv_np(I: np.ndarray, n: np.ndarray, Re: np.ndarray, Ie: np.ndarray) -> np.ndarray: """Invert the sersic profile. Compute the radius corresponding to a given intensity for a pure sersic profile. @@ -119,119 +131,123 @@ def sersic_inv_np(I, n, Re, Ie): return Re * ((1 - (1 / bn) * np.log(I / Ie)) ** (n)) -def sersic_I0_to_flux_torch(I0, n, R, q): +def sersic_I0_to_flux_torch(I0: ArrayLike, n: ArrayLike, R: ArrayLike, q: ArrayLike) -> ArrayLike: """Compute the total flux integrated to infinity for a 2D elliptical - sersic given the :math:`I_0,n,R_s,q` parameters which uniquely - define the profile (:math:`I_0` is the central intensity in - flux/arcsec^2). Note that :math:`R_s` is not the effective radius, + sersic given the $I_0,n,R_s,q$ parameters which uniquely + define the profile ($I_0$ is the central intensity in + flux/arcsec^2). Note that $R_s$ is not the effective radius, but in fact the scale radius in the more straightforward sersic representation: - .. math:: - - I(R) = I_0e^{-(R/R_s)^{1/n}} + $$I(R) = I_0e^{-(R/R_s)^{1/n}}$$ - Args: - I0: central intensity (flux/arcsec^2) - n: sersic index - R: Scale radius - q: axis ratio (b/a) + **Args:** + - `I0`: central intensity (flux/arcsec^2) + - `n`: sersic index + - `R`: Scale radius + - `q`: axis ratio (b/a) """ - return 2 * np.pi * I0 * q * n * R**2 * torch.exp(gammaln(2 * n)) + return 2 * np.pi * I0 * q * n * R**2 * backend.exp(backend.gammaln(2 * n)) -def sersic_flux_to_I0_torch(flux, n, R, q): +def sersic_flux_to_I0_torch(flux: ArrayLike, n: ArrayLike, R: ArrayLike, q: ArrayLike) -> ArrayLike: """Compute the central intensity (flux/arcsec^2) for a 2D elliptical - sersic given the :math:`F,n,R_s,q` parameters which uniquely - define the profile (:math:`F` is the total flux integrated to - infinity). Note that :math:`R_s` is not the effective radius, but + sersic given the $F,n,R_s,q$ parameters which uniquely + define the profile ($F$ is the total flux integrated to + infinity). Note that $R_s$ is not the effective radius, but in fact the scale radius in the more straightforward sersic representation: - .. math:: - - I(R) = I_0e^{-(R/R_s)^{1/n}} - - Args: - flux: total flux integrated to infinity (flux) - n: sersic index - R: Scale radius - q: axis ratio (b/a) + $$I(R) = I_0e^{-(R/R_s)^{1/n}}$$ + **Args:** + - `flux`: total flux integrated to infinity (flux) + - `n`: sersic index + - `R`: Scale radius + - `q`: axis ratio (b/a) """ - return flux / (2 * np.pi * q * n * R**2 * torch.exp(gammaln(2 * n))) + return flux / (2 * np.pi * q * n * R**2 * backend.exp(backend.gammaln(2 * n))) -def sersic_Ie_to_flux_torch(Ie, n, R, q): +def sersic_Ie_to_flux_torch(Ie: ArrayLike, n: ArrayLike, R: ArrayLike, q: ArrayLike) -> ArrayLike: """Compute the total flux integrated to infinity for a 2D elliptical - sersic given the :math:`I_e,n,R_e,q` parameters which uniquely - define the profile (:math:`I_e` is the intensity at :math:`R_e` in - flux/arcsec^2). Note that :math:`R_e` is the effective radius in + sersic given the $I_e,n,R_e,q$ parameters which uniquely + define the profile ($I_e$ is the intensity at $R_e$ in + flux/arcsec^2). Note that $R_e$ is the effective radius in the sersic representation: - .. math:: - - I(R) = I_ee^{-b_n[(R/R_e)^{1/n}-1]} + $$I(R) = I_ee^{-b_n[(R/R_e)^{1/n}-1]}$$ - Args: - Ie: intensity at the effective radius (flux/arcsec^2) - n: sersic index - R: Scale radius - q: axis ratio (b/a) + **Args:** + - `Ie`: intensity at the effective radius (flux/arcsec^2) + - `n`: sersic index + - `R`: Scale radius + - `q`: axis ratio (b/a) """ bn = sersic_n_to_b(n) return ( - 2 * np.pi * Ie * R**2 * q * n * (torch.exp(bn) * bn ** (-2 * n)) * torch.exp(gammaln(2 * n)) + 2 + * np.pi + * Ie + * R**2 + * q + * n + * (backend.exp(bn) * bn ** (-2 * n)) + * backend.exp(backend.gammaln(2 * n)) ) -def sersic_flux_to_Ie_torch(flux, n, R, q): - """Compute the intensity at :math:`R_e` (flux/arcsec^2) for a 2D - elliptical sersic given the :math:`F,n,R_e,q` parameters which - uniquely define the profile (:math:`F` is the total flux - integrated to infinity). Note that :math:`R_e` is the effective +def sersic_flux_to_Ie_torch(flux: ArrayLike, n: ArrayLike, R: ArrayLike, q: ArrayLike) -> ArrayLike: + """Compute the intensity at $R_e$ (flux/arcsec^2) for a 2D + elliptical sersic given the $F,n,R_e,q$ parameters which + uniquely define the profile ($F$ is the total flux + integrated to infinity). Note that $R_e$ is the effective radius in the sersic representation: - .. math:: - - I(R) = I_ee^{-b_n[(R/R_e)^{1/n}-1]} + $$I(R) = I_ee^{-b_n[(R/R_e)^{1/n}-1]}$$ - Args: - flux: flux integrated to infinity (flux) - n: sersic index - R: Scale radius - q: axis ratio (b/a) + **Args:** + - `flux`: flux integrated to infinity (flux) + - `n`: sersic index + - `R`: Scale radius + - `q`: axis ratio (b/a) """ bn = sersic_n_to_b(n) return flux / ( - 2 * np.pi * R**2 * q * n * (torch.exp(bn) * bn ** (-2 * n)) * torch.exp(gammaln(2 * n)) + 2 + * np.pi + * R**2 + * q + * n + * (backend.exp(bn) * bn ** (-2 * n)) + * backend.exp(backend.gammaln(2 * n)) ) -def sersic_inv_torch(I, n, Re, Ie): +def sersic_inv_torch(I: ArrayLike, n: ArrayLike, Re: ArrayLike, Ie: ArrayLike) -> ArrayLike: """Invert the sersic profile. Compute the radius corresponding to a given intensity for a pure sersic profile. """ bn = sersic_n_to_b(n) - return Re * ((1 - (1 / bn) * torch.log(I / Ie)) ** (n)) + return Re * ((1 - (1 / bn) * backend.log(I / Ie)) ** (n)) -def moffat_I0_to_flux(I0, n, rd, q): +def moffat_I0_to_flux(I0: float, n: float, rd: float, q: float) -> float: """ Compute the total flux integrated to infinity for a moffat profile. - Args: - I0: central intensity (flux/arcsec^2) - n: moffat curvature parameter (unitless) - rd: scale radius - q: axis ratio + **Args:** + - `I0`: central intensity (flux/arcsec^2) + - `n`: moffat curvature parameter (unitless) + - `rd`: scale radius + - `q`: axis ratio """ return I0 * np.pi * rd**2 * q / (n - 1) diff --git a/astrophot/utils/conversions/optimization.py b/astrophot/utils/conversions/optimization.py deleted file mode 100644 index ca3696a6..00000000 --- a/astrophot/utils/conversions/optimization.py +++ /dev/null @@ -1,75 +0,0 @@ -import numpy as np -import torch -from ... import AP_config - - -def boundaries(val, limits): - """val in limits expanded to range -inf to inf""" - tval = torch.as_tensor(val, device=AP_config.ap_device, dtype=AP_config.ap_dtype) - - if limits[0] is None: - return tval - 1.0 / (tval - limits[1]) - elif limits[1] is None: - return tval - 1.0 / (tval - limits[0]) - return torch.tan((tval - limits[0]) * np.pi / (limits[1] - limits[0]) - np.pi / 2) - - -def inv_boundaries(val, limits): - """val in range -inf to inf compressed to within the limits""" - tval = torch.as_tensor(val, device=AP_config.ap_device, dtype=AP_config.ap_dtype) - - if limits[0] is None: - return (tval + limits[1] - torch.sqrt(torch.pow(tval - limits[1], 2) + 4)) * 0.5 - elif limits[1] is None: - return (tval + limits[0] + torch.sqrt(torch.pow(tval - limits[0], 2) + 4)) * 0.5 - return (torch.arctan(tval) + np.pi / 2) * (limits[1] - limits[0]) / np.pi + limits[0] - - -def d_boundaries_dval(val, limits): - """derivative of: val in limits expanded to range -inf to inf""" - tval = torch.as_tensor(val, device=AP_config.ap_device, dtype=AP_config.ap_dtype) - if limits[0] is None: - return 1.0 + 1.0 / (tval - limits[1]) ** 2 - elif limits[1] is None: - return 1.0 - 1.0 / (tval - limits[0]) ** 2 - return (np.pi / (limits[1] - limits[0])) / torch.cos( - (tval - limits[0]) * np.pi / (limits[1] - limits[0]) - np.pi / 2 - ) ** 2 - - -def d_inv_boundaries_dval(val, limits): - """derivative of: val in range -inf to inf compressed to within the limits""" - tval = torch.as_tensor(val, device=AP_config.ap_device, dtype=AP_config.ap_dtype) - if limits[0] is None: - return 0.5 - 0.5 * (tval - limits[1]) / torch.sqrt(torch.pow(tval - limits[1], 2) + 4) - elif limits[1] is None: - return 0.5 + 0.5 * (tval - limits[0]) / torch.sqrt(torch.pow(tval - limits[0], 2) + 4) - return (limits[1] - limits[0]) / (np.pi * (tval**2 + 1)) - - -def cyclic_boundaries(val, limits): - """Applies cyclic boundary conditions to the input value.""" - tval = torch.as_tensor(val, device=AP_config.ap_device, dtype=AP_config.ap_dtype) - - return limits[0] + ((tval - limits[0]) % (limits[1] - limits[0])) - - -def cyclic_difference_torch(val1, val2, period): - """Applies the difference operation between two values with cyclic - boundary conditions. - - """ - tval1 = torch.as_tensor(val1, device=AP_config.ap_device, dtype=AP_config.ap_dtype) - tval2 = torch.as_tensor(val2, device=AP_config.ap_device, dtype=AP_config.ap_dtype) - - return torch.arcsin(torch.sin((tval1 - tval2) * np.pi / period)) * period / np.pi - - -def cyclic_difference_np(val1, val2, period): - """Applies the difference operation between two values with cyclic - boundary conditions. - - """ - tval1 = torch.as_tensor(val1, device=AP_config.ap_device, dtype=AP_config.ap_dtype) - tval2 = torch.as_tensor(val2, device=AP_config.ap_device, dtype=AP_config.ap_dtype) - return np.arcsin(np.sin((tval1 - tval2) * np.pi / period)) * period / np.pi diff --git a/astrophot/utils/conversions/units.py b/astrophot/utils/conversions/units.py index 64961906..3e1a3026 100644 --- a/astrophot/utils/conversions/units.py +++ b/astrophot/utils/conversions/units.py @@ -1,30 +1,40 @@ +from typing import Optional import numpy as np +__all__ = ( + "deg_to_arcsec", + "arcsec_to_deg", + "flux_to_sb", + "flux_to_mag", + "sb_to_flux", + "mag_to_flux", + "magperarcsec2_to_mag", + "mag_to_magperarcsec2", + "PA_shift_convention", +) + deg_to_arcsec = 3600.0 +arcsec_to_deg = 1.0 / deg_to_arcsec -def flux_to_sb(flux, pixel_area, zeropoint): +def flux_to_sb(flux: float, pixel_area: float, zeropoint: float) -> float: """Conversion from flux units to logarithmic surface brightness units. - .. math:: - - \\mu = -2.5\\log_{10}(flux) + z.p. + 2.5\\log_{10}(A) + $$\\mu = -2.5\\log_{10}(flux) + z.p. + 2.5\\log_{10}(A)$$ - where :math:`z.p.` is the zeropoint and :math:`A` is the area of a pixel. + where $z.p.$ is the zeropoint and $A$ is the area of a pixel. """ return -2.5 * np.log10(flux) + zeropoint + 2.5 * np.log10(pixel_area) -def flux_to_mag(flux, zeropoint, fluxe=None): +def flux_to_mag(flux: float, zeropoint: float, fluxe: Optional[float] = None) -> float: """Converts a flux total into logarithmic magnitude units. - .. math:: - - m = -2.5\\log_{10}(flux) + z.p. + $$m = -2.5\\log_{10}(flux) + z.p.$$ - where :math:`z.p.` is the zeropoint. + where $z.p.$ is the zeropoint. """ if fluxe is None: @@ -33,27 +43,23 @@ def flux_to_mag(flux, zeropoint, fluxe=None): return -2.5 * np.log10(flux) + zeropoint, 2.5 * fluxe / (np.log(10) * flux) -def sb_to_flux(sb, pixel_area, zeropoint): +def sb_to_flux(sb: float, pixel_area: float, zeropoint: float) -> float: """Converts logarithmic surface brightness units into flux units. - .. math:: + $$flux = A 10^{-(\\mu - z.p.)/2.5}$$ - flux = A 10^{-(\\mu - z.p.)/2.5} - - where :math:`z.p.` is the zeropoint and :math:`A` is the area of a pixel. + where $z.p.$ is the zeropoint and $A$ is the area of a pixel. """ return pixel_area * 10 ** (-(sb - zeropoint) / 2.5) -def mag_to_flux(mag, zeropoint, mage=None): +def mag_to_flux(mag: float, zeropoint: float, mage: Optional[float] = None) -> float: """converts logarithmic magnitude units into a flux total. - .. math:: - - flux = 10^{-(m - z.p.)/2.5} + $$flux = 10^{-(m - z.p.)/2.5}$$ - where :math:`z.p.` is the zeropoint. + where $z.p.$ is the zeropoint. """ if mage is None: @@ -63,21 +69,22 @@ def mag_to_flux(mag, zeropoint, mage=None): return I, np.log(10) * I * mage / 2.5 -def magperarcsec2_to_mag(mu, a=None, b=None, A=None): +def magperarcsec2_to_mag( + mu: float, a: Optional[float] = None, b: Optional[float] = None, A: Optional[float] = None +) -> float: """ Converts mag/arcsec^2 to mag - mu: mag/arcsec^2 - a: semi major axis radius (arcsec) - b: semi minor axis radius (arcsec) - A: pre-calculated area (arcsec^2) - returns: mag + **Args:** + - `mu`: mag/arcsec^2 + - `a`: semi major axis radius (arcsec) + - `b`: semi minor axis radius (arcsec) + - `A`: pre-calculated area (arcsec^2) - .. math:: - m = \\mu -2.5\\log_{10}(A) + $$m = \\mu -2.5\\log_{10}(A)$$ - where :math:`A` is an area in arcsec^2. + where $A$ is an area in arcsec^2. """ assert (A is not None) or (a is not None and b is not None) @@ -88,20 +95,26 @@ def magperarcsec2_to_mag(mu, a=None, b=None, A=None): ) # https://en.wikipedia.org/wiki/Surface_brightness#Calculating_surface_brightness -def mag_to_magperarcsec2(m, a=None, b=None, R=None, A=None): +def mag_to_magperarcsec2( + m: float, + a: Optional[float] = None, + b: Optional[float] = None, + R: Optional[float] = None, + A: Optional[float] = None, +) -> float: """ Converts mag to mag/arcsec^2 - m: mag - a: semi major axis radius (arcsec) - b: semi minor axis radius (arcsec) - A: pre-calculated area (arcsec^2) - returns: mag/arcsec^2 - .. math:: + **Args:** + - `m`: mag + - `a`: semi major axis radius (arcsec) + - `b`: semi minor axis radius (arcsec) + - `A`: pre-calculated area (arcsec^2) + - \\mu = m + 2.5\\log_{10}(A) + $$\\mu = m + 2.5\\log_{10}(A)$$ - where :math:`A` is an area in arcsec^2. + where $A$ is an area in arcsec^2. """ assert (A is not None) or (a is not None and b is not None) or (R is not None) if R is not None: @@ -111,18 +124,3 @@ def mag_to_magperarcsec2(m, a=None, b=None, R=None, A=None): return m + 2.5 * np.log10( A ) # https://en.wikipedia.org/wiki/Surface_brightness#Calculating_surface_brightness - - -def PA_shift_convention(pa, unit="rad"): - """ - Alternates between standard mathematical convention for angles, and astronomical position angle convention. - The standard convention is to measure angles counter-clockwise relative to the positive x-axis - The astronomical convention is to measure angles counter-clockwise relative to the positive y-axis - """ - - if unit == "rad": - shift = np.pi - elif unit == "deg": - shift = 180.0 - - return (pa - (shift / 2)) % shift diff --git a/astrophot/utils/decorators.py b/astrophot/utils/decorators.py index b1596ce1..ec556f60 100644 --- a/astrophot/utils/decorators.py +++ b/astrophot/utils/decorators.py @@ -1,9 +1,19 @@ from functools import wraps -import inspect import warnings +from inspect import cleandoc import numpy as np +__all__ = ("classproperty", "ignore_numpy_warnings", "combine_docstrings") + + +class classproperty: + def __init__(self, fget): + self.fget = fget + + def __get__(self, instance, owner): + return self.fget(owner) + def ignore_numpy_warnings(func): """This decorator is used to turn off numpy warnings. This should @@ -27,26 +37,13 @@ def wrapped(*args, **kwargs): return wrapped -def default_internal(func): - """This decorator inspects the input parameters for a function which - expects to receive `image` and `parameters` arguments. If either - of these are not given, then the model can use its default values - for the parameters assuming the `image` is the internal `target` - object and the `parameters` are the internally stored parameters. - - """ - sig = inspect.signature(func) - - @wraps(func) - def wrapper(self, *args, **kwargs): - bound = sig.bind(self, *args, **kwargs) - bound.apply_defaults() - - if bound.arguments.get("image") is None: - bound.arguments["image"] = self.target - if bound.arguments.get("parameters") is None: - bound.arguments["parameters"] = self.parameters - - return func(*bound.args, **bound.kwargs) - - return wrapper +def combine_docstrings(cls): + try: + combined_docs = [cleandoc(cls.__doc__)] + except AttributeError: + combined_docs = [] + for base in cls.__bases__: + if base.__doc__: + combined_docs.append(f"\n\n> SUBUNIT {base.__name__}\n\n{cleandoc(base.__doc__)}") + cls.__doc__ = "\n".join(combined_docs).strip() + return cls diff --git a/astrophot/utils/fitsopen.py b/astrophot/utils/fitsopen.py new file mode 100644 index 00000000..5c9a8d70 --- /dev/null +++ b/astrophot/utils/fitsopen.py @@ -0,0 +1,117 @@ +import numpy as np +import warnings +from astropy.utils.data import download_file +from astropy.io import fits +from astropy.utils.exceptions import AstropyWarning +from numpy.core.defchararray import startswith + +try: + from pyvo.dal import sia +except: + sia = None +import os + +# Suppress common Astropy warnings that can clutter CI logs +warnings.simplefilter("ignore", category=AstropyWarning) + + +def flip_hdu(hdu): + """ + Flips the image data in the FITS HDU on the RA axis to match the expected orientation. + + Args: + hdu (astropy.io.fits.HDUList): The FITS HDU to be flipped. + + Returns: + astropy.io.fits.HDUList: The flipped FITS HDU. + """ + assert "CD1_1" in hdu[0].header, "HDU does not contain WCS information." + assert "CD2_1" in hdu[0].header, "HDU does not contain WCS information." + assert "CRPIX1" in hdu[0].header, "HDU does not contain WCS information." + assert "NAXIS1" in hdu[0].header, "HDU does not contain WCS information." + hdu[0].data = hdu[0].data[:, ::-1].copy() + hdu[0].header["CD1_1"] = -hdu[0].header["CD1_1"] + hdu[0].header["CD2_1"] = -hdu[0].header["CD2_1"] + hdu[0].header["CRPIX1"] = int(hdu[0].header["NAXIS1"] / 2) + 1 + hdu[0].header["CRPIX2"] = int(hdu[0].header["NAXIS2"] / 2) + 1 + return hdu + + +def ls_open(ra, dec, size_arcsec, band="r", release="ls_dr9"): + """ + Retrieves and opens a FITS cutout from the deepest stacked image in the + specified Legacy Survey data release using the Astro Data Lab SIA service. + + Args: + ra (float): Right Ascension in decimal degrees. + dec (float): Declination in decimal degrees. + size_arcsec (float): Size of the square cutout (side length) in arcseconds. + band (str): The filter band (e.g., 'g', 'r', 'z'). Case-insensitive. + release (str): The Legacy Survey Data Release (e.g., 'DR9'). + + Returns: + astropy.io.fits.HDUList: The opened FITS file object. + """ + + if sia is None: + raise ImportError( + "Cannot use ls_open without pyvo. Please install pyvo (pip install pyvo) before continuing." + ) + + # 1. Set the specific SIA service endpoint for the desired release + # SIA endpoints for specific surveys are listed in the notebook. + service_url = f"https://datalab.noirlab.edu/sia/{release.lower()}" + svc = sia.SIAService(service_url) + + # 2. Convert size from arcseconds to degrees (FOV) for the SIA query + # and apply the cosine correction for RA. + fov_deg = size_arcsec / 3600.0 + + # The search method takes the position (RA, Dec) and the square FOV. + imgTable = svc.search( + (ra, dec), (fov_deg / np.cos(dec * np.pi / 180.0), fov_deg), verbosity=2 + ).to_table() + + # 3. Filter the table for stacked images in the specified band + target_band = band.lower() + + sel = ( + (imgTable["proctype"] == "Stack") + & (imgTable["prodtype"] == "image") + & (startswith(imgTable["obs_bandpass"].astype(str), target_band)) + ) + + Table = imgTable[sel] + + if len(Table) == 0: + raise ValueError( + f"No stacked FITS image found for {release} band '{band}' at the requested RA {ra} and Dec {dec}." + ) + + # 4. Pick the "deepest" image (longest exposure time) + # Note: 'exptime' data needs explicit float conversion for np.argmax + max_exptime_index = np.argmax(Table["exptime"].data.data.astype("float")) + row = Table[max_exptime_index] + + # 5. Download the file and open it + url = row["access_url"] # get the download URL + + # Use astropy's download_file, which handles the large data transfer + # and automatically uses a long timeout (120s in the notebook example) + filename = download_file(url, cache=False, show_progress=False, timeout=120) + + # Open the downloaded FITS file + hdu = fits.open(filename) + + try: + hdu = flip_hdu(hdu) + except AssertionError: + pass # If WCS info is missing, skip flipping + + # Clean up the temporary file created by download_file + try: + os.remove(filename) + except OSError: + pass # Ignore if cleanup fails + + return hdu diff --git a/astrophot/utils/initialize/PA.py b/astrophot/utils/initialize/PA.py new file mode 100644 index 00000000..59af6acc --- /dev/null +++ b/astrophot/utils/initialize/PA.py @@ -0,0 +1,13 @@ +from scipy.linalg import sqrtm +import numpy as np + + +def polar_decomposition(A): + # Step 1: Compute symmetric positive-definite matrix P + M = A.T @ A + P = sqrtm(M) # Principal square root of A^T A + + # Step 2: Compute rotation matrix R + P_inv = np.linalg.inv(P) + R = A @ P_inv + return R, P diff --git a/astrophot/utils/initialize/__init__.py b/astrophot/utils/initialize/__init__.py index d634daa1..9708041a 100644 --- a/astrophot/utils/initialize/__init__.py +++ b/astrophot/utils/initialize/__init__.py @@ -1,17 +1,22 @@ -from .segmentation_map import * -from .initialize import isophotes -from .center import center_of_mass, GaussianDensity_Peak, Lanczos_peak -from .construct_psf import gaussian_psf, moffat_psf, construct_psf +from .segmentation_map import ( + centroids_from_segmentation_map, + PA_from_segmentation_map, + q_from_segmentation_map, + windows_from_segmentation_map, + scale_windows, + filter_windows, + transfer_windows, +) +from .center import center_of_mass, recursive_center_of_mass +from .construct_psf import gaussian_psf, moffat_psf from .variance import auto_variance +from .PA import polar_decomposition __all__ = ( - "isophotes", "center_of_mass", - "GaussianDensity_Peak", - "Lanczos_peak", + "recursive_center_of_mass", "gaussian_psf", "moffat_psf", - "construct_psf", "centroids_from_segmentation_map", "PA_from_segmentation_map", "q_from_segmentation_map", @@ -20,4 +25,5 @@ "filter_windows", "transfer_windows", "auto_variance", + "polar_decomposition", ) diff --git a/astrophot/utils/initialize/center.py b/astrophot/utils/initialize/center.py index c895339f..0977d42b 100644 --- a/astrophot/utils/initialize/center.py +++ b/astrophot/utils/initialize/center.py @@ -1,88 +1,35 @@ import numpy as np -from scipy.optimize import minimize -from ..interpolate import point_Lanczos -from ... import AP_config - - -def center_of_mass(center, image, window=None): - """Iterative light weighted center of mass optimization. Each step - determines the light weighted center of mass within a small - window. The new center is used to create a new window. This - continues until the center no longer updates or an image boundary - is reached. - - """ - if window is None: - window = max(min(int(min(image.shape) / 10), 30), 6) - init_center = center - window += window % 2 - xx, yy = np.meshgrid(np.arange(window), np.arange(window)) - for iteration in range(100): - # Determine the image window to calculate COM - ranges = [ - [int(round(center[0]) - window / 2), int(round(center[0]) + window / 2)], - [int(round(center[1]) - window / 2), int(round(center[1]) + window / 2)], - ] - # Avoid edge of image - if ( - ranges[0][0] < 0 - or ranges[1][0] < 0 - or ranges[0][1] >= image.shape[0] - or ranges[1][1] >= image.shape[1] - ): - AP_config.ap_logger.warning("Image edge!") - return init_center - - # Compute COM - denom = np.sum(image[ranges[0][0] : ranges[0][1], ranges[1][0] : ranges[1][1]]) - new_center = [ - ranges[0][0] - + np.sum(image[ranges[0][0] : ranges[0][1], ranges[1][0] : ranges[1][1]] * yy) / denom, - ranges[1][0] - + np.sum(image[ranges[0][0] : ranges[0][1], ranges[1][0] : ranges[1][1]] * xx) / denom, - ] - new_center = np.array(new_center) - # Check for convergence - if np.sum(np.abs(np.array(center) - new_center)) < 0.1: - break - - center = new_center +def center_of_mass(image): + """Determines the light weighted center of mass""" + ii, jj = np.meshgrid(np.arange(image.shape[0]), np.arange(image.shape[1]), indexing="ij") + center = np.array((np.sum(image * ii), np.sum(image * jj))) / np.sum(image) return center -def GaussianDensity_Peak(center, image, window=10, std=0.5): - init_center = center - window += window % 2 +def recursive_center_of_mass(image, max_iter=10, tol=1e-1): + """Determines the light weighted center of mass in a progressively smaller window each time centered on the previous center.""" - def _add_flux(c): - r = np.round(center) - xx, yy = np.meshgrid( - np.arange(r[0] - window / 2, r[0] + window / 2 + 1) - c[0], - np.arange(r[1] - window / 2, r[1] + window / 2 + 1) - c[1], + center = center_of_mass(image) + for i in range(max_iter): + width = (image.shape[0] / (3 + i), image.shape[1] / (3 + i)) + ranges = ( + slice( + max(0, int(center[0] - width[0])), min(image.shape[0], int(center[0] + width[0])) + ), + slice( + max(0, int(center[1] - width[1])), min(image.shape[1], int(center[1] + width[1])) + ), ) - rr2 = xx**2 + yy**2 - f = image[ - int(r[1] - window / 2) : int(r[1] + window / 2 + 1), - int(r[0] - window / 2) : int(r[0] + window / 2 + 1), - ] - return -np.sum(np.exp(-rr2 / (2 * std)) * f) - - res = minimize(_add_flux, x0=center) - return res.x + subimage = image[ranges] + if subimage.size < 9: + return center + new_center = center_of_mass(subimage) + new_center += np.array((ranges[0].start, ranges[1].start)) + if np.linalg.norm(new_center - center) < tol: + return new_center -def Lanczos_peak(center, image, Lanczos_scale=3): - best = [np.inf, None] - for dx in np.arange(-3, 4): - for dy in np.arange(-3, 4): - res = minimize( - lambda x: -point_Lanczos(image, x[0], x[1], scale=Lanczos_scale), - x0=(center[0] + dx, center[1] + dy), - method="Nelder-Mead", - ) - if res.fun < best[0]: - best[0] = res.fun - best[1] = res.x - return best[1] + center = new_center + return center diff --git a/astrophot/utils/initialize/construct_psf.py b/astrophot/utils/initialize/construct_psf.py index 24ed6df3..c05bc88e 100644 --- a/astrophot/utils/initialize/construct_psf.py +++ b/astrophot/utils/initialize/construct_psf.py @@ -1,11 +1,17 @@ import numpy as np -from .center import GaussianDensity_Peak -from ..interpolate import shift_Lanczos_np -from ... import AP_config - def gaussian_psf(sigma, img_width, pixelscale, upsample=4, normalize=True): + """ + create a gaussian point spread function (PSF) image. + + **Args:** + - `sigma`: Standard deviation of the Gaussian in arcseconds. + - `img_width`: Width of the PSF image in pixels. + - `pixelscale`: Pixel scale in arcseconds per pixel. + - `upsample`: Upsampling factor to more accurately create the PSF (the outputted PSF is not upsampled). + - `normalize`: Whether to normalize the PSF so that the sum of all pixels equals 1. If False, the PSF will not be normalized. + """ assert img_width % 2 == 1, "psf images should have an odd shape" # Number of super sampled pixels @@ -36,6 +42,17 @@ def gaussian_psf(sigma, img_width, pixelscale, upsample=4, normalize=True): def moffat_psf(n, Rd, img_width, pixelscale, upsample=4, normalize=True): + """ + Create a Moffat point spread function (PSF) image. + + **Args:** + - `n`: Moffat index (power-law index). + - `Rd`: Scale radius of the Moffat profile in arcseconds. + - `img_width`: Width of the PSF image in pixels. + - `pixelscale`: Pixel scale in arcseconds per pixel. + - `upsample`: Upsampling factor to more accurately create the PSF (the outputted PSF is not upsampled). + - `normalize`: Whether to normalize the PSF so that the sum of all pixels equals 1. If False, the PSF will not be normalized. + """ assert img_width % 2 == 1, "psf images should have an odd shape" # Number of super sampled pixels @@ -63,70 +80,3 @@ def moffat_psf(n, Rd, img_width, pixelscale, upsample=4, normalize=True): if normalize: return ZZ / np.sum(ZZ) return ZZ - - -def construct_psf(stars, image, sky_est, size=51, mask=None, keep_init=False, Lanczos_scale=3): - """Given a list of initial guesses for star center locations, finds - the interpolated flux peak, re-centers the stars such that they - are exactly on a pixel center, then median stacks the normalized - stars to determine an average PSF. - - Note that all coordinates in this function are pixel - coordinates. That is, the image[0][0] pixel is at location (0,0) - and the image[2][7] pixel is at location (2,7) in this coordinate - system. - """ - size += 1 - (size % 2) - star_centers = [] - # determine exact (sub-pixel) center for each star - - for star in stars: - if keep_init: - star_centers = list(np.array(s) for s in stars) - break - try: - peak = GaussianDensity_Peak(star, image) - except Exception as e: - AP_config.ap_logger.warning("issue finding star center") - AP_config.ap_logger.warning(e) - AP_config.ap_logger.warning("skipping") - continue - pixel_cen = np.round(peak) - if ( - pixel_cen[0] < ((size - 1) / 2) - or pixel_cen[0] > (image.shape[1] - ((size - 1) / 2) - 1) - or pixel_cen[1] < ((size - 1) / 2) - or pixel_cen[1] > (image.shape[0] - ((size - 1) / 2) - 1) - ): - AP_config.ap_logger.debug("skipping star near edge at: {peak}") - continue - star_centers.append(peak) - - stacking = [] - # Extract the star from the image, and shift to align exactly with pixel grid - for star in star_centers: - center = np.round(star) - border = int((size - 1) / 2 + Lanczos_scale) - I = image[ - int(center[1] - border) : int(center[1] + border + 1), - int(center[0] - border) : int(center[0] + border + 1), - ] - shift = center - star - I = shift_Lanczos_np(I - sky_est, shift[0], shift[1], scale=Lanczos_scale) - I = I[Lanczos_scale:-Lanczos_scale, Lanczos_scale:-Lanczos_scale] - border = (size - 1) / 2 - if mask is not None: - I[ - mask[ - int(center[1] - border) : int(center[1] + border + 1), - int(center[0] - border) : int(center[0] + border + 1), - ] - ] = np.nan - # Add the normalized star image to the list - stacking.append(I / np.sum(I)) - - # Median stack the pixel images - stacked_psf = np.nanmedian(stacking, axis=0) - stacked_psf[stacked_psf < 0] = 0 - - return stacked_psf / np.sum(stacked_psf) diff --git a/astrophot/utils/initialize/initialize.py b/astrophot/utils/initialize/initialize.py deleted file mode 100644 index 3f03ca5f..00000000 --- a/astrophot/utils/initialize/initialize.py +++ /dev/null @@ -1,113 +0,0 @@ -import numpy as np -from scipy.stats import iqr -from scipy.fftpack import fft - -from ..isophote.extract import _iso_extract - - -def isophotes(image, center, threshold=None, pa=None, q=None, R=None, n_isophotes=3, more=False): - """Method for quickly extracting a small number of elliptical - isophotes for the sake of initializing other models. - - """ - - if pa is None: - pa = 0.0 - - if q is None: - q = 1.0 - - if R is None: - # Determine basic threshold if none given - if threshold is None: - threshold = np.nanmedian(image) + 3 * iqr(image[np.isfinite(image)], rng=(16, 84)) / 2 - - # Sample growing isophotes until threshold is reached - ellipse_radii = [1.0] - while ellipse_radii[-1] < (max(image.shape) / 2): - ellipse_radii.append(ellipse_radii[-1] * (1 + 0.2)) - isovals = _iso_extract( - image, - ellipse_radii[-1], - { - "q": q if isinstance(q, float) else np.max(q), - "pa": pa if isinstance(pa, float) else np.min(pa), - }, - {"x": center[0], "y": center[1]}, - more=False, - sigmaclip=True, - sclip_nsigma=3, - ) - if len(isovals) < 3: - continue - # Stop when at 3 time background noise - if (np.quantile(isovals, 0.8) < threshold) and len(ellipse_radii) > 4: - break - R = ellipse_radii[-1] - - # Determine which radii to sample based on input R, pa, and q - if isinstance(pa, float) and isinstance(q, float) and isinstance(R, float): - if n_isophotes == 1: - isophote_radii = [R] - else: - isophote_radii = np.linspace(0, R, n_isophotes) - elif hasattr(R, "__len__"): - isophote_radii = R - elif hasattr(pa, "__len__"): - isophote_radii = np.ones(len(pa)) * R - elif hasattr(q, "__len__"): - isophote_radii = np.ones(len(q)) * R - - # Sample the requested isophotes and record desired info - iso_info = [] - for i, r in enumerate(isophote_radii): - iso_info.append({"R": r}) - isovals = _iso_extract( - image, - r, - { - "q": q if isinstance(q, float) else q[i], - "pa": pa if isinstance(pa, float) else pa[i], - }, - {"x": center[0], "y": center[1]}, - more=more, - sigmaclip=True, - sclip_nsigma=3, - interp_mask=True, - ) - if more: - angles = isovals[1] - isovals = isovals[0] - if len(isovals) < 3: - iso_info[-1] = None - continue - coefs = fft(isovals) - iso_info[-1]["phase1"] = np.angle(coefs[1]) - iso_info[-1]["phase2"] = np.angle(coefs[2]) - iso_info[-1]["flux"] = np.median(isovals) - iso_info[-1]["noise"] = iqr(isovals, rng=(16, 84)) / 2 - iso_info[-1]["amplitude1"] = np.abs(coefs[1]) / ( - len(isovals) * (max(0, iso_info[-1]["flux"]) + iso_info[-1]["noise"]) - ) - iso_info[-1]["amplitude2"] = np.abs(coefs[2]) / ( - len(isovals) * (max(0, iso_info[-1]["flux"]) + iso_info[-1]["noise"]) - ) - iso_info[-1]["N"] = len(isovals) - if more: - iso_info[-1]["isovals"] = isovals - iso_info[-1]["angles"] = angles - - # recover lost isophotes just to keep code moving - for i in reversed(range(len(iso_info))): - if iso_info[i] is not None: - good_index = i - break - else: - raise ValueError( - "Unable to recover any isophotes, try on a better band or manually provide values" - ) - for i in range(len(iso_info)): - if iso_info[i] is None: - iso_info[i] = iso_info[good_index] - iso_info[i]["R"] = isophote_radii[i] - return iso_info diff --git a/astrophot/utils/initialize/segmentation_map.py b/astrophot/utils/initialize/segmentation_map.py index f81cf9c3..053d257b 100644 --- a/astrophot/utils/initialize/segmentation_map.py +++ b/astrophot/utils/initialize/segmentation_map.py @@ -1,11 +1,10 @@ from copy import deepcopy -from typing import Union +from typing import Optional, Union import numpy as np -import torch from astropy.io import fits -from ..angle_operations import Angle_COM_PA -from ..operations import axis_ratio_com +from ...backend_obj import backend +from ... import config __all__ = ( "centroids_from_segmentation_map", @@ -32,9 +31,9 @@ def _select_img(img, hduli): def centroids_from_segmentation_map( seg_map: Union[np.ndarray, str], - image: Union[np.ndarray, str], + image: "Image", + sky_level: Optional[float] = None, hdul_index_seg: int = 0, - hdul_index_img: int = 0, skip_index: tuple = (0,), ): """identify centroid centers for all segments in a segmentation map @@ -43,102 +42,131 @@ def centroids_from_segmentation_map( pixel space. A dictionary of pixel centers is produced where the keys of the dictionary correspond to the segment id's. - Parameters: - ---------- - seg_map (Union[np.ndarray, str]): A segmentation map which gives the object identity for each pixel - image (Union[np.ndarray, str]): An Image which will be used in the light weighted center of mass calculation - hdul_index_seg (int): If reading from a fits file this is the hdu list index at which the map is found. Default: 0 - hdul_index_img (int): If reading from a fits file this is the hdu list index at which the image is found. Default: 0 - skip_index (tuple): Lists which identities (if any) in the segmentation map should be ignored. Default (0,) - - Returns: - centroids (dict): dictionary of centroid positions matched to each segment ID. The centroids are in pixel coordinates + **Args:** + - `seg_map` (Union[np.ndarray, str]): A segmentation map which gives the object identity for each pixel + - `image` (Union[np.ndarray, str]): An Image which will be used in the light weighted center of mass calculation + - `sky_level` (float): The sky level to subtract from the image data before calculating centroids. Default: None, which uses the median of the image data. + - `hdul_index_seg` (int): If reading from a fits file this is the hdu list index at which the map is found. Default: 0 + - `skip_index` (tuple): Lists which identities (if any) in the segmentation map should be ignored. Default (0,) """ seg_map = _select_img(seg_map, hdul_index_seg) - image = _select_img(image, hdul_index_img) + seg_map = seg_map.T + if sky_level is None: + sky_level = np.nanmedian(backend.to_numpy(image.data)) + + data = backend.to_numpy(image._data) - sky_level centroids = {} - XX, YY = np.meshgrid(np.arange(seg_map.shape[1]), np.arange(seg_map.shape[0])) + II, JJ = np.meshgrid(np.arange(seg_map.shape[0]), np.arange(seg_map.shape[1]), indexing="ij") for index in np.unique(seg_map): if index is None or index in skip_index: continue N = seg_map == index - xcentroid = np.sum(XX[N] * image[N]) / np.sum(image[N]) - ycentroid = np.sum(YY[N] * image[N]) / np.sum(image[N]) - centroids[index] = [xcentroid, ycentroid] + icentroid = np.sum(II[N] * data[N]) / np.sum(data[N]) + jcentroid = np.sum(JJ[N] * data[N]) / np.sum(data[N]) + xcentroid, ycentroid = image.pixel_to_plane( + backend.as_array(icentroid, dtype=config.DTYPE, device=config.DEVICE), + backend.as_array(jcentroid, dtype=config.DTYPE, device=config.DEVICE), + params=(), + ) + centroids[index] = [xcentroid.item(), ycentroid.item()] return centroids def PA_from_segmentation_map( seg_map: Union[np.ndarray, str], - image: Union[np.ndarray, str], + image: "Image", centroids=None, + sky_level: Optional[float] = None, hdul_index_seg: int = 0, - hdul_index_img: int = 0, skip_index: tuple = (0,), - north=np.pi / 2, + softening: float = 1e-3, ): seg_map = _select_img(seg_map, hdul_index_seg) - image = _select_img(image, hdul_index_img) + + # reverse to match numpy indexing + seg_map = seg_map.T + if sky_level is None: + sky_level = np.nanmedian(backend.to_numpy(image.data)) + + data = backend.to_numpy(image._data) - sky_level if centroids is None: centroids = centroids_from_segmentation_map( seg_map=seg_map, image=image, skip_index=skip_index ) - XX, YY = np.meshgrid(np.arange(image.shape[1]), np.arange(image.shape[0])) - + x, y = image.coordinate_center_meshgrid() + x = backend.to_numpy(x) + y = backend.to_numpy(y) PAs = {} for index in np.unique(seg_map): if index is None or index in skip_index: continue N = seg_map == index - PA = ( - Angle_COM_PA(image[N], XX[N] - centroids[index][0], YY[N] - centroids[index][1]) + north - ) - PAs[index] = PA + xx = x[N] - centroids[index][0] + yy = y[N] - centroids[index][1] + mu20 = np.median(data[N] * np.abs(xx)) + mu02 = np.median(data[N] * np.abs(yy)) + mu11 = np.median(data[N] * xx * yy / np.sqrt(np.abs(xx * yy) + softening**2)) + M = np.array([[mu20, mu11], [mu11, mu02]]) + if np.any(np.iscomplex(M)) or np.any(~np.isfinite(M)): + PAs[index] = np.pi / 2 + else: + PAs[index] = (0.5 * np.arctan2(2 * mu11, mu20 - mu02) - np.pi / 2) % np.pi return PAs def q_from_segmentation_map( seg_map: Union[np.ndarray, str], - image: Union[np.ndarray, str], + image: "Image", centroids=None, - PAs=None, + sky_level: Optional[float] = None, hdul_index_seg: int = 0, - hdul_index_img: int = 0, skip_index: tuple = (0,), - north=np.pi / 2, + softening: float = 1e-3, ): seg_map = _select_img(seg_map, hdul_index_seg) - image = _select_img(image, hdul_index_img) + + # reverse to match numpy indexing + seg_map = seg_map.T + + if sky_level is None: + sky_level = np.nanmedian(backend.to_numpy(image.data)) + + data = backend.to_numpy(image._data) - sky_level if centroids is None: centroids = centroids_from_segmentation_map( seg_map=seg_map, image=image, skip_index=skip_index ) - if PAs is None: - PAs = PA_from_segmentation_map( - seg_map=seg_map, image=image, centroids=centroids, skip_index=skip_index - ) - - XX, YY = np.meshgrid(np.arange(image.shape[1]), np.arange(image.shape[0])) + x, y = image.coordinate_center_meshgrid() + x = backend.to_numpy(x) + y = backend.to_numpy(y) qs = {} for index in np.unique(seg_map): if index is None or index in skip_index: continue N = seg_map == index - qs[index] = axis_ratio_com( - image[N], PAs[index] + north, XX[N] - centroids[index][0], YY[N] - centroids[index][1] - ) + xx = x[N] - centroids[index][0] + yy = y[N] - centroids[index][1] + mu20 = np.median(data[N] * np.abs(xx)) + mu02 = np.median(data[N] * np.abs(yy)) + mu11 = np.median(data[N] * xx * yy / np.sqrt(np.abs(xx * yy) + softening**2)) + M = np.array([[mu20, mu11], [mu11, mu02]]) + if np.any(np.iscomplex(M)) or np.any(~np.isfinite(M)): + qs[index] = 0.7 + else: + l = np.abs(np.sort(np.linalg.eigvals(M))) + qs[index] = np.clip(np.sqrt(l[0] / l[1]), 0.1, 0.9) return qs @@ -151,7 +179,7 @@ def windows_from_segmentation_map(seg_map, hdul_index=0, skip_index=(0,)): boxes according to given factors and returns the coordinates. each window is formatted as a list of lists with: - window = [[xmin,xmax],[ymin,ymax]] + window = [[xmin,ymin],[xmax,ymax]] expand_scale changes the base window by the given factor. expand_border is added afterwards on all sides (so an @@ -159,56 +187,54 @@ def windows_from_segmentation_map(seg_map, hdul_index=0, skip_index=(0,)): """ - if isinstance(seg_map, str): - if seg_map.endswith(".fits"): - hdul = fits.open(seg_map) - seg_map = hdul[hdul_index].data - elif seg_map.endswith(".npy"): - seg_map = np.load(seg_map) - else: - raise ValueError(f"unrecognized file type, should be one of: fits, npy\n{seg_map}") + seg_map = _select_img(seg_map, hdul_index) + + seg_map = seg_map.T windows = {} for index in np.unique(seg_map): if index is None or index in skip_index: continue - Yid, Xid = np.where(seg_map == index) + Iid, Jid = np.where(seg_map == index) # Get window from segmap - windows[index] = [[np.min(Xid), np.max(Xid)], [np.min(Yid), np.max(Yid)]] + windows[index] = [[np.min(Iid), np.min(Jid)], [np.max(Iid), np.max(Jid)]] return windows -def scale_windows(windows, image_shape=None, expand_scale=1.0, expand_border=0.0): +def scale_windows(windows, image: "Image" = None, expand_scale=1.0, expand_border=0.0): new_windows = {} for index in list(windows.keys()): new_window = deepcopy(windows[index]) # Get center and shape of the window center = ( - (new_window[0][0] + new_window[0][1]) / 2, - (new_window[1][0] + new_window[1][1]) / 2, + (new_window[0][0] + new_window[1][0]) / 2, + (new_window[0][1] + new_window[1][1]) / 2, ) shape = ( - new_window[0][1] - new_window[0][0], - new_window[1][1] - new_window[1][0], + new_window[1][0] - new_window[0][0], + new_window[1][1] - new_window[0][1], ) # Update the window with any expansion coefficients new_window = [ [ int(center[0] - expand_scale * shape[0] / 2 - expand_border), - int(center[0] + expand_scale * shape[0] / 2 + expand_border), + int(center[1] - expand_scale * shape[1] / 2 - expand_border), ], [ - int(center[1] - expand_scale * shape[1] / 2 - expand_border), + int(center[0] + expand_scale * shape[0] / 2 + expand_border), int(center[1] + expand_scale * shape[1] / 2 + expand_border), ], ] # Ensure the window does not exceed the borders of the image - if image_shape is not None: + if image is not None: new_window = [ - [max(0, new_window[0][0]), min(image_shape[1], new_window[0][1])], - [max(0, new_window[1][0]), min(image_shape[0], new_window[1][1])], + [max(0, new_window[0][0]), max(0, new_window[0][1])], + [ + min(image._data.shape[0], new_window[1][0]), + min(image._data.shape[1], new_window[1][1]), + ], ] new_windows[index] = new_window return new_windows @@ -216,34 +242,34 @@ def scale_windows(windows, image_shape=None, expand_scale=1.0, expand_border=0.0 def filter_windows( windows, - min_size=None, - max_size=None, - min_area=None, - max_area=None, - min_flux=None, - max_flux=None, - image=None, + min_size: Optional[float] = None, + max_size: Optional[float] = None, + min_area: Optional[float] = None, + max_area: Optional[float] = None, + min_flux: Optional[float] = None, + max_flux: Optional[float] = None, + image: "Image" = None, ): """ Filter a set of windows based on a set of criteria. - Parameters - ---------- - min_size: minimum size of the window in pixels - max_size: maximum size of the window in pixels - min_area: minimum area of the window in pixels - max_area: maximum area of the window in pixels - min_flux: minimum flux of the window in ADU - max_flux: maximum flux of the window in ADU - image: the image from which the flux is calculated for min_flux and max_flux + **Args:** + - `windows`: A dictionary of windows to filter. Each window is formatted as a list of lists with: window = [[xmin,ymin],[xmax,ymax]] + - `min_size`: minimum size of the window in pixels + - `max_size`: maximum size of the window in pixels + - `min_area`: minimum area of the window in pixels + - `max_area`: maximum area of the window in pixels + - `min_flux`: minimum flux of the window in ADU + - `max_flux`: maximum flux of the window in ADU + - `image`: the image from which the flux is calculated for min_flux and max_flux """ new_windows = {} for w in list(windows.keys()): if min_size is not None: if ( min( - windows[w][0][1] - windows[w][0][0], - windows[w][1][1] - windows[w][1][0], + windows[w][1][0] - windows[w][0][0], + windows[w][1][1] - windows[w][0][1], ) < min_size ): @@ -251,28 +277,28 @@ def filter_windows( if max_size is not None: if ( max( - windows[w][0][1] - windows[w][0][0], - windows[w][1][1] - windows[w][1][0], + windows[w][1][0] - windows[w][0][0], + windows[w][1][1] - windows[w][0][1], ) > max_size ): continue if min_area is not None: if ( - (windows[w][0][1] - windows[w][0][0]) * (windows[w][1][1] - windows[w][1][0]) + (windows[w][1][0] - windows[w][0][0]) * (windows[w][1][1] - windows[w][0][1]) ) < min_area: continue if max_area is not None: if ( - (windows[w][0][1] - windows[w][0][0]) * (windows[w][1][1] - windows[w][1][0]) + (windows[w][1][0] - windows[w][0][0]) * (windows[w][1][1] - windows[w][0][1]) ) > max_area: continue if min_flux is not None: if ( np.sum( - image[ - windows[w][1][0] : windows[w][1][1], - windows[w][0][0] : windows[w][0][1], + backend.to_numpy(image._data)[ + windows[w][0][0] : windows[w][1][0], + windows[w][0][1] : windows[w][1][1], ] ) < min_flux @@ -281,9 +307,9 @@ def filter_windows( if max_flux is not None: if ( np.sum( - image[ - windows[w][1][0] : windows[w][1][1], - windows[w][0][0] : windows[w][0][1], + backend.to_numpy(image._data)[ + windows[w][0][0] : windows[w][1][0], + windows[w][0][1] : windows[w][1][1], ] ) > max_flux @@ -299,44 +325,35 @@ def transfer_windows(windows, base_image, new_image): for the relative adjustments in origin, pixelscale, and rotation between the two images. - Parameters - ---------- - windows : dict - A dictionary of windows to be transferred. Each window is formatted as a list of lists with: - window = [[xmin,xmax],[ymin,ymax]] - base_image : Image - The image object from which the windows are being transferred. - new_image : Image - The image object to which the windows are being transferred. + **Args:** + - `windows`: A dictionary of windows to be transferred. Each window is formatted as a list of lists with: window = [[xmin,ymin],[xmax,ymax]] + - `base_image`: The image object from which the windows are being transferred. + - `new_image`: The image object to which the windows are being transferred. """ new_windows = {} for w in list(windows.keys()): - bottom_corner = np.clip( - np.floor( - new_image.plane_to_pixel( - base_image.pixel_to_plane(torch.tensor([windows[w][0][0], windows[w][1][0]])) - ) - .detach() - .cpu() - .numpy() - ), - a_min=0, - a_max=np.array(new_image.shape) - 1, - ) - top_corner = np.clip( - np.ceil( - new_image.plane_to_pixel( - base_image.pixel_to_plane(torch.tensor([windows[w][0][1], windows[w][1][1]])) - ) - .detach() - .cpu() - .numpy() - ), - a_min=0, - a_max=np.array(new_image.shape) - 1, - ) + four_corners_base = backend.as_array( + [ + windows[w][0], + windows[w][1], + [windows[w][0][0], windows[w][1][1]], + [windows[w][1][0], windows[w][0][1]], + ], + dtype=base_image.data.dtype, + device=base_image.data.device, + ) # (4,2) + four_corners_new = backend.to_numpy( + backend.stack( + new_image.plane_to_pixel(*base_image.pixel_to_plane(*four_corners_base.T)), dim=-1 + ) + ) # (4,2) + + bottom_corner = np.floor(np.min(four_corners_new, axis=0)).astype(int) + bottom_corner = np.clip(bottom_corner, 0, np.array(new_image._data.shape)) + top_corner = np.ceil(np.max(four_corners_new, axis=0)).astype(int) + top_corner = np.clip(top_corner, 0, np.array(new_image._data.shape)) new_windows[w] = [ - [bottom_corner[0], top_corner[0]], - [bottom_corner[1], top_corner[1]], + [int(bottom_corner[0]), int(bottom_corner[1])], + [int(top_corner[0]), int(top_corner[1])], ] return new_windows diff --git a/astrophot/utils/initialize/variance.py b/astrophot/utils/initialize/variance.py index 9b8b65e9..68f881bd 100644 --- a/astrophot/utils/initialize/variance.py +++ b/astrophot/utils/initialize/variance.py @@ -2,16 +2,15 @@ from scipy.ndimage import gaussian_filter from scipy.stats import binned_statistic import torch -from ...errors import InvalidData -import matplotlib.pyplot as plt +from ...backend_obj import backend, ArrayLike def auto_variance(data, mask=None): - if isinstance(data, torch.Tensor): - data = data.detach().cpu().numpy() - if isinstance(mask, torch.Tensor): - mask = mask.detach().cpu().numpy() + if isinstance(data, backend.array_type): + data = backend.to_numpy(data) + if isinstance(mask, backend.array_type): + mask = backend.to_numpy(mask) if mask is None: mask = np.zeros(data.shape, dtype=int) @@ -46,9 +45,7 @@ def auto_variance(data, mask=None): # Check if the variance is increasing with flux if p[0] < 0: - raise InvalidData( - "Variance appears to be decreasing with flux! Cannot accurately estimate variance." - ) + return np.ones_like(data) * var # Compute the approximate variance map variance = np.clip(p[0] * data + p[1], np.min(std) ** 2, None) variance[np.logical_not(mask)] = np.inf diff --git a/astrophot/utils/integration.py b/astrophot/utils/integration.py new file mode 100644 index 00000000..e765a3c8 --- /dev/null +++ b/astrophot/utils/integration.py @@ -0,0 +1,36 @@ +from functools import lru_cache + +from scipy.special import roots_legendre +import torch +from ..backend_obj import backend + +__all__ = ("quad_table",) + + +@lru_cache(maxsize=32) +def quad_table(order, dtype, device): + """ + Generate a meshgrid for quadrature points using Legendre-Gauss quadrature. + + Parameters + ---------- + n : int + The number of quadrature points in each dimension. + dtype : torch.dtype + The desired data type of the tensor. + device : torch.device + The device on which to create the tensor. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + The generated meshgrid as a tuple of Tensors. + """ + abscissa, weights = roots_legendre(order) + + w = backend.as_array(weights, dtype=dtype, device=device) + a = backend.as_array(abscissa, dtype=dtype, device=device) / 2.0 + di, dj = backend.meshgrid(a, a, indexing="ij") + + w = backend.outer(w, w) / 4.0 + return di, dj, w diff --git a/astrophot/utils/interpolate.py b/astrophot/utils/interpolate.py index bbfe8335..b142e66d 100644 --- a/astrophot/utils/interpolate.py +++ b/astrophot/utils/interpolate.py @@ -1,355 +1,72 @@ -from functools import lru_cache - -import numpy as np import torch -from astropy.convolution import convolve_fft - -from .operations import fft_convolve_torch - - -def _h_poly(t): - """Helper function to compute the 'h' polynomial matrix used in the - cubic spline. - - Args: - t (Tensor): A 1D tensor representing the normalized x values. - - Returns: - Tensor: A 2D tensor of size (4, len(t)) representing the 'h' polynomial matrix. - - """ - - tt = t[None, :] ** (torch.arange(4, device=t.device)[:, None]) - A = torch.tensor( - [[1, 0, -3, 2], [0, 1, -2, 1], [0, 0, 3, -2], [0, 0, -1, 1]], - dtype=t.dtype, - device=t.device, - ) - return A @ tt - - -def cubic_spline_torch( - x: torch.Tensor, y: torch.Tensor, xs: torch.Tensor, extend: str = "const" -) -> torch.Tensor: - """Compute the 1D cubic spline interpolation for the given data points - using PyTorch. - - Args: - x (Tensor): A 1D tensor representing the x-coordinates of the known data points. - y (Tensor): A 1D tensor representing the y-coordinates of the known data points. - xs (Tensor): A 1D tensor representing the x-coordinates of the positions where - the cubic spline function should be evaluated. - extend (str, optional): The method for handling extrapolation, either "const" or "linear". - Default is "const". - "const": Use the value of the last known data point for extrapolation. - "linear": Use linear extrapolation based on the last two known data points. - - Returns: - Tensor: A 1D tensor representing the interpolated values at the specified positions (xs). - - """ - m = (y[1:] - y[:-1]) / (x[1:] - x[:-1]) - m = torch.cat([m[[0]], (m[1:] + m[:-1]) / 2, m[[-1]]]) - idxs = torch.searchsorted(x[:-1], xs) - 1 - dx = x[idxs + 1] - x[idxs] - hh = _h_poly((xs - x[idxs]) / dx) - ret = hh[0] * y[idxs] + hh[1] * m[idxs] * dx + hh[2] * y[idxs + 1] + hh[3] * m[idxs + 1] * dx - if extend == "const": - ret[xs > x[-1]] = y[-1] - elif extend == "linear": - indices = xs > x[-1] - ret[indices] = y[-1] + (xs[indices] - x[-1]) * (y[-1] - y[-2]) / (x[-1] - x[-2]) - return ret - - -def interpolate_bicubic(img, X, Y): - """ - wrapper for scipy bivariate spline interpolation - """ - f_interp = RectBivariateSpline( - np.arange(dat.shape[0], dtype=np.float32), - np.arange(dat.shape[1], dtype=np.float32), - dat, - ) - return f_interp(Y, X, grid=False) - - -def Lanczos_kernel_np(dx, dy, scale): - """convolution kernel for shifting all pixels in a grid by some - sub-pixel length. - - """ - xx = np.arange(-scale, scale + 1) - dx - if dx < 0: - xx *= -1 - Lx = np.sinc(xx) * np.sinc(xx / scale) - if dx > 0: - Lx[0] = 0 - else: - Lx[-1] = 0 - - yy = np.arange(-scale, scale + 1) - dy - if dy < 0: - yy *= -1 - Ly = np.sinc(yy) * np.sinc(yy / scale) - if dx > 0: - Ly[0] = 0 - else: - Ly[-1] = 0 - - LXX, LYY = np.meshgrid(Lx, Ly, indexing="xy") - LL = LXX * LYY - w = np.sum(LL) - LL /= w - # plt.imshow(LL.detach().numpy(), origin = "lower") - # plt.show() - return LL - - -def Lanczos_kernel(dx, dy, scale): - """Kernel function for Lanczos interpolation, defines the - interpolation behavior between pixels. - - """ - xx = np.arange(-scale + 1, scale + 1) + dx - yy = np.arange(-scale + 1, scale + 1) + dy - Lx = np.sinc(xx) * np.sinc(xx / scale) - Ly = np.sinc(yy) * np.sinc(yy / scale) - LXX, LYY = np.meshgrid(Lx, Ly) - LL = LXX * LYY - w = np.sum(LL) - LL /= w - return LL - - -def point_Lanczos(I, X, Y, scale): - """ - Apply Lanczos interpolation to evaluate a single point. - """ - ranges = [ - [int(np.floor(X) - scale + 1), int(np.floor(X) + scale + 1)], - [int(np.floor(Y) - scale + 1), int(np.floor(Y) + scale + 1)], - ] - LL = Lanczos_kernel(np.floor(X) - X, np.floor(Y) - Y, scale) - LL = LL[ - max(0, -ranges[1][0]) : LL.shape[0] + min(0, I.shape[0] - ranges[1][1]), - max(0, -ranges[0][0]) : LL.shape[1] + min(0, I.shape[1] - ranges[0][1]), - ] - F = I[ - max(0, ranges[1][0]) : min(I.shape[0], ranges[1][1]), - max(0, ranges[0][0]) : min(I.shape[1], ranges[0][1]), - ] - return np.sum(F * LL) - - -def _shift_Lanczos_kernel_torch(dx, dy, scale, dtype, device): - """convolution kernel for shifting all pixels in a grid by some - sub-pixel length. - - """ - xsign = 1 - 2 * (dx < 0).to(dtype=torch.int32) # flips the kernel if the shift is negative - xx = xsign * (torch.arange(int(-scale), int(scale + 1), dtype=dtype, device=device) - dx) - Lx = torch.sinc(xx) * torch.sinc(xx / scale) - - ysign = 1 - 2 * (dy < 0).to(dtype=torch.int32) - yy = ysign * (torch.arange(int(-scale), int(scale + 1), dtype=dtype, device=device) - dy) - Ly = torch.sinc(yy) * torch.sinc(yy / scale) - - LXX, LYY = torch.meshgrid(Lx, Ly, indexing="xy") - LL = LXX * LYY - w = torch.sum(LL) - # plt.imshow(LL.detach().numpy(), origin = "lower") - # plt.show() - return LL / w - - -def shift_Lanczos_torch(I, dx, dy, scale, dtype, device, img_prepadded=True): - """Apply Lanczos interpolation to shift by less than a pixel in x and - y. - - """ - LL = _shift_Lanczos_kernel_torch(dx, dy, scale, dtype, device) - ret = fft_convolve_torch(I, LL, img_prepadded=img_prepadded) - return ret - - -def shift_Lanczos_np(I, dx, dy, scale): - """Apply Lanczos interpolation to shift by less than a pixel in x and - y. - - I: the image - dx: amount by which the grid will be moved in the x-axis (the "data" is fixed and the grid moves). Should be a value from (-0.5,0.5) - dy: amount by which the grid will be moved in the y-axis (the "data" is fixed and the grid moves). Should be a value from (-0.5,0.5) - scale: dictates size of the Lanczos kernel. Full kernel size is 2*scale+1 - """ - LL = Lanczos_kernel_np(dx, dy, scale) - return convolve_fft(I, LL, boundary="fill") - - -def interpolate_Lanczos_grid(img, X, Y, scale): - """ - Perform Lanczos interpolation at a grid of points. - https://pixinsight.com/doc/docs/InterpolationAlgorithms/InterpolationAlgorithms.html - """ - - sinc_X = list( - np.sinc(np.arange(-scale + 1, scale + 1) - X[i] + np.floor(X[i])) - * np.sinc((np.arange(-scale + 1, scale + 1) - X[i] + np.floor(X[i])) / scale) - for i in range(len(X)) - ) - sinc_Y = list( - np.sinc(np.arange(-scale + 1, scale + 1) - Y[i] + np.floor(Y[i])) - * np.sinc((np.arange(-scale + 1, scale + 1) - Y[i] + np.floor(Y[i])) / scale) - for i in range(len(Y)) - ) - - # Extract an image which has the required dimensions - use_img = np.take( - np.take( - img, - np.arange(int(np.floor(Y[0]) - step + 1), int(np.floor(Y[-1]) + step + 1)), - 0, - mode="clip", - ), - np.arange(int(np.floor(X[0]) - step + 1), int(np.floor(X[-1]) + step + 1)), - 1, - mode="clip", - ) - - # Create a sliding window view of the image with the dimensions of the lanczos scale grid - # window = np.lib.stride_tricks.sliding_window_view(use_img, (2*scale, 2*scale)) - - # fixme going to need some broadcasting magic - XX = np.ones((2 * scale, 2 * scale)) - res = np.zeros((len(Y), len(X))) - for x, lowx, highx in zip(range(len(X)), np.floor(X) - step + 1, np.floor(X) + step + 1): - for y, lowy, highy in zip(range(len(Y)), np.floor(Y) - step + 1, np.floor(Y) + step + 1): - L = XX * sinc_X[x] * sinc_Y[y].reshape((sinc_Y[y].size, -1)) - res[y, x] = np.sum(use_img[lowy:highy, lowx:highx] * L) / np.sum(L) - return res +import numpy as np +from ..backend_obj import backend, ArrayLike -def interpolate_Lanczos(img, X, Y, scale): - """ - Perform Lanczos interpolation on an image at a series of specified points. - https://pixinsight.com/doc/docs/InterpolationAlgorithms/InterpolationAlgorithms.html - """ - flux = [] +__all__ = ("default_prof", "interp2d") - for i in range(len(X)): - box = [ - [ - max(0, int(round(np.floor(X[i]) - scale + 1))), - min(img.shape[1], int(round(np.floor(X[i]) + scale + 1))), - ], - [ - max(0, int(round(np.floor(Y[i]) - scale + 1))), - min(img.shape[0], int(round(np.floor(Y[i]) + scale + 1))), - ], - ] - chunk = img[box[1][0] : box[1][1], box[0][0] : box[0][1]] - XX = np.ones(chunk.shape) - Lx = ( - np.sinc(np.arange(-scale + 1, scale + 1) - X[i] + np.floor(X[i])) - * np.sinc((np.arange(-scale + 1, scale + 1) - X[i] + np.floor(X[i])) / scale) - )[ - box[0][0] - - int(round(np.floor(X[i]) - scale + 1)) : 2 * scale - + box[0][1] - - int(round(np.floor(X[i]) + scale + 1)) - ] - Ly = ( - np.sinc(np.arange(-scale + 1, scale + 1) - Y[i] + np.floor(Y[i])) - * np.sinc((np.arange(-scale + 1, scale + 1) - Y[i] + np.floor(Y[i])) / scale) - )[ - box[1][0] - - int(round(np.floor(Y[i]) - scale + 1)) : 2 * scale - + box[1][1] - - int(round(np.floor(Y[i]) + scale + 1)) - ] - L = XX * Lx * Ly.reshape((Ly.size, -1)) - w = np.sum(L) - flux.append(np.sum(chunk * L) / w) - return np.array(flux) - -def interp1d_torch(x_in, y_in, x_out): - indices = torch.searchsorted(x_in[:-1], x_out) - 1 - weights = (y_in[1:] - y_in[:-1]) / (x_in[1:] - x_in[:-1]) - return y_in[indices] + weights[indices] * (x_out - x_in[indices]) +def default_prof( + shape: tuple[int, int], pixelscale: float, min_pixels: int = 2, scale: float = 0.2 +) -> np.ndarray: + prof = [0, min_pixels * pixelscale] + imagescale = max(shape) # np.sqrt(np.sum(np.array(shape) ** 2)) + while prof[-1] < (imagescale * pixelscale / 2): + prof.append(prof[-1] + max(min_pixels * pixelscale, prof[-1] * scale)) + return np.array(prof) def interp2d( - im: torch.Tensor, - x: torch.Tensor, - y: torch.Tensor, -) -> torch.Tensor: + im: ArrayLike, + i: ArrayLike, + j: ArrayLike, + padding_mode: str = "zeros", +) -> ArrayLike: """ Interpolates a 2D image at specified coordinates. Similar to `torch.nn.functional.grid_sample` with `align_corners=False`. Args: im (Tensor): A 2D tensor representing the image. - x (Tensor): A tensor of x coordinates (in pixel space) at which to interpolate. - y (Tensor): A tensor of y coordinates (in pixel space) at which to interpolate. + i (Tensor): A tensor of i coordinates (in pixel space) at which to interpolate. + j (Tensor): A tensor of j coordinates (in pixel space) at which to interpolate. Returns: - Tensor: Tensor with the same shape as `x` and `y` containing the interpolated values. + Tensor: Tensor with the same shape as `i` and `j` containing the interpolated values. """ # Convert coordinates to pixel indices h, w = im.shape # reshape for indexing purposes - start_shape = x.shape - x = x.view(-1) - y = y.view(-1) - - x0 = x.floor().long() - y0 = y.floor().long() - x1 = x0 + 1 - y1 = y0 + 1 - x0 = x0.clamp(0, w - 2) - x1 = x1.clamp(1, w - 1) - y0 = y0.clamp(0, h - 2) - y1 = y1.clamp(1, h - 1) - - fa = im[y0, x0] - fb = im[y1, x0] - fc = im[y0, x1] - fd = im[y1, x1] - - wa = (x1 - x) * (y1 - y) - wb = (x1 - x) * (y - y0) - wc = (x - x0) * (y1 - y) - wd = (x - x0) * (y - y0) + start_shape = i.shape + i = i.flatten() + j = j.flatten() + + # valid + valid = (i >= -0.5) & (i <= (h - 0.5)) & (j >= -0.5) & (j <= (w - 0.5)) + + i0 = backend.long(backend.floor(i)) + j0 = backend.long(backend.floor(j)) + i0 = backend.clamp(i0, 0, h - 2) + i1 = i0 + 1 + j0 = backend.clamp(j0, 0, w - 2) + j1 = j0 + 1 + + fa = im[i0, j0] + fb = im[i0, j1] + fc = im[i1, j0] + fd = im[i1, j1] + + wa = (i1 - i) * (j1 - j) + wb = (i1 - i) * (j - j0) + wc = (i - i0) * (j1 - j) + wd = (i - i0) * (j - j0) result = fa * wa + fb * wb + fc * wc + fd * wd - return result.view(*start_shape) - - -@lru_cache(maxsize=32) -def curvature_kernel(dtype, device): - kernel = torch.tensor( - [ - [0.0, 1.0, 0.0], - [1.0, -4, 1.0], - [0.0, 1.0, 0.0], - ], # [[1., -2.0, 1.], [-2.0, 4, -2.0], [1.0, -2.0, 1.0]], - device=device, - dtype=dtype, - ) - return kernel - - -@lru_cache(maxsize=32) -def simpsons_kernel(dtype, device): - kernel = torch.ones(1, 1, 3, 3, dtype=dtype, device=device) - kernel[0, 0, 1, 1] = 16.0 - kernel[0, 0, 1, 0] = 4.0 - kernel[0, 0, 0, 1] = 4.0 - kernel[0, 0, 1, 2] = 4.0 - kernel[0, 0, 2, 1] = 4.0 - kernel = kernel / 36.0 - return kernel + if padding_mode == "zeros": + return (result * valid).reshape(start_shape) + elif padding_mode == "border": + return result.reshape(start_shape) + raise ValueError(f"Unsupported padding mode: {padding_mode}") diff --git a/astrophot/utils/isophote/__init__.py b/astrophot/utils/isophote/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/astrophot/utils/isophote/ellipse.py b/astrophot/utils/isophote/ellipse.py deleted file mode 100644 index 279ab618..00000000 --- a/astrophot/utils/isophote/ellipse.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy as np - - -def Rscale_Fmodes(theta, modes, Am, Phim): - """Factor to scale radius values given a set of fourier mode - amplitudes. - - """ - return np.exp(sum(Am[m] * np.cos(modes[m] * (theta + Phim[m])) for m in range(len(modes)))) - - -def parametric_Fmodes(theta, modes, Am, Phim): - """determines a number of scaled radius samples with fourier mode - perturbations for a unit circle. - - """ - x = np.cos(theta) - y = np.sin(theta) - Rscale = Rscale_Fmodes(theta, modes, Am, Phim) - return x * Rscale, y * Rscale - - -def Rscale_SuperEllipse(theta, ellip, C=2): - """Scale factor for radius values given a super ellipse coefficient.""" - res = (1 - ellip) / ( - np.abs((1 - ellip) * np.cos(theta)) ** (C) + np.abs(np.sin(theta)) ** (C) - ) ** (1.0 / C) - return res - - -def parametric_SuperEllipse(theta, ellip, C=2): - """determines a number of scaled radius samples with super ellipse - perturbations for a unit circle. - - """ - rs = Rscale_SuperEllipse(theta, ellip, C) - return rs * np.cos(theta), rs * np.sin(theta) diff --git a/astrophot/utils/isophote/extract.py b/astrophot/utils/isophote/extract.py deleted file mode 100644 index 5dbcf2ee..00000000 --- a/astrophot/utils/isophote/extract.py +++ /dev/null @@ -1,249 +0,0 @@ -import numpy as np -import logging -from scipy.stats import iqr - -from .ellipse import parametric_SuperEllipse, Rscale_SuperEllipse -from ..conversions.coordinates import Rotate_Cartesian_np -from ..interpolate import interpolate_Lanczos - - -def Sigma_Clip_Upper(v, iterations=10, nsigma=5): - """ - Perform sigma clipping on the "v" array. Each iteration involves - computing the median and 16-84 range, these are used to clip beyond - "nsigma" number of sigma above the median. This is repeated for - "iterations" number of iterations, or until convergence if None. - """ - - v2 = np.sort(v) - i = 0 - old_lim = 0 - lim = np.inf - while i < iterations and old_lim != lim: - med = np.median(v2[v2 < lim]) - rng = iqr(v2[v2 < lim], rng=[16, 84]) / 2 - old_lim = lim - lim = med + rng * nsigma - i += 1 - return lim - - -def _iso_between( - IMG, - sma_low, - sma_high, - PARAMS, - c, - more=False, - mask=None, - sigmaclip=False, - sclip_iterations=10, - sclip_nsigma=5, -): - if "m" not in PARAMS: - PARAMS["m"] = None - if "C" not in PARAMS: - PARAMS["C"] = None - Rlim = sma_high * ( - 1.0 - if PARAMS["m"] is None - else np.exp(sum(np.abs(PARAMS["Am"][m]) for m in range(len(PARAMS["m"])))) - ) - ranges = [ - [max(0, int(c["x"] - Rlim - 2)), min(IMG.shape[1], int(c["x"] + Rlim + 2))], - [max(0, int(c["y"] - Rlim - 2)), min(IMG.shape[0], int(c["y"] + Rlim + 2))], - ] - XX, YY = np.meshgrid( - np.arange(ranges[0][1] - ranges[0][0], dtype=float) - c["x"] + float(ranges[0][0]), - np.arange(ranges[1][1] - ranges[1][0], dtype=float) - c["y"] + float(ranges[1][0]), - ) - theta = np.arctan(YY / XX) + np.pi * (XX < 0) - RR = np.sqrt(XX**2 + YY**2) - Fmode_Rscale = ( - 1.0 - if PARAMS["m"] is None - else Rscale_Fmodes(theta - PARAMS["pa"], PARAMS["m"], PARAMS["Am"], PARAMS["Phim"]) - ) - SuperEllipse_Rscale = Rscale_SuperEllipse( - theta - PARAMS["pa"], PARAMS["ellip"], 2 if PARAMS["C"] is None else PARAMS["C"] - ) - RR /= SuperEllipse_Rscale * Fmode_Rscale - rselect = np.logical_and(RR < sma_high, RR > sma_low) - fluxes = IMG[ranges[1][0] : ranges[1][1], ranges[0][0] : ranges[0][1]][rselect] - CHOOSE = None - if mask is not None and sma_high > 5: - CHOOSE = np.logical_not( - mask[ranges[1][0] : ranges[1][1], ranges[0][0] : ranges[0][1]][rselect] - ) - # Perform sigma clipping if requested - if sigmaclip: - sclim = Sigma_Clip_Upper(fluxes, sclip_iterations, sclip_nsigma) - if CHOOSE is None: - CHOOSE = fluxes < sclim - else: - CHOOSE = np.logical_or(CHOOSE, fluxes < sclim) - if CHOOSE is not None and np.sum(CHOOSE) < 5: - logging.warning( - "Entire Isophote is Masked! R_l: %.3f, R_h: %.3f, PA: %.3f, ellip: %.3f" - % (sma_low, sma_high, PARAMS["pa"] * 180 / np.pi, PARAMS["ellip"]) - ) - CHOOSE = np.ones(CHOOSE.shape).astype(bool) - if CHOOSE is not None: - countmasked = np.sum(np.logical_not(CHOOSE)) - else: - countmasked = 0 - if more: - if CHOOSE is not None and sma_high > 5: - return fluxes[CHOOSE], theta[rselect][CHOOSE], countmasked - else: - return fluxes, theta[rselect], countmasked - else: - if CHOOSE is not None and sma_high > 5: - return fluxes[CHOOSE] - else: - return fluxes - - -def _iso_extract( - IMG, - sma, - PARAMS, - c, - more=False, - minN=None, - mask=None, - interp_mask=False, - rad_interp=30, - interp_method="lanczos", - interp_window=5, - sigmaclip=False, - sclip_iterations=10, - sclip_nsigma=5, -): - """ - Internal, basic function for extracting the pixel fluxes along an isophote - """ - if "m" not in PARAMS: - PARAMS["m"] = None - if "C" not in PARAMS: - PARAMS["C"] = None - N = max(15, int(0.9 * 2 * np.pi * sma)) - if minN is not None: - N = max(minN, N) - # points along ellipse to evaluate - theta = np.linspace(0, 2 * np.pi * (1.0 - 1.0 / N), N) - theta = np.arctan(PARAMS["q"] * np.tan(theta)) + np.pi * (np.cos(theta) < 0) - Fmode_Rscale = ( - 1.0 - if PARAMS["m"] is None - else Rscale_Fmodes(theta, PARAMS["m"], PARAMS["Am"], PARAMS["Phim"]) - ) - R = sma * Fmode_Rscale - # Define ellipse - X, Y = parametric_SuperEllipse( - theta, 1.0 - PARAMS["q"], 2 if PARAMS["C"] is None else PARAMS["C"] - ) - X, Y = R * X, R * Y - # rotate ellipse by PA - X, Y = Rotate_Cartesian_np(PARAMS["pa"], X, Y) - theta = (theta + PARAMS["pa"]) % (2 * np.pi) - # shift center - X, Y = X + c["x"], Y + c["y"] - - # Reject samples from outside the image - BORDER = np.logical_and( - np.logical_and(X >= 0, X < (IMG.shape[1] - 1)), - np.logical_and(Y >= 0, Y < (IMG.shape[0] - 1)), - ) - if not np.all(BORDER): - X = X[BORDER] - Y = Y[BORDER] - theta = theta[BORDER] - - Rlim = np.max(R) - if Rlim < rad_interp: - box = [ - [max(0, int(c["x"] - Rlim - 5)), min(IMG.shape[1], int(c["x"] + Rlim + 5))], - [max(0, int(c["y"] - Rlim - 5)), min(IMG.shape[0], int(c["y"] + Rlim + 5))], - ] - if interp_method == "bicubic": - flux = interpolate_bicubic( - IMG[box[1][0] : box[1][1], box[0][0] : box[0][1]], - X - box[0][0], - Y - box[1][0], - ) - elif interp_method == "lanczos": - flux = interpolate_Lanczos(IMG, X, Y, interp_window) - else: - raise ValueError( - "Unknown interpolate method %s. Should be one of lanczos or bicubic" % interp_method - ) - else: - # round to integers and sample pixels values - flux = IMG[np.rint(Y).astype(np.int32), np.rint(X).astype(np.int32)] - # CHOOSE holds boolean array for which flux values to keep, initialized as None for no clipping - CHOOSE = None - # Mask pixels if a mask is given - if mask is not None: - CHOOSE = np.logical_not(mask[np.rint(Y).astype(np.int32), np.rint(X).astype(np.int32)]) - # Perform sigma clipping if requested - if sigmaclip and len(flux) > 30: - sclim = Sigma_Clip_Upper(flux, sclip_iterations, sclip_nsigma) - if CHOOSE is None: - CHOOSE = flux < sclim - else: - CHOOSE = np.logical_or(CHOOSE, flux < sclim) - # Dont clip pixels if that removes all of the pixels - countmasked = 0 - if CHOOSE is not None and np.sum(CHOOSE) <= 0: - logging.warning( - "Entire Isophote was Masked! R: %.3f, PA: %.3f, q: %.3f" - % (sma, PARAMS["pa"] * 180 / np.pi, PARAMS["q"]) - ) - # Interpolate clipped flux values if requested - elif CHOOSE is not None and interp_mask: - flux[np.logical_not(CHOOSE)] = np.interp( - theta[np.logical_not(CHOOSE)], theta[CHOOSE], flux[CHOOSE], period=2 * np.pi - ) - # simply remove all clipped pixels if user doesn't request another option - elif CHOOSE is not None: - flux = flux[CHOOSE] - theta = theta[CHOOSE] - countmasked = np.sum(np.logical_not(CHOOSE)) - - # Return just the flux values, or flux and angle values - if more: - return flux, theta, countmasked - else: - return flux - - -def _iso_line(IMG, length, width, pa, c, more=False): - start = np.array([c["x"], c["y"]]) - end = start + length * np.array([np.cos(pa), np.sin(pa)]) - - ranges = [ - [ - max(0, int(min(start[0], end[0]) - 2)), - min(IMG.shape[1], int(max(start[0], end[0]) + 2)), - ], - [ - max(0, int(min(start[1], end[1]) - 2)), - min(IMG.shape[0], int(max(start[1], end[1]) + 2)), - ], - ] - XX, YY = np.meshgrid( - np.arange(ranges[0][1] - ranges[0][0], dtype=float), - np.arange(ranges[1][1] - ranges[1][0], dtype=float), - ) - XX -= c["x"] - float(ranges[0][0]) - YY -= c["y"] - float(ranges[1][0]) - XX, YY = (XX * np.cos(-pa) - YY * np.sin(-pa), XX * np.sin(-pa) + YY * np.cos(-pa)) - - lselect = np.logical_and.reduce((XX >= -0.5, XX <= length, np.abs(YY) <= (width / 2))) - flux = IMG[ranges[1][0] : ranges[1][1], ranges[0][0] : ranges[0][1]][lselect] - - if more: - return flux, XX[lselect], YY[lselect] - else: - return flux, XX[lselect] diff --git a/astrophot/utils/isophote/integrate.py b/astrophot/utils/isophote/integrate.py deleted file mode 100644 index eb3490b1..00000000 --- a/astrophot/utils/isophote/integrate.py +++ /dev/null @@ -1,210 +0,0 @@ -import numpy as np - - -def fluxdens_to_fluxsum(R, I, axisratio): - """ - Integrate a flux density profile - - R: semi-major axis length (arcsec) - I: flux density (flux/arcsec^2) - axisratio: b/a profile - """ - - S = np.zeros(len(R)) - S[0] = I[0] * np.pi * axisratio[0] * (R[0] ** 2) - for i in range(1, len(R)): - S[i] = trapz(2 * np.pi * I[: i + 1] * R[: i + 1] * axisratio[: i + 1], R[: i + 1]) + S[0] - return S - - -def fluxdens_to_fluxsum_errorprop( - R, I, IE, axisratio, axisratioE=None, N=100, symmetric_error=True -): - """ - Integrate a flux density profile - - R: semi-major axis length (arcsec) - I: flux density (flux/arcsec^2) - axisratio: b/a profile - """ - if axisratioE is None: - axisratioE = np.zeros(len(R)) - - # Create container for the monte-carlo iterations - sum_results = np.zeros((N, len(R))) - 99.999 - I_CHOOSE = np.logical_and(np.isfinite(I), I > 0) - if np.sum(I_CHOOSE) < 5: - return (None, None) if symmetric_error else (None, None, None) - sum_results[0][I_CHOOSE] = fluxdens_to_fluxsum(R[I_CHOOSE], I[I_CHOOSE], axisratio[I_CHOOSE]) - for i in range(1, N): - # Randomly sampled SB profile - tempI = np.random.normal(loc=I, scale=np.abs(IE)) - # Randomly sampled axis ratio profile - tempq = np.clip( - np.random.normal(loc=axisratio, scale=np.abs(axisratioE)), - a_min=1e-3, - a_max=1 - 1e-3, - ) - # Compute COG with sampled data - sum_results[i][I_CHOOSE] = fluxdens_to_fluxsum( - R[I_CHOOSE], tempI[I_CHOOSE], tempq[I_CHOOSE] - ) - - # Condense monte-carlo evaluations into profile and uncertainty envelope - sum_lower = sum_results[0] - np.quantile(sum_results, 0.317310507863 / 2, axis=0) - sum_upper = np.quantile(sum_results, 1.0 - 0.317310507863 / 2, axis=0) - sum_results[0] - - # Return requested uncertainty format - if symmetric_error: - return sum_results[0], np.abs(sum_lower + sum_upper) / 2 - else: - return sum_results[0], sum_lower, sum_upper - - -def _Fmode_integrand(t, parameters): - fsum = sum( - parameters["Am"][m] * np.cos(parameters["m"][m] * (t + parameters["Phim"][m])) - for m in range(len(parameters["m"])) - ) - dfsum = sum( - parameters["m"][m] - * parameters["Am"][m] - * np.sin(parameters["m"][m] * (t + parameters["Phim"][m])) - for m in range(len(parameters["m"])) - ) - return (np.sin(t) ** 2) * np.exp(2 * fsum) + np.sin(t) * np.cos(t) * np.exp(fsum) * dfsum - - -def Fmode_Areas(R, parameters): - A = [] - for i in range(len(R)): - A.append((R[i] ** 2) * quad(_Fmode_integrand, 0, 2 * np.pi, args=(parameters[i],))[0]) - return np.array(A) - - -def Fmode_fluxdens_to_fluxsum(R, I, parameters, A=None): - """ - Integrate a flux density profile, with isophotes including Fourier perturbations. - - Arguments - --------- - R: arcsec - semi-major axis length - - I: flux/arcsec^2 - flux density - - parameters: list of dictionaries - list of dictionary of isophote shape parameters for each radius. - formatted as - - .. code-block:: python - - { - "ellip": "ellipticity", - "m": "list of modes used", - "Am": "list of mode powers", - "Phim": "list of mode phases", - } - - entries for each radius. - """ - if all(parameters[p]["m"] is None for p in range(len(parameters))): - return fluxdens_to_fluxsum( - R, - I, - 1.0 - np.array(list(parameters[p]["ellip"] for p in range(len(parameters)))), - ) - - S = np.zeros(len(R)) - if A is None: - A = Fmode_Areas(R, parameters) - # update the Area calculation to be scaled by the ellipticity - Aq = A * np.array(list((1 - parameters[i]["ellip"]) for i in range(len(R)))) - S[0] = I[0] * Aq[0] - Adiff = np.array([Aq[0]] + list(Aq[1:] - Aq[:-1])) - for i in range(1, len(R)): - S[i] = trapz(I[: i + 1] * Adiff[: i + 1], R[: i + 1]) + S[0] - return S - - -def Fmode_fluxdens_to_fluxsum_errorprop(R, I, IE, parameters, N=100, symmetric_error=True): - """ - Integrate a flux density profile, with isophotes including Fourier perturbations. - - Arguments - --------- - R: arcsec - semi-major axis length - - I: flux/arcsec^2 - flux density - - parameters: list of dictionaries - list of dictionary of isophote shape parameters for each radius. - formatted as - - .. code-block:: python - - { - "ellip": "ellipticity", - "m": "list of modes used", - "Am": "list of mode powers", - "Phim": "list of mode phases", - } - - entries for each radius. - """ - - for i in range(len(R)): - if "ellip err" not in parameters[i]: - parameters[i]["ellip err"] = np.zeros(len(R)) - if all(parameters[p]["m"] is None for p in range(len(parameters))): - return fluxdens_to_fluxsum_errorprop( - R, - I, - IE, - 1.0 - np.array(list(parameters[p]["ellip"] for p in range(len(parameters)))), - np.array(list(parameters[p]["ellip err"] for p in range(len(parameters)))), - N=N, - symmetric_error=symmetric_error, - ) - - # Create container for the monte-carlo iterations - sum_results = np.zeros((N, len(R))) - 99.999 - I_CHOOSE = np.logical_and(np.isfinite(I), I > 0) - if np.sum(I_CHOOSE) < 5: - return (None, None) if symmetric_error else (None, None, None) - cut_parameters = list(compress(parameters, I_CHOOSE)) - A = Fmode_Areas(R[I_CHOOSE], cut_parameters) - sum_results[0][I_CHOOSE] = Fmode_fluxdens_to_fluxsum( - R[I_CHOOSE], I[I_CHOOSE], cut_parameters, A - ) - for i in range(1, N): - # Randomly sampled SB profile - tempI = np.random.normal(loc=I, scale=np.abs(IE)) - # Randomly sampled axis ratio profile - temp_parameters = deepcopy(cut_parameters) - for p in range(len(cut_parameters)): - temp_parameters[p]["ellip"] = np.clip( - np.random.normal( - loc=cut_parameters[p]["ellip"], - scale=np.abs(cut_parameters[p]["ellip err"]), - ), - a_min=1e-3, - a_max=1 - 1e-3, - ) - # Compute COG with sampled data - sum_results[i][I_CHOOSE] = Fmode_fluxdens_to_fluxsum( - R[I_CHOOSE], tempI[I_CHOOSE], temp_parameters, A - ) - - # Condense monte-carlo evaluations into profile and uncertainty envelope - sum_lower = sum_results[0] - np.quantile(sum_results, 0.317310507863 / 2, axis=0) - sum_upper = np.quantile(sum_results, 1.0 - 0.317310507863 / 2, axis=0) - sum_results[0] - - # Return requested uncertainty format - if symmetric_error: - return sum_results[0], np.abs(sum_lower + sum_upper) / 2 - else: - return sum_results[0], sum_lower, sum_upper diff --git a/astrophot/utils/operations.py b/astrophot/utils/operations.py deleted file mode 100644 index 9f403726..00000000 --- a/astrophot/utils/operations.py +++ /dev/null @@ -1,247 +0,0 @@ -from functools import lru_cache - -import torch -from scipy.fft import next_fast_len -from scipy.special import roots_legendre -import numpy as np - - -def fft_convolve_torch(img, psf, psf_fft=False, img_prepadded=False): - # Ensure everything is tensor - img = torch.as_tensor(img) - psf = torch.as_tensor(psf) - - if img_prepadded: - s = img.size() - else: - s = tuple( - next_fast_len(int(d + (p + 1) / 2), real=True) for d, p in zip(img.size(), psf.size()) - ) # list(int(d + (p + 1) / 2) for d, p in zip(img.size(), psf.size())) - - img_f = torch.fft.rfft2(img, s=s) - - if not psf_fft: - psf_f = torch.fft.rfft2(psf, s=s) - else: - psf_f = psf - - conv_f = img_f * psf_f - conv = torch.fft.irfft2(conv_f, s=s) - - # Roll the tensor to correct centering and crop to original image size - return torch.roll( - conv, - shifts=(-int((psf.size()[0] - 1) / 2), -int((psf.size()[1] - 1) / 2)), - dims=(0, 1), - )[: img.size()[0], : img.size()[1]] - - -def fft_convolve_multi_torch( - img, kernels, kernel_fft=False, img_prepadded=False, dtype=None, device=None -): - # Ensure everything is tensor - img = torch.as_tensor(img, dtype=dtype, device=device) - for k in range(len(kernels)): - kernels[k] = torch.as_tensor(kernels[k], dtype=dtype, device=device) - - if img_prepadded: - s = img.size() - else: - s = list(int(d + (p + 1) / 2) for d, p in zip(img.size(), kernels[0].size())) - - img_f = torch.fft.rfft2(img, s=s) - - if not kernel_fft: - kernels_f = list(torch.fft.rfft2(kernel, s=s) for kernel in kernels) - else: - psf_f = psf - - conv_f = img_f - - for kernel_f in kernels_f: - conv_f *= kernel_f - - conv = torch.fft.irfft2(conv_f, s=s) - - # Roll the tensor to correct centering and crop to original image size - return torch.roll( - conv, - shifts=( - -int((sum(kernel.size()[0] for kernel in kernels) - 1) / 2), - -int((sum(kernel.size()[1] for kernel in kernels) - 1) / 2), - ), - dims=(0, 1), - )[: img.size()[0], : img.size()[1]] - - -def axis_ratio_com(data, PA, X=None, Y=None, mask=None): - """get center of mass like quantity for axis ratio""" - if X is None: - S = data.shape - X, Y = np.meshgrid(np.arange(S[1]) - S[1] / 2, np.arange(S[0]) - S[0] / 2, indexing="xy") - if mask is None: - mask = np.zeros_like(data, dtype=bool) - mask = np.logical_not(mask) - - theta = np.arctan2(Y, X) - PA - theta = theta[mask] - data = data[mask] - ang_com_cos = np.sum(data * np.cos(theta) ** 2) / np.sum(data) - ang_com_sin = np.sum(data * np.sin(theta) ** 2) / np.sum(data) - return ang_com_sin / max(ang_com_sin, ang_com_cos) - - -def displacement_spacing(N, dtype=torch.float64, device="cpu"): - return torch.linspace(-(N - 1) / (2 * N), (N - 1) / (2 * N), N, dtype=dtype, device=device) - - -def displacement_grid(Nx, Ny, pixelscale=None, dtype=torch.float64, device="cpu"): - px = displacement_spacing(Nx, dtype=dtype, device=device) - py = displacement_spacing(Ny, dtype=dtype, device=device) - PX, PY = torch.meshgrid(px, py, indexing="xy") - return (pixelscale @ torch.stack((PX, PY)).view(2, -1)).reshape((2, *PX.shape)) - - -@lru_cache(maxsize=32) -def quad_table(n, p, dtype, device): - """ - from: https://pomax.github.io/bezierinfo/legendre-gauss.html - """ - abscissa, weights = roots_legendre(n) - - w = torch.tensor(weights, dtype=dtype, device=device) - a = torch.tensor(abscissa, dtype=dtype, device=device) - X, Y = torch.meshgrid(a, a, indexing="xy") - - W = torch.outer(w, w) / 4.0 - - X, Y = p @ (torch.stack((X, Y)).view(2, -1) / 2.0) - - return X, Y, W.reshape(-1) - - -def single_quad_integrate( - X, Y, image_header, eval_brightness, eval_parameters, dtype, device, quad_level=3 -): - - # collect gaussian quadrature weights - abscissaX, abscissaY, weight = quad_table(quad_level, image_header.pixelscale, dtype, device) - # Specify coordinates at which to evaluate function - Xs = torch.repeat_interleave(X[..., None], quad_level**2, -1) + abscissaX - Ys = torch.repeat_interleave(Y[..., None], quad_level**2, -1) + abscissaY - - # Evaluate the model at the quadrature points - res = eval_brightness( - X=Xs, - Y=Ys, - image=image_header, - parameters=eval_parameters, - ) - - # Reference flux for pixel is simply the mean of the evaluations - ref = res[..., (quad_level**2) // 2] # res.mean(axis=-1) # # alternative, use midpoint - - # Apply the weights and reduce to original pixel space - res = (res * weight).sum(axis=-1) - - return res, ref - - -def grid_integrate( - X, - Y, - image_header, - eval_brightness, - eval_parameters, - dtype, - device, - quad_level=3, - gridding=5, - _current_depth=1, - max_depth=2, - reference=None, -): - """The grid_integrate function performs adaptive quadrature - integration over a given pixel grid, offering precision control - where it is needed most. - - Args: - X (torch.Tensor): A 2D tensor representing the x-coordinates of the grid on which the function will be integrated. - Y (torch.Tensor): A 2D tensor representing the y-coordinates of the grid on which the function will be integrated. - image_header (ImageHeader): An object containing meta-information about the image. - eval_brightness (callable): A function that evaluates the brightness at each grid point. This function should be compatible with PyTorch tensor operations. - eval_parameters (Parameter_Group): An object containing parameters that are passed to the eval_brightness function. - dtype (torch.dtype): The data type of the output tensor. The dtype argument should be a valid PyTorch data type. - device (torch.device): The device on which to perform the computations. The device argument should be a valid PyTorch device. - quad_level (int, optional): The initial level of quadrature used in the integration. Defaults to 3. - gridding (int, optional): The factor by which the grid is subdivided when the integration error for a pixel is above the allowed threshold. Defaults to 5. - _current_depth (int, optional): The current depth level of the grid subdivision. Used for recursive calls to the function. Defaults to 1. - max_depth (int, optional): The maximum depth level of grid subdivision. Once this level is reached, no further subdivision is performed. Defaults to 2. - reference (torch.Tensor or None, optional): A scalar value that represents the allowed threshold for the integration error. - - Returns: - torch.Tensor: A tensor of the same shape as X and Y that represents the result of the integration on the grid. - - This function operates by first performing a quadrature - integration over the given pixels. If the maximum depth level has - been reached, it simply returns the result. Otherwise, it - calculates the integration error for each pixel and selects those - that have an error above the allowed threshold. For pixels that - have low error, the result is set as computed. For those with high - error, it sets up a finer sampling grid and recursively evaluates - the quadrature integration on it. Finally, it integrates the - results from the finer sampling grid back to the current - resolution. - - """ - # perform quadrature integration on the given pixels - res, ref = single_quad_integrate( - X, - Y, - image_header, - eval_brightness, - eval_parameters, - dtype, - device, - quad_level=quad_level, - ) - - # if the max depth is reached, simply return the integrated pixels - if _current_depth >= max_depth: - return res - - # Begin integral - integral = torch.zeros_like(X) - - # Select pixels which have errors above the allowed threshold - select = torch.abs((res - ref)) > reference - - # For pixels with low error, set the results as computed - integral[torch.logical_not(select)] = res[torch.logical_not(select)] - - # Set up sub-gridding to super resolve problem pixels - stepx, stepy = displacement_grid(gridding, gridding, image_header.pixelscale, dtype, device) - # Write out the coordinates for the super resolved pixels - subgridX = torch.repeat_interleave(X[select].unsqueeze(-1), gridding**2, -1) + stepx.reshape(-1) - subgridY = torch.repeat_interleave(Y[select].unsqueeze(-1), gridding**2, -1) + stepy.reshape(-1) - - # Recursively evaluate the quadrature integration on the finer sampling grid - subgridres = grid_integrate( - subgridX, - subgridY, - image_header.rescale_pixel(1 / gridding), - eval_brightness, - eval_parameters, - dtype, - device, - quad_level=quad_level, - gridding=gridding, - _current_depth=_current_depth + 1, - max_depth=max_depth, - reference=reference * gridding**2, - ) - - # Integrate the finer sampling grid back to current resolution - integral[select] = subgridres.sum(axis=(-1,)) - - return integral diff --git a/astrophot/utils/optimization.py b/astrophot/utils/optimization.py deleted file mode 100644 index 03edc409..00000000 --- a/astrophot/utils/optimization.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch - -from .. import AP_config - - -def chi_squared(target, model, mask=None, variance=None): - if mask is None: - if variance is None: - return torch.sum((target - model) ** 2) - else: - return torch.sum(((target - model) ** 2) / variance) - else: - mask = torch.logical_not(mask) - if variance is None: - return torch.sum((target[mask] - model[mask]) ** 2) - else: - return torch.sum(((target[mask] - model[mask]) ** 2) / variance[mask]) - - -def reduced_chi_squared(target, model, params, mask=None, variance=None): - if mask is None: - ndf = ( - torch.prod( - torch.tensor(target.shape, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - ) - - params - ) - else: - ndf = torch.sum(torch.logical_not(mask)) - params - return chi_squared(target, model, mask, variance) / ndf diff --git a/astrophot/utils/parametric_profiles.py b/astrophot/utils/parametric_profiles.py index bce0d7a5..9e945d9c 100644 --- a/astrophot/utils/parametric_profiles.py +++ b/astrophot/utils/parametric_profiles.py @@ -1,34 +1,29 @@ -import torch import numpy as np +from numpy import ndarray from .conversions.functions import sersic_n_to_b -from .interpolate import cubic_spline_torch +__all__ = ( + "sersic_np", + "gaussian_np", + "exponential_np", + "moffat_np", + "nuker_np", + "ferrer_np", + "king_np", +) -def sersic_torch(R, n, Re, Ie): - """Seric 1d profile function, specifically designed for pytorch - operations - Parameters: - R: Radii tensor at which to evaluate the sersic function - n: sersic index restricted to n > 0.36 - Re: Effective radius in the same units as R - Ie: Effective surface density - """ - bn = sersic_n_to_b(n) - return Ie * torch.exp(-bn * (torch.pow(R / Re, 1 / n) - 1)) - - -def sersic_np(R, n, Re, Ie): +def sersic_np(R: ndarray, n: ndarray, Re: ndarray, Ie: ndarray) -> ndarray: """Sersic 1d profile function, works more generally with numpy operations. In the event that impossible values are passed to the function it returns large values to guide optimizers away from such values. - Parameters: - R: Radii array at which to evaluate the sersic function - n: sersic index restricted to n > 0.36 - Re: Effective radius in the same units as R - Ie: Effective surface density + **Args:** + - `R`: Radii array at which to evaluate the sersic function + - `n`: sersic index restricted to n > 0.36 + - `Re`: Effective radius in the same units as R + - `Ie`: Effective surface density """ if np.any(np.array([n, Re, Ie]) <= 0): return np.ones(len(R)) * 1e6 @@ -36,94 +31,54 @@ def sersic_np(R, n, Re, Ie): return Ie * np.exp(-bn * ((R / Re) ** (1 / n) - 1)) -def gaussian_torch(R, sigma, I0): - """Gaussian 1d profile function, specifically designed for pytorch - operations. - - Parameters: - R: Radii tensor at which to evaluate the sersic function - sigma: standard deviation of the gaussian in the same units as R - I0: central surface density - """ - return (I0 / torch.sqrt(2 * np.pi * sigma**2)) * torch.exp(-0.5 * torch.pow(R / sigma, 2)) - - -def gaussian_np(R, sigma, I0): +def gaussian_np(R: ndarray, sigma: ndarray, I0: ndarray) -> ndarray: """Gaussian 1d profile function, works more generally with numpy operations. - Parameters: - R: Radii array at which to evaluate the sersic function - sigma: standard deviation of the gaussian in the same units as R - I0: central surface density + **Args:** + - `R`: Radii array at which to evaluate the gaussian function + - `sigma`: standard deviation of the gaussian in the same units as R + - `I0`: central surface density """ return (I0 / np.sqrt(2 * np.pi * sigma**2)) * np.exp(-0.5 * ((R / sigma) ** 2)) -def exponential_torch(R, Re, Ie): - """Exponential 1d profile function, specifically designed for pytorch - operations. - - Parameters: - R: Radii tensor at which to evaluate the sersic function - Re: Effective radius in the same units as R - Ie: Effective surface density - """ - return Ie * torch.exp( - -sersic_n_to_b(torch.tensor(1.0, dtype=R.dtype, device=R.device)) * ((R / Re) - 1.0) - ) - - -def exponential_np(R, Ie, Re): +def exponential_np(R: ndarray, Ie: ndarray, Re: ndarray) -> ndarray: """Exponential 1d profile function, works more generally with numpy operations. - Parameters: - R: Radii array at which to evaluate the sersic function - Re: Effective radius in the same units as R - Ie: Effective surface density + **Args:** + - `R`: Radii array at which to evaluate the exponential function + - `Ie`: Effective surface density + - `Re`: Effective radius in the same units as R """ return Ie * np.exp(-sersic_n_to_b(1.0) * (R / Re - 1.0)) -def moffat_torch(R, n, Rd, I0): - """Moffat 1d profile function, specifically designed for pytorch - operations - - Parameters: - R: Radii tensor at which to evaluate the moffat function - n: concentration index - Rd: scale length in the same units as R - I0: central surface density - - """ - return I0 / (1 + (R / Rd) ** 2) ** n - - -def moffat_np(R, n, Rd, I0): +def moffat_np(R: ndarray, n: ndarray, Rd: ndarray, I0: ndarray) -> ndarray: """Moffat 1d profile function, works with numpy operations. - Parameters: - R: Radii tensor at which to evaluate the moffat function - n: concentration index - Rd: scale length in the same units as R - I0: central surface density - + **Args:** + - `R`: Radii array at which to evaluate the moffat function + - `n`: concentration index + - `Rd`: scale length in the same units as R + - `I0`: central surface density """ return I0 / (1 + (R / Rd) ** 2) ** n -def nuker_torch(R, Rb, Ib, alpha, beta, gamma): - """Nuker 1d profile function, specifically designed for pytorch - operations +def nuker_np( + R: ndarray, Rb: ndarray, Ib: ndarray, alpha: ndarray, beta: ndarray, gamma: ndarray +) -> ndarray: + """Nuker 1d profile function, works with numpy functions - Parameters: - R: Radii tensor at which to evaluate the nuker function - Ib: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. - Rb: scale length radius - alpha: sharpness of transition between power law slopes - beta: outer power law slope - gamma: inner power law slope + **Args:** + - `R`: Radii tensor at which to evaluate the nuker function + - `Ib`: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. + - `Rb`: scale length radius + - `alpha`: sharpness of transition between power law slopes + - `beta`: outer power law slope + - `gamma`: inner power law slope """ return ( @@ -134,44 +89,31 @@ def nuker_torch(R, Rb, Ib, alpha, beta, gamma): ) -def nuker_np(R, Rb, Ib, alpha, beta, gamma): - """Nuker 1d profile function, works with numpy functions - - Parameters: - R: Radii tensor at which to evaluate the nuker function - Ib: brightness at the scale length, represented as the log of the brightness divided by pixel scale squared. - Rb: scale length radius - alpha: sharpness of transition between power law slopes - beta: outer power law slope - gamma: inner power law slope - +def ferrer_np(R: ndarray, rout: ndarray, alpha: ndarray, beta: ndarray, I0: ndarray) -> ndarray: """ - return ( - Ib - * (2 ** ((beta - gamma) / alpha)) - * ((R / Rb) ** (-gamma)) - * ((1 + (R / Rb) ** alpha) ** ((gamma - beta) / alpha)) - ) - + Modified Ferrer profile. + + **Args:** + - `R`: Radial distance from the center. + - `rout`: Outer radius of the profile. + - `alpha`: Power-law index. + - `beta`: Exponent for the modified Ferrer function. + - `I0`: Central intensity. + """ + return (R < rout) * I0 * ((1 - (np.clip(R, 0, rout) / rout) ** (2 - beta)) ** alpha) -def spline_torch(R, profR, profI, extend): - """Spline 1d profile function, cubic spline between points up - to second last point beyond which is linear, specifically designed - for pytorch. - Parameters: - R: Radii tensor at which to evaluate the sersic function - profR: radius values for the surface density profile in the same units as R - profI: surface density values for the surface density profile +def king_np(R: ndarray, Rc: ndarray, Rt: ndarray, alpha: ndarray, I0: ndarray) -> ndarray: + """ + Empirical King profile. + + **Args:** + - `R`: The radial distance from the center. + - `Rc`: The core radius of the profile. + - `Rt`: The truncation radius of the profile. + - `alpha`: The power-law index of the profile. + - `I0`: The central intensity of the profile. """ - I = cubic_spline_torch(profR, profI, R.view(-1), extend="none").view(*R.shape) - res = torch.zeros_like(I) - res[R <= profR[-1]] = 10 ** (I[R <= profR[-1]]) - if extend: - res[R > profR[-1]] = 10 ** ( - profI[-2] - + (R[R > profR[-1]] - profR[-2]) * ((profI[-1] - profI[-2]) / (profR[-1] - profR[-2])) - ) - else: - res[R > profR[-1]] = 0 - return res + beta = 1 / (1 + (Rt / Rc) ** 2) ** (1 / alpha) + gamma = 1 / (1 + (R / Rc) ** 2) ** (1 / alpha) + return (R < Rt) * I0 * ((np.clip(gamma, 0, 1) - beta) / (1 - beta)) ** alpha diff --git a/docs/requirements.txt b/docs/requirements.txt index e32d2be5..07b09906 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,8 +1,16 @@ +caustics +corner +emcee +graphviz ipywidgets -jupyter-book +jax<=0.7.0 +jupyter-book<2.0 matplotlib +nbformat nbsphinx photutils +pyvo scikit-image sphinx sphinx-rtd-theme +tqdm diff --git a/docs/source/_config.yml b/docs/source/_config.yml index d72b8966..635dc983 100644 --- a/docs/source/_config.yml +++ b/docs/source/_config.yml @@ -38,12 +38,12 @@ sphinx: extra_extensions: - "sphinx.ext.autodoc" - "sphinx.ext.autosummary" - - "sphinx.ext.napoleon" - - "sphinx.ext.doctest" - - "sphinx.ext.coverage" - - "sphinx.ext.mathjax" - - "sphinx.ext.ifconfig" - "sphinx.ext.viewcode" + # - "sphinx.ext.napoleon" + # - "sphinx.ext.doctest" + # - "sphinx.ext.coverage" + # - "sphinx.ext.mathjax" + # - "sphinx.ext.ifconfig" config: html_theme_options: logo: diff --git a/docs/source/_toc.yml b/docs/source/_toc.yml index ebd38f80..e35a6cd2 100644 --- a/docs/source/_toc.yml +++ b/docs/source/_toc.yml @@ -15,4 +15,5 @@ chapters: - file: contributing - file: citation - file: license + - file: astrophotdocs/index - file: modules diff --git a/docs/source/astrophotdocs/index.rst b/docs/source/astrophotdocs/index.rst new file mode 100644 index 00000000..6d12dec0 --- /dev/null +++ b/docs/source/astrophotdocs/index.rst @@ -0,0 +1,21 @@ +==================== +AstroPhot Docstrings +==================== + +Here you will find all of the AstroPhot class and method docstrings, built using +markdown formatting. These are useful for understanding the details of a given +model and can also be accessed via the python help command +```help(ap.object)```. For the AstroPhot ``ap.Model`` objects, the docstrings are a +combination of the various base-classes and mixins that make them up. They are +very detailed, but can be a bit awkward in their formatting, the good news is +that a lot of useful information is available there! + +.. toctree:: + :maxdepth: 2 + + models + image + fit + plots + utils + errors diff --git a/docs/source/coordinates.rst b/docs/source/coordinates.rst index f87377c7..95c22907 100644 --- a/docs/source/coordinates.rst +++ b/docs/source/coordinates.rst @@ -6,228 +6,103 @@ Coordinate systems in astronomy can be complicated, AstroPhot is no different. Here we explain how coordinate systems are handled to help you avoid possible pitfalls. -Basics ------- +For the most part, AstroPhot follows the FITS standard for coordinates, though +limited to the types of images that AstroPhot can model. -There are three main coordinate systems to think about. +Three Coordinate Systems +------------------------ + +There are three coordinate systems to think about. #. ``world`` coordinates are the classic (RA, DEC) that many astronomical sources are represented in. These should always be used in degree units as far as AstroPhot is concerned. -#. ``plane`` coordinates are the tangent plane on which AstroPhot - performs its calculations. Working on a plane makes everything - linear and does not introduce a noticeable effect for small enough - images. In the tangent plane everything should be represented in - arcsecond units. -#. ``pixel`` coordinates are specific to each image, they start at - (0,0) in the center of the [0,0] indexed pixel. These are - effectively unitless, a step of 1 in pixel coordinates is the same - as changing an index by 1. Though image array indexing is flipped - so pixel coordinate (3,10) represents the center of the index - [10,3] pixel. It is a convention for most images that the first - axis indexes vertically and the second axis indexis horizontally, - if this is not the case for your images you can apply a transpose - before passing the data to AstroPhot. Also, in the pixel coordinate - system the values are represented by floating point numbers and so - (1.3,2.8) is a valid pixel coordinate that is just partway between - pixel centers. +#. ``plane`` coordinates are the tangent plane on which AstroPhot performs its + calculations. Working on a plane makes everything linear and does not + introduce a noticeable projection effect for small enough images. In the + tangent plane everything should be represented in arcsecond units. +#. ``pixel`` coordinates are specific to each image, they start at (0,0) in the + center of the [0,0] indexed pixel. These are effectively unitless, a step of + 1 in pixel coordinates is the same as changing an index by 1. AstroPhot + adopts an indexing scheme standard to FITS files meaning the pixel coordinate + (5,9) corresponds to the pixel indexed at [5,9]. Normally for numpy arrays + and PyTorch tensors, the indexing would be flipped as [9,5] so AstroPhot + applies a transpose on any image it receives in an Image object. Also, in + the pixel coordinate system the values are represented by floating point + numbers, so (1.3,2.8) is a valid pixel coordinate that is just partway + between pixel centers. Transformations exist in AstroPhot for converting ``world`` to/from ``plane`` and for converting ``plane`` to/from ``pixel``. The best way -to interface with these is to use the ``image.window.world_to_plane`` +to interface with these is to use the ``image.world_to_plane`` for any AstroPhot image object (you may similarly swap ``world``, ``plane``, and ``pixel``). One gotcha to keep in mind with regards to ``world_to_plane`` and -``plane_to_world`` is that AstroPhot needs to know the reference -(RA_0, DEC_0) where the tangent plane meets with the celestial -sphere. You can set this by including ``reference_radec = (RA_0, -DEC_0)`` as an argument in an image you create. If a reference is not -given, then one will be assumed based on available information. Note -that if you are doing simultaneous multi-image analysis you should -ensure that the ``reference_radec`` is same for all images! +``plane_to_world`` is that AstroPhot needs to know the reference (RA, DEC) where +the tangent plane meets with the celestial sphere. AstroPhot now adopts the FITS +standard for this using ``image.crval`` to store the reference world +coordinates. Note that if you are doing simultaneous multi-image analysis you +should ensure that the ``crval`` is same for all images! Projection Systems ------------------ -AstroPhot currently implements three coordinate reference systems: -Gnomonic, Orthographic, and Steriographic. The default projection is -the Gnomonic, which represents the perspective of an observer at the -center of a sphere projected onto a plane. For the exact -implementation by AstroPhot see the `Wolfram MathWorld -`_ page. - -On small scales the choice of projection doesn't matter. For very -large images the effect may be detectable, though it is likely -insignificant compared to other effects in an image. Just like the -``reference_radec`` you can choose your projection system in an image -you construct by passing ``projection = 'gnomonic'`` as an argument. -Just like with the reference coordinate, for images to "talk" to each -other they should have the same projection. - -If you really want to change the projection after an image has -been created (warning, this may cause serious missalignments between -images), you can force it to update with:: - - image.window.projection = 'steriographic' - -which would change the projection to steriographic. The image won't -recompute its position in the new projection system, it will just use -new equations going forward. Hence the potential to seriously mess up -your image alignment if this is done after some calculations have -already been performed. - -Talking to the world --------------------- - -If you have images with WCS information then you will want to use this -to map images onto the same tangent plane. Often this will take the -form of information in a FITS file, which can easily be accessed using -Astropy like:: - - from astropy.io import fits - from astropy.wcs import WCS - hdu = fits.open("myimage.fits") - data = hdu[0].data - wcs = WCS(hdu[0].header) - -That is somewhat described in the basics section, however there are -some more features you can take advantage of. When creating an image -in AstroPhot, you need to tell it some basic properties so that the -image knows how to place itself in the tangent plane. Using the -Astropy WCS object above you can recover the reference coordinates -of the image in (RA, DEC), for an example Astropy wcs object you could -accomplish this with: - - ra, dec = wcs.wcs.crval - -meaning that you know the world position of the reference RA, Dec -of the image WCS. To have -AstroPhot place the image at the right location in the tangent plane -you can use the ``wcs`` argument when constructing the image:: - - image = ap.image.Target_Image( - data = data, - reference_radec = (ra, dec), - wcs = wcs, - ) - -AstroPhot will set the reference RA, DEC to these coordinates and also -set the image in the correct position. A more explicit alternative is -to just say what the reference coordinate should be. That would look -something like:: - - image = ap.image.Target_Image( - data = data, - pixelscale = pixelscale, - reference_radec = (ra,dec), - reference_imagexy = (x, y), - ) - -which uniquely defines the position of the image in the coordinate -system. Remember that the ``reference_radec`` should be the same for -all images in a multi-image analysis, while ``reference_imagexy`` -specifies the position of a particular image. Another similar option is to set -``center_radec`` like:: - - image = ap.image.Target_Image( - data = data, - pixelscale = pixelscale, - reference_radec = (ra,dec), - center_radec = (c_ra, c_dec), - ) - -You may also have a catalogue of objects that you would like to -project into the image. The easiest way to do this if you already have -an image object is to call the ``world_to_plane`` functions -manually. Say for example that you know the object position as an -Astropy ``SkyCoord`` object, and you want to use this to set the -center position of a sersic model. That would look like:: - - model = ap.models.AstroPhot_Model( - name = "knowloc", - model_type = "sersic galaxy model", - target = image, - parameters = { - "center": image.window.world_to_plane(obj_pos.ra.deg, obj_pos.dec.deg), - } - ) - -Which will start the object at the correct position in the image given -its world coordinates. As you can see, the ``center`` and in fact all -parameters for AstroPhot models are defined in the tangent plane. This -means that if you have optimized a model and you would like to present -its position in world coordinates that can be compared with other -sources, you will need to do the opposite operation:: - - world_position = image.window.plane_to_world(model["center"].value) - -That should assign ``world_position`` the coordinates in RA and DEC -(degrees), assuming that you initialized the image with a WCS or by -other means ensured that the world coordinates being used are -correct. If you never gave AstroPhot the information it needs, then it -likely assumed a reference position of (0,0) in the world coordinate -system. +AstroPhot currently only supports the Gnomonic projection system. This means +that the tangent plane is defined as "contacting" the celestial sphere at a +single point, the reference (crval) coordinates. The tangent plane coordinates +correspond to the world coordinates as viewed from the center of the celestial +sphere. This is the most common projection system used in astronomy and commonly +used in the FITS standard. It is also the one that Astropy usually uses for its +WCS objects. Coordinate reference points --------------------------- -As stated earlier, there are essentially three coordinate systems in -AstroPhot: ``world``, ``plane``, and ``pixel``. To uniquely specify -the transformation from ``world`` to ``plane`` AstroPhot keeps track -of two vectors: ``reference_radec`` and ``reference_planexy``. These -variables are stored in all ``Image_Header`` objects and essentially -pin down the mapping such that one coordinate will get mapped to the -other. All other coordinates follow from the projection system assumed -(i.e., Gnomonic). It is possible to specify these variables directly -when constructing an image, or implicitly if you give some other -relevant information (e.g., an Astropy WCS). AstroPhot Window objects -also keep track of two more vectors: ``reference_imageij`` and -``reference_imagexy``. These variables control where an image is -placed in the tangent plane and represent a fixed point between the -pixel coordinates and the tangent plane coordinates. If your pixel -scale matrix includes a rotation then the rotation will be performed -about this position. - -All together, these reference positions define how pixels are mapped -in AstroPhot. This level of generality is overkill for analyzing a -single image, so AstroPhot makes reasonable assumptions about these -reference points if you don't specify them all. This makes it easy to -do single image analysis without thinking too much about the -coordinate systems. However, for multi-band or multi-epoch imaging it -is critical to be absolutely clear about these coordinate -transformations so that images can be aligned properly on the sky. As -an intuitive explanation, think of ``reference_radec`` and -``reference_planexy`` as defining the coordinate system that is shared -between images, while ``reference_imageij`` and ``reference_imagexy`` -specify where a single image is located. As such, in multi-image -analysis if you wish to use world coordinates, you should explitcitly -pass the same ``reference_radec`` and ``reference_planexy`` to every -image so that the same coordinate system is defined for all of them -(the same tangent plane at the same point on the celestial sphere). If -you aren't going to interact with world coordinates, you can ignore -those reference points entirely and it won't affect your images. - -Below is a summary of the reference coordinates and their meaning: - -#. ``reference_radec`` world coordinates on the celestial sphere (RA, - DEC in degrees) where the tangent plane makes contact. This should - be the same for every image in a multi-image analysis. -#. ``reference_planexy`` tangent plane coordinates (arcsec) where it - makes contact with the celesial sphere. This should typically be - (0,0) though that is not stricktly enforced (it is assumed if not - given). This reference coordinate should be the same for all - images in a multi-image analysis. -#. ``reference_imageij`` pixel coordinates about which the image is - defined. For example in an Astropy WCS object the wcs.wcs.crpix - array gives the pixel coordinate reference point for which the - world coordinate mapping (wcs.wcs.crval) is defined. One may think - of the referenced pixel location as being "pinned" to the tangent - plane. This may be different for each image in a multi-image - analysis. -#. ``reference_imagexy`` tangent plane coordinates (arcsec) about - which the image is defined. This is the pivot point about which the - pixelscale matrix operates, therefore if the pixelscale matrix - defines a rotation then this is the coordinate about which the - rotation will be performed. This may be different for each image in - a multi-image analysis. +There are three coordinate systems in AstroPhot: ``world``, ``plane``, and +``pixel``. AstroPhot tracks a reference point in each coordinate system used to +connect each system. Below is a summary of the reference coordinates and their +meaning: + +#. ``crval`` world coordinates on the celestial sphere (RA, DEC in degrees) + where the tangent plane makes contact. crval always contacts the tangent + plane at (0,0) in the tangent plane coordinates. This should be the same for + every image in a multi-image analysis. +#. ``crtan`` tangent plane coordinates (arcsec) where the pixel grid makes + contact with the tangent plane. This is the pivot point about which the + pixelscale matrix operates, therefore if the pixelscale matrix defines a + rotation then this is the coordinate about which the rotation will be + performed. This may be different for each image in a multi-image analysis. +#. ``crpix`` pixel coordinates where the pixel grid makes contact with the + tangent plane. One may think of the referenced pixel location as being + "pinned" to the tangent plane. This may be different for each image in a + multi-image analysis. + +Thinking of the celestial sphere, tangent plane, and pixel grid as three +interconnected coordinate systems is crucial for understanding how AstroPhot +operates in a multi-image context. While the transformations may get +complicated, try to remember these contact points: + +* ``crval`` is in the world coordinates and contacts the tangent plane at + (0,0) in the tangent plane coordinates. +* ``crtan`` is in the tangent plane coordinates and contacts the pixel grid at + ``crpix`` in the pixel coordinates. + +What parts go where? +-------------------- + +Since AstroPhot works in multiple reference frames it can be easy to get lost. +Keep these basics in mind. The world coordinates are where catalogues exist, so +this is the coordinate system you should use when interfacing with external +resources. The tangent plane coordinates are where the models exist. So when +creating a model and considering factors like the position angle, you should +think in the tangent plane coordinates. The pixel coordinates are where the data +exists. So when you create a TargetImage object it is in pixel coordinates, but +so too is a ModelImage object since it is intended to be compared against a +TargetImage. This means that any distortions in the TargetImage (i.e. SIP +distortions) will show up in the ModelImage, but aren't actually part of the +model. This can manifest for example as a round Gaussian model looking +elliptical in its ModelImage because there is a skew in the CD matrix in the +TargetImage it is matching. In general this is a good thing because we care +about how our models look on the sky (tangent plane), not strictly how they look +in the pixel grid. diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index a07cc2fa..d7ea2c57 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -5,8 +5,8 @@ Getting Started First follow the installation instructions, then come here to learn how to use AstroPhot for the first time. -Basic AstroPhot code philosophy ------------------------------- +Basic AstroPhot code organization +--------------------------------- AstroPhot is a modular and object oriented astronomical image modelling package. Modularity means that it is relatively simple to change or replace one aspect of @@ -22,14 +22,13 @@ would expect. This makes the experience more user friendly hopefully meaning that you can quickly take advantage of the powerful features available. One of the core components of AstroPhot is the model objects, these are -organized in a class hierarchy with several layers of inheritance. While this is -not considered best programming practice for many situations, in AstroPhot it is -very intentional and we think helpful to users. With this hierarchy it is very -easy to customize a model to suit your needs without needing to rewrite a great -deal of code. Simply access the point in the hierarchy which most closely -matches your desired result and make minor modifications. In the tutorials you -can see how detailed models can be implemented with only a few lines of code -even though the user has complete freedom to change any aspect of the model. +organized in a class hierarchy with several layers of inheritance. With this +hierarchy it is very easy to customize a model to suit your needs without +needing to rewrite a great deal of code. Simply access the point in the +hierarchy which most closely matches your desired result and make minor +modifications. In the tutorials you can see how detailed models can be +implemented with only a few lines of code even though the user has complete +freedom to change any aspect of the model. Install ------- @@ -59,40 +58,15 @@ tutorials then run the:: command to download the AstroPhot tutorials. If you run into difficulty with this, you can also access the tutorials directly at :doc:`tutorials` to download -as PDFs. Once you have the tutorials, start a jupyter session and run through -them. The recommended order is: - -#. :doc:`tutorials/GettingStarted` -#. :doc:`tutorials/GroupModels` -#. :doc:`tutorials/ModelZoo` -#. :doc:`tutorials/FittingMethods` -#. :doc:`tutorials/BasicPSFModels` -#. :doc:`tutorials/JointModels` -#. :doc:`tutorials/AdvancedPSFModels` -#. :doc:`tutorials/CustomModels` - -When downloading the tutorials, you will also get a file called -``simple_config.py``, this is an example AstroPhot config file. Configuration -files are an alternate interface to the AstroPhot functionality. They are -somewhat more limited in capacity, but very easy to interface with. See the -guide on configuration files here: :doc:`configfile_interface` . - -Model Org Chart ---------------- - -As a quick reference for what kinds of models are available in AstroPhot, the -org chart shows you the class hierarchy where the leaf nodes at the bottom are -the models that can actually be used. Following different paths through the -hierarchy gives models with different properties. Just use the second line at -each step in the flow chart to construct the name. For example one could follow -a fairly direct path to get a ``sersic galaxy model``, or a more complex path to -get a ``nuker fourier warp galaxy model``. Note that the ``Component_Model`` -object doesn't have an identifier, it is really meant to hide in the background -while its subclasses do the work. - -.. image:: https://github.com/Autostronomy/AstroPhot/blob/main/media/AstroPhotModelOrgchart.png?raw=true - :alt: AstroPhot Model Org Chart - :width: 100 % +as PDFs or jupyter notebooks. Once you have the tutorials, start a jupyter +session and run through them. + +Model Zoo +--------- + +The best way to see what models are available in AstroPhot is to peruse the +:doc:`tutorials/ModelZoo`. Here you can see the models evaluated on a regular +grid, and play around with the values if you are running the tutorial locally. Detailed Documentation ---------------------- diff --git a/docs/source/prebuilt/segmap_models_fit.py b/docs/source/prebuilt/segmap_models_fit.py index dd9f0e61..ad1d819b 100644 --- a/docs/source/prebuilt/segmap_models_fit.py +++ b/docs/source/prebuilt/segmap_models_fit.py @@ -25,10 +25,7 @@ name = "field_name" # used for saving files target_file = ".fits" # can be a numpy array instead segmap_file = ".fits" # can be a numpy array instead -mask_file = None # ".fits" # can be a numpy array instead psf_file = None # ".fits" # can be a numpy array instead -variance_file = None # ".fits" # or numpy array or "auto" -pixelscale = 0.1 # arcsec/pixel zeropoint = 22.5 # mag initial_sky = None # If None, sky will be estimated. Recommended to set manually sky_locked = False @@ -46,8 +43,6 @@ save_residual_image = True target_hdu = 0 # FITS file index for image data segmap_hdu = 0 -mask_hdu = 0 -variance_hdu = 0 psf_hdu = 0 window_expand_scale = 2 # Windows from segmap will be expanded by this factor window_expand_border = 10 # Windows from segmap will be expanded by this number of pixels @@ -58,11 +53,11 @@ # load target and segmentation map # --------------------------------------------------------------------- print("loading target and segmentation map") -if isinstance(target_file, str): - hdu = fits.open(target_file) - target_data = np.array(hdu[target_hdu].data, dtype=np.float64) -else: - target_data = target_file +target = ap.TargetImage( + filename=target_file, + hduext=target_hdu, + zeropoint=zeropoint, +) if isinstance(segmap_file, str): hdu = fits.open(segmap_file) @@ -70,53 +65,18 @@ else: segmap_data = segmap_file -# load mask, variance, and psf +# load psf # --------------------------------------------------------------------- -# Mask -if isinstance(mask_file, str): - print("loading mask") - hdu = fits.open(mask_file) - mask_data = np.array(hdu[mask_hdu].data, dtype=bool) -elif mask_file is None: - mask_data = None -else: - mask_data = mask_file -# Variance -if isinstance(variance_file, str) and not variance_file == "auto": - print("loading variance") - hdu = fits.open(variance_file) - variance_data = np.array(hdu[variance_hdu].data, dtype=np.float64) -elif variance_file is None: - variance_data = None -else: - variance_data = variance_file # PSF if isinstance(psf_file, str): print("loading psf") hdu = fits.open(psf_file) psf_data = np.array(hdu[psf_hdu].data, dtype=np.float64) - psf = ap.image.PSF_Image( - data=psf_data, - pixelscale=pixelscale, - ) + target.psf = target.psf_image(data=psf_data) elif psf_file is None: psf = None else: - psf = ap.image.PSF_Image( - data=psf_file, - pixelscale=pixelscale, - ) - -# Create target object -# --------------------------------------------------------------------- -target = ap.image.Target_Image( - data=target_data, - pixelscale=pixelscale, - zeropoint=zeropoint, - mask=mask_data, - psf=psf, - variance=variance_data, -) + target.psf = target.psf_image(data=psf_file) # Initialization from segmap # --------------------------------------------------------------------- @@ -126,23 +86,21 @@ windows = ap.utils.initialize.filter_windows( windows, **segmap_filter, - image=target_data, + image=target, ) for ids in segmap_filter_ids: del windows[ids] -centers = ap.utils.initialize.centroids_from_segmentation_map(segmap_data, target_data) +centers = ap.utils.initialize.centroids_from_segmentation_map(segmap_data, target) if "galaxy" in model_type: - PAs = ap.utils.initialize.PA_from_segmentation_map(segmap_data, target_data, centers) - qs = ap.utils.initialize.q_from_segmentation_map(segmap_data, target_data, centers, PAs) + PAs = ap.utils.initialize.PA_from_segmentation_map(segmap_data, target, centers) + qs = ap.utils.initialize.q_from_segmentation_map(segmap_data, target, centers) else: PAs = None qs = None init_params = {} for window in windows: - init_params[window] = { - "center": np.array(centers[window]) * pixelscale, - } + init_params[window] = {"center": centers[window]} if "galaxy" in model_type: init_params[window]["PA"] = PAs[window] init_params[window]["q"] = qs[window] @@ -153,14 +111,15 @@ print("Creating models") models = [] models.append( - ap.models.AstroPhot_Model( + ap.Model( name="sky", model_type=sky_model_type, target=target, - parameters={"F": initial_sky} if initial_sky is not None else {}, - locked=sky_locked, + I=initial_sky if initial_sky is not None else {}, ) ) +if sky_locked: + models[0].to_static() primary_model = None for window in windows: if primary_key is not None and window == primary_key: @@ -175,25 +134,25 @@ primary_initial_params["PA"] = PAs[window] if "q" not in primary_initial_params and qs is not None and "galaxy" in primary_model_type: primary_initial_params["q"] = qs[window] - model = ap.models.AstroPhot_Model( + model = ap.Model( name=primary_name, model_type=primary_model_type, target=target, - parameters=primary_initial_params, + **primary_initial_params, window=windows[window], ) primary_model = model else: print(window) - model = ap.models.AstroPhot_Model( + model = ap.Model( name=f"{model_type} {window}", model_type=model_type, target=target, window=windows[window], - parameters=init_params[window], + **init_params[window], ) models.append(model) -model = ap.models.AstroPhot_Model( +model = ap.Model( name=f"{name} model", model_type="group model", target=target, @@ -204,12 +163,12 @@ # --------------------------------------------------------------------- print("Initializing model") model.initialize() -print("Fitting model") +print("Fitting model round 1") result = ap.fit.Iter(model, verbose=1).fit() print("expanding windows") windows = ap.utils.initialize.scale_windows( windows, - image_shape=target_data.shape, + image=target, expand_scale=window_expand_scale, expand_border=window_expand_border, ) @@ -217,44 +176,44 @@ models[i + 1].window = windows[window] print("Fitting round 2") result = ap.fit.Iter(model, verbose=1).fit() -# result.update_uncertainty() coming soon # Report Results # ---------------------------------------------------------------------- if not sky_locked: - print(models[0].parameters) + print(models[0]) if not primary_model is None: - print(primary_model.parameters) - totflux = primary_model.total_flux().detach().cpu().numpy() - print(f"Total Magnitude: {zeropoint - 2.5 * np.log10(totflux)}") + print(primary_model) + totmag = primary_model.total_magnitude().detach().cpu().numpy() + print(f"Total Magnitude: {totmag}") if hasattr(primary_model, "radial_model"): fig, ax = plt.subplots(figsize=(8, 8)) ap.plots.radial_light_profile(fig, ax, primary_model) plt.savefig(f"{name}_radial_light_profile.jpg") plt.close() + with open(f"{name}_primary_params.csv", "w") as f: + f.write("Name,Total Magnitude," + ",".join(primary_model.build_params_array_names()) + "\n") + f.write("string,mag," + ",".join(primary_model.build_params_array_units()) + "\n") + params = primary_model.build_params_array().detach().cpu().numpy() + f.write(",".join([str(x) for x in params]) + "\n") if print_all_models: + print(model) segmap_params = [] for segmodel in models[1:]: if segmodel.name == primary_name: continue - print(segmodel.parameters) - totflux = segmodel.total_flux().detach().cpu().numpy() + totmag = segmodel.total_magnitude().detach().cpu().numpy() segmap_params.append( - [segmodel.name, totflux] - + list(segmodel.parameters.vector_values().detach().cpu().numpy()) + [segmodel.name, totmag] + list(segmodel.build_params_array().detach().cpu().numpy()) ) with open(f"{name}_segmap_params.csv", "w") as f: - f.write("Name,Total Flux," + ",".join(segmodel.parameters.vector_names()) + "\n") - flat_params = segmodel.parameters.flat(False, False).values() - f.write( - "string,mag," + ",".join(p.units for p in flat_params for _ in range(p.size)) + "\n" - ) + f.write("Name,Total Magnitude," + ",".join(segmodel.build_params_array_names()) + "\n") + f.write("string,mag," + ",".join(segmodel.build_params_array_units()) + "\n") for row in segmap_params: f.write(",".join([str(x) for x in row]) + "\n") -model.save(f"{name}_parameters.yaml") +model.save_state(f"{name}_parameters.hdf5") if save_model_image: model().save(f"{name}_model_image.fits") fig, ax = plt.subplots() diff --git a/docs/source/prebuilt/single_model_fit.py b/docs/source/prebuilt/single_model_fit.py index acdfb17e..6529b011 100644 --- a/docs/source/prebuilt/single_model_fit.py +++ b/docs/source/prebuilt/single_model_fit.py @@ -22,13 +22,10 @@ ###################################################################### name = "object_name" # used for saving files target_file = ".fits" # can be a numpy array instead -mask_file = None # ".fits" # can be a numpy array instead psf_file = None # ".fits" # can be a numpy array instead -variance_file = None # ".fits" # or numpy array or "auto" -pixelscale = 0.1 # arcsec/pixel zeropoint = 22.5 # mag initial_params = None # e.g. {"center": [3, 3], "q": {"value": 0.8, "locked": True}} -window = None # None to fit whole image, otherwise ((xmin,xmax),(ymin,ymax)) pixels +window = None # None to fit whole image, otherwise (xmin,xmax,ymin,ymax) pixels initial_sky = None # If None, sky will be estimated sky_locked = False model_type = "sersic galaxy model" @@ -38,8 +35,6 @@ save_residual_image = True save_covariance_matrix = True target_hdu = 0 # FITS file index for image data -mask_hdu = 0 -variance_hdu = 0 psf_hdu = 0 sky_model_type = "flat sky model" ###################################################################### @@ -47,79 +42,43 @@ # load target # --------------------------------------------------------------------- print("loading target") -if isinstance(target_file, str): - hdu = fits.open(target_file) - target_data = np.array(hdu[target_hdu].data, dtype=np.float64) -else: - target_data = target_file +target = ap.TargetImage( + filename=target_file, + hduext=target_hdu, + zeropoint=zeropoint, +) -# load mask, variance, and psf -# --------------------------------------------------------------------- -# Mask -if isinstance(mask_file, str): - print("loading mask") - hdu = fits.open(mask_file) - mask_data = np.array(hdu[mask_hdu].data, dtype=bool) -elif mask_file is None: - mask_data = None -else: - mask_data = mask_file -# Variance -if isinstance(variance_file, str) and not variance_file == "auto": - print("loading variance") - hdu = fits.open(variance_file) - variance_data = np.array(hdu[variance_hdu].data, dtype=np.float64) -elif variance_file is None: - variance_data = None -else: - variance_data = variance_file # PSF if isinstance(psf_file, str): print("loading psf") hdu = fits.open(psf_file) psf_data = np.array(hdu[psf_hdu].data, dtype=np.float64) - psf = ap.image.PSF_Image( - data=psf_data, - pixelscale=pixelscale, - ) + target.psf = target.psf_image(data=psf_data) elif psf_file is None: psf = None else: - psf = ap.image.PSF_Image( - data=psf_file, - pixelscale=pixelscale, - ) - -# Create target object -# --------------------------------------------------------------------- -target = ap.image.Target_Image( - data=target_data, - pixelscale=pixelscale, - zeropoint=zeropoint, - mask=mask_data, - psf=psf, - variance=variance_data, -) + target.psf = target.psf_image(data=psf_file) # Create Model # --------------------------------------------------------------------- -model_object = ap.models.AstroPhot_Model( +model_object = ap.Model( name=name, model_type=model_type, target=target, - psf_mode="full" if psf_file is not None else "none", - parameters=initial_params, + psf_convolve=True if psf_file is not None else False, + **initial_params, window=window, ) -model_sky = ap.models.AstroPhot_Model( +model_sky = ap.Model( name="sky", model_type=sky_model_type, target=target, - parameters={"F": initial_sky} if initial_sky is not None else {}, + I=initial_sky if initial_sky is not None else {}, window=window, - locked=sky_locked, ) -model = ap.models.AstroPhot_Model( +if sky_locked: + model_sky.to_static() +model = ap.Model( name="astrophot model", model_type="group model", target=target, @@ -132,26 +91,15 @@ model.initialize() print("Fitting model") result = ap.fit.LM(model, verbose=1).fit() -print("Update uncertainty") -result.update_uncertainty() # Report Results # ---------------------------------------------------------------------- -if not sky_locked: - print(model_sky.parameters) -print(model_object.parameters) -totflux = model_object.total_flux().detach().cpu().numpy() -try: - totflux_err = model_object.total_flux_uncertainty().detach().cpu().numpy() -except AttributeError: - print( - "sorry, total flux uncertainty not available yet for this model. You are welcome to contribute! :)" - ) - totflux_err = 0 -print( - f"Total Magnitude: {zeropoint - 2.5 * np.log10(totflux)} +- {2.5 * totflux_err / (totflux * np.log(10))}" -) -model.save(f"{name}_parameters.yaml") +print(model) +totmag = model_object.total_magnitude().detach().cpu().numpy() +totmag_err = model_object.total_magnitude_uncertainty().detach().cpu().numpy() +print(f"Total Magnitude: {totmag} +- {totmag_err}") + +model.save_state(f"{name}_parameters.hdf5") if save_model_image: model().save(f"{name}_model_image.fits") fig, ax = plt.subplots() diff --git a/docs/source/tutorials/AdvancedPSFModels.ipynb b/docs/source/tutorials/AdvancedPSFModels.ipynb index a3080d09..287b7f7d 100644 --- a/docs/source/tutorials/AdvancedPSFModels.ipynb +++ b/docs/source/tutorials/AdvancedPSFModels.ipynb @@ -17,13 +17,11 @@ "metadata": {}, "outputs": [], "source": [ + "%matplotlib inline\n", "import astrophot as ap\n", "import numpy as np\n", "import torch\n", - "from astropy.io import fits\n", - "import matplotlib.pyplot as plt\n", - "\n", - "%matplotlib inline" + "import matplotlib.pyplot as plt" ] }, { @@ -44,15 +42,15 @@ "outputs": [], "source": [ "# First make a mock empirical PSF image\n", - "# np.random.seed(124)\n", + "np.random.seed(124)\n", "psf = ap.utils.initialize.moffat_psf(2.0, 3.0, 101, 0.5)\n", "variance = psf**2 / 100\n", "psf += np.random.normal(scale=np.sqrt(variance))\n", - "# psf[psf < 0] = 0 #ap.utils.initialize.moffat_psf(2.0, 3.0, 101, 0.5)[psf < 0]\n", "\n", - "psf_target = ap.image.PSF_Image(\n", + "psf_target = ap.PSFImage(\n", " data=psf,\n", " pixelscale=0.5,\n", + " variance=variance,\n", ")\n", "\n", "# To ensure the PSF has a normalized flux of 1, we call\n", @@ -72,7 +70,7 @@ "outputs": [], "source": [ "# Now we initialize on the image\n", - "psf_model = ap.models.AstroPhot_Model(\n", + "psf_model = ap.Model(\n", " name=\"init psf\",\n", " model_type=\"moffat psf model\",\n", " target=psf_target,\n", @@ -82,12 +80,12 @@ "\n", "# PSF model can be fit to it's own target for good initial values\n", "# Note we provide the weight map (1/variance) since a PSF_Image can't store that information.\n", - "ap.fit.LM(psf_model, verbose=1, W=1 / variance).fit()\n", + "ap.fit.LM(psf_model, verbose=1).fit()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(13, 5))\n", "ap.plots.psf_image(fig, ax[0], psf_model)\n", "ax[0].set_title(\"PSF model fit to mock empirical PSF\")\n", - "ap.plots.residual_image(fig, ax[1], psf_model, normalize_residuals=torch.tensor(variance))\n", + "ap.plots.residual_image(fig, ax[1], psf_model, normalize_residuals=True)\n", "ax[1].set_title(\"residuals\")\n", "plt.show()" ] @@ -104,6 +102,59 @@ "cell_type": "markdown", "id": "6", "metadata": {}, + "source": [ + "## Group PSF Model\n", + "\n", + "Just like group models for regular models, it is possible to make a `psf group model` to combine multiple psf models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "psf_model1 = ap.Model(\n", + " name=\"psf1\",\n", + " model_type=\"moffat psf model\",\n", + " n=2,\n", + " Rd=10,\n", + " I0=20, # essentially controls relative flux of this component\n", + " normalize_psf=False, # sub components shouldnt be individually normalized\n", + " target=psf_target,\n", + ")\n", + "psf_model2 = ap.Model(\n", + " name=\"psf2\",\n", + " model_type=\"sersic psf model\",\n", + " n=4,\n", + " Re=5,\n", + " Ie=1,\n", + " normalize_psf=False,\n", + " target=psf_target,\n", + ")\n", + "psf_group_model = ap.Model(\n", + " name=\"psf group\",\n", + " model_type=\"psf group model\",\n", + " target=psf_target,\n", + " models=[psf_model1, psf_model2],\n", + " normalize_psf=True, # group model should normalize the combined PSF\n", + ")\n", + "psf_group_model.initialize()\n", + "fig, ax = plt.subplots(1, 3, figsize=(15, 5))\n", + "ap.plots.psf_image(fig, ax[0], psf_group_model)\n", + "ax[0].set_title(\"PSF group model with two PSF models\")\n", + "ap.plots.psf_image(fig, ax[1], psf_group_model.models[0])\n", + "ax[1].set_title(\"PSF model component 1\")\n", + "ap.plots.psf_image(fig, ax[2], psf_group_model.models[1])\n", + "ax[2].set_title(\"PSF model component 2\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, "source": [ "## PSF modeling without stars\n", "\n", @@ -113,45 +164,49 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "9", "metadata": {}, "outputs": [], "source": [ "# Lets make some data that we need to fit\n", + "psf_target = ap.PSFImage(\n", + " data=np.zeros((51, 51)),\n", + " pixelscale=1.0,\n", + ")\n", "\n", - "true_psf = ap.utils.initialize.moffat_psf(\n", - " 2.0, # n !!!!! Take note, we want to get n = 2. !!!!!!\n", - " 3.0, # Rd !!!!! Take note, we want to get Rd = 3.!!!!!!\n", - " 51, # pixels\n", - " 1.0, # pixelscale\n", + "true_psf_model = ap.Model(\n", + " name=\"true psf\",\n", + " model_type=\"moffat psf model\",\n", + " target=psf_target,\n", + " n=2,\n", + " Rd=3,\n", ")\n", + "true_psf = true_psf_model().data\n", "\n", - "target = ap.image.Target_Image(\n", + "target = ap.TargetImage(\n", " data=torch.zeros(100, 100),\n", " pixelscale=1.0,\n", " psf=true_psf,\n", ")\n", "\n", - "true_model = ap.models.AstroPhot_Model(\n", + "true_model = ap.Model(\n", " name=\"true model\",\n", " model_type=\"sersic galaxy model\",\n", " target=target,\n", - " parameters={\n", - " \"center\": [50.0, 50.0],\n", - " \"q\": 0.4,\n", - " \"PA\": np.pi / 3,\n", - " \"n\": 2,\n", - " \"Re\": 25,\n", - " \"Ie\": 1,\n", - " },\n", - " psf_mode=\"full\",\n", + " center=[50.0, 50.0],\n", + " q=0.4,\n", + " PA=np.pi / 3,\n", + " n=2,\n", + " Re=25,\n", + " Ie=10,\n", + " psf_convolve=True,\n", ")\n", "\n", "# use the true model to make some data\n", "sample = true_model()\n", "torch.manual_seed(61803398)\n", - "target.data = sample.data + torch.normal(torch.zeros_like(sample.data), 0.1)\n", - "target.variance = 0.01 * torch.ones_like(sample.data)\n", + "target._data = sample.data + torch.normal(torch.zeros_like(sample.data), 0.1)\n", + "target.variance = 0.01 * torch.ones_like(sample.data.T)\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(16, 7))\n", "ap.plots.model_image(fig, ax[0], true_model)\n", @@ -164,14 +219,14 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "10", "metadata": {}, "outputs": [], "source": [ "# Now we will try and fit the data using just a plain sersic\n", "\n", "# Here we set up a sersic model for the galaxy\n", - "plain_galaxy_model = ap.models.AstroPhot_Model(\n", + "plain_galaxy_model = ap.Model(\n", " name=\"galaxy model\",\n", " model_type=\"sersic galaxy model\",\n", " target=target,\n", @@ -188,7 +243,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -206,68 +261,63 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "12", "metadata": {}, "outputs": [], "source": [ "# Now we will try and fit the data with a sersic model and a \"live\" psf\n", "\n", "# Here we create a target psf model which will determine the specs of our live psf model\n", - "psf_target = ap.image.PSF_Image(\n", + "psf_target = ap.PSFImage(\n", " data=np.zeros((51, 51)),\n", " pixelscale=target.pixelscale,\n", ")\n", "\n", - "# Here we create a moffat model for the PSF. Note that this is just a regular AstroPhot model that we have chosen\n", - "# to be a moffat, really any model can be used. To make it suitable as a PSF we will need to apply some very\n", - "# specific settings.\n", - "live_psf_model = ap.models.AstroPhot_Model(\n", + "live_psf_model = ap.Model(\n", " name=\"psf\",\n", " model_type=\"moffat psf model\",\n", " target=psf_target,\n", - " parameters={\n", - " \"n\": 1.0, # True value is 2.\n", - " \"Rd\": 2.0, # True value is 3.\n", - " },\n", + " n=1.0, # True value is 2.\n", + " Rd=2.0, # True value is 3.\n", ")\n", "\n", "# Here we set up a sersic model for the galaxy\n", - "live_galaxy_model = ap.models.AstroPhot_Model(\n", + "live_galaxy_model = ap.Model(\n", " name=\"galaxy model\",\n", " model_type=\"sersic galaxy model\",\n", " target=target,\n", - " psf_mode=\"full\",\n", + " psf_convolve=True,\n", " psf=live_psf_model, # Here we bind the PSF model to the galaxy model, this will add the psf_model parameters to the galaxy_model\n", ")\n", - "\n", - "live_psf_model.initialize()\n", "live_galaxy_model.initialize()\n", "\n", - "result = ap.fit.LM(live_galaxy_model, verbose=1).fit()\n", - "result.update_uncertainty()" + "result = ap.fit.LM(live_galaxy_model, verbose=3).fit()" ] }, { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "13", "metadata": {}, "outputs": [], "source": [ "print(\n", - " \"fitted n for moffat PSF: \", live_galaxy_model[\"psf:n\"].value.item(), \"we were hoping to get 2!\"\n", + " f\"fitted n for moffat PSF: {live_psf_model.n.value.item():.6f} +- {live_psf_model.n.uncertainty.item():.6f} we were hoping to get 2!\"\n", ")\n", "print(\n", - " \"fitted Rd for moffat PSF: \",\n", - " live_galaxy_model[\"psf:Rd\"].value.item(),\n", - " \"we were hoping to get 3!\",\n", + " f\"fitted Rd for moffat PSF: {live_psf_model.Rd.value.item():.6f} +- {live_psf_model.Rd.uncertainty.item():.6f} we were hoping to get 3!\"\n", + ")\n", + "fig, ax = ap.plots.covariance_matrix(\n", + " result.covariance_matrix.detach().cpu().numpy(),\n", + " live_galaxy_model.get_values().detach().cpu().numpy(),\n", + " live_galaxy_model.build_params_array_names(),\n", ")\n", - "print(live_galaxy_model.parameters)" + "plt.show()" ] }, { "cell_type": "markdown", - "id": "12", + "id": "14", "metadata": {}, "source": [ "This is truly remarkable! With no stars available we were still able to extract an accurate PSF from the image! To be fair, this example is essentially perfect for this kind of fitting and we knew the true model types (sersic and moffat) from the start. Still, this is a powerful capability in certain scenarios. For many applications (e.g. weak lensing) it is essential to get the absolute best PSF model possible. Here we have shown that not only stars, but galaxies in the field can be useful tools for measuring the PSF!" @@ -276,7 +326,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -290,7 +340,7 @@ }, { "cell_type": "markdown", - "id": "14", + "id": "16", "metadata": {}, "source": [ "There are regions of parameter space that are degenerate and so even in this idealized scenario the PSF model can get stuck. If you rerun the notebook with different random number seeds for pytorch you may find some where the optimizer \"fails by immobility\" this is when it gets stuck in the parameter space and can't find any way to improve the likelihood. In fact most of these \"fail\" fits do return really good values for the PSF model, so keep in mind that the \"fail\" flag only means the possibility of a truly failed fit. Unfortunately, detecting convergence is hard." @@ -298,181 +348,8 @@ }, { "cell_type": "markdown", - "id": "15", - "metadata": {}, - "source": [ - "## PSF fitting with a faint star\n", - "\n", - "Fitting a PSF to a galaxy is perhaps not the most stable way to get a good model. However, there is a very common situation where this kind of fitting is quite helpful. Consider the scenario that there is a star, but it is not very bright and it is right next to a galaxy. Now we need to model the galaxy and the star simultaneously, but the galaxy should be convolved with the PSF for the fit to be stable (otherwise you'll have to do several iterations to converge). If there were many stars you could perhaps just stack a bunch of them and hope the average is close enough, but in this case we don't have many to work with so we need to squeeze out as much statistical power as possible. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16", - "metadata": {}, - "outputs": [], - "source": [ - "# Lets make some data that we need to fit\n", - "\n", - "true_psf2 = ap.utils.initialize.moffat_psf(\n", - " 2.0, # n !!!!! Take note, we want to get n = 2. !!!!!!\n", - " 3.0, # Rd !!!!! Take note, we want to get Rd = 3.!!!!!!\n", - " 51, # pixels\n", - " 1.0, # pixelscale\n", - ")\n", - "\n", - "target2 = ap.image.Target_Image(\n", - " data=torch.zeros(100, 100),\n", - " pixelscale=1.0,\n", - " psf=true_psf,\n", - ")\n", - "\n", - "true_galaxy2 = ap.models.AstroPhot_Model(\n", - " name=\"true galaxy\",\n", - " model_type=\"sersic galaxy model\",\n", - " target=target2,\n", - " parameters={\n", - " \"center\": [50.0, 50.0],\n", - " \"q\": 0.4,\n", - " \"PA\": np.pi / 3,\n", - " \"n\": 2,\n", - " \"Re\": 25,\n", - " \"Ie\": 1,\n", - " },\n", - " psf_mode=\"full\",\n", - ")\n", - "true_star2 = ap.models.AstroPhot_Model(\n", - " name=\"true star\",\n", - " model_type=\"point model\",\n", - " target=target2,\n", - " parameters={\n", - " \"center\": [70, 70],\n", - " \"flux\": 2.0,\n", - " },\n", - ")\n", - "true_model2 = ap.models.AstroPhot_Model(\n", - " name=\"true model\",\n", - " model_type=\"group model\",\n", - " target=target2,\n", - " models=[true_galaxy2, true_star2],\n", - ")\n", - "\n", - "# use the true model to make some data\n", - "sample2 = true_model2()\n", - "torch.manual_seed(1618033988)\n", - "target2.data = sample2.data + torch.normal(torch.zeros_like(sample2.data), 0.1)\n", - "target2.variance = 0.01 * torch.ones_like(sample2.data)\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(16, 7))\n", - "ap.plots.model_image(fig, ax[0], true_model2)\n", - "ap.plots.target_image(fig, ax[1], target2)\n", - "ax[0].set_title(\"true model\")\n", - "ax[1].set_title(\"mock observed data\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, "id": "17", "metadata": {}, - "outputs": [], - "source": [ - "# Now we will try and fit the data\n", - "\n", - "psf_model2 = ap.models.AstroPhot_Model(\n", - " name=\"psf\",\n", - " model_type=\"moffat psf model\",\n", - " target=psf_target,\n", - " parameters={\n", - " \"n\": 1.0, # True value is 2.\n", - " \"Rd\": 2.0, # True value is 3.\n", - " },\n", - ")\n", - "\n", - "# Here we set up a sersic model for the galaxy\n", - "galaxy_model2 = ap.models.AstroPhot_Model(\n", - " name=\"galaxy model\",\n", - " model_type=\"sersic galaxy model\",\n", - " target=target,\n", - " psf_mode=\"full\",\n", - " psf=psf_model2,\n", - ")\n", - "\n", - "# Let AstroPhot determine its own initial parameters, so it has to start with whatever it decides automatically,\n", - "# just like a real fit.\n", - "galaxy_model2.initialize()\n", - "\n", - "star_model2 = ap.models.AstroPhot_Model(\n", - " name=\"star model\",\n", - " model_type=\"point model\",\n", - " target=target2,\n", - " psf=psf_model2,\n", - " parameters={\n", - " \"center\": [70, 70], # start the star in roughly the right location\n", - " \"flux\": 2.5,\n", - " },\n", - ")\n", - "\n", - "star_model2.initialize()\n", - "\n", - "full_model2 = ap.models.AstroPhot_Model(\n", - " name=\"full model\",\n", - " model_type=\"group model\",\n", - " models=[galaxy_model2, star_model2],\n", - " target=target2,\n", - ")\n", - "\n", - "result = ap.fit.LM(full_model2, verbose=1).fit()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18", - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 2, figsize=(16, 7))\n", - "ap.plots.model_image(fig, ax[0], full_model2)\n", - "ap.plots.residual_image(fig, ax[1], full_model2)\n", - "ax[0].set_title(\"fitted sersic+star model\")\n", - "ax[1].set_title(\"residuals\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19", - "metadata": {}, - "outputs": [], - "source": [ - "print(\"fitted n for moffat PSF: \", galaxy_model2[\"psf:n\"].value.item(), \"we were hoping to get 2!\")\n", - "print(\n", - " \"fitted Rd for moffat PSF: \", galaxy_model2[\"psf:Rd\"].value.item(), \"we were hoping to get 3!\"\n", - ")\n", - "\n", - "print(\n", - " \"---Note that we can just as well look at the star model parameters since they are the same---\"\n", - ")\n", - "print(\"fitted n for moffat PSF: \", psf_model2[\"n\"].value.item(), \"we were hoping to get 2!\")\n", - "print(\"fitted Rd for moffat PSF: \", psf_model2[\"Rd\"].value.item(), \"we were hoping to get 3!\")" - ] - }, - { - "cell_type": "markdown", - "id": "20", - "metadata": {}, - "source": [ - "Note that the fitted moffat parameters aren't much better than they were earlier when we just fit the galaxy alone. This shows us that extended objects have plenty of constraining power when it comes to PSF fitting, all this information has previously been left on the table! It makes sense that the galaxy dominates the PSF fit here, while the star is very simple to fit, it has much less light than the galaxy in this scenario so the S/N for the galaxy dominates. The reason this works really well is of course that the true data is in fact a sersic model, so we are working in a very idealized scenario. Real world galaxies are not necessarily well described by a sersic, so it is worthwhile to be cautious when doing this kind of fitting. Always make sure the results make sense before storming ahead with galaxy based PSF models, that said the payoff can be well worth it." - ] - }, - { - "cell_type": "markdown", - "id": "21", - "metadata": {}, "source": [ "## PSF fitting for faint stars\n", "\n", @@ -482,7 +359,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -491,7 +368,7 @@ }, { "cell_type": "markdown", - "id": "23", + "id": "19", "metadata": {}, "source": [ "## PSF fitting for saturated stars\n", @@ -502,7 +379,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "20", "metadata": {}, "outputs": [], "source": [ diff --git a/docs/source/tutorials/BasicPSFModels.ipynb b/docs/source/tutorials/BasicPSFModels.ipynb index d11cac0e..2b328687 100644 --- a/docs/source/tutorials/BasicPSFModels.ipynb +++ b/docs/source/tutorials/BasicPSFModels.ipynb @@ -22,7 +22,6 @@ "\n", "import astrophot as ap\n", "import numpy as np\n", - "import torch\n", "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline" @@ -35,10 +34,10 @@ "source": [ "## PSF Images\n", "\n", - "A `PSF_Image` is an AstroPhot object which stores the data for a PSF. It records the pixel values for the PSF as well as meta-data like the pixelscale at which it was taken. The point source function (PSF) is a description of how light is distributed into pixels when the light source is a delta function. In Astronomy we are blessed/cursed with many delta function like sources in our images and so PSF modelling is a major component of astronomical image analysis. Here are some points to keep in mind about a PSF.\n", + "A `PSFImage` is an AstroPhot object which stores the data for a PSF. It records the pixel values for the PSF as well as meta-data like the pixelscale at which it was taken. The point source function (PSF) is a description of how light is distributed into pixels when the light source is a delta function. In Astronomy we are blessed/cursed with many delta function like sources in our images and so PSF modelling is a major component of astronomical image analysis. Here are some points to keep in mind about a PSF.\n", "\n", "- PSF images are always odd in shape (e.g. 25x25 pixels, not 24x24 pixels), at the center pixel, in the center of that pixel is where the delta function point source is located by definition\n", - "- In AstroPhot, the coordinates of the center of the center pixel in a `PSF_Image` are always (0,0). \n", + "- In AstroPhot, the coordinates of the center of the center pixel in a `PSFImage` are always (0,0). \n", "- The light in each pixel of a PSF image is already integrated. That is to say, the flux value for a pixel does not represent some model evaluated at the center of the pixel, it instead represents an integral over the whole area of the pixel" ] }, @@ -56,7 +55,7 @@ "psf += np.random.normal(scale=psf / 4)\n", "psf[psf < 0] = ap.utils.initialize.gaussian_psf(2.0, 101, 0.5)[psf < 0]\n", "\n", - "psf_target = ap.image.PSF_Image(\n", + "psf_target = ap.PSFImage(\n", " data=psf,\n", " pixelscale=0.5,\n", ")\n", @@ -70,7 +69,7 @@ "plt.show()\n", "\n", "# Dummy target for sampling purposes\n", - "target = ap.image.Target_Image(data=np.zeros((300, 300)), pixelscale=0.5, psf=psf_target)" + "target = ap.TargetImage(data=np.zeros((300, 300)), pixelscale=0.5, psf=psf_target)" ] }, { @@ -90,17 +89,18 @@ "metadata": {}, "outputs": [], "source": [ - "pointsource = ap.models.AstroPhot_Model(\n", + "pointsource = ap.Model(\n", " model_type=\"point model\",\n", " target=target,\n", - " parameters={\"center\": [75, 75], \"flux\": 1},\n", + " center=[75.25, 75.9],\n", + " flux=1,\n", " psf=psf_target,\n", ")\n", "pointsource.initialize()\n", "# With a convolved sersic the center is much more smoothed out\n", "fig, ax = plt.subplots(figsize=(6, 6))\n", - "ap.plots.model_image(fig, ax, pointsource)\n", - "ax.set_title(\"Point source, convolved with empirical PSF\")\n", + "ap.plots.model_image(fig, ax, pointsource, showcbar=False)\n", + "ax.set_title(\"Point source, with empirical PSF\")\n", "plt.show()" ] }, @@ -108,6 +108,14 @@ "cell_type": "markdown", "id": "6", "metadata": {}, + "source": [ + "Don't worry about the \"fuzz\" of values outside the PSF model. These values are of order 1e-18 and are an artefact of the sub-pixel shift using the FFT shift theorem. They may be treated as zero for numerical purposes." + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, "source": [ "## Extended model PSF convolution\n", "\n", @@ -117,38 +125,53 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", "metadata": {}, "outputs": [], "source": [ - "model_nopsf = ap.models.AstroPhot_Model(\n", + "model_nopsf = ap.Model(\n", " model_type=\"sersic galaxy model\",\n", " target=target,\n", - " parameters={\"center\": [75, 75], \"q\": 0.6, \"PA\": 60 * np.pi / 180, \"n\": 3, \"Re\": 10, \"Ie\": 1},\n", - " psf_mode=\"none\", # no PSF convolution will be done\n", + " center=[75, 75],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " n=3,\n", + " Re=10,\n", + " Ie=10,\n", + " psf_convolve=False, # no PSF convolution will be done\n", ")\n", "model_nopsf.initialize()\n", - "model_psf = ap.models.AstroPhot_Model(\n", + "model_psf = ap.Model(\n", " model_type=\"sersic galaxy model\",\n", " target=target,\n", - " parameters={\"center\": [75, 75], \"q\": 0.6, \"PA\": 60 * np.pi / 180, \"n\": 3, \"Re\": 10, \"Ie\": 1},\n", - " psf_mode=\"full\", # now the full window will be PSF convolved using the PSF from the target\n", + " center=[75, 75],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " n=3,\n", + " Re=10,\n", + " Ie=10,\n", + " psf_convolve=True, # now the full window will be PSF convolved using the PSF from the target\n", ")\n", "model_psf.initialize()\n", "\n", "psf = psf.copy()\n", "psf[49:51] += 4 * np.mean(psf)\n", "psf[:, 49:51] += 4 * np.mean(psf)\n", - "psf_target_2 = ap.image.PSF_Image(\n", + "psf_target_2 = ap.PSFImage(\n", " data=psf,\n", " pixelscale=0.5,\n", ")\n", "psf_target_2.normalize()\n", - "model_selfpsf = ap.models.AstroPhot_Model(\n", + "model_selfpsf = ap.Model(\n", " model_type=\"sersic galaxy model\",\n", " target=target,\n", - " parameters={\"center\": [75, 75], \"q\": 0.6, \"PA\": 60 * np.pi / 180, \"n\": 3, \"Re\": 10, \"Ie\": 1},\n", - " psf_mode=\"full\",\n", + " center=[75, 75],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " n=3,\n", + " Re=10,\n", + " Ie=10,\n", + " psf_convolve=True,\n", " psf=psf_target_2, # Now this model has its own PSF, instead of using the target psf\n", ")\n", "model_selfpsf.initialize()\n", @@ -166,7 +189,48 @@ }, { "cell_type": "markdown", - "id": "8", + "id": "9", + "metadata": {}, + "source": [ + "## Supersampled PSF models\n", + "\n", + "It is generally best practice to use a PSF model that has been determined at a higher resolution than the image you are analyzing. In AstroPhot this can be easily handled by ensuring that the `PSFImage` has an appropriate pixelscale that shows how it is upsampled. For example if our target has a pixelscale of 0.5 and the PSFImage has a pixelscale of 0.25 then AstroPhot will automatically infer that it should work at 2x higher resolution. Note that AstroPhot assumes the PSF has been determined at an integer level of upsampling, so in the example if you set the PSFImage pixelscale to 0.3 then strange things would likely happen to your images!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "upsample_psf_target = ap.PSFImage(\n", + " data=ap.utils.initialize.gaussian_psf(2.0, 51, 0.25),\n", + " pixelscale=0.25, # This PSF is at a higher resolution than the target\n", + ")\n", + "target.psf = upsample_psf_target\n", + "\n", + "model_upsamplepsf = ap.Model(\n", + " model_type=\"sersic galaxy model\",\n", + " target=target,\n", + " center=[75, 75],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " n=3,\n", + " Re=10,\n", + " Ie=10,\n", + " psf_convolve=True,\n", + ")\n", + "model_upsamplepsf.initialize()\n", + "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", + "ap.plots.model_image(fig, ax, model_upsamplepsf)\n", + "ax.set_title(\"With PSF convolution (upsampled PSF)\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "11", "metadata": {}, "source": [ "That covers the basics of adding PSF convolution kernels to AstroPhot models! These techniques assume you already have a model for the PSF that you got with some other algorithm (ie PSFEx), however AstroPhot also has the ability to model the PSF live along with the rest of the models in an image. If you are interested in extracting the PSF from an image using AstroPhot, check out the `AdvancedPSFModels` tutorial. " @@ -175,7 +239,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "12", "metadata": {}, "outputs": [], "source": [] diff --git a/docs/source/tutorials/ConstrainedModels.ipynb b/docs/source/tutorials/ConstrainedModels.ipynb index 67297a59..599df83e 100644 --- a/docs/source/tutorials/ConstrainedModels.ipynb +++ b/docs/source/tutorials/ConstrainedModels.ipynb @@ -29,7 +29,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Range limits\n", + "## Valid Range\n", "\n", "The simplest form of constraint on a parameter is to restrict its range to within some limit. This is done at creation of the variable and you simply indicate the endpoints (non-inclusive) of the limits." ] @@ -40,24 +40,25 @@ "metadata": {}, "outputs": [], "source": [ - "target = ap.image.Target_Image(data=np.zeros((100, 100)), center=[0, 0], pixelscale=1)\n", - "gal1 = ap.models.AstroPhot_Model(\n", + "target = ap.TargetImage(data=np.zeros((100, 100)), crpix=[49.5, 49.5], pixelscale=1)\n", + "gal1 = ap.Model(\n", " name=\"galaxy1\",\n", " model_type=\"sersic galaxy model\",\n", - " parameters={\n", - " \"center\": {\n", - " \"value\": [0, 0],\n", - " \"limits\": [[-10, -20], [10, 20]],\n", - " }, # here we set the limits, note it can be different for each value\n", + " # here we set the limits, note it can be different for each value of center.\n", + " # The valid range is a tuple with two elements, the lower limit and the\n", + " # upper limit, either can be None\n", + " center={\n", + " \"value\": [0, 0],\n", + " \"valid\": ([-10, -20], [10, 20]),\n", " },\n", + " # One sided limits can be used for example to ensure a value is positive\n", + " Re={\"valid\": (0, None)},\n", " target=target,\n", ")\n", "\n", - "# Now if we try to set a value outside the range we get an error\n", - "try:\n", - " gal1[\"center\"].value = [25, 25]\n", - "except ap.errors.InvalidParameter as e:\n", - " print(\"got an AssertionError with message: \", e)" + "# Now if we try to set a value outside the range we get a warning\n", + "gal1.center.value = [25, 25]\n", + "gal1.center.value = [0, 0] # set back to good value" ] }, { @@ -82,37 +83,44 @@ "metadata": {}, "outputs": [], "source": [ - "target = ap.image.Target_Image(data=np.zeros((100, 100)), center=[0, 0], pixelscale=1)\n", - "gal1 = ap.models.AstroPhot_Model(\n", + "gal1 = ap.Model(\n", " name=\"galaxy1\",\n", " model_type=\"sersic galaxy model\",\n", - " parameters={\"center\": [-25, -25], \"PA\": 0, \"q\": 0.9, \"n\": 2, \"Re\": 5, \"Ie\": 1.0},\n", + " center=[-25, -25],\n", + " PA=0,\n", + " q=0.9,\n", + " n=2,\n", + " Re=5,\n", + " Ie=1.0,\n", " target=target,\n", ")\n", - "gal2 = ap.models.AstroPhot_Model(\n", + "gal2 = ap.Model(\n", " name=\"galaxy2\",\n", " model_type=\"sersic galaxy model\",\n", - " parameters={\"center\": [25, 25], \"PA\": 0, \"q\": 0.9, \"Ie\": 1.0},\n", + " center=[25, 25],\n", + " PA=0,\n", + " q=0.9,\n", + " Ie=1.0,\n", " target=target,\n", ")\n", "\n", "# here we set the equality constraint, setting the values for gal2 equal to the parameters of gal1\n", - "gal2[\"n\"].value = gal1[\"n\"]\n", - "gal2[\"Re\"].value = gal1[\"Re\"]\n", + "gal2.n = gal1.n\n", + "gal2.Re = gal1.Re\n", "\n", "# we make a group model to use both star models together\n", - "gals = ap.models.AstroPhot_Model(\n", + "gals = ap.Model(\n", " name=\"gals\",\n", " model_type=\"group model\",\n", " models=[gal1, gal2],\n", " target=target,\n", ")\n", "\n", - "print(gals.parameters)\n", - "\n", "fig, ax = plt.subplots()\n", "ap.plots.model_image(fig, ax, gals)\n", - "plt.show()" + "plt.show()\n", + "\n", + "gals.graphviz()" ] }, { @@ -122,7 +130,7 @@ "outputs": [], "source": [ "# We can now change a parameter value and both models will change\n", - "gal1[\"n\"].value = 1\n", + "gal1.n.value = 1\n", "\n", "fig, ax = plt.subplots()\n", "ap.plots.model_image(fig, ax, gals)\n", @@ -146,7 +154,7 @@ "\n", "- A spatially varying PSF can be forced to obey some smoothing function such as a plane or spline\n", "- The SED of a multiband fit may be constrained to follow some pre-determined form\n", - "- An astrometry correction in multi-image fitting can be included for each image to ensure precise alignment\n", + "- A light curve model could be used to constrain the brightness in a multi-epoch analysis\n", "\n", "The possibilities with this kind of constraint capability are quite extensive. If you do something creative with these functional constraints please let us know!" ] @@ -158,67 +166,54 @@ "outputs": [], "source": [ "# Here we will demo a spatially varying PSF where the moffat \"n\" parameter changes across the image\n", - "target = ap.image.Target_Image(data=np.zeros((100, 100)), center=[0, 0], pixelscale=1)\n", + "target = ap.TargetImage(data=np.zeros((100, 100)), crpix=[49.5, 49.5], pixelscale=1)\n", + "\n", + "psf_target = ap.PSFImage(data=np.zeros((55, 55)), pixelscale=1)\n", + "\n", + "# We make parameters and a function to control the moffat n parameter\n", + "intercept = ap.Param(\"intercept\", 3)\n", + "slope = ap.Param(\"slope\", [1 / 50, -1 / 50])\n", + "\n", + "\n", + "def constrained_moffat_n(n_param):\n", + " return n_param.intercept.value + torch.sum(n_param.slope.value * n_param.center.value)\n", "\n", - "psf_target = ap.image.PSF_Image(data=np.zeros((25, 25)), pixelscale=1)\n", "\n", - "# First we make all the star objects\n", + "# Next we make all the star and PSF objects\n", "allstars = []\n", "allpsfs = []\n", "for x in [-30, 0, 30]:\n", " for y in [-30, 0, 30]:\n", - " allpsfs.append(\n", - " ap.models.AstroPhot_Model(\n", - " name=\"psf\",\n", - " model_type=\"moffat psf model\",\n", - " parameters={\"Rd\": 2},\n", - " target=psf_target,\n", - " )\n", + " psf = ap.Model(\n", + " name=\"psf\",\n", + " model_type=\"moffat psf model\",\n", + " Rd=2,\n", + " n={\"value\": constrained_moffat_n},\n", + " target=psf_target,\n", " )\n", + " if len(allstars) > 0:\n", + " psf.Rd = allstars[0].psf.Rd\n", " allstars.append(\n", - " ap.models.AstroPhot_Model(\n", + " ap.Model(\n", " name=f\"star {x} {y}\",\n", " model_type=\"point model\",\n", - " parameters={\"center\": [x, y], \"flux\": 1},\n", + " center=[x, y],\n", + " flux=1,\n", " target=target,\n", - " psf=allpsfs[-1],\n", + " psf=psf,\n", " )\n", " )\n", - " allpsfs[-1][\"n\"].link(\n", - " allstars[-1][\"center\"]\n", - " ) # see we need to link the center as well so that it can be used in the function\n", - "\n", - "# we link the Rd parameter for all the PSFs so that they are the same\n", - "for psf in allpsfs[1:]:\n", - " psf[\"Rd\"].value = allpsfs[0][\"Rd\"]\n", - "\n", - "# next we need the parameters for the spatially varying PSF plane\n", - "P_intercept = ap.param.Parameter_Node(\n", - " name=\"intercept\",\n", - " value=3,\n", - ")\n", - "P_slope = ap.param.Parameter_Node(\n", - " name=\"slope\",\n", - " value=[1 / 50, -1 / 50],\n", - ")\n", - "\n", - "\n", - "# next we define the function which takes the parameters as input and returns the value for n\n", - "def constrained_moffat_n(params):\n", - " return params[\"intercept\"].value + torch.sum(params[\"slope\"].value * params[\"center\"].value)\n", "\n", + " # see we need to link the center as well so that it can be used in the function\n", + " psf.n.link((intercept, slope, allstars[-1].center))\n", "\n", - "# finally we assign this parameter function to the \"n\" parameter for each moffat\n", - "for psf in allpsfs:\n", - " psf[\"n\"].value = constrained_moffat_n\n", - " psf[\"n\"].link(P_intercept)\n", - " psf[\"n\"].link(P_slope)\n", "\n", "# A group model holds all the stars together\n", - "MODEL = ap.models.AstroPhot_Model(\n", + "sky = ap.Model(name=\"sky\", model_type=\"flat sky model\", I=1e-5, target=target)\n", + "MODEL = ap.Model(\n", " name=\"spatial PSF\",\n", " model_type=\"group model\",\n", - " models=allstars,\n", + " models=[sky] + allstars,\n", " target=target,\n", ")\n", "\n", diff --git a/docs/source/tutorials/CustomModels.ipynb b/docs/source/tutorials/CustomModels.ipynb index f8d64c5f..9ff0dfd6 100644 --- a/docs/source/tutorials/CustomModels.ipynb +++ b/docs/source/tutorials/CustomModels.ipynb @@ -6,13 +6,44 @@ "source": [ "# Custom model objects\n", "\n", - "Here we will go over some of the core functionality of AstroPhot models so that you can make your own custom models with arbitrary behavior. This is an advanced tutorial and likely not needed for most users. However, the flexibility of AstroPhot can be a real lifesaver for some niche applications! If you get stuck trying to make your own models, please contact Connor Stone (see GitHub), he can help you get the model working and maybe even help add it to the core AstroPhot model list!\n", + "Here we will go over some of the core functionality of AstroPhot models so that\n", + "you can make your own custom models with arbitrary behavior. This is an advanced\n", + "tutorial and likely not needed for most users. However, the flexibility of\n", + "AstroPhot can be a real lifesaver for some niche applications! If you get stuck\n", + "trying to make your own models, please contact Connor Stone (see GitHub), he can\n", + "help you get the model working and maybe even help add it to the core AstroPhot\n", + "model list!\n", "\n", "### AstroPhot model hierarchy\n", "\n", - "AstroPhot models are very much object oriented and inheritance driven. Every AstroPhot model inherits from `AstroPhot_Model` and so if you wish to make something truly original then this is where you would need to start. However, it is almost certain that is the wrong way to go. Further down the hierarchy is the `Component_Model` object, this is what you will likely use to construct a custom model as it represents a single \"unit\" in the astronomical image. Spline, Sersic, Exponential, Gaussian, PSF, Sky, etc. all of these inherit from `Component_Model` so likely that's what you will want. At its core, a `Component_Model` object defines a center location for the model, but it doesn't know anything else yet. At the same level as `Component_Model` is `Group_Model` which represents a collection of model objects (typically but not always `Component_Model` objects). A `Group_Model` is how you construct more complex models by composing several simpler models. It's unlikely you'll need to inherit from `Group_Model` so we won't discuss this any further (contact the developers if you're thinking about that). \n", + "AstroPhot models are very much object oriented and inheritance driven. Every\n", + "AstroPhot model inherits from `Model` and so if you wish to make something truly\n", + "original then this is where you would need to start. However, it is almost\n", + "certain that is the wrong way to go. Further down the hierarchy is the\n", + "`ComponentModel` object, this is what you will likely use to construct a custom\n", + "model as it represents a single \"unit\" in the astronomical image. Spline,\n", + "Sersic, Exponential, Gaussian, PSF, Sky, etc. all of these inherit from\n", + "`ComponentModel` so likely that's what you will want. At its core, a\n", + "`ComponentModel` object defines a center location for the model, but it doesn't\n", + "know anything else yet. At the same level as `ComponentModel` is `GroupModel`\n", + "which represents a collection of model objects (typically but not always\n", + "`ComponentModel` objects). A `GroupModel` is how you construct more complex\n", + "models by composing several simpler models. It's unlikely you'll need to inherit\n", + "from `GroupModel` so we won't discuss this any further (contact the developers\n", + "if you're thinking about that). \n", "\n", - "Inheriting from `Component_Model` are a few general classes which make it easier to build typical cases. There is the `Galaxy_Model` which adds a position angle and axis ratio to the model; also `Star_Model` which simply enforces no psf convolution on the object since that will be handled internally for anything star like; `Sky_Model` should be used for anything low resolution defined over the entire image, in this model psf convolution and integration are turned off since they shouldn't be needed. Based on these low level classes, you can \"jump in\" where it makes sense to define your model. Of course, you can take any AstroPhot model as a starting point and modify it to suit a given task, however we will not list all models here. See the documentation for a more complete list." + "Inheriting from `ComponentModel` are a few general classes which make it easier\n", + "to build typical cases. There is the `GalaxyModel` which adds a position angle\n", + "and axis ratio to the model; also `PointSource` which simply enforces some\n", + "restrictions that make more sense for a delta function model; `SkyModel` should\n", + "be used for anything low resolution defined over the entire image, in this model\n", + "psf convolution and sub-pixel integration are turned off since they shouldn't be\n", + "needed. Based on these low level classes, you can \"jump in\" where it makes sense\n", + "to define your model. If you are looking to define a sersic that has some\n", + "slightly different behaviour you may be able to take the `SersicGalaxy` class\n", + "and directly make your modification. Of course, you can take any AstroPhot model\n", + "as a starting point and modify it to suit a given task, however we will not list\n", + "all models here. See the documentation for a more complete list." ] }, { @@ -35,10 +66,9 @@ "from astropy.io import fits\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import socket\n", "\n", - "ap.AP_config.set_logging_output(\n", - " stdout=True, filename=None\n", - ") # see GettingStarted tutorial for what this does" + "socket.setdefaulttimeout(120)" ] }, { @@ -47,38 +77,26 @@ "metadata": {}, "outputs": [], "source": [ - "class My_Sersic(ap.models.Galaxy_Model):\n", + "class My_Sersic(ap.models.RadialMixin, ap.models.GalaxyModel):\n", " \"\"\"Let's make a sersic model!\"\"\"\n", "\n", - " model_type = f\"mysersic {ap.models.Galaxy_Model.model_type}\" # here we give a name to the model, the convention is to lead with a new identifier then include the name of the inheritance model\n", - " parameter_specs = {\n", - " \"my_n\": {\n", - " \"limits\": (0.36, 8)\n", - " }, # our sersic index will have some default limits so it doesn't produce weird results\n", - " \"my_Re\": {\n", - " \"limits\": (0, None)\n", - " }, # our effective radius must be positive, otherwise it is fair game\n", - " \"my_Ie\": {}, # our effective surface density could be any real number\n", + " _model_type = \"mysersic\" # here we give a name to the model, since we inherit from GalaxyModel the full model_type will be \"mysersic galaxy model\"\n", + " _parameter_specs = {\n", + " # our sersic index will have some default limits so it doesn't produce\n", + " # weird results We also indicate the expected shapeof the parameter, in\n", + " # this case a scalar. This isn't necessary but it gives AstroPhot more\n", + " # information to work with. if e.g. you accidentaly provide multiple\n", + " # values, you'll now get an error rather than confusing behavior later.\n", + " \"my_n\": {\"valid\": (0.36, 8), \"shape\": (), \"dynamic\": True},\n", + " \"my_Re\": {\"units\": \"arcsec\", \"valid\": (0, None), \"shape\": (), \"dynamic\": True},\n", + " \"my_Ie\": {\"units\": \"flux/arcsec^2\", \"dynamic\": True},\n", " }\n", - " _parameter_order = ap.models.Galaxy_Model._parameter_order + (\n", - " \"my_n\",\n", - " \"my_Re\",\n", - " \"my_Ie\",\n", - " ) # we have to tell AstroPhot what order to access these parameters, this is used in several underlying methods\n", "\n", - " def radial_model(\n", - " self, R, image=None, parameters=None\n", - " ): # by default a Galaxy_Model object will call radial_model to determine the flux at each pixel\n", - " bn = ap.utils.conversions.functions.sersic_n_to_b(\n", - " parameters[\"my_n\"].value\n", - " ) # AstroPhot has a number of useful util functions, though you are welcome to use your own\n", - " return (\n", - " parameters[\"my_Ie\"].value\n", - " * (image.pixel_area)\n", - " * torch.exp(\n", - " -bn * ((R / parameters[\"my_Re\"].value) ** (1.0 / parameters[\"my_n\"].value) - 1)\n", - " )\n", - " ) # this is simply the classic sersic profile. more details later." + " # a GalaxyModel object will determine the radius for each pixel then call radial_model to determine the brightness\n", + " @ap.forward\n", + " def radial_model(self, R, my_n, my_Re, my_Ie):\n", + " bn = ap.models.func.sersic_n_to_b(my_n)\n", + " return my_Ie * torch.exp(-bn * ((R / my_Re) ** (1.0 / my_n) - 1))" ] }, { @@ -99,15 +117,8 @@ ")\n", "target_data = np.array(hdu[0].data, dtype=np.float64)\n", "\n", - "# Create a target object with specified pixelscale and zeropoint\n", - "target = ap.image.Target_Image(\n", - " data=target_data,\n", - " pixelscale=0.262,\n", - " zeropoint=22.5,\n", - " variance=np.ones(target_data.shape) / 1e3,\n", - ")\n", + "target = ap.TargetImage(data=target_data, pixelscale=0.262, zeropoint=22.5, variance=\"auto\")\n", "\n", - "# The default AstroPhot target plotting method uses log scaling in bright areas and histogram scaling in faint areas\n", "fig, ax = plt.subplots(figsize=(8, 8))\n", "ap.plots.target_image(fig, ax, target)\n", "plt.show()" @@ -122,11 +133,10 @@ "my_model = My_Sersic( # notice we are now using the custom class\n", " name=\"wow I made a model\",\n", " target=target, # now the model knows what its trying to match\n", - " parameters={\n", - " \"my_n\": 1.0,\n", - " \"my_Re\": 50,\n", - " \"my_Ie\": 1.0,\n", - " }, # note we have to give initial values for our new parameters. We'll see what can be done for this later\n", + " # note we have to give initial values for our new parameters. AstroPhot doesn't know how to auto-initialize them because they are custom\n", + " my_n=1.0,\n", + " my_Re=50,\n", + " my_Ie=1.0,\n", ")\n", "\n", "# We gave it parameters for our new variables, but initialize will get starting values for everything else\n", @@ -167,11 +177,24 @@ "source": [ "Success! Our \"custom\" sersic model behaves exactly as expected. While going through the tutorial so far there may have been a few things that stood out to you. Lets discuss them now:\n", "\n", - "- What was \"sample_image\" in the radial_model function? This is an object for the image that we are currently sampling. You shouldn't need to do anything with it except get the pixelscale.\n", - "- what else is in \"ap.utils\"? Lots of stuff used in the background by AstroPhot. For now the organization of these is not very good and sometimes changes, so you may wish to just make your own functions for the time being.\n", - "- Why the weird way to access the parameters? The self\\[\"variable\"\\].value format was settled on for simplicity and generality. it's not perfect, but it works.\n", - "- Why is \"sample_image.pixel_area\" in the sersic evaluation? it is important for AstroPhot to know the size of the pixels it is evaluating, multiplying by this value will normalize the flux evaluation regardless of the pixel sizes.\n", - "- When making the model, why did we have to provide values for the parameters? Every model can define an \"initialize\" function which sets the values for its parameters. Since we didn't add that function to our custom class, it doesn't know how to set those variables. All the other variables can be auto-initialized though." + "- What is `ap.models.RadialMixin`? Think of \"Mixin's\" as power ups for classes,\n", + " this power up makes a `brightness` function which calls `radial_model` to\n", + " determine the flux density, that way you only need to define a radial function\n", + " rather than a more general `brightness(x,y)` 2D function.\n", + "- what else is in \"ap.models.func\"? Lots of stuff used in the background by\n", + " AstroPhot models. There is a similar `ap.image.func` for image specific\n", + " functions. You can use these, or write your own functions.\n", + "- How did the `radial_model` function accept the parameters I defined in\n", + " `_parameter_specs`? That's the work of `caskade` a powerful parameter\n", + " management tool.\n", + "- When making the model, why did we have to provide values for the parameters?\n", + " Every model can define an \"initialize\" function which sets the values for its\n", + " parameters. Since we didn't add that function to our custom class, it doesn't\n", + " know how to set those variables. All the other variables can be\n", + " auto-initialized though.\n", + "- Why is `radial_model` decorated with `@ap.forward`? This is part of the\n", + " `caskade` system, the `@ap.forward` here does a lot of heavily lifting\n", + " automatically to fill in values for `my_n`, `my_Re`, and `my_Ie`" ] }, { @@ -189,49 +212,34 @@ "metadata": {}, "outputs": [], "source": [ - "class My_Super_Sersic(\n", - " My_Sersic\n", - "): # note we're inheriting everything from the My_Sersic model since its not making any new parameters\n", - " model_type = \"super awesome sersic model\" # you can make the name anything you like, but the one above follows the normal convention\n", + "# note we're inheriting everything from the My_Sersic model since its not making any new parameters\n", + "class My_Super_Sersic(My_Sersic):\n", + " _model_type = \"super\" # the new name will be \"super mysersic galaxy model\"\n", "\n", - " def initialize(self, target=None, parameters=None):\n", - " if target is None: # good to just use the model target if none given\n", - " target = self.target\n", - " if parameters is None:\n", - " parameters = self.parameters\n", - " super().initialize(\n", - " target=target, parameters=parameters\n", - " ) # typically you want all the lower level parameters determined first\n", + " def initialize(self):\n", + " # typically you want all the lower level parameters determined first\n", + " super().initialize()\n", "\n", - " target_area = target[\n", - " self.window\n", - " ] # this gets the part of the image that the user actually wants us to analyze\n", + " # this gets the part of the image that the user actually wants us to analyze\n", + " target_area = target[self.window]\n", "\n", - " if self[\"my_n\"].value is None: # only do anything if the user didn't provide a value\n", - " with ap.param.Param_Unlock(parameters[\"my_n\"]):\n", - " parameters[\"my_n\"].value = (\n", - " 2.0 # make an initial value for my_n. Override locked since this is the beginning\n", - " )\n", - " parameters[\"my_n\"].uncertainty = (\n", - " 0.1 # make sure there is a starting point for the uncertainty too\n", - " )\n", + " # only initialize if the user didn't already provide a value\n", + " if not self.my_n.initialized:\n", + " # make an initial value for my_n. It's a \"dynamic_value\" so it can be optimized later\n", + " self.my_n.value = 2.0\n", "\n", - " if (\n", - " self[\"my_Re\"].value is None\n", - " ): # same as my_n, though in general you should try to do something smart to get a good starting point\n", - " with ap.param.Param_Unlock(parameters[\"my_Re\"]):\n", - " parameters[\"my_Re\"].value = 20.0\n", - " parameters[\"my_Re\"].uncertainty = 0.1\n", + " if not self.my_Re.initialized:\n", + " self.my_Re.value = 20.0\n", "\n", - " if self[\"my_Ie\"].value is None: # lets try to be a bit clever here\n", - " small_window = self.window.copy().crop_pixel(\n", - " (250,)\n", - " ) # This creates a window much smaller, but still centered on the same point\n", - " with ap.param.Param_Unlock(parameters[\"my_Ie\"]):\n", - " parameters[\"my_Ie\"].value = (\n", - " torch.median(target_area[small_window].data) / target_area.pixel_area\n", - " ) # this will be an average in the window, should at least get us within an order of magnitude\n", - " parameters[\"my_Ie\"].uncertainty = 0.1" + " # lets try to be a bit clever here. This will be an average in the\n", + " # window, should at least get us within an order of magnitude\n", + " if not self.my_Ie.initialized:\n", + " center = target_area.plane_to_pixel(*self.center.value)\n", + " i, j = int(center[0].item()), int(center[1].item())\n", + " self.my_Ie.value = (\n", + " torch.median(target_area.data[i - 100 : i + 100, j - 100 : j + 100])\n", + " / target_area.pixel_area\n", + " )" ] }, { @@ -240,10 +248,11 @@ "metadata": {}, "outputs": [], "source": [ - "my_super_model = My_Super_Sersic( # notice we switched the custom class\n", + "my_super_model = ap.Model(\n", " name=\"goodness I made another one\",\n", + " model_type=\"super mysersic galaxy model\", # this is the type we defined above\n", " target=target,\n", - ") # no longer need to provide initial values!\n", + ")\n", "\n", "my_super_model.initialize()\n", "\n", @@ -290,7 +299,13 @@ "source": [ "## Models from scratch\n", "\n", - "By inheriting from `Galaxy_Model` we got to start with some methods already available. In this section we will see how to create a model essentially from scratch by inheriting from the `Component_Model` object. Below is an example model which uses a $\\frac{I_0}{R}$ model, this is a weird model but it will work. To demonstrate the basics for a `Component_Model` is actually simpler than a `Galaxy_Model` we really only need the `evaluate_model` function, it's what you do with that function where the complexity arises." + "By inheriting from `GalaxyModel` we got to start with some methods already\n", + "available. In this section we will see how to create a model essentially from\n", + "scratch by inheriting from the `ComponentModel` object. Below is an example\n", + "model which uses a $\\frac{I_0}{R}$ model, this is a weird model but it will\n", + "work. To demonstrate the basics for a `ComponentModel` is actually simpler than\n", + "a `GalaxyModel` we really only need the `brightness(x,y)` function, it's what\n", + "you do with that function where the complexity arises." ] }, { @@ -299,34 +314,35 @@ "metadata": {}, "outputs": [], "source": [ - "class My_InvR(ap.models.Component_Model):\n", - " model_type = \"InvR model\"\n", + "class My_InvR(ap.models.ComponentModel):\n", + " _model_type = \"InvR\"\n", "\n", - " parameter_specs = {\n", - " \"my_Rs\": {\"limits\": (0, None)}, # This will be the scale length\n", - " \"my_I0\": {}, # This will be the central brightness\n", + " _parameter_specs = {\n", + " # scale length\n", + " \"my_Rs\": {\"units\": \"arcsec\", \"valid\": (0, None)},\n", + " \"my_I0\": {\"units\": \"flux/arcsec^2\"}, # central brightness\n", " }\n", - " _parameter_order = ap.models.Component_Model._parameter_order + (\n", - " \"my_Rs\",\n", - " \"my_I0\",\n", - " ) # we have to tell AstroPhot what order to access these parameters, this is used in several underlying methods\n", "\n", - " epsilon = 1e-4 # this can be set with model.epsilon, but will not be fit during optimization\n", + " def __init__(self, *args, epsilon=1e-4, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + " self.epsilon = epsilon\n", "\n", - " def evaluate_model(self, X=None, Y=None, image=None, parameters=None):\n", - " if X is None or Y is None:\n", - " Coords = image.get_coordinate_meshgrid()\n", - " X, Y = Coords - parameters[\"center\"].value[..., None, None]\n", - " return parameters[\"my_I0\"].value * image.pixel_area / torch.sqrt(X**2 + Y**2 + self.epsilon)" + " @ap.forward\n", + " def brightness(self, x, y, my_Rs, my_I0):\n", + " x, y = self.transform_coordinates(\n", + " x, y\n", + " ) # basically just subtracts the center from the coordinates\n", + " R = torch.sqrt(x**2 + y**2 + self.epsilon) / my_Rs\n", + " return my_I0 / R" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "See now that we must define a `evaluate_model` method. This takes coordinates, an image object, and parameters and returns the model evaluated at the coordinates. No need to worry about integrating the model within a pixel, this will be handled internally, just evaluate the model at the center of each pixel. For most situations this is made easier with the `get_coordinate_meshgrid_torch` method that all AstroPhot `Target_Image` objects have. We also add a new value `epsilon` which is a core radius in arcsec. This parameter will not be fit, it is set as part of the model creation. You can now also provide epsilon when creating the model, or do nothing and the default value will be used.\n", + "See now that we must define a `brightness` method. This takes general tangent plane coordinates and returns the model evaluated at those coordinates. No need to worry about integrating the model within a pixel, this will be handled internally, just evaluate the model at exactly the coordinates requested. We also add a new value `epsilon` which is a core radius in arcsec and stops numerical divide by zero errors at the center. This parameter will not be fit, it is set as part of the model creation. You can now also provide epsilon when creating the model, or do nothing and the default value will be used.\n", "\n", - "From here you have complete freedom, it need only provide a value for each pixel in the given image. Just make sure that it accounts for pixel size (proportional to pixelscale^2). Also make sure to use only pytorch functions, since that way it is possible to run on GPU and propagate derivatives." + "From here you have complete freedom, make sure to use only pytorch functions, since that way it is possible to run on GPU and propagate derivatives." ] }, { @@ -335,16 +351,20 @@ "metadata": {}, "outputs": [], "source": [ - "simpletarget = ap.image.Target_Image(data=np.zeros([100, 100]), pixelscale=1)\n", - "newmodel = My_InvR(\n", + "simpletarget = ap.TargetImage(data=np.zeros([100, 100]), pixelscale=1)\n", + "newmodel = ap.Model(\n", " name=\"newmodel\",\n", + " model_type=\"InvR model\", # this is the type we defined above\n", " epsilon=1,\n", - " parameters={\"center\": [50, 50], \"my_Rs\": 10, \"my_I0\": 1.0},\n", + " center=[50, 50],\n", + " my_Rs=10,\n", + " my_I0=1.0,\n", " target=simpletarget,\n", ")\n", "\n", "fig, ax = plt.subplots(1, 1, figsize=(8, 7))\n", "ap.plots.model_image(fig, ax, newmodel)\n", + "ax.set_title(\"Observe parental-figure, no hands!\")\n", "plt.show()" ] }, diff --git a/docs/source/tutorials/FittingMethods.ipynb b/docs/source/tutorials/FittingMethods.ipynb index 30a0ece8..7795fa18 100644 --- a/docs/source/tutorials/FittingMethods.ipynb +++ b/docs/source/tutorials/FittingMethods.ipynb @@ -17,6 +17,7 @@ "source": [ "%load_ext autoreload\n", "%autoreload 2\n", + "%matplotlib inline\n", "\n", "import torch\n", "import numpy as np\n", @@ -24,15 +25,20 @@ "from matplotlib.patches import Ellipse\n", "from scipy.stats import gaussian_kde as kde\n", "from scipy.stats import norm\n", + "from tqdm import tqdm\n", + "from corner import corner\n", "\n", - "%matplotlib inline\n", "import astrophot as ap" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "# Setup a fitting problem. You can ignore this cell to start, it just makes some test data to fit\n", @@ -41,7 +47,7 @@ "def true_params():\n", "\n", " # just some random parameters to use for fitting. Feel free to play around with these to see what happens!\n", - " sky_param = np.array([1.5])\n", + " sky_param = np.array([10**1.5])\n", " sersic_params = np.array(\n", " [\n", " [\n", @@ -51,7 +57,7 @@ " 37.19794926 * np.pi / 180,\n", " 2.14513004,\n", " 22.05219055,\n", - " 2.45583024,\n", + " 10**2.45583024,\n", " ],\n", " [\n", " 44.00353786,\n", @@ -60,7 +66,7 @@ " 172.03862521 * np.pi / 180,\n", " 2.88613347,\n", " 12.095631,\n", - " 2.76711163,\n", + " 10**2.76711163,\n", " ],\n", " ]\n", " )\n", @@ -70,11 +76,11 @@ "\n", "def init_params():\n", "\n", - " sky_param = np.array([1.4])\n", + " sky_param = np.array([10**1.4])\n", " sersic_params = np.array(\n", " [\n", - " [57.0, 56.0, 0.6, 40.0 * np.pi / 180, 1.5, 25.0, 2.0],\n", - " [45.0, 30.0, 0.5, 170.0 * np.pi / 180, 2.0, 10.0, 3.0],\n", + " [57.0, 56.0, 0.6, 40.0 * np.pi / 180, 1.5, 25.0, 10**2.0],\n", + " [45.0, 30.0, 0.5, 170.0 * np.pi / 180, 2.0, 10.0, 10**3.0],\n", " ]\n", " )\n", "\n", @@ -91,36 +97,33 @@ "\n", " # List of models, starting with the sky\n", " model_list = [\n", - " ap.models.AstroPhot_Model(\n", + " ap.Model(\n", " name=\"sky\",\n", " model_type=\"flat sky model\",\n", " target=target,\n", - " parameters={\"F\": sky_param[0]},\n", + " I=sky_param[0],\n", " )\n", " ]\n", " # Add models to the list\n", " for i, params in enumerate(sersic_params):\n", " model_list.append(\n", - " [\n", - " ap.models.AstroPhot_Model(\n", - " name=f\"sersic {i}\",\n", - " model_type=\"sersic galaxy model\",\n", - " target=target,\n", - " parameters={\n", - " \"center\": [params[0], params[1]],\n", - " \"q\": params[2],\n", - " \"PA\": params[3],\n", - " \"n\": params[4],\n", - " \"Re\": params[5],\n", - " \"Ie\": params[6],\n", - " },\n", - " # psf_mode = \"full\", # uncomment to try everything with PSF blurring (takes longer)\n", - " )\n", - " ]\n", + " ap.Model(\n", + " name=f\"sersic {i}\",\n", + " model_type=\"sersic galaxy model\",\n", + " target=target,\n", + " center=[params[0], params[1]],\n", + " q=params[2],\n", + " PA=params[3],\n", + " n=params[4],\n", + " Re=params[5],\n", + " Ie=params[6],\n", + " # psf_convolve = True, # uncomment to try everything with PSF blurring (takes longer)\n", + " )\n", " )\n", "\n", - " MODEL = ap.models.Group_Model(\n", + " MODEL = ap.Model(\n", " name=\"group\",\n", + " model_type=\"group model\",\n", " models=model_list,\n", " target=target,\n", " )\n", @@ -140,7 +143,7 @@ " PSF = ap.utils.initialize.gaussian_psf(2, 21, pixelscale)\n", " PSF /= np.sum(PSF)\n", "\n", - " target = ap.image.Target_Image(\n", + " target = ap.TargetImage(\n", " data=np.zeros((N, N)),\n", " pixelscale=pixelscale,\n", " psf=PSF,\n", @@ -230,81 +233,6 @@ " plt.show()\n", "\n", "\n", - "def corner_plot_covariance(\n", - " cov_matrix, mean, labels=None, figsize=(10, 10), true_values=None, ellipse_colors=\"g\"\n", - "):\n", - " num_params = cov_matrix.shape[0]\n", - " fig, axes = plt.subplots(num_params, num_params, figsize=figsize)\n", - " plt.subplots_adjust(wspace=0.0, hspace=0.0)\n", - "\n", - " for i in range(num_params):\n", - " for j in range(num_params):\n", - " ax = axes[i, j]\n", - "\n", - " if i == j:\n", - " x = np.linspace(\n", - " mean[i] - 3 * np.sqrt(cov_matrix[i, i]),\n", - " mean[i] + 3 * np.sqrt(cov_matrix[i, i]),\n", - " 100,\n", - " )\n", - " y = norm.pdf(x, mean[i], np.sqrt(cov_matrix[i, i]))\n", - " ax.plot(x, y, color=\"g\")\n", - " ax.set_xlim(\n", - " mean[i] - 3 * np.sqrt(cov_matrix[i, i]), mean[i] + 3 * np.sqrt(cov_matrix[i, i])\n", - " )\n", - " if true_values is not None:\n", - " ax.axvline(true_values[i], color=\"red\", linestyle=\"-\", lw=1)\n", - " elif j < i:\n", - " cov = cov_matrix[np.ix_([j, i], [j, i])]\n", - " lambda_, v = np.linalg.eig(cov)\n", - " lambda_ = np.sqrt(lambda_)\n", - " angle = np.rad2deg(np.arctan2(v[1, 0], v[0, 0]))\n", - " for k in [1, 2]:\n", - " ellipse = Ellipse(\n", - " xy=(mean[j], mean[i]),\n", - " width=lambda_[0] * k * 2,\n", - " height=lambda_[1] * k * 2,\n", - " angle=angle,\n", - " edgecolor=ellipse_colors,\n", - " facecolor=\"none\",\n", - " )\n", - " ax.add_artist(ellipse)\n", - "\n", - " # Set axis limits\n", - " margin = 3\n", - " ax.set_xlim(\n", - " mean[j] - margin * np.sqrt(cov_matrix[j, j]),\n", - " mean[j] + margin * np.sqrt(cov_matrix[j, j]),\n", - " )\n", - " ax.set_ylim(\n", - " mean[i] - margin * np.sqrt(cov_matrix[i, i]),\n", - " mean[i] + margin * np.sqrt(cov_matrix[i, i]),\n", - " )\n", - "\n", - " if true_values is not None:\n", - " ax.axvline(true_values[j], color=\"red\", linestyle=\"-\", lw=1)\n", - " ax.axhline(true_values[i], color=\"red\", linestyle=\"-\", lw=1)\n", - "\n", - " if j > i:\n", - " ax.axis(\"off\")\n", - "\n", - " if i < num_params - 1:\n", - " ax.set_xticklabels([])\n", - " else:\n", - " if labels is not None:\n", - " ax.set_xlabel(labels[j])\n", - " ax.yaxis.set_major_locator(plt.NullLocator())\n", - "\n", - " if j > 0:\n", - " ax.set_yticklabels([])\n", - " else:\n", - " if labels is not None:\n", - " ax.set_ylabel(labels[i])\n", - " ax.xaxis.set_major_locator(plt.NullLocator())\n", - "\n", - " plt.show()\n", - "\n", - "\n", "target = generate_target()" ] }, @@ -332,16 +260,29 @@ "outputs": [], "source": [ "MODEL = initialize_model(target, False)\n", + "\n", + "res_lm = ap.fit.LM(MODEL, verbose=1).fit()\n", + "print(res_lm.message)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "MODEL_init = initialize_model(target, False)\n", "fig, axarr = plt.subplots(1, 4, figsize=(24, 5))\n", "plt.subplots_adjust(wspace=0.1)\n", - "ap.plots.model_image(fig, axarr[0], MODEL)\n", + "ap.plots.model_image(fig, axarr[0], MODEL_init)\n", "axarr[0].set_title(\"Model before optimization\")\n", - "ap.plots.residual_image(fig, axarr[1], MODEL, normalize_residuals=True)\n", + "ap.plots.residual_image(fig, axarr[1], MODEL_init, normalize_residuals=True)\n", "axarr[1].set_title(\"Residuals before optimization\")\n", "\n", - "res_lm = ap.fit.LM(MODEL, verbose=1).fit()\n", - "print(res_lm.message)\n", - "\n", "ap.plots.model_image(fig, axarr[2], MODEL)\n", "axarr[2].set_title(\"Model after optimization\")\n", "ap.plots.residual_image(fig, axarr[3], MODEL, normalize_residuals=True)\n", @@ -362,22 +303,14 @@ "metadata": {}, "outputs": [], "source": [ - "param_names = list(MODEL.parameters.vector_names())\n", - "i = 0\n", - "while i < len(param_names):\n", - " param_names[i] = param_names[i].replace(\" \", \"\")\n", - " if \"center\" in param_names[i]:\n", - " center_name = param_names.pop(i)\n", - " param_names.insert(i, center_name.replace(\"center\", \"y\"))\n", - " param_names.insert(i, center_name.replace(\"center\", \"x\"))\n", - " i += 1\n", + "param_names = list(MODEL.build_params_array_names())\n", "set, sky = true_params()\n", - "corner_plot_covariance(\n", + "fig, ax = ap.plots.covariance_matrix(\n", " res_lm.covariance_matrix.detach().cpu().numpy(),\n", - " MODEL.parameters.vector_values().detach().cpu().numpy(),\n", + " MODEL.get_values().detach().cpu().numpy(),\n", " labels=param_names,\n", " figsize=(20, 20),\n", - " true_values=np.concatenate((sky, set.ravel())),\n", + " reference_values=np.concatenate((sky, set.ravel())),\n", ")" ] }, @@ -387,9 +320,9 @@ "source": [ "## Iterative Fit (models)\n", "\n", - "An iterative fitter is identified as `ap.fit.Iter`, this method is generally employed for large models where it is not feasible to hold all the relevant data in memory at once. The iterative fitter will cycle through the models in a `Group_Model` object and fit them one at a time to the image, using the residuals from the previous cycle. This can be a very robust way to deal with some fits, especially if the overlap between models is not too strong. It is however more dependent on good initialization than other methods like the Levenberg-Marquardt. Also, it is possible for the Iter method to get stuck in a local minimum under certain circumstances.\n", + "This iterative fitter is identified as `ap.fit.Iter`, this method is generally employed for large models where it is not feasible to hold all the relevant data in memory at once. The iterative fitter will cycle through the models in a `GroupModel` object and fit them one at a time to the image, using the residuals from the previous cycle. This can be a very robust way to deal with some fits, especially if the overlap between models is not too strong. It is however more dependent on good initialization than other methods like the Levenberg-Marquardt. Also, it is possible for the Iter method to get stuck in a local minimum under certain circumstances.\n", "\n", - "Note that while the Iterative fitter needs a `Group_Model` object to iterate over, it is not necessarily true that the sub models are `Component_Model` objects, they could be `Group_Model` objects as well. In this way it is possible to cycle through and fit \"clusters\" of objects that are nearby, so long as it doesn't consume too much memory.\n", + "Note that while the Iterative fitter needs a `GroupModel` object to iterate over, it is not necessarily true that the sub models are `ComponentModel` objects, they could be `GroupModel` objects as well. In this way it is possible to cycle through and fit \"clusters\" of objects that are nearby, so long as it doesn't consume too much memory.\n", "\n", "By only fitting one model at a time it is possible to get caught in a local minimum, or to get out of a local minimum that a different fitter was stuck in. For this reason it can be good to mix-and-match the iterative optimizers so they can help each other get unstuck if a fit is very challenging. " ] @@ -397,19 +330,36 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [ + "hide-output" + ] + }, "outputs": [], "source": [ "MODEL = initialize_model(target, False)\n", + "\n", + "res_iter = ap.fit.Iter(MODEL, verbose=1).fit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "MODEL_init = initialize_model(target, False)\n", "fig, axarr = plt.subplots(1, 4, figsize=(24, 5))\n", "plt.subplots_adjust(wspace=0.1)\n", - "ap.plots.model_image(fig, axarr[0], MODEL)\n", + "ap.plots.model_image(fig, axarr[0], MODEL_init)\n", "axarr[0].set_title(\"Model before optimization\")\n", - "ap.plots.residual_image(fig, axarr[1], MODEL, normalize_residuals=True)\n", + "ap.plots.residual_image(fig, axarr[1], MODEL_init, normalize_residuals=True)\n", "axarr[1].set_title(\"Residuals before optimization\")\n", "\n", - "res_iter = ap.fit.Iter(MODEL, verbose=1).fit()\n", - "\n", "ap.plots.model_image(fig, axarr[2], MODEL)\n", "axarr[2].set_title(\"Model after optimization\")\n", "ap.plots.residual_image(fig, axarr[3], MODEL, normalize_residuals=True)\n", @@ -421,13 +371,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Iterative Fit (parameters)\n", - "\n", - "This is an iterative fitter identified as `ap.fit.Iter_LM` and is generally employed for large models where it is not feasible to hold all the relevant data in memory at once. This iterative fitter will cycle through chunks of parameters and fit them one at a time to the image. This can be a very robust way to deal with some fits, especially if the overlap between models is not too strong. This is very similar to the other iterative fitter, however it is necessary for certain fitting circumstances when the problem can't be broken down into individual component models. This occurs, for example, when the models have many shared (constrained) parameters and there is no obvious way to break down sub-groups of models (an example of this is discussed in the AstroPhot paper).\n", + "## Iterative Fit (Param)\n", "\n", - "Note that this is iterating over the parameters, not the models. This allows it to handle parameter covariances even for very large models (if they happen to land in the same chunk). However, for this to work it must evaluate the whole model at each iteration making it somewhat slower than the regular `Iter` fitter, though it can make up for it by fitting larger chunks at a time which makes the whole optimization faster.\n", + "This iterative fitter is identified as `ap.fit.IterParam`, this is generally employed for large and interconnected models where it is not feasible to hold all the relevant data in memory at once. Unlike `ap.fit.Iter` which is intended to cycle through the sub models in a group model, this fitter iterates through parameters. The set of parameters which make up the model is broken into chunks and then fitting proceeds only on those chunks, rather than on all parameters simultaneously. For large models that have lots of interconnected/shared parameters, it doesn't really make sense to cycle through one sub-model at a time as optimizing that model may throw another model that is sharing a parameter into a bad part of parameter space. Thus `ap.fit.IterParam` is safe to use on any AstroPhot model without concern for this issue, the fitter will industriously proceed to high likelihood solutions monotonically. \n", "\n", - "By only fitting a subset of parameters at a time it is possible to get caught in a local minimum, or to get out of a local minimum that a different fitter was stuck in. For this reason it can be good to mix-and-match the iterative optimizers so they can help each other get unstuck. Since this iterative fitter chooses parameters randomly, it can sometimes get itself unstuck if it gets a lucky combination of parameters. Generally giving it more parameters to work with at a time is better." + "The tradeoff for this fitter is the same as for the other iterative fitter, if there are strong covariances in the likelihood structure then this fitter can take a long time to converge. The advantage here is that as the user you may take greater control over the combinations if you wish. The `chunks` argument can be set to an integer like `6` in which case, `6` parameters at a time will be fit (the last chunk may be smaller). Alternatively, the `chunks` parameter may be set to a tuple of numpy arrays, these should be boolean arrays that select the parameters for each chunk. For example, here is a possible `chunks` setup for a 7 parameter sersic model: `([1,1,0,0,0,0,0], [0,0,1,1,0,0,0], [0,0,0,0,1,1,1])` which makes three chunks to fit the `x,y` then `q, PA` then `n, Re, Ie` parameters. Note that you do not need to make the chunks exclusive, it is totally fine to have a parameter pop up in multiple chunks! Finally, there's the order the chunks are fit in. This can either `chunk_order=\"sequential\"` the default the chunks are fit in the order given, or `chunk_order=\"random\"` where each iteration a new random order is decided for the chunks to be evaluated." ] }, { @@ -437,15 +385,24 @@ "outputs": [], "source": [ "MODEL = initialize_model(target, False)\n", + "\n", + "res_iterparam = ap.fit.IterParam(MODEL, chunks=5, verbose=1).fit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_init = initialize_model(target, False)\n", "fig, axarr = plt.subplots(1, 4, figsize=(24, 5))\n", "plt.subplots_adjust(wspace=0.1)\n", - "ap.plots.model_image(fig, axarr[0], MODEL)\n", + "ap.plots.model_image(fig, axarr[0], MODEL_init)\n", "axarr[0].set_title(\"Model before optimization\")\n", - "ap.plots.residual_image(fig, axarr[1], MODEL, normalize_residuals=True)\n", + "ap.plots.residual_image(fig, axarr[1], MODEL_init, normalize_residuals=True)\n", "axarr[1].set_title(\"Residuals before optimization\")\n", "\n", - "res_iterlm = ap.fit.Iter_LM(MODEL, chunks=11, verbose=1).fit()\n", - "\n", "ap.plots.model_image(fig, axarr[2], MODEL)\n", "axarr[2].set_title(\"Model after optimization\")\n", "ap.plots.residual_image(fig, axarr[3], MODEL, normalize_residuals=True)\n", @@ -457,11 +414,47 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Gradient Descent\n", - "\n", - "A gradient descent fitter is identified as `ap.fit.Grad` and uses standard first order derivative methods as provided by PyTorch. These gradient descent methods include Adam, SGD, and LBFGS to name a few. The first order gradient is faster to evaluate and uses less memory, however it is considerably slower to converge than Levenberg-Marquardt. The gradient descent method with a small learning rate will reliably converge towards a local minimum, it will just do so slowly. \n", - "\n", - "In the example below we let it run for 1000 steps and even still it has not converged. In general you should not use gradient descent to optimize a model. However, in a challenging fitting scenario the small step size of gradient descent can actually be an advantage as it will not take any unedpectedly large steps which could mix up some models, or hop over the $\\chi^2$ minimum into impossible parameter space. Just make sure to finish with LM after using Grad so that it fully converges to a reliable minimum." + "The `ap.fit.IterParam` fitter can also generate a covariance matrix of uncertainties, just keep in mind that it only evaluates the covariances for parameters in the same chunk." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "param_names = list(MODEL.build_params_array_names())\n", + "set, sky = true_params()\n", + "fig, ax = ap.plots.covariance_matrix(\n", + " res_iterparam.covariance_matrix.detach().cpu().numpy(),\n", + " MODEL.get_values().detach().cpu().numpy(),\n", + " labels=param_names,\n", + " figsize=(20, 20),\n", + " reference_values=np.concatenate((sky, set.ravel())),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Scipy Minimize\n", + "\n", + "Any AstroPhot model becomes a function `model(x)` where `x` is a 1D tensor of\n", + "all the current dynamic parameters. This functional format is common for\n", + "external packages to use. AstroPhot includes a wrapper to access the\n", + "`scipy.optimize.minimize` minimizer list. AstroPhot will ensure the minimizers\n", + "respect the valid ranges set for each parameter.\n", + "\n", + "Typically, the AstroPhot LM optimizer is faster and more accurate than the Scipy\n", + "ones. The exact reason is unclear, but the Scipy minimizers are intended for\n", + "very general use, while the LM optimizer is specifically optimized for gaussian\n", + "log likelihoods.\n", + "\n", + "In the case below, the minimizer thinks it has terminated successfully, although\n", + "in fact it is quite far from the minimum. Consider this a lesson in trusting the\n", + "\"success\" message from an optimizer. It turns out to be very challenging to\n", + "identify if an optimizer is at a minimum, let alone the global minimum." ] }, { @@ -471,15 +464,29 @@ "outputs": [], "source": [ "MODEL = initialize_model(target, False)\n", + "\n", + "res_scipy = ap.fit.ScipyFit(MODEL, method=\"Powell\", verbose=1).fit()\n", + "print(res_scipy.scipy_res)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "MODEL_init = initialize_model(target, False)\n", "fig, axarr = plt.subplots(1, 4, figsize=(24, 5))\n", "plt.subplots_adjust(wspace=0.1)\n", - "ap.plots.model_image(fig, axarr[0], MODEL)\n", + "ap.plots.model_image(fig, axarr[0], MODEL_init)\n", "axarr[0].set_title(\"Model before optimization\")\n", - "ap.plots.residual_image(fig, axarr[1], MODEL, normalize_residuals=True)\n", + "ap.plots.residual_image(fig, axarr[1], MODEL_init, normalize_residuals=True)\n", "axarr[1].set_title(\"Residuals before optimization\")\n", "\n", - "res_grad = ap.fit.Grad(MODEL, verbose=1, max_iter=1000, optim_kwargs={\"lr\": 5e-3}).fit()\n", - "\n", "ap.plots.model_image(fig, axarr[2], MODEL)\n", "axarr[2].set_title(\"Model after optimization\")\n", "ap.plots.residual_image(fig, axarr[3], MODEL, normalize_residuals=True)\n", @@ -491,16 +498,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## No U-Turn Sampler (NUTS)\n", + "## Gradient Descent (Slalom)\n", "\n", - "Unlike the above methods, `ap.fit.NUTS` does not stricktly seek a minimum $\\chi^2$, instead it is an MCMC method which seeks to explore the likelihood space and provide a full posterior in the form of random samples. The NUTS method in AstroPhot is actually just a wrapper for the Pyro implementation (__[link here](https://docs.pyro.ai/en/stable/index.html)__). Most of the functionality can be accessed this way, though for very advanced applications it may be necessary to manually interface with Pyro (this is not very challenging as AstroPhot is fully differentiable).\n", + "A gradient descent fitter uses local gradient information to determine the direction of increased likelihood in parameter space. The challenge with gradient descent is choosing a step size. The `Slalom` algorithm developed for AstroPhot uses a few samples along the gradient direction to determine a parabola which it can then jump to the minimum of. In some sense this is like a 1D version of the Levenberg-Marquardt algorithm and the 1 dimension it choses is that along the gradient (plus momentum).\n", "\n", - "The first iteration of NUTS is always very slow since it compiles the forward method on the fly, after that each sample is drawn much faster. The warmup iterations take longer as the method is exploring the space and determining the ideal step size and mass matrix for fast integration with minimal numerical error (we only do 20 warmup steps here, if something goes wrong just try rerunning). Once the algorithm begins sampling it is able to move quickly (for an MCMC) through the parameter space. For many models, the NUTS sampler is able to collect nearly completely uncorrelated samples, meaning that even 100 is enough to get a good estimate of the posterior.\n", + "It is also possible to access the PyTorch gradient descent algorithms like `Adam` through the AstroPhot wrapper `ap.fit.Grad` which perform gradient descent using various algorithm designed for machine learning. In general though, those algorithms perform better on stochastic gradient descent problems, not static problems like seen by AstroPhot. So `Slalom` tends to perform better.\n", "\n", - "NUTS is far faster than other MCMC implementations such as a standard Metropolis Hastings MCMC. However, it is still a lot slower than the other optimizers (LM) since it is doing more than seeking a single high likelihood point, it is fully exploring the likelihood space. In simple cases, the automatic covariance matrix from LM is likely good enough, but if one really needs access to the full posterior of a complex model then NUTS is the best way to get it.\n", - "\n", - "For an excellent introduction to the Hamiltonian Monte-Carlo and a high level explanation of NUTS see this review:\n", - "__[Betancourt 2018](https://arxiv.org/pdf/1701.02434.pdf)__" + "As you see below, `Slalom` ends with a decent fit, though not good enough for perfect residuals like some other methods (Levenberg-Marquardt). This is typically the case. However, gradient descent can be very helpful for complex optimization tasks, because it is a slower optimization algorithm, it can be more stable in some circumstances. Try using it in cases where LM fails to get things back on track. Just make sure to finish off with an LM round to ensure you have settled into the minimum." ] }, { @@ -511,52 +515,84 @@ "source": [ "MODEL = initialize_model(target, False)\n", "\n", - "# Use LM to start the sampler at a high likelihood location, no burn-in needed!\n", - "# In general, NUTS is quite fast to do burn-in so this is often not needed\n", - "res1 = ap.fit.LM(MODEL).fit()\n", + "res_grad = ap.fit.Slalom(MODEL, verbose=1, momentum=0.1).fit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "MODEL_init = initialize_model(target, False)\n", + "fig, axarr = plt.subplots(1, 4, figsize=(24, 5))\n", + "plt.subplots_adjust(wspace=0.1)\n", + "ap.plots.model_image(fig, axarr[0], MODEL_init)\n", + "axarr[0].set_title(\"Model before optimization\")\n", + "ap.plots.residual_image(fig, axarr[1], MODEL_init, normalize_residuals=True)\n", + "axarr[1].set_title(\"Residuals before optimization\")\n", "\n", - "# Run the NUTS sampler\n", - "res_nuts = ap.fit.NUTS(\n", - " MODEL,\n", - " warmup=20,\n", - " max_iter=100,\n", - " inv_mass=res1.covariance_matrix,\n", - ").fit()" + "ap.plots.model_image(fig, axarr[2], MODEL)\n", + "axarr[2].set_title(\"Model after optimization\")\n", + "ap.plots.residual_image(fig, axarr[3], MODEL, normalize_residuals=True)\n", + "axarr[3].set_title(\"Residuals after optimization\")\n", + "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Note that there is no \"after optimization\" image above, because optimization was not done, it was full likelihood exploration. We can now create a corner plot with 2D projections of the 22 dimensional space that NUTS was exploring. The resulting corner plot is about what you would expect to get with 100 samples drawn from the multivariate gaussian found by LM above. If you run it again with more samples then the results will get even smoother." + "## Metropolis Adjusted Langevin Algorithm (MALA)\n", + "\n", + "This is one of the simplest gradient based samplers, and is very powerful. The standard Metropolis Hastings algorithm will use a gaussian proposal distribution then use the Metropolis Hastings accept/reject stage. MALA uses gradient information to determine a better proposal distribution locally (while maintaining detailed balance) and then uses the Metropolis Hastings accept/reject stage. The `ap.fit.MALA` fitter object is just a basic wrapper over the `ap.fit.func.mala` function, so feel free to check it out if you want more details on this simple and powerful sampler!" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [ + "hide-output" + ] + }, "outputs": [], "source": [ - "# corner plot of the posterior\n", - "# observe that it is very similar to the corner plot from the LM optimization since this case can be roughly\n", - "# approximated as a multivariate gaussian centered on the maximum likelihood point\n", - "param_names = list(MODEL.parameters.vector_names())\n", - "i = 0\n", - "while i < len(param_names):\n", - " param_names[i] = param_names[i].replace(\" \", \"\")\n", - " if \"center\" in param_names[i]:\n", - " center_name = param_names.pop(i)\n", - " param_names.insert(i, center_name.replace(\"center\", \"y\"))\n", - " param_names.insert(i, center_name.replace(\"center\", \"x\"))\n", - " i += 1\n", + "MODEL = initialize_model(target, False)\n", + "\n", + "# Use LM to start the sampler at a high likelihood location, no burn-in needed!\n", + "res1 = ap.fit.LM(MODEL, verbose=0).fit()\n", + "\n", + "res_mala = ap.fit.MALA(\n", + " model=MODEL,\n", + " chains=4,\n", + " max_iter=300,\n", + " epsilon=8e-1,\n", + " likelihood=\"poisson\",\n", + " mass_matrix=res1.covariance_matrix.detach().cpu().numpy(),\n", + ").fit()\n", + "chain_mala = res_mala.chain.reshape(-1, res_mala.chain.shape[-1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "# # corner plot of the posterior\n", + "param_names = list(MODEL.build_params_array_names())\n", "\n", "set, sky = true_params()\n", - "corner_plot(\n", - " res_nuts.chain.detach().cpu().numpy(),\n", - " labels=param_names,\n", - " figsize=(20, 20),\n", - " true_values=np.concatenate((sky, set.ravel())),\n", - ")" + "fig = corner(chain_mala, labels=param_names, truths=np.concatenate((sky, set.ravel())))" ] }, { @@ -565,19 +601,23 @@ "source": [ "## Hamiltonian Monte-Carlo (HMC)\n", "\n", - "The `ap.fit.HMC` is a simpler variant of the NUTS sampler. HMC takes a fixed number of steps at a fixed step size following Hamiltonian dynamics. This is in contrast to NUTS which attempts to optimally choose these parameters. HMC may be suitable in some cases where NUTS is unable to find ideal parameters. Also in some cases where you already know the pretty good step parameters HMC may run faster. If you don't want to fiddle around with parameters then stick with NUTS, HMC results will still have autocorrelation which will depend on the problem and choice of step parameters." + "The `ap.fit.HMC` takes a fixed number of steps at a fixed step size following Hamiltonian dynamics. This is in contrast to NUTS which attempts to optimally choose these parameters. The simplest way to think of HMC is as performing a number of MALA steps all in one go, so if `leapfrog_steps = 10` then HMC is very similar to running MALA then taking every tenth step and adding it to the chain. HMC results will still have autocorrelation which will depend on the problem and choice of step parameters." ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [ + "hide-output" + ] + }, "outputs": [], "source": [ "MODEL = initialize_model(target, False)\n", "\n", "# Use LM to start the sampler at a high likelihood location, no burn-in needed!\n", - "res1 = ap.fit.LM(MODEL).fit()\n", + "res1 = ap.fit.LM(MODEL, verbose=0).fit()\n", "\n", "# Run the HMC sampler\n", "res_hmc = ap.fit.HMC(\n", @@ -593,26 +633,23 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "# corner plot of the posterior\n", - "param_names = list(MODEL.parameters.vector_names())\n", - "i = 0\n", - "while i < len(param_names):\n", - " param_names[i] = param_names[i].replace(\" \", \"\")\n", - " if \"center\" in param_names[i]:\n", - " center_name = param_names.pop(i)\n", - " param_names.insert(i, center_name.replace(\"center\", \"y\"))\n", - " param_names.insert(i, center_name.replace(\"center\", \"x\"))\n", - " i += 1\n", + "param_names = list(MODEL.build_params_array_names())\n", "\n", "set, sky = true_params()\n", - "corner_plot(\n", + "fig = corner(\n", " res_hmc.chain.detach().cpu().numpy(),\n", " labels=param_names,\n", - " figsize=(20, 20),\n", - " true_values=np.concatenate((sky, set.ravel())),\n", + " truths=np.concatenate((sky, set.ravel())),\n", + " plot_contours=False,\n", + " smooth=0.8,\n", ")" ] }, @@ -622,7 +659,7 @@ "source": [ "## Metropolis Hastings\n", "\n", - "This is the classic MCMC algorithm using the Metropolis Hastngs accept step identified with `ap.fit.MHMCMC`. One can set the gaussian random step scale and then explore the posterior. While this technically always works, in practice it can take exceedingly long to actually converge to the posterior. This is because the step size must be set very small to have a reasonable likelihood of accepting each step, so it never moves very far in parameter space. With each subsequent sample being very close to the previous sample it can take a long time for it to wander away from its starting point. In the example below it would take an extremely long time for the chain to converge. Instead of waiting that long, we demonstrate the functionality with 1000 steps, but suggest using NUTS for any real world problem. Still, if there is something NUTS can't handle (a function that isn't differentiable) then MHMCMC can save the day (even if it takes all day to do it)." + "This is the more standard MCMC algorithm using the Metropolis Hastngs accept step identified with `ap.fit.MHMCMC`. Under the hood, this is just a wrapper for the excellent `emcee` package, if you want to take advantage of more `emcee` features you can very easily use `ap.fit.MHMCMC` as a starting point. However, one should keep in mind that for large models it can take exceedingly long to actually converge to the posterior. Instead of waiting that long, we demonstrate the functionality with 100 steps (and 30 chains), but suggest using MALA for any real world problem. Still, if there is something MALA can't handle (a function that isn't differentiable) then MHMCMC can save the day (even if it takes all day to do it)." ] }, { @@ -634,10 +671,12 @@ "MODEL = initialize_model(target, False)\n", "\n", "# Use LM to start the sampler at a high likelihood location, no burn-in needed!\n", - "res1 = ap.fit.LM(MODEL).fit()\n", + "print(\"running LM fit\")\n", + "res1 = ap.fit.LM(MODEL, verbose=0).fit()\n", "\n", - "# Run the HMC sampler\n", - "res_mh = ap.fit.MHMCMC(MODEL, verbose=1, max_iter=1000, epsilon=1e-4, report_after=np.inf).fit()" + "# Run the MHMCMC sampler\n", + "print(\"running MHMCMC sampling\")\n", + "res_mh = ap.fit.MHMCMC(MODEL, verbose=1, max_iter=100).fit()" ] }, { @@ -647,26 +686,15 @@ "outputs": [], "source": [ "# corner plot of the posterior\n", - "# note that, even 1000 samples is not enough to overcome the autocorrelation so the posterior has not converged.\n", - "# In fact it is not even close to convergence as can be seen by the multi-modal blobs in the posterior since this\n", - "# problem is unimodal (except the modes where models are swapped). It is almost never worthwhile to use this\n", - "# sampler except as a sanity check on very simple models.\n", - "param_names = list(MODEL.parameters.vector_names())\n", - "i = 0\n", - "while i < len(param_names):\n", - " param_names[i] = param_names[i].replace(\" \", \"\")\n", - " if \"center\" in param_names[i]:\n", - " center_name = param_names.pop(i)\n", - " param_names.insert(i, center_name.replace(\"center\", \"y\"))\n", - " param_names.insert(i, center_name.replace(\"center\", \"x\"))\n", - " i += 1\n", + "# note that, even 3000 samples is not enough to overcome the autocorrelation so the posterior has not converged.\n", + "param_names = list(MODEL.build_params_array_names())\n", "\n", "set, sky = true_params()\n", - "corner_plot(\n", - " res_mh.chain[::10], # thin by a factor 10 so the plot works in reasonable time\n", + "fig = corner(\n", + " res_mh.chain,\n", " labels=param_names,\n", - " figsize=(20, 20),\n", - " true_values=np.concatenate((sky, set.ravel())),\n", + " truths=np.concatenate((sky, set.ravel())),\n", + " smooth=0.8,\n", ")" ] }, diff --git a/docs/source/tutorials/FunctionalInterface.ipynb b/docs/source/tutorials/FunctionalInterface.ipynb new file mode 100644 index 00000000..7cc69e31 --- /dev/null +++ b/docs/source/tutorials/FunctionalInterface.ipynb @@ -0,0 +1,483 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Functional AstroPhot interface\n", + "\n", + "AstroPhot is an object oriented code, meaning that it is build on python objects that behave in intuitively meaningful ways. For example it is possible to add two model images together to get a new model image, even if one of them only fills a subwindow of pixels, this is because the model images are aware of what part of the scene they represent and can behave accordingly. This is all very nice so long as you are building the kinds of models that AstroPhot is designed for, and when you are not trying to squeeze out every last bit of performance. For most cases, AstroPhot objects can handle complex configurations and perform very quickly. Still, you may need to push things with highly specific customization. Let's consider a case where some specialization can give a big performance boost, a supernova light curve." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import astrophot as ap\n", + "import numpy as np\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "import socket\n", + "from corner import corner\n", + "\n", + "socket.setdefaulttimeout(120)\n", + "ap.backend.backend = \"jax\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": { + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "def CD_rot(theta):\n", + " return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])\n", + "\n", + "\n", + "def sn_flux(t):\n", + " return 5 * np.exp(-0.5 * ((t - 10) / 5) ** 2)" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Generate Mock data\n", + "\n", + "Here we will use the usual AstroPhot object oriented interface to generate some mock SN data. There is a fixed host Sersic galaxy, and a Gaussian point source with variable flux as the SN. Every observation is a new pointing of the telescope, so the images are not all aligned and are rotated randomly. The AstroPhot object oriented framework handles this by having target images aware of the WCS that connects the pixels to their location on the sky. We will see in the functional version that everything has to be more explicit, but is more or less the same." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "psf = jnp.array(ap.utils.initialize.gaussian_psf(0.1, 21, 0.1))\n", + "target = ap.TargetImageList(\n", + " list(\n", + " ap.TargetImage(\n", + " name=f\"epoch_{i}\",\n", + " data=np.zeros((32, 32)),\n", + " crpix=(16, 16),\n", + " crtan=0.1 * np.random.normal(size=(2,)),\n", + " CD=0.1 * CD_rot(2 * np.pi * np.random.normal()),\n", + " psf=psf,\n", + " )\n", + " for i in range(10)\n", + " )\n", + ")\n", + "T = np.linspace(-10, 30, 10)\n", + "dataset = {\n", + " \"image\": jnp.zeros((10, 32, 32)),\n", + " \"variance\": jnp.zeros((10, 32, 32)),\n", + " \"crpix\": jnp.zeros((10, 2)),\n", + " \"crtan\": jnp.zeros((10, 2)),\n", + " \"CD\": jnp.zeros((10, 2, 2)),\n", + "}\n", + "models = []\n", + "for i, img in enumerate(target.images):\n", + " host = ap.Model(\n", + " name=f\"host_{i}\",\n", + " target=img,\n", + " model_type=\"sersic galaxy model\",\n", + " center=(0.0, 0.0),\n", + " q=0.7,\n", + " PA=np.pi / 4,\n", + " n=2,\n", + " Re=1,\n", + " Ie=1,\n", + " psf_convolve=True,\n", + " )\n", + " host.initialize()\n", + " models.append(host)\n", + " sn = ap.Model(\n", + " name=f\"supernova_{i}\",\n", + " target=img,\n", + " model_type=\"point model\",\n", + " psf=psf,\n", + " center=(0.4, 0.0),\n", + " flux=sn_flux(T[i]),\n", + " )\n", + " sn.initialize()\n", + " models.append(sn)\n", + " sky = ap.Model(name=f\"sky_{i}\", target=img, model_type=\"flat sky model\", I=0.1 / 0.1**2)\n", + " sky.initialize()\n", + " models.append(sky)\n", + " img.data = np.array(host().data + sn().data + sky().data).T\n", + " img.variance = 0.0001 * np.array(img.data).T\n", + " img.data = img.data.T + np.random.normal(scale=0.01 * np.sqrt(np.array(img.data))).T\n", + " dataset[\"image\"] = dataset[\"image\"].at[i].set(img.data.T)\n", + " dataset[\"variance\"] = dataset[\"variance\"].at[i].set(img.variance.T)\n", + " dataset[\"crpix\"] = dataset[\"crpix\"].at[i].set(jnp.array(img.crpix))\n", + " dataset[\"crtan\"] = dataset[\"crtan\"].at[i].set(img.crtan.value)\n", + " dataset[\"CD\"] = dataset[\"CD\"].at[i].set(img.CD.value)\n", + "apmodel = ap.Model(name=\"AstroPhotModel\", model_type=\"group model\", target=target, models=models)\n", + "fig, axarr = plt.subplots(2, 5, figsize=(15, 6))\n", + "for ax, img in zip(axarr.flatten(), target.images):\n", + " ap.plots.target_image(fig, ax, img)\n", + " ax.set_title(img.name)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## Build the functional model\n", + "\n", + "Below we build a functional version of the AstroPhot model which generated the data. The end result is an identical sampling algorithm which strips away all the object oriented layers of the AstroPhot model to give a pure function to compute pixel values. This is a very insightful exercise to learn exactly what AstroPhot does under the hood. As you can see, there are a number of subtle effects to account for which AstroPhot does automatically, but at a high level it is all very straightforward." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "def model_img(\n", + " sersic_x,\n", + " sersic_y,\n", + " sersic_q,\n", + " sersic_PA,\n", + " sersic_n,\n", + " sersic_Re,\n", + " sersic_Ie,\n", + " psf,\n", + " sn_x,\n", + " sn_y,\n", + " sn_flux,\n", + " sky,\n", + " crpix,\n", + " crtan,\n", + " CD,\n", + "):\n", + " # Sample sersic\n", + " pixel_area = 0.1 * 0.1\n", + " # Pad by 20 pixels to avoid edge effects from convolution\n", + " i, j, w = ap.image.func.pixel_quad_meshgrid(\n", + " (32 + 20, 32 + 20), ap.config.DTYPE, ap.config.DEVICE, order=3\n", + " )\n", + " #\n", + " x, y = ap.image.func.pixel_to_plane_linear(j, i, *(crpix + 10), CD, *crtan)\n", + " sx, sy = x - sersic_x, y - sersic_y\n", + " sx, sy = ap.models.func.rotate(-sersic_PA + np.pi / 2, sx, sy)\n", + " sy = sy / sersic_q\n", + " sr = jnp.sqrt(sx**2 + sy**2)\n", + " z = ap.models.func.sersic(sr, n=sersic_n, Re=sersic_Re, Ie=sersic_Ie)\n", + " sample = ap.models.func.pixel_quad_integrator(z, w)\n", + " sample = ap.models.func.convolve(sample, psf)\n", + " sample = sample[10:-10, 10:-10] * pixel_area\n", + "\n", + " # Sample point source (empirical PSF)\n", + " i, j, w = ap.image.func.pixel_quad_meshgrid(\n", + " (32, 32), ap.config.DTYPE, ap.config.DEVICE, order=3\n", + " )\n", + " gj, gi = ap.image.func.plane_to_pixel_linear(sn_x, sn_y, *crpix, CD, *crtan)\n", + " z = ap.utils.interpolate.interp2d(\n", + " psf, j - gj + (psf.shape[1] // 2), i - gi + (psf.shape[0] // 2)\n", + " )\n", + " sample = sample + sn_flux * ap.models.func.pixel_quad_integrator(z, w)\n", + "\n", + " # add sky level\n", + " return sample + sky\n", + "\n", + "\n", + "# fixed: sersic_x, sersic_y, psf, crpix, CD\n", + "# global: sersic_q, sersic_PA, sersic_n, sersic_Re, sersic_Ie, sn_x, sn_y\n", + "# per image: sky, sn_sigma, sn_flux, crtan\n", + "\n", + "\n", + "@jax.jit\n", + "def full_model(\n", + " sersic_x,\n", + " sersic_y,\n", + " sersic_q,\n", + " sersic_PA,\n", + " sersic_n,\n", + " sersic_Re,\n", + " sersic_Ie,\n", + " psf,\n", + " sn_x,\n", + " sn_y,\n", + " sn_flux,\n", + " sky,\n", + " crpix,\n", + " crtan,\n", + " CD,\n", + "):\n", + " return jax.vmap(\n", + " model_img,\n", + " in_axes=(None, None, None, None, None, None, None, None, None, None, 0, 0, 0, 0, 0),\n", + " )(\n", + " sersic_x,\n", + " sersic_y,\n", + " sersic_q,\n", + " sersic_PA,\n", + " sersic_n,\n", + " sersic_Re,\n", + " sersic_Ie,\n", + " psf,\n", + " sn_x,\n", + " sn_y,\n", + " sn_flux,\n", + " sky,\n", + " crpix,\n", + " crtan,\n", + " CD,\n", + " )\n", + "\n", + "\n", + "def model(params, sersic_x, sersic_y, psf, crpix, CD):\n", + " return full_model(\n", + " sersic_x,\n", + " sersic_y,\n", + " params[0],\n", + " params[1],\n", + " params[2],\n", + " params[3],\n", + " params[4],\n", + " psf,\n", + " params[5],\n", + " params[6],\n", + " params[7:17],\n", + " params[17:27],\n", + " crpix,\n", + " params[27:47].reshape(10, 2),\n", + " CD,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "And to see the model in action we can sample it using the true parameter values. As expected, this produces a perfect set of residuals which look like pure random noise." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "params_true = jnp.array(\n", + " np.concatenate(\n", + " [\n", + " [0.7], # sersic_q\n", + " [np.pi / 4], # sersic_PA\n", + " [2.0], # sersic_n\n", + " [1.0], # sersic_Re\n", + " [1.0], # sersic_Ie\n", + " [0.4], # sn_x\n", + " [0.0], # sn_y\n", + " sn_flux(T), # sn_flux\n", + " np.array([0.1] * 10), # sky\n", + " np.array(dataset[\"crtan\"].flatten()), # crtan\n", + " ]\n", + " )\n", + ")\n", + "extra = (jnp.array(0.0), jnp.array(0.0), psf, dataset[\"crpix\"], dataset[\"CD\"])\n", + "sample = model(params_true, *extra)\n", + "residuals = (dataset[\"image\"] - sample) / jnp.sqrt(dataset[\"variance\"])\n", + "fig, axarr = plt.subplots(3, 10, figsize=(18, 6))\n", + "for i, (img, samp, resid) in enumerate(zip(dataset[\"image\"], sample, residuals)):\n", + " axarr[0, i].imshow(img.T, origin=\"lower\", cmap=\"viridis\")\n", + " axarr[0, i].set_title(f\"obs {i}\")\n", + " axarr[1, i].imshow(samp.T, origin=\"lower\", cmap=\"viridis\")\n", + " axarr[1, i].set_title(f\"model {i}\")\n", + " axarr[2, i].imshow(resid.T, origin=\"lower\", cmap=\"seismic\", vmin=-5, vmax=5)\n", + " axarr[2, i].set_title(f\"residual {i}\")\n", + "for ax in axarr.flatten():\n", + " ax.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axarr = plt.subplots(3, 10, figsize=(18, 6))\n", + "ap.plots.target_image(fig, axarr[0], apmodel.target)\n", + "ap.plots.model_image(fig, axarr[1], apmodel, showcbar=False)\n", + "ap.plots.residual_image(\n", + " fig, axarr[2], apmodel, scaling=\"clip\", normalize_residuals=True, showcbar=False\n", + ")\n", + "for ax in axarr.flatten():\n", + " ax.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's compare how fast the two code are\n", + "print(\"Functional interface timings:\")\n", + "%timeit model(params_true, *extra)\n", + "print(\"AstroPhot model timings:\")\n", + "%timeit apmodel()" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "This is quite a striking result, the functional implementation is ~100x faster than the AstroPhot model! However, it is important to put this speed comparison in context. The AstroPhot model is much easier, less error prone, and more intuitive to put together. If we are only going to run the model a few times then we will save much more than 500ms by getting the code written faster. The cutout size of 32x32 is very small, while AstroPhot is built to scale to very large images. For larger images, the Python overhead is negligible and the two codes will have near identical runtime. In fact, if the images get a lot larger the functional version as written will run out of memory while the AstroPhot model could carry on easily because of how it chunks the data. Also, note that the plots are quite different, AstroPhot plots all the images properly oriented in the sky, while for the functional version we don't have that capability. AstroPhot has a more complete understanding of the data and can perform a lot more operations on the results. AstroPhot could also combine in data at different resolutions and sizes, while our functional version is predicated on the idea that all the images will be 32x32 pixels, we would need to completely rewrite it to change that. If we wanted to change the model to fix some parameter or to turn one of the fixed parameters into a free parameter, we would have to trace it through the whole functional implementation and make updates accordingly. This goes for any change really, what if we needed to add in a mask, a second sersic model, or start modelling the PSF (rather than taking it as fixed); all of these would require painful changes to the functional version while they would be trivial additions to the AstroPhot model.\n", + "\n", + "For these reasons and more, it is highly recommended to do lots of prototyping with object oriented AstroPhot models **before** ever considering the functional interface." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "# Make 8 chains, starting at the true parameters\n", + "params = np.stack(list(np.array(params_true) for _ in range(4)))\n", + "\n", + "# Compute a mass matrix using the Fisher information matrix\n", + "J = jax.jacfwd(model, argnums=0)(params_true, *extra).reshape(-1, params_true.shape[-1])\n", + "V = dataset[\"variance\"].reshape(-1)\n", + "H = J.T @ (J / V[:, None])\n", + "M = jnp.linalg.inv(H)\n", + "\n", + "\n", + "def log_likelihood(params, sersic_x, sersic_y, psf, crpix, CD):\n", + " model_sample = model(params, sersic_x, sersic_y, psf, crpix, CD)\n", + " residuals = (dataset[\"image\"] - model_sample) ** 2 / dataset[\"variance\"]\n", + " return -0.5 * jnp.sum(residuals)\n", + "\n", + "\n", + "# Vectorized log likelihood and gradient functions\n", + "vmodel = jax.jit(jax.vmap(log_likelihood, in_axes=(0, None, None, None, None, None)))\n", + "vgmodel = jax.jit(\n", + " jax.vmap(jax.grad(log_likelihood, argnums=0), in_axes=(0, None, None, None, None, None))\n", + ")\n", + "\n", + "# Run MALA sampling\n", + "chain, logp = ap.fit.func.mala(\n", + " params,\n", + " lambda p: np.array(vmodel(jnp.array(p), *extra)),\n", + " lambda p: np.array(vgmodel(jnp.array(p), *extra)),\n", + " num_samples=400,\n", + " epsilon=5e-1,\n", + " mass_matrix=np.array(M),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "Now lets plot the likelihood distributions for the flux parameters compared to their true value. As you can see, the distributions do a good job of covering the ground truth! This means we have accurately extracted the light curve for the supernova data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "figure = corner(\n", + " chain.reshape(-1, chain.shape[-1])[:, 7:17],\n", + " labels=list(f\"flux at epoch {i}\" for i in range(10)),\n", + " truths=params_true[7:17],\n", + ")\n", + "figure.suptitle(\"Likelihood distributions for supernova fluxes at each epoch\", fontsize=20)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "Below we show the likelihood distribution for the sersic host parameters. We can see that there is some non-linearity and certainly lots of correlation in these parameters. This makes the sampling a bit trickier, but MALA is up to the task." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "figure = corner(\n", + " chain.reshape(-1, chain.shape[-1])[:, :5],\n", + " labels=[\"sersic_q\", \"sersic_PA\", \"sersic_n\", \"sersic_Re\", \"sersic_Ie\"],\n", + " truths=params_true[:5],\n", + ")\n", + "figure.suptitle(\"Likelihood distributions for host parameters\", fontsize=20)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/GettingStarted.ipynb b/docs/source/tutorials/GettingStarted.ipynb index f2512b90..c1e680ac 100644 --- a/docs/source/tutorials/GettingStarted.ipynb +++ b/docs/source/tutorials/GettingStarted.ipynb @@ -15,19 +15,19 @@ "metadata": {}, "outputs": [], "source": [ + "%matplotlib inline\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", - "import os\n", "import astrophot as ap\n", "import numpy as np\n", "import torch\n", "from astropy.io import fits\n", "from astropy.wcs import WCS\n", "import matplotlib.pyplot as plt\n", - "from time import time\n", + "import socket\n", "\n", - "%matplotlib inline" + "socket.setdefaulttimeout(120)" ] }, { @@ -45,25 +45,26 @@ "metadata": {}, "outputs": [], "source": [ - "model1 = ap.models.AstroPhot_Model(\n", - " name=\"model1\", # every model must have a unique name\n", + "model1 = ap.Model(\n", + " name=\"model1\",\n", " model_type=\"sersic galaxy model\", # this specifies the kind of model\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"n\": 2,\n", - " \"Re\": 10,\n", - " \"Ie\": 1,\n", - " }, # here we set initial values for each parameter\n", - " target=ap.image.Target_Image(\n", - " data=np.zeros((100, 100)), zeropoint=22.5, pixelscale=1.0\n", - " ), # every model needs a target, more on this later\n", + " # here we set initial values for each parameter\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " n=2,\n", + " Re=10,\n", + " Ie=1,\n", + " # every model needs a target, more on this later\n", + " target=ap.TargetImage(data=np.zeros((100, 100)), zeropoint=22.5),\n", ")\n", - "model1.initialize() # before using the model it is good practice to call initialize so the model can get itself ready\n", + "\n", + "# models must/should be initialized before doing anything with them.\n", + "# This makes sure all the parameters and metadata are ready to go.\n", + "model1.initialize()\n", "\n", "# We can print the model's current state\n", - "model1.parameters" + "print(model1)" ] }, { @@ -72,10 +73,9 @@ "metadata": {}, "outputs": [], "source": [ - "# AstroPhot has built in methods to plot relevant information. We didn't specify the region on the sky for\n", - "# this model to focus on, so we just made a 100x100 window. Unless you are very lucky this won't\n", - "# line up with what you're trying to fit, so next we'll see how to give the model a target.\n", - "\n", + "# AstroPhot has built in methods to plot relevant information. This plots the model\n", + "# as projected into the \"target\" image. Thus it has the same pixelscale, orientation\n", + "# and (optionally) PSF as the model's target.\n", "fig, ax = plt.subplots(figsize=(8, 7))\n", "ap.plots.model_image(fig, ax, model1)\n", "plt.show()" @@ -102,12 +102,11 @@ ")\n", "target_data = np.array(hdu[0].data, dtype=np.float64)\n", "\n", - "# Create a target object with specified pixelscale and zeropoint\n", - "target = ap.image.Target_Image(\n", + "target = ap.TargetImage(\n", " data=target_data,\n", - " pixelscale=0.262, # Every target image needs to know it's pixelscale in arcsec/pixel\n", - " zeropoint=22.5, # optionally, you can give a zeropoint to tell AstroPhot what the pixel flux units are\n", - " variance=\"auto\", # Automatic variance estimate for testing and demo purposes, in real analysis use weight maps, counts, gain, etc to compute variance!\n", + " pixelscale=0.262,\n", + " zeropoint=22.5, # optionally, a zeropoint tells AstroPhot the pixel flux units\n", + " variance=\"auto\", # Automatic variance estimate for testing and demo purposes only! In real analysis use weight maps, counts, gain, etc to compute variance!\n", ")\n", "\n", "# The default AstroPhot target plotting method uses log scaling in bright areas and histogram scaling in faint areas\n", @@ -123,19 +122,21 @@ "outputs": [], "source": [ "# This model now has a target that it will attempt to match\n", - "model2 = ap.models.AstroPhot_Model(\n", + "model2 = ap.Model(\n", " name=\"model with target\",\n", - " model_type=\"sersic galaxy model\", # feel free to swap out sersic with other profile types\n", - " target=target, # now the model knows what its trying to match\n", + " model_type=\"sersic galaxy model\",\n", + " target=target,\n", ")\n", "\n", - "# Instead of giving initial values for all the parameters, it is possible to simply call \"initialize\" and AstroPhot\n", - "# will try to guess initial values for every parameter assuming the galaxy is roughly centered. It is also possible\n", - "# to set just a few parameters and let AstroPhot try to figure out the rest. For example you could give it an initial\n", + "# Instead of giving initial values for all the parameters, it is possible to\n", + "# simply call \"initialize\" and AstroPhot will try to guess initial values for\n", + "# every parameter. It is also possible to set just a few parameters and let\n", + "# AstroPhot try to figure out the rest. For example you could give it an initial\n", "# Guess for the center and it will work from there.\n", "model2.initialize()\n", "\n", - "# Plotting the initial parameters and residuals, we see it gets the rough shape of the galaxy right, but still has some fitting to do\n", + "# Plotting the initial parameters and residuals, we see it gets the rough shape\n", + "# of the galaxy right, but still has some fitting to do\n", "fig4, ax4 = plt.subplots(1, 2, figsize=(16, 6))\n", "ap.plots.model_image(fig4, ax4[0], model2)\n", "ap.plots.residual_image(fig4, ax4[1], model2)\n", @@ -149,13 +150,11 @@ "outputs": [], "source": [ "# Now that the model has been set up with a target and initialized with parameter values, it is time to fit the image\n", - "result = ap.fit.LM(model2, verbose=1).fit()\n", + "result = ap.fit.LMfast(model2, verbose=1).fit()\n", "\n", "# See that we use ap.fit.LM, this is the Levenberg-Marquardt Chi^2 minimization method, it is the recommended technique\n", - "# for most least-squares problems. However, there are situations in which different optimizers may be more desirable\n", - "# so the ap.fit package includes a few options to pick from. The various fitting methods will be described in a\n", - "# different tutorial.\n", - "print(\"Fit message:\", result.message) # the fitter will return a message about its convergence" + "# for most least-squares problems. See the Fitting Methods tutorial for more on fitters!\n", + "print(\"Fit message:\", result.message) # the fitter will store a message about its convergence" ] }, { @@ -164,6 +163,7 @@ "metadata": {}, "outputs": [], "source": [ + "print(model2)\n", "# we now plot the fitted model and the image residuals\n", "fig5, ax5 = plt.subplots(1, 2, figsize=(16, 6))\n", "ap.plots.model_image(fig5, ax5[0], model2)\n", @@ -227,8 +227,8 @@ "# can still see how the covariance of the parameters plays out in a given fit.\n", "fig, ax = ap.plots.covariance_matrix(\n", " result.covariance_matrix.detach().cpu().numpy(),\n", - " model2.parameters.vector_values().detach().cpu().numpy(),\n", - " model2.parameters.vector_names(),\n", + " model2.get_values().detach().cpu().numpy(),\n", + " model2.build_params_array_names(),\n", ")\n", "plt.show()" ] @@ -272,15 +272,11 @@ "outputs": [], "source": [ "# note, we don't provide a name here. A unique name will automatically be generated using the model type\n", - "model3 = ap.models.AstroPhot_Model(\n", + "model3 = ap.Model(\n", " model_type=\"sersic galaxy model\",\n", " target=target,\n", - " window=[\n", - " [480, 595],\n", - " [555, 665],\n", - " ], # this is a region in pixel coordinates ((xmin,xmax),(ymin,ymax))\n", + " window=[480, 595, 555, 665], # this is a region in pixel coordinates (imin,imax,jmin,jmax)\n", ")\n", - "\n", "print(f\"automatically generated name: '{model3.name}'\")\n", "\n", "# We can plot the \"model window\" to show us what part of the image will be analyzed by that model\n", @@ -297,9 +293,7 @@ "outputs": [], "source": [ "model3.initialize()\n", - "\n", - "result = ap.fit.LM(model3, verbose=1).fit()\n", - "print(result.message)" + "result = ap.fit.LMfast(model3, verbose=1).fit()" ] }, { @@ -309,6 +303,7 @@ "outputs": [], "source": [ "# Note that when only a window is fit, the default plotting methods will only show that window\n", + "print(model3)\n", "fig7, ax7 = plt.subplots(1, 2, figsize=(16, 6))\n", "ap.plots.model_image(fig7, ax7[0], model3)\n", "ap.plots.residual_image(fig7, ax7[1], model3, normalize_residuals=True)\n", @@ -334,14 +329,13 @@ "source": [ "# here we make a sersic model that can only have q and n in a narrow range\n", "# Also, we give PA and initial value and lock that so it does not change during fitting\n", - "constrained_param_model = ap.models.AstroPhot_Model(\n", + "constrained_param_model = ap.Model(\n", " name=\"constrained parameters\",\n", " model_type=\"sersic galaxy model\",\n", - " parameters={\n", - " \"q\": {\"limits\": [0.4, 0.6]},\n", - " \"n\": {\"limits\": [2, 3]},\n", - " \"PA\": {\"value\": 60 * np.pi / 180, \"locked\": True},\n", - " },\n", + " q={\"valid\": (0.4, 0.6)},\n", + " n={\"valid\": (2, 3)},\n", + " PA={\"value\": 60 * np.pi / 180},\n", + " target=target,\n", ")" ] }, @@ -359,56 +353,32 @@ "outputs": [], "source": [ "# model 1 is a sersic model\n", - "model_1 = ap.models.AstroPhot_Model(\n", - " model_type=\"sersic galaxy model\", parameters={\"center\": [50, 50], \"PA\": np.pi / 4}\n", - ")\n", + "model_1 = ap.Model(model_type=\"sersic galaxy model\", center=[50, 50], PA=np.pi / 4, target=target)\n", "# model 2 is an exponential model\n", - "model_2 = ap.models.AstroPhot_Model(\n", - " model_type=\"exponential galaxy model\",\n", - ")\n", + "model_2 = ap.Model(model_type=\"exponential galaxy model\", target=target)\n", "\n", "# Here we add the constraint for \"PA\" to be the same for each model.\n", "# In doing so we provide the model and parameter name which should\n", "# be connected.\n", - "model_2[\"PA\"].value = model_1[\"PA\"]\n", + "model_2.PA = model_1.PA\n", "\n", "# Here we can see how the two models now both can modify this parameter\n", "print(\n", " \"initial values: model_1 PA\",\n", - " model_1[\"PA\"].value.item(),\n", + " model_1.PA.value.item(),\n", " \"model_2 PA\",\n", - " model_2[\"PA\"].value.item(),\n", + " model_2.PA.value.item(),\n", ")\n", "# Now we modify the PA for model_1\n", - "model_1[\"PA\"].value = np.pi / 3\n", + "model_1.PA.value = np.pi / 3\n", "print(\n", " \"change model_1: model_1 PA\",\n", - " model_1[\"PA\"].value.item(),\n", - " \"model_2 PA\",\n", - " model_2[\"PA\"].value.item(),\n", - ")\n", - "# Similarly we modify the PA for model_2\n", - "model_2[\"PA\"].value = np.pi / 2\n", - "print(\n", - " \"change model_2: model_1 PA\",\n", - " model_1[\"PA\"].value.item(),\n", + " model_1.PA.value.item(),\n", " \"model_2 PA\",\n", - " model_2[\"PA\"].value.item(),\n", + " model_2.PA.value.item(),\n", ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Keep in mind that both models have full control over the parameter, it is listed in both of\n", - "# their \"parameter_order\" tuples.\n", - "print(\"model_1 parameters: \", model_1.parameter_order)\n", - "print(\"model_2 parameters: \", model_2.parameter_order)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -424,12 +394,14 @@ "metadata": {}, "outputs": [], "source": [ - "# Save the model to a file\n", + "# Save the model state to a file\n", "\n", - "model2.save() # will default to save as AstroPhot.yaml\n", - "\n", - "with open(\"AstroPhot.yaml\", \"r\") as f:\n", - " print(f.read()) # show what the saved file looks like" + "model2.save_state(\"current_spot.hdf5\", appendable=True) # save as it is\n", + "model2.q = 0.1 # do some updates to the model\n", + "model2.PA = 0.1\n", + "model2.n = 0.9\n", + "model2.Re = 0.1\n", + "model2.append_state(\"current_spot.hdf5\") # save the updated model state as often as you like" ] }, { @@ -438,13 +410,10 @@ "metadata": {}, "outputs": [], "source": [ - "# load a model from a file\n", + "# load a model state from a file\n", "\n", - "# note that the target still must be specified, only the parameters are saved\n", - "model4 = ap.models.AstroPhot_Model(name=\"new name\", filename=\"AstroPhot.yaml\", target=target)\n", - "print(\n", - " model4\n", - ") # can see that it has been constructed with all the same parameters as the saved model2." + "model2.load_state(\"current_spot.hdf5\", index=0) # load the first state from the file\n", + "print(model2) # see that the values are back to where they started" ] }, { @@ -463,7 +432,7 @@ "ax.imshow(\n", " np.log10(saved_image_hdu[0].data),\n", " origin=\"lower\",\n", - " cmap=\"plasma\",\n", + " cmap=\"viridis\",\n", ")\n", "plt.show()" ] @@ -493,42 +462,14 @@ "\n", "target.save(\"target.fits\")\n", "\n", - "new_target = ap.image.Target_Image(filename=\"target.fits\")\n", + "# Note that it is often also possible to load from regular FITS files\n", + "new_target = ap.TargetImage(filename=\"target.fits\")\n", "\n", "fig, ax = plt.subplots(figsize=(8, 8))\n", "ap.plots.target_image(fig, ax, new_target)\n", "plt.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Give the model new parameter values manually\n", - "\n", - "print(\n", - " \"parameter input order: \", model4.parameter_order\n", - ") # use this to see what order you have to give the parameters as input\n", - "\n", - "# plot the old model\n", - "fig9, ax9 = plt.subplots(1, 2, figsize=(16, 6))\n", - "ap.plots.model_image(fig9, ax9[0], model4)\n", - "T = ax9[0].set_title(\"parameters as loaded\")\n", - "\n", - "# update and plot the new parameters\n", - "new_parameters = torch.tensor(\n", - " [75, 110, 0.4, 20 * np.pi / 180, 3, 25, 0.12]\n", - ") # note that the center parameter needs two values as input\n", - "model4.initialize() # initialize must be called before optimization, or any other activity in which parameters are updated\n", - "model4.parameters.vector_set_values(\n", - " new_parameters\n", - ") # full_sample will update the parameters, then run sample and return the model image\n", - "ap.plots.model_image(fig9, ax9[1], model4)\n", - "T = ax9[1].set_title(\"new parameter values\")" - ] - }, { "cell_type": "code", "execution_count": null, @@ -539,9 +480,7 @@ "\n", "fig2, ax2 = plt.subplots(figsize=(8, 8))\n", "\n", - "pixels = (\n", - " model4().data.detach().cpu().numpy()\n", - ") # model4.model_image.data is the pytorch stored model image pixel values. Calling detach().cpu().numpy() is needed to get the data out of pytorch and in a usable form\n", + "pixels = model2().data.detach().cpu().numpy()\n", "\n", "im = plt.imshow(\n", " np.log10(pixels), # take log10 for better dynamic range\n", @@ -566,44 +505,46 @@ "outputs": [], "source": [ "# first let's download an image to play with\n", - "hdu = fits.open(\n", - " \"https://www.legacysurvey.org/viewer/fits-cutout?ra=36.3684&dec=-25.6389&size=700&layer=ls-dr9&pixscale=0.262&bands=r\"\n", - ")\n", + "filename = \"https://www.legacysurvey.org/viewer/fits-cutout?ra=36.3684&dec=-25.6389&size=700&layer=ls-dr9&pixscale=0.262&bands=r\"\n", + "hdu = fits.open(filename)\n", "target_data = np.array(hdu[0].data, dtype=np.float64)\n", "\n", "wcs = WCS(hdu[0].header)\n", "\n", "# Create a target object with WCS which will specify the pixelscale and origin for us!\n", - "target = ap.image.Target_Image(\n", + "target = ap.TargetImage(\n", " data=target_data,\n", " zeropoint=22.5,\n", " wcs=wcs,\n", ")\n", "\n", "fig3, ax3 = plt.subplots(figsize=(8, 8))\n", - "ap.plots.target_image(\n", - " fig3, ax3, target, flipx=True\n", - ") # note we flip the x-axis since RA coordinates are backwards\n", + "ap.plots.target_image(fig3, ax3, target)\n", "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Even better, just load directly from a FITS file\n", + "\n", + "AstroPhot recognizes standard FITS keywords to extract a target image. Note that this wont work for all FITS files, just ones that define the following keywords: `CTYPE1`, `CTYPE2`, `CRVAL1`, `CRVAL2`, `CRPIX1`, `CRPIX2`, `CD1_1`, `CD1_2`, `CD2_1`, `CD2_2`, and `MAGZP` with the usual meanings. AstroPhot can also handle SIP, see the SIP tutorial for details there.\n", + "\n", + "Further keywords specific to AstroPhot that it uses for some advanced features like multi-band fitting are: `CRTAN1`, `CRTAN2` used for aligning images, and `IDNTY` used for identifying when two images are actually cutouts of the same image. And AstroPhot also will store the `PSF`, `WEIGHT`, and `MASK` in extra extensions of the FITS file when it makes one." + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "# Models can be constructed by providing model_type, or by creating the desired class directly\n", + "target = ap.TargetImage(filename=filename)\n", "\n", - "# notice this is no longer \"AstroPhot_Model\"\n", - "model1_v2 = ap.models.Sersic_Galaxy(\n", - " parameters={\"center\": [50, 50], \"q\": 0.6, \"PA\": 60 * np.pi / 180, \"n\": 2, \"Re\": 10, \"Ie\": 1},\n", - " target=ap.image.Target_Image(data=np.zeros((100, 100)), pixelscale=1),\n", - " psf_mode=\"full\", # only change is the psf_mode\n", - ")\n", - "\n", - "# This will be the same as model1, except note that the \"psf_mode\" keyword is now tracked since it isn't a default value\n", - "print(model1_v2)" + "fig3, ax3 = plt.subplots(figsize=(8, 8))\n", + "ap.plots.target_image(fig3, ax3, target)\n", + "plt.show()" ] }, { @@ -614,14 +555,12 @@ "source": [ "# List all the available model names\n", "\n", - "# AstroPhot keeps track of all the subclasses of the AstroPhot_Model object, this list will\n", + "# AstroPhot keeps track of all the subclasses of the AstroPhot Model object, this list will\n", "# include all models even ones added by the user\n", - "print(\n", - " ap.models.AstroPhot_Model.List_Model_Names(usable=True)\n", - ") # set usable = None for all models, or usable = False for only base classes\n", + "print(ap.Model.List_Models(usable=True, types=True))\n", "print(\"---------------------------\")\n", "# It is also possible to get all sub models of a specific Type\n", - "print(\"only warp models: \", ap.models.Warp_Galaxy.List_Model_Names())" + "print(\"only galaxy models: \", ap.models.GalaxyModel.List_Models(types=True))" ] }, { @@ -640,7 +579,7 @@ "outputs": [], "source": [ "# check if AstroPhot has detected your GPU\n", - "print(ap.AP_config.ap_device) # most likely this will say \"cpu\" unless you already have a cuda GPU,\n", + "print(ap.config.DEVICE) # most likely this will say \"cpu\" unless you already have a cuda GPU,\n", "# in which case it should say \"cuda:0\"" ] }, @@ -651,7 +590,7 @@ "outputs": [], "source": [ "# If you have a GPU but want to use the cpu for some reason, just set:\n", - "ap.AP_config.ap_device = \"cpu\"\n", + "ap.config.DEVICE = \"cpu\"\n", "# BEFORE creating anything else (models, images, etc.)" ] }, @@ -671,17 +610,19 @@ "outputs": [], "source": [ "# Again do this BEFORE creating anything else\n", - "ap.AP_config.ap_dtype = torch.float32\n", + "ap.config.DTYPE = torch.float32\n", "\n", "# Now new AstroPhot objects will be made with single bit precision\n", - "W1 = ap.image.Window(origin=[0, 0], pixel_shape=[1, 1], pixelscale=1)\n", - "print(\"now a single:\", W1.origin.dtype)\n", + "T1 = ap.TargetImage(data=np.zeros((100, 100)))\n", + "T1.to()\n", + "print(\"now a single:\", T1.data.dtype)\n", "\n", "# Here we switch back to double precision\n", - "ap.AP_config.ap_dtype = torch.float64\n", - "W2 = ap.image.Window(origin=[0, 0], pixel_shape=[1, 1], pixelscale=1)\n", - "print(\"back to double:\", W2.origin.dtype)\n", - "print(\"old window is still single:\", W1.origin.dtype)" + "ap.config.DTYPE = torch.float64\n", + "T2 = ap.TargetImage(data=np.zeros((100, 100)))\n", + "T2.to()\n", + "print(\"back to double:\", T2.data.dtype)\n", + "print(\"old image is still single!:\", T1.data.dtype)" ] }, { @@ -697,7 +638,7 @@ "source": [ "## Tracking output\n", "\n", - "The AstroPhot optimizers, and occasionally the other AstroPhot objects, will provide status updates about themselves which can be very useful for debugging problems or just keeping tabs on progress. There are a number of use cases for AstroPhot, each having different desired output behaviors. To accommodate all users, AstroPhot implements a general logging system. The object `ap.AP_config.ap_logger` is a logging object which by default writes to AstroPhot.log in the local directory. As the user, you can set that logger to be any logging object you like for arbitrary complexity. Most users will, however, simply want to control the filename, or have it output to screen instead of a file. Below you can see examples of how to do that." + "The AstroPhot optimizers, and occasionally the other AstroPhot objects, will provide status updates about themselves which can be very useful for debugging problems or just keeping tabs on progress. There are a number of use cases for AstroPhot, each having different desired output behaviors. To accommodate all users, AstroPhot implements a general logging system. The object `ap.config.logger` is a logging object which by default writes to AstroPhot.log in the local directory. As the user, you can set that logger to be any logging object you like for arbitrary complexity. Most users will, however, simply want to control the filename, or have it output to screen instead of a file. Below you can see examples of how to do that." ] }, { @@ -709,23 +650,23 @@ "# note that the log file will be where these tutorial notebooks are in your filesystem\n", "\n", "# Here we change the settings so AstroPhot only prints to a log file\n", - "ap.AP_config.set_logging_output(stdout=False, filename=\"AstroPhot.log\")\n", - "ap.AP_config.ap_logger.info(\"message 1: this should only appear in the AstroPhot log file\")\n", + "ap.config.set_logging_output(stdout=False, filename=\"AstroPhot.log\")\n", + "ap.config.logger.info(\"message 1: this should only appear in the AstroPhot log file\")\n", "\n", "# Here we change the settings so AstroPhot only prints to console\n", - "ap.AP_config.set_logging_output(stdout=True, filename=None)\n", - "ap.AP_config.ap_logger.info(\"message 2: this should only print to the console\")\n", + "ap.config.set_logging_output(stdout=True, filename=None)\n", + "ap.config.logger.info(\"message 2: this should only print to the console\")\n", "\n", "# Here we change the settings so AstroPhot prints to both, which is the default\n", - "ap.AP_config.set_logging_output(stdout=True, filename=\"AstroPhot.log\")\n", - "ap.AP_config.ap_logger.info(\"message 3: this should appear in both the console and the log file\")" + "ap.config.set_logging_output(stdout=True, filename=\"AstroPhot.log\")\n", + "ap.config.logger.info(\"message 3: this should appear in both the console and the log file\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "You can also change the logging level and/or formatter for the stdout and filename options (see `help(ap.AP_config.set_logging_output)` for details). However, at that point you may want to simply make your own logger object and assign it to the `ap.AP_config.ap_logger` variable." + "You can also change the logging level and/or formatter for the stdout and filename options (see `help(ap.config.set_logging_output)` for details). However, at that point you may want to simply make your own logger object and assign it to the `ap.config.logger` variable." ] }, { diff --git a/docs/source/tutorials/GettingStartedJAX.ipynb b/docs/source/tutorials/GettingStartedJAX.ipynb new file mode 100644 index 00000000..717cabcd --- /dev/null +++ b/docs/source/tutorials/GettingStartedJAX.ipynb @@ -0,0 +1,708 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using AstroPhot with JAX\n", + "\n", + "In this notebook we will run through the same \"getting started\" tutorial, except this time using JAX!\n", + "\n", + "You'll notice right away that basically everything is the same. The only difference is that now all the data and parameters are stored as ``jax.numpy`` arrays. So if that's how you prefer to interact with AstroPhot then forge on! AstroPhot should integrate with a JAX workflow very easily. If you want to treat AstroPhot in a functional way, then simply build the model you want then use ``f = lambda x: model(x).data`` and now ``f(x)`` returns the model image and you can do all the usual, vmap, autograd, etc stuff of JAX on this. Similarly, making ``l = lambda x: model.gaussian_log_likelihood(x)`` will return a scalar log likelihood function (Poisson also works). One note though, JAX has a reputation for being fast, this is true of JIT compiled JAX but not necessarily \"eager\" JAX where we simply define functions and evaluate them. This is the mode that AstroPhot mostly works in since it is so dynamic in the number of options it has and the freedom users have to change them. For this reason, you will find that AstroPhot is faster in PyTorch than JAX (uncompiled). For now we provide this API so JAX users can take advantage of AstroPhot in their workflow. So long as you work in a JAX-oriented way (JIT compile before expecting anything to be fast) then everything should work well and fast. There are only a handful of AstroPhot models that don't work yet in JAX (notably the isothermal edgeon galaxy model since JAX doesn't have the K1 Bessel function)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import astrophot as ap\n", + "import numpy as np\n", + "import jax\n", + "from astropy.io import fits\n", + "from astropy.wcs import WCS\n", + "import matplotlib.pyplot as plt\n", + "import socket\n", + "\n", + "socket.setdefaulttimeout(120)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setting the backend to JAX\n", + "\n", + "The first thing we need to do is tell AstroPhot to start using JAX. The easiest way to do this is by setting the environment variable `CASAKDE_BACKEND=\"jax\"` which will update the caskade parameter manager and AstroPhot to now use JAX. If you want to control the backend inside a script so that you can easily mix and match between scripts, then just make sure to set the backend at the beginning and don't change it within one script!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ap.backend.backend = \"jax\"\n", + "# and that's it!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Your first model\n", + "\n", + "The basic format for making an AstroPhot model is given below. Once a model object is constructed, it can be manipulated and updated in various ways." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model1 = ap.Model(\n", + " name=\"model1\",\n", + " model_type=\"sersic galaxy model\", # this specifies the kind of model\n", + " # here we set initial values for each parameter\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " n=2,\n", + " Re=10,\n", + " Ie=1,\n", + " # every model needs a target, more on this later\n", + " target=ap.TargetImage(data=np.zeros((100, 100)), zeropoint=22.5),\n", + ")\n", + "\n", + "# models must/should be initialized before doing anything with them.\n", + "# This makes sure all the parameters and metadata are ready to go.\n", + "model1.initialize()\n", + "\n", + "# We can print the model's current state\n", + "print(model1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# AstroPhot has built in methods to plot relevant information. This plots the model\n", + "# as projected into the \"target\" image. Thus it has the same pixelscale, orientation\n", + "# and (optionally) PSF as the model's target.\n", + "fig, ax = plt.subplots(figsize=(8, 7))\n", + "ap.plots.model_image(fig, ax, model1)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Giving the model a Target\n", + "\n", + "Typically, the main goal when constructing an AstroPhot model is to fit to an image. We need to give the model access to the image and some information about it to get started." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# first let's download an image to play with\n", + "hdu = fits.open(\n", + " \"https://www.legacysurvey.org/viewer/fits-cutout?ra=36.3684&dec=-25.6389&size=700&layer=ls-dr9&pixscale=0.262&bands=r\"\n", + ")\n", + "target_data = np.array(hdu[0].data, dtype=np.float64)\n", + "\n", + "target = ap.TargetImage(\n", + " data=target_data,\n", + " pixelscale=0.262,\n", + " zeropoint=22.5, # optionally, a zeropoint tells AstroPhot the pixel flux units\n", + " variance=\"auto\", # Automatic variance estimate for testing and demo purposes only! In real analysis use weight maps, counts, gain, etc to compute variance!\n", + ")\n", + "\n", + "# The default AstroPhot target plotting method uses log scaling in bright areas and histogram scaling in faint areas\n", + "fig3, ax3 = plt.subplots(figsize=(8, 8))\n", + "ap.plots.target_image(fig3, ax3, target)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This model now has a target that it will attempt to match\n", + "model2 = ap.Model(\n", + " name=\"model with target\",\n", + " model_type=\"sersic galaxy model\",\n", + " target=target,\n", + ")\n", + "\n", + "# Instead of giving initial values for all the parameters, it is possible to\n", + "# simply call \"initialize\" and AstroPhot will try to guess initial values for\n", + "# every parameter. It is also possible to set just a few parameters and let\n", + "# AstroPhot try to figure out the rest. For example you could give it an initial\n", + "# Guess for the center and it will work from there.\n", + "model2.initialize()\n", + "\n", + "# Plotting the initial parameters and residuals, we see it gets the rough shape\n", + "# of the galaxy right, but still has some fitting to do\n", + "fig4, ax4 = plt.subplots(1, 2, figsize=(16, 6))\n", + "ap.plots.model_image(fig4, ax4[0], model2)\n", + "ap.plots.residual_image(fig4, ax4[1], model2)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Now that the model has been set up with a target and initialized with parameter values, it is time to fit the image\n", + "result = ap.fit.LM(model2, verbose=1).fit()\n", + "\n", + "# See that we use ap.fit.LM, this is the Levenberg-Marquardt Chi^2 minimization method, it is the recommended technique\n", + "# for most least-squares problems. See the Fitting Methods tutorial for more on fitters!\n", + "print(\"Fit message:\", result.message) # the fitter will store a message about its convergence" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(model2)\n", + "# we now plot the fitted model and the image residuals\n", + "fig5, ax5 = plt.subplots(1, 2, figsize=(16, 6))\n", + "ap.plots.model_image(fig5, ax5[0], model2)\n", + "ap.plots.residual_image(fig5, ax5[1], model2, normalize_residuals=True)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot surface brightness profile\n", + "\n", + "# we now plot the model profile and a data profile. The model profile is determined from the model parameters\n", + "# the data profile is determined by taking the median of pixel values at a given radius. Notice that the model\n", + "# profile is slightly higher than the data profile? This is because there are other objects in the image which\n", + "# are not being modelled, the data profile uses a median so they are ignored, but for the model we fit all pixels.\n", + "fig10, ax10 = plt.subplots(figsize=(8, 8))\n", + "ap.plots.radial_light_profile(fig10, ax10, model2)\n", + "ap.plots.radial_median_profile(fig10, ax10, model2)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Update uncertainty estimates\n", + "\n", + "After running a fit, the `ap.fit.LM` optimizer can update the uncertainty for each parameter. In fact it can return the full covariance matrix if needed. For a demo of what can be done with the covariance matrix see the `FittingMethods` tutorial. One important note is that the variance image needs to be correct for the uncertainties to be meaningful!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result.update_uncertainty()\n", + "print(model2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that these uncertainties are pure statistical uncertainties that come from evaluating the structure of the $\\chi^2$ minimum. Systematic uncertainties are not included and these often significantly outweigh the standard errors. As can be seen in the residual plot above, there is certainly plenty of unmodelled structure there. Use caution when interpreting the errors from these fits." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the uncertainty matrix\n", + "\n", + "# While the scale of the uncertainty may not be meaningful if the image variance is not accurate, we\n", + "# can still see how the covariance of the parameters plays out in a given fit.\n", + "fig, ax = ap.plots.covariance_matrix(\n", + " result.covariance_matrix,\n", + " model2.get_values(),\n", + " model2.build_params_array_names(),\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Record the total flux/magnitude\n", + "\n", + "Often the parameter of interest is the total flux or magnitude, even if this isn't one of the core parameters of the model, it can be computed. For Sersic and Moffat models with analytic total fluxes it will be integrated to infinity, for most other models it will simply be the total flux in the window." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " f\"Total Flux: {model2.total_flux().item():.1f} +- {model2.total_flux_uncertainty().item():.1f}\"\n", + ")\n", + "print(\n", + " f\"Total Magnitude: {model2.total_magnitude().item():.4f} +- {model2.total_magnitude_uncertainty().item():.4f}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Giving the model a specific target window\n", + "\n", + "Sometimes an object isn't nicely centered in the image, and may not even be the dominant object in the image. It is therefore nice to be able to specify what part of the image we should analyze." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# note, we don't provide a name here. A unique name will automatically be generated using the model type\n", + "model3 = ap.Model(\n", + " model_type=\"sersic galaxy model\",\n", + " target=target,\n", + " window=[480, 595, 555, 665], # this is a region in pixel coordinates (imin,imax,jmin,jmax)\n", + ")\n", + "print(f\"automatically generated name: '{model3.name}'\")\n", + "\n", + "# We can plot the \"model window\" to show us what part of the image will be analyzed by that model\n", + "fig6, ax6 = plt.subplots(figsize=(8, 8))\n", + "ap.plots.target_image(fig6, ax6, model3.target)\n", + "ap.plots.model_window(fig6, ax6, model3)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model3.initialize()\n", + "result = ap.fit.LM(model3, verbose=1).fit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Note that when only a window is fit, the default plotting methods will only show that window\n", + "print(model3)\n", + "fig7, ax7 = plt.subplots(1, 2, figsize=(16, 6))\n", + "ap.plots.model_image(fig7, ax7[0], model3)\n", + "ap.plots.residual_image(fig7, ax7[1], model3, normalize_residuals=True)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setting parameter constraints\n", + "\n", + "A common feature of fitting parameters is that they have some constraint on their behaviour and cannot be sampled at any value from (-inf, inf). AstroPhot circumvents this by remapping any constrained parameter to a space where it can take any real value, at least for the sake of fitting. For most parameters these constraints are applied by default; for example the axis ratio q is required to be in the range (0,1). Other parameters, such as the position angle (PA) are cyclic, they can be in the range (0,pi) but also can wrap around. It is possible to manually set these constraints while constructing a model.\n", + "\n", + "In general adding constraints makes fitting more difficult. There is a chance that the fitting process runs up against a constraint boundary and gets stuck. However, sometimes adding constraints is necessary and so the capability is included." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# here we make a sersic model that can only have q and n in a narrow range\n", + "# Also, we give PA and initial value and lock that so it does not change during fitting\n", + "constrained_param_model = ap.Model(\n", + " name=\"constrained parameters\",\n", + " model_type=\"sersic galaxy model\",\n", + " q={\"valid\": (0.4, 0.6)},\n", + " n={\"valid\": (2, 3)},\n", + " PA={\"value\": 60 * np.pi / 180},\n", + " target=target,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Aside from constraints on an individual parameter, it is sometimes desirable to have different models share parameter values. For example you may wish to combine multiple simple models into a more complex model (more on that in a different tutorial), and you may wish for them all to have the same center. This can be accomplished with \"equality constraints\" as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# model 1 is a sersic model\n", + "model_1 = ap.Model(model_type=\"sersic galaxy model\", center=[50, 50], PA=np.pi / 4, target=target)\n", + "# model 2 is an exponential model\n", + "model_2 = ap.Model(model_type=\"exponential galaxy model\", target=target)\n", + "\n", + "# Here we add the constraint for \"PA\" to be the same for each model.\n", + "# In doing so we provide the model and parameter name which should\n", + "# be connected.\n", + "model_2.PA = model_1.PA\n", + "\n", + "# Here we can see how the two models now both can modify this parameter\n", + "print(\n", + " \"initial values: model_1 PA\",\n", + " model_1.PA.value.item(),\n", + " \"model_2 PA\",\n", + " model_2.PA.value.item(),\n", + ")\n", + "# Now we modify the PA for model_1\n", + "model_1.PA.value = np.pi / 3\n", + "print(\n", + " \"change model_1: model_1 PA\",\n", + " model_1.PA.value.item(),\n", + " \"model_2 PA\",\n", + " model_2.PA.value.item(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic things to do with a model\n", + "\n", + "Now that we know how to create a model and fit it to an image, lets get to know the model a bit better." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save the model state to a file\n", + "\n", + "model2.save_state(\"current_spot.hdf5\", appendable=True) # save as it is\n", + "model2.q = 0.1 # do some updates to the model\n", + "model2.PA = 0.1\n", + "model2.n = 0.9\n", + "model2.Re = 0.1\n", + "model2.append_state(\"current_spot.hdf5\") # save the updated model state as often as you like" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load a model state from a file\n", + "\n", + "model2.load_state(\"current_spot.hdf5\", index=0) # load the first state from the file\n", + "print(model2) # see that the values are back to where they started" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save the model image to a file\n", + "\n", + "model_image_sample = model2()\n", + "model_image_sample.save(\"model2.fits\")\n", + "\n", + "saved_image_hdu = fits.open(\"model2.fits\")\n", + "fig, ax = plt.subplots(figsize=(8, 8))\n", + "ax.imshow(\n", + " np.log10(saved_image_hdu[0].data),\n", + " origin=\"lower\",\n", + " cmap=\"viridis\",\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot model image with discrete levels\n", + "\n", + "# this is very useful for visualizing subtle features and for eyeballing the brightness at a given location.\n", + "# just add the \"cmap_levels\" keyword to the model_image call and tell it how many levels you want\n", + "fig11, ax11 = plt.subplots(figsize=(8, 8))\n", + "ap.plots.model_image(fig11, ax11, model2, cmap_levels=15)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save and load a target image\n", + "\n", + "target.save(\"target.fits\")\n", + "\n", + "# Note that it is often also possible to load from regular FITS files\n", + "new_target = ap.TargetImage(filename=\"target.fits\")\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 8))\n", + "ap.plots.target_image(fig, ax, new_target)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Access the model image pixels directly\n", + "\n", + "fig2, ax2 = plt.subplots(figsize=(8, 8))\n", + "\n", + "# Transpose because AstroPhot indexes with (i,j) while numpy uses (j,i)\n", + "pixels = model2().data.T\n", + "\n", + "im = plt.imshow(\n", + " np.log10(pixels), # take log10 for better dynamic range\n", + " origin=\"lower\",\n", + " cmap=ap.plots.visuals.cmap_grad, # gradient colourmap default for AstroPhot\n", + ")\n", + "plt.colorbar(im)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load target with WCS information" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# first let's download an image to play with\n", + "filename = \"https://www.legacysurvey.org/viewer/fits-cutout?ra=36.3684&dec=-25.6389&size=700&layer=ls-dr9&pixscale=0.262&bands=r\"\n", + "hdu = fits.open(filename)\n", + "target_data = np.array(hdu[0].data, dtype=np.float64)\n", + "\n", + "wcs = WCS(hdu[0].header)\n", + "\n", + "# Create a target object with WCS which will specify the pixelscale and origin for us!\n", + "target = ap.TargetImage(\n", + " data=target_data,\n", + " zeropoint=22.5,\n", + " wcs=wcs,\n", + ")\n", + "\n", + "fig3, ax3 = plt.subplots(figsize=(8, 8))\n", + "ap.plots.target_image(fig3, ax3, target)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Even better, just load directly from a FITS file\n", + "\n", + "AstroPhot recognizes standard FITS keywords to extract a target image. Note that this wont work for all FITS files, just ones that define the following keywords: `CTYPE1`, `CTYPE2`, `CRVAL1`, `CRVAL2`, `CRPIX1`, `CRPIX2`, `CD1_1`, `CD1_2`, `CD2_1`, `CD2_2`, and `MAGZP` with the usual meanings. AstroPhot can also handle SIP, see the SIP tutorial for details there.\n", + "\n", + "Further keywords specific to AstroPhot that it uses for some advanced features like multi-band fitting are: `CRTAN1`, `CRTAN2` used for aligning images, and `IDNTY` used for identifying when two images are actually cutouts of the same image. And AstroPhot also will store the `PSF`, `WEIGHT`, and `MASK` in extra extensions of the FITS file when it makes one." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target = ap.TargetImage(filename=filename)\n", + "\n", + "fig3, ax3 = plt.subplots(figsize=(8, 8))\n", + "ap.plots.target_image(fig3, ax3, target)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# List all the available model names\n", + "\n", + "# AstroPhot keeps track of all the subclasses of the AstroPhot Model object, this list will\n", + "# include all models even ones added by the user\n", + "print(ap.Model.List_Models(usable=True, types=True))\n", + "print(\"---------------------------\")\n", + "# It is also possible to get all sub models of a specific Type\n", + "print(\"only galaxy models: \", ap.models.GalaxyModel.List_Models(types=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using GPU acceleration\n", + "\n", + "This one is easy! If you have a cuda enabled GPU available, JAX will just automatically detect it and use that device. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# this is different for the JAX version, JAX automatically handles device placement\n", + "# So AstroPhot just gives None as the device to let JAX to its thing\n", + "print(ap.config.DEVICE)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Boost GPU acceleration with single precision float32\n", + "\n", + "If you are using a GPU you can get significant performance increases in both memory and speed by switching from double precision (float64, the AstroPhot default) to single precision (float32) floating point numbers. The trade off is reduced precision, this can cause some unexpected behaviors. For example an optimizer may keep iterating forever if it is trying to optimize down to a precision below what the float32 will track. Typically, numbers with float32 are good down to 6 places and AstroPhot by default only attempts to minimize the Chi^2 to 3 places. However, to ensure the fit is secure to 3 places it often checks what is happening down at 4 or 5 places. Hence, issues can arise. For the most part you can go ahead with float32 and if you run into a weird bug, try on float64 before looking further.\n", + "\n", + "JAX has a global automatic type, so its not always a good idea to try and specify the type. By default, AstroPhot enables the ``jax.config.update(\"jax_enable_x64\", True)`` option so JAX will automatically use float64. You can switch this flag in the JAX config if you's like to use float32. That said, it is still possible to use the global AstroPhot config to set the data type." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Again do this BEFORE creating anything else\n", + "ap.config.DTYPE = jax.numpy.float32\n", + "\n", + "# Now new AstroPhot objects will be made with single bit precision\n", + "T1 = ap.TargetImage(data=np.zeros((100, 100)))\n", + "T1.to()\n", + "print(\"now a single:\", T1.data.dtype)\n", + "\n", + "# Here we switch back to double precision\n", + "ap.config.DTYPE = jax.numpy.float64\n", + "T2 = ap.TargetImage(data=np.zeros((100, 100)))\n", + "T2.to()\n", + "print(\"back to double:\", T2.data.dtype)\n", + "print(\"old image is still single!:\", T1.data.dtype)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See how the window created as a float32 stays that way? That's really bad to have lying around! Make sure to change the data type before creating anything! " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tracking output\n", + "\n", + "The AstroPhot optimizers, and occasionally the other AstroPhot objects, will provide status updates about themselves which can be very useful for debugging problems or just keeping tabs on progress. There are a number of use cases for AstroPhot, each having different desired output behaviors. To accommodate all users, AstroPhot implements a general logging system. The object `ap.config.logger` is a logging object which by default writes to AstroPhot.log in the local directory. As the user, you can set that logger to be any logging object you like for arbitrary complexity. Most users will, however, simply want to control the filename, or have it output to screen instead of a file. Below you can see examples of how to do that." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# note that the log file will be where these tutorial notebooks are in your filesystem\n", + "\n", + "# Here we change the settings so AstroPhot only prints to a log file\n", + "ap.config.set_logging_output(stdout=False, filename=\"AstroPhot.log\")\n", + "ap.config.logger.info(\"message 1: this should only appear in the AstroPhot log file\")\n", + "\n", + "# Here we change the settings so AstroPhot only prints to console\n", + "ap.config.set_logging_output(stdout=True, filename=None)\n", + "ap.config.logger.info(\"message 2: this should only print to the console\")\n", + "\n", + "# Here we change the settings so AstroPhot prints to both, which is the default\n", + "ap.config.set_logging_output(stdout=True, filename=\"AstroPhot.log\")\n", + "ap.config.logger.info(\"message 3: this should appear in both the console and the log file\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also change the logging level and/or formatter for the stdout and filename options (see `help(ap.config.set_logging_output)` for details). However, at that point you may want to simply make your own logger object and assign it to the `ap.config.logger` variable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/source/tutorials/GravitationalLensing.ipynb b/docs/source/tutorials/GravitationalLensing.ipynb new file mode 100644 index 00000000..b39f810c --- /dev/null +++ b/docs/source/tutorials/GravitationalLensing.ipynb @@ -0,0 +1,206 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Gravitational Lensing\n", + "\n", + "AstroPhot is now part of the caskade ecosystem. caskade simulators can interface\n", + "very easily since the parameter management is handled automatically. Here we\n", + "demonstrate how the caustics package, which is also written in caskade, can be\n", + "used to add gravitational lensing to AstroPhot models. This is similar to the\n", + "Custom Models tutorial although more specific." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import astrophot as ap\n", + "import matplotlib.pyplot as plt\n", + "import caustics\n", + "import numpy as np\n", + "import torch\n", + "import socket\n", + "\n", + "socket.setdefaulttimeout(120)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "class LensSersic(ap.models.SersicGalaxy):\n", + " _model_type = \"lensed\"\n", + "\n", + " def __init__(self, *args, lens, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + " self.lens = lens\n", + "\n", + " def transform_coordinates(self, x, y):\n", + " x, y = self.lens.raytrace(x, y)\n", + " x, y = super().transform_coordinates(x, y)\n", + " return x, y" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "target = ap.TargetImage(\n", + " filename=\"https://www.legacysurvey.org/viewer/fits-cutout?ra=177.1380&dec=19.5008&size=150&layer=ls-dr9&pixscale=0.262&bands=g\",\n", + " name=\"horseshoe\",\n", + " variance=\"auto\",\n", + " zeropoint=22.5,\n", + ")\n", + "target.psf = target.psf_image(data=ap.utils.initialize.gaussian_psf(0.5, 51, 0.262))\n", + "\n", + "cosmology = caustics.FlatLambdaCDM(name=\"cosmology\")\n", + "lens = caustics.SIE(\n", + " name=\"lens\",\n", + " x0=0.28,\n", + " y0=0.79,\n", + " q=0.9,\n", + " phi=2.5 * np.pi / 10,\n", + " Rein=5.5,\n", + " z_l=0.4457,\n", + " z_s=2.379,\n", + " cosmology=cosmology,\n", + ")\n", + "lens.to_dynamic()\n", + "lens.z_l.to_static()\n", + "lens.z_s.to_static()\n", + "source = ap.Model(\n", + " name=\"source\",\n", + " model_type=\"lensed sersic galaxy model\",\n", + " lens=lens,\n", + " center=[0.2, 0.42],\n", + " q=0.6,\n", + " PA=np.pi / 3,\n", + " n=1,\n", + " Re=0.1,\n", + " Ie=1.5,\n", + " target=target,\n", + " psf_convolve=True,\n", + ")\n", + "lenslight = ap.Model(\n", + " name=\"lenslight\",\n", + " model_type=\"sersic galaxy model\",\n", + " center=lambda p: torch.stack((p.x0.value, p.y0.value)),\n", + " q=lens.q,\n", + " PA=0,\n", + " n=4.7,\n", + " Re=1,\n", + " Ie=0.2,\n", + " target=target,\n", + " psf_convolve=True,\n", + ")\n", + "lenslight.center.link((lens.x0, lens.y0))\n", + "\n", + "model = ap.Model(\n", + " name=\"horseshoe\",\n", + " model_type=\"group model\",\n", + " models=[source, lenslight],\n", + " target=target,\n", + ")\n", + "model.initialize()\n", + "\n", + "fig, axarr = plt.subplots(1, 3, figsize=(15, 4))\n", + "ap.plots.target_image(fig, axarr[0], target)\n", + "axarr[0].set_title(\"Target Image\")\n", + "ap.plots.model_image(fig, axarr[1], model)\n", + "axarr[1].set_title(\"Model Image\")\n", + "ap.plots.residual_image(fig, axarr[2], model)\n", + "axarr[2].set_title(\"Residual Image\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "Note that we give reasonable starting parameters for the lensing model. Gravitational lensing is notoriously hard to model, so we need to start near the correct minimum otherwise we may easily fall to some poor local minimum." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "model.graphviz()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "res = ap.fit.LM(model, verbose=1).fit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axarr = plt.subplots(1, 3, figsize=(15, 4))\n", + "ap.plots.target_image(fig, axarr[0], target)\n", + "axarr[0].set_title(\"Target Image\")\n", + "ap.plots.model_image(fig, axarr[1], model, vmax=32)\n", + "axarr[1].set_title(\"Model Image\")\n", + "ap.plots.residual_image(fig, axarr[2], model)\n", + "axarr[2].set_title(\"Residual Image\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "This is not an exceptionally good fit, but it is well known that the horseshoe requires a more detailed model than an SIE lens. The cool result here is that we were able to link AstroPhot and caustics very easily to create a detailed lensing model!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/GroupModels.ipynb b/docs/source/tutorials/GroupModels.ipynb index d2d5ac85..d43feb28 100644 --- a/docs/source/tutorials/GroupModels.ipynb +++ b/docs/source/tutorials/GroupModels.ipynb @@ -23,10 +23,11 @@ "source": [ "import astrophot as ap\n", "import numpy as np\n", - "import torch\n", "from astropy.io import fits\n", "import matplotlib.pyplot as plt\n", - "from scipy.stats import iqr" + "import socket\n", + "\n", + "socket.setdefaulttimeout(120)" ] }, { @@ -36,11 +37,8 @@ "outputs": [], "source": [ "# first let's download an image to play with\n", - "hdu = fits.open(\n", - " \"https://www.legacysurvey.org/viewer/fits-cutout?ra=155.7720&dec=15.1494&size=150&layer=ls-dr9&pixscale=0.262&bands=r\"\n", - ")\n", + "hdu = ap.utils.ls_open(155.7720, 15.1494, 150 * 0.262, band=\"r\")\n", "target_data = np.array(hdu[0].data, dtype=np.float64)\n", - "\n", "fig1, ax1 = plt.subplots(figsize=(8, 8))\n", "plt.imshow(np.arctan(target_data / 0.05), origin=\"lower\", cmap=\"inferno\")\n", "plt.axis(\"off\")\n", @@ -60,7 +58,7 @@ "#########################################\n", "from photutils.segmentation import detect_sources, deblend_sources\n", "\n", - "initsegmap = detect_sources(target_data, threshold=0.02, npixels=5)\n", + "initsegmap = detect_sources(target_data, threshold=0.02, npixels=6)\n", "segmap = deblend_sources(target_data, initsegmap, npixels=5).data\n", "fig8, ax8 = plt.subplots(figsize=(8, 8))\n", "ax8.imshow(segmap, origin=\"lower\", cmap=\"inferno\")\n", @@ -74,11 +72,11 @@ "outputs": [], "source": [ "pixelscale = 0.262\n", - "target = ap.image.Target_Image(\n", - " data=target_data,\n", + "target = ap.TargetImage(\n", + " data=target_data + 0.01, # add fake sky level back in\n", " pixelscale=pixelscale,\n", " zeropoint=22.5,\n", - " variance=\"auto\", # np.ones_like(target_data) * np.std(target_data[segmap == 0]) ** 2,\n", + " variance=\"auto\", # this will estimate the variance from the data\n", ")\n", "fig2, ax2 = plt.subplots(figsize=(8, 8))\n", "ap.plots.target_image(fig2, ax2, target)\n", @@ -105,13 +103,11 @@ "# This will convert the segmentation map into boxes that enclose the identified pixels\n", "windows = ap.utils.initialize.windows_from_segmentation_map(segmap)\n", "# Next we scale up the windows so that AstroPhot can fit the faint parts of each object as well\n", - "windows = ap.utils.initialize.scale_windows(\n", - " windows, image_shape=target_data.shape, expand_scale=2, expand_border=10\n", - ")\n", + "windows = ap.utils.initialize.scale_windows(windows, image=target, expand_scale=2, expand_border=10)\n", "# Here we get some basic starting parameters for the galaxies (center, position angle, axis ratio)\n", - "centers = ap.utils.initialize.centroids_from_segmentation_map(segmap, target_data)\n", - "PAs = ap.utils.initialize.PA_from_segmentation_map(segmap, target_data, centers)\n", - "qs = ap.utils.initialize.q_from_segmentation_map(segmap, target_data, centers, PAs)" + "centers = ap.utils.initialize.centroids_from_segmentation_map(segmap, target)\n", + "PAs = ap.utils.initialize.PA_from_segmentation_map(segmap, target, centers)\n", + "qs = ap.utils.initialize.q_from_segmentation_map(segmap, target, centers)" ] }, { @@ -124,32 +120,45 @@ "seg_models = []\n", "for win in windows:\n", " seg_models.append(\n", - " ap.models.AstroPhot_Model(\n", + " ap.Model(\n", " name=f\"object {win:02d}\",\n", " window=windows[win],\n", " model_type=\"sersic galaxy model\",\n", " target=target,\n", - " parameters={\n", - " \"center\": np.array(centers[win]) * pixelscale,\n", - " \"PA\": PAs[win],\n", - " \"q\": qs[win],\n", - " },\n", + " center=centers[win],\n", + " PA=PAs[win],\n", + " q=qs[win],\n", " )\n", " )\n", - "sky = ap.models.AstroPhot_Model(\n", + "sky = ap.Model(\n", " name=f\"sky level\",\n", " model_type=\"flat sky model\",\n", " target=target,\n", + " I={\"valid\": (0, None)},\n", ")\n", "\n", "# We build the group model just like any other, except we pass a list of other models\n", - "groupmodel = ap.models.AstroPhot_Model(\n", + "groupmodel = ap.Model(\n", " name=\"group\", models=[sky] + seg_models, target=target, model_type=\"group model\"\n", ")\n", "\n", "groupmodel.initialize()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "x = groupmodel.get_values()\n", + "x = x.repeat(5, 1)\n", + "imgs = torch.vmap(lambda x: groupmodel(x).data)(x)\n", + "print(imgs.shape)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -179,7 +188,8 @@ "source": [ "# This is now a very complex model composed of 9 sub-models! In total 57 parameters!\n", "# Here we will limit it to 1 iteration so that it runs quickly. In general you should let it run to convergence\n", - "result = ap.fit.Iter(groupmodel, verbose=1, max_iter=1).fit()" + "result = ap.fit.Iter(groupmodel, verbose=1, max_iter=2).fit()\n", + "result = ap.fit.LM(groupmodel, verbose=0, max_iter=2).fit()" ] }, { @@ -190,7 +200,7 @@ "source": [ "# Now we can see what the fitting has produced\n", "fig10, ax10 = plt.subplots(1, 2, figsize=(16, 7))\n", - "ap.plots.model_image(fig10, ax10[0], groupmodel)\n", + "ap.plots.model_image(fig10, ax10[0], groupmodel, vmax=25)\n", "ap.plots.residual_image(fig10, ax10[1], groupmodel, normalize_residuals=True)\n", "plt.show()" ] @@ -201,6 +211,76 @@ "source": [ "Which is a pretty good fit! We haven't accounted for the PSF yet, so some of the central regions are not very well fit. It is very easy to add a PSF model to AstroPhot for fitting. Check out the Basic PSF Models tutorial for more information." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Segmentation maps\n", + "\n", + "AstroPhot can produce a model based segmentation map. Essentially, once the models are fit it can compute the \"importance\" of each pixel to a given model. For each pixel and for each model it is possible to compute what fraction of the model's total flux is placed in that pixel. Whichever model assigns the highest fraction of all its flux to a given pixel, is the \"winner\" for that pixel and so the segmentation map assigns the pixel to its index. Note that this is only done at the first level of a group model, since group models can contain group models, it is possible to have a complex multi-component model still act as one index in the segmentation map. \n", + "\n", + "Also note that this means AstroPhot can perform segmentation even for images with non-zero sky levels, there is no need to do background subtraction before segmenting (though you do need to fit the models)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(groupmodel.segmentation_map(), origin=\"lower\", cmap=\"inferno\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Deblending\n", + "\n", + "AstroPhot can perform a basic deblending based on the fitted model. A new target image is created for each object which for each pixel holds the fraction of signal from the original target corresponding to the fraction of light coming from that individual model (compared to the full group model). This can create some patches of zero pixel values where the model falls to zero in its own window, or where other models are much brighter. \n", + "\n", + "Note that this works even when the sky level is not subtracted. Though for very bright sky levels, the deblended objects tend to just look like their model images.\n", + "\n", + "AstroPhot doesn't use deblending, it's forward modelling approach means that it simultaneously models all objects using a principled Gaussian (or Poisson) likelihood. That said, other analyses may make use of deblended stamps. It is also a good systematic check of the flux estimates. A flux estimate that varies wildly from the deblend total flux might be cause for concern." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "subtargets = groupmodel.deblend()\n", + "fig, axarr = plt.subplots(2, int(np.ceil(len(subtargets) / 2)), figsize=(16, 7))\n", + "for i, subtarget in enumerate(subtargets):\n", + " ax = axarr.flatten()[i]\n", + " ap.plots.target_image(fig, ax, subtarget)\n", + " ax.set_title(subtarget.name, fontsize=10)\n", + " ax.axis(\"off\")\n", + "axarr.flatten()[-1].axis(\"off\")\n", + "plt.show()\n", + "\n", + "for submodel, subtarget in zip(groupmodel.models, subtargets):\n", + " print(\n", + " f\"{submodel.name}: total model flux = {submodel.total_flux().item():.2f} ± {submodel.total_flux_uncertainty().item():.2f}, deblend total flux = {subtarget.data.sum().item():.2f}\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Observe that all the models (except the sky, which we fudged anyway) are within one sigma between the model flux and the deblended flux. This is a good sign! If there had been any major deviations that would be very suspicious." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/docs/source/tutorials/ImageAlignment.ipynb b/docs/source/tutorials/ImageAlignment.ipynb new file mode 100644 index 00000000..84aa9f12 --- /dev/null +++ b/docs/source/tutorials/ImageAlignment.ipynb @@ -0,0 +1,296 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Aligning Images\n", + "\n", + "In AstroPhot, the image WCS is part of the model and so can be optimized alongside other model parameters. Here we will demonstrate a basic example of image alignment, but the sky is the limit, you can perform highly detailed image alignment with AstroPhot!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import astrophot as ap\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import socket\n", + "\n", + "socket.setdefaulttimeout(120)" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Relative shift\n", + "\n", + "Often the WCS solution is already really good, we just need a local shift in x and/or y to get things just right. Lets start by optimizing a translation in the WCS that improves the fit for our models!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "target_r = ap.TargetImage(\n", + " filename=\"https://www.legacysurvey.org/viewer/fits-cutout?ra=329.2715&dec=13.6483&size=150&layer=ls-dr9&pixscale=0.262&bands=r\",\n", + " name=\"target_r\",\n", + " variance=\"auto\",\n", + ")\n", + "target_g = ap.TargetImage(\n", + " filename=\"https://www.legacysurvey.org/viewer/fits-cutout?ra=329.2715&dec=13.6483&size=150&layer=ls-dr9&pixscale=0.262&bands=g\",\n", + " name=\"target_g\",\n", + " variance=\"auto\",\n", + ")\n", + "\n", + "# Uh-oh! our images are misaligned by 1 pixel, this will cause problems!\n", + "target_g.crpix = target_g.crpix + 1\n", + "\n", + "fig, axarr = plt.subplots(1, 2, figsize=(15, 7))\n", + "ap.plots.target_image(fig, axarr[0], target_r)\n", + "axarr[0].set_title(\"Target Image (r-band)\")\n", + "ap.plots.target_image(fig, axarr[1], target_g)\n", + "axarr[1].set_title(\"Target Image (g-band)\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "# fmt: off\n", + "# r-band model\n", + "psfr = ap.Model(name=\"psfr\", model_type=\"moffat psf model\", n=2, Rd=1.0, target=target_r.psf_image(data=np.zeros((51, 51))))\n", + "star1r = ap.Model(name=\"star1-r\", model_type=\"point model\", window=[0, 60, 80, 135], center=[12, 9], psf=psfr, target=target_r)\n", + "star2r = ap.Model(name=\"star2-r\", model_type=\"point model\", window=[40, 90, 20, 70], center=[3, -7], psf=psfr, target=target_r)\n", + "star3r = ap.Model(name=\"star3-r\", model_type=\"point model\", window=[109, 150, 40, 90], center=[-15, -3], psf=psfr, target=target_r)\n", + "modelr = ap.Model(name=\"model-r\", model_type=\"group model\", models=[star1r, star2r, star3r], target=target_r)\n", + "\n", + "# g-band model\n", + "psfg = ap.Model(name=\"psfg\", model_type=\"moffat psf model\", n=2, Rd=1.0, target=target_g.psf_image(data=np.zeros((51, 51))))\n", + "star1g = ap.Model(name=\"star1-g\", model_type=\"point model\", window=[0, 60, 80, 135], center=star1r.center, psf=psfg, target=target_g)\n", + "star2g = ap.Model(name=\"star2-g\", model_type=\"point model\", window=[40, 90, 20, 70], center=star2r.center, psf=psfg, target=target_g)\n", + "star3g = ap.Model(name=\"star3-g\", model_type=\"point model\", window=[109, 150, 40, 90], center=star3r.center, psf=psfg, target=target_g)\n", + "modelg = ap.Model(name=\"model-g\", model_type=\"group model\", models=[star1g, star2g, star3g], target=target_g)\n", + "\n", + "# total model\n", + "target_full = ap.TargetImageList([target_r, target_g])\n", + "model = ap.Model(name=\"model\", model_type=\"group model\", models=[modelr, modelg], target=target_full)\n", + "\n", + "# fmt: on\n", + "fig, axarr = plt.subplots(1, 2, figsize=(15, 7))\n", + "ap.plots.target_image(fig, axarr, target_full)\n", + "axarr[0].set_title(\"Target Image (r-band)\")\n", + "axarr[1].set_title(\"Target Image (g-band)\")\n", + "ap.plots.model_window(fig, axarr[0], modelr)\n", + "ap.plots.model_window(fig, axarr[1], modelg)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "model.initialize()\n", + "res = ap.fit.LM(model, verbose=1).fit()\n", + "fig, axarr = plt.subplots(2, 2, figsize=(15, 10))\n", + "ap.plots.model_image(fig, axarr[0], model)\n", + "axarr[0, 0].set_title(\"Model Image (r-band)\")\n", + "axarr[0, 1].set_title(\"Model Image (g-band)\")\n", + "ap.plots.residual_image(fig, axarr[1], model)\n", + "axarr[1, 0].set_title(\"Residual Image (r-band)\")\n", + "axarr[1, 1].set_title(\"Residual Image (g-band)\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "Here we see a clear signal of an image misalignment, in the g-band all of the residuals have a dipole in the same direction! Lets free up the position of the g-band image and optimize a shift. This only requires a single line of code!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "target_g.crtan.to_dynamic()" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "Now we can optimize the model again, notice how it now has two more parameters. These are the x,y position of the image in the tangent plane. See the AstroPhot coordinate description on the website for more details on why this works." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "res = ap.fit.LM(model, verbose=1).fit()\n", + "fig, axarr = plt.subplots(2, 2, figsize=(15, 10))\n", + "ap.plots.model_image(fig, axarr[0], model)\n", + "axarr[0, 0].set_title(\"Model Image (r-band)\")\n", + "axarr[0, 1].set_title(\"Model Image (g-band)\")\n", + "ap.plots.residual_image(fig, axarr[1], model)\n", + "axarr[1, 0].set_title(\"Residual Image (r-band)\")\n", + "axarr[1, 1].set_title(\"Residual Image (g-band)\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "Yay! no more dipole. The fits aren't the best, clearly these objects aren't super well described by a single moffat model. But the main goal today was to show that we could align the images very easily. Note, its probably best to start with a reasonably good WCS from the outset, and this two stage approach where we optimize the models and then optimize the models plus a shift might be more stable than just fitting everything at once from the outset. Often for more complex models it is best to start with a simpler model and fit each time you introduce more complexity." + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "## Shift and rotation\n", + "\n", + "Lets say we really don't trust our WCS, we think something has gone wrong and we want freedom to fully shift and rotate the relative positions of the images relative to each other. How can we do this?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "def rotate(phi):\n", + " \"\"\"Create a 2D rotation matrix for a given angle in radians.\"\"\"\n", + " return torch.stack(\n", + " [\n", + " torch.stack([torch.cos(phi), -torch.sin(phi)]),\n", + " torch.stack([torch.sin(phi), torch.cos(phi)]),\n", + " ]\n", + " )\n", + "\n", + "\n", + "# Uh-oh! Our image is misaligned by some small angle\n", + "target_g.CD = target_g.CD.value @ rotate(torch.tensor(np.pi / 32, dtype=torch.float64))\n", + "# Uh-oh! our alignment from before has been erased\n", + "target_g.crtan.value = (0, 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axarr = plt.subplots(2, 2, figsize=(15, 10))\n", + "ap.plots.model_image(fig, axarr[0], model)\n", + "axarr[0, 0].set_title(\"Model Image (r-band)\")\n", + "axarr[0, 1].set_title(\"Model Image (g-band)\")\n", + "ap.plots.residual_image(fig, axarr[1], model)\n", + "axarr[1, 0].set_title(\"Residual Image (r-band)\")\n", + "axarr[1, 1].set_title(\"Residual Image (g-band)\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "Notice that there is not a universal dipole like in the shift example. Most of the offset is caused by the rotation in this example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "# this will control the relative rotation of the g-band image\n", + "phi = ap.Param(name=\"phi\", value=0.0, dynamic=True, dtype=torch.float64)\n", + "\n", + "# Set the target_g CD matrix to be a function of the rotation angle\n", + "# The CD matrix can encode rotation, skew, and rectangular pixels. We\n", + "# are only interested in the rotation here.\n", + "init_CD = target_g.CD.value.clone()\n", + "target_g.CD = lambda p: init_CD @ rotate(p.phi.value)\n", + "target_g.CD.link(phi)\n", + "\n", + "# also optimize the shift of the g-band image\n", + "target_g.crtan.to_dynamic()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "res = ap.fit.LM(model, verbose=1).fit()\n", + "fig, axarr = plt.subplots(2, 2, figsize=(15, 10))\n", + "ap.plots.model_image(fig, axarr[0], model)\n", + "axarr[0, 0].set_title(\"Model Image (r-band)\")\n", + "axarr[0, 1].set_title(\"Model Image (g-band)\")\n", + "ap.plots.residual_image(fig, axarr[1], model)\n", + "axarr[1, 0].set_title(\"Residual Image (r-band)\")\n", + "axarr[1, 1].set_title(\"Residual Image (g-band)\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/ImageTypes.ipynb b/docs/source/tutorials/ImageTypes.ipynb new file mode 100644 index 00000000..229d9e97 --- /dev/null +++ b/docs/source/tutorials/ImageTypes.ipynb @@ -0,0 +1,166 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Alternate Image Types\n", + "\n", + "AstroPhot operates in the tangent plane space and so must have a mapping between the pixels and the sky that it can use to properly perform integration within every pixel. Aside from the standard `ap.TargetImage` used to store regular data with a linear mapping between pixel space and the tangent plane, there are two more image types `ap.SIPTargetImage` and `ap.CMOSTargetImage` which are explained below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import astrophot as ap\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.patches import Rectangle" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## SIP Target Image\n", + "\n", + "The `ap.SIPTargetImage` object stores data for a pixel array that is distorted using Simple-Image-Polynomial distortions. This is a non-linear polynomial transformation that is used to account for optical effects in images that result in the sky being non-linearly projected onto the pixel grid used to collect data. AstroPhot follows the WCS standard when it comes to SIP distortions and can read the SIP coefficients directly from an image. AstroPhot can also save a SIP distortion model to a FITS image. Internally the SIP coefficients are stored in `image.sipA`, `image.sipB`, `image.sipAP` and `image.SIPBP` which are formatted as dictionaries with the keys as tuples of two integers giving the powers and the value as the coefficient. For example in a FITS file the header line `A_1_2 = 0.01` will translate to `image.sipA = {(1,2): 0.01}`. \n", + "\n", + "Some particulars of the AstroPhot implementation. For the sake of efficiency, when a SIP image is created AstroPhot evaluates the SIP distortion at every pixel and stores that in a distortion map with the same size as the image. Afterwards, calling `image.pixel_to_plane` will not evaluate the SIP polynomial, but instead a bilinear interpolation of the distortion model will be used. This massively increases speed, but means that the distortion model is only accurate up to the bilinear interpolation accuracy, since most SIP distortions are quite smooth, this interpolation is extremely accurate. For queries beyond the borders of the image, AstroPhot will not extrapolate the SIP polynomials, instead the distortion amount at the pixel border is simply carried onwards. As second element of the AstroPhot implementation is that if a backwards model (`AP` and `BP`) is not provided, then AstroPhot will use linear algebra to determine the backwards model. This is taken from the very clever code written by Shu Liu and Lei Hi that you [can find here](https://github.com/Roman-Supernova-PIT/sfft/blob/master/sfft/utils/CupyWCSTransform.py).\n", + "\n", + "For the most part, once you define a `ap.SIPTargetImage` you can use it like a regular `ap.TargetImage` object." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "target = ap.SIPTargetImage(\n", + " data=torch.randn(128, 256),\n", + " sipA={(0, 1): 1e-3, (1, 0): -1e-3, (1, 1): 1e-4, (2, 0): -5e-5, (0, 2): -5e-4},\n", + " sipB={(0, 1): 1e-3, (1, 0): -1e-3, (1, 1): -1e-3, (2, 0): 1e-4, (0, 2): 2e-3},\n", + ")\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 5))\n", + "ap.plots.target_image(fig, ax, target)\n", + "ax.set_title(\"SIP Target Image\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "Because the pixels are distorted on the sky, this means that the amount of area on the sky for each pixel is different. One would expect a pixel that projects to a larger area to collect more light than one that gets squished smaller. A uniform source observed through a telescope with SIP distortions will therefore produce a non-uniform image. As such, AstroPhot tracks the projected area of each pixel to ensure its calculations are accurate. Here is what that pixel area map looks like for the above image. As you can see, the parts which get stretched out then correspond to larger areas, and the parts that get squished correspond to smaller areas." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(target.pixel_area_map.T, cmap=\"inferno\", origin=\"lower\")\n", + "plt.colorbar(label=\"Pixel Area (arcsec$^2$)\")\n", + "plt.title(\"Pixel Area Map\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## CMOS Target Image\n", + "\n", + "A CMOS sensor is an alternative technology from a CCD for collecting light in an optical system. While it has certain advantages, one challenge with CMOS sensors is that only a sub region of each pixel is actually sensitive to light, the rest holding per-pixel electronics. This means there are gaps in the true placement of the CMOS pixels on the sky. Currently AstroPhot implements this by ensuring that the models are only sampled and integrated in the appropriate pixel areas. However, this treatment is not appropriate for certain PSF convolution modes and so the `ap.CMOSTargetImage` is under active development. Expect some changes in the future as we ensure it is viable for all model types. Currently, sky models, point source models, and un-convolved galaxy models should all work accurately. Adding convolved galaxy models is set for future work." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "target = ap.CMOSTargetImage(\n", + " data=torch.randn(128, 256),\n", + " subpixel_loc=(-0.1, -0.1),\n", + " subpixel_scale=0.8,\n", + ")\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 5))\n", + "ap.plots.target_image(fig, ax, target)\n", + "ax.set_title(\"CMOS Target Image\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "There is no visible difference when plotting the data as compressing every pixel in an image like above would make it hard to see what is happening. Below we plot what a single pixel truly looks like in the CMOS target representation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(5, 5))\n", + "r1 = Rectangle((-0.5, -0.5), 1, 1, facecolor=\"grey\", label=\"Pixel Area\")\n", + "ax.add_patch(r1)\n", + "r2 = Rectangle((-0.5, -0.5), 0.8, 0.8, facecolor=\"blue\", label=\"Subpixel Area\")\n", + "ax.add_patch(r2)\n", + "ax.set_xlim(-0.5, 0.5)\n", + "ax.set_ylim(-0.5, 0.5)\n", + "ax.set_title(\"CMOS Pixel Representation\")\n", + "ax.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "Where the blue subpixel area is actually sensitive to light. Note that pixel indexing places (0,0) at the center of the pixel and every pixel has size 1, so for the first pixel show here the pixel coordinates range from -0.5 to +0.5 on both axes. This is also the representation used to define a `ap.CMOSTargetImage` where `subpixel_loc` gives the pixel coordinates of the center of the subpixel and `subpixel_scale` gives the side length of the subpixel." + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/JointModels.ipynb b/docs/source/tutorials/JointModels.ipynb index 18aaa1b1..8b1eee03 100644 --- a/docs/source/tutorials/JointModels.ipynb +++ b/docs/source/tutorials/JointModels.ipynb @@ -6,7 +6,7 @@ "source": [ "# Joint Modelling\n", "\n", - "In this tutorial you will learn how to set up a joint modelling fit which encoporates the data from multiple images. These use `Group_Model` objects just like in the `GroupModels.ipynb` tutorial, the main difference being how the `Target_Image` object is constructed and that more care must be taken when assigning targets to models. \n", + "In this tutorial you will learn how to set up a joint modelling fit which encoporates the data from multiple images. These use `GroupModel` objects just like in the `GroupModels.ipynb` tutorial, the main difference being how the `TargetImage` object is constructed and that more care must be taken when assigning targets to models. \n", "\n", "It is, of course, more work to set up a fit across multiple target images. However, the tradeoff can be well worth it. Perhaps there is space-based data with high resolution, but groundbased data has better S/N. Or perhaps each band individually does not have enough signal for a confident fit, but all three together just might. Perhaps colour information is of paramount importance for a science goal, one would hope that both bands could be treated on equal footing but in a consistent way when extracting profile information. There are a number of reasons why one might wish to try and fit a multi image picture of a galaxy simultaneously. \n", "\n", @@ -20,12 +20,10 @@ "outputs": [], "source": [ "import astrophot as ap\n", - "import numpy as np\n", - "import torch\n", - "from astropy.io import fits\n", - "from astropy.wcs import WCS\n", "import matplotlib.pyplot as plt\n", - "from scipy.stats import iqr" + "import socket\n", + "\n", + "socket.setdefaulttimeout(120)" ] }, { @@ -37,60 +35,46 @@ "# First we need some data to work with, let's use LEDA 41136 as our example galaxy\n", "\n", "# The images must be aligned to a common coordinate system. From the DESI Legacy survey we are extracting\n", - "# each image from a common center coordinate, so we set the center as (0,0) for all the images and they\n", - "# should be aligned.\n", + "# each image using its RA and DEC coordinates, the WCS in the FITS header will ensure a common coordinate system.\n", "\n", "# It is also important to have a good estimate of the variance and the PSF for each image since these\n", "# affect the relative weight of each image. For the tutorial we use simple approximations, but in\n", "# science level analysis one should endeavor to get the best measure available for these.\n", "\n", "# Our first image is from the DESI Legacy-Survey r-band. This image has a pixelscale of 0.262 arcsec/pixel and is 500 pixels across\n", - "lrimg = fits.open(\n", - " \"https://www.legacysurvey.org/viewer/fits-cutout?ra=187.3119&dec=12.9783&size=500&layer=ls-dr9&pixscale=0.262&bands=r\"\n", - ")\n", - "target_r = ap.image.Target_Image(\n", - " data=np.array(lrimg[0].data, dtype=np.float64),\n", + "target_r = ap.TargetImage(\n", + " filename=\"https://www.legacysurvey.org/viewer/fits-cutout?ra=187.3119&dec=12.9783&size=500&layer=ls-dr9&pixscale=0.262&bands=r\",\n", " zeropoint=22.5,\n", " variance=\"auto\", # auto variance gets it roughly right, use better estimate for science!\n", - " psf=ap.utils.initialize.gaussian_psf(\n", - " 1.12 / 2.355, 51, 0.262\n", - " ), # we construct a basic gaussian psf for each image by giving the simga (arcsec), image width (pixels), and pixelscale (arcsec/pixel)\n", - " wcs=WCS(lrimg[0].header), # note pixelscale and origin not needed when we have a WCS object!\n", + " psf=ap.utils.initialize.gaussian_psf(1.12 / 2.355, 51, 0.262),\n", + " name=\"rband\",\n", ")\n", "\n", "\n", "# The second image is a unWISE W1 band image. This image has a pixelscale of 2.75 arcsec/pixel and is 52 pixels across\n", - "lw1img = fits.open(\n", - " \"https://www.legacysurvey.org/viewer/fits-cutout?ra=187.3119&dec=12.9783&size=52&layer=unwise-neo7&pixscale=2.75&bands=1\"\n", - ")\n", - "target_W1 = ap.image.Target_Image(\n", - " data=np.array(lw1img[0].data, dtype=np.float64),\n", + "target_W1 = ap.TargetImage(\n", + " filename=\"https://www.legacysurvey.org/viewer/fits-cutout?ra=187.3119&dec=12.9783&size=52&layer=unwise-neo7&pixscale=2.75&bands=1\",\n", " zeropoint=25.199,\n", " variance=\"auto\",\n", " psf=ap.utils.initialize.gaussian_psf(6.1 / 2.355, 21, 2.75),\n", - " wcs=WCS(lw1img[0].header),\n", - " reference_radec=target_r.window.reference_radec,\n", + " name=\"W1band\",\n", ")\n", "\n", "# The third image is a GALEX NUV band image. This image has a pixelscale of 1.5 arcsec/pixel and is 90 pixels across\n", - "lnuvimg = fits.open(\n", - " \"https://www.legacysurvey.org/viewer/fits-cutout?ra=187.3119&dec=12.9783&size=90&layer=galex&pixscale=1.5&bands=n\"\n", - ")\n", - "target_NUV = ap.image.Target_Image(\n", - " data=np.array(lnuvimg[0].data, dtype=np.float64),\n", + "target_NUV = ap.TargetImage(\n", + " filename=\"https://www.legacysurvey.org/viewer/fits-cutout?ra=187.3119&dec=12.9783&size=90&layer=galex&pixscale=1.5&bands=n\",\n", " zeropoint=20.08,\n", " variance=\"auto\",\n", " psf=ap.utils.initialize.gaussian_psf(5.4 / 2.355, 21, 1.5),\n", - " wcs=WCS(lnuvimg[0].header),\n", - " reference_radec=target_r.window.reference_radec,\n", + " name=\"NUVband\",\n", ")\n", "\n", "fig1, ax1 = plt.subplots(1, 3, figsize=(18, 6))\n", - "ap.plots.target_image(fig1, ax1[0], target_r, flipx=True)\n", + "ap.plots.target_image(fig1, ax1[0], target_r)\n", "ax1[0].set_title(\"r-band image\")\n", - "ap.plots.target_image(fig1, ax1[1], target_W1, flipx=True)\n", + "ap.plots.target_image(fig1, ax1[1], target_W1)\n", "ax1[1].set_title(\"W1-band image\")\n", - "ap.plots.target_image(fig1, ax1[2], target_NUV, flipx=True)\n", + "ap.plots.target_image(fig1, ax1[2], target_NUV)\n", "ax1[2].set_title(\"NUV-band image\")\n", "plt.show()" ] @@ -103,7 +87,7 @@ "source": [ "# The joint model will need a target to try and fit, but now that we have multiple images the \"target\" is\n", "# a Target_Image_List object which points to all three.\n", - "target_full = ap.image.Target_Image_List((target_r, target_W1, target_NUV))\n", + "target_full = ap.TargetImageList((target_r, target_W1, target_NUV))\n", "# It doesn't really need any other information since everything is already available in the individual targets" ] }, @@ -116,23 +100,29 @@ "# To make things easy to start, lets just fit a sersic model to all three. In principle one can use arbitrary\n", "# group models designed for each band individually, but that would be unnecessarily complex for a tutorial\n", "\n", - "model_r = ap.models.AstroPhot_Model(\n", + "model_r = ap.Model(\n", " name=\"rband model\",\n", " model_type=\"sersic galaxy model\",\n", " target=target_r,\n", - " psf_mode=\"full\",\n", + " psf_convolve=True,\n", ")\n", - "model_W1 = ap.models.AstroPhot_Model(\n", + "\n", + "model_W1 = ap.Model(\n", " name=\"W1band model\",\n", " model_type=\"sersic galaxy model\",\n", " target=target_W1,\n", - " psf_mode=\"full\",\n", + " center=[0, 0],\n", + " PA=-2.3,\n", + " psf_convolve=True,\n", ")\n", - "model_NUV = ap.models.AstroPhot_Model(\n", + "\n", + "model_NUV = ap.Model(\n", " name=\"NUVband model\",\n", " model_type=\"sersic galaxy model\",\n", " target=target_NUV,\n", - " psf_mode=\"full\",\n", + " center=[0, 0],\n", + " PA=-2.3,\n", + " psf_convolve=True,\n", ")\n", "\n", "# At this point we would just be fitting three separate models at the same time, not very interesting. Next\n", @@ -141,7 +131,7 @@ "for p in [\"center\", \"q\", \"PA\", \"n\", \"Re\"]:\n", " model_W1[p].value = model_r[p]\n", " model_NUV[p].value = model_r[p]\n", - "# Now every model will have a unique Ie, but every other parameter is shared for all three" + "# Now every model will have a unique Ie, but every other parameter is shared" ] }, { @@ -152,7 +142,7 @@ "source": [ "# We can now make the joint model object\n", "\n", - "model_full = ap.models.AstroPhot_Model(\n", + "model_full = ap.Model(\n", " name=\"LEDA 41136\",\n", " model_type=\"group model\",\n", " models=[model_r, model_W1, model_NUV],\n", @@ -160,7 +150,7 @@ ")\n", "\n", "model_full.initialize()\n", - "model_full.parameters" + "model_full.graphviz()" ] }, { @@ -182,28 +172,15 @@ "# here we plot the results of the fitting, notice that each band has a different PSF and pixelscale. Also, notice\n", "# that the colour bars represent significantly different ranges since each model was allowed to fit its own Ie.\n", "# meanwhile the center, PA, q, and Re is the same for every model.\n", - "fig1, ax1 = plt.subplots(1, 3, figsize=(18, 6))\n", - "ap.plots.model_image(fig1, ax1, model_full, flipx=True)\n", - "ax1[0].set_title(\"r-band model image\")\n", - "ax1[1].set_title(\"W1-band model image\")\n", - "ax1[2].set_title(\"NUV-band model image\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# We can also plot the residual images. As can be seen, the galaxy is fit in all three bands simultaneously\n", - "# with the majority of the light removed in all bands. A residual can be seen in the r band. This is likely\n", - "# due to there being more structure in the r-band than just a sersic. The W1 and NUV bands look excellent though\n", - "fig1, ax1 = plt.subplots(1, 3, figsize=(18, 6))\n", - "ap.plots.residual_image(fig1, ax1, model_full, flipx=True, normalize_residuals=True)\n", - "ax1[0].set_title(\"r-band residual image\")\n", - "ax1[1].set_title(\"W1-band residual image\")\n", - "ax1[2].set_title(\"NUV-band residual image\")\n", + "fig1, ax1 = plt.subplots(2, 3, figsize=(18, 12))\n", + "ap.plots.model_image(fig1, ax1[0], model_full)\n", + "ax1[0][0].set_title(\"r-band model image\")\n", + "ax1[0][1].set_title(\"W1-band model image\")\n", + "ax1[0][2].set_title(\"NUV-band model image\")\n", + "ap.plots.residual_image(fig1, ax1[1], model_full, normalize_residuals=True)\n", + "ax1[1][0].set_title(\"r-band residual image\")\n", + "ax1[1][1].set_title(\"W1-band residual image\")\n", + "ax1[1][2].set_title(\"NUV-band residual image\")\n", "plt.show()" ] }, @@ -232,58 +209,39 @@ "DEC = 15.5512\n", "# Our first image is from the DESI Legacy-Survey r-band. This image has a pixelscale of 0.262 arcsec/pixel\n", "rsize = 90\n", - "rimg = fits.open(\n", - " f\"https://www.legacysurvey.org/viewer/fits-cutout?ra={RA}&dec={DEC}&size={rsize}&layer=ls-dr9&pixscale=0.262&bands=r\"\n", - ")\n", - "rimg_data = np.array(rimg[0].data, dtype=np.float64)\n", - "rwcs = WCS(rimg[0].header)\n", - "\n", - "# dont do this unless you've read and understand the coordinates explainer in the docs!\n", - "ref_loc = rwcs.pixel_to_world(0, 0)\n", - "target_r.header.reference_radec = (ref_loc.ra.deg, ref_loc.dec.deg)\n", "\n", "# Now we make our targets\n", - "target_r = ap.image.Target_Image(\n", - " data=rimg_data,\n", + "target_r = ap.image.TargetImage(\n", + " filename=f\"https://www.legacysurvey.org/viewer/fits-cutout?ra={RA}&dec={DEC}&size={rsize}&layer=ls-dr9&pixscale=0.262&bands=r\",\n", " zeropoint=22.5,\n", - " variance=\"auto\", # Note that the variance is important to ensure all images are compared with proper statistical weight. Use better estimate than auto for science!\n", - " psf=ap.utils.initialize.gaussian_psf(\n", - " 1.12 / 2.355, 51, 0.262\n", - " ), # we construct a basic gaussian psf for each image by giving the simga (arcsec), image width (pixels), and pixelscale (arcsec/pixel)\n", - " wcs=rwcs,\n", + " variance=\"auto\",\n", + " psf=ap.utils.initialize.gaussian_psf(1.12 / 2.355, 51, 0.262),\n", + " name=\"rband\",\n", ")\n", "\n", "# The second image is a unWISE W1 band image. This image has a pixelscale of 2.75 arcsec/pixel\n", "wsize = int(rsize * 0.262 / 2.75)\n", - "w1img = fits.open(\n", - " f\"https://www.legacysurvey.org/viewer/fits-cutout?ra={RA}&dec={DEC}&size={wsize}&layer=unwise-neo7&pixscale=2.75&bands=1\"\n", - ")\n", - "target_W1 = ap.image.Target_Image(\n", - " data=np.array(w1img[0].data, dtype=np.float64),\n", + "target_W1 = ap.image.TargetImage(\n", + " filename=f\"https://www.legacysurvey.org/viewer/fits-cutout?ra={RA}&dec={DEC}&size={wsize}&layer=unwise-neo7&pixscale=2.75&bands=1\",\n", " zeropoint=25.199,\n", " variance=\"auto\",\n", " psf=ap.utils.initialize.gaussian_psf(6.1 / 2.355, 21, 2.75),\n", - " wcs=WCS(w1img[0].header),\n", - " reference_radec=target_r.window.reference_radec,\n", + " name=\"W1band\",\n", ")\n", "\n", "# The third image is a GALEX NUV band image. This image has a pixelscale of 1.5 arcsec/pixel\n", "gsize = int(rsize * 0.262 / 1.5)\n", - "nuvimg = fits.open(\n", - " f\"https://www.legacysurvey.org/viewer/fits-cutout?ra={RA}&dec={DEC}&size={gsize}&layer=galex&pixscale=1.5&bands=n\"\n", - ")\n", - "target_NUV = ap.image.Target_Image(\n", - " data=np.array(nuvimg[0].data, dtype=np.float64),\n", + "target_NUV = ap.image.TargetImage(\n", + " filename=f\"https://www.legacysurvey.org/viewer/fits-cutout?ra={RA}&dec={DEC}&size={gsize}&layer=galex&pixscale=1.5&bands=n\",\n", " zeropoint=20.08,\n", " variance=\"auto\",\n", " psf=ap.utils.initialize.gaussian_psf(5.4 / 2.355, 21, 1.5),\n", - " wcs=WCS(nuvimg[0].header),\n", - " reference_radec=target_r.window.reference_radec,\n", + " name=\"NUVband\",\n", ")\n", - "target_full = ap.image.Target_Image_List((target_r, target_W1, target_NUV))\n", + "target_full = ap.image.TargetImageList((target_r, target_W1, target_NUV))\n", "\n", "fig1, ax1 = plt.subplots(1, 3, figsize=(18, 6))\n", - "ap.plots.target_image(fig1, ax1, target_full, flipx=True)\n", + "ap.plots.target_image(fig1, ax1, target_full)\n", "ax1[0].set_title(\"r-band image\")\n", "ax1[1].set_title(\"W1-band image\")\n", "ax1[2].set_title(\"NUV-band image\")\n", @@ -303,8 +261,9 @@ "#########################################\n", "from photutils.segmentation import detect_sources, deblend_sources\n", "\n", - "initsegmap = detect_sources(rimg_data, threshold=0.01, npixels=10)\n", - "segmap = deblend_sources(rimg_data, initsegmap, npixels=5).data\n", + "rdata = target_r.data.detach().cpu().numpy()\n", + "initsegmap = detect_sources(rdata, threshold=0.01, npixels=10)\n", + "segmap = deblend_sources(rdata, initsegmap, npixels=5).data\n", "fig8, ax8 = plt.subplots(figsize=(8, 8))\n", "ax8.imshow(segmap, origin=\"lower\", cmap=\"inferno\")\n", "plt.show()\n", @@ -312,17 +271,15 @@ "rwindows = ap.utils.initialize.windows_from_segmentation_map(segmap)\n", "# Next we scale up the windows so that AstroPhot can fit the faint parts of each object as well\n", "rwindows = ap.utils.initialize.scale_windows(\n", - " rwindows, image_shape=rimg_data.shape, expand_scale=1.5, expand_border=10\n", + " rwindows, image=target_r, expand_scale=1.5, expand_border=10\n", ")\n", "w1windows = ap.utils.initialize.transfer_windows(rwindows, target_r, target_W1)\n", - "w1windows = ap.utils.initialize.scale_windows(\n", - " w1windows, image_shape=w1img[0].data.shape, expand_border=1\n", - ")\n", + "w1windows = ap.utils.initialize.scale_windows(w1windows, image=target_W1, expand_border=1)\n", "nuvwindows = ap.utils.initialize.transfer_windows(rwindows, target_r, target_NUV)\n", "# Here we get some basic starting parameters for the galaxies (center, position angle, axis ratio)\n", - "centers = ap.utils.initialize.centroids_from_segmentation_map(segmap, rimg_data)\n", - "PAs = ap.utils.initialize.PA_from_segmentation_map(segmap, rimg_data, centers)\n", - "qs = ap.utils.initialize.q_from_segmentation_map(segmap, rimg_data, centers, PAs)" + "centers = ap.utils.initialize.centroids_from_segmentation_map(segmap, target_r)\n", + "PAs = ap.utils.initialize.PA_from_segmentation_map(segmap, target_r, centers)\n", + "qs = ap.utils.initialize.q_from_segmentation_map(segmap, target_r, centers)" ] }, { @@ -346,35 +303,33 @@ " # create the submodels for this object\n", " sub_list = []\n", " sub_list.append(\n", - " ap.models.AstroPhot_Model(\n", + " ap.Model(\n", " name=f\"rband model {i}\",\n", " model_type=\"sersic galaxy model\", # we could use spline models for the r-band since it is well resolved\n", " target=target_r,\n", " window=rwindows[window],\n", - " psf_mode=\"full\",\n", - " parameters={\n", - " \"center\": target_r.pixel_to_plane(torch.tensor(centers[window])),\n", - " \"PA\": -PAs[window],\n", - " \"q\": qs[window],\n", - " },\n", + " psf_convolve=True,\n", + " center=centers[window],\n", + " PA=PAs[window],\n", + " q=qs[window],\n", " )\n", " )\n", " sub_list.append(\n", - " ap.models.AstroPhot_Model(\n", + " ap.Model(\n", " name=f\"W1band model {i}\",\n", " model_type=\"sersic galaxy model\",\n", " target=target_W1,\n", " window=w1windows[window],\n", - " psf_mode=\"full\",\n", + " psf_convolve=True,\n", " )\n", " )\n", " sub_list.append(\n", - " ap.models.AstroPhot_Model(\n", + " ap.Model(\n", " name=f\"NUVband model {i}\",\n", " model_type=\"sersic galaxy model\",\n", " target=target_NUV,\n", " window=nuvwindows[window],\n", - " psf_mode=\"full\",\n", + " psf_convolve=True,\n", " )\n", " )\n", " # ensure equality constraints\n", @@ -385,7 +340,7 @@ "\n", " # Make the multiband model for this object\n", " model_list.append(\n", - " ap.models.AstroPhot_Model(\n", + " ap.Model(\n", " name=f\"model {i}\",\n", " model_type=\"group model\",\n", " target=target_full,\n", @@ -393,14 +348,14 @@ " )\n", " )\n", "# Make the full model for this system of objects\n", - "MODEL = ap.models.AstroPhot_Model(\n", + "MODEL = ap.Model(\n", " name=f\"full model\",\n", " model_type=\"group model\",\n", " target=target_full,\n", " models=model_list,\n", ")\n", "fig, ax = plt.subplots(1, 3, figsize=(16, 5))\n", - "ap.plots.target_image(fig, ax, MODEL.target, flipx=True)\n", + "ap.plots.target_image(fig, ax, MODEL.target)\n", "ap.plots.model_window(fig, ax, MODEL)\n", "ax[0].set_title(\"r-band image\")\n", "ax[1].set_title(\"W1-band image\")\n", @@ -415,7 +370,15 @@ "outputs": [], "source": [ "MODEL.initialize()\n", - "\n", + "MODEL.graphviz()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "# We give it only one iteration for runtime/demo purposes, you should let these algorithms run to convergence\n", "result = ap.fit.Iter(MODEL, verbose=1, max_iter=1).fit()" ] @@ -426,11 +389,15 @@ "metadata": {}, "outputs": [], "source": [ - "fig1, ax1 = plt.subplots(1, 3, figsize=(18, 4))\n", - "ap.plots.model_image(fig1, ax1, MODEL, flipx=True, vmax=30)\n", - "ax1[0].set_title(\"r-band model image\")\n", - "ax1[1].set_title(\"W1-band model image\")\n", - "ax1[2].set_title(\"NUV-band model image\")\n", + "fig1, ax1 = plt.subplots(2, 3, figsize=(18, 11))\n", + "ap.plots.model_image(fig1, ax1[0], MODEL, vmax=30)\n", + "ax1[0][0].set_title(\"r-band model image\")\n", + "ax1[0][1].set_title(\"W1-band model image\")\n", + "ax1[0][2].set_title(\"NUV-band model image\")\n", + "ap.plots.residual_image(fig1, ax1[1], MODEL, normalize_residuals=True)\n", + "ax1[1][0].set_title(\"r-band residual image\")\n", + "ax1[1][1].set_title(\"W1-band residual image\")\n", + "ax1[1][2].set_title(\"NUV-band residual image\")\n", "plt.show()" ] }, @@ -443,20 +410,6 @@ "An important note here is that the SB levels for the W1 and NUV data are quire reasonable. While the structure (center, PA, q, n, Re) was shared between bands and therefore mostly driven by the r-band, the brightness is entirely independent between bands meaning the Ie (and therefore SB) values are right from the W1 and NUV data!" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 3, figsize=(18, 6))\n", - "ap.plots.residual_image(fig, ax, MODEL, flipx=True, normalize_residuals=True)\n", - "ax[0].set_title(\"r-band residual image\")\n", - "ax[1].set_title(\"W1-band residual image\")\n", - "ax[2].set_title(\"NUV-band residual image\")\n", - "plt.show()" - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/docs/source/tutorials/ModelZoo.ipynb b/docs/source/tutorials/ModelZoo.ipynb index cc8a5307..0dbaec62 100644 --- a/docs/source/tutorials/ModelZoo.ipynb +++ b/docs/source/tutorials/ModelZoo.ipynb @@ -21,14 +21,15 @@ "source": [ "%load_ext autoreload\n", "%autoreload 2\n", + "%matplotlib inline\n", "\n", "import astrophot as ap\n", "import numpy as np\n", - "import torch\n", "import matplotlib.pyplot as plt\n", + "import matplotlib.animation as animation\n", + "from IPython.display import HTML\n", "\n", - "%matplotlib inline\n", - "basic_target = ap.image.Target_Image(data=np.zeros((100, 100)), pixelscale=1, zeropoint=20)" + "basic_target = ap.TargetImage(data=np.zeros((100, 100)), pixelscale=1, zeropoint=20)" ] }, { @@ -51,11 +52,7 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"flat sky model\", parameters={\"center\": [50, 50], \"F\": 1}, target=basic_target\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", + "M = ap.Model(model_type=\"flat sky model\", center=[50, 50], I=1, target=basic_target)\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(figsize=(7, 6))\n", @@ -77,13 +74,42 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", + "M = ap.Model(\n", " model_type=\"plane sky model\",\n", - " parameters={\"center\": [50, 50], \"F\": 10, \"delta\": [1e-2, 2e-2]},\n", + " center=[50, 50],\n", + " I0=10,\n", + " delta=[1e-2, 2e-2],\n", + " target=basic_target,\n", + ")\n", + "M.initialize()\n", + "\n", + "fig, ax = plt.subplots(figsize=(7, 6))\n", + "ap.plots.model_image(fig, ax, M)\n", + "ax.set_title(M.name)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Bilinear Sky Model\n", + "\n", + "This allows for a complex sky model which can vary arbitrarily as a function of position. Here we plot a sky that is just noise, but one would typically make it smoothly varying. The noise sky makes the nature of bilinear interpolation very clear, large flux changes can create sharp edges in the reconstruction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "M = ap.Model(\n", + " model_type=\"bilinear sky model\",\n", + " I=np.random.uniform(0, 1, (5, 5)) + 1,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(figsize=(7, 6))\n", @@ -100,12 +126,12 @@ "\n", "These models are well suited to describe stars or any other point like source of light, they may also be used to convolve with other models during optimization. Some things to keep in mind about PSF models:\n", "\n", - "- Their \"target\" should be a PSF_Image\n", + "- Their \"target\" should be a `PSFImage` object\n", "- They are always centered at (0,0) so there is no need to optimize the center position\n", "- Their total flux is typically normalized to 1, so no need to optimize any normalization parameters\n", - "- They can be used in a lot of places that a PSF_Image can be used, such as the convolution kernel for a model\n", + "- They can be used in a lot of places that a `PSFImage` can be used, such as the convolution kernel for a model\n", "\n", - "They behave a bit differently than other models, see the point source model further down. A PSF describes the abstract point source light distribution, to actually model a star in a field you will need a point source object (further down) which is convolved by a PSF model." + "They behave a bit differently than other models, see the point source model further down. A PSF describes the abstract point source light distribution, to actually model a star in a field you will need a `point model` object (further down) to represent a delta function of brightness with some total flux." ] }, { @@ -122,7 +148,7 @@ "psf += np.random.normal(scale=psf / 3)\n", "psf[psf < 0] = ap.utils.initialize.gaussian_psf(3.0, 101, 1.0)[psf < 0] + 1e-10\n", "\n", - "psf_target = ap.image.PSF_Image(\n", + "psf_target = ap.PSFImage(\n", " data=psf / np.sum(psf),\n", " pixelscale=1,\n", ")\n", @@ -155,15 +181,13 @@ "wgt = np.array((0.0001, 0.01, 1.0, 0.01, 0.0001))\n", "PSF[48:53] += (sinc(x[48:53]) ** 2) * wgt.reshape((-1, 1))\n", "PSF[:, 48:53] += (sinc(x[:, 48:53]) ** 2) * wgt\n", - "PSF = ap.image.PSF_Image(data=PSF, pixelscale=psf_target.pixelscale)\n", + "PSF = ap.PSFImage(data=PSF, pixelscale=psf_target.pixelscale)\n", "\n", - "M = ap.models.AstroPhot_Model(\n", + "M = ap.Model(\n", " model_type=\"pixelated psf model\",\n", " target=psf_target,\n", - " parameters={\"pixels\": np.log10(PSF.data / psf_target.pixel_area)},\n", + " pixels=PSF.data / psf_target.pixel_area,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -190,13 +214,8 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"gaussian psf model\", parameters={\"sigma\": 10}, target=psf_target\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", + "M = ap.Model(model_type=\"gaussian psf model\", sigma=10, target=psf_target)\n", "M.initialize()\n", - "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", "ap.plots.psf_image(fig, ax[0], M)\n", "ap.plots.radial_light_profile(fig, ax[1], M)\n", @@ -217,11 +236,7 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"moffat psf model\", parameters={\"n\": 2.0, \"Rd\": 10.0}, target=psf_target\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", + "M = ap.Model(model_type=\"moffat psf model\", n=2.0, Rd=10.0, target=psf_target)\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -246,13 +261,14 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"moffat2d psf model\",\n", - " parameters={\"n\": 2.0, \"Rd\": 10.0, \"q\": 0.7, \"PA\": 3.14 / 3},\n", + "M = ap.Model(\n", + " model_type=\"2d moffat psf model\",\n", + " n=2.0,\n", + " Rd=10.0,\n", + " q=0.7,\n", + " PA=3.14 / 3,\n", " target=psf_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -275,13 +291,11 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", + "M = ap.Model(\n", " model_type=\"airy psf model\",\n", - " parameters={\"aRL\": 1.0 / 20},\n", + " aRL=1.0 / 20,\n", " target=psf_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -295,40 +309,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Zernike Polynomial PSF" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"zernike psf model\", order_n=4, integrate_mode=\"none\", target=psf_target\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", + "### Basis PSF\n", "\n", - "fig, axarr = plt.subplots(3, 5, figsize=(18, 10))\n", - "for i, ax in enumerate(axarr.flatten()):\n", - " Anm = torch.zeros_like(M[\"Anm\"].value)\n", - " Anm[0] = 1.0\n", - " Anm[i] = 1.0\n", - " M[\"Anm\"].value = Anm\n", - " ax.set_title(f\"n: {M.nm_list[i][0]} m: {M.nm_list[i][1]}\")\n", - " ap.plots.psf_image(fig, ax, M, norm=None)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Eigen basis PSF point source\n", + "A basis psf model allows one to provide a series of images such as an Eigen decomposition or a Zernike polynomial (or any other basis one likes). The weight of each component is fit to determine the final model. If a suitable basis is chosen then it is possible to encode highly complex models with only a few free parameters as the weights. \n", "\n", - "An eigen basis is a set of images which can be combined to form a PSF model. The eigen basis model makes it possible to fit the coefficients for the basis as model parameters. In fact the zernike polynomials are a kind of basis, so we will use them as input to the eigen psf model." + "For the `basis` argument one may provide the basis manually (N imgs, H, W) or simply provide `\"zernike:n\"` where `n` gives the Zernike order up to which will be fit.\n", + "\n", + "As the basis may be provided manually, one can even provide a base PSF model as the first component and then use the Zernike coefficients as perturbations." ] }, { @@ -337,39 +324,23 @@ "metadata": {}, "outputs": [], "source": [ - "super_basic_target = ap.image.Target_Image(data=np.zeros((101, 101)), pixelscale=1)\n", - "Z = ap.models.AstroPhot_Model(\n", - " model_type=\"zernike psf model\", order_n=4, integrate_mode=\"none\", target=psf_target\n", - ")\n", - "Z.initialize()\n", - "basis = []\n", - "for i in range(10):\n", - " Anm = torch.zeros_like(Z[\"Anm\"].value)\n", - " Anm[0] = 1.0\n", - " Anm[i] = 1.0\n", - " Z[\"Anm\"].value = Anm\n", - " basis.append(Z().data)\n", - "basis = torch.stack(basis)\n", - "\n", - "W = np.linspace(1, 0.1, 10)\n", - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"eigen psf model\",\n", - " eigen_basis=basis,\n", - " eigen_pixelscale=1,\n", - " parameters={\"weights\": W},\n", - " target=psf_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", + "w = [1.5, 0, 0, 0.0, -0.5, 0, 0.5, 0, 0, 0, 0.0, 0, 1, 0, 0]\n", + "M = ap.Model(model_type=\"basis psf model\", basis=\"zernike:4\", weights=w, target=psf_target)\n", "M.initialize()\n", - "\n", + "nm_list = ap.models.func.zernike_n_m_list(4)\n", + "fig, axarr = plt.subplots(3, 5, figsize=(18, 10))\n", + "for i, ax in enumerate(axarr.flatten()):\n", + " ax.set_title(f\"n: {nm_list[i][0]} m: {nm_list[i][1]}\")\n", + " ax.imshow(M.basis[i], cmap=\"RdBu_r\", origin=\"lower\")\n", + " plt.colorbar(ax.images[0], ax=ax, fraction=0.046, pad=0.04)\n", + " ax.axis(\"off\")\n", + "plt.show()\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.psf_image(fig, ax[0], M, norm=None)\n", - "W = np.random.rand(10)\n", - "M[\"weights\"].value = W\n", - "ap.plots.psf_image(fig, ax[1], M, norm=None)\n", - "ax[0].set_title(M.name)\n", - "ax[1].set_title(\"random weights\")\n", + "ap.plots.psf_image(fig, ax[0], M, vmin=5e-5)\n", + "ax[1].plot(np.arange(1, 16), M.weights.value.numpy(), marker=\"o\")\n", + "ax[1].set_xlabel(\"Zernike mode index\")\n", + "ax[1].set_ylabel(\"Weight\")\n", + "ax[0].set_title(\"Zernike basis PSF model\")\n", "plt.show()" ] }, @@ -379,14 +350,14 @@ "source": [ "## The Point Source Model\n", "\n", - "This model is used to represent point sources in the sky. It is effectively a delta function at a given position with a given flux. Otherwise it has no structure. You must provide it a PSF model so that it can project into the sky." + "This model is used to represent point sources in the sky such as stars, supernovae, asteroids, small galaxies, quasars, and more. It is effectively a delta function at a given position with a given flux. Otherwise it has no structure. You must provide it a PSF model so that it can project into the sky. That PSF model may take the form of an image (`PSFImage` object) or may itself be a psf model with its own parameters." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Point Source using PSF_Image" + "### Point Source using PSFImage" ] }, { @@ -395,15 +366,15 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", + "M = ap.Model(\n", " model_type=\"point model\",\n", - " parameters={\"center\": [50, 50], \"flux\": 1},\n", + " center=[50, 50],\n", + " flux=10,\n", " psf=psf_target,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", + "M.to()\n", "\n", "fig, ax = plt.subplots(figsize=(7, 6))\n", "ap.plots.model_image(fig, ax, M)\n", @@ -424,24 +395,18 @@ "metadata": {}, "outputs": [], "source": [ - "psf = ap.models.AstroPhot_Model(\n", - " model_type=\"moffat psf model\", parameters={\"n\": 2.0, \"Rd\": 10.0}, target=psf_target\n", - ")\n", + "psf = ap.Model(model_type=\"moffat psf model\", n=2.0, Rd=10.0, target=psf_target)\n", "psf.initialize()\n", "\n", - "M = ap.models.AstroPhot_Model(\n", + "M = ap.Model(\n", " model_type=\"point model\",\n", - " parameters={\"center\": [50, 50], \"flux\": 1},\n", + " center=[50, 50],\n", + " flux=1,\n", " psf=psf,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", - "\n", - "# Note that the PSF model now shows up as a \"parameter\" for the point model. In fact this is just a pointer to the PSF parameter graph which you can see by printing the parameters\n", - "print(M.parameters)\n", - "\n", + "print(M)\n", "fig, ax = plt.subplots(figsize=(7, 6))\n", "ap.plots.model_image(fig, ax, M)\n", "ax.set_title(M.name)\n", @@ -452,7 +417,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Galaxy Models" + "## Primary Galaxy Models\n", + "\n", + "These models are represented mostly by their radial profile and are numerically straightforward to work with. All of these models also have perturbative extensions described below in the SuperEllipse, Fourier, Warp, Ray, and Wedge sections." ] }, { @@ -472,24 +439,20 @@ "source": [ "# Here we make an arbitrary spline profile out of a sine wave and a line\n", "x = np.linspace(0, 10, 14)\n", - "spline_profile = np.sin(x * 2 + 2) / 20 + 1 - x / 20\n", + "spline_profile = np.array(list((np.sin(x * 2 + 2) / 20 + 1 - x / 20)) + [-4])\n", "# Here we write down some corresponding radii for the points in the non-parametric profile. AstroPhot will make\n", "# radii to match an input profile, but it is generally better to manually provide values so you have some control\n", "# over their placement. Just note that it is assumed the first point will be at R = 0.\n", - "NP_prof = [0] + list(np.logspace(np.log10(2), np.log10(50), 13))\n", + "NP_prof = [0] + list(np.logspace(np.log10(2), np.log10(50), 13)) + [200]\n", "\n", - "M = ap.models.AstroPhot_Model(\n", + "M = ap.Model(\n", " model_type=\"spline galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"I(R)\": {\"value\": spline_profile, \"prof\": NP_prof},\n", - " },\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " I_R={\"value\": 10**spline_profile, \"prof\": NP_prof},\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -512,13 +475,16 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", + "M = ap.Model(\n", " model_type=\"sersic galaxy model\",\n", - " parameters={\"center\": [50, 50], \"q\": 0.6, \"PA\": 60 * np.pi / 180, \"n\": 3, \"Re\": 10, \"Ie\": 1},\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " n=3,\n", + " Re=10,\n", + " Ie=10,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -541,13 +507,15 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", + "M = ap.Model(\n", " model_type=\"exponential galaxy model\",\n", - " parameters={\"center\": [50, 50], \"q\": 0.6, \"PA\": 60 * np.pi / 180, \"Re\": 10, \"Ie\": 1},\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " Re=10,\n", + " Ie=1,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -570,13 +538,15 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", + "M = ap.Model(\n", " model_type=\"gaussian galaxy model\",\n", - " parameters={\"center\": [50, 50], \"q\": 0.6, \"PA\": 60 * np.pi / 180, \"sigma\": 20, \"flux\": 1},\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " sigma=20,\n", + " flux=10,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -599,22 +569,18 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", + "M = ap.Model(\n", " model_type=\"nuker galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"Rb\": 10.0,\n", - " \"Ib\": 1.0,\n", - " \"alpha\": 4.0,\n", - " \"beta\": 3.0,\n", - " \"gamma\": -0.2,\n", - " },\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " Rb=10.0,\n", + " Ib=10.0,\n", + " alpha=4.0,\n", + " beta=3.0,\n", + " gamma=-0.2,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -628,9 +594,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Edge on model\n", - "\n", - "Currently there is only one dedicared edge on model, the self gravitating isothermal disk from van der Kruit & Searle 1981. If you know of another common edge on model, feel free to let us know and we can add it in!" + "### Ferrer Model" ] }, { @@ -639,17 +603,21 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"isothermal sech2 edgeon model\",\n", - " parameters={\"center\": [50, 50], \"PA\": 60 * np.pi / 180, \"I0\": 0.0, \"hs\": 3.0, \"rs\": 5.0},\n", + "M = ap.Model(\n", + " model_type=\"ferrer galaxy model\",\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " rout=40.0,\n", + " alpha=2.0,\n", + " beta=1.0,\n", + " I0=10.0,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", + "ap.plots.model_image(fig, ax[0], M, vmax=30)\n", "ap.plots.radial_light_profile(fig, ax[1], M)\n", "ax[0].set_title(M.name)\n", "plt.show()" @@ -659,54 +627,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Multi Gaussian Expansion\n", + "### King Model\n", "\n", - "A multi gaussian expansion is essentially a model made of overlapping gaussian models that share the same center. However, they are combined into a single model for computational efficiency. Another advantage of the MGE is that it is possible to determine a deprojection of the model from 2D into a 3D shape since the projection of a 3D gaussian is a 2D gaussian. Note however, that in some configurations this deprojection is not unique. See Cappellari 2002 for more details.\n", - "\n", - "Note: The ``PA`` can be either a single value (same for all components) or an array with values for each component." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"mge model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": [0.9, 0.8, 0.6, 0.5],\n", - " \"PA\": 30 * np.pi / 180,\n", - " \"sigma\": [4.0, 8.0, 16.0, 32.0],\n", - " \"flux\": np.ones(4) / 4,\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 1, figsize=(6, 6))\n", - "ap.plots.model_image(fig, ax, M)\n", - "ax.set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Super Ellipse Models\n", - "\n", - "A super ellipse is a regular ellipse, except the radius metric changes from R = sqrt(x^2 + y^2) to the more general: R = (x^C + y^C)^1/C. The parameter C = 2 for a regular ellipse, for 0 2 the shape becomes more \"boxy.\" In AstroPhot we use the parameter C0 = C-2 for simplicity." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Spline SuperEllipse" + "This is the Empirical King model with the extra free parameter $\\alpha$" ] }, { @@ -715,23 +638,21 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"spline superellipse galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"C0\": 2,\n", - " \"I(R)\": {\"value\": spline_profile, \"prof\": NP_prof},\n", - " },\n", + "M = ap.Model(\n", + " model_type=\"king galaxy model\",\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " Rc=10.0,\n", + " Rt=40.0,\n", + " alpha=2.01,\n", + " I0=10.0,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", + "ap.plots.model_image(fig, ax[0], M, vmax=30)\n", "ap.plots.radial_light_profile(fig, ax[1], M)\n", "ax[0].set_title(M.name)\n", "plt.show()" @@ -741,44 +662,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Sersic SuperEllipse" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"sersic superellipse galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"C0\": 2,\n", - " \"n\": 3,\n", - " \"Re\": 10,\n", - " \"Ie\": 1,\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" + "## Special Galaxy Models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Exponential SuperEllipse" + "### Edge on model\n", + "\n", + "Currently there is only one dedicared edge on model, the self gravitating isothermal disk from van der Kruit & Searle 1981. If you know of another common edge on model, feel free to let us know and we can add it in!" ] }, { @@ -787,13 +680,15 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"exponential superellipse galaxy model\",\n", - " parameters={\"center\": [50, 50], \"q\": 0.6, \"PA\": 60 * np.pi / 180, \"C0\": 2, \"Re\": 10, \"Ie\": 1},\n", + "M = ap.Model(\n", + " model_type=\"isothermal sech2 edgeon model\",\n", + " center=[50, 50],\n", + " PA=60 * np.pi / 180,\n", + " I0=1.0,\n", + " hs=3.0,\n", + " rs=5.0,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -807,7 +702,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Gaussian SuperEllipse" + "### Multi Gaussian Expansion\n", + "\n", + "A multi gaussian expansion is essentially a model made of overlapping gaussian models that share the same center. However, they are combined into a single model for computational efficiency. Another advantage of the MGE is that it is possible to determine a deprojection of the model from 2D into a 3D shape since the projection of a 3D gaussian is a 2D gaussian. Note however, that in some configurations this deprojection is not unique. See Cappellari 2002 for more details.\n", + "\n", + "Note: The ``PA`` can be either a single value (same for all components) or an array with values for each component." ] }, { @@ -816,26 +715,20 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"gaussian superellipse galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"C0\": 2,\n", - " \"sigma\": 20,\n", - " \"flux\": 1,\n", - " },\n", + "M = ap.Model(\n", + " model_type=\"mge model\",\n", + " center=[50, 50],\n", + " q=[0.9, 0.8, 0.6, 0.5],\n", + " PA=30 * np.pi / 180,\n", + " sigma=[4.0, 8.0, 16.0, 32.0],\n", + " flux=np.ones(4) / 4,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", + "fig, ax = plt.subplots(1, 1, figsize=(6, 6))\n", + "ap.plots.model_image(fig, ax, M)\n", + "ax.set_title(M.name)\n", "plt.show()" ] }, @@ -843,7 +736,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Nuker SuperEllipse" + "### Gaussian Ellipsoid\n", + "\n", + "This model is an intrinsically 3D gaussian ellipsoid shape, which is projected to 2D for imaging. " ] }, { @@ -852,122 +747,69 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"nuker superellipse galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"C0\": 2,\n", - " \"Rb\": 10.0,\n", - " \"Ib\": 1.0,\n", - " \"alpha\": 4.0,\n", - " \"beta\": 3.0,\n", - " \"gamma\": -0.2,\n", - " },\n", + "M = ap.Model(\n", + " model_type=\"gaussianellipsoid model\",\n", + " center=[50, 50],\n", + " sigma_a=20.0, # disk radius\n", + " sigma_b=20.0, # also disk radius\n", + " sigma_c=2.0, # disk thickness\n", + " alpha=0.0, # disk spin\n", + " beta=np.arccos(0.6), # disk inclination\n", + " gamma=30 * np.pi / 180, # disk position angle\n", + " flux=10.0,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Fourier Ellipse Models\n", - "\n", - "A Fourier ellipse is a scaling on the radius values as a function of theta. It takes the form: $R' = R * exp(\\sum_m am*cos(m*theta + phim))$, where am and phim are the parameters which describe the Fourier perturbations. Using the \"modes\" argument as a tuple, users can select which Fourier modes are used. As a rough intuition: mode 1 acts like a shift of the model; mode 2 acts like ellipticity; mode 3 makes a lopsided model (triangular in the extreme); and mode 4 makes peanut/diamond perturbations. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Spline Fourier" + "M.initialize()" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ - "fourier_am = np.array([0.1, 0.3, -0.2])\n", - "fourier_phim = np.array([10 * np.pi / 180, 0, 40 * np.pi / 180])\n", - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"spline fourier galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"am\": fourier_am,\n", - " \"phim\": fourier_phim,\n", - " \"I(R)\": {\"value\": spline_profile, \"prof\": NP_prof},\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", + "beta = np.linspace(0, np.pi, 50)\n", + "M.beta = beta[0]\n", + "fig, ax = plt.subplots(1, 1, figsize=(6, 6))\n", + "ap.plots.model_image(fig, ax, M, showcbar=False)\n", "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" + "\n", + "def update(frame):\n", + " M.beta = beta[frame]\n", + " ax.clear()\n", + " ap.plots.model_image(fig, ax, M, showcbar=False, vmin=24, vmax=30)\n", + " ax.set_title(f\"{M.name} beta = {beta[frame]:.2f} rad\")\n", + " return ax\n", + "\n", + "\n", + "ani = animation.FuncAnimation(fig, update, frames=50, interval=60)\n", + "plt.close()\n", + "# Save animation as gif\n", + "# ani.save(\"microlensing_animation.gif\", writer='pillow', fps=16) # Adjust 'fps' for the speed\n", + "# Or display the animation inline\n", + "HTML(ani.to_jshtml())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Sersic Fourier" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"sersic fourier galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"am\": fourier_am,\n", - " \"phim\": fourier_phim,\n", - " \"n\": 3,\n", - " \"Re\": 10,\n", - " \"Ie\": 1,\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", + "## Super Ellipse Models\n", "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" + "A super ellipse is a regular ellipse, except the radius metric changes from $R = \\sqrt{x^2 + y^2}$ to the more general: $R = |x^C + y^C|^{1/C}$. The parameter $C = 2$ for a regular ellipse, for $0 2$ the shape becomes more \"boxy.\" \n", + "\n", + "There are superellipse versions of all the primary galaxy models: `sersic`, `exponential`, `gaussian`, `moffat`, `spline`, `ferrer`, `king`, and `nuker`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Exponential Fourier" + "### Sersic SuperEllipse" ] }, { @@ -976,21 +818,17 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"exponential fourier galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"am\": fourier_am,\n", - " \"phim\": fourier_phim,\n", - " \"Re\": 10,\n", - " \"Ie\": 1,\n", - " },\n", + "M = ap.Model(\n", + " model_type=\"sersic superellipse galaxy model\",\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " C=4,\n", + " n=3,\n", + " Re=10,\n", + " Ie=1,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -1004,44 +842,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Gaussian Fourier" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"gaussian fourier galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"am\": fourier_am,\n", - " \"phim\": fourier_phim,\n", - " \"sigma\": 20,\n", - " \"flux\": 1,\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", + "## Fourier Ellipse Models\n", "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" + "A Fourier ellipse is a scaling on the radius values as a function of theta. It takes the form: $R' = R * \\exp(\\sum_m a_m*\\cos(m*\\theta + \\phi_m))$, where am and phim are the parameters which describe the Fourier perturbations. Using the \"modes\" argument as a tuple, users can select which Fourier modes are used. As a rough intuition: mode 1 acts like a shift of the model; mode 2 acts like ellipticity; mode 3 makes a lopsided model (triangular in the extreme); and mode 4 makes peanut/diamond perturbations. \n", + "\n", + "There are Fourier Ellipse versions of all the primary galaxy models: `sersic`, `exponential`, `gaussian`, `moffat`, `spline`, `ferrer`, `king`, and `nuker`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Nuker Fourier" + "### Sersic Fourier" ] }, { @@ -1050,24 +862,22 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"nuker fourier galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"am\": fourier_am,\n", - " \"phim\": fourier_phim,\n", - " \"Rb\": 10.0,\n", - " \"Ib\": 1.0,\n", - " \"alpha\": 4.0,\n", - " \"beta\": 3.0,\n", - " \"gamma\": -0.2,\n", - " },\n", + "fourier_am = np.array([0.1, 0.3, -0.2])\n", + "fourier_phim = np.array([10 * np.pi / 180, 0, 40 * np.pi / 180])\n", + "\n", + "M = ap.Model(\n", + " model_type=\"sersic fourier galaxy model\",\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " am=fourier_am,\n", + " phim=fourier_phim,\n", + " modes=(2, 3, 4),\n", + " n=3,\n", + " Re=10,\n", + " Ie=1,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -1083,22 +893,24 @@ "source": [ "## Warp Model\n", "\n", - "A warp model performs a radially varying coordinate transform. Essentially instead of applying a rotation matrix **Rot** on all coordinates X,Y we instead construct a unique rotation matrix for each coordinate pair **Rot(R)** where $R = \\sqrt(X^2 + Y^2)$. We also apply a radially dependent axis ratio **q(R)** to all the coordinates:\n", + "A warp model performs a radially varying coordinate transform. Essentially instead of applying a rotation matrix **Rot** on all coordinates X,Y we instead construct a unique rotation matrix for each coordinate pair **Rot(R)** where $R = \\sqrt{X^2 + Y^2}$. We also apply a radially dependent axis ratio **q(R)** to all the coordinates:\n", "\n", - "$R = \\sqrt(X^2 + Y^2)$\n", + "$R = \\sqrt{X^2 + Y^2}$\n", "\n", "$X, Y = Rotate(X, Y, PA(R))$\n", "\n", - "$Y = Y / q(R)$\n", + "$Y = \\frac{Y}{q(R)}$\n", "\n", - "The net effect is a radially varying PA and axis ratio which allows the model to represent spiral arms, bulges, or other features that change the apparent shape of a galaxy in a radially varying way." + "The net effect is a radially varying PA and axis ratio which allows the model to represent spiral arms, bulges, or other features that change the apparent shape of a galaxy in a radially varying way.\n", + "\n", + "There are warp versions of all the primary galaxy models: `sersic`, `exponential`, `gaussian`, `moffat`, `spline`, `ferrer`, `king`, and `nuker`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Spline Warp" + "### Sersic Warp" ] }, { @@ -1109,25 +921,26 @@ "source": [ "warp_q = np.linspace(0.1, 0.4, 14)\n", "warp_pa = np.linspace(0, np.pi - 0.2, 14)\n", - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"spline warp galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"I(R)\": {\"value\": spline_profile, \"prof\": NP_prof},\n", - " },\n", + "prof = np.linspace(0.0, 50, 14)\n", + "M = ap.Model(\n", + " model_type=\"sersic warp galaxy model\",\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " q_R={\"value\": warp_q, \"dynamic\": True, \"prof\": prof},\n", + " PA_R={\"value\": warp_pa, \"dynamic\": True, \"prof\": prof},\n", + " n=3,\n", + " Re=10,\n", + " Ie=1,\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", + "fig, ax = plt.subplots(1, 3, figsize=(20, 6))\n", "ap.plots.model_image(fig, ax[0], M)\n", "ap.plots.radial_light_profile(fig, ax[1], M)\n", + "ap.plots.warp_phase_profile(fig, ax[2], M)\n", + "ax[2].legend()\n", "ax[0].set_title(M.name)\n", "plt.show()" ] @@ -1136,45 +949,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Sersic Warp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"sersic warp galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"n\": 3,\n", - " \"Re\": 10,\n", - " \"Ie\": 1,\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", + "## Ray Model\n", "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" + "A ray model allows the user to break the galaxy up into regions that can be fit separately. There are two basic kinds of ray model: symmetric and asymmetric. A symmetric ray model (symmetric_rays = True) assumes 180 degree symmetry of the galaxy and so each ray is reflected through the center. This means that essentially the major axes and the minor axes are being fit separately. For an asymmetric ray model (symmetric_rays = False) each ray is it's own profile to be fit separately. \n", + "\n", + "In a ray model there is a smooth boundary between the rays. This smoothness is accomplished by applying a $(\\cos(r*theta)+1)/2$ weight to each profile, where r is dependent on the number of rays and theta is shifted to center on each ray in turn. The exact cosine weighting is dependent on if the rays are symmetric and if there is an even or odd number of rays. \n", + "\n", + "There are ray versions of all the primary galaxy models: `sersic`, `exponential`, `gaussian`, `moffat`, `spline`, `ferrer`, `king`, and `nuker`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Exponential Warp" + "### Sersic Ray" ] }, { @@ -1183,26 +971,23 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"exponential warp galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"Re\": 10,\n", - " \"Ie\": 1,\n", - " },\n", + "M = ap.Model(\n", + " model_type=\"sersic ray galaxy model\",\n", + " symmetric=True,\n", + " segments=2,\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " n=[1, 3],\n", + " Re=[10, 5],\n", + " Ie=[1, 0.5],\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", + "ap.plots.ray_light_profile(fig, ax[1], M)\n", "ax[0].set_title(M.name)\n", "plt.show()" ] @@ -1211,7 +996,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Gaussian Warp" + "## Wedge Model\n", + "\n", + "A wedge model behaves just like a ray model, except the boundaries are sharp. This has the advantage that the wedges can be very different in brightness without the \"smoothing\" from the ray model washing out the dimmer one. It also has the advantage of less \"mixing\" of information between the rays, each one can be counted on to have fit only the pixels in it's wedge without any influence from a neighbor. However, it has the disadvantage that the discontinuity at the boundary makes fitting behave strangely when a bright spot lays near the boundary.\n", + "\n", + "There are wedge versions of all the primary galaxy models: `sersic`, `exponential`, `gaussian`, `moffat`, `spline`, `ferrer`, `king`, and `nuker`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sersic Wedge" ] }, { @@ -1220,177 +1016,18 @@ "metadata": {}, "outputs": [], "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"gaussian warp galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"sigma\": 30,\n", - " \"flux\": 1,\n", - " },\n", + "M = ap.Model(\n", + " model_type=\"sersic wedge galaxy model\",\n", + " symmetric=True,\n", + " segments=2,\n", + " center=[50, 50],\n", + " q=0.6,\n", + " PA=60 * np.pi / 180,\n", + " n=[1, 3],\n", + " Re=[10, 5],\n", + " Ie=[1, 0.5],\n", " target=basic_target,\n", ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Nuker Warp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"nuker warp galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"Rb\": 10.0,\n", - " \"Ib\": 1.0,\n", - " \"alpha\": 4.0,\n", - " \"beta\": 3.0,\n", - " \"gamma\": -0.2,\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Ray Model\n", - "\n", - "A ray model allows the user to break the galaxy up into regions that can be fit separately. There are two basic kinds of ray model: symmetric and asymmetric. A symmetric ray model (symmetric_rays = True) assumes 180 degree symmetry of the galaxy and so each ray is reflected through the center. This means that essentially the major axes and the minor axes are being fit separately. For an asymmetric ray model (symmetric_rays = False) each ray is it's own profile to be fit separately. \n", - "\n", - "In a ray model there is a smooth boundary between the rays. This smoothness is accomplished by applying a $(\\cos(r*theta)+1)/2$ weight to each profile, where r is dependent on the number of rays and theta is shifted to center on each ray in turn. The exact cosine weighting is dependent on if the rays are symmetric and if there is an even or odd number of rays. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Spline Ray" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"spline ray galaxy model\",\n", - " symmetric_rays=True,\n", - " rays=2,\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"I(R)\": {\"value\": np.array([spline_profile * 2, spline_profile]), \"prof\": NP_prof},\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.ray_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Sersic Ray" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"sersic ray galaxy model\",\n", - " symmetric_rays=True,\n", - " rays=2,\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"n\": [1, 3],\n", - " \"Re\": [10, 5],\n", - " \"Ie\": [1, 0.5],\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.ray_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Exponential Ray" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"exponential ray galaxy model\",\n", - " symmetric_rays=True,\n", - " rays=2,\n", - " parameters={\"center\": [50, 50], \"q\": 0.6, \"PA\": 60 * np.pi / 180, \"Re\": [10, 5], \"Ie\": [1, 2]},\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", "M.initialize()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", @@ -1400,448 +1037,6 @@ "plt.show()" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Gaussian Ray" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"gaussian ray galaxy model\",\n", - " symmetric_rays=True,\n", - " rays=2,\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"sigma\": [10, 20],\n", - " \"flux\": [1.5, 1.0],\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.ray_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Nuker Ray" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"nuker ray galaxy model\",\n", - " symmetric_rays=True,\n", - " rays=2,\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"Rb\": [10.0, 1.0],\n", - " \"Ib\": [1.0, 0.0],\n", - " \"alpha\": [4.0, 1.0],\n", - " \"beta\": [3.0, 1.0],\n", - " \"gamma\": [-0.2, 0.2],\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.ray_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Wedge Model\n", - "\n", - "A wedge model behaves just like a ray model, except the boundaries are sharp. This has the advantage that the wedges can be very different in brightness without the \"smoothing\" from the ray model washing out the dimmer one. It also has the advantage of less \"mixing\" of information between the rays, each one can be counted on to have fit only the pixels in it's wedge without any influence from a neighbor. However, it has the disadvantage that the discontinuity at the boundary makes fitting behave strangely when a bright spot lays near the boundary." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Spline Wedge" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"spline wedge galaxy model\",\n", - " symmetric_wedges=True,\n", - " wedges=2,\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"I(R)\": {\"value\": np.array([spline_profile, spline_profile * 2]), \"prof\": NP_prof},\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.wedge_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## High Order Warp Models\n", - "\n", - "The models below combine the Warp coordinate transform with radial behaviour transforms: SuperEllipse and Fourier. These higher order models can create highly complex shapes, though their scientific use-case is less clear. They are included for completeness as they may be useful in some specific instances. These models are also included to demonstrate the flexibility in making AstroPhot models, in a future tutorial we will discuss how to make your own model types." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Spline SuperEllipse Warp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"spline superellipse warp galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"C0\": 2,\n", - " \"I(R)\": {\"value\": spline_profile, \"prof\": NP_prof},\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Sersic SuperEllipse Warp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"sersic superellipse warp galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"C0\": 2,\n", - " \"n\": 3,\n", - " \"Re\": 10,\n", - " \"Ie\": 1,\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Exponential SuperEllipse Warp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"exponential superellipse warp galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"C0\": 2,\n", - " \"Re\": 10,\n", - " \"Ie\": 1,\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Gaussian SuperEllipse Warp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"gaussian superellipse warp galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"C0\": 2,\n", - " \"sigma\": 30,\n", - " \"flux\": 1,\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Spline Fourier Warp\n", - "\n", - "not sure how this abomination would fit a galaxy, but you are welcome to try" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"spline fourier warp galaxy model\",\n", - " modes=(1, 3, 4),\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"am\": fourier_am,\n", - " \"phim\": fourier_phim,\n", - " \"I(R)\": {\"value\": spline_profile, \"prof\": NP_prof},\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Sersic Fourier Warp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"sersic fourier warp galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"am\": fourier_am,\n", - " \"phim\": fourier_phim,\n", - " \"n\": 3,\n", - " \"Re\": 10,\n", - " \"Ie\": 1,\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Exponential Fourier Warp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"exponential fourier warp galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"am\": fourier_am,\n", - " \"phim\": fourier_phim,\n", - " \"Re\": 10,\n", - " \"Ie\": 1,\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Gassian Fourier Warp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "M = ap.models.AstroPhot_Model(\n", - " model_type=\"gaussian fourier warp galaxy model\",\n", - " parameters={\n", - " \"center\": [50, 50],\n", - " \"q\": 0.6,\n", - " \"PA\": 60 * np.pi / 180,\n", - " \"q(R)\": warp_q,\n", - " \"PA(R)\": warp_pa,\n", - " \"am\": fourier_am,\n", - " \"phim\": fourier_phim,\n", - " \"sigma\": 20,\n", - " \"flux\": 1,\n", - " },\n", - " target=basic_target,\n", - ")\n", - "print(M.parameter_order)\n", - "print(tuple(P.units for P in M.parameters))\n", - "M.initialize()\n", - "\n", - "fig, ax = plt.subplots(1, 2, figsize=(14, 6))\n", - "ap.plots.model_image(fig, ax[0], M)\n", - "ap.plots.radial_light_profile(fig, ax[1], M)\n", - "ax[0].set_title(M.name)\n", - "plt.show()" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/docs/source/tutorials/PoissonLikelihood.ipynb b/docs/source/tutorials/PoissonLikelihood.ipynb new file mode 100644 index 00000000..d271b9d5 --- /dev/null +++ b/docs/source/tutorials/PoissonLikelihood.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Poisson Noise Model\n", + "\n", + "For the most part, astronomical images are modelled assuming an independent Gaussian uncertainty on every pixel resulting in a negative log likelihood of the form: $\\sum_i\\frac{(d_i-m_i)^2}{2\\sigma_i^2}$ where $d_i$ is the pixel value, $m_i$ is the model value for that pixel, and $\\sigma_i$ is the uncertainty on that pixel. However, in truth the best model for an astronomical image is the Poisson distribution with negative log likelihood of: $\\sum_i m_i + \\log(d_i!) - d_i\\log(m_i)$ with the same definitions, except specifying that $d_i$ is in counts (number of photons or electrons). For large enough $d_i$ these likelihoods are essentially identical and Gaussian is easier to work with. When signal-to-noise ratios get very low, the differences between Poisson and Gaussian distributions can become apparent and so it is important to treat the data with a Poisson likelihood. These conditions regularly occur for gamma ray, x-ray, and low SNR UV data, but are less common for longer wavelengths. AstroPhot can model Poisson likelihood data, here we will demo an example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import astrophot as ap\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Make some mock data\n", + "\n", + "Lets create some mock low SNR data. Notice that poisson noise isn't additive like gaussian noise. To sample the image, out true model acts as a photon rate and the `np.random.poisson` samples some number of counts based on that rate. Our goal will be to recover the rate of every pixel and ultimately the sersic parameters that produce the correct rate model. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "# make some mock data\n", + "target = ap.TargetImage(data=np.zeros((128, 128)))\n", + "true_model = ap.Model(\n", + " name=\"truth\",\n", + " model_type=\"sersic galaxy model\",\n", + " center=(64, 64),\n", + " q=0.7,\n", + " PA=0.5,\n", + " n=1,\n", + " Re=32,\n", + " Ie=1,\n", + " target=target,\n", + ")\n", + "img = true_model().data.detach().cpu().numpy()\n", + "np.random.seed(42) # for reproducibility\n", + "target.data = np.random.poisson(img) # sample poisson distribution\n", + "true_params = true_model.get_values()\n", + "\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", + "ap.plots.model_image(fig, ax[0], true_model)\n", + "ax[0].set_title(\"True Model\")\n", + "ap.plots.target_image(fig, ax[1], target)\n", + "ax[1].set_title(\"Target Image (Poisson Sampled)\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "Indeed this is some noisy data. The AstroPhot target_image plotting routine struggles a bit with this image, but it kind of looks neat anyway.\n", + "\n", + "## Model the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "model = ap.Model(name=\"model\", model_type=\"sersic galaxy model\", target=target)\n", + "model.initialize()" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "While the Levenberg-Marquardt algorithm is traditionally considered as a least squares algorithm, that is actually just its most common application. LM naturally generalizes to a broad class of problems, including the Poisson Likelihood (see [Fowler 2014](https://ui.adsabs.harvard.edu/abs/2014JLTP..176..414F/abstract)). Here we see the AstroPhot automatic initialization does well on this image and recovers decent starting parameters, LM has an easy time finishing the job to find the maximum likelihood.\n", + "\n", + "Note that the idea of a $\\chi^2/{\\rm dof}$ is not as clearly defined for a Poisson likelihood. We take the closest analogue by taking 2 times the negative log likelihood divided by the DoF. This doesn't have any strict statistical meaning but is somewhat intuitive to work with for those used to $\\chi^2/{\\rm dof}$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "res = ap.fit.LM(model, likelihood=\"poisson\", verbose=1).fit()\n", + "\n", + "fig, ax = plt.subplots()\n", + "ap.plots.model_image(fig, ax, model)\n", + "ax.set_title(\"Fitted Model\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "Plotting the model parameters and uncertainty, we see that we have indeed recovered very close to the true values for all parameters! Note that the true values are, however, not where we expect with respect to the 1-2 sigma uncertainty contours. There are two reasons for this, one is that this is a Poisson likelihood and so a Gaussian approximation is only so good, the other is that the model is non-linear so again the Gaussian approximation at the maximum likelihood will not exactly describe the PDF (which actually affects model uncertainties even for a Gaussian likelihood)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = ap.plots.covariance_matrix(\n", + " res.covariance_matrix.detach().cpu().numpy(),\n", + " model.get_values().detach().cpu().numpy(),\n", + " reference_values=true_params.detach().cpu().numpy(),\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "If you encounter a problem where LM struggles to fit the poisson data, the `Slalom` optimizer is also quite efficient in these settings. See the fitting methods tutorial for more details." + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/index.rst b/docs/source/tutorials/index.rst index c1fa8b91..2d6deef0 100644 --- a/docs/source/tutorials/index.rst +++ b/docs/source/tutorials/index.rst @@ -10,11 +10,17 @@ version of each tutorial is available here. :maxdepth: 1 GettingStarted + GettingStartedJAX GroupModels FittingMethods ModelZoo BasicPSFModels JointModels + ImageAlignment + PoissonLikelihood CustomModels + FunctionalInterface + GravitationalLensing AdvancedPSFModels + ImageTypes ConstrainedModels diff --git a/make_docs.py b/make_docs.py new file mode 100644 index 00000000..f9670b26 --- /dev/null +++ b/make_docs.py @@ -0,0 +1,103 @@ +import astrophot as ap +import nbformat +from nbformat.v4 import new_notebook, new_markdown_cell +import pkgutil +from types import ModuleType, FunctionType +import os +from textwrap import dedent +from inspect import cleandoc, getmodule, signature + +skip_methods = [ + "to_valid", + "topological_ordering", + "to_static", + "to_dynamic", + "unlink", + "update_graph", + "save_state", + "load_state", + "append_state", + "link", + "graphviz", + "graph_print", + "graph_dict", + "from_valid", + "fill_params", + "fill_kwargs", + "fill_dynamic_values", + "clear_params", + "build_params_list", + "build_params_dict", + "build_params_array", +] + + +def dot_path(path): + i = path.rfind("AstroPhot") + path = path[i + 10 :] + path = path.replace("/", ".") + return path[:-3] + + +def gather_docs(module, module_only=False): + docs = {} + for name in module.__all__: + obj = getattr(module, name) + if module_only and not isinstance(obj, ModuleType): + continue + if isinstance(obj, type): + if obj.__doc__ is None: + continue + docs[name] = cleandoc(obj.__doc__) + subfuncs = [docs[name]] + for attr in dir(obj): + if attr.startswith("_"): + continue + if attr in skip_methods: + continue + attrobj = getattr(obj, attr) + if not isinstance(attrobj, FunctionType): + continue + if attrobj.__doc__ is None: + continue + sig = str(signature(attrobj)).replace("self,", "").replace("self", "") + subfuncs.append(f"**method:** {attr}{sig}\n\n" + cleandoc(attrobj.__doc__)) + if len(subfuncs) > 1: + docs[name] = "\n\n".join(subfuncs) + elif isinstance(obj, FunctionType): + if obj.__doc__ is None: + continue + sig = str(signature(obj)) + docs[name] = "**signature:** " + name + sig + "\n\n" + cleandoc(obj.__doc__) + elif isinstance(obj, ModuleType): + docs[name] = gather_docs(obj) + else: + print(f"!!!unexpected type {type(obj)}!!!") + return docs + + +def make_cells(mod_dict, path, depth=2): + print(mod_dict.keys()) + cells = [] + for k in mod_dict: + if isinstance(mod_dict[k], str): + cells.append(new_markdown_cell(f"{'#'*depth} {path}.{k}\n\n" + mod_dict[k])) + elif isinstance(mod_dict[k], dict): + print(k) + cells += make_cells(mod_dict[k], path=path + "." + k, depth=depth + 1) + return cells + + +output_dir = "docs/source/astrophotdocs" +all_ap = gather_docs(ap, True) + +for submodule in all_ap: + nb = new_notebook() + nb.cells = [new_markdown_cell(f"# {submodule}")] + make_cells( + all_ap[submodule], f"astrophot.{submodule}" + ) + + filename = f"{submodule}.ipynb" + path = os.path.join(output_dir, filename) + with open(path, "w", encoding="utf-8") as f: + nbformat.write(nb, f) diff --git a/pyproject.toml b/pyproject.toml index 5beaae94..faaf81cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,9 @@ Documentation = "https://autostronomy.github.io/AstroPhot/" Repository = "https://github.com/Autostronomy/AstroPhot" Issues = "https://github.com/Autostronomy/AstroPhot/issues" +[project.optional-dependencies] +dev = ["pre-commit", "nbval", "nbconvert", "graphviz", "ipywidgets", "jupyter-book", "matplotlib", "photutils", "scikit-image", "caustics", "emcee", "corner", "jax<=0.7.0", "pyvo"] + [project.scripts] astrophot = "astrophot:run_from_terminal" diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 416634f5..00000000 --- a/requirements-dev.txt +++ /dev/null @@ -1 +0,0 @@ -pre-commit diff --git a/requirements.txt b/requirements.txt index efd85c11..1a4dfb24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ astropy>=5.3 +caskade>=0.6.0 h5py>=3.8.0 matplotlib>=3.7 numpy>=1.24.0,<2.0.0 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..92081514 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,15 @@ +import matplotlib +import matplotlib.pyplot as plt +import pytest + + +@pytest.fixture(autouse=True) +def no_block_show(monkeypatch): + def close_show(*args, **kwargs): + # plt.savefig("/dev/null") # or do nothing + plt.close("all") + + monkeypatch.setattr(plt, "show", close_show) + + # Also ensure we are in a non-GUI backend + matplotlib.use("Agg") diff --git a/tests/test_cmos_image.py b/tests/test_cmos_image.py new file mode 100644 index 00000000..4cfb5123 --- /dev/null +++ b/tests/test_cmos_image.py @@ -0,0 +1,92 @@ +import astrophot as ap +import torch +import numpy as np + +import pytest + +###################################################################### +# Image Objects +###################################################################### + + +@pytest.fixture() +def cmos_target(): + arr = ap.backend.zeros((10, 15)) + return ap.CMOSTargetImage( + data=arr, + pixelscale=0.7, + zeropoint=1.0, + variance=ap.backend.ones_like(arr), + mask=ap.backend.zeros_like(arr), + subpixel_loc=(-0.25, -0.25), + subpixel_scale=0.5, + ) + + +def test_cmos_image_creation(cmos_target): + cmos_copy = cmos_target.copy() + assert cmos_copy.pixelscale == 0.7, "image should track pixelscale" + assert cmos_copy.zeropoint == 1.0, "image should track zeropoint" + assert cmos_copy.crpix[0] == 0, "image should track crpix" + assert cmos_copy.crpix[1] == 0, "image should track crpix" + assert cmos_copy.subpixel_loc == (-0.25, -0.25), "image should track subpixel location" + assert cmos_copy.subpixel_scale == 0.5, "image should track subpixel scale" + + print(cmos_target.data.shape) + i, j = cmos_target.pixel_center_meshgrid() + assert i.shape == (15, 10), "meshgrid should have correct shape" + assert j.shape == (15, 10), "meshgrid should have correct shape" + + x, y = cmos_target.coordinate_center_meshgrid() + assert x.shape == (15, 10), "coordinate meshgrid should have correct shape" + assert y.shape == (15, 10), "coordinate meshgrid should have correct shape" + + +def test_cmos_model_sample(cmos_target): + model = ap.Model( + name="test cmos", + model_type="sersic galaxy model", + target=cmos_target, + center=(3, 5), + q=0.7, + PA=np.pi / 3, + n=2.5, + Re=4, + Ie=1.0, + sampling_mode="midpoint", + integrate_mode="bright", + ) + model.initialize() + img = model.sample() + + assert isinstance(img, ap.CMOSModelImage), "sampled image should be a CMOSModelImage" + assert img.pixelscale == cmos_target.pixelscale, "sampled image should have the same pixelscale" + assert img.zeropoint == cmos_target.zeropoint, "sampled image should have the same zeropoint" + assert ( + img.subpixel_loc == cmos_target.subpixel_loc + ), "sampled image should have the same subpixel location" + + +def test_cmos_image_save_load(cmos_target): + # Save the image + cmos_target.save("cmos_image.fits") + + # Load the image + loaded_image = ap.CMOSTargetImage(filename="cmos_image.fits") + + # Check if the loaded image matches the original + assert ap.backend.allclose( + cmos_target.data, loaded_image.data + ), "Loaded image data should match original" + assert ap.backend.allclose( + cmos_target.pixelscale, loaded_image.pixelscale + ), "Loaded image pixelscale should match original" + assert ap.backend.allclose( + cmos_target.zeropoint, loaded_image.zeropoint + ), "Loaded image zeropoint should match original" + assert np.allclose( + cmos_target.subpixel_loc, loaded_image.subpixel_loc + ), "Loaded image subpixel location should match original" + assert np.allclose( + cmos_target.subpixel_scale, loaded_image.subpixel_scale + ), "Loaded image subpixel scale should match original" diff --git a/tests/test_fit.py b/tests/test_fit.py index 3f0f43f8..bfb2ad13 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -1,575 +1,297 @@ -import unittest - import torch import numpy as np import astrophot as ap from utils import make_basic_sersic +import pytest ###################################################################### # Fit Objects ###################################################################### -class TestComponentModelFits(unittest.TestCase): - def test_sersic_fit_grad(self): - """ - Simply test that the gradient optimizer changes the parameters - """ - np.random.seed(12345) - N = 50 - Width = 20 - shape = (N + 10, N) - true_params = [2, 5, 10, -3, 5, 0.7, np.pi / 4] - IXX, IYY = np.meshgrid( - np.linspace(-Width, Width, shape[1]), np.linspace(-Width, Width, shape[0]) - ) - QPAXX, QPAYY = ap.utils.conversions.coordinates.Axis_Ratio_Cartesian_np( - true_params[5], IXX - true_params[3], IYY - true_params[4], true_params[6] - ) - Z0 = ap.utils.parametric_profiles.sersic_np( - np.sqrt(QPAXX**2 + QPAYY**2), - true_params[0], - true_params[1], - true_params[2], - ) + np.random.normal(loc=0, scale=0.1, size=shape) - tar = ap.image.Target_Image( - data=Z0, - pixelscale=0.8, - variance=np.ones(Z0.shape) * (0.1**2), - ) - - mod = ap.models.Sersic_Galaxy( - name="sersic model", - target=tar, - parameters={ - "center": [-3.2 + N / 2, 5.1 + (N + 10) / 2], - "q": 0.6, - "PA": np.pi / 4, - "n": 2, - "Re": 5, - "Ie": 10, - }, - ) - - self.assertFalse(mod.locked, "default model should not be locked") - - mod.initialize() - - mod_initparams = {} - for p in mod.parameters: - mod_initparams[p.name] = np.copy(p.vector_representation().detach().cpu().numpy()) - - res = ap.fit.Grad(model=mod, max_iter=10).fit() - - for p in mod.parameters: - self.assertFalse( - np.any(p.vector_representation().detach().cpu().numpy() == mod_initparams[p.name]), - f"parameter {p.name} should update with optimization", - ) - - def test_sersic_fit_lm(self): - """ - Test sersic fitting with entirely independent sersic sampling at 10x resolution. - """ - N = 50 - pixelscale = 0.8 - shape = (N + 10, N) - true_params = { - "center": [ - shape[0] * pixelscale / 2 - 3.35, - shape[1] * pixelscale / 2 + 5.35, - ], - "n": 1, - "Re": 20, - "Ie": 0.0, - "q": 0.7, - "PA": np.pi / 4, - } - tar = make_basic_sersic( - N=shape[0], - M=shape[1], - pixelscale=pixelscale, - x=true_params["center"][0], - y=true_params["center"][1], - n=true_params["n"], - Re=true_params["Re"], - Ie=true_params["Ie"], - q=true_params["q"], - PA=true_params["PA"], - ) - mod = ap.models.AstroPhot_Model( - name="sersic model", - model_type="sersic galaxy model", - target=tar, - sampling_mode="simpsons", - ) - - mod.initialize() - ap.AP_config.set_logging_output(stdout=True, filename="AstroPhot.log") - res = ap.fit.LM(model=mod, verbose=2).fit() - res.update_uncertainty() - - self.assertAlmostEqual( - mod["center"].value[0].item() / true_params["center"][0], - 1, - 2, - "LM should accurately recover parameters in simple cases", - ) - self.assertAlmostEqual( - mod["center"].value[1].item() / true_params["center"][1], - 1, - 2, - "LM should accurately recover parameters in simple cases", - ) - self.assertAlmostEqual( - mod["n"].value.item(), - true_params["n"], - 1, - msg="LM should accurately recover parameters in simple cases", - ) - self.assertAlmostEqual( - (mod["Re"].value.item()) / true_params["Re"], - 1, - delta=1, - msg="LM should accurately recover parameters in simple cases", - ) - self.assertAlmostEqual( - mod["Ie"].value.item(), - true_params["Ie"], - 1, - "LM should accurately recover parameters in simple cases", - ) - self.assertAlmostEqual( - mod["PA"].value.item() / true_params["PA"], - 1, - delta=0.5, - msg="LM should accurately recover parameters in simple cases", - ) - self.assertAlmostEqual( - mod["q"].value.item(), - true_params["q"], - 1, - "LM should accurately recover parameters in simple cases", - ) - res.covariance_matrix - - # check for crash - mod.total_flux() - mod.total_flux_uncertainty() - mod.total_magnitude() - mod.total_magnitude_uncertainty() - - -class TestGroupModelFits(unittest.TestCase): - def test_groupmodel_fit(self): - """ - Simply test that fitting a group model changes the parameter values - """ - np.random.seed(12345) - N = 50 - Width = 20 - shape = (N + 10, N) - true_params1 = [2, 4, 10, -3, 5, 0.7, np.pi / 4] - true_params2 = [1.2, 6, 8, 2, -3, 0.5, -np.pi / 4] - IXX, IYY = np.meshgrid( - np.linspace(-Width, Width, shape[1]), np.linspace(-Width, Width, shape[0]) - ) - QPAXX, QPAYY = ap.utils.conversions.coordinates.Axis_Ratio_Cartesian_np( - true_params1[5], - IXX - true_params1[3], - IYY - true_params1[4], - true_params1[6], - ) - Z0 = ap.utils.parametric_profiles.sersic_np( - np.sqrt(QPAXX**2 + QPAYY**2), - true_params1[0], - true_params1[1], - true_params1[2], - ) - QPAXX, QPAYY = ap.utils.conversions.coordinates.Axis_Ratio_Cartesian_np( - true_params2[5], - IXX - true_params2[3], - IYY - true_params2[4], - true_params2[6], - ) - Z0 += ap.utils.parametric_profiles.sersic_np( - np.sqrt(QPAXX**2 + QPAYY**2), - true_params2[0], - true_params2[1], - true_params2[2], - ) - Z0 += np.random.normal(loc=0, scale=0.1, size=shape) - tar = ap.image.Target_Image( - data=Z0, - pixelscale=0.8, - variance=np.ones(Z0.shape) * (0.1**2), - ) - - mod1 = ap.models.Sersic_Galaxy( - name="sersic model 1", - target=tar, - parameters={"center": {"value": [-3.2 + N / 2, 5.1 + (N + 10) / 2]}}, - ) - mod2 = ap.models.Sersic_Galaxy( - name="sersic model 2", - target=tar, - parameters={"center": {"value": [2.1 + N / 2, -3.1 + (N + 10) / 2]}}, - ) - - smod = ap.models.Group_Model(name="group model", models=[mod1, mod2], target=tar) - - self.assertFalse(smod.locked, "default model should not be locked") - - smod.initialize() - - mod1_initparams = {} - for p in mod1.parameters: - mod1_initparams[p.name] = np.copy(p.vector_representation().detach().cpu().numpy()) - mod2_initparams = {} - for p in mod2.parameters: - mod2_initparams[p.name] = np.copy(p.vector_representation().detach().cpu().numpy()) - - res = ap.fit.Grad(model=smod, max_iter=10).fit() - - for p in mod1.parameters: - self.assertFalse( - np.any(p.vector_representation().detach().cpu().numpy() == mod1_initparams[p.name]), - f"mod1 parameter {p.name} should update with optimization", - ) - for p in mod2.parameters: - self.assertFalse( - np.any(p.vector_representation().detach().cpu().numpy() == mod2_initparams[p.name]), - f"mod2 parameter {p.name} should update with optimization", - ) - - -class TestLM(unittest.TestCase): - def test_lm_creation(self): - target = make_basic_sersic() - new_model = ap.models.AstroPhot_Model( - name="test sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - target=target, - ) - - LM = ap.fit.LM(new_model, max_iter=10) - - LM.fit() - - def test_chunk_parameter_jacobian(self): - target = make_basic_sersic() - new_model = ap.models.AstroPhot_Model( - name="test sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - target=target, - jacobian_chunksize=3, - ) - - LM = ap.fit.LM(new_model, max_iter=10) - - LM.fit() - - def test_chunk_image_jacobian(self): - target = make_basic_sersic() - new_model = ap.models.AstroPhot_Model( - name="test sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - target=target, - image_chunksize=15, - ) - - LM = ap.fit.LM(new_model, max_iter=10) - - LM.fit() - - def test_group_fit_step(self): - np.random.seed(123456) - tar = make_basic_sersic(N=51, M=51) - mod1 = ap.models.Sersic_Galaxy( - name="base model 1", - target=tar, - window=[[0, 25], [0, 25]], - parameters={ - "center": [5, 5], - "PA": 0, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - ) - mod2 = ap.models.Sersic_Galaxy( - name="base model 2", - target=tar, - window=[[25, 51], [25, 51]], - parameters={ - "center": [5, 5], - "PA": 0, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - ) - - smod = ap.models.AstroPhot_Model( - name="group model", - model_type="group model", - models=[mod1, mod2], - target=tar, - ) - vec_init = smod.parameters.vector_values().detach().clone() - LM = ap.fit.LM(smod, max_iter=1).fit() - vec_final = smod.parameters.vector_values().detach().clone() - self.assertFalse( - torch.all(vec_init == vec_final), - "LM should update parameters in LM step", - ) - - -class TestMiniFit(unittest.TestCase): - def test_minifit(self): - target = make_basic_sersic() - new_model = ap.models.AstroPhot_Model( - name="test sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, +@pytest.mark.parametrize("center", [[20, 20], [25.1, 17.324567]]) +@pytest.mark.parametrize("PA", [0, 60 * np.pi / 180]) +@pytest.mark.parametrize("q", [0.2, 0.8]) +@pytest.mark.parametrize("n", [1, 4]) +@pytest.mark.parametrize("Re", [10, 25.1]) +def test_chunk_jacobian(center, PA, q, n, Re): + target = make_basic_sersic() + model = ap.Model( + name="test sersic", + model_type="sersic galaxy model", + center=center, + PA=PA, + q=q, + n=n, + Re=Re, + Ie=10.0, + target=target, + integrate_mode="none", + ) + + Jtrue = model.jacobian() + + model.jacobian_maxparams = 3 + + Jchunked = model.jacobian() + assert ap.backend.allclose( + Jtrue.data, Jchunked.data + ), "Param chunked Jacobian should match full Jacobian" + + model.jacobian_maxparams = 10 + model.jacobian_maxpixels = 20**2 + + Jchunked = model.jacobian() + + assert ap.backend.allclose( + Jtrue.data, Jchunked.data + ), "Pixel chunked Jacobian should match full Jacobian" + + +@pytest.fixture +def sersic_model(): + target = make_basic_sersic() + model = ap.Model( + name="test sersic", + model_type="sersic galaxy model", + center=[20, 20], + PA=np.pi, + q=0.7, + n=2, + Re=15, + Ie=10.0, + target=target, + ) + model.initialize() + return model + + +@pytest.mark.parametrize( + "fitter,extra", + [ + (ap.fit.LM, {}), + (ap.fit.LM, {"likelihood": "poisson"}), + (ap.fit.LMfast, {}), + (ap.fit.IterParam, {"chunks": 3, "chunk_order": "sequential", "verbose": 2}), + ( + ap.fit.IterParam, + {"chunks": 3, "chunk_order": "random", "verbose": 2, "likelihood": "poisson"}, + ), + (ap.fit.Grad, {}), + (ap.fit.ScipyFit, {}), + (ap.fit.MHMCMC, {}), + (ap.fit.HMC, {}), + (ap.fit.MALA, {"epsilon": 1e-3}), + ( + ap.fit.MALA, + { + "epsilon": 1e-3, + "likelihood": "poisson", + "initial_state": [[20, 20, 0.7, np.pi, 2, 15, 10]], }, - target=target, - ) - - MF = ap.fit.MiniFit( - new_model, downsample_factor=2, method_quargs={"max_iter": 10}, verbose=1 - ) - - MF.fit() - - -class TestIter(unittest.TestCase): - def test_iter_basic(self): - target = make_basic_sersic() - model_list = [] - model_list.append( - ap.models.AstroPhot_Model( - name="basic sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - target=target, - ) - ) - model_list.append( - ap.models.AstroPhot_Model( - name="basic sky", - model_type="flat sky model", - parameters={"F": -1}, - target=target, - ) - ) - - MODEL = ap.models.AstroPhot_Model( - name="model", - model_type="group model", - target=target, - models=model_list, - ) - - MODEL.initialize() - - res = ap.fit.Iter(MODEL, method=ap.fit.LM) - - res.fit() - - -class TestIterLM(unittest.TestCase): - def test_iter_basic(self): - target = make_basic_sersic() - model_list = [] - model_list.append( - ap.models.AstroPhot_Model( - name="basic sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - target=target, - ) - ) - model_list.append( - ap.models.AstroPhot_Model( - name="basic sky", - model_type="flat sky model", - parameters={"F": -1}, - target=target, - ) - ) - - MODEL = ap.models.AstroPhot_Model( - name="model", - model_type="group model", - target=target, - models=model_list, - ) - - MODEL.initialize() - - res = ap.fit.Iter_LM(MODEL) - - res.fit() - - -class TestHMC(unittest.TestCase): - def test_hmc_sample(self): - np.random.seed(12345) - N = 50 - pixelscale = 0.8 - true_params = { - "n": 2, - "Re": 10, - "Ie": 1, - "center": [-3.3, 5.3], - "q": 0.7, - "PA": np.pi / 4, - } - target = ap.image.Target_Image( - data=np.zeros((N, N)), - pixelscale=pixelscale, - ) - - MODEL = ap.models.Sersic_Galaxy( - name="sersic model", - target=target, - parameters=true_params, - ) - img = MODEL().data.detach().cpu().numpy() - target.data = torch.Tensor( - img - + np.random.normal(scale=0.1, size=img.shape) - + np.random.normal(scale=np.sqrt(img) / 10) - ) - target.variance = torch.Tensor(0.1**2 + img / 100) - - HMC = ap.fit.HMC(MODEL, epsilon=1e-5, max_iter=5, warmup=2) - HMC.fit() - - -class TestNUTS(unittest.TestCase): - def test_nuts_sample(self): - np.random.seed(12345) - N = 50 - pixelscale = 0.8 - true_params = { - "n": 2, - "Re": 10, - "Ie": 1, - "center": [-3.3, 5.3], - "q": 0.7, - "PA": np.pi / 4, - } - target = ap.image.Target_Image( - data=np.zeros((N, N)), - pixelscale=pixelscale, - ) - - MODEL = ap.models.Sersic_Galaxy( - name="sersic model", - target=target, - parameters=true_params, - ) - img = MODEL().data.detach().cpu().numpy() - target.data = torch.Tensor( - img - + np.random.normal(scale=0.1, size=img.shape) - + np.random.normal(scale=np.sqrt(img) / 10) - ) - target.variance = torch.Tensor(0.1**2 + img / 100) - - NUTS = ap.fit.NUTS(MODEL, max_iter=5, warmup=2) - NUTS.fit() - - -class TestMHMCMC(unittest.TestCase): - def test_singlesersic(self): - np.random.seed(12345) - N = 50 - pixelscale = 0.8 - true_params = { - "n": 2, - "Re": 10, - "Ie": 1, - "center": [-3.3, 5.3], - "q": 0.7, - "PA": np.pi / 4, - } - target = ap.image.Target_Image( - data=np.zeros((N, N)), - pixelscale=pixelscale, - ) - - MODEL = ap.models.Sersic_Galaxy( - name="sersic model", - target=target, - parameters=true_params, - ) - img = MODEL().data.detach().cpu().numpy() - target.data = torch.Tensor( - img - + np.random.normal(scale=0.1, size=img.shape) - + np.random.normal(scale=np.sqrt(img) / 10) - ) - target.variance = torch.Tensor(0.1**2 + img / 100) - - MHMCMC = ap.fit.MHMCMC(MODEL, epsilon=1e-4, max_iter=100) - MHMCMC.fit() - - self.assertGreater( - MHMCMC.acceptance, - 0.1, - "MHMCMC should have nonzero acceptance for simple fits", - ) - - -if __name__ == "__main__": - unittest.main() + ), + (ap.fit.MiniFit, {}), + (ap.fit.Slalom, {}), + ], +) +@pytest.mark.parametrize("fit_valid", [True, False]) +def test_fitters(fitter, extra, sersic_model, fit_valid): + if ap.backend.backend == "jax" and fitter in [ap.fit.Grad, ap.fit.HMC]: + pytest.skip("Grad and HMC not implemented for JAX backend") + model = sersic_model + model.initialize() + ll_init = model.gaussian_log_likelihood() + pll_init = model.poisson_log_likelihood() + result = fitter(model, max_iter=100, fit_valid=fit_valid, **extra).fit() + ll_final = model.gaussian_log_likelihood() + pll_final = model.poisson_log_likelihood() + assert ll_final > ll_init, f"{fitter.__name__} should improve the log likelihood" + assert pll_final > pll_init, f"{fitter.__name__} should improve the poisson log likelihood" + + +def test_fitters_iter(): + target = make_basic_sersic() + model1 = ap.Model( + name="test1", + model_type="sersic galaxy model", + center=[20, 20], + PA=np.pi, + q=0.7, + n=2, + Re=15, + Ie=10.0, + target=target, + ) + model2 = ap.Model( + name="test2", + model_type="sersic galaxy model", + center=[20.5, 21], + PA=1.5 * np.pi, + q=0.9, + n=1, + Re=10, + Ie=8.0, + target=target, + ) + model = ap.Model( + name="test group", + model_type="group model", + models=[model1, model2], + target=target, + ) + model.initialize() + ll_init = model.gaussian_log_likelihood() + pll_init = model.poisson_log_likelihood() + result = ap.fit.Iter(model, max_iter=10).fit() + ll_final = model.gaussian_log_likelihood() + pll_final = model.poisson_log_likelihood() + assert ll_final > ll_init, f"Iter should improve the log likelihood" + assert pll_final > pll_init, f"Iter should improve the poisson log likelihood" + + # test hessian + Hgauss = model.hessian(likelihood="gaussian") + assert ap.backend.all( + ap.backend.isfinite(Hgauss) + ), "Hessian should be finite for Gaussian likelihood" + Hpoisson = model.hessian(likelihood="poisson") + assert ap.backend.all( + ap.backend.isfinite(Hpoisson) + ), "Hessian should be finite for Poisson likelihood" + + +def test_hessian(sersic_model): + model = sersic_model + model.initialize() + Hgauss = model.hessian(likelihood="gaussian") + assert ap.backend.all( + ap.backend.isfinite(Hgauss) + ), "Hessian should be finite for Gaussian likelihood" + Hpoisson = model.hessian(likelihood="poisson") + assert ap.backend.all( + ap.backend.isfinite(Hpoisson) + ), "Hessian should be finite for Poisson likelihood" + assert Hgauss is not None, "Hessian should be computed for Gaussian likelihood" + assert Hpoisson is not None, "Hessian should be computed for Poisson likelihood" + with pytest.raises(ValueError): + model.hessian(likelihood="unknown") + + +def test_gradient(sersic_model): + if ap.backend.backend == "jax": + pytest.skip("JAX backend does not support backward function") + model = sersic_model + target = model.target + target.weight = 1 / (10 + target.variance) + model.initialize() + x = model.get_values() + grad = model.gradient() + assert ap.backend.all(ap.backend.isfinite(grad)), "Gradient should be finite" + assert grad.shape == x.shape, "Gradient shape should match parameters shape" + x.requires_grad = True + ll = model.gaussian_log_likelihood(x) + ll.backward() + autograd = x.grad + assert ap.backend.allclose(grad, autograd, rtol=1e-4), "Gradient should match autograd gradient" + + funcgrad = ap.backend.grad(model.gaussian_log_likelihood)(x) + assert ap.backend.allclose( + grad, funcgrad, rtol=1e-4 + ), "Gradient should match functional gradient" + + +def test_options(sersic_model): + model = sersic_model + model.initialize() + + with pytest.raises(ValueError): + ap.fit.LM(model, likelihood="unknown") + with pytest.raises(ValueError): + ap.fit.IterParam(model, likelihood="unknown") + with pytest.raises(ap.errors.OptimizeStopSuccess): + model.target.mask = ap.backend.ones_like(model.target.mask, dtype=bool) + ap.fit.IterParam(model) + model.target.mask = ap.backend.zeros_like(model.target.mask, dtype=bool) + + fitter = ap.fit.IterParam( + model=model, + W=model.target.weight, + ndf=np.prod(model.target.data.shape), + chunk_order="invalid", + ) + with pytest.raises(ValueError): + fitter.fit() + + model.to_static(False) + res = ap.fit.IterParam(model).fit() + assert "No parameters to optimize" in res.message, "Should exit if no dynamic parameters" + + +# class TestHMC(unittest.TestCase): +# def test_hmc_sample(self): +# np.random.seed(12345) +# N = 50 +# pixelscale = 0.8 +# true_params = { +# "n": 2, +# "Re": 10, +# "Ie": 1, +# "center": [-3.3, 5.3], +# "q": 0.7, +# "PA": np.pi / 4, +# } +# target = ap.image.Target_Image( +# data=np.zeros((N, N)), +# pixelscale=pixelscale, +# ) + +# MODEL = ap.models.Sersic_Galaxy( +# name="sersic model", +# target=target, +# parameters=true_params, +# ) +# img = MODEL().data.detach().cpu().numpy() +# target.data = torch.Tensor( +# img +# + np.random.normal(scale=0.1, size=img.shape) +# + np.random.normal(scale=np.sqrt(img) / 10) +# ) +# target.variance = torch.Tensor(0.1**2 + img / 100) + +# HMC = ap.fit.HMC(MODEL, epsilon=1e-5, max_iter=5, warmup=2) +# HMC.fit() + + +# class TestNUTS(unittest.TestCase): +# def test_nuts_sample(self): +# np.random.seed(12345) +# N = 50 +# pixelscale = 0.8 +# true_params = { +# "n": 2, +# "Re": 10, +# "Ie": 1, +# "center": [-3.3, 5.3], +# "q": 0.7, +# "PA": np.pi / 4, +# } +# target = ap.image.Target_Image( +# data=np.zeros((N, N)), +# pixelscale=pixelscale, +# ) + +# MODEL = ap.models.Sersic_Galaxy( +# name="sersic model", +# target=target, +# parameters=true_params, +# ) +# img = MODEL().data.detach().cpu().numpy() +# target.data = torch.Tensor( +# img +# + np.random.normal(scale=0.1, size=img.shape) +# + np.random.normal(scale=np.sqrt(img) / 10) +# ) +# target.variance = torch.Tensor(0.1**2 + img / 100) + +# NUTS = ap.fit.NUTS(MODEL, max_iter=5, warmup=2) +# NUTS.fit() diff --git a/tests/test_group_models.py b/tests/test_group_models.py index 4477933c..9285c0ac 100644 --- a/tests/test_group_models.py +++ b/tests/test_group_models.py @@ -1,320 +1,129 @@ -import unittest - +import astrophot as ap import numpy as np -import torch import astrophot as ap from utils import make_basic_sersic, make_basic_gaussian_psf -class TestGroup(unittest.TestCase): - def test_groupmodel_creation(self): - np.random.seed(12345) - shape = (10, 15) - tar = ap.image.Target_Image( - data=np.random.normal(loc=0, scale=1.4, size=shape), - pixelscale=0.8, - variance=np.ones(shape) * (1.4**2), - ) - - mod1 = ap.models.Component_Model( - name="base model 1", - target=tar, - parameters={"center": {"value": [5, 5], "locked": True}}, - ) - mod2 = ap.models.Component_Model( - name="base model 2", - target=tar, - parameters={"center": {"value": [5, 5], "locked": True}}, - ) - - smod = ap.models.AstroPhot_Model( - name="group model", - model_type="group model", - models=[mod1, mod2], - psf_mode="none", - target=tar, - ) - - self.assertFalse(smod.locked, "default model state should not be locked") - - smod.initialize() - - self.assertTrue(torch.all(smod().data == 0), "model_image should be zeros") - - # add existing model does nothing - smod.add_model(mod1) - self.assertEqual(len(smod.models), 2, "Adding existing model should not change model count") - - # error for adding mdoels with the same name - mod3 = ap.models.Component_Model( - name="base model 2", - target=tar, - parameters={"center": {"value": [5, 5], "locked": True}}, - ) - with self.assertRaises(KeyError): - smod.add_model(mod3) - - # Warning for wrong kwarg name - with self.assertLogs(ap.AP_config.ap_logger.name, level="WARNING"): - ap.models.AstroPhot_Model( - name="group model", - model_type="group model", - model=[mod1, mod2], - psf_mode="none", - target=tar, - ) - - def test_jointmodel_creation(self): - np.random.seed(12345) - shape = (10, 15) - tar1 = ap.image.Target_Image( - data=np.random.normal(loc=0, scale=1.4, size=shape), - pixelscale=0.8, - variance=np.ones(shape) * (1.4**2), - ) - shape2 = (33, 42) - tar2 = ap.image.Target_Image( - data=np.random.normal(loc=0, scale=1.4, size=shape2), - pixelscale=0.3, - origin=(43.2, 78.01), - variance=np.ones(shape2) * (1.4**2), - ) - - tar = ap.image.Target_Image_List([tar1, tar2]) - - mod1 = ap.models.Flat_Sky( - name="base model 1", - target=tar1, - ) - mod2 = ap.models.Flat_Sky( - name="base model 2", - target=tar2, - ) - - smod = ap.models.AstroPhot_Model( - name="group model", - model_type="group model", - models=[mod1, mod2], - target=tar, - ) - - self.assertFalse(smod.locked, "default model state should not be locked") - - smod.initialize() - self.assertTrue( - torch.all(torch.isfinite(smod().flatten("data"))).item(), "model_image should be real" - ) - - fm = smod.fit_mask() - for fmi in fm: - self.assertTrue(torch.sum(fmi).item() == 0, "this fit_mask should not mask any pixels") - - def test_groupmodel_saveload(self): - np.random.seed(12345) - tar = make_basic_sersic(N=51, M=51) - - psf = ap.models.Moffat_PSF( - name="psf model 1", - target=make_basic_gaussian_psf(N=11), - parameters={ - "center": {"value": [5, 5], "locked": True}, - "n": 2.0, - "Rd": 3.0, - "I0": {"value": 0.0, "locked": True}, - }, - ) - - mod1 = ap.models.Sersic_Galaxy( - name="base model 1", - target=tar, - parameters={"center": {"value": [5, 5], "locked": False}}, - psf=psf, - psf_mode="full", - ) - mod2 = ap.models.Sersic_Galaxy( - name="base model 2", - target=tar, - parameters={"center": {"value": [5, 5], "locked": False}}, - ) - - smod = ap.models.AstroPhot_Model( - name="group model", - model_type="group model", - models=[mod1, mod2], - target=tar, - ) - - self.assertFalse(smod.locked, "default model state should not be locked") - - smod.initialize() - - self.assertTrue(torch.all(torch.isfinite(smod().data)), "model_image should be real values") - - smod.save("test_save_group_model.yaml") - - newmod = ap.models.AstroPhot_Model( - name="group model", - filename="test_save_group_model.yaml", - ) - self.assertEqual(len(smod.models), len(newmod.models), "Group model should load sub models") - - self.assertEqual(newmod.parameters.size, 16, "Group model size should sum all parameters") - - self.assertTrue( - torch.all(newmod.parameters.vector_values() == smod.parameters.vector_values()), - "Save/load should extract all parameters", - ) - - -class TestPSFGroup(unittest.TestCase): - def test_psfgroupmodel_creation(self): - tar = make_basic_gaussian_psf() - - mod1 = ap.models.AstroPhot_Model( - name="base model 1", - model_type="moffat psf model", - target=tar, - ) - - mod2 = ap.models.AstroPhot_Model( - name="base model 2", - model_type="moffat psf model", - target=tar, - ) - - smod = ap.models.AstroPhot_Model( - name="group model", - model_type="psf group model", - models=[mod1, mod2], - target=tar, - ) - - smod.initialize() - - self.assertTrue( - torch.all(smod().data >= 0), - "PSF group sample should be greater than or equal to zero", - ) - - def test_psfgroupmodel_saveload(self): - np.random.seed(12345) - tar = make_basic_gaussian_psf() - - psf1 = ap.models.Moffat_PSF( - name="psf model 1", - target=tar, - parameters={ - "n": 2.0, - "Rd": 3.0, - }, - ) - - psf2 = ap.models.Sersic_PSF( - name="psf model 2", - target=tar, - parameters={ - "n": 2.0, - "Re": 3.0, - }, - ) - - smod = ap.models.AstroPhot_Model( - name="group model", - model_type="psf group model", - models=[psf1, psf2], - target=tar, - ) - - smod.initialize() - - self.assertTrue(torch.all(torch.isfinite(smod().data)), "psf_image should be real values") - - smod.save("test_save_psfgroup_model.yaml") - - newmod = ap.models.AstroPhot_Model( - name="group model", - filename="test_save_psfgroup_model.yaml", - ) - self.assertEqual(len(smod.models), len(newmod.models), "Group model should load sub models") - - self.assertEqual(newmod.parameters.size, 4, "Group model size should sum all parameters") - - self.assertTrue( - torch.all(newmod.parameters.vector_values() == smod.parameters.vector_values()), - "Save/load should extract all parameters", - ) - - def test_psfgroupmodel_fitting(self): - - np.random.seed(124) - pixelscale = 1.0 - psf1 = ap.utils.initialize.moffat_psf(1.0, 4.0, 101, pixelscale, normalize=False) - psf2 = ap.utils.initialize.moffat_psf(3.0, 2.0, 101, pixelscale, normalize=False) - psf = psf1 + 0.5 * psf2 - psf /= psf.sum() - star = psf * 10 # flux of 10 - variance = star / 1e5 - star += np.random.normal(scale=np.sqrt(variance)) - - psf_target2 = ap.image.PSF_Image( - data=star.copy() / star.sum(), # empirical PSF from cutout - pixelscale=pixelscale, - ) - psf_target2.normalize() - - point_target = ap.image.Target_Image( - data=star, # cutout of star - pixelscale=pixelscale, - variance=variance, - ) - - moffat_component1 = ap.models.AstroPhot_Model( - name="psf part1", - model_type="moffat psf model", - target=psf_target2, - parameters={ - "n": 1.5, - "Rd": 4.5, - "I0": {"value": -3.0, "locked": False}, - }, - normalize_psf=False, - ) - - moffat_component2 = ap.models.AstroPhot_Model( - name="psf part2", - model_type="moffat psf model", - target=psf_target2, - parameters={ - "n": 2.6, - "Rd": 1.7, - "I0": {"value": -2.3, "locked": False}, - }, - normalize_psf=False, - ) - - full_psf_model = ap.models.AstroPhot_Model( - name="full psf", - model_type="psf group model", - target=psf_target2, - models=[moffat_component1, moffat_component2], - normalize_psf=True, - ) - full_psf_model.initialize() - - model = ap.models.AstroPhot_Model( - name="star", - model_type="point model", - target=point_target, - psf=full_psf_model, - ) - model.initialize() - - ap.fit.LM(model, verbose=1).fit() - - self.assertTrue( - abs(model["flux"].value.item() - 1.0) < 1e-2, "Star flux should be accurate" - ) - self.assertTrue( - model["flux"].uncertainty.item() < 1e-2, "Star flux uncertainty should be small" - ) +def test_jointmodel_creation(): + np.random.seed(12345) + shape = (10, 15) + tar1 = ap.TargetImage( + name="target1", + data=np.random.normal(loc=0, scale=1.4, size=shape), + pixelscale=0.8, + variance=np.ones(shape) * (1.4**2), + ) + shape2 = (33, 42) + tar2 = ap.TargetImage( + name="target2", + data=np.random.normal(loc=0, scale=1.4, size=shape2), + pixelscale=0.3, + variance=np.ones(shape2) * (1.4**2), + ) + + tar = ap.TargetImageList([tar1, tar2]) + + mod1 = ap.models.FlatSky( + name="base model 1", + target=tar1, + ) + mod2 = ap.models.FlatSky( + name="base model 2", + target=tar2, + ) + + smod = ap.Model( + name="group model", + model_type="group model", + models=[mod1, mod2], + target=tar, + ) + + smod.initialize() + assert ap.backend.all( + ap.backend.isfinite(smod().flatten("data")) + ).item(), "model_image should be real" + + fm = smod.fit_mask() + for fmi in fm: + assert ap.backend.sum(fmi).item() == 0, "this fit_mask should not mask any pixels" + + +def test_psfgroupmodel_creation(): + tar = make_basic_gaussian_psf() + + mod1 = ap.Model( + name="base model 1", + model_type="moffat psf model", + target=tar, + ) + + mod2 = ap.Model( + name="base model 2", + model_type="moffat psf model", + target=tar, + ) + + smod = ap.Model( + name="group model", + model_type="psf group model", + models=[mod1, mod2], + target=tar, + ) + + smod.initialize() + + assert ap.backend.all( + smod().data >= 0 + ), "PSF group sample should be greater than or equal to zero" + + +def test_joint_multi_band_multi_object(): + target1 = make_basic_sersic(52, 53, name="target1") + target2 = make_basic_sersic(48, 65, name="target2") + target3 = make_basic_sersic(60, 49, name="target3") + target4 = make_basic_sersic(60, 49, name="target4") + + # fmt: off + model11 = ap.Model(name="model11", model_type="sersic galaxy model", window=(0, 50, 5, 52), target=target1) + model12 = ap.Model(name="model12", model_type="sersic galaxy model", window=(3, 53, 0, 49), target=target1) + model1 = ap.Model(name="model1", model_type="group model", models=[model11, model12], target=target1) + + model21 = ap.Model(name="model21", model_type="sersic galaxy model", window=(1, 62, 10, 48), target=target2) + model22 = ap.Model(name="model22", model_type="sersic galaxy model", window=(2, 60, 5, 49), target=target2) + model2 = ap.Model(name="model2", model_type="group model", models=[model21, model22], target=target2) + + model31 = ap.Model(name="model31", model_type="sersic galaxy model", window=(1, 62, 10, 48), target=target3) + model32 = ap.Model(name="model32", model_type="sersic galaxy model", window=(2, 60, 5, 49), target=target3) + model3 = ap.Model(name="model3", model_type="group model", models=[model31, model32], target=target3) + + model4 = ap.Model(name="model4", model_type="sersic galaxy model", window=(0, 53, 0, 52), target=target1) + + model51 = ap.Model(name="model51", model_type="sersic galaxy model", window=(0, 65, 0, 48), target=target2) + model52 = ap.Model(name="model52", model_type="sersic galaxy model", window=(0, 49, 0, 60), target=target3) + model5 = ap.Model(name="model5", model_type="group model", models=[model51, model52], target=ap.TargetImageList([target2, target3])) + + model = ap.Model(name="joint model", model_type="group model", models=[model1, model2, model3, model4, model5], target=ap.TargetImageList([target1, target2, target3, target4])) + # fmt: on + + model.initialize() + mask = model.fit_mask() + assert len(mask) == 4, "There should be 4 fit masks for the 4 targets" + for m in mask: + assert ap.backend.all(ap.backend.isfinite(m)), "this fit_mask should be finite" + sample = model.sample(window=ap.WindowList([target1.window, target2.window, target3.window])) + assert isinstance(sample, ap.ImageList), "Sample should be an ImageList" + for image in sample: + assert ap.backend.all(ap.backend.isfinite(image.data)), "Sample image data should be finite" + assert ap.backend.all(image.data >= 0), "Sample image data should be non-negative" + + jacobian = model.jacobian() + assert isinstance(jacobian, ap.ImageList), "Jacobian should be an ImageList" + for image in jacobian: + assert ap.backend.all( + ap.backend.isfinite(image.data) + ), "Jacobian image data should be finite" + + window = model.window + assert isinstance(window, ap.WindowList), "Window should be a WindowList" diff --git a/tests/test_image.py b/tests/test_image.py index 02919ecc..50e03415 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -1,684 +1,365 @@ -import unittest -from astrophot import image import astrophot as ap -import torch import numpy as np -from utils import get_astropy_wcs, make_basic_sersic +from utils import make_basic_sersic, get_astropy_wcs +import pytest ###################################################################### # Image Objects ###################################################################### -class TestImage(unittest.TestCase): - def test_image_creation(self): - arr = torch.zeros((10, 15)) - base_image = image.Image( - data=arr, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2), - metadata={"note": "test image"}, - ) - - self.assertEqual(base_image.pixel_length, 1.0, "image should track pixelscale") - self.assertEqual(base_image.zeropoint, 1.0, "image should track zeropoint") - self.assertEqual(base_image.origin[0], 0, "image should track origin") - self.assertEqual(base_image.origin[1], 0, "image should track origin") - self.assertEqual(base_image.metadata["note"], "test image", "image should track note") - - slicer = image.Window(origin=(3, 2), pixel_shape=(4, 5)) - sliced_image = base_image[slicer] - self.assertEqual(sliced_image.origin[0], 3, "image should track origin") - self.assertEqual(sliced_image.origin[1], 2, "image should track origin") - self.assertEqual(base_image.origin[0], 0, "subimage should not change image origin") - self.assertEqual(base_image.origin[1], 0, "subimage should not change image origin") - - second_base_image = image.Image(data=arr, pixelscale=1.0, metadata={"note": "test image"}) - self.assertEqual(base_image.pixel_length, 1.0, "image should track pixelscale") - self.assertIsNone(second_base_image.zeropoint, "image should track zeropoint") - self.assertEqual(second_base_image.origin[0], 0, "image should track origin") - self.assertEqual(second_base_image.origin[1], 0, "image should track origin") - self.assertEqual( - second_base_image.metadata["note"], "test image", "image should track note" - ) - - def test_copy(self): - - new_image = image.Image( - data=torch.zeros((10, 15)), - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - ) - - copy_image = new_image.copy() - self.assertEqual( - new_image.pixel_length, - copy_image.pixel_length, - "copied image should have same pixelscale", - ) - self.assertEqual( - new_image.zeropoint, - copy_image.zeropoint, - "copied image should have same zeropoint", - ) - self.assertEqual( - new_image.window, copy_image.window, "copied image should have same window" - ) - copy_image += 1 - self.assertEqual( - new_image.data[0][0], - 0.0, - "copied image should not share data with original", - ) - - blank_copy_image = new_image.blank_copy() - self.assertEqual( - new_image.pixel_length, - blank_copy_image.pixel_length, - "copied image should have same pixelscale", - ) - self.assertEqual( - new_image.zeropoint, - blank_copy_image.zeropoint, - "copied image should have same zeropoint", - ) - self.assertEqual( - new_image.window, - blank_copy_image.window, - "copied image should have same window", - ) - blank_copy_image += 1 - self.assertEqual( - new_image.data[0][0], - 0.0, - "copied image should not share data with original", - ) - - def test_image_arithmetic(self): - - arr = torch.zeros((10, 12)) - base_image = image.Image( - data=arr, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.ones(2), - ) - slicer = image.Window(origin=(0, 0), pixel_shape=(5, 5)) - sliced_image = base_image[slicer] - sliced_image += 1 - - self.assertEqual(base_image.data[1][1], 1, "slice should update base image") - self.assertEqual(base_image.data[5][5], 0, "slice should only update its region") - - second_image = image.Image( - data=torch.ones((5, 5)), - pixelscale=1.0, - zeropoint=1.0, - origin=[3, 3], - ) - - # Test iadd - base_image += second_image - self.assertEqual(base_image.data[1][1], 1, "image addition should only update its region") - self.assertEqual(base_image.data[3][3], 2, "image addition should update its region") - self.assertEqual(base_image.data[5][5], 1, "image addition should update its region") - self.assertEqual(base_image.data[8][8], 0, "image addition should only update its region") - - # Test isubtract - base_image -= second_image - self.assertEqual( - base_image.data[1][1], 1, "image subtraction should only update its region" - ) - self.assertEqual(base_image.data[3][3], 1, "image subtraction should update its region") - self.assertEqual(base_image.data[5][5], 0, "image subtraction should update its region") - self.assertEqual( - base_image.data[8][8], 0, "image subtraction should only update its region" - ) - - base_image.data[6:, 6:] += 1.0 - - self.assertEqual(base_image.data[1][1], 1, "array addition should only update its region") - self.assertEqual(base_image.data[6][6], 1, "array addition should update its region") - self.assertEqual(base_image.data[8][8], 1, "array addition should update its region") - - def test_excersize_arithmatic(self): - - arr = torch.zeros((10, 12)) - base_image = image.Image( - data=arr, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.ones(2), - ) - second_image = image.Image( - data=torch.ones((5, 5)), - pixelscale=1.0, - zeropoint=1.0, - origin=[3, 3], - ) - - new_img = base_image + second_image - new_img = new_img - second_image - - self.assertTrue( - torch.allclose(new_img.data, torch.zeros_like(new_img.data)), - "addition and subtraction should produce no change", - ) - - base_image += second_image - base_image -= second_image - - self.assertTrue( - torch.allclose(base_image.data, torch.zeros_like(base_image.data)), - "addition and subtraction should produce no change", - ) - - new_img = base_image + 10.0 - new_img = new_img - 10.0 - - self.assertTrue( - torch.allclose(new_img.data, torch.zeros_like(new_img.data)), - "addition and subtraction should produce no change", - ) - - base_image += 10.0 - base_image -= 10.0 - - self.assertTrue( - torch.allclose(base_image.data, torch.zeros_like(base_image.data)), - "addition and subtraction should produce no change", - ) - - def test_image_manipulation(self): - - new_image = image.Image( - data=torch.ones((16, 32)), - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - ) - - # image reduction - for scale in [2, 4, 8, 16]: - reduced_image = new_image.reduce(scale) - - self.assertEqual( - reduced_image.data[0][0], - scale**2, - "reduced image should sum sub pixels", - ) - self.assertEqual( - reduced_image.pixel_length, - scale, - "pixelscale should increase with reduced image", - ) - self.assertEqual( - reduced_image.origin[0], - new_image.origin[0], - "origin should not change with reduced image", - ) - self.assertEqual( - reduced_image.shape[0], - new_image.shape[0], - "shape should not change with reduced image", - ) - - # image cropping - new_image.crop( - [torch.tensor(1, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device)] - ) - self.assertEqual( - new_image.data.shape[0], 14, "crop should cut 1 pixel from both sides here" - ) - new_image.crop( - torch.tensor([3, 2], dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device) - ) - self.assertEqual( - new_image.data.shape[1], - 24, - "previous crop and current crop should have cut from this axis", - ) - new_image.crop( - torch.tensor([3, 2, 1, 0], dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device) - ) - self.assertEqual( - new_image.data.shape[0], - 9, - "previous crop and current crop should have cut from this axis", - ) - - def test_image_save_load(self): - - new_image = image.Image( - data=torch.ones((16, 32)), - pixelscale=0.76, - zeropoint=21.4, - origin=torch.zeros(2) + 0.1, - ) - - new_image.save("Test_AstroPhot.fits") - - loaded_image = ap.image.Image(filename="Test_AstroPhot.fits") - - self.assertTrue( - torch.all(new_image.data == loaded_image.data), - "Loaded image should have same pixel values", - ) - self.assertTrue( - torch.all(new_image.origin == loaded_image.origin), - "Loaded image should have same origin", - ) - self.assertEqual( - new_image.pixel_length, - loaded_image.pixel_length, - "Loaded image should have same pixel scale", - ) - self.assertEqual( - new_image.zeropoint, - loaded_image.zeropoint, - "Loaded image should have same zeropoint", - ) - - def test_image_wcs_roundtrip(self): - - wcs = get_astropy_wcs() - # Minimal input - I = ap.image.Image( - data=torch.zeros((20, 20)), - zeropoint=22.5, - wcs=wcs, - ) - - self.assertTrue( - torch.allclose( - I.world_to_plane(I.plane_to_world(torch.zeros_like(I.window.reference_radec))), - torch.zeros_like(I.window.reference_radec), - ), - "WCS world/plane roundtrip should return input value", - ) - self.assertTrue( - torch.allclose( - I.pixel_to_plane(I.plane_to_pixel(torch.zeros_like(I.window.reference_radec))), - torch.zeros_like(I.window.reference_radec), - ), - "WCS pixel/plane roundtrip should return input value", - ) - self.assertTrue( - torch.allclose( - I.world_to_pixel(I.pixel_to_world(torch.zeros_like(I.window.reference_radec))), - torch.zeros_like(I.window.reference_radec), - atol=1e-6, - ), - "WCS world/pixel roundtrip should return input value", - ) - - self.assertTrue( - torch.allclose( - I.pixel_to_plane_delta( - I.plane_to_pixel_delta(torch.ones_like(I.window.reference_radec)) - ), - torch.ones_like(I.window.reference_radec), - ), - "WCS pixel/plane delta roundtrip should return input value", - ) - - def test_image_display(self): - new_image = image.Image( - data=torch.ones((16, 32)), - pixelscale=0.76, - zeropoint=21.4, - origin=torch.zeros(2) + 0.1, - ) - - self.assertIsInstance(str(new_image), str, "String representation should be a string!") - self.assertIsInstance(repr(new_image), str, "Repr should be a string!") - - def test_image_errors(self): - - new_image = image.Image( - data=torch.ones((16, 32)), - pixelscale=0.76, - zeropoint=21.4, - origin=torch.zeros(2) + 0.1, - ) - - # Change data badly - with self.assertRaises(ap.errors.SpecificationConflict): - new_image.data = np.zeros((5, 5)) - - # Fractional image reduction - with self.assertRaises(ap.errors.SpecificationConflict): - reduced = new_image.reduce(0.2) - - # Negative expand image - with self.assertRaises(ap.errors.SpecificationConflict): - unexpanded = new_image.expand((-2, 3)) - - -class TestTargetImage(unittest.TestCase): - def test_variance(self): - - new_image = image.Target_Image( - data=torch.ones((16, 32)), - variance=torch.ones((16, 32)), - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - ) - - self.assertTrue(new_image.has_variance, "target image should store variance") - - reduced_image = new_image.reduce(2) - self.assertEqual(reduced_image.variance[0][0], 4, "reduced image should sum sub pixels") - - new_image.to() - new_image.variance = None - self.assertFalse(new_image.has_variance, "target image update to no variance") - - def test_mask(self): - - new_image = image.Target_Image( - data=torch.ones((16, 32)), - mask=torch.ones((16, 32)), - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - ) - self.assertTrue(new_image.has_mask, "target image should store mask") - - reduced_image = new_image.reduce(2) - self.assertEqual(reduced_image.mask[0][0], 1, "reduced image should mask appropriately") - - new_image.mask = None - self.assertFalse(new_image.has_mask, "target image update to no mask") - - data = torch.ones((16, 32)) - data[1, 1] = torch.nan - data[5, 5] = torch.nan - - new_image = image.Target_Image( - data=data, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - ) - self.assertTrue(new_image.has_mask, "target image with nans should create mask") - self.assertEqual(new_image.mask[1][1].item(), True, "nan should be masked") - self.assertEqual(new_image.mask[5][5].item(), True, "nan should be masked") - - def test_psf(self): - - new_image = image.Target_Image( - data=torch.ones((15, 33)), - psf=torch.ones((9, 9)), - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - ) - self.assertTrue(new_image.has_psf, "target image should store variance") - self.assertEqual( - new_image.psf.psf_border_int[0], - 5, - "psf border should be half psf size, rounded up ", - ) - - reduced_image = new_image.reduce(3) - self.assertEqual( - reduced_image.psf.data[0][0], - 9, - "reduced image should sum sub pixels in psf", - ) - - new_image.psf = None - self.assertFalse(new_image.has_psf, "target image update to no variance") - - def test_reduce(self): - new_image = image.Target_Image( - data=torch.ones((30, 36)), - psf=torch.ones((9, 9)), - variance="auto", - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - ) - smaller_image = new_image.reduce(3) - self.assertEqual(smaller_image.data[0][0], 9, "reduction should sum flux") - self.assertEqual( - tuple(smaller_image.data.shape), - (10, 12), - "reduction should decrease image size", - ) - self.assertEqual(smaller_image.psf.data[0][0], 9, "reduction should sum psf flux") - self.assertEqual( - tuple(smaller_image.psf.data.shape), - (3, 3), - "reduction should decrease psf image size", - ) - - def test_target_save_load(self): - new_image = image.Target_Image( - data=torch.ones((16, 32)), - variance="auto", - mask=torch.zeros((16, 32)), - psf=torch.ones((9, 9)), - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - ) - - new_image.save("Test_target_AstroPhot.fits") - - loaded_image = ap.image.Target_Image(filename="Test_target_AstroPhot.fits") - - self.assertTrue( - torch.all(new_image.variance == loaded_image.variance), - "Loaded image should have same variance", - ) - self.assertTrue( - torch.all(new_image.psf.data == loaded_image.psf.data), - "Loaded image should have same psf", - ) - - def test_auto_var(self): - target = make_basic_sersic() - target.variance = "auto" - - def test_target_errors(self): - new_image = image.Target_Image( - data=torch.ones((16, 32)), - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - ) - - # bad variance - with self.assertRaises(ap.errors.SpecificationConflict): - new_image.variance = np.ones((5, 5)) - - # bad mask - with self.assertRaises(ap.errors.SpecificationConflict): - new_image.mask = np.zeros((5, 5)) - - -class TestPSFImage(unittest.TestCase): - def test_copying(self): - psf_image = image.PSF_Image( - data=torch.ones((15, 15)), - pixelscale=1.0, - ) - - copy_psf = psf_image.copy() - self.assertEqual( - psf_image.data[0][0], - copy_psf.data[0][0], - "copied image should have same data", - ) - blank_psf = psf_image.blank_copy() - self.assertNotEqual( - psf_image.data[0][0], - blank_psf.data[0][0], - "blank copied image should not have same data", - ) - - psf_image.to(dtype=torch.float32) - - def test_reducing(self): - psf_image = image.PSF_Image( - data=torch.ones((15, 15)), - pixelscale=1.0, - ) - new_image = image.Target_Image( - data=torch.ones((36, 45)), - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - psf=psf_image, - ) - - reduce_image = new_image.reduce(3) - self.assertEqual( - tuple(reduce_image.psf.data.shape), - (5, 5), - "reducing image should reduce psf", - ) - self.assertEqual( - reduce_image.psf.pixel_length, - 3, - "reducing image should update pixelscale factor", - ) - - def test_psf_errors(self): - with self.assertRaises(ap.errors.SpecificationConflict): - psf_image = image.PSF_Image( - data=torch.ones((18, 15)), - pixelscale=1.0, - ) - - -class TestModelImage(unittest.TestCase): - def test_replace(self): - new_image = image.Model_Image( - data=torch.ones((16, 32)), - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - ) - other_image = image.Model_Image( - data=5 * torch.ones((4, 4)), - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 4 + 0.1, - ) - - new_image.replace(other_image) - new_image.replace(other_image.window, other_image.data) - - self.assertEqual( - new_image.data[0][0], - 1, - "image replace should occur at proper location in image, this data should be untouched", - ) - self.assertEqual( - new_image.data[5][5], 5, "image replace should update values in its window" - ) - - def test_shift(self): - - new_image = image.Model_Image( - data=torch.ones((16, 32)), - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - ) - new_image.shift_origin( - torch.tensor((-0.1, -0.1), dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - is_prepadded=False, - ) - - self.assertAlmostEqual( - torch.sum(new_image.data).item(), - 16 * 32, - delta=1, - msg="Shifting field of ones should give field of ones", - ) - - def test_errors(self): - - with self.assertRaises(ap.errors.InvalidData): - new_image = image.Model_Image() - - -class TestJacobianImage(unittest.TestCase): - def test_jacobian_add(self): - - new_image = ap.image.Jacobian_Image( - parameters=["a", "b", "c"], - target_identity="target1", - data=torch.ones((16, 32, 3)), - pixelscale=1.0, - zeropoint=1.0, - window=ap.image.Window(origin=torch.zeros(2) + 0.1, pixel_shape=torch.tensor((32, 16))), - ) - other_image = ap.image.Jacobian_Image( - parameters=["b", "d"], - target_identity="target1", - data=5 * torch.ones((4, 4, 2)), - pixelscale=1.0, - zeropoint=1.0, - window=ap.image.Window( - origin=torch.zeros(2) + 4 + 0.1, pixel_shape=torch.tensor((4, 4)) - ), - ) - - new_image += other_image - - self.assertEqual( - tuple(new_image.data.shape), - (16, 32, 4), - "Jacobian addition should manage parameter identities", - ) - self.assertEqual( - tuple(new_image.flatten("data").shape), - (512, 4), - "Jacobian should flatten to Npix*Nparams tensor", - ) - - def test_jacobian_error(self): - - # Create parameter list with multiple same entries - with self.assertRaises(ap.errors.SpecificationConflict): - new_image = ap.image.Jacobian_Image( - parameters=["a", "b", "c", "a"], - target_identity="target1", - data=torch.ones((16, 32, 3)), - pixelscale=1.0, - zeropoint=1.0, - window=ap.image.Window( - origin=torch.zeros(2) + 0.1, pixel_shape=torch.tensor((32, 16)) - ), - ) - - # Adding a model image to a jacobian image - new_image = ap.image.Jacobian_Image( - parameters=["a", "b", "c"], - target_identity="target1", - data=torch.ones((16, 32, 3)), - pixelscale=1.0, - zeropoint=1.0, - window=ap.image.Window(origin=torch.zeros(2) + 0.1, pixel_shape=torch.tensor((32, 16))), - ) - bad_image = image.Model_Image( - data=torch.ones((16, 32)), - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2) + 0.1, - ) - with self.assertRaises(ap.errors.InvalidImage): - new_image += bad_image - - -if __name__ == "__main__": - unittest.main() +@pytest.fixture() +def base_image(): + arr = np.zeros((10, 15)) + return ap.Image( + data=arr, + pixelscale=1.0, + zeropoint=1.0, + ) + + +def test_image_creation(base_image): + base_image.to() + assert base_image.pixelscale == 1.0, "image should track pixelscale" + assert base_image.zeropoint == 1.0, "image should track zeropoint" + assert base_image.crpix[0] == 0, "image should track crpix" + assert base_image.crpix[1] == 0, "image should track crpix" + + base_image.to(dtype=ap.backend.float64) + slicer = ap.Window((7, 13, 4, 7), base_image) + sliced_image = base_image[slicer] + assert sliced_image.crpix[0] == -7, "crpix of subimage should give relative position" + assert sliced_image.crpix[1] == -4, "crpix of subimage should give relative position" + assert sliced_image._data.shape == (6, 3), "sliced image should have correct shape" + + +def test_copy(base_image): + copy_image = base_image.copy() + assert ( + base_image.pixelscale == copy_image.pixelscale + ), "copied image should have same pixelscale" + assert base_image.zeropoint == copy_image.zeropoint, "copied image should have same zeropoint" + assert ( + base_image.window.extent == copy_image.window.extent + ), "copied image should have same window" + copy_image += 1 + assert base_image._data[0][0] == 0.0, "copied image should not share data with original" + + blank_copy_image = base_image.blank_copy() + assert ( + base_image.pixelscale == blank_copy_image.pixelscale + ), "copied image should have same pixelscale" + assert ( + base_image.zeropoint == blank_copy_image.zeropoint + ), "copied image should have same zeropoint" + assert ( + base_image.window.extent == blank_copy_image.window.extent + ), "copied image should have same window" + blank_copy_image += 1 + assert base_image._data[0][0] == 0.0, "copied image should not share data with original" + + +def test_image_arithmetic(base_image): + slicer = ap.Window((-1, 5, 6, 15), base_image) + sliced_image = base_image[slicer] + sliced_image += 1 + + assert base_image._data[1][8] == 0, "slice should not update base image" + assert base_image._data[5][5] == 0, "slice should not update base image" + + second_image = ap.Image( + data=np.ones((5, 5)), + pixelscale=1.0, + zeropoint=1.0, + crpix=(-1, 1), + ) + + # Test iadd + base_image += second_image + assert base_image._data[0][0] == 0, "image addition should only update its region" + assert base_image._data[3][3] == 1, "image addition should update its region" + assert base_image._data[3][4] == 0, "image addition should only update its region" + assert base_image._data[5][3] == 1, "image addition should update its region" + + # Test isubtract + base_image -= second_image + assert ap.backend.allclose( + base_image.data, ap.backend.zeros_like(base_image.data) + ), "image subtraction should only update its region" + + +def test_image_manipulation(): + new_image = ap.Image( + data=np.ones((16, 32)), + pixelscale=1.0, + zeropoint=1.0, + ) + + # image reduction + for scale in [2, 4, 8, 16]: + reduced_image = new_image.reduce(scale) + + assert reduced_image._data[0][0] == scale**2, "reduced image should sum sub pixels" + assert reduced_image.pixelscale == scale, "pixelscale should increase with reduced image" + + # image cropping + crop_image = new_image.crop([1]) + assert crop_image._data.shape[1] == 14, "crop should cut 1 pixel from both sides here" + crop_image = new_image.crop([3, 2]) + assert ( + crop_image._data.shape[0] == 26 + ), "crop should have cut 3 pixels from both sides of this axis" + crop_image = new_image.crop([3, 2, 1, 0]) + assert ( + crop_image._data.shape[0] == 27 + ), "crop should have cut 3 pixels from left, 2 from right, 1 from top, and 0 from bottom" + + +def test_image_save_load(): + new_image = ap.Image( + data=np.ones((16, 32)), + pixelscale=0.76, + zeropoint=21.4, + crtan=(8.0, 1.2), + crpix=(2, 3), + crval=(100.0, -32.1), + ) + + new_image.save("Test_AstroPhot.fits") + + loaded_image = ap.Image(filename="Test_AstroPhot.fits") + + assert ap.backend.allclose( + new_image.data, loaded_image.data + ), "Loaded image should have same pixel values" + assert ap.backend.allclose( + new_image.crtan.value, loaded_image.crtan.value + ), "Loaded image should have same tangent plane origin" + assert np.all( + new_image.crpix == loaded_image.crpix + ), "Loaded image should have same reference pixel" + assert ap.backend.allclose( + new_image.crval.value, loaded_image.crval.value + ), "Loaded image should have same reference world coordinates" + assert ap.backend.allclose( + new_image.pixelscale, loaded_image.pixelscale + ), "Loaded image should have same pixel scale" + assert ap.backend.allclose( + new_image.CD.value, loaded_image.CD.value + ), "Loaded image should have same pixel scale" + assert new_image.zeropoint == loaded_image.zeropoint, "Loaded image should have same zeropoint" + + +def test_image_wcs_roundtrip(): + # Minimal input + I = ap.Image( + data=np.zeros((21, 21)), + zeropoint=22.5, + crpix=(10, 10), + crtan=(1.0, -10.0), + crval=(160.0, 45.0), + CD=0.05 + * np.array( + [[np.cos(np.pi / 4), -np.sin(np.pi / 4)], [np.sin(np.pi / 4), np.cos(np.pi / 4)]] + ), + ) + + assert ap.backend.allclose( + ap.backend.stack(I.world_to_plane(*I.plane_to_world(*I.center))), + I.center, + ), "WCS world/plane roundtrip should return input value" + assert ap.backend.allclose( + ap.backend.stack(I.pixel_to_plane(*I.plane_to_pixel(*I.center))), + I.center, + ), "WCS pixel/plane roundtrip should return input value" + assert ap.backend.allclose( + ap.backend.stack(I.world_to_pixel(*I.pixel_to_world(*ap.backend.zeros_like(I.center)))), + ap.backend.zeros_like(I.center), + atol=1e-6, + ), "WCS world/pixel roundtrip should return input value" + + +def test_target_image_variance(): + new_image = ap.TargetImage( + data=np.ones((16, 32)), + variance=2 * np.ones((16, 32)), + pixelscale=1.0, + zeropoint=1.0, + ) + + assert new_image.variance[0][0] == 2, "target image should store variance" + + reduced_image = new_image.reduce(2) + assert reduced_image.variance[0][0] == 8, "reduced image should sum sub pixels" + + new_image.variance = None + assert new_image.variance[0][0] == 1, "target image update to neutral variance" + + +def test_target_image_mask(): + new_image = ap.TargetImage( + data=np.ones((16, 32)), + mask=np.arange(16 * 32).reshape((16, 32)) % 4 == 0, + pixelscale=1.0, + zeropoint=1.0, + ) + assert ap.backend.sum(new_image.mask) > 0, "target image should store mask" + + reduced_image = new_image.reduce(2) + assert reduced_image._mask[0][0] == 1, "reduced image should mask appropriately" + assert reduced_image._mask[1][0] == 0, "reduced image should mask appropriately" + + new_image.mask = None + assert ap.backend.sum(new_image.mask) == 0, "target image update to no mask" + + data = np.ones((16, 32)) + data[1, 1] = np.nan + data[5, 5] = np.nan + + new_image = ap.TargetImage( + data=data, + pixelscale=1.0, + zeropoint=1.0, + ) + assert ap.backend.sum(new_image.mask) > 0, "target image with nans should create mask" + assert new_image._mask[1][1].item() == True, "nan should be masked" + assert new_image._mask[5][5].item() == True, "nan should be masked" + + +def test_target_image_psf(): + new_image = ap.TargetImage( + data=np.ones((15, 33)), + psf=np.ones((9, 9)), + pixelscale=1.0, + zeropoint=1.0, + ) + assert new_image.has_psf, "target image should store variance" + assert new_image.psf.psf_pad == 4, "psf border should be half psf size" + + reduced_image = new_image.reduce(3) + assert reduced_image.psf._data[0][0] == 9, "reduced image should sum sub pixels in psf" + + new_image.psf = None + assert not new_image.has_psf, "target image update to no psf" + + +def test_target_image_reduce(): + new_image = ap.TargetImage( + data=np.ones((30, 36)), + psf=np.ones((9, 9)), + variance="auto", + pixelscale=1.0, + zeropoint=1.0, + ) + smaller_image = new_image.reduce(3) + assert smaller_image._data[0][0] == 9, "reduction should sum flux" + assert tuple(smaller_image._data.shape) == (12, 10), "reduction should decrease image size" + + +def test_target_image_save_load(): + new_image = ap.TargetImage( + data=np.ones((16, 32)), + variance=np.ones((16, 32)), + mask=np.zeros((16, 32)), + psf=np.ones((9, 9)), + CD=[[1.0, 0.0], [0.0, 1.5]], + zeropoint=1.0, + ) + + new_image.save("Test_target_AstroPhot.fits") + + loaded_image = ap.TargetImage(filename="Test_target_AstroPhot.fits") + + assert ap.backend.allclose( + new_image.data, loaded_image.data + ), "Loaded image should have same pixel values" + assert ap.backend.allclose( + new_image.mask, loaded_image.mask + ), "Loaded image should have same mask" + assert ap.backend.allclose( + new_image.variance, loaded_image.variance + ), "Loaded image should have same variance" + assert ap.backend.allclose( + new_image.psf.data, loaded_image.psf.data + ), "Loaded image should have same psf" + assert ap.backend.allclose( + new_image.CD.value, loaded_image.CD.value + ), "Loaded image should have same pixel scale" + + +def test_target_image_auto_var(): + target = make_basic_sersic() + target.variance = "auto" + + +def test_target_image_errors(): + new_image = ap.TargetImage( + data=np.ones((16, 32)), + pixelscale=1.0, + zeropoint=1.0, + ) + + # bad variance + with pytest.raises(ap.errors.SpecificationConflict): + new_image.variance = np.ones((5, 5)) + + # bad mask + with pytest.raises(ap.errors.SpecificationConflict): + new_image.mask = np.zeros((5, 5)) + + +def test_psf_image_copying(): + psf_image = ap.PSFImage( + data=np.ones((15, 15)), + ) + + assert psf_image.psf_pad == 7, "psf image should have correct psf_pad" + psf_image.normalize() + assert np.allclose( + ap.backend.to_numpy(psf_image._data), 1 / 15**2 + ), "psf image should normalize to sum to 1" + + +def test_jacobian_add(): + new_image = ap.JacobianImage( + parameters=["a", "b", "c"], + data=np.ones((16, 32, 3)), + ) + other_image = ap.JacobianImage( + parameters=["b", "d"], + data=5 * np.ones((4, 4, 2)), + ) + + new_image += other_image + + assert tuple(new_image._data.shape) == ( + 32, + 16, + 3, + ), "Jacobian addition should manage parameter identities" + assert tuple(new_image.flatten("data").shape) == ( + 512, + 3, + ), "Jacobian should flatten to Npix*Nparams tensor" + assert new_image._data[0, 0, 0].item() == 1, "Jacobian addition should not change original data" + assert new_image._data[0, 0, 1].item() == 6, " Jacobian addition should add correctly" + + +def test_image_with_wcs(): + WCS = get_astropy_wcs() + image = ap.TargetImage( + data=np.ones((170, 180)), + wcs=WCS, + ) + assert image._data.shape[0] == WCS.pixel_shape[0], "Image should have correct shape from WCS" + assert image._data.shape[1] == WCS.pixel_shape[1], "Image should have correct shape from WCS" + assert np.allclose( + image.CD.value * ap.utils.conversions.units.arcsec_to_deg, WCS.pixel_scale_matrix + ), "Image should have correct CD from WCS" + assert np.allclose( + image.crpix, WCS.wcs.crpix[::-1] - 1 + ), "Image should have correct CRPIX from WCS" + assert np.allclose( + image.crval.npvalue, WCS.wcs.crval + ), "Image should have correct CRVAL from WCS" diff --git a/tests/test_image_header.py b/tests/test_image_header.py deleted file mode 100644 index 55e7357f..00000000 --- a/tests/test_image_header.py +++ /dev/null @@ -1,144 +0,0 @@ -import unittest -import astrophot as ap -import torch - -from utils import get_astropy_wcs - -###################################################################### -# Image_Header Objects -###################################################################### - - -class TestImageHeader(unittest.TestCase): - def test_image_header_creation(self): - - # Minimal input - H = ap.image.Image_Header( - data_shape=(20, 20), - zeropoint=22.5, - pixelscale=0.2, - ) - - self.assertTrue(torch.all(H.origin == 0), "Origin should be assumed zero if not given") - - # Center - H = ap.image.Image_Header( - data_shape=(20, 20), - pixelscale=0.2, - center=(10, 10), - ) - - self.assertTrue( - torch.all(H.origin == 8), - "Center provided, origin should be adjusted accordingly", - ) - - # Origin - H = ap.image.Image_Header( - data_shape=(20, 20), - pixelscale=0.2, - origin=(10, 10), - ) - - self.assertTrue(torch.all(H.origin == 10), "Origin provided, origin should be as given") - - # Center radec - H = ap.image.Image_Header( - data_shape=(20, 20), - pixelscale=0.2, - center_radec=(10, 10), - ) - - self.assertTrue( - torch.allclose(H.plane_to_world(H.center), torch.ones_like(H.center) * 10), - "Center_radec provided, center should be as given in world coordinates", - ) - - # Origin radec - H = ap.image.Image_Header( - data_shape=(20, 20), - pixelscale=0.2, - origin_radec=(10, 10), - ) - - self.assertTrue( - torch.allclose(H.plane_to_world(H.origin), torch.ones_like(H.center) * 10), - "Origin_radec provided, origin should be as given in world coordinates", - ) - - # Astropy WCS - wcs = get_astropy_wcs() - H = ap.image.Image_Header( - data_shape=(180, 180), - wcs=wcs, - ) - - sky_coord = wcs.pixel_to_world(*wcs.wcs.crpix) - wcs_world = torch.tensor((sky_coord.ra.deg, sky_coord.dec.deg)) - self.assertTrue( - torch.allclose( - torch.tensor( - wcs.wcs.crpix, - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - H.world_to_pixel(wcs_world), - ), - "Astropy WCS initialization should map crval crpix coordinates", - ) - - def test_image_header_wcs_roundtrip(self): - - wcs = get_astropy_wcs() - # Minimal input - H = ap.image.Image_Header( - data_shape=(20, 20), - zeropoint=22.5, - wcs=wcs, - ) - - self.assertTrue( - torch.allclose( - H.world_to_plane(H.plane_to_world(torch.zeros_like(H.window.reference_radec))), - torch.zeros_like(H.window.reference_radec), - ), - "WCS world/plane roundtrip should return input value", - ) - self.assertTrue( - torch.allclose( - H.pixel_to_plane(H.plane_to_pixel(torch.zeros_like(H.window.reference_radec))), - torch.zeros_like(H.window.reference_radec), - ), - "WCS pixel/plane roundtrip should return input value", - ) - self.assertTrue( - torch.allclose( - H.world_to_pixel(H.pixel_to_world(torch.zeros_like(H.window.reference_radec))), - torch.zeros_like(H.window.reference_radec), - atol=1e-6, - ), - "WCS world/pixel roundtrip should return input value", - ) - - self.assertTrue( - torch.allclose( - H.pixel_to_plane_delta( - H.plane_to_pixel_delta(torch.ones_like(H.window.reference_radec)) - ), - torch.ones_like(H.window.reference_radec), - ), - "WCS pixel/plane delta roundtrip should return input value", - ) - - def test_iamge_header_repr(self): - - wcs = get_astropy_wcs() - # Minimal input - H = ap.image.Image_Header( - data_shape=(20, 20), - zeropoint=22.5, - wcs=wcs, - ) - - S = str(H) - R = repr(H) diff --git a/tests/test_image_list.py b/tests/test_image_list.py index b4f2bcd0..fa9c0c88 100644 --- a/tests/test_image_list.py +++ b/tests/test_image_list.py @@ -1,468 +1,183 @@ -import unittest import astrophot as ap -import torch +import numpy as np +import pytest ###################################################################### # Image List Object ###################################################################### -class TestImageList(unittest.TestCase): - def test_image_creation(self): - arr1 = torch.zeros((10, 15)) - base_image1 = ap.image.Image( - data=arr1, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2), - metadata={"note": "test image 1"}, - ) - arr2 = torch.ones((15, 10)) - base_image2 = ap.image.Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - origin=torch.ones(2), - metadata={"note": "test image 2"}, - ) - - test_image = ap.image.Image_List((base_image1, base_image2)) - - for image, original_image in zip(test_image, (base_image1, base_image2)): - self.assertEqual( - image.pixel_length, - original_image.pixel_length, - "image should track pixelscale", - ) - self.assertEqual( - image.zeropoint, - original_image.zeropoint, - "image should track zeropoint", - ) - self.assertEqual(image.origin[0], original_image.origin[0], "image should track origin") - self.assertEqual(image.origin[1], original_image.origin[1], "image should track origin") - self.assertEqual( - image.metadata["note"], - original_image.metadata["note"], - "image should track note", - ) - - slicer = ap.image.Window_List( - ( - ap.image.Window(origin=(3, 2), pixel_shape=(4, 5)), - ap.image.Window(origin=(3, 2), pixel_shape=(4, 5)), - ) - ) - sliced_image = test_image[slicer] - - self.assertEqual(sliced_image[0].origin[0], 3, "image should track origin") - self.assertEqual(sliced_image[0].origin[1], 2, "image should track origin") - self.assertEqual(sliced_image[1].origin[0], 3, "image should track origin") - self.assertEqual(sliced_image[1].origin[1], 2, "image should track origin") - self.assertEqual(base_image1.origin[0], 0, "subimage should not change image origin") - self.assertEqual(base_image1.origin[1], 0, "subimage should not change image origin") - - def test_copy(self): - - arr1 = torch.zeros((10, 15)) - base_image1 = ap.image.Image( - data=arr1, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2), - ) - arr2 = torch.ones((15, 10)) - base_image2 = ap.image.Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - origin=torch.ones(2), - ) - - test_image = ap.image.Image_List((base_image1, base_image2)) - - copy_image = test_image.copy() - for ti, ci in zip(test_image, copy_image): - self.assertEqual( - ti.pixel_length, - ci.pixel_length, - "copied image should have same pixelscale", - ) - self.assertEqual(ti.zeropoint, ci.zeropoint, "copied image should have same zeropoint") - self.assertEqual(ti.window, ci.window, "copied image should have same window") - preval = ti.data[0][0].item() - ci += 1 - self.assertEqual( - ti.data[0][0], - preval, - "copied image should not share data with original", - ) - - blank_copy_image = test_image.blank_copy() - for ti, ci in zip(test_image, blank_copy_image): - self.assertEqual( - ti.pixel_length, - ci.pixel_length, - "copied image should have same pixelscale", - ) - self.assertEqual(ti.zeropoint, ci.zeropoint, "copied image should have same zeropoint") - self.assertEqual(ti.window, ci.window, "copied image should have same window") - preval = ti.data[0][0].item() - ci += 1 - self.assertEqual( - ti.data[0][0], - preval, - "copied image should not share data with original", - ) - - def test_image_arithmetic(self): - - arr1 = torch.zeros((10, 15)) - base_image1 = ap.image.Image( - data=arr1, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2), - ) - arr2 = torch.ones((15, 10)) - base_image2 = ap.image.Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - origin=torch.ones(2), - ) - test_image = ap.image.Image_List((base_image1, base_image2)) - - arr3 = torch.ones((10, 15)) - base_image3 = ap.image.Image( - data=arr3, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.ones(2), - ) - arr4 = torch.zeros((15, 10)) - base_image4 = ap.image.Image( - data=arr4, - pixelscale=0.5, - zeropoint=2.0, - origin=torch.zeros(2), - ) - second_image = ap.image.Image_List((base_image3, base_image4)) - - # Test iadd - test_image += second_image - - self.assertEqual( - test_image[0].data[0][0], 0, "image addition should only update its region" - ) - self.assertEqual(test_image[0].data[3][3], 1, "image addition should update its region") - self.assertEqual(test_image[1].data[0][0], 1, "image addition should update its region") - self.assertEqual(test_image[1].data[1][1], 1, "image addition should update its region") - - # Test iadd - test_image -= second_image - - self.assertEqual( - test_image[0].data[0][0], 0, "image addition should only update its region" - ) - self.assertEqual(test_image[0].data[3][3], 0, "image addition should update its region") - self.assertEqual(test_image[1].data[0][0], 1, "image addition should update its region") - self.assertEqual(test_image[1].data[1][1], 1, "image addition should update its region") - - def test_image_list_display(self): - arr1 = torch.zeros((10, 15)) - base_image1 = ap.image.Image( - data=arr1, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2), - ) - arr2 = torch.ones((15, 10)) - base_image2 = ap.image.Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - origin=torch.ones(2), - ) - test_image = ap.image.Image_List((base_image1, base_image2)) - - self.assertIsInstance(str(test_image), str, "String representation should be a string!") - self.assertIsInstance(repr(test_image), str, "Repr should be a string!") - - def test_image_list_windowset(self): - arr1 = torch.zeros((10, 15)) - base_image1 = ap.image.Image( - data=arr1, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2), - note="test image 1", - ) - arr2 = torch.ones((15, 10)) - base_image2 = ap.image.Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - origin=torch.ones(2), - note="test image 2", - ) - test_image = ap.image.Image_List((base_image1, base_image2)) - arr3 = torch.ones((10, 15)) - base_image3 = ap.image.Image( - data=arr3, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.ones(2), - note="test image 3", - ) - arr4 = torch.zeros((15, 10)) - base_image4 = ap.image.Image( - data=arr4, - pixelscale=0.5, - zeropoint=2.0, - origin=torch.zeros(2), - note="test image 4", - ) - second_image = ap.image.Image_List((base_image3, base_image4), window=test_image.window) - - def test_image_list_errors(self): - arr1 = torch.zeros((10, 15)) - base_image1 = ap.image.Image( - data=arr1, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2), - ) - arr2 = torch.ones((15, 10)) - base_image2 = ap.image.Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - origin=torch.ones(2), - ) - test_image = ap.image.Image_List((base_image1, base_image2)) - # Bad ra dec reference point - bad_base_image2 = ap.image.Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - reference_radec=torch.ones(2), - ) - with self.assertRaises(ap.errors.ConflicingWCS): - test_image = ap.image.Image_List((base_image1, bad_base_image2)) - - # Bad tangent plane x y reference point - bad_base_image2 = ap.image.Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - reference_planexy=torch.ones(2), - ) - with self.assertRaises(ap.errors.ConflicingWCS): - test_image = ap.image.Image_List((base_image1, bad_base_image2)) - - # Bad WCS projection - bad_base_image2 = ap.image.Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - projection="orthographic", - ) - with self.assertRaises(ap.errors.ConflicingWCS): - test_image = ap.image.Image_List((base_image1, bad_base_image2)) - - -class TestModelImageList(unittest.TestCase): - def test_model_image_list_creation(self): - arr1 = torch.zeros((10, 15)) - base_image1 = ap.image.Model_Image( - data=arr1, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2), - ) - arr2 = torch.ones((15, 10)) - base_image2 = ap.image.Model_Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - origin=torch.ones(2), - ) - - test_image = ap.image.Model_Image_List((base_image1, base_image2)) - - save_image = test_image.copy() - second_image = test_image.copy() - - second_image += (2, 2) - second_image -= (1, 1) - - test_image += second_image - - test_image -= second_image - - self.assertTrue( - torch.all(test_image[0].data == save_image[0].data), - "adding then subtracting should give the same image", - ) - self.assertTrue( - torch.all(test_image[1].data == save_image[1].data), - "adding then subtracting should give the same image", - ) - - print(test_image.data) - test_image.clear_image() - print(test_image.data) - test_image.replace(second_image) - print(test_image.data) - - test_image -= (1, 1) - print(test_image.data) - - self.assertTrue( - torch.all(test_image[0].data == save_image[0].data), - "adding then subtracting should give the same image", - ) - self.assertTrue( - torch.all(test_image[1].data == save_image[1].data), - "adding then subtracting should give the same image", - ) - - self.assertIsNone( - test_image.target_identity, - "Targets have not been assigned so target identity should be None", - ) - - def test_errors(self): - - # Model_Image_List with non Model_Image object - arr1 = torch.zeros((10, 15)) - base_image1 = ap.image.Model_Image( - data=arr1, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2), - ) - arr2 = torch.ones((15, 10)) - base_image2 = ap.image.Target_Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - origin=torch.ones(2), - ) - - with self.assertRaises(ap.errors.InvalidImage): - test_image = ap.image.Model_Image_List((base_image1, base_image2)) - - -class TestTargetImageList(unittest.TestCase): - def test_target_image_list_creation(self): - arr1 = torch.zeros((10, 15)) - base_image1 = ap.image.Target_Image( - data=arr1, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2), - variance=torch.ones_like(arr1), - mask=torch.zeros_like(arr1), - ) - arr2 = torch.ones((15, 10)) - base_image2 = ap.image.Target_Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - origin=torch.ones(2), - variance=torch.ones_like(arr2), - mask=torch.zeros_like(arr2), - ) - - test_image = ap.image.Target_Image_List((base_image1, base_image2)) - - save_image = test_image.copy() - second_image = test_image.copy() - - second_image += (2, 2) - second_image -= (1, 1) - - test_image += second_image - - test_image -= second_image - - self.assertTrue( - torch.all(test_image[0].data == save_image[0].data), - "adding then subtracting should give the same image", - ) - self.assertTrue( - torch.all(test_image[1].data == save_image[1].data), - "adding then subtracting should give the same image", - ) - - test_image += (1, 1) - test_image -= (1, 1) - - self.assertTrue( - torch.all(test_image[0].data == save_image[0].data), - "adding then subtracting should give the same image", - ) - self.assertTrue( - torch.all(test_image[1].data == save_image[1].data), - "adding then subtracting should give the same image", - ) - - def test_targetlist_errors(self): - arr1 = torch.zeros((10, 15)) - base_image1 = ap.image.Target_Image( - data=arr1, - pixelscale=1.0, - zeropoint=1.0, - origin=torch.zeros(2), - variance=torch.ones_like(arr1), - mask=torch.zeros_like(arr1), - ) - arr2 = torch.ones((15, 10)) - base_image2 = ap.image.Image( - data=arr2, - pixelscale=0.5, - zeropoint=2.0, - origin=torch.ones(2), - ) - with self.assertRaises(ap.errors.InvalidImage): - test_image = ap.image.Target_Image_List((base_image1, base_image2)) - - -class TestJacobianImageList(unittest.TestCase): - def test_jacobian_image_list_creation(self): - arr1 = torch.zeros((10, 15, 3)) - base_image1 = ap.image.Jacobian_Image( - data=arr1, - parameters=["a", "b", "c"], - target_identity="target1", - pixelscale=1.0, - zeropoint=1.0, - window=ap.image.Window(origin=torch.zeros(2) + 0.1, pixel_shape=torch.tensor((15, 10))), - ) - arr2 = torch.ones((15, 10, 3)) - base_image2 = ap.image.Jacobian_Image( - data=arr2, - parameters=["a", "b", "c"], - target_identity="target2", - pixelscale=0.5, - zeropoint=2.0, - window=ap.image.Window(origin=torch.zeros(2) + 0.2, pixel_shape=torch.tensor((10, 15))), - ) - - test_image = ap.image.Jacobian_Image_List((base_image1, base_image2)) - - second_image = test_image.copy() - - test_image += second_image - - self.assertEqual( - test_image.flatten("data").shape, - (300, 3), - "flattened jacobian should include all pixels and merge parameters", - ) - - -if __name__ == "__main__": - unittest.main() +def test_image_creation(): + arr1 = ap.backend.zeros((10, 15)) + base_image1 = ap.Image(data=arr1, pixelscale=1.0, zeropoint=1.0, name="image1") + arr2 = ap.backend.ones((15, 10)) + base_image2 = ap.Image(data=arr2, pixelscale=0.5, zeropoint=2.0, name="image2") + + test_image = ap.ImageList((base_image1, base_image2)) + + slicer = ap.WindowList( + (ap.Window((3, 12, 5, 8), base_image1), ap.Window((4, 8, 3, 13), base_image2)) + ) + sliced_image = test_image[slicer] + print(sliced_image[0]._data.shape, sliced_image[1]._data.shape) + assert sliced_image[0]._data.shape == (9, 3), "image slice incorrect shape" + assert sliced_image[1]._data.shape == (4, 10), "image slice incorrect shape" + assert np.all(sliced_image[0].crpix == np.array([-3, -5])), "image should track origin" + assert np.all(sliced_image[1].crpix == np.array([-4, -3])), "image should track origin" + + +def test_copy(): + arr1 = np.zeros((10, 15)) + 2 + base_image1 = ap.Image(data=arr1, pixelscale=1.0, zeropoint=1.0, name="image1") + arr2 = np.ones((15, 10)) + base_image2 = ap.Image(data=arr2, pixelscale=0.5, zeropoint=2.0, name="image2") + + test_image = ap.ImageList((base_image1, base_image2)) + + copy_image = test_image.copy() + copy_image.images[0] += 5 + copy_image.images[1] += 5 + + for ti, ci in zip(test_image, copy_image): + assert ti.pixelscale == ci.pixelscale, "copied image should have same pixelscale" + assert ti.zeropoint == ci.zeropoint, "copied image should have same zeropoint" + assert ap.backend.all(ti.data != ci.data), "copied image should not modify original data" + + blank_copy_image = test_image.blank_copy() + for ti, ci in zip(test_image, blank_copy_image): + assert ti.pixelscale == ci.pixelscale, "copied image should have same pixelscale" + assert ti.zeropoint == ci.zeropoint, "copied image should have same zeropoint" + + +def test_image_arithmetic(): + arr1 = np.zeros((10, 15)) + base_image1 = ap.Image(data=arr1, pixelscale=1.0, zeropoint=1.0, name="image1") + arr2 = np.ones((15, 10)) + base_image2 = ap.Image(data=arr2, pixelscale=0.5, zeropoint=2.0, name="image2") + test_image = ap.ImageList((base_image1, base_image2)) + + base_image3 = base_image1.copy() + base_image3 += 1 + base_image4 = base_image2.copy() + base_image4 -= 2 + second_image = ap.ImageList((base_image3, base_image4)) + + # Test iadd + test_image += second_image + + assert ap.backend.allclose( + test_image[0].data, ap.backend.ones_like(base_image1.data) + ), "image addition should update its region" + assert ap.backend.allclose( + base_image1.data, ap.backend.ones_like(base_image1.data) + ), "image addition should update its region" + assert ap.backend.allclose( + test_image[1].data, ap.backend.zeros_like(base_image2.data) + ), "image addition should update its region" + assert ap.backend.allclose( + base_image2.data, ap.backend.zeros_like(base_image2.data) + ), "image addition should update its region" + + # Test isub + test_image -= second_image + + assert ap.backend.allclose( + test_image[0].data, ap.backend.zeros_like(base_image1.data) + ), "image addition should update its region" + assert ap.backend.allclose( + base_image1.data, ap.backend.zeros_like(base_image1.data) + ), "image addition should update its region" + assert ap.backend.allclose( + test_image[1].data, ap.backend.ones_like(base_image2.data) + ), "image addition should update its region" + assert ap.backend.allclose( + base_image2.data, ap.backend.ones_like(base_image2.data) + ), "image addition should update its region" + + new_image = test_image + second_image + new_image = test_image - second_image + new_image = new_image.to(dtype=ap.backend.float32, device="cpu") + assert isinstance(new_image, ap.ImageList), "new image should be an ImageList" + + new_image += base_image1 + new_image -= base_image2 + + +def test_model_image_list_error(): + arr1 = np.zeros((10, 15)) + base_image1 = ap.ModelImage(data=arr1, pixelscale=1.0, zeropoint=1.0) + arr2 = np.ones((15, 10)) + base_image2 = ap.Image(data=arr2, pixelscale=0.5, zeropoint=2.0) + + with pytest.raises(ap.errors.InvalidImage): + ap.ModelImageList((base_image1, base_image2)) + + +def test_target_image_list_creation(): + arr1 = np.zeros((10, 15)) + base_image1 = ap.TargetImage( + data=arr1, + pixelscale=1.0, + zeropoint=1.0, + variance=np.ones_like(arr1), + mask=np.zeros_like(arr1), + name="image1", + ) + arr2 = np.ones((15, 10)) + base_image2 = ap.TargetImage( + data=arr2, + pixelscale=0.5, + zeropoint=2.0, + variance=np.ones_like(arr2), + mask=np.zeros_like(arr2), + name="image2", + ) + + test_image = ap.TargetImageList((base_image1, base_image2)) + + save_image = test_image.copy() + second_image = test_image.copy() + + second_image[0].data += 1 + second_image[1].data += 1 + + test_image += second_image + test_image -= second_image + + assert ap.backend.all( + test_image[0].data == save_image[0].data + ), "adding then subtracting should give the same image" + assert ap.backend.all( + test_image[1].data == save_image[1].data + ), "adding then subtracting should give the same image" + + +def test_targetlist_errors(): + arr1 = np.zeros((10, 15)) + base_image1 = ap.TargetImage( + data=arr1, + pixelscale=1.0, + zeropoint=1.0, + variance=np.ones_like(arr1), + mask=np.zeros_like(arr1), + ) + arr2 = np.ones((15, 10)) + base_image2 = ap.Image( + data=arr2, + pixelscale=0.5, + zeropoint=2.0, + ) + with pytest.raises(ap.errors.InvalidImage): + ap.TargetImageList((base_image1, base_image2)) + + +def test_jacobian_image_list_error(): + arr1 = np.zeros((10, 15, 3)) + base_image1 = ap.JacobianImage( + parameters=["a", "1", "zz"], data=arr1, pixelscale=1.0, zeropoint=1.0 + ) + arr2 = np.ones((15, 10)) + base_image2 = ap.Image(data=arr2, pixelscale=0.5, zeropoint=2.0) + + with pytest.raises(ap.errors.InvalidImage): + ap.JacobianImageList((base_image1, base_image2)) diff --git a/tests/test_model.py b/tests/test_model.py index 524f8705..a349e137 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,8 +1,7 @@ -import unittest import astrophot as ap -import torch import numpy as np from utils import make_basic_sersic, make_basic_gaussian_psf +import pytest # torch.autograd.set_detect_anomaly(True) ###################################################################### @@ -10,286 +9,261 @@ ###################################################################### -class TestModel(unittest.TestCase): - def test_AstroPhot_Model(self): - - model = ap.models.AstroPhot_Model(name="test model") - - self.assertIsNone(model.target, "model should not have a target at this point") - - target = ap.image.Target_Image(data=torch.zeros((16, 32)), pixelscale=1.0) - - model.target = target - - model.window = target.window - - model.locked = True - model.locked = False - - state = model.get_state() - - def test_initialize_does_not_recurse(self): - "Test case for error where missing parameter name triggered print that triggered missing parameter name ..." - target = make_basic_sersic() - model = ap.models.AstroPhot_Model( +def test_model_sampling_modes(): + + target = make_basic_sersic(90, 100) + model = ap.Model( + name="test sersic", + model_type="sersic galaxy model", + center=[40, 41.9], + PA=60 * np.pi / 180, + q=0.8, + n=0.5, + Re=20, + Ie=1, + target=target, + ) + + # With subpixel integration + model.integrate_mode = "bright" + auto = ap.backend.to_numpy(model().data) + model.sampling_mode = "midpoint" + midpoint = ap.backend.to_numpy(model().data) + midpoint_bright = midpoint.copy() + model.sampling_mode = "simpsons" + simpsons = ap.backend.to_numpy(model().data) + model.sampling_mode = "quad:5" + quad5 = ap.backend.to_numpy(model().data) + assert np.allclose(midpoint, auto, rtol=1e-2), "Midpoint sampling should match auto sampling" + assert np.allclose(midpoint, simpsons, rtol=1e-2), "Simpsons sampling should match midpoint" + assert np.allclose(midpoint, quad5, rtol=1e-2), "Quad5 sampling should match midpoint sampling" + assert np.allclose(simpsons, quad5, rtol=1e-6), "Quad5 sampling should match Simpsons sampling" + + # Without subpixel integration + model.integrate_mode = "none" + auto = ap.backend.to_numpy(model().data) + model.sampling_mode = "midpoint" + midpoint = ap.backend.to_numpy(model().data) + model.sampling_mode = "simpsons" + simpsons = ap.backend.to_numpy(model().data) + model.sampling_mode = "quad:5" + quad5 = ap.backend.to_numpy(model().data) + assert np.allclose( + midpoint, midpoint_bright, rtol=1e-2 + ), "no integrate sampling should match bright sampling" + assert np.allclose(midpoint, auto, rtol=1e-2), "Midpoint sampling should match auto sampling" + assert np.allclose(midpoint, simpsons, rtol=1e-2), "Simpsons sampling should match midpoint" + assert np.allclose(midpoint, quad5, rtol=1e-2), "Quad5 sampling should match midpoint sampling" + assert np.allclose(simpsons, quad5, rtol=1e-6), "Quad5 sampling should match Simpsons sampling" + + # curvature based subpixel integration + model.integrate_mode = "curvature" + auto = ap.backend.to_numpy(model().data) + model.sampling_mode = "midpoint" + midpoint = ap.backend.to_numpy(model().data) + model.sampling_mode = "simpsons" + simpsons = ap.backend.to_numpy(model().data) + model.sampling_mode = "quad:5" + quad5 = ap.backend.to_numpy(model().data) + assert np.allclose( + midpoint, midpoint_bright, rtol=1e-2 + ), "curvature integrate sampling should match bright sampling" + assert np.allclose(midpoint, auto, rtol=1e-2), "Midpoint sampling should match auto sampling" + assert np.allclose(midpoint, simpsons, rtol=1e-2), "Simpsons sampling should match midpoint" + assert np.allclose(midpoint, quad5, rtol=1e-2), "Quad5 sampling should match midpoint sampling" + assert np.allclose(simpsons, quad5, rtol=1e-6), "Quad5 sampling should match Simpsons sampling" + + model.integrate_mode = "should raise" + with pytest.raises(ap.errors.SpecificationConflict): + model() + model.integrate_mode = "none" + model.sampling_mode = "should raise" + with pytest.raises(ap.errors.SpecificationConflict): + model() + model.sampling_mode = "midpoint" + model.integrate_mode = "none" + + # test PSF modes + model.psf = np.array([[0.05, 0.1, 0.05], [0.1, 0.4, 0.1], [0.05, 0.1, 0.05]]) + model.psf_convolve = True + model() + + +def test_model_errors(): + + # Target that is not a target image + arr = np.zeros((10, 15)) + target = ap.image.Image(data=arr, pixelscale=1.0, zeropoint=1.0) + + with pytest.raises(ap.errors.InvalidTarget): + ap.Model( name="test model", model_type="sersic galaxy model", target=target, ) - # Define a function that accesses a parameter that doesn't exist - def calc(params): - return params["A"].value - - model["center"].value = calc - - with self.assertRaises(KeyError) as context: - model.initialize() - - def test_basic_model_methods(self): - - target = make_basic_sersic() - model = ap.models.AstroPhot_Model( - name="test sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - target=target, - ) - rep = model.parameters.vector_representation() - nat = model.parameters.vector_values() - self.assertTrue( - torch.all(torch.isclose(rep, model.parameters.vector_transform_val_to_rep(nat))), - "transform should map between parameter natural and representation", - ) - self.assertTrue( - torch.all(torch.isclose(nat, model.parameters.vector_transform_rep_to_val(rep))), - "transform should map between parameter representation and natural", - ) - - def test_model_sampling_modes(self): - - target = make_basic_sersic(100, 100) - model = ap.models.AstroPhot_Model( - name="test sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - target=target, - ) - res = model() - model.sampling_mode = "trapezoid" - res = model() - model.sampling_mode = "simpsons" - res = model() - model.sampling_mode = "quad:3" - res = model() - model.integrate_mode = "none" - res = model() - model.integrate_mode = "should raise" - self.assertRaises(ap.errors.SpecificationConflict, model) - model.integrate_mode = "none" - model.sampling_mode = "should raise" - self.assertRaises(ap.errors.SpecificationConflict, model) - model.sampling_mode = "midpoint" - - # test PSF modes - model.psf = np.array([[0.05, 0.1, 0.05], [0.1, 0.4, 0.1], [0.05, 0.1, 0.05]]) - model.integrate_mode = "none" - model.psf_mode = "full" - model.psf_convolve_mode = "direct" - res = model() - model.psf_convolve_mode = "fft" - res = model() - - def test_model_creation(self): - np.random.seed(12345) - shape = (10, 15) - tar = ap.image.Target_Image( - data=np.random.normal(loc=0, scale=1.4, size=shape), - pixelscale=0.8, - variance=np.ones(shape) * (1.4**2), - psf=np.array([[0.05, 0.1, 0.05], [0.1, 0.4, 0.1], [0.05, 0.1, 0.05]]), - ) - - mod = ap.models.Component_Model( - name="base model", - target=tar, - parameters={"center": {"value": [5, 5], "locked": True}}, - ) - - mod.initialize() - - self.assertFalse(mod.locked, "default model should not be locked") - - self.assertTrue(torch.all(mod().data == 0), "Component_Model model_image should be zeros") - - def test_mask(self): - - target = make_basic_sersic() - mask = torch.zeros_like(target.data) - mask[10, 13] = 1 - model = ap.models.AstroPhot_Model( - name="test sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, + # model that doesn't exist + target = make_basic_sersic() + with pytest.raises(ap.errors.UnrecognizedModel): + ap.Model( + name="test model", + model_type="sersic gaaxy model", target=target, - mask=mask, - ) - - sample = model() - self.assertEqual(sample.data[10, 13].item(), 0.0, "masked values should be zero") - self.assertNotEqual(sample.data[11, 12].item(), 0.0, "unmasked values should NOT be zero") - - def test_model_errors(self): - - # Invalid name - self.assertRaises(ap.errors.NameNotAllowed, ap.models.AstroPhot_Model, name="my|model") - - # Target that is not a target image - arr = torch.zeros((10, 15)) - target = ap.image.Image(data=arr, pixelscale=1.0, zeropoint=1.0, origin=torch.zeros(2)) - - with self.assertRaises(ap.errors.InvalidTarget): - model = ap.models.AstroPhot_Model( - name="test model", - model_type="sersic galaxy model", - target=target, - ) - - # model that doesn't exist - target = make_basic_sersic() - with self.assertRaises(ap.errors.UnrecognizedModel): - model = ap.models.AstroPhot_Model( - name="test model", - model_type="sersic gaaxy model", - target=target, - ) - - # invalid window - with self.assertRaises(ap.errors.InvalidWindow): - model = ap.models.AstroPhot_Model( - name="test model", - model_type="sersic galaxy model", - target=target, - window=(1, 2, 3), - ) - - -class TestAllModelBasics(unittest.TestCase): - def test_all_model_sample(self): - - target = make_basic_sersic() - for model_type in ap.models.Component_Model.List_Model_Names(usable=True): - print(model_type) - MODEL = ap.models.AstroPhot_Model( - name="test model", - model_type=model_type, - target=target, - ) - MODEL.initialize() - for P in MODEL.parameter_order: - self.assertIsNotNone( - MODEL[P].value, - f"Model type {model_type} parameter {P} should not be None after initialization", - ) - img = MODEL() - self.assertTrue( - torch.all(torch.isfinite(img.data)), - "Model should evaluate a real number for the full image", - ) - self.assertIsInstance(str(MODEL), str, "String representation should return string") - self.assertIsInstance(repr(MODEL), str, "Repr should return string") - - -class TestSersic(unittest.TestCase): - def test_sersic_creation(self): - np.random.seed(12345) - N = 50 - Width = 20 - shape = (N + 10, N) - true_params = [2, 5, 10, -3, 5, 0.7, np.pi / 4] - IXX, IYY = np.meshgrid( - np.linspace(-Width, Width, shape[1]), np.linspace(-Width, Width, shape[0]) ) - QPAXX, QPAYY = ap.utils.conversions.coordinates.Axis_Ratio_Cartesian_np( - true_params[5], IXX - true_params[3], IYY - true_params[4], true_params[6] - ) - Z0 = ap.utils.parametric_profiles.sersic_np( - np.sqrt(QPAXX**2 + QPAYY**2), - true_params[0], - true_params[1], - true_params[2], - ) + np.random.normal(loc=0, scale=0.1, size=shape) - tar = ap.image.Target_Image( - data=Z0, - pixelscale=0.8, - variance=np.ones(Z0.shape) * (0.1**2), - ) - - mod = ap.models.Sersic_Galaxy( - name="sersic model", - target=tar, - parameters={"center": [-3.2 + N / 2, 5.1 + (N + 10) / 2]}, - ) - - self.assertFalse(mod.locked, "default model should not be locked") - - mod.initialize() - def test_sersic_save_load(self): - target = make_basic_sersic() - psf = make_basic_gaussian_psf() - model = ap.models.AstroPhot_Model( - name="test sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - psf=psf, - psf_mode="full", - target=target, +@pytest.mark.parametrize( + "model_type", ap.models.ComponentModel.List_Models(usable=True, types=True) +) +def test_all_model_sample(model_type): + + if ap.backend.backend == "jax" and np.random.randint(0, 3) > 0: + pytest.skip("JAX is very slow, randomly reducing the number of tests") + if model_type == "isothermal sech2 edgeon model" and ap.backend.backend == "jax": + pytest.skip("JAX doesnt have bessel function k1 yet") + + if ( + model_type in ["ferrer warp galaxy model", "king warp galaxy model"] + and ap.backend.backend == "jax" + ): + pytest.skip("JAX version doesnt support these models yet, difficulty with gradients") + + target = make_basic_sersic() + target.zeropoint = 22.5 + MODEL = ap.Model( + name="test model", + model_type=model_type, + target=target, + integrate_mode=( + "none" if ap.backend.backend == "jax" else "bright" + ), # JAX JIT is reallly slow for any integration + ) + MODEL.initialize() + MODEL.to() + for P in MODEL.dynamic_params: + assert ( + P.value is not None + ), f"Model type {model_type} parameter {P.name} should not be None after initialization" + img = MODEL() + assert ap.backend.all( + ap.backend.isfinite(img.data) + ), "Model should evaluate a real number for the full image" + + res = ap.fit.LM(MODEL, max_iter=10, verbose=1).fit() + print(res.loss_history) + + print(MODEL) # test printing + + # sky has little freedom to fit, some more complex models need extra + # attention to get a good fit so here we just check that they can improve + if ( + "sky" in model_type + or "king" in model_type + or "spline" in model_type + or model_type + in [ + "exponential warp galaxy model", + "ferrer warp galaxy model", + "ferrer ray galaxy model", + "isothermal sech2 edgeon model", + ] + ): + assert res.loss_history[0] > res.loss_history[-1], ( + f"Model {model_type} should fit to the target image, but did not. " + f"Initial loss: {res.loss_history[0]}, Final loss: {res.loss_history[-1]}" ) - - model.initialize() - model.save("test_AstroPhot_sersic.yaml") - model2 = ap.models.AstroPhot_Model( - name="load model", - filename="test_AstroPhot_sersic.yaml", + else: # Most models should get significantly better after just a few iterations + assert res.loss_history[0] > (1.5 * res.loss_history[-1]), ( + f"Model {model_type} should fit to the target image, but did not. " + f"Initial loss: {res.loss_history[0]}, Final loss: {res.loss_history[-1]}" ) - for P in model.parameter_order: - self.assertAlmostEqual( - model[P].value.detach().cpu().tolist(), - model2[P].value.detach().cpu().tolist(), - msg="loaded model should have same parameters", - ) - - -if __name__ == "__main__": - unittest.main() + F = MODEL.total_flux() + assert ap.backend.isfinite(F), "Model total flux should be finite after fitting" + assert F > 0, "Model total flux should be positive after fitting" + U = MODEL.total_flux_uncertainty() + assert ap.backend.isfinite(U), "Model total flux uncertainty should be finite after fitting" + assert U >= 0, "Model total flux uncertainty should be non-negative after fitting" + M = MODEL.total_magnitude() + assert ap.backend.isfinite(M), "Model total magnitude should be finite after fitting" + U_M = MODEL.total_magnitude_uncertainty() + assert ap.backend.isfinite( + U_M + ), "Model total magnitude uncertainty should be finite after fitting" + assert U_M >= 0, "Model total magnitude uncertainty should be non-negative after fitting" + + allnames = set() + for name in MODEL.build_params_array_names(): + assert name not in allnames, f"Duplicate parameter name found: {name}" + allnames.add(name) + + +def test_sersic_save_load(): + + target = make_basic_sersic() + model = ap.Model( + name="test sersic", + model_type="sersic galaxy model", + center=[20, 20], + PA=60 * np.pi / 180, + q=0.5, + n=2, + Re=5, + Ie=1, + target=target, + ) + + model.initialize() + model.save_state("test_AstroPhot_sersic.hdf5", appendable=True) + model.center = [30, 30] + model.PA = 30 * np.pi / 180 + model.q = 0.8 + model.n = 3 + model.Re = 10 + model.Ie = 2 + target.crtan = [1.0, 2.0] + model.append_state("test_AstroPhot_sersic.hdf5") + model.load_state("test_AstroPhot_sersic.hdf5", index=0) + + assert model.center.value[0].item() == 20, "Model center should be loaded correctly" + assert model.center.value[1].item() == 20, "Model center should be loaded correctly" + assert model.PA.value.item() == 60 * np.pi / 180, "Model PA should be loaded correctly" + assert model.q.value.item() == 0.5, "Model q should be loaded correctly" + assert model.n.value.item() == 2, "Model n should be loaded correctly" + assert model.Re.value.item() == 5, "Model Re should be loaded correctly" + assert model.Ie.value.item() == 1, "Model Ie should be loaded correctly" + assert model.target.crtan.value[0] == 0.0, "Model target crtan should be loaded correctly" + assert model.target.crtan.value[1] == 0.0, "Model target crtan should be loaded correctly" + + +@pytest.mark.parametrize("center", [[20, 20], [25.1, 17.324567]]) +@pytest.mark.parametrize("PA", [0, 60 * np.pi / 180]) +@pytest.mark.parametrize("q", [0.2, 0.8]) +@pytest.mark.parametrize("n", [1, 4]) +@pytest.mark.parametrize("Re", [10, 25.1]) +def test_chunk_sample(center, PA, q, n, Re): + target = make_basic_sersic() + model = ap.Model( + name="test sersic", + model_type="sersic galaxy model", + center=center, + PA=PA, + q=q, + n=n, + Re=Re, + Ie=10.0, + target=target, + integrate_mode="none", + ) + + full_img = model.sample() + + chunk_img = target.model_image() + + for chunk in model.window.chunk(20**2): + sample = model.sample(window=chunk) + chunk_img += sample + + assert ap.backend.allclose( + full_img.data, chunk_img.data + ), "Chunked sample should match full sample within tolerance" diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py new file mode 100644 index 00000000..aaa7d40a --- /dev/null +++ b/tests/test_notebooks.py @@ -0,0 +1,57 @@ +import platform +import glob +import pytest +import runpy +import subprocess +import os +import caskade as ck +import astrophot as ap + +pytestmark = pytest.mark.skipif( + platform.system() in ["Windows", "Darwin"], + reason="Graphviz not installed on Windows runner", +) + + +notebooks = glob.glob( + os.path.join( + os.path.split(os.path.dirname(__file__))[0], "docs", "source", "tutorials", "*.ipynb" + ) +) + + +def convert_notebook_to_py(nbpath): + subprocess.run( + ["jupyter", "nbconvert", "--to", "python", nbpath], + check=True, + ) + pypath = nbpath.replace(".ipynb", ".py") + with open(pypath, "r") as f: + content = f.readlines() + with open(pypath, "w") as f: + for line in content: + if line.startswith("get_ipython()"): + # Remove get_ipython() lines to avoid errors in script execution + continue + f.write(line) + + +def cleanup_py_scripts(nbpath): + try: + os.remove(nbpath.replace(".ipynb", ".py")) + os.remove(nbpath.replace(".ipynb", ".pyc")) + except FileNotFoundError: + pass + + +@pytest.mark.parametrize("nb_path", notebooks) +def test_notebook(nb_path): + if ap.backend.backend == "jax": + pytest.skip("Requires torch backend") + convert_notebook_to_py(nb_path) + try: + runpy.run_path(nb_path.replace(".ipynb", ".py"), run_name="__main__") + finally: + ck.backend.backend = "torch" + ap.backend.backend = "torch" + cleanup_py_scripts(nb_path) diff --git a/tests/test_param.py b/tests/test_param.py new file mode 100644 index 00000000..7740dc1b --- /dev/null +++ b/tests/test_param.py @@ -0,0 +1,59 @@ +import pytest + +import astrophot as ap +from astrophot.param import Param + +from utils import make_basic_sersic + + +def test_param(): + + a = Param("a", value=1.0, uncertainty=0.1, valid=(0, 2), prof=1.0) + assert isinstance(a.uncertainty, ap.backend.array_type), "uncertainty should be a tensor" + assert isinstance(a.prof, ap.backend.array_type), "prof should be a tensor" + assert a.initialized, "parameter should be marked as initialized" + assert a.soft_valid(a.value) == a.value, "soft valid should return the value if not near limits" + assert ( + a.soft_valid(-1 * ap.backend.ones_like(a.value)) > a.valid[0] + ), "soft valid should push values inside the limits" + assert ( + a.soft_valid(3 * ap.backend.ones_like(a.value)) < a.valid[1] + ), "soft valid should push values inside the limits" + + b = Param("b", value=[2.0, 3.0], uncertainty=[0.1, 0.1], valid=(1, None)) + assert ap.backend.all( + b.soft_valid(-1 * ap.backend.ones_like(b.value)) > b.valid[0] + ), "soft valid should push values inside the limits" + assert b.prof is None + + c = Param("c", value=lambda P: P.a.value, valid=(None, 4.0)) + c.link(a) + assert c.initialized, "pointer should be marked as initialized" + assert c.uncertainty is None + + +def test_module(): + + target = make_basic_sersic() + model1 = ap.Model(name="test model 1", model_type="sersic galaxy model", target=target) + model2 = ap.Model(name="test model 2", model_type="sersic galaxy model", target=target) + model = ap.Model(name="test", model_type="group model", target=target, models=[model1, model2]) + model.initialize() + + U = ap.backend.ones_like(model.get_values()) * 0.1 + model.fill_dynamic_value_uncertainties(U) + + paramsu = model.get_values(attribute="uncertainty") + assert ap.backend.all(ap.backend.isfinite(paramsu)), "All parameters should be finite" + + paramsn = model.build_params_array_names() + assert all(isinstance(name, str) for name in paramsn), "All parameter names should be strings" + + paramsun = model.build_params_array_units() + assert all(isinstance(unit, str) for unit in paramsun), "All parameter units should be strings" + + index = model.dynamic_params_array_index(model2.q) + assert index == [9], "Parameter index should be correct" + + with pytest.raises(ValueError): + model.dynamic_params_array_index(5.0) # Not a Param instance diff --git a/tests/test_parameter.py b/tests/test_parameter.py deleted file mode 100644 index bfa9b4cd..00000000 --- a/tests/test_parameter.py +++ /dev/null @@ -1,570 +0,0 @@ -import unittest -from astrophot.param import ( - Node as BaseNode, - Parameter_Node, - Param_Mask, - Param_Unlock, -) -import astrophot as ap -import torch -import numpy as np - - -class Node(BaseNode): - """ - Dummy class for testing purposes - """ - - def value(self): - return None - - -class TestNode(unittest.TestCase): - - def test_node_init(self): - node1 = Node("node1") - node2 = Node("node2", locked=True) - - # Check for bad naming - with self.assertRaises(ValueError): - node_bad = Node("node:bad") - - def test_node_link(self): - node1 = Node("node1") - node2 = Node("node2") - node3 = Node("node3", locked=True) - - node1.link(node2, node3) - - self.assertTrue(node1.branch, "node1 is a branch") - self.assertFalse(node3.branch, "node1 is not a branch") - self.assertIs(node1["node2"], node2, "node getitem should fetch correct node") - - for Na, Nb in zip(node1.flat().values(), (node2, node3)): - self.assertIs(Na, Nb, "node flat should produce correct order") - - node4 = Node("node4") - - node2.link(node4) - - for Na, Nb in zip(node1.flat(include_locked=False).values(), (node4,)): - self.assertIs(Na, Nb, "node flat should produce correct order") - - # Check for cycle in DAG - with self.assertRaises(ap.errors.InvalidParameter): - node4.link(node1) - - node1.dump() - - self.assertEqual(len(node1.nodes), 0, "dump should clear all nodes") - - def test_node_access(self): - node1 = Node("node1") - node2 = Node("node2") - node3 = Node("node3", locked=True) - - node1.link(node2, node3) - node4 = Node("node4") - - node2.link(node4) - - self.assertIs(node1["node2:node4"], node4, "node getitem should fetch correct node") - self.assertEqual( - node1["node1"], - node1, - "node should get itself when getter called with its name", - ) - - # Check that error is raised when requesting non existent node - with self.assertRaises(KeyError): - badnode = node1[1.2] - - def test_state(self): - - node1 = Node("node1") - node2 = Node("node2") - node3 = Node("node3", locked=True) - - node1.link(node2, node3) - - state = node1.get_state() - - S = str(node1) - R = repr(node1) - - -class TestParameter(unittest.TestCase): - @torch.no_grad() - def test_parameter_setting(self): - base_param = Parameter_Node("base param") - base_param.value = 1.0 - self.assertEqual(base_param.value, 1, msg="Value should be set to 1") - - base_param.value = 2.0 - self.assertEqual(base_param.value, 2, msg="Value should update to 2") - - base_param.value += 2.0 - self.assertEqual(base_param.value, 4, msg="Value should update to 4") - - # Test a locked parameter that it does not change - locked_param = Parameter_Node("locked param", value=1.0, locked=True) - locked_param.value = 2.0 - self.assertEqual(locked_param.value, 1, msg="Locked value should remain at 1") - - locked_param.value = 2.0 - self.assertEqual(locked_param.value, 1, msg="Locked value should remain at 1") - - def test_parameter_limits(self): - - # Lower limit parameter - lowlim_param = Parameter_Node("lowlim param", limits=(1, None)) - lowlim_param.value = 100.0 - self.assertEqual( - lowlim_param.value, - 100, - msg="lower limit variable should not have upper limit", - ) - with self.assertRaises(ap.errors.InvalidParameter): - lowlim_param.value = -100.0 - - # Upper limit parameter - uplim_param = Parameter_Node("uplim param", limits=(None, 1)) - uplim_param.value = -100.0 - self.assertEqual( - uplim_param.value, - -100, - msg="upper limit variable should not have lower limit", - ) - with self.assertRaises(ap.errors.InvalidParameter): - uplim_param.value = 100.0 - - # Range limit parameter - range_param = Parameter_Node("range param", limits=(-1, 1)) - with self.assertRaises(ap.errors.InvalidParameter): - range_param.value = 100.0 - with self.assertRaises(ap.errors.InvalidParameter): - range_param.value = -100.0 - - # Cyclic Range limit parameter - cyrange_param = Parameter_Node("cyrange param", limits=(-1, 1), cyclic=True) - cyrange_param.value = 2.0 - self.assertEqual( - cyrange_param.value, - 0, - msg="cyclic variable should loop in range (upper)", - ) - cyrange_param.value = -2.0 - self.assertEqual( - cyrange_param.value, - 0, - msg="cyclic variable should loop in range (lower)", - ) - - def test_parameter_array(self): - - param_array1 = Parameter_Node("array1", value=list(float(3 + i) for i in range(5))) - param_array2 = Parameter_Node("array2", value=list(float(i) for i in range(5))) - - param_array2.value = list(float(3) for i in range(5)) - self.assertTrue( - torch.all(param_array2.value == 3), - msg="parameter array value should be updated", - ) - - self.assertEqual(len(param_array2), 5, "parameter array should have length attribute") - - def test_parameter_gradients(self): - V = torch.ones(3) - V.requires_grad = True - params = Parameter_Node("input params", value=V) - X = torch.sum(params.value * 3) - X.backward() - self.assertTrue(torch.all(V.grad == 3), "Parameters should track gradient") - - def test_parameter_state(self): - - P = Parameter_Node( - "state", value=1.0, uncertainty=0.5, limits=(-2, 2), locked=True, prof=1.0 - ) - - P2 = Parameter_Node("v2") - P2.set_state(P.get_state()) - - self.assertEqual(P.value, P2.value, "state should preserve value") - self.assertEqual(P.uncertainty, P2.uncertainty, "state should preserve uncertainty") - self.assertEqual(P.prof, P2.prof, "state should preserve prof") - self.assertEqual(P.locked, P2.locked, "state should preserve locked") - self.assertEqual( - P.limits[0].tolist(), P2.limits[0].tolist(), "state should preserve limits" - ) - self.assertEqual( - P.limits[1].tolist(), P2.limits[1].tolist(), "state should preserve limits" - ) - - S = str(P) - - def test_parameter_value(self): - - P1 = Parameter_Node( - "test1", value=0.5, uncertainty=0.5, limits=(-1, 1), locked=False, prof=1.0 - ) - - P2 = Parameter_Node("test2", value=P1) - - P3 = Parameter_Node("test3", value=lambda P: P["test1"].value ** 2, link=(P1,)) - - self.assertEqual(P1.value.item(), 0.5, "Parameter should store value") - self.assertEqual(P2.value.item(), 0.5, "Pointing parameter should fetch value") - self.assertEqual(P3.value.item(), 0.25, "Function parameter should compute value") - - self.assertEqual(P2.shape, P1.shape, "reference node should map shape") - self.assertEqual(P3.shape, P1.shape, "reference node should map shape") - - -class TestParamContext(unittest.TestCase): - def test_unlock(self): - locked_param = Parameter_Node("locked param", value=1.0, locked=True) - locked_param.value = 2.0 - self.assertEqual( - locked_param.value.item(), - 1.0, - "locked parameter should not be updated out of context", - ) - with Param_Unlock(locked_param): - locked_param.value = 2.0 - self.assertEqual( - locked_param.value.item(), - 2.0, - "locked parameter should be updated in context", - ) - with Param_Unlock(): - locked_param.value = 3.0 - self.assertEqual( - locked_param.value.item(), - 3.0, - "locked parameter should be updated in global unlock context", - ) - - -class TestParameterVector(unittest.TestCase): - def test_param_vector_creation(self): - - P1 = Parameter_Node( - "test1", value=0.5, uncertainty=0.5, limits=(-1, 1), locked=False, prof=1.0 - ) - P2 = Parameter_Node("test2", value=2.0, uncertainty=5.0, locked=False) - P3 = Parameter_Node("test3", value=[4.0, 5.0], uncertainty=[5.0, 5.0], locked=False) - P4 = Parameter_Node("test4", value=P2) - P5 = Parameter_Node("test5", value=lambda P: P["test1"].value ** 2, link=(P1,)) - PG = Parameter_Node("testgroup", link=(P1, P2, P3, P4, P5)) - - self.assertTrue( - torch.all( - PG.vector_values() - == torch.tensor([0.5, 2.0, 4.0, 5.0], dtype=P1.value.dtype, device=P1.value.device) - ), - "Vector store all leaf node values", - ) - self.assertEqual(PG.mask.numel(), 4, "Vector should take all/only leaf node masks") - self.assertEqual( - PG.vector_identities().size, - 4, - "Vector should take all/only leaf node identities", - ) - self.assertEqual(PG.identities.size, 4, "Vector should take all/only leaf node identities") - self.assertEqual(PG.names.size, 4, "Vector should take all/only leaf node names") - self.assertEqual(PG.vector_names().size, 4, "Vector should take all/only leaf node names") - - PG.value = [1.0, 2.0, 3.0, 4.0] - self.assertTrue( - torch.all( - PG.vector_values() - == torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=P1.value.dtype, device=P1.value.device) - ), - "Vector store all leaf node values", - ) - - def test_vector_masking(self): - - P1 = Parameter_Node( - "test1", value=0.5, uncertainty=0.3, limits=(-1, 1), locked=False, prof=1.0 - ) - P2 = Parameter_Node("test2", value=2.0, uncertainty=1.0, locked=False) - P3 = Parameter_Node("test3", value=[4.0, 5.0], uncertainty=[5.0, 3.0], locked=False) - P4 = Parameter_Node("test4", value=P2) - P5 = Parameter_Node("test5", value=lambda P: P["test1"].value ** 2, link=(P1,)) - PG = Parameter_Node("testgroup", link=(P1, P2, P3, P4, P5)) - - mask = torch.tensor([1, 0, 0, 1], dtype=torch.bool, device=P1.value.device) - - with Param_Mask(PG, mask): - self.assertTrue( - torch.all( - PG.vector_values() - == torch.tensor([0.5, 5.0], dtype=P1.value.dtype, device=P1.value.device) - ), - "Vector store all leaf node values", - ) - self.assertTrue( - torch.all( - PG.vector_uncertainty() - == torch.tensor([0.3, 3.0], dtype=P1.value.dtype, device=P1.value.device) - ), - "Vector store all leaf node uncertainty", - ) - self.assertEqual( - PG.vector_mask().numel(), - 4, - "Vector should take all/only leaf node masks", - ) - self.assertEqual( - PG.vector_identities().size, - 2, - "Vector should take all/only leaf node identities", - ) - - # Nested masking - new_mask = torch.tensor([1, 0], dtype=torch.bool, device=P1.value.device) - with Param_Mask(PG, new_mask): - self.assertTrue( - torch.all( - PG.vector_values() - == torch.tensor([0.5], dtype=P1.value.dtype, device=P1.value.device) - ), - "Vector store all leaf node values", - ) - self.assertTrue( - torch.all( - PG.vector_uncertainty() - == torch.tensor([0.3], dtype=P1.value.dtype, device=P1.value.device) - ), - "Vector store all leaf node uncertainty", - ) - self.assertEqual( - PG.vector_mask().numel(), - 4, - "Vector should take all/only leaf node masks", - ) - self.assertEqual( - PG.vector_identities().size, - 1, - "Vector should take all/only leaf node identities", - ) - - self.assertTrue( - torch.all( - PG.vector_values() - == torch.tensor([0.5, 5.0], dtype=P1.value.dtype, device=P1.value.device) - ), - "Vector store all leaf node values", - ) - self.assertTrue( - torch.all( - PG.vector_uncertainty() - == torch.tensor([0.3, 3.0], dtype=P1.value.dtype, device=P1.value.device) - ), - "Vector store all leaf node uncertainty", - ) - self.assertEqual( - PG.vector_mask().numel(), - 4, - "Vector should take all/only leaf node masks", - ) - self.assertEqual( - PG.vector_identities().size, - 2, - "Vector should take all/only leaf node identities", - ) - - def test_vector_representation(self): - - P1 = Parameter_Node( - "test1", value=0.5, uncertainty=0.3, limits=(-1, 1), locked=False, prof=1.0 - ) - P2 = Parameter_Node("test2", value=2.0, uncertainty=1.0, locked=False) - P3 = Parameter_Node( - "test3", - value=[4.0, 5.0], - uncertainty=[5.0, 3.0], - limits=(1.0, None), - locked=False, - ) - P4 = Parameter_Node("test4", value=P2) - P5 = Parameter_Node("test5", value=lambda P: P["test1"].value ** 2, link=(P1,)) - P6 = Parameter_Node( - "test6", - value=((5, 6), (7, 8)), - uncertainty=0.1 * np.zeros((2, 2)), - limits=(None, 10.0), - ) - PG = Parameter_Node("testgroup", link=(P1, P2, P3, P4, P5, P6)) - - mask = torch.tensor([1, 1, 0, 1, 0, 1, 0, 1], dtype=torch.bool, device=P1.value.device) - - self.assertEqual( - len(PG.vector_representation()), - 8, - "representation should collect all values", - ) - with Param_Mask(PG, mask): - # round trip - vec = PG.vector_values().clone() - rep = PG.vector_representation() - PG.vector_set_representation(rep) - self.assertTrue( - torch.all(vec == PG.vector_values()), - "representation should be reversible", - ) - self.assertEqual(PG.vector_values().numel(), 5, "masked values shouldn't be shown") - - def test_printing(self): - - def node_func_sqr(P): - return P["test1"].value ** 2 - - P1 = Parameter_Node( - "test1", value=0.5, uncertainty=0.3, limits=(-1, 1), locked=False, prof=1.0 - ) - P2 = Parameter_Node("test2", value=2.0, uncertainty=1.0, locked=False) - P3 = Parameter_Node( - "test3", - value=[4.0, 5.0], - uncertainty=[5.0, 3.0], - limits=((0.0, 1.0), None), - locked=False, - ) - P4 = Parameter_Node("test4", value=P2) - P5 = Parameter_Node("test5", value=node_func_sqr, link=(P1,)) - P6 = Parameter_Node( - "test6", - value=((5, 6), (7, 8)), - uncertainty=0.1 * np.zeros((2, 2)), - limits=(None, 10 * np.ones((2, 2))), - ) - PG = Parameter_Node("testgroup", link=(P1, P2, P3, P4, P5, P6)) - - self.assertEqual( - str(PG), - """testgroup: -test1: 0.5 +- 0.3 [none], limits: (-1.0, 1.0) -test2: 2.0 +- 1.0 [none] -test3: [4.0, 5.0] +- [5.0, 3.0] [none], limits: ([0.0, 1.0], None) -test6: [[5.0, 6.0], [7.0, 8.0]] +- [[0.0, 0.0], [0.0, 0.0]] [none], limits: (None, [[10.0, 10.0], [10.0, 10.0]])""", - "String representation should return specific string", - ) - - ref_string = """testgroup (id-140071931416000, branch node): - test1 (id-140071931414752): 0.5 +- 0.3 [none], limits: (-1.0, 1.0) - test2 (id-140071931415376): 2.0 +- 1.0 [none] - test3 (id-140071931415472): [4.0, 5.0] +- [5.0, 3.0] [none], limits: ([0.0, 1.0], None) - test4 (id-140071931414272) points to: test2 (id-140071931415376): 2.0 +- 1.0 [none] - test5 (id-140071931414992, function node, node_func_sqr): - test1 (id-140071931414752): 0.5 +- 0.3 [none], limits: (-1.0, 1.0) - test6 (id-140071931415616): [[5.0, 6.0], [7.0, 8.0]] +- [[0.0, 0.0], [0.0, 0.0]] [none], limits: (None, [[10.0, 10.0], [10.0, 10.0]])""" - # Remove ids since they change every time - while "(id-" in ref_string: - start = ref_string.find("(id-") - end = ref_string.find(")", start) + 1 - ref_string = ref_string[:start] + ref_string[end:] - - repr_string = repr(PG) - # Remove ids since they change every time - count = 0 - while "(id-" in repr_string: - start = repr_string.find("(id-") - end = repr_string.find(")", start) + 1 - repr_string = repr_string[:start] + repr_string[end:] - count += 1 - if count > 100: - raise RuntimeError("infinite loop! Something very wrong with parameter repr") - self.assertEqual(repr_string, ref_string, "Repr should return specific string") - - def test_empty_vector(self): - def node_func_sqr(P): - return P["test1"].value ** 2 - - P1 = Parameter_Node( - "test1", value=0.5, uncertainty=0.3, limits=(-1, 1), locked=True, prof=1.0 - ) - P2 = Parameter_Node("test2", value=2.0, uncertainty=1.0, locked=True) - P3 = Parameter_Node( - "test3", - value=[4.0, 5.0], - uncertainty=[5.0, 3.0], - limits=((0.0, 1.0), None), - locked=True, - ) - P4 = Parameter_Node("test4", value=P2) - P5 = Parameter_Node("test5", value=node_func_sqr, link=(P1,)) - P6 = Parameter_Node( - "test6", - value=((5, 6), (7, 8)), - uncertainty=0.1 * np.zeros((2, 2)), - limits=(None, 10 * np.ones((2, 2))), - locked=True, - ) - PG = Parameter_Node("testgroup", link=(P1, P2, P3, P4, P5, P6)) - - self.assertEqual(PG.names.shape, (0,), "all locked parameter should have empty names") - self.assertEqual( - PG.identities.shape, - (0,), - "all locked parameter should have empty identities", - ) - self.assertEqual( - PG.vector_names().shape, - (0,), - "all locked parameter should have empty names", - ) - self.assertEqual( - PG.vector_identities().shape, - (0,), - "all locked parameter should have empty identities", - ) - - self.assertEqual( - PG.vector_values().shape, - (0,), - "all locked parameter should have empty values", - ) - self.assertEqual( - PG.vector_uncertainty().shape, - (0,), - "all locked parameter should have empty uncertainty", - ) - self.assertEqual( - PG.vector_mask().shape, (0,), "all locked parameter should have empty mask" - ) - self.assertEqual( - PG.vector_representation().shape, - (0,), - "all locked parameter should have empty representation", - ) - - def test_none_uncertainty(self): - - P1 = Parameter_Node( - "test1", value=0.5, uncertainty=0.3, limits=(-1, 1), locked=False, prof=1.0 - ) - P2 = Parameter_Node("test2", value=2.0, locked=True) - P3 = Parameter_Node("test3", value=[4.0, 5.0], limits=((0.0, 1.0), None), locked=False) - P4 = Parameter_Node("test4", link=(P1, P2, P3)) - - self.assertEqual( - tuple(P4.vector_uncertainty().detach().cpu().tolist()), - (0.3, 1.0, 1.0), - "None uncertainty should be filled with ones", - ) - - P3.uncertainty = None - P4.vector_set_uncertainty((0.1, 0.1, 0.1)) - - self.assertEqual( - tuple(P4.vector_uncertainty().detach().cpu().tolist()), - (0.1, 0.1, 0.1), - "None uncertainty should be filled using vector_set_uncertainty", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_plots.py b/tests/test_plots.py index 6d78aadd..4d6a59c7 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -1,212 +1,165 @@ -import unittest - import numpy as np import matplotlib.pyplot as plt import astrophot as ap from utils import make_basic_sersic, make_basic_gaussian_psf +import pytest -class TestPlots(unittest.TestCase): - """ - Can't test visuals, so this only tests that the code runs - """ - - def test_target_image(self): - target = make_basic_sersic() +""" +Can't test visuals, so this only tests that the code runs +""" - try: - fig, ax = plt.subplots() - except Exception: - print("skipping test_target_image because matplotlib is not installed properly") - return - ap.plots.target_image(fig, ax, target) - plt.close() +def test_target_image(): + target = make_basic_sersic() + try: + fig, ax = plt.subplots() + except Exception: + pytest.skip("skipping test_target_image because matplotlib is not installed properly") + ap.plots.target_image(fig, ax, target) + plt.close() - def test_psf_image(self): - target = make_basic_gaussian_psf() +def test_psf_image(): + target = make_basic_gaussian_psf() + try: + fig, ax = plt.subplots() + except Exception: + pytest.skip("skipping test_target_image because matplotlib is not installed properly") + ap.plots.psf_image(fig, ax, target) + plt.close() + + +def test_target_image_list(): + target1 = make_basic_sersic(name="target1") + target2 = make_basic_sersic(name="target2") + target = ap.TargetImageList([target1, target2]) + try: + fig, ax = plt.subplots(2) + except Exception: + pytest.skip("skipping test_target_image_list because matplotlib is not installed properly") + ap.plots.target_image(fig, ax, target) + plt.close() + + +def test_model_image(): + target = make_basic_sersic() + new_model = ap.Model( + name="constrained sersic", + model_type="sersic galaxy model", + center=[20, 20], + PA=60 * np.pi / 180, + q=0.5, + n=2, + Re=5, + Ie=1, + target=target, + ) + new_model.initialize() + try: + fig, ax = plt.subplots() + except Exception: + pytest.skip("skipping test because matplotlib is not installed properly") + ap.plots.model_image(fig, ax, new_model) + plt.close() + + +def test_residual_image(): + target = make_basic_sersic() + new_model = ap.Model( + name="constrained sersic", + model_type="sersic galaxy model", + center=[20, 20], + PA=60 * np.pi / 180, + q=0.5, + n=2, + Re=5, + Ie=1, + target=target, + ) + new_model.initialize() + try: fig, ax = plt.subplots() + except Exception: + pytest.skip("skipping test because matplotlib is not installed properly") + ap.plots.residual_image(fig, ax, new_model) + plt.close() + + +def test_model_windows(): + target = make_basic_sersic() + new_model = ap.Model( + name="constrained sersic", + model_type="sersic galaxy model", + center=[20, 20], + PA=60 * np.pi / 180, + q=0.5, + n=2, + Re=5, + Ie=1, + window=(10, 10, 30, 30), + target=target, + ) + new_model.initialize() + try: + fig, ax = plt.subplots() + except Exception: + pytest.skip("skipping test because matplotlib is not installed properly") + ap.plots.model_window(fig, ax, new_model) + plt.close() + - ap.plots.psf_image(fig, ax, target) - plt.close() - - def test_target_image_list(self): - target1 = make_basic_sersic() - target2 = make_basic_sersic() - target = ap.image.Target_Image_List([target1, target2]) - - try: - fig, ax = plt.subplots(2) - except Exception: - print("skipping test_target_image_list because matplotlib is not installed properly") - return - - ap.plots.target_image(fig, ax, target) - plt.close() - - def test_model_image(self): - target = make_basic_sersic() - - new_model = ap.models.AstroPhot_Model( - name="constrained sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - target=target, - ) - new_model.initialize() - - try: - fig, ax = plt.subplots() - except Exception: - print("skipping test because matplotlib is not installed properly") - return - - ap.plots.model_image(fig, ax, new_model) - - plt.close() - - def test_residual_image(self): - target = make_basic_sersic() - - new_model = ap.models.AstroPhot_Model( - name="constrained sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - target=target, - ) - new_model.initialize() - - try: - fig, ax = plt.subplots() - except Exception: - print("skipping test because matplotlib is not installed properly") - return - - ap.plots.residual_image(fig, ax, new_model) - - plt.close() - - def test_model_windows(self): - - target = make_basic_sersic() - - new_model = ap.models.AstroPhot_Model( - name="constrained sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - target=target, - ) - new_model.initialize() - - try: - fig, ax = plt.subplots() - except Exception: - print("skipping test because matplotlib is not installed properly") - return - - ap.plots.model_window(fig, ax, new_model) - - plt.close() - - def test_radial_median_profile(self): - - target = make_basic_sersic() - - new_model = ap.models.AstroPhot_Model( - name="constrained sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - target=target, - ) - new_model.initialize() - - try: - fig, ax = plt.subplots() - except Exception: - print("skipping test because matplotlib is not installed properly") - return - - ap.plots.radial_median_profile(fig, ax, new_model) - - plt.close() - - try: - fig, ax = plt.subplots() - except Exception: - print("skipping test because matplotlib is not installed properly") - return - - target.header.zeropoint = 22.5 - ap.plots.radial_median_profile(fig, ax, new_model, rad_unit="pixel", return_profile=True) - - plt.close() - - def test_radial_light_profile(self): - - target = make_basic_sersic() - - new_model = ap.models.AstroPhot_Model( - name="constrained sersic", - model_type="sersic galaxy model", - parameters={ - "center": [20, 20], - "PA": 60 * np.pi / 180, - "q": 0.5, - "n": 2, - "Re": 5, - "Ie": 1, - }, - target=target, - ) - new_model.initialize() - - try: - fig, ax = plt.subplots() - except Exception: - print("skipping test because matplotlib is not installed properly") - return - - ap.plots.radial_light_profile(fig, ax, new_model) - - plt.close() - - try: - fig, ax = plt.subplots() - except Exception: - print("skipping test because matplotlib is not installed properly") - return - - target.header.zeropoint = 22.5 - ap.plots.radial_light_profile(fig, ax, new_model) - - plt.close() +def test_covariance_matrix(): + covariance_matrix = np.array([[1, 0.5], [0.5, 1]]) + mean = np.array([0, 0]) + try: + fig, ax = plt.subplots() + except Exception: + pytest.skip("skipping test because matplotlib is not installed properly") + fig, ax = ap.plots.covariance_matrix(covariance_matrix, mean, labels=["x", "y"]) + plt.close() + + +def test_radial_profile(): + target = make_basic_sersic() + new_model = ap.Model( + name="constrained sersic", + model_type="sersic galaxy model", + center=[20, 20], + PA=60 * np.pi / 180, + q=0.5, + n=2, + Re=5, + Ie=1, + target=target, + ) + new_model.initialize() + try: + fig, ax = plt.subplots() + except Exception: + pytest.skip("skipping test because matplotlib is not installed properly") + ap.plots.radial_light_profile(fig, ax, new_model) + plt.close() + + +def test_radial_median_profile(): + target = make_basic_sersic() + new_model = ap.Model( + name="constrained sersic", + model_type="sersic galaxy model", + center=[20, 20], + PA=60 * np.pi / 180, + q=0.5, + n=2, + Re=5, + Ie=1, + target=target, + ) + new_model.initialize() + try: + fig, ax = plt.subplots() + except Exception: + pytest.skip("skipping test because matplotlib is not installed properly") + ap.plots.radial_median_profile(fig, ax, new_model) + plt.close() diff --git a/tests/test_psfmodel.py b/tests/test_psfmodel.py index 967f138a..586672ed 100644 --- a/tests/test_psfmodel.py +++ b/tests/test_psfmodel.py @@ -1,8 +1,7 @@ -import unittest import astrophot as ap -import torch import numpy as np from utils import make_basic_gaussian_psf +import pytest # torch.autograd.set_detect_anomaly(True) ###################################################################### @@ -10,73 +9,58 @@ ###################################################################### -class TestAllPSFModelBasics(unittest.TestCase): - def test_all_psfmodel_sample(self): - - target = make_basic_gaussian_psf() - for model_type in ap.models.PSF_Model.List_Model_Names(usable=True): - print(model_type) - MODEL = ap.models.AstroPhot_Model( - name="test model", - model_type=model_type, - target=target, - ) - MODEL.initialize() - for P in MODEL.parameter_order: - self.assertIsNotNone( - MODEL[P].value, - f"Model type {model_type} parameter {P} should not be None after initialization", - ) - print(MODEL.parameters) - img = MODEL() - self.assertTrue( - torch.all(torch.isfinite(img.data)), - "Model should evaluate a real number for the full image", - ) - self.assertIsInstance(str(MODEL), str, "String representation should return string") - self.assertIsInstance(repr(MODEL), str, "Repr should return string") - - -class TestEigenPSF(unittest.TestCase): - def test_init(self): - target = make_basic_gaussian_psf(N=51, rand=666) - dat = target.data.detach() - dat[dat < 0] = 0 - target = ap.image.PSF_Image(data=dat, pixelscale=target.pixelscale) - basis = np.stack( - list( - make_basic_gaussian_psf(N=51, sigma=s, rand=int(4923 * s)).data - for s in np.linspace(8, 1, 5) - ) - ) - # basis = np.random.rand(10,51,51) - EM = ap.models.AstroPhot_Model( - model_type="eigen psf model", - eigen_basis=basis, - eigen_pixelscale=1, - target=target, +@pytest.mark.parametrize("model_type", ap.models.PSFModel.List_Models(usable=True, types=True)) +def test_all_psfmodel_sample(model_type): + if model_type == "airy psf model" and ap.backend.backend == "jax": + pytest.skip( + "Skipping airy psf model, JAX does not support bessel_j1 with finite derivatives it seems" ) - EM.initialize() - - res = ap.fit.LM(EM, verbose=1).fit() + if "nuker" in model_type: + kwargs = {"Ib": {"value": None, "dynamic": True}} + elif "gaussian" in model_type: + kwargs = {"flux": {"value": None, "dynamic": True}} + elif "exponential" in model_type: + kwargs = {"Ie": {"value": None, "dynamic": True}} + else: + kwargs = {} + target = make_basic_gaussian_psf(pixelscale=0.8) + MODEL = ap.Model( + name="test model", + model_type=model_type, + target=target, + normalize_psf=False, + **kwargs, + ) + MODEL.initialize() + print(MODEL) + for P in MODEL.dynamic_params: + assert P.value is not None, ( + f"Model type {model_type} parameter {P} should not be None after initialization", + ) + img = MODEL() - self.assertEqual(res.message, "success") + assert ap.backend.all( + ap.backend.isfinite(img.data) + ), "Model should evaluate a real number for the full image" + if model_type == "pixelated psf model": + psf = ap.utils.initialize.gaussian_psf(3 * 0.8, 25, 0.8) + MODEL.pixels.value = psf / np.sum(psf) -class TestPixelPSF(unittest.TestCase): - def test_init(self): - target = make_basic_gaussian_psf(N=11) - target.data[target.data < 0] = 0 - target = ap.image.PSF_Image( - data=target.data / torch.sum(target.data), pixelscale=target.pixelscale - ) + assert ap.backend.all( + ap.backend.isfinite(MODEL.jacobian().data) + ), "Model should evaluate a real number for the jacobian" - PM = ap.models.AstroPhot_Model( - model_type="pixelated psf model", - target=target, - ) + res = ap.fit.LM(MODEL, max_iter=10).fit() - PM.initialize() + assert len(res.loss_history) > 2, "Optimizer must be able to find steps to improve the model" - self.assertTrue(torch.allclose(PM().data, target.data)) + if "pixelated" in model_type: # fixme pixelated having difficulties + return + assert ((res.loss_history[0] - 1) > (2 * (res.loss_history[-1] - 1))) or ( + res.loss_history[-1] < 1.0 + ), ( + f"Model {model_type} should fit to the target image, but did not. " + f"Initial loss: {res.loss_history[0]}, Final loss: {res.loss_history[-1]}" + ) diff --git a/tests/test_sip_image.py b/tests/test_sip_image.py new file mode 100644 index 00000000..f79570be --- /dev/null +++ b/tests/test_sip_image.py @@ -0,0 +1,136 @@ +import astrophot as ap +import numpy as np + +import pytest + +###################################################################### +# Image Objects +###################################################################### + + +@pytest.fixture() +def sip_target(): + arr = np.zeros((10, 15)) + return ap.SIPTargetImage( + data=arr, + pixelscale=1.0, + zeropoint=1.0, + variance=np.ones_like(arr), + mask=np.zeros_like(arr), + sipA={(1, 0): 1e-4, (0, 1): 1e-4, (2, 3): -1e-5}, + sipB={(1, 0): -1e-4, (0, 1): 5e-5, (2, 3): 2e-6}, + # sipAP={(1, 0): -1e-4, (0, 1): -1e-4, (2, 3): 1e-5}, + # sipBP={(1, 0): 1e-4, (0, 1): -5e-5, (2, 3): -2e-6}, + ) + + +def test_sip_image_creation(sip_target): + assert sip_target.pixelscale == 1.0, "image should track pixelscale" + assert sip_target.zeropoint == 1.0, "image should track zeropoint" + assert sip_target.crpix[0] == 0, "image should track crpix" + assert sip_target.crpix[1] == 0, "image should track crpix" + + slicer = ap.Window((7, 13, 4, 7), sip_target) + sliced_image = sip_target[slicer] + assert sliced_image.crpix[0] == -7, "crpix of subimage should give relative position" + assert sliced_image.crpix[1] == -4, "crpix of subimage should give relative position" + assert sliced_image._data.shape == (6, 3), "sliced image should have correct shape" + assert sliced_image.pixel_area_map.shape == ( + 6, + 3, + ), "sliced image should have correct pixel area map shape" + assert sliced_image.distortion_ij.shape == ( + 2, + 6, + 3, + ), "sliced image should have correct distortion shape" + assert sliced_image.distortion_IJ.shape == ( + 2, + 6, + 3, + ), "sliced image should have correct distortion shape" + + sip_model_image = sip_target.model_image(upsample=2, pad=1) + assert sip_model_image._data.shape == (32, 22), "model image should have correct shape" + assert sip_model_image.pixel_area_map.shape == ( + 32, + 22, + ), "model image pixel area map should have correct shape" + assert sip_model_image.distortion_ij.shape == ( + 2, + 32, + 22, + ), "model image distortion model should have correct shape" + assert sip_model_image.distortion_IJ.shape == ( + 2, + 32, + 22, + ), "model image distortion model should have correct shape" + + # reduce + sip_model_reduce = sip_model_image.reduce(scale=1) + assert sip_model_reduce is sip_model_image, "reduce should return the same image if scale is 1" + sip_model_reduce = sip_model_image.reduce(scale=2) + assert sip_model_reduce._data.shape == (16, 11), "reduced model image should have correct shape" + + # crop + sip_model_crop = sip_model_image.crop(1) + assert sip_model_crop._data.shape == (30, 20), "cropped model image should have correct shape" + sip_model_crop = sip_model_image.crop([1]) + assert sip_model_crop._data.shape == (30, 20), "cropped model image should have correct shape" + sip_model_crop = sip_model_image.crop([1, 2]) + assert sip_model_crop._data.shape == (30, 18), "cropped model image should have correct shape" + sip_model_crop = sip_model_image.crop([1, 2, 3, 4]) + assert sip_model_crop._data.shape == (29, 15), "cropped model image should have correct shape" + + sip_model_crop.fluxdensity_to_flux() + assert ap.backend.all( + sip_model_crop.data >= 0 + ), "cropped model image data should be non-negative after flux density to flux conversion" + + +def test_sip_image_wcs_roundtrip(sip_target): + """ + Test that the WCS roundtrip works correctly for SIP images. + """ + i, j = sip_target.pixel_center_meshgrid() + x, y = sip_target.pixel_to_plane(i, j) + i2, j2 = sip_target.plane_to_pixel(x, y) + + assert ap.backend.allclose(i, i2, atol=0.05), "i coordinates should match after WCS roundtrip" + assert ap.backend.allclose(j, j2, atol=0.05), "j coordinates should match after WCS roundtrip" + + +def test_sip_image_save_load(sip_target): + """ + Test that SIP images can be saved and loaded correctly. + """ + # Save the SIP image to a file + sip_target.save("test_sip_image.fits") + + # Load the SIP image from the file + loaded_image = ap.SIPTargetImage(filename="test_sip_image.fits") + + # Check that the loaded image matches the original + assert ap.backend.allclose( + sip_target.data, loaded_image.data + ), "Loaded image data should match original" + assert ap.backend.allclose( + sip_target.pixelscale, loaded_image.pixelscale + ), "Loaded image pixelscale should match original" + assert ap.backend.allclose( + sip_target.zeropoint, loaded_image.zeropoint + ), "Loaded image zeropoint should match original" + print(loaded_image.sipA) + assert all( + np.allclose(sip_target.sipA[key], loaded_image.sipA[key]) for key in sip_target.sipA + ), "Loaded image sipA should match original" + assert all( + np.allclose(sip_target.sipB[key], loaded_image.sipB[key]) for key in sip_target.sipB + ), "Loaded image sipB should match original" + assert all( + np.allclose(sip_target.sipAP[key], loaded_image.sipAP[key]) for key in sip_target.sipAP + ), "Loaded image sipAP should match original" + assert all( + np.allclose(sip_target.sipBP[key], loaded_image.sipBP[key]) for key in sip_target.sipBP + ), "Loaded image sipBP should match original" diff --git a/tests/test_utils.py b/tests/test_utils.py index 3ef9c9e6..79c1c43a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,4 @@ -import unittest import numpy as np -import torch -import h5py -from scipy.signal import fftconvolve from scipy.special import gamma import astrophot as ap from utils import make_basic_sersic, make_basic_gaussian @@ -12,475 +8,147 @@ ###################################################################### -class TestFFT(unittest.TestCase): - def test_fft(self): - - target = make_basic_sersic() - - convolved = ap.utils.operations.fft_convolve_torch( - target.data, - target.psf.data, - ) - scipy_convolve = fftconvolve( - target.data.detach().cpu().numpy(), - target.psf.data.detach().cpu().numpy(), - mode="same", - ) - self.assertLess( - torch.std(convolved), - torch.std(target.data), - "Convolved image should be smoothed", - ) - - self.assertTrue( - np.all(np.isclose(convolved.detach().cpu().numpy(), scipy_convolve)), - "Should reproduce scipy convolve", - ) - - def test_fft_multi(self): - - target = make_basic_sersic() - - convolved = ap.utils.operations.fft_convolve_multi_torch( - target.data, [target.psf.data, target.psf.data] - ) - self.assertLess( - torch.std(convolved), - torch.std(target.data), - "Convolved image should be smoothed", - ) - - -class TestOptimize(unittest.TestCase): - def test_chi2(self): - - # with variance - # with mask - mask = torch.zeros(10, dtype=torch.bool, device=ap.AP_config.ap_device) - mask[2] = 1 - chi2 = ap.utils.optimization.chi_squared( - torch.ones(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - torch.zeros(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - mask=mask, - variance=2 * torch.ones(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - ) - self.assertEqual(chi2, 4.5, "Chi squared calculation incorrect") - chi2_red = ap.utils.optimization.reduced_chi_squared( - torch.ones(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - torch.zeros(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - params=3, - mask=mask, - variance=2 * torch.ones(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - ) - self.assertEqual(chi2_red.item(), 0.75, "Chi squared calculation incorrect") - - # no mask - chi2 = ap.utils.optimization.chi_squared( - torch.ones(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - torch.zeros(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - variance=2 * torch.ones(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - ) - self.assertEqual(chi2, 5, "Chi squared calculation incorrect") - chi2_red = ap.utils.optimization.reduced_chi_squared( - torch.ones(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - torch.zeros(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - params=3, - variance=2 * torch.ones(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - ) - self.assertEqual(chi2_red.item(), 5 / 7, "Chi squared calculation incorrect") - - # no variance - # with mask - mask = torch.zeros(10, dtype=torch.bool, device=ap.AP_config.ap_device) - mask[2] = 1 - chi2 = ap.utils.optimization.chi_squared( - torch.ones(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - torch.zeros(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - mask=mask, - ) - self.assertEqual(chi2.item(), 9, "Chi squared calculation incorrect") - chi2_red = ap.utils.optimization.reduced_chi_squared( - torch.ones(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - torch.zeros(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - params=3, - mask=mask, - ) - self.assertEqual(chi2_red.item(), 1.5, "Chi squared calculation incorrect") - - # no mask - chi2 = ap.utils.optimization.chi_squared( - torch.ones(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - torch.zeros(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - ) - self.assertEqual(chi2.item(), 10, "Chi squared calculation incorrect") - chi2_red = ap.utils.optimization.reduced_chi_squared( - torch.ones(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - torch.zeros(10, dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device), - params=3, - ) - self.assertEqual(chi2_red.item(), 10 / 7, "Chi squared calculation incorrect") - - -class TestPSF(unittest.TestCase): - def test_make_psf(self): - - target = make_basic_gaussian(x=10, y=10) - target += make_basic_gaussian(x=40, y=40, rand=54321) - - psf = ap.utils.initialize.construct_psf( - [[10, 10], [40, 40]], - target.data.detach().cpu().numpy(), - sky_est=0.0, - size=5, - ) - - self.assertTrue(np.all(np.isfinite(psf))) - - -class TestSegtoWindow(unittest.TestCase): - def test_segtowindow(self): - - segmap = np.zeros((100, 100), dtype=int) - - segmap[5:9, 20:30] = 1 - segmap[50:90, 17:35] = 2 - segmap[26:34, 80:85] = 3 - - centroids = ap.utils.initialize.centroids_from_segmentation_map(segmap, image=segmap) - - PAs = ap.utils.initialize.PA_from_segmentation_map( - segmap, - image=segmap, - centroids=centroids, - ) - qs = ap.utils.initialize.q_from_segmentation_map( - segmap, - image=segmap, - centroids=centroids, - ) - - windows = ap.utils.initialize.windows_from_segmentation_map(segmap) - - self.assertEqual(len(windows), 3, "should ignore zero index, but find all three windows") - self.assertEqual(len(centroids), 3, "should ignore zero index, but find all three windows") - self.assertEqual(len(PAs), 3, "should ignore zero index, but find all three windows") - self.assertEqual(len(qs), 3, "should ignore zero index, but find all three windows") - - self.assertEqual(windows[1], [[20, 29], [5, 8]], "Windows should be identified by index") - - # transfer windows - old_image = ap.image.Target_Image( - data=np.zeros((100, 100)), - pixelscale=1.0, - ) - new_image = ap.image.Target_Image( - data=np.zeros((100, 100)), - pixelscale=0.9, - origin=(0.1, 1.2), - ) - new_windows = ap.utils.initialize.transfer_windows(windows, old_image, new_image) - self.assertEqual( - windows.keys(), - new_windows.keys(), - "Transferred windows should have the same set of windows", - ) - - # scale windows - - new_windows = ap.utils.initialize.scale_windows( - windows, image_shape=(100, 100), expand_scale=2, expand_border=3 - ) - - self.assertEqual(new_windows[2], [[5, 45], [27, 100]], "Windows should scale appropriately") - - filtered_windows = ap.utils.initialize.filter_windows( - new_windows, min_size=10, max_size=80, min_area=30, max_area=1000 - ) - filtered_windows = ap.utils.initialize.filter_windows( - new_windows, min_flux=10, max_flux=1000, image=np.ones(segmap.shape) - ) - - self.assertEqual(len(filtered_windows), 2, "windows should have been filtered") - - # check original - self.assertEqual( - windows[3], [[80, 84], [26, 33]], "Original windows should not have changed" - ) - - -class TestConversions(unittest.TestCase): - def test_conversions_units(self): - - # flux to sb - self.assertEqual( - ap.utils.conversions.units.flux_to_sb(1.0, 1.0, 0.0), - 0, - "flux incorrectly converted to sb", - ) - - # sb to flux - self.assertEqual( - ap.utils.conversions.units.sb_to_flux(1.0, 1.0, 0.0), - (10 ** (-1 / 2.5)), - "sb incorrectly converted to flux", - ) - - # flux to mag no error - self.assertEqual( - ap.utils.conversions.units.flux_to_mag(1.0, 0.0), - 0, - "flux incorrectly converted to mag (no error)", - ) - - # flux to mag with error - self.assertEqual( - ap.utils.conversions.units.flux_to_mag(1.0, 0.0, fluxe=1.0), - (0.0, 2.5 / np.log(10)), - "flux incorrectly converted to mag (with error)", - ) - - # mag to flux no error: - self.assertEqual( - ap.utils.conversions.units.mag_to_flux(1.0, 0.0, mage=None), - (10 ** (-1 / 2.5)), - "mag incorrectly converted to flux (no error)", - ) - - # mag to flux with error: - [ - self.assertAlmostEqual( - ap.utils.conversions.units.mag_to_flux(1.0, 0.0, mage=1.0)[i], - (10 ** (-1.0 / 2.5), np.log(10) * (1.0 / 2.5) * 10 ** (-1.0 / 2.5))[i], - msg="mag incorrectly converted to flux (with error)", - ) - for i in range(1) - ] - - # magperarcsec2 to mag with area A defined - self.assertAlmostEqual( - ap.utils.conversions.units.magperarcsec2_to_mag(1.0, a=None, b=None, A=1.0), - (1.0 - 2.5 * np.log10(1.0)), - msg="mag/arcsec^2 incorrectly converted to mag (area A given, a and b not defined)", - ) - - # magperarcsec2 to mag with semi major and minor axes defined (a, and b) - self.assertAlmostEqual( - ap.utils.conversions.units.magperarcsec2_to_mag(1.0, a=1.0, b=1.0, A=None), - (1.0 - 2.5 * np.log10(np.pi)), - msg="mag/arcsec^2 incorrectly converted to mag (semi major/minor axes defined)", - ) - - # mag to magperarcsec2 with area A defined - self.assertAlmostEqual( - ap.utils.conversions.units.mag_to_magperarcsec2(1.0, a=None, b=None, A=1.0, R=None), - (1.0 + 2.5 * np.log10(1.0)), - msg="mag incorrectly converted to mag/arcsec^2 (area A given)", - ) - - # mag to magperarcsec2 with radius R given (assumes circular) - self.assertAlmostEqual( - ap.utils.conversions.units.mag_to_magperarcsec2(1.0, a=None, b=None, A=None, R=1.0), - (1.0 + 2.5 * np.log10(np.pi)), - msg="mag incorrectly converted to mag/arcsec^2 (radius R given)", - ) - - # mag to magperarcsec2 with semi major and minor axes defined (a, and b) - self.assertAlmostEqual( - ap.utils.conversions.units.mag_to_magperarcsec2(1.0, a=1.0, b=1.0, A=None, R=None), - (1.0 + 2.5 * np.log10(np.pi)), - msg="mag incorrectly converted to mag/arcsec^2 (area A given)", - ) - - # position angle PA to radians - self.assertAlmostEqual( - ap.utils.conversions.units.PA_shift_convention(1.0, unit="rad"), - ((1.0 - (np.pi / 2)) % np.pi), - msg="PA incorrectly converted to radians", - ) - - # position angle PA to degrees - self.assertAlmostEqual( - ap.utils.conversions.units.PA_shift_convention(1.0, unit="deg"), - ((1.0 - (180 / 2)) % 180), - msg="PA incorrectly converted to degrees", - ) - - def test_conversion_dict_to_hdf5(self): - - # convert string to hdf5 - self.assertEqual( - ap.utils.conversions.dict_to_hdf5.to_hdf5_has_None(l="test"), - (False), - "Failed to properly identify string object while converting to hdf5", - ) - - # convert __iter__ to hdf5 - self.assertEqual( - ap.utils.conversions.dict_to_hdf5.to_hdf5_has_None(l="__iter__"), - (False), - "Attempted to convert '__iter__' to hdf5 key", - ) - - # convert hdf5 file to dict - h = h5py.File("mytestfile.hdf5", "w") - dset = h.create_dataset("mydataset", (1,), dtype="i") - dset[...] = np.array([1.0]) - self.assertEqual( - ap.utils.conversions.dict_to_hdf5.hdf5_to_dict(h=h), - ({"mydataset": h["mydataset"]}), - "Failed to convert hdf5 file to dict", - ) - - # convert dict to hdf5 - target = make_basic_sersic().data.detach().cpu().numpy()[0] - d = {"sersic": target.tolist()} - ap.utils.conversions.dict_to_hdf5.dict_to_hdf5(h=h5py.File("mytestfile2.hdf5", "w"), D=d) - self.assertEqual( - (list(h5py.File("mytestfile2.hdf5", "r"))), - (list(d)), - "Failed to convert dict of strings to hdf5", - ) - - def test_conversion_functions(self): - - sersic_n = ap.utils.conversions.functions.sersic_n_to_b(1.0) - # sersic I0 to flux - numpy - self.assertAlmostEqual( - ap.utils.conversions.functions.sersic_I0_to_flux_np(1.0, 1.0, 1.0, 1.0), - (2 * np.pi * gamma(2)), - msg="Error converting sersic central intensity to flux (np)", - ) - - # sersic flux to I0 - numpy - self.assertAlmostEqual( - ap.utils.conversions.functions.sersic_flux_to_I0_np(1.0, 1.0, 1.0, 1.0), - (1.0 / (2 * np.pi * gamma(2))), - msg="Error converting sersic flux to central intensity (np)", - ) - - # sersic Ie to flux - numpy - self.assertAlmostEqual( - ap.utils.conversions.functions.sersic_Ie_to_flux_np(1.0, 1.0, 1.0, 1.0), - (2 * np.pi * gamma(2) * np.exp(sersic_n) * sersic_n ** (-2)), - msg="Error converting sersic effective intensity to flux (np)", - ) - - # sersic flux to Ie - numpy - self.assertAlmostEqual( - ap.utils.conversions.functions.sersic_flux_to_Ie_np(1.0, 1.0, 1.0, 1.0), - (1 / (2 * np.pi * gamma(2) * np.exp(sersic_n) * sersic_n ** (-2))), - msg="Error converting sersic flux to effective intensity (np)", - ) - - # inverse sersic - numpy - self.assertAlmostEqual( - ap.utils.conversions.functions.sersic_inv_np(1.0, 1.0, 1.0, 1.0), - (1.0 - (1.0 / sersic_n) * np.log(1.0)), - msg="Error computing inverse sersic function (np)", - ) - - # sersic I0 to flux - torch - tv = torch.tensor([[1.0]], dtype=torch.float64) - self.assertEqual( - torch.round( - ap.utils.conversions.functions.sersic_I0_to_flux_np(tv, tv, tv, tv), - decimals=7, - ), - torch.round(torch.tensor([[2 * np.pi * gamma(2)]]), decimals=7), - msg="Error converting sersic central intensity to flux (torch)", - ) - - # sersic flux to I0 - torch - self.assertEqual( - torch.round( - ap.utils.conversions.functions.sersic_flux_to_I0_np(tv, tv, tv, tv), - decimals=7, - ), - torch.round(torch.tensor([[1.0 / (2 * np.pi * gamma(2))]]), decimals=7), - msg="Error converting sersic flux to central intensity (torch)", - ) - - # sersic Ie to flux - torch - self.assertEqual( - torch.round( - ap.utils.conversions.functions.sersic_Ie_to_flux_np(tv, tv, tv, tv), - decimals=7, - ), - torch.round( - torch.tensor([[2 * np.pi * gamma(2) * np.exp(sersic_n) * sersic_n ** (-2)]]), - decimals=7, - ), - msg="Error converting sersic effective intensity to flux (torch)", - ) - - # sersic flux to Ie - torch - self.assertEqual( - torch.round( - ap.utils.conversions.functions.sersic_flux_to_Ie_np(tv, tv, tv, tv), - decimals=7, - ), - torch.round( - torch.tensor([[1 / (2 * np.pi * gamma(2) * np.exp(sersic_n) * sersic_n ** (-2))]]), - decimals=7, - ), - msg="Error converting sersic flux to effective intensity (torch)", - ) - - # inverse sersic - torch - self.assertEqual( - torch.round(ap.utils.conversions.functions.sersic_inv_np(tv, tv, tv, tv), decimals=7), - torch.round(torch.tensor([[1.0 - (1.0 / sersic_n) * np.log(1.0)]]), decimals=7), - msg="Error computing inverse sersic function (torch)", - ) - - -class TestInterpolate(unittest.TestCase): - def test_interpolate_functions(self): - - # Lanczos kernel interpolation on the center point of a gaussian (10., 10.) - model = make_basic_gaussian(x=10.0, y=10.0).data.detach().cpu().numpy() - lanczos_interp = ap.utils.interpolate.point_Lanczos(model, 10.0, 10.0, scale=0.8) - self.assertTrue(np.all(np.isfinite(model)), msg="gaussian model returning nonfinite values") - self.assertLess(lanczos_interp, 1.0, msg="Lanczos interpolation greater than total flux") - self.assertTrue( - np.isfinite(lanczos_interp), - msg="Lanczos interpolate returning nonfinite values", - ) - - -class TestAngleOperations(unittest.TestCase): - def test_angle_operation_functions(self): - - test_angles = np.array([np.pi, 2 * np.pi, 3 * np.pi, 4 * np.pi]) - # angle median - self.assertAlmostEqual( - ap.utils.angle_operations.Angle_Median(test_angles), - -np.pi / 2, - msg="incorrectly calculating median of list of angles", - ) - - # angle scatter (iqr) - self.assertAlmostEqual( - ap.utils.angle_operations.Angle_Scatter(test_angles), - np.pi, - msg="incorrectly calculating iqr of list of angles", - ) - - def test_angle_com(self): - pixelscale = 0.8 - tar = make_basic_sersic( - N=50, - M=50, - pixelscale=pixelscale, - x=24.5 * pixelscale, - y=24.5 * pixelscale, - PA=115 * np.pi / 180, - ) - - res = ap.utils.angle_operations.Angle_COM_PA(tar.data.detach().cpu().numpy()) - - self.assertAlmostEqual(res + np.pi / 2, 115 * np.pi / 180, delta=0.1) - - -if __name__ == "__main__": - unittest.main() +def test_make_psf(): + + target = make_basic_gaussian(x=10, y=10) + target += make_basic_gaussian(x=40, y=40, rand=54321) + + assert np.all( + np.isfinite(ap.backend.to_numpy(target.data)) + ), "Target image should be finite after creation" + + +def test_conversions_units(): + + # flux to sb + # flux to sb + assert ( + ap.utils.conversions.units.flux_to_sb(1.0, 1.0, 0.0) == 0 + ), "flux incorrectly converted to sb" + + # sb to flux + assert ap.utils.conversions.units.sb_to_flux(1.0, 1.0, 0.0) == ( + 10 ** (-1 / 2.5) + ), "sb incorrectly converted to flux" + + # flux to mag no error + assert ( + ap.utils.conversions.units.flux_to_mag(1.0, 0.0) == 0 + ), "flux incorrectly converted to mag (no error)" + + # flux to mag with error + assert ap.utils.conversions.units.flux_to_mag(1.0, 0.0, fluxe=1.0) == ( + 0.0, + 2.5 / np.log(10), + ), "flux incorrectly converted to mag (with error)" + + # mag to flux no error: + assert ap.utils.conversions.units.mag_to_flux(1.0, 0.0, mage=None) == ( + 10 ** (-1 / 2.5) + ), "mag incorrectly converted to flux (no error)" + + # mag to flux with error: + for i in range(1): + assert np.isclose( + ap.utils.conversions.units.mag_to_flux(1.0, 0.0, mage=1.0)[i], + (10 ** (-1.0 / 2.5), np.log(10) * (1.0 / 2.5) * 10 ** (-1.0 / 2.5))[i], + ), "mag incorrectly converted to flux (with error)" + + # magperarcsec2 to mag with area A defined + assert np.isclose( + ap.utils.conversions.units.magperarcsec2_to_mag(1.0, a=None, b=None, A=1.0), + (1.0 - 2.5 * np.log10(1.0)), + ), "mag/arcsec^2 incorrectly converted to mag (area A given, a and b not defined)" + + # magperarcsec2 to mag with semi major and minor axes defined (a, and b) + assert np.isclose( + ap.utils.conversions.units.magperarcsec2_to_mag(1.0, a=1.0, b=1.0, A=None), + (1.0 - 2.5 * np.log10(np.pi)), + ), "mag/arcsec^2 incorrectly converted to mag (semi major/minor axes defined)" + + # mag to magperarcsec2 with area A defined + assert np.isclose( + ap.utils.conversions.units.mag_to_magperarcsec2(1.0, a=None, b=None, A=1.0, R=None), + (1.0 + 2.5 * np.log10(1.0)), + ), "mag incorrectly converted to mag/arcsec^2 (area A given)" + + # mag to magperarcsec2 with radius R given (assumes circular) + assert np.isclose( + ap.utils.conversions.units.mag_to_magperarcsec2(1.0, a=None, b=None, A=None, R=1.0), + (1.0 + 2.5 * np.log10(np.pi)), + ), "mag incorrectly converted to mag/arcsec^2 (radius R given)" + + # mag to magperarcsec2 with semi major and minor axes defined (a, and b) + assert np.isclose( + ap.utils.conversions.units.mag_to_magperarcsec2(1.0, a=1.0, b=1.0, A=None, R=None), + (1.0 + 2.5 * np.log10(np.pi)), + ), "mag incorrectly converted to mag/arcsec^2 (area A given)" + + +def test_conversion_functions(): + + sersic_n = ap.utils.conversions.functions.sersic_n_to_b(1.0) + # sersic I0 to flux - numpy + assert np.isclose( + ap.utils.conversions.functions.sersic_I0_to_flux_np(1.0, 1.0, 1.0, 1.0), + (2 * np.pi * gamma(2)), + ), "Error converting sersic central intensity to flux (np)" + # sersic flux to I0 - numpy + assert np.isclose( + ap.utils.conversions.functions.sersic_flux_to_I0_np(1.0, 1.0, 1.0, 1.0), + (1.0 / (2 * np.pi * gamma(2))), + ), "Error converting sersic flux to central intensity (np)" + + # sersic Ie to flux - numpy + assert np.isclose( + ap.utils.conversions.functions.sersic_Ie_to_flux_np(1.0, 1.0, 1.0, 1.0), + (2 * np.pi * gamma(2) * np.exp(sersic_n) * sersic_n ** (-2)), + ), "Error converting sersic effective intensity to flux (np)" + + # sersic flux to Ie - numpy + assert np.isclose( + ap.utils.conversions.functions.sersic_flux_to_Ie_np(1.0, 1.0, 1.0, 1.0), + (1 / (2 * np.pi * gamma(2) * np.exp(sersic_n) * sersic_n ** (-2))), + ), "Error converting sersic flux to effective intensity (np)" + + # inverse sersic - numpy + assert np.isclose( + ap.utils.conversions.functions.sersic_inv_np(1.0, 1.0, 1.0, 1.0), + (1.0 - (1.0 / sersic_n) * np.log(1.0)), + ), "Error computing inverse sersic function (np)" + + # sersic I0 to flux - torch + tv = ap.backend.as_array([[1.0]], dtype=ap.backend.float64) + assert ap.backend.allclose( + ap.utils.conversions.functions.sersic_I0_to_flux_np(tv, tv, tv, tv), + ap.backend.as_array([[2 * np.pi * gamma(2)]]), + rtol=1e-7, + ), "Error converting sersic central intensity to flux (torch)" + + # sersic flux to I0 - torch + assert ap.backend.allclose( + ap.utils.conversions.functions.sersic_flux_to_I0_np(tv, tv, tv, tv), + ap.backend.as_array([[1.0 / (2 * np.pi * gamma(2))]]), + rtol=1e-7, + ), "Error converting sersic flux to central intensity (torch)" + + # sersic Ie to flux - torch + assert ap.backend.allclose( + ap.utils.conversions.functions.sersic_Ie_to_flux_np(tv, tv, tv, tv), + ap.backend.as_array([[2 * np.pi * gamma(2) * np.exp(sersic_n) * sersic_n ** (-2)]]), + rtol=1e-7, + ), "Error converting sersic effective intensity to flux (torch)" + + # sersic flux to Ie - torch + assert ap.backend.allclose( + ap.utils.conversions.functions.sersic_flux_to_Ie_np(tv, tv, tv, tv), + ap.backend.as_array([[1 / (2 * np.pi * gamma(2) * np.exp(sersic_n) * sersic_n ** (-2))]]), + rtol=1e-7, + ), "Error converting sersic flux to effective intensity (torch)" + + # inverse sersic - torch + assert ap.backend.allclose( + ap.utils.conversions.functions.sersic_inv_np(tv, tv, tv, tv), + ap.backend.as_array([[1.0 - (1.0 / sersic_n) * np.log(1.0)]]), + rtol=1e-7, + ), "Error computing inverse sersic function (torch)" diff --git a/tests/test_wcs.py b/tests/test_wcs.py deleted file mode 100644 index 8c0d930b..00000000 --- a/tests/test_wcs.py +++ /dev/null @@ -1,553 +0,0 @@ -import unittest -import astrophot as ap -import numpy as np -import torch - - -class TestWPCS(unittest.TestCase): - def test_wpcs_creation(self): - - # Blank startup - wcs_blank = ap.image.WPCS() - - self.assertEqual(wcs_blank.projection, "gnomonic", "Default projection should be Gnomonic") - self.assertTrue( - torch.all(wcs_blank.reference_radec == 0), - "default reference world coordinates should be zeros", - ) - self.assertTrue( - torch.all(wcs_blank.reference_planexy == 0), - "default reference plane coordinates should be zeros", - ) - - # Provided parameters - wcs_set = ap.image.WPCS( - projection="orthographic", - reference_radec=(90, 10), - ) - - self.assertEqual(wcs_set.projection, "orthographic", "Provided projection was Orthographic") - self.assertTrue( - torch.all( - wcs_set.reference_radec - == torch.tensor( - (90, 10), dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device - ) - ), - "World coordinates should be as provided", - ) - self.assertNotEqual( - wcs_blank.projection, - "orthographic", - "Not all WCS objects should be updated", - ) - self.assertFalse( - torch.all( - wcs_blank.reference_radec - == torch.tensor( - (90, 10), dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device - ) - ), - "Not all WCS objects should be updated", - ) - - wcs_set = wcs_set.copy() - - self.assertEqual(wcs_set.projection, "orthographic", "Provided projection was Orthographic") - self.assertTrue( - torch.all( - wcs_set.reference_radec - == torch.tensor( - (90, 10), dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device - ) - ), - "World coordinates should be as provided", - ) - self.assertNotEqual( - wcs_blank.projection, - "orthographic", - "Not all WCS objects should be updated", - ) - self.assertFalse( - torch.all( - wcs_blank.reference_radec - == torch.tensor( - (90, 10), dtype=ap.AP_config.ap_dtype, device=ap.AP_config.ap_device - ) - ), - "Not all WCS objects should be updated", - ) - - def test_wpcs_round_trip(self): - - for projection in ["gnomonic", "orthographic", "steriographic"]: - print(projection) - for ref_coords in [(20.3, 79), (120.2, -19), (300, -50), (0, 0)]: - print(ref_coords) - wcs = ap.image.WPCS( - projection=projection, - reference_radec=ref_coords, - ) - - test_grid_RA, test_grid_DEC = torch.meshgrid( - torch.linspace( - ref_coords[0] - 10, - ref_coords[0] + 10, - 10, - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), # RA - torch.linspace( - ref_coords[1] - 10, - ref_coords[1] + 10, - 10, - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), # DEC - indexing="xy", - ) - - project_x, project_y = wcs.world_to_plane( - test_grid_RA, - test_grid_DEC, - ) - - reproject_RA, reproject_DEC = wcs.plane_to_world( - project_x, - project_y, - ) - - self.assertTrue( - torch.allclose(reproject_RA, test_grid_RA), - "Round trip RA should map back to itself", - ) - self.assertTrue( - torch.allclose(reproject_DEC, test_grid_DEC), - "Round trip DEC should map back to itself", - ) - - def test_wpcs_errors(self): - with self.assertRaises(ap.errors.InvalidWCS): - wcs = ap.image.WPCS( - projection="connor", - ) - - -class TestPPCS(unittest.TestCase): - - def test_ppcs_creation(self): - # Blank startup - wcs_blank = ap.image.PPCS() - - self.assertTrue( - np.all( - wcs_blank.pixelscale.detach().cpu().numpy() == np.array([[1.0, 0.0], [0.0, 1.0]]) - ), - "Default pixelscale should be 1", - ) - self.assertTrue( - torch.all(wcs_blank.reference_imageij == -0.5), - "default reference pixel coordinates should be -0.5", - ) - self.assertTrue( - torch.all(wcs_blank.reference_imagexy == 0.0), - "default reference plane coordinates should be zeros", - ) - - # Provided parameters - wcs_set = ap.image.PPCS( - pixelscale=[[-0.173205, 0.1], [0.15, 0.259808]], - reference_imageij=(5, 10), - reference_imagexy=(0.12, 0.45), - ) - - self.assertTrue( - torch.allclose( - wcs_set.pixelscale, - torch.tensor( - [[-0.173205, 0.1], [0.15, 0.259808]], - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "Provided pixelscale should be used", - ) - self.assertTrue( - torch.allclose( - wcs_set.reference_imageij, - torch.tensor( - (5.0, 10.0), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "pixel reference coordinates should be as provided", - ) - self.assertTrue( - torch.allclose( - wcs_set.reference_imagexy, - torch.tensor( - (0.12, 0.45), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "plane reference coordinates should be as provided", - ) - self.assertTrue( - torch.allclose( - wcs_set.plane_to_pixel( - torch.tensor( - (0.12, 0.45), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ) - ), - torch.tensor( - (5.0, 10.0), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "plane reference coordinates should map to pixel reference coordinates", - ) - self.assertTrue( - torch.allclose( - wcs_set.pixel_to_plane( - torch.tensor( - (5.0, 10.0), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ) - ), - torch.tensor( - (0.12, 0.45), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "pixel reference coordinates should map to plane reference coordinates", - ) - - wcs_set = wcs_set.copy() - - self.assertTrue( - torch.allclose( - wcs_set.pixelscale, - torch.tensor( - [[-0.173205, 0.1], [0.15, 0.259808]], - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "Provided pixelscale should be used", - ) - self.assertTrue( - torch.allclose( - wcs_set.reference_imageij, - torch.tensor( - (5.0, 10.0), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "pixel reference coordinates should be as provided", - ) - self.assertTrue( - torch.allclose( - wcs_set.reference_imagexy, - torch.tensor( - (0.12, 0.45), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "plane reference coordinates should be as provided", - ) - self.assertTrue( - torch.allclose( - wcs_set.plane_to_pixel( - torch.tensor( - (0.12, 0.45), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ) - ), - torch.tensor( - (5.0, 10.0), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "plane reference coordinates should map to pixel reference coordinates", - ) - self.assertTrue( - torch.allclose( - wcs_set.pixel_to_plane( - torch.tensor( - (5.0, 10.0), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ) - ), - torch.tensor( - (0.12, 0.45), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "pixel reference coordinates should map to plane reference coordinates", - ) - - wcs_set.pixelscale = None - - def test_ppcs_round_trip(self): - - for pixelscale in [ - 0.2, - [[0.6, 0.0], [0.0, 0.4]], - [[-0.173205, 0.1], [0.15, 0.259808]], - ]: - print(pixelscale) - for ref_coords in [(20.3, 79), (120.2, -19), (300, -50), (0, 0)]: - print(ref_coords) - wcs = ap.image.PPCS( - pixelscale=pixelscale, - reference_imagexy=ref_coords, - ) - - test_grid_x, test_grid_y = torch.meshgrid( - torch.linspace( - ref_coords[0] - 10, - ref_coords[0] + 10, - 10, - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), # x - torch.linspace( - ref_coords[1] - 10, - ref_coords[1] + 10, - 10, - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), # y - indexing="xy", - ) - - project_i, project_j = wcs.plane_to_pixel( - test_grid_x, - test_grid_y, - ) - - reproject_x, reproject_y = wcs.pixel_to_plane( - project_i, - project_j, - ) - - self.assertTrue( - torch.allclose(reproject_x, test_grid_x), - "Round trip x should map back to itself", - ) - self.assertTrue( - torch.allclose(reproject_y, test_grid_y), - "Round trip y should map back to itself", - ) - - -class TestWCS(unittest.TestCase): - def test_wcs_creation(self): - - wcs = ap.image.WCS( - projection="orthographic", - pixelscale=[[-0.173205, 0.1], [0.15, 0.259808]], - reference_radec=(120.2, -19), - reference_imagexy=(33.0, 123.0), - ) - - wcs2 = wcs.copy() - - self.assertEqual(wcs2.projection, "orthographic", "Provided projection was Orthographic") - self.assertTrue( - torch.allclose(wcs2.reference_radec, wcs.reference_radec), - "World coordinates should be as provided", - ) - self.assertTrue( - torch.allclose(wcs2.reference_planexy, wcs.reference_planexy), - "Plane coordinates should be as provided", - ) - self.assertTrue( - torch.allclose(wcs2.reference_imagexy, wcs.reference_imagexy), - "imagexy coordinates should be as provided", - ) - self.assertTrue( - torch.allclose(wcs2.reference_imageij, wcs.reference_imageij), - "imageij coordinates should be as provided", - ) - self.assertTrue( - torch.allclose(wcs2.pixelscale, wcs.pixelscale), - "pixelscale should be as provided", - ) - - def test_wcs_roundtrip(self): - for pixelscale in [ - 0.2, - [[0.6, 0.0], [0.0, 0.4]], - [[-0.173205, 0.1], [0.15, 0.259808]], - ]: - print(pixelscale) - for ref_coords_xy in [(33.0, 123.0), (-430.2, -11), (-97.0, 5), (0, 0)]: - for projection in ["gnomonic", "orthographic", "steriographic"]: - print(projection) - for ref_coords_radec in [ - (20.3, 79), - (120.2, -19), - (300, -50), - (0, 0), - ]: - print(ref_coords_radec) - wcs = ap.image.WCS( - projection=projection, - pixelscale=pixelscale, - reference_radec=ref_coords_radec, - reference_imagexy=ref_coords_xy, - ) - - test_grid_RA, test_grid_DEC = torch.meshgrid( - torch.linspace( - ref_coords_radec[0] - 10, - ref_coords_radec[0] + 10, - 10, - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), # RA - torch.linspace( - ref_coords_radec[1] - 10, - ref_coords_radec[1] + 10, - 10, - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), # DEC - indexing="xy", - ) - - project_i, project_j = wcs.world_to_pixel( - test_grid_RA, - test_grid_DEC, - ) - - reproject_RA, reproject_DEC = wcs.pixel_to_world( - project_i, - project_j, - ) - - self.assertTrue( - torch.allclose(reproject_RA, test_grid_RA), - "Round trip RA should map back to itself", - ) - self.assertTrue( - torch.allclose(reproject_DEC, test_grid_DEC), - "Round trip DEC should map back to itself", - ) - - def test_wcs_state(self): - wcs = ap.image.WCS( - projection="orthographic", - pixelscale=[[-0.173205, 0.1], [0.15, 0.259808]], - reference_radec=(120.2, -19), - reference_imagexy=(33.0, 123.0), - ) - - wcs_state = wcs.get_state() - - new_wcs = ap.image.WCS(state=wcs_state) - - self.assertEqual( - wcs.projection, new_wcs.projection, "WCS projection should be set by state" - ) - self.assertTrue( - torch.allclose( - wcs.pixelscale, - torch.tensor( - [[-0.173205, 0.1], [0.15, 0.259808]], - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "WCS pixelscale should be set by state", - ) - self.assertTrue( - torch.allclose( - wcs.reference_radec, - torch.tensor( - (120.2, -19), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "WCS reference RA DEC should be set by state", - ) - self.assertTrue( - torch.allclose( - wcs.reference_imagexy, - torch.tensor( - (33.0, 123.0), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "WCS reference image position should be set by state", - ) - - wcs_state = wcs.get_fits_state() - - new_wcs = ap.image.WCS() - new_wcs.set_fits_state(state=wcs_state) - - self.assertEqual( - wcs.projection, new_wcs.projection, "WCS projection should be set by state" - ) - self.assertTrue( - torch.allclose( - wcs.pixelscale, - torch.tensor( - [[-0.173205, 0.1], [0.15, 0.259808]], - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "WCS pixelscale should be set by state", - ) - self.assertTrue( - torch.allclose( - wcs.reference_radec, - torch.tensor( - (120.2, -19), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "WCS reference RA DEC should be set by state", - ) - self.assertTrue( - torch.allclose( - wcs.reference_imagexy, - torch.tensor( - (33.0, 123.0), - dtype=ap.AP_config.ap_dtype, - device=ap.AP_config.ap_device, - ), - ), - "WCS reference image position should be set by state", - ) - - def test_wcs_repr(self): - - wcs = ap.image.WCS( - projection="orthographic", - pixelscale=[[-0.173205, 0.1], [0.15, 0.259808]], - reference_radec=(120.2, -19), - reference_imagexy=(33.0, 123.0), - ) - - S = str(wcs) - R = repr(wcs) diff --git a/tests/test_window.py b/tests/test_window.py index 3e51f079..98c7a679 100644 --- a/tests/test_window.py +++ b/tests/test_window.py @@ -1,412 +1,75 @@ -import unittest import astrophot as ap import numpy as np -import torch -class TestWindow(unittest.TestCase): - def test_window_creation(self): - - window1 = ap.image.Window(origin=(0, 6), pixel_shape=(100, 110)) - - window1.to(dtype=torch.float64, device="cpu") - - self.assertEqual(window1.origin[0], 0, "Window should store origin") - self.assertEqual(window1.origin[1], 6, "Window should store origin") - self.assertEqual(window1.shape[0], 100, "Window should store shape") - self.assertEqual(window1.shape[1], 110, "Window should store shape") - self.assertEqual(window1.center[0], 50.0, "Window should determine center") - self.assertEqual(window1.center[1], 61.0, "Window should determine center") - - self.assertRaises(Exception, ap.image.Window) - - x = str(window1) - x = repr(window1) - - wcs = window1.get_astropywcs() - - def test_window_crop(self): - - window1 = ap.image.Window(origin=(0, 6), pixel_shape=(100, 110)) - - window1.crop_to_pixel([[10, 90], [15, 105]]) - self.assertTrue( - np.all(window1.origin.detach().cpu().numpy() == np.array([10.0, 21])), - "crop pixels should move origin", - ) - self.assertTrue( - np.all(window1.pixel_shape.detach().cpu().numpy() == np.array([80, 90])), - "crop pixels should change shape", - ) - - window2 = ap.image.Window(origin=(0, 6), pixel_shape=(100, 110)) - window2.crop_pixel((5,)) - self.assertTrue( - np.all(window2.origin.detach().cpu().numpy() == np.array([5.0, 11.0])), - "crop pixels should move origin", - ) - self.assertTrue( - np.all(window2.pixel_shape.detach().cpu().numpy() == np.array([90, 100])), - "crop pixels should change shape", - ) - window2.pad_pixel((5,)) - - window2 = ap.image.Window(origin=(0, 6), pixel_shape=(100, 110)) - window2.crop_pixel((5, 6)) - self.assertTrue( - np.all(window2.origin.detach().cpu().numpy() == np.array([5.0, 12.0])), - "crop pixels should move origin", - ) - self.assertTrue( - np.all(window2.pixel_shape.detach().cpu().numpy() == np.array([90, 98])), - "crop pixels should change shape", - ) - window2.pad_pixel((5, 6)) - - window2 = ap.image.Window(origin=(0, 6), pixel_shape=(100, 110)) - window2.crop_pixel((5, 6, 7, 8)) - self.assertTrue( - np.all(window2.origin.detach().cpu().numpy() == np.array([5.0, 12.0])), - "crop pixels should move origin", - ) - self.assertTrue( - np.all(window2.pixel_shape.detach().cpu().numpy() == np.array([88, 96])), - "crop pixels should change shape", - ) - window2.pad_pixel((5, 6, 7, 8)) - - self.assertTrue( - np.all(window2.origin.detach().cpu().numpy() == np.array([0.0, 6.0])), - "pad pixels should move origin", - ) - self.assertTrue( - np.all(window2.pixel_shape.detach().cpu().numpy() == np.array([100, 110])), - "pad pixels should change shape", - ) - - def test_window_get_indices(self): - - window1 = ap.image.Window(origin=(0, 6), pixel_shape=(100, 110)) - xstep, ystep = np.meshgrid(range(100), range(110), indexing="xy") - zstep = xstep + ystep - window2 = ap.image.Window(origin=(15, 15), pixel_shape=(30, 200)) - - zsliced = zstep[window1.get_self_indices(window2)] - self.assertTrue( - np.all(zsliced == zstep[9:110, 15:45]), - "window slices should get correct part of image", - ) - zsliced = zstep[window2.get_other_indices(window1)] - self.assertTrue( - np.all(zsliced == zstep[9:110, 15:45]), - "window slices should get correct part of image", - ) - - def test_window_arithmetic(self): - - windowbig = ap.image.Window(origin=(0, 0), pixel_shape=(100, 110)) - windowsmall = ap.image.Window(origin=(40, 40), pixel_shape=(20, 30)) - - # Logical or, size - ###################################################################### - big_or_small = windowbig | windowsmall - self.assertEqual( - big_or_small.origin[0], - 0, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - big_or_small.shape[0], - 100, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - windowbig.origin[0], - 0, - "logical or of images should not affect initial images", - ) - self.assertEqual( - windowbig.shape[0], - 100, - "logical or of images should not affect initial images", - ) - self.assertEqual( - windowsmall.origin[0], - 40, - "logical or of images should not affect initial images", - ) - self.assertEqual( - windowsmall.shape[0], - 20, - "logical or of images should not affect initial images", - ) - - # Logical and, size - ###################################################################### - big_and_small = windowbig & windowsmall - self.assertEqual( - big_and_small.origin[0], - 40, - "logical and of images should take overlap region", - ) - self.assertEqual( - big_and_small.shape[0], - 20, - "logical and of images should take overlap region", - ) - self.assertEqual( - big_and_small.shape[1], - 30, - "logical and of images should take overlap region", - ) - self.assertEqual( - windowbig.origin[0], - 0, - "logical and of images should not affect initial images", - ) - self.assertEqual( - windowbig.shape[0], - 100, - "logical and of images should not affect initial images", - ) - self.assertEqual( - windowsmall.origin[0], - 40, - "logical and of images should not affect initial images", - ) - self.assertEqual( - windowsmall.shape[0], - 20, - "logical and of images should not affect initial images", - ) - - # Logical or, offset - ###################################################################### - windowoffset = ap.image.Window(origin=(40, -20), pixel_shape=(100, 90)) - big_or_offset = windowbig | windowoffset - self.assertEqual( - big_or_offset.origin[0], - 0, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - big_or_offset.origin[1], - -20, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - big_or_offset.shape[0], - 140, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - big_or_offset.shape[1], - 130, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - windowbig.origin[0], - 0, - "logical or of images should not affect initial images", - ) - self.assertEqual( - windowbig.shape[0], - 100, - "logical or of images should not affect initial images", - ) - self.assertEqual( - windowoffset.origin[0], - 40, - "logical or of images should not affect initial images", - ) - self.assertEqual( - windowoffset.shape[0], - 100, - "logical or of images should not affect initial images", - ) - - # Logical and, offset - ###################################################################### - big_and_offset = windowbig & windowoffset - self.assertEqual( - big_and_offset.origin[0], - 40, - "logical and of images should take overlap region", - ) - self.assertEqual( - big_and_offset.origin[1], - 0, - "logical and of images should take overlap region", - ) - self.assertEqual( - big_and_offset.shape[0], - 60, - "logical and of images should take overlap region", - ) - self.assertEqual( - big_and_offset.shape[1], - 70, - "logical and of images should take overlap region", - ) - self.assertEqual( - windowbig.origin[0], - 0, - "logical and of images should not affect initial images", - ) - self.assertEqual( - windowbig.shape[0], - 100, - "logical and of images should not affect initial images", - ) - self.assertEqual( - windowoffset.origin[0], - 40, - "logical and of images should not affect initial images", - ) - self.assertEqual( - windowoffset.shape[0], - 100, - "logical and of images should not affect initial images", - ) - - # Logical ior, size - ###################################################################### - windowbig |= windowsmall - self.assertEqual( - windowbig.origin[0], - 0, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - windowbig.shape[0], - 100, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - windowsmall.origin[0], - 40, - "logical or of images should not affect input image", - ) - self.assertEqual( - windowsmall.shape[0], - 20, - "logical or of images should not affect input image", - ) - - # Logical ior, offset - ###################################################################### - windowbig |= windowoffset - self.assertEqual( - windowbig.origin[0], - 0, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - windowbig.origin[1], - -20, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - windowbig.shape[0], - 140, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - windowbig.shape[1], - 130, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - windowoffset.origin[0], - 40, - "logical or of images should not affect input image", - ) - self.assertEqual( - windowoffset.shape[0], - 100, - "logical or of images should not affect input image", - ) - - # Logical iand, offset - ###################################################################### - windowbig = ap.image.Window(origin=(0, 0), pixel_shape=(100, 110)) - windowbig &= windowoffset - self.assertEqual( - windowbig.origin[0], 40, "logical and of images should take overlap region" - ) - self.assertEqual(windowbig.origin[1], 0, "logical and of images should take overlap region") - self.assertEqual(windowbig.shape[0], 60, "logical and of images should take overlap region") - self.assertEqual(windowbig.shape[1], 70, "logical and of images should take overlap region") - self.assertEqual( - windowoffset.origin[0], - 40, - "logical and of images should not affect input image", - ) - self.assertEqual( - windowoffset.shape[0], - 100, - "logical and of images should not affect input image", - ) - - windowbig &= windowsmall - - self.assertEqual( - windowbig, - windowsmall, - "logical and of images should take overlap region, equality should be internally determined", - ) - - def test_window_state(self): - window_init = ap.image.Window( - origin=[1.0, 2.0], - pixel_shape=[10, 15], - pixelscale=1, - projection="orthographic", - reference_radec=(0, 0), - ) - window = ap.image.Window(state=window_init.get_state()) - self.assertEqual(window.origin[0].item(), 1.0, "Window initialization should read state") - self.assertEqual(window.shape[0].item(), 10.0, "Window initialization should read state") - self.assertEqual( - window.pixelscale[0][0].item(), - 1.0, - "Window initialization should read state", - ) - - state = window.get_state() - self.assertEqual( - state["reference_imagexy"][1], 2.0, "Window get state should collect values" - ) - self.assertEqual(state["pixel_shape"][1], 15.0, "Window get state should collect values") - self.assertEqual(state["pixelscale"][1][0], 0.0, "Window get state should collect values") - self.assertEqual( - state["projection"], - "orthographic", - "Window get state should collect values", - ) - self.assertEqual( - tuple(state["reference_radec"]), - (0.0, 0.0), - "Window get state should collect values", - ) - - def test_window_logic(self): - - window1 = ap.image.Window(origin=[0.0, 1.0], pixel_shape=[10.0, 11.0]) - window2 = ap.image.Window(origin=[0.0, 1.0], pixel_shape=[10.0, 11.0]) - window3 = ap.image.Window(origin=[-0.6, 0.4], pixel_shape=[15.0, 18.0]) - - self.assertEqual(window1, window2, "same origin, shape windows should evaluate equal") - self.assertNotEqual(window1, window3, "Different windows should not evaluate equal") - - def test_window_errors(self): - - # Initialize with conflicting information - with self.assertRaises(ap.errors.SpecificationConflict): - window = ap.image.Window( - origin=[0.0, 1.0], origin_radec=[5.0, 6.0], pixel_shape=[10.0, 11.0] - ) - - -if __name__ == "__main__": - unittest.main() +def test_window_creation(): + + image = ap.Image( + data=np.zeros((100, 110)), + pixelscale=0.3, + zeropoint=1.0, + name="test_image", + ) + window = ap.Window((2, 107, 3, 97), image) + + assert np.all(window.crpix == image.crpix), "Window should inherit crpix from image" + assert window.identity == image.identity, "Window should inherit identity from image" + assert window.shape == (105, 94), "Window should have correct shape" + assert window.extent == (2, 107, 3, 97), "Window should have correct extent" + assert str(window) == "Window(2, 107, 3, 97)", "String representation should match" + + +def test_window_chunk(): + + image = ap.Image( + data=np.zeros((100, 110)), + pixelscale=0.3, + zeropoint=1.0, + name="test_image", + ) + window1 = ap.Window((2, 107, 3, 97), image) + + subwindows = window1.chunk(10**2) + reconstitute = subwindows[0] + for subwindow in subwindows: + reconstitute |= subwindow + assert ( + reconstitute.i_low == window1.i_low + ), "chunked windows should reconstitute to original window" + assert ( + reconstitute.i_high == window1.i_high + ), "chunked windows should reconstitute to original window" + assert ( + reconstitute.j_low == window1.j_low + ), "chunked windows should reconstitute to original window" + assert ( + reconstitute.j_high == window1.j_high + ), "chunked windows should reconstitute to original window" + + +def test_window_arithmetic(): + + image = ap.Image( + data=np.zeros((100, 110)), + pixelscale=0.3, + zeropoint=1.0, + name="test_image", + ) + windowbig = ap.Window((2, 107, 3, 97), image) + windowsmall = ap.Window((20, 45, 30, 90), image) + + # Logical or, size + ###################################################################### + big_or_small = windowbig | windowsmall + assert big_or_small.i_low == 2, "logical or of images should take largest bounding box" + assert big_or_small.i_high == 107, "logical or of images should take largest bounding box" + assert big_or_small.j_low == 3, "logical or of images should take largest bounding box" + assert big_or_small.j_high == 97, "logical or of images should take largest bounding box" + + # Logical and, size + ###################################################################### + big_and_small = windowbig & windowsmall + assert big_and_small.i_low == 20, "logical and of images should take overlap region" + assert big_and_small.i_high == 45, "logical and of images should take overlap region" + assert big_and_small.j_low == 30, "logical and of images should take overlap region" + assert big_and_small.j_high == 90, "logical and of images should take overlap region" diff --git a/tests/test_window_list.py b/tests/test_window_list.py index c1b88d98..7c983e73 100644 --- a/tests/test_window_list.py +++ b/tests/test_window_list.py @@ -1,7 +1,5 @@ -import unittest import astrophot as ap import numpy as np -import torch ###################################################################### @@ -9,243 +7,37 @@ ###################################################################### -class TestWindowList(unittest.TestCase): - def test_windowlist_creation(self): - - window1 = ap.image.Window(origin=(0, 6), pixel_shape=(100, 110)) - window2 = ap.image.Window(origin=(0, 6), pixel_shape=(100, 110)) - windowlist = ap.image.Window_List([window1, window2]) - - windowlist.to(dtype=torch.float64, device="cpu") - - # under review - self.assertEqual(windowlist.origin[0][0], 0, "Window list should capture origin") - self.assertEqual(windowlist.origin[1][1], 6, "Window list should capture origin") - self.assertEqual(windowlist.shape[0][0], 100, "Window list should capture shape") - self.assertEqual(windowlist.shape[1][1], 110, "Window list should capture shape") - self.assertEqual(windowlist.center[1][0], 50.0, "Window should determine center") - self.assertEqual(windowlist.center[0][1], 61.0, "Window should determine center") - - x = str(windowlist) - x = repr(windowlist) - - def test_window_arithmetic(self): - - windowbig = ap.image.Window(origin=(0, 0), pixel_shape=(100, 110)) - windowsmall = ap.image.Window(origin=(40, 40), pixel_shape=(20, 30)) - windowlistbs = ap.image.Window_List([windowbig, windowsmall]) - windowlistbb = ap.image.Window_List([windowbig, windowbig]) - windowlistsb = ap.image.Window_List([windowsmall, windowbig]) - - # Logical or, size - ###################################################################### - big_or_small = windowlistbs | windowlistsb - - self.assertEqual( - big_or_small.origin[0][0], - 0.0, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - big_or_small.origin[1][0], - 0.0, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - big_or_small.shape[0][0], - 100, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - windowbig.origin[0], - 0, - "logical or of images should not affect initial images", - ) - self.assertEqual( - windowbig.shape[0], - 100, - "logical or of images should not affect initial images", - ) - self.assertEqual( - windowsmall.origin[0], - 40, - "logical or of images should not affect initial images", - ) - self.assertEqual( - windowsmall.shape[0], - 20, - "logical or of images should not affect initial images", - ) - - # Logical and, size - ###################################################################### - big_and_small = windowlistbs & windowlistsb - self.assertEqual( - big_and_small.origin[0][0], - 40, - "logical and of images should take overlap region", - ) - self.assertEqual( - big_and_small.shape[0][0], - 20, - "logical and of images should take overlap region", - ) - self.assertEqual( - big_and_small.shape[0][1], - 30, - "logical and of images should take overlap region", - ) - self.assertEqual( - windowbig.origin[0], - 0, - "logical and of images should not affect initial images", - ) - self.assertEqual( - windowbig.shape[0], - 100, - "logical and of images should not affect initial images", - ) - self.assertEqual( - windowsmall.origin[0], - 40, - "logical and of images should not affect initial images", - ) - self.assertEqual( - windowsmall.shape[0], - 20, - "logical and of images should not affect initial images", - ) - - # Logical or, offset - ###################################################################### - windowoffset = ap.image.Window(origin=(40, -20), pixel_shape=(100, 90)) - windowlistoffset = ap.image.Window_List([windowoffset, windowoffset]) - big_or_offset = windowlistbb | windowlistoffset - self.assertEqual( - big_or_offset.origin[0][0], - 0, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - big_or_offset.origin[1][1], - -20, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - big_or_offset.shape[0][0], - 140, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - big_or_offset.shape[1][1], - 130, - "logical or of images should take largest bounding box", - ) - self.assertEqual( - windowbig.origin[0], - 0, - "logical or of images should not affect initial images", - ) - self.assertEqual( - windowbig.shape[0], - 100, - "logical or of images should not affect initial images", - ) - self.assertEqual( - windowoffset.origin[0], - 40, - "logical or of images should not affect initial images", - ) - self.assertEqual( - windowoffset.shape[0], - 100, - "logical or of images should not affect initial images", - ) - - # Logical and, offset - ###################################################################### - big_and_offset = windowlistbb & windowlistoffset - self.assertEqual( - big_and_offset.origin[0][0], - 40, - "logical and of images should take overlap region", - ) - self.assertEqual( - big_and_offset.origin[0][1], - 0, - "logical and of images should take overlap region", - ) - self.assertEqual( - big_and_offset.shape[0][0], - 60, - "logical and of images should take overlap region", - ) - self.assertEqual( - big_and_offset.shape[0][1], - 70, - "logical and of images should take overlap region", - ) - self.assertEqual( - windowbig.origin[0], - 0, - "logical and of images should not affect initial images", - ) - self.assertEqual( - windowbig.shape[0], - 100, - "logical and of images should not affect initial images", - ) - self.assertEqual( - windowoffset.origin[0], - 40, - "logical and of images should not affect initial images", - ) - self.assertEqual( - windowoffset.shape[0], - 100, - "logical and of images should not affect initial images", - ) - - def test_windowlist_logic(self): - - window1 = ap.image.Window(origin=[0.0, 1.0], pixel_shape=[10.2, 11.8]) - window2 = ap.image.Window(origin=[0.0, 1.0], pixel_shape=[10.2, 11.8]) - window3 = ap.image.Window(origin=[-0.6, 0.4], pixel_shape=[15.2, 18.0]) - windowlist1 = ap.image.Window_List([window1, window1.copy()]) - windowlist2 = ap.image.Window_List([window2, window2.copy()]) - windowlist3 = ap.image.Window_List([window3, window3.copy()]) - - self.assertEqual( - windowlist1, windowlist2, "same origin, shape windows should evaluate equal" - ) - self.assertNotEqual(windowlist1, windowlist3, "Different windows should not evaluate equal") - - def test_image_list_errors(self): - window1 = ap.image.Window(origin=[0.0, 1.0], pixel_shape=[10.2, 11.8]) - window2 = ap.image.Window(origin=[0.0, 1.0], pixel_shape=[10.2, 11.8]) - windowlist1 = ap.image.Window_List([window1, window2]) - - # Bad ra dec reference point - window2 = ap.image.Window( - origin=[0.0, 1.0], reference_radec=np.ones(2), pixel_shape=[10.2, 11.8] - ) - with self.assertRaises(ap.errors.ConflicingWCS): - test_image = ap.image.Window_List((window1, window2)) - - # Bad tangent plane x y reference point - window2 = ap.image.Window( - origin=[0.0, 1.0], reference_planexy=np.ones(2), pixel_shape=[10.2, 11.8] - ) - with self.assertRaises(ap.errors.ConflicingWCS): - test_image = ap.image.Window_List((window1, window2)) - - # Bad WCS projection - window2 = ap.image.Window( - origin=[0.0, 1.0], projection="orthographic", pixel_shape=[10.2, 11.8] - ) - with self.assertRaises(ap.errors.ConflicingWCS): - test_image = ap.image.Window_List((window1, window2)) - - -if __name__ == "__main__": - unittest.main() +def test_windowlist_creation(): + + image1 = ap.Image( + data=np.zeros((10, 15)), + pixelscale=1.0, + zeropoint=1.0, + name="image1", + ) + image2 = ap.Image( + data=np.ones((15, 10)), + pixelscale=0.5, + zeropoint=2.0, + name="image2", + ) + window1 = ap.Window([4, 13, 5, 9], image1) + window2 = ap.Window([0, 7, 1, 8], image2) + windowlist = ap.WindowList([window1, window2]) + + window3 = ap.Window([3, 12, 5, 8], image1) + assert windowlist.index(window3) == 0, "WindowList should find window by index" + assert len(windowlist) == 2, "WindowList should have two windows" + + window21 = ap.Window([5, 10, 6, 9], image1) + window22 = ap.Window([0, 9, 0, 8], image2) + windowlist2 = ap.WindowList([window21, window22]) + + windowlist_and = windowlist & windowlist2 + assert len(windowlist_and) == 2, "WindowList should have two windows after intersection" + assert windowlist_and[0].image is image1, "First window should be from image1" + assert windowlist_and[1].image is image2, "Second window should be from image2" + assert windowlist_and[0].i_low == 5, "First window should have i_low of 5" + assert windowlist_and[0].i_high == 10, "First window should have i_high of 10" + assert windowlist_and[0].j_low == 6, "First window should have j_low of 6" + assert windowlist_and[0].j_high == 9, "First window should have j_high of 9" diff --git a/tests/utils.py b/tests/utils.py index 2fd94f8f..7bbbb9df 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,69 +8,69 @@ def get_astropy_wcs(): "SIMPLE": "T", "NAXIS": 2, "NAXIS1": 180, - "NAXIS2": 180, + "NAXIS2": 170, "CTYPE1": "RA---TAN", "CTYPE2": "DEC--TAN", "CRVAL1": 195.0588, "CRVAL2": 28.0608, "CRPIX1": 90.5, - "CRPIX2": 90.5, + "CRPIX2": 85.5, "CD1_1": -0.000416666666666667, "CD1_2": 0.0, "CD2_1": 0.0, "CD2_2": 0.000416666666666667, - "IMAGEW": 180.0, - "IMAGEH": 180.0, + # "IMAGEW": 180.0, + # "IMAGEH": 170.0, } return WCS(hdr) def make_basic_sersic( - N=50, + N=52, M=50, pixelscale=0.8, - x=24.5, - y=25.4, + x=20.5, + y=21.4, PA=45 * np.pi / 180, - q=0.6, - n=2, - Re=7.1, - Ie=0, + q=0.7, + n=1.5, + Re=15.1, + Ie=10.0, rand=12345, + **kwargs, ): np.random.seed(rand) mask = np.zeros((N, M), dtype=bool) mask[0][0] = True - target = ap.image.Target_Image( + target = ap.TargetImage( data=np.zeros((N, M)), pixelscale=pixelscale, psf=ap.utils.initialize.gaussian_psf(2 / pixelscale, 11, pixelscale), mask=mask, zeropoint=21.5, + **kwargs, ) - MODEL = ap.models.Sersic_Galaxy( + MODEL = ap.models.SersicGalaxy( name="basic sersic model", target=target, - parameters={ - "center": [x, y], - "PA": PA, - "q": q, - "n": n, - "Re": Re, - "Ie": Ie, - }, + center=[x, y], + PA=PA, + q=q, + n=n, + Re=Re, + Ie=Ie, sampling_mode="quad:5", ) - img = MODEL().data.detach().cpu().numpy() + img = ap.backend.to_numpy(MODEL().data) target.data = ( img - + np.random.normal(scale=0.1, size=img.shape) + + np.random.normal(scale=0.5, size=img.shape) + np.random.normal(scale=np.sqrt(img) / 10) ) - target.variance = 0.1**2 + img / 100 + target.variance = 0.5**2 + img / 100 return target @@ -88,25 +88,23 @@ def make_basic_gaussian( ): np.random.seed(rand) - target = ap.image.Target_Image( + target = ap.TargetImage( data=np.zeros((N, M)), pixelscale=pixelscale, psf=ap.utils.initialize.gaussian_psf(2 / pixelscale, 11, pixelscale), ) - MODEL = ap.models.Gaussian_Galaxy( + MODEL = ap.models.GaussianGalaxy( name="basic gaussian source", target=target, - parameters={ - "center": [x, y], - "sigma": sigma, - "flux": flux, - "PA": {"value": 0.0, "locked": True}, - "q": {"value": 0.99, "locked": True}, - }, + center=[x, y], + sigma=sigma, + flux=flux, + PA=0.0, + q=0.99, ) - img = MODEL().data.detach().cpu().numpy() + img = ap.backend.to_numpy(MODEL().data) target.data = ( img + np.random.normal(scale=0.1, size=img.shape) @@ -120,17 +118,17 @@ def make_basic_gaussian( def make_basic_gaussian_psf( N=25, pixelscale=0.8, - sigma=3, + sigma=4, rand=12345, ): np.random.seed(rand) - psf = ap.utils.initialize.gaussian_psf(sigma / pixelscale, N, pixelscale) - psf += np.random.normal(scale=psf / 2) - psf[psf < 0] = 0 - target = ap.image.PSF_Image( - data=psf / np.sum(psf), + psf = ap.utils.initialize.gaussian_psf(sigma * pixelscale, N, pixelscale) + target = ap.PSFImage( + data=psf + np.random.normal(scale=np.sqrt(psf) / 20), pixelscale=pixelscale, + variance=psf / 400, ) + target.normalize() return target