Skip to content

Commit 0d56936

Browse files
committed
Fix downcasts during addition normalization
- Allow Conditions in pattern restrictions
1 parent 0f90f44 commit 0d56936

File tree

3 files changed

+28
-12
lines changed

3 files changed

+28
-12
lines changed

src/api/python.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -3638,7 +3638,7 @@ impl PythonExpression {
36383638
filter_fn
36393639
.call(py, (data,), None)
36403640
.expect("Bad callback function")
3641-
.extract::<bool>(py)
3641+
.is_truthy(py)
36423642
.expect("Pattern filter does not return a boolean")
36433643
})
36443644
})),
@@ -3801,7 +3801,7 @@ impl PythonExpression {
38013801
cmp_fn
38023802
.call(py, (data1, data2), None)
38033803
.expect("Bad callback function")
3804-
.extract::<bool>(py)
3804+
.is_truthy(py)
38053805
.expect("Pattern comparison does not return a boolean")
38063806
})
38073807
}),
@@ -11591,7 +11591,7 @@ impl PythonGraph {
1159111591
Python::with_gil(|py| {
1159211592
match filter_fn.call(py, (Self { graph: g.clone() }, v), None) {
1159311593
Ok(r) => r
11594-
.extract::<bool>(py)
11594+
.is_truthy(py)
1159511595
.expect("Match map does not return a boolean"),
1159611596
Err(e) => {
1159711597
if e.is_instance_of::<exceptions::PyKeyboardInterrupt>(py) {

src/normalize.rs

+16-2
Original file line numberDiff line numberDiff line change
@@ -1429,7 +1429,15 @@ impl<'a> AtomView<'a> {
14291429
}
14301430
}
14311431

1432-
a.set_normalized(true);
1432+
if a.get_nargs() == 0 {
1433+
out.to_num(Coefficient::zero());
1434+
} else if a.get_nargs() == 1 {
1435+
let mut b = ws.new_atom();
1436+
b.set_from_view(&a.to_add_view().iter().next().unwrap());
1437+
out.set_from_view(&b.as_view());
1438+
} else {
1439+
a.set_normalized(true);
1440+
}
14331441
return;
14341442
}
14351443
}
@@ -1517,7 +1525,13 @@ impl<'a> AtomView<'a> {
15171525
}
15181526
}
15191527

1520-
a.set_normalized(true);
1528+
if a.get_nargs() == 1 {
1529+
let mut b = ws.new_atom();
1530+
b.set_from_view(&a.to_add_view().iter().next().unwrap());
1531+
out.set_from_view(&b.as_view());
1532+
} else {
1533+
a.set_normalized(true);
1534+
}
15211535
} else if let AtomView::Add(_) = rhs {
15221536
rhs.add_normalized(*self, ws, out);
15231537
} else {

symbolica.pyi

+9-7
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ class Expression:
331331
"""
332332

333333
@classmethod
334-
def load(_cls, filename: str, conflict_fn: Callable[[str], str]) -> Expression:
334+
def load(_cls, filename: str, conflict_fn: Optional[Callable[[str], str]] = None) -> Expression:
335335
"""Load an expression and its state from a file. The state will be merged
336336
with the current one. If a symbol has conflicting attributes, the conflict
337337
can be resolved using the renaming function `conflict_fn`.
@@ -587,7 +587,7 @@ class Expression:
587587

588588
def req(
589589
self,
590-
filter_fn: Callable[[Expression], bool],
590+
filter_fn: Callable[[Expression], bool | Condition],
591591
) -> PatternRestriction:
592592
"""
593593
Create a new pattern restriction that calls the function `filter_fn` with the matched
@@ -605,7 +605,7 @@ class Expression:
605605
def req_cmp(
606606
self,
607607
other: Expression | int | float | Decimal,
608-
cmp_fn: Callable[[Expression, Expression], bool],
608+
cmp_fn: Callable[[Expression, Expression], bool | Condition],
609609
) -> PatternRestriction:
610610
"""
611611
Create a new pattern restriction that calls the function `cmp_fn` with another the matched
@@ -1445,24 +1445,25 @@ class Expression:
14451445
"""
14461446

14471447
def canonize_tensors(self,
1448-
contracted_indices: Sequence[Tuple[Expression | int,Expression | int]]) -> Expression:
1448+
contracted_indices: Sequence[Tuple[Expression | int, Expression | int]]) -> Expression:
14491449
"""Canonize (products of) tensors in the expression by relabeling repeated indices.
14501450
The tensors must be written as functions, with its indices as the arguments.
14511451
Subexpressions, constants and open indices are supported.
1452-
1452+
14531453
If the contracted indices are distinguishable (for example in their dimension),
14541454
you can provide a group marker as the second element in the tuple of the index
14551455
specification.
14561456
This makes sure that an index will not be renamed to an index from a different group.
1457-
1457+
14581458
Examples
14591459
--------
14601460
>>> g = Expression.symbol('g', is_symmetric=True)
14611461
>>> fc = Expression.symbol('fc', is_cyclesymmetric=True)
14621462
>>> mu1, mu2, mu3, mu4, k1 = Expression.symbol('mu1', 'mu2', 'mu3', 'mu4', 'k1')
14631463
>>> e = g(mu2, mu3)*fc(mu4, mu2, k1, mu4, k1, mu3)
14641464
>>> print(e.canonize_tensors([(mu1, 0), (mu2, 0), (mu3, 0), (mu4, 0)]))
1465-
yields `g(mu1,mu2)*fc(mu1,mu3,mu2,k1,mu3,k1)`.
1465+
1466+
yields `g(mu1, mu2)*fc(mu1, mu3, mu2, k1, mu3, k1)`.
14661467
"""
14671468

14681469

@@ -2157,6 +2158,7 @@ class Transformer:
21572158
square_brackets_for_function: bool = False,
21582159
num_exp_as_superscript: bool = True,
21592160
latex: bool = False,
2161+
show_namespaces: bool = False,
21602162
) -> Transformer:
21612163
"""
21622164
Create a transformer that prints the expression.

0 commit comments

Comments
 (0)