@@ -1336,6 +1336,99 @@ impl PythonTransformer {
1336
1336
return append_transformer ! ( self , Transformer :: Collect ( xs, key_map, coeff_map) ) ;
1337
1337
}
1338
1338
1339
+ /// Create a transformer that collects terms involving the same power of variables or functions with the name `x`.
1340
+ ///
1341
+ /// Both the key (the quantity collected in) and its coefficient can be mapped using
1342
+ /// `key_map` and `coeff_map` transformers respectively.
1343
+ ///
1344
+ /// Examples
1345
+ /// --------
1346
+ /// >>> from symbolica import Expression
1347
+ /// >>> x, f = Expression.symbol('x', 'f')
1348
+ /// >>> e = f(1,2) + x*f(1,2)
1349
+ /// >>>
1350
+ /// >>> print(e.transform().collect_symbol(x).execute())
1351
+ ///
1352
+ /// yields `(1+x)*f(1,2)`.
1353
+ ///
1354
+ /// Parameters
1355
+ /// ----------
1356
+ /// x: Expression
1357
+ /// The symbol to collect in
1358
+ /// key_map: Transformer
1359
+ /// A transformer to be applied to the quantity collected in
1360
+ /// coeff_map: Transformer
1361
+ /// A transformer to be applied to the coefficient
1362
+ #[ pyo3( signature = ( x, key_map = None , coeff_map = None ) ) ]
1363
+ pub fn collect_symbol (
1364
+ & self ,
1365
+ x : PythonExpression ,
1366
+ key_map : Option < PythonTransformer > ,
1367
+ coeff_map : Option < PythonTransformer > ,
1368
+ ) -> PyResult < PythonTransformer > {
1369
+ let Some ( x) = x. expr . get_symbol ( ) else {
1370
+ return Err ( exceptions:: PyValueError :: new_err (
1371
+ "Collect must be done wrt a variable or function" ,
1372
+ ) ) ;
1373
+ } ;
1374
+
1375
+ let key_map = if let Some ( key_map) = key_map {
1376
+ let Pattern :: Transformer ( p) = key_map. expr else {
1377
+ return Err ( exceptions:: PyValueError :: new_err (
1378
+ "Key map must be a transformer" ,
1379
+ ) ) ;
1380
+ } ;
1381
+
1382
+ if p. 0 . is_some ( ) {
1383
+ Err ( exceptions:: PyValueError :: new_err (
1384
+ "Key map must be an unbound transformer" ,
1385
+ ) ) ?;
1386
+ }
1387
+
1388
+ p. 1 . clone ( )
1389
+ } else {
1390
+ vec ! [ ]
1391
+ } ;
1392
+
1393
+ let coeff_map = if let Some ( coeff_map) = coeff_map {
1394
+ let Pattern :: Transformer ( p) = coeff_map. expr else {
1395
+ return Err ( exceptions:: PyValueError :: new_err (
1396
+ "Key map must be a transformer" ,
1397
+ ) ) ;
1398
+ } ;
1399
+
1400
+ if p. 0 . is_some ( ) {
1401
+ Err ( exceptions:: PyValueError :: new_err (
1402
+ "Key map must be an unbound transformer" ,
1403
+ ) ) ?;
1404
+ }
1405
+
1406
+ p. 1 . clone ( )
1407
+ } else {
1408
+ vec ! [ ]
1409
+ } ;
1410
+
1411
+ return append_transformer ! ( self , Transformer :: CollectSymbol ( x, key_map, coeff_map) ) ;
1412
+ }
1413
+
1414
+ /// Create a transformer that collects common factors from (nested) sums.
1415
+ ///
1416
+ /// Examples
1417
+ /// --------
1418
+ ///
1419
+ /// >>> from symbolica import *
1420
+ /// >>> e = E('x*(x+y*x+x^2+y*(x+x^2))')
1421
+ /// >>> e.transform().collect_factors().execute()
1422
+ ///
1423
+ /// yields
1424
+ ///
1425
+ /// ```log
1426
+ /// v1^2*(1+v1+v2+v2*(1+v1))
1427
+ /// ```
1428
+ pub fn collect_factors ( & self ) -> PyResult < PythonTransformer > {
1429
+ return append_transformer ! ( self , Transformer :: CollectFactors ) ;
1430
+ }
1431
+
1339
1432
/// Create a transformer that collects numerical factors by removing the numerical content from additions.
1340
1433
/// For example, `-2*x + 4*x^2 + 6*x^3` will be transformed into `-2*(x - 2*x^2 - 3*x^3)`.
1341
1434
///
@@ -4055,6 +4148,102 @@ impl PythonExpression {
4055
4148
Ok ( b. into ( ) )
4056
4149
}
4057
4150
4151
+ /// Collect terms involving the same power of variables or functions with the name `x`, e.g.
4152
+ ///
4153
+ /// ```math
4154
+ /// collect_symbol(f(1,2) + x*f*(1,2), f) = (1+x)*f(1,2)
4155
+ /// ```
4156
+ ///
4157
+ ///
4158
+ /// Both the *key* (the quantity collected in) and its coefficient can be mapped using
4159
+ /// `key_map` and `coeff_map` respectively.
4160
+ ///
4161
+ /// Examples
4162
+ /// --------
4163
+ ///
4164
+ /// >>> from symbolica import Expression
4165
+ /// >>> x, f = Expression.symbol('x', 'f')
4166
+ /// >>> e = f(1,2) + x*f(1,2)
4167
+ /// >>>
4168
+ /// >>> print(e.collect_symbol(f))
4169
+ ///
4170
+ /// yields `(1+x)*f(1,2)`.
4171
+ #[ pyo3( signature = ( x, key_map = None , coeff_map = None ) ) ]
4172
+ pub fn collect_symbol (
4173
+ & self ,
4174
+ x : PythonExpression ,
4175
+ key_map : Option < PyObject > ,
4176
+ coeff_map : Option < PyObject > ,
4177
+ ) -> PyResult < PythonExpression > {
4178
+ let Some ( x) = x. expr . get_symbol ( ) else {
4179
+ return Err ( exceptions:: PyValueError :: new_err (
4180
+ "Collect must be done wrt a variable or function" ,
4181
+ ) ) ;
4182
+ } ;
4183
+
4184
+ let b = self . expr . collect_symbol :: < i16 > (
4185
+ x,
4186
+ if let Some ( key_map) = key_map {
4187
+ Some ( Box :: new ( move |key, out| {
4188
+ Python :: with_gil ( |py| {
4189
+ let key: PythonExpression = key. to_owned ( ) . into ( ) ;
4190
+
4191
+ out. set_from_view (
4192
+ & key_map
4193
+ . call ( py, ( key, ) , None )
4194
+ . expect ( "Bad callback function" )
4195
+ . extract :: < PythonExpression > ( py)
4196
+ . expect ( "Key map should return an expression" )
4197
+ . expr
4198
+ . as_view ( ) ,
4199
+ )
4200
+ } ) ;
4201
+ } ) )
4202
+ } else {
4203
+ None
4204
+ } ,
4205
+ if let Some ( coeff_map) = coeff_map {
4206
+ Some ( Box :: new ( move |coeff, out| {
4207
+ Python :: with_gil ( |py| {
4208
+ let coeff: PythonExpression = coeff. to_owned ( ) . into ( ) ;
4209
+
4210
+ out. set_from_view (
4211
+ & coeff_map
4212
+ . call ( py, ( coeff, ) , None )
4213
+ . expect ( "Bad callback function" )
4214
+ . extract :: < PythonExpression > ( py)
4215
+ . expect ( "Coeff map should return an expression" )
4216
+ . expr
4217
+ . as_view ( ) ,
4218
+ )
4219
+ } ) ;
4220
+ } ) )
4221
+ } else {
4222
+ None
4223
+ } ,
4224
+ ) ;
4225
+
4226
+ Ok ( b. into ( ) )
4227
+ }
4228
+
4229
+ /// Collect common factors from (nested) sums.
4230
+ ///
4231
+ /// Examples
4232
+ /// --------
4233
+ ///
4234
+ /// >>> from symbolica import *
4235
+ /// >>> e = E('x*(x+y*x+x^2+y*(x+x^2))')
4236
+ /// >>> e.collect_factors()
4237
+ ///
4238
+ /// yields
4239
+ ///
4240
+ /// ```log
4241
+ /// v1^2*(1+v1+v2+v2*(1+v1))
4242
+ /// ```
4243
+ pub fn collect_factors ( & self ) -> PythonExpression {
4244
+ self . expr . collect_factors ( ) . into ( )
4245
+ }
4246
+
4058
4247
/// Collect numerical factors by removing the numerical content from additions.
4059
4248
/// For example, `-2*x + 4*x^2 + 6*x^3` will be transformed into `-2*(x - 2*x^2 - 3*x^3)`.
4060
4249
///
0 commit comments