diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2c44c4d44f..a7ec2f6e64 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -84,13 +84,13 @@ jobs: install-mlx: [0] install-xarray: [0] part: - - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor" + - "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor" - "tests/scan" - - "tests/tensor --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/conv --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py" - - "tests/tensor/rewriting" + - "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting" + - "tests/tensor/test_basic.py tests/tensor/test_elemwise.py" - "tests/tensor/test_math.py" - - "tests/tensor/test_basic.py tests/tensor/test_inplace.py tests/tensor/conv" - - "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py" + - "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/conv" + - "tests/tensor/rewriting" exclude: - python-version: "3.11" fast-compile: 1 @@ -167,7 +167,7 @@ jobs: install-numba: 0 install-jax: 0 install-torch: 0 - part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py" + part: "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py" steps: - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 6adc16ec59..af0b0b7173 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1101,30 +1101,6 @@ def same_out_float_only(type) -> tuple[ScalarType]: return (type,) -class transfer_type(MetaObject): - __props__ = ("transfer",) - - def __init__(self, *transfer): - assert all(isinstance(x, int | str) or x is None for x in transfer) - self.transfer = transfer - - def __str__(self): - return f"transfer_type{self.transfer}" - - def __call__(self, *types): - upcast = upcast_out(*types) - retval = [] - for i in self.transfer: - if i is None: - retval += [upcast] - elif isinstance(i, str): - retval += [i] - else: - retval += [types[i]] - return retval - # return [upcast if i is None else types[i] for i in self.transfer] - - class specific_out(MetaObject): __props__ = ("spec",) @@ -2446,6 +2422,10 @@ def handle_int(v): class Second(BinaryScalarOp): + @staticmethod + def output_types_preference(_first_type, second_type): + return [second_type] + def impl(self, x, y): return y @@ -2474,7 +2454,7 @@ def grad(self, inputs, gout): return DisconnectedType()(), y.zeros_like(dtype=config.floatX) -second = Second(transfer_type(1), name="second") +second = Second(name="second") class Identity(UnaryScalarOp): @@ -2515,18 +2495,6 @@ def clone_float32(self): return convert_to_float32 return self - def make_new_inplace(self, output_types_preference=None, name=None): - """ - This op.__init__ fct don't have the same parameter as other scalar op. - This breaks the insert_inplace_optimizer optimization. - This function is a fix to patch this, by ignoring the - output_types_preference passed by the optimization, and replacing it - by the current output type. This should only be triggered when - both input and output have the same dtype anyway. - - """ - return self.__class__(self.o_type, name) - def impl(self, input): return self.ctor(input) @@ -4322,22 +4290,6 @@ def __str__(self): return self._name - def make_new_inplace(self, output_types_preference=None, name=None): - """ - This op.__init__ fct don't have the same parameter as other scalar op. - This break the insert_inplace_optimizer optimization. - This fct allow fix patch this. - - """ - d = {k: getattr(self, k) for k in self.init_param} - out = self.__class__(**d) - if name: - out.name = name - else: - name = out.name - super(Composite, out).__init__(output_types_preference, name) - return out - @property def fgraph(self): if hasattr(self, "_fgraph"): diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index f23c4e1c42..e4bfc871fc 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -136,9 +136,6 @@ def clone(self, name=None, **kwargs): def fn(self): raise NotImplementedError - def make_new_inplace(self, output_types_preference=None, name=None): - return self.clone(output_types_preference=output_types_preference, name=name) - def make_node(self, n_steps, *inputs): assert len(inputs) == self.nin - 1 diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 4388c110c8..f1d8bc09df 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -20,7 +20,7 @@ from pytensor.printing import Printer, pprint from pytensor.scalar import get_scalar_type from pytensor.scalar.basic import identity as scalar_identity -from pytensor.scalar.basic import int64, transfer_type, upcast +from pytensor.scalar.basic import int64, upcast from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length from pytensor.tensor.basic import _get_vector_length, as_tensor_variable @@ -1634,17 +1634,12 @@ def construct(symbol): symbolname = symbolname or symbol.__name__ if symbolname.endswith("_inplace"): - base_symbol_name = symbolname[: -len("_inplace")] - scalar_op = getattr(scalar, base_symbol_name) - inplace_scalar_op = scalar_op.__class__(transfer_type(0)) - rval = Elemwise( - inplace_scalar_op, - {0: 0}, - nfunc_spec=(nfunc and (nfunc, nin, nout)), + raise ValueError( + "Creation of automatic inplace elemwise operations deprecated" ) - else: - scalar_op = getattr(scalar, symbolname) - rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout))) + + scalar_op = getattr(scalar, symbolname) + rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout))) if getattr(symbol, "__doc__"): rval.__doc__ = symbol.__doc__ diff --git a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py deleted file mode 100644 index 8c0df0e2e0..0000000000 --- a/pytensor/tensor/inplace.py +++ /dev/null @@ -1,427 +0,0 @@ -from pytensor import printing -from pytensor.printing import pprint -from pytensor.tensor.elemwise import scalar_elemwise - - -@scalar_elemwise -def lt_inplace(a, b): - """a < b (inplace on a)""" - - -@scalar_elemwise -def gt_inplace(a, b): - """a > b (inplace on a)""" - - -@scalar_elemwise -def le_inplace(a, b): - """a <= b (inplace on a)""" - - -@scalar_elemwise -def ge_inplace(a, b): - """a >= b (inplace on a)""" - - -@scalar_elemwise -def eq_inplace(a, b): - """a == b (inplace on a)""" - - -@scalar_elemwise -def neq_inplace(a, b): - """a != b (inplace on a)""" - - -@scalar_elemwise -def and__inplace(a, b): - """bitwise a & b (inplace on a)""" - - -@scalar_elemwise -def or__inplace(a, b): - """bitwise a | b (inplace on a)""" - - -@scalar_elemwise -def xor_inplace(a, b): - """bitwise a ^ b (inplace on a)""" - - -@scalar_elemwise -def invert_inplace(a): - """bitwise ~a (inplace on a)""" - - -@scalar_elemwise -def abs_inplace(a): - """|`a`| (inplace on `a`)""" - - -@scalar_elemwise -def exp_inplace(a): - """e^`a` (inplace on `a`)""" - - -@scalar_elemwise -def exp2_inplace(a): - """2^`a` (inplace on `a`)""" - - -@scalar_elemwise -def expm1_inplace(a): - """e^`a` - 1 (inplace on `a`)""" - - -@scalar_elemwise -def neg_inplace(a): - """-a (inplace on a)""" - - -@scalar_elemwise -def reciprocal_inplace(a): - """1.0/a (inplace on a)""" - - -@scalar_elemwise -def log_inplace(a): - """base e logarithm of a (inplace on a)""" - - -@scalar_elemwise -def log1p_inplace(a): - """log(1+a)""" - - -@scalar_elemwise -def log2_inplace(a): - """base 2 logarithm of a (inplace on a)""" - - -@scalar_elemwise -def log10_inplace(a): - """base 10 logarithm of a (inplace on a)""" - - -@scalar_elemwise -def sign_inplace(a): - """sign of `a` (inplace on `a`)""" - - -@scalar_elemwise -def ceil_inplace(a): - """ceil of `a` (inplace on `a`)""" - - -@scalar_elemwise -def floor_inplace(a): - """floor of `a` (inplace on `a`)""" - - -@scalar_elemwise -def trunc_inplace(a): - """trunc of `a` (inplace on `a`)""" - - -@scalar_elemwise -def round_half_to_even_inplace(a): - """round_half_to_even_inplace(a) (inplace on `a`)""" - - -@scalar_elemwise -def round_half_away_from_zero_inplace(a): - """round_half_away_from_zero_inplace(a) (inplace on `a`)""" - - -@scalar_elemwise -def sqr_inplace(a): - """square of `a` (inplace on `a`)""" - - -@scalar_elemwise -def sqrt_inplace(a): - """square root of `a` (inplace on `a`)""" - - -@scalar_elemwise -def deg2rad_inplace(a): - """convert degree `a` to radian(inplace on `a`)""" - - -@scalar_elemwise -def rad2deg_inplace(a): - """convert radian `a` to degree(inplace on `a`)""" - - -@scalar_elemwise -def cos_inplace(a): - """cosine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arccos_inplace(a): - """arccosine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def sin_inplace(a): - """sine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arcsin_inplace(a): - """arcsine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def tan_inplace(a): - """tangent of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arctan_inplace(a): - """arctangent of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arctan2_inplace(a, b): - """arctangent of `a` / `b` (inplace on `a`)""" - - -@scalar_elemwise -def cosh_inplace(a): - """hyperbolic cosine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arccosh_inplace(a): - """hyperbolic arc cosine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def sinh_inplace(a): - """hyperbolic sine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arcsinh_inplace(a): - """hyperbolic arc sine of `a` (inplace on `a`)""" - - -@scalar_elemwise -def tanh_inplace(a): - """hyperbolic tangent of `a` (inplace on `a`)""" - - -@scalar_elemwise -def arctanh_inplace(a): - """hyperbolic arc tangent of `a` (inplace on `a`)""" - - -@scalar_elemwise -def erf_inplace(a): - """error function""" - - -@scalar_elemwise -def erfc_inplace(a): - """complementary error function""" - - -@scalar_elemwise -def erfcx_inplace(a): - """scaled complementary error function""" - - -@scalar_elemwise -def owens_t_inplace(h, a): - """owens t function""" - - -@scalar_elemwise -def gamma_inplace(a): - """gamma function""" - - -@scalar_elemwise -def gammaln_inplace(a): - """log gamma function""" - - -@scalar_elemwise -def psi_inplace(a): - """derivative of log gamma function""" - - -@scalar_elemwise -def tri_gamma_inplace(a): - """second derivative of the log gamma function""" - - -@scalar_elemwise -def gammainc_inplace(k, x): - """regularized lower gamma function (P)""" - - -@scalar_elemwise -def gammaincc_inplace(k, x): - """regularized upper gamma function (Q)""" - - -@scalar_elemwise -def gammau_inplace(k, x): - """upper incomplete gamma function""" - - -@scalar_elemwise -def gammal_inplace(k, x): - """lower incomplete gamma function""" - - -@scalar_elemwise -def gammaincinv_inplace(k, x): - """Inverse to the regularized lower incomplete gamma function""" - - -@scalar_elemwise -def gammainccinv_inplace(k, x): - """Inverse of the regularized upper incomplete gamma function""" - - -@scalar_elemwise -def j0_inplace(x): - """Bessel function of the first kind of order 0.""" - - -@scalar_elemwise -def j1_inplace(x): - """Bessel function of the first kind of order 1.""" - - -@scalar_elemwise -def jv_inplace(v, x): - """Bessel function of the first kind of order v (real).""" - - -@scalar_elemwise -def i0_inplace(x): - """Modified Bessel function of the first kind of order 0.""" - - -@scalar_elemwise -def i1_inplace(x): - """Modified Bessel function of the first kind of order 1.""" - - -@scalar_elemwise -def iv_inplace(v, x): - """Modified Bessel function of the first kind of order v (real).""" - - -@scalar_elemwise -def ive_inplace(v, x): - """Exponentially scaled modified Bessel function of the first kind of order v (real).""" - - -@scalar_elemwise -def sigmoid_inplace(x): - """Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit""" - - -@scalar_elemwise -def softplus_inplace(x): - """Compute log(1 + exp(x)), also known as softplus or log1pexp""" - - -@scalar_elemwise -def log1mexp_inplace(x): - """Compute log(1 - exp(x)), also known as log1mexp""" - - -@scalar_elemwise -def betainc_inplace(a, b, x): - """Regularized incomplete beta function""" - - -@scalar_elemwise -def betaincinv_inplace(a, b, x): - """Inverse of the regularized incomplete beta function""" - - -@scalar_elemwise -def second_inplace(a): - """Fill `a` with `b`""" - - -fill_inplace = second_inplace -pprint.assign(fill_inplace, printing.FunctionPrinter(["fill="])) - - -@scalar_elemwise -def maximum_inplace(a, b): - """elementwise addition (inplace on `a`)""" - - -@scalar_elemwise -def minimum_inplace(a, b): - """elementwise addition (inplace on `a`)""" - - -@scalar_elemwise -def add_inplace(a, b): - """elementwise addition (inplace on `a`)""" - - -@scalar_elemwise -def sub_inplace(a, b): - """elementwise subtraction (inplace on `a`)""" - - -@scalar_elemwise -def mul_inplace(a, b): - """elementwise multiplication (inplace on `a`)""" - - -@scalar_elemwise -def true_div_inplace(a, b): - """elementwise division (inplace on `a`)""" - - -@scalar_elemwise -def int_div_inplace(a, b): - """elementwise division (inplace on `a`)""" - - -@scalar_elemwise -def mod_inplace(a, b): - """elementwise modulo (inplace on `a`)""" - - -@scalar_elemwise -def pow_inplace(a, b): - """elementwise power (inplace on `a`)""" - - -@scalar_elemwise -def conj_inplace(a): - """elementwise conjugate (inplace on `a`)""" - - -@scalar_elemwise -def hyp2f1_inplace(a, b, c, z): - """gaussian hypergeometric function""" - - -pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either")) -pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either")) -pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left")) -pprint.assign(neg_inplace, printing.OperatorPrinter("-=", 0, "either")) -pprint.assign(true_div_inplace, printing.OperatorPrinter("/=", -1, "left")) -pprint.assign(int_div_inplace, printing.OperatorPrinter("//=", -1, "left")) -pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right")) - - -def transpose_inplace(x, **kwargs): - "Perform a transpose on a tensor without copying the underlying storage" - dims = list(range(x.ndim - 1, -1, -1)) - return x.dimshuffle(dims) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 1bcaa8624d..dc30beedf3 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -35,7 +35,6 @@ Mul, ScalarOp, get_scalar_type, - transfer_type, upcast_out, upgrade_to_float, ) @@ -287,22 +286,17 @@ def create_inplace_node(self, node, inplace_pattern): op = node.op scalar_op = op.scalar_op inplace_pattern = {i: o for i, [o] in inplace_pattern.items()} - if hasattr(scalar_op, "make_new_inplace"): - new_scalar_op = scalar_op.make_new_inplace( - transfer_type( - *[ - inplace_pattern.get(i, o.dtype) - for i, o in enumerate(node.outputs) - ] + try: + return type(op)(scalar_op, inplace_pattern).make_node(*node.inputs) + except TypeError: + # Elemwise raises TypeError if we try to inplace an output on an input of a different dtype + if config.optimizer_verbose: + print( # noqa: T201 + f"InplaceElemwise failed because the output dtype of {node} changed when rebuilt. " + "Perhaps due to a change in config.floatX or config.cast_policy" ) - ) - else: - new_scalar_op = type(scalar_op)( - transfer_type( - *[inplace_pattern.get(i, None) for i in range(len(node.outputs))] - ) - ) - return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs) + # InplaceGraphOptimizer will chug along fine if we return the original node + return node optdb.register( diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 954656cebe..a068335d5b 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -6,13 +6,13 @@ import pytensor import pytensor.tensor as pt -import pytensor.tensor.inplace as pti import pytensor.tensor.math as ptm from pytensor import config, function from pytensor.compile import get_mode from pytensor.compile.ops import deep_copy_op from pytensor.gradient import grad from pytensor.scalar import Composite, float64 +from pytensor.scalar import add as scalar_add from pytensor.tensor import blas, tensor from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum @@ -30,6 +30,8 @@ rng = np.random.default_rng(42849) +add_inplace = Elemwise(scalar_add, {0: 0}) + @pytest.mark.parametrize( "inputs, input_vals, output_fn", @@ -80,7 +82,7 @@ np.array(1.0, dtype=config.floatX), np.array(1.0, dtype=config.floatX), ], - lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)), + lambda x, y: add_inplace(deep_copy_op(x), deep_copy_op(y)), ), ( [pt.vector(), pt.vector()], @@ -88,7 +90,7 @@ rng.standard_normal(100).astype(config.floatX), rng.standard_normal(100).astype(config.floatX), ], - lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)), + lambda x, y: add_inplace(deep_copy_op(x), deep_copy_op(y)), ), ( [pt.vector(), pt.vector()], diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 303cf970d4..0380f997e3 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -31,7 +31,6 @@ from pytensor.graph.traversal import ancestors from pytensor.printing import debugprint from pytensor.scalar import PolyGamma, Psi, TriGamma -from pytensor.tensor import inplace from pytensor.tensor.basic import Alloc, constant, join, second, switch from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas_c import CGemv @@ -1134,15 +1133,15 @@ def test_log1p(): f = function([x], log(1 + (x)), mode=m) assert [node.op for node in f.maker.fgraph.toposort()] == [log1p] f = function([x], log(1 + (-x)), mode=m) - assert [node.op for node in f.maker.fgraph.toposort()] == [ - neg, - inplace.log1p_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] == [ + ps.neg, + ps.log1p, ] f = function([x], -log(1 + (-x)), mode=m) - assert [node.op for node in f.maker.fgraph.toposort()] == [ - neg, - inplace.log1p_inplace, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] == [ + ps.neg, + ps.log1p, + ps.neg, ] # check trickier cases (and use different dtype) @@ -4035,27 +4034,27 @@ def test_exp_over_1_plus_exp(self): # todo: solve issue #4589 first # assert check_stack_trace( # f, ops_to_check=[sigmoid, neg_inplace]) - assert [node.op for node in f.maker.fgraph.toposort()] == [ - sigmoid, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] == [ + ps.sigmoid, + ps.neg, ] f(data) f = pytensor.function([x], pt.fill(x, -1.0) / (1 - exp(-x)), mode=m) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.neg, ] f(data) f = pytensor.function([x], pt.fill(x, -1.0) / (2 + exp(-x)), mode=m) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.neg, ] f(data) f = pytensor.function([x], pt.fill(x, -1.1) / (1 + exp(-x)), mode=m) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.neg, ] f(data) @@ -4077,10 +4076,10 @@ def test_exp_over_1_plus_exp(self): (pt.fill(x, -1.1) * exp(x)) / ((1 + exp(x)) * (1 + exp(-x))), mode=m, ) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - mul, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.mul, + ps.neg, ] f(data) f = pytensor.function( @@ -4088,10 +4087,10 @@ def test_exp_over_1_plus_exp(self): (pt.fill(x, -1.0) * exp(x)) / ((2 + exp(x)) * (1 + exp(-x))), mode=m, ) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - mul, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.mul, + ps.neg, ] f(data) f = pytensor.function( @@ -4099,10 +4098,10 @@ def test_exp_over_1_plus_exp(self): (pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (2 + exp(-x))), mode=m, ) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - mul, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.mul, + ps.neg, ] f(data) f = pytensor.function( @@ -4110,10 +4109,10 @@ def test_exp_over_1_plus_exp(self): (pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (1 + exp(x))), mode=m, ) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - mul, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.mul, + ps.neg, ] f(data) f = pytensor.function( @@ -4121,10 +4120,10 @@ def test_exp_over_1_plus_exp(self): (pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (2 + exp(-x))), mode=m, ) - assert [node.op for node in f.maker.fgraph.toposort()] != [ - sigmoid, - mul, - inplace.neg_inplace, + assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [ + ps.sigmoid, + ps.mul, + ps.neg, ] f(data) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 0390fbbac8..46a5e2e4fa 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -2797,7 +2797,6 @@ def test_infer_shape(self, cast_policy): out = arange(start, stop, 1) f = function([start, stop], out.shape, mode=mode) assert len(f.maker.fgraph.toposort()) == 5 - # 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)] if config.cast_policy == "custom": assert out.dtype == "int64" elif config.cast_policy == "numpy+floatX": diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 6d1e843a9e..60592d1b31 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -17,7 +17,6 @@ from pytensor.gradient import grad from pytensor.graph.rewriting.basic import in2out from pytensor.graph.utils import InconsistencyError -from pytensor.tensor import inplace from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.blas import ( BatchedDot, @@ -40,6 +39,7 @@ ger, ger_destructive, ) +from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import Dot, dot, mean, mul, outer, sigmoid from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger from pytensor.tensor.type import ( @@ -258,16 +258,20 @@ def test_destroy_map1(self): rng = np.random.default_rng(seed=utt.fetch_seed()) Z = as_tensor_variable(rng.random((2, 2))) A = as_tensor_variable(rng.random((2, 2))) + Zt = Z.transpose() + assert isinstance(Zt.owner.op, DimShuffle) and Zt.owner.op.view_map == {0: [0]} with pytest.raises(InconsistencyError, match=Gemm.E_z_uniq): - gemm_inplace(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0) + gemm_inplace(Z, 1.0, A, Zt, 1.0) def test_destroy_map2(self): # test that only first input can be overwritten. rng = np.random.default_rng(seed=utt.fetch_seed()) Z = as_tensor_variable(rng.random((2, 2))) A = as_tensor_variable(rng.random((2, 2))) + Zt = Z.transpose() + assert isinstance(Zt.owner.op, DimShuffle) and Zt.owner.op.view_map == {0: [0]} with pytest.raises(InconsistencyError, match=Gemm.E_z_uniq): - gemm_inplace(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0) + gemm_inplace(Z, 1.0, Zt, A, 1.0) def test_destroy_map3(self): # test that only first input can be overwritten diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 5d20bf837b..5a61bf8f8a 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -20,6 +20,9 @@ from pytensor.link.basic import PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.scalar import ScalarOp, float32, float64, int32, int64 +from pytensor.scalar import add as scalar_add +from pytensor.scalar import exp as scalar_exp +from pytensor.scalar import xor as scalar_xor from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import get_scalar_constant_value, second from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -43,6 +46,16 @@ ) from tests import unittest_tools from tests.link.test_link import make_function +from tests.tensor.utils import ( + _bad_runtime_broadcast_binary_normal, + inplace_func, + integers, + integers_uint16, + integers_uint32, + makeBroadcastTester, + random, + random_complex, +) def reduce_bitwise_and(x, axis=-1, dtype="int8"): @@ -334,7 +347,7 @@ def with_linker_inplace(self, linker, op, type, rand_val): x = x_type("x") y = y_type("y") - e = op(ps.Add(ps.transfer_type(0)), {0: 0})(x, y) + e = op(ps.add, {0: 0})(x, y) f = make_function(copy(linker).accept(FunctionGraph([x, y], [e]))) xv = rand_val(xsh) yv = rand_val(ysh) @@ -348,7 +361,7 @@ def with_linker_inplace(self, linker, op, type, rand_val): if isinstance(linker, PerformLinker): x = x_type("x") y = y_type("y") - e = op(ps.Add(ps.transfer_type(0)), {0: 0})(x, y) + e = op(ps.add, {0: 0})(x, y) f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape]))) xv = rand_val(xsh) yv = rand_val(ysh) @@ -390,7 +403,10 @@ def test_fill(self): ): x = t(pytensor.config.floatX, shape=(None, None))("x") y = t(pytensor.config.floatX, shape=(1, 1))("y") - e = op(ps.Second(ps.transfer_type(0)), {0: 0})(x, y) + op1 = op(ps.second, {0: 0}) + op2 = op(ps.second, {0: 0}) + assert op1 == op2 + e = op(ps.Second(), {0: 0})(x, y) f = make_function(linker().accept(FunctionGraph([x, y], [e]))) xv = rval((5, 5)) yv = rval((1, 1)) @@ -1113,3 +1129,99 @@ def test_numpy_warning_suppressed(): y = pt.log(x) fn = pytensor.function([x], y, mode=Mode(linker="py")) assert fn(0) == -np.inf + + +rng = np.random.default_rng(18) +_good_add_inplace = dict( + same_shapes=(random(2, 3, rng=rng), random(2, 3, rng=rng)), + not_same_dimensions=(random(2, 2, rng=rng), random(2, rng=rng)), + scalar=(random(2, 3, rng=rng), random(1, 1, rng=rng)), + row=(random(2, 3, rng=rng), random(1, 3, rng=rng)), + column=(random(2, 3, rng=rng), random(2, 1, rng=rng)), + integers=(integers(2, 3, rng=rng), integers(2, 3, rng=rng)), + uint32=(integers_uint32(2, 3, rng=rng), integers_uint32(2, 3, rng=rng)), + uint16=(integers_uint16(2, 3, rng=rng), integers_uint16(2, 3, rng=rng)), + # (float32, >int16) upcasts to float64 by default + dtype_valid_mixup=( + random(2, 3, rng=rng), + integers(2, 3, rng=rng).astype( + "int16" if config.floatX == "float32" else "int64" + ), + ), + complex1=(random_complex(2, 3, rng=rng), random_complex(2, 3, rng=rng)), + complex2=(random_complex(2, 3, rng=rng), random(2, 3, rng=rng)), + empty=(np.asarray([], dtype=config.floatX), np.asarray([1], dtype=config.floatX)), +) +TestAddInplaceBroadcast = makeBroadcastTester( + op=Elemwise(scalar_add, {0: 0}), + expected=lambda x, y: x + y, + good=_good_add_inplace, + # Cannot inplace on first input if it doesn't match output dtype (upcast of inputs) + bad_build=dict(dtype_invalid_mixup=_good_add_inplace["dtype_valid_mixup"][::-1]), + bad_runtime=_bad_runtime_broadcast_binary_normal, + inplace=True, +) + + +@pytest.mark.xfail( + config.cycle_detection == "fast" and config.mode != "FAST_COMPILE", + reason="Cycle detection is fast and mode is FAST_COMPILE", +) +def test_exp_inplace_grad_1(): + utt.verify_grad( + Elemwise(scalar_exp, {0: 0}), + [ + np.asarray( + [ + [1.5089518, 1.48439076, -4.7820262], + [2.04832468, 0.50791564, -1.58892269], + ] + ) + ], + ) + + +def test_XOR_inplace(): + dtype = [ + "int8", + "int16", + "int32", + "int64", + ] + xor_inplace = Elemwise(scalar_xor, {0: 0}) + + for dtype in dtype: + x, y = vector(dtype=dtype), vector(dtype=dtype) + l = np.asarray([0, 0, 1, 1], dtype=dtype) + r = np.asarray([0, 1, 0, 1], dtype=dtype) + ix = x + ix = xor_inplace(ix, y) + gn = inplace_func([x, y], ix) + _ = gn(l, r) + # test the in-place stuff + assert np.all(l == np.asarray([0, 1, 1, 0])), l + + +def test_inplace_dtype_changed(): + with pytensor.config.change_flags(cast_policy="numpy+floatX", floatX="float64"): + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="int32") + with pytensor.config.change_flags(floatX="float32"): + out = pt.add(x, y) + + assert out.dtype == "float32" + with pytensor.config.change_flags(floatX="float32"): + fn32 = pytensor.function( + [In(x, mutable=True), In(y, mutable=True)], + out, + mode="fast_run", + ) + assert fn32.maker.fgraph.outputs[0].owner.op.destroy_map == {0: [0]} + + with pytensor.config.change_flags(floatX="float64"): + fn64 = pytensor.function( + [In(x, mutable=True), In(y, mutable=True)], + out, + mode="fast_run", + ) + assert fn64.maker.fgraph.outputs[0].owner.op.destroy_map == {} diff --git a/tests/tensor/test_inplace.py b/tests/tensor/test_inplace.py deleted file mode 100644 index a31a26df07..0000000000 --- a/tests/tensor/test_inplace.py +++ /dev/null @@ -1,465 +0,0 @@ -import numpy as np -import pytest - -from pytensor import config -from pytensor.scalar.basic import round_half_away_from_zero_vec, upcast -from pytensor.tensor.inplace import ( - abs_inplace, - add_inplace, - arccos_inplace, - arccosh_inplace, - arcsin_inplace, - arcsinh_inplace, - arctan2_inplace, - arctan_inplace, - arctanh_inplace, - ceil_inplace, - conj_inplace, - cos_inplace, - cosh_inplace, - deg2rad_inplace, - exp2_inplace, - exp_inplace, - expm1_inplace, - floor_inplace, - int_div_inplace, - log1p_inplace, - log2_inplace, - log10_inplace, - log_inplace, - maximum_inplace, - minimum_inplace, - mod_inplace, - mul_inplace, - neg_inplace, - pow_inplace, - rad2deg_inplace, - reciprocal_inplace, - round_half_away_from_zero_inplace, - round_half_to_even_inplace, - sign_inplace, - sin_inplace, - sinh_inplace, - sqr_inplace, - sqrt_inplace, - sub_inplace, - tan_inplace, - tanh_inplace, - true_div_inplace, - trunc_inplace, - xor_inplace, -) -from pytensor.tensor.type import vector -from tests import unittest_tools as utt -from tests.tensor.utils import ( - _bad_build_broadcast_binary_normal, - _bad_runtime_broadcast_binary_normal, - _bad_runtime_reciprocal, - _good_broadcast_binary_arctan2, - _good_broadcast_binary_normal, - _good_broadcast_div_mod_normal_float_inplace, - _good_broadcast_pow_normal_float_pow, - _good_broadcast_unary_arccosh, - _good_broadcast_unary_arcsin_float, - _good_broadcast_unary_arctanh, - _good_broadcast_unary_normal, - _good_broadcast_unary_normal_abs, - _good_broadcast_unary_normal_float, - _good_broadcast_unary_normal_float_no_complex, - _good_broadcast_unary_normal_float_no_empty_no_complex, - _good_broadcast_unary_normal_no_complex, - _good_broadcast_unary_positive_float, - _good_broadcast_unary_tan, - _good_broadcast_unary_wide_float, - _good_reciprocal_inplace, - _numpy_true_div, - angle_eps, - check_floatX, - copymod, - div_grad_rtol, - ignore_isfinite_mode, - inplace_func, - makeBroadcastTester, - upcast_float16_ufunc, -) - - -TestAddInplaceBroadcast = makeBroadcastTester( - op=add_inplace, - expected=lambda x, y: x + y, - good=_good_broadcast_binary_normal, - bad_build=_bad_build_broadcast_binary_normal, - bad_runtime=_bad_runtime_broadcast_binary_normal, - inplace=True, -) - -TestSubInplaceBroadcast = makeBroadcastTester( - op=sub_inplace, - expected=lambda x, y: x - y, - good=_good_broadcast_binary_normal, - bad_build=_bad_build_broadcast_binary_normal, - bad_runtime=_bad_runtime_broadcast_binary_normal, - inplace=True, -) - -TestMaximumInplaceBroadcast = makeBroadcastTester( - op=maximum_inplace, - expected=np.maximum, - good=_good_broadcast_binary_normal, - bad_build=_bad_build_broadcast_binary_normal, - bad_runtime=_bad_runtime_broadcast_binary_normal, - inplace=True, -) - -TestMinimumInplaceBroadcast = makeBroadcastTester( - op=minimum_inplace, - expected=np.minimum, - good=_good_broadcast_binary_normal, - bad_build=_bad_build_broadcast_binary_normal, - bad_runtime=_bad_runtime_broadcast_binary_normal, - inplace=True, -) - -TestMulInplaceBroadcast = makeBroadcastTester( - op=mul_inplace, - expected=lambda x, y: x * y, - good=_good_broadcast_binary_normal, - bad_build=_bad_build_broadcast_binary_normal, - bad_runtime=_bad_runtime_broadcast_binary_normal, - inplace=True, -) - -TestTrueDivInplaceBroadcast = makeBroadcastTester( - op=true_div_inplace, - expected=_numpy_true_div, - good=copymod( - _good_broadcast_div_mod_normal_float_inplace, - # The output is now in float, we cannot work inplace on an int. - without=["integer", "uint8", "uint16", "int8"], - ), - grad_rtol=div_grad_rtol, - inplace=True, -) - -TestReciprocalInplaceBroadcast = makeBroadcastTester( - op=reciprocal_inplace, - expected=lambda x: _numpy_true_div(np.int8(1), x), - good=_good_reciprocal_inplace, - bad_runtime=_bad_runtime_reciprocal, - grad_rtol=div_grad_rtol, - inplace=True, -) - -TestModInplaceBroadcast = makeBroadcastTester( - op=mod_inplace, - expected=lambda x, y: np.asarray(x % y, dtype=upcast(x.dtype, y.dtype)), - good=copymod( - _good_broadcast_div_mod_normal_float_inplace, ["complex1", "complex2"] - ), - grad_eps=1e-5, - inplace=True, -) - -TestPowInplaceBroadcast = makeBroadcastTester( - op=pow_inplace, - expected=lambda x, y: x**y, - good=_good_broadcast_pow_normal_float_pow, - inplace=True, - mode=ignore_isfinite_mode, -) - -TestNegInplaceBroadcast = makeBroadcastTester( - op=neg_inplace, - expected=lambda x: -x, - good=_good_broadcast_unary_normal, - inplace=True, -) - -TestSgnInplaceBroadcast = makeBroadcastTester( - op=sign_inplace, - expected=np.sign, - good=_good_broadcast_unary_normal_no_complex, - inplace=True, -) - -TestAbsInplaceBroadcast = makeBroadcastTester( - op=abs_inplace, - expected=lambda x: np.abs(x), - good=_good_broadcast_unary_normal_abs, - inplace=True, -) - -TestIntDivInplaceBroadcast = makeBroadcastTester( - op=int_div_inplace, - expected=lambda x, y: check_floatX((x, y), x // y), - good=_good_broadcast_div_mod_normal_float_inplace, - # I don't test the grad as the output is always an integer - # (this is not a continuous output). - # grad=_grad_broadcast_div_mod_normal, - inplace=True, -) - -TestCeilInplaceBroadcast = makeBroadcastTester( - op=ceil_inplace, - expected=upcast_float16_ufunc(np.ceil), - good=copymod( - _good_broadcast_unary_normal_no_complex, - without=["integers", "int8", "uint8", "uint16"], - ), - # corner cases includes a lot of integers: points where Ceil is not - # continuous (not differentiable) - inplace=True, -) - -TestFloorInplaceBroadcast = makeBroadcastTester( - op=floor_inplace, - expected=upcast_float16_ufunc(np.floor), - good=copymod( - _good_broadcast_unary_normal_no_complex, - without=["integers", "int8", "uint8", "uint16"], - ), - inplace=True, -) - -TestTruncInplaceBroadcast = makeBroadcastTester( - op=trunc_inplace, - expected=upcast_float16_ufunc(np.trunc), - good=_good_broadcast_unary_normal_no_complex, - inplace=True, -) - -TestRoundHalfToEvenInplaceBroadcast = makeBroadcastTester( - op=round_half_to_even_inplace, - expected=np.round, - good=_good_broadcast_unary_normal_float_no_complex, - inplace=True, -) - -TestRoundHalfAwayFromZeroInplaceBroadcast = makeBroadcastTester( - op=round_half_away_from_zero_inplace, - expected=lambda a: round_half_away_from_zero_vec(a), - good=_good_broadcast_unary_normal_float_no_empty_no_complex, - inplace=True, -) - -TestSqrInplaceBroadcast = makeBroadcastTester( - op=sqr_inplace, - expected=np.square, - good=_good_broadcast_unary_normal, - inplace=True, -) - -TestExpInplaceBroadcast = makeBroadcastTester( - op=exp_inplace, - expected=np.exp, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestExp2InplaceBroadcast = makeBroadcastTester( - op=exp2_inplace, - expected=np.exp2, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestExpm1InplaceBroadcast = makeBroadcastTester( - op=expm1_inplace, - expected=np.expm1, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestLogInplaceBroadcast = makeBroadcastTester( - op=log_inplace, - expected=np.log, - good=_good_broadcast_unary_positive_float, - inplace=True, -) - -TestLog2InplaceBroadcast = makeBroadcastTester( - op=log2_inplace, - expected=np.log2, - good=_good_broadcast_unary_positive_float, - inplace=True, -) - -TestLog10InplaceBroadcast = makeBroadcastTester( - op=log10_inplace, - expected=np.log10, - good=_good_broadcast_unary_positive_float, - inplace=True, -) - -TestLog1pInplaceBroadcast = makeBroadcastTester( - op=log1p_inplace, - expected=np.log1p, - good=_good_broadcast_unary_positive_float, - inplace=True, -) - -TestSqrtInplaceBroadcast = makeBroadcastTester( - op=sqrt_inplace, - expected=np.sqrt, - good=_good_broadcast_unary_positive_float, - inplace=True, -) - -TestDeg2radInplaceBroadcast = makeBroadcastTester( - op=deg2rad_inplace, - expected=np.deg2rad, - good=_good_broadcast_unary_normal_float_no_complex, - inplace=True, - eps=angle_eps, -) - -TestRad2degInplaceBroadcast = makeBroadcastTester( - op=rad2deg_inplace, - expected=np.rad2deg, - good=_good_broadcast_unary_normal_float_no_complex, - inplace=True, - eps=angle_eps, -) - -TestSinInplaceBroadcast = makeBroadcastTester( - op=sin_inplace, - expected=np.sin, - good=_good_broadcast_unary_wide_float, - inplace=True, -) - -TestArcsinInplaceBroadcast = makeBroadcastTester( - op=arcsin_inplace, - expected=np.arcsin, - good=_good_broadcast_unary_arcsin_float, - inplace=True, -) - -TestCosInplaceBroadcast = makeBroadcastTester( - op=cos_inplace, - expected=np.cos, - good=_good_broadcast_unary_wide_float, - inplace=True, -) - -TestArccosInplaceBroadcast = makeBroadcastTester( - op=arccos_inplace, - expected=np.arccos, - good=_good_broadcast_unary_arcsin_float, - inplace=True, -) - -TestTanInplaceBroadcast = makeBroadcastTester( - op=tan_inplace, - expected=np.tan, - good=copymod( - _good_broadcast_unary_tan, without=["integers", "int8", "uint8", "uint16"] - ), - inplace=True, -) - -TestArctanInplaceBroadcast = makeBroadcastTester( - op=arctan_inplace, - expected=np.arctan, - good=_good_broadcast_unary_wide_float, - inplace=True, -) - -TestArctan2InplaceBroadcast = makeBroadcastTester( - op=arctan2_inplace, - expected=np.arctan2, - good=copymod( - _good_broadcast_binary_arctan2, - without=["integers", "int8", "uint8", "uint16", "dtype_mixup_2"], - ), - inplace=True, -) - -TestCoshInplaceBroadcast = makeBroadcastTester( - op=cosh_inplace, - expected=np.cosh, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestArccoshInplaceBroadcast = makeBroadcastTester( - op=arccosh_inplace, - expected=np.arccosh, - good=copymod(_good_broadcast_unary_arccosh, without=["integers", "uint8"]), - inplace=True, -) - -TestSinhInplaceBroadcast = makeBroadcastTester( - op=sinh_inplace, - expected=np.sinh, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestArcsinhInplaceBroadcast = makeBroadcastTester( - op=arcsinh_inplace, - expected=np.arcsinh, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestTanhInplaceBroadcast = makeBroadcastTester( - op=tanh_inplace, - expected=np.tanh, - good=_good_broadcast_unary_normal_float, - inplace=True, -) - -TestArctanhInplaceBroadcast = makeBroadcastTester( - op=arctanh_inplace, - expected=np.arctanh, - good=copymod( - _good_broadcast_unary_arctanh, without=["integers", "int8", "uint8", "uint16"] - ), - inplace=True, -) - -TestConjInplaceBroadcast = makeBroadcastTester( - op=conj_inplace, - expected=np.conj, - good=_good_broadcast_unary_normal, - inplace=True, -) - - -@pytest.mark.xfail( - config.cycle_detection == "fast" and config.mode != "FAST_COMPILE", - reason="Cycle detection is fast and mode is FAST_COMPILE", -) -def test_exp_inplace_grad_1(): - utt.verify_grad( - exp_inplace, - [ - np.asarray( - [ - [1.5089518, 1.48439076, -4.7820262], - [2.04832468, 0.50791564, -1.58892269], - ] - ) - ], - ) - - -def test_XOR_inplace(): - dtype = [ - "int8", - "int16", - "int32", - "int64", - ] - - for dtype in dtype: - x, y = vector(dtype=dtype), vector(dtype=dtype) - l = np.asarray([0, 0, 1, 1], dtype=dtype) - r = np.asarray([0, 1, 0, 1], dtype=dtype) - ix = x - ix = xor_inplace(ix, y) - gn = inplace_func([x, y], ix) - _ = gn(l, r) - # test the in-place stuff - assert np.all(l == np.asarray([0, 1, 1, 0])), l diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index fbfa5fb77e..d4c2e3463f 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -12,13 +12,12 @@ from pytensor.configdefaults import config from pytensor.gradient import NullTypeGradError, verify_grad from pytensor.scalar import ScalarLoop -from pytensor.tensor import gammaincc, inplace, kn, kv, kve, vector +from pytensor.tensor import gammaincc, kn, kv, kve, vector from pytensor.tensor.elemwise import Elemwise from tests import unittest_tools as utt from tests.tensor.utils import ( _good_broadcast_unary_chi2sf, _good_broadcast_unary_normal, - _good_broadcast_unary_normal_float, _good_broadcast_unary_normal_float_no_complex, _good_broadcast_unary_normal_float_no_complex_small_neg_range, _good_broadcast_unary_normal_no_complex, @@ -85,14 +84,6 @@ def scipy_special_gammal(k, x): eps=2e-10, mode=mode_no_scipy, ) -TestErfInplaceBroadcast = makeBroadcastTester( - op=inplace.erf_inplace, - expected=expected_erf, - good=_good_broadcast_unary_normal_float, - mode=mode_no_scipy, - eps=2e-10, - inplace=True, -) TestErfcBroadcast = makeBroadcastTester( op=pt.erfc, @@ -102,14 +93,6 @@ def scipy_special_gammal(k, x): eps=2e-10, mode=mode_no_scipy, ) -TestErfcInplaceBroadcast = makeBroadcastTester( - op=inplace.erfc_inplace, - expected=expected_erfc, - good=_good_broadcast_unary_normal_float_no_complex, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) TestErfcxBroadcast = makeBroadcastTester( op=pt.erfcx, @@ -119,14 +102,6 @@ def scipy_special_gammal(k, x): eps=2e-10, mode=mode_no_scipy, ) -TestErfcxInplaceBroadcast = makeBroadcastTester( - op=inplace.erfcx_inplace, - expected=expected_erfcx, - good=_good_broadcast_unary_normal_float_no_complex_small_neg_range, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) TestErfinvBroadcast = makeBroadcastTester( op=pt.erfinv, @@ -192,14 +167,6 @@ def scipy_special_gammal(k, x): eps=2e-10, mode=mode_no_scipy, ) -TestOwensTInplaceBroadcast = makeBroadcastTester( - op=inplace.owens_t_inplace, - expected=expected_owenst, - good=_good_broadcast_binary_owenst, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_unary_gammaln = dict( @@ -223,14 +190,6 @@ def scipy_special_gammal(k, x): mode=mode_no_scipy, eps=1e-5, ) -TestGammaInplaceBroadcast = makeBroadcastTester( - op=inplace.gamma_inplace, - expected=expected_gamma, - good=_good_broadcast_unary_gammaln, - mode=mode_no_scipy, - eps=1e-5, - inplace=True, -) TestGammalnBroadcast = makeBroadcastTester( op=pt.gammaln, @@ -240,14 +199,6 @@ def scipy_special_gammal(k, x): eps=2e-10, mode=mode_no_scipy, ) -TestGammalnInplaceBroadcast = makeBroadcastTester( - op=inplace.gammaln_inplace, - expected=expected_gammaln, - good=_good_broadcast_unary_gammaln, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_unary_psi = dict( @@ -265,14 +216,6 @@ def scipy_special_gammal(k, x): eps=2e-10, mode=mode_no_scipy, ) -TestPsiInplaceBroadcast = makeBroadcastTester( - op=inplace.psi_inplace, - expected=expected_psi, - good=_good_broadcast_unary_psi, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) _good_broadcast_unary_tri_gamma = _good_broadcast_unary_psi @@ -283,14 +226,6 @@ def scipy_special_gammal(k, x): eps=2e-8, mode=mode_no_scipy, ) -TestTriGammaInplaceBroadcast = makeBroadcastTester( - op=inplace.tri_gamma_inplace, - expected=expected_tri_gamma, - good=_good_broadcast_unary_tri_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) TestChi2SFBroadcast = makeBroadcastTester( op=pt.chi2sf, @@ -343,15 +278,6 @@ def scipy_special_gammal(k, x): mode=mode_no_scipy, ) -TestGammaIncInplaceBroadcast = makeBroadcastTester( - op=inplace.gammainc_inplace, - expected=expected_gammainc, - good=_good_broadcast_binary_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) - TestGammaInccBroadcast = makeBroadcastTester( op=pt.gammaincc, expected=expected_gammaincc, @@ -361,15 +287,6 @@ def scipy_special_gammal(k, x): mode=mode_no_scipy, ) -TestGammaInccInplaceBroadcast = makeBroadcastTester( - op=inplace.gammaincc_inplace, - expected=expected_gammaincc, - good=_good_broadcast_binary_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) - def test_gammainc_ddk_tabulated_values(): # This test replicates part of the old STAN test: @@ -447,15 +364,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestGammaUInplaceBroadcast = makeBroadcastTester( - op=inplace.gammau_inplace, - expected=expected_gammau, - good=_good_broadcast_binary_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) - TestGammaLBroadcast = makeBroadcastTester( op=pt.gammal, expected=expected_gammal, @@ -464,15 +372,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestGammaLInplaceBroadcast = makeBroadcastTester( - op=inplace.gammal_inplace, - expected=expected_gammal, - good=_good_broadcast_binary_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) - rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_binary_gamma = dict( normal=( @@ -490,15 +389,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestGammaIncInvInplaceBroadcast = makeBroadcastTester( - op=inplace.gammaincinv_inplace, - expected=expected_gammaincinv, - good=_good_broadcast_binary_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) - TestGammaInccInvBroadcast = makeBroadcastTester( op=pt.gammainccinv, expected=expected_gammainccinv, @@ -507,15 +397,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestGammaInccInvInplaceBroadcast = makeBroadcastTester( - op=inplace.gammainccinv_inplace, - expected=expected_gammainccinv, - good=_good_broadcast_binary_gamma, - eps=2e-8, - mode=mode_no_scipy, - inplace=True, -) - rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_unary_bessel = dict( normal=(random_ranged(-10, 10, (2, 3), rng=rng),), @@ -562,15 +443,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestJ0InplaceBroadcast = makeBroadcastTester( - op=inplace.j0_inplace, - expected=expected_j0, - good=_good_broadcast_unary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - TestJ1Broadcast = makeBroadcastTester( op=pt.j1, expected=expected_j1, @@ -580,15 +452,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestJ1InplaceBroadcast = makeBroadcastTester( - op=inplace.j1_inplace, - expected=expected_j1, - good=_good_broadcast_unary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - TestJvBroadcast = makeBroadcastTester( op=pt.jv, expected=expected_jv, @@ -597,15 +460,6 @@ def test_gammaincc_ddk_performance(benchmark): mode=mode_no_scipy, ) -TestJvInplaceBroadcast = makeBroadcastTester( - op=inplace.jv_inplace, - expected=expected_jv, - good=_good_broadcast_binary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - def test_verify_jv_grad(): # Verify Jv gradient. @@ -628,15 +482,6 @@ def fixed_first_input_jv(x): mode=mode_no_scipy, ) -TestI0InplaceBroadcast = makeBroadcastTester( - op=inplace.i0_inplace, - expected=expected_i0, - good=_good_broadcast_unary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - TestI1Broadcast = makeBroadcastTester( op=pt.i1, expected=expected_i1, @@ -646,15 +491,6 @@ def fixed_first_input_jv(x): mode=mode_no_scipy, ) -TestI1InplaceBroadcast = makeBroadcastTester( - op=inplace.i1_inplace, - expected=expected_i1, - good=_good_broadcast_unary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - TestIvBroadcast = makeBroadcastTester( op=pt.iv, expected=expected_iv, @@ -663,15 +499,6 @@ def fixed_first_input_jv(x): mode=mode_no_scipy, ) -TestIvInplaceBroadcast = makeBroadcastTester( - op=inplace.iv_inplace, - expected=expected_iv, - good=_good_broadcast_binary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - TestIveBroadcast = makeBroadcastTester( op=pt.ive, expected=expected_ive, @@ -680,15 +507,6 @@ def fixed_first_input_jv(x): mode=mode_no_scipy, ) -TestIveInplaceBroadcast = makeBroadcastTester( - op=inplace.ive_inplace, - expected=expected_ive, - good=_good_broadcast_binary_bessel, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, -) - def test_verify_iv_grad(): # Verify Iv gradient. @@ -721,15 +539,6 @@ def fixed_first_input_ive(x): eps=1e-8, ) -TestSigmoidInplaceBroadcast = makeBroadcastTester( - op=inplace.sigmoid_inplace, - expected=expected_sigmoid, - good=_good_broadcast_unary_normal_no_complex, - grad=_grad_broadcast_unary_normal, - eps=1e-8, - inplace=True, -) - class TestSigmoid: def test_elemwise(self): @@ -758,15 +567,6 @@ def test_elemwise(self): eps=1e-8, ) -TestSoftplusInplaceBroadcast = makeBroadcastTester( - op=inplace.softplus_inplace, - expected=expected_sofplus, - good=_good_broadcast_unary_softplus, - grad=_grad_broadcast_unary_normal, - eps=1e-8, - inplace=True, -) - class TestSoftplus: def test_elemwise(self): @@ -805,14 +605,6 @@ def expected_log1mexp(x): eps=1e-8, ) -TestLog1mexpInplaceBroadcast = makeBroadcastTester( - op=inplace.log1mexp_inplace, - expected=expected_log1mexp, - good=_good_broadcast_unary_log1mexp, - eps=1e-8, - inplace=True, -) - _good_broadcast_ternary_betainc = dict( normal=( random_ranged(0, 1000, (2, 3)), @@ -828,14 +620,6 @@ def expected_log1mexp(x): grad=_good_broadcast_ternary_betainc, ) -TestBetaincInplaceBroadcast = makeBroadcastTester( - op=inplace.betainc_inplace, - expected=special.betainc, - good=_good_broadcast_ternary_betainc, - grad=_good_broadcast_ternary_betainc, - inplace=True, -) - class TestBetaIncGrad: def test_stan_grad_partial(self): @@ -926,13 +710,6 @@ def test_beta_inc_stan_grad_combined(self): good=_good_broadcast_ternary_betaincinv, ) -TestBetaincinvInplaceBroadcast = makeBroadcastTester( - op=inplace.betaincinv_inplace, - expected=special.betaincinv, - good=_good_broadcast_ternary_betaincinv, - inplace=True, -) - _good_broadcast_quaternary_hyp2f1 = dict( normal=( random_ranged(0, 20, (2, 3)), @@ -949,13 +726,6 @@ def test_beta_inc_stan_grad_combined(self): grad=_good_broadcast_quaternary_hyp2f1, ) -TestHyp2F1InplaceBroadcast = makeBroadcastTester( - op=inplace.hyp2f1_inplace, - expected=expected_hyp2f1, - good=_good_broadcast_quaternary_hyp2f1, - inplace=True, -) - class TestHyp2F1Grad: few_iters_case = ( diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index 1a8b2455ec..8ebf25a1d9 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -672,7 +672,9 @@ def test_grad_none(self): return Checker -def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs): +def makeBroadcastTester( + op, expected, checks=None, name=None, *, inplace=False, **kwargs +): if checks is None: checks = {} if name is None: @@ -695,22 +697,20 @@ def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs): # cases we need to add it manually. if not name.endswith("Tester"): name += "Tester" - if "inplace" in kwargs: - if kwargs["inplace"]: - _expected = expected - if not isinstance(_expected, dict): - - def expected(*inputs): - return np.array(_expected(*inputs), dtype=inputs[0].dtype) - - def inplace_check(inputs, outputs): - # this used to be inputs[0] is output[0] - # I changed it so that it was easier to satisfy by the - # DebugMode - return np.all(inputs[0] == outputs[0]) - - checks = dict(checks, inplace_check=inplace_check) - del kwargs["inplace"] + if inplace: + _expected = expected + if not isinstance(_expected, dict): + + def expected(*inputs): + return np.array(_expected(*inputs), dtype=inputs[0].dtype) + + def inplace_check(inputs, outputs): + # this used to be inputs[0] is output[0] + # I changed it so that it was easier to satisfy by the + # DebugMode + return np.all(inputs[0] == outputs[0]) + + checks = dict(checks, inplace_check=inplace_check) return makeTester(name, op, expected, checks, **kwargs) @@ -815,6 +815,7 @@ def inplace_check(inputs, outputs): big_scalar=[np.arange(17.0, 29.0, 0.5, dtype=config.floatX)], ) +# FIXME: Why is this empty? _bad_build_broadcast_binary_normal = dict() _bad_runtime_broadcast_binary_normal = dict(