Skip to content

Commit

Permalink
Pass strict=True in calls to zip()
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 6, 2024
1 parent 8f91eb2 commit 13fb3d2
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 24 deletions.
3 changes: 2 additions & 1 deletion pymbolic/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,8 @@ def solve_affine_equations_for(unknowns, equations):
div = mat[nonz_row, j]

unknown_val = int(rhs_mat[nonz_row, -1]) // div
for parameter, coeff in zip(parameters_list, rhs_mat[nonz_row]):
for parameter, coeff in zip(
parameters_list, rhs_mat[nonz_row, :-1], strict=True):
unknown_val += (int(coeff) // div) * parameter

result[unknown] = unknown_val
Expand Down
5 changes: 3 additions & 2 deletions pymbolic/geometric_algebra/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ def map_product(self, expr):
if not has_d_source_nablas:
rec_children = [self.rec(child) for child in expr.children]
if all(rec_child is child
for rec_child, child in zip(rec_children, expr.children)):
for rec_child, child in zip(
rec_children, expr.children, strict=True)):
return expr

return type(expr)(tuple(rec_children))
Expand All @@ -296,7 +297,7 @@ def map_product(self, expr):
result = [list(expr.children)]

for child_idx, (d_source_nabla_ids, _child) in enumerate(
zip(d_source_nabla_ids_per_child, expr.children)):
zip(d_source_nabla_ids_per_child, expr.children, strict=True)):
if not d_source_nabla_ids:
continue

Expand Down
29 changes: 15 additions & 14 deletions pymbolic/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,8 @@ def map_call(self,
])
if (function is expr.function
and all(child is orig_child
for child, orig_child in zip(expr.parameters, parameters))):
for child, orig_child in zip(
expr.parameters, parameters, strict=True))):
return expr

return type(expr)(function, parameters)
Expand All @@ -767,7 +768,7 @@ def map_call_with_kwargs(self,

if (function is expr.function
and all(child is orig_child for child, orig_child in
zip(parameters, expr.parameters))
zip(parameters, expr.parameters, strict=True))
and all(kw_parameters[k] is v for k, v in
expr.kw_parameters.items())):
return expr
Expand Down Expand Up @@ -795,7 +796,7 @@ def map_sum(self,
) -> ExpressionT:
children = [self.rec(child, *args, **kwargs) for child in expr.children]
if all(child is orig_child
for child, orig_child in zip(children, expr.children)):
for child, orig_child in zip(children, expr.children, strict=True)):
return expr

return type(expr)(tuple(children))
Expand All @@ -805,7 +806,7 @@ def map_product(self,
) -> ExpressionT:
children = [self.rec(child, *args, **kwargs) for child in expr.children]
if all(child is orig_child
for child, orig_child in zip(children, expr.children)):
for child, orig_child in zip(children, expr.children, strict=True)):
return expr

return type(expr)(tuple(children))
Expand Down Expand Up @@ -877,7 +878,7 @@ def map_bitwise_or(self,
) -> ExpressionT:
children = [self.rec(child, *args, **kwargs) for child in expr.children]
if all(child is orig_child
for child, orig_child in zip(children, expr.children)):
for child, orig_child in zip(children, expr.children, strict=True)):
return expr

return type(expr)(tuple(children))
Expand All @@ -887,7 +888,7 @@ def map_bitwise_and(self,
) -> ExpressionT:
children = [self.rec(child, *args, **kwargs) for child in expr.children]
if all(child is orig_child
for child, orig_child in zip(children, expr.children)):
for child, orig_child in zip(children, expr.children, strict=True)):
return expr

return type(expr)(tuple(children))
Expand All @@ -897,7 +898,7 @@ def map_bitwise_xor(self,
) -> ExpressionT:
children = [self.rec(child, *args, **kwargs) for child in expr.children]
if all(child is orig_child
for child, orig_child in zip(children, expr.children)):
for child, orig_child in zip(children, expr.children, strict=True)):
return expr

return type(expr)(tuple(children))
Expand All @@ -915,7 +916,7 @@ def map_logical_or(self,
) -> ExpressionT:
children = [self.rec(child, *args, **kwargs) for child in expr.children]
if all(child is orig_child
for child, orig_child in zip(children, expr.children)):
for child, orig_child in zip(children, expr.children, strict=True)):
return expr

return type(expr)(tuple(children))
Expand All @@ -925,7 +926,7 @@ def map_logical_and(self,
) -> ExpressionT:
children = [self.rec(child, *args, **kwargs) for child in expr.children]
if all(child is orig_child
for child, orig_child in zip(children, expr.children)):
for child, orig_child in zip(children, expr.children, strict=True)):
return expr

return type(expr)(tuple(children))
Expand All @@ -952,7 +953,7 @@ def map_tuple(self,
) -> ExpressionT:
children = [self.rec(child, *args, **kwargs) for child in expr]
if all(child is orig_child
for child, orig_child in zip(children, expr)):
for child, orig_child in zip(children, expr, strict=True)):
return expr

return tuple(children)
Expand Down Expand Up @@ -994,7 +995,7 @@ def map_substitution(self,
child = self.rec(expr.child, *args, **kwargs)
values = tuple([self.rec(v, *args, **kwargs) for v in expr.values])
if child is expr.child and all(val is orig_val
for val, orig_val in zip(values, expr.values)):
for val, orig_val in zip(values, expr.values, strict=True)):
return expr

return type(expr)(child, expr.variables, values)
Expand All @@ -1016,7 +1017,7 @@ def map_slice(self,
for child in expr.children
]))
if all(child is orig_child
for child, orig_child in zip(children, expr.children)):
for child, orig_child in zip(children, expr.children, strict=True)):
return expr

return type(expr)(children)
Expand All @@ -1037,7 +1038,7 @@ def map_min(self, expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ExpressionT
self.rec(child, *args, **kwargs) for child in expr.children
])
if all(child is orig_child
for child, orig_child in zip(children, expr.children)):
for child, orig_child in zip(children, expr.children, strict=True)):
return expr

return type(expr)(children)
Expand All @@ -1047,7 +1048,7 @@ def map_max(self, expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ExpressionT
self.rec(child, *args, **kwargs) for child in expr.children
])
if all(child is orig_child
for child, orig_child in zip(children, expr.children)):
for child, orig_child in zip(children, expr.children, strict=True)):
return expr

return type(expr)(children)
Expand Down
4 changes: 2 additions & 2 deletions pymbolic/mapper/stringifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def map_substitution(
) -> str:
substs = ", ".join(
"{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs))
for name, val in zip(expr.variables, expr.values)
for name, val in zip(expr.variables, expr.values, strict=True)
)

return "[%s]{%s}" % (self.rec(expr.child, PREC_NONE, *args, **kwargs), substs)
Expand Down Expand Up @@ -1139,7 +1139,7 @@ def map_substitution(
) -> str:
substs = ", ".join(
"{}={}".format(name, self.rec(val, PREC_NONE, *args, **kwargs))
for name, val in zip(expr.variables, expr.values)
for name, val in zip(expr.variables, expr.values, strict=True)
)

return self.format(
Expand Down
6 changes: 3 additions & 3 deletions pymbolic/mapper/unifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def map_sum(self, expr, other, urecs):

for my_child, other_child in zip(
expr.children,
(other.children[i] for i in perm)):
(other.children[i] for i in perm), strict=True):
it_assignments = self.rec(my_child, other_child, it_assignments)
if not it_assignments:
break
Expand Down Expand Up @@ -302,7 +302,7 @@ def map_list(self, expr, other, urecs):
or len(expr) != len(other)):
return []

for my_child, other_child in zip(expr, other):
for my_child, other_child in zip(expr, other, strict=True):
urecs = self.rec(my_child, other_child, urecs)
if not urecs:
break
Expand Down Expand Up @@ -399,7 +399,7 @@ def partitions(s, k):
for partition in partitions(
other_leftovers, len(plain_var_candidates)):
result = urec
for subset, var in zip(partition, plain_var_candidates):
for subset, var in zip(partition, plain_var_candidates, strict=True):
rec = self.unification_record_from_equation(
var, factory(other.children[i] for i in subset))
result = result.unify(rec)
Expand Down
5 changes: 3 additions & 2 deletions pymbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def __getstate__(self) -> tuple[Any]:
def __setstate__(self, state) -> None:
# Can't use trivial pickling: _hash_value cache must stay unset
assert len(self.init_arg_names) == len(state), type(self)
for name, value in zip(self.init_arg_names, state):
for name, value in zip(self.init_arg_names, state, strict=True):
object.__setattr__(self, name, value)

# }}}
Expand Down Expand Up @@ -1791,7 +1791,8 @@ def flattened_sum(terms: Iterable[ArithmeticExpressionT]) -> ArithmeticExpressio

def linear_combination(coefficients, expressions):
return sum(coefficient * expression
for coefficient, expression in zip(coefficients, expressions)
for coefficient, expression
in zip(coefficients, expressions, strict=True)
if coefficient and expression)


Expand Down

0 comments on commit 13fb3d2

Please sign in to comment.