diff --git a/.travis.yml b/.travis.yml
index 59178a1..76898ec 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -46,7 +46,6 @@ env:
# Try all python versions with the latest numpy
- SETUP_CMD='test'
-
matrix:
include:
@@ -73,7 +72,6 @@ matrix:
- python: 3.5
env: NUMPY_VERSION=1.10
-
install:
# We now use the ci-helpers package to set up our testing environment.
diff --git a/docs/gen_plots.py b/docs/gen_plots.py
index fd81c71..b4de26d 100644
--- a/docs/gen_plots.py
+++ b/docs/gen_plots.py
@@ -1,6 +1,5 @@
from astropy.modeling.fitting import SherpaFitter
from astropy.modeling.models import Gaussian1D, Gaussian2D
-
import numpy as np
import matplotlib.pyplot as plt
diff --git a/docs/index.rst b/docs/index.rst
index 937ac8e..b9111a3 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -45,6 +45,7 @@ To make use of the entry points plugin registry which automatically makes the |S
Otherwise one can just use the latest stable ``astropy`` via::
conda install astropy
+
Next install Sherpa_ using the conda ``sherpa`` channel. Note that Sherpa
currently needs to be installed after astropy on Mac OSX.
@@ -236,4 +237,8 @@ API/Reference
Credit
------
-The development of this package was made possible by the generous support of the `Google Summer of Code `_ program in 2016 under the `OpenAstronomy `_ by `Michele Costa `_ with the support and advice of mentors `Tom Aldcroft `_, `Omar Laurino `_, `Moritz Guenther `_, and `Doug Burke `_.
+The development of this package was made possible by the generous support of the `Google Summer of Code `_ program in 2016
+under the `OpenAstronomy `_
+by `Michele Costa `_ with the support and advice of mentors
+`Tom Aldcroft `_, `Omar Laurino `_,
+`Moritz Guenther `_, and `Doug Burke `_.
diff --git a/saba/main.py b/saba/main.py
index afe9d83..e3daa00 100644
--- a/saba/main.py
+++ b/saba/main.py
@@ -1,18 +1,21 @@
from __future__ import (absolute_import, unicode_literals, division,
print_function)
-import numpy as np
from collections import OrderedDict
+import numpy as np
+import warnings
+import copy
+
from sherpa.fit import Fit
from sherpa.data import Data1D, Data1DInt, Data2D, Data2DInt, DataSimulFit
from sherpa.data import BaseData
from sherpa.models import UserModel, Parameter, SimulFitModel
+from sherpa.instrument import PSFModel
from sherpa.stats import Chi2, Chi2ConstVar, Chi2DataVar, Chi2Gehrels
from sherpa.stats import Chi2ModVar, Chi2XspecVar, LeastSq
from sherpa.stats import CStat, WStat, Cash
from sherpa.optmethods import GridSearch, LevMar, MonCar, NelderMead
from sherpa.estmethods import Confidence, Covariance, Projection
from sherpa.sim import MCMC
-import warnings
from astropy.utils import format_doc
from astropy.utils.exceptions import AstropyUserWarning
@@ -27,8 +30,6 @@
if "SherpaFitter" not in w.message.message:
warnings.warn(w)
-# from astropy.modeling
-
__all__ = ('SherpaFitter', 'SherpaMCMC')
@@ -268,6 +269,50 @@ def stat_values(self):
return self._stat_values
+def make_rsp(data,rsp):
+ """
+ Take an array as a response which is then convolved with the model output.
+ Parameters
+ ----------
+ data: a sherpa dataset
+ rsp : an array which represets rsp
+ """
+ def wrap_rsp(data, rsp):
+ rsp = np.asarray(rsp)
+ rdata = copy.deepcopy(data)
+ rdata.y = rsp
+ psf = PSFModel("user_rsp", rdata)
+ psf.fold(data)
+ return psf
+
+ try:
+ ndims = len(data.data.datasets[0].get_dims())
+ except AttributeError:
+ ndims = len(data.data.get_dims())
+
+ if ndims > 1:
+ return None
+ else:
+ rsp = np.asarray(rsp)
+
+ if data.ndata > 1:
+ if rsp.ndim > 1 or rsp.dtype == np.object:
+ if rsp.shape[0] == data.ndata:
+ zipped = zip(data.data.datasets, rsp)
+ else:
+ raise AstropyUserWarning("There is more than 1 but not"
+ " ndata responses")
+ else:
+ zipped = zip(data.data.datasets,
+ [rsp for _ in xrange(data.ndata)])
+
+ rsp = []
+ for da, rr in zipped:
+ rsp.append(wrap_rsp(da, rr))
+ else:
+ return wrap_rsp(data.data, rsp)
+
+
class SherpaFitter(Fitter):
__doc__ = """
Sherpa Fitter for astropy models.
@@ -321,7 +366,7 @@ def __init__(self, optimizer="levmar", statistic="leastsq", estmethod="covarianc
setattr(self.__class__, 'est_config', property(lambda s: s._est_config, doc=self._est_method.__doc__))
- def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg=None, bkg_scale=1, **kwargs):
+ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg=None, bkg_scale=1, rsp=None, **kwargs):
"""
Fit the astropy model with a the sherpa fit routines.
@@ -347,6 +392,9 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None,
bkg_sale : float or list of floats (optional)
the scaling factor for the dataset if a single value
is supplied it will be copied for each dataset
+ rsp: array or list of arrays
+ this is convolved with the model output when fitting the model
+ N.B only 1D is currently supported.
**kwargs :
keyword arguments will be passed on to sherpa fit routine
@@ -364,10 +412,15 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None,
self._data = Dataset(n_inputs, x, y, z, xbinsize, ybinsize, err, bkg, bkg_scale)
+ if rsp is not None:
+ self._rsp = make_rsp(self._data, rsp)
+ else:
+ self._rsp = None
+
if self._data.ndata > 1:
if len(models) == 1:
- self._fitmodel = ConvertedModel([models.copy() for _ in xrange(self._data.ndata)], tie_list)
+ self._fitmodel = ConvertedModel([models.copy() for _ in xrange(self._data.ndata)], tie_list, rsp=self._rsp)
# Copy the model so each data set has the same model!
elif len(models) == self._data.ndata:
self._fitmodel = ConvertedModel(models, tie_list)
@@ -377,9 +430,10 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None,
else:
if len(models) > 1:
self._data.make_simfit(len(models))
- self._fitmodel = ConvertedModel(models, tie_list)
+ self._fitmodel = ConvertedModel(models, tie_list,
+ rsp=self._rsp)
else:
- self._fitmodel = ConvertedModel(models)
+ self._fitmodel = ConvertedModel(models, rsp=self._rsp)
self._fitter = Fit(self._data.data, self._fitmodel.sherpa_model, self._stat_method, self._opt_method, self._est_method, **kwargs)
self.fit_info = self._fitter.fit()
@@ -633,16 +687,31 @@ class ConvertedModel(object):
e.g. [(modelB.y, modelA.x)] will mean that y in modelB will be tied to x of modelA
"""
- def __init__(self, models, tie_list=None):
+ def __init__(self, models, tie_list=None, rsp=None):
self.model_dict = OrderedDict()
try:
models.parameters # does it quack
self.sherpa_model = self._astropy_to_sherpa_model(models)
+ self.rsp = rsp
+ if rsp is not None:
+ self.sherpa_model = rsp(self.sherpa_model)
+
self.model_dict[models] = self.sherpa_model
except AttributeError:
- for mod in models:
+ try:
+ n_rsp = len(rsp)
+ assert len(models) == n_rsp, AstropyUserWarning("The number of responses must be either 1 or the numeber of models %i" % len(models))
+ zipped = zip(models, rsp)
+
+ except TypeError:
+ zipped = zip(models, [rsp for _ in range(len(models))])
+
+ for mod, rsp in zipped:
self.model_dict[mod] = self._astropy_to_sherpa_model(mod)
+ if rsp is not None:
+ self.sherpa_model[mod] = rsp(self.sherpa_model[mod])
+
if tie_list is not None:
for par1, par2 in tie_list:
getattr(self.model_dict[par1._model], par1.name).link = getattr(self.model_dict[par2._model], par2.name)
diff --git a/saba/tests/coveragerc b/saba/tests/coveragerc
index 3a21984..3e5082c 100644
--- a/saba/tests/coveragerc
+++ b/saba/tests/coveragerc
@@ -18,7 +18,7 @@ exclude_lines =
pragma: no cover
# Don't complain about packages we have installed
- # except ImportError
+ except ImportError
# Don't complain if tests don't hit assertions
raise AssertionError
diff --git a/saba/tests/test_main.py b/saba/tests/test_main.py
index 67489c9..db5cc63 100644
--- a/saba/tests/test_main.py
+++ b/saba/tests/test_main.py
@@ -289,6 +289,21 @@ def test_bkg_doesnt_explode(self):
sfit(m, x, y, bkg=bkg)
# TODO: Make this better!
+
+ def test_rsp1d_doesnt_explode(self):
+ """
+ Check this goes through the motions
+ """
+
+ self.fitter(self.model1d.copy(), self.x1, self.y1, err=self.dy1, rsp=self.rsp1)
+
+ def test_rsp1d_multi_doesnt_explode(self):
+ """
+ Check this goes through the motions
+ """
+
+ self.fitter([self.model1d.copy(), self.model1d_2.copy()], [self.x1, self.x2], [self.y1, self.y2], err=[self.dy1, self.dy2], rsp=[self.rsp1, self.rsp2])
+
def test_entry_points(self):
# a little to test that entry points can be loaded!
from pkg_resources import iter_entry_points