-
-
Notifications
You must be signed in to change notification settings - Fork 79
Enable save_metric=1 and sources MCMC metric info from new JSON file
#844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
cdb9d9d
62fefae
c7fef6b
838b78c
f7a9ae9
c47791b
4823796
d23a9d6
66512d5
4a40ef7
cb493b5
84ae036
063dfb9
2d076af
b138308
29ddbfd
4502aef
d80461d
0f5dab8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,16 @@ | ||
| """Container for metadata parsed from the output of a CmdStan run""" | ||
|
|
||
| from __future__ import annotations | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this necessary for pydantic? I know there was a lot of discussion about annotation handling and pydantic, but I didn't keep up with it |
||
|
|
||
| import copy | ||
| import json | ||
| import math | ||
| import os | ||
| from typing import Any, Iterator | ||
| from typing import Any, Iterator, Literal | ||
|
|
||
| import numpy as np | ||
| import stanio | ||
| from pydantic import BaseModel, Field, field_validator, model_validator | ||
|
|
||
| from cmdstanpy.utils import stancsv | ||
|
|
||
|
|
@@ -79,3 +85,53 @@ def stan_vars(self) -> dict[str, stanio.Variable]: | |
| These are the user-defined variables in the Stan program. | ||
| """ | ||
| return self._stan_vars | ||
|
|
||
|
|
||
| class MetricInfo(BaseModel): | ||
| """Structured representation of HMC-NUTS metric information, | ||
| as output by CmdStan""" | ||
|
|
||
| chain_id: int = Field(gt=0) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe chain id can be equal to 0, just not greater than
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also -- is there a reason we need the chain ID in here? Since it isn't in the file is necessitates our own from_json rather than just using |
||
| stepsize: float | ||
| metric_type: Literal["diag_e", "dense_e", "unit_e"] | ||
| inv_metric: np.ndarray | ||
|
|
||
| model_config = {"arbitrary_types_allowed": True} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sounds like a scary config option -- what does it do for us here? |
||
|
|
||
| @field_validator("inv_metric", mode="before") | ||
| @classmethod | ||
| def convert_inv_metric(cls, v: Any) -> np.ndarray: | ||
| return np.asarray(v) | ||
|
|
||
| @field_validator("stepsize") | ||
| @classmethod | ||
| def validate_stepsize(cls, v: float) -> float: | ||
| if not math.isnan(v) and v <= 0: | ||
| raise ValueError("stepsize must be greater than 0 or NaN") | ||
| return v | ||
|
|
||
| @model_validator(mode="after") | ||
| def validate_inv_metric_shape(self) -> MetricInfo: | ||
| if ( | ||
| self.metric_type in ("diag_e", "unit_e") | ||
| and self.inv_metric.ndim != 1 | ||
| ): | ||
| raise ValueError( | ||
| "inv_metric must be 1D for diag_e and unit_e metric type" | ||
| ) | ||
| if self.metric_type == "dense_e": | ||
| if self.inv_metric.ndim != 2: | ||
| raise ValueError("Dense inv_metric must be 2D") | ||
| if self.inv_metric.shape[0] != self.inv_metric.shape[1]: | ||
| raise ValueError("Dense inv_metric must be square") | ||
|
|
||
| return self | ||
|
|
||
| @classmethod | ||
| def from_json(cls, file: str | os.PathLike, chain_id: int) -> MetricInfo: | ||
| """Parse and validate a metric json given a file path and chain_id""" | ||
| with open(file) as f: | ||
| info_dict = json.load(f) | ||
|
|
||
| info_dict['chain_id'] = chain_id | ||
| return cls.model_validate(info_dict) # type: ignore | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this flag? It seems like it's always equivalent to
self._metric_type == None