Skip to content

Commit 0b5d49e

Browse files
authored
Merge pull request #517 from bjodah/dummy_index
Add attribute to Dummy: dummy_index
2 parents 630a2b0 + 0dd33bb commit 0b5d49e

File tree

5 files changed

+66
-13
lines changed

5 files changed

+66
-13
lines changed

symengine/lib/symengine.pxd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
133133
RCP[const Rational] rcp_static_cast_Rational "SymEngine::rcp_static_cast<const SymEngine::Rational>"(rcp_const_basic &b) nogil
134134
RCP[const Complex] rcp_static_cast_Complex "SymEngine::rcp_static_cast<const SymEngine::Complex>"(rcp_const_basic &b) nogil
135135
RCP[const Number] rcp_static_cast_Number "SymEngine::rcp_static_cast<const SymEngine::Number>"(rcp_const_basic &b) nogil
136+
RCP[const Dummy] rcp_static_cast_Dummy "SymEngine::rcp_static_cast<const SymEngine::Dummy>"(rcp_const_basic &b) nogil
136137
RCP[const Add] rcp_static_cast_Add "SymEngine::rcp_static_cast<const SymEngine::Add>"(rcp_const_basic &b) nogil
137138
RCP[const Mul] rcp_static_cast_Mul "SymEngine::rcp_static_cast<const SymEngine::Mul>"(rcp_const_basic &b) nogil
138139
RCP[const Pow] rcp_static_cast_Pow "SymEngine::rcp_static_cast<const SymEngine::Pow>"(rcp_const_basic &b) nogil
@@ -180,7 +181,7 @@ cdef extern from "<symengine/symbol.h>" namespace "SymEngine":
180181
Symbol(string name) nogil
181182
string get_name() nogil
182183
cdef cppclass Dummy(Symbol):
183-
pass
184+
size_t get_index()
184185

185186
cdef extern from "<symengine/number.h>" namespace "SymEngine":
186187
cdef cppclass Number(Basic):
@@ -322,6 +323,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
322323
rcp_const_basic make_rcp_Symbol "SymEngine::make_rcp<const SymEngine::Symbol>"(string name) nogil
323324
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"() nogil
324325
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"(string name) nogil
326+
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"(string &name, size_t index) nogil
325327
rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp<const SymEngine::PySymbol>"(string name, PyObject * pyobj, bool use_pickle) except +
326328
rcp_const_basic make_rcp_Constant "SymEngine::make_rcp<const SymEngine::Constant>"(string name) nogil
327329
rcp_const_basic make_rcp_Infty "SymEngine::make_rcp<const SymEngine::Infty>"(RCP[const Number] i) nogil

symengine/lib/symengine_wrapper.in.pyx

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,10 @@ def sympy2symengine(a, raise_error=False):
278278
"""
279279
import sympy
280280
from sympy.core.function import AppliedUndef as sympy_AppliedUndef
281-
if isinstance(a, sympy.Symbol):
281+
if isinstance(a, sympy.Dummy):
282+
return Dummy(a.name, a.dummy_index)
283+
elif isinstance(a, sympy.Symbol):
282284
return Symbol(a.name)
283-
elif isinstance(a, sympy.Dummy):
284-
return Dummy(a.name)
285285
elif isinstance(a, sympy.Mul):
286286
return mul(*[sympy2symengine(x, raise_error) for x in a.args])
287287
elif isinstance(a, sympy.Add):
@@ -1304,10 +1304,10 @@ cdef class Symbol(Expr):
13041304
return sympy.Symbol(str(self))
13051305

13061306
def __reduce__(self):
1307-
if type(self) == Symbol:
1307+
if type(self) in (Symbol, Dummy):
13081308
return Basic.__reduce__(self)
13091309
else:
1310-
raise NotImplementedError("pickling for Symbol subclass not implemented")
1310+
raise NotImplementedError("pickling for subclass of Symbol or Dummy not implemented")
13111311

13121312
def _sage_(self):
13131313
import sage.all as sage
@@ -1340,15 +1340,20 @@ cdef class Symbol(Expr):
13401340

13411341
cdef class Dummy(Symbol):
13421342

1343-
def __init__(Basic self, name=None, *args, **kwargs):
1344-
if name is None:
1345-
self.thisptr = symengine.make_rcp_Dummy()
1343+
def __init__(Basic self, name=None, dummy_index=None, *args, **kwargs):
1344+
cdef size_t index
1345+
if dummy_index is None:
1346+
if name is None:
1347+
self.thisptr = symengine.make_rcp_Dummy()
1348+
else:
1349+
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"))
13461350
else:
1347-
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"))
1351+
index = dummy_index
1352+
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"), index)
13481353

13491354
def _sympy_(self):
13501355
import sympy
1351-
return sympy.Dummy(str(self)[1:])
1356+
return sympy.Dummy(name=self.name, dummy_index=self.dummy_index)
13521357

13531358
@property
13541359
def is_Dummy(self):
@@ -1358,6 +1363,12 @@ cdef class Dummy(Symbol):
13581363
def func(self):
13591364
return self.__class__
13601365

1366+
@property
1367+
def dummy_index(self):
1368+
cdef RCP[const symengine.Dummy] this = \
1369+
symengine.rcp_static_cast_Dummy(self.thisptr)
1370+
cdef size_t index = deref(this).get_index()
1371+
return index
13611372

13621373
def symarray(prefix, shape, **kwargs):
13631374
""" Creates an nd-array of symbols

symengine/tests/test_pickling.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol
1+
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol, Dummy
22
from symengine.test_utilities import raises
33
import pickle
44
import unittest
@@ -57,3 +57,19 @@ def test_llvm_double():
5757
ll = pickle.loads(ss)
5858
inp = [1, 2, 3]
5959
assert np.allclose(l(inp), ll(inp))
60+
61+
62+
def _check_pickling_roundtrip(arg):
63+
s2 = pickle.dumps(arg)
64+
arg2 = pickle.loads(s2)
65+
assert arg == arg2
66+
s3 = pickle.dumps(arg2)
67+
arg3 = pickle.loads(s3)
68+
assert arg == arg3
69+
70+
71+
def test_pickling_roundtrip():
72+
x, y, z = symbols('x y z')
73+
_check_pickling_roundtrip(x+y)
74+
_check_pickling_roundtrip(Dummy('d'))
75+
_check_pickling_roundtrip(Dummy('d') - z)

symengine/tests/test_symbol.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ def test_dummy():
156156
x2 = Symbol('x')
157157
xdummy1 = Dummy('x')
158158
xdummy2 = Dummy('x')
159+
assert xdummy1.dummy_index != xdummy2.dummy_index # maybe test using "less than"?
160+
assert xdummy1.name == 'x'
161+
assert xdummy2.name == 'x'
159162

160163
assert x1 == x2
161164
assert x1 != xdummy1

symengine/tests/test_sympy_conv.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from symengine import (Symbol, Integer, sympify, SympifyError, log,
22
function_symbol, I, E, pi, oo, zoo, nan, true, false,
3-
exp, gamma, have_mpfr, have_mpc, DenseMatrix, sin, cos, tan, cot,
3+
exp, gamma, have_mpfr, have_mpc, DenseMatrix, Dummy, sin, cos, tan, cot,
44
csc, sec, asin, acos, atan, acot, acsc, asec, sinh, cosh, tanh, coth,
55
asinh, acosh, atanh, acoth, atan2, Add, Mul, Pow, diff, GoldenRatio,
66
Catalan, EulerGamma, UnevaluatedExpr, RealDouble)
@@ -833,3 +833,24 @@ def test_conv_large_integers():
833833
if have_sympy:
834834
c = a._sympy_()
835835
d = sympify(c)
836+
837+
838+
def _check_sympy_roundtrip(arg):
839+
arg_sy1 = sympy.sympify(arg)
840+
arg_se2 = sympify(arg_sy1)
841+
assert arg == arg_se2
842+
arg_sy2 = sympy.sympify(arg_se2)
843+
assert arg_sy2 == arg_sy1
844+
arg_se3 = sympify(arg_sy2)
845+
assert arg_se3 == arg
846+
847+
848+
@unittest.skipIf(not have_sympy, "SymPy not installed")
849+
def test_sympy_roundtrip():
850+
x = Symbol("x")
851+
y = Symbol("y")
852+
d = Dummy("d")
853+
_check_sympy_roundtrip(x)
854+
_check_sympy_roundtrip(x+y)
855+
_check_sympy_roundtrip(x**y)
856+
_check_sympy_roundtrip(d)

0 commit comments

Comments
 (0)