diff --git a/setup.py b/setup.py index 2fb647a5..3056a233 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ def read(filename): "tqdm", "versioneer", "xarray", + "sly", ], extras_require={ "dev": [ diff --git a/src/pymorize/cli.py b/src/pymorize/cli.py index 84008023..6c4aa902 100644 --- a/src/pymorize/cli.py +++ b/src/pymorize/cli.py @@ -17,6 +17,10 @@ from .cmorizer import CMORizer from .filecache import fc from .logging import add_report_logger, logger +from .prototype.cellmethods.cellmethods_parser import ( + parse_cell_methods, + translate_to_xarray, +) from .ssh_tunnel import ssh_tunnel_cli from .validate import GENERAL_VALIDATOR, PIPELINES_VALIDATOR, RULES_VALIDATOR @@ -248,6 +252,37 @@ def directory(config_file, output_dir, verbose, quiet, logfile, profile_mem): cmorizer.check_rules_for_output_dir(output_dir) +@validate.command() +@click_loguru.logging_options +@click_loguru.init_logger() +@click.argument("config_file", type=click.Path(exists=True)) +def cellmethods(config_file, verbose, quiet, logfile, profile_mem): + logger.info(f"Processing {config_file}") + with open(config_file, "r") as f: + cfg = yaml.safe_load(f) + cmorizer = CMORizer.from_dict(cfg) + seen_rules = set() + for rule in cmorizer.rules: + if rule.name in seen_rules: + continue + else: + seen_rules.add(rule.name) + cellmethod_text = rule.data_request_variable.cell_methods + if not cellmethod_text.strip(): + continue + else: + tokengroups = parse_cell_methods(cellmethod_text) + logger.info(f"{rule.cmor_variable!r}: Parsing cellmethods text...") + logger.info(f"{cellmethod_text}") + logger.info("Tokens:") + for tok in tokengroups: + logger.info(f" {tok}") + logger.info("xarray translation (Pseudo code)") + codelines = translate_to_xarray(cellmethod_text) + for line in codelines.splitlines(): + logger.info(f" {line}") + + ################################################################################ ################################################################################ ################################################################################ diff --git a/src/pymorize/prototype/__init__.py b/src/pymorize/prototype/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pymorize/prototype/cellmethods/__init__.py b/src/pymorize/prototype/cellmethods/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pymorize/prototype/cellmethods/cell_methods_xarray.py b/src/pymorize/prototype/cellmethods/cell_methods_xarray.py new file mode 100644 index 00000000..883313df --- /dev/null +++ b/src/pymorize/prototype/cellmethods/cell_methods_xarray.py @@ -0,0 +1,165 @@ +from typing import List, Optional, Tuple + +import numpy as np +import xarray as xr + +from .cellmethods_parser import parse_cell_methods + +""" +Prototype code only. Not sure if cellmethods are handled this way. Lot of ambiguity at many steps. +""" + + +class CellMethodsConverter: + def __init__(self): + self.function_map = { + "mean": xr.DataArray.mean, + "sum": xr.DataArray.sum, + "maximum": xr.DataArray.max, + "minimum": xr.DataArray.min, + "point": lambda x, dim: x.isel(**{dim: 0}), + } + + self.dimension_map = { + "area": "area", + "time": "time", + "depth": "depth", + "longitude": "lon", + "grid_longitude": "grid_lon", + } + + def apply_constraint( + self, da: xr.DataArray, constraint: str, value: str, scope: Optional[str] = None + ) -> xr.DataArray: + """Apply where/over constraints to the DataArray.""" + if constraint == "where": + # Handle special cases with mask variables + if scope and "(comment: mask=" in scope: + mask_var = scope.split("mask=")[1].rstrip(")") + # Assuming the mask variable is available in the same dataset + return da.where(da.coords[mask_var] > 0) + + # Handle basic area type constraints + area_types = [ + "land", + "sea", + "ice_sheet", + "sea_ice", + "crops", + "trees", + "vegetation", + "unfrozen_soil", + "cloud", + "natural_grasses", + "floating_ice_shelf", + "grounded_ice_sheet", + "ice_free_sea", + "sea_ice_melt_pond", + "sea_ice_ridges", + "snow", + "sector", + "shrubs", + "pastures", + ] + + if value in area_types: + # Use the mask from coordinates + mask_var = f"{value}_mask" + mask = da.coords[mask_var] + # Create a boolean mask array that matches the data dimensions + mask_data = mask.values > 0 + # Broadcast mask to match data dimensions + for _ in range(len(da.dims) - 1): + mask_data = mask_data[:, np.newaxis] + # Apply the mask + return da.where(mask_data) + + elif constraint == "over": + if value == "all_area_types": + # No filtering needed, already considering all areas + return da + elif value in ["days", "months", "years", "hours"]: + # This will be handled in the time aggregation + return da + + return da + + def process_cell_method( + self, da: xr.DataArray, method: List[Tuple[str, str]] + ) -> xr.DataArray: + """Process a single cell method (one group of operations).""" + result = da.copy() # Make a copy to preserve coordinates + dim = None + func = None + constraints = [] + scope = None + + for token_type, token_value in method: + if token_type == "DIMENSION": + dim = self.dimension_map.get(token_value, token_value) + elif token_type == "FUNCTION": + func = self.function_map[token_value] + elif token_type == "CONSTRAINT": + constraints.append(token_value) + elif token_type == "AREATYPE" or token_type == "SELECTION": + if constraints: + result = self.apply_constraint( + result, constraints[-1], token_value, scope + ) + elif token_type == "SCOPE": + scope = token_value + + if dim and func: + # Handle time-based selections before applying the function + if dim == "time" and constraints and constraints[-1] == "over": + # Get the appropriate time frequency + freq = {"hours": "h", "days": "D", "months": "M", "years": "Y"}.get( + token_value + ) + if freq: + result = result.resample(time=freq).mean() + else: + # Apply the main function + if func == self.function_map["point"]: + result = result.isel(**{dim: 0}) + else: + result = func(result, dim=dim) + + return result + + def apply_cell_methods( + self, da: xr.DataArray, cell_methods_str: str + ) -> xr.DataArray: + """Apply cell methods to a DataArray based on the cell_methods string.""" + parsed = parse_cell_methods(cell_methods_str) + if parsed is None: + raise ValueError(f"Failed to parse cell methods string: {cell_methods_str}") + + result = da + for method in parsed: + result = self.process_cell_method(result, method) + + return result + + +# Example usage: +def apply_cell_methods(da: xr.DataArray, cell_methods_str: str) -> xr.DataArray: + """ + Apply cell methods to a DataArray based on the cell_methods string. + + Args: + da: Input xarray DataArray + cell_methods_str: Cell methods string (e.g., "area: mean time: maximum") + + Returns: + Processed xarray DataArray + + Example: + >>> import xarray as xr + >>> import numpy as np + >>> data = np.random.rand(4, 3, 2) # time, area, depth + >>> da = xr.DataArray(data, dims=['time', 'area', 'depth']) + >>> result = apply_cell_methods(da, "area: mean time: maximum") + """ + converter = CellMethodsConverter() + return converter.apply_cell_methods(da, cell_methods_str) diff --git a/src/pymorize/prototype/cellmethods/cellmethods_parser.py b/src/pymorize/prototype/cellmethods/cellmethods_parser.py new file mode 100644 index 00000000..4ee193e1 --- /dev/null +++ b/src/pymorize/prototype/cellmethods/cellmethods_parser.py @@ -0,0 +1,329 @@ +from sly import Lexer, Parser + + +class CellMethodsLexer(Lexer): + # set of token names + tokens = { + DIMENSION, # noqa: F821 + FUNCTION, # noqa: F821 + CONSTRAINT, # noqa: F821 + AREATYPE, # noqa: F821 + SELECTION, # noqa: F821 + COMMENT, # noqa: F821 + } # noqa: F821 + + # string containing ignored characters between token + ignore = " \t" + + # Regular expression rules for tokens + # DIMENSION = r"area:|time:|grid_longitude:|longitude:|latitude:|depth:" + DIMENSION = r"[a-zA-Z_]+:" + FUNCTION = r"mean|minimum|maximum|sum|point" + CONSTRAINT = r"within|over|where" + AREATYPE = r"[a-zA-Z_]+" + SELECTION = r"[a-zA-Z_]+" + + def DIMENSION(self, t): # noqa: F811 + t.value = t.value[:-1] + return t + + _areatypes = set( + [ + "land", + "shrubs", + "pastures", + "crops", + "trees", + "vegetation", + "unfrozen_soil", + "cloud", + "natural_grasses", + "floating_ice_shelf", + "grounded_ice_sheet", + "ice_free_sea", + "ice_sheet", + "sea", + "sea_ice", + "sea_ice_melt_pond", + "sea_ice_ridges", + "snow", + "sector", + ] + ) + _selection = set(["hours", "days", "years", "months"]) + _selection.add("all_area_types") + + @_(r"[a-zA-Z_]+") # noqa: F821 + def AREATYPE(self, t): # noqa: F811 + if t.value in self._areatypes: + return t + if t.value in self._selection: + t.type = "SELECTION" + return t + + @_(r"\(.*?\)") # noqa: F821 + def COMMENT(self, t): + value = t.value[1:-1] + t.value = ( + value.replace("comment:", "").replace("[", "").replace("]", "").strip() + ) + return t + + @_(r"\n+") # noqa: F821 + def newline(self, t): + self.lineno += t.value.count("\n") + + +class CellMethodsParser(Parser): + tokens = CellMethodsLexer.tokens + debugfile = "parser.out" + + def __init__(self): + self.tmp = [] + + @_("statements") # noqa: F821 + def program(self, p): + return corrections(p.statements) + # return p.statements + + @_("statement") # noqa: F821 + def statements(self, p): + return p.statement + + @_("statements statement") # noqa: F821 + def statements(self, p): # noqa: F811 + return p.statements + p.statement + + @_("dimension function") # noqa: F821 + def statement(self, p): + return [p.dimension + p.function] + + @_("dimension function comment") # noqa: F821 + def statement(self, p): # noqa: F811 + return [p.dimension + p.function + p.comment] + + @_("dimension function expr") # noqa: F821 + def statement(self, p): # noqa: F811 + return [p.dimension + p.function + p.expr] + + @_("dimension function exprs") # noqa: F821 + def statement(self, p): # noqa: F811 + return [p.dimension + p.function + p.exprs] + + @_("dimensions function") # noqa: F821 + def statement(self, p): # noqa: F811 + return [dim + p.function for dim in p.dimensions] + + @_("dimensions function comment") # noqa: F821 + def statement(self, p): # noqa: F811 + return [dim + p.function + p.comment for dim in p.dimensions] + + @_("dimensions function expr") # noqa: F821 + def statement(self, p): # noqa: F811 + return [dim + p.function + p.expr for dim in p.dimensions] + + @_("dimensions function exprs") # noqa: F821 + def statement(self, p): # noqa: F811 + return [dim + p.function + expr for dim in p.dimensions for expr in p.exprs] + + @_("constraint areatype comment") # noqa: F821 + def expr(self, p): + return p.constraint + p.areatype + p.comment + + @_("constraint selection comment") # noqa: F821 + def expr(self, p): # noqa: F811 + return p.constraint + p.selection + p.comment + + @_("constraint areatype") # noqa: F821 + def expr(self, p): # noqa: F811 + return p.constraint + p.areatype + + @_("constraint selection") # noqa: F821 + def expr(self, p): # noqa: F811 + return p.constraint + p.selection + + @_("expr expr") # noqa: F821 + def exprs(self, p): + return p.expr0 + p.expr1 + + @_("exprs expr") # noqa: F821 + def exprs(self, p): # noqa: F811 + return p.exprs + p.expr + + @_("dimension dimension") # noqa: F821 + def dimensions(self, p): + return [p.dimension0, p.dimension1] + + @_("dimensions dimension") # noqa: F821 + def dimensions(self, p): # noqa: F811 + return p.dimensions + [p.dimension] + + @_("DIMENSION") # noqa: F821 + def dimension(self, p): + return [("DIMENSION", p.DIMENSION)] + + @_("FUNCTION") # noqa: F821 + def function(self, p): + return [("FUNCTION", p.FUNCTION)] + + @_("CONSTRAINT") # noqa: F821 + def constraint(self, p): + return [("CONSTRAINT", p.CONSTRAINT)] + + @_("AREATYPE") # noqa: F821 + def areatype(self, p): + return [("AREATYPE", p.AREATYPE)] + + @_("SELECTION") # noqa: F821 + def selection(self, p): + return [("SELECTION", p.SELECTION)] + + @_("COMMENT") # noqa: F821 + def comment(self, p): + return [("COMMENT", p.COMMENT)] + + +def corrections(groups): + result = [] + for group in groups: + grp = [] + tokens = iter(group) + tok = next(tokens) + tok_type, tok_value = tok + grp.append(tok) + if tok_type == "DIMENSION" and tok_value == "time": + while True: + try: + tok = next(tokens) + except StopIteration: + break + tok_type, tok_value = tok + # for `time` dimension, only SELECTION type is allowed as constraint + if tok_type == "AREATYPE": + grp.pop() + else: + grp.append(tok) + elif tok_type == "DIMENSION" and tok_value == "area": + while True: + try: + tok = next(tokens) + except StopIteration: + break + tok_type, tok_value = tok + if tok_type == "SELECTION" and tok_value != "all_area_types": + grp.pop() + else: + grp.append(tok) + else: + grp.extend(list(tokens)) + result.append(grp) + return result + + +class XArrayTranslator: + """ + Represent parsed tree as human readable (pseudo code) xarray operations. + Produces strings and not xarray objects. + """ + + def __init__(self, da_name="da"): + self.da_name = da_name + self.function_map = { + "maximum": "max", + "minimum": "min", + "point": "isel", + "within": "groupby", + "over": "groupby", + } + + def translate_group(self, group): + """Translate a single group of tokens into an xarray operation.""" + tokens = iter(group) + token_type, dim = next(tokens) + assert token_type == "DIMENSION" + token_type, function = next(tokens) + function = self.function_map.get(function, function) + assert token_type == "FUNCTION" + texts = [] + try: + token_type, tok_value = next(tokens) + except StopIteration: + return f"{self.da_name}.{function}(dim={dim})" + else: + if token_type == "COMMENT": + if "mask=" in tok_value: + mask = tok_value.split("=")[1] + return f"{self.da_name}.where({mask}){function}(dim={dim} # comment: {tok_value})" + else: + return ( + f"{self.da_name}.{function}(dim={dim}) # comment: {tok_value}" + ) + elif token_type == "CONSTRAINT": + _constraint = tok_value + constraint = self.function_map.get(tok_value, tok_value) + token_type, tok_value = next(tokens) + if constraint == "groupby": + text = f"{self.da_name}.{constraint}({tok_value}).{function}(dim={dim}) # {_constraint}" + else: + text = f"{self.da_name}.{function}(dim={dim}).{constraint}({tok_value})" + texts.append(text) + # if constraint == "over": + # token_type, tok_value = next(tokens) + # text = f"{self.da_name}.{function}(dim={dim}).{constraint}({tok_value})" + # texts.append(text) + while True: + try: + token_type, tok_value = next(tokens) + except StopIteration: + break + if token_type == "COMMENT": + text = f" # comment: {tok_value}" + texts.append(text) + elif token_type == "CONSTRAINT": + constraint = tok_value + token_type, tok_value = next(tokens) + text = f".{constraint}({tok_value})" + texts.append(text) + text = "".join(texts) + return text + + def translate(self, groups): + """Translate all groups into a sequence of xarray operations.""" + operations = [] + intermediate = self.da_name + + if len(groups) == 1: + # For single operations, just return the operation directly + return self.translate_group(groups[0]) + + for i, group in enumerate(groups): + if i > 0: + # Use the result of the previous operation + self.da_name = f"result_{i}" + operations.append(f"{self.da_name} = {intermediate}") + + intermediate = self.translate_group(group) + + if i == len(groups) - 1: + # Last operation should be assigned to final result + operations.append(f"result = {intermediate}") + + return "\n".join(operations) + + +lexer = CellMethodsLexer() +parser = CellMethodsParser() + + +def parse_cell_methods(text): + tokens = lexer.tokenize(text) + group = parser.parse(tokens) + return group + + +def translate_to_xarray(text): + """Convenience function to parse cell methods and translate to xarray operations.""" + translator = XArrayTranslator() + + parsed = parse_cell_methods(text) + return translator.translate(parsed) diff --git a/src/pymorize/prototype/cellmethods/test_cellmethods_parser.py b/src/pymorize/prototype/cellmethods/test_cellmethods_parser.py new file mode 100644 index 00000000..72c07c38 --- /dev/null +++ b/src/pymorize/prototype/cellmethods/test_cellmethods_parser.py @@ -0,0 +1,160 @@ +from .cellmethods_parser import parse_cell_methods + + +def test_single_statement_with_just_action(): + text = "area: mean" + result = parse_cell_methods(text) + expected = [[("DIMENSION", "area"), ("FUNCTION", "mean")]] + assert result == expected + + +def test_single_statement_with_action_and_constraint(): + text = "area: mean where land" + result = parse_cell_methods(text) + expected = [ + [ + ("DIMENSION", "area"), + ("FUNCTION", "mean"), + ("CONSTRAINT", "where"), + ("AREATYPE", "land"), + ] + ] + assert result == expected + + +def test_single_statement_with_action_and_constraint_and_comment(): + text = "area: mean where land (comment: mask=landFrac)" + result = parse_cell_methods(text) + expected = [ + [ + ("DIMENSION", "area"), + ("FUNCTION", "mean"), + ("CONSTRAINT", "where"), + ("AREATYPE", "land"), + ("COMMENT", "mask=landFrac"), + ] + ] + assert result == expected + + +def test_many_dimensions_map_to_single_function(): + text = "area: depth: time: mean" + result = parse_cell_methods(text) + expected = [ + [ + ("DIMENSION", "area"), + ("FUNCTION", "mean"), + ], + [ + ("DIMENSION", "depth"), + ("FUNCTION", "mean"), + ], + [ + ("DIMENSION", "time"), + ("FUNCTION", "mean"), + ], + ] + assert result == expected + + +def test_statements_with_comment_in_middle(): + text = "longitude: sum (comment: basin sum [along zig-zag grid path]) depth: sum time: mean" + result = parse_cell_methods(text) + expected = [ + [ + ("DIMENSION", "longitude"), + ("FUNCTION", "sum"), + ("COMMENT", "basin sum along zig-zag grid path"), + ], + [ + ("DIMENSION", "depth"), + ("FUNCTION", "sum"), + ], + [ + ("DIMENSION", "time"), + ("FUNCTION", "mean"), + ], + ] + assert result == expected + + +def test_time_dimension_constraint_omits_areatpye(): + text = "area: time: mean where cloud" + result = parse_cell_methods(text) + expected = [ + [ + ("DIMENSION", "area"), + ("FUNCTION", "mean"), + ("CONSTRAINT", "where"), + ("AREATYPE", "cloud"), + ], + [ + ("DIMENSION", "time"), + ("FUNCTION", "mean"), + ], + ] + assert result == expected + + +def test_multiple_contraints(): + text = "area: mean where land over all_area_types time: mean" + result = parse_cell_methods(text) + expected = [ + [ + ("DIMENSION", "area"), + ("FUNCTION", "mean"), + ("CONSTRAINT", "where"), + ("AREATYPE", "land"), + ("CONSTRAINT", "over"), + ("SELECTION", "all_area_types"), + ], + [ + ("DIMENSION", "time"), + ("FUNCTION", "mean"), + ], + ] + assert result == expected + + +def test_statements_with_repeated_dimensions(): + text = "area: mean where crops time: minimum within days time: mean over days" + result = parse_cell_methods(text) + expected = [ + [ + ("DIMENSION", "area"), + ("FUNCTION", "mean"), + ("CONSTRAINT", "where"), + ("AREATYPE", "crops"), + ], + [ + ("DIMENSION", "time"), + ("FUNCTION", "minimum"), + ("CONSTRAINT", "within"), + ("SELECTION", "days"), + ], + [ + ("DIMENSION", "time"), + ("FUNCTION", "mean"), + ("CONSTRAINT", "over"), + ("SELECTION", "days"), + ], + ] + assert result == expected + + +def test_area_dimension_contraint_omits_selection(): + text = "area: time: mean over days" + result = parse_cell_methods(text) + expected = [ + [ + ("DIMENSION", "area"), + ("FUNCTION", "mean"), + ], + [ + ("DIMENSION", "time"), + ("FUNCTION", "mean"), + ("CONSTRAINT", "over"), + ("SELECTION", "days"), + ], + ] + assert result == expected diff --git a/src/pymorize/prototype/cellmethods/test_xarray_translation.py b/src/pymorize/prototype/cellmethods/test_xarray_translation.py new file mode 100644 index 00000000..8f47eda4 --- /dev/null +++ b/src/pymorize/prototype/cellmethods/test_xarray_translation.py @@ -0,0 +1,46 @@ +from .cellmethods_parser import translate_to_xarray + +test_cases = [ + ("area: mean", "da.mean(dim=area)"), + ("area: mean where sea", "da.mean(dim=area).where(sea)"), + ( + "area: mean where sea time: mean", + "result_1 = da.mean(dim=area).where(sea)\n" "result = result_1.mean(dim=time)", + ), + ( + "area: mean time: maximum within days", + "result_1 = da.mean(dim=area)\nresult = result_1.groupby(days).max(dim=time) # within", + ), + ( + "area: mean time: mean within days time: mean over days", + "result_1 = da.mean(dim=area)\n" + "result_2 = result_1.groupby(days).mean(dim=time) # within\n" + "result = result_2.groupby(days).mean(dim=time) # over", + ), + ( + "area: mean (comment: over land and sea ice) time: point", + "result_1 = da.mean(dim=area) # comment: over land and sea ice\n" + "result = result_1.isel(dim=time)", + ), + ( + "area: depth: time: mean", + "result_1 = da.mean(dim=area)\n" + "result_2 = result_1.mean(dim=depth)\n" + "result = result_2.mean(dim=time)", + ), +] + + +def test_translations(): + for input_text, expected in test_cases: + result = translate_to_xarray(input_text) + assert ( + result == expected + ), f"\nInput: {input_text}\nExpected:\n{expected}\nGot:\n{result}" + print(f"\nInput: {input_text}") + print("Generated xarray code:") + print(result) + + +if __name__ == "__main__": + test_translations()