Skip to content

Commit

Permalink
allow reduced form prior spec
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Feb 14, 2022
1 parent 5dbfc79 commit 5ee2ed7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Installing the repo version is as simple as
Documentation
-------------

There is some `preliminary documentation <https://pydsge.readthedocs.io/en/latest/index.html>`_ out there.
There is some `documentation <https://pydsge.readthedocs.io/en/latest/index.html>`_ out there.

- `Installation Guide <https://pydsge.readthedocs.io/en/latest/installation_guide.html>`_
- `Getting Started <https://pydsge.readthedocs.io/en/latest/getting_started.html>`_
Expand Down
33 changes: 21 additions & 12 deletions pydsge/mpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def posterior_sampler(self, nsamples, seed=0, verbose=True):
import random

random.seed(seed)
sample = self.get_chain()[-self.get_tune :]
sample = self.get_chain()[-self.get_tune:]
sample = sample.reshape(-1, sample.shape[(-1)])
sample = random.choices(sample, k=nsamples)
return sample
Expand Down Expand Up @@ -118,7 +118,8 @@ def runner(locseed):
]

if test_lprob:
draw_prob = pelf.lprob(pdraw, linear=None, verbose=verbose > 1)
draw_prob = pelf.lprob(
pdraw, linear=None, verbose=verbose > 1)
done = not np.isinf(draw_prob)
else:
pelf.set_par(pdraw)
Expand Down Expand Up @@ -205,7 +206,8 @@ def get_par(
gen_sys(self, verbose=verbose, **args)
pfnames, pffunc = self.parafunc
pars_str = [str(p) for p in self.parameters]
pars = np.array(self.par) if hasattr(self, "par") else np.array(self.par_fix)
pars = np.array(self.par) if hasattr(
self, "par") else np.array(self.par_fix)
if npar is not None:
if len(npar) != len(self.par_fix):
pars[self.prior_arg] = npar
Expand Down Expand Up @@ -273,9 +275,10 @@ def get_par(
elif dummy == "prior_mean":
par_cand = []
for pp in self.prior.keys():
if self.prior[pp][3] == "uniform":
if "uniform" in self.prior[pp]:
par_cand.append(
0.5 * self.prior[pp][(-2)] + 0.5 * self.prior[pp][(-1)]
0.5 * self.prior[pp][(-2)] +
0.5 * self.prior[pp][(-1)]
)
else:
par_cand.append(self.prior[pp][(-2)])
Expand All @@ -286,9 +289,10 @@ def get_par(
if self.prior[pp][3] == "inv_gamma_dynare":
par_cand.append(self.prior[pp][(-2)] * 10)
else:
if self.prior[pp][3] == "uniform":
if "uniform" in self.prior[pp]:
par_cand.append(
0.5 * self.prior[pp][(-2)] + 0.5 * self.prior[pp][(-1)]
0.5 * self.prior[pp][(-2)] +
0.5 * self.prior[pp][(-1)]
)
else:
par_cand.append(self.prior[pp][(-2)])
Expand Down Expand Up @@ -320,11 +324,13 @@ def get_par(
return (pdict, pfdict)
if asdict:
return dict(
zip(np.array(pars_str)[self.prior_arg], np.round(par_cand, roundto))
zip(np.array(pars_str)[self.prior_arg],
np.round(par_cand, roundto))
)
if nsamples > 1:
if dummy not in ("prior", "post", "posterior"):
par_cand = par_cand * (1 + 0.001 * np.random.randn(nsamples, len(par_cand)))
par_cand = par_cand * \
(1 + 0.001 * np.random.randn(nsamples, len(par_cand)))
return par_cand


Expand Down Expand Up @@ -366,7 +372,8 @@ def set_par(

pfnames, pffunc = self.parafunc
pars_str = [str(p) for p in self.parameters]
par = np.array(self.par) if hasattr(self, "par") else np.array(self.par_fix)
par = np.array(self.par) if hasattr(
self, "par") else np.array(self.par_fix)

if setpar is None:
if dummy is None:
Expand Down Expand Up @@ -397,7 +404,8 @@ def set_par(
"Can not set parameter '%s' that is a function of other parameters." % dummy
)
else:
raise SyntaxError("Parameter '%s' is not defined for this model." % dummy)
raise SyntaxError(
"Parameter '%s' is not defined for this model." % dummy)

gen_sys(self, par=list(par), verbose=verbose, **args)

Expand All @@ -408,7 +416,8 @@ def set_par(
if verbose > 1:
pdict = dict(zip(pars_str, np.round(self.par, roundto)))
pfdict = dict(zip(pfnames, np.round(pffunc(self.par), roundto)))
print("[set_par:]".ljust(15, " ") + " Parameter(s):\n%s\n%s" % (pdict, pfdict))
print("[set_par:]".ljust(15, " ") +
" Parameter(s):\n%s\n%s" % (pdict, pfdict))

if return_vv:
return get_par(self), self.vv
Expand Down

0 comments on commit 5ee2ed7

Please sign in to comment.