diff --git a/.travis.yml b/.travis.yml index 59178a1..76898ec 100644 --- a/.travis.yml +++ b/.travis.yml @@ -46,7 +46,6 @@ env: # Try all python versions with the latest numpy - SETUP_CMD='test' - matrix: include: @@ -73,7 +72,6 @@ matrix: - python: 3.5 env: NUMPY_VERSION=1.10 - install: # We now use the ci-helpers package to set up our testing environment. diff --git a/docs/gen_plots.py b/docs/gen_plots.py index fd81c71..b4de26d 100644 --- a/docs/gen_plots.py +++ b/docs/gen_plots.py @@ -1,6 +1,5 @@ from astropy.modeling.fitting import SherpaFitter from astropy.modeling.models import Gaussian1D, Gaussian2D - import numpy as np import matplotlib.pyplot as plt diff --git a/docs/index.rst b/docs/index.rst index 937ac8e..b9111a3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -45,6 +45,7 @@ To make use of the entry points plugin registry which automatically makes the |S Otherwise one can just use the latest stable ``astropy`` via:: conda install astropy + Next install Sherpa_ using the conda ``sherpa`` channel. Note that Sherpa currently needs to be installed after astropy on Mac OSX. @@ -236,4 +237,8 @@ API/Reference Credit ------ -The development of this package was made possible by the generous support of the `Google Summer of Code `_ program in 2016 under the `OpenAstronomy `_ by `Michele Costa `_ with the support and advice of mentors `Tom Aldcroft `_, `Omar Laurino `_, `Moritz Guenther `_, and `Doug Burke `_. +The development of this package was made possible by the generous support of the `Google Summer of Code `_ program in 2016 +under the `OpenAstronomy `_ +by `Michele Costa `_ with the support and advice of mentors +`Tom Aldcroft `_, `Omar Laurino `_, +`Moritz Guenther `_, and `Doug Burke `_. diff --git a/saba/main.py b/saba/main.py index afe9d83..e3daa00 100644 --- a/saba/main.py +++ b/saba/main.py @@ -1,18 +1,21 @@ from __future__ import (absolute_import, unicode_literals, division, print_function) -import numpy as np from collections import OrderedDict +import numpy as np +import warnings +import copy + from sherpa.fit import Fit from sherpa.data import Data1D, Data1DInt, Data2D, Data2DInt, DataSimulFit from sherpa.data import BaseData from sherpa.models import UserModel, Parameter, SimulFitModel +from sherpa.instrument import PSFModel from sherpa.stats import Chi2, Chi2ConstVar, Chi2DataVar, Chi2Gehrels from sherpa.stats import Chi2ModVar, Chi2XspecVar, LeastSq from sherpa.stats import CStat, WStat, Cash from sherpa.optmethods import GridSearch, LevMar, MonCar, NelderMead from sherpa.estmethods import Confidence, Covariance, Projection from sherpa.sim import MCMC -import warnings from astropy.utils import format_doc from astropy.utils.exceptions import AstropyUserWarning @@ -27,8 +30,6 @@ if "SherpaFitter" not in w.message.message: warnings.warn(w) -# from astropy.modeling - __all__ = ('SherpaFitter', 'SherpaMCMC') @@ -268,6 +269,50 @@ def stat_values(self): return self._stat_values +def make_rsp(data,rsp): + """ + Take an array as a response which is then convolved with the model output. + Parameters + ---------- + data: a sherpa dataset + rsp : an array which represets rsp + """ + def wrap_rsp(data, rsp): + rsp = np.asarray(rsp) + rdata = copy.deepcopy(data) + rdata.y = rsp + psf = PSFModel("user_rsp", rdata) + psf.fold(data) + return psf + + try: + ndims = len(data.data.datasets[0].get_dims()) + except AttributeError: + ndims = len(data.data.get_dims()) + + if ndims > 1: + return None + else: + rsp = np.asarray(rsp) + + if data.ndata > 1: + if rsp.ndim > 1 or rsp.dtype == np.object: + if rsp.shape[0] == data.ndata: + zipped = zip(data.data.datasets, rsp) + else: + raise AstropyUserWarning("There is more than 1 but not" + " ndata responses") + else: + zipped = zip(data.data.datasets, + [rsp for _ in xrange(data.ndata)]) + + rsp = [] + for da, rr in zipped: + rsp.append(wrap_rsp(da, rr)) + else: + return wrap_rsp(data.data, rsp) + + class SherpaFitter(Fitter): __doc__ = """ Sherpa Fitter for astropy models. @@ -321,7 +366,7 @@ def __init__(self, optimizer="levmar", statistic="leastsq", estmethod="covarianc setattr(self.__class__, 'est_config', property(lambda s: s._est_config, doc=self._est_method.__doc__)) - def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg=None, bkg_scale=1, **kwargs): + def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg=None, bkg_scale=1, rsp=None, **kwargs): """ Fit the astropy model with a the sherpa fit routines. @@ -347,6 +392,9 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg_sale : float or list of floats (optional) the scaling factor for the dataset if a single value is supplied it will be copied for each dataset + rsp: array or list of arrays + this is convolved with the model output when fitting the model + N.B only 1D is currently supported. **kwargs : keyword arguments will be passed on to sherpa fit routine @@ -364,10 +412,15 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, self._data = Dataset(n_inputs, x, y, z, xbinsize, ybinsize, err, bkg, bkg_scale) + if rsp is not None: + self._rsp = make_rsp(self._data, rsp) + else: + self._rsp = None + if self._data.ndata > 1: if len(models) == 1: - self._fitmodel = ConvertedModel([models.copy() for _ in xrange(self._data.ndata)], tie_list) + self._fitmodel = ConvertedModel([models.copy() for _ in xrange(self._data.ndata)], tie_list, rsp=self._rsp) # Copy the model so each data set has the same model! elif len(models) == self._data.ndata: self._fitmodel = ConvertedModel(models, tie_list) @@ -377,9 +430,10 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, else: if len(models) > 1: self._data.make_simfit(len(models)) - self._fitmodel = ConvertedModel(models, tie_list) + self._fitmodel = ConvertedModel(models, tie_list, + rsp=self._rsp) else: - self._fitmodel = ConvertedModel(models) + self._fitmodel = ConvertedModel(models, rsp=self._rsp) self._fitter = Fit(self._data.data, self._fitmodel.sherpa_model, self._stat_method, self._opt_method, self._est_method, **kwargs) self.fit_info = self._fitter.fit() @@ -633,16 +687,31 @@ class ConvertedModel(object): e.g. [(modelB.y, modelA.x)] will mean that y in modelB will be tied to x of modelA """ - def __init__(self, models, tie_list=None): + def __init__(self, models, tie_list=None, rsp=None): self.model_dict = OrderedDict() try: models.parameters # does it quack self.sherpa_model = self._astropy_to_sherpa_model(models) + self.rsp = rsp + if rsp is not None: + self.sherpa_model = rsp(self.sherpa_model) + self.model_dict[models] = self.sherpa_model except AttributeError: - for mod in models: + try: + n_rsp = len(rsp) + assert len(models) == n_rsp, AstropyUserWarning("The number of responses must be either 1 or the numeber of models %i" % len(models)) + zipped = zip(models, rsp) + + except TypeError: + zipped = zip(models, [rsp for _ in range(len(models))]) + + for mod, rsp in zipped: self.model_dict[mod] = self._astropy_to_sherpa_model(mod) + if rsp is not None: + self.sherpa_model[mod] = rsp(self.sherpa_model[mod]) + if tie_list is not None: for par1, par2 in tie_list: getattr(self.model_dict[par1._model], par1.name).link = getattr(self.model_dict[par2._model], par2.name) diff --git a/saba/tests/coveragerc b/saba/tests/coveragerc index 3a21984..3e5082c 100644 --- a/saba/tests/coveragerc +++ b/saba/tests/coveragerc @@ -18,7 +18,7 @@ exclude_lines = pragma: no cover # Don't complain about packages we have installed - # except ImportError + except ImportError # Don't complain if tests don't hit assertions raise AssertionError diff --git a/saba/tests/test_main.py b/saba/tests/test_main.py index 67489c9..db5cc63 100644 --- a/saba/tests/test_main.py +++ b/saba/tests/test_main.py @@ -289,6 +289,21 @@ def test_bkg_doesnt_explode(self): sfit(m, x, y, bkg=bkg) # TODO: Make this better! + + def test_rsp1d_doesnt_explode(self): + """ + Check this goes through the motions + """ + + self.fitter(self.model1d.copy(), self.x1, self.y1, err=self.dy1, rsp=self.rsp1) + + def test_rsp1d_multi_doesnt_explode(self): + """ + Check this goes through the motions + """ + + self.fitter([self.model1d.copy(), self.model1d_2.copy()], [self.x1, self.x2], [self.y1, self.y2], err=[self.dy1, self.dy2], rsp=[self.rsp1, self.rsp2]) + def test_entry_points(self): # a little to test that entry points can be loaded! from pkg_resources import iter_entry_points