Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def compose(self, idx: int, cmd: list[str]) -> list[str]:
cmd.append(f'window={self.adapt_metric_window}')
if self.adapt_step_size is not None:
cmd.append('term_buffer={}'.format(self.adapt_step_size))
if self.adapt_engaged:
cmd.append('save_metric=1')
# End adapt subsection

if self.num_chains > 1:
cmd.append('num_chains={}'.format(self.num_chains))

Expand Down
69 changes: 47 additions & 22 deletions cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
stancsv,
)

from .metadata import InferenceMetadata
from .metadata import InferenceMetadata, MetricInfo
from .runset import RunSet


Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(
# info from CSV values, instantiated lazily
self._draws: np.ndarray = np.array(())
# only valid when not is_fixed_param
self._metric_type: str | None = None
self._metric: np.ndarray = np.array(())
self._step_size: np.ndarray = np.array(())
self._divergences: np.ndarray = np.zeros(self.runset.chains, dtype=int)
Expand All @@ -92,6 +93,8 @@ def __init__(
# info from CSV header and initial and final comment blocks
config = self._validate_csv_files()
self._metadata: InferenceMetadata = InferenceMetadata(config)
self._chain_metric_info: list[MetricInfo] = []
self._metric_info_parsed: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this flag? It seems like it's always equivalent to self._metric_type == None

if not self._is_fixed_param:
self._check_sampler_diagnostics()

Expand Down Expand Up @@ -216,11 +219,13 @@ def metric_type(self) -> str | None:
to CmdStan arg 'metric'.
When sampler algorithm 'fixed_param' is specified, metric_type is None.
"""
return (
self._metadata.cmdstan_config['metric']
if not self._is_fixed_param
else None
)
if self._is_fixed_param:
return None

if not self._metric_info_parsed:
self._parse_metric_info()

return self._metric_type

@property
def inv_metric(self) -> np.ndarray | None:
Expand All @@ -230,10 +235,15 @@ def inv_metric(self) -> np.ndarray | None:
a ``nchains x nparams x nparams`` array when metric_type is 'dense_e',
or ``None`` when metric_type is 'unit_e' or algorithm is 'fixed_param'.
"""
if self._is_fixed_param or self.metric_type == 'unit_e':
if self._is_fixed_param:
return None

if not self._metric_info_parsed:
self._parse_metric_info()

if self.metric_type == 'unit_e':
return None

self._assemble_draws()
return self._metric

@property
Expand All @@ -242,8 +252,13 @@ def step_size(self) -> np.ndarray | None:
Step size used by sampler for each chain.
When sampler algorithm 'fixed_param' is specified, step size is None.
"""
self._assemble_draws()
return self._step_size if not self._is_fixed_param else None
if self._is_fixed_param:
return None

if not self._metric_info_parsed:
self._parse_metric_info()

return self._step_size

@property
def thin(self) -> int:
Expand Down Expand Up @@ -382,6 +397,27 @@ def _validate_csv_files(self) -> dict[str, Any]:
self._max_treedepths[i] = drest['ct_max_treedepth']
return dzero

def _parse_metric_info(self) -> None:
"""Extracts metric type, inv_metric, and step size information from the
parsed metric JSONs."""
self._chain_metric_info = [
MetricInfo.from_json(mf, chain_id)
for mf, chain_id in zip(
self.runset.metric_files, self.runset.chain_ids
)
]
metric_types = {cmi.metric_type for cmi in self._chain_metric_info}
if len(metric_types) != 1:
raise ValueError("Inconsistent metric types found across chains")
self._metric_type = self._chain_metric_info[0].metric_type
self._metric = np.asarray(
[cmi.inv_metric for cmi in self._chain_metric_info]
)
self._step_size = np.asarray(
[cmi.stepsize for cmi in self._chain_metric_info]
)
self._metric_info_parsed = True

def _check_sampler_diagnostics(self) -> None:
"""
Warn if any iterations ended in divergences or hit maxtreedepth.
Expand Down Expand Up @@ -424,13 +460,11 @@ def _assemble_draws(self) -> None:
dtype=np.float64,
order='F',
)
self._step_size = np.empty(self.chains, dtype=np.float64)

mass_matrix_per_chain = []
for chain in range(self.chains):
try:
(
comments,
_,
header,
draws,
) = stancsv.parse_comments_header_and_draws(
Expand All @@ -443,20 +477,11 @@ def _assemble_draws(self) -> None:
draws_np = np.empty((0, n_cols))

self._draws[:, chain, :] = draws_np
if not self._is_fixed_param:
(
self._step_size[chain],
mass_matrix,
) = stancsv.parse_hmc_adaptation_lines(comments)
mass_matrix_per_chain.append(mass_matrix)
except Exception as exc:
raise ValueError(
f"Parsing output from {self.runset.csv_files[chain]} failed"
) from exc

if all(mm is not None for mm in mass_matrix_per_chain):
self._metric = np.array(mass_matrix_per_chain)

assert self._draws is not None

def summary(
Expand Down
58 changes: 57 additions & 1 deletion cmdstanpy/stanfit/metadata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""Container for metadata parsed from the output of a CmdStan run"""

from __future__ import annotations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary for pydantic? I know there was a lot of discussion about annotation handling and pydantic, but I didn't keep up with it


import copy
import json
import math
import os
from typing import Any, Iterator
from typing import Any, Iterator, Literal

import numpy as np
import stanio
from pydantic import BaseModel, Field, field_validator, model_validator

from cmdstanpy.utils import stancsv

Expand Down Expand Up @@ -79,3 +85,53 @@ def stan_vars(self) -> dict[str, stanio.Variable]:
These are the user-defined variables in the Stan program.
"""
return self._stan_vars


class MetricInfo(BaseModel):
"""Structured representation of HMC-NUTS metric information,
as output by CmdStan"""

chain_id: int = Field(gt=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe chain id can be equal to 0, just not greater than

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also -- is there a reason we need the chain ID in here? Since it isn't in the file is necessitates our own from_json rather than just using model_validate_json

stepsize: float
metric_type: Literal["diag_e", "dense_e", "unit_e"]
inv_metric: np.ndarray

model_config = {"arbitrary_types_allowed": True}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like a scary config option -- what does it do for us here?


@field_validator("inv_metric", mode="before")
@classmethod
def convert_inv_metric(cls, v: Any) -> np.ndarray:
return np.asarray(v)

@field_validator("stepsize")
@classmethod
def validate_stepsize(cls, v: float) -> float:
if not math.isnan(v) and v <= 0:
raise ValueError("stepsize must be greater than 0 or NaN")
return v

@model_validator(mode="after")
def validate_inv_metric_shape(self) -> MetricInfo:
if (
self.metric_type in ("diag_e", "unit_e")
and self.inv_metric.ndim != 1
):
raise ValueError(
"inv_metric must be 1D for diag_e and unit_e metric type"
)
if self.metric_type == "dense_e":
if self.inv_metric.ndim != 2:
raise ValueError("Dense inv_metric must be 2D")
if self.inv_metric.shape[0] != self.inv_metric.shape[1]:
raise ValueError("Dense inv_metric must be square")

return self

@classmethod
def from_json(cls, file: str | os.PathLike, chain_id: int) -> MetricInfo:
"""Parse and validate a metric json given a file path and chain_id"""
with open(file) as f:
info_dict = json.load(f)

info_dict['chain_id'] = chain_id
return cls.model_validate(info_dict) # type: ignore
24 changes: 24 additions & 0 deletions cmdstanpy/stanfit/runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
self._stdout_files, self._profile_files = [], []
self._csv_files, self._diagnostic_files = [], []
self._config_files = []
self._metric_files = []

# per-process output files
if one_process_per_chain and chains > 1:
Expand Down Expand Up @@ -87,6 +88,10 @@ def __init__(
# per-chain output files
if chains == 1:
self._csv_files = [self.gen_file_name(".csv")]
if args.method == Method.SAMPLE:
self._metric_files = [
self.gen_file_name(".json", extra="metric")
]
if args.save_latent_dynamics:
self._diagnostic_files = [
self.gen_file_name(".csv", extra="diagnostic")
Expand All @@ -95,6 +100,20 @@ def __init__(
self._csv_files = [
self.gen_file_name(".csv", id=id) for id in self._chain_ids
]
if args.method == Method.SAMPLE:
if one_process_per_chain:
self._metric_files = [
os.path.join(
self._outdir,
f"{self._base_outfile}_{id}_metric.json",
)
for id in self._chain_ids
]
else:
self._metric_files = [
self.gen_file_name(".json", extra="metric", id=id)
for id in self._chain_ids
]
if args.save_latent_dynamics:
self._diagnostic_files = [
self.gen_file_name(".csv", extra="diagnostic", id=id)
Expand Down Expand Up @@ -222,6 +241,11 @@ def profile_files(self) -> list[str]:
"""List of paths to CmdStan profiler files."""
return self._profile_files

@property
def metric_files(self) -> list[str]:
"""List of paths to CmdStan NUTS-HMC sampler metric files."""
return self._metric_files

def gen_file_name(
self, suffix: str, *, extra: str = "", id: int | None = None
) -> str:
Expand Down
100 changes: 0 additions & 100 deletions cmdstanpy/utils/stancsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,44 +103,6 @@ def csv_bytes_list_to_numpy(
return out


def parse_hmc_adaptation_lines(
comment_lines: list[bytes],
) -> tuple[float | None, npt.NDArray[np.float64] | None]:
"""Extracts step size/mass matrix information from the Stan CSV comment
lines by parsing the adaptation section. If the diag_e metric is used,
the returned mass matrix will be a 1D array of the diagnoal elements,
if the dense_e metric is used, it will be a 2D array representing the
entire matrix, and if unit_e is used then None will be returned.

Returns a (step_size, mass_matrix) tuple"""
step_size, mass_matrix = None, None

cleaned_lines = (ln.lstrip(b"# ") for ln in comment_lines)
in_matrix_block = False
diag_e_metric = False
matrix_lines = []
for line in cleaned_lines:
if in_matrix_block and line.strip():
# Stop when we get to timing block
if line.startswith(b"Elapsed Time"):
break
matrix_lines.append(line)
elif line.startswith(b"Step size"):
_, ss_str = line.split(b" = ")
step_size = float(ss_str)
elif line.startswith(b"Diagonal") or line.startswith(b"Elements"):
in_matrix_block = True
elif line.startswith(b"No free"):
break
elif b"diag_e" in line:
diag_e_metric = True
if matrix_lines:
mass_matrix = csv_bytes_list_to_numpy(matrix_lines)
if diag_e_metric and mass_matrix.shape[0] == 1:
mass_matrix = mass_matrix[0]
return step_size, mass_matrix


def extract_key_val_pairs(
comment_lines: list[bytes], remove_default_text: bool = True
) -> Iterator[tuple[str, str]]:
Expand Down Expand Up @@ -346,67 +308,6 @@ def column_count(ln: bytes) -> int:
)


def raise_on_invalid_adaptation_block(comment_lines: list[bytes]) -> None:
"""Throws ValueErrors if the parsed adaptation block is invalid, e.g.
the metric information is not present, consistent with the rest of
the file, or the step size info cannot be processed."""

def column_count(ln: bytes) -> int:
return ln.count(b",") + 1

ln_iter = enumerate(comment_lines, start=2)
metric = None
for _, line in ln_iter:
if b"metric =" in line:
_, val = line.split(b" = ")
metric = val.replace(b"(Default)", b"").strip().decode()
if b"Adaptation terminated" in line:
break
else: # No adaptation block found
raise ValueError("No adaptation block found, expecting metric")

if metric is None:
raise ValueError("No reported metric found")
# At this point iterator should be in the adaptation block

# Ensure step size exists and is valid float
num, line = next(ln_iter)
if not line.startswith(b"# Step size"):
raise ValueError(
f"line {num}: expecting step size, found:\n\t \"{line.decode()}\""
)
_, step_size = line.split(b" = ")
try:
float(step_size.strip())
except ValueError as exc:
raise ValueError(
f"line {num}: invalid step size: {step_size.decode()}"
) from exc

# Ensure mass matrix valid
num, line = next(ln_iter)
if metric == "unit_e":
return
if not (
(metric == "diag_e" and line.startswith(b"# Diagonal elements of "))
or (metric == "dense_e" and line.startswith(b"# Elements of inverse"))
):
raise ValueError(
f"line {num}: invalid or missing mass matrix specification"
)

# Validating mass matrix shape
_, line = next(ln_iter)
num_unconstrained_params = column_count(line)
if metric == "diag_e":
return
for (num, line), _ in zip(ln_iter, range(1, num_unconstrained_params)):
if column_count(line) != num_unconstrained_params:
raise ValueError(
f"line {num}: invalid or missing mass matrix specification"
)


def parse_timing_lines(
comment_lines: list[bytes],
) -> dict[str, float]:
Expand Down Expand Up @@ -489,7 +390,6 @@ def parse_sampler_metadata_from_csv(
and header
and not is_sneaky_fixed_param(header)
):
raise_on_invalid_adaptation_block(comments)
max_depth: int = config["max_depth"] # type: ignore
max_tree_hits, divs = extract_max_treedepth_and_divergence_counts(
header, draws, max_depth, num_warmup
Expand Down
Loading