Skip to content

Commit

Permalink
add interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Nov 24, 2024
1 parent 0cb46b6 commit ca9b406
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 98 deletions.
5 changes: 5 additions & 0 deletions src/ai_models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def _main(argv):
choices=sorted(available_outputs()),
)

parser.add_argument(
"--interpolate",
help="Should the results be interpolated",
)

parser.add_argument(
"--date",
default="-1",
Expand Down
89 changes: 0 additions & 89 deletions src/ai_models/inputs/interpolate.py

This file was deleted.

6 changes: 3 additions & 3 deletions src/ai_models/inputs/opendata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from earthkit.data.indexing.fieldlist import FieldArray
from multiurl import download

from ..interpolate import Interpolate
from .base import RequestBasedInput
from .compute import make_z_from_gh
from .interpolate import Interpolate
from .recenter import recenter
from .transform import NewMetadataField

Expand Down Expand Up @@ -71,7 +71,7 @@ def _adjust(self, kwargs):
if isinstance(grid, list):
grid = tuple(grid)

kwargs["resol"], source, interp, oversampling, metadata = RESOLS[grid]
kwargs["resol"], source, interp, oversampling, _ = RESOLS[grid]
r = dict(**kwargs)
r.update(self.owner.retrieve)

Expand All @@ -80,7 +80,7 @@ def _adjust(self, kwargs):
logging.info("Interpolating input data from %s to %s.", source, grid)
if oversampling:
logging.warning("This will oversample the input data.")
return Interpolate(grid, source, metadata)
return Interpolate(source=source, target=grid)
else:
return _identity

Expand Down
2 changes: 1 addition & 1 deletion src/ai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def collect_archive_requests(self, written):
self.archiving[path].add(handle.as_namespace("mars"))

def finalise(self):
self.output.flush()
self.output.close()

if self.archive_requests:
with open(self.archive_requests, "w") as f:
Expand Down
50 changes: 45 additions & 5 deletions src/ai_models/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Output:
def write(self, *args, **kwargs):
pass

def flush(self, *args, **kwargs):
def close(self):
pass


Expand Down Expand Up @@ -104,6 +104,9 @@ def write(self, data, *args, check=False, **kwargs):

return handle, path

def close(self):
self.output.close()


class FileOutput(GribOutputBase):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -167,8 +170,8 @@ def write(self, *args, **kwargs):

return self.output.write(*args, **kwargs)

def flush(self, *args, **kwargs):
return self.output.flush(*args, **kwargs)
def close(self):
return self.output.close()


class NoLabelling:
Expand All @@ -181,8 +184,40 @@ def write(self, *args, **kwargs):
kwargs["deleteLocalDefinition"] = 1
return self.output.write(*args, **kwargs)

def flush(self, *args, **kwargs):
return self.output.flush(*args, **kwargs)
def close(self):
return self.output.close()


class InterpolatedOutput:
def __init__(self, owner, output, interpolate, **kwargs):
self.owner = owner
self.output = output
try:
self.target = (float(interpolate), float(interpolate))
except ValueError:
self.target = interpolate.upper()

@cached_property
def interpolator(self):
from ..interpolate import Interpolate

return Interpolate(source=self.owner.grid, target=self.target)

def write(self, values, template, *args, **kwargs):

if values is None:
values = template.to_numpy(flatten=True)
# We need to extract a few more metadata from the template
for m in ("date", "time", "step", "param", "paramId", "shortName"):
kwargs[m] = template.metadata(m)

values, metadata = self.interpolator.interpolate(values)
kwargs.update(metadata)

return self.output.write(values, template, *args, **kwargs)

def close(self):
return self.output.close()


def get_output(name, owner, *args, **kwargs):
Expand All @@ -191,6 +226,11 @@ def get_output(name, owner, *args, **kwargs):
result = HindcastReLabel(owner, result, **kwargs)
if owner.expver is None:
result = NoLabelling(owner, result, **kwargs)

if kwargs.get("interpolate") is not None:
# Interpolate the output
result = InterpolatedOutput(owner, result, **kwargs)

return result


Expand Down

0 comments on commit ca9b406

Please sign in to comment.