diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index f35ab4c523..825deab3b1 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -37,7 +37,9 @@ import abc from collections.abc import Sequence +from enum import Enum, auto from functools import singledispatch +from typing import Union from pytensor.graph.op import Op from pytensor.graph.utils import MetaType @@ -131,14 +133,33 @@ def _icdf_helper(rv, value, **kwargs): return rv_icdf +class MeasureType(Enum): + Discrete = auto() + Continuous = auto() + Mixed = auto() + + class MeasurableVariable(abc.ABC): """A variable that can be assigned a measure/log-probability""" + def __init__( + self, + *args, + ndim_supp: Union[int, tuple], + supp_axes: tuple, + measure_type: Union[MeasureType, tuple], + **kwargs, + ): + self.ndim_supp = ndim_supp + self.supp_axes = supp_axes + self.measure_type = measure_type + super().__init__(*args, **kwargs) + MeasurableVariable.register(RandomVariable) -class MeasurableElemwise(Elemwise): +class MeasurableElemwise(MeasurableVariable, Elemwise): """Base class for Measurable Elemwise variables""" valid_scalar_types: tuple[MetaType, ...] = () @@ -152,4 +173,36 @@ def __init__(self, scalar_op, *args, **kwargs): super().__init__(scalar_op, *args, **kwargs) -MeasurableVariable.register(MeasurableElemwise) +def get_measure_type_info( + base_var, +): + from pymc.logprob.utils import DiracDelta + + if not isinstance(base_var, MeasurableVariable): + base_op = base_var.owner.op + index = base_var.owner.outputs.index(base_var) + else: + base_op = base_var + if not isinstance(base_op, MeasurableVariable): + raise TypeError("base_op must be a RandomVariable or MeasurableVariable") + + if isinstance(base_op, DiracDelta): + ndim_supp = 0 + supp_axes = () + measure_type = MeasureType.Discrete + return ndim_supp, supp_axes, measure_type + + if isinstance(base_op, RandomVariable): + ndim_supp = base_op.ndim_supp + supp_axes = tuple(range(-ndim_supp, 0)) + measure_type = ( + MeasureType.Continuous if base_op.dtype.startswith("float") else MeasureType.Discrete + ) + return base_op.ndim_supp, supp_axes, measure_type + else: + # We'll need this for operators like scan and IfElse + if isinstance(base_op.ndim_supp, tuple): + if len(base_var.owner.outputs) != len(base_op.ndim_supp): + raise NotImplementedError("length of outputs and meta-properties is different") + return base_op.ndim_supp[index], base_op.supp_axes, base_op.measure_type + return base_op.ndim_supp, base_op.supp_axes, base_op.measure_type diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index f5d8cf848c..5a3ab8ecfa 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -25,9 +25,11 @@ from pymc.logprob.abstract import ( MeasurableElemwise, + MeasureType, _logcdf_helper, _logprob, _logprob_helper, + get_measure_type_info, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import check_potential_measurability @@ -81,7 +83,14 @@ def find_measurable_comparisons( elif isinstance(node_scalar_op, LE): node_scalar_op = GE() - compared_op = MeasurableComparison(node_scalar_op) + ndim_supp, supp_axes, _ = get_measure_type_info(measurable_var) + + compared_op = MeasurableComparison( + scalar_op=node_scalar_op, + ndim_supp=ndim_supp, + supp_axes=supp_axes, + measure_type=MeasureType.Discrete, + ) compared_rv = compared_op.make_node(measurable_var, const).default_output() return [compared_rv] @@ -148,7 +157,13 @@ def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[list[ return None node_scalar_op = node.op.scalar_op - bitwise_op = MeasurableBitwise(node_scalar_op) + ndim_supp, supp_axis, measure_type = get_measure_type_info(base_var) + bitwise_op = MeasurableBitwise( + scalar_op=node_scalar_op, + ndim_supp=ndim_supp, + supp_axes=supp_axis, + measure_type=MeasureType.Discrete, + ) bitwise_rv = bitwise_op.make_node(base_var).default_output() return [bitwise_rv] diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index b9221e08db..5aeefd253f 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -48,7 +48,13 @@ from pytensor.tensor.math import ceil, clip, floor, round_half_to_even from pytensor.tensor.variable import TensorConstant -from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob +from pymc.logprob.abstract import ( + MeasurableElemwise, + MeasureType, + _logcdf, + _logprob, + get_measure_type_info, +) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import CheckParameterValue @@ -59,9 +65,6 @@ class MeasurableClip(MeasurableElemwise): valid_scalar_types = (Clip,) -measurable_clip = MeasurableClip(scalar_clip) - - @node_rewriter(tracks=[clip]) def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]: # TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub) @@ -81,6 +84,15 @@ def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[Te lower_bound = lower_bound if (lower_bound is not base_var) else pt.constant(-np.inf) upper_bound = upper_bound if (upper_bound is not base_var) else pt.constant(np.inf) + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_var) + + if measure_type == MeasureType.Continuous: + measure_type = MeasureType.Mixed + + measurable_clip = MeasurableClip( + scalar_clip, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) + clipped_rv = measurable_clip.make_node(base_var, lower_bound, upper_bound).outputs[0] return [clipped_rv] @@ -167,7 +179,11 @@ def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[lis return None [base_var] = node.inputs - rounded_op = MeasurableRound(node.op.scalar_op) + ndim_supp, supp_axis, _ = get_measure_type_info(base_var) + measure_type = MeasureType.Discrete + rounded_op = MeasurableRound( + node.op.scalar_op, ndim_supp=ndim_supp, supp_axes=supp_axis, measure_type=measure_type + ) rounded_rv = rounded_op.make_node(base_var).default_output() rounded_rv.name = node.outputs[0].name return [rounded_rv] diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index 1cf202ec5e..2e8cb34caf 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -43,18 +43,20 @@ from pytensor.tensor import TensorVariable from pytensor.tensor.shape import SpecifyShape -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.abstract import ( + MeasurableVariable, + _logprob, + _logprob_helper, + get_measure_type_info, +) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import replace_rvs_by_values -class MeasurableSpecifyShape(SpecifyShape): +class MeasurableSpecifyShape(MeasurableVariable, SpecifyShape): """A placeholder used to specify a log-likelihood for a specify-shape sub-graph.""" -MeasurableVariable.register(MeasurableSpecifyShape) - - @_logprob.register(MeasurableSpecifyShape) def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs): (value,) = values @@ -86,7 +88,11 @@ def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable ): return None # pragma: no cover - new_op = MeasurableSpecifyShape() + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_rv) + + new_op = MeasurableSpecifyShape( + ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) new_rv = new_op.make_node(base_rv, *shape).default_output() return [new_rv] @@ -100,13 +106,10 @@ def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable ) -class MeasurableCheckAndRaise(CheckAndRaise): +class MeasurableCheckAndRaise(MeasurableVariable, CheckAndRaise): """A placeholder used to specify a log-likelihood for an assert sub-graph.""" -MeasurableVariable.register(MeasurableCheckAndRaise) - - @_logprob.register(MeasurableCheckAndRaise) def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs): (value,) = values @@ -133,7 +136,14 @@ def find_measurable_check_and_raise(fgraph, node) -> Optional[list[TensorVariabl return None op = node.op - new_op = MeasurableCheckAndRaise(exc_type=op.exc_type, msg=op.msg) + ndim_supp, supp_axis, d_type = get_measure_type_info(base_rv) + new_op = MeasurableCheckAndRaise( + exc_type=op.exc_type, + msg=op.msg, + ndim_supp=ndim_supp, + supp_axes=supp_axis, + measure_type=d_type, + ) new_rv = new_op.make_node(base_rv, *conds).default_output() return [new_rv] diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index 810f226c8b..4b7fb10a13 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -42,17 +42,19 @@ from pytensor.tensor import TensorVariable from pytensor.tensor.extra_ops import CumOp -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.abstract import ( + MeasurableVariable, + _logprob, + _logprob_helper, + get_measure_type_info, +) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -class MeasurableCumsum(CumOp): +class MeasurableCumsum(MeasurableVariable, CumOp): """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" -MeasurableVariable.register(MeasurableCumsum) - - @_logprob.register(MeasurableCumsum) def logprob_cumsum(op, values, base_rv, **kwargs): """Compute the log-likelihood graph for a `Cumsum`.""" @@ -101,7 +103,14 @@ def find_measurable_cumsums(fgraph, node) -> Optional[list[TensorVariable]]: if not rv_map_feature.request_measurable(node.inputs): return None - new_op = MeasurableCumsum(axis=node.op.axis or 0, mode="add") + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_rv) + new_op = MeasurableCumsum( + axis=node.op.axis or 0, + mode="add", + ndim_supp=ndim_supp, + supp_axes=supp_axes, + measure_type=measure_type, + ) new_rv = new_op.make_node(base_rv).default_output() return [new_rv] diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 011ce5e5fe..e1a90bd72b 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -68,8 +68,10 @@ from pymc.logprob.abstract import ( MeasurableElemwise, MeasurableVariable, + MeasureType, _logprob, _logprob_helper, + get_measure_type_info, ) from pymc.logprob.rewriting import ( PreserveRVMappings, @@ -217,16 +219,17 @@ def rv_pull_down(x: TensorVariable) -> TensorVariable: return fgraph.outputs[0] -class MixtureRV(Op): +class MixtureRV(MeasurableVariable, Op): """A placeholder used to specify a log-likelihood for a mixture sub-graph.""" __props__ = ("indices_end_idx", "out_dtype", "out_broadcastable") - def __init__(self, indices_end_idx, out_dtype, out_broadcastable): - super().__init__() + def __init__(self, *args, indices_end_idx, out_dtype, out_broadcastable, **kwargs): + # super().__init__(*args, **kwargs) self.indices_end_idx = indices_end_idx self.out_dtype = out_dtype self.out_broadcastable = out_broadcastable + super().__init__(*args, **kwargs) def make_node(self, *inputs): return Apply(self, list(inputs), [TensorType(self.out_dtype, self.out_broadcastable)()]) @@ -235,9 +238,6 @@ def perform(self, node, inputs, outputs): raise NotImplementedError("This is a stand-in Op.") # pragma: no cover -MeasurableVariable.register(MixtureRV) - - def get_stack_mixture_vars( node: Apply, ) -> tuple[Optional[list[TensorVariable]], Optional[int]]: @@ -304,11 +304,28 @@ def find_measurable_index_mixture(fgraph, node): if rv_map_feature.request_measurable(mixture_rvs) != mixture_rvs: return None + all_ndim_supp = [] + all_supp_axes = [] + all_measure_type = [] + for i in range(0, len(mixture_rvs)): + ndim_supp, supp_axes, measure_type = get_measure_type_info(mixture_rvs[i]) + all_ndim_supp.append(ndim_supp) + all_supp_axes.append(supp_axes) + all_measure_type.append(measure_type) + + if all_measure_type[1:] == all_measure_type[:-1]: + m_type = all_measure_type[0] + else: + m_type = MeasureType.Mixed + # Replace this sub-graph with a `MixtureRV` mix_op = MixtureRV( - 1 + len(mixing_indices), - old_mixture_rv.dtype, - old_mixture_rv.broadcastable, + ndim_supp=all_ndim_supp[0], + supp_axes=all_supp_axes[0], + measure_type=all_measure_type, + indices_end_idx=1 + len(mixing_indices), + out_dtype=old_mixture_rv.dtype, + out_broadcastable=old_mixture_rv.broadcastable, ) new_node = mix_op.make_node(*([join_axis, *mixing_indices, *mixture_rvs])) @@ -403,9 +420,6 @@ class MeasurableSwitchMixture(MeasurableElemwise): valid_scalar_types = (Switch,) -measurable_switch_mixture = MeasurableSwitchMixture(scalar_switch) - - @node_rewriter([switch]) def find_measurable_switch_mixture(fgraph, node): rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) @@ -430,6 +444,24 @@ def find_measurable_switch_mixture(fgraph, node): if rv_map_feature.request_measurable(components) != components: return None + all_ndim_supp = [] + all_supp_axes = [] + all_measure_type = [] + for i in range(0, len(components)): + ndim_supp, supp_axes, measure_type = get_measure_type_info(components[i]) + all_ndim_supp.append(ndim_supp) + all_supp_axes.append(supp_axes) + all_measure_type.append(measure_type) + + if all_measure_type[1:] == all_measure_type[:-1]: + m_type = all_measure_type[0] + else: + m_type = MeasureType.Mixed + + measurable_switch_mixture = MeasurableSwitchMixture( + scalar_switch, ndim_supp=all_ndim_supp[0], supp_axes=all_supp_axes[0], measure_type=m_type + ) + return [measurable_switch_mixture(switch_cond, *components)] @@ -459,13 +491,10 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa ) -class MeasurableIfElse(IfElse): +class MeasurableIfElse(MeasurableVariable, IfElse): """Measurable subclass of IfElse operator.""" -MeasurableVariable.register(MeasurableIfElse) - - @node_rewriter([IfElse]) def useless_ifelse_outputs(fgraph, node): """Remove outputs that are shared across the IfElse branches.""" @@ -517,7 +546,32 @@ def find_measurable_ifelse_mixture(fgraph, node): if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_rvs): return None - return MeasurableIfElse(n_outs=op.n_outs).make_node(if_var, *base_rvs).outputs + ndim_supp_all = () + supp_axes_all = () + measure_type_all = () + + half_len = int(len(base_rvs) / 2) + length = len(base_rvs) + + for base_rv1, base_rv2 in zip(base_rvs[0:half_len], base_rvs[half_len + 1 : length - 1]): + meta_info = get_measure_type_info(base_rv1) + if meta_info != get_measure_type_info(base_rv2): + return None + ndim_supp, supp_axes, measure_type = meta_info + ndim_supp_all += (ndim_supp,) + supp_axes_all += (supp_axes,) + measure_type_all += (measure_type,) + + return ( + MeasurableIfElse( + n_outs=op.n_outs, + ndim_supp=ndim_supp_all, + supp_axes=supp_axes_all, + measure_type=measure_type_all, + ) + .make_node(if_var, *base_rvs) + .outputs + ) measurable_ir_rewrites_db.register( diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 0dc78d0b0d..06e5127e29 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -51,6 +51,7 @@ _logcdf_helper, _logprob, _logprob_helper, + get_measure_type_info, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db from pymc.logprob.utils import find_negated_var @@ -58,18 +59,30 @@ from pymc.pytensorf import constant_fold -class MeasurableMax(Max): +class MeasurableMax(MeasurableVariable, Max): """A placeholder used to specify a log-likelihood for a max sub-graph.""" + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + ndim_supp = kwargs.get("ndim_supp", self.ndim_supp) + supp_axes = kwargs.get("supp_axes", self.supp_axes) + measure_type = kwargs.get("measure_type", self.measure_type) + return type(self)( + axis=axis, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) -MeasurableVariable.register(MeasurableMax) - -class MeasurableMaxDiscrete(Max): +class MeasurableMaxDiscrete(MeasurableVariable, Max): """A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables""" - -MeasurableVariable.register(MeasurableMaxDiscrete) + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + ndim_supp = kwargs.get("ndim_supp", self.ndim_supp) + supp_axes = kwargs.get("supp_axes", self.supp_axes) + measure_type = kwargs.get("measure_type", self.measure_type) + return type(self)( + axis=axis, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) @node_rewriter([Max]) @@ -104,11 +117,17 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[list[Tens if axis != base_var_dims: return None + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_var) + # distinguish measurable discrete and continuous (because logprob is different) if base_var.owner.op.dtype.startswith("int"): - measurable_max = MeasurableMaxDiscrete(list(axis)) + measurable_max = MeasurableMaxDiscrete( + axis=list(axis), ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) else: - measurable_max = MeasurableMax(list(axis)) + measurable_max = MeasurableMax( + axis=list(axis), ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) max_rv_node = measurable_max.make_node(base_var) max_rv = max_rv_node.outputs @@ -158,17 +177,32 @@ def max_logprob_discrete(op, values, base_rv, **kwargs): return logprob -class MeasurableMaxNeg(Max): +class MeasurableMaxNeg(MeasurableVariable, Max): """A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph. This shows up in the graph of min, which is (neg(max(neg(x))).""" + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + ndim_supp = kwargs.get("ndim_supp", self.ndim_supp) + supp_axes = kwargs.get("supp_axes", self.supp_axes) + measure_type = kwargs.get("measure_type", self.measure_type) + return type(self)( + axis=axis, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) -MeasurableVariable.register(MeasurableMaxNeg) - -class MeasurableDiscreteMaxNeg(Max): +class MeasurableDiscreteMaxNeg(MeasurableVariable, Max): """A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables""" + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + ndim_supp = kwargs.get("ndim_supp", self.ndim_supp) + supp_axes = kwargs.get("supp_axes", self.supp_axes) + measure_type = kwargs.get("measure_type", self.measure_type) + return type(self)( + axis=axis, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) + MeasurableVariable.register(MeasurableDiscreteMaxNeg) @@ -213,11 +247,17 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[ if not rv_map_feature.request_measurable([base_rv]): return None + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_rv) + # distinguish measurable discrete and continuous (because logprob is different) if base_rv.owner.op.dtype.startswith("int"): - measurable_min = MeasurableDiscreteMaxNeg(list(axis)) + measurable_min = MeasurableDiscreteMaxNeg( + list(axis), ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) else: - measurable_min = MeasurableMaxNeg(list(axis)) + measurable_min = MeasurableMaxNeg( + list(axis), ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) return measurable_min.make_node(base_rv).outputs diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 44ac31a0c3..f1e9985104 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -54,7 +54,7 @@ from pytensor.tensor.variable import TensorVariable from pytensor.updates import OrderedUpdates -from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.logprob.abstract import MeasurableVariable, _logprob, get_measure_type_info from pymc.logprob.basic import conditional_logp from pymc.logprob.rewriting import ( PreserveRVMappings, @@ -66,16 +66,13 @@ from pymc.logprob.utils import replace_rvs_by_values -class MeasurableScan(Scan): +class MeasurableScan(MeasurableVariable, Scan): """A placeholder used to specify a log-likelihood for a scan sub-graph.""" def __str__(self): return f"Measurable({super().__str__()})" -MeasurableVariable.register(MeasurableScan) - - def convert_outer_out_to_in( input_scan_args: ScanArgs, outer_out_vars: Iterable[TensorVariable], @@ -469,10 +466,31 @@ def find_measurable_scans(fgraph, node): # Replace the mapping rv_map_feature.update_rv_maps(rv_var, new_val_var, full_out) + clients: dict[Variable, list[Variable]] = {} + local_fgraph_topo = pytensor.graph.basic.io_toposort( + curr_scanargs.inner_inputs, + [o for o in curr_scanargs.inner_outputs if not isinstance(o.type, RandomType)], + clients=clients, + ) + all_ndim_supp = () + all_supp_axes = () + all_measure_type = () + for var in curr_scanargs.inner_outputs: + if var.owner.op is None: + continue + if isinstance(var.owner.op, MeasurableVariable): + ndim_supp, supp_axes, measure_type = get_measure_type_info(var) + all_ndim_supp += (ndim_supp,) + all_supp_axes += (supp_axes,) + all_measure_type += (measure_type,) + op = MeasurableScan( curr_scanargs.inner_inputs, curr_scanargs.inner_outputs, curr_scanargs.info, + ndim_supp=all_ndim_supp, + supp_axes=all_supp_axes, + measure_type=all_measure_type, mode=node.op.mode, ) new_node = op.make_node(*curr_scanargs.outer_inputs) diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 9cbf456b7b..bbb28d47e9 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -50,7 +50,12 @@ local_rv_size_lift, ) -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.abstract import ( + MeasurableVariable, + _logprob, + _logprob_helper, + get_measure_type_info, +) from pymc.logprob.rewriting import ( PreserveRVMappings, assume_measured_ir_outputs, @@ -62,7 +67,7 @@ @node_rewriter([Alloc]) def naive_bcast_rv_lift(fgraph, node): - """Lift an ``Alloc`` through a ``RandomVariable`` ``Op``. + """Lift a ``Alloc`` through a ``RandomVariable`` ``Op``. XXX: This implementation simply broadcasts the ``RandomVariable``'s parameters, which won't always work (e.g. multivariate distributions). @@ -122,13 +127,10 @@ def naive_bcast_rv_lift(fgraph, node): return [bcasted_node.outputs[1]] -class MeasurableMakeVector(MakeVector): +class MeasurableMakeVector(MeasurableVariable, MakeVector): """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" -MeasurableVariable.register(MeasurableMakeVector) - - @_logprob.register(MeasurableMakeVector) def logprob_make_vector(op, values, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `MeasurableMakeVector`.""" @@ -149,13 +151,10 @@ def logprob_make_vector(op, values, *base_rvs, **kwargs): return pt.stack(logps) -class MeasurableJoin(Join): +class MeasurableJoin(MeasurableVariable, Join): """A placeholder used to specify a log-likelihood for a join sub-graph.""" -MeasurableVariable.register(MeasurableJoin) - - @_logprob.register(MeasurableJoin) def logprob_join(op, values, axis, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `Join`.""" @@ -221,15 +220,21 @@ def find_measurable_stacks(fgraph, node) -> Optional[list[TensorVariable]]: if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_vars): return None + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_vars[0]) + if is_join: - measurable_stack = MeasurableJoin()(axis, *base_vars) + measurable_stack = MeasurableJoin( + ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + )(axis, *base_vars) else: - measurable_stack = MeasurableMakeVector(node.op.dtype)(*base_vars) + measurable_stack = MeasurableMakeVector( + node.op.dtype, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + )(*base_vars) return [measurable_stack] -class MeasurableDimShuffle(DimShuffle): +class MeasurableDimShuffle(MeasurableVariable, DimShuffle): """A placeholder used to specify a log-likelihood for a dimshuffle sub-graph.""" # Need to get the absolute path of `c_func_file`, otherwise it tries to @@ -237,9 +242,6 @@ class MeasurableDimShuffle(DimShuffle): c_func_file = DimShuffle.get_path(DimShuffle.c_func_file) -MeasurableVariable.register(MeasurableDimShuffle) - - @_logprob.register(MeasurableDimShuffle) def logprob_dimshuffle(op, values, base_var, **kwargs): """Compute the log-likelihood graph for a `MeasurableDimShuffle`.""" @@ -296,10 +298,27 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[list[TensorVariable]]: if not isinstance(base_var.owner.op, RandomVariable): return None # pragma: no cover - measurable_dimshuffle = MeasurableDimShuffle(node.op.input_broadcastable, node.op.new_order)( - base_var - ) + ref = list(range(0, base_var.type.ndim)) + + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_var) + new_supp_axes = list(supp_axes) + + for x in supp_axes: + if base_var.type.ndim + x not in node.op.new_order: + return None + + for x in new_supp_axes: + new_supp_axes[x] = node.op.new_order.index(ref[x]) - node.outputs[0].type.ndim + + supp_axes = tuple(new_supp_axes) + measurable_dimshuffle = MeasurableDimShuffle( + node.op.input_broadcastable, + node.op.new_order, + ndim_supp=ndim_supp, + supp_axes=supp_axes, + measure_type=measure_type, + )(base_var) return [measurable_dimshuffle] diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 3702a97550..1610e3b3b9 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -115,6 +115,7 @@ _logcdf_helper, _logprob, _logprob_helper, + get_measure_type_info, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import ( @@ -500,10 +501,15 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[li transform = ScaleTransform( transform_args_fn=lambda *inputs: inputs[-1], ) + + ndim_supp, supp_axes, measure_type = get_measure_type_info(measurable_input) transform_op = MeasurableTransform( scalar_op=scalar_op, transform=transform, measurable_input_idx=measurable_input_idx, + ndim_supp=ndim_supp, + supp_axes=supp_axes, + measure_type=measure_type, ) transform_out = transform_op.make_node(*transform_inputs).default_output() return [transform_out] diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 49827f7a61..c72d13f9d3 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -261,7 +261,7 @@ def local_check_parameter_to_ninf_switch(fgraph, node): class DiracDelta(Op): - """An `Op` that represents a Dirac-delta distribution.""" + """An `Op` that represents a Dirac-Delta distribution.""" __props__ = ("rtol", "atol") diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index ada8f71b2e..bd663c8d73 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -32,6 +32,7 @@ pymc/initial_point.py pymc/logprob/binary.py pymc/logprob/censoring.py +pymc/logprob/checks.py pymc/logprob/basic.py pymc/logprob/mixture.py pymc/logprob/order.py diff --git a/tests/logprob/test_abstract.py b/tests/logprob/test_abstract.py index 7a0bc61e78..c7a408fbed 100644 --- a/tests/logprob/test_abstract.py +++ b/tests/logprob/test_abstract.py @@ -58,13 +58,20 @@ def assert_equal_hash(classA, classB): def test_measurable_elemwise(): # Default does not accept any scalar_op + + ndim_supp = 0 + support_axis = None + d_type = "mixed" + with pytest.raises(TypeError, match=re.escape("scalar_op exp is not valid")): - MeasurableElemwise(exp) + MeasurableElemwise(exp, ndim_supp=ndim_supp, supp_axes=support_axis, measure_type=d_type) class TestMeasurableElemwise(MeasurableElemwise): valid_scalar_types = (Exp,) - measurable_exp_op = TestMeasurableElemwise(scalar_op=exp) + measurable_exp_op = TestMeasurableElemwise( + ndim_supp=ndim_supp, supp_axes=support_axis, measure_type=d_type, scalar_op=exp + ) measurable_exp = measurable_exp_op(0.0) assert isinstance(measurable_exp.owner.op, MeasurableVariable) diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py index e56a248741..cbdd5e6902 100644 --- a/tests/logprob/test_binary.py +++ b/tests/logprob/test_binary.py @@ -21,7 +21,9 @@ from pymc import logp from pymc.logprob import conditional_logp +from pymc.logprob.abstract import MeasureType, get_measure_type_info from pymc.testing import assert_no_rvs +from tests.logprob.utils import measure_type_info_helper @pytest.mark.parametrize( @@ -60,6 +62,38 @@ def test_continuous_rv_comparison_bitwise(comparison_op, exp_logp_true, exp_logp assert np.isclose(logp_fn_not(1), getattr(ref_scipy, exp_logp_false)(0.5)) +@pytest.mark.parametrize( + "comparison_op, exp_logp_true, exp_logp_false, inputs", + [ + ((pt.lt, pt.le), "logcdf", "logsf", (pt.random.normal(0, 1), 0.5)), + ((pt.gt, pt.ge), "logsf", "logcdf", (pt.random.normal(0, 1), 0.5)), + ((pt.lt, pt.le), "logsf", "logcdf", (0.5, pt.random.normal(0, 1))), + ((pt.gt, pt.ge), "logcdf", "logsf", (0.5, pt.random.normal(0, 1))), + ], +) +def test_measure_type_info(comparison_op, exp_logp_true, exp_logp_false, inputs): + for op in comparison_op: + comp_x_rv = op(*inputs) + + if inputs[0] == 0.5: + base_rv = inputs[1] + else: + base_rv = inputs[0] + + comp_x_vv = comp_x_rv.clone() + ndim_supp, supp_axes, measure_type = measure_type_info_helper(comp_x_rv, comp_x_vv) + + ndim_supp_base, supp_axes_base, _ = get_measure_type_info(base_rv) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type == MeasureType.Discrete + + @pytest.mark.parametrize( "comparison_op, exp_logp_true, exp_logp_false, inputs", [ diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index de407fd579..ee535c96b1 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -43,9 +43,11 @@ from pymc import logp from pymc.logprob import conditional_logp +from pymc.logprob.abstract import MeasureType, get_measure_type_info from pymc.logprob.transform_value import TransformValuesRewrite from pymc.logprob.transforms import LogTransform from pymc.testing import assert_no_rvs +from tests.logprob.utils import measure_type_info_helper @pytensor.config.change_flags(compute_test_value="raise") @@ -70,6 +72,36 @@ def test_continuous_rv_clip(): assert np.isclose(logp_fn(0), ref_scipy.logpdf(0)) +@pytest.mark.parametrize( + "measure_type", + [("Discrete"), ("Continuous")], +) +def test_clip_measure_type_info(measure_type): + if measure_type == "Continuous": + base_rv = pt.random.normal(0.5, 1) + rv = pt.clip(base_rv, -2, 2) + else: + base_rv = pt.random.poisson(2) + rv = pt.clip(base_rv, 1, 4) + + vv = rv.clone() + + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(base_rv) + + ndim_supp, supp_axes, measure_type = measure_type_info_helper(rv, vv) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + if measure_type_base == MeasureType.Continuous: + assert measure_type_base != measure_type + else: + assert measure_type_base == measure_type + + def test_discrete_rv_clip(): x_rv = pt.random.poisson(2) cens_x_rv = pt.clip(x_rv, 1, 4) @@ -262,3 +294,27 @@ def test_rounding(rounding_op): logprob.eval({xr_vv: test_value}), expected_logp, ) + + +@pytest.mark.parametrize("rounding_op", (pt.round, pt.floor, pt.ceil)) +def test_round_measure_type_info(rounding_op): + loc = 1 + scale = 2 + test_value = np.arange(-3, 4) + + x = pt.random.normal(loc, scale, size=test_value.shape, name="x") + xr = rounding_op(x) + xr.name = "xr" + + xr_vv = xr.clone() + ndim_supp, supp_axes, measure_type = measure_type_info_helper(xr, xr_vv) + + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(x) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type == MeasureType.Discrete diff --git a/tests/logprob/test_checks.py b/tests/logprob/test_checks.py index db60c573e1..607352d6f9 100644 --- a/tests/logprob/test_checks.py +++ b/tests/logprob/test_checks.py @@ -44,8 +44,10 @@ from scipy import stats from pymc.distributions import Dirichlet +from pymc.logprob.abstract import get_measure_type_info from pymc.logprob.basic import conditional_logp from tests.distributions.test_multivariate import dirichlet_logpdf +from tests.logprob.utils import measure_type_info_helper def test_specify_shape_logprob(): @@ -77,6 +79,27 @@ def test_specify_shape_logprob(): x_logp_fn(last_dim=1, x=x_vv_test_invalid) +def test_shape_measure_type_info(): + last_dim = pt.scalar(name="last_dim", dtype="int64") + x_base = Dirichlet.dist(pt.ones((last_dim,)), shape=(5, last_dim)) + x_base.name = "x" + x_rv = pt.specify_shape(x_base, shape=(5, 3)) + x_rv.name = "x" + + x_vv = x_rv.clone() + ndim_supp, supp_axes, measure_type = measure_type_info_helper(x_rv, x_vv) + + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(x_base) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type + + def test_assert_logprob(): rv = pt.random.normal() assert_op = Assert("Test assert") @@ -99,3 +122,24 @@ def test_assert_logprob(): # Since here the value to the rv is negative, an exception is raised as the condition is not met with pytest.raises(AssertionError, match="Test assert"): assert_logp.eval({assert_vv: -5.0}) + + +def test_assert_measure_type_info(): + rv = pt.random.normal() + assert_op = Assert("Test assert") + # Example: Add assert that rv must be positive + assert_rv = assert_op(rv, rv > 0) + assert_rv.name = "assert_rv" + + assert_vv = assert_rv.clone() + ndim_supp, supp_axes, measure_type = measure_type_info_helper(assert_rv, assert_vv) + + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(rv) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type diff --git a/tests/logprob/test_cumsum.py b/tests/logprob/test_cumsum.py index 552cea92d0..ddfba864dc 100644 --- a/tests/logprob/test_cumsum.py +++ b/tests/logprob/test_cumsum.py @@ -41,8 +41,10 @@ import scipy.stats as st from pymc import logp +from pymc.logprob.abstract import get_measure_type_info from pymc.logprob.basic import conditional_logp from pymc.testing import assert_no_rvs +from tests.logprob.utils import measure_type_info_helper @pytest.mark.parametrize( @@ -69,6 +71,36 @@ def test_normal_cumsum(size, axis): ) +@pytest.mark.parametrize( + "size, axis", + [ + (10, None), + (10, 0), + ((2, 10), 0), + ((2, 10), 1), + ((3, 2, 10), 0), + ((3, 2, 10), 1), + ((3, 2, 10), 2), + ], +) +def test_measure_type_info(size, axis): + rv = pt.random.normal(0, 1, size=size).cumsum(axis) + vv = rv.clone() + base_rv = pt.random.normal(0, 1, size=size) + base_vv = base_rv.clone() + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(base_rv) + + ndim_supp, supp_axes, measure_type = measure_type_info_helper(rv, vv) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type + + @pytest.mark.parametrize( "size, axis", [ diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index b3e5c5656e..538a127644 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -52,13 +52,13 @@ as_index_constant, ) -from pymc.logprob.abstract import MeasurableVariable +from pymc.logprob.abstract import MeasurableVariable, MeasureType, get_measure_type_info from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.mixture import MeasurableSwitchMixture, expand_indices from pymc.logprob.rewriting import construct_ir_fgraph -from pymc.logprob.utils import dirac_delta +from pymc.logprob.utils import dirac_delta as diracdelta from pymc.testing import assert_no_rvs -from tests.logprob.utils import scipy_logprob +from tests.logprob.utils import measure_type_info_helper, scipy_logprob def test_mixture_basics(): @@ -881,7 +881,7 @@ def test_mixture_with_DiracDelta(): srng = pt.random.RandomStream(29833) X_rv = srng.normal(0, 1, name="X") - Y_rv = dirac_delta(0.0) + Y_rv = diracdelta(0.0) Y_rv.name = "Y" I_rv = srng.categorical([0.5, 0.5], size=1) @@ -900,7 +900,7 @@ def test_mixture_with_DiracDelta(): assert m_vv in logp_res -def test_scalar_switch_mixture(): +def test_switch_mixture(): srng = pt.random.RandomStream(29833) X_rv = srng.normal(-10.0, 0.1, name="X") @@ -937,6 +937,36 @@ def test_scalar_switch_mixture(): np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 1})) +def test_measure_type_info_switch_mixture(): + srng = pt.random.RandomStream(29833) + + X_rv = srng.normal(-10.0, 0.1, name="X") + Y_rv = srng.normal(10.0, 0.1, name="Y") + + I_rv = srng.bernoulli(0.5, name="I") + i_vv = I_rv.clone() + i_vv.name = "i" + + # When I_rv == True, X_rv flows through otherwise Y_rv does + Z1_rv = pt.switch(I_rv, X_rv, Y_rv) + Z1_rv.name = "Z1" + + assert Z1_rv.eval({I_rv: 0}) > 5 + assert Z1_rv.eval({I_rv: 1}) < -5 + + z_vv = Z1_rv.clone() + z_vv.name = "z1" + + fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) + assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture) + + ndim_supp = fgraph.outputs[0].owner.op.ndim_supp + measure_type = fgraph.outputs[0].owner.op.measure_type + + np.testing.assert_almost_equal(0, ndim_supp) + assert measure_type == MeasureType.Continuous + + @pytest.mark.parametrize("switch_cond_scalar", (True, False)) def test_switch_mixture_vector(switch_cond_scalar): if switch_cond_scalar: @@ -1031,6 +1061,26 @@ def test_ifelse_mixture_one_component(): ) +def test_measure_type_ifelse(): + if_rv = pt.random.bernoulli(0.5, name="if") + scale_rv = pt.random.halfnormal(name="scale") + comp_then = pt.random.normal(0, scale_rv, size=(2,), name="comp_then") + comp_else = pt.random.halfnormal(0, scale_rv, size=(4,), name="comp_else") + mix_rv = ifelse(if_rv, comp_then, comp_else, name="mix") + + if_vv = if_rv.clone() + scale_vv = scale_rv.clone() + mix_vv = mix_rv.clone() + + ndim_supp, supp_axes, measure_type = measure_type_info_helper(mix_rv, mix_vv) + + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(comp_then) + + assert ndim_supp_base == 0 + assert supp_axes_base == () + assert measure_type_base == MeasureType.Continuous + + def test_ifelse_mixture_multiple_components(): rng = np.random.default_rng(968) diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 4d15240375..dff4fe2273 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -45,7 +45,9 @@ import pymc as pm from pymc import logp +from pymc.logprob.abstract import get_measure_type_info from pymc.testing import assert_no_rvs +from tests.logprob.utils import measure_type_info_helper def test_argmax(): @@ -185,6 +187,43 @@ def test_max_logprob(shape, value, axis): ) +def test_measure_type_info_order(): + """Test whether the logprob for ```pt.max``` produces the corrected + + The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here: + U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k) + for all 1<=k<=n + """ + x = pt.random.uniform(0, 1, size=(3,)) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_vv = x_max.clone() + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(x) + + ndim_supp, supp_axes, measure_type = measure_type_info_helper(x_max, x_max_vv) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type + + x_min = pt.min(x, axis=-1) + x_min_vv = x_min.clone() + + ndim_supp_min, supp_axes_min, measure_type_min = measure_type_info_helper(x_min, x_min_vv) + + assert np.isclose( + ndim_supp_base, + ndim_supp_min, + ) + assert supp_axes_base == supp_axes_min + + assert measure_type_base == measure_type_min + + @pytest.mark.parametrize( "shape, value, axis", [ diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index 30a76680e7..f051b497bd 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -44,14 +44,16 @@ from pytensor.scan.utils import ScanArgs from scipy import stats -from pymc.logprob.abstract import _logprob_helper +from pymc.logprob.abstract import MeasureType, _logprob_helper from pymc.logprob.basic import conditional_logp, logp +from pymc.logprob.rewriting import construct_ir_fgraph from pymc.logprob.scan import ( construct_scan, convert_outer_out_to_in, get_random_outer_outputs, ) from pymc.testing import assert_no_rvs +from tests.logprob.utils import measure_type_info_helper def create_inner_out_logp(value_map): @@ -504,6 +506,52 @@ def test_scan_over_seqs(): ) +def test_measure_type_scan_non_pure_rv_output(): + grw, _ = pytensor.scan( + fn=lambda xtm1: pt.random.normal() + xtm1, + outputs_info=[pt.zeros(())], + n_steps=10, + name="grw1", + ) + + grw1_vv = grw[0].clone() + + fgraph, _, _ = construct_ir_fgraph({grw[0]: grw1_vv}) + node = fgraph.outputs[0].owner + ndim_supp = node.inputs[0].owner.op.ndim_supp + supp_axes = node.inputs[0].owner.op.supp_axes + measure_type = node.inputs[0].owner.op.measure_type + assert ndim_supp == (0,) and supp_axes == ((),) and measure_type[0] is MeasureType.Continuous + + +def test_measure_type_scan_over_seqs(): + """Test that logprob inference for scans based on sequences (mapping).""" + rng = np.random.default_rng(543) + n_steps = 10 + + xs = pt.random.normal(size=(n_steps,), name="xs") # use vector with a fixed size + ys, _ = pytensor.scan( + fn=lambda x, x1: ( + pt.random.multinomial(x, np.ones(4) / 4), + pt.random.poisson(x1), + ), # use multinomial and poisson + sequences=[xs, xs], + outputs_info=[None, None], + name=("ys1", "ys2"), + ) + ys1_vv = ys[0].clone() + + ndim_supp, supp_axes, measure_type = measure_type_info_helper(ys[0], ys1_vv) + + if ( + not ndim_supp == (1, 0) + and not supp_axes == ((-1,), ()) + and not isinstance(measure_type[0], MeasureType.Discrete) + and not isinstance(measure_type[1], MeasureType.Discrete) + ): + assert 0 + + def test_scan_carried_deterministic_state(): """Test logp of scans with carried states downstream of measured variables. diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index e61e0d1700..72158d9b1a 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -45,10 +45,14 @@ from pytensor.tensor.basic import Alloc from scipy import stats as st +import pymc as pm + +from pymc.logprob.abstract import MeasureType, get_measure_type_info from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.rewriting import logprob_rewrites_db from pymc.logprob.tensor import naive_bcast_rv_lift from pymc.testing import assert_no_rvs +from tests.logprob.utils import measure_type_info_helper def test_naive_bcast_rv_lift(): @@ -134,6 +138,35 @@ def test_measurable_make_vector(): assert np.isclose(make_vector_logp_eval.sum(), ref_logp_eval_eval) +def test_measure_type_make_vector(): + base1_rv = pt.random.normal(name="base1") + base2_rv = pt.random.halfnormal(name="base2") + base3_rv = pt.random.exponential(name="base3") + y_rv = pt.stack((base1_rv, base2_rv, base3_rv)) + y_rv.name = "y" + + ndim_supp_base_1, supp_axes_base_1, measure_type_base_1 = get_measure_type_info(base1_rv) + ndim_supp_base_2, supp_axes_base_2, measure_type_base_2 = get_measure_type_info(base2_rv) + ndim_supp_base_3, supp_axes_base_3, measure_type_base_3 = get_measure_type_info(base3_rv) + + base1_vv = base1_rv.clone() + base2_vv = base2_rv.clone() + base3_vv = base3_rv.clone() + y_vv = y_rv.clone() + + ndim_supp, supp_axes, measure_type = measure_type_info_helper(y_rv, y_vv) + + assert np.isclose( + ndim_supp_base_1, + ndim_supp_base_2, + ndim_supp_base_3, + ndim_supp, + ) + assert supp_axes_base_1 == supp_axes_base_2 == supp_axes_base_3 == supp_axes + + assert measure_type_base_1 == measure_type_base_2 == measure_type_base_3 == measure_type + + @pytest.mark.parametrize("reverse", (False, True)) def test_measurable_make_vector_interdependent(reverse): """Test that we can obtain a proper graph when stacked RVs depend on each other""" @@ -268,6 +301,43 @@ def test_measurable_join_univariate(size1, size2, axis, concatenate): ) +@pytest.mark.parametrize( + "size1, size2, axis, concatenate", + [ + ((5,), (3,), 0, True), + ((5,), (3,), -1, True), + ((5, 2), (3, 2), 0, True), + ((2, 5), (2, 3), 1, True), + ((2, 5), (2, 5), 0, False), + ((2, 5), (2, 5), 1, False), + ((2, 5), (2, 5), 2, False), + ], +) +def test_measure_type_join_univariate(size1, size2, axis, concatenate): + base1_rv = pt.random.normal(size=size1, name="base1") + base2_rv = pt.random.exponential(size=size2, name="base2") + if concatenate: + y_rv = pt.concatenate((base1_rv, base2_rv), axis=axis) + else: + y_rv = pt.stack((base1_rv, base2_rv), axis=axis) + y_rv.name = "y" + + ndim_supp_base_1, supp_axes_base_1, measure_type_base_1 = get_measure_type_info(base1_rv) + ndim_supp_base_2, supp_axes_base_2, measure_type_base_2 = get_measure_type_info(base2_rv) + + y_vv = y_rv.clone() + ndim_supp, supp_axes, measure_type = measure_type_info_helper(y_rv, y_vv) + + assert np.isclose( + ndim_supp_base_1, + ndim_supp_base_2, + ndim_supp, + ) + assert supp_axes_base_1 == supp_axes_base_2 == supp_axes + + assert measure_type_base_1 == measure_type_base_2 == measure_type + + @pytest.mark.parametrize( "size1, supp_size1, size2, supp_size2, axis, concatenate", [ @@ -401,6 +471,43 @@ def test_measurable_dimshuffle(ds_order, multivariate): np.testing.assert_array_equal(ref_logp_fn(base_test_value), ds_logp_fn(ds_test_value)) +@pytensor.config.change_flags(cxx="") +@pytest.mark.parametrize( + "multivariate, ds_order, ans_ndim_supp, ans_supp_axes, ans_measure_type", + [ + (True, (0, 2), 1, (-1,), MeasureType.Continuous), # Drop + (True, (2, 0), 1, (-2,), MeasureType.Continuous), # Swap and drop + (True, (2, 1, "x", 0), 1, (-4,), MeasureType.Continuous), # Swap and expand + (True, ("x", 0, 2), 1, (-1,), MeasureType.Continuous), # Expand and drop + (True, (2, "x", 0), 1, (-3,), MeasureType.Continuous), # Swap, expand and drop + (False, (0, 2), 0, (), MeasureType.Continuous), + (False, (2, 1, "x", 0), 0, (), MeasureType.Continuous), + ], +) +def test_measure_type_dimshuffle( + multivariate, ds_order, ans_ndim_supp, ans_supp_axes, ans_measure_type +): + if multivariate: + base_rv = pm.Dirichlet.dist([1, 1, 1], shape=(7, 1, 3)) + ds_rv = base_rv.dimshuffle(ds_order) + base_vv = base_rv.clone() + + else: + base_rv = pt.random.beta(1, 2, size=(2, 1, 3)) + base_rv_1 = pt.exp(base_rv) + ds_rv = base_rv_1.dimshuffle(ds_order) + + ds_vv = ds_rv.clone() + + ndim_supp, supp_axes, measure_type = measure_type_info_helper(ds_rv, ds_vv) + + assert ndim_supp == ans_ndim_supp + + assert supp_axes == ans_supp_axes + + assert measure_type == ans_measure_type + + def test_unmeargeable_dimshuffles(): # Test that graphs with DimShuffles that cannot be lifted/merged fail diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index acf7296f47..4eb3a6b7c0 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -45,6 +45,7 @@ from pymc.distributions.continuous import Cauchy, ChiSquared from pymc.distributions.discrete import Bernoulli +from pymc.logprob.abstract import get_measure_type_info from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp from pymc.logprob.transforms import ( ArccoshTransform, @@ -66,6 +67,7 @@ from pymc.logprob.utils import ParameterValueError from pymc.testing import Rplusbig, Vector, assert_no_rvs from tests.distributions.test_transform import check_jacobian_det +from tests.logprob.utils import measure_type_info_helper class DirichletScipyDist: @@ -229,6 +231,26 @@ def test_exp_transform_rv(): ) +def test_measure_type_exp_transform_rv(): + base_rv = pt.random.normal(0, 1, size=3, name="base_rv") + y_rv = pt.exp(base_rv) + y_rv.name = "y" + + y_vv = y_rv.clone() + + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(base_rv) + + ndim_supp, supp_axes, measure_type = measure_type_info_helper(y_rv, y_vv) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type + + def test_log_transform_rv(): base_rv = pt.random.lognormal(0, 1, size=2, name="base_rv") y_rv = pt.log(base_rv) diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index e5aa36b830..db43e4fea3 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -41,6 +41,7 @@ from scipy import stats as stats from pymc.logprob import icdf, logcdf, logp +from pymc.logprob.rewriting import construct_ir_fgraph def scipy_logprob(obs, p): @@ -112,3 +113,14 @@ def scipy_logprob_tester( np.testing.assert_array_equal(pytensor_res_val.shape, numpy_res.shape) np.testing.assert_array_almost_equal(pytensor_res_val, numpy_res, 4) + + +def measure_type_info_helper(rv, vv): + """Extract measurable information from rv""" + fgraph, _, _ = construct_ir_fgraph({rv: vv}) + node = fgraph.outputs[0].owner + ndim_supp = node.op.ndim_supp + supp_axes = node.op.supp_axes + measure_type = node.op.measure_type + + return ndim_supp, supp_axes, measure_type