Skip to content

Commit 9796f18

Browse files
committed
Add methods to collect in symbols and factors
- `replace_map` now takes a `FnMut` as an argument
1 parent e6d31c2 commit 9796f18

File tree

7 files changed

+588
-14
lines changed

7 files changed

+588
-14
lines changed

src/api/python.rs

+189
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,99 @@ impl PythonTransformer {
13361336
return append_transformer!(self, Transformer::Collect(xs, key_map, coeff_map));
13371337
}
13381338

1339+
/// Create a transformer that collects terms involving the same power of variables or functions with the name `x`.
1340+
///
1341+
/// Both the key (the quantity collected in) and its coefficient can be mapped using
1342+
/// `key_map` and `coeff_map` transformers respectively.
1343+
///
1344+
/// Examples
1345+
/// --------
1346+
/// >>> from symbolica import Expression
1347+
/// >>> x, f = Expression.symbol('x', 'f')
1348+
/// >>> e = f(1,2) + x*f(1,2)
1349+
/// >>>
1350+
/// >>> print(e.transform().collect_symbol(x).execute())
1351+
///
1352+
/// yields `(1+x)*f(1,2)`.
1353+
///
1354+
/// Parameters
1355+
/// ----------
1356+
/// x: Expression
1357+
/// The symbol to collect in
1358+
/// key_map: Transformer
1359+
/// A transformer to be applied to the quantity collected in
1360+
/// coeff_map: Transformer
1361+
/// A transformer to be applied to the coefficient
1362+
#[pyo3(signature = (x, key_map = None, coeff_map = None))]
1363+
pub fn collect_symbol(
1364+
&self,
1365+
x: PythonExpression,
1366+
key_map: Option<PythonTransformer>,
1367+
coeff_map: Option<PythonTransformer>,
1368+
) -> PyResult<PythonTransformer> {
1369+
let Some(x) = x.expr.get_symbol() else {
1370+
return Err(exceptions::PyValueError::new_err(
1371+
"Collect must be done wrt a variable or function",
1372+
));
1373+
};
1374+
1375+
let key_map = if let Some(key_map) = key_map {
1376+
let Pattern::Transformer(p) = key_map.expr else {
1377+
return Err(exceptions::PyValueError::new_err(
1378+
"Key map must be a transformer",
1379+
));
1380+
};
1381+
1382+
if p.0.is_some() {
1383+
Err(exceptions::PyValueError::new_err(
1384+
"Key map must be an unbound transformer",
1385+
))?;
1386+
}
1387+
1388+
p.1.clone()
1389+
} else {
1390+
vec![]
1391+
};
1392+
1393+
let coeff_map = if let Some(coeff_map) = coeff_map {
1394+
let Pattern::Transformer(p) = coeff_map.expr else {
1395+
return Err(exceptions::PyValueError::new_err(
1396+
"Key map must be a transformer",
1397+
));
1398+
};
1399+
1400+
if p.0.is_some() {
1401+
Err(exceptions::PyValueError::new_err(
1402+
"Key map must be an unbound transformer",
1403+
))?;
1404+
}
1405+
1406+
p.1.clone()
1407+
} else {
1408+
vec![]
1409+
};
1410+
1411+
return append_transformer!(self, Transformer::CollectSymbol(x, key_map, coeff_map));
1412+
}
1413+
1414+
/// Create a transformer that collects common factors from (nested) sums.
1415+
///
1416+
/// Examples
1417+
/// --------
1418+
///
1419+
/// >>> from symbolica import *
1420+
/// >>> e = E('x*(x+y*x+x^2+y*(x+x^2))')
1421+
/// >>> e.transform().collect_factors().execute()
1422+
///
1423+
/// yields
1424+
///
1425+
/// ```log
1426+
/// v1^2*(1+v1+v2+v2*(1+v1))
1427+
/// ```
1428+
pub fn collect_factors(&self) -> PyResult<PythonTransformer> {
1429+
return append_transformer!(self, Transformer::CollectFactors);
1430+
}
1431+
13391432
/// Create a transformer that collects numerical factors by removing the numerical content from additions.
13401433
/// For example, `-2*x + 4*x^2 + 6*x^3` will be transformed into `-2*(x - 2*x^2 - 3*x^3)`.
13411434
///
@@ -4055,6 +4148,102 @@ impl PythonExpression {
40554148
Ok(b.into())
40564149
}
40574150

4151+
/// Collect terms involving the same power of variables or functions with the name `x`, e.g.
4152+
///
4153+
/// ```math
4154+
/// collect_symbol(f(1,2) + x*f*(1,2), f) = (1+x)*f(1,2)
4155+
/// ```
4156+
///
4157+
///
4158+
/// Both the *key* (the quantity collected in) and its coefficient can be mapped using
4159+
/// `key_map` and `coeff_map` respectively.
4160+
///
4161+
/// Examples
4162+
/// --------
4163+
///
4164+
/// >>> from symbolica import Expression
4165+
/// >>> x, f = Expression.symbol('x', 'f')
4166+
/// >>> e = f(1,2) + x*f(1,2)
4167+
/// >>>
4168+
/// >>> print(e.collect_symbol(f))
4169+
///
4170+
/// yields `(1+x)*f(1,2)`.
4171+
#[pyo3(signature = (x, key_map = None, coeff_map = None))]
4172+
pub fn collect_symbol(
4173+
&self,
4174+
x: PythonExpression,
4175+
key_map: Option<PyObject>,
4176+
coeff_map: Option<PyObject>,
4177+
) -> PyResult<PythonExpression> {
4178+
let Some(x) = x.expr.get_symbol() else {
4179+
return Err(exceptions::PyValueError::new_err(
4180+
"Collect must be done wrt a variable or function",
4181+
));
4182+
};
4183+
4184+
let b = self.expr.collect_symbol::<i16>(
4185+
x,
4186+
if let Some(key_map) = key_map {
4187+
Some(Box::new(move |key, out| {
4188+
Python::with_gil(|py| {
4189+
let key: PythonExpression = key.to_owned().into();
4190+
4191+
out.set_from_view(
4192+
&key_map
4193+
.call(py, (key,), None)
4194+
.expect("Bad callback function")
4195+
.extract::<PythonExpression>(py)
4196+
.expect("Key map should return an expression")
4197+
.expr
4198+
.as_view(),
4199+
)
4200+
});
4201+
}))
4202+
} else {
4203+
None
4204+
},
4205+
if let Some(coeff_map) = coeff_map {
4206+
Some(Box::new(move |coeff, out| {
4207+
Python::with_gil(|py| {
4208+
let coeff: PythonExpression = coeff.to_owned().into();
4209+
4210+
out.set_from_view(
4211+
&coeff_map
4212+
.call(py, (coeff,), None)
4213+
.expect("Bad callback function")
4214+
.extract::<PythonExpression>(py)
4215+
.expect("Coeff map should return an expression")
4216+
.expr
4217+
.as_view(),
4218+
)
4219+
});
4220+
}))
4221+
} else {
4222+
None
4223+
},
4224+
);
4225+
4226+
Ok(b.into())
4227+
}
4228+
4229+
/// Collect common factors from (nested) sums.
4230+
///
4231+
/// Examples
4232+
/// --------
4233+
///
4234+
/// >>> from symbolica import *
4235+
/// >>> e = E('x*(x+y*x+x^2+y*(x+x^2))')
4236+
/// >>> e.collect_factors()
4237+
///
4238+
/// yields
4239+
///
4240+
/// ```log
4241+
/// v1^2*(1+v1+v2+v2*(1+v1))
4242+
/// ```
4243+
pub fn collect_factors(&self) -> PythonExpression {
4244+
self.expr.collect_factors().into()
4245+
}
4246+
40584247
/// Collect numerical factors by removing the numerical content from additions.
40594248
/// For example, `-2*x + 4*x^2 + 6*x^3` will be transformed into `-2*(x - 2*x^2 - 3*x^3)`.
40604249
///

src/atom.rs

+7
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,13 @@ impl<'a> From<&AtomView<'a>> for AtomOrView<'a> {
832832
}
833833

834834
impl<'a> AtomOrView<'a> {
835+
pub fn into_owned(self) -> Atom {
836+
match self {
837+
AtomOrView::Atom(a) => a,
838+
AtomOrView::View(a) => a.to_owned(),
839+
}
840+
}
841+
835842
pub fn as_view(&'a self) -> AtomView<'a> {
836843
match self {
837844
AtomOrView::Atom(a) => a.as_view(),

src/atom/core.rs

+46-2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ pub trait AtomCore {
7070
/// collect(x + x * y + x^2, x) = x * (1+y) + x^2
7171
/// ```
7272
///
73+
/// Use [collect_symbol](AtomCore::collect_symbol) to collect using the name of a function only.
74+
///
7375
/// Both the *key* (the quantity collected in) and its coefficient can be mapped using
7476
/// `key_map` and `coeff_map` respectively.
7577
///
@@ -91,6 +93,34 @@ pub trait AtomCore {
9193
self.as_atom_view().collect::<E, T>(x, key_map, coeff_map)
9294
}
9395

96+
/// Collect terms involving the same power of variables or functions with the name `x`, e.g.
97+
///
98+
/// ```math
99+
/// collect_symbol(f(1,2) + x*f*(1,2), f) = (1+x)*f(1,2)
100+
/// ```
101+
///
102+
///
103+
/// Both the *key* (the quantity collected in) and its coefficient can be mapped using
104+
/// `key_map` and `coeff_map` respectively.
105+
///
106+
/// # Example
107+
///
108+
/// ```
109+
/// use symbolica::{atom::AtomCore, parse, symbol};
110+
/// let expr = parse!("f(1,2) + x*f(1,2)").unwrap();
111+
/// let collected = expr.collect_symbol::<u8>(symbol!("f"), None, None);
112+
/// assert_eq!(collected, parse!("(1+x)*f(1,2)").unwrap());
113+
/// ```
114+
fn collect_symbol<E: Exponent>(
115+
&self,
116+
x: Symbol,
117+
key_map: Option<Box<dyn Fn(AtomView, &mut Atom)>>,
118+
coeff_map: Option<Box<dyn Fn(AtomView, &mut Atom)>>,
119+
) -> Atom {
120+
self.as_atom_view()
121+
.collect_symbol::<E>(x, key_map, coeff_map)
122+
}
123+
94124
/// Collect terms involving the same power of `x`, where `x` is a variable or function, e.g.
95125
///
96126
/// ```math
@@ -120,6 +150,20 @@ pub trait AtomCore {
120150
.collect_multiple::<E, T>(xs, key_map, coeff_map)
121151
}
122152

153+
/// Collect common factors from (nested) sums.
154+
///
155+
/// # Example
156+
///
157+
/// ```
158+
/// use symbolica::{atom::AtomCore, parse};
159+
/// let expr = parse!("x*(x+y*x+x^2+y*(x+x^2))").unwrap();
160+
/// let collected = expr.collect_factors();
161+
/// assert_eq!(collected, parse!("x^2*(1+x+y+y*(1+x))").unwrap());
162+
/// ```
163+
fn collect_factors(&self) -> Atom {
164+
self.as_atom_view().collect_factors()
165+
}
166+
123167
/// Collect terms involving the same power of `x` in `xs`, where `xs` is a list of indeterminates.
124168
/// Return the list of key-coefficient pairs
125169
///
@@ -1266,7 +1310,7 @@ pub trait AtomCore {
12661310
/// ```
12671311
/// use symbolica::{atom::AtomCore, parse};
12681312
/// let expr = parse!("x + y").unwrap();
1269-
/// let result = expr.replace_map(&|term, _ctx, out| {
1313+
/// let result = expr.replace_map(|term, _ctx, out| {
12701314
/// if term.to_string() == "symbolica::x" {
12711315
/// *out = parse!("z").unwrap();
12721316
/// true
@@ -1276,7 +1320,7 @@ pub trait AtomCore {
12761320
/// });
12771321
/// assert_eq!(result, parse!("z + y").unwrap());
12781322
/// ```
1279-
fn replace_map<F: Fn(AtomView, &Context, &mut Atom) -> bool>(&self, m: &F) -> Atom {
1323+
fn replace_map<F: FnMut(AtomView, &Context, &mut Atom) -> bool>(&self, m: F) -> Atom {
12801324
self.as_atom_view().replace_map(m)
12811325
}
12821326

0 commit comments

Comments
 (0)