From bf794729e69fd69730f4f8c12b01a5f44f665b65 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sat, 3 Jun 2023 10:44:53 +0530 Subject: [PATCH 01/19] Adding metainfo --- pymc/logprob/abstract.py | 33 ++++++++++++++++++++++++++++++++- pymc/logprob/binary.py | 1 + pymc/logprob/censoring.py | 3 --- pymc/logprob/checks.py | 19 ++++++++++++++++++- pymc/logprob/cumsum.py | 13 ++++++++++++- pymc/logprob/mixture.py | 6 ++++++ pymc/logprob/scan.py | 23 +++++++++++++++++++++++ pymc/logprob/tensor.py | 24 ++++++++++++++++++++++++ pymc/logprob/transforms.py | 9 +++++++++ tests/logprob/test_abstract.py | 11 +++++++++-- tests/logprob/test_checks.py | 2 ++ 11 files changed, 136 insertions(+), 8 deletions(-) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index f35ab4c523..63af1e678b 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -143,13 +143,44 @@ class MeasurableElemwise(Elemwise): valid_scalar_types: tuple[MetaType, ...] = () - def __init__(self, scalar_op, *args, **kwargs): + def __init__(self, scalar_op, ndim_supp, support_axis, d_type, *args, **kwargs): if not isinstance(scalar_op, self.valid_scalar_types): raise TypeError( f"scalar_op {scalar_op} is not valid for class {self.__class__}. " f"Acceptable types are {self.valid_scalar_types}" ) super().__init__(scalar_op, *args, **kwargs) + self.ndim_supp = ndim_supp + self.support_axis = support_axis + self.d_type = d_type + + +from enum import Enum, auto + + +class MeasureType(Enum): + Discrete = auto() + Continuous = auto() + Mixed = auto() + + +def get_default_measurable_metainfo(base_op: Op, base_dtype) -> Tuple[int, Tuple[int], MeasureType]: + if not isinstance(base_op, MeasurableVariable): + raise TypeError("base_op must be a RandomVariable or MeasurableVariable") + + ndim_supp = base_op.ndim_supp + + supp_axes = getattr(base_op, "supp_axes", None) + if supp_axes is None: + supp_axes = tuple(range(-base_op.ndim_supp, 0)) + + measure_type = getattr(base_op, "measure_type", None) + if measure_type is None: + measure_type = ( + MeasureType.Discrete if base_dtype.dtype.startswith("int") else MeasureType.Continuous + ) + + return ndim_supp, supp_axes, measure_type MeasurableVariable.register(MeasurableElemwise) diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index f5d8cf848c..a8865d7f51 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -28,6 +28,7 @@ _logcdf_helper, _logprob, _logprob_helper, + get_default_measurable_metainfo, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import check_potential_measurability diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index b9221e08db..4e590dad82 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -59,9 +59,6 @@ class MeasurableClip(MeasurableElemwise): valid_scalar_types = (Clip,) -measurable_clip = MeasurableClip(scalar_clip) - - @node_rewriter(tracks=[clip]) def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]: # TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub) diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index 1cf202ec5e..614a78778e 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -43,7 +43,12 @@ from pytensor.tensor import TensorVariable from pytensor.tensor.shape import SpecifyShape -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.abstract import ( + MeasurableVariable, + _logprob, + _logprob_helper, + get_default_measurable_metainfo, +) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import replace_rvs_by_values @@ -51,6 +56,12 @@ class MeasurableSpecifyShape(SpecifyShape): """A placeholder used to specify a log-likelihood for a specify-shape sub-graph.""" + def __init__(self, ndim_supp, support_axis, d_type, *args, **kwargs): + super().__init__() + self.ndim_supp = ndim_supp + self.support_axis = support_axis + self.d_type = d_type + MeasurableVariable.register(MeasurableSpecifyShape) @@ -103,6 +114,12 @@ def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable class MeasurableCheckAndRaise(CheckAndRaise): """A placeholder used to specify a log-likelihood for an assert sub-graph.""" + def __init__(self, exc_type, msg, ndim_supp, support_axis, d_type, *args, **kwargs): + super().__init__(exc_type, msg) + self.ndim_supp = ndim_supp + self.support_axis = support_axis + self.d_type = d_type + MeasurableVariable.register(MeasurableCheckAndRaise) diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index 810f226c8b..6d61927f11 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -42,13 +42,24 @@ from pytensor.tensor import TensorVariable from pytensor.tensor.extra_ops import CumOp -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.abstract import ( + MeasurableVariable, + _logprob, + _logprob_helper, + get_default_measurable_metainfo, +) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db class MeasurableCumsum(CumOp): """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" + def __init__(self, ndim_supp, support_axis, d_type, axis, mode, *args, **kwargs): + super().__init__(axis, mode) + self.ndim_supp = ndim_supp + self.support_axis = support_axis + self.d_type = d_type + MeasurableVariable.register(MeasurableCumsum) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 011ce5e5fe..15afa4028e 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -462,6 +462,12 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa class MeasurableIfElse(IfElse): """Measurable subclass of IfElse operator.""" + def __init__(self, ndim_supp, support_axis, d_type, n_outs, *args, **kwargs): + super().__init__(n_outs) + self.ndim_supp = ndim_supp + self.support_axis = support_axis + self.d_type = d_type + MeasurableVariable.register(MeasurableIfElse) diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 44ac31a0c3..84de37ae04 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -72,6 +72,23 @@ class MeasurableScan(Scan): def __str__(self): return f"Measurable({super().__str__()})" + def __init__( + self, + inner_inputs, + inner_outputs, + info, + ndim_supp, + support_axis, + d_type, + mode, + *args, + **kwargs, + ): + super().__init__(inner_inputs, inner_outputs, info, mode) + self.ndim_supp = ndim_supp + self.support_axis = support_axis + self.d_type = d_type + MeasurableVariable.register(MeasurableScan) @@ -469,10 +486,16 @@ def find_measurable_scans(fgraph, node): # Replace the mapping rv_map_feature.update_rv_maps(rv_var, new_val_var, full_out) + for n in local_fgraph_topo: + if isinstance(n.op, MeasurableVariable): + ndim_supp, supp_axis, d_type = get_default_measurable_metainfo(n.op, node.inputs[0]) op = MeasurableScan( curr_scanargs.inner_inputs, curr_scanargs.inner_outputs, curr_scanargs.info, + ndim_supp=ndim_supp, + support_axis=supp_axis, + d_type=d_type, mode=node.op.mode, ) new_node = op.make_node(*curr_scanargs.outer_inputs) diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 9cbf456b7b..5cff2cb348 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -125,6 +125,12 @@ def naive_bcast_rv_lift(fgraph, node): class MeasurableMakeVector(MakeVector): """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" + def __init__(self, data_type, ndim_supp, support_axis, d_type, *args, **kwargs): + super().__init__(data_type) + self.ndim_supp = ndim_supp + self.support_axis = support_axis + self.d_type = d_type + MeasurableVariable.register(MeasurableMakeVector) @@ -152,6 +158,12 @@ def logprob_make_vector(op, values, *base_rvs, **kwargs): class MeasurableJoin(Join): """A placeholder used to specify a log-likelihood for a join sub-graph.""" + def __init__(self, axis, ndim_supp, support_axis, d_type, *args, **kwargs): + super().__init__(axis) + self.ndim_supp = ndim_supp + self.support_axis = support_axis + self.d_type = d_type + MeasurableVariable.register(MeasurableJoin) @@ -221,6 +233,10 @@ def find_measurable_stacks(fgraph, node) -> Optional[list[TensorVariable]]: if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_vars): return None + ndim_supp, supp_axis, d_type = get_default_measurable_metainfo( + base_vars[0].owner.op, base_vars[0] + ) + if is_join: measurable_stack = MeasurableJoin()(axis, *base_vars) else: @@ -236,6 +252,14 @@ class MeasurableDimShuffle(DimShuffle): # find it locally and fails when a new `Op` is initialized c_func_file = DimShuffle.get_path(DimShuffle.c_func_file) + def __init__( + self, input_broadcastable, new_order, ndim_supp, support_axis, d_type, *args, **kwargs + ): + super().__init__(input_broadcastable, new_order) + self.ndim_supp = ndim_supp + self.support_axis = support_axis + self.d_type = d_type + MeasurableVariable.register(MeasurableDimShuffle) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 3702a97550..6515ec79fd 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -500,10 +500,19 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[li transform = ScaleTransform( transform_args_fn=lambda *inputs: inputs[-1], ) + + ndim_supp = measurable_inputs[0].owner.op.ndim_supp + supp_axes = getattr(measurable_inputs[0].owner.op, "supp_axes", None) + if supp_axes is None: + supp_axes = tuple(range(-ndim_supp, 0)) + transform_op = MeasurableTransform( scalar_op=scalar_op, transform=transform, measurable_input_idx=measurable_input_idx, + ndim_supp=ndim_supp, + support_axis=supp_axes, + d_type=MeasureType.Continuous, ) transform_out = transform_op.make_node(*transform_inputs).default_output() return [transform_out] diff --git a/tests/logprob/test_abstract.py b/tests/logprob/test_abstract.py index 7a0bc61e78..9b05170a7c 100644 --- a/tests/logprob/test_abstract.py +++ b/tests/logprob/test_abstract.py @@ -58,13 +58,20 @@ def assert_equal_hash(classA, classB): def test_measurable_elemwise(): # Default does not accept any scalar_op + + ndim_supp = 0 + support_axis = None + d_type = "mixed" + with pytest.raises(TypeError, match=re.escape("scalar_op exp is not valid")): - MeasurableElemwise(exp) + MeasurableElemwise(exp, ndim_supp=ndim_supp, support_axis=support_axis, d_type=d_type) class TestMeasurableElemwise(MeasurableElemwise): valid_scalar_types = (Exp,) - measurable_exp_op = TestMeasurableElemwise(scalar_op=exp) + measurable_exp_op = TestMeasurableElemwise( + ndim_supp=ndim_supp, support_axis=support_axis, d_type=d_type, scalar_op=exp + ) measurable_exp = measurable_exp_op(0.0) assert isinstance(measurable_exp.owner.op, MeasurableVariable) diff --git a/tests/logprob/test_checks.py b/tests/logprob/test_checks.py index db60c573e1..cd0b3023bb 100644 --- a/tests/logprob/test_checks.py +++ b/tests/logprob/test_checks.py @@ -35,6 +35,8 @@ # SOFTWARE. import re +import re + import numpy as np import pytensor import pytensor.tensor as pt From bfe7a2f481d9e8065a2cd8bda0f03efdeb90cc2f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 5 Jun 2023 11:53:52 +0200 Subject: [PATCH 02/19] Use multiple inheritance for measurable meta info --- pymc/logprob/abstract.py | 59 ++++++++++++++++++++------------------ pymc/logprob/binary.py | 2 +- pymc/logprob/checks.py | 24 ++-------------- pymc/logprob/cumsum.py | 13 ++------- pymc/logprob/mixture.py | 20 ++++--------- pymc/logprob/scan.py | 36 +++++++---------------- pymc/logprob/tensor.py | 39 +++---------------------- pymc/logprob/transforms.py | 11 +++---- 8 files changed, 61 insertions(+), 143 deletions(-) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 63af1e678b..5b3e6a0050 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -131,9 +131,31 @@ def _icdf_helper(rv, value, **kwargs): return rv_icdf +from enum import Enum, auto + + +class MeasureType(Enum): + Discrete = auto() + Continuous = auto() + Mixed = auto() + + class MeasurableVariable(abc.ABC): """A variable that can be assigned a measure/log-probability""" + def __init__( + self, + *args, + ndim_supp: Union[int, Tuple[int]], + supp_axes: Tuple[Union[int, Tuple[int]]], + measure_type: Union[MeasureType, Tuple[MeasureType]], + **kwargs, + ): + self.ndim_supp = ndim_supp + self.supp_axes = supp_axes + self.measure_type = measure_type + super().__init__(*args, **kwargs) + MeasurableVariable.register(RandomVariable) @@ -143,44 +165,25 @@ class MeasurableElemwise(Elemwise): valid_scalar_types: tuple[MetaType, ...] = () - def __init__(self, scalar_op, ndim_supp, support_axis, d_type, *args, **kwargs): + def __init__(self, scalar_op, *args, **kwargs): if not isinstance(scalar_op, self.valid_scalar_types): raise TypeError( f"scalar_op {scalar_op} is not valid for class {self.__class__}. " f"Acceptable types are {self.valid_scalar_types}" ) super().__init__(scalar_op, *args, **kwargs) - self.ndim_supp = ndim_supp - self.support_axis = support_axis - self.d_type = d_type - -from enum import Enum, auto - - -class MeasureType(Enum): - Discrete = auto() - Continuous = auto() - Mixed = auto() - -def get_default_measurable_metainfo(base_op: Op, base_dtype) -> Tuple[int, Tuple[int], MeasureType]: +def get_measurable_meta_info(base_op: Op) -> Tuple[int, Tuple[int], MeasureType]: if not isinstance(base_op, MeasurableVariable): raise TypeError("base_op must be a RandomVariable or MeasurableVariable") - ndim_supp = base_op.ndim_supp - - supp_axes = getattr(base_op, "supp_axes", None) - if supp_axes is None: - supp_axes = tuple(range(-base_op.ndim_supp, 0)) - - measure_type = getattr(base_op, "measure_type", None) - if measure_type is None: + if isinstance(base_op, RandomVariable): + ndim_supp = base_op.ndim_supp + supp_axes = tuple(range(-ndim_supp, 0)) measure_type = ( - MeasureType.Discrete if base_dtype.dtype.startswith("int") else MeasureType.Continuous + MeasureType.Continuous if base_op.dtype.startswith("float") else MeasureType.Discrete ) - - return ndim_supp, supp_axes, measure_type - - -MeasurableVariable.register(MeasurableElemwise) + return base_op.ndim_supp, supp_axes, measure_type + else: + return base_op.ndim_supp, base_op.supp_axes, base_op.measure_type diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index a8865d7f51..b531381bd0 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -28,7 +28,7 @@ _logcdf_helper, _logprob, _logprob_helper, - get_default_measurable_metainfo, + get_measurable_meta_info, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import check_potential_measurability diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index 614a78778e..3731999233 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -47,24 +47,15 @@ MeasurableVariable, _logprob, _logprob_helper, - get_default_measurable_metainfo, + get_measurable_meta_info, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import replace_rvs_by_values -class MeasurableSpecifyShape(SpecifyShape): +class MeasurableSpecifyShape(MeasurableVariable, SpecifyShape): """A placeholder used to specify a log-likelihood for a specify-shape sub-graph.""" - def __init__(self, ndim_supp, support_axis, d_type, *args, **kwargs): - super().__init__() - self.ndim_supp = ndim_supp - self.support_axis = support_axis - self.d_type = d_type - - -MeasurableVariable.register(MeasurableSpecifyShape) - @_logprob.register(MeasurableSpecifyShape) def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs): @@ -111,18 +102,9 @@ def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable ) -class MeasurableCheckAndRaise(CheckAndRaise): +class MeasurableCheckAndRaise(MeasurableVariable, CheckAndRaise): """A placeholder used to specify a log-likelihood for an assert sub-graph.""" - def __init__(self, exc_type, msg, ndim_supp, support_axis, d_type, *args, **kwargs): - super().__init__(exc_type, msg) - self.ndim_supp = ndim_supp - self.support_axis = support_axis - self.d_type = d_type - - -MeasurableVariable.register(MeasurableCheckAndRaise) - @_logprob.register(MeasurableCheckAndRaise) def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs): diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index 6d61927f11..ddad3592de 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -46,23 +46,14 @@ MeasurableVariable, _logprob, _logprob_helper, - get_default_measurable_metainfo, + get_measurable_meta_info, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -class MeasurableCumsum(CumOp): +class MeasurableCumsum(MeasurableVariable, CumOp): """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" - def __init__(self, ndim_supp, support_axis, d_type, axis, mode, *args, **kwargs): - super().__init__(axis, mode) - self.ndim_supp = ndim_supp - self.support_axis = support_axis - self.d_type = d_type - - -MeasurableVariable.register(MeasurableCumsum) - @_logprob.register(MeasurableCumsum) def logprob_cumsum(op, values, base_rv, **kwargs): diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 15afa4028e..d7abcfcd01 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -70,6 +70,7 @@ MeasurableVariable, _logprob, _logprob_helper, + get_measurable_meta_info, ) from pymc.logprob.rewriting import ( PreserveRVMappings, @@ -217,16 +218,17 @@ def rv_pull_down(x: TensorVariable) -> TensorVariable: return fgraph.outputs[0] -class MixtureRV(Op): +class MixtureRV(MeasurableVariable, Op): """A placeholder used to specify a log-likelihood for a mixture sub-graph.""" __props__ = ("indices_end_idx", "out_dtype", "out_broadcastable") - def __init__(self, indices_end_idx, out_dtype, out_broadcastable): + def __init__(self, *args, indices_end_idx, out_dtype, out_broadcastable, **kwargs): super().__init__() self.indices_end_idx = indices_end_idx self.out_dtype = out_dtype self.out_broadcastable = out_broadcastable + super().__init__(*args, **kwargs) def make_node(self, *inputs): return Apply(self, list(inputs), [TensorType(self.out_dtype, self.out_broadcastable)()]) @@ -235,9 +237,6 @@ def perform(self, node, inputs, outputs): raise NotImplementedError("This is a stand-in Op.") # pragma: no cover -MeasurableVariable.register(MixtureRV) - - def get_stack_mixture_vars( node: Apply, ) -> tuple[Optional[list[TensorVariable]], Optional[int]]: @@ -459,18 +458,9 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa ) -class MeasurableIfElse(IfElse): +class MeasurableIfElse(MeasurableVariable, IfElse): """Measurable subclass of IfElse operator.""" - def __init__(self, ndim_supp, support_axis, d_type, n_outs, *args, **kwargs): - super().__init__(n_outs) - self.ndim_supp = ndim_supp - self.support_axis = support_axis - self.d_type = d_type - - -MeasurableVariable.register(MeasurableIfElse) - @node_rewriter([IfElse]) def useless_ifelse_outputs(fgraph, node): diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 84de37ae04..b1cc215717 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -66,32 +66,12 @@ from pymc.logprob.utils import replace_rvs_by_values -class MeasurableScan(Scan): +class MeasurableScan(MeasurableVariable, Scan): """A placeholder used to specify a log-likelihood for a scan sub-graph.""" def __str__(self): return f"Measurable({super().__str__()})" - def __init__( - self, - inner_inputs, - inner_outputs, - info, - ndim_supp, - support_axis, - d_type, - mode, - *args, - **kwargs, - ): - super().__init__(inner_inputs, inner_outputs, info, mode) - self.ndim_supp = ndim_supp - self.support_axis = support_axis - self.d_type = d_type - - -MeasurableVariable.register(MeasurableScan) - def convert_outer_out_to_in( input_scan_args: ScanArgs, @@ -486,16 +466,22 @@ def find_measurable_scans(fgraph, node): # Replace the mapping rv_map_feature.update_rv_maps(rv_var, new_val_var, full_out) + all_ndim_supp = [] + all_supp_axes = [] + all_measure_type = [] for n in local_fgraph_topo: if isinstance(n.op, MeasurableVariable): - ndim_supp, supp_axis, d_type = get_default_measurable_metainfo(n.op, node.inputs[0]) + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(n.op) + all_ndim_supp.append(ndim_supp) + all_supp_axes.append(supp_axes) + all_measure_type.append(measure_type) op = MeasurableScan( curr_scanargs.inner_inputs, curr_scanargs.inner_outputs, curr_scanargs.info, - ndim_supp=ndim_supp, - support_axis=supp_axis, - d_type=d_type, + ndim_supp=all_ndim_supp, + support_axis=all_supp_axes, + measure_type=all_measure_type, mode=node.op.mode, ) new_node = op.make_node(*curr_scanargs.outer_inputs) diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 5cff2cb348..e78937d3be 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -122,18 +122,9 @@ def naive_bcast_rv_lift(fgraph, node): return [bcasted_node.outputs[1]] -class MeasurableMakeVector(MakeVector): +class MeasurableMakeVector(MeasurableVariable, MakeVector): """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" - def __init__(self, data_type, ndim_supp, support_axis, d_type, *args, **kwargs): - super().__init__(data_type) - self.ndim_supp = ndim_supp - self.support_axis = support_axis - self.d_type = d_type - - -MeasurableVariable.register(MeasurableMakeVector) - @_logprob.register(MeasurableMakeVector) def logprob_make_vector(op, values, *base_rvs, **kwargs): @@ -155,18 +146,9 @@ def logprob_make_vector(op, values, *base_rvs, **kwargs): return pt.stack(logps) -class MeasurableJoin(Join): +class MeasurableJoin(MeasurableVariable, Join): """A placeholder used to specify a log-likelihood for a join sub-graph.""" - def __init__(self, axis, ndim_supp, support_axis, d_type, *args, **kwargs): - super().__init__(axis) - self.ndim_supp = ndim_supp - self.support_axis = support_axis - self.d_type = d_type - - -MeasurableVariable.register(MeasurableJoin) - @_logprob.register(MeasurableJoin) def logprob_join(op, values, axis, *base_rvs, **kwargs): @@ -233,9 +215,7 @@ def find_measurable_stacks(fgraph, node) -> Optional[list[TensorVariable]]: if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_vars): return None - ndim_supp, supp_axis, d_type = get_default_measurable_metainfo( - base_vars[0].owner.op, base_vars[0] - ) + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_vars[0].owner.op) if is_join: measurable_stack = MeasurableJoin()(axis, *base_vars) @@ -245,24 +225,13 @@ def find_measurable_stacks(fgraph, node) -> Optional[list[TensorVariable]]: return [measurable_stack] -class MeasurableDimShuffle(DimShuffle): +class MeasurableDimShuffle(MeasurableVariable, DimShuffle): """A placeholder used to specify a log-likelihood for a dimshuffle sub-graph.""" # Need to get the absolute path of `c_func_file`, otherwise it tries to # find it locally and fails when a new `Op` is initialized c_func_file = DimShuffle.get_path(DimShuffle.c_func_file) - def __init__( - self, input_broadcastable, new_order, ndim_supp, support_axis, d_type, *args, **kwargs - ): - super().__init__(input_broadcastable, new_order) - self.ndim_supp = ndim_supp - self.support_axis = support_axis - self.d_type = d_type - - -MeasurableVariable.register(MeasurableDimShuffle) - @_logprob.register(MeasurableDimShuffle) def logprob_dimshuffle(op, values, base_var, **kwargs): diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 6515ec79fd..62ef34ee1f 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -115,6 +115,7 @@ _logcdf_helper, _logprob, _logprob_helper, + get_measurable_meta_info, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import ( @@ -501,18 +502,14 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[li transform_args_fn=lambda *inputs: inputs[-1], ) - ndim_supp = measurable_inputs[0].owner.op.ndim_supp - supp_axes = getattr(measurable_inputs[0].owner.op, "supp_axes", None) - if supp_axes is None: - supp_axes = tuple(range(-ndim_supp, 0)) - + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(measurable_input.owner.op) transform_op = MeasurableTransform( scalar_op=scalar_op, transform=transform, measurable_input_idx=measurable_input_idx, ndim_supp=ndim_supp, - support_axis=supp_axes, - d_type=MeasureType.Continuous, + supp_axes=supp_axes, + measure_type=measure_type, ) transform_out = transform_op.make_node(*transform_inputs).default_output() return [transform_out] From 4b7672d2512f9ef260295525840a3710b8d7d621 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sat, 10 Jun 2023 13:06:45 +0530 Subject: [PATCH 03/19] Logprob tests passed --- pymc/logprob/abstract.py | 2 +- pymc/logprob/mixture.py | 49 +++++++++++++++++++++++++++++++--- pymc/logprob/scan.py | 2 +- tests/logprob/test_abstract.py | 4 +-- 4 files changed, 49 insertions(+), 8 deletions(-) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 5b3e6a0050..108b866b7d 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -174,7 +174,7 @@ def __init__(self, scalar_op, *args, **kwargs): super().__init__(scalar_op, *args, **kwargs) -def get_measurable_meta_info(base_op: Op) -> Tuple[int, Tuple[int], MeasureType]: +def get_measurable_meta_info(base_op: MeasurableVariable) -> Tuple[int, Tuple[int], MeasureType]: if not isinstance(base_op, MeasurableVariable): raise TypeError("base_op must be a RandomVariable or MeasurableVariable") diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index d7abcfcd01..9ca18a45f7 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -224,7 +224,7 @@ class MixtureRV(MeasurableVariable, Op): __props__ = ("indices_end_idx", "out_dtype", "out_broadcastable") def __init__(self, *args, indices_end_idx, out_dtype, out_broadcastable, **kwargs): - super().__init__() + # super().__init__(*args, **kwargs) self.indices_end_idx = indices_end_idx self.out_dtype = out_dtype self.out_broadcastable = out_broadcastable @@ -303,11 +303,16 @@ def find_measurable_index_mixture(fgraph, node): if rv_map_feature.request_measurable(mixture_rvs) != mixture_rvs: return None + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(mixture_rvs[0].owner.op) + # Replace this sub-graph with a `MixtureRV` mix_op = MixtureRV( - 1 + len(mixing_indices), - old_mixture_rv.dtype, - old_mixture_rv.broadcastable, + ndim_supp=ndim_supp, + supp_axes=supp_axes, + measure_type=measure_type, + indices_end_idx=1 + len(mixing_indices), + out_dtype=old_mixture_rv.dtype, + out_broadcastable=old_mixture_rv.broadcastable, ) new_node = mix_op.make_node(*([join_axis, *mixing_indices, *mixture_rvs])) @@ -324,6 +329,42 @@ def find_measurable_index_mixture(fgraph, node): return [new_mixture_rv] +@node_rewriter([switch]) +def find_measurable_switch_mixture(fgraph, node): + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + + if rv_map_feature is None: + return None # pragma: no cover + + old_mixture_rv = node.default_output() + idx, *components = node.inputs + + if rv_map_feature.request_measurable(components) != components: + return None + + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(mixture_rvs[0].owner.op) + + mix_op = MixtureRV( + indices_end_idx=2, + out_dtype=old_mixture_rv.dtype, + out_broadcastable=old_mixture_rv.broadcastable, + ndim_supp=ndim_supp, + supp_axes=supp_axes, + measure_type=measure_type, + ) + new_mixture_rv = mix_op.make_node( + *([NoneConst, as_nontensor_scalar(node.inputs[0])] + components[::-1]) + ).default_output() + + if pytensor.config.compute_test_value != "off": + if not hasattr(old_mixture_rv.tag, "test_value"): + compute_test_value(node) + + new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value + + return [new_mixture_rv] + + @_logprob.register(MixtureRV) def logprob_MixtureRV( op, values, *inputs: Optional[Union[TensorVariable, slice]], name=None, **kwargs diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index b1cc215717..3a1f060daf 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -480,7 +480,7 @@ def find_measurable_scans(fgraph, node): curr_scanargs.inner_outputs, curr_scanargs.info, ndim_supp=all_ndim_supp, - support_axis=all_supp_axes, + supp_axes=all_supp_axes, measure_type=all_measure_type, mode=node.op.mode, ) diff --git a/tests/logprob/test_abstract.py b/tests/logprob/test_abstract.py index 9b05170a7c..c7a408fbed 100644 --- a/tests/logprob/test_abstract.py +++ b/tests/logprob/test_abstract.py @@ -64,13 +64,13 @@ def test_measurable_elemwise(): d_type = "mixed" with pytest.raises(TypeError, match=re.escape("scalar_op exp is not valid")): - MeasurableElemwise(exp, ndim_supp=ndim_supp, support_axis=support_axis, d_type=d_type) + MeasurableElemwise(exp, ndim_supp=ndim_supp, supp_axes=support_axis, measure_type=d_type) class TestMeasurableElemwise(MeasurableElemwise): valid_scalar_types = (Exp,) measurable_exp_op = TestMeasurableElemwise( - ndim_supp=ndim_supp, support_axis=support_axis, d_type=d_type, scalar_op=exp + ndim_supp=ndim_supp, supp_axes=support_axis, measure_type=d_type, scalar_op=exp ) measurable_exp = measurable_exp_op(0.0) assert isinstance(measurable_exp.owner.op, MeasurableVariable) From cfd6fcf90dd67ba30b597fc1ce04731f6cee8ecf Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 11 Jun 2023 03:40:36 +0530 Subject: [PATCH 04/19] Acc to latest code --- pymc/logprob/abstract.py | 4 ++-- pymc/logprob/binary.py | 17 +++++++++++++++-- pymc/logprob/censoring.py | 17 +++++++++++++++-- pymc/logprob/checks.py | 21 +++++++++++++++++++-- pymc/logprob/cumsum.py | 12 +++++++++++- pymc/logprob/mixture.py | 15 +++++++++++++-- pymc/logprob/scan.py | 12 +++++++++++- pymc/logprob/tensor.py | 36 ++++++++++++++++++++++++++++++------ 8 files changed, 116 insertions(+), 18 deletions(-) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 108b866b7d..cf7202c6a3 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -160,7 +160,7 @@ def __init__( MeasurableVariable.register(RandomVariable) -class MeasurableElemwise(Elemwise): +class MeasurableElemwise(MeasurableVariable, Elemwise): """Base class for Measurable Elemwise variables""" valid_scalar_types: tuple[MetaType, ...] = () @@ -174,7 +174,7 @@ def __init__(self, scalar_op, *args, **kwargs): super().__init__(scalar_op, *args, **kwargs) -def get_measurable_meta_info(base_op: MeasurableVariable) -> Tuple[int, Tuple[int], MeasureType]: +def get_measurable_meta_info(base_op: Op) -> Tuple[int, Tuple[int], MeasureType]: if not isinstance(base_op, MeasurableVariable): raise TypeError("base_op must be a RandomVariable or MeasurableVariable") diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index b531381bd0..34765886c3 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -82,7 +82,14 @@ def find_measurable_comparisons( elif isinstance(node_scalar_op, LE): node_scalar_op = GE() - compared_op = MeasurableComparison(node_scalar_op) + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(measurable_var.owner.op) + + compared_op = MeasurableComparison( + scalar_op=node_scalar_op, + ndim_supp=ndim_supp, + supp_axes=supp_axes, + measure_type=measure_type, + ) compared_rv = compared_op.make_node(measurable_var, const).default_output() return [compared_rv] @@ -149,7 +156,13 @@ def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[list[ return None node_scalar_op = node.op.scalar_op - bitwise_op = MeasurableBitwise(node_scalar_op) + ndim_supp, supp_axis, measure_type = get_measurable_meta_info(base_var.owner.op) + bitwise_op = MeasurableBitwise( + scalar_op=node_scalar_op, + ndim_supp=ndim_supp, + supp_axes=supp_axis, + measure_type=measure_type, + ) bitwise_rv = bitwise_op.make_node(base_var).default_output() return [bitwise_rv] diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index 4e590dad82..904d78de91 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -48,7 +48,12 @@ from pytensor.tensor.math import ceil, clip, floor, round_half_to_even from pytensor.tensor.variable import TensorConstant -from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob +from pymc.logprob.abstract import ( + MeasurableElemwise, + _logcdf, + _logprob, + get_measurable_meta_info, +) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import CheckParameterValue @@ -78,6 +83,11 @@ def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[Te lower_bound = lower_bound if (lower_bound is not base_var) else pt.constant(-np.inf) upper_bound = upper_bound if (upper_bound is not base_var) else pt.constant(np.inf) + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_var.owner.op) + measurable_clip = MeasurableClip( + scalar_clip, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) + clipped_rv = measurable_clip.make_node(base_var, lower_bound, upper_bound).outputs[0] return [clipped_rv] @@ -164,7 +174,10 @@ def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[lis return None [base_var] = node.inputs - rounded_op = MeasurableRound(node.op.scalar_op) + ndim_supp, supp_axis, d_type = get_measurable_meta_info(base_var.owner.op) + rounded_op = MeasurableRound( + node.op.scalar_op, ndim_supp=ndim_supp, supp_axes=supp_axis, measure_type=d_type + ) rounded_rv = rounded_op.make_node(base_var).default_output() rounded_rv.name = node.outputs[0].name return [rounded_rv] diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index 3731999233..b8ed2dad0c 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -57,6 +57,9 @@ class MeasurableSpecifyShape(MeasurableVariable, SpecifyShape): """A placeholder used to specify a log-likelihood for a specify-shape sub-graph.""" +MeasurableVariable.register(MeasurableSpecifyShape) + + @_logprob.register(MeasurableSpecifyShape) def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs): (value,) = values @@ -88,7 +91,11 @@ def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable ): return None # pragma: no cover - new_op = MeasurableSpecifyShape() + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rv.owner.op) + + new_op = MeasurableSpecifyShape( + ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) new_rv = new_op.make_node(base_rv, *shape).default_output() return [new_rv] @@ -106,6 +113,9 @@ class MeasurableCheckAndRaise(MeasurableVariable, CheckAndRaise): """A placeholder used to specify a log-likelihood for an assert sub-graph.""" +MeasurableVariable.register(MeasurableCheckAndRaise) + + @_logprob.register(MeasurableCheckAndRaise) def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs): (value,) = values @@ -132,7 +142,14 @@ def find_measurable_check_and_raise(fgraph, node) -> Optional[list[TensorVariabl return None op = node.op - new_op = MeasurableCheckAndRaise(exc_type=op.exc_type, msg=op.msg) + ndim_supp, supp_axis, d_type = get_measurable_meta_info(base_rv.owner.op) + new_op = MeasurableCheckAndRaise( + exc_type=op.exc_type, + msg=op.msg, + ndim_supp=ndim_supp, + supp_axes=supp_axis, + measure_type=d_type, + ) new_rv = new_op.make_node(base_rv, *conds).default_output() return [new_rv] diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index ddad3592de..8a141fb940 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -55,6 +55,9 @@ class MeasurableCumsum(MeasurableVariable, CumOp): """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" +MeasurableVariable.register(MeasurableCumsum) + + @_logprob.register(MeasurableCumsum) def logprob_cumsum(op, values, base_rv, **kwargs): """Compute the log-likelihood graph for a `Cumsum`.""" @@ -103,7 +106,14 @@ def find_measurable_cumsums(fgraph, node) -> Optional[list[TensorVariable]]: if not rv_map_feature.request_measurable(node.inputs): return None - new_op = MeasurableCumsum(axis=node.op.axis or 0, mode="add") + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rv.owner.op) + new_op = MeasurableCumsum( + axis=node.op.axis or 0, + mode="add", + ndim_supp=ndim_supp, + supp_axes=supp_axes, + measure_type=measure_type, + ) new_rv = new_op.make_node(base_rv).default_output() return [new_rv] diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 9ca18a45f7..6ced52db49 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -237,6 +237,9 @@ def perform(self, node, inputs, outputs): raise NotImplementedError("This is a stand-in Op.") # pragma: no cover +MeasurableVariable.register(MixtureRV) + + def get_stack_mixture_vars( node: Apply, ) -> tuple[Optional[list[TensorVariable]], Optional[int]]: @@ -342,7 +345,7 @@ def find_measurable_switch_mixture(fgraph, node): if rv_map_feature.request_measurable(components) != components: return None - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(mixture_rvs[0].owner.op) + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(idx.owner.op) mix_op = MixtureRV( indices_end_idx=2, @@ -554,7 +557,15 @@ def find_measurable_ifelse_mixture(fgraph, node): if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_rvs): return None - return MeasurableIfElse(n_outs=op.n_outs).make_node(if_var, *base_rvs).outputs + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rvs[0].owner.op) + + return ( + MeasurableIfElse( + n_outs=op.n_outs, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) + .make_node(if_var, *base_rvs) + .outputs + ) measurable_ir_rewrites_db.register( diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 3a1f060daf..c954a82eeb 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -54,7 +54,7 @@ from pytensor.tensor.variable import TensorVariable from pytensor.updates import OrderedUpdates -from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.logprob.abstract import MeasurableVariable, _logprob, get_measurable_meta_info from pymc.logprob.basic import conditional_logp from pymc.logprob.rewriting import ( PreserveRVMappings, @@ -73,6 +73,9 @@ def __str__(self): return f"Measurable({super().__str__()})" +MeasurableVariable.register(MeasurableScan) + + def convert_outer_out_to_in( input_scan_args: ScanArgs, outer_out_vars: Iterable[TensorVariable], @@ -466,6 +469,12 @@ def find_measurable_scans(fgraph, node): # Replace the mapping rv_map_feature.update_rv_maps(rv_var, new_val_var, full_out) + clients: Dict[Variable, List[Variable]] = {} + local_fgraph_topo = pytensor.graph.basic.io_toposort( + curr_scanargs.inner_inputs, + [o for o in curr_scanargs.inner_outputs if not isinstance(o.type, RandomType)], + clients=clients, + ) all_ndim_supp = [] all_supp_axes = [] all_measure_type = [] @@ -475,6 +484,7 @@ def find_measurable_scans(fgraph, node): all_ndim_supp.append(ndim_supp) all_supp_axes.append(supp_axes) all_measure_type.append(measure_type) + op = MeasurableScan( curr_scanargs.inner_inputs, curr_scanargs.inner_outputs, diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index e78937d3be..43e53c4df3 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -50,7 +50,12 @@ local_rv_size_lift, ) -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.abstract import ( + MeasurableVariable, + _logprob, + _logprob_helper, + get_measurable_meta_info, +) from pymc.logprob.rewriting import ( PreserveRVMappings, assume_measured_ir_outputs, @@ -126,6 +131,9 @@ class MeasurableMakeVector(MeasurableVariable, MakeVector): """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" +MeasurableVariable.register(MeasurableMakeVector) + + @_logprob.register(MeasurableMakeVector) def logprob_make_vector(op, values, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `MeasurableMakeVector`.""" @@ -150,6 +158,9 @@ class MeasurableJoin(MeasurableVariable, Join): """A placeholder used to specify a log-likelihood for a join sub-graph.""" +MeasurableVariable.register(MeasurableJoin) + + @_logprob.register(MeasurableJoin) def logprob_join(op, values, axis, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `Join`.""" @@ -218,9 +229,13 @@ def find_measurable_stacks(fgraph, node) -> Optional[list[TensorVariable]]: ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_vars[0].owner.op) if is_join: - measurable_stack = MeasurableJoin()(axis, *base_vars) + measurable_stack = MeasurableJoin( + ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + )(axis, *base_vars) else: - measurable_stack = MeasurableMakeVector(node.op.dtype)(*base_vars) + measurable_stack = MeasurableMakeVector( + node.op.dtype, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + )(*base_vars) return [measurable_stack] @@ -233,6 +248,9 @@ class MeasurableDimShuffle(MeasurableVariable, DimShuffle): c_func_file = DimShuffle.get_path(DimShuffle.c_func_file) +MeasurableVariable.register(MeasurableDimShuffle) + + @_logprob.register(MeasurableDimShuffle) def logprob_dimshuffle(op, values, base_var, **kwargs): """Compute the log-likelihood graph for a `MeasurableDimShuffle`.""" @@ -289,9 +307,15 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[list[TensorVariable]]: if not isinstance(base_var.owner.op, RandomVariable): return None # pragma: no cover - measurable_dimshuffle = MeasurableDimShuffle(node.op.input_broadcastable, node.op.new_order)( - base_var - ) + ndim_supp, supp_axis, d_type = get_measurable_meta_info(base_var.owner.op) + + measurable_dimshuffle = MeasurableDimShuffle( + node.op.input_broadcastable, + node.op.new_order, + ndim_supp=ndim_supp, + supp_axes=supp_axis, + measure_type=d_type, + )(base_var) return [measurable_dimshuffle] From d04bd40b45d75186e91a4f85f6cfa24f621e4f58 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 11 Jun 2023 04:08:48 +0530 Subject: [PATCH 05/19] Pytest formatting --- tests/logprob/test_checks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/logprob/test_checks.py b/tests/logprob/test_checks.py index cd0b3023bb..db60c573e1 100644 --- a/tests/logprob/test_checks.py +++ b/tests/logprob/test_checks.py @@ -35,8 +35,6 @@ # SOFTWARE. import re -import re - import numpy as np import pytensor import pytensor.tensor as pt From abb342d9d7f9f1855c6f7ec6461eb3615d753b70 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 2 Jul 2023 19:00:43 +0530 Subject: [PATCH 06/19] Tests for Meta_info added --- pymc/logprob/abstract.py | 6 +- scripts/run_mypy.py | 1 + tests/logprob/test_binary.py | 36 ++++++++++ tests/logprob/test_censoring.py | 44 ++++++++++++ tests/logprob/test_checks.py | 49 +++++++++++++ tests/logprob/test_cumsum.py | 32 +++++++++ tests/logprob/test_mixture.py | 59 +++++++++++++++- tests/logprob/test_rewriting.py | 1 + tests/logprob/test_scan.py | 29 +++++++- tests/logprob/test_tensor.py | 118 +++++++++++++++++++++++++++++++ tests/logprob/test_transforms.py | 20 ++++++ tests/logprob/utils.py | 12 ++++ 12 files changed, 403 insertions(+), 4 deletions(-) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index cf7202c6a3..4c4c4f42a3 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -174,7 +174,11 @@ def __init__(self, scalar_op, *args, **kwargs): super().__init__(scalar_op, *args, **kwargs) -def get_measurable_meta_info(base_op: Op) -> Tuple[int, Tuple[int], MeasureType]: +def get_measurable_meta_info( + base_op: Op, +) -> Tuple[ + Union[int, Tuple[int]], Tuple[Union[int, Tuple[int]]], Union[MeasureType, Tuple[MeasureType]] +]: if not isinstance(base_op, MeasurableVariable): raise TypeError("base_op must be a RandomVariable or MeasurableVariable") diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index ada8f71b2e..bd663c8d73 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -32,6 +32,7 @@ pymc/initial_point.py pymc/logprob/binary.py pymc/logprob/censoring.py +pymc/logprob/checks.py pymc/logprob/basic.py pymc/logprob/mixture.py pymc/logprob/order.py diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py index e56a248741..ed6c9e6fe2 100644 --- a/tests/logprob/test_binary.py +++ b/tests/logprob/test_binary.py @@ -21,7 +21,9 @@ from pymc import logp from pymc.logprob import conditional_logp +from pymc.logprob.abstract import get_measurable_meta_info from pymc.testing import assert_no_rvs +from tests.logprob.utils import meta_info_helper @pytest.mark.parametrize( @@ -60,6 +62,40 @@ def test_continuous_rv_comparison_bitwise(comparison_op, exp_logp_true, exp_logp assert np.isclose(logp_fn_not(1), getattr(ref_scipy, exp_logp_false)(0.5)) +@pytest.mark.parametrize( + "comparison_op, exp_logp_true, exp_logp_false, inputs", + [ + ((pt.lt, pt.le), "logcdf", "logsf", (pt.random.normal(0, 1), 0.5)), + ((pt.gt, pt.ge), "logsf", "logcdf", (pt.random.normal(0, 1), 0.5)), + ((pt.lt, pt.le), "logsf", "logcdf", (0.5, pt.random.normal(0, 1))), + ((pt.gt, pt.ge), "logcdf", "logsf", (0.5, pt.random.normal(0, 1))), + ], +) +def test_meta_info(comparison_op, exp_logp_true, exp_logp_false, inputs): + for op in comparison_op: + comp_x_rv = op(*inputs) + + if inputs[0] == 0.5: + base_rv = inputs[1] + else: + base_rv = inputs[0] + + comp_x_vv = comp_x_rv.clone() + ndim_supp, supp_axes, measure_type = meta_info_helper(comp_x_rv, comp_x_vv) + + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info( + base_rv.owner.op + ) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type + + @pytest.mark.parametrize( "comparison_op, exp_logp_true, exp_logp_false, inputs", [ diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index de407fd579..725c3edf26 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -46,6 +46,7 @@ from pymc.logprob.transform_value import TransformValuesRewrite from pymc.logprob.transforms import LogTransform from pymc.testing import assert_no_rvs +from tests.logprob.utils import meta_info_helper @pytensor.config.change_flags(compute_test_value="raise") @@ -70,6 +71,25 @@ def test_continuous_rv_clip(): assert np.isclose(logp_fn(0), ref_scipy.logpdf(0)) +def test_clip_meta_info(): + base_rv = pt.random.normal(0.5, 1) + rv = pt.clip(base_rv, -2, 2) + + vv = rv.clone() + + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv.owner.op) + + ndim_supp, supp_axes, measure_type = meta_info_helper(rv, vv) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type + + def test_discrete_rv_clip(): x_rv = pt.random.poisson(2) cens_x_rv = pt.clip(x_rv, 1, 4) @@ -262,3 +282,27 @@ def test_rounding(rounding_op): logprob.eval({xr_vv: test_value}), expected_logp, ) + + +@pytest.mark.parametrize("rounding_op", (pt.round, pt.floor, pt.ceil)) +def test_round_meta_info(rounding_op): + loc = 1 + scale = 2 + test_value = np.arange(-3, 4) + + x = pt.random.normal(loc, scale, size=test_value.shape, name="x") + xr = rounding_op(x) + xr.name = "xr" + + xr_vv = xr.clone() + ndim_supp, supp_axes, measure_type = meta_info_helper(xr, xr_vv) + + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(x.owner.op) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type diff --git a/tests/logprob/test_checks.py b/tests/logprob/test_checks.py index db60c573e1..d5b9aa19cf 100644 --- a/tests/logprob/test_checks.py +++ b/tests/logprob/test_checks.py @@ -35,17 +35,23 @@ # SOFTWARE. import re +from collections import deque + import numpy as np import pytensor import pytensor.tensor as pt import pytest +from pytensor.graph.basic import io_toposort from pytensor.raise_op import Assert from scipy import stats from pymc.distributions import Dirichlet +from pymc.logprob.abstract import get_measurable_meta_info from pymc.logprob.basic import conditional_logp +from pymc.logprob.rewriting import construct_ir_fgraph from tests.distributions.test_multivariate import dirichlet_logpdf +from tests.logprob.utils import meta_info_helper def test_specify_shape_logprob(): @@ -77,6 +83,28 @@ def test_specify_shape_logprob(): x_logp_fn(last_dim=1, x=x_vv_test_invalid) +def test_shape_meta_info(): + last_dim = pt.scalar(name="last_dim", dtype="int64") + x_base = Dirichlet.dist(pt.ones((last_dim,)), shape=(5, last_dim)) + x_base.name = "x" + x_rv = pt.specify_shape(x_base, shape=(5, 3)) + x_rv.name = "x" + + # 2. Request logp + x_vv = x_rv.clone() + ndim_supp, supp_axes, measure_type = meta_info_helper(x_rv, x_vv) + + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(x_base.owner.op) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type + + def test_assert_logprob(): rv = pt.random.normal() assert_op = Assert("Test assert") @@ -99,3 +127,24 @@ def test_assert_logprob(): # Since here the value to the rv is negative, an exception is raised as the condition is not met with pytest.raises(AssertionError, match="Test assert"): assert_logp.eval({assert_vv: -5.0}) + + +def test_assert_meta_info(): + rv = pt.random.normal() + assert_op = Assert("Test assert") + # Example: Add assert that rv must be positive + assert_rv = assert_op(rv, rv > 0) + assert_rv.name = "assert_rv" + + assert_vv = assert_rv.clone() + ndim_supp, supp_axes, measure_type = meta_info_helper(assert_rv, assert_vv) + + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(rv.owner.op) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type diff --git a/tests/logprob/test_cumsum.py b/tests/logprob/test_cumsum.py index 552cea92d0..3646043a8d 100644 --- a/tests/logprob/test_cumsum.py +++ b/tests/logprob/test_cumsum.py @@ -41,8 +41,10 @@ import scipy.stats as st from pymc import logp +from pymc.logprob.abstract import get_measurable_meta_info from pymc.logprob.basic import conditional_logp from pymc.testing import assert_no_rvs +from tests.logprob.utils import meta_info_helper @pytest.mark.parametrize( @@ -69,6 +71,36 @@ def test_normal_cumsum(size, axis): ) +@pytest.mark.parametrize( + "size, axis", + [ + (10, None), + (10, 0), + ((2, 10), 0), + ((2, 10), 1), + ((3, 2, 10), 0), + ((3, 2, 10), 1), + ((3, 2, 10), 2), + ], +) +def test_meta_info(size, axis): + rv = pt.random.normal(0, 1, size=size).cumsum(axis) + vv = rv.clone() + base_rv = pt.random.normal(0, 1, size=size) + base_vv = base_rv.clone() + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv.owner.op) + + ndim_supp, supp_axes, measure_type = meta_info_helper(rv, vv) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type + + @pytest.mark.parametrize( "size, axis", [ diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index b3e5c5656e..fdb88ee523 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -58,7 +58,7 @@ from pymc.logprob.rewriting import construct_ir_fgraph from pymc.logprob.utils import dirac_delta from pymc.testing import assert_no_rvs -from tests.logprob.utils import scipy_logprob +from tests.logprob.utils import meta_info_helper, scipy_logprob def test_mixture_basics(): @@ -900,7 +900,38 @@ def test_mixture_with_DiracDelta(): assert m_vv in logp_res -def test_scalar_switch_mixture(): +def test_meta_with_DiracDelta(): + srng = pt.random.RandomStream(29833) + + X_rv = srng.normal(0, 1, name="X") + Y_rv = dirac_delta(0.0) + Y_rv.name = "Y" + + I_rv = srng.categorical([0.5, 0.5], size=1) + + i_vv = I_rv.clone() + i_vv.name = "i" + + M_rv = pt.stack([X_rv, Y_rv])[I_rv] + M_rv.name = "M" + + m_vv = M_rv.clone() + m_vv.name = "m" + + ndim_supp, supp_axes, measure_type = meta_info_helper(M_rv, m_vv) + + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(X_rv.owner.op) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type + + +def test_switch_mixture(): srng = pt.random.RandomStream(29833) X_rv = srng.normal(-10.0, 0.1, name="X") @@ -1031,6 +1062,30 @@ def test_ifelse_mixture_one_component(): ) +def test_meta_ifelse(): + if_rv = pt.random.bernoulli(0.5, name="if") + scale_rv = pt.random.halfnormal(name="scale") + comp_then = pt.random.normal(0, scale_rv, size=(2,), name="comp_then") + comp_else = pt.random.halfnormal(0, scale_rv, size=(4,), name="comp_else") + mix_rv = ifelse(if_rv, comp_then, comp_else, name="mix") + + if_vv = if_rv.clone() + scale_vv = scale_rv.clone() + mix_vv = mix_rv.clone() + + ndim_supp, supp_axes, measure_type = meta_info_helper(mix_rv, mix_vv) + + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(if_rv.owner.op) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base != measure_type + + def test_ifelse_mixture_multiple_components(): rng = np.random.default_rng(968) diff --git a/tests/logprob/test_rewriting.py b/tests/logprob/test_rewriting.py index 66c28b102d..a7a208866c 100644 --- a/tests/logprob/test_rewriting.py +++ b/tests/logprob/test_rewriting.py @@ -55,6 +55,7 @@ from pymc.logprob.rewriting import cleanup_ir, local_lift_DiracDelta from pymc.logprob.transform_value import TransformedValue, TransformValuesRewrite from pymc.logprob.utils import DiracDelta, dirac_delta +from tests.logprob.utils import scipy_logprob def test_local_lift_DiracDelta(): diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index 30a76680e7..b07e3b5813 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -44,7 +44,7 @@ from pytensor.scan.utils import ScanArgs from scipy import stats -from pymc.logprob.abstract import _logprob_helper +from pymc.logprob.abstract import _logprob_helper, get_measurable_meta_info from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.scan import ( construct_scan, @@ -52,6 +52,7 @@ get_random_outer_outputs, ) from pymc.testing import assert_no_rvs +from tests.logprob.utils import meta_info_helper def create_inner_out_logp(value_map): @@ -504,6 +505,32 @@ def test_scan_over_seqs(): ) +def test_meta_scan_over_seqs(): + """Test that logprob inference for scans based on sequences (mapping).""" + rng = np.random.default_rng(543) + n_steps = 10 + + xs = pt.random.normal(size=(n_steps,), name="xs") + ys, _ = pytensor.scan( + fn=lambda x: pt.random.normal(x), sequences=[xs], outputs_info=[None], name="ys" + ) + + xs_vv = ys.clone() + ys_vv = ys.clone() + + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(xs.owner.op) + + ndim_supp, supp_axes, measure_type = meta_info_helper(ys, ys_vv) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes[0] + + assert measure_type_base == measure_type[0] + + def test_scan_carried_deterministic_state(): """Test logp of scans with carried states downstream of measured variables. diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index e61e0d1700..f87d765db2 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -45,10 +45,12 @@ from pytensor.tensor.basic import Alloc from scipy import stats as st +from pymc.logprob.abstract import get_measurable_meta_info from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.rewriting import logprob_rewrites_db from pymc.logprob.tensor import naive_bcast_rv_lift from pymc.testing import assert_no_rvs +from tests.logprob.utils import meta_info_helper, scipy_logprob def test_naive_bcast_rv_lift(): @@ -134,6 +136,41 @@ def test_measurable_make_vector(): assert np.isclose(make_vector_logp_eval.sum(), ref_logp_eval_eval) +def test_meta_make_vector(): + base1_rv = pt.random.normal(name="base1") + base2_rv = pt.random.halfnormal(name="base2") + base3_rv = pt.random.exponential(name="base3") + y_rv = pt.stack((base1_rv, base2_rv, base3_rv)) + y_rv.name = "y" + + ndim_supp_base_1, supp_axes_base_1, measure_type_base_1 = get_measurable_meta_info( + base1_rv.owner.op + ) + ndim_supp_base_2, supp_axes_base_2, measure_type_base_2 = get_measurable_meta_info( + base2_rv.owner.op + ) + ndim_supp_base_3, supp_axes_base_3, measure_type_base_3 = get_measurable_meta_info( + base3_rv.owner.op + ) + + base1_vv = base1_rv.clone() + base2_vv = base2_rv.clone() + base3_vv = base3_rv.clone() + y_vv = y_rv.clone() + + ndim_supp, supp_axes, measure_type = meta_info_helper(y_rv, y_vv) + + assert np.isclose( + ndim_supp_base_1, + ndim_supp_base_2, + ndim_supp_base_3, + ndim_supp, + ) + assert supp_axes_base_1 == supp_axes_base_2 == supp_axes_base_3 == supp_axes + + assert measure_type_base_1 == measure_type_base_2 == measure_type_base_3 == measure_type + + @pytest.mark.parametrize("reverse", (False, True)) def test_measurable_make_vector_interdependent(reverse): """Test that we can obtain a proper graph when stacked RVs depend on each other""" @@ -268,6 +305,47 @@ def test_measurable_join_univariate(size1, size2, axis, concatenate): ) +@pytest.mark.parametrize( + "size1, size2, axis, concatenate", + [ + ((5,), (3,), 0, True), + ((5,), (3,), -1, True), + ((5, 2), (3, 2), 0, True), + ((2, 5), (2, 3), 1, True), + ((2, 5), (2, 5), 0, False), + ((2, 5), (2, 5), 1, False), + ((2, 5), (2, 5), 2, False), + ], +) +def test_meta_join_univariate(size1, size2, axis, concatenate): + base1_rv = pt.random.normal(size=size1, name="base1") + base2_rv = pt.random.exponential(size=size2, name="base2") + if concatenate: + y_rv = pt.concatenate((base1_rv, base2_rv), axis=axis) + else: + y_rv = pt.stack((base1_rv, base2_rv), axis=axis) + y_rv.name = "y" + + ndim_supp_base_1, supp_axes_base_1, measure_type_base_1 = get_measurable_meta_info( + base1_rv.owner.op + ) + ndim_supp_base_2, supp_axes_base_2, measure_type_base_2 = get_measurable_meta_info( + base2_rv.owner.op + ) + + y_vv = y_rv.clone() + ndim_supp, supp_axes, measure_type = meta_info_helper(y_rv, y_vv) + + assert np.isclose( + ndim_supp_base_1, + ndim_supp_base_2, + ndim_supp, + ) + assert supp_axes_base_1 == supp_axes_base_2 == supp_axes + + assert measure_type_base_1 == measure_type_base_2 == measure_type + + @pytest.mark.parametrize( "size1, supp_size1, size2, supp_size2, axis, concatenate", [ @@ -401,6 +479,46 @@ def test_measurable_dimshuffle(ds_order, multivariate): np.testing.assert_array_equal(ref_logp_fn(base_test_value), ds_logp_fn(ds_test_value)) +@pytensor.config.change_flags(cxx="") +@pytest.mark.parametrize( + "ds_order", + [ + (0, 2, 1), # Swap + (2, 1, 0), # Swap + (1, 2, 0), # Swap + (0, 1, 2, "x"), # Expand + ( + 0, + 2, + ), # Drop + (2, 0), # Swap and drop + (2, 1, "x", 0), # Swap and expand + ("x", 0, 2), # Expand and drop + (2, "x", 0), # Swap, expand and drop + ], +) +@pytest.mark.parametrize("multivariate", [True]) +def test_meta_measurable_dimshuffle(ds_order, multivariate): + if multivariate: + base_rv = pt.random.dirichlet([1, 2, 3], size=(2, 1)) + + ds_rv = base_rv.dimshuffle(ds_order) + base_vv = base_rv.clone() + ds_vv = ds_rv.clone() + + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv.owner.op) + + ndim_supp, supp_axes, measure_type = meta_info_helper(ds_rv, ds_vv) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type + + def test_unmeargeable_dimshuffles(): # Test that graphs with DimShuffles that cannot be lifted/merged fail diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index acf7296f47..7ecd0ef851 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -229,6 +229,26 @@ def test_exp_transform_rv(): ) +def test_meta_exp_transform_rv(): + base_rv = pt.random.normal(0, 1, size=3, name="base_rv") + y_rv = pt.exp(base_rv) + y_rv.name = "y" + + y_vv = y_rv.clone() + + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv.owner.op) + + ndim_supp, supp_axes, measure_type = meta_info_helper(y_rv, y_vv) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type + + def test_log_transform_rv(): base_rv = pt.random.lognormal(0, 1, size=2, name="base_rv") y_rv = pt.log(base_rv) diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index e5aa36b830..4978b5482e 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -41,6 +41,7 @@ from scipy import stats as stats from pymc.logprob import icdf, logcdf, logp +from pymc.logprob.rewriting import construct_ir_fgraph def scipy_logprob(obs, p): @@ -112,3 +113,14 @@ def scipy_logprob_tester( np.testing.assert_array_equal(pytensor_res_val.shape, numpy_res.shape) np.testing.assert_array_almost_equal(pytensor_res_val, numpy_res, 4) + + +def meta_info_helper(rv, vv): + fgraph, _, _ = construct_ir_fgraph({rv: vv}) + node = fgraph.outputs[0].owner + + ndim_supp = node.op.ndim_supp + supp_axes = node.op.supp_axes + measure_type = node.op.measure_type + + return ndim_supp, supp_axes, measure_type From 8ee30f654a31a62e6e5e0af00e4951b4c5752c6b Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 2 Jul 2023 19:39:29 +0530 Subject: [PATCH 07/19] Tests for Meta_info addedand rebased --- pymc/logprob/mixture.py | 45 +++++------------------------------ tests/logprob/test_mixture.py | 2 +- 2 files changed, 7 insertions(+), 40 deletions(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 6ced52db49..7068c08599 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -332,42 +332,6 @@ def find_measurable_index_mixture(fgraph, node): return [new_mixture_rv] -@node_rewriter([switch]) -def find_measurable_switch_mixture(fgraph, node): - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - - old_mixture_rv = node.default_output() - idx, *components = node.inputs - - if rv_map_feature.request_measurable(components) != components: - return None - - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(idx.owner.op) - - mix_op = MixtureRV( - indices_end_idx=2, - out_dtype=old_mixture_rv.dtype, - out_broadcastable=old_mixture_rv.broadcastable, - ndim_supp=ndim_supp, - supp_axes=supp_axes, - measure_type=measure_type, - ) - new_mixture_rv = mix_op.make_node( - *([NoneConst, as_nontensor_scalar(node.inputs[0])] + components[::-1]) - ).default_output() - - if pytensor.config.compute_test_value != "off": - if not hasattr(old_mixture_rv.tag, "test_value"): - compute_test_value(node) - - new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value - - return [new_mixture_rv] - - @_logprob.register(MixtureRV) def logprob_MixtureRV( op, values, *inputs: Optional[Union[TensorVariable, slice]], name=None, **kwargs @@ -446,9 +410,6 @@ class MeasurableSwitchMixture(MeasurableElemwise): valid_scalar_types = (Switch,) -measurable_switch_mixture = MeasurableSwitchMixture(scalar_switch) - - @node_rewriter([switch]) def find_measurable_switch_mixture(fgraph, node): rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) @@ -473,6 +434,12 @@ def find_measurable_switch_mixture(fgraph, node): if rv_map_feature.request_measurable(components) != components: return None + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(components[0].owner.op) + + measurable_switch_mixture = MeasurableSwitchMixture( + scalar_switch, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) + return [measurable_switch_mixture(switch_cond, *components)] diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index fdb88ee523..81fc78f050 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -52,7 +52,7 @@ as_index_constant, ) -from pymc.logprob.abstract import MeasurableVariable +from pymc.logprob.abstract import MeasurableVariable, get_measurable_meta_info from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.mixture import MeasurableSwitchMixture, expand_indices from pymc.logprob.rewriting import construct_ir_fgraph From 32a9b2b122c7e84cc351acabf406f413bf14ae16 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 9 Jul 2023 22:32:17 +0530 Subject: [PATCH 08/19] practice changes --- pymc/logprob/tensor.py | 8 ++--- tests/logprob/test_mixture.py | 4 +-- tests/logprob/test_tensor.py | 61 +++++++++++++++++++++++++++++++++-- tests/logprob/utils.py | 13 ++++++++ 4 files changed, 77 insertions(+), 9 deletions(-) diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 43e53c4df3..6f8efcd25e 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -304,17 +304,17 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[list[TensorVariable]]: # lifted towards the base RandomVariable. # TODO: If we include the support axis as meta information in each # intermediate MeasurableVariable, we can lift this restriction. - if not isinstance(base_var.owner.op, RandomVariable): - return None # pragma: no cover + # if not isinstance(base_var.owner.op, RandomVariable): + # return None # pragma: no cover - ndim_supp, supp_axis, d_type = get_measurable_meta_info(base_var.owner.op) + ndim_supp, supp_axis, measure_type = get_measurable_meta_info(base_var.owner.op) measurable_dimshuffle = MeasurableDimShuffle( node.op.input_broadcastable, node.op.new_order, ndim_supp=ndim_supp, supp_axes=supp_axis, - measure_type=d_type, + measure_type=measure_type, )(base_var) return [measurable_dimshuffle] diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 81fc78f050..c0c5c9d2ca 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -1075,7 +1075,7 @@ def test_meta_ifelse(): ndim_supp, supp_axes, measure_type = meta_info_helper(mix_rv, mix_vv) - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(if_rv.owner.op) + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(comp_then.owner.op) assert np.isclose( ndim_supp_base, @@ -1083,7 +1083,7 @@ def test_meta_ifelse(): ) assert supp_axes_base == supp_axes - assert measure_type_base != measure_type + assert measure_type_base == measure_type def test_ifelse_mixture_multiple_components(): diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index f87d765db2..3e9a013092 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -50,7 +50,11 @@ from pymc.logprob.rewriting import logprob_rewrites_db from pymc.logprob.tensor import naive_bcast_rv_lift from pymc.testing import assert_no_rvs -from tests.logprob.utils import meta_info_helper, scipy_logprob +from tests.logprob.utils import ( + get_measurable_meta_infos, + meta_info_helper, + scipy_logprob, +) def test_naive_bcast_rv_lift(): @@ -497,12 +501,16 @@ def test_measurable_dimshuffle(ds_order, multivariate): (2, "x", 0), # Swap, expand and drop ], ) -@pytest.mark.parametrize("multivariate", [True]) +@pytest.mark.parametrize("multivariate", (False, True)) def test_meta_measurable_dimshuffle(ds_order, multivariate): if multivariate: base_rv = pt.random.dirichlet([1, 2, 3], size=(2, 1)) + ds_rv = base_rv.dimshuffle(ds_order) + else: + base_rv = pt.random.beta(1, 2, size=(2, 1, 3)) + base_rv_1 = pt.exp(base_rv) + ds_rv = base_rv_1.dimshuffle(ds_order) - ds_rv = base_rv.dimshuffle(ds_order) base_vv = base_rv.clone() ds_vv = ds_rv.clone() @@ -519,6 +527,53 @@ def test_meta_measurable_dimshuffle(ds_order, multivariate): assert measure_type_base == measure_type +def test_meta_unmeargeable_dimshuffles(): + # Test that graphs with DimShuffles that cannot be lifted/merged fail + + # Initial support axis is at axis=-1 + x = pt.random.dirichlet( + np.ones((3,)), + size=(4, 2), + ) + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(x.owner.op) + # pytensor.dprint(x.owner.inputs[0]) + # print(ndim_supp_base) + # print(supp_axes_base) + # print(measure_type_base) + # print(x.shape) + + # Support axis is now at axis=-2 + y = x.dimshuffle((0, 2, 1)) + y_vv = y.clone() + ndim_supp, supp_axes, measure_type = meta_info_helper(y, y_vv) + pytensor.dprint(y.owner.outputs[0]) + # print(ndim_supp) + # print(supp_axes) + # print(measure_type) + # print(y.shape) + # Downstream dimshuffle will not be lifted through cumsum. If it ever is, + # we will need a different measurable Op example + z = pt.cumsum(y, axis=-2) + z_vv = z.clone() + ndim_supp_base, supp_axes_base, measure_type_base = meta_info_helper(z, z_vv) + # print(ndim_supp_base) + # print(supp_axes_base) + # print(measure_type_base) + # print(z.shape) + # Support axis is now at axis=-3 + w = z.dimshuffle((1, 0, 2)) + w_vv = w.clone() + ndim_supp_base, supp_axes_base, measure_type_base = meta_info_helper(w, w_vv) + # print(ndim_supp_base) + # print(supp_axes_base) + # print(measure_type_base) + # print(w.shape) + # TODO: Check that logp is correct if this type of graphs is ever supported + with pytest.raises(RuntimeError, match="could not be derived"): + conditional_logp({w: w_vv}) + assert 0 + + def test_unmeargeable_dimshuffles(): # Test that graphs with DimShuffles that cannot be lifted/merged fail diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index 4978b5482e..23e02d934e 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -36,6 +36,7 @@ import numpy as np +import pytensor from pytensor import tensor as pt from scipy import stats as stats @@ -124,3 +125,15 @@ def meta_info_helper(rv, vv): measure_type = node.op.measure_type return ndim_supp, supp_axes, measure_type + + +def get_measurable_meta_infos( + base_op, +): + # if not isinstance(base_op, MeasurableVariable): + # raise TypeError("base_op must be a RandomVariable or MeasurableVariable") + + # if isinstance(base_op, RandomVariable): + ndim_supp = base_op.ndim_supp + supp_axes = tuple(range(-ndim_supp, 0)) + return base_op.ndim_supp, supp_axes From 68ec9fc503b933e85c7b626f3dfdb64d1464d6a7 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Tue, 17 Oct 2023 15:09:23 +0530 Subject: [PATCH 09/19] draft changes till mixture --- pymc/distributions/multivariate.py | 26 +++++++ pymc/logprob/abstract.py | 18 ++++- pymc/logprob/binary.py | 9 ++- pymc/logprob/censoring.py | 12 ++- pymc/logprob/checks.py | 4 +- pymc/logprob/cumsum.py | 2 +- pymc/logprob/mixture.py | 41 ++++++++-- pymc/logprob/tensor.py | 51 +++++++++++-- pymc/logprob/transforms.py | 2 +- tests/logprob/test_binary.py | 8 +- tests/logprob/test_censoring.py | 21 ++++-- tests/logprob/test_checks.py | 4 +- tests/logprob/test_cumsum.py | 2 +- tests/logprob/test_mixture.py | 85 +++++++++++++-------- tests/logprob/test_scan.py | 2 +- tests/logprob/test_tensor.py | 115 +++++++++-------------------- tests/logprob/test_transforms.py | 2 +- tests/logprob/utils.py | 4 +- 18 files changed, 251 insertions(+), 157 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 956bca276d..dd38f1eb98 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -153,6 +153,14 @@ def quaddist_chol(value, mu, cov): delta = value - mu chol_cov = nan_lower_cholesky(cov) + if mat_type != "tau": + dist, logdet, ok = quaddist_chol(delta, chol_cov) + else: + dist, logdet, ok = quaddist_tau(delta, chol_cov) + if onedim: + return dist[0], logdet, ok + + return dist, logdet, ok diag = pt.diagonal(chol_cov, axis1=-2, axis2=-1) # Check if the covariance matrix is positive definite. @@ -284,6 +292,22 @@ class MvStudentTRV(RandomVariable): dtype = "floatX" _print_name = ("MvStudentT", "\\operatorname{MvStudentT}") + def make_node(self, rng, size, dtype, nu, mu, cov): + nu = pt.as_tensor_variable(nu) + if not nu.ndim == 0: + raise ValueError("nu must be a scalar (ndim=0).") + + return super().make_node(rng, size, dtype, nu, mu, cov) + + def __call__(self, nu, mu=None, cov=None, size=None, **kwargs): + dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype + + if mu is None: + mu = np.array([0.0], dtype=dtype) + if cov is None: + cov = np.array([[1.0]], dtype=dtype) + return super().__call__(nu, mu, cov, size=size, **kwargs) + def _supp_shape_from_params(self, dist_params, param_shapes=None): return supp_shape_from_ref_param_shape( ndim_supp=self.ndim_supp, @@ -2463,6 +2487,8 @@ def logp(value, W, node1, node2, N, sigma, zero_sum_stdev): return check_parameters(pairwise_difference + zero_sum, sigma > 0, msg="sigma > 0") + return check_parameters(pairwise_difference + zero_sum, sigma > 0, msg="sigma > 0") + class StickBreakingWeightsRV(RandomVariable): name = "stick_breaking_weights" diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 4c4c4f42a3..631deb493e 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -175,13 +175,23 @@ def __init__(self, scalar_op, *args, **kwargs): def get_measurable_meta_info( - base_op: Op, + base_var, ) -> Tuple[ Union[int, Tuple[int]], Tuple[Union[int, Tuple[int]]], Union[MeasureType, Tuple[MeasureType]] ]: + # instead of taking base_op, take base_var as input + # Get base_op from base_var.owner.op + # index= base_var.owner.outputs.index(base_var) gives the output + if not isinstance(base_var, MeasurableVariable): + base_op = base_var.owner.op + index = base_var.owner.outputs.index(base_var) + else: + base_op = base_var if not isinstance(base_op, MeasurableVariable): raise TypeError("base_op must be a RandomVariable or MeasurableVariable") + # Add a test for pm.mixture, exponentiate it. Ask for logprob of this as this is not a rv and also does not have ndim_supp and properties. Such a test might exist in distributions. Do check. + # TODO: Handle Symbolic random variables if isinstance(base_op, RandomVariable): ndim_supp = base_op.ndim_supp supp_axes = tuple(range(-ndim_supp, 0)) @@ -190,4 +200,10 @@ def get_measurable_meta_info( ) return base_op.ndim_supp, supp_axes, measure_type else: + if isinstance(base_op.ndim_supp, tuple): + if len(base_var.owner.outputs) != len(base_op.ndim_supp): + raise NotImplementedError("length of outputs and meta-propertues is different") + return base_op.ndim_supp[index], base_op.supp_axes, base_op.measure_type + # check if base_var.owner.outputs length is same as length of each prop( length of the tuple). If not , raise an error. + # We'll need this for scan or IfElse return base_op.ndim_supp, base_op.supp_axes, base_op.measure_type diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index 34765886c3..370a6218e1 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -25,6 +25,7 @@ from pymc.logprob.abstract import ( MeasurableElemwise, + MeasureType, _logcdf_helper, _logprob, _logprob_helper, @@ -82,13 +83,13 @@ def find_measurable_comparisons( elif isinstance(node_scalar_op, LE): node_scalar_op = GE() - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(measurable_var.owner.op) + ndim_supp, supp_axes, _ = get_measurable_meta_info(measurable_var) compared_op = MeasurableComparison( scalar_op=node_scalar_op, ndim_supp=ndim_supp, supp_axes=supp_axes, - measure_type=measure_type, + measure_type=MeasureType.Discrete, ) compared_rv = compared_op.make_node(measurable_var, const).default_output() return [compared_rv] @@ -156,12 +157,12 @@ def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[list[ return None node_scalar_op = node.op.scalar_op - ndim_supp, supp_axis, measure_type = get_measurable_meta_info(base_var.owner.op) + ndim_supp, supp_axis, measure_type = get_measurable_meta_info(base_var) bitwise_op = MeasurableBitwise( scalar_op=node_scalar_op, ndim_supp=ndim_supp, supp_axes=supp_axis, - measure_type=measure_type, + measure_type=MeasureType.Discrete, ) bitwise_rv = bitwise_op.make_node(base_var).default_output() return [bitwise_rv] diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index 904d78de91..5c5da5b46e 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -50,6 +50,7 @@ from pymc.logprob.abstract import ( MeasurableElemwise, + MeasureType, _logcdf, _logprob, get_measurable_meta_info, @@ -83,7 +84,11 @@ def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[Te lower_bound = lower_bound if (lower_bound is not base_var) else pt.constant(-np.inf) upper_bound = upper_bound if (upper_bound is not base_var) else pt.constant(np.inf) - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_var.owner.op) + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_var) + + if measure_type == MeasureType.Continuous: + measure_type = MeasureType.Mixed + measurable_clip = MeasurableClip( scalar_clip, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type ) @@ -174,9 +179,10 @@ def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[lis return None [base_var] = node.inputs - ndim_supp, supp_axis, d_type = get_measurable_meta_info(base_var.owner.op) + ndim_supp, supp_axis, _ = get_measurable_meta_info(base_var) + measure_type = MeasureType.Discrete rounded_op = MeasurableRound( - node.op.scalar_op, ndim_supp=ndim_supp, supp_axes=supp_axis, measure_type=d_type + node.op.scalar_op, ndim_supp=ndim_supp, supp_axes=supp_axis, measure_type=measure_type ) rounded_rv = rounded_op.make_node(base_var).default_output() rounded_rv.name = node.outputs[0].name diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index b8ed2dad0c..6a67a1b542 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -91,7 +91,7 @@ def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable ): return None # pragma: no cover - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rv.owner.op) + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rv) new_op = MeasurableSpecifyShape( ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type @@ -142,7 +142,7 @@ def find_measurable_check_and_raise(fgraph, node) -> Optional[list[TensorVariabl return None op = node.op - ndim_supp, supp_axis, d_type = get_measurable_meta_info(base_rv.owner.op) + ndim_supp, supp_axis, d_type = get_measurable_meta_info(base_rv) new_op = MeasurableCheckAndRaise( exc_type=op.exc_type, msg=op.msg, diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index 8a141fb940..bfed13e5fa 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -106,7 +106,7 @@ def find_measurable_cumsums(fgraph, node) -> Optional[list[TensorVariable]]: if not rv_map_feature.request_measurable(node.inputs): return None - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rv.owner.op) + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rv) new_op = MeasurableCumsum( axis=node.op.axis or 0, mode="add", diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 7068c08599..f93f12f424 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -68,6 +68,7 @@ from pymc.logprob.abstract import ( MeasurableElemwise, MeasurableVariable, + MeasureType, _logprob, _logprob_helper, get_measurable_meta_info, @@ -306,13 +307,25 @@ def find_measurable_index_mixture(fgraph, node): if rv_map_feature.request_measurable(mixture_rvs) != mixture_rvs: return None - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(mixture_rvs[0].owner.op) + all_ndim_supp = [] + all_supp_axes = [] + all_measure_type = [] + for i in range(0, len(mixture_rvs)): + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(mixture_rvs[0]) + all_ndim_supp.append(ndim_supp) + all_supp_axes.append(supp_axes) + all_measure_type.append(measure_type) + + if all_measure_type[1:] == all_measure_type[:-1]: + m_type = all_measure_type[0] + else: + m_type = MeasureType.Mixed # Replace this sub-graph with a `MixtureRV` mix_op = MixtureRV( - ndim_supp=ndim_supp, - supp_axes=supp_axes, - measure_type=measure_type, + ndim_supp=all_ndim_supp[0], + supp_axes=all_supp_axes[0], + measure_type=all_measure_type, indices_end_idx=1 + len(mixing_indices), out_dtype=old_mixture_rv.dtype, out_broadcastable=old_mixture_rv.broadcastable, @@ -434,10 +447,24 @@ def find_measurable_switch_mixture(fgraph, node): if rv_map_feature.request_measurable(components) != components: return None - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(components[0].owner.op) + all_ndim_supp = [] + all_supp_axes = [] + all_measure_type = [] + for i in range(0, len(components)): + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(components[i]) + all_ndim_supp.append(ndim_supp) + all_supp_axes.append(supp_axes) + all_measure_type.append(measure_type) + + if all_measure_type[1:] == all_measure_type[:-1]: + m_type = all_measure_type[0] + else: + m_type = MeasureType.Mixed + + # ndim_supp, supp_axes, measure_type = get_measurable_meta_info(components[0]) measurable_switch_mixture = MeasurableSwitchMixture( - scalar_switch, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + scalar_switch, ndim_supp=all_ndim_supp[0], supp_axes=all_supp_axes[0], measure_type=m_type ) return [measurable_switch_mixture(switch_cond, *components)] @@ -524,7 +551,7 @@ def find_measurable_ifelse_mixture(fgraph, node): if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_rvs): return None - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rvs[0].owner.op) + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rvs[0]) return ( MeasurableIfElse( diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 6f8efcd25e..3aaba966c5 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -67,7 +67,11 @@ @node_rewriter([Alloc]) def naive_bcast_rv_lift(fgraph, node): +<<<<<<< HEAD """Lift an ``Alloc`` through a ``RandomVariable`` ``Op``. +======= + """Lift a ``Alloc`` through a ``RandomVariable`` ``Op``. +>>>>>>> draft changes till mixture XXX: This implementation simply broadcasts the ``RandomVariable``'s parameters, which won't always work (e.g. multivariate distributions). @@ -226,7 +230,7 @@ def find_measurable_stacks(fgraph, node) -> Optional[list[TensorVariable]]: if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_vars): return None - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_vars[0].owner.op) + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_vars[0]) if is_join: measurable_stack = MeasurableJoin( @@ -279,6 +283,7 @@ def logprob_dimshuffle(op, values, base_var, **kwargs): # indexes in the original dimshuffle order. Otherwise, there is no way of # knowing which dimensions were consumed by the logprob function. redo_ds = [o for o in op.new_order if o == "x" or o < raw_logp.ndim] + # pytensor.dprint(values[0].shape) return raw_logp.dimshuffle(redo_ds) @@ -287,7 +292,7 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[list[TensorVariable]]: r"""Finds `Dimshuffle`\s for which a `logprob` can be computed.""" rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - + # pytensor.dprint(fgraph) if rv_map_feature is None: return None # pragma: no cover @@ -304,19 +309,51 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[list[TensorVariable]]: # lifted towards the base RandomVariable. # TODO: If we include the support axis as meta information in each # intermediate MeasurableVariable, we can lift this restriction. - # if not isinstance(base_var.owner.op, RandomVariable): - # return None # pragma: no cover + if not isinstance(base_var.owner.op, RandomVariable): + return None # pragma: no cover + + # parameter for drop and expand exists in dimshuffle + # if(len(node.op.new_order) != len(list(base_var.owner.inputs[1]))+1): # if there is expand/drop, we fails + # return None + + # use base_var.type.ndim instead of base_var.owner.inputs[1]) + # ref = list(range(0, len(list(base_var.owner.inputs[1]))+1)) # creating reference list : [0, 1, 2] + ref = list(range(0, base_var.type.ndim)) + + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_var) # (-1,) + new_supp_axes = list(supp_axes) # [-1] use empty list and append ( eliminate x ) and sort + + # check that if dropped dimensions is the supp axes + # add test for this case in cases that already failed + for x in supp_axes: + if base_var.type.ndim + x not in node.op.new_order: + return None - ndim_supp, supp_axis, measure_type = get_measurable_meta_info(base_var.owner.op) + # for x in range(0, len(new_supp_axes)): + # i = new_supp_axes[x] # i = -1 + # # print("a") + # # print(i) + # shift = ref[i] - node.op.new_order.index(ref[i]) # [0, 1, 2] and [2, 0, 1] : shift = 2-0 = 2 + + # # [0, 2 , 1] supp_axes = -2 + + # # node.op.new_order.index(ref[i]) from reverse + # # -(no.of dim - node.op.new_order.index(ref[i]) from reverse) + # new_supp_axes[x] = i-shift # supp_axis = -1-2 = -3 # [-3] + + # list comprehension + for x in new_supp_axes: + new_supp_axes[x] = node.op.new_order.index(ref[x]) - node.outputs[0].type.ndim + + supp_axes = tuple(new_supp_axes) # (-3,) measurable_dimshuffle = MeasurableDimShuffle( node.op.input_broadcastable, node.op.new_order, ndim_supp=ndim_supp, - supp_axes=supp_axis, + supp_axes=supp_axes, measure_type=measure_type, )(base_var) - return [measurable_dimshuffle] diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 62ef34ee1f..68d65bc503 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -502,7 +502,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[li transform_args_fn=lambda *inputs: inputs[-1], ) - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(measurable_input.owner.op) + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(measurable_input) transform_op = MeasurableTransform( scalar_op=scalar_op, transform=transform, diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py index ed6c9e6fe2..273a7994f1 100644 --- a/tests/logprob/test_binary.py +++ b/tests/logprob/test_binary.py @@ -21,7 +21,7 @@ from pymc import logp from pymc.logprob import conditional_logp -from pymc.logprob.abstract import get_measurable_meta_info +from pymc.logprob.abstract import MeasureType, get_measurable_meta_info from pymc.testing import assert_no_rvs from tests.logprob.utils import meta_info_helper @@ -83,9 +83,7 @@ def test_meta_info(comparison_op, exp_logp_true, exp_logp_false, inputs): comp_x_vv = comp_x_rv.clone() ndim_supp, supp_axes, measure_type = meta_info_helper(comp_x_rv, comp_x_vv) - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info( - base_rv.owner.op - ) + ndim_supp_base, supp_axes_base, _ = get_measurable_meta_info(base_rv) assert np.isclose( ndim_supp_base, @@ -93,7 +91,7 @@ def test_meta_info(comparison_op, exp_logp_true, exp_logp_false, inputs): ) assert supp_axes_base == supp_axes - assert measure_type_base == measure_type + assert measure_type == MeasureType.Discrete @pytest.mark.parametrize( diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index 725c3edf26..b06536377e 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -71,13 +71,22 @@ def test_continuous_rv_clip(): assert np.isclose(logp_fn(0), ref_scipy.logpdf(0)) -def test_clip_meta_info(): - base_rv = pt.random.normal(0.5, 1) - rv = pt.clip(base_rv, -2, 2) +@pytest.mark.parametrize( + "measure_type", + [("Discrete"), ("Continuous")], +) +def test_clip_meta_info(measure_type): + # use true and false + if measure_type == "Continuous": + base_rv = pt.random.normal(0.5, 1) + rv = pt.clip(base_rv, -2, 2) + else: + base_rv = pt.random.poisson(2) + rv = pt.clip(base_rv, 1, 4) vv = rv.clone() - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv.owner.op) + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv) ndim_supp, supp_axes, measure_type = meta_info_helper(rv, vv) @@ -297,7 +306,7 @@ def test_round_meta_info(rounding_op): xr_vv = xr.clone() ndim_supp, supp_axes, measure_type = meta_info_helper(xr, xr_vv) - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(x.owner.op) + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(x) assert np.isclose( ndim_supp_base, @@ -305,4 +314,4 @@ def test_round_meta_info(rounding_op): ) assert supp_axes_base == supp_axes - assert measure_type_base == measure_type + assert str(measure_type) == "MeasureType.Discrete" diff --git a/tests/logprob/test_checks.py b/tests/logprob/test_checks.py index d5b9aa19cf..0f2c517690 100644 --- a/tests/logprob/test_checks.py +++ b/tests/logprob/test_checks.py @@ -94,7 +94,7 @@ def test_shape_meta_info(): x_vv = x_rv.clone() ndim_supp, supp_axes, measure_type = meta_info_helper(x_rv, x_vv) - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(x_base.owner.op) + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(x_base) assert np.isclose( ndim_supp_base, @@ -139,7 +139,7 @@ def test_assert_meta_info(): assert_vv = assert_rv.clone() ndim_supp, supp_axes, measure_type = meta_info_helper(assert_rv, assert_vv) - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(rv.owner.op) + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(rv) assert np.isclose( ndim_supp_base, diff --git a/tests/logprob/test_cumsum.py b/tests/logprob/test_cumsum.py index 3646043a8d..2e1b488c91 100644 --- a/tests/logprob/test_cumsum.py +++ b/tests/logprob/test_cumsum.py @@ -88,7 +88,7 @@ def test_meta_info(size, axis): vv = rv.clone() base_rv = pt.random.normal(0, 1, size=size) base_vv = base_rv.clone() - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv.owner.op) + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv) ndim_supp, supp_axes, measure_type = meta_info_helper(rv, vv) diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index c0c5c9d2ca..7a854d6319 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -52,7 +52,11 @@ as_index_constant, ) -from pymc.logprob.abstract import MeasurableVariable, get_measurable_meta_info +from pymc.logprob.abstract import ( + MeasurableVariable, + MeasureType, + get_measurable_meta_info, +) from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.mixture import MeasurableSwitchMixture, expand_indices from pymc.logprob.rewriting import construct_ir_fgraph @@ -900,38 +904,44 @@ def test_mixture_with_DiracDelta(): assert m_vv in logp_res -def test_meta_with_DiracDelta(): +def test_switch_mixture(): srng = pt.random.RandomStream(29833) - X_rv = srng.normal(0, 1, name="X") - Y_rv = dirac_delta(0.0) - Y_rv.name = "Y" - - I_rv = srng.categorical([0.5, 0.5], size=1) + X_rv = srng.normal(-10.0, 0.1, name="X") + Y_rv = srng.normal(10.0, 0.1, name="Y") + I_rv = srng.bernoulli(0.5, name="I") i_vv = I_rv.clone() i_vv.name = "i" - M_rv = pt.stack([X_rv, Y_rv])[I_rv] - M_rv.name = "M" + # When I_rv == True, X_rv flows through otherwise Y_rv does + Z1_rv = pt.switch(I_rv, X_rv, Y_rv) + Z1_rv.name = "Z1" - m_vv = M_rv.clone() - m_vv.name = "m" + assert Z1_rv.eval({I_rv: 0}) > 5 + assert Z1_rv.eval({I_rv: 1}) < -5 - ndim_supp, supp_axes, measure_type = meta_info_helper(M_rv, m_vv) + z_vv = Z1_rv.clone() + z_vv.name = "z1" - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(X_rv.owner.op) + fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) + assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture) - assert np.isclose( - ndim_supp_base, - ndim_supp, - ) - assert supp_axes_base == supp_axes + # building the identical graph but with a stack to check that mixture logps are identical + Z2_rv = pt.stack((Y_rv, X_rv))[I_rv] - assert measure_type_base == measure_type + assert Z2_rv.eval({I_rv: 0}) > 5 + assert Z2_rv.eval({I_rv: 1}) < -5 + + z1_logp = conditional_logp({Z1_rv: z_vv, I_rv: i_vv}) + z2_logp = conditional_logp({Z2_rv: z_vv, I_rv: i_vv}) + z1_logp_combined = pt.sum([pt.sum(factor) for factor in z1_logp.values()]) + z2_logp_combined = pt.sum([pt.sum(factor) for factor in z2_logp.values()]) + np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 1})) + np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 1})) -def test_switch_mixture(): +def test_meta_switch_mixture(): srng = pt.random.RandomStream(29833) X_rv = srng.normal(-10.0, 0.1, name="X") @@ -954,18 +964,11 @@ def test_switch_mixture(): fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture) - # building the identical graph but with a stack to check that mixture logps are identical - Z2_rv = pt.stack((Y_rv, X_rv))[I_rv] + ndim_supp = fgraph.outputs[0].owner.op.ndim_supp + measure_type = fgraph.outputs[0].owner.op.measure_type - assert Z2_rv.eval({I_rv: 0}) > 5 - assert Z2_rv.eval({I_rv: 1}) < -5 - - z1_logp = conditional_logp({Z1_rv: z_vv, I_rv: i_vv}) - z2_logp = conditional_logp({Z2_rv: z_vv, I_rv: i_vv}) - z1_logp_combined = pt.sum([pt.sum(factor) for factor in z1_logp.values()]) - z2_logp_combined = pt.sum([pt.sum(factor) for factor in z2_logp.values()]) - np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 1})) - np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 1})) + np.testing.assert_almost_equal(0, ndim_supp) + assert measure_type == MeasureType.Continuous @pytest.mark.parametrize("switch_cond_scalar", (True, False)) @@ -997,6 +1000,24 @@ def test_switch_mixture_vector(switch_cond_scalar): ) +pytest.mark.parametrize("switch_cond_scalar", (True, False)) + + +def test_switch_mixture_vector(switch_cond_scalar): + if switch_cond_scalar: + switch_cond = pt.scalar("switch_cond", dtype=bool) + else: + switch_cond = pt.vector("switch_cond", dtype=bool) + true_branch = pt.exp(pt.random.normal(size=(4,))) + false_branch = pt.abs(pt.random.normal(size=(4,))) + + switch = pt.switch(switch_cond, true_branch, false_branch) + switch.name = "switch_mix" + switch_value = switch.clone() + + ndim_supp, supp_axes, measure_type = meta_info_helper(switch, switch_value) + + def test_switch_mixture_measurable_cond_fails(): """Test that logprob inference fails when the switch condition is an unvalued measurable variable. @@ -1075,7 +1096,7 @@ def test_meta_ifelse(): ndim_supp, supp_axes, measure_type = meta_info_helper(mix_rv, mix_vv) - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(comp_then.owner.op) + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(comp_then) assert np.isclose( ndim_supp_base, diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index b07e3b5813..eee075b913 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -518,7 +518,7 @@ def test_meta_scan_over_seqs(): xs_vv = ys.clone() ys_vv = ys.clone() - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(xs.owner.op) + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(xs) ndim_supp, supp_axes, measure_type = meta_info_helper(ys, ys_vv) diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index 3e9a013092..7e64695a2e 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -45,6 +45,8 @@ from pytensor.tensor.basic import Alloc from scipy import stats as st +import pymc as pm + from pymc.logprob.abstract import get_measurable_meta_info from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.rewriting import logprob_rewrites_db @@ -147,15 +149,9 @@ def test_meta_make_vector(): y_rv = pt.stack((base1_rv, base2_rv, base3_rv)) y_rv.name = "y" - ndim_supp_base_1, supp_axes_base_1, measure_type_base_1 = get_measurable_meta_info( - base1_rv.owner.op - ) - ndim_supp_base_2, supp_axes_base_2, measure_type_base_2 = get_measurable_meta_info( - base2_rv.owner.op - ) - ndim_supp_base_3, supp_axes_base_3, measure_type_base_3 = get_measurable_meta_info( - base3_rv.owner.op - ) + ndim_supp_base_1, supp_axes_base_1, measure_type_base_1 = get_measurable_meta_info(base1_rv) + ndim_supp_base_2, supp_axes_base_2, measure_type_base_2 = get_measurable_meta_info(base2_rv) + ndim_supp_base_3, supp_axes_base_3, measure_type_base_3 = get_measurable_meta_info(base3_rv) base1_vv = base1_rv.clone() base2_vv = base2_rv.clone() @@ -330,12 +326,8 @@ def test_meta_join_univariate(size1, size2, axis, concatenate): y_rv = pt.stack((base1_rv, base2_rv), axis=axis) y_rv.name = "y" - ndim_supp_base_1, supp_axes_base_1, measure_type_base_1 = get_measurable_meta_info( - base1_rv.owner.op - ) - ndim_supp_base_2, supp_axes_base_2, measure_type_base_2 = get_measurable_meta_info( - base2_rv.owner.op - ) + ndim_supp_base_1, supp_axes_base_1, measure_type_base_1 = get_measurable_meta_info(base1_rv) + ndim_supp_base_2, supp_axes_base_2, measure_type_base_2 = get_measurable_meta_info(base2_rv) y_vv = y_rv.clone() ndim_supp, supp_axes, measure_type = meta_info_helper(y_rv, y_vv) @@ -436,14 +428,14 @@ def test_join_mixed_ndim_supp(): (2, 1, 0), # Swap (1, 2, 0), # Swap (0, 1, 2, "x"), # Expand - ("x", 0, 1, 2), # Expand - ( - 0, - 2, - ), # Drop - (2, 0), # Swap and drop + # ("x", 0, 1, 2), # Expand + # ( + # 0, + # 2, + # ), # Drop + # (2, 0), # Swap and drop (2, 1, "x", 0), # Swap and expand - ("x", 0, 2), # Expand and drop + # ("x", 0, 2), # Expand and drop (2, "x", 0), # Swap, expand and drop ], ) @@ -483,14 +475,16 @@ def test_measurable_dimshuffle(ds_order, multivariate): np.testing.assert_array_equal(ref_logp_fn(base_test_value), ds_logp_fn(ds_test_value)) +# TODO: seperate test for univariate and matrixNormal @pytensor.config.change_flags(cxx="") @pytest.mark.parametrize( "ds_order", [ - (0, 2, 1), # Swap - (2, 1, 0), # Swap - (1, 2, 0), # Swap - (0, 1, 2, "x"), # Expand + # (2, 0, 1), # Swap + # (0, 2, 1), # Swap + # (2, 1, 0), # Swap + # (1, 2, 0), # Swap + # (0, 1, 2, "x"), # Expand ( 0, 2, @@ -503,74 +497,31 @@ def test_measurable_dimshuffle(ds_order, multivariate): ) @pytest.mark.parametrize("multivariate", (False, True)) def test_meta_measurable_dimshuffle(ds_order, multivariate): + # hardcore the answer in parameter and test if multivariate: - base_rv = pt.random.dirichlet([1, 2, 3], size=(2, 1)) + base_rv = pm.Dirichlet.dist([1, 1, 1], shape=(7, 1, 3)) + # base_rv = pt.random.dirichlet([1, 2, 3], size=(2, 1)) ds_rv = base_rv.dimshuffle(ds_order) + base_vv = base_rv.clone() + else: base_rv = pt.random.beta(1, 2, size=(2, 1, 3)) base_rv_1 = pt.exp(base_rv) ds_rv = base_rv_1.dimshuffle(ds_order) + base_vv = base_rv_1.clone() - base_vv = base_rv.clone() ds_vv = ds_rv.clone() - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv.owner.op) + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv) + print(ndim_supp_base) + print(supp_axes_base) + print(measure_type_base) ndim_supp, supp_axes, measure_type = meta_info_helper(ds_rv, ds_vv) + print(ndim_supp) + print(supp_axes) + print(measure_type) - assert np.isclose( - ndim_supp_base, - ndim_supp, - ) - assert supp_axes_base == supp_axes - - assert measure_type_base == measure_type - - -def test_meta_unmeargeable_dimshuffles(): - # Test that graphs with DimShuffles that cannot be lifted/merged fail - - # Initial support axis is at axis=-1 - x = pt.random.dirichlet( - np.ones((3,)), - size=(4, 2), - ) - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(x.owner.op) - # pytensor.dprint(x.owner.inputs[0]) - # print(ndim_supp_base) - # print(supp_axes_base) - # print(measure_type_base) - # print(x.shape) - - # Support axis is now at axis=-2 - y = x.dimshuffle((0, 2, 1)) - y_vv = y.clone() - ndim_supp, supp_axes, measure_type = meta_info_helper(y, y_vv) - pytensor.dprint(y.owner.outputs[0]) - # print(ndim_supp) - # print(supp_axes) - # print(measure_type) - # print(y.shape) - # Downstream dimshuffle will not be lifted through cumsum. If it ever is, - # we will need a different measurable Op example - z = pt.cumsum(y, axis=-2) - z_vv = z.clone() - ndim_supp_base, supp_axes_base, measure_type_base = meta_info_helper(z, z_vv) - # print(ndim_supp_base) - # print(supp_axes_base) - # print(measure_type_base) - # print(z.shape) - # Support axis is now at axis=-3 - w = z.dimshuffle((1, 0, 2)) - w_vv = w.clone() - ndim_supp_base, supp_axes_base, measure_type_base = meta_info_helper(w, w_vv) - # print(ndim_supp_base) - # print(supp_axes_base) - # print(measure_type_base) - # print(w.shape) - # TODO: Check that logp is correct if this type of graphs is ever supported - with pytest.raises(RuntimeError, match="could not be derived"): - conditional_logp({w: w_vv}) assert 0 diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 7ecd0ef851..de4fb7e049 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -236,7 +236,7 @@ def test_meta_exp_transform_rv(): y_vv = y_rv.clone() - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv.owner.op) + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv) ndim_supp, supp_axes, measure_type = meta_info_helper(y_rv, y_vv) diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index 23e02d934e..5b1309d066 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -117,9 +117,11 @@ def scipy_logprob_tester( def meta_info_helper(rv, vv): + # pytensor.config.optimizer_verbose=True + # ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"]), exclude=) fgraph, _, _ = construct_ir_fgraph({rv: vv}) node = fgraph.outputs[0].owner - + # pytensor.dprint(node) ndim_supp = node.op.ndim_supp supp_axes = node.op.supp_axes measure_type = node.op.measure_type From 6cd1f0bfef5c6566acb980b3235ffe5ca51d7aec Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Mon, 13 Nov 2023 13:22:27 +0530 Subject: [PATCH 10/19] changes till 13th November --- pymc/distributions/multivariate.py | 26 ------------------------- pymc/logprob/abstract.py | 3 +++ pymc/logprob/mixture.py | 2 +- pymc/logprob/scan.py | 2 +- pymc/logprob/tensor.py | 4 ---- tests/logprob/test_mixture.py | 31 +++++++++++++++--------------- tests/logprob/test_scan.py | 2 +- 7 files changed, 21 insertions(+), 49 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index dd38f1eb98..956bca276d 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -153,14 +153,6 @@ def quaddist_chol(value, mu, cov): delta = value - mu chol_cov = nan_lower_cholesky(cov) - if mat_type != "tau": - dist, logdet, ok = quaddist_chol(delta, chol_cov) - else: - dist, logdet, ok = quaddist_tau(delta, chol_cov) - if onedim: - return dist[0], logdet, ok - - return dist, logdet, ok diag = pt.diagonal(chol_cov, axis1=-2, axis2=-1) # Check if the covariance matrix is positive definite. @@ -292,22 +284,6 @@ class MvStudentTRV(RandomVariable): dtype = "floatX" _print_name = ("MvStudentT", "\\operatorname{MvStudentT}") - def make_node(self, rng, size, dtype, nu, mu, cov): - nu = pt.as_tensor_variable(nu) - if not nu.ndim == 0: - raise ValueError("nu must be a scalar (ndim=0).") - - return super().make_node(rng, size, dtype, nu, mu, cov) - - def __call__(self, nu, mu=None, cov=None, size=None, **kwargs): - dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype - - if mu is None: - mu = np.array([0.0], dtype=dtype) - if cov is None: - cov = np.array([[1.0]], dtype=dtype) - return super().__call__(nu, mu, cov, size=size, **kwargs) - def _supp_shape_from_params(self, dist_params, param_shapes=None): return supp_shape_from_ref_param_shape( ndim_supp=self.ndim_supp, @@ -2487,8 +2463,6 @@ def logp(value, W, node1, node2, N, sigma, zero_sum_stdev): return check_parameters(pairwise_difference + zero_sum, sigma > 0, msg="sigma > 0") - return check_parameters(pairwise_difference + zero_sum, sigma > 0, msg="sigma > 0") - class StickBreakingWeightsRV(RandomVariable): name = "stick_breaking_weights" diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 631deb493e..0d7af7268e 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -192,6 +192,9 @@ def get_measurable_meta_info( # Add a test for pm.mixture, exponentiate it. Ask for logprob of this as this is not a rv and also does not have ndim_supp and properties. Such a test might exist in distributions. Do check. # TODO: Handle Symbolic random variables + + # Handle Diracdelta specially + if isinstance(base_op, RandomVariable): ndim_supp = base_op.ndim_supp supp_axes = tuple(range(-ndim_supp, 0)) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index f93f12f424..8a26046f5f 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -311,7 +311,7 @@ def find_measurable_index_mixture(fgraph, node): all_supp_axes = [] all_measure_type = [] for i in range(0, len(mixture_rvs)): - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(mixture_rvs[0]) + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(mixture_rvs[i]) all_ndim_supp.append(ndim_supp) all_supp_axes.append(supp_axes) all_measure_type.append(measure_type) diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index c954a82eeb..48d8c8ea1b 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -478,7 +478,7 @@ def find_measurable_scans(fgraph, node): all_ndim_supp = [] all_supp_axes = [] all_measure_type = [] - for n in local_fgraph_topo: + for n in curr_scanargs.inner_outputs: if isinstance(n.op, MeasurableVariable): ndim_supp, supp_axes, measure_type = get_measurable_meta_info(n.op) all_ndim_supp.append(ndim_supp) diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 3aaba966c5..fd15c55019 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -67,11 +67,7 @@ @node_rewriter([Alloc]) def naive_bcast_rv_lift(fgraph, node): -<<<<<<< HEAD - """Lift an ``Alloc`` through a ``RandomVariable`` ``Op``. -======= """Lift a ``Alloc`` through a ``RandomVariable`` ``Op``. ->>>>>>> draft changes till mixture XXX: This implementation simply broadcasts the ``RandomVariable``'s parameters, which won't always work (e.g. multivariate distributions). diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 7a854d6319..724b0dcc41 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -52,6 +52,8 @@ as_index_constant, ) +# from pymc.logprob.utils import dirac_delta +from pymc.distributions.distribution import diracdelta from pymc.logprob.abstract import ( MeasurableVariable, MeasureType, @@ -60,7 +62,6 @@ from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.mixture import MeasurableSwitchMixture, expand_indices from pymc.logprob.rewriting import construct_ir_fgraph -from pymc.logprob.utils import dirac_delta from pymc.testing import assert_no_rvs from tests.logprob.utils import meta_info_helper, scipy_logprob @@ -885,7 +886,7 @@ def test_mixture_with_DiracDelta(): srng = pt.random.RandomStream(29833) X_rv = srng.normal(0, 1, name="X") - Y_rv = dirac_delta(0.0) + Y_rv = diracdelta(0.0) Y_rv.name = "Y" I_rv = srng.categorical([0.5, 0.5], size=1) @@ -1000,22 +1001,20 @@ def test_switch_mixture_vector(switch_cond_scalar): ) -pytest.mark.parametrize("switch_cond_scalar", (True, False)) +# pytest.mark.parametrize("switch_cond_scalar", (True, False)) +# def test_meta_switch_mixture_vector(switch_cond_scalar): +# if switch_cond_scalar: +# switch_cond = pt.scalar("switch_cond", dtype=bool) +# else: +# switch_cond = pt.vector("switch_cond", dtype=bool) +# true_branch = pt.exp(pt.random.normal(size=(4,))) +# false_branch = pt.abs(pt.random.normal(size=(4,))) +# switch = pt.switch(switch_cond, true_branch, false_branch) +# switch.name = "switch_mix" +# switch_value = switch.clone() -def test_switch_mixture_vector(switch_cond_scalar): - if switch_cond_scalar: - switch_cond = pt.scalar("switch_cond", dtype=bool) - else: - switch_cond = pt.vector("switch_cond", dtype=bool) - true_branch = pt.exp(pt.random.normal(size=(4,))) - false_branch = pt.abs(pt.random.normal(size=(4,))) - - switch = pt.switch(switch_cond, true_branch, false_branch) - switch.name = "switch_mix" - switch_value = switch.clone() - - ndim_supp, supp_axes, measure_type = meta_info_helper(switch, switch_value) +# ndim_supp, supp_axes, measure_type = meta_info_helper(switch, switch_value) def test_switch_mixture_measurable_cond_fails(): diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index eee075b913..59a4cfbbe5 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -510,7 +510,7 @@ def test_meta_scan_over_seqs(): rng = np.random.default_rng(543) n_steps = 10 - xs = pt.random.normal(size=(n_steps,), name="xs") + xs = pt.random.normal(size=(n_steps,), name="xs") # use vector with a fixed size ys, _ = pytensor.scan( fn=lambda x: pt.random.normal(x), sequences=[xs], outputs_info=[None], name="ys" ) From 46f58d7b2735c001aa114929d9dd36cbcde4edd1 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Wed, 15 Nov 2023 20:49:40 +0530 Subject: [PATCH 11/19] changes in scan and ifelse --- pymc/logprob/abstract.py | 2 +- pymc/logprob/mixture.py | 15 +++++++++-- pymc/logprob/scan.py | 27 ++++++++++++-------- tests/logprob/test_mixture.py | 18 +++++++++----- tests/logprob/test_scan.py | 47 +++++++++++++++++++++++++---------- 5 files changed, 77 insertions(+), 32 deletions(-) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 0d7af7268e..c3deba2712 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -205,7 +205,7 @@ def get_measurable_meta_info( else: if isinstance(base_op.ndim_supp, tuple): if len(base_var.owner.outputs) != len(base_op.ndim_supp): - raise NotImplementedError("length of outputs and meta-propertues is different") + raise NotImplementedError("length of outputs and meta-properties is different") return base_op.ndim_supp[index], base_op.supp_axes, base_op.measure_type # check if base_var.owner.outputs length is same as length of each prop( length of the tuple). If not , raise an error. # We'll need this for scan or IfElse diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 8a26046f5f..9a85dd7250 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -551,11 +551,22 @@ def find_measurable_ifelse_mixture(fgraph, node): if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_rvs): return None - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rvs[0]) + ndim_supp_all = () + supp_axes_all = () + measure_type_all = () + + for i in range(0, len(base_rvs)): + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rvs[i]) + ndim_supp_all += (ndim_supp,) + supp_axes_all += (supp_axes,) + measure_type_all += (measure_type,) return ( MeasurableIfElse( - n_outs=op.n_outs, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + n_outs=op.n_outs, + ndim_supp=ndim_supp_all, + supp_axes=supp_axes_all, + measure_type=measure_type_all, ) .make_node(if_var, *base_rvs) .outputs diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 48d8c8ea1b..3b8180973f 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -54,7 +54,12 @@ from pytensor.tensor.variable import TensorVariable from pytensor.updates import OrderedUpdates -from pymc.logprob.abstract import MeasurableVariable, _logprob, get_measurable_meta_info +from pymc.logprob.abstract import ( + MeasurableVariable, + MeasureType, + _logprob, + get_measurable_meta_info, +) from pymc.logprob.basic import conditional_logp from pymc.logprob.rewriting import ( PreserveRVMappings, @@ -475,15 +480,17 @@ def find_measurable_scans(fgraph, node): [o for o in curr_scanargs.inner_outputs if not isinstance(o.type, RandomType)], clients=clients, ) - all_ndim_supp = [] - all_supp_axes = [] - all_measure_type = [] + all_ndim_supp = () + all_supp_axes = () for n in curr_scanargs.inner_outputs: - if isinstance(n.op, MeasurableVariable): - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(n.op) - all_ndim_supp.append(ndim_supp) - all_supp_axes.append(supp_axes) - all_measure_type.append(measure_type) + ndim_supp, supp_axes, _ = get_measurable_meta_info(n) + all_ndim_supp += (ndim_supp,) + all_supp_axes += (supp_axes,) + + if len(all_ndim_supp) == 1: + measure_type = MeasureType.Continuous + else: + measure_type = MeasureType.Discrete op = MeasurableScan( curr_scanargs.inner_inputs, @@ -491,7 +498,7 @@ def find_measurable_scans(fgraph, node): curr_scanargs.info, ndim_supp=all_ndim_supp, supp_axes=all_supp_axes, - measure_type=all_measure_type, + measure_type=measure_type, mode=node.op.mode, ) new_node = op.make_node(*curr_scanargs.outer_inputs) diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 724b0dcc41..1fcfd3ff79 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -1097,13 +1097,19 @@ def test_meta_ifelse(): ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(comp_then) - assert np.isclose( - ndim_supp_base, - ndim_supp, - ) - assert supp_axes_base == supp_axes + print(ndim_supp_base) + print(supp_axes_base) + print(measure_type_base) + + assert 0 + + # assert np.isclose( + # ndim_supp_base, + # ndim_supp, + # ) + # assert supp_axes_base == supp_axes - assert measure_type_base == measure_type + # assert measure_type_base == measure_type def test_ifelse_mixture_multiple_components(): diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index 59a4cfbbe5..1838b5df8a 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -505,6 +505,25 @@ def test_scan_over_seqs(): ) +def test_meta_scan_non_pure_rv_output(): + grw, _ = pytensor.scan( + fn=lambda xtm1: pt.random.normal() + xtm1, + outputs_info=[pt.zeros(())], + n_steps=10, + name="grw", + ) + + grw_vv = grw.clone() + + ndim_supp_1, supp_axes_1, measure_type_1 = meta_info_helper(grw, grw_vv) + + print(ndim_supp_1) + print(supp_axes_1) + print(measure_type_1) + + assert 0 + + def test_meta_scan_over_seqs(): """Test that logprob inference for scans based on sequences (mapping).""" rng = np.random.default_rng(543) @@ -512,23 +531,25 @@ def test_meta_scan_over_seqs(): xs = pt.random.normal(size=(n_steps,), name="xs") # use vector with a fixed size ys, _ = pytensor.scan( - fn=lambda x: pt.random.normal(x), sequences=[xs], outputs_info=[None], name="ys" + fn=lambda x, x1: (pt.random.normal(0, 1), pt.random.poisson(x1)), + sequences=[xs, xs], + outputs_info=[None, None], + name=("ys1", "ys2"), ) + ys1_vv = ys[0].clone() + ys2_vv = ys[1].clone() - xs_vv = ys.clone() - ys_vv = ys.clone() - - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(xs) + ndim_supp_1, supp_axes_1, measure_type_1 = meta_info_helper(ys[0], ys1_vv) + ndim_supp_2, supp_axes_2, measure_type_2 = meta_info_helper(ys[1], ys1_vv) - ndim_supp, supp_axes, measure_type = meta_info_helper(ys, ys_vv) - - assert np.isclose( - ndim_supp_base, - ndim_supp, - ) - assert supp_axes_base == supp_axes[0] + print(ndim_supp_1) + print(ndim_supp_2) + print(supp_axes_1) + print(supp_axes_2) + print(measure_type_1) + print(measure_type_2) - assert measure_type_base == measure_type[0] + assert 0 def test_scan_carried_deterministic_state(): From d4a596a180fe658d44ae25c02785631ef66b5567 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Tue, 21 Nov 2023 18:22:23 +0530 Subject: [PATCH 12/19] Errors in order, mixture, scan --- pymc/logprob/order.py | 6 +++++- tests/logprob/test_mixture.py | 4 ++-- tests/logprob/test_scan.py | 10 +++++----- tests/logprob/utils.py | 2 +- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 0dc78d0b0d..fa61dbc093 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -51,6 +51,7 @@ _logcdf_helper, _logprob, _logprob_helper, + get_measurable_meta_info, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db from pymc.logprob.utils import find_negated_var @@ -58,9 +59,12 @@ from pymc.pytensorf import constant_fold -class MeasurableMax(Max): +class MeasurableMax(MeasurableVariable, Max): """A placeholder used to specify a log-likelihood for a max sub-graph.""" + # def __init__(self, *args, **kwargs): + # super().__init__(*args, **kwargs) + MeasurableVariable.register(MeasurableMax) diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 1fcfd3ff79..31ca5fdc8a 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -52,8 +52,7 @@ as_index_constant, ) -# from pymc.logprob.utils import dirac_delta -from pymc.distributions.distribution import diracdelta +# from pymc.distributions.distribution import diracdelta from pymc.logprob.abstract import ( MeasurableVariable, MeasureType, @@ -62,6 +61,7 @@ from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.mixture import MeasurableSwitchMixture, expand_indices from pymc.logprob.rewriting import construct_ir_fgraph +from pymc.logprob.utils import dirac_delta as diracdelta from pymc.testing import assert_no_rvs from tests.logprob.utils import meta_info_helper, scipy_logprob diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index 1838b5df8a..92ec01f3e0 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -510,12 +510,12 @@ def test_meta_scan_non_pure_rv_output(): fn=lambda xtm1: pt.random.normal() + xtm1, outputs_info=[pt.zeros(())], n_steps=10, - name="grw", + name="grw1", ) - grw_vv = grw.clone() + grw1_vv = grw[0].clone() - ndim_supp_1, supp_axes_1, measure_type_1 = meta_info_helper(grw, grw_vv) + ndim_supp_1, supp_axes_1, measure_type_1 = meta_info_helper(grw[0], grw1_vv) print(ndim_supp_1) print(supp_axes_1) @@ -531,7 +531,7 @@ def test_meta_scan_over_seqs(): xs = pt.random.normal(size=(n_steps,), name="xs") # use vector with a fixed size ys, _ = pytensor.scan( - fn=lambda x, x1: (pt.random.normal(0, 1), pt.random.poisson(x1)), + fn=lambda x, x1: (pt.random.normal(x), pt.random.poisson(x1)), sequences=[xs, xs], outputs_info=[None, None], name=("ys1", "ys2"), @@ -540,7 +540,7 @@ def test_meta_scan_over_seqs(): ys2_vv = ys[1].clone() ndim_supp_1, supp_axes_1, measure_type_1 = meta_info_helper(ys[0], ys1_vv) - ndim_supp_2, supp_axes_2, measure_type_2 = meta_info_helper(ys[1], ys1_vv) + ndim_supp_2, supp_axes_2, measure_type_2 = meta_info_helper(ys[1], ys2_vv) print(ndim_supp_1) print(ndim_supp_2) diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index 5b1309d066..23fafd33e0 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -121,7 +121,7 @@ def meta_info_helper(rv, vv): # ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"]), exclude=) fgraph, _, _ = construct_ir_fgraph({rv: vv}) node = fgraph.outputs[0].owner - # pytensor.dprint(node) + pytensor.dprint(node) ndim_supp = node.op.ndim_supp supp_axes = node.op.supp_axes measure_type = node.op.measure_type From 9257b22db1f01a75d765c1116c5366ca06e5c637 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sat, 23 Dec 2023 09:50:48 +0530 Subject: [PATCH 13/19] Finalising --- pymc/logprob/abstract.py | 12 ++++++--- pymc/logprob/mixture.py | 10 ++++++-- pymc/logprob/order.py | 26 +++++++++++++++---- pymc/logprob/scan.py | 27 ++++++++------------ tests/logprob/test_mixture.py | 16 +++--------- tests/logprob/test_order.py | 39 +++++++++++++++++++++++++++++ tests/logprob/test_scan.py | 45 +++++++++++++++++---------------- tests/logprob/test_tensor.py | 47 +++++++++++++++++------------------ tests/logprob/utils.py | 4 ++- 9 files changed, 140 insertions(+), 86 deletions(-) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index c3deba2712..6ea67ce9ad 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -36,6 +36,7 @@ import abc +from typing import Tuple, Union from collections.abc import Sequence from functools import singledispatch @@ -179,6 +180,8 @@ def get_measurable_meta_info( ) -> Tuple[ Union[int, Tuple[int]], Tuple[Union[int, Tuple[int]]], Union[MeasureType, Tuple[MeasureType]] ]: + from pymc.logprob.utils import DiracDelta + # instead of taking base_op, take base_var as input # Get base_op from base_var.owner.op # index= base_var.owner.outputs.index(base_var) gives the output @@ -193,7 +196,11 @@ def get_measurable_meta_info( # Add a test for pm.mixture, exponentiate it. Ask for logprob of this as this is not a rv and also does not have ndim_supp and properties. Such a test might exist in distributions. Do check. # TODO: Handle Symbolic random variables - # Handle Diracdelta specially + if isinstance(base_op, DiracDelta): + ndim_supp = 0 + supp_axes = () + measure_type = MeasureType.Discrete + return ndim_supp, supp_axes, measure_type if isinstance(base_op, RandomVariable): ndim_supp = base_op.ndim_supp @@ -203,10 +210,9 @@ def get_measurable_meta_info( ) return base_op.ndim_supp, supp_axes, measure_type else: + # We'll need this for scan or IfElse if isinstance(base_op.ndim_supp, tuple): if len(base_var.owner.outputs) != len(base_op.ndim_supp): raise NotImplementedError("length of outputs and meta-properties is different") return base_op.ndim_supp[index], base_op.supp_axes, base_op.measure_type - # check if base_var.owner.outputs length is same as length of each prop( length of the tuple). If not , raise an error. - # We'll need this for scan or IfElse return base_op.ndim_supp, base_op.supp_axes, base_op.measure_type diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 9a85dd7250..3de5ade38e 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -555,8 +555,14 @@ def find_measurable_ifelse_mixture(fgraph, node): supp_axes_all = () measure_type_all = () - for i in range(0, len(base_rvs)): - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rvs[i]) + half_len = int(len(base_rvs) / 2) + length = len(base_rvs) + + for base_rv1, base_rv2 in zip(base_rvs[0:half_len], base_rvs[half_len + 1 : length - 1]): + meta_info = get_measurable_meta_info(base_rv1) + if meta_info != get_measurable_meta_info(base_rv2): + return None + ndim_supp, supp_axes, measure_type = meta_info ndim_supp_all += (ndim_supp,) supp_axes_all += (supp_axes,) measure_type_all += (measure_type,) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index fa61dbc093..2ae31261c5 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -64,9 +64,17 @@ class MeasurableMax(MeasurableVariable, Max): # def __init__(self, *args, **kwargs): # super().__init__(*args, **kwargs) + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + ndim_supp = kwargs.get("ndim_supp", self.ndim_supp) + supp_axes = kwargs.get("supp_axes", self.supp_axes) + measure_type = kwargs.get("measure_type", self.measure_type) + return type(self)( + axis=axis, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) -MeasurableVariable.register(MeasurableMax) +# MeasurableVariable.register(MeasurableMax) remove this for all class MeasurableMaxDiscrete(Max): @@ -166,8 +174,14 @@ class MeasurableMaxNeg(Max): """A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph. This shows up in the graph of min, which is (neg(max(neg(x))).""" - -MeasurableVariable.register(MeasurableMaxNeg) + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + ndim_supp = kwargs.get("ndim_supp", self.ndim_supp) + supp_axes = kwargs.get("supp_axes", self.supp_axes) + measure_type = kwargs.get("measure_type", self.measure_type) + return type(self)( + axis=axis, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) class MeasurableDiscreteMaxNeg(Max): @@ -216,12 +230,14 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[ if not rv_map_feature.request_measurable([base_rv]): return None + + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_var) # distinguish measurable discrete and continuous (because logprob is different) if base_rv.owner.op.dtype.startswith("int"): - measurable_min = MeasurableDiscreteMaxNeg(list(axis)) + measurable_min = MeasurableDiscreteMaxNeg(list(axis), ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type) else: - measurable_min = MeasurableMaxNeg(list(axis)) + measurable_min = MeasurableMaxNeg(list(axis), ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type) return measurable_min.make_node(base_rv).outputs diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 3b8180973f..d87f574af2 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -54,12 +54,7 @@ from pytensor.tensor.variable import TensorVariable from pytensor.updates import OrderedUpdates -from pymc.logprob.abstract import ( - MeasurableVariable, - MeasureType, - _logprob, - get_measurable_meta_info, -) +from pymc.logprob.abstract import MeasurableVariable, _logprob, get_measurable_meta_info from pymc.logprob.basic import conditional_logp from pymc.logprob.rewriting import ( PreserveRVMappings, @@ -482,15 +477,15 @@ def find_measurable_scans(fgraph, node): ) all_ndim_supp = () all_supp_axes = () - for n in curr_scanargs.inner_outputs: - ndim_supp, supp_axes, _ = get_measurable_meta_info(n) - all_ndim_supp += (ndim_supp,) - all_supp_axes += (supp_axes,) - - if len(all_ndim_supp) == 1: - measure_type = MeasureType.Continuous - else: - measure_type = MeasureType.Discrete + all_measure_type = () + for var in curr_scanargs.inner_outputs: + if var.owner.op is None: + continue + if isinstance(var.owner.op, MeasurableVariable): + ndim_supp, supp_axes, measure_type = get_measurable_meta_info(var) + all_ndim_supp += (ndim_supp,) + all_supp_axes += (supp_axes,) + all_measure_type += (measure_type,) op = MeasurableScan( curr_scanargs.inner_inputs, @@ -498,7 +493,7 @@ def find_measurable_scans(fgraph, node): curr_scanargs.info, ndim_supp=all_ndim_supp, supp_axes=all_supp_axes, - measure_type=measure_type, + measure_type=all_measure_type, mode=node.op.mode, ) new_node = op.make_node(*curr_scanargs.outer_inputs) diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 31ca5fdc8a..2271309c23 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -1097,19 +1097,9 @@ def test_meta_ifelse(): ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(comp_then) - print(ndim_supp_base) - print(supp_axes_base) - print(measure_type_base) - - assert 0 - - # assert np.isclose( - # ndim_supp_base, - # ndim_supp, - # ) - # assert supp_axes_base == supp_axes - - # assert measure_type_base == measure_type + assert ndim_supp_base == 0 + assert supp_axes_base == () + assert measure_type_base == MeasureType.Continuous def test_ifelse_mixture_multiple_components(): diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 4d15240375..22fc529eff 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -45,7 +45,9 @@ import pymc as pm from pymc import logp +from pymc.logprob.abstract import get_measurable_meta_info from pymc.testing import assert_no_rvs +from tests.logprob.utils import meta_info_helper def test_argmax(): @@ -185,6 +187,43 @@ def test_max_logprob(shape, value, axis): ) +def test_meta_info_order(): + """Test whether the logprob for ```pt.max``` produces the corrected + + The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here: + U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k) + for all 1<=k<=n + """ + x = pt.random.uniform(0, 1, size=(3,)) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_vv = x_max.clone() + ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(x) + + ndim_supp, supp_axes, measure_type = meta_info_helper(x_max, x_max_vv) + + assert np.isclose( + ndim_supp_base, + ndim_supp, + ) + assert supp_axes_base == supp_axes + + assert measure_type_base == measure_type + + x_min = pt.min(x, axis=-1) + x_min_vv = x_min.clone() + + ndim_supp_min, supp_axes_min, measure_type_min = meta_info_helper(x_min, x_min_vv) + + assert np.isclose( + ndim_supp_base, + ndim_supp_min, + ) + assert supp_axes_base == supp_axes_min + + assert measure_type_base == measure_type_min + + @pytest.mark.parametrize( "shape, value, axis", [ diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index 92ec01f3e0..06fb2eb647 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -44,8 +44,9 @@ from pytensor.scan.utils import ScanArgs from scipy import stats -from pymc.logprob.abstract import _logprob_helper, get_measurable_meta_info +from pymc.logprob.abstract import MeasureType, _logprob_helper from pymc.logprob.basic import conditional_logp, logp +from pymc.logprob.rewriting import construct_ir_fgraph from pymc.logprob.scan import ( construct_scan, convert_outer_out_to_in, @@ -505,7 +506,7 @@ def test_scan_over_seqs(): ) -def test_meta_scan_non_pure_rv_output(): +def test_measure_type_scan_non_pure_rv_output(): grw, _ = pytensor.scan( fn=lambda xtm1: pt.random.normal() + xtm1, outputs_info=[pt.zeros(())], @@ -515,41 +516,41 @@ def test_meta_scan_non_pure_rv_output(): grw1_vv = grw[0].clone() - ndim_supp_1, supp_axes_1, measure_type_1 = meta_info_helper(grw[0], grw1_vv) + # rename meta to measure_type + fgraph, _, _ = construct_ir_fgraph({grw[0]: grw1_vv}) + node = fgraph.outputs[0].owner + ndim_supp = node.inputs[0].owner.op.ndim_supp + supp_axes = node.inputs[0].owner.op.supp_axes + measure_type = node.inputs[0].owner.op.measure_type + assert ndim_supp == (0,) and supp_axes == ((),) and measure_type[0] is MeasureType.Continuous - print(ndim_supp_1) - print(supp_axes_1) - print(measure_type_1) - assert 0 - - -def test_meta_scan_over_seqs(): +def test_measure_type_scan_over_seqs(): """Test that logprob inference for scans based on sequences (mapping).""" rng = np.random.default_rng(543) n_steps = 10 xs = pt.random.normal(size=(n_steps,), name="xs") # use vector with a fixed size ys, _ = pytensor.scan( - fn=lambda x, x1: (pt.random.normal(x), pt.random.poisson(x1)), + fn=lambda x, x1: ( + pt.random.multinomial(x, np.ones(4) / 4), + pt.random.poisson(x1), + ), # use multinomial and poisson sequences=[xs, xs], outputs_info=[None, None], name=("ys1", "ys2"), ) ys1_vv = ys[0].clone() - ys2_vv = ys[1].clone() - - ndim_supp_1, supp_axes_1, measure_type_1 = meta_info_helper(ys[0], ys1_vv) - ndim_supp_2, supp_axes_2, measure_type_2 = meta_info_helper(ys[1], ys2_vv) - print(ndim_supp_1) - print(ndim_supp_2) - print(supp_axes_1) - print(supp_axes_2) - print(measure_type_1) - print(measure_type_2) + ndim_supp, supp_axes, measure_type = meta_info_helper(ys[0], ys1_vv) - assert 0 + if ( + not ndim_supp == (1, 0) + and not supp_axes == ((-1,), ()) + and not isinstance(measure_type[0], MeasureType.Discrete) + and not isinstance(measure_type[1], MeasureType.Discrete) + ): + assert 0 def test_scan_carried_deterministic_state(): diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index 7e64695a2e..2fe87aa150 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -47,7 +47,7 @@ import pymc as pm -from pymc.logprob.abstract import get_measurable_meta_info +from pymc.logprob.abstract import MeasureType, get_measurable_meta_info from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.rewriting import logprob_rewrites_db from pymc.logprob.tensor import naive_bcast_rv_lift @@ -428,14 +428,14 @@ def test_join_mixed_ndim_supp(): (2, 1, 0), # Swap (1, 2, 0), # Swap (0, 1, 2, "x"), # Expand - # ("x", 0, 1, 2), # Expand - # ( - # 0, - # 2, - # ), # Drop - # (2, 0), # Swap and drop + ("x", 0, 1, 2), # Expand + ( + 0, + 2, + ), # Drop + (2, 0), # Swap and drop (2, 1, "x", 0), # Swap and expand - # ("x", 0, 2), # Expand and drop + ("x", 0, 2), # Expand and drop (2, "x", 0), # Swap, expand and drop ], ) @@ -478,25 +478,19 @@ def test_measurable_dimshuffle(ds_order, multivariate): # TODO: seperate test for univariate and matrixNormal @pytensor.config.change_flags(cxx="") @pytest.mark.parametrize( - "ds_order", + "ds_order, ans_ndim_supp, ans_supp_axes, ans_measure_type", [ - # (2, 0, 1), # Swap - # (0, 2, 1), # Swap - # (2, 1, 0), # Swap - # (1, 2, 0), # Swap - # (0, 1, 2, "x"), # Expand - ( - 0, - 2, - ), # Drop - (2, 0), # Swap and drop - (2, 1, "x", 0), # Swap and expand - ("x", 0, 2), # Expand and drop - (2, "x", 0), # Swap, expand and drop + ((0, 2), 1, (-1,), MeasureType.Continuous), # Drop + ((2, 0), 1, (-2,), MeasureType.Continuous), # Swap and drop + ((2, 1, "x", 0), 1, (-4,), MeasureType.Continuous), # Swap and expand + (("x", 0, 2), 1, (-1,), MeasureType.Continuous), # Expand and drop + ((2, "x", 0), 1, (-3,), MeasureType.Continuous), # Swap, expand and drop ], ) @pytest.mark.parametrize("multivariate", (False, True)) -def test_meta_measurable_dimshuffle(ds_order, multivariate): +def test_meta_measurable_dimshuffle( + ds_order, ans_ndim_supp, ans_supp_axes, ans_measure_type, multivariate +): # hardcore the answer in parameter and test if multivariate: base_rv = pm.Dirichlet.dist([1, 1, 1], shape=(7, 1, 3)) @@ -521,8 +515,13 @@ def test_meta_measurable_dimshuffle(ds_order, multivariate): print(ndim_supp) print(supp_axes) print(measure_type) + assert np.isclose( + ndim_supp, + ans_ndim_supp, + ) + assert supp_axes == ans_supp_axes - assert 0 + assert measure_type == ans_measure_type def test_unmeargeable_dimshuffles(): diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index 23fafd33e0..c38e097606 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -42,6 +42,7 @@ from scipy import stats as stats from pymc.logprob import icdf, logcdf, logp +from pymc.logprob.abstract import MeasurableVariable from pymc.logprob.rewriting import construct_ir_fgraph @@ -121,7 +122,8 @@ def meta_info_helper(rv, vv): # ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"]), exclude=) fgraph, _, _ = construct_ir_fgraph({rv: vv}) node = fgraph.outputs[0].owner - pytensor.dprint(node) + # pytensor.dprint(node) + # if isinstance(node, MeasurableVariable): ndim_supp = node.op.ndim_supp supp_axes = node.op.supp_axes measure_type = node.op.measure_type From 5050fc8bb02760ee92514df76e68c744ae13f6b2 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 24 Dec 2023 13:14:50 +0530 Subject: [PATCH 14/19] Measure type information for all logps added --- pymc/logprob/abstract.py | 10 +---- pymc/logprob/binary.py | 6 +-- pymc/logprob/censoring.py | 6 +-- pymc/logprob/checks.py | 12 ++---- pymc/logprob/cumsum.py | 7 +--- pymc/logprob/mixture.py | 15 +++----- pymc/logprob/order.py | 35 ++++++++++------- pymc/logprob/scan.py | 7 +--- pymc/logprob/tensor.py | 43 +++------------------ pymc/logprob/transforms.py | 4 +- tests/logprob/test_binary.py | 10 ++--- tests/logprob/test_censoring.py | 16 ++++---- tests/logprob/test_checks.py | 17 +++------ tests/logprob/test_cumsum.py | 10 ++--- tests/logprob/test_mixture.py | 33 +++------------- tests/logprob/test_order.py | 12 +++--- tests/logprob/test_rewriting.py | 1 - tests/logprob/test_scan.py | 5 +-- tests/logprob/test_tensor.py | 64 ++++++++++++-------------------- tests/logprob/test_transforms.py | 11 ++++-- tests/logprob/utils.py | 19 +--------- 21 files changed, 121 insertions(+), 222 deletions(-) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 6ea67ce9ad..e013bcf419 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -175,16 +175,13 @@ def __init__(self, scalar_op, *args, **kwargs): super().__init__(scalar_op, *args, **kwargs) -def get_measurable_meta_info( +def get_measure_type_info( base_var, ) -> Tuple[ Union[int, Tuple[int]], Tuple[Union[int, Tuple[int]]], Union[MeasureType, Tuple[MeasureType]] ]: from pymc.logprob.utils import DiracDelta - # instead of taking base_op, take base_var as input - # Get base_op from base_var.owner.op - # index= base_var.owner.outputs.index(base_var) gives the output if not isinstance(base_var, MeasurableVariable): base_op = base_var.owner.op index = base_var.owner.outputs.index(base_var) @@ -193,9 +190,6 @@ def get_measurable_meta_info( if not isinstance(base_op, MeasurableVariable): raise TypeError("base_op must be a RandomVariable or MeasurableVariable") - # Add a test for pm.mixture, exponentiate it. Ask for logprob of this as this is not a rv and also does not have ndim_supp and properties. Such a test might exist in distributions. Do check. - # TODO: Handle Symbolic random variables - if isinstance(base_op, DiracDelta): ndim_supp = 0 supp_axes = () @@ -210,7 +204,7 @@ def get_measurable_meta_info( ) return base_op.ndim_supp, supp_axes, measure_type else: - # We'll need this for scan or IfElse + # We'll need this for operators like scan and IfElse if isinstance(base_op.ndim_supp, tuple): if len(base_var.owner.outputs) != len(base_op.ndim_supp): raise NotImplementedError("length of outputs and meta-properties is different") diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index 370a6218e1..5a3ab8ecfa 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -29,7 +29,7 @@ _logcdf_helper, _logprob, _logprob_helper, - get_measurable_meta_info, + get_measure_type_info, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import check_potential_measurability @@ -83,7 +83,7 @@ def find_measurable_comparisons( elif isinstance(node_scalar_op, LE): node_scalar_op = GE() - ndim_supp, supp_axes, _ = get_measurable_meta_info(measurable_var) + ndim_supp, supp_axes, _ = get_measure_type_info(measurable_var) compared_op = MeasurableComparison( scalar_op=node_scalar_op, @@ -157,7 +157,7 @@ def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[list[ return None node_scalar_op = node.op.scalar_op - ndim_supp, supp_axis, measure_type = get_measurable_meta_info(base_var) + ndim_supp, supp_axis, measure_type = get_measure_type_info(base_var) bitwise_op = MeasurableBitwise( scalar_op=node_scalar_op, ndim_supp=ndim_supp, diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index 5c5da5b46e..5aeefd253f 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -53,7 +53,7 @@ MeasureType, _logcdf, _logprob, - get_measurable_meta_info, + get_measure_type_info, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import CheckParameterValue @@ -84,7 +84,7 @@ def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[Te lower_bound = lower_bound if (lower_bound is not base_var) else pt.constant(-np.inf) upper_bound = upper_bound if (upper_bound is not base_var) else pt.constant(np.inf) - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_var) + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_var) if measure_type == MeasureType.Continuous: measure_type = MeasureType.Mixed @@ -179,7 +179,7 @@ def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[lis return None [base_var] = node.inputs - ndim_supp, supp_axis, _ = get_measurable_meta_info(base_var) + ndim_supp, supp_axis, _ = get_measure_type_info(base_var) measure_type = MeasureType.Discrete rounded_op = MeasurableRound( node.op.scalar_op, ndim_supp=ndim_supp, supp_axes=supp_axis, measure_type=measure_type diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index 6a67a1b542..2e8cb34caf 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -47,7 +47,7 @@ MeasurableVariable, _logprob, _logprob_helper, - get_measurable_meta_info, + get_measure_type_info, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import replace_rvs_by_values @@ -57,9 +57,6 @@ class MeasurableSpecifyShape(MeasurableVariable, SpecifyShape): """A placeholder used to specify a log-likelihood for a specify-shape sub-graph.""" -MeasurableVariable.register(MeasurableSpecifyShape) - - @_logprob.register(MeasurableSpecifyShape) def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs): (value,) = values @@ -91,7 +88,7 @@ def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable ): return None # pragma: no cover - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rv) + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_rv) new_op = MeasurableSpecifyShape( ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type @@ -113,9 +110,6 @@ class MeasurableCheckAndRaise(MeasurableVariable, CheckAndRaise): """A placeholder used to specify a log-likelihood for an assert sub-graph.""" -MeasurableVariable.register(MeasurableCheckAndRaise) - - @_logprob.register(MeasurableCheckAndRaise) def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs): (value,) = values @@ -142,7 +136,7 @@ def find_measurable_check_and_raise(fgraph, node) -> Optional[list[TensorVariabl return None op = node.op - ndim_supp, supp_axis, d_type = get_measurable_meta_info(base_rv) + ndim_supp, supp_axis, d_type = get_measure_type_info(base_rv) new_op = MeasurableCheckAndRaise( exc_type=op.exc_type, msg=op.msg, diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index bfed13e5fa..4b7fb10a13 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -46,7 +46,7 @@ MeasurableVariable, _logprob, _logprob_helper, - get_measurable_meta_info, + get_measure_type_info, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db @@ -55,9 +55,6 @@ class MeasurableCumsum(MeasurableVariable, CumOp): """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" -MeasurableVariable.register(MeasurableCumsum) - - @_logprob.register(MeasurableCumsum) def logprob_cumsum(op, values, base_rv, **kwargs): """Compute the log-likelihood graph for a `Cumsum`.""" @@ -106,7 +103,7 @@ def find_measurable_cumsums(fgraph, node) -> Optional[list[TensorVariable]]: if not rv_map_feature.request_measurable(node.inputs): return None - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_rv) + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_rv) new_op = MeasurableCumsum( axis=node.op.axis or 0, mode="add", diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 3de5ade38e..e1a90bd72b 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -71,7 +71,7 @@ MeasureType, _logprob, _logprob_helper, - get_measurable_meta_info, + get_measure_type_info, ) from pymc.logprob.rewriting import ( PreserveRVMappings, @@ -238,9 +238,6 @@ def perform(self, node, inputs, outputs): raise NotImplementedError("This is a stand-in Op.") # pragma: no cover -MeasurableVariable.register(MixtureRV) - - def get_stack_mixture_vars( node: Apply, ) -> tuple[Optional[list[TensorVariable]], Optional[int]]: @@ -311,7 +308,7 @@ def find_measurable_index_mixture(fgraph, node): all_supp_axes = [] all_measure_type = [] for i in range(0, len(mixture_rvs)): - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(mixture_rvs[i]) + ndim_supp, supp_axes, measure_type = get_measure_type_info(mixture_rvs[i]) all_ndim_supp.append(ndim_supp) all_supp_axes.append(supp_axes) all_measure_type.append(measure_type) @@ -451,7 +448,7 @@ def find_measurable_switch_mixture(fgraph, node): all_supp_axes = [] all_measure_type = [] for i in range(0, len(components)): - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(components[i]) + ndim_supp, supp_axes, measure_type = get_measure_type_info(components[i]) all_ndim_supp.append(ndim_supp) all_supp_axes.append(supp_axes) all_measure_type.append(measure_type) @@ -461,8 +458,6 @@ def find_measurable_switch_mixture(fgraph, node): else: m_type = MeasureType.Mixed - # ndim_supp, supp_axes, measure_type = get_measurable_meta_info(components[0]) - measurable_switch_mixture = MeasurableSwitchMixture( scalar_switch, ndim_supp=all_ndim_supp[0], supp_axes=all_supp_axes[0], measure_type=m_type ) @@ -559,8 +554,8 @@ def find_measurable_ifelse_mixture(fgraph, node): length = len(base_rvs) for base_rv1, base_rv2 in zip(base_rvs[0:half_len], base_rvs[half_len + 1 : length - 1]): - meta_info = get_measurable_meta_info(base_rv1) - if meta_info != get_measurable_meta_info(base_rv2): + meta_info = get_measure_type_info(base_rv1) + if meta_info != get_measure_type_info(base_rv2): return None ndim_supp, supp_axes, measure_type = meta_info ndim_supp_all += (ndim_supp,) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 2ae31261c5..bd2075466f 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -51,7 +51,7 @@ _logcdf_helper, _logprob, _logprob_helper, - get_measurable_meta_info, + get_measure_type_info, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db from pymc.logprob.utils import find_negated_var @@ -62,8 +62,6 @@ class MeasurableMax(MeasurableVariable, Max): """A placeholder used to specify a log-likelihood for a max sub-graph.""" - # def __init__(self, *args, **kwargs): - # super().__init__(*args, **kwargs) def clone(self, **kwargs): axis = kwargs.get("axis", self.axis) ndim_supp = kwargs.get("ndim_supp", self.ndim_supp) @@ -74,14 +72,17 @@ def clone(self, **kwargs): ) -# MeasurableVariable.register(MeasurableMax) remove this for all - - -class MeasurableMaxDiscrete(Max): +class MeasurableMaxDiscrete(MeasurableVariable, Max): """A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables""" - -MeasurableVariable.register(MeasurableMaxDiscrete) + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + ndim_supp = kwargs.get("ndim_supp", self.ndim_supp) + supp_axes = kwargs.get("supp_axes", self.supp_axes) + measure_type = kwargs.get("measure_type", self.measure_type) + return type(self)( + axis=axis, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) @node_rewriter([Max]) @@ -116,11 +117,17 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[list[Tens if axis != base_var_dims: return None + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_var) + # distinguish measurable discrete and continuous (because logprob is different) if base_var.owner.op.dtype.startswith("int"): - measurable_max = MeasurableMaxDiscrete(list(axis)) + measurable_max = MeasurableMaxDiscrete( + axis=list(axis), ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) else: - measurable_max = MeasurableMax(list(axis)) + measurable_max = MeasurableMax( + axis=list(axis), ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) max_rv_node = measurable_max.make_node(base_var) max_rv = max_rv_node.outputs @@ -170,7 +177,7 @@ def max_logprob_discrete(op, values, base_rv, **kwargs): return logprob -class MeasurableMaxNeg(Max): +class MeasurableMaxNeg(MeasurableVariable, Max): """A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph. This shows up in the graph of min, which is (neg(max(neg(x))).""" @@ -230,8 +237,8 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[ if not rv_map_feature.request_measurable([base_rv]): return None - - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_var) + + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_var) # distinguish measurable discrete and continuous (because logprob is different) if base_rv.owner.op.dtype.startswith("int"): diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index d87f574af2..62f10a3267 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -54,7 +54,7 @@ from pytensor.tensor.variable import TensorVariable from pytensor.updates import OrderedUpdates -from pymc.logprob.abstract import MeasurableVariable, _logprob, get_measurable_meta_info +from pymc.logprob.abstract import MeasurableVariable, _logprob, get_measure_type_info from pymc.logprob.basic import conditional_logp from pymc.logprob.rewriting import ( PreserveRVMappings, @@ -73,9 +73,6 @@ def __str__(self): return f"Measurable({super().__str__()})" -MeasurableVariable.register(MeasurableScan) - - def convert_outer_out_to_in( input_scan_args: ScanArgs, outer_out_vars: Iterable[TensorVariable], @@ -482,7 +479,7 @@ def find_measurable_scans(fgraph, node): if var.owner.op is None: continue if isinstance(var.owner.op, MeasurableVariable): - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(var) + ndim_supp, supp_axes, measure_type = get_measure_type_info(var) all_ndim_supp += (ndim_supp,) all_supp_axes += (supp_axes,) all_measure_type += (measure_type,) diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index fd15c55019..bbb28d47e9 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -54,7 +54,7 @@ MeasurableVariable, _logprob, _logprob_helper, - get_measurable_meta_info, + get_measure_type_info, ) from pymc.logprob.rewriting import ( PreserveRVMappings, @@ -131,9 +131,6 @@ class MeasurableMakeVector(MeasurableVariable, MakeVector): """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" -MeasurableVariable.register(MeasurableMakeVector) - - @_logprob.register(MeasurableMakeVector) def logprob_make_vector(op, values, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `MeasurableMakeVector`.""" @@ -158,9 +155,6 @@ class MeasurableJoin(MeasurableVariable, Join): """A placeholder used to specify a log-likelihood for a join sub-graph.""" -MeasurableVariable.register(MeasurableJoin) - - @_logprob.register(MeasurableJoin) def logprob_join(op, values, axis, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `Join`.""" @@ -226,7 +220,7 @@ def find_measurable_stacks(fgraph, node) -> Optional[list[TensorVariable]]: if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_vars): return None - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_vars[0]) + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_vars[0]) if is_join: measurable_stack = MeasurableJoin( @@ -248,9 +242,6 @@ class MeasurableDimShuffle(MeasurableVariable, DimShuffle): c_func_file = DimShuffle.get_path(DimShuffle.c_func_file) -MeasurableVariable.register(MeasurableDimShuffle) - - @_logprob.register(MeasurableDimShuffle) def logprob_dimshuffle(op, values, base_var, **kwargs): """Compute the log-likelihood graph for a `MeasurableDimShuffle`.""" @@ -279,7 +270,6 @@ def logprob_dimshuffle(op, values, base_var, **kwargs): # indexes in the original dimshuffle order. Otherwise, there is no way of # knowing which dimensions were consumed by the logprob function. redo_ds = [o for o in op.new_order if o == "x" or o < raw_logp.ndim] - # pytensor.dprint(values[0].shape) return raw_logp.dimshuffle(redo_ds) @@ -288,7 +278,7 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[list[TensorVariable]]: r"""Finds `Dimshuffle`\s for which a `logprob` can be computed.""" rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - # pytensor.dprint(fgraph) + if rv_map_feature is None: return None # pragma: no cover @@ -308,40 +298,19 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[list[TensorVariable]]: if not isinstance(base_var.owner.op, RandomVariable): return None # pragma: no cover - # parameter for drop and expand exists in dimshuffle - # if(len(node.op.new_order) != len(list(base_var.owner.inputs[1]))+1): # if there is expand/drop, we fails - # return None - - # use base_var.type.ndim instead of base_var.owner.inputs[1]) - # ref = list(range(0, len(list(base_var.owner.inputs[1]))+1)) # creating reference list : [0, 1, 2] ref = list(range(0, base_var.type.ndim)) - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(base_var) # (-1,) - new_supp_axes = list(supp_axes) # [-1] use empty list and append ( eliminate x ) and sort + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_var) + new_supp_axes = list(supp_axes) - # check that if dropped dimensions is the supp axes - # add test for this case in cases that already failed for x in supp_axes: if base_var.type.ndim + x not in node.op.new_order: return None - # for x in range(0, len(new_supp_axes)): - # i = new_supp_axes[x] # i = -1 - # # print("a") - # # print(i) - # shift = ref[i] - node.op.new_order.index(ref[i]) # [0, 1, 2] and [2, 0, 1] : shift = 2-0 = 2 - - # # [0, 2 , 1] supp_axes = -2 - - # # node.op.new_order.index(ref[i]) from reverse - # # -(no.of dim - node.op.new_order.index(ref[i]) from reverse) - # new_supp_axes[x] = i-shift # supp_axis = -1-2 = -3 # [-3] - - # list comprehension for x in new_supp_axes: new_supp_axes[x] = node.op.new_order.index(ref[x]) - node.outputs[0].type.ndim - supp_axes = tuple(new_supp_axes) # (-3,) + supp_axes = tuple(new_supp_axes) measurable_dimshuffle = MeasurableDimShuffle( node.op.input_broadcastable, diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 68d65bc503..1610e3b3b9 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -115,7 +115,7 @@ _logcdf_helper, _logprob, _logprob_helper, - get_measurable_meta_info, + get_measure_type_info, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import ( @@ -502,7 +502,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[li transform_args_fn=lambda *inputs: inputs[-1], ) - ndim_supp, supp_axes, measure_type = get_measurable_meta_info(measurable_input) + ndim_supp, supp_axes, measure_type = get_measure_type_info(measurable_input) transform_op = MeasurableTransform( scalar_op=scalar_op, transform=transform, diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py index 273a7994f1..cbdd5e6902 100644 --- a/tests/logprob/test_binary.py +++ b/tests/logprob/test_binary.py @@ -21,9 +21,9 @@ from pymc import logp from pymc.logprob import conditional_logp -from pymc.logprob.abstract import MeasureType, get_measurable_meta_info +from pymc.logprob.abstract import MeasureType, get_measure_type_info from pymc.testing import assert_no_rvs -from tests.logprob.utils import meta_info_helper +from tests.logprob.utils import measure_type_info_helper @pytest.mark.parametrize( @@ -71,7 +71,7 @@ def test_continuous_rv_comparison_bitwise(comparison_op, exp_logp_true, exp_logp ((pt.gt, pt.ge), "logcdf", "logsf", (0.5, pt.random.normal(0, 1))), ], ) -def test_meta_info(comparison_op, exp_logp_true, exp_logp_false, inputs): +def test_measure_type_info(comparison_op, exp_logp_true, exp_logp_false, inputs): for op in comparison_op: comp_x_rv = op(*inputs) @@ -81,9 +81,9 @@ def test_meta_info(comparison_op, exp_logp_true, exp_logp_false, inputs): base_rv = inputs[0] comp_x_vv = comp_x_rv.clone() - ndim_supp, supp_axes, measure_type = meta_info_helper(comp_x_rv, comp_x_vv) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(comp_x_rv, comp_x_vv) - ndim_supp_base, supp_axes_base, _ = get_measurable_meta_info(base_rv) + ndim_supp_base, supp_axes_base, _ = get_measure_type_info(base_rv) assert np.isclose( ndim_supp_base, diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index b06536377e..daced03775 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -43,10 +43,11 @@ from pymc import logp from pymc.logprob import conditional_logp +from pymc.logprob.abstract import get_measure_type_info from pymc.logprob.transform_value import TransformValuesRewrite from pymc.logprob.transforms import LogTransform from pymc.testing import assert_no_rvs -from tests.logprob.utils import meta_info_helper +from tests.logprob.utils import measure_type_info_helper @pytensor.config.change_flags(compute_test_value="raise") @@ -75,8 +76,7 @@ def test_continuous_rv_clip(): "measure_type", [("Discrete"), ("Continuous")], ) -def test_clip_meta_info(measure_type): - # use true and false +def test_clip_measure_type_info(measure_type): if measure_type == "Continuous": base_rv = pt.random.normal(0.5, 1) rv = pt.clip(base_rv, -2, 2) @@ -86,9 +86,9 @@ def test_clip_meta_info(measure_type): vv = rv.clone() - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv) + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(base_rv) - ndim_supp, supp_axes, measure_type = meta_info_helper(rv, vv) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(rv, vv) assert np.isclose( ndim_supp_base, @@ -294,7 +294,7 @@ def test_rounding(rounding_op): @pytest.mark.parametrize("rounding_op", (pt.round, pt.floor, pt.ceil)) -def test_round_meta_info(rounding_op): +def test_round_measure_type_info(rounding_op): loc = 1 scale = 2 test_value = np.arange(-3, 4) @@ -304,9 +304,9 @@ def test_round_meta_info(rounding_op): xr.name = "xr" xr_vv = xr.clone() - ndim_supp, supp_axes, measure_type = meta_info_helper(xr, xr_vv) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(xr, xr_vv) - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(x) + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(x) assert np.isclose( ndim_supp_base, diff --git a/tests/logprob/test_checks.py b/tests/logprob/test_checks.py index 0f2c517690..bc1948de88 100644 --- a/tests/logprob/test_checks.py +++ b/tests/logprob/test_checks.py @@ -35,23 +35,19 @@ # SOFTWARE. import re -from collections import deque - import numpy as np import pytensor import pytensor.tensor as pt import pytest -from pytensor.graph.basic import io_toposort from pytensor.raise_op import Assert from scipy import stats from pymc.distributions import Dirichlet -from pymc.logprob.abstract import get_measurable_meta_info +from pymc.logprob.abstract import get_measure_type_info from pymc.logprob.basic import conditional_logp -from pymc.logprob.rewriting import construct_ir_fgraph from tests.distributions.test_multivariate import dirichlet_logpdf -from tests.logprob.utils import meta_info_helper +from tests.logprob.utils import measure_type_info_helper def test_specify_shape_logprob(): @@ -90,11 +86,10 @@ def test_shape_meta_info(): x_rv = pt.specify_shape(x_base, shape=(5, 3)) x_rv.name = "x" - # 2. Request logp x_vv = x_rv.clone() - ndim_supp, supp_axes, measure_type = meta_info_helper(x_rv, x_vv) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(x_rv, x_vv) - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(x_base) + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(x_base) assert np.isclose( ndim_supp_base, @@ -137,9 +132,9 @@ def test_assert_meta_info(): assert_rv.name = "assert_rv" assert_vv = assert_rv.clone() - ndim_supp, supp_axes, measure_type = meta_info_helper(assert_rv, assert_vv) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(assert_rv, assert_vv) - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(rv) + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(rv) assert np.isclose( ndim_supp_base, diff --git a/tests/logprob/test_cumsum.py b/tests/logprob/test_cumsum.py index 2e1b488c91..ddfba864dc 100644 --- a/tests/logprob/test_cumsum.py +++ b/tests/logprob/test_cumsum.py @@ -41,10 +41,10 @@ import scipy.stats as st from pymc import logp -from pymc.logprob.abstract import get_measurable_meta_info +from pymc.logprob.abstract import get_measure_type_info from pymc.logprob.basic import conditional_logp from pymc.testing import assert_no_rvs -from tests.logprob.utils import meta_info_helper +from tests.logprob.utils import measure_type_info_helper @pytest.mark.parametrize( @@ -83,14 +83,14 @@ def test_normal_cumsum(size, axis): ((3, 2, 10), 2), ], ) -def test_meta_info(size, axis): +def test_measure_type_info(size, axis): rv = pt.random.normal(0, 1, size=size).cumsum(axis) vv = rv.clone() base_rv = pt.random.normal(0, 1, size=size) base_vv = base_rv.clone() - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv) + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(base_rv) - ndim_supp, supp_axes, measure_type = meta_info_helper(rv, vv) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(rv, vv) assert np.isclose( ndim_supp_base, diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 2271309c23..538a127644 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -52,18 +52,13 @@ as_index_constant, ) -# from pymc.distributions.distribution import diracdelta -from pymc.logprob.abstract import ( - MeasurableVariable, - MeasureType, - get_measurable_meta_info, -) +from pymc.logprob.abstract import MeasurableVariable, MeasureType, get_measure_type_info from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.mixture import MeasurableSwitchMixture, expand_indices from pymc.logprob.rewriting import construct_ir_fgraph from pymc.logprob.utils import dirac_delta as diracdelta from pymc.testing import assert_no_rvs -from tests.logprob.utils import meta_info_helper, scipy_logprob +from tests.logprob.utils import measure_type_info_helper, scipy_logprob def test_mixture_basics(): @@ -942,7 +937,7 @@ def test_switch_mixture(): np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 1})) -def test_meta_switch_mixture(): +def test_measure_type_info_switch_mixture(): srng = pt.random.RandomStream(29833) X_rv = srng.normal(-10.0, 0.1, name="X") @@ -1001,22 +996,6 @@ def test_switch_mixture_vector(switch_cond_scalar): ) -# pytest.mark.parametrize("switch_cond_scalar", (True, False)) -# def test_meta_switch_mixture_vector(switch_cond_scalar): -# if switch_cond_scalar: -# switch_cond = pt.scalar("switch_cond", dtype=bool) -# else: -# switch_cond = pt.vector("switch_cond", dtype=bool) -# true_branch = pt.exp(pt.random.normal(size=(4,))) -# false_branch = pt.abs(pt.random.normal(size=(4,))) - -# switch = pt.switch(switch_cond, true_branch, false_branch) -# switch.name = "switch_mix" -# switch_value = switch.clone() - -# ndim_supp, supp_axes, measure_type = meta_info_helper(switch, switch_value) - - def test_switch_mixture_measurable_cond_fails(): """Test that logprob inference fails when the switch condition is an unvalued measurable variable. @@ -1082,7 +1061,7 @@ def test_ifelse_mixture_one_component(): ) -def test_meta_ifelse(): +def test_measure_type_ifelse(): if_rv = pt.random.bernoulli(0.5, name="if") scale_rv = pt.random.halfnormal(name="scale") comp_then = pt.random.normal(0, scale_rv, size=(2,), name="comp_then") @@ -1093,9 +1072,9 @@ def test_meta_ifelse(): scale_vv = scale_rv.clone() mix_vv = mix_rv.clone() - ndim_supp, supp_axes, measure_type = meta_info_helper(mix_rv, mix_vv) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(mix_rv, mix_vv) - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(comp_then) + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(comp_then) assert ndim_supp_base == 0 assert supp_axes_base == () diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 22fc529eff..dff4fe2273 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -45,9 +45,9 @@ import pymc as pm from pymc import logp -from pymc.logprob.abstract import get_measurable_meta_info +from pymc.logprob.abstract import get_measure_type_info from pymc.testing import assert_no_rvs -from tests.logprob.utils import meta_info_helper +from tests.logprob.utils import measure_type_info_helper def test_argmax(): @@ -187,7 +187,7 @@ def test_max_logprob(shape, value, axis): ) -def test_meta_info_order(): +def test_measure_type_info_order(): """Test whether the logprob for ```pt.max``` produces the corrected The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here: @@ -198,9 +198,9 @@ def test_meta_info_order(): x.name = "x" x_max = pt.max(x, axis=-1) x_max_vv = x_max.clone() - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(x) + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(x) - ndim_supp, supp_axes, measure_type = meta_info_helper(x_max, x_max_vv) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(x_max, x_max_vv) assert np.isclose( ndim_supp_base, @@ -213,7 +213,7 @@ def test_meta_info_order(): x_min = pt.min(x, axis=-1) x_min_vv = x_min.clone() - ndim_supp_min, supp_axes_min, measure_type_min = meta_info_helper(x_min, x_min_vv) + ndim_supp_min, supp_axes_min, measure_type_min = measure_type_info_helper(x_min, x_min_vv) assert np.isclose( ndim_supp_base, diff --git a/tests/logprob/test_rewriting.py b/tests/logprob/test_rewriting.py index a7a208866c..66c28b102d 100644 --- a/tests/logprob/test_rewriting.py +++ b/tests/logprob/test_rewriting.py @@ -55,7 +55,6 @@ from pymc.logprob.rewriting import cleanup_ir, local_lift_DiracDelta from pymc.logprob.transform_value import TransformedValue, TransformValuesRewrite from pymc.logprob.utils import DiracDelta, dirac_delta -from tests.logprob.utils import scipy_logprob def test_local_lift_DiracDelta(): diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index 06fb2eb647..f051b497bd 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -53,7 +53,7 @@ get_random_outer_outputs, ) from pymc.testing import assert_no_rvs -from tests.logprob.utils import meta_info_helper +from tests.logprob.utils import measure_type_info_helper def create_inner_out_logp(value_map): @@ -516,7 +516,6 @@ def test_measure_type_scan_non_pure_rv_output(): grw1_vv = grw[0].clone() - # rename meta to measure_type fgraph, _, _ = construct_ir_fgraph({grw[0]: grw1_vv}) node = fgraph.outputs[0].owner ndim_supp = node.inputs[0].owner.op.ndim_supp @@ -542,7 +541,7 @@ def test_measure_type_scan_over_seqs(): ) ys1_vv = ys[0].clone() - ndim_supp, supp_axes, measure_type = meta_info_helper(ys[0], ys1_vv) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(ys[0], ys1_vv) if ( not ndim_supp == (1, 0) diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index 2fe87aa150..0bc620fb86 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -47,16 +47,12 @@ import pymc as pm -from pymc.logprob.abstract import MeasureType, get_measurable_meta_info +from pymc.logprob.abstract import MeasureType, get_measure_type_info from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.rewriting import logprob_rewrites_db from pymc.logprob.tensor import naive_bcast_rv_lift from pymc.testing import assert_no_rvs -from tests.logprob.utils import ( - get_measurable_meta_infos, - meta_info_helper, - scipy_logprob, -) +from tests.logprob.utils import measure_type_info_helper, scipy_logprob def test_naive_bcast_rv_lift(): @@ -142,23 +138,23 @@ def test_measurable_make_vector(): assert np.isclose(make_vector_logp_eval.sum(), ref_logp_eval_eval) -def test_meta_make_vector(): +def test_measure_type_make_vector(): base1_rv = pt.random.normal(name="base1") base2_rv = pt.random.halfnormal(name="base2") base3_rv = pt.random.exponential(name="base3") y_rv = pt.stack((base1_rv, base2_rv, base3_rv)) y_rv.name = "y" - ndim_supp_base_1, supp_axes_base_1, measure_type_base_1 = get_measurable_meta_info(base1_rv) - ndim_supp_base_2, supp_axes_base_2, measure_type_base_2 = get_measurable_meta_info(base2_rv) - ndim_supp_base_3, supp_axes_base_3, measure_type_base_3 = get_measurable_meta_info(base3_rv) + ndim_supp_base_1, supp_axes_base_1, measure_type_base_1 = get_measure_type_info(base1_rv) + ndim_supp_base_2, supp_axes_base_2, measure_type_base_2 = get_measure_type_info(base2_rv) + ndim_supp_base_3, supp_axes_base_3, measure_type_base_3 = get_measure_type_info(base3_rv) base1_vv = base1_rv.clone() base2_vv = base2_rv.clone() base3_vv = base3_rv.clone() y_vv = y_rv.clone() - ndim_supp, supp_axes, measure_type = meta_info_helper(y_rv, y_vv) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(y_rv, y_vv) assert np.isclose( ndim_supp_base_1, @@ -317,7 +313,7 @@ def test_measurable_join_univariate(size1, size2, axis, concatenate): ((2, 5), (2, 5), 2, False), ], ) -def test_meta_join_univariate(size1, size2, axis, concatenate): +def test_measure_type_join_univariate(size1, size2, axis, concatenate): base1_rv = pt.random.normal(size=size1, name="base1") base2_rv = pt.random.exponential(size=size2, name="base2") if concatenate: @@ -326,11 +322,11 @@ def test_meta_join_univariate(size1, size2, axis, concatenate): y_rv = pt.stack((base1_rv, base2_rv), axis=axis) y_rv.name = "y" - ndim_supp_base_1, supp_axes_base_1, measure_type_base_1 = get_measurable_meta_info(base1_rv) - ndim_supp_base_2, supp_axes_base_2, measure_type_base_2 = get_measurable_meta_info(base2_rv) + ndim_supp_base_1, supp_axes_base_1, measure_type_base_1 = get_measure_type_info(base1_rv) + ndim_supp_base_2, supp_axes_base_2, measure_type_base_2 = get_measure_type_info(base2_rv) y_vv = y_rv.clone() - ndim_supp, supp_axes, measure_type = meta_info_helper(y_rv, y_vv) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(y_rv, y_vv) assert np.isclose( ndim_supp_base_1, @@ -475,26 +471,24 @@ def test_measurable_dimshuffle(ds_order, multivariate): np.testing.assert_array_equal(ref_logp_fn(base_test_value), ds_logp_fn(ds_test_value)) -# TODO: seperate test for univariate and matrixNormal @pytensor.config.change_flags(cxx="") @pytest.mark.parametrize( - "ds_order, ans_ndim_supp, ans_supp_axes, ans_measure_type", + "multivariate, ds_order, ans_ndim_supp, ans_supp_axes, ans_measure_type", [ - ((0, 2), 1, (-1,), MeasureType.Continuous), # Drop - ((2, 0), 1, (-2,), MeasureType.Continuous), # Swap and drop - ((2, 1, "x", 0), 1, (-4,), MeasureType.Continuous), # Swap and expand - (("x", 0, 2), 1, (-1,), MeasureType.Continuous), # Expand and drop - ((2, "x", 0), 1, (-3,), MeasureType.Continuous), # Swap, expand and drop + (True, (0, 2), 1, (-1,), MeasureType.Continuous), # Drop + (True, (2, 0), 1, (-2,), MeasureType.Continuous), # Swap and drop + (True, (2, 1, "x", 0), 1, (-4,), MeasureType.Continuous), # Swap and expand + (True, ("x", 0, 2), 1, (-1,), MeasureType.Continuous), # Expand and drop + (True, (2, "x", 0), 1, (-3,), MeasureType.Continuous), # Swap, expand and drop + (False, (0, 2), 0, (), MeasureType.Continuous), + (False, (2, 1, "x", 0), 0, (), MeasureType.Continuous), ], ) -@pytest.mark.parametrize("multivariate", (False, True)) -def test_meta_measurable_dimshuffle( - ds_order, ans_ndim_supp, ans_supp_axes, ans_measure_type, multivariate +def test_measure_type_dimshuffle( + multivariate, ds_order, ans_ndim_supp, ans_supp_axes, ans_measure_type ): - # hardcore the answer in parameter and test if multivariate: base_rv = pm.Dirichlet.dist([1, 1, 1], shape=(7, 1, 3)) - # base_rv = pt.random.dirichlet([1, 2, 3], size=(2, 1)) ds_rv = base_rv.dimshuffle(ds_order) base_vv = base_rv.clone() @@ -502,23 +496,13 @@ def test_meta_measurable_dimshuffle( base_rv = pt.random.beta(1, 2, size=(2, 1, 3)) base_rv_1 = pt.exp(base_rv) ds_rv = base_rv_1.dimshuffle(ds_order) - base_vv = base_rv_1.clone() ds_vv = ds_rv.clone() - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv) - print(ndim_supp_base) - print(supp_axes_base) - print(measure_type_base) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(ds_rv, ds_vv) + + assert ndim_supp == ans_ndim_supp - ndim_supp, supp_axes, measure_type = meta_info_helper(ds_rv, ds_vv) - print(ndim_supp) - print(supp_axes) - print(measure_type) - assert np.isclose( - ndim_supp, - ans_ndim_supp, - ) assert supp_axes == ans_supp_axes assert measure_type == ans_measure_type diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index de4fb7e049..c026b40519 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -44,7 +44,11 @@ from pytensor.graph.basic import equal_computations from pymc.distributions.continuous import Cauchy, ChiSquared +<<<<<<< HEAD from pymc.distributions.discrete import Bernoulli +======= +from pymc.logprob.abstract import get_measure_type_info +>>>>>>> Measure type information for all logps added from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp from pymc.logprob.transforms import ( ArccoshTransform, @@ -66,6 +70,7 @@ from pymc.logprob.utils import ParameterValueError from pymc.testing import Rplusbig, Vector, assert_no_rvs from tests.distributions.test_transform import check_jacobian_det +from tests.logprob.utils import measure_type_info_helper class DirichletScipyDist: @@ -229,16 +234,16 @@ def test_exp_transform_rv(): ) -def test_meta_exp_transform_rv(): +def test_measure_type_exp_transform_rv(): base_rv = pt.random.normal(0, 1, size=3, name="base_rv") y_rv = pt.exp(base_rv) y_rv.name = "y" y_vv = y_rv.clone() - ndim_supp_base, supp_axes_base, measure_type_base = get_measurable_meta_info(base_rv) + ndim_supp_base, supp_axes_base, measure_type_base = get_measure_type_info(base_rv) - ndim_supp, supp_axes, measure_type = meta_info_helper(y_rv, y_vv) + ndim_supp, supp_axes, measure_type = measure_type_info_helper(y_rv, y_vv) assert np.isclose( ndim_supp_base, diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index c38e097606..e2640c672d 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -117,27 +117,12 @@ def scipy_logprob_tester( np.testing.assert_array_almost_equal(pytensor_res_val, numpy_res, 4) -def meta_info_helper(rv, vv): - # pytensor.config.optimizer_verbose=True - # ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"]), exclude=) +def measure_type_info_helper(rv, vv): + """Extract measurable information from rv""" fgraph, _, _ = construct_ir_fgraph({rv: vv}) node = fgraph.outputs[0].owner - # pytensor.dprint(node) - # if isinstance(node, MeasurableVariable): ndim_supp = node.op.ndim_supp supp_axes = node.op.supp_axes measure_type = node.op.measure_type return ndim_supp, supp_axes, measure_type - - -def get_measurable_meta_infos( - base_op, -): - # if not isinstance(base_op, MeasurableVariable): - # raise TypeError("base_op must be a RandomVariable or MeasurableVariable") - - # if isinstance(base_op, RandomVariable): - ndim_supp = base_op.ndim_supp - supp_axes = tuple(range(-ndim_supp, 0)) - return base_op.ndim_supp, supp_axes From 52e9e6f72f971dddce01ae07cf1a6e92b85db3f8 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 24 Dec 2023 13:30:30 +0530 Subject: [PATCH 15/19] Measure type information for all logps added --- tests/logprob/test_censoring.py | 9 ++++++--- tests/logprob/test_checks.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index daced03775..ee535c96b1 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -43,7 +43,7 @@ from pymc import logp from pymc.logprob import conditional_logp -from pymc.logprob.abstract import get_measure_type_info +from pymc.logprob.abstract import MeasureType, get_measure_type_info from pymc.logprob.transform_value import TransformValuesRewrite from pymc.logprob.transforms import LogTransform from pymc.testing import assert_no_rvs @@ -96,7 +96,10 @@ def test_clip_measure_type_info(measure_type): ) assert supp_axes_base == supp_axes - assert measure_type_base == measure_type + if measure_type_base == MeasureType.Continuous: + assert measure_type_base != measure_type + else: + assert measure_type_base == measure_type def test_discrete_rv_clip(): @@ -314,4 +317,4 @@ def test_round_measure_type_info(rounding_op): ) assert supp_axes_base == supp_axes - assert str(measure_type) == "MeasureType.Discrete" + assert measure_type == MeasureType.Discrete diff --git a/tests/logprob/test_checks.py b/tests/logprob/test_checks.py index bc1948de88..607352d6f9 100644 --- a/tests/logprob/test_checks.py +++ b/tests/logprob/test_checks.py @@ -79,7 +79,7 @@ def test_specify_shape_logprob(): x_logp_fn(last_dim=1, x=x_vv_test_invalid) -def test_shape_meta_info(): +def test_shape_measure_type_info(): last_dim = pt.scalar(name="last_dim", dtype="int64") x_base = Dirichlet.dist(pt.ones((last_dim,)), shape=(5, last_dim)) x_base.name = "x" @@ -124,7 +124,7 @@ def test_assert_logprob(): assert_logp.eval({assert_vv: -5.0}) -def test_assert_meta_info(): +def test_assert_measure_type_info(): rv = pt.random.normal() assert_op = Assert("Test assert") # Example: Add assert that rv must be positive From 203ab89fa61d72bd124e9f0b102fd78708600b05 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Tue, 9 Jan 2024 17:09:20 +0530 Subject: [PATCH 16/19] Solved Return type Error --- pymc/logprob/abstract.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index e013bcf419..ffd517edbc 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -177,9 +177,7 @@ def __init__(self, scalar_op, *args, **kwargs): def get_measure_type_info( base_var, -) -> Tuple[ - Union[int, Tuple[int]], Tuple[Union[int, Tuple[int]]], Union[MeasureType, Tuple[MeasureType]] -]: +): from pymc.logprob.utils import DiracDelta if not isinstance(base_var, MeasurableVariable): From 24c4d8bc1ef1a1e0974cdb152f969fb84c62528a Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Fri, 22 Mar 2024 04:24:09 +0530 Subject: [PATCH 17/19] Solving merge conflicts --- pymc/logprob/order.py | 23 ++++++++++++++++++----- tests/logprob/test_transforms.py | 3 --- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index bd2075466f..06e5127e29 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -191,9 +191,18 @@ def clone(self, **kwargs): ) -class MeasurableDiscreteMaxNeg(Max): +class MeasurableDiscreteMaxNeg(MeasurableVariable, Max): """A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables""" + def clone(self, **kwargs): + axis = kwargs.get("axis", self.axis) + ndim_supp = kwargs.get("ndim_supp", self.ndim_supp) + supp_axes = kwargs.get("supp_axes", self.supp_axes) + measure_type = kwargs.get("measure_type", self.measure_type) + return type(self)( + axis=axis, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) + MeasurableVariable.register(MeasurableDiscreteMaxNeg) @@ -237,14 +246,18 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[ if not rv_map_feature.request_measurable([base_rv]): return None - - ndim_supp, supp_axes, measure_type = get_measure_type_info(base_var) + + ndim_supp, supp_axes, measure_type = get_measure_type_info(base_rv) # distinguish measurable discrete and continuous (because logprob is different) if base_rv.owner.op.dtype.startswith("int"): - measurable_min = MeasurableDiscreteMaxNeg(list(axis), ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type) + measurable_min = MeasurableDiscreteMaxNeg( + list(axis), ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) else: - measurable_min = MeasurableMaxNeg(list(axis), ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type) + measurable_min = MeasurableMaxNeg( + list(axis), ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type + ) return measurable_min.make_node(base_rv).outputs diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index c026b40519..4eb3a6b7c0 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -44,11 +44,8 @@ from pytensor.graph.basic import equal_computations from pymc.distributions.continuous import Cauchy, ChiSquared -<<<<<<< HEAD from pymc.distributions.discrete import Bernoulli -======= from pymc.logprob.abstract import get_measure_type_info ->>>>>>> Measure type information for all logps added from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp from pymc.logprob.transforms import ( ArccoshTransform, From 8364f34500fa0baa765d0a39f459d3e154a594e1 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Fri, 22 Mar 2024 04:31:05 +0530 Subject: [PATCH 18/19] Solving merge conflicts --- pymc/logprob/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 49827f7a61..c72d13f9d3 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -261,7 +261,7 @@ def local_check_parameter_to_ninf_switch(fgraph, node): class DiracDelta(Op): - """An `Op` that represents a Dirac-delta distribution.""" + """An `Op` that represents a Dirac-Delta distribution.""" __props__ = ("rtol", "atol") From ddcda2f8ec1151f9fc3ccf871517f5d0d5b57920 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Fri, 29 Mar 2024 16:29:35 +0530 Subject: [PATCH 19/19] Solve pre-commit issues --- pymc/logprob/abstract.py | 12 +++++------- pymc/logprob/scan.py | 2 +- tests/logprob/test_tensor.py | 2 +- tests/logprob/utils.py | 2 -- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index ffd517edbc..825deab3b1 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -36,9 +36,10 @@ import abc -from typing import Tuple, Union from collections.abc import Sequence +from enum import Enum, auto from functools import singledispatch +from typing import Union from pytensor.graph.op import Op from pytensor.graph.utils import MetaType @@ -132,9 +133,6 @@ def _icdf_helper(rv, value, **kwargs): return rv_icdf -from enum import Enum, auto - - class MeasureType(Enum): Discrete = auto() Continuous = auto() @@ -147,9 +145,9 @@ class MeasurableVariable(abc.ABC): def __init__( self, *args, - ndim_supp: Union[int, Tuple[int]], - supp_axes: Tuple[Union[int, Tuple[int]]], - measure_type: Union[MeasureType, Tuple[MeasureType]], + ndim_supp: Union[int, tuple], + supp_axes: tuple, + measure_type: Union[MeasureType, tuple], **kwargs, ): self.ndim_supp = ndim_supp diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 62f10a3267..f1e9985104 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -466,7 +466,7 @@ def find_measurable_scans(fgraph, node): # Replace the mapping rv_map_feature.update_rv_maps(rv_var, new_val_var, full_out) - clients: Dict[Variable, List[Variable]] = {} + clients: dict[Variable, list[Variable]] = {} local_fgraph_topo = pytensor.graph.basic.io_toposort( curr_scanargs.inner_inputs, [o for o in curr_scanargs.inner_outputs if not isinstance(o.type, RandomType)], diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index 0bc620fb86..72158d9b1a 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -52,7 +52,7 @@ from pymc.logprob.rewriting import logprob_rewrites_db from pymc.logprob.tensor import naive_bcast_rv_lift from pymc.testing import assert_no_rvs -from tests.logprob.utils import measure_type_info_helper, scipy_logprob +from tests.logprob.utils import measure_type_info_helper def test_naive_bcast_rv_lift(): diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index e2640c672d..db43e4fea3 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -36,13 +36,11 @@ import numpy as np -import pytensor from pytensor import tensor as pt from scipy import stats as stats from pymc.logprob import icdf, logcdf, logp -from pymc.logprob.abstract import MeasurableVariable from pymc.logprob.rewriting import construct_ir_fgraph