Skip to content

Commit

Permalink
More typing of GA mappers
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 13, 2024
1 parent a365dbf commit ce1b948
Showing 1 changed file with 44 additions and 25 deletions.
69 changes: 44 additions & 25 deletions pymbolic/geometric_algebra/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@

# This is experimental, undocumented, and could go away any second.
# Consider yourself warned.

from collections.abc import Set
from typing import ClassVar

import pymbolic.geometric_algebra.primitives as prim
from pymbolic.geometric_algebra import MultiVector
from pymbolic.mapper import (
CachedMapper,
CollectedT,
Collector as CollectorBase,
CombineMapper as CombineMapperBase,
IdentityMapper as IdentityMapperBase,
P,
ResultT,
WalkMapper as WalkMapperBase,
)
from pymbolic.mapper.constant_folder import (
Expand All @@ -46,50 +49,66 @@
PREC_NONE,
StringifyMapper as StringifyMapperBase,
)
from pymbolic.primitives import Expression


class IdentityMapper(IdentityMapperBase):
def map_multivector_variable(self, expr):
class IdentityMapper(IdentityMapperBase[P]):
def map_nabla(
self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> Expression:
return expr

map_nabla = map_multivector_variable
map_nabla_component = map_multivector_variable
def map_nabla_component(self,
expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs) -> Expression:
return expr

def map_derivative_source(self, expr):
operand = self.rec(expr.operand)
def map_derivative_source(self,
expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
) -> Expression:
operand = self.rec(expr.operand, *args, **kwargs)
if operand is expr.operand:
return expr

return type(expr)(operand, expr.nabla_id)


class CombineMapper(CombineMapperBase):
def map_derivative_source(self, expr):
return self.rec(expr.operand)
class CombineMapper(CombineMapperBase[ResultT, P]):
def map_derivative_source(
self, expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
return self.rec(expr.operand, *args, **kwargs)


class Collector(CollectorBase):
def map_nabla(self, expr):
class Collector(CollectorBase[CollectedT, P]):
def map_nabla(self,
expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs
) -> Set[CollectedT]:
return set()

map_nabla_component = map_nabla
def map_nabla_component(self,
expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
) -> Set[CollectedT]:
return set()


class WalkMapper(WalkMapperBase):
def map_nabla(self, expr, *args):
self.visit(expr, *args)
self.post_visit(expr)
class WalkMapper(WalkMapperBase[P]):
def map_nabla(self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> None:
self.visit(expr, *args, **kwargs)
self.post_visit(expr, *args, **kwargs)

def map_nabla_component(self, expr, *args):
self.visit(expr, *args)
self.post_visit(expr)
def map_nabla_component(
self, expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
) -> None:
self.visit(expr, *args, **kwargs)
self.post_visit(expr, *args, **kwargs)

def map_derivative_source(self, expr, *args):
if not self.visit(expr, *args):
def map_derivative_source(
self, expr, *args: P.args, **kwargs: P.kwargs
) -> None:
if not self.visit(expr, *args, **kwargs):
return

self.rec(expr.operand)
self.post_visit(expr)
self.rec(expr.operand, *args, **kwargs)
self.post_visit(expr, *args, **kwargs)


class EvaluationMapper(EvaluationMapperBase):
Expand All @@ -106,7 +125,7 @@ def map_derivative_source(self, expr):
return type(expr)(operand, expr.nabla_id)


class StringifyMapper(StringifyMapperBase):
class StringifyMapper(StringifyMapperBase[[]]):
AXES: ClassVar[dict[int, str]] = {0: "x", 1: "y", 2: "z"}

def map_nabla(self, expr, enclosing_prec):
Expand Down

0 comments on commit ce1b948

Please sign in to comment.