From 13fb3d2d4057a04ec67e2d325a65177d391ff01b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 6 Nov 2024 15:42:54 -0600 Subject: [PATCH] Pass strict=True in calls to zip() --- pymbolic/algorithm.py | 3 ++- pymbolic/geometric_algebra/mapper.py | 5 +++-- pymbolic/mapper/__init__.py | 29 ++++++++++++++-------------- pymbolic/mapper/stringifier.py | 4 ++-- pymbolic/mapper/unifier.py | 6 +++--- pymbolic/primitives.py | 5 +++-- 6 files changed, 28 insertions(+), 24 deletions(-) diff --git a/pymbolic/algorithm.py b/pymbolic/algorithm.py index 4dac6f3..f66fc60 100644 --- a/pymbolic/algorithm.py +++ b/pymbolic/algorithm.py @@ -489,7 +489,8 @@ def solve_affine_equations_for(unknowns, equations): div = mat[nonz_row, j] unknown_val = int(rhs_mat[nonz_row, -1]) // div - for parameter, coeff in zip(parameters_list, rhs_mat[nonz_row]): + for parameter, coeff in zip( + parameters_list, rhs_mat[nonz_row, :-1], strict=True): unknown_val += (int(coeff) // div) * parameter result[unknown] = unknown_val diff --git a/pymbolic/geometric_algebra/mapper.py b/pymbolic/geometric_algebra/mapper.py index 61ca833..bf5b64a 100644 --- a/pymbolic/geometric_algebra/mapper.py +++ b/pymbolic/geometric_algebra/mapper.py @@ -285,7 +285,8 @@ def map_product(self, expr): if not has_d_source_nablas: rec_children = [self.rec(child) for child in expr.children] if all(rec_child is child - for rec_child, child in zip(rec_children, expr.children)): + for rec_child, child in zip( + rec_children, expr.children, strict=True)): return expr return type(expr)(tuple(rec_children)) @@ -296,7 +297,7 @@ def map_product(self, expr): result = [list(expr.children)] for child_idx, (d_source_nabla_ids, _child) in enumerate( - zip(d_source_nabla_ids_per_child, expr.children)): + zip(d_source_nabla_ids_per_child, expr.children, strict=True)): if not d_source_nabla_ids: continue diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 65dbfb9..34dc9ab 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -749,7 +749,8 @@ def map_call(self, ]) if (function is expr.function and all(child is orig_child - for child, orig_child in zip(expr.parameters, parameters))): + for child, orig_child in zip( + expr.parameters, parameters, strict=True))): return expr return type(expr)(function, parameters) @@ -767,7 +768,7 @@ def map_call_with_kwargs(self, if (function is expr.function and all(child is orig_child for child, orig_child in - zip(parameters, expr.parameters)) + zip(parameters, expr.parameters, strict=True)) and all(kw_parameters[k] is v for k, v in expr.kw_parameters.items())): return expr @@ -795,7 +796,7 @@ def map_sum(self, ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(tuple(children)) @@ -805,7 +806,7 @@ def map_product(self, ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(tuple(children)) @@ -877,7 +878,7 @@ def map_bitwise_or(self, ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(tuple(children)) @@ -887,7 +888,7 @@ def map_bitwise_and(self, ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(tuple(children)) @@ -897,7 +898,7 @@ def map_bitwise_xor(self, ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(tuple(children)) @@ -915,7 +916,7 @@ def map_logical_or(self, ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(tuple(children)) @@ -925,7 +926,7 @@ def map_logical_and(self, ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(tuple(children)) @@ -952,7 +953,7 @@ def map_tuple(self, ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr] if all(child is orig_child - for child, orig_child in zip(children, expr)): + for child, orig_child in zip(children, expr, strict=True)): return expr return tuple(children) @@ -994,7 +995,7 @@ def map_substitution(self, child = self.rec(expr.child, *args, **kwargs) values = tuple([self.rec(v, *args, **kwargs) for v in expr.values]) if child is expr.child and all(val is orig_val - for val, orig_val in zip(values, expr.values)): + for val, orig_val in zip(values, expr.values, strict=True)): return expr return type(expr)(child, expr.variables, values) @@ -1016,7 +1017,7 @@ def map_slice(self, for child in expr.children ])) if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(children) @@ -1037,7 +1038,7 @@ def map_min(self, expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ExpressionT self.rec(child, *args, **kwargs) for child in expr.children ]) if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(children) @@ -1047,7 +1048,7 @@ def map_max(self, expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ExpressionT self.rec(child, *args, **kwargs) for child in expr.children ]) if all(child is orig_child - for child, orig_child in zip(children, expr.children)): + for child, orig_child in zip(children, expr.children, strict=True)): return expr return type(expr)(children) diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 46ba83f..a3ccd3f 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -698,7 +698,7 @@ def map_substitution( ) -> str: substs = ", ".join( "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs)) - for name, val in zip(expr.variables, expr.values) + for name, val in zip(expr.variables, expr.values, strict=True) ) return "[%s]{%s}" % (self.rec(expr.child, PREC_NONE, *args, **kwargs), substs) @@ -1139,7 +1139,7 @@ def map_substitution( ) -> str: substs = ", ".join( "{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs)) - for name, val in zip(expr.variables, expr.values) + for name, val in zip(expr.variables, expr.values, strict=True) ) return self.format( diff --git a/pymbolic/mapper/unifier.py b/pymbolic/mapper/unifier.py index a7d1dad..c4eec56 100644 --- a/pymbolic/mapper/unifier.py +++ b/pymbolic/mapper/unifier.py @@ -214,7 +214,7 @@ def map_sum(self, expr, other, urecs): for my_child, other_child in zip( expr.children, - (other.children[i] for i in perm)): + (other.children[i] for i in perm), strict=True): it_assignments = self.rec(my_child, other_child, it_assignments) if not it_assignments: break @@ -302,7 +302,7 @@ def map_list(self, expr, other, urecs): or len(expr) != len(other)): return [] - for my_child, other_child in zip(expr, other): + for my_child, other_child in zip(expr, other, strict=True): urecs = self.rec(my_child, other_child, urecs) if not urecs: break @@ -399,7 +399,7 @@ def partitions(s, k): for partition in partitions( other_leftovers, len(plain_var_candidates)): result = urec - for subset, var in zip(partition, plain_var_candidates): + for subset, var in zip(partition, plain_var_candidates, strict=True): rec = self.unification_record_from_equation( var, factory(other.children[i] for i in subset)) result = result.unify(rec) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 0f92630..181f166 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -772,7 +772,7 @@ def __getstate__(self) -> tuple[Any]: def __setstate__(self, state) -> None: # Can't use trivial pickling: _hash_value cache must stay unset assert len(self.init_arg_names) == len(state), type(self) - for name, value in zip(self.init_arg_names, state): + for name, value in zip(self.init_arg_names, state, strict=True): object.__setattr__(self, name, value) # }}} @@ -1791,7 +1791,8 @@ def flattened_sum(terms: Iterable[ArithmeticExpressionT]) -> ArithmeticExpressio def linear_combination(coefficients, expressions): return sum(coefficient * expression - for coefficient, expression in zip(coefficients, expressions) + for coefficient, expression + in zip(coefficients, expressions, strict=True) if coefficient and expression)