diff --git a/mathics/builtin/numbers/calculus.py b/mathics/builtin/numbers/calculus.py index d5fc19770..556c8e5db 100644 --- a/mathics/builtin/numbers/calculus.py +++ b/mathics/builtin/numbers/calculus.py @@ -10,7 +10,7 @@ """ from itertools import product -from typing import Optional, Union +from typing import Optional import numpy as np import sympy @@ -59,16 +59,19 @@ from mathics.core.systemsymbols import ( SymbolAnd, SymbolAutomatic, + SymbolComplex, SymbolConditionalExpression, SymbolD, SymbolDerivative, SymbolInfinity, SymbolInfix, + SymbolInteger, SymbolIntegrate, SymbolLeft, SymbolLog, SymbolNIntegrate, SymbolO, + SymbolReal, SymbolRule, SymbolSequence, SymbolSeries, @@ -76,6 +79,7 @@ SymbolSimplify, SymbolUndefined, ) +from mathics.eval.calculus import solve_sympy from mathics.eval.makeboxes import format_element from mathics.eval.nevaluator import eval_N from mathics.eval.numbers.calculus.integrators import ( @@ -2210,105 +2214,38 @@ class Solve(Builtin): messages = { "eqf": "`1` is not a well-formed equation.", "svars": 'Equations may not give solutions for all "solve" variables.', + "fulldim": "The solution set contains a full-dimensional component; use Reduce for complete solution information.", } - # FIXME: the problem with removing the domain parameter from the outside - # is that the we can't make use of this information inside - # the evaluation method where it is may be needed. rules = { - "Solve[eqs_, vars_, Complexes]": "Solve[eqs, vars]", - "Solve[eqs_, vars_, Reals]": ( - "Cases[Solve[eqs, vars], {Rule[x_,y_?RealValuedNumberQ]}]" - ), - "Solve[eqs_, vars_, Integers]": ( - "Cases[Solve[eqs, vars], {Rule[x_,y_Integer]}]" - ), + "Solve[eqs_, vars_]": "Solve[eqs, vars, Complexes]" } summary_text = "find generic solutions for variables" - def eval(self, eqs, vars, evaluation: Evaluation): - "Solve[eqs_, vars_]" + def eval(self, eqs, vars, domain, evaluation: Evaluation): + "Solve[eqs_, vars_, domain_]" - vars_original = vars - head_name = vars.get_head_name() + variables = vars + head_name = variables.get_head_name() if head_name == "System`List": - vars = vars.elements + variables = variables.elements else: - vars = [vars] - for var in vars: + variables = [variables] + for var in variables: if ( (isinstance(var, Atom) and not isinstance(var, Symbol)) or head_name in ("System`Plus", "System`Times", "System`Power") or # noqa A_CONSTANT & var.get_attributes(evaluation.definitions) ): - evaluation.message("Solve", "ivar", vars_original) + evaluation.message("Solve", "ivar", vars) return - vars_sympy = [var.to_sympy() for var in vars] - if None in vars_sympy: + sympy_variables = [var.to_sympy() for var in variables] + if None in sympy_variables: evaluation.message("Solve", "ivar") return - all_var_tuples = list(zip(vars, vars_sympy)) - - def cut_var_dimension(expressions: Union[Expression, list[Expression]]): - '''delete unused variables to avoid SymPy's PolynomialError - : Not a zero-dimensional system in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]''' - if not isinstance(expressions, list): - expressions = [expressions] - subset_vars = set() - subset_vars_sympy = set() - for var, var_sympy in all_var_tuples: - pattern = Pattern.create(var) - for equation in expressions: - if not equation.is_free(pattern, evaluation): - subset_vars.add(var) - subset_vars_sympy.add(var_sympy) - return subset_vars, subset_vars_sympy - - def solve_sympy(equations: Union[Expression, list[Expression]]): - if not isinstance(equations, list): - equations = [equations] - equations_sympy = [] - denoms_sympy = [] - subset_vars, subset_vars_sympy = cut_var_dimension(equations) - for equation in equations: - if equation is SymbolTrue: - continue - elif equation is SymbolFalse: - return [] - elements = equation.elements - for left, right in [(elements[index], elements[index + 1]) for index in range(len(elements) - 1)]: - # ↑ to deal with things like a==b==c==d - left = left.to_sympy() - right = right.to_sympy() - if left is None or right is None: - return [] - equation_sympy = left - right - equation_sympy = sympy.together(equation_sympy) - equation_sympy = sympy.cancel(equation_sympy) - equations_sympy.append(equation_sympy) - numer, denom = equation_sympy.as_numer_denom() - denoms_sympy.append(denom) - try: - results = sympy.solve(equations_sympy, subset_vars_sympy, dict=True) # no transform_dict needed with dict=True - # Filter out results for which denominator is 0 - # (SymPy should actually do that itself, but it doesn't!) - results = [ - sol - for sol in results - if all(sympy.simplify(denom.subs(sol)) != 0 for denom in denoms_sympy) - ] - return results - except sympy.PolynomialError: - # raised for e.g. Solve[x^2==1&&z^2==-1,{x,y,z}] when not deleting - # unused variables beforehand - return [] - except NotImplementedError: - return [] - except TypeError as exc: - if str(exc).startswith("expected Symbol, Function or Derivative"): - evaluation.message("Solve", "ivar", vars_original) + variable_tuples = list(zip(variables, sympy_variables)) def solve_recur(expression: Expression): '''solve And, Or and List within the scope of sympy, @@ -2336,7 +2273,7 @@ def solve_recur(expression: Expression): inequations.append(sub_condition) else: inequations.append(child.to_sympy()) - solutions.extend(solve_sympy(equations)) + solutions.extend(solve_sympy(evaluation, equations, variables, domain)) conditions = sympy.And(*inequations) result = [sol for sol in solutions if conditions.subs(sol)] return result, None if solutions else conditions @@ -2346,7 +2283,7 @@ def solve_recur(expression: Expression): conditions = [] for child in expression.elements: if child.has_form("Equal", 2): - solutions.extend(solve_sympy(child)) + solutions.extend(solve_sympy(evaluation, child, variables, domain)) elif child.get_head_name() in ('System`And', 'System`Or'): # I don't believe List would be in here sub_solution, sub_condition = solve_recur(child) solutions.extend(sub_solution) @@ -2365,8 +2302,8 @@ def solve_recur(expression: Expression): if conditions is not None: evaluation.message("Solve", "fulldim") else: - if eqs.has_form("Equal", 2): - solutions = solve_sympy(eqs) + if eqs.get_head_name() == "System`Equal": + solutions = solve_sympy(evaluation, eqs, variables, domain) else: evaluation.message("Solve", "fulldim") return ListExpression(ListExpression()) @@ -2376,7 +2313,7 @@ def solve_recur(expression: Expression): return ListExpression(ListExpression()) if any( - sol and any(var not in sol for var in vars_sympy) for sol in solutions + sol and any(var not in sol for var in sympy_variables) for sol in solutions ): evaluation.message("Solve", "svars") @@ -2385,7 +2322,7 @@ def solve_recur(expression: Expression): ListExpression( *( Expression(SymbolRule, var, from_sympy(sol[var_sympy])) - for var, var_sympy in all_var_tuples + for var, var_sympy in variable_tuples if var_sympy in sol ), ) diff --git a/mathics/core/atoms.py b/mathics/core/atoms.py index eadc66ad0..64249c2dd 100644 --- a/mathics/core/atoms.py +++ b/mathics/core/atoms.py @@ -774,9 +774,9 @@ def get_sort_key(self, pattern_sort=False) -> tuple: def sameQ(self, other) -> bool: """Mathics SameQ""" return ( - isinstance(other, Complex) - and self.real == other.real - and self.imag == other.imag + isinstance(other, Complex) and + self.real == other.real and + self.imag == other.imag ) def round(self, d=None) -> "Complex": diff --git a/mathics/core/convert/sympy.py b/mathics/core/convert/sympy.py index 843d679f3..fa333b906 100644 --- a/mathics/core/convert/sympy.py +++ b/mathics/core/convert/sympy.py @@ -5,6 +5,7 @@ Conversion to SymPy is handled directly in BaseElement descendants. """ +from collections.abc import Iterable from typing import Optional, Type, Union import sympy @@ -13,9 +14,6 @@ # Import the singleton class from sympy.core.numbers import S -BasicSympy = sympy.Expr - - from mathics.core.atoms import ( MATHICS3_COMPLEX_I, Complex, @@ -40,6 +38,7 @@ ) from mathics.core.list import ListExpression from mathics.core.number import FP_MANTISA_BINARY_DIGITS +from mathics.core.rules import Pattern from mathics.core.symbols import ( Symbol, SymbolFalse, @@ -62,16 +61,21 @@ SymbolGreater, SymbolGreaterEqual, SymbolIndeterminate, + SymbolIntegers, SymbolLess, SymbolLessEqual, SymbolMatrixPower, SymbolO, SymbolPi, SymbolPiecewise, + SymbolReals, SymbolSlot, SymbolUnequal, ) +BasicSympy = sympy.Expr + + SymbolPrime = Symbol("Prime") SymbolRoot = Symbol("Root") SymbolRootSum = Symbol("RootSum") @@ -130,6 +134,39 @@ def to_sympy_matrix(data, **kwargs) -> Optional[sympy.MutableDenseMatrix]: return None +def apply_domain_to_symbols(symbols: Iterable[sympy.Symbol], domain) -> dict[sympy.Symbol, sympy.Symbol]: + """Create new sympy symbols with domain applied. + Return a dict maps old to new. + """ + # FIXME: this substitute solution would break when Solve[Abs[x]==3, x],where x=-3 and x=3. + # However, substituting symbol prior to actual solving would cause sympy to have biased assumption, + # it would refuse to solve Abs() when symbol is in Complexes + result = {} + for symbol in symbols: + if domain == SymbolReals: + new_symbol = sympy.Symbol(repr(symbol), real=True) + elif domain == SymbolIntegers: + new_symbol = sympy.Symbol(repr(symbol), integer=True) + else: + new_symbol = symbol + result[symbol] = new_symbol + return result + + +def cut_dimension(evaluation, expressions: Union[Expression, list[Expression]], symbols: Iterable[sympy.Symbol]) -> set[sympy.Symbol]: + '''delete unused variables to avoid SymPy's PolynomialError + : Not a zero-dimensional system in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]''' + if not isinstance(expressions, list): + expressions = [expressions] + subset = set() + for symbol in symbols: + pattern = Pattern.create(symbol) + for equation in expressions: + if not equation.is_free(pattern, evaluation): + subset.add(symbol) + return subset + + class SympyExpression(BasicSympy): is_Function = True nargs = None @@ -363,9 +400,9 @@ def old_from_sympy(expr) -> BaseElement: if is_Cn_expr(name): return Expression(SymbolC, Integer(int(name[1:]))) if name.startswith(sympy_symbol_prefix): - name = name[len(sympy_symbol_prefix) :] + name = name[len(sympy_symbol_prefix):] if name.startswith(sympy_slot_prefix): - index = name[len(sympy_slot_prefix) :] + index = name[len(sympy_slot_prefix):] return Expression(SymbolSlot, Integer(int(index))) elif expr.is_NumberSymbol: name = str(expr) @@ -517,7 +554,7 @@ def old_from_sympy(expr) -> BaseElement: *[from_sympy(arg) for arg in expr.args] ) if name.startswith(sympy_symbol_prefix): - name = name[len(sympy_symbol_prefix) :] + name = name[len(sympy_symbol_prefix):] args = [from_sympy(arg) for arg in expr.args] builtin = sympy_to_mathics.get(name) if builtin is not None: diff --git a/mathics/core/systemsymbols.py b/mathics/core/systemsymbols.py index fa388a92d..fb1ea6d79 100644 --- a/mathics/core/systemsymbols.py +++ b/mathics/core/systemsymbols.py @@ -56,6 +56,7 @@ SymbolCompile = Symbol("System`Compile") SymbolCompiledFunction = Symbol("System`CompiledFunction") SymbolComplex = Symbol("System`Complex") +SymbolComplexes = Symbol("System`Complexes") SymbolComplexInfinity = Symbol("System`ComplexInfinity") SymbolCondition = Symbol("System`Condition") SymbolConditionalExpression = Symbol("System`ConditionalExpression") @@ -124,6 +125,7 @@ SymbolInfix = Symbol("System`Infix") SymbolInputForm = Symbol("System`InputForm") SymbolInteger = Symbol("System`Integer") +SymbolIntegers = Symbol("System`Integers") SymbolIntegrate = Symbol("System`Integrate") SymbolLeft = Symbol("System`Left") SymbolLength = Symbol("System`Length") @@ -200,6 +202,7 @@ SymbolRational = Symbol("System`Rational") SymbolRe = Symbol("System`Re") SymbolReal = Symbol("System`Real") +SymbolReals = Symbol("System`Reals") SymbolRealAbs = Symbol("System`RealAbs") SymbolRealDigits = Symbol("System`RealDigits") SymbolRealSign = Symbol("System`RealSign") diff --git a/test/builtin/calculus/test_solve.py b/test/builtin/calculus/test_solve.py index dd190b8b7..234a641dc 100644 --- a/test/builtin/calculus/test_solve.py +++ b/test/builtin/calculus/test_solve.py @@ -43,14 +43,34 @@ def test_solve(): "Issue #1235", ), ( - "Solve[{x^2==4 && x < 0},{x}]", - "{x->-2}", - "", + "Solve[Abs[-2/3*(lambda + 2) + 8/3 + 4] == 4, lambda,Reals]", + "{{lambda -> 2}, {lambda -> 14}}", + "abs()", ), ( - "Solve[{x^2==4 && x < 0 && x > -4},{x}]", - "{x->-2}", - "", + "Solve[q^3 == (20-12)/(4-3), q,Reals]", + "{{q -> 2}}", + "domain check", + ), + ( + "Solve[x + Pi/3 == 2k*Pi + Pi/6 || x + Pi/3 == 2k*Pi + 5Pi/6, x,Reals]", + "{{x -> -Pi / 6 + 2 k Pi}, {x -> Pi / 2 + 2 k Pi}}", + "logics involved", + ), + ( + "Solve[m - 1 == 0 && -(m + 1) != 0, m,Reals]", + "{{m -> 1}}", + "logics and constraints", + ), + ( + "Solve[(lambda + 1)/6 == 1/(mu - 1) == lambda/4, {lambda, mu},Reals]", + "{{lambda -> 2, mu -> 3}}", + "chained equations", + ), + ( + "Solve[2*x0*Log[x0] + x0 - 2*a*x0 == -1 && x0^2*Log[x0] - a*x0^2 + b == b - x0, {x0, a, b},Reals]", + "{{x0 -> 1, a -> 1}}", + "excess variable b", ), ): session.evaluate("Clear[h]; Clear[g]; Clear[f];")