Skip to content

Commit

Permalink
Fixes for fine tuning and input data loading.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed Apr 4, 2024
1 parent 86f25f4 commit 9d9b9d6
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 24 deletions.
6 changes: 6 additions & 0 deletions gprof_nn/data/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,15 @@ def extract_finetuning_samples(
surface_precip="surface_precip_combined"
).drop(["latitude", "longitude"])

for ref_var in reference_variables:
var_data = ref_data[ref_var].data
var_data[var_data < -1_000] = np.nan


data = xr.merge(
[input_data, ref_data]
)
data.attrs["source"] = "collocs"

write_training_samples_1d(
output_path_1d,
Expand Down
9 changes: 5 additions & 4 deletions gprof_nn/data/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,10 +555,10 @@ def load_tbs_1d_conical_other(
"""
tbs = training_data["brightness_temperatures"].data
tbs_full = np.nan * np.ones(tbs.shape[:-1] + (15,), dtype="float32")
tbs_full[:, sensor.gmi_channels] = tbs
tbs_full[:, sensor.gprof_channels] = tbs
angles = training_data["earth_incidence_angle"].data
angles_full = np.nan * np.ones(tbs.shape[:-1] + (15,), dtype="float32")
angles_full[:, sensor.gmi_channels] = angles
angles_full[:, sensor.gprof_channels] = angles
tbs = torch.tensor(tbs_full.astype("float32"))
angles = torch.tensor(angles_full.astype("float32"))
return tbs, angles
Expand Down Expand Up @@ -786,6 +786,7 @@ def load_training_data(self, dataset: xr.Dataset) -> Dict[str, torch.Tensor]:
anc = load_ancillary_data_1d(dataset)
targets = load_targets_1d(dataset, self.targets)


x = {
"brightness_temperatures": tbs,
"ancillary_data": anc,
Expand Down Expand Up @@ -1399,7 +1400,7 @@ def __getitem__(self, ind) -> Tuple[Dict[str, torch.Tensor], Path]:
}
return inpt_data, self.files[ind]

def save_results(self, results, output_path, input_file) -> None:
def finalize_results(self, results, input_file) -> Tuple[xr.Dataset, str]:
"""
Save simulator results to training file.
Expand Down Expand Up @@ -1431,7 +1432,7 @@ def save_results(self, results, output_path, input_file) -> None:
"_FillValue": 2 ** 16 - 1,
"zlib": True
}
output_data.to_netcdf(input_file)
return output_data, input_file.name



Expand Down
88 changes: 68 additions & 20 deletions gprof_nn/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tempfile import TemporaryDirectory
from pathlib import Path
import re
from typing import List, Union
from typing import Dict, List, Union, Tuple

import numpy as np
import xarray as xr
Expand All @@ -23,7 +23,7 @@
import pandas as pd

from gprof_nn import sensors
from gprof_nn.definitions import PROFILE_TARGETS, ALL_TARGETS
from gprof_nn.definitions import PROFILE_TARGETS, ALL_TARGETS, ANCILLARY_VARIABLES
from gprof_nn.data import get_profile_clusters
from gprof_nn.data.bin import BinFile
from gprof_nn.data.training_data import (
Expand Down Expand Up @@ -1689,11 +1689,22 @@ def finalize(self, data):
return self.dataset


GPROF_CHANNELS = {
""
}



class L1CLoader:
"""
Loads retrieval input from a L1C file.
Loader object to load retrieval input from one or multiple L1C files.
"""
def __init__(self, inputs: Union[str, Path, List[str], List[Path], List[Granule]], config):
def __init__(
self,
inputs: Union[str, Path, List[str], List[Path], List[Granule]],
config: str,
ancillary_data: bool = True
):

if isinstance(inputs, Path) and path.is_dir():
self.files = sorted(list(path.glob("**/*.HDF5")))
Expand All @@ -1702,48 +1713,85 @@ def __init__(self, inputs: Union[str, Path, List[str], List[Path], List[Granule]
self.files = inputs
else:
self.files = [inputs]
self.config = config
self.ancillary_data = ancillary_data


def load_data(self, file):
def load_data(self, file: Union[Path, Granule]) -> Dict[str, torch.tensor]:
"""
Load input data for a given file.
Runs the preprocessor on the given L1C file and loads the retrieval input data
from the results.
Args:
file: A path (or granule object) identifying (a subset of) a L1C file.
"""

print(file)
if isinstance(file, Granule):
with TemporaryDirectory() as tmp:
input_file = Path(tmp) / file.file_record.local_path.name
l1c_file = L1CFile(file.file_record.local_path)
sensor = l1c_file.sensor
l1c_file.extract_scan_range(*file.primary_index_range, input_file)
data_pp = run_preprocessor(input_file, L1CFile(input_file).sensor)
input_data = run_preprocessor(input_file, sensor)
else:
input_file = file
data_pp = run_preprocessor(input_file, L1CFile(input_file).sensor)
l1c_file = L1CFile(file.file_record.local_path)
sensor = l1c_file.sensor
input_data = run_preprocessor(input_file, sensor)

tbs = data_pp.brightness_temperatures.data
filename = input_file.name

tbs = input_data.brightness_temperatures.data
tbs[tbs < 0] = np.nan
tbs_full = np.nan * np.zeros((tbs.shape[:2] +(15,)), dtype=np.float32)
tbs_full[..., sensor.gprof_channels] = tbs

angs = data_pp.earth_incidence_angle.data
angs = input_data.earth_incidence_angle.data
angs[angs < -100] = np.nan
angs_full = np.nan * np.zeros((angs.shape[:2] +(15,)), dtype=np.float32)
angs_full[..., sensor.gmi_channels] = angs

anc = np.stack([data_pp[var] for var in ANCILLARY_VARIABLES], -1)
anc = np.stack([input_data[var] for var in ANCILLARY_VARIABLES], -1)

if self.config == "1d":
return {
"brightness_temperatures": torch.tensor(tbs.reshape(-1, tbs.shape[-1])),
"viewing_angles": torch.tensor(angs.reshape(-1, angs.shape[-1])),
"ancillary_data": torch.tensor(anc.reshape(-1, anc.shape[-1])),
}
"brightness_temperatures": torch.tensor(tbs_full.reshape(-1, 15)),
"viewing_angles": torch.tensor(angs_full.reshape(-1, 15)),
"ancillary_data": torch.tensor(anc.reshape(-1, 8)),
}, filename, input_data

tbs = np.transpose(tbs, (2, 0, 1))
angs = np.transpose(angs, (2, 0, 1))
tbs_full = np.transpose(tbs_full, (2, 0, 1))
angs_full = np.transpose(angs_full, (2, 0, 1))
anc = np.transpose(anc, (2, 0, 1))

return {
"brightness_temperatures": torch.tensor(tbs),
"viewing_angles": torch.tensor(angs),
"brightness_temperatures": torch.tensor(tbs_full),
"viewing_angles": torch.tensor(angs_full),
"ancillary_data": torch.tensor(anc),
}
}, filename, input_data

def __len__(self):
return len(self.files)

def __iter__(self):
for file in self.files:
yield self.load_data(file)

def finalize_results(
self,
results: xr.Dataset,
filename: Path,
preprocessor_data: xr.Dataset
) -> xr.Dataset:
"""
"""
data = preprocessor_data.copy()
shape = (data.scans.size, data.pixels.size)
for var in results:
var_data = results[var].data
data[var] = (("scans", "pixels"), var_data.reshape(shape))
return data
1 change: 1 addition & 0 deletions gprof_nn/sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2465,6 +2465,7 @@ def get_sensor(sensor, platform=None, date=None):
modeling_error=AMSR2_MODELING_ERROR,
correction=DATA_FOLDER / "corrections_amsr2.nc",
)
AMSR2.gprof_channels = np.arange(10)


###############################################################################
Expand Down

0 comments on commit 9d9b9d6

Please sign in to comment.