Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions commands.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Run pytest with coverage for the specific directory
pytest --cov=src/open_data_pvnet tests/
34 changes: 34 additions & 0 deletions scripts/preprocess_gfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import xarray as xr
import logging
from open_data_pvnet.nwp.gfs_dataset import preprocess_gfs_data

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def process_gfs_year(year: int, input_path: str, output_path: str) -> None:
"""
Process GFS data for a specific year.

Args:
year (int): Year to process
input_path (str): Input file pattern
output_path (str): Output path for processed data
"""
logger.info(f"Processing GFS data for year {year}")

# Load data
gfs = xr.open_mfdataset(f"{input_path}/{year}*.zarr.zip", engine="zarr")

# Apply preprocessing
gfs = preprocess_gfs_data(gfs)

# Save processed data
logger.info(f"Saving processed data to {output_path}")
gfs.to_zarr(f"{output_path}/gfs_uk_{year}.zarr", mode="w")
logger.info(f"Completed processing for {year}")


if __name__ == "__main__":
for year in [2023, 2024]:
process_gfs_year(year=year, input_path="/mnt/storage_b/nwp/gfs/global", output_path="uk")
3 changes: 2 additions & 1 deletion src/open_data_pvnet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from open_data_pvnet.utils.data_uploader import upload_monthly_zarr, upload_to_huggingface
from open_data_pvnet.scripts.archive import handle_archive
from open_data_pvnet.nwp.met_office import CONFIG_PATHS
from open_data_pvnet.nwp.dwd import process_dwd_data

# Removed unused import: process_dwd_data

logger = logging.getLogger(__name__)

Expand Down
27 changes: 27 additions & 0 deletions src/open_data_pvnet/nwp/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Constants for NWP data processing."""

import xarray as xr
import numpy as np

# GFS channels as defined in gfs_data_config.yaml
GFS_CHANNELS = [
"dlwrf", # Downward Long-Wave Radiation Flux
"dswrf", # Downward Short-Wave Radiation Flux
"hcc", # High Cloud Cover
"lcc", # Low Cloud Cover
"mcc", # Medium Cloud Cover
"prate", # Precipitation Rate
"r", # Relative Humidity
"t", # Temperature
"tcc", # Total Cloud Cover
"u10", # U-Component of Wind at 10m
"u100", # U-Component of Wind at 100m
"v10", # V-Component of Wind at 10m
"v100", # V-Component of Wind at 100m
"vis", # Visibility
]

# Define normalization constants
NWP_MEANS = {"gfs": xr.DataArray(np.zeros(len(GFS_CHANNELS)), coords={"channel": GFS_CHANNELS})}

NWP_STDS = {"gfs": xr.DataArray(np.ones(len(GFS_CHANNELS)), coords={"channel": GFS_CHANNELS})}
107 changes: 59 additions & 48 deletions src/open_data_pvnet/nwp/dwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@

CONFIG_PATH = PROJECT_BASE / "src/open_data_pvnet/configs/dwd_data_config.yaml"


class DWDHTMLParser(HTMLParser):
def __init__(self):
super().__init__()
self.links = []

def handle_starttag(self, tag, attrs):
if tag == 'a':
if tag == "a":
for attr in attrs:
if attr[0] == 'href':
if attr[0] == "href":
self.links.append(attr[1])

def error(self, message):
pass # Override to prevent errors from being raised


def generate_variable_url(variable: str, year: int, month: int, day: int, hour: int) -> str:
"""
Generate the URL for a specific variable and forecast time.
Expand All @@ -48,6 +50,7 @@ def generate_variable_url(variable: str, year: int, month: int, day: int, hour:
timestamp = f"{year:04d}{month:02d}{day:02d}{hour:02d}"
return f"{base_url}/{hour:02d}/{variable.lower()}/icon-eu_europe_regular-lat-lon_single-level_{timestamp}_*"


def decompress_bz2(input_path: Path, output_path: Path):
"""
Decompress a bz2 file.
Expand All @@ -56,9 +59,10 @@ def decompress_bz2(input_path: Path, output_path: Path):
input_path (Path): Path to the bz2 file
output_path (Path): Path to save the decompressed file
"""
with bz2.open(input_path, 'rb') as source, open(output_path, 'wb') as dest:
with bz2.open(input_path, "rb") as source, open(output_path, "wb") as dest:
dest.write(source.read())


def fetch_dwd_data(year: int, month: int, day: int, hour: int):
"""
Fetch DWD ICON-EU NWP data for the specified year, month, day, and hour.
Expand Down Expand Up @@ -108,10 +112,10 @@ def fetch_dwd_data(year: int, month: int, day: int, hour: int):
logger.debug(f"HTML content: {response.text}")
links = parser.links
logger.debug(f"Extracted links: {links}")

timestamp = f"{year:04d}{month:02d}{day:02d}{hour:02d}"
target_prefix = f"icon-eu_europe_regular-lat-lon_single-level_{timestamp}"

found_file = False
for href in links:
if not href or not href.startswith(target_prefix):
Expand All @@ -125,7 +129,7 @@ def fetch_dwd_data(year: int, month: int, day: int, hour: int):
file_response = requests.get(file_url, stream=True)
file_response.raise_for_status()

with open(compressed_file, 'wb') as f:
with open(compressed_file, "wb") as f:
for chunk in file_response.iter_content(chunk_size=8192):
f.write(chunk)

Expand All @@ -145,27 +149,59 @@ def fetch_dwd_data(year: int, month: int, day: int, hour: int):

return total_files


def convert_grib_to_zarr(raw_dir: Path, zarr_dir: Path) -> xr.Dataset:
"""Convert GRIB2 files to Zarr format."""
datasets = []
for grib_file in raw_dir.glob("*.grib2"):
try:
ds = xr.open_dataset(grib_file, engine="cfgrib")
variable_name = grib_file.stem.split("_")[0]
main_var = list(ds.data_vars)[0]
ds = ds.rename({main_var: variable_name})
datasets.append(ds)
except Exception as e:
logger.error(f"Error reading {grib_file}: {e}")
continue

if not datasets:
return None

return xr.merge(datasets)


def handle_upload(
zarr_dir: Path,
raw_dir: Path,
config_path: Path,
year: int,
month: int,
day: int,
overwrite: bool,
archive_type: str,
) -> None:
"""Handle uploading data to Hugging Face."""
try:
upload_to_huggingface(config_path, zarr_dir.name, year, month, day, overwrite, archive_type)
logger.info("Upload to Hugging Face completed.")
shutil.rmtree(raw_dir)
shutil.rmtree(zarr_dir)
logger.info("Temporary directories cleaned up.")
except Exception as e:
logger.error(f"Error during upload: {e}")
raise


def process_dwd_data(
year: int,
month: int,
day: int,
hour: int,
overwrite: bool = False,
archive_type: str = "zarr.zip",
skip_upload: bool = True, # Skip upload by default until HF token is set
skip_upload: bool = True,
):
"""
Fetch, convert, and upload DWD ICON-EU data.

Args:
year (int): Year of data
month (int): Month of data
day (int): Day of data
hour (int): Hour of data
overwrite (bool): Whether to overwrite existing files. Defaults to False.
archive_type (str): Type of archive to create ("zarr.zip" or "tar")
skip_upload (bool): Whether to skip uploading to Hugging Face. Defaults to True.
"""
"""Fetch, convert, and upload DWD ICON-EU data."""
config = load_config(CONFIG_PATH)
local_output_dir = config["input_data"]["nwp"]["dwd"]["local_output_dir"]

Expand All @@ -186,43 +222,18 @@ def process_dwd_data(
# Step 2: Convert GRIB2 files to Zarr
if not zarr_dir.exists() or overwrite:
zarr_dir.mkdir(parents=True, exist_ok=True)

# Load all GRIB2 files and combine them
datasets = []
for grib_file in raw_dir.glob("*.grib2"):
try:
ds = xr.open_dataset(grib_file, engine='cfgrib')
variable_name = grib_file.stem.split("_")[0] # Get variable name from our filename
# Rename the main variable to match the filename
main_var = list(ds.data_vars)[0]
ds = ds.rename({main_var: variable_name})
datasets.append(ds)
except Exception as e:
logger.error(f"Error reading {grib_file}: {e}")
continue
combined_ds = convert_grib_to_zarr(raw_dir, zarr_dir)

if not datasets:
if combined_ds is None:
logger.warning("No valid GRIB2 files found. Exiting process.")
return

# Merge all datasets
combined_ds = xr.merge(datasets)

# Save to zarr format
logger.info(f"Saving combined dataset to {zarr_dir}")
combined_ds.to_zarr(zarr_dir, mode='w')
combined_ds.to_zarr(zarr_dir, mode="w")

# Step 3: Upload Zarr directory (optional)
if not skip_upload:
try:
upload_to_huggingface(CONFIG_PATH, zarr_dir.name, year, month, day, overwrite, archive_type)
logger.info("Upload to Hugging Face completed.")
shutil.rmtree(raw_dir)
shutil.rmtree(zarr_dir)
logger.info("Temporary directories cleaned up.")
except Exception as e:
logger.error(f"Error during upload: {e}")
raise
handle_upload(zarr_dir, raw_dir, CONFIG_PATH, year, month, day, overwrite, archive_type)
else:
logger.info("Skipping upload to Hugging Face (skip_upload=True)")
logger.info(f"Data is available in {zarr_dir}")
69 changes: 53 additions & 16 deletions src/open_data_pvnet/nwp/gfs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
from torch.utils.data import Dataset
from ocf_data_sampler.config import load_yaml_configuration
from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
import fsspec
import numpy as np

from open_data_pvnet.nwp.constants import NWP_MEANS, NWP_STDS

# Configure logging
logging.basicConfig(level=logging.WARNING)
Expand All @@ -23,31 +22,69 @@
xr.set_options(keep_attrs=True)


def preprocess_gfs_data(gfs: xr.Dataset) -> xr.Dataset:
"""
Preprocess GFS dataset to ensure correct coordinates and dimensions.

Args:
gfs (xr.Dataset): Raw GFS dataset

Returns:
xr.Dataset: Preprocessed dataset with correct coordinates and dimensions
"""
# 1. Ensure longitude is in [0, 360) range
gfs["longitude"] = (gfs["longitude"] + 360) % 360

# 2. Select UK region (longitude 350-10, latitude 45-65)
gfs = gfs.sel(latitude=slice(65, 45)) # North to South

# 3. Handle the longitude wrap-around for UK
gfs1 = gfs.sel(longitude=slice(350, 360)) # Western part
gfs2 = gfs.sel(longitude=slice(0, 10)) # Eastern part
gfs = xr.concat([gfs1, gfs2], dim="longitude")

# 4. Convert to DataArray with channel dimension
if isinstance(gfs, xr.Dataset):
gfs = gfs.to_array(dim="channel")

# 5. Optimize chunking for performance
gfs = gfs.chunk(
{
"init_time_utc": -1, # Keep full time dimension
"step": 4, # Chunk forecast steps
"channel": -1, # Keep all channels together
"latitude": 1, # Small chunks for spatial dimensions
"longitude": 1,
}
)

return gfs


def open_gfs(dataset_path: str) -> xr.DataArray:
"""
Opens the GFS dataset stored in Zarr format and prepares it for processing.
Opens and preprocesses the GFS dataset.

Args:
dataset_path (str): Path to the GFS dataset.
dataset_path (str): Path to the GFS dataset

Returns:
xr.DataArray: The processed GFS data.
xr.DataArray: Processed GFS data with correct dimensions
"""
logging.info("Opening GFS dataset synchronously...")
logging.info("Opening GFS dataset...")
store = fsspec.get_mapper(dataset_path, anon=True)
gfs_dataset: xr.Dataset = xr.open_dataset(
store, engine="zarr", consolidated=True, chunks="auto"
)
gfs_data: xr.DataArray = gfs_dataset.to_array(dim="channel")
gfs_dataset = xr.open_dataset(store, engine="zarr", consolidated=True)

if "init_time" in gfs_data.dims:
logging.debug("Renaming 'init_time' to 'init_time_utc'...")
gfs_data = gfs_data.rename({"init_time": "init_time_utc"})
# Apply preprocessing
gfs_data = preprocess_gfs_data(gfs_dataset)

required_dims = ["init_time_utc", "step", "channel", "latitude", "longitude"]
gfs_data = gfs_data.transpose(*required_dims)
# Ensure required dimensions are present
expected_dims = ["init_time_utc", "step", "channel", "latitude", "longitude"]
if not all(dim in gfs_data.dims for dim in expected_dims):
raise ValueError(
f"Missing required dimensions. Expected {expected_dims}, got {list(gfs_data.dims)}"
)

logging.debug(f"GFS dataset dimensions: {gfs_data.dims}")
return gfs_data


Expand Down
4 changes: 2 additions & 2 deletions src/open_data_pvnet/utils/data_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _validate_config(config):
# First validate required fields
if "general" not in config:
raise ValueError("No general configuration section found")

if "destination_dataset_id" not in config["general"]:
raise ValueError("No destination_dataset_id found in general configuration")

Expand All @@ -41,7 +41,7 @@ def _validate_config(config):

# Check provider configuration
nwp_config = config["input_data"]["nwp"]

# Check if it's a DWD config
if "dwd" in nwp_config:
local_output_dir = nwp_config["dwd"]["local_output_dir"]
Expand Down
Loading
Loading