diff --git a/gprof_nn/config.py b/gprof_nn/config.py index 4edcae5..927cb69 100644 --- a/gprof_nn/config.py +++ b/gprof_nn/config.py @@ -78,7 +78,7 @@ class DataConfig(ConfigBase): """ era5_path : Path = Path("/qdata2/archive/ERA5") model_path : Path = Path(user_data_dir("gprof_nn", "gprof_nn")) / "models" - mrms_path : Path = Path("/pdata4/mrms/") + mrms_path : Path = Path("/pdata4/veljko/") def print(self): txt = "[data]\n" diff --git a/gprof_nn/data/era5.py b/gprof_nn/data/era5.py index fdd2b56..dd9ef91 100644 --- a/gprof_nn/data/era5.py +++ b/gprof_nn/data/era5.py @@ -19,6 +19,7 @@ from gprof_nn.config import CONFIG from gprof_nn import sensors +from gprof_nn.definitions import DATA_SPLIT from gprof_nn.data.l1c import L1CFile from gprof_nn.data.preprocessor import run_preprocessor from gprof_nn.logging import get_console, log_messages @@ -174,8 +175,6 @@ def process_l1c_file( # Drop unneeded variables. drop = ["sunglint_angle", "quality_flag", "wet_bulb_temperature", "lapse_rate"] - if not isinstance(sensor, sensors.CrossTrackScanner): - drop.append("earth_incidence_angle") data_pp = data_pp.drop_vars(drop) start_time = data_pp["scan_time"].data[0] @@ -183,7 +182,7 @@ def process_l1c_file( era5_data = load_era5_data(start_time, end_time) add_era5_precip(data_pp, era5_data) - data_pp.attrs["source"] = 2 + data_pp.attrs["source"] = "era5" if output_path_1d is not None: write_training_samples_1d( output_path_1d, @@ -191,8 +190,8 @@ def process_l1c_file( data_pp, ) if output_path_3d is not None: - n_pixels = data_pp.pixels.size - n_scans = max(n_pixels, 128) + n_pixels = 64 + n_scans = 128 write_training_samples_3d( output_path_3d, "mrms", @@ -212,6 +211,7 @@ def process_l1c_files( end_time: np.datetime64, output_path_1d: Optional[Path] = None, output_path_3d: Optional[Path] = None, + split: str = None, n_processes: int = 4, log_queue: Optional[Queue] = None ): @@ -226,6 +226,8 @@ def process_l1c_files( the training samples for the GPROF-NN 1D retrieval. output_path_3d: Path pointing to the folder to which to write the training samples for the GPROF-NN 3D retrieval. + split: An optional string specifying whether to extract only data from + one of the three data splits ['training', 'validation', 'test']. n_processes: The number of processes to use for parallel processing. log_queue: Queue to use for logging from sub-processes. @@ -241,10 +243,25 @@ def process_l1c_files( LOGGER.info("Looking for files in %s.", l1c_path) while time < end_time: - l1c_files += L1CFile.find_files(time, l1c_path) + + l1c_files_day = L1CFile.find_files(time, l1c_path, sensor=sensor) + # Check if day of month should be skipped. + if split is not None: + days = DATA_SPLIT[split] + l1c_files_split = [] + for l1c_file in l1c_files: + time = L1CFile(l1c_file).start_time + day_of_month = int( + (time - time.astype("datetime64[M]")).astype("timedelta64[D]").astype("int64") + ) + if day_of_month + 1 in days: + l1c_files_split.append(l1c_file) + l1c_files_day = l1c_files_split + + l1c_files += l1c_files_day time += np.timedelta64(24 * 60 * 60, "s") - LOGGER.info("Found %s L1C fiels to process", len(l1c_files)) + LOGGER.info("Found %s L1C files to process", len(l1c_files)) pool = ProcessPoolExecutor(max_workers=n_processes) @@ -263,7 +280,7 @@ def process_l1c_files( with Progress(console=get_console()) as progress: pbar = progress.add_task( - "Extracting pretraining data:", + "Extracting ERA5 collocations:", total=len(tasks) ) for task in as_completed(tasks): diff --git a/gprof_nn/data/l1c.py b/gprof_nn/data/l1c.py index 394fbc7..b1ba1d1 100644 --- a/gprof_nn/data/l1c.py +++ b/gprof_nn/data/l1c.py @@ -168,6 +168,8 @@ def find_files(cls, date, path, roi=None, sensor=sensors.GMI): month = date.month day = date.day data_path = Path(path) / f"{year:02}{month:02}" / f"{year:02}{month:02}{day:02}" + print("DATA PATH :: ", data_path) + print("PATTERN :: ", sensor.l1c_file_prefix + f"*{date.year:04}{month:02}{day:02}*{sensor.l1c_version}.HDF5") files = list( data_path.glob( sensor.l1c_file_prefix + f"*{date.year:04}{month:02}{day:02}*{sensor.l1c_version}.HDF5" @@ -478,23 +480,45 @@ def to_xarray_dataset(self, roi=None): # Handle case that observations are split up. tbs = [] + eia = [] tbs.append(input[f"{swath}/Tc"][:][indices]) + eia_s = input[f"{swath}/incidenceAngle"][:][indices] + eia_s = np.broadcast_to(eia_s, tbs[-1].shape) + eia.append(eia_s) if "S2" in input.keys(): tbs.append(input["S2/Tc"][:][indices]) + eia_s = input[f"S2/incidenceAngle"][:][indices] + eia_s = np.broadcast_to(eia_s, tbs[-1].shape) + eia.append(eia_s) if "S3" in input.keys(): tbs.append(input["S3/Tc"][:][indices]) + eia_s = input[f"s2/incidenceangle"][:][indices] + eia_s = np.broadcast_to(eia_s, tbs[-1].shape) + eia.append(eia_s) if "S4" in input.keys(): tbs.append(input["S4/Tc"][:][indices]) + eia_s = input[f"s2/incidenceangle"][:][indices] + eia_s = np.broadcast_to(eia_s, tbs[-1].shape) + eia.append(eia_s) if "S5" in input.keys(): tbs_s = input["S5/Tc"][:][indices] + eia_s = input[f"s2/incidenceangle"][:][indices] if tbs_s.shape[-2] > tbs[-1].shape[-2]: tbs_s = tbs_s[..., ::2, :] + eia_s = eia_s[..., ::2] tbs.append(tbs_s) + eia_s = np.broadcast_to(eia_s, tbs[-1].shape) + eia.append(eia_s) if "S6" in input.keys(): tbs_s = input["S6/Tc"][:][indices] + eia_s = input[f"s2/incidenceangle"][:][indices] if tbs_s.shape[-2] > tbs[-1].shape[-2]: tbs_s = tbs_s[..., ::2, :] + eia_s = eia_s[..., ::2] tbs.append(tbs_s) + eia_s = input[f"s2/incidenceangle"][:][indices] + eia_s = np.broadcast_to(eia_s, tbs[-1].shape) + eia.append(eia_s) n_pixels = max([array.shape[1] for array in tbs]) tbs_r = [] @@ -546,11 +570,8 @@ def to_xarray_dataset(self, roi=None): "scan_time": (dims[:1], times), } - if "incidenceAngle" in input[f"{swath}"].keys(): - data["incidence_angle"] = ( - dims, - input[f"{swath}/incidenceAngle"][indices, :, 0], - ) + eia = np.concatenate(eia, axis=-1) + data["incidence_angle"] = (dims + ("channels",), eia) if "SCorientation" in input[f"{swath}/SCstatus"]: data["sensor_orientation"] = ( diff --git a/gprof_nn/data/mrms.py b/gprof_nn/data/mrms.py index 112a171..5f4c53b 100644 --- a/gprof_nn/data/mrms.py +++ b/gprof_nn/data/mrms.py @@ -15,15 +15,15 @@ import click import numpy as np import xarray as xr -import pandas as pd from pyresample import geometry, kd_tree from pykdtree.kdtree import KDTree from rich.progress import Progress from scipy.signal import convolve from gprof_nn import sensors -from gprof_nn.logging import get_console, log_messages from gprof_nn.coordinates import latlon_to_ecef +from gprof_nn.definitions import DATA_SPLIT +from gprof_nn.logging import get_console, log_messages from gprof_nn.data.validation import unify_grid, calculate_angles from gprof_nn.data.preprocessor import run_preprocessor from gprof_nn.data.l1c import L1CFile @@ -269,7 +269,7 @@ def extract_collocations( # Match targets match_file.match_targets(data_pp) - data_pp.attrs["source"] = 1 + data_pp.attrs["source"] = "mrms" if output_path_1d is not None: write_training_samples_1d( @@ -279,8 +279,8 @@ def extract_collocations( ) if output_path_3d is not None: - n_pixels = data_pp.pixels.size - n_scans = max(n_pixels, 128) + n_pixels = 64 + n_scans = 128 write_training_samples_3d( output_path_3d, "mrms", @@ -299,7 +299,8 @@ def process_match_file( l1c_path: Path, output_path_1d: Optional[Path] = None, output_path_3d: Optional[Path] = None, - n_processes: int = 4 + n_processes: int = 4, + split: Optional[str] = None ): """ Process a single MRMS match-up file. @@ -311,21 +312,40 @@ def process_match_file( l1c_file: Path object pointing to the L1C file to collocate with the match ups. output_path_1d: Path pointing to the folder to which to write - the GPROF-NN 1D training data. + the output_path_3d: Path pointing to the folder to which to write the GPROF-NN 3D training data. + n_processes: The number of processes to use for the data + extraction. + split: An optional string 'train', 'validation', 'test' specifying + which split of the dataset to extract. """ + match_file = Path(match_file) year_month = match_file.name[:4] + l1c_path = Path(l1c_path) l1c_files = (l1c_path / year_month).glob( f"**/{sensor.l1c_file_prefix}*.HDF5" ) l1c_files = sorted(list(l1c_files)) + if split is not None: + l1c_files_split = [] + days = DATA_SPLIT[split] + for l1c_file in l1c_files: + time = L1CFile(l1c_file).start_time + day_of_month = int( + (time - time.astype("datetime64[M]")).astype("timedelta64[D]").astype("int64") + ) + if day_of_month + 1 in days: + l1c_files_split.append(l1c_file) + l1c_files = l1c_files_split + LOGGER.info( f"Found {len(l1c_files)} L1C files matching MRMS match-up file " f"{match_file}." ) + pool = ProcessPoolExecutor(max_workers=n_processes) tasks = [] for l1c_file in l1c_files: @@ -369,6 +389,7 @@ def process_match_files( l1c_path: Path, output_path_1d: Path, output_path_3d: Path, + split: Optional[str] = None, n_processes: int = 4 ): """ @@ -391,6 +412,7 @@ def process_match_files( l1c_path, output_path_1d, output_path_3d, + split=split, n_processes=n_processes ) @@ -455,6 +477,7 @@ def cli( l1c_path, output_path_1d, output_path_3d, + split=split, n_processes=n_processes ) diff --git a/gprof_nn/data/preprocessor.py b/gprof_nn/data/preprocessor.py index 4606a98..c39e2f9 100644 --- a/gprof_nn/data/preprocessor.py +++ b/gprof_nn/data/preprocessor.py @@ -726,12 +726,12 @@ def has_preprocessor(): "AMSR2": "gprof2023pp_AMSR2_L1C", "AMSRE": "gprof2021pp_AMSRE_L1C", "ATMS": "gprof2021pp_ATMS_L1C", - ("GMI", "MHS"): "gprof2021pp_GMI_MHS_L1C", + ("GMI", "MHS"): "gprof2023pp_GMI_L1C", ("GMI", "TMIPR"): "gprof2021pp_GMI_TMI_L1C", ("GMI", "TMIPO"): "gprof2021pp_GMI_TMI_L1C", ("GMI", "SSMI"): "gprof2021pp_GMI_SSMI_L1C", ("GMI", "SSMIS"): "gprof2021pp_GMI_SSMIS_L1C", - ("GMI", "AMSR2"): "gprof2021pp_GMI_AMSR2_L1C", + ("GMI", "AMSR2"): "gprof2023pp_GMI_L1C", ("GMI", "AMSRE"): "gprof2021pp_GMI_AMSRE_L1C", ("GMI", "ATMS"): "gprof2021pp_GMI_ATMS_L1C", } diff --git a/gprof_nn/data/retrieval.py b/gprof_nn/data/retrieval.py index 1444ac3..5ee973b 100644 --- a/gprof_nn/data/retrieval.py +++ b/gprof_nn/data/retrieval.py @@ -11,7 +11,6 @@ from pathlib import Path import numpy as np -from quantnn.normalizer import Normalizer import xarray from gprof_nn.definitions import MISSING diff --git a/gprof_nn/data/sim.py b/gprof_nn/data/sim.py index f49e1e8..dfc6003 100644 --- a/gprof_nn/data/sim.py +++ b/gprof_nn/data/sim.py @@ -34,7 +34,7 @@ DATA_SPLIT, LEVELS, N_LAYERS, - PROFILE_NAMES, + PROFILE_TARGETS, SEAICE_YEARS, ) @@ -46,7 +46,9 @@ from gprof_nn.data.utils import ( compressed_pixel_range, N_PIXELS_CENTER, - save_scene + save_scene, + write_training_samples_1d, + write_training_samples_3d ) from gprof_nn.logging import get_console from gprof_nn.utils import CONUS @@ -72,16 +74,6 @@ ############################################################################### -CHANNEL_INDICES = { - "TMIPO": [0, 1, 2, 3, 4, 6, 7, 8, 9], - "TMIPR": [0, 1, 2, 3, 4, 6, 7, 8, 9], - "SSMI": [2, 3, 4, 6, 7, 8, 9], - "SSMIS": [2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14], - "AMSR2": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], - "AMSRE": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -} - - class SimFile: """ Interface class to read GPROF .sim files. @@ -198,7 +190,10 @@ def match_targets(self, input_data, targets=None): n_angles = 0 if self.sensor.n_angles > 1: n_angles = self.sensor.n_angles + n_chans = self.sensor.n_chans + if isinstance(self.sensor, sensors.ConicalScanner): + n_chans = 15 if "tbs_simulated" in self.data.dtype.fields: if n_angles > 0: @@ -216,10 +211,6 @@ def match_targets(self, input_data, targets=None): indices = np.clip(indices, 0, matched.shape[0] - 1) tbs = self.data["tbs_simulated"] - if self.sensor.sensor_name in CHANNEL_INDICES: - ch_inds = CHANNEL_INDICES[self.sensor.sensor_name] - tbs = tbs[..., ch_inds] - # tbs = tbs.reshape((-1,) + shape[2:]) matched[indices, ...] = tbs matched[indices, ...][dists > 10e3] = np.nan @@ -242,9 +233,6 @@ def match_targets(self, input_data, targets=None): matched[:] = np.nan biases = self.data["tbs_bias"] - if self.sensor.sensor_name in CHANNEL_INDICES: - ch_inds = CHANNEL_INDICES[self.sensor.sensor_name] - biases = biases[..., ch_inds] matched[indices, ...] = biases matched[indices, ...][dists > 10e3] = np.nan @@ -261,7 +249,7 @@ def match_targets(self, input_data, targets=None): # Extract matching data for target in targets: - if target in PROFILE_NAMES: + if target in PROFILE_TARGETS: n = n_scans * w_c shape = (n_scans, w_c, 28) full_shape = (n_scans, n_pixels, 28) @@ -286,7 +274,7 @@ def match_targets(self, input_data, targets=None): matched_full[:] = np.nan matched_full[:, cmpr] = matched - if target in PROFILE_NAMES: + if target in PROFILE_TARGETS: data = matched_full[:, cmpr] input_data[target] = ( ("scans", "pixels_center", "levels"), @@ -326,23 +314,15 @@ def to_xarray_dataset(self): for key, _, *shape in record_type.descr: data = self.data[key] - if key in [ - "emissivity", - "tbs_observed", - "tbs_simulated", - "tbs_bias", - "d_tbs", - ]: - if self.sensor.sensor_name in CHANNEL_INDICES: - ch_inds = CHANNEL_INDICES[self.sensor.sensor_name] - data = data[..., ch_inds] - dims = ("samples",) if len(data.shape) > 1: dims = dims + tuple([dim_dict[s] for s in data.shape[1:]]) results[key] = dims, data + if isinstance(self.sensor, sensors.CrossTrackScanner): + results["angles"] = (("angles",), self.header["viewing_angles"][0]) + dataset = xr.Dataset(results) year = dataset["scan_time"].data["year"] - 1970 @@ -573,12 +553,7 @@ def collocate_targets( # and therefore renamed. sensor = sim_file.sensor if sensor != sensors.GMI: - data_pp = data_pp.rename( - { - "channels": "channels_gmi", - "brightness_temperatures": "brightness_temperatures_gmi", - } - ) + data_pp = data_pp.rename({"channels": "channels_gmi"}) # Match targets from sim file to preprocessor data. LOGGER.debug("Matching retrieval targets for file %s.", sim_filename) @@ -615,89 +590,10 @@ def collocate_targets( if subset is not None: subset.mask_surface_precip(data_pp) - return data_pp + data_pp.attrs["sensor"] = sim_file.sensor.name + data_pp.attrs["source"] = "sim" - -def write_training_samples_1d( - dataset: xr.Dataset, - output_path: Path, -) -> None: - """ - Write training data in GPROF-NN 1D format. - - Args: - dataset: An 'xarray.Dataset' containing collocated input - observations and reference data. - output_path: Path to which the training data will be written. - """ - subset = {} - dataset = dataset[{"pixels": slice(*compressed_pixel_range())}] - mask = np.isfinite(dataset.surface_precip.data) - - for var in dataset.variables: - arr = dataset[var] - if arr.data.ndim < 2: - arr_data = np.broadcast_to(arr.data[..., None], mask.shape) - else: - arr_data = arr.data - - subset[var] = ((("samples",) + arr.dims[2:]), arr_data[mask]) - - subset = xr.Dataset(subset) - start_time = pd.to_datetime(dataset.scan_time.data[0].item()) - start_time = start_time.strftime("%Y%m%d%H%M%S") - end_time = pd.to_datetime(dataset.scan_time.data[-1].item()) - end_time = end_time.strftime("%Y%m%d%H%M%S") - filename = f"sim_{start_time}_{end_time}.nc" - - save_scene(subset, output_path / filename) - - -def write_training_samples_3d( - dataset, - output_path, - min_valid=20 -): - """ - Write training data in GPROF-NN 3D format. - - Args: - dataset: - output_path: - min_valid: The minimum number of valid surface precipitation - pixels for a scene to be stored. - - """ - mask = np.any(np.isfinite(dataset.surface_precip.data), 1) - valid_scans = np.where(mask)[0] - n_scans = dataset.scans.size - - encodings = { - name: {"zlib": True} for name in dataset.variables - } - - while len(valid_scans) > 0: - - ind = np.random.randint(0, len(valid_scans)) - scan_start = min(max(valid_scans[ind] - 110, 0), n_scans - 221) - scan_end = scan_start + 221 - - scene = dataset[{"scans": slice(scan_start, scan_end)}] - start_time = pd.to_datetime(scene.scan_time.data[0].item()) - start_time = start_time.strftime("%Y%m%d%H%M%S") - end_time = pd.to_datetime(scene.scan_time.data[-1].item()) - end_time = end_time.strftime("%Y%m%d%H%M%S") - filename = f"sim_{start_time}_{end_time}.nc" - - valid_pixels = (scene.surface_precip.data >= 0.0).sum() - if valid_pixels > min_valid: - #scene.to_netcdf(output_path / filename, encoding=encodings) - save_scene(scene, output_path / filename) - - within_scene = (valid_scans >= scan_start) * (valid_scans < scan_end) - if within_scene.sum() == 0: - break - valid_scans = valid_scans[~within_scene] + return data_pp def process_sim_file( @@ -724,9 +620,9 @@ def process_sim_file( """ data = collocate_targets(sim_file, sensor, era5_path) if output_path_1d is not None: - write_training_samples_1d(data, output_path_1d) + write_training_samples_1d(output_path_1d, "sim", data) if output_path_3d is not None: - write_training_samples_3d(data, output_path_3d) + write_training_samples_3d(output_path_3d, "sim", data, n_scans=221, n_pixels=221) def process_files( @@ -762,7 +658,14 @@ def process_files( files = [] for path in sim_files: date = path.stem.split(".")[-2] - date = datetime.strptime(date, "%Y%m%d") + try: + date = datetime.strptime(date, "%Y%m%d") + except ValueError: + LOGGER.warning( + "Ignoring file not matching expected sim file name patter: %s", + path + ) + continue if start_time is not None and date < start_time: continue if end_time is not None and date > end_time: @@ -912,375 +815,3 @@ def cli(sensor: Sensor, -def add_brightness_temperatures(data, sensor): - """ - Add brightness temperatures variables to dataset. - - Simulated observations from *.sim files for sensors other than GMI lack - the 'brightness_temperature' variable. This function adds these as empty - variables to enable merging with MRMS- and L1C-derived datasets. - - Args: - data: 'xarray.Dataset' containing the matched data from the *.sim file. - sensor: Sensor object representing the sensor for which the data is - extracted. - - Return: - The 'xarray.Dataset' with the added 'brighness_temperatures' variable. - """ - if "brightness_temperatures" in data.variables.keys(): - return data - n_samples = data.samples.size - n_scans = data.scans.size - n_pixels = data.pixels.size - - n_channels = sensor.n_chans - shape = (n_samples, n_scans, n_pixels, n_channels) - bts = np.zeros(shape, dtype=np.float32) - bts[:] = np.nan - data["brightness_temperatures"] = (("samples", "scans", "pixels", "channels"), bts) - return data - - -############################################################################### -# File processor -############################################################################### - - -def get_l1c_files_for_seaice(sensor, day): - """ - Finds sensors L1C files that should be used to extract - ERA5 collocations. - - The function first checks whether there is a specific SEAICE year - defined for the given sensor in ``gprof_nn.definitions``. If that - is not the case it will look for L1C files for the current database - period. - - If the above doesn't produce any L1C files, then GMI collocations - with ERA5 are used. - - Args: - sensor: Sensor for which the data is to be extracted. - - Return: - List of L1C filenames to process. - """ - # Collect L1C files to process. - l1c_file_path = sensor.l1c_file_path - l1c_files = [] - - # Get L1C for specific year ... - if sensor.name in SEAICE_YEARS: - year = SEAICE_YEARS[sensor.name] - for month in range(1, 13): - try: - date = datetime(year, month, day) - l1c_files += list(L1CFile.find_files( - date, l1c_file_path, sensor=sensor - )) - except ValueError: - pass - else: - for year, month in DATABASE_MONTHS: - try: - date = datetime(year, month, day) - l1c_files += list(L1CFile.find_files( - date, l1c_file_path, sensor=sensor - )) - except ValueError: - pass - - # If no L1C files are found use GMI co-locations. - if len(l1c_files) < 1: - for year, month in DATABASE_MONTHS: - try: - date = datetime(year, month, day) - l1c_file_path = sensors.GMI.l1c_file_path - l1c_files += list(L1CFile.find_files( - date, l1c_file_path, sensor=sensors.GMI - )) - except ValueError: - pass - l1c_files = [f.filename for f in l1c_files] - l1c_files = np.random.permutation(l1c_files) - return l1c_files - - -@dataclass -class SubsetConfig: - tcwv_bounds: Optional[Tuple[float, float]] = None - t2m_bounds: Optional[Tuple[float, float]] = None - ocean_only: bool = False - land_only: bool = False - surface_types: Optional[Tuple[float, float]] = None - - - def mask_surface_precip(self, dataset): - """ - Sets surface precip in given dataset to nan for samples - outside certain ancillary data bounds. - - Args: - dataset: An xarray.Dataset containing a 'surface_precip' field - and GPROF anciallary data. - """ - surface_precip = dataset.surface_precip.data - - if self.tcwv_bounds is not None: - tcwv_min, tcwv_max = self.tcwv_bounds - tcwv = dataset.total_column_water_vapor.data - valid = (tcwv >= tcwv_min) * (tcwv <= tcwv_max) - surface_precip[~valid] = np.nan - - if self.t2m_bounds is not None: - t2m_min, t2m_max = self.t2m_bounds - t2m = dataset.two_meter_temperature.data - valid = (t2m >= t2m_min) * (t2m <= t2m_max) - surface_precip[~valid] = np.nan - - if self.ocean_only: - ocean_frac = dataset.ocean_fraction - valid = ocean_frac == 100 - surface_precip[~valid] = np.nan - - if self.land_only: - land_frac = dataset.land_fraction - valid = land_frac == 100 - surface_precip[~valid] = np.nan - - if self.surface_types is not None: - valid = np.zeros_like(surface_precip, dtype=bool) - for surface_type in self.surface_types: - valid += dataset.surface_type.data == surface_type - surface_precip[~valid] = np.nan - - -class SimFileProcessor: - """ - Processor class that manages the extraction of GPROF training data. A - single processor instance processes all *.sim, MRMRS matchup and L1C - files for a given day from each month of the database period. - """ - - def __init__( - self, - output_file, - sensor, - configuration, - era5_path=None, - n_workers=4, - day=None, - subset=None - ): - """ - Create retrieval driver. - - Args: - output_file: The file in which to store the extracted data. - sensor: Sensor object defining the sensor for which to extract - training data. - era5_path: Path to the root of the directory tree containing - ERA5 data. - n_workers: The number of worker processes to use. - day: Day of the month for which to extract the data. - subset: A SubsetConfig object specifying a subset of the - database to extract. - """ - self.output_file = output_file - self.sensor = sensor - self.configuration = configuration - - self.era5_path = era5_path - if self.era5_path is not None: - self.era5_path = Path(self.era5_path) - - self.pool = futures.ProcessPoolExecutor(max_workers=n_workers) - - if day is None: - self.day = 1 - else: - self.day = day - - if subset is None: - subset = SubsetConfig() - self.subset = subset - - def run(self): - """ - Start the processing. - - This will start processing all suitable input files that have been found and - stores the names of the produced result files in the ``processed`` attribute - of the driver. - """ - # Collect simulator files to process. - sim_file_path = self.sensor.sim_file_path - if self.sensor.sim_file_path is not None: - sim_files = SimFile.find_files( - sim_file_path, - sensor=self.sensor, - day=self.day - ) - sim_files = np.random.permutation(sim_files) - else: - sim_files = [] - - # Collect MRMS files to process. - if self.sensor.mrms_file_path is not None: - mrms_file_path = self.sensor.mrms_file_path - if mrms_file_path is None: - mrms_files = MRMSMatchFile.find_files( - sensors.GMI.mrms_file_path, sensor=sensors.GMI - ) - else: - if hasattr(self.sensor, "mrms_sensor"): - mrms_sensor = self.sensor.mrms_sensor - else: - mrms_sensor = self.sensor - mrms_files = MRMSMatchFile.find_files( - mrms_file_path, - sensor=mrms_sensor - ) - mrms_files = np.random.permutation(mrms_files) - else: - mrms_files = [] - - # Collect L1C files to process. - l1c_file_path = self.sensor.l1c_file_path - if self.era5_path is not None: - l1c_files = get_l1c_files_for_seaice(self.sensor, self.day)[:100] - else: - l1c_files = [] - - n_sim_files = len(sim_files) - LOGGER.info("Found %s SIM files.", n_sim_files) - n_mrms_files = len(mrms_files) - LOGGER.info("Found %s MRMS files.", n_mrms_files) - n_l1c_files = len(l1c_files) - LOGGER.info("Found %s L1C files.", n_l1c_files) - i = 0 - - # Submit tasks interleaving .sim and MRMS files. - log_queue = gprof_nn.logging.get_log_queue() - tasks = [] - files = [] - while i < max(n_sim_files, n_mrms_files, n_l1c_files): - if i < n_sim_files: - sim_file = sim_files[i] - files.append(sim_file) - tasks.append( - self.pool.submit( - process_sim_file, - sim_file, - self.sensor, - self.configuration, - self.era5_path, - self.subset, - log_queue=log_queue, - ) - ) - if i < n_mrms_files: - mrms_file = mrms_files[i] - files.append(mrms_file) - if hasattr(self.sensor, "mrms_sensor"): - sensor = self.sensor.mrms_sensor - else: - sensor = self.sensor - tasks.append( - self.pool.submit( - process_mrms_file, - sensor, - mrms_file, - self.configuration, - self.day, - log_queue=log_queue, - ) - ) - if i < n_l1c_files: - l1c_file = l1c_files[i] - files.append(l1c_file) - tasks.append( - self.pool.submit( - process_l1c_file, - l1c_file, - self.sensor, - self.configuration, - self.era5_path, - log_queue=log_queue, - ) - ) - i += 1 - - datasets = [] - output_path = Path(self.output_file).parent - output_file = Path(self.output_file).stem - - # Retrieve extracted observations and concatenate into - # single dataset. - - n_tasks = len(tasks) - n_chunks = 4 - chunk = 1 - - with Progress(console=get_console()) as progress: - pbar = progress.add_task("Extracting data:", total=len(tasks)) - for task, filename in zip(tasks, files): - # Log messages from processes. - task_done = False - dataset = None - while not task_done: - try: - gprof_nn.logging.log_messages() - dataset = task.result() - task_done = True - except futures.TimeoutError: - pass - except Exception as exc: - LOGGER.error( - "The following error was encountered while " - "processing file %s results: %s", - str(filename), - exc, - ) - get_console().print_exception() - task_done = True - progress.advance(pbar) - - if dataset is not None: - dataset = add_brightness_temperatures(dataset, self.sensor) - datasets.append(dataset) - if len(datasets) > n_tasks // n_chunks: - dataset = xr.concat(datasets, "samples") - filename = output_path / (output_file + f"_{chunk:02}.nc") - dataset.attrs["sensor"] = self.sensor.name - - encodings = {} - for var in dataset: - encodings[var] = {"zlib": True} - if dataset[var].dtype == np.float64: - encodings[var]["dtype"] = "float32" - dataset.to_netcdf(filename, encoding=encodings) - # subprocess.run(["lz4", "-f", "--rm", filename], check=True) - LOGGER.info("Finished writing file: %s", filename) - datasets = [] - chunk += 1 - - if len(datasets) > 0: - # Store dataset with sensor name as attribute. - dataset = xr.concat(datasets, "samples") - filename = output_path / (output_file + f"_{chunk:02}.nc") - dataset.attrs["sensor"] = self.sensor.name - dataset.attrs["configuration"] = self.configuration - LOGGER.info("Writing file: %s", filename) - - encodings = {} - for var in dataset: - encodings[var] = {"zlib": True} - if dataset[var].dtype == np.float64: - encodings[var]["dtype"] = "float32" - dataset.to_netcdf(filename, encoding=encodings) - - # Explicit clean up to avoid memory leak. - del datasets - del dataset diff --git a/gprof_nn/data/training_data.py b/gprof_nn/data/training_data.py index 5a6978b..1db7a06 100644 --- a/gprof_nn/data/training_data.py +++ b/gprof_nn/data/training_data.py @@ -7,34 +7,46 @@ the training data for the GPROF-NN retrievals. """ import io +import itertools import math import logging import os from pathlib import Path import subprocess from tempfile import TemporaryDirectory +from typing import Dict, List, Optional, Tuple import numpy as np from scipy.ndimage import rotate import torch -from torch.utils.data import Dataset +from torch.utils.data import Dataset, IterableDataset import xarray as xr from quantnn.normalizer import MinMaxNormalizer from gprof_nn import sensors -from gprof_nn.data.utils import (apply_limits, - compressed_pixel_range, - load_variable, - decompress_scene, - remap_scene, - upsample_scans) +from gprof_nn.utils import ( + calculate_interpolation_weights, + interpolate +) +from gprof_nn.data.utils import ( + apply_limits, + compressed_pixel_range, + load_variable, + decompress_scene, + remap_scene, + upsample_scans +) from gprof_nn.utils import expand_tbs -from gprof_nn.definitions import (MASKED_OUTPUT, - LAT_BINS, - TIME_BINS, - LIMITS, - ALL_TARGETS) +from gprof_nn.definitions import ( + ANCILLARY_VARIABLES, + MASKED_OUTPUT, + LAT_BINS, + TIME_BINS, + LIMITS, + ALL_TARGETS, + PROFILE_TARGETS +) from gprof_nn.data.preprocessor import PreprocessorFile from gprof_nn.augmentation import (get_transformation_coordinates, extract_domain) @@ -70,6 +82,11 @@ } +EIA_GMI = np.array([ + [52.98] * 10 + [49.16] * 5 +]) + + def calculate_resampling_indices(latitudes, time, sensor): """ Calculate scene indices based on latitude and local times. @@ -394,1265 +411,942 @@ def write_preprocessor_file(input_data, output_file): PreprocessorFile.write(output_file, new_data, sensor, template=template) -############################################################################### -# GPROF-NN 1D -############################################################################### - -class Dataset1DBase: +def load_tbs_1d_gmi( + training_data: xr.Dataset, +) -> torch.Tensor: """ - Base class for batched datasets providing generic implementations of batch - access and shuffling. + Load brightness temperatures for GMI training data. + + The training data for GMI contains the actual L1C observations and + thus doesn't need any additional modifications. + + Args: + training_data: The xarray.Dataset containing the training data. + + Return: + A torch tensor containing the loaded brightness temperatures. """ + tbs = training_data["brightness_temperatures"].data + return torch.tensor(tbs) - def __init__(self): - seed = int.from_bytes(os.urandom(4), "big") + os.getpid() - self._rng = np.random.default_rng(seed) - self.indices = None - def _shuffle(self): - if self.indices is None: - self.indices = np.arange(self.x.shape[0]) - if not self._shuffled: - self.indices = self._rng.permutation(self.indices) +def load_tbs_1d_xtrack_sim( + training_data: xr.Dataset, + angles: np.ndarray, + sensor: sensors.Sensor +) -> torch.Tensor: + """ + Load brightness temperatures for cross-track scanning sensors from simulator + collocations. - def __getitem__(self, i): - """ - Return element from the dataset. This is part of the - pytorch interface for datasets. + Args: + training_data: An xarray.Dataset containing training data extracted from + GPROF simulator files. + angles: A np.ndarray cotaining the viewing angle of the tbs to load. + sensor: The sensor from which the TBs are loaded - Args: - i(int): The index of the sample to return - """ - if i >= len(self): - LOGGER.info("Finished iterating through dataset %s.", self.filename.name) - raise IndexError() - if i == 0: - if self.shuffle: - self._shuffle() - if self.transform_zeros: - self._transform_zeros() - - if self.indices is None: - self.indices = np.arange(self.x.shape[0]) - - self._shuffled = False - if self.batch_size is None: - if isinstance(self.y, dict): - return ( - torch.tensor(self.x[[i], :]), - {k: torch.tensor(self.y[k][[i]]) for k in self.y}, - ) + Return: + A torch tensor containing the loaded brightness temperatures. + + """ + samples = np.arange(training_data.samples.size) + samples = xr.DataArray(samples, dims="samples") + angles = xr.DataArray(np.abs(angles), dims="samples") - i_start = self.batch_size * i - i_end = self.batch_size * (i + 1) - indices = self.indices[i_start:i_end] + training_data = training_data[ + ["simulated_brightness_temperatures", "brightness_temperature_biases"] + ] + training_data = training_data.interp(samples=samples, angles=angles) + tbs = training_data.simulated_brightness_temperatures.data - x = torch.tensor(self.x[indices, :]) - if isinstance(self.y, dict): - y = {k: torch.tensor(self.y[k][indices]) for k in self.y} - else: - y = torch.tensor(self.y[indices]) + tbs_full = np.nan * np.zeros((tbs.shape[0], 15), dtype=np.float32) + tbs_full[:, sensor.gmi_channels] = tbs - return x, y + biases = training_data.brightness_temperature_biases.data + biases_full = np.nan * np.zeros((tbs.shape[0], 15), dtype=np.float32) + biases_full[:, sensor.gmi_channels] = biases - def __len__(self): - """ - The number of samples in the dataset. - """ - if self.batch_size: - n = self.x.shape[0] // self.batch_size - if (self.x.shape[0] % self.batch_size) > 0: - n = n + 1 - return n - else: - return self.x.shape[0] + biases = ( + biases_full / + np.cos(np.deg2rad(EIA_GMI))[None] * + np.cos(np.deg2rad(angles.data[..., None])) + ) + return torch.tensor(tbs_full - biases) -class GPROF_NN_1D_Dataset(Dataset1DBase): - """ - Dataset class providing an interface for the single-pixel GPROF-NN 1D - retrieval algorithm. - - Attributes: - x: Rank-2 tensor containing the input data with - samples along first dimension. - y: The target values - filename: The filename from which the data is loaded. - targets: List of names of target variables. - batch_size: The size of data batches returned by __getitem__ method. - normalizer: The normalizer used to normalize the data. - shuffle: Whether or not the ordering of the data is shuffled. - augment: Whether or not high-frequency observations are randomly set to - missing to simulate observations at the edge of the swath. + +def load_tbs_1d_conical_sim( + training_data: xr.Dataset, + sensor: sensors.Sensor +) -> torch.Tensor: """ + Load brightness temperatures for cross-track scanning sensors from simulator + collocations. - def __init__( - self, - filename, - targets=None, - normalize=True, - normalizer=None, - transform_zeros=True, - batch_size=512, - shuffle=True, - augment=True, - sensor=None, - permute=None, - ): - """ - Create GPROF 1D dataset. + Args: + training_data: An xarray.Dataset containing training data extracted from + GPROF simulator files. + angles: A np.ndarray cotaining the viewing angle of the tbs to load. + sensor: The sensor from which the TBs are loaded - Args: - filename: Path to the NetCDF file containing the training data to - load. - targets: String or list of strings specifying the names of the - variables to use as retrieval targets. - normalize: Whether or not to normalize the input data. - normalizer: Normalizer object or class to use to normalize the - input data. If normalizer is a class object this object will - be initialized with the training input data. If 'None' a - ``quantnn.normalizer.MinMaxNormalizer`` will be used and - initialized with the loaded data. - transform_zeros: Whether or not to replace very small values with - random values. - batch_size: Number of samples in each training batch. - shuffle: Whether or not to shuffle the training data. - augment: Whether or not to randomly mask high-frequency channels - and to randomly permute ancillary data. - sensor: Sensor object corresponding to the training data. Only - necessary if the sensor cannot be inferred from the - corresponding sensor attribute of the dataset file. - permute: If not ``None`` the input feature corresponding to the - given index will be permuted in order to break correlation - between input and output. - """ - super().__init__() - self.filename = Path(filename) - self.dataset = decompress_and_load(self.filename) + Return: + A torch tensor containing the loaded brightness temperatures. - if targets is None: - targets = ["surface_precip"] - self.targets = targets - self.transform_zeros = transform_zeros - self.batch_size = batch_size - self.shuffle = shuffle - self.augment = augment + """ + training_data = training_data[ + [ + "simulated_brightness_temperatures", + "brightness_temperature_biases", + ] + ] + tbs = training_data.simulated_brightness_temperatures.data + biases = training_data.brightness_temperature_biases.data - # Determine sensor from dataset and compare to provided sensor - # argument. - # The following options are possible: - # - The 'sensor' argument is None so data is loaded following - # conventions of the generic sensor instance. - # - The 'sensor' argument is provided but corresponds to the same - # sensor class. In this case simply use the provided sensor object - # to load the data. - # - The 'senor' argument is provided but corresonds to a different - # sensor class. In this case we are dealing with pre-training using - # gmi data. - if "sensor" not in self.dataset.attrs: - raise Exception(f"Provided dataset lacks 'sensor' attribute.") - sensor_name = self.dataset.attrs["sensor"] - dataset_sensor = sensors.get_sensor(sensor_name) - - if sensor is None: - self.sensor = dataset_sensor - else: - if sensor_name == "GMI" and sensor != dataset_sensor: - self.sensor = dataset_sensor - else: - self.sensor = sensor - - kwargs = {} - if self.sensor.latitude_ratios is not None: - latitudes = self.dataset.latitude.mean(("scans", "pixels")).data - longitudes = self.dataset.longitude.mean(("scans", "pixels")).data - if "pixels" in self.dataset.scan_time.dims: - scan_time = self.dataset.scan_time.mean(("scans", "pixels")) - else: - scan_time = self.dataset.scan_time.mean(("scans",)) - local_time = ( - scan_time + (longitudes / 360 * 24 * 60 * 60).astype("timedelta64[s]") - ) - minutes = local_time.dt.hour.data * 60 + local_time.dt.minute.data - indices = calculate_resampling_indices(latitudes, minutes, self.sensor) - kwargs["indices"] = indices + tbs = tbs - biases + return torch.tensor(tbs) - x, y = self.sensor.load_training_data_1d( - self.dataset, self.targets, self.augment, self._rng, **kwargs - ) - # If this is pre-training, we need to extract the correct indices. - # For conical scanners we also replace the viewing angle feature - # with random values. - if sensor is not None and sensor != self.sensor: - LOGGER.info("Extracting channels %s for pre-training.", sensor.gmi_channels) - indices = list(sensor.gmi_channels) + list(range(15, 15 + 24)) - if isinstance(sensor, sensors.CrossTrackScanner): - indices.insert(sensor.n_chans, 0) - x = x[:, indices] - if isinstance(sensor, sensors.ConicalScanner): - shape = x[:, sensor.n_chans].shape - x[:, sensor.n_chans] = self._rng.uniform(-1, 1, size=shape) - - self.x = x - self.y = y - LOGGER.info("Loaded %s samples from %s", self.x.shape[0], self.filename.name) - - if normalizer is None: - self.normalizer = MinMaxNormalizer(self.x) - elif isinstance(normalizer, type): - self.normalizer = normalizer(self.x) - else: - self.normalizer = normalizer - - self.normalize = normalize - if normalize: - self.x = self.normalizer(self.x) - - if transform_zeros: - self._transform_zeros() - - if permute is not None: - n_features = self.sensor.n_chans + 2 - if isinstance(self.sensor, sensors.CrossTrackScanner): - n_features += 1 - if permute < n_features: - self.x[:, permute] = self._rng.permutation(self.x[:, permute]) - elif permute == n_features: - self.x[:, -24:-4] = self._rng.permutation(self.x[:, -24:-4]) - else: - self.x[:, -4:] = self._rng.permutation(self.x[:, -4:]) +def load_tbs_1d_xtrack_other( + training_data: xr.Dataset, + sensor: sensors.Sensor +) -> torch.Tensor: + """ + Load brightness temperatures for cross-track scanning sensors from collocations + with real observations, i.e., MRMS or ERA5 collocations. - self.x = self.x.astype(np.float32) - if isinstance(self.y, dict): - self.y = {k: self.y[k].astype(np.float32) for k in self.y} - else: - self.y = self.y.astype(np.float32) + Args: + training_data: An xarray.Dataset containing training data extracted from + GPROF simulator files. + sensor: The sensor from which the TBs are loaded - self._shuffled = False - if self.shuffle: - self._shuffle() + Return: + A tuple ``(tbs, angs)`` containing the brightness temperatures ``tbs`` + and corresponding earth incidence angles ``angs``. + """ + tbs = training_data["brightness_temperatures"].data + tbs_full = np.nan * np.zeros((tbs.shape[0], 15), dtype=np.float32) + tbs_full[:, sensor.gmi_channels] = tbs + angles = training_data["earth_incidence_angle"].data + angles_full = np.broadcast_to(angles[..., None], tbs_full.shape) - def __repr__(self): - return f"GPROF_NN_1D_Dataset({self.filename.name}, n_batches={len(self)})" + tbs = torch.tensor(tbs_full.astype("float32")) + angles = torch.tensor(angles_full.astype("float32")) + return tbs, angles - def __str__(self): - return f"GPROF_NN_1D_Dataset({self.filename.name}, n_batches={len(self)})" - def _transform_zeros(self): - """ - Transforms target values that are zero to small, non-zero values. - """ - if isinstance(self.y, dict): - y = self.y - else: - y = {self.targets: self.y} - for k, y_k in y.items(): - if k not in _THRESHOLDS: - continue - threshold = _THRESHOLDS[k] - indices = (y_k <= threshold) * (y_k >= -threshold) - if indices.sum() > 0: - t_l = np.log10(threshold) - y_k[indices] = 10 ** self._rng.uniform(t_l - 4, t_l, indices.sum()) +def load_tbs_1d_conical_other( + training_data: xr.Dataset, + sensor: sensors.Sensor +) -> torch.Tensor: + """ + Load brightness temperatures for non-GMI conical scanner from collocations + with real observations, i.e., MRMS or ERA5 collocations. - def _load_data(self): - """ - Loads the data from the file into the ``x`` and ``y`` attributes. - """ + Args: + training_data: An xarray.Dataset containing training data extracted from + GPROF simulator files. + sensor: The sensor from which the TBs are loaded - def to_xarray_dataset(self, mask=None, batch=None): - """ - Convert training data to xarray dataset. + Return: + A tuple ``(tbs, angs)`` containing the brightness temperatures ``tbs`` + and corresponding earth incidence angles ``angs``. + """ + tbs = training_data["brightness_temperatures"].data + tbs_full = np.nan * np.zeros((tbs.shape[0], 15), dtype=np.float32) + tbs_full[:, sensor.gmi_channels] = tbs + angles = training_data["earth_incidence_angle"].data + angles_full = np.nan * np.zeros((tbs.shape[0], 15), dtype=np.float32) + angles_full[:, sensor.gmi_channels] = angles - Args: - mask: A mask to select samples to include in the dataset. + tbs = torch.tensor(tbs_full.astype("float32")) + angles = torch.tensor(angles_full.astype("float32")) + return tbs, angles - Return: - An 'xarray.Dataset' containing the training data but converted - back to the original format. - """ - if batch is None: - x = self.x - y = self.y - else: - x, y = batch - if isinstance(x, torch.Tensor): - x = x.numpy() - y = {k: t.numpy() for k, t in y.items()} - if mask is None: - mask = slice(0, None) - if self.normalize: - x = self.normalizer.invert(x[mask]) - else: - x = x[mask] - sensor = self.sensor +def load_ancillary_data_1d(training_data: xr.Dataset,) -> torch.Tensor: + """ + Load brightness temperatures for GMI training data. - n_samples = x.shape[0] - n_layers = 28 + Args: + training_data: The xarray.Dataset containing the training data. - tbs = x[:, : sensor.n_chans] - if sensor.n_angles > 1: - eia = x[:, sensor.n_chans] - anc_start = sensor.n_chans + 1 - else: - eia = None - anc_start = sensor.n_chans + Return: + A torch tensor containign the ancillary data concatenated along the + last dimension. + """ + data = [] + for var in ANCILLARY_VARIABLES: + data.append(training_data[var].data) + data = np.stack(data, -1) + return torch.tensor(data) - dims = ("samples", "channels") - new_dataset = { - "brightness_temperatures": (dims, tbs), - } - vars = [ - "two_meter_temperature", - "total_column_water_vapor", - "ocean_fraction", - "land_fraction", - "ice_fraction", - "snow_depth", - "leaf_area_index", - "orographic_wind", - "moisture_convergence" - ] - for offset, name in enumerate(vars): - new_dataset[name] = (dims[:-1], x[..., anc_start + offset]) +def load_targets_1d( + training_data: xr.Dataset, + targets: List[str] +) -> Dict[str, torch.Tensor]: + """ + Load retrieval target tensors from training data file. - if eia is not None: - new_dataset["earth_incidence_angle"] = (dims[:1], eia) + Args: + training_data: The xarray.Dataset containing the training data. + targets: List of the targets to load. + """ + targs = {} + for var in targets: - dims = ("samples", "layers") - for k, v in y.items(): - n_dims = v.ndim - new_dataset[k] = (dims[:n_dims], v) + if var in training_data: + data_t = training_data[var].data + else: + n_samples = training_data.samples.size + if var in PROFILE_TARGETS: + shape = (n_samples, 1, 28) + else: + shape = (n_samples, 1, 28) + data_t = np.zeros(shape, dtype=np.float32) - new_dataset = xr.Dataset(new_dataset) - new_dataset.attrs = self.dataset.attrs - new_dataset.attrs["sensor"] = self.sensor.name - new_dataset.attrs["platform"] = self.sensor.platform.name - return new_dataset + if data_t.ndim == 1: + data_t = data_t[..., None] + targs[var] = torch.tensor(data_t.astype("float32")) + return targs - def save(self, filename): - """ - Store dataset as NetCDF file. - Args: - filename: The name of the file to which to write the dataset. - """ - new_dataset = self.to_xarray_dataset() - new_dataset.to_netcdf(filename) +def load_targets_1d_xtrack( + training_data: xr.Dataset, + angles: np.ndarray, + targets: List[str] +) -> Dict[str, torch.Tensor]: + """ + Load retrieval target tensors from training data file for x-track scanners. + Since the 'surface_precip' and 'convective_precip' variables are + + Args: + training_data: The xarray.Dataset containing the training data. + targets: List of the targets to load. + """ + samples = np.arange(training_data.samples.size) + samples = xr.DataArray(samples, dims="samples") + angles = xr.DataArray(np.abs(angles), dims="samples") + training_data = training_data[targets] + training_data = training_data.interp(samples=samples, angles=angles) -############################################################################### -# GPROF-NN 3D -############################################################################### + targs = {} + for var in targets: + data_t = training_data[var].data + if data_t.ndim == 1: + data_t = data_t[..., None] + targs[var] = torch.tensor(data_t.astype("float32")) + return targs -class GPROF_NN_3D_Dataset: +class GPROFNN1DDataset(IterableDataset): """ - Base class for GPROF-NN 3D-retrieval training data in which training - samples consist of 3D scenes of input data and corresponding target - fields. - - Objects of this class act as an iterator over batches in the training - data set. + Dataset class for loading the training data for GPROF-NN 1D retrieval. """ + combine_files = 4 def __init__( self, - filename, - targets=None, - batch_size=32, - normalize=True, - normalizer=None, - sensor=None, - transform_zeros=True, - shuffle=True, - augment=True, - input_dimensions=None, + path: Path, + targets: Optional[List[str]] = None, + transform_zeros: bool = True, + augment: bool = True, + validation: bool = False, ): """ + Create GPROF-NN 1D dataset. + + The GPROF-NN 1D data is split up into separate files by orbit. This + dataset loads the training data from all available files. And provides + an iterable over the samples in the dataset. + Args: - filename: Path of the NetCDF file containing the training data. - sensor: The sensor object to use to load the data. - targets: List of the targets to load from the data. - batch_size: The size of batches in the training data. - normalize: Whether or not to noramlize the input data. - normalizer: Normalizer object to use to normalize the input - data. May alternatively be a normalizer class that will - be used to instantiate a new normalizer object with the loaded - input data. If 'None', a new ``quantnn.normalizer.MinMaxNormalizer`` - will be created with the loaded input data. - sensor: Explicit sensor object that can be passed to override the - generic sensor information contained in the training data - file. - transform_zeros: Whether or not to transform target values that are - zero to small random values. - shuffle: Whether or not to shuffle the data. - augment: Whether or not to augment the training data. - input_dimensions: Tuple ``(width, height)`` specifying the width - and height of the input data. If ``None`` default setting for - each sensor are used. + path: The path containing the training data files. + targets: A list of the target variables to load. + transform_zeros: Whether or not to replace zeros in the output + with small random values. + augment: Whether or not to apply data augmentation to the loaded + data. + validation: If set to 'True', data loaded in consecutive iterations + over the dataset will be identical. """ - self.filename = Path(filename) - self.dataset = decompress_and_load(self.filename) + super().__init__() + if targets is None: - self.targets = ["surface_precip"] - else: - self.targets = targets + targets = ALL_TARGETS + self.targets = targets self.transform_zeros = transform_zeros - self.batch_size = batch_size - self.shuffle = shuffle + self.validation = validation self.augment = augment - if augment: - seed = int.from_bytes(os.urandom(4), "big") + os.getpid() - else: - seed = 111 - self._rng = np.random.default_rng(seed) - - if sensor is None: - sensor = self.dataset.attrs["sensor"] - sensor = getattr(sensors, sensor) - self.sensor = sensor - - # Determine sensor from dataset and compare to provided sensor - # argument. - # The following options are possible: - # - The 'sensor' argument is None so data is loaded following - # conventions of the generic sensor instance. - # - The 'sensor' argument is provided but corresponds to the same - # sensor class. In this case simply use the provided sensor object - # to load the data. - # - The 'senor' argument is provided but corresonds to a different - # sensor class. In this case we are dealing with pre-training using - # gmi data. - if "sensor" not in self.dataset.attrs: - raise Exception(f"Provided dataset lacks 'sensor' attribute.") - sensor_name = self.dataset.attrs["sensor"] - dataset_sensor = sensors.get_sensor(sensor_name) - - if sensor is None: - self.sensor = dataset_sensor - else: - if sensor_name == "GMI" and sensor != dataset_sensor: - self.sensor = dataset_sensor - else: - self.sensor = sensor - - if input_dimensions is None: - width, height = _INPUT_DIMENSIONS[self.sensor.name.upper()] - else: - width, height = input_dimensions - - latitudes = self.dataset.latitude.mean(("scans", "pixels")).data - longitudes = self.dataset.longitude.mean(("scans", "pixels")).data - scan_time = self.dataset.scan_time.mean(("scans")) - local_time = ( - scan_time + (longitudes / 360 * 24 * 60 * 60).astype("timedelta64[s]") - ) - minutes = local_time.dt.hour.data * 60 + local_time.dt.minute.data - indices = calculate_resampling_indices(latitudes, minutes, self.sensor) - if indices is None: - kwargs = {} - else: - kwargs = {"indices": indices} - - x, y = self.sensor.load_training_data_3d( - self.dataset, self.targets, augment, self._rng, width=width, height=height, - **kwargs - ) - - # If this is pre-training, we need to extract the correct indices. - # For conical scanners we also replace the viewing angle feature - # with random values. - if sensor is not None and sensor != self.sensor: - LOGGER.info("Extracting channels %s for pre-training.", sensor.gmi_channels) - indices = list(sensor.gmi_channels) + list(range(15, 15 + 24)) - if isinstance(sensor, sensors.CrossTrackScanner): - indices.insert(sensor.n_chans, 0) - x = x[:, indices] - if isinstance(sensor, sensors.ConicalScanner): - shape = x[:, sensor.n_chans].shape - x[:, sensor.n_chans] = self._rng.uniform(-1, 1, size=shape) - - self.x = x - self.y = y - LOGGER.info("Loaded %s samples from %s", self.x.shape[0], self.filename.name) - - if normalizer is None: - self.normalizer = MinMaxNormalizer(x) - elif isinstance(normalizer, type): - self.normalizer = normalizer(x) - else: - self.normalizer = normalizer - - self.normalize = normalize - if normalize: - x = self.normalizer(x) - - self.x = x - self.y = y - - if transform_zeros: - self._transform_zeros() - - self.x = self.x.astype(np.float32) - if isinstance(self.y, dict): - self.y = {k: self.y[k].astype(np.float32) for k in self.y} - else: - self.y = self.y.astype(np.float32) + self.path = Path(path) + if not self.path.exists(): + raise RuntimeError( + "The provided path does not exists." + ) - self._shuffled = False - self.indices = np.arange(self.x.shape[0]) - if self.shuffle: - self._shuffle() + files = sorted(list(self.path.glob("*_*_*.nc"))) + if len(files) == 0: + raise RuntimeError( + "Could not find any GPROF-NN 1D training data files " + f"in {self.path}." + ) + self.files = files - def __repr__(self): - return f"GPROF_NN_3D_Dataset({self.filename.name}, n_batches={len(self)})" + self.init_rng() + self.files = self.rng.permutation(self.files) - def __str__(self): - return self.__repr__() - def _transform_zeros(self): - """ - Transforms target values that are zero to small, non-zero values. - """ - if isinstance(self.y, dict): - y = self.y - else: - y = {self.target: self.y} - for k, y_k in y.items(): - if k not in _THRESHOLDS: - continue - threshold = _THRESHOLDS[k] - indices = (y_k <= threshold) * (y_k >= -threshold) - if indices.sum() > 0: - t_l = np.log10(threshold) - y_k[indices] = 10 ** self._rng.uniform(t_l - 4, t_l, indices.sum()) - - def _shuffle(self): - if not self._shuffled and self.shuffle: - LOGGER.info("Shuffling dataset %s.", self.filename.name) - self.indices = self._rng.permutation(self.indices) - self._shuffled = True - - def __getitem__(self, i): + def init_rng(self, w_id=0): """ - Return element from the dataset. This is part of the - pytorch interface for datasets. + Initialize random number generator. Args: - i(int): The index of the sample to return + w_id: The worker ID which of the worker process.. """ - if i >= len(self): - LOGGER.info("Finished iterating through dataset %s.", self.filename.name) - raise IndexError() - if i == 0: - self._shuffle() - if self.transform_zeros: - self._transform_zeros() - - self._shuffled = False - if self.batch_size is None: - if isinstance(self.y, dict): - return ( - torch.tensor(self.x[[i], :]), - {k: torch.tensor(self.y[k][[i]]) for k in self.y}, - ) - - i_start = self.batch_size * i - i_end = self.batch_size * (i + 1) - indices = self.indices[i_start:i_end] - - x = torch.tensor(self.x[indices]) - if isinstance(self.y, dict): - y = {k: torch.tensor(self.y[k][indices]) for k in self.y} + if self.validation: + seed = 42 else: - y = torch.tensor(self.y[indices]) - - return x, y + seed = int.from_bytes(os.urandom(4), "big") + w_id - def __len__(self): - """ - The number of samples in the dataset. - """ - if self.batch_size: - n = self.x.shape[0] // self.batch_size - if self.x.shape[0] % self.batch_size > 0: - n = n + 1 - return n - else: - return self.x.shape[0] + self.rng = np.random.default_rng(seed) - def to_xarray_dataset(self, mask=None, batch=None): + def worker_init_fn(self, w_id: int) -> None: """ - Convert training data to xarray dataset. + Initializes the worker state for parallel data loading. Args: - mask: A mask to select samples to include in the dataset. - - Return: - An 'xarray.Dataset' containing the training data but converted - back to the original format. + w_id: The ID of the worker. """ - if batch is None: - x = self.x - y = self.y - else: - x, y = batch - if isinstance(x, torch.Tensor): - x = x.numpy() - y = {k: t.numpy() for k, t in y.items()} - - if mask is None: - mask = slice(0, None) - if self.normalize: - x = self.normalizer.invert(x[mask]) - else: - x = x[mask] - sensor = self.sensor + self.init_rng(w_id) + winfo = torch.utils.data.get_worker_info() + n_workers = winfo.num_workers + + self.files = self.files[w_id::n_workers] + + def load_training_data(self, dataset: xr.Dataset) -> Dict[str, torch.Tensor]: + + sensor = sensors.get_sensor(dataset.attrs["sensor"]) + targets = self.targets + ref_target = targets[0] + + if sensor == sensors.GMI: + tbs = dataset["brightness_temperatures"].data + y_t = dataset[ref_target].data + valid_input = np.any(tbs > 0, -1) + valid_target = np.isfinite(y_t).any(tuple(range(1, y_t.ndim))) + mask = valid_input * valid_target + dataset = dataset[{"samples": mask}] + + tbs = load_tbs_1d_gmi(dataset) + anc = load_ancillary_data_1d(dataset) + targets = load_targets_1d(dataset, self.targets) + angs = torch.tensor(np.broadcast_to(EIA_GMI.astype("float32"), tbs.shape)) + + elif isinstance(sensor, sensors.CrossTrackScanner): + + if dataset.attrs["source"] == "sim": + tbs = dataset["brightness_temperatures"].data + y_t = dataset[ref_target].data + valid_input = np.any(tbs > 0, -1) + valid_target = np.isfinite(y_t).any(tuple(range(1, y_t.ndim))) + mask = valid_input * valid_target + dataset = dataset[{"samples": mask}] + angles = dataset["angles"].data + angs = self.rng.uniform( + angles.min(), + angles.max(), + size=dataset.samples.size, + ).astype(np.float32) + tbs = load_tbs_1d_xtrack_sim(dataset, angs, sensor) + angs = torch.tensor(angs) + else: + tbs = dataset["brightness_temperatures"].data + y_t = dataset[ref_target].data + valid_input = np.any(tbs > 0, -1) + valid_target = np.isfinite(y_t).any(tuple(range(1, y_t.ndim))) + mask = valid_input * valid_target + dataset = dataset[{"samples": mask}] + tbs, angs = load_tbs_1d_xtrack_other(dataset, sensor) - n_samples = x.shape[0] - n_layers = 28 + anc = load_ancillary_data_1d(dataset) + targets = load_targets_1d(dataset, self.targets) - tbs = np.transpose(x[:, : sensor.n_chans], (0, 2, 3, 1)) - if sensor.n_angles > 1: - eia = x[:, sensor.n_chans] - anc_start = sensor.n_chans + 1 - else: - eia = None - anc_start = sensor.n_chans + elif isinstance(sensor, sensors.ConicalScanner): - dims = ("samples", "scans", "pixels", "channels") - new_dataset = { - "brightness_temperatures": (dims, tbs), + if dataset.source == 0: + tbs = load_tbs_1d_conical_sim(dataset) + else: + tbs = load_tbs_1d_conical_other(dataset) + angs = torch.tensor(np.broadcast_to(EIA_GMI.astype("float32"), tbs.shape)) + anc = load_ancillary_data_1d(dataset) + targets = load_targets_1d(dataset) + + x = { + "brightness_temperatures": tbs, + "ancillary_data": anc, + "viewing_angles": angs } - vars = [ - "two_meter_temperature", - "total_column_water_vapor", - "ocean_fraction", - "land_fraction", - "ice_fraction", - "snow_depth", - "leaf_area_index", - "orographic_wind", - "moisture_convergence" - ] - for offset, name in enumerate(vars): - new_dataset[name] = (dims[:-1], x[:, anc_start + offset]) - - if eia is not None: - new_dataset["earth_incidence_angle"] = (dims[:-1], eia) - - dims = ("samples", "scans", "pixels", "layers") - for k, v in y.items(): - n_dims = v.ndim - if n_dims > 3: - v = np.transpose(v, (0, 2, 3, 1)) - new_dataset[k] = (dims[:n_dims], v) - - new_dataset = xr.Dataset(new_dataset) - new_dataset.attrs = self.dataset.attrs - new_dataset.attrs["sensor"] = self.sensor.name - new_dataset.attrs["platform"] = self.sensor.platform.name - return new_dataset - - def save(self, filename): - """ - Store dataset as NetCDF file. + return x, targets - Args: - filename: The name of the file to which to write the dataset. - """ - new_dataset = self.to_xarray_dataset() - new_dataset.to_netcdf(filename) + def __repr__(self): + return f"GPROFNN1DDataset(path={self.path}, targets={self.targets})" -class SimulatorDataset(GPROF_NN_3D_Dataset): - """ - Dataset to train a simulator network to predict simulated brightness - temperatures and brightness temperature biases. - """ + def __iter__(self): - def __init__( - self, - filename, - batch_size=32, - normalize=True, - normalizer=None, - shuffle=True, - augment=True, - ): - """ - Args: - filename: Path to the NetCDF file containing the training data. - normalize: Whether or not to normalize the input data. - batch_size: Number of samples in each training batch. - normalizer: The normalizer used to normalize the data. - shuffle: Whether or not to shuffle the training data. - augment: Whether or not to randomly mask high-frequency channels - and to randomly permute ancillary data. - """ - self.filename = Path(filename) - # Load and decompress data but keep only scenes for which - # contain simulated obs. - self.dataset = decompress_and_load(self.filename) - self.dataset = self.dataset[{"samples": self.dataset.source == 0}] - - targets = ["simulated_brightness_temperatures", "brightness_temperature_biases"] - self.transform_zeros = False - self.batch_size = batch_size - self.shuffle = shuffle - self.augment = augment + all_files = self.rng.permutation(self.files) + for ind in range(0, len(self.files), self.combine_files): + files = all_files[ind:ind + self.combine_files] - if augment: - seed = int.from_bytes(os.urandom(4), "big") + os.getpid() - else: - seed = 111 - self._rng = np.random.default_rng(seed) - - x, y = self.load_training_data_3d(self.dataset, targets, augment, self._rng) - indices_1h = list(range(17, 39)) - if normalizer is None: - self.normalizer = MinMaxNormalizer(x, exclude_indices=indices_1h) - elif isinstance(normalizer, type): - self.normalizer = normalizer(x, exclude_indices=indices_1h) - else: - self.normalizer = normalizer + inputs = {} + targets = {} - self.normalize = normalize - if normalize: - x = self.normalizer(x) + for path in files: + with xr.open_dataset(path) as input_file: - self.x = x - self.y = {} + inputs_f, targets_f = self.load_training_data(input_file) + for name, tensor in inputs_f.items(): + inputs.setdefault(name, []).append(tensor) + for name, tensor in targets_f.items(): + targets.setdefault(name, []).append(tensor) - sensor = getattr(sensors, self.dataset.attrs["sensor"]) - biases = y["brightness_temperature_biases"] - for i in range(biases.shape[1]): - key = f"brightness_temperature_biases_{i}" - self.y[key] = biases[:, [i]] + inputs = {name: np.concatenate(data) for name, data in inputs.items()} + targets = {name: np.concatenate(data) for name, data in targets.items()} - sims = y["simulated_brightness_temperatures"] - for i in range(biases.shape[1]): - key = f"simulated_brightness_temperatures_{i}" - if isinstance(sensor, sensors.ConicalScanner): - self.y[key] = sims[:, [i]] - else: - self.y[key] = sims[:, :, [i]] + n_samples = inputs["brightness_temperatures"].shape[0] + for ind in self.rng.permutation(n_samples): + yield ( + {name: data[ind] for name, data in inputs.items()}, + {name: data[ind] for name, data in targets.items()} + ) - self.x = self.x.astype(np.float32) - if isinstance(self.y, dict): - self.y = {k: self.y[k].astype(np.float32) for k in self.y} - else: - self.y = self.y.astype(np.float32) - self._shuffled = False - self.indices = np.arange(self.x.shape[0]) - if self.shuffle: - self._shuffle() +def load_training_data_3d_gmi( + scene: xr.Dataset, + targets: List[str], + augment: bool = False, + rng: np.random.Generator = None, +) -> Tuple[Dict[str, torch.Tensor]]: + """ + Load GPROF-NN 3D training scene for GMI. + + Args: + scene: An xarray.Dataset containing the scene from which to load + the training data. + targets: A list containing a list of the targets to load. + augment: Whether or not to augment the input data. + rng: A numpy random number generator to use for the augmentation. - def load_training_data_3d(self, dataset, targets, augment, rng): - """ - Load data for training a simulator data. + Return: + A tuple ``(x, y)`` of dictionaries ``x`` and ``y`` containing the + training input data in ``x`` and the training reference data in ``y``. + """ + variables = [ + name for name in targets + ["latitude", "longitude"] + if name in scene + ] + scene = decompress_scene(scene, variables) + + if augment: + p_x_o = rng.random() + p_x_i = rng.random() + p_y = rng.random() + else: + p_x_o = 0.5 + p_x_i = 0.5 + p_y = rng.random() + + lats = scene.latitude.data + lons = scene.longitude.data + coords = get_transformation_coordinates( + lats, lons, sensors.GMI.viewing_geometry, 64, 128, p_x_i, p_x_o, p_y + ) + scene = remap_scene(scene, coords, variables) + + tbs = torch.tensor(scene.brightness_temperatures.data) + angs = torch.tensor(np.broadcast_to(EIA_GMI.astype("float32"), tbs.shape)) + anc = torch.tensor(np.stack( + [scene[anc_var].data.astype("float32") for anc_var in ANCILLARY_VARIABLES] + )) + tbs = torch.permute(tbs, (2, 0, 1)) + angs = torch.permute(angs, (2, 0, 1)) + + x = { + "brightness_temperatures": tbs, + "viewing_angles": angs, + "ancillary_data": anc + } - This function is a replacement for the ``load_training_data3d`` - method of the sensor that is called by the other training data - objects to load the data. This is required because the data - the input for the simulator are always the GMI Tbs. + y = {} + for target in targets: + # MRMS collocations don't contain all targets. + if target not in scene: + if target in PROFILE_TARGETS: + empty = torch.nan * torch.zeros((28, 128, 64)) + else: + empty = torch.nan * torch.zeros((1, 128, 64)) + y[target] = empty + continue - Args: - dataset: The 'xarray.Dataset' from which to load the training - data. - targets: List of the targets to load. - augment: Whether or not to augment the training data. - rng: 'numpy.random.Generator' to use to generate random numbers. + data = torch.tensor(scene[target].data.astype("float32")) + dims = tuple(range(data.ndim)) + data = torch.permute(data, dims[-2:] + dims[:-2]) + y[target] = data - Return: + return x, y - Tuple ``(x, y)`` containing the training input ``x`` and a - dictionary of target data ``y``. - """ - sensor = getattr(sensors, dataset.attrs["sensor"]) - # - # Input data - # +def load_training_data_3d_xtrack_sim( + sensor: sensors.Sensor, + scene: xr.Dataset, + targets: List[str], + augment: bool = False, + rng: np.random.Generator = None, +) -> Tuple[Dict[str, torch.Tensor]]: + """ + Load GPROF-NN 3D training scene for cross-track scannres from + sim-file training data. - # Brightness temperatures - n = dataset.samples.size + Args: + sensor: The sensor from which the training data was extracted. + scene: An xarray.Dataset containing the scene from which to load + the training data. + targets: A list containing a list of the targets to load. + augment: Whether or not to augment the input data. + rng: A numpy random number generator to use for the augmentation. - x = [] - y = {} + Return: + A tuple ``(x, y)`` of dictionaries ``x`` and ``y`` containing the + training input data in ``x`` and the training reference data in ``y``. + """ + required = [ + "latitude", + "longitude", + "simulated_brightness_temperatures", + "brightness_temperature_biases" + ] + variables = [ + name for name in targets + required + if name in scene + ] + scene = decompress_scene(scene, variables) + + if augment: + p_x_o = rng.random() + p_x_i = rng.random() + p_y = rng.random() + else: + p_x_o = 0.5 + p_x_i = 0.5 + p_y = rng.random() + + width = 64 + height = 128 + + lats = scene.latitude.data + lons = scene.longitude.data + coords = get_transformation_coordinates( + lats, lons, sensor.viewing_geometry, width, height, p_x_i, p_x_o, p_y + ) + scene = remap_scene(scene, coords, variables) + + center = sensor.viewing_geometry.get_window_center(p_x_o, width) + j_start = int(center[1, 0, 0] - width // 2) + j_end = int(center[1, 0, 0] + width // 2) + angs = sensor.viewing_geometry.get_earth_incidence_angles() + angs = angs[j_start:j_end] + angs = np.repeat(angs.reshape(1, -1), height, axis=0) + weights = calculate_interpolation_weights(np.abs(angs), sensor.angles) + weights = np.repeat(weights.reshape(1, -1, sensor.n_angles), height, axis=0) + weights = calculate_interpolation_weights(np.abs(angs), scene.angles.data) + + # Calculate brightness temperatures + tbs_sim = scene.simulated_brightness_temperatures.data + tbs_sim = interpolate(tbs_sim, weights) + tb_biases = scene.brightness_temperature_biases.data + tbs = tbs_sim - tb_biases + + full_shape = tbs_sim.shape[:2] + (15,) + tbs_full = np.nan * np.ones(full_shape, dtype="float32") + tbs_full[:, :, sensor.gmi_channels] = tbs + tbs_full = torch.permute(torch.tensor(tbs_full), (2, 0, 1)) + + angs_full = np.nan * np.ones(full_shape, dtype="float32") + angs_full[:, :, sensor.gmi_channels] = angs[..., None] + angs_full = torch.permute(torch.tensor(angs_full), (2, 0, 1)) + + anc = torch.tensor(np.stack( + [scene[anc_var].data.astype("float32") for anc_var in ANCILLARY_VARIABLES] + )) + + x = { + "brightness_temperatures": tbs_full, + "viewing_angles": angs_full, + "ancillary_data": anc + } - vs = ["latitude", "longitude"] - if sensor != sensors.GMI: - vs += ["brightness_temperatures_gmi"] + y = {} + for target in targets: + # MRMS collocations don't contain all targets. + if target not in scene: + if target in PROFILE_TARGETS: + empty = torch.nan * torch.zeros((28, 128, 64)) + else: + empty = torch.nan * torch.zeros((1, 128, 64)) + y[target] = empty + continue - for i in range(n): - scene = decompress_scene(dataset[{"samples": i}], targets + vs) + data = scene[target].data.astype("float32") - if augment: - p_x_o = rng.random() - p_x_i = rng.random() - p_y = rng.random() - else: - p_x_o = 0.5 - p_x_i = 0.5 - p_y = 0.5 - - lats = scene.latitude.data - lons = scene.longitude.data - geometry = sensors.GMI.viewing_geometry - coords = get_transformation_coordinates( - lats, lons, geometry, 96, 128, p_x_i, p_x_o, p_y - ) + if target in ["surface_precip", "convective_precip"]: + data = interpolate(data, weights) - scene = remap_scene(scene, coords, targets + vs) + data = torch.tensor(data) + dims = tuple(range(data.ndim)) + data = torch.permute(data, dims[-2:] + dims[:-2]) + y[target] = data - # - # Input data - # + return x, y - if sensor == sensors.GMI: - tbs = sensor.load_brightness_temperatures(scene) - else: - tbs = load_variable(scene, "brightness_temperatures_gmi") - tbs = np.transpose(tbs, (2, 0, 1)) - if augment: - r = rng.random() - n_p = rng.integers(10, 30) - if r > 0.80: - tbs[10:15, :, :n_p] = np.nan - t2m = sensor.load_two_meter_temperature(scene)[np.newaxis] - tcwv = sensor.load_total_column_water_vapor(scene)[np.newaxis] - st = sensor.load_surface_type(scene) - st = np.transpose(st, (2, 0, 1)) - am = sensor.load_airmass_type(scene) - am = np.transpose(am, (2, 0, 1)) - x.append(np.concatenate([tbs, t2m, tcwv, st, am], axis=0)) - - # - # Output data - # - - for t in targets: - y_t = sensor.load_target(scene, t, None) - y_t = np.nan_to_num(y_t, nan=MASKED_OUTPUT) - dims_sp = tuple(range(2)) - dims_t = tuple(range(2, y_t.ndim)) - - y.setdefault(t, []).append(np.transpose(y_t, dims_t + dims_sp)) - - # Also flip data if requested. - if augment: - r = rng.random() - if r > 0.5: - x[i] = np.flip(x[i], -2) - for k in targets: - y[k][i] = np.flip(y[k][i], -2) - - r = rng.random() - if r > 0.5: - x[i] = np.flip(x[i], -1) - for k in targets: - y[k][i] = np.flip(y[k][i], -1) - - x = np.stack(x) - for k in targets: - y[k] = np.stack(y[k]) - - return x, y - - -HR_TARGETS = [ - "surface_precip", - "rain_water_path", - "ice_water_path", - "cloud_water_path", - "rain_water_content", - "snow_water_content", - "cloud_water_content", - "latent_heat", -] - - -def _remap_hr_scene(scene, coords, targets): - """ - Special remapping functions for GPROF-NN HR training scenes. - Since the HR training scenes contain retrieval outputs at higher - resolution, this function must be used for the remapping as it takes - into account the differences in the interpolation between input - and output data. +def load_training_data_3d_conical_sim( + sensor: sensors.Sensor, + scene: xr.Dataset, + targets: List[str], + augment: bool = False, + rng: np.random.Generator = None, +) -> Tuple[Dict[str, torch.Tensor]]: + """ + Load GPROF-NN 3D training scene for non-GMI conical scanners from + sim-file training data. Args: - scenes: An xarray.Dataset containing the data for the training - schene. - coords: A numpy array containing the coordinates defining the - remapping. - targets: A list defining the targets to load. + sensor: The sensor from which the training data was extracted. + scene: An xarray.Dataset containing the scene from which to load + the training data. + targets: A list containing a list of the targets to load. + augment: Whether or not to augment the input data. + rng: A numpy random number generator to use for the augmentation. Return: - An xarray.Dataset containing the remapped scene. + A tuple ``(x, y)`` of dictionaries ``x`` and ``y`` containing the + training input data in ``x`` and the training reference data in ``y``. """ - variables = ["brightness_temperatures",] + targets - data = {} - - dims = ("scans", "pixels") - dims_3 = ("scans_3", "pixels") - - i_start, _ = compressed_pixel_range() - coords_3 = upsample_scans(coords, axis=1) - coords_3[1] -= (i_start + 1) - n_scans = scene.scans.size - coords_3[0] *= 3 - - for v in variables: - if v in HR_TARGETS: - remap_coords = coords_3 - dims = ("scans_3", "pixels") - else: - remap_coords = coords - dims = ("scans", "pixels") - - data_v = scene[v].data - if v in LIMITS: - data_v = apply_limits(data_v, *LIMITS[v]) - data_r = extract_domain(data_v, remap_coords, order=1) - data_r = data_r.astype(np.float32) - data[v] = (dims + scene[v].dims[2:], data_r) - else: - data[v] = (scene[v].dims, scene[v].data) + required = [ + "latitude", + "longitude", + "simulated_brightness_temperatures", + "brightness_temperature_biases" + ] + variables = [ + name for name in targets + required + if name in scene + ] + scene = decompress_scene(scene, variables) + + if augment: + p_x_o = rng.random() + p_x_i = rng.random() + p_y = rng.random() + else: + p_x_o = 0.5 + p_x_i = 0.5 + p_y = rng.random() + + width = 64 + height = 128 + + lats = scene.latitude.data + lons = scene.longitude.data + coords = get_transformation_coordinates( + lats, lons, sensor.viewing_geometry, width, height, p_x_i, p_x_o, p_y + ) + scene = remap_scene(scene, coords, variables) + + # Calculate brightness temperatures + tbs_sim = scene.simulated_brightness_temperatures.data + tb_biases = scene.brightness_temperature_biases.data + tbs = torch.tensor(tbs_sim - tb_biases, dtype=torch.float32) + tbs = torch.permute(tbs, (2, 0, 1)) + + angs_full = torch.tensor( + np.broadcast_to(EIA_GMI.astype("float32")[0][..., None, None], tbs.shape) + ) + for ind in range(15): + if ind not in sensor.gmi_channels: + angs_full[ind] = np.nan + angs_full = torch.tensor(angs_full) + + anc = torch.tensor(np.stack( + [scene[anc_var].data.astype("float32") for anc_var in ANCILLARY_VARIABLES] + )) + + x = { + "brightness_temperatures": tbs, + "viewing_angles": angs_full, + "ancillary_data": anc + } - return xr.Dataset(data) + y = {} + for target in targets: + # MRMS collocations don't contain all targets. + if target not in scene: + if target in PROFILE_TARGETS: + empty = torch.nan * torch.zeros((28, 128, 64)) + else: + empty = torch.nan * torch.zeros((1, 128, 64)) + y[target] = empty + continue -class GPROF_NN_HR_Dataset(GPROF_NN_3D_Dataset): - """ - Dataset to tran a neural network model on high resolution output - data. - """ - def __init__( - self, - filename, - batch_size=32, - normalize=True, - normalizer=None, - transform_zeros=True, - shuffle=True, - augment=True, - targets=None - ): - """ - Args: - filename: Path to the NetCDF file containing the training data. - batch_size: Number of samples in each training batch. - normalize: Whether or not to normalize the input data. - normalizer: The normalizer used to normalize the data. - transform_zeros: Whether or not to transform zeros to small - random values. - shuffle: Whether or not to shuffle the training data. - augment: Whether or not to randomly mask high-frequency channels - and to randomly permute ancillary data. - targets: List of targets to load. - """ - self.filename = Path(filename) - # Load and decompress data but keep only scenes for which - # contain simulated obs. - self.dataset = decompress_and_load(self.filename) + data = scene[target].data.astype("float32") - if targets is None: - targets = HR_TARGETS + data = torch.tensor(data) + dims = tuple(range(data.ndim)) + data = torch.permute(data, dims[-2:] + dims[:-2]) + y[target] = data - self.transform_zeros = transform_zeros - self.batch_size = batch_size - self.shuffle = shuffle - self.augment = augment + return x, y - if augment: - seed = int.from_bytes(os.urandom(4), "big") + os.getpid() - else: - seed = 111 - self._rng = np.random.default_rng(seed) - - x, y = self.load_training_data_3d(self.dataset, targets, augment, self._rng) - if normalizer is None: - self.normalizer = MinMaxNormalizer(x) - elif isinstance(normalizer, type): - self.normalizer = normalizer(x) - else: - self.normalizer = normalizer - self.normalize = normalize - if normalize: - x = self.normalizer(x) +def load_training_data_3d_other( + sensor: sensors.Sensor, + scene: xr.Dataset, + targets: List[str], + augment: bool = False, + rng: np.random.Generator = None, +) -> Tuple[Dict[str, torch.Tensor]]: + """ + Load training data for non-GMI sensors that are training scenes extracted + from actualy observations, i.e., not .sim-file derived. - self.x = x - self.y = y + Args: + sensor: The sensor object from which the training data was extracted. + scene: An xarray.Dataset containing the scene from which to load + the training data. + targets: A list containing a list of the targets to load. + augment: Whether or not to augment the input data. + rng: A numpy random number generator to use for the augmentation. - self.x = self.x.astype(np.float32) - if isinstance(self.y, dict): - self.y = {k: self.y[k].astype(np.float32) for k in self.y} - else: - self.y = self.y.astype(np.float32) + Return: + A tuple ``(x, y)`` of dictionaries ``x`` and ``y`` containing the + training input data in ``x`` and the training reference data in ``y``. + """ + required = [ + "latitude", + "longitude", + "simulated_brightness_temperatures", + "brightness_temperature_biases" + ] + variables = [ + name for name in targets + required + if name in scene + ] + + width = 64 + height = 128 + + if augment: + pix_start = rng.integers(0, scene.pixels.size - width + 1) + scn_start = rng.integers(0, scene.scans.size - height + 1) + else: + pix_start = (scene.pixels.size - width) // 2 + scn_start = (scene.scns.size - height) // 2 + pix_end = pix_start + width + scn_end = scn_start + height + scene = scene[{"pixels": slice(pix_start, pix_end), "scans": slice(scn_start, scn_end)}] + + + # Calculate brightness temperatures + tbs = scene.brightness_temperatures.data + full_shape = tbs.shape[:2] + (15,) + if tbs.shape != full_shape: + tbs_full = np.nan * np.ones(full_shape, dtype="float32") + tbs_full[:, :, sensor.gmi_channels] = tbs + else: + tbs_full = tbs + tbs_full = torch.permute(torch.tensor(tbs_full), (2, 0, 1)) + + angs = scene.earth_incidence_angle.data + if angs.ndim == 2: + angs = angs[..., None] + if tbs.shape != full_shape: + angs_full = np.nan * np.ones(full_shape, dtype="float32") + angs_full[:, :, sensor.gmi_channels] = angs + else: + angs_full = angs + angs_full = torch.permute(torch.tensor(angs_full), (2, 0, 1)) - self._shuffled = False - self.indices = np.arange(self.x.shape[0]) - if self.shuffle: - self._shuffle() + anc = torch.tensor(np.stack( + [scene[anc_var].data.astype("float32") for anc_var in ANCILLARY_VARIABLES] + )) - # Delete decompressed raw input data. - del self.dataset + x = { + "brightness_temperatures": tbs_full, + "viewing_angles": angs_full, + "ancillary_data": anc + } + y = {} + for target in targets: + # MRMS collocations don't contain all targets. + if target not in scene: + if target in PROFILE_TARGETS: + empty = torch.nan * torch.zeros((28, 128, 64)) + else: + empty = torch.nan * torch.zeros((1, 128, 64)) + y[target] = empty + continue - def load_training_data_3d(self, dataset, targets, augment, rng): - """ - Load data for training a simulator data. - This function is a replacement for the ``load_training_data3d`` - method of the sensor that is called by the other training data - objects to load the data. This is required because the data - the input for the simulator are always the GMI Tbs. + data = scene[target].data.astype("float32") - Args: - dataset: The 'xarray.Dataset' from which to load the training - data. - targets: List of the targets to load. - augment: Whether or not to augment the training data. - rng: 'numpy.random.Generator' to use to generate random numbers. + data = torch.tensor(data) + dims = tuple(range(data.ndim)) + data = torch.permute(data, dims[-2:] + dims[:-2]) + y[target] = data - Return: + return x, y - Tuple ``(x, y)`` containing the training input ``x`` and a - dictionary of target data ``y``. - """ - # Brightness temperatures - n = dataset.samples.size - x = [] - y = {} +class GPROFNN3DDataset(Dataset): + """ + Dataset class for loading the training data for GPROF-NN 3D retrieval. + """ - vs = ["latitude", "longitude"] + def __init__( + self, + path: Path, + targets: Optional[List[str]] = None, + transform_zeros: bool = True, + augment: bool = True, + validation: bool = False + ): + """ + Create GPROF-NN 3D dataset. - for i in range(n): + The training data for the GPROF-NN 3D retrieval consists of 2D scenes + in separate files. - scene = dataset[{"samples": i}] + Args: + path: The path containing the training data files. + targets: A list of the target variables to load. + transform_zeros: Whether or not to replace zeros in the output + with small random values. + augment: Whether or not to apply data augmentation to the loaded + data. + validation: If set to 'True', data loaded in consecutive iterations + over the dataset will be identical. + """ + super().__init__() - if augment: - p_x_o = rng.random() - p_x_i = rng.random() - p_y = rng.random() - else: - p_x_o = 0.5 - p_x_i = 0.5 - p_y = 0.5 - - lats = scene.latitude.data - lons = scene.longitude.data - geometry = sensors.GMI.viewing_geometry - coords = get_transformation_coordinates( - lats, lons, geometry, 96, 128, p_x_i, p_x_o, p_y + if targets is None: + targets = ALL_TARGETS + self.targets = targets + self.transform_zeros = transform_zeros + self.validation = validation + self.augment = augment and not validation + self.validation = validation + + self.path = Path(path) + if not self.path.exists(): + raise RuntimeError( + "The provided path does not exists." ) - if augment: - ang = rng.uniform(-90, 90) - coords = rotate( - coords, - angle=ang, - axes=(-1, -2), - reshape=False, - order=0, - cval=np.nan - ) + files = sorted(list(self.path.glob("*_*_*.nc"))) + if len(files) == 0: + raise RuntimeError( + "Could not find any GPROF-NN 3D training data files " + f"in {self.path}." + ) + self.files = files - scene = _remap_hr_scene(scene, coords, targets + vs) - - # - # Input data - # - - tbs = sensors.GMI.load_brightness_temperatures(scene) - tbs = expand_tbs(tbs) - tbs = np.transpose(tbs, (2, 0, 1)) - if augment: - r = rng.random() - n_p = rng.integers(10, 30) - if r > 0.80: - tbs[10:15, :, :n_p] = np.nan - x.append(tbs) - - # - # Output data - # - - for t in targets: - y_t = sensors.GMI.load_target(scene, t, None) - y_t = np.nan_to_num(y_t, nan=MASKED_OUTPUT) - dims_sp = tuple(range(2)) - dims_t = tuple(range(2, y_t.ndim)) - - y.setdefault(t, []).append(np.transpose(y_t, dims_t + dims_sp)) - - # Also flip data if requested. - if augment: - r = rng.random() - if r > 0.5: - x[i] = np.flip(x[i], -2) - for k in targets: - y[k][i] = np.flip(y[k][i], -2) - - r = rng.random() - if r > 0.5: - x[i] = np.flip(x[i], -1) - for k in targets: - y[k][i] = np.flip(y[k][i], -1) - - x = np.stack(x) - for k in targets: - y[k] = np.stack(y[k]) - - return x, y - - -NORMALIZER = MinMaxNormalizer(np.ones((23, 1, 1)), feature_axis=0) -NORMALIZER.stats = { - 0: (130, 350), # - 1: (70, 350), # - 2: (130, 350), # - 3: (80, 350), # - 4: (130, 310), # - 5: (110, 310), # - 6: (60, 310), # - 7: (100, 310), # - 8: (50, 310), # - 9: (50, 310), # - 10: (60, 310), # - 11: (60, 310), # - 12: (60, 310), # - 13: (60, 310), # - 14: (60, 310), # - 15: (-70, 70), # Earth incidence angle - 16: (220, 320), # Two-meter temperature - 17: (0, 85), # Total-column water vapor - 18: (0, 100), # Land fraction - 19: (0, 100), # Ice fraction - 20: (0, 7), # Leaf-area index - 21: (-4, 4), # Orographic wind - 22: (-5e-6, 5e-6) # Moisture convergence -} + self.init_rng() + self.files = self.rng.permutation(self.files) -class PretrainingDataset(Dataset): - """ - Dataset class to load pretraining data for GPROF retrievals. - The unsupervised pretraining tasks the GPROF-NN model to reproduce - all 15 GPM channels from as little as three input channels. Input - channels are dropped randomly. Corruption in the form of scaling by - a value between 0.8 and 1.2 is applied to the remaining channels - with a pre-defined probability. - """ - def __init__( - self, - path: Path, - normalize: bool = True, - ancillary_data: bool = True, - channel_corruption: float = 0.2 - ): - """ - Args: - path: The path containing the pretraining data. - normalize: Whether or not the input data should be normalized. - ancillary_data: Whether or not to include ancillary data - in the input. - channel_corruption: The probability of an input channel to be - corrupted by a random scaling between 0.8 and 1.2. + def init_rng(self, w_id=0): """ - self.normalize = normalize - self.ancillary_data = ancillary_data - self.files = np.array(sorted(list(Path(path).glob("**/*.nc")))) - self.rng = np.random.default_rng() - self.channel_corruption = channel_corruption + Initialize random number generator. - def seed(self, *args): - """ - Seed the data loader's random generator. + Args: + w_id: The worker ID which of the worker process.. """ - seed = int.from_bytes(os.urandom(4), "little") + os.getpid() + if self.validation: + seed = 42 + else: + seed = int.from_bytes(os.urandom(4), "big") + w_id self.rng = np.random.default_rng(seed) - def __len__(self): + def worker_init_fn(self, w_id: int): """ - The number of samples in the dataset. + Pytorch retrieve interface. """ - return len(self.files) + self.init_rng(w_id) + winfo = torch.utils.data.get_worker_info() + n_workers = winfo.num_workers - def __getitem__(self, index): - """ - Load one training sample. - """ - path = self.files[index] - with xr.open_dataset(path) as data: - - tbs = np.transpose(data.brightness_temperatures.data, (2, 0, 1)) - valid_chans = np.where(np.any(tbs >= 0.0, (1, 2)))[0] - n_valid = len(valid_chans) - n_drop = self.rng.integers(1, max(n_valid - 3, 0)) - drop = self.rng.permutation(valid_chans)[:n_drop] - n_valid = len(valid_chans) - - tbs_in = -1.5 * np.ones_like(tbs) - tbs_out = MASKED_OUTPUT * np.ones_like(tbs) - - for chan in valid_chans: - if chan not in drop: - tbs_in[chan] = tbs[chan] - - if self.rng.random() > (1.0 - self.channel_corruption): - fac = 0.8 + 0.4 * self.rng.random() - tbs_in[chan] *= fac - - tbs_out[chan] = tbs[chan] - - if self.ancillary_data: - - eia = data.earth_incidence_angle.data[..., 0][None] - t2m = data.two_meter_temperature.data[None] - tcwv = data.total_column_water_vapor.data[None] - f_land = data.land_fraction.data[None] - f_ice = data.ice_fraction.data[None] - snow = data.snow_depth.data[None] - lai = data.leaf_area_index.data[None] - wind = data.orographic_wind.data[None] - conv = data.moisture_convergence.data[None] - - x = np.concatenate([ - tbs_in, eia, t2m, tcwv, f_land, f_ice, snow, lai, wind, - conv - ]) - else: - x = tbs_in - - if self.rng.random() > 0.5: - x = np.flip(x, -1) - tbs_out = np.flip(tbs_out, -1) + def __repr__(self): + return f"GPROFNN3DDataset(path={self.path}, targets={self.targets})" - if self.rng.random() > 0.5: - x = np.flip(x, -2) - tbs_out = np.flip(tbs_out, -2) + def __len__(self): + return len(self.files) - if self.normalize: - x = torch.tensor(NORMALIZER(x)) + def __getitem__(self, ind): + with xr.open_dataset(self.files[ind]) as scene: + sensor = scene.attrs["sensor"] + sensor = getattr(sensors, sensor) - tbs_out = np.nan_to_num(tbs_out, nan=MASKED_OUTPUT, copy=True) - y = {f"tbs_{ind}": torch.tensor(tbs_out[ind]) for ind in range(15)} - return (x, y) + if sensor == sensors.GMI: + return load_training_data_3d_gmi( + scene, + targets=self.targets, + augment=self.augment, + rng=self.rng + ) + elif isinstance(sensor, sensors.CrossTrackScanner): + if scene.source == "sim": + return load_training_data_3d_xtrack_sim( + sensor, + scene, + targets=self.targets, + augment=self.augment, + rng=self.rng + ) + return load_training_data_3d_other( + sensor, + scene, + targets=self.targets, + augment=self.augment, + rng=self.rng + ) + elif isinstance(sensor, sensors.ConicalScanner): + if scene.source == "sim": + return load_training_data_3d_conical_sim( + sensor, + scene, + targets=self.targets, + augment=self.augment, + rng=self.rng + ) + return load_training_data_3d_other( + sensor, + scene, + targets=self.targets, + augment=self.augment, + rng=self.rng + ) + raise RuntimeError( + "Invalid sensor/scene combination in training file %s.", + self.files[ind] + ) diff --git a/gprof_nn/data/utils.py b/gprof_nn/data/utils.py index 52dc192..82a40a1 100644 --- a/gprof_nn/data/utils.py +++ b/gprof_nn/data/utils.py @@ -150,14 +150,12 @@ def remap_scene(scene, coords, targets): "brightness_temperatures", "two_meter_temperature", "total_column_water_vapor", - "ocean_fraction", "land_fraction", "ice_fraction", "snow_depth", "leaf_area_index", "orographic_wind", "moisture_convergence", - "source", ] + targets data = {} @@ -234,6 +232,7 @@ def save_scene( tbs = scene.brightness_temperatures.data tbs[tbs < 0] = np.nan tbs[tbs > 400] = np.nan + if "simulated_brightness_temperatures" in scene: tbs = scene.simulated_brightness_temperatures.data tbs[tbs < 0] = np.nan @@ -306,6 +305,8 @@ def save_scene( "ice_water_path", "cloud_water_path", "rain_water_path", + "surface_precip", + "convective_precip" ]: encoding[var] = {"dtype": "float32", "zlib": True} @@ -415,6 +416,8 @@ def write_training_samples_1d( """ dataset = dataset[{"pixels": slice(*compressed_pixel_range())}] mask = np.isfinite(dataset.surface_precip.data) + if mask.ndim > 2: + mask = mask.all(-1) valid = {} for var in dataset.variables: @@ -424,8 +427,8 @@ def write_training_samples_1d( else: arr_data = arr.data valid[var] = ((("samples",) + arr.dims[2:]), arr_data[mask]) - valid = xr.Dataset(valid) + valid = xr.Dataset(valid, attrs=dataset.attrs) start_time = pd.to_datetime(dataset.scan_time.data[0].item()) start_time = start_time.strftime("%Y%m%d%H%M%S") end_time = pd.to_datetime(dataset.scan_time.data[-1].item()) diff --git a/gprof_nn/definitions.py b/gprof_nn/definitions.py index 7509a0f..188cb4e 100644 --- a/gprof_nn/definitions.py +++ b/gprof_nn/definitions.py @@ -24,13 +24,24 @@ "cloud_water_path", ] -PROFILE_NAMES = [ +PROFILE_TARGETS = [ "rain_water_content", "cloud_water_content", "snow_water_content", "latent_heat", ] +ANCILLARY_VARIABLES = [ + "land_fraction", + "ice_fraction", + "leaf_area_index", + "snow_depth", + "two_meter_temperature", + "total_column_water_vapor", + "orographic_wind", + "moisture_convergence", +] + SURFACE_TYPE_NAMES = [ "Ocean", "Sea-Ice", @@ -84,9 +95,11 @@ "AMSRE": 2004 } -TEST_DAYS = [1, 2, 3] -VALIDATION_DAYS = [4, 5] -TRAINING_DAYS = list(range(6, 32)) +DATA_SPLIT = { + "test": [1, 2, 3], + "validation": [4, 5], + "training": list(range(6, 32)) +} LIMITS = { "brightness_temperatures": (0, 400), diff --git a/gprof_nn/sensors.py b/gprof_nn/sensors.py index d272c39..3e030ca 100644 --- a/gprof_nn/sensors.py +++ b/gprof_nn/sensors.py @@ -439,27 +439,6 @@ def preprocessor_pixel_record(self): """ return self._preprocessor_pixel_record - @abstractmethod - def load_training_data_1d(self, dataset, targets, augment, rng): - """ - Load training data for GPROF-NN 1D algorithm from NetCDF file. - - Args: - filename: Path of the file from which to load the data. - targets: List of target names to load. - augment: Flag indicating whether or not to augment the training - data. - rng: Numpy random number generator to use for the data - augmentation. - - Return: - Tuple ``(x, y)`` consisting of rank-2 tensor ``x`` with input - features oriented along the last dimension and a dictionary - ``y`` containing the values of the retrieval targets for all - inputs in ``x``. - """ - pass - @property def latitude_ratios(self): """ @@ -512,12 +491,6 @@ def load_viewing_angle(self, data, mask=None): """ return load_variable(data, "viewing_angle", mask=mask) - def load_ocean_fraction(self, data, mask=None): - """ - Load ocean fraction from dataset and convert to 1-hot encoding. - """ - return load_variable(data, "ocean_fraction", mask=mask) - def load_land_fraction(self, data, mask=None): """ Load land fraction from dataset. @@ -627,128 +600,7 @@ def load_brightness_temperatures(self, data, angles=None, mask=None): return load_variable(data, "simulated_brightness_temperatures", mask=mask) return load_variable(data, "brightness_temperatures", mask=mask) - def _load_scene_1d(self, scene, targets, augment, rng): - """ - Helper function to parallelize loading of 1D training data. - """ - if "surface_precip" not in targets: - ts = targets + ["surface_precip"] - else: - ts = targets - if self.use_simulated_tbs: - ts = targets + ["simulated_brightness_temperatures"] - scene = decompress_scene(scene, ts) - - # - # Input data - # - - # Select only samples that have a finite surface precip value. - sp = self.load_target(scene, "surface_precip") - valid = sp >= 0 - - tbs = self.load_brightness_temperatures(scene, mask=valid) - if self.use_simulated_tbs and self.delta_tbs is not None: - noise = np.array(self.delta_tbs) * rng.normal(size=tbs.shape) - tbs += noise - - if augment: - r = rng.random(tbs.shape[0]) - tbs[r > 0.9, 10:15] = np.nan - t2m = self.load_two_meter_temperature(scene, valid)[..., np.newaxis] - tcwv = self.load_total_column_water_vapor(scene, valid) - tcwv = tcwv[..., np.newaxis] - ocean_frac = self.load_ocean_fraction(scene, valid)[..., None] - land_frac = self.load_land_fraction(scene, valid)[..., None] - ice_frac = self.load_ice_fraction(scene, valid)[..., None] - snow_depth = self.load_snow_depth(scene, valid)[..., None] - leaf_area_index = self.load_leaf_area_index(scene, valid)[..., None] - orographic_wind = self.load_orographic_wind(scene, valid)[..., None] - moisture_conv = self.load_moisture_convergence(scene, valid)[..., None] - - x = np.concatenate([ - tbs, - t2m, - tcwv, - ocean_frac, - land_frac, - ice_frac, - snow_depth, - leaf_area_index, - orographic_wind, - moisture_conv - ], axis=1) - - # - # Output data - # - - y = {} - for t in targets: - y_t = self.load_target(scene, t, valid) - y_t = np.nan_to_num(y_t, nan=MASKED_OUTPUT) - y[t] = y_t - - return x, y - - def load_training_data_1d( - self, dataset, targets, augment, rng, drop_inputs=None - ): - """ - Load training data for GPROF-NN 1D retrieval. This function will - only load pixels that with a finite surface precip value in order - to avoid training on samples that don't provide any information to - the 1D retrieval. - - Output values that may be missing for a given pixel are masked using - the 'MASKED_OUTPUT' value. - - Args: - filename: The filename of the NetCDF file containing the training - data. - targets: List of the targets to load. - augment: Whether or not to augment the training data. - rng: Numpy random number generator to use for augmentation. - drop_inputs: A probability with which to set all inputs randomly - to a missing value. - - Return: - Tuple ``(x, y)`` containing the un-batched, un-shuffled training - data as it is contained in the given NetCDF file. - """ - dataset.load() - - x = [] - y = {} - - if isinstance(dataset, (str, Path)): - dataset = xr.open_dataset(dataset) - loaded = True - else: - loaded = False - - n_scenes = dataset.samples.size - - for i in range(n_scenes): - scene = dataset[{"samples": i}] - x_i, y_i = self._load_scene_1d(scene, targets, augment, rng) - x.append(x_i) - for target in targets: - y.setdefault(target, []).append(y_i[target]) - - x = np.concatenate(x, axis=0) - for k in targets: - y[k] = np.concatenate(y[k], axis=0) - - if loaded: - dataset.close() - - if drop_inputs is not None: - drop_inputs_from_sample(x, drop_inputs, self, rng) - - return x, y - - def _load_scene_3d( + def load_scene( self, scene, targets, augment, variables, rng, width, height, drop_inputs=None ): """ @@ -795,7 +647,6 @@ def _load_scene_3d( t2m = self.load_two_meter_temperature(scene)[np.newaxis] tcwv = self.load_total_column_water_vapor(scene)[np.newaxis] - ocean_frac = self.load_ocean_fraction(scene)[None] land_frac = self.load_land_fraction(scene)[None] ice_frac = self.load_ice_fraction(scene)[None] snow_depth = self.load_snow_depth(scene)[None] @@ -807,7 +658,6 @@ def _load_scene_3d( tbs, t2m, tcwv, - ocean_frac, land_frac, ice_frac, snow_depth, @@ -823,7 +673,6 @@ def _load_scene_3d( for t in targets: y_t = self.load_target(scene, t) - y_t = np.nan_to_num(y_t, nan=MASKED_OUTPUT) dims_sp = tuple(range(2)) dims_t = tuple(range(2, y_t.ndim)) @@ -2061,12 +1910,12 @@ def __repr__(self): TRMM = Platform("TRMM", "/pdata4/archive/GPM/1C_TMI_ITE/", "1C.TRMM.TMI") -NOAA19 = Platform("NOAA19", "/pdata4/archive/GPM/1C_NOAA19_ITE/", "1C.NOAA19.MHS") +NOAA19 = Platform("NOAA19", "/pdata4/archive/GPM/1C_NOAA19_V7/", "1C.NOAA19.MHS") NPP = Platform("NPP", "/pdata4/archive/GPM/1C_ATMS_ITE/", "1C.NPP.ATMS") GPM = Platform("GPM-CO", "/pdata4/archive/GPM/1CR_GMI_V7/", "1C-R.GPM.GMI") F15 = Platform("F15", "/pdata4/archive/GPM/1C_F15_ITE/", "1C.F15.SSMI") F17 = Platform("F17", "/pdata4/archive/GPM/1C_F17_ITE/", "1C.F17.SSMIS") -GCOMW1 = Platform("GCOM-W1", "/pdata4/archive/GPM/1C_AMSR2_ITE/", "1C.GCOMW1.AMSR2") +GCOMW1 = Platform("GCOM-W1", "/pdata4/archive/GPM/1C_AMSR2_V7/", "1C.GCOMW1.AMSR2") AQUA = Platform("AQUA", "/pdata4/archive/GPM/1C_AMSRE/", "1C.AQUA.AMSRE") ############################################################################### diff --git a/gprof_nn/training.py b/gprof_nn/training.py index b984b7d..62ac418 100644 --- a/gprof_nn/training.py +++ b/gprof_nn/training.py @@ -2,420 +2,28 @@ gprof_nn.training ================= -Implements training routines for the different stages of the -training of the GPROF-NN retrievals. +Interface for training of the GPROF-NN retrievals. """ -from dataclasses import dataclass -from pathlib import Path -import re -from typing import Optional, Union, Tuple, List -import numpy as np -import pytorch_lightning as pl -from quantnn import metrics -from quantnn.mrnn import MRNN, Quantiles -import torch -from torch.utils.data import DataLoader - -from gprof_nn.definitions import MASKED_OUTPUT -from gprof_nn.models import GPROF_NN_3D_CONFIGS, GPROFNet3D -from gprof_nn.data.training_data import PretrainingDataset - - -@dataclass -class TrainingConfig: - """ - A description of a training regime. - """ - name: str - n_epochs: int - optimizer: str - optimizer_kwargs: Optional[dict] = None - scheduler: str = None - scheduler_kwargs: Optional[dict] = None - precision: str = "16-mixed" - batch_size: int = 8 - accelerator: str = "cuda" - data_loader_workers: int = 4 - minimum_lr: Optional[float] = None - reuse_optimizer: bool = False - stepwise_scheduling: bool = False - - -def parse_training_config(path: Union[str, Path]): - """ - Parse a training config file. - - Args: - path: Path pointing to the training config file. - - Return: - A list 'TrainingConfig' objects representing the training - passes to perform. - """ - path = Path(path) - parser = ConfigParser() - parser.read(path) - - training_configs = [] - - for section_name in parser.sections(): - - sec = parser[section_name] - - n_epochs = sec.getint("n_epochs", 1) - optimizer = sec.get("optimizer", "SGD") - optimizer_kwargs = eval(sec.get("optimizer_kwargs", "{}")) - scheduler = sec.get("scheduler", None) - scheduler_kwargs = eval(sec.get("scheduler_kwargs", "{}")) - precision = sec.get("precision", "16-mixed") - batch_size = sec.getint("batch_size", 8) - data_loader_workers = sec.getint("data_loader_workers", 8) - minimum_lr = sec.getfloat("minimum_lr", None) - reuse_optimizer = sec.getboolean("reuse_optimizer", False) - stepwise_scheduling = sec.getboolean("stepwise_scheduling", False) - - training_configs.append(TrainingConfig( - name=section_name, - n_epochs=n_epochs, - optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, - precision=precision, - sample_rate=sample_rate, - batch_size=batch_size, - data_loader_workers=data_loader_workers, - minimum_lr=minimum_lr, - reuse_optimizer=reuse_optimizer, - stepwise_scheduling=stepwise_scheduling - )) - - -def get_optimizer_and_scheduler( - training_config, - model, - previous_optimizer=None -): - """ - Return torch optimizer, learning-rate scheduler and callback objects - corresponding to this configuration. - - Args: - training_config: A TrainingConfig object specifying training - settings for one training stage. - model: The model to be trained as a torch.nn.Module object. - previous_optimizer: Optimizer from the previous stage in case - it is reused. - - Return: - A tuple ``(optimizer, scheduler, callbacks)`` containing a PyTorch - optimizer object ``optimizer``, the corresponding LR scheduler - ``scheduler`` and a list of callbacks. - - Raises: - Value error if training configuration specifies to reuse the optimizer - but 'previous_optimizer' is none. - - """ - if training_config.reuse_optimizer: - if previous_optimizer is None: - raise RuntimeError( - "Training stage '{training_config.name}' has 'reuse_optimizer' " - "set to 'True' but no previous optimizer is available." - ) - optimizer = previous_optimizer - - else: - optimizer_cls = getattr(torch.optim, training_config.optimizer) - optimizer = optimizer_cls( - model.parameters(), - **training_config.optimizer_kwargs - ) - - scheduler = training_config.scheduler - if scheduler is None: - return optimizer, None, [] - - if scheduler == "lr_search": - scheduler = torch.optim.lr_scheduler.ExponentialLR( - optimizer, - gamma=2.0 - ) - callbacks = [ - ResetParameters(), - ] - return optimizer, scheduler, callbacks - - scheduler = getattr(torch.optim.lr_scheduler, training_config.scheduler) - scheduler_kwargs = training_config.scheduler_kwargs - if scheduler_kwargs is None: - scheduler_kwargs = {} - scheduler = scheduler( - optimizer=optimizer, - **scheduler_kwargs, - ) - scheduler.stepwise = training_config.stepwise_scheduling - - if training_config.minimum_lr is not None: - callbacks = [ - EarlyStopping( - f"Learning rate", - stopping_threshold=training_config.minimum_lr * 1.001, - patience=training_config.n_epochs, - verbose=True, - strict=True - ) - ] - else: - callbacks = [] - - return optimizer, scheduler, callbacks - - -def create_data_loaders_pretraining( - training_config: TrainingConfig, - training_data_path: Path, - validation_data_path: Optional[Path] -) -> Tuple[DataLoader, Optional[DataLoader]]: - """ - Create pytorch Dataloaders for training and validation data. - - Args: - training_config: Dataclass specifying the training configuration, - which defines how many processes to use for the data loading. - training_data_path: The path pointing to the folder containing - the training data. - validation_data_path: The path pointing to the folder containing - the validation data. - """ - training_data = PretrainingDataset( - training_data_path, - normalize=True, - ancillary_data=True, - channel_corruption=0.2 - ) - training_loader = DataLoader( - training_data, - shuffle=True, - batch_size=training_config.batch_size, - num_workers=training_config.data_loader_workers, - worker_init_fn=training_data.seed, - pin_memory=True, - ) - if validation_data_path is None: - return training_loader, None - - validation_data = PretrainingDataset( - validation_data_path, - normalize=True, - ancillary_data=True, - channel_corruption=0.2 - ) - validation_loader = DataLoader( - validation_data, - shuffle=False, - batch_size=training_config.batch_size, - num_workers=training_config.data_loader_workers, - worker_init_fn=validation_data.seed, - pin_memory=True, - ) - return training_loader, validation_loader - - -def find_most_recent_checkpoint(path: Path, model_name: str) -> Path: - """ - Find most recente Pytorch lightning checkpoint files. - - Args: - path: A pathlib.Path object pointing to the folder containing the - checkpoints. - model_name: The model name as defined by the user. - - Return: - If a checkpoint was found, returns a object pointing to the - checkpoint file with the highest version number. Otherwise - returns 'None'. - """ - path = Path(path) - - checkpoint_files = list(path.glob(f"{model_name}*.ckpt")) - if len(checkpoint_files) == 0: - return None - if len(checkpoint_files) == 1: - return checkpoint_files[0] - - checkpoint_regexp = re.compile(rf"{model_name}(-v\d*)?.ckpt") - versions = [] - for checkpoint_file in checkpoint_files: - match = checkpoint_regexp.match(checkpoint_file.name) - if match is None: - return None - if match.group(1) is None: - versions.append(-1) - else: - versions.append(int(match.group(1)[2:])) - ind = np.argmax(versions) - return checkpoint_files[ind] - - -def compile_model( - model_config: str, - kind: str = "pretraining", - base_model: Optional[Path] = None -) -> MRNN: - """ - Compile quantnn.mrnn.MRNN model for the training. - - Args: - model_config: Name of the model config. - kind: The kind of training to be performed. - base_model: Path to a pretrained model to initialize the model - with. - """ - config = GPROF_NN_3D_CONFIGS[model_config] - if kind.lower() in ["pretraining", "pretrain"]: - targets = { - f"tbs_{ind}": (32,) for ind in range(15) - } - else: - targets = { - name: (16, 28) if name in PROFILE_NAMES else (32,) - for name in ALL_TARGETS - } - - model = GPROFNet3D( - config.n_channels, - config.n_blocks, - targets=targets, - ancillary_data=config.ancillary_data - ) - - losses = { - name: Quantiles(np.linspace(0, 1, shape[0] + 2)[1:-1]) - for name, shape in targets.items() - } - - return MRNN(losses=losses, model=model) - - -def run_pretraining( - output_path: Union[Path, str], - model_config: str, - training_configs: List[TrainingConfig], - training_data_path: Path, - validation_data_path: Optional[Path] = None, - continue_training: bool = False -): - """ - Run pretraining for GPROF-NN model. - - Args: - output_path: The path to which to write the trained model. - model_config: The configuration of the base model. Should be one - ['small', 'small_no_ancillary', 'large', 'large_no_ancillary'] - of for 'small' order 'large' model capacity and including - ancillary data or not. - training_configs: List of training configs specifying the training settings - for all training passes to perform. - validation_data: Optional path pointing to the validation data. - """ - output_path = Path(output_path) - output_path.mkdir(exist_ok=True, parents=True) - - mrnn = compile_model(model_config, "pretrain") - model_name = f"gprof_nn_3d_pre_{model_config}" - - mtrcs = [ - metrics.Bias(), - metrics.Correlation(), - metrics.CRPS(), - metrics.MeanSquaredError(), - ] - lightning_module = mrnn.lightning( - mask=MASKED_OUTPUT, - metrics=mtrcs, - name=model_name, - log_dir=output_path / "logs" - ) - - ckpt_path = None - if continue_training: - ckpt_path = find_most_recent_checkpoint(output_path, model_name) - ckpt_data = torch.load(ckpt_path) - stage = ckpt_data["stage"] - lightning_module.stage = stage - - devices = None - callbacks = [ - pl.callbacks.LearningRateMonitor(), - pl.callbacks.ModelCheckpoint( - dirpath=output_path, - filename=f"gprof_nn_{model_name}", - verbose=True - ) - ] - - all_optimizers = [] - all_schedulers = [] - all_callbacks = [] - opt_prev = None - for stage_ind, training_config in enumerate(training_configs): - opt_s, sch_s, cback_s = get_optimizer_and_scheduler( - training_config, - mrnn.model, - previous_optimizer=opt_prev - ) - opt_prev = opt_s - all_optimizers.append(opt_s) - all_schedulers.append(sch_s) - all_callbacks.append(cback_s) - - lightning_module.optimizer = all_optimizers - lightning_module.scheduler = all_schedulers - - - for stage_ind, training_config in enumerate(training_configs): - if stage_ind < lightning_module.stage: - continue - - # Restore LR if optimizer is reused. - if training_config.reuse_optimizer: - if "lr" in training_config.optimizer_kwargs: - optim = lightning_module.optimizer[stage_ind] - lr = training_config.optimizer_kwargs["lr"] - for group in optim.param_groups: - group["lr"] = lr - - stage_callbacks = callbacks + all_callbacks[stage_ind] - training_loader, validation_loader = create_data_loaders_pretraining( - training_config, - training_data_path, - validation_data_path - ) - - if training_config.accelerator in ["cuda", "gpu"]: - devices = -1 - else: - devices = 1 - lightning_module.stage_name = training_config.name - - trainer = pl.Trainer( - default_root_dir=output_path, - max_epochs=training_config.n_epochs, - accelerator=training_config.accelerator, - devices=devices, - precision=training_config.precision, - logger=lightning_module.tensorboard, - callbacks=stage_callbacks, - num_sanity_val_steps=0, - #strategy=pl.strategies.DDPStrategy(find_unused_parameters=True), - ) - trainer.fit( - model=lightning_module, - train_dataloaders=training_loader, - val_dataloaders=validation_loader, - ckpt_path=ckpt_path - ) - mrnn.save(output_path / f"cimr_{model_name}.pckl") - ckpt_path=None +def init(path: Path, + configuration: str, + training_data_path: Path, + validation_data_path: Optional[Path] = None, + targets: List[str] = ALL_TARGETS, + ancillary_data: bool = True, +) -> None: + + config_path = Path(__file__).parent / "config_files" + + training_config = config_path / f"gprof_nn_{config.lower()}_training.toml" + training_config = open(training_config, "r").read() + training.config = training_config.format({ + "training_dataset_args": f"{{path = '{training_data_path}'}}", + "validation_dataset_args": f"{{path = '{validation_data_path}'}}", + }) + with open(path / "training.toml") as output: + output.write(training_config) + + model_config = config_path / f"gprof_nn_{config.lower()}_training.toml" + model_config = open(training_config, "r").read() + for target in targets: diff --git a/notebooks/evaluate_ancillary_variables.ipynb b/notebooks/evaluate_ancillary_variables.ipynb index 2624619..c89afb4 100644 --- a/notebooks/evaluate_ancillary_variables.ipynb +++ b/notebooks/evaluate_ancillary_variables.ipynb @@ -319,7 +319,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/test/conftest.py b/test/conftest.py index aad599c..e87ad62 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -23,6 +23,19 @@ not SIM_DATA.exists(), reason="Needs sim files." ) +HAS_TEST_DATA = "GPROF_NN_TEST_DATA" in os.environ +NEEDS_TEST_DATA = pytest.mark.skipif( + not HAS_TEST_DATA, reason="Test data not available." +) + +@pytest.fixture() +def test_data(): + """ + The test data path as set in the 'GPROF_NN_TEST_DATA' environment variable. + """ + return Path(os.environ["GPROF_NN_TEST_DATA"]) + + @pytest.fixture(scope="session") def sim_collocations_gmi() -> xr.Dataset: """ diff --git a/test/data/test_era5.py b/test/data/test_era5.py deleted file mode 100644 index 0533a73..0000000 --- a/test/data/test_era5.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Tests for the gprof_nn.data.era5 module. -""" -from datetime import datetime - -from conftest import NEEDS_ARCHIVES -import numpy as np - -from gprof_nn.data.era5 import ( - load_era5_data, - add_era5_precip -) - - -@NEEDS_ARCHIVES -def test_load_era5_data(): - """ - Tests adding ERA5 precip data to preprocessor data. - """ - - start_time = np.datetime64("2020-01-01T23:00:00") - end_time = np.datetime64("2020-01-02T01:00:00") - - era5_data = load_era5_data(start_time, end_time) - assert era5_data.time[0] < start_time - assert era5_data.time[-1] > end_time - - -@NEEDS_ARCHIVES -def test_add_era5_precip(preprocessor_data_gmi): - """ - Tests adding ERA5 precip data to preprocessor data. - """ - data_pp = preprocessor_data_gmi - start_time = preprocessor_data_gmi.scan_time.data[0] - end_time = preprocessor_data_gmi.scan_time.data[-1] - era5_data = load_era5_data(start_time, end_time) - - add_era5_precip(data_pp, era5_data) - - surface_type = data_pp.surface_type.data - surface_precip = data_pp.surface_precip.data - - sea_ice = (surface_type == 2) + (surface_type == 16) - assert np.all(np.isfinite(surface_precip[sea_ice])) - assert np.all(np.isnan(surface_precip[~sea_ice])) diff --git a/test/data/test_pretraining.py b/test/data/test_pretraining.py deleted file mode 100644 index b1717d6..0000000 --- a/test/data/test_pretraining.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -Tests for the gprof_nn.data.pretraining module. -""" -from pathlib import Path - -from conftest import NEEDS_ARCHIVES, l1c_file_gmi - -import pytest -import xarray as xr -import numpy as np - -from gprof_nn import sensors -from gprof_nn.data.l1c import L1CFile -from gprof_nn.data.pretraining import process_l1c_file -from gprof_nn.data.training_data import PretrainingDataset -from gprof_nn.models import GPROFNet3D - - -@pytest.fixture(scope="module") -def gmi_pretraining_data(tmp_path_factory, l1c_file_gmi): - input_path = tmp_path_factory.mktemp("l1c") - input_file = input_path / l1c_file_gmi.name - L1CFile(l1c_file_gmi).extract_scan_range(0, 512, input_file) - - path = tmp_path_factory.mktemp("data") - scenes = process_l1c_file(sensors.GMI, input_file, path) - return path - - -@NEEDS_ARCHIVES -def test_process_l1c_file_gmi(gmi_pretraining_data): - """ - Assert that processing a l1c file produces sample files suitable for - pretraining. - """ - files = sorted(list(gmi_pretraining_data.glob("*.nc"))) - assert len(files) > 0 - - data = xr.load_dataset(files[0]) - data.brightness_temperatures[..., 0].min() > 0.0 - - -@NEEDS_ARCHIVES -def test_pretraining_dataset(gmi_pretraining_data): - """ - Assert that data loaded using the PretrainingDataset is valid. - """ - dataset = PretrainingDataset(gmi_pretraining_data) - for ind in range(len(dataset)): - x, y = dataset[ind] - x = x.numpy() - assert x.min() >= -1.5 - assert x.max() <= 1.0 - for key in y: - assert np.all(np.isfinite(y[key].numpy())) - - -@NEEDS_ARCHIVES -def test_pretraining_model(gmi_pretraining_data): - - dataset = PretrainingDataset(gmi_pretraining_data) - - x, y = dataset[0] - x = x[None] - - targets = {f"channel_{ind}": (32,) for ind in range(15)} - - model = GPROFNet3D( - [32, 64, 128, 256, 512], - [2, 2, 2, 2, 2], - targets=targets, - ancillary_data=True - ) - y = model(x) - - assert "channel_1" in y diff --git a/test/data/test_sim.py b/test/data/test_sim.py deleted file mode 100644 index b40f24a..0000000 --- a/test/data/test_sim.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -This file contains tests for the reading and processing of .sim files -defined in 'gprof_nn.data.sim.py'. -""" -from pathlib import Path - -import numpy as np -import pytest -import xarray as xr - -from gprof_nn import sensors -from gprof_nn.data import get_test_data_path -from gprof_nn.data.l1c import L1CFile -from gprof_nn.data.sim import ( - SimFile, - SubsetConfig, - apply_orographic_enhancement, - collocate_targets, - write_training_samples_1d, - write_training_samples_3d, -) -from gprof_nn.data.preprocessor import PreprocessorFile - - -DATA_PATH = get_test_data_path() -HAS_ARCHIVES = Path(sensors.GMI.l1c_file_path).exists() - -SIM_DATA = Path("/qdata1/pbrown/dbaseV8") -NEEDS_SIM_DATA = pytest.mark.skipif( - not SIM_DATA.exists(), reason="Needs sim files." -) - -@NEEDS_SIM_DATA -def test_open_sim_file_gmi(): - """ - Tests reading simulator output file for GMI. - """ - input_file = SIM_DATA / "simV8/1810/GMI.dbsatTb.20181031.026559.sim" - - sim_file = SimFile(input_file) - data = sim_file.to_xarray_dataset() - - assert "surface_precip" in data.variables.keys() - assert "latent_heat" in data.variables.keys() - assert "snow_water_content" in data.variables.keys() - assert "rain_water_content" in data.variables.keys() - - valid = data.surface_precip.data > -9999 - assert valid.sum() > 0 - assert np.all(data.surface_precip[valid] >= 0.0) - assert np.all(data.surface_precip[valid] <= 1000.0) - assert np.all(data.latitude >= -90.0) - assert np.all(data.latitude <= 90.0) - assert np.all(data.longitude >= -180.0) - assert np.all(data.longitude <= 180.0) - - -@NEEDS_SIM_DATA -def test_collocate_targets(tmp_path): - - input_file = SIM_DATA / "simV8/1810/GMI.dbsatTb.20181031.026559.sim" - data = collocate_targets( - input_file, - sensors.GMI, - None, - ) - sp = data.surface_precip.data - assert (sp >= 0.0).sum() > 0 - - output_path = tmp_path / "1d" - output_path.mkdir() - write_training_samples_1d(data, output_path) - training_files = list(output_path.glob("*.nc")) - assert len(training_files) == 1 - - output_path = tmp_path / "3d" - output_path.mkdir() - write_training_samples_3d(data, output_path) - training_files = list(output_path.glob("*.nc")) - assert len(training_files) > 1 diff --git a/test/test_augmentation.py b/test/test_augmentation.py deleted file mode 100644 index b1f4634..0000000 --- a/test/test_augmentation.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -Tests for the data augmentation methods in gprof_nn.augmentation. -""" -from pathlib import Path - -import numpy as np -import xarray as xr - -from gprof_nn import sensors -from gprof_nn.data import get_test_data_path -from gprof_nn.augmentation import ( - Swath, - get_center_pixels, - get_transformation_coordinates, - get_center_pixel_input, - M, - N -) -from gprof_nn.data.training_data import decompress_and_load - - -DATA_PATH = get_test_data_path() - - -def test_gmi_geometry(): - """ - Assert that coordinate transformation function for GMI viewing - geometry are reversible. - """ - i = np.arange(0, 221, 10) - j = np.arange(0, 221, 10) - ij = np.stack(np.meshgrid(i, j)) - geometry = sensors.GMI.viewing_geometry - xy = geometry.pixel_coordinates_to_euclidean(ij) - ij_r = geometry.euclidean_to_pixel_coordinates(xy) - assert np.all(np.isclose(ij, ij_r)) - - -def test_mhs_geometry(): - """ - Assert that coordinate transformation function for GMI viewing - geometry are reversible. - """ - i = np.arange(0, 90, 10) - j = np.arange(0, 90, 10) - ij = np.stack(np.meshgrid(i, j)) - geometry = sensors.MHS.viewing_geometry - xy = geometry.pixel_coordinates_to_euclidean(ij) - ij_r = geometry.euclidean_to_pixel_coordinates(xy) - assert np.all(np.isclose(ij, ij_r)) - - -def test_swath_geometry(): - """ - Assert that coordinate transformation function for GMI viewing - geometry are reversible. - """ - input_file = DATA_PATH / "mhs" / "gprof_nn_mhs_era5.nc.gz" - input_data = decompress_and_load(input_file) - - lats = input_data.latitude.data[0] - lons = input_data.longitude.data[0] - - i = np.arange(0, 221, 10) - j = np.arange(0, 221, 10) - ij = np.stack(np.meshgrid(i, j)) - - swath = Swath(lats, lons) - - xy = swath.pixel_coordinates_to_euclidean(ij) - ij_r = swath.euclidean_to_pixel_coordinates(xy) - - assert np.all(np.isclose(ij, ij_r)) - - -def test_interpolation_weights(): - """ - Ensure that all interpolation weights are positive and sum to 1. - """ - geometry = sensors.MHS.viewing_geometry - weights = geometry.get_interpolation_weights(sensors.MHS.angles) - assert np.all(weights.sum(-1) == 1.0) - assert np.all(weights >= 0) - - -def test_inputer_center(): - """ - Ensures that the calculated window always contains the center of - the GMI swath. - """ - l = get_center_pixel_input(0.0, 64) - assert l + 32 == 110 - r = get_center_pixel_input(1.0, 64) - assert r - 32 == 110 - - -def test_transformation_coordinates(): - """ - Ensure that transformation coordinates correspond to identity - mapping for when input and output window are located at the - center of the swath. - """ - input_file = DATA_PATH / "mhs" / "gprof_nn_mhs_era5.nc.gz" - input_data = decompress_and_load(input_file) - - lats = input_data.latitude.data[0] - lons = input_data.longitude.data[0] - geometry = sensors.GMI.viewing_geometry - c = get_transformation_coordinates( - lats, lons, geometry, 64, - 64, 0.5, 0.5, 0.5 - ) - assert np.all(np.isclose(c[1, 32, :], np.arange(110 - 32, 110 + 32), atol=2.0)) diff --git a/test/test_bin.py b/test/test_bin.py deleted file mode 100644 index f484dfc..0000000 --- a/test/test_bin.py +++ /dev/null @@ -1,576 +0,0 @@ -""" -This file tests the 'gprof_nn.data.bin' module which provides functionality -to read and extract training data from the 'bin' files used by GPROF. -""" -from pathlib import Path - -import numpy as np -import xarray as xr - -from gprof_nn import sensors -from gprof_nn.data import get_test_data_path -from gprof_nn.data.bin import FileProcessor, BinFile -from gprof_nn.data.training_data import GPROF_NN_1D_Dataset - - -DATA_PATH = get_test_data_path() - - -def test_bin_file_gmi(): - """ - Test reading the different types of bin files for standard surface types, - for sea ice and snow. - """ - # - # Simulator-derived bin files. - # - - input_file = DATA_PATH / "gmi" / "bin" / "gpm_275_14_03_17.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["rain_water_path"] >= 0) - assert np.all(input_data["two_meter_temperature"] > 275 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 275 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 14 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 14 + 0.5) - assert np.all(input_data["surface_type"] == 17) - assert np.all(input_data["airmass_type"] == 3) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # Seaice bin files. - # - - input_file = DATA_PATH / "gmi" / "bin" / "gpm_269_00_16.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["rain_water_path"] < 0) - assert np.all(input_data["two_meter_temperature"] > 269 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 269 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 0 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 0 + 0.5) - assert np.all(input_data["surface_type"] == 16) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # MRMS bin files. - # - - input_file = DATA_PATH / "gmi" / "bin" / "gpm_298_28_11.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["rain_water_path"] < 0) - assert np.all(input_data["two_meter_temperature"] > 298 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 298 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 28 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 28 + 0.5) - assert np.all(input_data["surface_type"] == 11) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - -def test_bin_file_mhs(): - """ - Test reading of MHS bin files and ensure all values are physical and match - given bin. - """ - # - # Simulator-derived bin files. - # - - input_file = DATA_PATH / "mhs" / "bin" / "gpm_289_52_04.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["rain_water_path"] >= 0) - assert np.all(input_data["two_meter_temperature"] > 289 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 289 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 52 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 52 + 0.5) - assert np.all(input_data["surface_type"] == 4) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # Seaice bin files. - # - - input_file = DATA_PATH / "mhs" / "bin" / "gpm_271_20_16.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["rain_water_path"] < 0) - assert np.all(input_data["two_meter_temperature"] > 271 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 271 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 20 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 20 + 0.5) - assert np.all(input_data["surface_type"] == 16) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # MRMS bin files. - # - - input_file = DATA_PATH / "mhs" / "bin" / "gpm_292_25_11.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["rain_water_path"] < 0) - assert np.all(input_data["two_meter_temperature"] > 292 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 292 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 25 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 25 + 0.5) - assert np.all(input_data["surface_type"] == 11) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - -def test_bin_file_tmi(): - """ - Test reading of TMI bin files and ensure all values are physical and match - given bin. - """ - # - # Simulator-derived bin files. - # - - DATA_PATH = Path(__file__).parent/ "data" - input_file = DATA_PATH / "tmi" / "bin" / "gpm_309_08_04.bin" - - input_data = BinFile(input_file).to_xarray_dataset() - - assert input_data.channels.size == 9 - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["rain_water_path"] >= 0) - assert np.all(input_data["two_meter_temperature"] > 309 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 309 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 8 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 8 + 0.5) - assert np.all(input_data["surface_type"] == 4) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # Seaice bin files. - # - - input_file = DATA_PATH / "tmi" / "bin" / "gpm_273_15_16.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["rain_water_path"] < 0) - assert np.all(input_data["two_meter_temperature"] > 273 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 273 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 15 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 15 + 0.5) - assert np.all(input_data["surface_type"] == 16) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # MTN bin files. - # - - input_file = DATA_PATH / "tmi" / "bin" / "gpm_295_16_11.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["rain_water_path"] < 0) - assert np.all(input_data["two_meter_temperature"] > 295 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 295 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 16 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 16 + 0.5) - assert np.all(input_data["surface_type"] == 11) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - -def test_bin_file_ssmi(): - """ - Test reading of SSMIS bin files and ensure all values are physical and - match given bin. - """ - # - # Simulator-derived bin files. - # - - DATA_PATH = Path(__file__).parent/ "data" - input_file = DATA_PATH / "ssmi" / "bin" / "gpm_287_45_04.bin" - - input_data = BinFile(input_file).to_xarray_dataset() - - assert input_data.channels.size == 7 - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["rain_water_path"] >= 0) - assert np.all(input_data["two_meter_temperature"] > 287 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 287 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 45 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 45 + 0.5) - assert np.all(input_data["surface_type"] == 4) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # Seaice bin files. - # - - input_file = DATA_PATH / "ssmi" / "bin" / "gpm_240_02_16.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["rain_water_path"] < 0) - assert np.all(input_data["two_meter_temperature"] > 240 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 240 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 2 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 2 + 0.5) - assert np.all(input_data["surface_type"] == 16) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # MTN bin files. - # - - input_file = DATA_PATH / "ssmi" / "bin" / "gpm_272_06_03_17.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["two_meter_temperature"] > 272 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 272 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 6 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 6 + 0.5) - assert np.all(input_data["surface_type"] == 17) - assert np.all(input_data["airmass_type"] == 3) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - -def test_bin_file_ssmis(): - """ - Test reading of SSMIS bin files and ensure all values are physical and - match given bin. - """ - # - # Simulator-derived bin files. - # - - DATA_PATH = Path(__file__).parent/ "data" - input_file = DATA_PATH / "ssmis" / "bin" / "gpm_320_36_07.bin" - - input_data = BinFile(input_file).to_xarray_dataset() - - assert input_data.channels.size == 11 - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["rain_water_path"] >= 0) - assert np.all(input_data["two_meter_temperature"] > 320 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 320 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 36 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 36 + 0.5) - assert np.all(input_data["surface_type"] == 7) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - assert np.all(np.any(valid, axis=0)) - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # Seaice bin files. - # - - input_file = DATA_PATH / "ssmis" / "bin" / "gpm_297_26_02.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["rain_water_path"] < 0) - assert np.all(input_data["two_meter_temperature"] > 297 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 297 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 26 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 26 + 0.5) - assert np.all(input_data["surface_type"] == 2) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - assert np.all(np.any(valid, axis=0)) - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # MTN bin files. - # - - input_file = DATA_PATH / "ssmis" / "bin" / "gpm_320_20_00_17.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["two_meter_temperature"] > 320 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 320 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 20 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 20 + 0.5) - assert np.all(input_data["surface_type"] == 17) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # MRMS bin files. - # - - input_file = DATA_PATH / "ssmis" / "bin" / "gpm_301_17_10.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["two_meter_temperature"] > 301 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 301 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 17 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 17 + 0.5) - assert np.all(input_data["surface_type"] == 10) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - -def test_bin_file_atms(): - """ - Test reading of ATMS bin files and ensure all values are physical and - match given bin. - """ - # - # Simulator-derived bin files. - # - - DATA_PATH = Path(__file__).parent/ "data" - input_file = DATA_PATH / "atms" / "bin" / "gpm_289_52_03.bin" - - input_data = BinFile(input_file).to_xarray_dataset() - - assert input_data.channels.size == 5 - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["rain_water_path"] >= 0) - assert np.all(input_data["two_meter_temperature"] > 289 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 289 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 52 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 52 + 0.5) - assert np.all(input_data["surface_type"] == 3) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - assert np.all(np.any(valid, axis=0)) - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # Seaice bin files. - # - - input_file = DATA_PATH / "atms" / "bin" / "gpm_264_09_16.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["rain_water_path"] < 0) - assert np.all(input_data["two_meter_temperature"] > 264 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 264 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 9 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 9 + 0.5) - assert np.all(input_data["surface_type"] == 16) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - assert np.all(np.any(valid, axis=0)) - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # MTN bin files. - # - - input_file = DATA_PATH / "atms" / "bin" / "gpm_305_31_01_17.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["two_meter_temperature"] > 305 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 305 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 31 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 31 + 0.5) - assert np.all(input_data["surface_type"] == 17) - assert np.all(input_data["airmass_type"] == 1) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - # - # MRMS bin files. - # - - input_file = DATA_PATH / "atms" / "bin" / "gpm_291_13_11.bin" - input_data = BinFile(input_file).to_xarray_dataset() - - assert np.all(input_data["surface_precip"] >= 0) - assert np.all(input_data["surface_precip"] <= 500) - assert np.all(input_data["convective_precip"] >= 0) - assert np.all(input_data["convective_precip"] <= 500) - assert np.all(input_data["two_meter_temperature"] > 291 - 0.5) - assert np.all(input_data["two_meter_temperature"] < 291 + 0.5) - assert np.all(input_data["total_column_water_vapor"] > 13 - 0.5) - assert np.all(input_data["total_column_water_vapor"] < 13 + 0.5) - assert np.all(input_data["surface_type"] == 11) - assert np.all(input_data["airmass_type"] == 0) - tbs = input_data.brightness_temperatures.data - valid = tbs > 0 - tbs = tbs[valid] - assert np.all(tbs > 20) - assert np.all(tbs < 400) - - -def test_file_processor_gmi(tmp_path): - """ - This tests the extraction of data from a bin file and ensures that - the extracted dataset matches the original data. - """ - path = Path(__file__).parent - processor = FileProcessor( - path / "data" / "gmi" / "bin" / "processor", include_profiles=True - ) - output_file = tmp_path / "test_file.nc" - processor.run_async(output_file, 0.0, 1.0, 1) - - input_file = BinFile(path / "data" / "gmi" / "bin" / "gpm_275_14_03_17.bin") - input_data = input_file.to_xarray_dataset() - - dataset = GPROF_NN_1D_Dataset( - output_file, normalize=False, shuffle=False, augment=False - ) - normalizer = dataset.normalizer - - tbs_input = input_data["brightness_temperatures"].data - tbs = dataset.x[:, : input_file.sensor.n_chans] - tbs = np.nan_to_num(tbs, nan=-9999.9) - valid = tbs[:, -1] > 0 - - assert np.all(np.isclose(tbs_input[valid].mean(), tbs[valid].mean())) - assert np.all( - np.isclose(input_data["two_meter_temperature"].mean(), dataset.x[:, 15].mean()) - ) - assert np.all( - np.isclose( - input_data["total_column_water_vapor"].mean(), dataset.x[:, 16].mean() - ) - ) - - surface_types = np.where(dataset.x[:, 17:35])[1] - assert np.all(np.isclose(input_data.surface_type, surface_types + 1)) - airmass_types = np.where(dataset.x[:, 35:])[1] - assert np.all(np.isclose(np.maximum(input_data.airmass_type, 1), airmass_types)) diff --git a/test/test_combined.py b/test/test_combined.py deleted file mode 100644 index 8b9fb98..0000000 --- a/test/test_combined.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -Tests for the 'gprof_nn.data.combined' module. -""" -from pathlib import Path - -import numpy as np -import pytest - -from gprof_nn.data import get_test_data_path -from gprof_nn.data.combined import ( - GPMCMBFile, - calculate_smoothing_kernels, - load_combined_data_bin -) - - -DATA_PATH = get_test_data_path() - - -def test_read_gpm_cmb_file(): - """ - Test reading of GPM combined file. - """ - path = DATA_PATH / "cmb" - filename = ( - path / - "2B.GPM.DPRGMI.CORRA2022.20210829-S205206-E222439.042628.V07A.HDF5" - ) - data = GPMCMBFile(filename).to_xarray_dataset() - assert "latitude" in data.variables - assert "longitude" in data.variables - assert "surface_precip" in data.variables - - -def test_read_gpm_cmb_file_smoothed(): - """ - Test reading of GPM combined file with smoothing of surface precip data. - """ - path = DATA_PATH / "cmb" - filename = ( - path / - "2B.GPM.DPRGMI.CORRA2022.20210829-S205206-E222439.042628.V07A.HDF5" - ) - data = GPMCMBFile(filename).to_xarray_dataset(smooth=True) - assert "latitude" in data.variables - assert "longitude" in data.variables - assert "surface_precip" in data.variables - - -@pytest.mark.slow -def test_read_gpm_cmb_file_profiles_smoothed(): - """ - Test reading of GPM combined file with profiles and smoothing. - """ - path = DATA_PATH / "cmb" - filename = ( - path / - "2B.GPM.DPRGMI.CORRA2022.20210829-S205206-E222439.042628.V07A.HDF5" - ) - data = GPMCMBFile(filename).to_xarray_dataset(profiles=True, smooth=True) - assert "latitude" in data.variables - assert "longitude" in data.variables - assert "surface_precip" in data.variables - - -def test_smoothing_kernels(): - """ - Test that calculation of smoothing kernels works correctly. - """ - k = calculate_smoothing_kernels(2.0 * 4.9e3, 2.0 * 5.09e3) - - # Assert kernel has expected shape and is normalized. - assert k.shape == (9, 9) - assert np.isclose(k.sum(), 1.0) - - # Assert that full-width at half maximum is at the correct location. - k_max = k[4, 4] - assert np.isclose(k[4, 5] / k_max, 0.5) - assert np.isclose(k[4, 3] / k_max, 0.5) - assert np.isclose(k[5, 4] / k_max, 0.5) - assert np.isclose(k[3, 4] / k_max, 0.5) - - -def test_read_combined_bin(): - """ - Ensure that values read from combined data in .bin format - are consistent. - """ - data_path = Path(__file__).parent / "data" - filename = data_path / "CMB.20181001.026079.ITE768.bin.gz" - data = load_combined_data_bin(filename) - - lats = data.latitude.data - assert np.all((lats >= -90) * (lats <= 90)) - - lons = data.longitude.data - assert np.all((lons >= -180) * (lons <= 180)) - - sp = data.surface_precip.data - sp[sp >= 0.0] - assert np.all((sp >= 0.0) * (sp <= 500)) - - t = data.temperature.data - t = t[t >= 0] - assert np.all((t > 150) * (t <= 400)) - - tbs = data.simulated_brightness_temperatures.data - tbs = tbs[tbs >= 0.0] - assert np.all((tbs >= 20.0) * (tbs < 400)) - - scan_time = data.scan_time.data - assert np.all(np.isclose(scan_time[0], 2018)) - assert np.all(np.isclose(scan_time[1], 10)) diff --git a/test/test_config.py b/test/test_config.py deleted file mode 100644 index 03d9347..0000000 --- a/test/test_config.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -from pathlib import Path - -from gprof_nn.config import ( - parse_config_file, - get_config_file, - set_config -) - -CONFIG = """ -[preprocessor] -GMI=/gmi/preprocessor -MHS=/mhs/preprocessor -""" - -def test_config(tmp_path): - """ - Test parsing of config file. - """ - config_file = tmp_path / "config.ini" - with open(config_file, "w") as hndl: - hndl.write(CONFIG) - - config = parse_config_file(config_file) - assert config.preprocessor.GMI == Path("/gmi/preprocessor") - assert config.preprocessor.MHS == Path("/mhs/preprocessor") - - -def test_get_config_file(tmp_path): - """ - Test path of config file. - """ - config_file = tmp_path / "config.ini" - os.environ["GPROF_NN_CONFIG"] = str(tmp_path / "config.ini") - config_file_2 = get_config_file() - assert config_file == config_file_2 - - -def test_set_config(tmp_path): - """ - Test setting of config file. - """ - config_file = tmp_path / "config.ini" - os.environ["GPROF_NN_CONFIG"] = str(tmp_path / "config.ini") - set_config("data", "era5_path", tmp_path) - - config = parse_config_file(config_file) - assert config.data.era5_path == tmp_path diff --git a/test/test_coordinates.py b/test/test_coordinates.py deleted file mode 100644 index 23d1141..0000000 --- a/test/test_coordinates.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Test for the coordinates submodule. -""" -import numpy as np -from gprof_nn.coordinates import latlon_to_ecef - -def test_latlon_to_ecef(): - - lons = [0, 90, 180, 270, 360] - lats = [0, 0, 0, 0, 0] - - x, y, z = latlon_to_ecef(lons, lats) - print(x, y, z) - - assert np.all(np.isclose(x[[1, 3]], 0.0)) - assert np.all(np.isclose(y[[0, 2, 4]], 0.0)) - assert np.all(np.isclose(z, 0.0)) diff --git a/test/test_data.py b/test/test_data.py deleted file mode 100644 index 45bb844..0000000 --- a/test/test_data.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -This file tests the 'gprof_nn.data' module. -""" -from pathlib import Path - -import gprof_nn.data -from gprof_nn import sensors - -def test_get_model_path(tmpdir, monkeypatch): - """ - Test that retrieving a model file from the server downloads - the requested file if not present. - """ - path = tmpdir - monkeypatch.setattr(gprof_nn.data, "_DATA_DIR", tmpdir) - tmpdir.mkdir("models") - - path = gprof_nn.data.get_model_path("1D", sensors.GMI, "ERA5") - - assert path.exists() - assert (tmpdir / "models" / "gprof_nn_1d_gmi_era5.pckl").exists() - -def test_get_profile_clusters(tmpdir, monkeypatch): - """ - Test retrieving profile files from the data server downloads the - file if not present. - """ - tmpdir = Path(str(tmpdir)) - monkeypatch.setattr(gprof_nn.data, "_DATA_DIR", tmpdir) - (tmpdir / "profiles").mkdir() - - path = gprof_nn.data.get_profile_clusters() - - assert path.exists() - assert (tmpdir / "profiles" / "GPM_profile_clustersV7.dat").exists() diff --git a/test/test_kwaj.py b/test/test_kwaj.py deleted file mode 100644 index d90e5b1..0000000 --- a/test/test_kwaj.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Tests for gprof_nn.data.kwaj module. -""" -from pathlib import Path - -import numpy as np -import xarray as xr - -from gprof_nn.data.kwaj import RadarFile, get_overpasses - - -DATA_PATH = Path(__file__).parent / "data" - - -def test_radar_file_open_file(): - """ - Test opening of a file from daily archive as xarray.Datset. - """ - filename = DATA_PATH / "kwaj_data.tar.gz" - radar_file = RadarFile(filename) - - filename = "180620/KWAJ_2018_0620_000650.cf.gz" - data = radar_file.open_file(filename) - - assert "RR" in data.variables - assert "RC" in data.variables - - assert data.RR.min() >= 0.0 - assert data.RR.max() >= 0.0 - - assert data.latitude.max() >= 9.0 - assert data.latitude.min() <= 8.0 - assert data.longitude.max() >= 168 - assert data.longitude.min() <= 167 - - -def test_radar_file_open_files(): - """ - Test opening of a file from daily archive as xarray.Datset. - """ - filename = DATA_PATH / "kwaj_data.tar.gz" - radar_file = RadarFile(filename) - - start = np.datetime64("2018-06-20T00:15:00") - end = np.datetime64("2018-06-20T00:20:00") - data = radar_file.open_files(start, end) - - assert "RR" in data.variables - assert "RC" in data.variables - - assert data.RR.min() >= 0.0 - assert data.RR.max() >= 0.0 - - assert data.latitude.max() >= 9.0 - assert data.latitude.min() <= 8.0 - assert data.longitude.max() >= 168 - assert data.longitude.min() <= 167 - - -def test_get_overpasses_gmi(): - overpasses = get_overpasses("gmi") - assert len(overpasses) > 0 diff --git a/test/test_l1c.py b/test/test_l1c.py deleted file mode 100644 index 8744cbf..0000000 --- a/test/test_l1c.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -Test reading and manipulation of L1C files. -""" -from pathlib import Path - -import numpy as np - -from gprof_nn import sensors -from gprof_nn.data import get_test_data_path -from gprof_nn.data.l1c import L1CFile - - -DATA_PATH = get_test_data_path() - - -def test_open_granule_gmi(): - """ - Test finding of specific GMI L1C file and reading data into - xarray.Dataset. - """ - l1c_path = DATA_PATH / "gmi" / "l1c" - - l1c_file = L1CFile.open_granule(27510, l1c_path, sensors.GMI) - l1c_data = l1c_file.to_xarray_dataset() - - assert l1c_data.pixels.size == 221 - assert l1c_data.scans.size == 512 - assert l1c_file.sensor == sensors.GMI - - tbs = l1c_data.brightness_temperatures.data - valid = np.all(tbs > 0, axis=-1) - tbs = tbs[valid] - assert np.all((tbs > 20) * (tbs <= 350)) - -def test_open_granule_mhs(): - """ - Test finding of specific MHS L1C file and reading data into - xarray.Dataset. - """ - l1c_path = DATA_PATH / "mhs" / "l1c" - - l1c_file = L1CFile.open_granule(51010, l1c_path, sensors.MHS) - l1c_data = l1c_file.to_xarray_dataset() - - assert l1c_data.pixels.size == 90 - assert l1c_data.scans.size == 2295 - assert l1c_file.sensor == sensors.MHS - - tbs = l1c_data.brightness_temperatures.data - valid = np.all(tbs > 0, axis=-1) - tbs = tbs[valid] - assert np.all((tbs > 20) * (tbs <= 350)) - - -def test_open_granule_tmi(): - """ - Test finding of specific TMI L1C file and reading data into - xarray.Dataset. - """ - - DATA_PATH = Path(__file__).parent / "data" - l1c_path = DATA_PATH / "tmi" / "l1c" - - l1c_file = L1CFile.open_granule(69095, l1c_path, sensors.TMI) - l1c_data = l1c_file.to_xarray_dataset() - - assert l1c_data.pixels.size == 208 - assert l1c_file.sensor == sensors.TMIPO - - tbs = l1c_data.brightness_temperatures.data - valid = np.all(tbs > 0, axis=-1) - tbs = tbs[valid] - assert np.all((tbs > 20) * (tbs <= 350)) - - -def test_open_granule_ssmis(): - """ - Test finding of specific SSMIS L1C file and reading data into - xarray.Dataset. - """ - DATA_PATH = Path(__file__).parent / "data" - l1c_path = DATA_PATH / "ssmis" / "l1c" - - l1c_file = L1CFile.open_granule(61436, l1c_path, sensors.SSMIS) - l1c_data = l1c_file.to_xarray_dataset() - - assert l1c_data.pixels.size == 180 - assert l1c_file.sensor == sensors.SSMIS - - tbs = l1c_data.brightness_temperatures.data - valid = np.all(tbs > 0, axis=-1) - tbs = tbs[valid] - assert np.all((tbs > 20) * (tbs <= 350)) - - -def test_open_granule_atms(): - """ - Test finding of specific ATMS L1C file and reading data into - xarray.Dataset. - """ - DATA_PATH = Path(__file__).parent / "data" - l1c_path = DATA_PATH / "atms" / "l1c" - - l1c_file = L1CFile.open_granule(35889, l1c_path, sensors.ATMS) - l1c_data = l1c_file.to_xarray_dataset() - - assert l1c_data.pixels.size == 96 - assert l1c_file.sensor == sensors.ATMS - - tbs = l1c_data.brightness_temperatures.data - valid = np.all(tbs > 0, axis=-1) - tbs = tbs[valid] - assert np.all((tbs > 20) * (tbs <= 350)) - - - -def test_find_file_gmi(): - """ - Tests finding a GMI L1C file for a given date. - """ - l1c_path = DATA_PATH / "gmi" / "l1c" - date = np.datetime64("2019-01-01T00:30:00") - l1c_file = L1CFile.find_file(date, l1c_path) - l1c_data = l1c_file.to_xarray_dataset() - assert date > l1c_data.scan_time[0] - assert date < l1c_data.scan_time[-1] - assert l1c_file.sensor == sensors.GMI - - -def test_find_file_mhs(): - """ - Tests finding an MHS file for a given date. - """ - l1c_path = DATA_PATH / "mhs" / "l1c" - date = np.datetime64("2019-01-01T01:33:00") - l1c_file = L1CFile.find_file(date, l1c_path, sensor=sensors.MHS) - data = l1c_file.to_xarray_dataset() - - assert date > data.scan_time[0] - assert date < data.scan_time[-1] - assert l1c_file.sensor == sensors.MHS - assert "incidence_angle" in data.variables - - -def test_find_file_tmi(): - """ - Tests finding an TMI file for a given date. - """ - DATA_PATH = Path(__file__).parent / "data" - l1c_path = DATA_PATH / "tmi" / "l1c" - date = np.datetime64("2010-01-01T01:00:00") - l1c_file = L1CFile.find_file(date, l1c_path, sensor=sensors.TMI) - data = l1c_file.to_xarray_dataset() - - assert date > data.scan_time[0] - assert date < data.scan_time[-1] - assert l1c_file.sensor == sensors.TMIPO - - -def test_find_file_ssmis(): - """ - Tests finding a GMI L1C file for a given date. - """ - DATA_PATH = Path(__file__).parent / "data" - l1c_path = DATA_PATH / "ssmis" / "l1c" - date = np.datetime64("2018-10-01T00:30:00") - l1c_file = L1CFile.find_file(date, l1c_path, sensor=sensors.SSMIS) - l1c_data = l1c_file.to_xarray_dataset() - assert date > l1c_data.scan_time[0] - assert date < l1c_data.scan_time[-1] - assert l1c_file.sensor == sensors.SSMIS - - - -def test_find_files(): - """ - Ensure that finding a file covering a given ROI works as expected - as well the loading of observations covering only a certain - ROI. - """ - l1c_path = DATA_PATH / "gmi" / "l1c" - date = np.datetime64("2019-01-01T00:30:00") - - roi = (-37, -65, -35, -63) - files = list(L1CFile.find_files(date, l1c_path, roi=roi)) - assert len(files) == 1 - - data = files[0].to_xarray_dataset(roi=roi) - n_scans = data.scans.size - lats = data["latitude"].data - lons = data["longitude"].data - assert n_scans < 200 - - # Ensure each scan has at least one obs at a longitude larger than the - # minimum requested. - assert np.all(np.sum(lons >= roi[0], -1) > 1) - assert np.all(np.sum(lons < roi[2], -1) > 1) - assert np.all(np.sum(lats >= roi[1], -1) > 1) - assert np.all(np.sum(lats < roi[3], -1) > 1) - - roi = (-35, 60, -10, 62) - files = list(L1CFile.find_files(date, l1c_path, roi=roi)) - assert len(files) == 0 - -def test_extract_scans(tmpdir): - """ - Test finding of specific GMI L1C file and reading data into - xarray.Dataset. - """ - l1c_path = DATA_PATH / "gmi" / "l1c" - - l1c_file = L1CFile.open_granule(27510, l1c_path, sensors.GMI) - l1c_data = l1c_file.to_xarray_dataset() - - lat = l1c_data.latitude.data[250, :].mean() - lon = l1c_data.longitude.data[250, :].mean() - lon_0 = lon - 0.5 - lon_1 = lon + 0.5 - lat_0 = lat - 0.5 - lat_1 = lat + 0.5 - roi = [lon_0, lat_0, lon_1, lat_1] - - roi_path = Path(tmpdir) / "roi.HDF5" - l1c_file.extract_scans(roi, roi_path, min_scans=256) - roi_file = L1CFile(roi_path) - roi_data = roi_file.to_xarray_dataset() - - lats = roi_data.latitude.data - lons = roi_data.longitude.data - print(lats.mean()) - print(lons.mean()) - print(roi) - - assert roi_data.scans.size >= 256 - inside = ((lons >= lon_0) * - (lons < lon_1) * - (lats >= lat_0) * - (lats < lat_1)) - assert np.any(inside) - - -def test_extract_scan_range(tmpdir): - """ - Test finding of specific GMI L1C file and reading data into - xarray.Dataset. - """ - l1c_path = DATA_PATH / "gmi" / "l1c" - l1c_file = L1CFile.open_granule(27510, l1c_path, sensors.GMI) - l1c_data = l1c_file.to_xarray_dataset() - - roi_path = tmpdir / "l1c_file.HDF5" - l1c_file.extract_scan_range(0, 256, roi_path) - roi_file = L1CFile(roi_path) - roi_data = roi_file.to_xarray_dataset() - - assert roi_data.scans.size == 256 diff --git a/test/test_legacy.py b/test/test_legacy.py deleted file mode 100644 index be434e3..0000000 --- a/test/test_legacy.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -Tests for the gprof_nn.legacy module. -""" -from pathlib import Path - -import numpy as np -import pandas as pd -import pytest - -from gprof_nn import sensors -from gprof_nn.data import get_test_data_path -from gprof_nn.data.training_data import (GPROF_NN_1D_Dataset, - write_preprocessor_file) -from gprof_nn.legacy import (has_gprof, - write_sensitivity_file, - DEFAULT_SENSITIVITIES, - run_gprof_training_data, - run_gprof_standard) -from gprof_nn.data.preprocessor import PreprocessorFile - - -DATA_PATH = get_test_data_path() - - -HAS_GPROF = has_gprof() - - -def test_write_sensitivity_file(tmp_path): - """ - Write sensitivity file containing default sensitivities and ensure - that sensitivities in files match the original ones. - """ - nedts_ref = DEFAULT_SENSITIVITIES - sensitivity_file = tmp_path / "sensitivities.txt" - write_sensitivity_file(sensors.GMI, sensitivity_file, nedts=nedts_ref) - nedts = np.loadtxt(sensitivity_file) - assert np.all(np.isclose(nedts_ref, nedts)) - - -@pytest.mark.skipif(not HAS_GPROF, reason="GPROF executable missing.") -def test_run_gprof_training_data(): - """ - Test running the legacy GPROF algorithm on training data. - """ - path = Path(__file__).parent - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc" - - results = run_gprof_training_data(sensors.GMI, - "ERA5", - input_file, - "STANDARD", - False) - assert "surface_precip" in results.variables - assert "surface_precip_true" in results.variables - - -@pytest.mark.skipif(not HAS_GPROF, reason="GPROF executable missing.") -def test_run_gprof_training_data_preserve_structure(): - """ - Test running the legacy GPROF algorithm on training data while - preserving the spatial structure. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc" - - results = run_gprof_training_data(sensors.GMI, - "ERA5", - input_file, - "STANDARD", - False, - preserve_structure=True) - assert "surface_precip" in results.variables - assert "surface_precip_true" in results.variables - -@pytest.mark.slow -@pytest.mark.skipif(not HAS_GPROF, reason="GPROF executable missing.") -def test_run_gprof_standard(): - """ - Test running legacy GPROF on a preprocessor input file. - """ - input_file = DATA_PATH / "gmi" / "pp" / "GMIERA5_190101_027510.pp" - results = run_gprof_standard(sensors.GMI, - "ERA5", - input_file, - "STANDARD", - False) - assert "surface_precip" in results.variables diff --git a/test/test_models.py b/test/test_models.py deleted file mode 100644 index 0cd78a4..0000000 --- a/test/test_models.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -This file tests the neural network models defined in the -'gprof_nn.models' module. -""" -from pathlib import Path - -import numpy as np -import torch -from quantnn.transformations import LogLinear - -from gprof_nn import sensors -from gprof_nn.data import get_test_data_path -from gprof_nn.definitions import ALL_TARGETS -from gprof_nn.data.training_data import (SimulatorDataset, - GPROF_NN_3D_Dataset) -from gprof_nn.models import ( - MLP, - ResidualMLP, - HyperResidualMLP, - MultiHeadMLP, - GPROF_NN_1D_QRNN, - GPROF_NN_1D_DRNN, - GPROF_NN_3D_QRNN, - SimulatorNet -) - - -DATA_PATH = Path(__file__).parent / "data" - - -def test_mlp(): - """ - Tests for MLP module. - """ - # Make sure 0-layer configuration does nothing. - network = MLP(39, 128, 128, 0) - x = torch.ones(1, 39) - y, acc = network(x, None) - assert np.all(np.isclose(y.detach().numpy(), x.detach().numpy())) - - -def test_residual_mlp(): - """ - Tests for MLP module with residual connections. - """ - # Make sure 0-layer configuration does nothing. - network = ResidualMLP(39, 128, 128, 0) - x = torch.ones(1, 39) - y, acc = network(x, None) - assert np.all(np.isclose(y.detach().numpy(), x.detach().numpy())) - - # Make sure residuals are forwarded through network in internal - # configuration. - network = ResidualMLP(39, 39, 39, 3, internal=True) - x = torch.ones(1, 39) - for p in network.parameters(): - p.data[:] = 0.0 - y, acc = network(x, None) - assert np.all(np.isclose(y.detach().numpy(), x.detach().numpy())) - - -def test_hyper_residual_mlp(): - """ - Tests for MLP module with hyper-residual connections. - """ - network = HyperResidualMLP(39, 128, 128, 0) - x = torch.ones(1, 39) - y, acc = network(x, None) - assert np.all(np.isclose(y.detach().numpy(), x.detach().numpy())) - - # Make sure residuals and hyperresiduals are forwarded through network - # in internal configuration. - network = HyperResidualMLP(39, 39, 39, 3, internal=True) - x = torch.ones(1, 39) - for p in network.parameters(): - p.data[:] = 0.0 - y, acc = network(x, None) - assert np.all(np.isclose(y.detach().numpy(), 3.0 * x.detach().numpy())) - assert np.all(np.isclose(acc.detach().numpy(), 4.0 * x.detach().numpy())) - - -def test_gprof_nn_1d(): - """ - Tests for GPROFNN1D classes module with hyper-residual connections. - """ - network = GPROF_NN_1D_QRNN(sensors.GMI, - 3, 128, 2, 64, - activation="GELU", - transformation=LogLinear) - x = torch.ones(1, 24) - y = network.predict(x) - assert all([t in y for t in ALL_TARGETS]) - network = GPROF_NN_1D_QRNN(sensors.GMI, - 3, 128, 2, 64, - activation="GELU", - residuals="hyper", - transformation=LogLinear) - x = torch.ones(1, 24) - y = network.predict(x) - assert all([t in y for t in ALL_TARGETS]) - - network = GPROF_NN_1D_DRNN(sensors.GMI, - 3, 128, 2, 64, - residuals="hyper", - activation="GELU") - x = torch.ones(1, 24) - y = network.predict(x) - assert all([t in y for t in ALL_TARGETS]) - - # - # Test dropping of inputs. - # - - network = GPROF_NN_1D_DRNN(sensors.GMI, - 3, 128, 2, 64, - residuals="hyper", - activation="GELU", - drop_inputs=[0, 14] - ) - x = torch.ones(1, 24) - x[:, 0] = np.nan - x[:, 14] = np.nan - y = network.predict(x) - assert all([np.all(np.isfinite(value.numpy())) for value in y.values()]) - - -def test_gprof_nn_3d_gmi(): - """ - Ensure that GPROF_NN_3D model is consistent with training data - for GMI. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc" - dataset = GPROF_NN_3D_Dataset(input_file) - network = GPROF_NN_3D_QRNN(sensors.GMI, 2, 128, 2, 64) - x, y = dataset[0] - y_pred = network.predict(x) - assert all([t in y_pred for t in y]) - - network = GPROF_NN_3D_QRNN(sensors.GMI, 2, 128, 2, 64, drop_inputs=[0, 14]) - x, y = dataset[0] - x[:, 0] = np.nan - x[:, 14] = np.nan - y_pred = network.predict(x) - assert all([np.all(np.isfinite(value.numpy())) for value in y_pred.values()]) - - -def test_gprof_nn_3d_mhs(): - """ - Ensure that GPROF_NN_3D model is consistent with training data - for MHS. - """ - input_file = DATA_PATH / "mhs" / "gprof_nn_mhs_era5.nc.gz" - dataset = GPROF_NN_3D_Dataset(input_file, sensor=sensors.MHS) - network = GPROF_NN_3D_QRNN(sensors.MHS, 2, 128, 2, 64) - x, y = dataset[0] - y_pred = network.predict(x) - assert all([t in y_pred for t in y]) - - -def test_simulator(): - """ - Test simulator network. - """ - path = Path(__file__).parent - file = DATA_PATH / "mhs" / "gprof_nn_mhs_era5.nc" - data = SimulatorDataset(file, batch_size=1) - - simulator = SimulatorNet(sensors.MHS, 64, 2, 32) - x, y = data[0] - print(x.shape) - y_pred = simulator(x) - for k in y_pred: - assert k in y - assert y_pred[k].shape[1] == y[k].shape[1] - assert y_pred[k].shape[2] == y[k].shape[2] diff --git a/test/test_mrms.py b/test/test_mrms.py deleted file mode 100644 index 909a4a8..0000000 --- a/test/test_mrms.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -Test reading of MRMS-GMI match ups used for the surface precip -prediction over snow surfaces. -""" -from pathlib import Path - -import numpy as np -import pytest - -from gprof_nn import config -from gprof_nn import sensors -from gprof_nn.data import get_test_data_path -from gprof_nn.data.mrms import MRMSMatchFile, has_snowdas_ratios -from gprof_nn.data.l1c import L1CFile -from gprof_nn.utils import CONUS - - -MRMS_PATH = Path(config.CONFIG.data.mrms_path) -NEEDS_MRMS_DATA = pytest.mark.skipif( - not MRMS_PATH.exists(), - reason="MRMS collocations are not available." -) - - -TEST_FILE_GMI = "1801_MRMS2GMI_gprof_db_08all.bin.gz" -TEST_FILE_MHS = "1801_MRMS2MHS_DB1_01.bin.gz" -TEST_FILE_SSMIS = "1810_MRMS2SSMIS_01.bin.gz" -TEST_FILE_AMSR2 = "1810_MRMS2AMSR2_01.bin.gz" - -HAS_SNOWDAS_RATIOS = has_snowdas_ratios() - -############################################################################### -# GMI -############################################################################### - -@NEEDS_MRMS_DATA -def test_read_file_gmi(): - """ - Read GMI match file and ensure that all latitudes roughly match - CONUS coordinates. - """ - path = MRMS_PATH / "GMI2MRMS_match2019" / "db_mrms4GMI" / TEST_FILE_GMI - ms = MRMSMatchFile(path) - assert np.all(ms.data["latitude"] > 20.0) - assert np.all(ms.data["latitude"] < 70.0) - assert np.all(ms.data["longitude"] > -130.0) - assert np.all(ms.data["longitude"] < -50.0) - data = ms.to_xarray_dataset(day=23) - -@NEEDS_MRMS_DATA -def test_match_precip_gmi(): - """ - Match surface precip from MRMS file to observations in L1C file. - """ - path = MRMS_PATH / "GMI2MRMS_match2019" / "db_mrms4GMI" / TEST_FILE_GMI - date = np.datetime64("2018-01-24T00:00:00") - roi = CONUS - - mrms_file = MRMSMatchFile(path) - l1c_files = L1CFile.find_files(date, path, roi=roi) - for f in l1c_files: - data = mrms_file.match_targets(f.to_xarray_dataset(roi=CONUS)) - data.to_netcdf("test.nc") - -@NEEDS_MRMS_DATA -def test_find_files_gmi(): - """ - Ensure that exactly one GMI MRMS file is found in test data. - """ - path = MRMS_PATH / "GMI2MRMS_match2019" / "db_mrms4GMI" - files = MRMSMatchFile.find_files(path, sensor=sensors.GMI) - assert len(files) > 0 - -############################################################################### -# MHS -############################################################################### - -@NEEDS_MRMS_DATA -def test_read_file_mhs(): - """ - Read MHS match file and ensure that all latitudes roughly match - CONUS coordinates. - """ - path = MRMS_PATH / "MHS2MRMS_match2019" / "monthly_2021" / TEST_FILE_MHS - ms = MRMSMatchFile(path) - - assert np.all(ms.data["latitude"] > 20.0) - assert np.all(ms.data["latitude"] < 70.0) - assert np.all(ms.data["longitude"] > -130.0) - assert np.all(ms.data["longitude"] < -50.0) - - -@NEEDS_MRMS_DATA -def test_match_precip_mhs(): - """ - Match surface precip from MRMS file to observations in L1C file. - """ - path = MRMS_PATH / "MHS2MRMS_match2019" / "monthly_2021" / TEST_FILE_MHS - date = np.datetime64("2018-01-01T01:00:00") - roi = CONUS - - mrms_file = MRMSMatchFile(path, sensor=sensors.MHS) - l1c_files = L1CFile.find_files(date, path, roi=roi, sensor=sensors.MHS) - for f in l1c_files: - data = mrms_file.match_targets(f.to_xarray_dataset(roi=CONUS)) - data.to_netcdf("test.nc") - - -@NEEDS_MRMS_DATA -def test_find_files_mhs(): - """ - Ensure that exactly one GMI MRMS file is found in test data. - """ - path = MRMS_PATH / "MHS2MRMS_match2019" / "monthly_2021" - files = MRMSMatchFile.find_files(path, sensor=sensors.MHS) - assert len(files) == 58 - -############################################################################### -# SSMIS -############################################################################### - -@NEEDS_MRMS_DATA -def test_read_file_ssmis(): - """ - Read SSMIS match file and ensure that all latitudes roughly match - CONUS coordinates. - """ - path = MRMS_PATH / "SSMIS2MRMS_match2019" / "monthly_2021" / TEST_FILE_SSMIS - ms = MRMSMatchFile(path) - ms = MRMSMatchFile(path) - - assert np.all(ms.data["latitude"] > 20.0) - assert np.all(ms.data["latitude"] < 70.0) - assert np.all(ms.data["longitude"] > -130.0) - assert np.all(ms.data["longitude"] < -50.0) - - data = ms.to_xarray_dataset(day=23) - tbs = data.brightness_temperatures.data - valid = tbs >= 0 - tbs = tbs[valid] - assert np.all((tbs >= 0) * (tbs <= 400)) - - -############################################################################### -# AMSR2 -############################################################################### - -@NEEDS_MRMS_DATA -def test_read_file_amsr2(): - """ - Read AMSR2 match file and ensure that all latitudes roughly match - CONUS coordinates. - """ - path = MRMS_PATH / "AMSR22MRMS_match2019" / "monthly_2021" / TEST_FILE_AMSR2 - ms = MRMSMatchFile(path) - - assert np.all(ms.data["latitude"] > 20.0) - assert np.all(ms.data["latitude"] < 70.0) - assert np.all(ms.data["longitude"] > -130.0) - assert np.all(ms.data["longitude"] < -50.0) - - data = ms.to_xarray_dataset(day=23) - tbs = data.brightness_temperatures.data - valid = tbs >= 0 - tbs = tbs[valid] - assert np.all((tbs >= 0) * (tbs <= 400)) - - surface_precip = data.surface_precip.data - assert np.any(np.isfinite(surface_precip)) diff --git a/test/test_normalizer.py b/test/test_normalizer.py deleted file mode 100644 index 11b1d01..0000000 --- a/test/test_normalizer.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -Tests for the gprof_nn.normalizer module. -""" -from gprof_nn import sensors -from gprof_nn.normalizer import get_normalizer - - -def test_get_normalizer(): - - normalizer_gmi = get_normalizer(sensors.GMI) - assert len(normalizer_gmi.stats) == 15 + 9 - - normalizer_gmi_2 = get_normalizer(sensors.GMI, [0, 3]) - assert len(normalizer_gmi_2.stats) == 15 + 7 - assert normalizer_gmi.stats[1] == normalizer_gmi_2.stats[0] - - # For cross-track scanners stats should an entry for each - # channel, earth incidence angle, tcwv and t2m. - normalizer_mhs = get_normalizer(sensors.MHS) - assert len(normalizer_mhs.stats) == sensors.MHS.n_chans + 3 - - for index, gmi_index in enumerate(sensors.MHS.gmi_channels): - normalizer_mhs.stats[index] == normalizer_gmi.stats[gmi_index] - - # For conical scanners stats should an entry for each - # channel, tcwv and t2m. - normalizer_mhs = get_normalizer(sensors.SSMI) - assert len(normalizer_mhs.stats) == sensors.SSMI.n_chans + 2 diff --git a/test/test_plotting.py b/test/test_plotting.py deleted file mode 100644 index b7e0a38..0000000 --- a/test/test_plotting.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -Tests for the plotting module. -""" -from gprof_nn.plotting import set_style - - -def test_set_style(): - """ - Tests that setting the matplotlib style works. - """ - set_style() diff --git a/test/test_preprocessor.py b/test/test_preprocessor.py deleted file mode 100644 index 7b7a091..0000000 --- a/test/test_preprocessor.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -Tests for reading the preprocessor format. -""" -from datetime import datetime -from pathlib import Path - -import numpy as np -import pytest -from quantnn.normalizer import Normalizer - -from gprof_nn import sensors -from gprof_nn.data import get_test_data_path -from gprof_nn.data.preprocessor import ( - PreprocessorFile, - has_preprocessor, - run_preprocessor, - calculate_frozen_precip, - ERA5, - GANAL, -) -from gprof_nn.data.training_data import GPROF_NN_1D_Dataset, write_preprocessor_file -from gprof_nn.data.l1c import L1CFile - - -GPM_DATA = Path("/pdata4/archive/GPM") - -NEEDS_GPM_DATA = pytest.mark.skipif( - not GPM_DATA.exists(), reason="Needs GPM L1C data." -) - -@NEEDS_GPM_DATA -def test_preprocessor_gmi(tmp_path): - """ - Test running the preprocessor for GMI and loading the - results. - """ - l1c_file = L1CFile( - GPM_DATA / - "1CR_GMI_V7/1801/180101/1C-R.GPM.GMI.XCAL2016-C" - ".20180101-S010928-E024202.021833.V07A.HDF5" - ) - l1c_file.extract_scan_range(1000, 1005, tmp_path / "gmi_l1c.HDF5") - pp_data = run_preprocessor( - tmp_path / "gmi_l1c.HDF5", - sensors.GMI - ) - - assert np.all( - pp_data.scan_time > np.datetime64("2018-01-01T00:00:00") - ) - tbs = pp_data.brightness_temperatures.data - valid = tbs >= 0.0 - tbs = tbs[valid] - - assert tbs.size > 0 - assert tbs.max() < 320 - - -@NEEDS_GPM_DATA -def test_preprocessor_amsr2(tmp_path): - """ - Test running the preprocessor for AMSR2 and loading the - results. - """ - l1c_file = L1CFile( - GPM_DATA / - "1C_AMSR2_V7/1501/150101/1C.GCOMW1.AMSR2.XCAL2016-" - "V.20150101-S000954-E014846.013958.V07A.HDF5" - ) - l1c_file.extract_scan_range(1000, 1005, tmp_path / "amsr2_l1c.HDF5") - pp_data = run_preprocessor( - tmp_path / "amsr2_l1c.HDF5", - sensors.AMSR2 - ) - - assert np.all( - pp_data.scan_time > np.datetime64("2015-01-01T00:00:00") - ) - tbs = pp_data.brightness_temperatures.data - valid = tbs >= 0.0 - tbs = tbs[valid] - - assert tbs.size > 0 - assert tbs.max() < 320 diff --git a/test/test_profiles.py b/test/test_profiles.py deleted file mode 100644 index 98fd26e..0000000 --- a/test/test_profiles.py +++ /dev/null @@ -1,67 +0,0 @@ -from pathlib import Path - -import numpy as np - -from gprof_nn.data import get_profile_clusters -from gprof_nn.data.profiles import ProfileClusters - - -def test_load_clusters(): - """ - Ensures that profiles are loaded correctly. - """ - path = get_profile_clusters() - profiles = ProfileClusters(path, True) - rwc = profiles.get_profiles("rain_water_content", 280.0) - assert np.all(np.isclose(rwc[:, -1], 0.0)) - - profiles = ProfileClusters(path, False) - rwc = profiles.get_profiles("cloud_water_content", 280) - assert rwc.shape == (40, 28) - - -def test_get_profile_data(): - """ - Ensure that profile data has expected shape - """ - path = get_profile_clusters() - profiles = ProfileClusters(path, True) - data = profiles.get_profile_data("rain_water_content") - assert data.shape == (12, 28, 40) - - profiles = ProfileClusters(path, False) - data = profiles.get_profile_data("rain_water_content") - assert data.shape == (12, 28, 40) - -def test_get_scales_and_indices(): - """ - Ensures that cluster centers are matched to their respective - indices. - """ - # Raining - path = get_profile_clusters() - profiles = ProfileClusters(path, True) - cwc = profiles.get_profile_data("cloud_water_content") - - scales, indices = profiles.get_scales_and_indices( - "cloud_water_content", - 269.0, - cwc[0].transpose() - ) - assert np.all(np.isclose(indices, - np.arange(40))) - assert np.all(np.isclose(scales, 1.0, rtol=1e-3)) - - - # Non-raining - profiles = ProfileClusters(path, False) - cwc = profiles.get_profile_data("cloud_water_content") - - scales, indices = profiles.get_scales_and_indices( - "cloud_water_content", - 269.0, - cwc[0].transpose() - ) - assert np.all(np.isclose(indices, - np.arange(40))) - assert np.all(np.isclose(scales, 1.0, atol=0.1)) diff --git a/test/test_resolution.py b/test/test_resolution.py deleted file mode 100644 index 2153496..0000000 --- a/test/test_resolution.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np -import scipy as sp -from scipy.fft import fft2, ifft2 -import xarray as xr - -from gprof_nn.resolution import FourierAnalysis - - -def make_test_window_fourier(): - """ - Makes a test structure to test fourier analysis. - """ - - x = np.arange(64).reshape(1, -1) - y = np.arange(64).reshape(-1, 1) - # Signal at n = 2 in x and n = 8 in y direction. - z = np.cos(2 * np.pi * x / 64 * 2) + np.cos(2 * np.pi * y / 64 * 16) - - results = xr.Dataset({ - "surface_precip": (("along_track", "across_track"), z) - }) - - # Only low frequ signel in GPROF - z_gprof = np.cos(2 * np.pi * x / 64 * 2) + 0.0 * y - results_gprof = xr.Dataset({ - "surface_precip": (("along_track", "across_track"), z_gprof) - }) - - window = { - "reference": results, - "gprof": results_gprof - } - return window - - -def test_fourier_analysis(): - """ - Test calculation of energy spectra using Fourier analysis. - """ - window = make_test_window_fourier() - fa = FourierAnalysis(["gprof"]) - fa.process(window) - r = fa.get_statistics()["gprof"] - - print(r.energy_ret.data) - # WN 2 present in both datasets - assert not np.all(np.isclose(r.energy_ret[2 * 2], 0.0)) - assert not np.all(np.isclose(r.energy_ref[2 * 2], 0.0)) - - # WN 16 present in only one of them - assert np.all(np.isclose(r.energy_ret[2 * 16], 0.0)) - assert not np.all(np.isclose(r.energy_ref[2 * 16], 0.0)) - - diff --git a/test/test_retrieval.py b/test/test_retrieval.py deleted file mode 100644 index a0ad675..0000000 --- a/test/test_retrieval.py +++ /dev/null @@ -1,414 +0,0 @@ -""" -Tests for code running, writing and reading retrieval data. -""" -from pathlib import Path - -import numpy as np -import pytest -import torch -import xarray as xr - -from quantnn import QRNN -from quantnn.normalizer import Normalizer - -from gprof_nn import sensors -from gprof_nn.data import get_model_path, get_test_data_path -from gprof_nn.data import get_profile_clusters -from gprof_nn.data.training_data import GPROF_NN_1D_Dataset -from gprof_nn.data.preprocessor import PreprocessorFile -from gprof_nn.data.retrieval import (RetrievalFile, - ORBIT_HEADER_TYPES, - PROFILE_INFO_TYPES) -from gprof_nn.retrieval import (calculate_padding_dimensions, - RetrievalDriver, - RetrievalGradientDriver, - PreprocessorLoader1D, - PreprocessorLoader3D, - NetcdfLoader1D, - NetcdfLoader3D, - SimulatorLoader) - - -DATA_PATH = get_test_data_path() - - -def test_calculate_padding_dimensions(): - """ - Ensure that padding values match expected values and that - the order is inverse to that of tensor axes. - """ - x = torch.ones(32, 32) - padding = calculate_padding_dimensions(x) - assert padding == (0, 0, 0, 0) - - x = torch.ones(16, 32) - padding = calculate_padding_dimensions(x) - assert padding == (0, 0, 8, 8) - - x = torch.ones(32, 16) - padding = calculate_padding_dimensions(x) - assert padding == (8, 8, 0, 0) - - -def test_retrieval_read_and_write(tmp_path): - """ - Ensure that reading data from a retrieval file and writing that - data into a retrieval conserves data. - - This checks both the writing of the GPROF binary retrieval file - format including all headers as well as the parsing of the format. - """ - retrieval_file = (DATA_PATH / "gmi" / - "retrieval" / "GMIERA5_190101_027510.bin") - retrieval_file = RetrievalFile(retrieval_file, has_profiles=True) - retrieval_data = retrieval_file.to_xarray_dataset(full_profiles=False) - preprocessor_file = PreprocessorFile( - DATA_PATH / "gmi" / "pp" / "GMIERA5_190101_027510.pp" - ) - ancillary_data = get_profile_clusters() - output_file = preprocessor_file.write_retrieval_results( - tmp_path, - retrieval_data, - ancillary_data=ancillary_data) - output_file = RetrievalFile(output_file) - - # - # Orbit header. - # - exceptions = [ - "preprocessor", - "algorithm", - "creation_date", - "granule_end_date", - ] - - for k in ORBIT_HEADER_TYPES.fields: - if k not in exceptions: - assert retrieval_file.orbit_header[k] == output_file.orbit_header[k] - - # - # Check profile info. - # - - #for k in PROFILE_INFO_TYPES.fields: - # if not k == "species_description": - # assert(np.all(np.isclose(retrieval_file.profile_info[k], - # output_file.profile_info[k]))) - output_data = output_file.to_xarray_dataset() - - # - # Check retrieval data. - # - - for v in retrieval_data.variables: - if v in ["two_meter_temperature", "frozen_precip"]: - continue - if v not in ["surface_precip", "convective_precip"]: - continue - assert np.all(np.isclose(retrieval_data[v].data, - output_data[v].data, - rtol=1e-2)) - - # Ensure that scan date matches time stamp of file. - assert np.all(np.isclose( - retrieval_data.scan_time.dt.year.data, - 2019 - )) - assert np.all(np.isclose( - retrieval_data.scan_time.dt.month.data, - 1 - )) - assert np.all(np.isclose( - retrieval_data.scan_time.dt.day.data, - 1 - )) - - -def test_retrieval_preprocessor_1d_gmi(tmp_path): - """ - Ensure that GPROF-NN 1D retrieval works with preprocessor input. - """ - input_file = DATA_PATH / "gmi" / "pp" / "GMIERA5_190101_027510.pp" - - model_path = get_model_path("1D", sensors.GMI, "ERA5") - qrnn = QRNN.load(model_path) - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path) - output_file = driver.run() - data = RetrievalFile(output_file).to_xarray_dataset() - assert "rain_water_content" in data.variables - -@pytest.mark.xfail -def test_retrieval_l1c_1d_gmi(tmp_path): - """ - Ensure that GPROF-NN 1D retrieval works with preprocessor input. - """ - input_file = ( - DATA_PATH / "gmi" / - "1C-R.GPM.GMI.XCAL2016-C.20180124-S000358-E013632.022190.V05A.HDF5" - ) - qrnn = QRNN.load(DATA_PATH / "gmi" / "gprof_nn_1d_gmi_era5_na.pckl") - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path) - output_file = driver.run() - data = RetrievalFile(output_file).to_xarray_dataset() - assert "rain_water_content" in data.variables - - -def test_retrieval_preprocessor_1d_mhs(tmp_path): - """ - Ensure that GPROF-NN 1D retrieval works with preprocessor input. - """ - input_file = DATA_PATH / "mhs" / "pp" / "MHS.pp" - - model_path = get_model_path("1D", sensors.MHS_NOAA19, "ERA5") - qrnn = QRNN.load(model_path) - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path) - output_file = driver.run() - data = RetrievalFile(output_file).to_xarray_dataset() - assert "rain_water_content" in data.variables - - -def test_retrieval_preprocessor_3d(tmp_path): - """ - Ensure that GPROF-NN 3D retrieval works with preprocessor input. - """ - input_file = DATA_PATH / "gmi" / "pp" / "GMIERA5_190101_027510.pp" - - model_path = get_model_path("3D", sensors.GMI, "ERA5") - qrnn = QRNN.load(model_path) - qrnn.model.sensor = sensors.GMI - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path) - output_file = driver.run() - data = RetrievalFile(output_file).to_xarray_dataset() - assert "rain_water_content" in data.variables - - -def test_retrieval_preprocessor_3d_tiled(tmp_path): - """ - Ensure that running the 3D tiled retrieval yields the expected - output dimensions. - """ - input_file = DATA_PATH / "gmi" / "pp" / "GMIERA5_190101_027510.pp" - - model_path = get_model_path("3D", sensors.GMI, "ERA5") - qrnn = QRNN.load(model_path) - qrnn.model.sensor = sensors.GMI - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path) - output_file = driver.run() - data = RetrievalFile(output_file).to_xarray_dataset() - - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path, - tiling=(128, 128)) - output_file = driver.run() - data_tiled = RetrievalFile(output_file).to_xarray_dataset() - - assert data.scans.size == data_tiled.scans.size - - -def test_retrieval_hr_tiled(tmp_path): - """ - Ensure that running the 3D tiled retrieval yields the expected - output dimensions. - """ - input_file = ( - DATA_PATH / "gmi" / "l1c" / - "1C-R.GPM.GMI.XCAL2016-C.20190101-S001447-E014719.027510.V07A.HDF5" - ) - - model_path = get_model_path("HR", sensors.GMI, "ERA5") - qrnn = QRNN.load(model_path) - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path) - output_file = driver.run() - data = xr.load_dataset(output_file) - - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path, - tiling=(128, 128)) - output_file = driver.run() - data_tiled = xr.load_dataset(output_file) - - assert data.scans.size == data_tiled.scans.size - - -@pytest.mark.xfail -def test_retrieval_l1c_3d(tmp_path): - """ - Ensure that GPROF-NN 1D retrieval works with preprocessor input. - """ - input_file = ( - DATA_PATH / "gmi" / "l1c" - "1C-R.GPM.GMI.XCAL2016-C.20180124-S000358-E013632.022190.V05A.HDF5" - ) - - qrnn = QRNN.load(DATA_PATH / "gmi" / "gprof_nn_3d_gmi_era5_na.pckl") - qrnn.model.sensor = sensors.GMI - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path, - compress=False) - output_file = driver.run() - data = RetrievalFile(output_file).to_xarray_dataset() - assert "rain_water_content" in data.variables - - -def test_retrieval_bin_file_1d(tmp_path): - """ - Ensure that GPROF-NN 1D retrieval works with NetCDF input. - """ - input_file = DATA_PATH / "gmi" / "bin" / "gpm_275_14_03_17.bin" - - model_path = get_model_path("1D", sensors.GMI, "ERA5") - qrnn = QRNN.load(model_path) - qrnn.training_data_class = GPROF_NN_1D_Dataset - qrnn.preprocessor_class = PreprocessorLoader1D - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path, - compress=False) - output_file = driver.run() - data = xr.load_dataset(output_file) - assert "rain_water_content" in data.variables - assert "rain_water_content_true" in data.variables - - -def test_retrieval_netcdf_1d(tmp_path): - """ - Ensure that GPROF-NN 1D retrieval works with NetCDF input. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc" - - model_path = get_model_path("1D", sensors.GMI, "ERA5") - qrnn = QRNN.load(model_path) - qrnn.training_data_class = GPROF_NN_1D_Dataset - qrnn.preprocessor_class = PreprocessorLoader1D - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path, - compress=False) - output_file = driver.run() - data = xr.load_dataset(output_file) - assert "rain_water_content" in data.variables - assert "rain_water_content_true" in data.variables - - -def test_retrieval_netcdf_1d_full(tmp_path): - """ - Test running the 1D retrieval with the spatial structure retained. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc" - - model_path = get_model_path("1D", sensors.GMI, "ERA5") - qrnn = QRNN.load(model_path) - qrnn.training_data_class = GPROF_NN_1D_Dataset - qrnn.preprocessor_class = PreprocessorLoader1D - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path, - compress=False, - preserve_structure=True) - output_file = driver.run() - data = xr.load_dataset(output_file) - assert "rain_water_content" in data.variables - assert "rain_water_content_true" in data.variables - -def test_retrieval_netcdf_1d_gradients(tmp_path): - """ - Ensure that GPROF-NN 1D retrieval with NetCDF input and gradients - works. - """ - data_path = Path(__file__).parent / "data" - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc" - - model_path = get_model_path("1D", sensors.GMI, "ERA5") - qrnn = QRNN.load(model_path) - qrnn.training_data_class = GPROF_NN_1D_Dataset - qrnn.preprocessor_class = PreprocessorLoader1D - driver = RetrievalGradientDriver(input_file, - qrnn, - output_file=tmp_path, - compress=False) - output_file = driver.run() - data = xr.load_dataset(output_file) - assert "surface_precip_grad" in data.variables - - -def test_retrieval_netcdf_3d(tmp_path): - """ - Ensure that GPROF-NN 3D retrieval works with NetCDF input. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc" - - model_path = get_model_path("3D", sensors.GMI, "ERA5") - qrnn = QRNN.load(model_path) - qrnn.model.sensor = sensors.GMI - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path, - compress=False) - output_file = driver.run() - data = xr.load_dataset(output_file) - assert "rain_water_content" in data.variables - assert "pixels" in data.dims.keys() - assert "scans" in data.dims.keys() - - -@pytest.mark.xfail -def test_simulator_gmi(tmp_path): - """ - Ensure that GPROF-NN 3D retrieval works with NetCDF input. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc" - - qrnn = QRNN.load(DATA_PATH / "gmi" / "simulator_gmi.pckl") - qrnn.netcdf_class = SimulatorLoader - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path, - compress=False) - output_file = driver.run() - data = xr.load_dataset(output_file) - - assert "simulated_brightness_temperatures" in data.variables - assert "brightness_temperature_biases" in data.variables - - data_0 = data[{"samples": data.source == 0}] - tbs_sim = data_0["simulated_brightness_temperatures"].data - assert np.all(np.isfinite(tbs_sim)) - - -@pytest.mark.xfail -def test_simulator_mhs(tmp_path): - """ - Ensure that GPROF-NN 3D retrieval works with NetCDF input. - """ - input_file = DATA_PATH / "gprof_nn_mhs_era5_5.nc" - - qrnn = QRNN.load(DATA_PATH / "mhs" / "simulator_mhs.pckl") - driver = RetrievalDriver(input_file, - qrnn, - output_file=tmp_path, - compress=False) - output_file = driver.run() - data = xr.load_dataset(output_file) - - assert "simulated_brightness_temperatures" in data.variables - assert "brightness_temperature_biases" in data.variables - - data_0 = data[{"samples": data.source == 0}] - tbs_sim = data_0["simulated_brightness_temperatures"].data - assert np.all(np.isfinite(tbs_sim)) - diff --git a/test/test_retrieval_amsr2.sh b/test/test_retrieval_amsr2.sh deleted file mode 100755 index 3db189c..0000000 --- a/test/test_retrieval_amsr2.sh +++ /dev/null @@ -1,9 +0,0 @@ -#! /bin/bash -wget -q https://rain.atmos.colostate.edu/gprof_nn/test/gcomw1_amsr2.pp -echo "Running GPROF-NN 1D retrieval for AMSR 2." -gprof_nn 1d ERA5 gcomw1_amsr2.pp -o test.bin -python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.mean(RetrievalFile('test.bin').to_xarray_dataset().surface_precip.data >= 0.0) > 0.99 else 0;" - -#echo "Running GPROF-NN 3D retrieval for AMSR 2." -#gprof_nn 3d ERA5 gcomw1_amsr2.pp -#python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.mean(RetrievalFile('test.bin').to_xarray_dataset().surface_precip.data >= 0.0) > 0.99 else 0;" diff --git a/test/test_retrieval_amsre.sh b/test/test_retrieval_amsre.sh deleted file mode 100755 index 5110ee3..0000000 --- a/test/test_retrieval_amsre.sh +++ /dev/null @@ -1,9 +0,0 @@ -#! /bin/bash -wget -q https://rain.atmos.colostate.edu/gprof_nn/test/aqua_amsre.pp -echo "Running GPROF-NN 1D retrieval for AMSRE." -gprof_nn 1d ERA5 aqua_amsre.pp -o test.bin -python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.all(RetrievalFile('test.bin').to_xarray_dataset().surface_precip >= 0.0) else 0;" - -#echo "Running GPROF-NN 3D retrieval for AMSRE." -#gprof_nn 3d ERA5 aqua_amsre.pp -#python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.all(RetrievalFile('test.bin').to_xarray_dataset().surface_precip >= 0.0) else 0;" diff --git a/test/test_retrieval_f15.sh b/test/test_retrieval_f15.sh deleted file mode 100755 index f2be41a..0000000 --- a/test/test_retrieval_f15.sh +++ /dev/null @@ -1,9 +0,0 @@ -#! /bin/bash -wget -q https://rain.atmos.colostate.edu/gprof_nn/test/f15_era5.pp -echo "Running GPROF-NN 1D retrieval for SSMI." -gprof_nn 1d ERA5 f15_era5.pp -o test.bin -python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.all(RetrievalFile('test.bin').to_xarray_dataset().surface_precip >= 0.0) else 0;" - -#echo "Running GPROF-NN 3D retrieval for SSMI." -#gprof_nn 3d ERA5 f15_era5.pp -#python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.all(RetrievalFile('test.bin').to_xarray_dataset().surface_precip >= 0.0) else 0;" diff --git a/test/test_retrieval_f17.sh b/test/test_retrieval_f17.sh deleted file mode 100755 index 14d40f6..0000000 --- a/test/test_retrieval_f17.sh +++ /dev/null @@ -1,9 +0,0 @@ -#! /bin/bash -wget -q https://rain.atmos.colostate.edu/gprof_nn/test/f17_era5.pp -echo "Running GPROF-NN 1D retrieval for SSMIS." -gprof_nn 1d ERA5 f17_era5.pp -o test.bin -python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.all(RetrievalFile('test.bin').to_xarray_dataset().surface_precip >= 0.0) else 0;" - -#echo "Running GPROF-NN 3D retrieval for SSMIS." -#gprof_nn 3d ERA5 f17_era5.pp -#python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.all(RetrievalFile('test.bin').to_xarray_dataset().surface_precip >= 0.0) else 0;" diff --git a/test/test_retrieval_gmi.sh b/test/test_retrieval_gmi.sh deleted file mode 100755 index 59a1bdf..0000000 --- a/test/test_retrieval_gmi.sh +++ /dev/null @@ -1,9 +0,0 @@ -#! /bin/bash -wget -q https://rain.atmos.colostate.edu/gprof_nn/test/gmi_era5_harvey.pp -echo "Running GPROF-NN 1D retrieval." -gprof_nn 1d ERA5 gmi_era5_harvey.pp -o test_gmi.bin -python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.all(RetrievalFile('test_gmi.bin').to_xarray_dataset().surface_precip >= 0.0) else 0;" -rm test_gmi.bin -#echo "Runing GPROF-NN 3D retrieval." -#gprof_nn 3d ERA5 gmi_era5_harvey.pp -o test.bin -#python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.all(RetrievalFile('test.bin').to_xarray_dataset().surface_precip >= 0.0) else 0;" diff --git a/test/test_retrieval_gmi_hr.sh b/test/test_retrieval_gmi_hr.sh deleted file mode 100755 index 102a9e5..0000000 --- a/test/test_retrieval_gmi_hr.sh +++ /dev/null @@ -1,5 +0,0 @@ -#! /bin/bash -wget -q https://rain.atmos.colostate.edu/gprof_nn/test/l1cr_gmi_test.HDF5 -echo "Running GPROF-NN HR retrieval." -gprof_nn hr l1cr_gmi_test.HDF5 -o test.nc -python -c "import xarray as xr; import numpy as np; 1/0 if not np.all(xr.load_dataset('test.nc').surface_precip.data >= 0.0) else 0;" diff --git a/test/test_retrieval_mhs.sh b/test/test_retrieval_mhs.sh deleted file mode 100644 index 4331acd..0000000 --- a/test/test_retrieval_mhs.sh +++ /dev/null @@ -1,6 +0,0 @@ -#! /bin/bash -wget https://rain.atmos.colostate.edu/gprof_nn/test/mhs_era5_harvey.pp -echo "Running GPROF-NN 1D retrieval." -gprof_nn 1d ERA5 mhs_era5_harvey.pp -echo "Running GPROF-NN 3D retrieval." -gprof_nn 3d ERA5 mhs_era5_harvey.pp diff --git a/test/test_retrieval_tmi.sh b/test/test_retrieval_tmi.sh deleted file mode 100755 index 15c8b2c..0000000 --- a/test/test_retrieval_tmi.sh +++ /dev/null @@ -1,18 +0,0 @@ -#! /bin/bash -wget -q https://rain.atmos.colostate.edu/gprof_nn/test/tmipr_era5.pp -echo "Running GPROF-NN 1D retrieval for TMIPR." -gprof_nn 1d ERA5 tmipr_era5.pp -o test_tmipr.bin -python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.all(RetrievalFile('test_tmipr.bin').to_xarray_dataset().surface_precip >= 0.0) else 0;" -rm test_tmipr.bin -#echo "Running GPROF-NN 3D retrieval for TMIPR." -#gprof_nn 3d ERA5 tmipr_era5.pp -#python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.all(RetrievalFile('test.bin').to_xarray_dataset().surface_precip >= 0.0) else 0;" - -wget -q https://rain.atmos.colostate.edu/gprof_nn/test/tmipo_era5.pp -echo "Running GPROF-NN 1D retrieval for TMIPO." -gprof_nn 1d ERA5 tmipo_era5.pp -o test_tmipo.bin -python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.all(RetrievalFile('test_tmipo.bin').to_xarray_dataset().surface_precip >= 0.0) else 0;" -rm test_tmipo.bin -#echo "Running GPROF-NN 3D retrieval for TMIPO." -#gprof_nn 3d ERA5 tmipo_era5.pp -#python -c "from gprof_nn.data.retrieval import RetrievalFile; import numpy as np; 1/0 if not np.all(RetrievalFile('test.bin').to_xarray_dataset().surface_precip >= 0.0) else 0;" diff --git a/test/test_sensors.py b/test/test_sensors.py deleted file mode 100644 index 0c82d42..0000000 --- a/test/test_sensors.py +++ /dev/null @@ -1,322 +0,0 @@ -""" -Tests for the data loading function of the sensor classes. -""" -from pathlib import Path - -import numpy as np -import xarray as xr - -from gprof_nn import sensors -from gprof_nn.data import get_test_data_path -from gprof_nn.data.training_data import decompress_and_load - -TEST_FILE_GMI = Path("gmi") / "gprof_nn_gmi_era5.nc.gz" -TEST_FILE_MHS = Path("mhs") / "gprof_nn_mhs_era5.nc.gz" -TEST_FILE_TMI = Path("tmi") / "gprof_nn_tmi_era5.nc.gz" - - -DATA_PATH = get_test_data_path() - - -def test_calculate_smoothing_kernel(): - """ - Ensure that 'calculate_smoothing_kernel' returns kernels with the right - FWHM. - """ - k = sensors.calculate_smoothing_kernel(1, 1, 2, 2, 11) - c = k[5, 5] - c2 = k[5, 3] - assert np.isclose(c2 / c, 0.5) - c = k[5, 5] - c2 = k[3, 5] - assert np.isclose(c2 / c, 0.5) - - -def test_calculate_smoothing_kernels(): - """ - Ensure that 'calculate_smoothing_kernels' returns one kernel for each - viewing angle and that the kernels have the expected shape. - """ - kernels = sensors.calculate_smoothing_kernels(sensors.MHS) - assert len(kernels) == sensors.MHS.n_angles - assert kernels[0].shape == (11, 11) - - -def test_smooth_gmi_field(): - """ - Ensure that smoothing a GMI field inserts the smoothed field along the - right axis. - """ - field = np.zeros((32, 32, 4)) - field[15, 15] = 1.0 - - kernels = sensors.calculate_smoothing_kernels(sensors.MHS) - kernels = [kernels[0]] * 10 - field_s = sensors.smooth_gmi_field(field, kernels) - - assert field_s.shape[2] == sensors.MHS.n_angles - assert np.all(np.isclose(field_s[:, :, 0], field_s[:, :, 1], atol=1e-3)) - - -def test_load_training_data_1d_gmi(): - """ - Ensure that loading the training data for GMI produces realistic - values. - """ - input_file = DATA_PATH / TEST_FILE_GMI - input_data = decompress_and_load(input_file) - - sensor = sensors.GMI - - targets = ["surface_precip", "rain_water_content"] - rng = np.random.default_rng() - - x, y = sensor.load_training_data_1d(input_data, targets, False, rng) - - # TB ranges - assert np.all(x[:, :5] > 20) - assert np.all(x[:, :5] < 500) - # Two-meter temperature - assert np.all(x[:, 15] > 200) - assert np.all(x[:, 15] < 400) - # TCWV - assert np.all(x[:, 16] >= 0) - assert np.all(x[:, 16] < 100) - - # Assert all targets are loaded - assert all(t in y for t in targets) - - # Ensure that loaded surface precip is within the range given - # of the surface precip observed for the different angles. - sp_ref = input_data.surface_precip.data - mask = sp_ref >= 0 - sp_ref = sp_ref[mask] - sp = y["surface_precip"] - mask = np.isfinite(sp) - assert np.all(sp_ref[mask].max(axis=-1) >= sp[mask]) - assert np.all(sp_ref[mask].min(axis=-1) <= sp[mask]) - - st = x[:, -22:-4] > 0 - assert np.all(np.isclose(st.sum(axis=1), 1.0)) - - at = x[:, -4:] > 0 - assert np.all(np.isclose(at.sum(axis=1), 1.0)) - - -def test_load_training_data_3d_gmi(): - """ - Ensure that loading the training data for GMI produces realistic - values. - """ - input_file = DATA_PATH / TEST_FILE_GMI - input_data = decompress_and_load(input_file) - - sensor = sensors.GMI - - targets = ["surface_precip", "rain_water_content"] - rng = np.random.default_rng() - - x, y = sensor.load_training_data_3d( - input_data, targets, False, rng - ) - - # TB ranges - valid = np.isfinite(x[:, :5]) - assert np.all(x[:, :5][valid] > 20) - assert np.all(x[:, :5][valid] < 500) - # Two-meter temperature - assert np.all(x[:, 15] > 200) - assert np.all(x[:, 15] < 400) - # TCWV - assert np.all(x[:, 16] >= 0) - assert np.all(x[:, 16] < 100) - - # Assert all targets are loaded - assert all(t in y for t in targets) - - # Ensure that loaded surface precip is within the range given - # of the surface precip observed for the different angles. - sp = y["surface_precip"] - mask = sp >= 0 - assert np.all(sp[mask] < 500) - - tbs = x[:, :15] - valid = tbs >= 0 - assert np.all((tbs[valid] > 20) * (tbs[valid] < 400)) - - st = x[:, -22:-4] > 0 - assert np.all(np.isclose(st.sum(axis=1), 1.0)) - - at = x[:, -4:] > 0 - assert np.all(np.isclose(at.sum(axis=1), 1.0)) - - -def test_load_training_data_1d_mhs(): - """ - Ensure that loading the training data for MHS produces realistic - values. - """ - input_file = DATA_PATH / TEST_FILE_MHS - input_data = decompress_and_load(input_file) - - sensor = sensors.MHS - - targets = ["surface_precip", "rain_water_content"] - rng = np.random.default_rng() - - x, y = sensor.load_training_data_1d(input_data, targets, False, rng) - - # TB ranges - assert np.all(x[:, :5] > 20) - assert np.all(x[:, :5] < 500) - # Earth incidence angles - assert np.all(x[:, 5] > -65) - assert np.all(x[:, 5] < 65) - # Two-meter temperature - assert np.all(x[:, 6] > 200) - assert np.all(x[:, 6] < 400) - # TCWV - assert np.all(x[:, 7] >= 0) - assert np.all(x[:, 7] < 100) - - # Assert all targets are loaded - assert all(t in y for t in targets) - - -def test_load_training_data_3d_mhs(): - """ - Ensure that loading the training data for MHS produces realistic - values. - """ - input_file = DATA_PATH / TEST_FILE_MHS - input_data = decompress_and_load(input_file) - - sensor = sensors.MHS - - targets = ["surface_precip", "rain_water_content"] - rng = np.random.default_rng() - - x, y = sensor.load_training_data_3d(input_data, targets, False, rng) - - # Two-meter temperature - valid = np.isfinite(x[:, 6]) - assert np.all(x[:, 6][valid] > 150) - assert np.all(x[:, 6][valid] < 400) - # TCWV - valid = np.isfinite(x[:, 7]) - assert np.all(x[:, 7][valid] >= 0) - assert np.all(x[:, 7][valid] < 100) - - # Assert all targets are loaded - assert all(t in y for t in targets) - - # Ensure that loaded surface precip is within the range given - # of the surface precip observed for the different angles. - sp = y["surface_precip"] - mask = sp >= 0 - assert np.all(sp[mask] < 500) - - tbs = x[:, :5] - valid = tbs >= 0 - assert np.all((tbs[valid] > 20) * (tbs[valid] < 400)) - - -def test_interpolation_mhs(tmp_path): - """ - Ensure that interpolation of surface precipitation - works. - """ - input_file = DATA_PATH / TEST_FILE_MHS - input_data = decompress_and_load(input_file) - - input_data = input_data.sel({"samples": input_data.source == 0}) - - for i in range(10): - input_data.surface_precip[:, :, :, i] = i - - input_data.to_netcdf(tmp_path / "test.nc") - - sensor = sensors.MHS - sensor._angles = np.arange(10) - - targets = ["surface_precip", "rain_water_content"] - rng = np.random.default_rng() - - x, y = sensor.load_training_data_1d(tmp_path / "test.nc", targets, False, rng) - - # Assert all targets are loaded - sp = y["surface_precip"] - va = np.abs(x[:, 5]) - - inds = (sp > 1.0) * (sp < 8.0) - assert np.all(np.isclose(va[inds], sp[inds])) - - -def test_load_training_data_1d_tmi(): - """ - Ensure that loading the training data for TMI produces realistic - values. - """ - DATA_PATH = Path(__file__).parent / "data" - input_file = DATA_PATH / TEST_FILE_TMI - input_data = decompress_and_load(input_file) - - sensor = sensors.TMI - - targets = ["surface_precip", "rain_water_content"] - rng = np.random.default_rng() - - x, y = sensor.load_training_data_1d(input_data, targets, False, rng) - - # TB ranges - assert np.all(x[:, :9] > 20) - assert np.all(x[:, :9] < 500) - # Two-meter temperature - assert np.all(x[:, 9] > 200) - assert np.all(x[:, 9] < 400) - # TCWV - assert np.all(x[:, 10] >= 0) - assert np.all(x[:, 10] < 100) - - # Assert all targets are loaded - assert all(t in y for t in targets) - - -def test_load_training_data_3d_tmi(): - """ - Ensure that loading the training data for TMI produces realistic - values. - """ - DATA_PATH = Path(__file__).parent / "data" - input_file = DATA_PATH / TEST_FILE_TMI - input_data = decompress_and_load(input_file) - - sensor = sensors.TMI - - targets = ["surface_precip", "rain_water_content"] - rng = np.random.default_rng() - - x, y = sensor.load_training_data_3d(input_data, targets, False, rng) - - # Two-meter temperature - valid = np.isfinite(x[:, 6]) - assert np.all(x[:, 9][valid] > 150) - assert np.all(x[:, 9][valid] < 400) - # TCWV - valid = np.isfinite(x[:, 7]) - assert np.all(x[:, 10][valid] >= 0) - assert np.all(x[:, 10][valid] < 100) - - # Assert all targets are loaded - assert all(t in y for t in targets) - - # Ensure that loaded surface precip is within the range given - # of the surface precip observed for the different angles. - sp = y["surface_precip"] - mask = sp >= 0 - assert np.all(sp[mask] < 500) - - tbs = x[:, :5] - valid = tbs >= 0 - assert np.all((tbs[valid] > 20) * (tbs[valid] < 400)) - diff --git a/test/test_statistics.py b/test/test_statistics.py deleted file mode 100644 index e07b472..0000000 --- a/test/test_statistics.py +++ /dev/null @@ -1,693 +0,0 @@ -""" -Tests for the gprof_nn.statistics module. -""" -from pathlib import Path - -import numpy as np -import xarray as xr - -from gprof_nn import sensors -from gprof_nn.data.bin import BinFile -from gprof_nn.data import get_test_data_path -from gprof_nn.definitions import ALL_TARGETS, LAT_BINS, TIME_BINS -from gprof_nn.data.preprocessor import PreprocessorFile -from gprof_nn.data.retrieval import RetrievalFile -from gprof_nn.data.training_data import (GPROF_NN_1D_Dataset, - decompress_and_load) -from gprof_nn.data.combined import GPMCMBFile -from gprof_nn.statistics import (StatisticsProcessor, - TrainingDataStatistics, - BinFileStatistics, - ObservationStatistics, - GlobalDistribution, - ZonalDistribution, - RetrievalStatistics, - GPMCMBStatistics, - resample_scans) - - -DATA_PATH = get_test_data_path() - - -def test_training_statistics_gmi(tmpdir): - """ - Ensure that TrainingDataStatistics class reproduces statistic of - GMI training data file. - """ - files = [DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc.gz"] * 2 - - - stats = [TrainingDataStatistics(kind="1d"), - ZonalDistribution(monthly=False), - GlobalDistribution()] - processor = StatisticsProcessor(sensors.GMI, - files, - stats) - processor.run(2, tmpdir) - input_data = GPROF_NN_1D_Dataset(files[0], - normalize=False, - shuffle=False, - targets=ALL_TARGETS) - input_data = input_data.to_xarray_dataset() - - results = xr.open_dataset(str( - tmpdir / - "training_data_statistics_gmi.nc" - )) - - # Ensure TB dists match. - for st in range(1, 19): - bins = np.linspace(0, 400, 401) - i_st = ((input_data.surface_type == st) * - (input_data.surface_precip >= 0)).data - - tbs = input_data["brightness_temperatures"].data[i_st] - counts_ref, _ = np.histogram(tbs[:, 0], bins=bins) - counts = results["brightness_temperatures"][st - 1, 0].data - - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - tcwv = input_data["total_column_water_vapor"].data[i_st] - bins_tcwv = np.linspace(-0.5, 99.5, 101) - counts_ref, _, _ = np.histogram2d( - tcwv, tbs[:, 0], - bins=(bins_tcwv, bins) - ) - counts = results["brightness_temperatures_tcwv"][st - 1, 0].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure surface_precip dists match. - bins = np.logspace(-3, np.log10(2e2), 201) - x = input_data["surface_precip"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["surface_precip"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure RWC distributions match. - bins = np.logspace(-4, np.log10(2e1), 201) - x = input_data["rain_water_content"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["rain_water_content"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure two-meter-temperature distributions match. - bins = np.linspace(239.5, 339.5, 101) - x = input_data["two_meter_temperature"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["two_meter_temperature"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure surface type distributions match - bins = np.arange(19) + 0.5 - x = input_data["surface_type"].data - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["surface_type"].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # - # Zonal distributions - # - - input_data = decompress_and_load(files[0]) - results = xr.open_dataset(str(tmpdir / "zonal_distribution_gmi.nc")) - lat_bins = np.linspace(-90, 90, 181) - sp_bins = np.logspace(-2, 2.5, 201) - bins = (lat_bins, sp_bins) - sp = input_data["surface_precip"].data - lats = input_data["latitude"].data - valid = sp >= 0.0 - sp = sp[valid] - lats = lats[valid] - cs_ref, _, _ = np.histogram2d(lats, sp, bins=bins) - cs = results["surface_precip_mean"].data - assert np.all(np.isclose(2.0 * cs_ref, cs)) - - # - # Global distributions - # - - input_data = decompress_and_load(files[0]) - results = xr.open_dataset(str(tmpdir / "global_distribution_gmi.nc")) - lat_bins = np.arange(-90, 90 + 1e-3, 5) - lon_bins = np.arange(-180, 180 + 1e-3, 5) - sp_bins = np.logspace(-2, 2.5, 201) - sp = input_data["surface_precip"].data - lons = input_data["longitude"].data - lats = input_data["latitude"].data - valid = sp >= 0.0 - sp = sp[valid] - lats = lats[valid] - lons = lons[valid] - bins = (lat_bins, lon_bins, sp_bins) - vals = np.stack([lats, lons, sp], axis=-1) - cs_ref, _ = np.histogramdd(vals, bins=bins) - cs = results["surface_precip_mean"].data - assert np.all(np.isclose(2.0 * cs_ref, cs)) - - -def test_training_statistics_mhs(tmpdir): - """ - Ensure that TrainingDataStatistics class reproduces statistic of - MHS training data file. - """ - files = [DATA_PATH / "mhs" / "gprof_nn_mhs_era5.nc"] * 2 - - - stats = [TrainingDataStatistics(kind="1D"), - GlobalDistribution(), - ZonalDistribution()] - processor = StatisticsProcessor(sensors.MHS, - files, - stats) - processor.run(2, tmpdir) - input_data = GPROF_NN_1D_Dataset(files[0], - normalize=False, - shuffle=False, - targets=ALL_TARGETS) - input_data = input_data.to_xarray_dataset() - - results = xr.open_dataset(str( - tmpdir / - "training_data_statistics_mhs.nc" - )) - - # Ensure total column water vapor distributions match. - st = 1 - bins = np.linspace(-0.5, 99.5, 101) - i_st = (input_data.surface_type == 1).data - - x = input_data["total_column_water_vapor"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["total_column_water_vapor"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - -def test_bin_statistics_gmi(tmpdir): - """ - Ensure that TrainingDataStatistics class reproduces statistic of - GMI bin files. - """ - files = [DATA_PATH / "gmi" / "bin" / "gpm_269_00_16.bin"] * 2 - - - stats = [BinFileStatistics(), - ZonalDistribution(), - GlobalDistribution()] - processor = StatisticsProcessor(sensors.GMI, - files, - stats) - processor.run(2, tmpdir) - input_data = BinFile(files[0]).to_xarray_dataset() - - results = xr.open_dataset(str( - tmpdir / - "bin_file_statistics_gmi.nc" - )) - - # Ensure TB dists match. - st = 4 - bins = np.linspace(0, 400, 401) - i_st = (input_data.surface_type == st).data - tbs = input_data["brightness_temperatures"].data[i_st] - counts_ref, _ = np.histogram(tbs[:, 0], bins=bins) - counts = results["brightness_temperatures"][st - 1, 0].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure surface_precip dists match. - bins = np.logspace(-3, np.log10(2e2), 201) - i_st = (input_data.surface_type == st).data - x = input_data["surface_precip"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["surface_precip"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure RWC distributions match. - bins = np.logspace(-4, np.log10(2e1), 201) - i_st = (input_data.surface_type == st).data - x = input_data["rain_water_content"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["rain_water_content"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure two-meter-temperature distributions match. - bins = np.linspace(239.5, 339.5, 101) - i_st = (input_data.surface_type == st).data - x = input_data["two_meter_temperature"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["two_meter_temperature"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure surface type distributions match - bins = np.arange(19) + 0.5 - x = input_data["surface_type"].data - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["surface_type"].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure conditional means match - st = input_data.surface_type.data[0] - i_t2m = int(np.round(input_data.two_meter_temperature.data[0] - 240)) - mean_sp = results["surface_precip_mean_t2m"][st - 1, i_t2m] - mean_sp_ref = input_data.surface_precip.data.mean() - assert np.isclose(mean_sp_ref, mean_sp) - - # Ensure conditional means match - st = input_data.surface_type.data[0] - i_tcwv = int(np.round(input_data.total_column_water_vapor.data[0])) - mean_sp = results["surface_precip_mean_tcwv"][st - 1, i_tcwv] - mean_sp_ref = input_data.surface_precip.data.mean() - assert np.isclose(mean_sp_ref, mean_sp) - - -def test_bin_statistics_mhs_sea_ice(tmpdir): - """ - Ensure that TrainingDataStatistics class reproduces statistic of - MHS bin file for a sea ice surface. - """ - files = [DATA_PATH / "mhs" / "bin" / "gpm_271_20_16.bin"] * 2 - - - stats = [BinFileStatistics(), - ZonalDistribution(), - GlobalDistribution()] - processor = StatisticsProcessor(sensors.MHS, - files, - stats) - processor.run(2, tmpdir) - input_data = BinFile(files[0]).to_xarray_dataset() - - results = xr.open_dataset(str( - tmpdir / - "bin_file_statistics_mhs.nc" - )) - - # Ensure TB dists match. - st = 2 - bins = np.linspace(0, 400, 401) - inds = (input_data.surface_type == st).data - inds = inds * (input_data.pixel_position == 4).data - tbs = input_data["brightness_temperatures"].data[inds] - counts_ref, _ = np.histogram(tbs[:, 0], bins=bins) - counts = results["brightness_temperatures"][st - 1, 0, 3].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure surface_precip dists match. - bins = np.logspace(-3, np.log10(2e2), 201) - x = input_data["surface_precip"].data[inds] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["surface_precip"][st - 1, 3].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure RWC distributions match. - bins = np.logspace(-4, np.log10(2e1), 201) - i_st = (input_data.surface_type == st).data - x = input_data["rain_water_content"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["rain_water_content"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure two-meter-temperature distributions match. - bins = np.linspace(239.5, 339.5, 101) - i_st = (input_data.surface_type == st).data - x = input_data["two_meter_temperature"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["two_meter_temperature"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure surface type distributions match - bins = np.arange(19) + 0.5 - x = input_data["surface_type"].data - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["surface_type"].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - -def test_bin_statistics_mhs_ocean(tmpdir): - """ - Ensure that TrainingDataStatistics class reproduces statistic of - MHS bin file for a land surface. - """ - files = [DATA_PATH / "mhs" / "bin" / "gpm_289_52_04.bin"] * 2 - - - stats = [BinFileStatistics(), - ZonalDistribution(), - GlobalDistribution()] - processor = StatisticsProcessor(sensors.MHS, - files, - stats) - processor.run(2, tmpdir) - input_data = BinFile(files[0]).to_xarray_dataset() - - results = xr.open_dataset(str( - tmpdir / - "bin_file_statistics_mhs.nc" - )) - - # Ensure TB dists match. - st = 1 - bins = np.linspace(0, 400, 401) - inds = (input_data.surface_type == st).data - tbs = input_data["brightness_temperatures"].data[inds] - counts_ref, _ = np.histogram(tbs[:, 0, 0], bins=bins) - counts = results["brightness_temperatures"][st - 1, 0, 0].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure surface_precip dists match. - bins = np.logspace(-3, np.log10(2e2), 201) - x = input_data["surface_precip"].data[inds, 0] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["surface_precip"][st - 1, 0].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure RWC distributions match. - bins = np.logspace(-4, np.log10(2e1), 201) - i_st = (input_data.surface_type == st).data - x = input_data["rain_water_content"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["rain_water_content"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure two-meter-temperature distributions match. - bins = np.linspace(239.5, 339.5, 101) - i_st = (input_data.surface_type == st).data - x = input_data["two_meter_temperature"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["two_meter_temperature"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure surface type distributions match - bins = np.arange(19) + 0.5 - x = input_data["surface_type"].data - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["surface_type"].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - -def test_observation_statistics_gmi(tmpdir): - """ - Ensure that TrainingDataStatistics class reproduces statistic of - MHS bin file for an ocean surface. - """ - files = [DATA_PATH / "gmi" / "pp" / "GMIERA5_190101_027510.pp"] * 2 - - stats = [ObservationStatistics(conditional=1), - ZonalDistribution(), - GlobalDistribution()] - processor = StatisticsProcessor(sensors.GMI, - files, - stats) - processor.run(2, tmpdir) - input_data = PreprocessorFile(files[0]).to_xarray_dataset() - - results = xr.open_dataset(str( - tmpdir / - "observation_statistics_gmi.nc" - )) - - # Ensure TB dists match. - st = 1 - bins = np.linspace(0, 400, 401) - inds = (input_data.surface_type == st).data - tbs = input_data["brightness_temperatures"].data[inds] - counts_ref, _ = np.histogram(tbs[:, 0], bins=bins) - counts = results["brightness_temperatures"][st - 1, 0].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - bins_tcwv = np.linspace(-0.5, 99.5, 101) - inds = (input_data.surface_type == st).data - tcwv = input_data["total_column_water_vapor"].data[inds] - counts_ref, _, _ = np.histogram2d( - tcwv, tbs[:, 0], bins=(bins_tcwv, bins) - ) - counts = results["brightness_temperatures_tcwv"][st - 1, 0].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure two-meter-temperature distributions match. - bins = np.linspace(240, 330, 201) - i_st = (input_data.surface_type == st).data - x = input_data["two_meter_temperature"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["two_meter_temperature"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure surface type distributions match - bins = np.arange(19) + 0.5 - x = input_data["surface_type"].data - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["surface_type"].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure latitude distributions match. - bins = np.linspace(-90, 90, 181) - lats = input_data["latitude"].data - counts_ref, _ = np.histogram(lats, bins=bins) - counts = results["latitudes"].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure local time distributions match. - bins = (np.linspace(0, 24, 25) - 0.5) * 60 - lons = input_data["longitude"] - scan_time = input_data["scan_time"] - local_time = (scan_time - + (lons / 360 * 24 * 60 * 60).astype("timedelta64[s]")) - minutes = local_time.dt.hour * 60 + local_time.dt.minute.data - counts_ref, _ = np.histogram(minutes, bins=bins) - counts = results["local_time"].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - -def test_observation_statistics_mhs(tmpdir): - """ - Ensure that TrainingDataStatistics class reproduces statistic of - MHS bin file for an ocean surface. - """ - files = [DATA_PATH / "mhs" / "pp" / "MHS.pp"] * 2 - - stats = [ObservationStatistics(gmi_range=False), - ZonalDistribution(), - GlobalDistribution()] - processor = StatisticsProcessor(sensors.MHS, - files, - stats) - processor.run(2, tmpdir) - input_data = PreprocessorFile(files[0]).to_xarray_dataset() - - results = xr.open_dataset(str( - tmpdir / - "observation_statistics_mhs.nc", - )) - - # Ensure TB dists match. - bins = np.linspace(0, 400, 401) - st = 1 - sensor = sensors.MHS - angle_bins = sensor.angle_bins#np.zeros(sensor.angles.size + 1) - #angle_bins[1:-1] = 0.5 * (sensor.angles[1:] + sensor.angles[:-1]) - #angle_bins[0] = 2.0 * angle_bins[1] - angle_bins[2] - #angle_bins[-1] = 2.0 * angle_bins[-2] - angle_bins[-3] - lower = angle_bins[3] - upper = angle_bins[2] - eia = np.abs(input_data.earth_incidence_angle.data) - inds = ((input_data.surface_type.data == st) * - (eia >= lower) * - (eia < upper)) - tbs = input_data["brightness_temperatures"].data[inds] - counts_ref, _ = np.histogram(tbs[:, 0], bins=bins) - counts = results["brightness_temperatures"][st - 1, 0, 2].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - bins_tcwv = np.linspace(-0.5, 99.5, 101) - tcwv = input_data["total_column_water_vapor"].data[inds] - counts_ref, _, _ = np.histogram2d( - tcwv, tbs[:, 0], bins=(bins_tcwv, bins) - ) - counts = results["brightness_temperatures_tcwv"][st - 1, 0, 2].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure two-meter-temperature distributions match. - bins = np.linspace(240, 330, 201) - i_st = (input_data.surface_type == st).data - x = input_data["two_meter_temperature"].data[i_st] - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["two_meter_temperature"][st - 1].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure surface type distributions match - bins = np.arange(19) + 0.5 - x = input_data["surface_type"].data - counts_ref, _ = np.histogram(x, bins=bins) - counts = results["surface_type"].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure latitude distributions match. - bins = np.linspace(-90, 90, 181) - lats = input_data["latitude"].data - counts_ref, _ = np.histogram(lats, bins=bins) - counts = results["latitudes"].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - # Ensure local time distributions match. - bins = (np.linspace(0, 24, 25) - 0.5) * 60 - lons = input_data["longitude"] - scan_time = input_data["scan_time"] - local_time = (scan_time - + (lons / 360 * 24 * 60 * 60).astype("timedelta64[s]")) - minutes = local_time.dt.hour * 60 + local_time.dt.minute.data - counts_ref, _ = np.histogram(minutes, bins=bins) - counts = results["local_time"].data - assert np.all(np.isclose(counts, 2.0 * counts_ref)) - - -def test_retrieval_statistics_gmi(tmpdir): - """ - Ensure that calculated means of retrieval results statistics match - directly calculated ones. - """ - source_file = DATA_PATH / "gmi" / "retrieval" / "GMIERA5_190101_027510.bin" - # This retrieval file contains profiles so it has to be converted - # to a netcdf file first. - data = RetrievalFile(source_file, has_profiles=True).to_xarray_dataset() - data.to_netcdf(tmpdir / "input.nc") - - files = [tmpdir / "input.nc"] * 2 - stats = [RetrievalStatistics()] - processor = StatisticsProcessor(sensors.GMI, - files, - stats) - processor.run(2, str(tmpdir)) - - results = xr.load_dataset(str(tmpdir / "retrieval_statistics.nc")) - - i_b = 18 - mean_sp = results["surface_precip_mean_t2m"][4, i_b] - st = data.surface_type.data - l_t2m, r_t2m = np.linspace(239.5, 339.5, 101)[i_b:i_b + 2] - indices = ((data.surface_type.data == 5) * - (data.two_meter_temperature.data >= l_t2m) * - (data.two_meter_temperature.data < r_t2m) * - (data.surface_precip.data > -999)) - mean_sp_ref = data.surface_precip.data[indices].mean() - - assert np.isclose(mean_sp_ref, mean_sp) - - mean_sp = results["surface_precip_mean_tcwv"][0, i_b] - st = data.surface_type.data - l_tcwv, r_tcwv = np.linspace(-0.5, 99.5, 101)[i_b:i_b + 2] - indices = ((data.surface_type.data == 1) * - (data.total_column_water_vapor.data >= l_tcwv) * - (data.total_column_water_vapor.data < r_tcwv) * - (data.surface_precip.data > -999)) - mean_sp_ref = data.surface_precip.data[indices].mean() - - assert np.isclose(mean_sp_ref, mean_sp) - - -def test_gpm_cmb_statistics(tmpdir): - input_file = ( - DATA_PATH / "cmb" / - "2B.GPM.DPRGMI.CORRA2018.20210829-S205206-E222439.042628.V06A.HDF5" - ) - files = [input_file] * 2 - stats = [GPMCMBStatistics()] - processor = StatisticsProcessor(sensors.GMI, - files, - stats) - processor.run(2, str(tmpdir)) - results = xr.load_dataset(str(tmpdir / "gpm_combined_statistics.nc")) - - input_data = GPMCMBFile(input_file).to_xarray_dataset(smooth=True) - surface_precip = input_data.surface_precip.data - lats = input_data.latitude.data - latitude_bins = np.arange(-90, 90 + 1e-3, 5) - sp_bins = np.logspace(-2, 2.5, 201) - bins = (latitude_bins, sp_bins) - - cs_ref, _, _ = np.histogram2d( - lats.ravel(), - surface_precip.ravel(), - bins=bins - ) - cs = results["surface_precip"].data.sum(axis=1) - - assert np.all(np.isclose(cs, 2.0 * cs_ref)) - - -def test_resample_scans(): - """ - Tests the resampling of observation along the scan dimension. - """ - n_mins = 24 * 60 - lats = np.linspace(-90, 89, n_mins) + 0.5 - lons = np.zeros(n_mins) - scan_time = np.arange(0, n_mins, dtype="datetime64[m]") - - dataset = xr.Dataset({ - "scan_time": (("scans",), scan_time), - "latitude": (("scans", "pixels"), lats[..., np.newaxis]), - "longitude": (("scans", "pixels"), lons[..., np.newaxis]) - }) - - weights = np.zeros(LAT_BINS.size - 1) - weights[LAT_BINS.size // 2:] = 1.0 - dataset = resample_scans(dataset, weights) - assert np.all(dataset.latitude.data > 0) - - weights = np.zeros((LAT_BINS.size - 1, TIME_BINS.size - 1)) - weights[..., TIME_BINS.size // 2:] = 1.0 - dataset = resample_scans(dataset, weights) - minutes = (dataset.scan_time.dt.hour * 60 + - dataset.scan_time.dt.minute) - assert np.all(minutes.data >= n_mins // 2) - - -def test_retrieval_statistics_resampled(tmpdir): - """ - Test calculating retrieval statistics with latitude resampling. - """ - source_file = DATA_PATH / "gmi" / "retrieval" / "GMIERA5_190101_027510.bin" - # This retrieval file contains profiles so it has to be converted - # to a netcdf file first. - data = RetrievalFile(source_file, has_profiles=True).to_xarray_dataset() - data.to_netcdf(tmpdir / "input.nc") - - # Set statistics so that only scans over - # equator are considered. - statistics = np.zeros(LAT_BINS.size - 1) - lats = 0.5 * (LAT_BINS[1:] + LAT_BINS[:-1]) - statistics[lats > 0] = 1.0 - - files = [tmpdir / "input.nc"] * 2 - stats = [RetrievalStatistics(statistics)] - processor = StatisticsProcessor(sensors.GMI, - files, - stats) - processor.run(2, str(tmpdir)) - - results = xr.load_dataset(str(tmpdir / "retrieval_statistics.nc")) - - # Select only scans whose mean latitude is above equator. - mean_lats = data.latitude.mean("pixels") - data = data[{"scans": mean_lats > 0}] - - mean_sp = results["surface_precip_mean_t2m"][1, 1] - l_t2m, r_t2m = np.linspace(239.5, 339.5, 101)[1:3] - indices = ((data.surface_type.data == 2) * - (data.two_meter_temperature.data >= l_t2m) * - (data.two_meter_temperature.data < r_t2m) * - (data.surface_precip.data > -999)) - mean_sp_ref = data.surface_precip.data[indices].mean() - - assert np.isclose(mean_sp_ref, mean_sp, atol=1e-2) - - mean_sp = results["surface_precip_mean_tcwv"][1, 2] - st = data.surface_type.data - l_tcwv, r_tcwv = np.linspace(-0.5, 99.5, 101)[2:4] - indices = ((data.surface_type.data == 2) * - (data.total_column_water_vapor.data >= l_tcwv) * - (data.total_column_water_vapor.data < r_tcwv) * - (data.surface_precip.data > -999)) - mean_sp_ref = data.surface_precip.data[indices].mean() - - assert np.isclose(mean_sp_ref, mean_sp, atol=1e-2) - diff --git a/test/test_surface.py b/test/test_surface.py deleted file mode 100644 index aaf2a68..0000000 --- a/test/test_surface.py +++ /dev/null @@ -1,63 +0,0 @@ -""" -Tests for the loading of surface maps for the GPROF-NN data processing. -""" -from datetime import datetime - -import pytest -import numpy as np - -from gprof_nn.data.surface import (read_land_mask, - read_autosnow, - read_emissivity_classes) -from gprof_nn.data.preprocessor import has_preprocessor - - -HAS_PREPROCESSOR = has_preprocessor() - - -@pytest.mark.skipif(not HAS_PREPROCESSOR, reason="Preprocessor missing.") -def test_read_land_mask(): - """ - Test reading of land mask. - """ - mask = read_land_mask("GMI") - assert mask.mask.shape == (180 * 32, 360 * 32) - - mask = read_land_mask("MHS") - assert mask.mask.shape == (180 * 16, 360 * 16) - - # Ensure point in North Atlantic is classified as Ocean. - m = mask.interp({"longitude": -46.0, "latitude": 35.0}) - assert np.isclose(m.mask.data, 0) - - # Ensure point in Africa is classified as land. - m = mask.interp({"longitude": 0.0, "latitude": 20.0}) - assert np.all(m.mask.data > 0) - - -@pytest.mark.skipif(not HAS_PREPROCESSOR, reason="Preprocessor missing.") -def test_read_autosnow(): - """ - Test reading of autosnow files. - """ - autosnow = read_autosnow("2021-01-01T00:00:00") - - # Ensure no snow around equator - autosnow_eq = autosnow.interp({"latitude": 0.0, "longitude": 0.0}, "nearest") - assert np.all(autosnow_eq.snow.data == 0) - - -@pytest.mark.skipif(not HAS_PREPROCESSOR, reason="Preprocessor missing.") -def test_read_emissivity_classes(): - """ - Test reading of emissivity classes. - """ - data = read_emissivity_classes() - - # Ensure point in North Atlantic is classified as Ocean. - data_i = data.interp({"longitude": -46.0, "latitude": 35.0}) - assert np.all(np.isclose(data_i.emissivity.data, 0)) - - # Ensure point in Africa is classified as land. - data_i = data.interp({"longitude": 0.0, "latitude": 20.0}) - assert np.all(data_i.emissivity.data > 0) diff --git a/test/test_tiling.py b/test/test_tiling.py deleted file mode 100644 index df39a9e..0000000 --- a/test/test_tiling.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -Test for the tiling and assembling of input. -""" -import numpy as np -import matplotlib.pyplot as plt - -from gprof_nn.tiling import Tiler - -def upsample(x): - m, n = x.shape - x_new = np.zeros((3 * m - 2, n)) - x_new[0::3, :] = x - x_new[1::3, :] = (2 * x[:-1] / 3 + 1 * x[1:] / 3) - x_new[2::3, :] = (1 * x[:-1] / 3 + 2 * x[1:] / 3) - return x_new - - -def test_assembling(): - """ - Tests assembling of tiles. Also tests that the assemble of upsampled - tiles works as expected. - """ - x_3 = np.tile(np.arange(128 * 3 - 2)[..., None], (1, 128)) - x_3 = x_3.astype(np.float32) - x = x_3[::3] - - tiler = Tiler(x, tile_size=(64, 64), overlap=(16, 16)) - - tiles = [ - [tiler.get_tile(i, j) for j in range(tiler.N)] - for i in range(tiler.M) - ] - x_assembled = tiler.assemble(tiles) - assert np.all(np.isclose(x, x_assembled)) - - tiles = [ - [upsample(tiler.get_tile(i, j)) for j in range(tiler.N)] - for i in range(tiler.M) - ] - tiler_3 = Tiler(x_3, tile_size=(3 * 64 - 2, 64), overlap=(3 * 16 - 2, 16)) - x_3_assembled = tiler_3.assemble(tiles) - assert np.all(np.isclose(x_3, x_3_assembled)) diff --git a/test/test_training_data.py b/test/test_training_data.py deleted file mode 100644 index 6d8c65d..0000000 --- a/test/test_training_data.py +++ /dev/null @@ -1,1064 +0,0 @@ -""" -Tests for the Pytorch dataset classes used to load the training -data. -""" -from pathlib import Path - -import numpy as np -import torch -import xarray as xr - -from quantnn.qrnn import QRNN -from quantnn.normalizer import Normalizer -from quantnn.models.pytorch.xception import XceptionFpn - -from gprof_nn import sensors -from gprof_nn.data import get_test_data_path -from gprof_nn.data.training_data import ( - load_variable, - decompress_scene, - decompress_and_load, - remap_scene, - GPROF_NN_1D_Dataset, - GPROF_NN_3D_Dataset, - SimulatorDataset, -) - - -DATA_PATH = get_test_data_path() - -###################################################################### -# GMI - 1D -###################################################################### - -def test_to_xarray_dataset_1d_gmi(): - """ - Ensure that converting training data to 'xarray.Dataset' yield same - Tbs as the ones found in the first batch of the training data when - data is not shuffled. - """ - data_path = Path(__file__).parent / "data" - input_file = data_path / "gmi" / "gprof_nn_gmi_era5.nc" - dataset = GPROF_NN_1D_Dataset( - input_file, - batch_size=64, - normalize=False, - shuffle=False, - targets=["surface_precip", "rain_water_content"] - ) - - # - # Conversion using datasets 'x' attribute. - # - - data = dataset.to_xarray_dataset() - x, y = dataset[0] - x = x.numpy() - - tbs = data.brightness_temperatures.data[:x.shape[0]] - tbs_ref = x[:, :15] - valid = np.isfinite(tbs_ref) - assert np.all(np.isclose(tbs[valid], tbs_ref[valid])) - - t2m = data.two_meter_temperature.data[:x.shape[0]] - t2m_ref = x[:, 15] - assert np.all(np.isclose(t2m, t2m_ref)) - - tcwv = data.total_column_water_vapor.data[:x.shape[0]] - tcwv_ref = x[:, 16] - assert np.all(np.isclose(tcwv, tcwv_ref)) - - of = data.ocean_fraction.data[:x.shape[0]] - of_ref = x[:, 17] - assert np.all(np.isclose(of, of_ref)) - - # - # Conversion using only first batch - # - - x, y = dataset[0] - data = dataset.to_xarray_dataset(batch=(x, y)) - x = x.numpy() - - tbs = data.brightness_temperatures.data - tbs_ref = x[:, :15] - valid = np.isfinite(tbs_ref) - assert np.all(np.isclose(tbs[valid], tbs_ref[valid])) - - t2m = data.two_meter_temperature.data - t2m_ref = x[:, 15] - assert np.all(np.isclose(t2m, t2m_ref)) - - tcwv = data.total_column_water_vapor.data - tcwv_ref = x[:, 16] - assert np.all(np.isclose(tcwv, tcwv_ref)) - - of = data.ocean_fraction.data[:x.shape[0]] - of_ref = x[:, 17] - assert np.all(np.isclose(of, of_ref)) - - -def test_permutation_gmi(): - """ - Ensure that permutation permutes the right input features. - """ - # Permute continuous input - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc.gz" - dataset_1 = GPROF_NN_1D_Dataset( - input_file, - batch_size=16, - shuffle=False, - augment=False, - transform_zeros=False, - targets=["surface_precip"], - ) - dataset_2 = GPROF_NN_1D_Dataset( - input_file, - batch_size=16, - shuffle=False, - augment=False, - transform_zeros=False, - targets=["surface_precip"], - permute=0, - ) - x_1, y_1 = dataset_1[0] - y_1 = y_1["surface_precip"] - x_2, y_2 = dataset_2[0] - y_2 = y_2["surface_precip"] - - assert np.all(np.isclose(y_1, y_2)) - assert ~np.all(np.isclose(x_1[:, :1], x_2[:, :1])) - assert np.all(np.isclose(x_1[:, 1:], x_2[:, 1:])) - - # Permute surface type - dataset_2 = GPROF_NN_1D_Dataset( - input_file, - batch_size=16, - shuffle=False, - augment=False, - transform_zeros=False, - targets=["surface_precip"], - permute=17, - ) - x_2, y_2 = dataset_2[0] - y_2 = y_2["surface_precip"] - - assert np.all(np.isclose(y_1, y_2)) - assert np.all(np.isclose(x_1[:, :-24], x_2[:, :-24])) - assert ~np.all(np.isclose(x_1[:, -24:-4], x_2[:, -24:-4])) - assert np.all(np.isclose(x_1[:, -4:], x_2[:, -4:])) - - # Permute airmass type - dataset_2 = GPROF_NN_1D_Dataset( - input_file, - batch_size=16, - shuffle=False, - augment=False, - transform_zeros=False, - targets=["surface_precip"], - permute=18, - ) - x_2, y_2 = dataset_2[0] - y_2 = y_2["surface_precip"] - - assert np.all(np.isclose(y_1, y_2)) - assert np.all(np.isclose(x_1[:, :-4], x_2[:, :-4])) - - -def test_gprof_1d_dataset_input_gmi(): - """ - Ensure that input variables have realistic values. - """ - data_path = Path(__file__).parent / "data" - input_file = data_path / "gmi" / "gprof_nn_gmi_era5.nc" - dataset = GPROF_NN_1D_Dataset( - input_file, batch_size=1, normalize=False, targets=["surface_precip"] - ) - x, _ = dataset[0] - x = x.numpy() - - tbs = x[:, :15] - tbs = tbs[np.isfinite(tbs)] - assert np.all((tbs > 30) * (tbs < 400)) - - t2m = x[:, 15] - assert np.all((t2m > 180) * (t2m < 350)) - - tcwv = x[:, 16] - assert np.all((tcwv > 0) * (tcwv < 100)) - - -def test_gprof_1d_dataset_gmi(): - """ - Ensure that iterating over single-pixel dataset conserves - statistics. - """ - data_path = Path(__file__).parent / "data" - input_file = data_path / "gmi" / "gprof_nn_gmi_era5.nc" - dataset = GPROF_NN_1D_Dataset( - input_file, batch_size=1, augment=False, targets=["surface_precip"] - ) - - xs = [] - ys = [] - - x_mean_ref = dataset.x.sum(axis=0) - y_mean_ref = dataset.y["surface_precip"].sum(axis=0) - - for x, y in dataset: - xs.append(x) - ys.append(y["surface_precip"]) - - xs = torch.cat(xs, dim=0) - ys = torch.cat(ys, dim=0) - - x_mean = xs.sum(dim=0).detach().numpy() - y_mean = ys.sum(dim=0).detach().numpy() - - assert np.all(np.isclose(x_mean, x_mean_ref, rtol=1e-3)) - assert np.all(np.isclose(y_mean, y_mean_ref, rtol=1e-3)) - - -def test_gprof_1d_dataset_multi_target_gmi(): - """ - Ensure that iterating over single-pixel dataset conserves - statistics. - """ - data_path = Path(__file__).parent / "data" - input_file = data_path / "gmi" / "gprof_nn_gmi_era5.nc" - dataset = GPROF_NN_1D_Dataset( - input_file, - targets=["surface_precip", - "latent_heat", - "rain_water_content", - "scan_time"], - batch_size=1, - transform_zeros=False, - ) - - xs = [] - ys = {} - - x_mean_ref = np.sum(dataset.x, axis=0) - y_mean_ref = {k: np.sum(dataset.y[k], axis=0) for k in dataset.y} - - for x, y in dataset: - xs.append(x) - for k in y: - ys.setdefault(k, []).append(y[k]) - - xs = torch.cat(xs, dim=0) - ys = {k: torch.cat(ys[k], dim=0) for k in ys} - - x_mean = np.sum(xs.detach().numpy(), axis=0) - y_mean = {k: np.sum(ys[k].detach().numpy(), axis=0) for k in ys} - - assert np.all(np.isclose(x_mean, x_mean_ref, rtol=1e-3)) - for k in y_mean_ref: - assert np.all(np.isclose(y_mean[k], y_mean_ref[k], rtol=1e-3)) - - -def test_profile_variables(): - """ - Test loading of profile variables. - """ - path = Path(__file__).parent - input_file = path / "data" / "gmi" / "gprof_nn_gmi_era5.nc" - - PROFILE_TARGETS = [ - "rain_water_content", - "snow_water_content", - "cloud_water_content", - "latent_heat", - ] - dataset = GPROF_NN_1D_Dataset(input_file, targets=PROFILE_TARGETS, batch_size=1) - x, y = dataset[0] - - -###################################################################### -# MHS - 1D -###################################################################### - - -def test_gprof_1d_dataset_mhs(): - """ - Ensure that iterating over single-pixel dataset conserves - statistics. - """ - input_file = DATA_PATH / "mhs" / "gprof_nn_mhs_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, - batch_size=1, - augment=False, - targets=["surface_precip"], - sensor=sensors.MHS, - ) - - xs = [] - ys = [] - - x_mean_ref = dataset.x.sum(axis=0) - y_mean_ref = dataset.y["surface_precip"].sum(axis=0) - - for x, y in dataset: - xs.append(x) - ys.append(y["surface_precip"]) - - xs = torch.cat(xs, dim=0) - ys = torch.cat(ys, dim=0) - - x_mean = xs.sum(dim=0).detach().numpy() - y_mean = ys.sum(dim=0).detach().numpy() - - assert np.all(np.isclose(x_mean, x_mean_ref, rtol=1e-3)) - assert np.all(np.isclose(y_mean, y_mean_ref, rtol=1e-3)) - - assert np.all(np.isclose(x[:, 8:26].sum(-1), 1.0)) - - -def test_gprof_1d_dataset_multi_target_mhs(): - """ - Ensure that iterating over single-pixel dataset conserves - statistics. - """ - input_file = DATA_PATH / "mhs" / "gprof_nn_mhs_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, - targets=["surface_precip", "latent_heat", "rain_water_content"], - batch_size=1, - transform_zeros=False, - sensor=sensors.MHS, - ) - - xs = [] - ys = {} - - x_mean_ref = np.mean(dataset.x, axis=0) - y_mean_ref = {k: np.mean(dataset.y[k], axis=0) for k in dataset.y} - - for x, y in dataset: - xs.append(x) - for k in y: - ys.setdefault(k, []).append(y[k]) - - xs = torch.cat(xs, dim=0) - ys = {k: torch.cat(ys[k], dim=0) for k in ys} - - x_mean = np.mean(xs.detach().numpy(), axis=0) - y_mean = {k: np.mean(ys[k].detach().numpy(), axis=0) for k in ys} - - assert np.all(np.isclose(x_mean, x_mean_ref, atol=1e-3)) - for k in y_mean_ref: - assert np.all(np.isclose(y_mean[k], y_mean_ref[k], rtol=1e-3)) - - -def test_gprof_1d_dataset_input_mhs(): - """ - Ensure that input variables have realistic values. - """ - input_file = DATA_PATH / "mhs" / "gprof_nn_mhs_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, batch_size=1, normalize=False, targets=["surface_precip"] - ) - x, _ = dataset[0] - x = x.numpy() - - tbs = x[:, :5] - tbs = tbs[np.isfinite(tbs)] - assert np.all((tbs > 30) * (tbs < 400)) - - eia = x[:, 5] - eia = eia[np.isfinite(eia)] - assert np.all((eia >= -60) * (eia <= 60)) - - t2m = x[:, 6] - t2m = t2m[np.isfinite(t2m)] - assert np.all((t2m > 180) * (t2m < 350)) - - tcwv = x[:, 7] - tcwv = tcwv[np.isfinite(tcwv)] - assert np.all((tcwv > 0) * (tcwv < 100)) - - -def test_gprof_1d_dataset_pretraining_mhs(): - """ - Test that the correct inputs are loaded when loading a pre-training dataset - for MHS. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, batch_size=1, normalize=False, targets=["surface_precip"], - sensor=sensors.MHS - ) - - x, _ = dataset[0] - x = x.numpy() - - assert x.shape[1] == 5 + 3 + 18 + 4 - - -def test_to_xarray_dataset_1d_mhs(): - """ - Ensure that converting training data to 'xarray.Dataset' yield same - Tbs as the ones found in the first batch of the training data when - data is not shuffled. - """ - input_file = DATA_PATH / "mhs" / "gprof_nn_mhs_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, - batch_size=64, - normalize=False, - shuffle=False, - targets=["surface_precip", "rain_water_content"] - ) - - # - # Conversion using datasets 'x' attribute. - # - - data = dataset.to_xarray_dataset() - x, y = dataset[0] - x = x.numpy() - - t2m = data.two_meter_temperature.data[:x.shape[0]] - t2m_ref = x[:, 6] - assert np.all(np.isclose(t2m, t2m_ref)) - - tcwv = data.total_column_water_vapor.data[:x.shape[0]] - tcwv_ref = x[:, 7] - assert np.all(np.isclose(tcwv, tcwv_ref)) - - st = data.surface_type.data[:x.shape[0]] - inds, st_ref = np.where(x[:, -22:-4]) - assert np.all(np.isclose(st[inds], st_ref + 1)) - - at = data.airmass_type.data[:x.shape[0]] - inds, at_ref = np.where(x[:, -4:]) - assert np.all(np.isclose(at[inds], at_ref)) - - -###################################################################### -# TMI - 1D -###################################################################### - - -def test_gprof_1d_dataset_tmi(): - """ - Ensure that iterating over single-pixel dataset conserves - statistics. - """ - DATA_PATH = Path(__file__).parent / "data" - input_file = DATA_PATH / "tmi" / "gprof_nn_tmi_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, - batch_size=1, - augment=False, - targets=["surface_precip"] - ) - - xs = [] - ys = [] - - x_mean_ref = dataset.x.mean(axis=0) - y_mean_ref = dataset.y["surface_precip"].mean(axis=0) - - for x, y in dataset: - xs.append(x) - ys.append(y["surface_precip"]) - - xs = torch.cat(xs, dim=0) - ys = torch.cat(ys, dim=0) - - x_mean = xs.mean(dim=0).detach().numpy() - y_mean = ys.mean(dim=0).detach().numpy() - - assert np.all(np.isclose(x_mean, x_mean_ref, rtol=1e-3)) - assert np.all(np.isclose(y_mean, y_mean_ref, rtol=1e-3)) - - assert np.all(np.isclose(x[:, 11:29].sum(-1), 1.0)) - - -def test_gprof_1d_dataset_multi_target_tmi(): - """ - Ensure that iterating over single-pixel dataset conserves - statistics. - """ - DATA_PATH = Path(__file__).parent / "data" - input_file = DATA_PATH / "tmi" / "gprof_nn_tmi_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, - targets=["surface_precip", "latent_heat", "rain_water_content"], - batch_size=1, - transform_zeros=False, - ) - - xs = [] - ys = {} - - x_mean_ref = np.mean(dataset.x, axis=0) - y_mean_ref = {k: np.mean(dataset.y[k], axis=0) for k in dataset.y} - - for x, y in dataset: - xs.append(x) - for k in y: - ys.setdefault(k, []).append(y[k]) - - xs = torch.cat(xs, dim=0) - ys = {k: torch.cat(ys[k], dim=0) for k in ys} - - x_mean = np.mean(xs.detach().numpy(), axis=0) - y_mean = {k: np.mean(ys[k].detach().numpy(), axis=0) for k in ys} - - assert np.all(np.isclose(x_mean, x_mean_ref, atol=1e-3)) - for k in y_mean_ref: - assert np.all(np.isclose(y_mean[k], y_mean_ref[k], rtol=1e-3)) - - -def test_gprof_1d_dataset_input_tmi(): - """ - Ensure that input variables have realistic values. - """ - DATA_PATH = Path(__file__).parent / "data" - input_file = DATA_PATH / "tmi" / "gprof_nn_tmi_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, batch_size=1, normalize=False, targets=["surface_precip"] - ) - x, _ = dataset[0] - x = x.numpy() - - tbs = x[:, :9] - tbs = tbs[np.isfinite(tbs)] - assert np.all((tbs > 30) * (tbs < 400)) - - t2m = x[:, 9] - t2m = t2m[np.isfinite(t2m)] - assert np.all((t2m > 180) * (t2m < 350)) - - tcwv = x[:, 10] - tcwv = tcwv[np.isfinite(tcwv)] - assert np.all((tcwv > 0) * (tcwv < 100)) - - -def test_gprof_1d_dataset_pretraining_tmi(): - """ - Test that the correct inputs are loaded when loading a pre-training dataset - for TMI. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, batch_size=1, normalize=False, targets=["surface_precip"], - sensor=sensors.TMIPR - ) - - x, _ = dataset[0] - x = x.numpy() - - assert x.shape[1] == 9 + 2 + 18 + 4 - -###################################################################### -# SSMI - 1D -###################################################################### - - -def test_gprof_1d_dataset_ssmi(): - """ - Ensure that iterating over single-pixel dataset conserves - statistics. - """ - DATA_PATH = Path(__file__).parent / "data" - input_file = DATA_PATH / "ssmi" / "gprof_nn_ssmi_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, - batch_size=1, - augment=False, - targets=["surface_precip"] - ) - - xs = [] - ys = [] - - x_mean_ref = dataset.x.sum(axis=0) - y_mean_ref = dataset.y["surface_precip"].sum(axis=0) - - for x, y in dataset: - xs.append(x) - ys.append(y["surface_precip"]) - - xs = torch.cat(xs, dim=0) - ys = torch.cat(ys, dim=0) - - x_mean = xs.sum(dim=0).detach().numpy() - y_mean = ys.sum(dim=0).detach().numpy() - - assert np.all(np.isclose(x_mean, x_mean_ref, rtol=1e-3)) - assert np.all(np.isclose(y_mean, y_mean_ref, rtol=1e-3)) - - assert np.all(np.isclose(x[:, 9:27].sum(-1), 1.0)) - - -def test_gprof_1d_dataset_multi_target_ssmi(): - """ - Ensure that iterating over single-pixel dataset conserves - statistics. - """ - DATA_PATH = Path(__file__).parent / "data" - input_file = DATA_PATH / "ssmi" / "gprof_nn_ssmi_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, - targets=["surface_precip", "latent_heat", "rain_water_content"], - batch_size=1, - transform_zeros=False, - ) - - xs = [] - ys = {} - - x_mean_ref = np.mean(dataset.x, axis=0) - y_mean_ref = {k: np.mean(dataset.y[k], axis=0) for k in dataset.y} - - for x, y in dataset: - xs.append(x) - for k in y: - ys.setdefault(k, []).append(y[k]) - - xs = torch.cat(xs, dim=0) - ys = {k: torch.cat(ys[k], dim=0) for k in ys} - - x_mean = np.mean(xs.detach().numpy(), axis=0) - y_mean = {k: np.mean(ys[k].detach().numpy(), axis=0) for k in ys} - - assert np.all(np.isclose(x_mean, x_mean_ref, atol=1e-3)) - for k in y_mean_ref: - assert np.all(np.isclose(y_mean[k], y_mean_ref[k], rtol=1e-3)) - - -def test_gprof_1d_dataset_input_ssmi(): - """ - Ensure that input variables have realistic values. - """ - DATA_PATH = Path(__file__).parent / "data" - input_file = DATA_PATH / "ssmi" / "gprof_nn_ssmi_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, batch_size=1, normalize=False, targets=["surface_precip"], - sensor=sensors.SSMI_F08 - ) - xs = [dataset[i][0] for i in range(len(dataset))] - x = torch.cat(xs, 0) - x = x.numpy() - - tbs = x[:, :7] - # Some inputs should be missing - assert np.any(np.isnan(tbs[:, 5:])) - tbs = tbs[np.isfinite(tbs)] - assert np.all((tbs > 30) * (tbs < 400)) - - t2m = x[:, 7] - t2m = t2m[np.isfinite(t2m)] - assert np.all((t2m > 180) * (t2m < 350)) - - tcwv = x[:, 8] - tcwv = tcwv[np.isfinite(tcwv)] - assert np.all((tcwv > 0) * (tcwv < 100)) - - -def test_gprof_1d_dataset_pretraining_ssmi(): - """ - Test that the correct inputs are loaded when loading a pre-training dataset - for SSMI. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, batch_size=1, normalize=False, targets=["surface_precip"], - sensor=sensors.SSMI - ) - - x, _ = dataset[0] - x = x.numpy() - - assert x.shape[1] == 7 + 2 + 18 + 4 - - -###################################################################### -# GMI - 3D -###################################################################### - -def test_gprof_3d_dataset_input_gmi(): - """ - Ensure that input variables have realistic values. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc.gz" - dataset = GPROF_NN_3D_Dataset( - input_file, batch_size=1, normalize=False, targets=["surface_precip"] - ) - x, _ = dataset[0] - x = x.numpy() - - tbs = x[:, :15] - tbs = tbs[np.isfinite(tbs)] - assert np.all((tbs > 30) * (tbs < 400)) - - t2m = x[:, 15] - assert np.all((t2m > 180) * (t2m < 350)) - - tcwv = x[:, 16] - assert np.all((tcwv > 0) * (tcwv < 100)) - - -def test_gprof_3d_dataset_gmi(): - """ - Ensure that iterating over 3D dataset conserves - statistics. - """ - data_path = Path(__file__).parent / "data" - input_file = data_path / "gmi" / "gprof_nn_gmi_era5.nc" - dataset = GPROF_NN_3D_Dataset( - input_file, batch_size=1, augment=False, transform_zeros=True - ) - - xs = [] - ys = [] - - x_mean_ref = dataset.x.sum(axis=0) - y_mean_ref = dataset.y["surface_precip"].sum(axis=0) - - for x, y in dataset: - xs.append(x) - ys.append(y["surface_precip"]) - - xs = torch.cat(xs, dim=0) - ys = torch.cat(ys, dim=0) - - x_mean = xs.sum(dim=0).detach().numpy() - y_mean = ys.sum(dim=0).detach().numpy() - - y_mean = y_mean[np.isfinite(y_mean)] - y_mean_ref = y_mean_ref[np.isfinite(y_mean_ref)] - - assert np.all(np.isclose(x_mean, x_mean_ref, atol=1e-3)) - assert np.all(np.isclose(y_mean, y_mean_ref, atol=1e-3)) - - -def test_gprof_3d_dataset_profiles(): - """ - Ensure that loading of profile variables works. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc.gz" - dataset = GPROF_NN_3D_Dataset( - input_file, - batch_size=1, - augment=False, - transform_zeros=True, - targets=["rain_water_content", "snow_water_content", "cloud_water_content"], - ) - - xs = [] - ys = {} - - x_mean_ref = dataset.x.sum(axis=0) - y_mean_ref = {} - for k in dataset.targets: - y_mean_ref[k] = dataset.y[k].sum(axis=0) - - for x, y in dataset: - xs.append(x) - for k in y: - ys.setdefault(k, []).append(y[k]) - - xs = torch.cat(xs, dim=0) - for k in dataset.targets: - ys[k] = torch.cat(ys[k], dim=0) - - x_mean = xs.sum(dim=0).detach().numpy() - y_mean = {} - for k in dataset.targets: - y_mean[k] = ys[k].sum(dim=0).detach().numpy() - - for k in dataset.targets: - y_mean[k] = y_mean[k][np.isfinite(y_mean[k])] - y_mean_ref[k] = y_mean_ref[k][np.isfinite(y_mean_ref[k])] - - assert np.all(np.isclose(x_mean, x_mean_ref, atol=1e-3)) - for k in dataset.targets: - assert np.all(np.isclose(y_mean[k], y_mean_ref[k], atol=1e-3)) - - -def test_to_xarray_dataset_3d(): - """ - Ensure that converting training data to 'xarray.Dataset' yield same - Tbs as the ones found in the first batch of the training data when - data is not shuffled. - """ - data_path = Path(__file__).parent / "data" - input_file = data_path / "gmi" / "gprof_nn_gmi_era5.nc" - dataset = GPROF_NN_3D_Dataset( - input_file, - batch_size=32, - normalize=False, - shuffle=False, - targets=["surface_precip", "rain_water_content"] - ) - data = dataset.to_xarray_dataset() - x, y = dataset[0] - x = x.numpy() - - tbs = data.brightness_temperatures.data - tbs_ref = x[:, :15] - tbs_ref = np.transpose(tbs_ref, (0, 2, 3, 1)) - valid = np.isfinite(tbs_ref) - assert np.all(np.isclose(tbs[valid], tbs_ref[valid])) - - t2m = data.two_meter_temperature.data - t2m_ref = x[:, 15] - assert np.all(np.isclose(t2m, t2m_ref)) - - tcwv = data.total_column_water_vapor.data - tcwv_ref = x[:, 16] - assert np.all(np.isclose(tcwv, tcwv_ref)) - - -def test_gprof_3d_dataset_input_mhs(): - """ - Ensure that input variables have realistic values. - """ - input_file = DATA_PATH / "mhs" / "gprof_nn_mhs_era5.nc.gz" - dataset = GPROF_NN_3D_Dataset( - input_file, batch_size=1, normalize=False, targets=["surface_precip"], - sensor=sensors.MHS - ) - x, _ = dataset[0] - x = x.numpy() - - tbs = x[:, :5] - tbs = tbs[np.isfinite(tbs)] - assert np.all((tbs > 30) * (tbs < 400)) - - eia = x[:, 5] - eia = eia[np.isfinite(eia)] - assert np.all((eia >= -60) * (eia <= 60)) - - t2m = x[:, 6] - t2m = t2m[np.isfinite(t2m)] - assert np.all((t2m > 200) * (t2m < 350)) - - tcwv = x[:, 7] - tcwv = tcwv[np.isfinite(tcwv)] - assert np.all((tcwv > 0) * (tcwv < 100)) - - -def test_gprof_3d_dataset_mhs(): - """ - Test loading of 3D training data for MHS sensor. - """ - input_file = DATA_PATH / "mhs" / "gprof_nn_mhs_era5.nc" - dataset = GPROF_NN_3D_Dataset( - input_file, batch_size=1, augment=False, transform_zeros=True - ) - - xs = [] - ys = [] - - x_mean_ref = dataset.x.sum(axis=0) - y_mean_ref = dataset.y["surface_precip"].sum(axis=0) - - for x, y in dataset: - xs.append(x) - ys.append(y["surface_precip"]) - - xs = torch.cat(xs, dim=0) - ys = torch.cat(ys, dim=0) - - x_mean = xs.sum(dim=0).detach().numpy() - y_mean = ys.sum(dim=0).detach().numpy() - - y_mean = y_mean[np.isfinite(y_mean)] - y_mean_ref = y_mean_ref[np.isfinite(y_mean_ref)] - - assert np.all(np.isclose(x_mean, x_mean_ref, atol=1e-3)) - assert np.all(np.isclose(y_mean, y_mean_ref, atol=1e-3)) - - -def test_gprof_3d_dataset_input_tmi(): - """ - Ensure that input variables have realistic values. - """ - DATA_PATH = Path(__file__).parent / "data" - input_file = DATA_PATH / "tmi" / "gprof_nn_tmi_era5.nc.gz" - dataset = GPROF_NN_3D_Dataset( - input_file, batch_size=1, normalize=False, targets=["surface_precip"] - ) - x, _ = dataset[0] - x = x.numpy() - - tbs = x[:, :9] - tbs = tbs[np.isfinite(tbs)] - assert np.all((tbs > 30) * (tbs < 400)) - - t2m = x[:, 9] - t2m = t2m[np.isfinite(t2m)] - assert np.all((t2m > 200) * (t2m < 350)) - - tcwv = x[:, 10] - tcwv = tcwv[np.isfinite(tcwv)] - assert np.all((tcwv > 0) * (tcwv < 100)) - - -def test_gprof_3d_dataset_tmi(): - """ - Test loading of 3D training data for MHS sensor. - """ - DATA_PATH = Path(__file__).parent / "data" - input_file = DATA_PATH / "tmi" / "gprof_nn_tmi_era5.nc" - dataset = GPROF_NN_3D_Dataset( - input_file, batch_size=1, augment=False, transform_zeros=True - ) - - xs = [] - ys = [] - - x_mean_ref = dataset.x.sum(axis=0) - y_mean_ref = dataset.y["surface_precip"].sum(axis=0) - - for x, y in dataset: - xs.append(x) - ys.append(y["surface_precip"]) - - xs = torch.cat(xs, dim=0) - ys = torch.cat(ys, dim=0) - - x_mean = xs.sum(dim=0).detach().numpy() - y_mean = ys.sum(dim=0).detach().numpy() - - y_mean = y_mean[np.isfinite(y_mean)] - y_mean_ref = y_mean_ref[np.isfinite(y_mean_ref)] - - assert np.all(np.isclose(x_mean, x_mean_ref, atol=1e-3)) - assert np.all(np.isclose(y_mean, y_mean_ref, atol=1e-3)) - - -def test_simulator_dataset_gmi(): - """ - Test loading of simulator training data. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc.gz" - dataset = SimulatorDataset(input_file, normalize=False, batch_size=1024) - x, y = dataset[0] - x = x.numpy() - y = {k: y[k].numpy() for k in y} - - tbs = x[:, :15] - tbs = tbs[np.isfinite(tbs)] - assert np.all((tbs > 30) * (tbs < 400)) - t2m = x[:, 15] - assert np.all((t2m > 180) * (t2m < 350)) - tcwv = x[:, 16] - assert np.all((tcwv > 0) * (tcwv < 100)) - - # Input Tbs must match simulated plus biases. - for i in range(x.shape[0]): - tbs_in = x[i, [0], :, :] - tbs_sim = y[f"simulated_brightness_temperatures_0"][i, :, :, :] - tbs_bias = y[f"brightness_temperature_biases_0"][i, :, :, :] - tbs_sim[tbs_sim <= -900] = np.nan - tbs_bias[tbs_bias <= -900] = np.nan - tbs_out = tbs_sim - tbs_bias - - valid = np.isfinite(tbs_out) * np.isfinite(tbs_in) - tbs_in = tbs_in[valid] - tbs_out = tbs_out[valid] - - if tbs_in.size == 0: - continue - ind = np.argmax(np.abs(tbs_in - tbs_out)) - assert np.all(np.isclose(tbs_in, tbs_out, atol=1e-3)) - - -def test_simulator_dataset_mhs(): - """ - Test loading of simulator training data. - """ - input_file = DATA_PATH / "mhs" / "gprof_nn_mhs_era5.nc.gz" - dataset = SimulatorDataset( - input_file, batch_size=1024, augment=True, normalize=False - ) - x, y = dataset[0] - - x = x.numpy() - tbs = x[:, :15] - tbs = tbs[np.isfinite(tbs)] - assert np.all((tbs > 30) * (tbs < 400)) - t2m = x[:, 15] - t2m = t2m[np.isfinite(t2m)] - assert np.all((t2m > 180) * (t2m < 350)) - tcwv = x[:, 16] - tcwv = tcwv[np.isfinite(tcwv)] - assert np.all((tcwv > 0) * (tcwv < 100)) - - assert np.all(np.isfinite(y["brightness_temperature_biases_0"].numpy())) - assert np.all(np.isfinite(y["simulated_brightness_temperatures_0"].numpy())) - assert "brightness_temperature_biases_0" in y - assert len(y["brightness_temperature_biases_0"].shape) == 4 - assert "simulated_brightness_temperatures_0" in y - assert len(y["simulated_brightness_temperatures_0"].shape) == 5 - - -def test_gprof_3d_dataset_pretraining_mhs(): - """ - Test that the correct inputs are loaded when loading a pre-training dataset - for MHS. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc.gz" - dataset = GPROF_NN_3D_Dataset( - input_file, batch_size=1, normalize=False, targets=["surface_precip"], - sensor=sensors.MHS - ) - - x, _ = dataset[0] - x = x.numpy() - - assert x.shape[1] == 5 + 3 + 18 + 4 - - -def test_gprof_3d_dataset_pretraining_tmi(): - """ - Test that the correct inputs are loaded when loading a pre-training dataset - for TMI. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc.gz" - dataset = GPROF_NN_3D_Dataset( - input_file, batch_size=1, normalize=False, targets=["surface_precip"], - sensor=sensors.TMIPR - ) - - x, _ = dataset[0] - x = x.numpy() - - assert x.shape[1] == 9 + 2 + 18 + 4 - - -def test_drop_inputs(): - """ - Enusre that dropp - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc.gz" - dataset = GPROF_NN_1D_Dataset( - input_file, batch_size=1, augment=False, targets=["surface_precip"] - ) - - xs = [] - ys = [] - - x_mean_ref = dataset.x.sum(axis=0) - y_mean_ref = dataset.y["surface_precip"].sum(axis=0) - - for x, y in dataset: - xs.append(x) - ys.append(y["surface_precip"]) - - xs = torch.cat(xs, dim=0) - ys = torch.cat(ys, dim=0) - - x_mean = xs.sum(dim=0).detach().numpy() - y_mean = ys.sum(dim=0).detach().numpy() - - assert np.all(np.isclose(x_mean, x_mean_ref, rtol=1e-3)) - assert np.all(np.isclose(y_mean, y_mean_ref, rtol=1e-3)) diff --git a/test/test_utils.py b/test/test_utils.py deleted file mode 100644 index 4e01de0..0000000 --- a/test/test_utils.py +++ /dev/null @@ -1,382 +0,0 @@ -""" -Tests for the ``gprof_nn.utils`` module. -""" -from pathlib import Path - -import numpy as np -import xarray as xr - -from gprof_nn.augmentation import get_transformation_coordinates -from gprof_nn.data import get_test_data_path -from gprof_nn.sensors import GMI_VIEWING_GEOMETRY -from gprof_nn.utils import ( - apply_limits, - get_mask, - calculate_interpolation_weights, - interpolate, - calculate_tiles_and_cuts, -) -from gprof_nn.data.utils import ( - load_variable, - decompress_scene, - remap_scene, - upsample_scans, - save_scene, - extract_scenes, - write_training_samples_1d, - write_training_samples_3d, -) -from gprof_nn.data.training_data import decompress_and_load - -from conftest import ( - training_files_3d_gmi, - sim_collocations_gmi -) - - -DATA_PATH = get_test_data_path() - - -def test_apply_limits(): - """ - Ensure that upper and lower bounds are applied correctly. - """ - x = np.random.normal(size=(10, 10)) - - x_l = apply_limits(x, 0.0, None) - x_l = x_l[np.isfinite(x_l)] - assert np.all(x_l >= 0.0) - - x_r = apply_limits(x, None, 0.0) - x_r = x_r[np.isfinite(x_r)] - assert np.all(x_r <= 0.0) - - x = apply_limits(x, 0.0, 0.0) - x = x[np.isfinite(x)] - assert x.size == 0 - - -def test_get_mask(): - """ - Ensure that values extracted with mask are within given limits. - """ - x = np.random.normal(size=(10, 10)) - - mask = get_mask(x, 0.0, None) - x_l = x[mask] - assert np.all(x_l >= 0.0) - - mask = get_mask(x, None, 0.0) - x_r = x[mask] - assert np.all(x_r <= 0.0) - - mask = get_mask(x, 0.0, 0.0) - x = x[mask] - assert x.size == 0 - - -def test_calculate_interpolation_weights(): - """ - Ensure that calculating interpolation weights for the grid values - itself produces a diagonal matrix of weights. - - Also ensure that weights always sum to one across last dimension. - """ - grid = np.arange(0, 11) - weights = calculate_interpolation_weights(grid, grid) - - assert np.all(np.isclose(weights.diagonal(), 1.0)) - - values = np.random.uniform(0, 10, size=(10, 10)) - weights = calculate_interpolation_weights(values, grid) - assert np.all(np.isclose(np.sum(weights, 2), 1.0)) - - -def test_interpolation(): - """ - Ensure that calculating interpolation weights for the grid values - itself produces a diagonal matrix of weights. - """ - grid = np.arange(0, 11) - weights = calculate_interpolation_weights(grid, grid) - y = interpolate(np.repeat(grid.reshape(1, -1), 11, 0), weights) - assert np.all(np.isclose(grid, y)) - - values = np.random.uniform(0, 10, size=(10)) - weights = calculate_interpolation_weights(values, grid) - y = interpolate(np.repeat(grid.reshape(1, -1), 10, 0), weights) - assert np.all(np.isclose(y, values)) - - -def test_load_variable(): - """ - Ensure that loading a variable correctly replaces invalid value and - conserves shape when used without mask. - - Also ensure that masking works. - """ - input_file = DATA_PATH / "gmi" / "gprof_nn_gmi_era5.nc.gz" - dataset = decompress_and_load(input_file) - sp = load_variable(dataset, "surface_precip") - - expected_shape = (dataset.samples.size, - dataset.scans.size, - dataset.pixels.size) - assert sp.shape == expected_shape - - sp = sp[np.isfinite(sp)] - assert np.all((sp >= 0.0) * (sp < 500)) - - sp = load_variable(dataset, "surface_precip") - mask = sp > 10 - sp = load_variable(dataset, "surface_precip", mask) - sp = sp[np.isfinite(sp)] - assert np.all((sp > 10.0) * (sp < 500)) - - -def test_decompress_scene(training_files_3d_gmi): - """ - Ensure that loading a variable correctly replaces invalid value and - conserves shape when used without mask. - - Also ensure that masking works. - """ - input_file = training_files_3d_gmi[0] - scene = xr.load_dataset(input_file) #decompress_and_load(input_file)[{"samples": 1}] - - scene_d = decompress_scene( - scene, - [ - "surface_precip", - "rain_water_content", - "rain_water_path", - "surface_type" - ]) - - assert "pixels" in scene_d.rain_water_content.dims - - # Over ocean all pixels where IWP is defines should also - # have a valid surface precip value. - rwp = scene_d.rain_water_path.data - sp = scene_d.surface_precip.data - st = scene_d.surface_type - inds = (st == 1) * (rwp >= 0.0) - assert np.all(sp[inds] >= 0.0) - - -def test_calculate_tiles_and_cuts(): - """ - Test calculation of tiles and cuts for slicing of inputs. - """ - array = np.random.rand(1234, 128) - tiles, cuts = calculate_tiles_and_cuts(array.shape[0], 256, 8) - arrays_raw = [array[tile] for tile in tiles] - assert arrays_raw[-1].shape[0] == 256 - arrays = [arr[cut] for arr, cut in zip(arrays_raw, cuts)] - array_rec = np.concatenate(arrays, 0) - assert array_rec.shape == array.shape - assert np.all(np.isclose(array, array_rec)) - - array = np.random.rand(111, 128) - tiles, cuts = calculate_tiles_and_cuts(array.shape[0], 256, 8) - arrays_raw = [array[tile] for tile in tiles] - assert arrays_raw[-1].shape[0] == 111 - arrays = [arr[cut] for arr, cut in zip(arrays_raw, cuts)] - array_rec = np.concatenate(arrays, 0) - assert array_rec.shape == array.shape - assert np.all(np.isclose(array, array_rec)) - -def test_upsample_scans(): - - array = np.arange(10).astype(np.float32) - array_3 = upsample_scans(array) - - assert array_3.size == 28 - assert np.all(np.isclose(array_3, np.linspace(0, 9, 28))) - - -def test_save_scene( - tmp_path, - sim_collocations_gmi -): - data = sim_collocations_gmi - - save_scene(sim_collocations_gmi, tmp_path / "scene.nc") - data_loaded = xr.load_dataset(tmp_path / "scene.nc") - - # TB differences should be small and invalid values should - # be the same - for var in [ - "brightness_temperatures", - "simulated_brightness_temperatures", - "brightness_temperature_biases" - ]: - tbs = data[var].data - tbs_l = data_loaded[var].data - mask = np.isnan(tbs) + (tbs < -150) - mask_l = np.isnan(tbs_l) - assert np.all(mask == mask_l) - delta = tbs[~mask] - tbs_l[~mask] - assert np.abs(delta).max() <= 0.01 - assert np.abs(delta).max() > 0.0 - - tbs = data.brightness_temperatures.data - tbs_l = data.brightness_temperatures.data - mask = np.isnan(tbs) - mask_l = np.isnan(tbs_l) - assert np.all(mask == mask_l) - delta = tbs[~mask] - tbs_l[~mask] - assert np.abs(delta).max() <= 0.01 - - # Ensure that compression of ancillary data is - # lossless. - for var in [ - "surface_type", - "mountain_type", - "airlifting_index", - "mountain_type", - "land_fraction", - "ice_fraction", - ]: - print("TARGET :: ", var) - trgt = data[var].data - trgt_l = data_loaded[var].data - - valid = trgt >= 0 - valid_l = trgt_l >= 0 - assert np.all(valid == valid_l) - - err = trgt[valid] - trgt_l[valid] - assert np.all(err == 0.0) - - # Ensure that compression of ancillary data and targets is - # lossless. - for var in [ - "total_column_water_vapor", - "two_meter_temperature", - "surface_precip", - "ice_water_path", - "rain_water_path", - "cloud_water_path", - "rain_water_content", - "cloud_water_content", - "snow_water_content", - "latent_heat" - ]: - trgt = data[var].data - trgt_l = data_loaded[var].data - - valid = np.isfinite(trgt) - valid_l = np.isfinite(trgt_l) - assert np.all(valid == valid_l) - - err = trgt[valid] - trgt_l[valid] - assert np.all(np.abs(err) < 1e-6) - - -def test_extract_scenes(): - """ - Ensure that extracting scenes produces the expected amount - of valid pixels in the output. - """ - brightness_temperatures = np.random.rand(100, 100, 12) - surface_precip = np.random.rand(100, 100) - surface_precip[surface_precip < 0.5] = np.nan - data = xr.Dataset({ - "brightness_temperatures": ( - ("scans", "pixels", "channels"), brightness_temperatures - ), - "surface_precip": ( - ("scans", "pixels"), surface_precip - ) - }) - - scenes_sp_o = extract_scenes( - data, - 10, - 10, - overlapping=True, - min_valid=50, - reference_var="surface_precip" - ) - scenes_tbs_o = extract_scenes( - data, - 10, - 10, - overlapping=True, - min_valid=50, - reference_var="brightness_temperatures" - ) - scenes_sp = extract_scenes( - data, - 10, - 10, - overlapping=False, - min_valid=50, - reference_var="surface_precip" - ) - scenes_tbs = extract_scenes( - data, - 10, - 10, - overlapping=False, - min_valid=50, - reference_var="brightness_temperatures" - ) - - for scene in scenes_sp_o: - assert np.isfinite(scene.surface_precip.data).sum() >= 50 - - assert len(scenes_sp_o) <= len(scenes_tbs_o) - assert len(scenes_sp) <= len(scenes_sp_o) - assert len(scenes_tbs) <= len(scenes_tbs_o) - - -def test_write_training_samples_1d( - tmp_path, - sim_collocations_gmi -): - """ - Ensure that extracting and writing training samples produces - scenes of the expected size and containing the expected amount - of valid pixels. - """ - data = sim_collocations_gmi - - write_training_samples_1d( - tmp_path, - "sim_", - sim_collocations_gmi, - ) - - samples = sorted(list(tmp_path.glob("*.nc"))) - assert len(samples) == 1 - - -def test_write_training_samples_3d( - tmp_path, - sim_collocations_gmi -): - """ - Ensure that extracting and writing training samples produces - scenes of the expected size and containing the expected amount - of valid pixels. - """ - data = sim_collocations_gmi - - write_training_samples_3d( - tmp_path, - "sim_", - sim_collocations_gmi, - min_valid=512, - n_scans=128, - n_pixels=64 - ) - - samples = sorted(list(tmp_path.glob("*.nc"))) - for sample in samples: - data = xr.load_dataset(sample) - valid = np.isfinite(data.surface_precip.data) - - assert valid.shape == (128, 64) - assert valid.sum() >= 512 diff --git a/test/test_validation.py b/test/test_validation.py deleted file mode 100644 index d598afe..0000000 --- a/test/test_validation.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -Tests for the gprof_nn.data.validation module. -""" -import numpy as np -import xarray as xr - -from gprof_nn import sensors -from gprof_nn.data import get_test_data_path -from gprof_nn.data.l1c import L1CFile -from gprof_nn.data.validation import (ValidationData, - unify_grid, - ValidationFileProcessor) -from gprof_nn.utils import great_circle_distance - - -DATA_PATH = get_test_data_path() - - -def test_get_granules(): - """ - Test listing of granules. - """ - validation_data = ValidationData(sensors.GMI) - granules = validation_data.get_granules(2016, 10) - assert 15199 in granules - - -def test_open_granule(): - """ - Test listing of granules. - """ - validation_data = ValidationData(sensors.GMI) - data = validation_data.open_granule(2016, 10, 15199) - - sp = data.surface_precip.data - sp = sp[sp >= 0] - assert np.all(sp <= 500) - - lats = data.latitude.data - assert np.all((lats >= 0) * (lats <= 60)) - - lons = data.longitude.data - assert np.all((lons >= -180) * (lons <= 0)) - - -def test_unify_grid(): - """ - Test that the unify grid function yields grids with resolutions - close to 5km. - """ - l1c_file = DATA_PATH / "gmi" / "l1c" / ( - "1C-R.GPM.GMI.XCAL2016-C.20190101-S001447-E014719.027510.V07A.HDF5" - ) - l1c_data = L1CFile(l1c_file).to_xarray_dataset() - lats = l1c_data.latitude.data - lons = l1c_data.longitude.data - - lats_5, lons_5 = unify_grid(lats, lons, sensors.GMI) - - # Along track distance. - d = great_circle_distance( - lats_5[:-1], lons_5[:-1], - lats_5[1:], lons_5[1:] - ) - assert d.min() > 4.4e3 - assert d.max() < 5.5e3 - - - # Across track distance - d = great_circle_distance( - lats_5[:, :-1], lons_5[:, :-1], - lats_5[:, 1:], lons_5[:, 1:] - ) - assert d.min() > 4.8e3 - assert d.max() < 5.2e3 - - -def test_validation_file_processor(tmp_path): - """ - Ensure that mrmrs data is interpolated to 5km x 5km grid. - """ - mrms_file = tmp_path / "mrms.nc" - pp_file = tmp_path / "preprocessor.pp" - - processor = ValidationFileProcessor(sensors.GMI, 2016, 10) - processor.process_granule(15199, mrms_file, pp_file) - - mrms_data = xr.load_dataset(mrms_file) - - assert mrms_data.attrs["sensor"] == "GMI" - - lats = mrms_data.latitude.data - lons = mrms_data.longitude.data - - # Along track distance. - d = great_circle_distance( - lats[:-1], lons[:-1], - lats[1:], lons[1:] - ) - assert d.min() > 4.4e3 - assert d.max() < 5.5e3 - - # Across track distance - d = great_circle_distance( - lats[:, :-1], lons[:, :-1], - lats[:, 1:], lons[:, 1:] - ) - assert d.min() > 4.8e3 - assert d.max() < 5.2e3 - - sp = mrms_data.surface_precip.data - sp = sp[np.isfinite(sp)] - assert np.all((sp >= 0.0) * (sp <= 500.0))