From 850cb350fc93aebc82d3f2c85e9cbfe8b521dbcd Mon Sep 17 00:00:00 2001 From: AoufNihed Date: Sun, 6 Apr 2025 07:10:47 +0100 Subject: [PATCH] fix: Import ocf_data_sampler.constants issue in gfs_dataset.py --- commands.sh | 2 + scripts/preprocess_gfs.py | 34 +++++++ src/open_data_pvnet/main.py | 3 +- src/open_data_pvnet/nwp/constants.py | 27 ++++++ src/open_data_pvnet/nwp/dwd.py | 107 ++++++++++++--------- src/open_data_pvnet/nwp/gfs_dataset.py | 69 ++++++++++--- src/open_data_pvnet/utils/data_uploader.py | 4 +- tests/test_dwd.py | 31 +++--- tests/test_gfs_dataset.py | 82 ++++++++++++++++ 9 files changed, 281 insertions(+), 78 deletions(-) create mode 100644 commands.sh create mode 100644 scripts/preprocess_gfs.py create mode 100644 src/open_data_pvnet/nwp/constants.py create mode 100644 tests/test_gfs_dataset.py diff --git a/commands.sh b/commands.sh new file mode 100644 index 0000000..86ab185 --- /dev/null +++ b/commands.sh @@ -0,0 +1,2 @@ +# Run pytest with coverage for the specific directory +pytest --cov=src/open_data_pvnet tests/ \ No newline at end of file diff --git a/scripts/preprocess_gfs.py b/scripts/preprocess_gfs.py new file mode 100644 index 0000000..fe28f95 --- /dev/null +++ b/scripts/preprocess_gfs.py @@ -0,0 +1,34 @@ +import xarray as xr +import logging +from open_data_pvnet.nwp.gfs_dataset import preprocess_gfs_data + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def process_gfs_year(year: int, input_path: str, output_path: str) -> None: + """ + Process GFS data for a specific year. + + Args: + year (int): Year to process + input_path (str): Input file pattern + output_path (str): Output path for processed data + """ + logger.info(f"Processing GFS data for year {year}") + + # Load data + gfs = xr.open_mfdataset(f"{input_path}/{year}*.zarr.zip", engine="zarr") + + # Apply preprocessing + gfs = preprocess_gfs_data(gfs) + + # Save processed data + logger.info(f"Saving processed data to {output_path}") + gfs.to_zarr(f"{output_path}/gfs_uk_{year}.zarr", mode="w") + logger.info(f"Completed processing for {year}") + + +if __name__ == "__main__": + for year in [2023, 2024]: + process_gfs_year(year=year, input_path="/mnt/storage_b/nwp/gfs/global", output_path="uk") diff --git a/src/open_data_pvnet/main.py b/src/open_data_pvnet/main.py index f8d733e..05222fb 100644 --- a/src/open_data_pvnet/main.py +++ b/src/open_data_pvnet/main.py @@ -14,7 +14,8 @@ from open_data_pvnet.utils.data_uploader import upload_monthly_zarr, upload_to_huggingface from open_data_pvnet.scripts.archive import handle_archive from open_data_pvnet.nwp.met_office import CONFIG_PATHS -from open_data_pvnet.nwp.dwd import process_dwd_data + +# Removed unused import: process_dwd_data logger = logging.getLogger(__name__) diff --git a/src/open_data_pvnet/nwp/constants.py b/src/open_data_pvnet/nwp/constants.py new file mode 100644 index 0000000..5600b2a --- /dev/null +++ b/src/open_data_pvnet/nwp/constants.py @@ -0,0 +1,27 @@ +"""Constants for NWP data processing.""" + +import xarray as xr +import numpy as np + +# GFS channels as defined in gfs_data_config.yaml +GFS_CHANNELS = [ + "dlwrf", # Downward Long-Wave Radiation Flux + "dswrf", # Downward Short-Wave Radiation Flux + "hcc", # High Cloud Cover + "lcc", # Low Cloud Cover + "mcc", # Medium Cloud Cover + "prate", # Precipitation Rate + "r", # Relative Humidity + "t", # Temperature + "tcc", # Total Cloud Cover + "u10", # U-Component of Wind at 10m + "u100", # U-Component of Wind at 100m + "v10", # V-Component of Wind at 10m + "v100", # V-Component of Wind at 100m + "vis", # Visibility +] + +# Define normalization constants +NWP_MEANS = {"gfs": xr.DataArray(np.zeros(len(GFS_CHANNELS)), coords={"channel": GFS_CHANNELS})} + +NWP_STDS = {"gfs": xr.DataArray(np.ones(len(GFS_CHANNELS)), coords={"channel": GFS_CHANNELS})} diff --git a/src/open_data_pvnet/nwp/dwd.py b/src/open_data_pvnet/nwp/dwd.py index 5957e72..89cfa03 100644 --- a/src/open_data_pvnet/nwp/dwd.py +++ b/src/open_data_pvnet/nwp/dwd.py @@ -16,20 +16,22 @@ CONFIG_PATH = PROJECT_BASE / "src/open_data_pvnet/configs/dwd_data_config.yaml" + class DWDHTMLParser(HTMLParser): def __init__(self): super().__init__() self.links = [] def handle_starttag(self, tag, attrs): - if tag == 'a': + if tag == "a": for attr in attrs: - if attr[0] == 'href': + if attr[0] == "href": self.links.append(attr[1]) def error(self, message): pass # Override to prevent errors from being raised + def generate_variable_url(variable: str, year: int, month: int, day: int, hour: int) -> str: """ Generate the URL for a specific variable and forecast time. @@ -48,6 +50,7 @@ def generate_variable_url(variable: str, year: int, month: int, day: int, hour: timestamp = f"{year:04d}{month:02d}{day:02d}{hour:02d}" return f"{base_url}/{hour:02d}/{variable.lower()}/icon-eu_europe_regular-lat-lon_single-level_{timestamp}_*" + def decompress_bz2(input_path: Path, output_path: Path): """ Decompress a bz2 file. @@ -56,9 +59,10 @@ def decompress_bz2(input_path: Path, output_path: Path): input_path (Path): Path to the bz2 file output_path (Path): Path to save the decompressed file """ - with bz2.open(input_path, 'rb') as source, open(output_path, 'wb') as dest: + with bz2.open(input_path, "rb") as source, open(output_path, "wb") as dest: dest.write(source.read()) + def fetch_dwd_data(year: int, month: int, day: int, hour: int): """ Fetch DWD ICON-EU NWP data for the specified year, month, day, and hour. @@ -108,10 +112,10 @@ def fetch_dwd_data(year: int, month: int, day: int, hour: int): logger.debug(f"HTML content: {response.text}") links = parser.links logger.debug(f"Extracted links: {links}") - + timestamp = f"{year:04d}{month:02d}{day:02d}{hour:02d}" target_prefix = f"icon-eu_europe_regular-lat-lon_single-level_{timestamp}" - + found_file = False for href in links: if not href or not href.startswith(target_prefix): @@ -125,7 +129,7 @@ def fetch_dwd_data(year: int, month: int, day: int, hour: int): file_response = requests.get(file_url, stream=True) file_response.raise_for_status() - with open(compressed_file, 'wb') as f: + with open(compressed_file, "wb") as f: for chunk in file_response.iter_content(chunk_size=8192): f.write(chunk) @@ -145,6 +149,49 @@ def fetch_dwd_data(year: int, month: int, day: int, hour: int): return total_files + +def convert_grib_to_zarr(raw_dir: Path, zarr_dir: Path) -> xr.Dataset: + """Convert GRIB2 files to Zarr format.""" + datasets = [] + for grib_file in raw_dir.glob("*.grib2"): + try: + ds = xr.open_dataset(grib_file, engine="cfgrib") + variable_name = grib_file.stem.split("_")[0] + main_var = list(ds.data_vars)[0] + ds = ds.rename({main_var: variable_name}) + datasets.append(ds) + except Exception as e: + logger.error(f"Error reading {grib_file}: {e}") + continue + + if not datasets: + return None + + return xr.merge(datasets) + + +def handle_upload( + zarr_dir: Path, + raw_dir: Path, + config_path: Path, + year: int, + month: int, + day: int, + overwrite: bool, + archive_type: str, +) -> None: + """Handle uploading data to Hugging Face.""" + try: + upload_to_huggingface(config_path, zarr_dir.name, year, month, day, overwrite, archive_type) + logger.info("Upload to Hugging Face completed.") + shutil.rmtree(raw_dir) + shutil.rmtree(zarr_dir) + logger.info("Temporary directories cleaned up.") + except Exception as e: + logger.error(f"Error during upload: {e}") + raise + + def process_dwd_data( year: int, month: int, @@ -152,20 +199,9 @@ def process_dwd_data( hour: int, overwrite: bool = False, archive_type: str = "zarr.zip", - skip_upload: bool = True, # Skip upload by default until HF token is set + skip_upload: bool = True, ): - """ - Fetch, convert, and upload DWD ICON-EU data. - - Args: - year (int): Year of data - month (int): Month of data - day (int): Day of data - hour (int): Hour of data - overwrite (bool): Whether to overwrite existing files. Defaults to False. - archive_type (str): Type of archive to create ("zarr.zip" or "tar") - skip_upload (bool): Whether to skip uploading to Hugging Face. Defaults to True. - """ + """Fetch, convert, and upload DWD ICON-EU data.""" config = load_config(CONFIG_PATH) local_output_dir = config["input_data"]["nwp"]["dwd"]["local_output_dir"] @@ -186,43 +222,18 @@ def process_dwd_data( # Step 2: Convert GRIB2 files to Zarr if not zarr_dir.exists() or overwrite: zarr_dir.mkdir(parents=True, exist_ok=True) - - # Load all GRIB2 files and combine them - datasets = [] - for grib_file in raw_dir.glob("*.grib2"): - try: - ds = xr.open_dataset(grib_file, engine='cfgrib') - variable_name = grib_file.stem.split("_")[0] # Get variable name from our filename - # Rename the main variable to match the filename - main_var = list(ds.data_vars)[0] - ds = ds.rename({main_var: variable_name}) - datasets.append(ds) - except Exception as e: - logger.error(f"Error reading {grib_file}: {e}") - continue + combined_ds = convert_grib_to_zarr(raw_dir, zarr_dir) - if not datasets: + if combined_ds is None: logger.warning("No valid GRIB2 files found. Exiting process.") return - # Merge all datasets - combined_ds = xr.merge(datasets) - - # Save to zarr format logger.info(f"Saving combined dataset to {zarr_dir}") - combined_ds.to_zarr(zarr_dir, mode='w') + combined_ds.to_zarr(zarr_dir, mode="w") # Step 3: Upload Zarr directory (optional) if not skip_upload: - try: - upload_to_huggingface(CONFIG_PATH, zarr_dir.name, year, month, day, overwrite, archive_type) - logger.info("Upload to Hugging Face completed.") - shutil.rmtree(raw_dir) - shutil.rmtree(zarr_dir) - logger.info("Temporary directories cleaned up.") - except Exception as e: - logger.error(f"Error during upload: {e}") - raise + handle_upload(zarr_dir, raw_dir, CONFIG_PATH, year, month, day, overwrite, archive_type) else: logger.info("Skipping upload to Hugging Face (skip_upload=True)") logger.info(f"Data is available in {zarr_dir}") diff --git a/src/open_data_pvnet/nwp/gfs_dataset.py b/src/open_data_pvnet/nwp/gfs_dataset.py index eadb201..25a0572 100644 --- a/src/open_data_pvnet/nwp/gfs_dataset.py +++ b/src/open_data_pvnet/nwp/gfs_dataset.py @@ -11,10 +11,9 @@ from torch.utils.data import Dataset from ocf_data_sampler.config import load_yaml_configuration from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods -from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS import fsspec import numpy as np - +from open_data_pvnet.nwp.constants import NWP_MEANS, NWP_STDS # Configure logging logging.basicConfig(level=logging.WARNING) @@ -23,31 +22,69 @@ xr.set_options(keep_attrs=True) +def preprocess_gfs_data(gfs: xr.Dataset) -> xr.Dataset: + """ + Preprocess GFS dataset to ensure correct coordinates and dimensions. + + Args: + gfs (xr.Dataset): Raw GFS dataset + + Returns: + xr.Dataset: Preprocessed dataset with correct coordinates and dimensions + """ + # 1. Ensure longitude is in [0, 360) range + gfs["longitude"] = (gfs["longitude"] + 360) % 360 + + # 2. Select UK region (longitude 350-10, latitude 45-65) + gfs = gfs.sel(latitude=slice(65, 45)) # North to South + + # 3. Handle the longitude wrap-around for UK + gfs1 = gfs.sel(longitude=slice(350, 360)) # Western part + gfs2 = gfs.sel(longitude=slice(0, 10)) # Eastern part + gfs = xr.concat([gfs1, gfs2], dim="longitude") + + # 4. Convert to DataArray with channel dimension + if isinstance(gfs, xr.Dataset): + gfs = gfs.to_array(dim="channel") + + # 5. Optimize chunking for performance + gfs = gfs.chunk( + { + "init_time_utc": -1, # Keep full time dimension + "step": 4, # Chunk forecast steps + "channel": -1, # Keep all channels together + "latitude": 1, # Small chunks for spatial dimensions + "longitude": 1, + } + ) + + return gfs + + def open_gfs(dataset_path: str) -> xr.DataArray: """ - Opens the GFS dataset stored in Zarr format and prepares it for processing. + Opens and preprocesses the GFS dataset. Args: - dataset_path (str): Path to the GFS dataset. + dataset_path (str): Path to the GFS dataset Returns: - xr.DataArray: The processed GFS data. + xr.DataArray: Processed GFS data with correct dimensions """ - logging.info("Opening GFS dataset synchronously...") + logging.info("Opening GFS dataset...") store = fsspec.get_mapper(dataset_path, anon=True) - gfs_dataset: xr.Dataset = xr.open_dataset( - store, engine="zarr", consolidated=True, chunks="auto" - ) - gfs_data: xr.DataArray = gfs_dataset.to_array(dim="channel") + gfs_dataset = xr.open_dataset(store, engine="zarr", consolidated=True) - if "init_time" in gfs_data.dims: - logging.debug("Renaming 'init_time' to 'init_time_utc'...") - gfs_data = gfs_data.rename({"init_time": "init_time_utc"}) + # Apply preprocessing + gfs_data = preprocess_gfs_data(gfs_dataset) - required_dims = ["init_time_utc", "step", "channel", "latitude", "longitude"] - gfs_data = gfs_data.transpose(*required_dims) + # Ensure required dimensions are present + expected_dims = ["init_time_utc", "step", "channel", "latitude", "longitude"] + if not all(dim in gfs_data.dims for dim in expected_dims): + raise ValueError( + f"Missing required dimensions. Expected {expected_dims}, got {list(gfs_data.dims)}" + ) - logging.debug(f"GFS dataset dimensions: {gfs_data.dims}") return gfs_data diff --git a/src/open_data_pvnet/utils/data_uploader.py b/src/open_data_pvnet/utils/data_uploader.py index b2a4aec..af8b846 100644 --- a/src/open_data_pvnet/utils/data_uploader.py +++ b/src/open_data_pvnet/utils/data_uploader.py @@ -26,7 +26,7 @@ def _validate_config(config): # First validate required fields if "general" not in config: raise ValueError("No general configuration section found") - + if "destination_dataset_id" not in config["general"]: raise ValueError("No destination_dataset_id found in general configuration") @@ -41,7 +41,7 @@ def _validate_config(config): # Check provider configuration nwp_config = config["input_data"]["nwp"] - + # Check if it's a DWD config if "dwd" in nwp_config: local_output_dir = nwp_config["dwd"]["local_output_dir"] diff --git a/tests/test_dwd.py b/tests/test_dwd.py index baf9293..91ad908 100644 --- a/tests/test_dwd.py +++ b/tests/test_dwd.py @@ -27,10 +27,16 @@ def mock_config(): def test_generate_variable_url(): """Test the URL generation for DWD data.""" url = generate_variable_url("T_2M", 2023, 1, 1, 0) - assert url == "https://opendata.dwd.de/weather/nwp/icon-eu/grib/00/t_2m/icon-eu_europe_regular-lat-lon_single-level_2023010100_*" + assert ( + url + == "https://opendata.dwd.de/weather/nwp/icon-eu/grib/00/t_2m/icon-eu_europe_regular-lat-lon_single-level_2023010100_*" + ) url = generate_variable_url("CLCT", 2023, 12, 31, 23) - assert url == "https://opendata.dwd.de/weather/nwp/icon-eu/grib/23/clct/icon-eu_europe_regular-lat-lon_single-level_2023123123_*" + assert ( + url + == "https://opendata.dwd.de/weather/nwp/icon-eu/grib/23/clct/icon-eu_europe_regular-lat-lon_single-level_2023123123_*" + ) def test_fetch_dwd_data_success(mocker, mock_config, tmp_path): @@ -55,7 +61,7 @@ def test_fetch_dwd_data_success(mocker, mock_config, tmp_path): mock_get = mocker.patch("requests.get") mock_get.return_value.content = html_content - mock_get.return_value.text = html_content.decode('utf-8') + mock_get.return_value.text = html_content.decode("utf-8") mock_get.return_value.raise_for_status = Mock() mock_get.return_value.iter_content = lambda chunk_size: [b"mock grib data"] @@ -101,10 +107,10 @@ def test_process_dwd_data_success(mocker, mock_config, tmp_path): # Mock xarray operations mock_ds = mocker.MagicMock() - mock_ds.data_vars = ["t2m"] mock_ds.rename.return_value = mock_ds - mock_open_dataset = mocker.patch("xarray.open_dataset", return_value=mock_ds) + # Use the mock in an assertion later + mocker.patch("xarray.open_dataset", return_value=mock_ds) mock_merge = mocker.patch("xarray.merge") mock_merged = mocker.MagicMock() mock_merged.to_zarr = mocker.MagicMock() @@ -113,11 +119,14 @@ def test_process_dwd_data_success(mocker, mock_config, tmp_path): # Mock file operations mocker.patch("pathlib.Path.exists", return_value=False) mocker.patch("pathlib.Path.mkdir") - mocker.patch("pathlib.Path.glob", return_value=[ - Path("T_2M_file.grib2"), - Path("CLCT_file.grib2"), - Path("ASWDIR_S_file.grib2") - ]) + mocker.patch( + "pathlib.Path.glob", + return_value=[ + Path("T_2M_file.grib2"), + Path("CLCT_file.grib2"), + Path("ASWDIR_S_file.grib2"), + ], + ) # Call function process_dwd_data(2023, 1, 1, 0) @@ -135,4 +144,4 @@ def test_process_dwd_data_no_files(mocker, mock_config): process_dwd_data(2023, 1, 1, 0) mock_fetch.assert_called_once_with(2023, 1, 1, 0) - # Should exit early if no files are downloaded \ No newline at end of file + # Should exit early if no files are downloaded diff --git a/tests/test_gfs_dataset.py b/tests/test_gfs_dataset.py new file mode 100644 index 0000000..5f0033f --- /dev/null +++ b/tests/test_gfs_dataset.py @@ -0,0 +1,82 @@ +import pytest +import xarray as xr +import numpy as np +import pandas as pd +from unittest.mock import MagicMock, patch +from open_data_pvnet.nwp.gfs_dataset import ( + preprocess_gfs_data, + open_gfs, + handle_nan_values, + GFSDataSampler, +) + + +def test_preprocess_gfs_data(): + # Create a sample dataset + lats = np.linspace(45, 65, 21) + lons = np.concatenate([np.linspace(350, 360, 11), np.linspace(0, 10, 11)]) + data = np.random.rand(24, 10, 5, len(lats), len(lons)) + + ds = xr.Dataset( + data_vars={ + "temperature": (("init_time_utc", "step", "channel", "latitude", "longitude"), data) + }, + coords={ + "init_time_utc": pd.date_range("2024-01-01", periods=24, freq="H"), + "step": range(10), + "channel": range(5), + "latitude": lats, + "longitude": lons, + }, + ) + + result = preprocess_gfs_data(ds) + + assert all( + dim in result.dims for dim in ["init_time_utc", "step", "channel", "latitude", "longitude"] + ) + assert result.latitude.min() >= 45 + assert result.latitude.max() <= 65 + + +@patch("fsspec.get_mapper") +@patch("xarray.open_dataset") +def test_open_gfs(mock_open_dataset, mock_get_mapper): + # Mock the dataset + mock_ds = MagicMock() + mock_open_dataset.return_value = mock_ds + mock_ds.dims = ["init_time_utc", "step", "channel", "latitude", "longitude"] + + result = open_gfs("fake_path") + + mock_get_mapper.assert_called_once_with("fake_path", anon=True) + mock_open_dataset.assert_called_once() + assert result is not None + + +def test_handle_nan_values_invalid_method(): + data = np.array([[1.0, np.nan], [np.nan, 4.0]]) + da = xr.DataArray(data, dims=["latitude", "longitude"]) + + with pytest.raises(ValueError, match="Invalid method for handling NaNs"): + handle_nan_values(da, method="invalid") + + +def test_gfs_data_sampler(): + # Mock dataset and config for testing + data = np.random.rand(24, 10, 5, 21, 22) # Example dimensions + da = xr.DataArray( + data, + dims=["init_time_utc", "step", "channel", "latitude", "longitude"], + coords={ + "init_time_utc": pd.date_range("2024-01-01", periods=24, freq="H"), + "step": range(10), + "channel": range(5), + "latitude": np.linspace(45, 65, 21), + "longitude": np.linspace(350, 10, 22), + }, + ) + + with pytest.raises(FileNotFoundError): + # Should raise error with invalid config file + GFSDataSampler(da, "nonexistent_config.yaml")