Skip to content

Commit 703ceb4

Browse files
committed
Fix ARM ASM division
- Add documentation for S, N, E in Python API
1 parent b8372b3 commit 703ceb4

File tree

2 files changed

+186
-56
lines changed

2 files changed

+186
-56
lines changed

src/evaluate.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2022,13 +2022,14 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
20222022
(self.reserved_indices - self.param_count) * 8
20232023
);
20242024
*out += &format!(
2025-
"\t\t\"fdiv d{}, d{}, d31\\n\\t\"\n",
2025+
"\t\t\"fdiv d{}, d31, d{}\\n\\t\"\n",
20262026
out_reg, out_reg
20272027
);
20282028
}
20292029
InlineASM::None => unreachable!(),
20302030
}
20312031
} else {
2032+
// load 1 into out_reg
20322033
match asm_flavour {
20332034
InlineASM::X64 => {
20342035
*out += &format!(
@@ -2058,7 +2059,7 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
20582059
InlineASM::AArch64 => {
20592060
*out += &format!(
20602061
"\t\t\"fdiv d{}, d{}, d{}\\n\\t\"\n",
2061-
out_reg, j, out_reg
2062+
out_reg, out_reg, j
20622063
);
20632064
}
20642065
InlineASM::None => unreachable!(),
@@ -2074,7 +2075,7 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
20742075
InlineASM::AArch64 => {
20752076
let addr = asm_load!(*k);
20762077
*out +=
2077-
&format!("\t\t\"ldr d31, {}\\n\\t\"\n", addr,);
2078+
&format!("\t\t\"ldr d31, {}\\n\\t\"\n", addr);
20782079

20792080
*out += &format!(
20802081
"\t\t\"fdiv d{}, d{}, d31\\n\\t\"\n",
@@ -2117,7 +2118,7 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
21172118
InlineASM::AArch64 => {
21182119
*out += &format!(
21192120
"\t\t\"fdiv d{}, d{}, d{}\\n\\t\"\n",
2120-
out_reg, j, out_reg
2121+
out_reg, out_reg, j
21212122
);
21222123
}
21232124
InlineASM::None => unreachable!(),
@@ -2136,7 +2137,7 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
21362137
&format!("\t\t\"ldr d31, {}\\n\\t\"\n", addr,);
21372138

21382139
*out += &format!(
2139-
"\t\t\"fdiv d{}, d{}, d31\\n\\t\"\n",
2140+
"\t\t\"fdiv d{}, d31, d{}\\n\\t\"\n",
21402141
out_reg, out_reg
21412142
);
21422143
}
@@ -2266,10 +2267,10 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
22662267

22672268
match asm_flavour {
22682269
InlineASM::X64 => {
2269-
*out += &format!("\t\t:\n\t\t: \"r\"(out), \"r\"(Z), \"r\"({}_CONSTANTS_double), \"r\"(params)\n\t\t: \"memory\", \"xmm0\");\n", function_name);
2270+
*out += &format!("\t\t:\n\t\t: \"r\"(out), \"r\"(Z), \"r\"({}_CONSTANTS_double), \"r\"(params)\n\t\t: \"memory\", \"xmm0\", \"xmm1\", \"xmm2\", \"xmm3\", \"xmm4\", \"xmm5\", \"xmm6\", \"xmm7\", \"xmm8\", \"xmm9\", \"xmm10\", \"xmm11\", \"xmm12\", \"xmm13\", \"xmm14\", \"xmm15\");\n", function_name);
22702271
}
22712272
InlineASM::AArch64 => {
2272-
*out += &format!("\t\t:\n\t\t: \"r\"(out), \"r\"(Z), \"r\"({}_CONSTANTS_double), \"r\"(params)\n\t\t: \"memory\", \"d0\");\n", function_name);
2273+
*out += &format!("\t\t:\n\t\t: \"r\"(out), \"r\"(Z), \"r\"({}_CONSTANTS_double), \"r\"(params)\n\t\t: \"memory\", \"d0\", \"d1\", \"d2\", \"d3\", \"d4\", \"d5\", \"d6\", \"d7\", \"d8\", \"d9\", \"d10\", \"d11\", \"d12\", \"d13\", \"d14\", \"d15\", \"d16\", \"d17\", \"d18\", \"d19\", \"d20\", \"d21\", \"d22\", \"d23\", \"d24\", \"d25\", \"d26\", \"d27\", \"d28\", \"d29\", \"d30\", \"d31\");\n", function_name);
22732274
}
22742275
InlineASM::None => unreachable!(),
22752276
}
@@ -2570,10 +2571,10 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
25702571
);
25712572

25722573
if *o * 16 < 450 {
2573-
*out += &format!("\t\t\"stp d0, d1, {}\\n\\t\"", addr_o.0);
2574+
*out += &format!("\t\t\"stp d0, d1, {}\\n\\t\"\n", addr_o.0);
25742575
} else {
25752576
*out += &format!("\t\t\"str d0, {}\\n\\t\"", addr_o.0);
2576-
*out += &format!("\t\t\"str d1, {}\\n\\t\"", addr_o.1);
2577+
*out += &format!("\t\t\"str d1, {}\\n\\t\"\n", addr_o.1);
25772578
}
25782579
}
25792580
InlineASM::None => unreachable!(),
@@ -2674,10 +2675,10 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
26742675

26752676
match asm_flavour {
26762677
InlineASM::X64 => {
2677-
*out += &format!("\t\t:\n\t\t: \"r\"(out), \"r\"(Z), \"r\"({}_CONSTANTS_complex), \"r\"(params)\n\t\t: \"memory\", \"xmm0\");\n", function_name);
2678+
*out += &format!("\t\t:\n\t\t: \"r\"(out), \"r\"(Z), \"r\"({}_CONSTANTS_complex), \"r\"(params)\n\t\t: \"memory\", \"xmm0\", \"xmm1\", \"xmm2\", \"xmm3\", \"xmm4\", \"xmm5\", \"xmm6\", \"xmm7\", \"xmm8\", \"xmm9\", \"xmm10\", \"xmm11\", \"xmm12\", \"xmm13\", \"xmm14\", \"xmm15\");\n", function_name);
26782679
}
26792680
InlineASM::AArch64 => {
2680-
*out += &format!("\t\t:\n\t\t: \"r\"(out), \"r\"(Z), \"r\"({}_CONSTANTS_complex), \"r\"(params)\n\t\t: \"memory\", \"d0\", \"d1\");\n", function_name);
2681+
*out += &format!("\t\t:\n\t\t: \"r\"(out), \"r\"(Z), \"r\"({}_CONSTANTS_complex), \"r\"(params)\n\t\t: \"memory\", \"d0\", \"d1\", \"d2\", \"d3\", \"d4\", \"d5\", \"d6\", \"d7\", \"d8\", \"d9\", \"d10\", \"d11\", \"d12\", \"d13\", \"d14\", \"d15\", \"d16\", \"d17\", \"d18\", \"d19\", \"d20\", \"d21\", \"d22\", \"d23\", \"d24\", \"d25\", \"d26\", \"d27\", \"d28\", \"d29\", \"d30\", \"d31\");\n", function_name);
26812682
}
26822683
InlineASM::None => unreachable!(),
26832684
}

symbolica.pyi

Lines changed: 174 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,61 @@ def S(name: str,
4444
is_cyclesymmetric: Optional[bool] = None,
4545
is_linear: Optional[bool] = None,
4646
custom_normalization: Optional[Transformer] = None) -> Expression:
47-
"""Shorthand notation for :func:`Expression.symbol`"""
47+
"""
48+
Create new symbols from `names`. Symbols can have attributes,
49+
such as symmetries. If no attributes
50+
are specified and the symbol was previously defined, the attributes are inherited.
51+
Once attributes are defined on a symbol, they cannot be redefined later.
52+
53+
Examples
54+
--------
55+
Define a regular symbol and use it as a variable:
56+
>>> x = S('x')
57+
>>> e = x**2 + 5
58+
>>> print(e)
59+
x**2 + 5
60+
61+
Define a regular symbol and use it as a function:
62+
>>> f = S('f')
63+
>>> e = f(1,2)
64+
>>> print(e)
65+
f(1,2)
66+
67+
68+
Define a symmetric function:
69+
>>> f = S('f', is_symmetric=True)
70+
>>> e = f(2,1)
71+
>>> print(e)
72+
f(1,2)
73+
74+
75+
Define a linear and symmetric function:
76+
>>> p1, p2, p3, p4 = ES('p1', 'p2', 'p3', 'p4')
77+
>>> dot = S('dot', is_symmetric=True, is_linear=True)
78+
>>> e = dot(p2+2*p3,p1+3*p2-p3)
79+
dot(p1,p2)+2*dot(p1,p3)+3*dot(p2,p2)-dot(p2,p3)+6*dot(p2,p3)-2*dot(p3,p3)
80+
81+
Define a custom normalization function:
82+
>>> e = S('real_log', custom_normalization=Transformer().replace_all(E("x_(exp(x1_))"), E("x1_")))
83+
>>> E("real_log(exp(x)) + real_log(5)")
84+
85+
Parameters
86+
----------
87+
name : str
88+
The name of the symbol
89+
is_symmetric : Optional[bool]
90+
Set to true if the symbol is symmetric.
91+
is_antisymmetric : Optional[bool]
92+
Set to true if the symbol is antisymmetric.
93+
is_cyclesymmetric : Optional[bool]
94+
Set to true if the symbol is cyclesymmetric.
95+
is_linear : Optional[bool]
96+
Set to true if the symbol is linear.
97+
custom_normalization : Optional[Transformer]
98+
A transformer that is called after every normalization. Note that the symbol
99+
name cannot be used in the transformer as this will lead to a definition of the
100+
symbol. Use a wildcard with the same attributes instead.
101+
"""
48102

49103

50104
@overload
@@ -54,15 +108,84 @@ def S(*names: str,
54108
is_cyclesymmetric: Optional[bool] = None,
55109
is_linear: Optional[bool] = None,
56110
custom_normalization: Optional[Transformer] = None) -> Sequence[Expression]:
57-
"""Shorthand notation for :func:`Expression.symbol`"""
111+
"""
112+
Create new symbols from `names`. Symbols can have attributes,
113+
such as symmetries. If no attributes
114+
are specified and the symbol was previously defined, the attributes are inherited.
115+
Once attributes are defined on a symbol, they cannot be redefined later.
116+
117+
Examples
118+
--------
119+
Define two regular symbols:
120+
>>> x, y = S('x', 'y')
121+
122+
Define two symmetric functions:
123+
>>> f, g = S('f', 'g', is_symmetric=True)
124+
>>> e = f(2,1)
125+
>>> print(e)
126+
f(1,2)
127+
128+
Parameters
129+
----------
130+
name : str
131+
The name of the symbol
132+
is_symmetric : Optional[bool]
133+
Set to true if the symbol is symmetric.
134+
is_antisymmetric : Optional[bool]
135+
Set to true if the symbol is antisymmetric.
136+
is_cyclesymmetric : Optional[bool]
137+
Set to true if the symbol is cyclesymmetric.
138+
is_linear : Optional[bool]
139+
Set to true if the symbol is multilinear.
140+
custom_normalization : Optional[Transformer]
141+
A transformer that is called after every normalization. Note that the symbol
142+
name cannot be used in the transformer as this will lead to a definition of the
143+
symbol. Use a wildcard with the same attributes instead.
144+
"""
58145

59146

60147
def N(num: int | float | str | Decimal, relative_error: Optional[float] = None) -> Expression:
61-
"""Shorthand notation for :func:`Expression.num`"""
148+
"""Create a new Symbolica number from an int, a float, or a string.
149+
A floating point number is kept as a float with the same precision as the input,
150+
but it can also be converted to the smallest rational number given a `relative_error`.
151+
152+
Examples
153+
--------
154+
>>> e = N(1) / 2
155+
>>> print(e)
156+
1/2
157+
158+
>>> print(N(1/3))
159+
>>> print(N(0.33, 0.1))
160+
>>> print(N('0.333`3'))
161+
>>> print(N(Decimal('0.1234')))
162+
3.3333333333333331e-1
163+
1/3
164+
3.33e-1
165+
1.2340e-1
166+
"""
62167

63168

64169
def E(input: str) -> Expression:
65-
"""Shorthand notation for :func:`Expression.parse`"""
170+
"""
171+
Parse a Symbolica expression from a string.
172+
173+
Parameters
174+
----------
175+
input: str
176+
An input string. UTF-8 character are allowed.
177+
178+
Examples
179+
--------
180+
>>> e = E('x^2+y+y*4')
181+
>>> print(e)
182+
x^2+5*y
183+
184+
Raises
185+
------
186+
ValueError
187+
If the input is not a valid Symbolica expression.
188+
"""
66189

67190

68191
class AtomType(Enum):
@@ -150,14 +273,9 @@ class Expression:
150273
is_linear: Optional[bool] = None,
151274
custom_normalization: Optional[Transformer] = None) -> Expression:
152275
"""
153-
Create a new symbol from a `name`. Symbols carry information about their attributes.
154-
The symbol can signal that it is symmetric if it is used as a function
155-
using `is_symmetric=True`, antisymmetric using `is_antisymmetric=True`,
156-
cyclesymmetric using `is_cyclesymmetric=True`, and
157-
multilinear using `is_linear=True`. If no attributes
158-
are specified, the attributes are inherited from the symbol if it was already defined,
159-
otherwise all attributes are set to `false`.
160-
276+
Create new symbols from `names`. Symbols can have attributes,
277+
such as symmetries. If no attributes
278+
are specified and the symbol was previously defined, the attributes are inherited.
161279
Once attributes are defined on a symbol, they cannot be redefined later.
162280
163281
Examples
@@ -187,6 +305,27 @@ class Expression:
187305
>>> dot = Expression.symbol('dot', is_symmetric=True, is_linear=True)
188306
>>> e = dot(p2+2*p3,p1+3*p2-p3)
189307
dot(p1,p2)+2*dot(p1,p3)+3*dot(p2,p2)-dot(p2,p3)+6*dot(p2,p3)-2*dot(p3,p3)
308+
309+
Define a custom normalization function:
310+
>>> e = S('real_log', custom_normalization=Transformer().replace_all(E("x_(exp(x1_))"), E("x1_")))
311+
>>> E("real_log(exp(x)) + real_log(5)")
312+
313+
Parameters
314+
----------
315+
name : str
316+
The name of the symbol
317+
is_symmetric : Optional[bool]
318+
Set to true if the symbol is symmetric.
319+
is_antisymmetric : Optional[bool]
320+
Set to true if the symbol is antisymmetric.
321+
is_cyclesymmetric : Optional[bool]
322+
Set to true if the symbol is cyclesymmetric.
323+
is_linear : Optional[bool]
324+
Set to true if the symbol is linear.
325+
custom_normalization : Optional[Transformer]
326+
A transformer that is called after every normalization. Note that the symbol
327+
name cannot be used in the transformer as this will lead to a definition of the
328+
symbol. Use a wildcard with the same attributes instead.
190329
"""
191330

192331
@overload
@@ -199,48 +338,38 @@ class Expression:
199338
is_linear: Optional[bool] = None,
200339
custom_normalization: Optional[Transformer] = None) -> Sequence[Expression]:
201340
"""
202-
Create new symbols from `names`. Symbols carry information about their attributes.
203-
The symbol can signal that it is symmetric if it is used as a function
204-
using `is_symmetric=True`, antisymmetric using `is_antisymmetric=True`,
205-
cyclesymmetric using `is_cyclesymmetric=True`, and
206-
multilinear using `is_linear=True`. If no attributes
207-
are specified, the attributes are inherited from the symbol if it was already defined,
208-
otherwise all attributes are set to `false`. A transformer that is executed
209-
after normalization can be defined with `custom_normalization`.
210-
341+
Create new symbols from `names`. Symbols can have attributes,
342+
such as symmetries. If no attributes
343+
are specified and the symbol was previously defined, the attributes are inherited.
211344
Once attributes are defined on a symbol, they cannot be redefined later.
212345
213346
Examples
214347
--------
215-
Define a regular symbol and use it as a variable:
216-
>>> x = Expression.symbol('x')
217-
>>> e = x**2 + 5
218-
>>> print(e)
219-
x**2 + 5
220-
221-
Define a regular symbol and use it as a function:
222-
>>> f = Expression.symbol('f')
223-
>>> e = f(1,2)
224-
>>> print(e)
225-
f(1,2)
226-
348+
Define two regular symbols:
349+
>>> x, y = Expression.symbol('x', 'y')
227350
228-
Define a symmetric function:
229-
>>> f = Expression.symbol('f', is_symmetric=True)
351+
Define two symmetric functions:
352+
>>> f, g = Expression.symbol('f', 'g', is_symmetric=True)
230353
>>> e = f(2,1)
231354
>>> print(e)
232355
f(1,2)
233356
234-
235-
Define a linear and symmetric function:
236-
>>> p1, p2, p3, p4 = Expression.symbol('p1', 'p2', 'p3', 'p4')
237-
>>> dot = Expression.symbol('dot', is_symmetric=True, is_linear=True)
238-
>>> e = dot(p2+2*p3,p1+3*p2-p3)
239-
dot(p1,p2)+2*dot(p1,p3)+3*dot(p2,p2)-dot(p2,p3)+6*dot(p2,p3)-2*dot(p3,p3)
240-
241-
Define a custom normalization function:
242-
>>> e = S('real_log', custom_normalization=Transformer().replace_all(E("x_(exp(x1_))"), E("x1_")))
243-
>>> E("real_log(exp(x)) + real_log(5)")
357+
Parameters
358+
----------
359+
name : str
360+
The name of the symbol
361+
is_symmetric : Optional[bool]
362+
Set to true if the symbol is symmetric.
363+
is_antisymmetric : Optional[bool]
364+
Set to true if the symbol is antisymmetric.
365+
is_cyclesymmetric : Optional[bool]
366+
Set to true if the symbol is cyclesymmetric.
367+
is_linear : Optional[bool]
368+
Set to true if the symbol is multilinear.
369+
custom_normalization : Optional[Transformer]
370+
A transformer that is called after every normalization. Note that the symbol
371+
name cannot be used in the transformer as this will lead to a definition of the
372+
symbol. Use a wildcard with the same attributes instead.
244373
"""
245374

246375
@overload

0 commit comments

Comments
 (0)