diff --git a/README.md b/README.md index e05256fc..acb5a5d5 100644 --- a/README.md +++ b/README.md @@ -113,13 +113,13 @@ There does not seem to be an automated way to do this selecting and downloading, ## Configure `nowcasting_dataset` to point to the downloaded data Copy and modify one of the config yaml files in -[`nowcasting_dataset/config/`](https://github.com/openclimatefix/nowcasting_dataset/tree/main/nowcasting_dataset/config) -and modify `prepare_ml_data.py` to use your config file. +[`nowcasting_dataset/config/`](https://github.com/openclimatefix/nowcasting_dataset/tree/main/nowcasting_dataset/config). ## Prepare ML batches -Run [`scripts/prepare_ml_data.py`](https://github.com/openclimatefix/nowcasting_dataset/blob/main/scripts/prepare_ml_data.py) +Run [`scripts/prepare_ml_data.py --help`](https://github.com/openclimatefix/nowcasting_dataset/blob/main/scripts/prepare_ml_data.py) +to learn how to run the `prepare_ml_data.py` script. ## What exactly is in each batch? diff --git a/conftest.py b/conftest.py index b66a63b7..0e4e9a6f 100644 --- a/conftest.py +++ b/conftest.py @@ -22,7 +22,7 @@ register_xr_data_set_to_tensor() -def pytest_addoption(parser): +def pytest_addoption(parser): # noqa: D103 parser.addoption( "--use_cloud_data", action="store_true", @@ -32,12 +32,12 @@ def pytest_addoption(parser): @pytest.fixture -def use_cloud_data(request): +def use_cloud_data(request): # noqa: D103 return request.config.getoption("--use_cloud_data") @pytest.fixture -def sat_filename(use_cloud_data: bool) -> Path: +def sat_filename(use_cloud_data: bool) -> Path: # noqa: D103 if use_cloud_data: return consts.SAT_FILENAME else: @@ -47,24 +47,23 @@ def sat_filename(use_cloud_data: bool) -> Path: @pytest.fixture -def sat_data_source(sat_filename: Path): +def sat_data_source(sat_filename: Path): # noqa: D103 return SatelliteDataSource( image_size_pixels=pytest.IMAGE_SIZE_PIXELS, zarr_path=sat_filename, history_minutes=0, forecast_minutes=5, channels=("HRV",), - n_timesteps_per_batch=2, ) @pytest.fixture -def general_data_source(): +def general_data_source(): # noqa: D103 return MetadataDataSource(history_minutes=0, forecast_minutes=5, object_at_center="GSP") @pytest.fixture -def gsp_data_source(): +def gsp_data_source(): # noqa: D103 return GSPDataSource( image_size_pixels=16, meters_per_pixel=2000, @@ -75,7 +74,7 @@ def gsp_data_source(): @pytest.fixture -def configuration(): +def configuration(): # noqa: D103 filename = os.path.join(os.path.dirname(nowcasting_dataset.__file__), "config", "gcp.yaml") configuration = load_yaml_configuration(filename) @@ -83,5 +82,5 @@ def configuration(): @pytest.fixture -def test_data_folder(): +def test_data_folder(): # noqa: D103 return os.path.join(os.path.dirname(nowcasting_dataset.__file__), "../tests/data") diff --git a/environment.yml b/environment.yml index a35b95c4..36a12961 100644 --- a/environment.yml +++ b/environment.yml @@ -28,7 +28,6 @@ dependencies: # Machine learning - pytorch::pytorch # explicitly specify pytorch channel to prevent conda from using conda-forge for pytorch, and hence installing the CPU-only version. - - pytorch-lightning # PV & Geospatial - pvlib @@ -45,6 +44,4 @@ dependencies: - pre-commit - pip: - - neptune-client[pytorch-lightning] - - tilemapbase - git+https://github.com/SheffieldSolar/PV_Live-API diff --git a/notebooks/2021-09/2021-09-07/sat_data.py b/notebooks/2021-09/2021-09-07/sat_data.py index 5082e29a..6d3acfaf 100644 --- a/notebooks/2021-09/2021-09-07/sat_data.py +++ b/notebooks/2021-09/2021-09-07/sat_data.py @@ -1,3 +1,4 @@ +"""Notebook""" from datetime import datetime from nowcasting_dataset.data_sources.satellite.satellite_data_source import SatelliteDataSource @@ -9,7 +10,6 @@ forecast_len=12, image_size_pixels=64, meters_per_pixel=2000, - n_timesteps_per_batch=32, ) s.open() diff --git a/nowcasting_dataset/config/__init__.py b/nowcasting_dataset/config/__init__.py index 93de233c..135ba3b2 100644 --- a/nowcasting_dataset/config/__init__.py +++ b/nowcasting_dataset/config/__init__.py @@ -1 +1,3 @@ """ Configuration of the dataset """ +from nowcasting_dataset.config.load import load_yaml_configuration +from nowcasting_dataset.config.model import Configuration, InputData, set_git_commit diff --git a/nowcasting_dataset/config/model.py b/nowcasting_dataset/config/model.py index e4c60cd1..038b00be 100644 --- a/nowcasting_dataset/config/model.py +++ b/nowcasting_dataset/config/model.py @@ -15,16 +15,18 @@ from typing import Optional import git +import pandas as pd from pathy import Pathy from pydantic import BaseModel, Field, root_validator, validator +# nowcasting_dataset imports from nowcasting_dataset.consts import ( DEFAULT_N_GSP_PER_EXAMPLE, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE, NWP_VARIABLE_NAMES, SAT_VARIABLE_NAMES, ) - +from nowcasting_dataset.dataset.split import split IMAGE_SIZE_PIXELS_FIELD = Field(64, description="The number of pixels of the region of interest.") METERS_PER_PIXEL_FIELD = Field(2000, description="The number of meters per pixel.") @@ -102,7 +104,7 @@ class Satellite(DataSourceMixin): """Satellite configuration model""" satellite_zarr_path: str = Field( - "gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr", + "gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr", # noqa: E501 description="The path which holds the satellite zarr.", ) satellite_channels: tuple = Field( @@ -116,7 +118,7 @@ class NWP(DataSourceMixin): """NWP configuration model""" nwp_zarr_path: str = Field( - "gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr", + "gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr", # noqa: E501 description="The path which holds the NWP zarr.", ) nwp_channels: tuple = Field(NWP_VARIABLE_NAMES, description="the channels used in the nwp data") @@ -213,7 +215,8 @@ def set_forecast_and_history_minutes(cls, values): Run through the different data sources and if the forecast or history minutes are not set, then set them to the default values """ - + # It would be much better to use nowcasting_dataset.data_sources.ALL_DATA_SOURCE_NAMES, + # but that causes a circular import. ALL_DATA_SOURCE_NAMES = ("pv", "satellite", "nwp", "gsp", "topographic", "sun") enabled_data_sources = [ data_source_name @@ -249,8 +252,8 @@ def set_all_to_defaults(cls): class OutputData(BaseModel): """Output data model""" - filepath: str = Field( - "gs://solar-pv-nowcasting-data/prepared_ML_training_data/v7/", + filepath: Pathy = Field( + Pathy("gs://solar-pv-nowcasting-data/prepared_ML_training_data/v7/"), description=( "Where the data is saved to. If this is running on the cloud then should include" " 'gs://' or 's3://'" @@ -262,7 +265,29 @@ class Process(BaseModel): """Pydantic model of how the data is processed""" seed: int = Field(1234, description="Random seed, so experiments can be repeatable") - batch_size: int = Field(32, description="the number of examples per batch") + batch_size: int = Field(32, description="The number of examples per batch") + t0_datetime_frequency: pd.Timedelta = Field( + pd.Timedelta("5 minutes"), + description=( + "The temporal frequency at which t0 datetimes will be sampled." + " Can be any string that `pandas.Timedelta()` understands." + " For example, if this is set to '5 minutes', then, for each example, the t0 datetime" + " could be at 0, 5, ..., 55 minutes past the hour. If there are DataSources with a" + " lower sample rate (e.g. half-hourly) then these lower-sample-rate DataSources will" + " still produce valid examples. For example, if a half-hourly DataSource is asked for" + " an example with t0=12:05, history_minutes=60, forecast_minutes=60, then it will" + " return data at 11:30, 12:00, 12:30, and 13:00." + ), + ) + split_method: split.SplitMethod = Field( + split.SplitMethod.DAY, + description=( + "The method used to split the t0 datetimes into train, validation and test sets." + ), + ) + n_train_batches: int = 250 + n_validation_batches: int = 10 + n_test_batches: int = 10 upload_every_n_batches: int = Field( 16, description=( diff --git a/nowcasting_dataset/config/on_premises.yaml b/nowcasting_dataset/config/on_premises.yaml index 226254e7..55acec89 100644 --- a/nowcasting_dataset/config/on_premises.yaml +++ b/nowcasting_dataset/config/on_premises.yaml @@ -56,7 +56,7 @@ input_data: topographic_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/Topographic/europe_dem_1km_osgb.tif output_data: - filepath: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/prepared_ML_training_data/v8/ + filepath: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/prepared_ML_training_data/v_testing/ process: batch_size: 32 seed: 1234 diff --git a/nowcasting_dataset/consts.py b/nowcasting_dataset/consts.py index d5e1cd54..f392d709 100644 --- a/nowcasting_dataset/consts.py +++ b/nowcasting_dataset/consts.py @@ -102,3 +102,9 @@ TOPOGRAPHIC_X_COORDS, ] + list(DATETIME_FEATURE_NAMES) T0_DT = "t0_dt" + + +SPATIAL_AND_TEMPORAL_LOCATIONS_OF_EACH_EXAMPLE_FILENAME = ( + "spatial_and_temporal_locations_of_each_example.csv" +) +SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES = ("t0_datetime_UTC", "x_center_OSGB", "y_center_OSGB") diff --git a/nowcasting_dataset/data_sources/__init__.py b/nowcasting_dataset/data_sources/__init__.py index 9f82670f..717e892e 100644 --- a/nowcasting_dataset/data_sources/__init__.py +++ b/nowcasting_dataset/data_sources/__init__.py @@ -1,11 +1,23 @@ """ Various DataSources """ -from nowcasting_dataset.data_sources.data_source import DataSource -from nowcasting_dataset.data_sources.datetime.datetime_data_source import DatetimeDataSource +from nowcasting_dataset.data_sources.data_source import DataSource # noqa: F401 +from nowcasting_dataset.data_sources.datetime.datetime_data_source import ( # noqa: F401 + DatetimeDataSource, +) +from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWPDataSource from nowcasting_dataset.data_sources.pv.pv_data_source import PVDataSource from nowcasting_dataset.data_sources.satellite.satellite_data_source import SatelliteDataSource -from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource from nowcasting_dataset.data_sources.topographic.topographic_data_source import ( TopographicDataSource, ) + +MAP_DATA_SOURCE_NAME_TO_CLASS = { + "pv": PVDataSource, + "satellite": SatelliteDataSource, + "nwp": NWPDataSource, + "gsp": GSPDataSource, + "topographic": TopographicDataSource, + "sun": SunDataSource, +} +ALL_DATA_SOURCE_NAMES = tuple(MAP_DATA_SOURCE_NAME_TO_CLASS.keys()) diff --git a/nowcasting_dataset/data_sources/data_source.py b/nowcasting_dataset/data_sources/data_source.py index 827e802a..1dc825af 100644 --- a/nowcasting_dataset/data_sources/data_source.py +++ b/nowcasting_dataset/data_sources/data_source.py @@ -1,17 +1,22 @@ """ General Data Source Class """ import itertools import logging +from concurrent import futures from dataclasses import InitVar, dataclass from numbers import Number -from typing import Iterable, List, Tuple +from pathlib import Path +from typing import Iterable, List, Tuple, Union import pandas as pd import xarray as xr +import nowcasting_dataset.filesystem.utils as nd_fs_utils import nowcasting_dataset.time as nd_time +import nowcasting_dataset.utils as nd_utils from nowcasting_dataset import square +from nowcasting_dataset.consts import SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput -from nowcasting_dataset.dataset.xr_utils import join_dataset_to_batch_dataset +from nowcasting_dataset.dataset.xr_utils import join_list_dataset_to_batch_dataset, make_dim_index logger = logging.getLogger(__name__) @@ -39,7 +44,8 @@ class DataSource: forecast_minutes: int def __post_init__(self): - """ Post Init """ + """Post Init""" + self.check_input_paths_exist() self.sample_period_duration = pd.Timedelta(self.sample_period_minutes, unit="minutes") # TODO: Do we still need all these different representations of sequence lengths? @@ -69,10 +75,14 @@ def __post_init__(self): self._history_duration + self._forecast_duration + self.sample_period_duration ) - def _get_start_dt(self, t0_dt: pd.Timestamp) -> pd.Timestamp: + def _get_start_dt( + self, t0_dt: Union[pd.Timestamp, pd.DatetimeIndex] + ) -> Union[pd.Timestamp, pd.DatetimeIndex]: return t0_dt - self._history_duration - def _get_end_dt(self, t0_dt: pd.Timestamp) -> pd.Timestamp: + def _get_end_dt( + self, t0_dt: Union[pd.Timestamp, pd.DatetimeIndex] + ) -> Union[pd.Timestamp, pd.DatetimeIndex]: return t0_dt + self._forecast_duration def get_contiguous_t0_time_periods(self) -> pd.DataFrame: @@ -99,8 +109,7 @@ def sample_period_minutes(self) -> int: """ This is the default sample period in minutes. - This functions may be overwritten if - the sample period of the data source is not 5 minutes. + This functions may be overwritten if the sample period of the data source is not 5 minutes. """ logging.debug( "Getting sample_period_minutes default of 5 minutes. " @@ -112,13 +121,104 @@ def open(self): """Open the data source, if necessary. Called from each worker process. Useful for data sources where the - underlying data source cannot be forked (like Zarr on GCP!). + underlying data source cannot be forked (like Zarr). + + Data sources which can be forked safely should call open() from __init__(). + """ + pass + + def check_input_paths_exist(self) -> None: + """Check any input paths exist. Raise FileNotFoundError if not. - Data sources which can be forked safely should call open() - from __init__(). + Can be overridden by child classes. """ pass + # TODO: Issue #319: Standardise parameter names. + def create_batches( + self, + spatial_and_temporal_locations_of_each_example: pd.DataFrame, + idx_of_first_batch: int, + batch_size: int, + dst_path: Path, + local_temp_path: Path, + upload_every_n_batches: int, + ) -> None: + """Create multiple batches and save them to disk. + + Safe to call from worker processes. + + Args: + spatial_and_temporal_locations_of_each_example: A DataFrame where each row specifies + the spatial and temporal location of an example. The number of rows must be + an exact multiple of `batch_size`. + Columns are: t0_datetime_UTC, x_center_OSGB, y_center_OSGB. + idx_of_first_batch: The batch number of the first batch to create. + batch_size: The number of examples per batch. + dst_path: The final destination path for the batches. Must exist. + local_temp_path: The local temporary path. This is only required when dst_path is a + cloud storage bucket, so files must first be created on the VM's local disk in temp_path + and then uploaded to dst_path every upload_every_n_batches. Must exist. Will be emptied. + upload_every_n_batches: Upload the contents of temp_path to dst_path after this number + of batches have been created. If 0 then will write directly to dst_path. + """ + # Sanity checks: + assert idx_of_first_batch >= 0 + assert batch_size > 0 + assert len(spatial_and_temporal_locations_of_each_example) % batch_size == 0 + assert upload_every_n_batches >= 0 + assert spatial_and_temporal_locations_of_each_example.columns.to_list() == list( + SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES + ) + + self.open() + + # Figure out where to write batches to: + save_batches_locally_and_upload = upload_every_n_batches > 0 + if save_batches_locally_and_upload: + nd_fs_utils.delete_all_files_in_temp_path(local_temp_path) + path_to_write_to = local_temp_path if save_batches_locally_and_upload else dst_path + + # Split locations per example into batches: + n_batches = len(spatial_and_temporal_locations_of_each_example) // batch_size + locations_for_batches = [] + for batch_idx in range(n_batches): + start_example_idx = batch_idx * batch_size + end_example_idx = (batch_idx + 1) * batch_size + locations_for_batch = spatial_and_temporal_locations_of_each_example.iloc[ + start_example_idx:end_example_idx + ] + locations_for_batches.append(locations_for_batch) + + # Loop round each batch: + for n_batches_processed, locations_for_batch in enumerate(locations_for_batches): + batch_idx = idx_of_first_batch + n_batches_processed + logger.debug(f"{self.__class__.__name__} creating batch {batch_idx}!") + + # Generate batch. + batch = self.get_batch( + t0_datetimes=locations_for_batch.t0_datetime_UTC, + x_locations=locations_for_batch.x_center_OSGB, + y_locations=locations_for_batch.y_center_OSGB, + ) + + # Save batch to disk. + netcdf_filename = path_to_write_to / nd_utils.get_netcdf_filename(batch_idx) + batch.to_netcdf(netcdf_filename) + + # Upload if necessary. + if ( + save_batches_locally_and_upload + and n_batches_processed > 0 + and n_batches_processed % upload_every_n_batches == 0 + ): + nd_fs_utils.upload_and_delete_local_files(dst_path, path_to_write_to) + + # Upload last few batches, if necessary: + if save_batches_locally_and_upload: + nd_fs_utils.upload_and_delete_local_files(dst_path, path_to_write_to) + + # TODO: Issue #319: Standardise parameter names. def get_batch( self, t0_datetimes: pd.DatetimeIndex, @@ -131,28 +231,39 @@ def get_batch( Args: t0_datetimes: list of timestamps for the datetime of the batches. The batch will also include data for historic and future depending on `history_minutes` and - `future_minutes`. + `future_minutes`. The batch size is given by the length of the t0_datetimes. x_locations: x center batch locations y_locations: y center batch locations Returns: Batch data. """ - examples = [] - zipped = zip(t0_datetimes, x_locations, y_locations) - for t0_datetime, x_location, y_location in zipped: - output: xr.Dataset = self.get_example(t0_datetime, x_location, y_location) - - examples.append(output) - - # could add option here, to save each data source using - # 1. # DataSourceOutput.to_xr_dataset() to make it a dataset - # 2. DataSourceOutput.save_netcdf(), save to netcdf - - # get the name of the cls, this could be one of the data sources like Sun + assert len(t0_datetimes) == len( + x_locations + ), f"len(t0_datetimes) != len(x_locations): {len(t0_datetimes)} != {len(x_locations)}" + assert len(t0_datetimes) == len( + y_locations + ), f"len(t0_datetimes) != len(y_locations): {len(t0_datetimes)} != {len(y_locations)}" + zipped = list(zip(t0_datetimes, x_locations, y_locations)) + batch_size = len(t0_datetimes) + + with futures.ThreadPoolExecutor(max_workers=batch_size) as executor: + future_examples = [] + for coords in zipped: + t0_datetime, x_location, y_location = coords + future_example = executor.submit( + self.get_example, t0_datetime, x_location, y_location + ) + future_examples.append(future_example) + examples = [future_example.result() for future_example in future_examples] + + # Get the DataSource class, this could be one of the data sources like Sun cls = examples[0].__class__ + # Set the coords to be indices before joining into a batch + examples = [make_dim_index(example) for example in examples] + # join the examples together, and cast them to the cls, so that validation can occur - return cls(join_dataset_to_batch_dataset(examples)) + return cls(join_list_dataset_to_batch_dataset(examples)) def datetime_index(self) -> pd.DatetimeIndex: """Returns a complete list of all available datetimes.""" @@ -180,6 +291,7 @@ def get_contiguous_time_periods(self) -> pd.DataFrame: max_gap_duration=self.sample_period_duration, ) + # TODO: Issue #319: Standardise parameter names. def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]: """Find a valid geographical locations for each t0_datetime. @@ -191,10 +303,12 @@ def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], L raise NotImplementedError() # ****************** METHODS THAT MUST BE OVERRIDDEN ********************** + # TODO: Issue #319: Standardise parameter names. def _get_time_slice(self, t0_dt: pd.Timestamp): """Get a single timestep of data. Must be overridden.""" raise NotImplementedError() + # TODO: Issue #319: Standardise parameter names. def get_example( self, t0_dt: pd.Timestamp, #: Datetime of "now": The most recent obs. @@ -240,21 +354,20 @@ class ZarrDataSource(ImageDataSource): channels: The Zarr parameters to load. """ - channels: Iterable[str] - #: Mustn't be None, but cannot have a non-default arg in this position :) - n_timesteps_per_batch: int = None + # zarr_path and channels must be set. But dataclasses complains about defining a non-default + # argument after a default argument if we remove the ` = None`. + zarr_path: Union[Path, str] = None + channels: Iterable[str] = None consolidated: bool = True def __post_init__(self, image_size_pixels: int, meters_per_pixel: int): """ Post init """ super().__post_init__(image_size_pixels, meters_per_pixel) self._data = None - if self.n_timesteps_per_batch is None: - # Using hacky default for now. The whole concept of n_timesteps_per_batch - # will be removed when #213 is completed. - # TODO: Remove n_timesteps_per_batch when #213 is completed! - self.n_timesteps_per_batch = 16 - logger.warning("n_timesteps_per_batch is not set! Using default!") + + def check_input_paths_exist(self) -> None: + """Check input paths exist. If not, raise a FileNotFoundError.""" + nd_fs_utils.check_path_exists(self.zarr_path) @property def data(self): @@ -306,10 +419,7 @@ def get_example( f"actual shape {selected_data.shape}" ) - # rename 'variable' to 'channels' - selected_data = selected_data.rename({"variable": "channels"}) - - return selected_data + return selected_data.load() def geospatial_border(self) -> List[Tuple[Number, Number]]: """ diff --git a/nowcasting_dataset/data_sources/data_source_list.py b/nowcasting_dataset/data_sources/data_source_list.py deleted file mode 100644 index 86554054..00000000 --- a/nowcasting_dataset/data_sources/data_source_list.py +++ /dev/null @@ -1,153 +0,0 @@ -"""DataSourceList class.""" - -import logging - -import numpy as np -import pandas as pd - -import nowcasting_dataset.time as nd_time -import nowcasting_dataset.utils as nd_utils -from nowcasting_dataset.config import model -from nowcasting_dataset import data_sources -logger = logging.getLogger(__name__) - - -class DataSourceList(list): - """Hold a list of DataSource objects. - - Attrs: - data_source_which_defines_geospatial_locations: The DataSource used to compute the - geospatial locations of each example. - """ - - @classmethod - def from_config(cls, config_for_all_data_sources: model.InputData): - """Create a DataSource List from an InputData configuration object. - - For each key in each DataSource's configuration object, the string `_` - is removed from the key before passing to the DataSource constructor. This allows us to - have verbose field names in the configuration YAML files, whilst also using standard - constructor arguments for DataSources. - """ - data_source_name_to_class = { - "pv": data_sources.PVDataSource, - "satellite": data_sources.SatelliteDataSource, - "nwp": data_sources.NWPDataSource, - "gsp": data_sources.GSPDataSource, - "topographic": data_sources.TopographicDataSource, - "sun": data_sources.SunDataSource, - } - data_source_list = cls([]) - for data_source_name, data_source_class in data_source_name_to_class.items(): - logger.debug(f"Creating {data_source_name} DataSource object.") - config_for_data_source = getattr(config_for_all_data_sources, data_source_name) - if config_for_data_source is None: - logger.info(f"No configuration found for {data_source_name}.") - continue - config_for_data_source = config_for_data_source.dict() - - # Strip `_` from the config option field names. - config_for_data_source = nd_utils.remove_regex_pattern_from_keys( - config_for_data_source, pattern_to_remove=f"^{data_source_name}_" - ) - - try: - data_source = data_source_class(**config_for_data_source) - except Exception: - logger.exception(f"Exception whilst instantiating {data_source_name}!") - raise - data_source_list.append(data_source) - if ( - data_source_name - == config_for_all_data_sources.data_source_which_defines_geospatial_locations - ): - data_source_list.data_source_which_defines_geospatial_locations = data_source - logger.info( - f"DataSource {data_source_name} set as" - " data_source_which_defines_geospatial_locations" - ) - - try: - _ = data_source_list.data_source_which_defines_geospatial_locations - except AttributeError: - logger.warning( - "No DataSource configured as data_source_which_defines_geospatial_locations!" - ) - return data_source_list - - def get_t0_datetimes_across_all_data_sources(self, freq: str) -> pd.DatetimeIndex: - """ - Compute the intersection of the t0 datetimes available across all DataSources. - - Args: - freq: The Pandas frequency string. The returned DatetimeIndex will be at this frequency, - and every datetime will be aligned to this frequency. For example, if - freq='5 minutes' then every datetime will be at 00, 05, ..., 55 minutes - past the hour. - - Returns: Valid t0 datetimes, taking into consideration all DataSources, - filtered by daylight hours (SatelliteDataSource.datetime_index() removes the night - datetimes). - """ - logger.debug("Get the intersection of time periods across all DataSources.") - - # Get the intersection of t0 time periods from all data sources. - t0_time_periods_for_all_data_sources = [] - for data_source in self: - logger.debug(f"Getting t0 time periods for {type(data_source).__name__}") - try: - t0_time_periods = data_source.get_contiguous_t0_time_periods() - except NotImplementedError: - pass # Skip data_sources with no concept of time. - else: - t0_time_periods_for_all_data_sources.append(t0_time_periods) - - intersection_of_t0_time_periods = nd_time.intersection_of_multiple_dataframes_of_periods( - t0_time_periods_for_all_data_sources - ) - - t0_datetimes = nd_time.time_periods_to_datetime_index( - time_periods=intersection_of_t0_time_periods, freq=freq - ) - - return t0_datetimes - - def sample_spatial_and_temporal_locations_for_examples( - self, t0_datetimes: pd.DatetimeIndex, n_examples: int - ) -> pd.DataFrame: - """ - Computes the geospatial and temporal locations for each training example. - - The first data_source in this DataSourceList defines the geospatial locations of - each example. - - Args: - t0_datetimes: All available t0 datetimes. Can be computed with - `DataSourceList.get_t0_datetimes_across_all_data_sources()` - n_examples: The number of examples requested. - - Returns: - Each row of each the DataFrame specifies the position of each example, using - columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'. - """ - # This code is for backwards-compatibility with code which expects the first DataSource - # in the list to be used to define which DataSource defines the spatial location. - # TODO: Remove this try block after implementing issue #213. - try: - data_source_which_defines_geospatial_locations = ( - self.data_source_which_defines_geospatial_locations - ) - except AttributeError: - data_source_which_defines_geospatial_locations = self[0] - - shuffled_t0_datetimes = np.random.choice(t0_datetimes, size=n_examples) - x_locations, y_locations = data_source_which_defines_geospatial_locations.get_locations( - shuffled_t0_datetimes - ) - return pd.DataFrame( - { - "t0_datetime_UTC": shuffled_t0_datetimes, - "x_center_OSGB": x_locations, - "y_center_OSGB": y_locations, - } - ) diff --git a/nowcasting_dataset/data_sources/datasource_output.py b/nowcasting_dataset/data_sources/datasource_output.py index d03a27ad..cc7c4fdd 100644 --- a/nowcasting_dataset/data_sources/datasource_output.py +++ b/nowcasting_dataset/data_sources/datasource_output.py @@ -4,13 +4,11 @@ import logging import os from pathlib import Path -from typing import List import numpy as np -from pydantic import BaseModel, Field from nowcasting_dataset.dataset.xr_utils import PydanticXArrayDataSet -from nowcasting_dataset.filesystem.utils import make_folder +from nowcasting_dataset.filesystem.utils import makedirs from nowcasting_dataset.utils import get_netcdf_filename logger = logging.getLogger(__name__) @@ -30,7 +28,7 @@ def get_name(self) -> str: def save_netcdf(self, batch_i: int, path: Path): """ - Save batch to netcdf file + Save batch to netcdf file in path//. Args: batch_i: the batch id, used to make the filename @@ -43,8 +41,8 @@ def save_netcdf(self, batch_i: int, path: Path): # make folder folder = os.path.join(path, name) if batch_i == 0: - # only need to make the folder once, or check that there folder is there once - make_folder(path=folder) + # only need to make the folder once, or check that the folder is there once + makedirs(path=folder) # make file local_filename = os.path.join(folder, filename) diff --git a/nowcasting_dataset/data_sources/datetime/datetime_data_source.py b/nowcasting_dataset/data_sources/datetime/datetime_data_source.py index e83fdab9..a6617691 100644 --- a/nowcasting_dataset/data_sources/datetime/datetime_data_source.py +++ b/nowcasting_dataset/data_sources/datetime/datetime_data_source.py @@ -1,7 +1,6 @@ """ Datetime DataSource - add hour and year features """ from dataclasses import dataclass from numbers import Number -from typing import List, Tuple import pandas as pd diff --git a/nowcasting_dataset/data_sources/datetime/datetime_model.py b/nowcasting_dataset/data_sources/datetime/datetime_model.py index f4f647b4..9c7b3d8f 100644 --- a/nowcasting_dataset/data_sources/datetime/datetime_model.py +++ b/nowcasting_dataset/data_sources/datetime/datetime_model.py @@ -1,6 +1,5 @@ """ Model for output of datetime data """ from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput -from nowcasting_dataset.utils import coord_to_range class Datetime(DataSourceOutput): diff --git a/nowcasting_dataset/data_sources/fake.py b/nowcasting_dataset/data_sources/fake.py index 65468cae..d48b93ee 100644 --- a/nowcasting_dataset/data_sources/fake.py +++ b/nowcasting_dataset/data_sources/fake.py @@ -16,8 +16,8 @@ from nowcasting_dataset.data_sources.topographic.topographic_model import Topographic from nowcasting_dataset.dataset.xr_utils import ( convert_data_array_to_dataset, - join_dataset_to_batch_dataset, join_list_data_array_to_batch_dataset, + join_list_dataset_to_batch_dataset, ) @@ -26,7 +26,7 @@ def datetime_fake(batch_size, seq_length_5): xr_arrays = [create_datetime_dataset(seq_length=seq_length_5) for _ in range(batch_size)] # make dataset - xr_dataset = join_dataset_to_batch_dataset(xr_arrays) + xr_dataset = join_list_dataset_to_batch_dataset(xr_arrays) return Datetime(xr_dataset) @@ -48,7 +48,7 @@ def gsp_fake( ] # make dataset - xr_dataset = join_dataset_to_batch_dataset(xr_arrays) + xr_dataset = join_list_dataset_to_batch_dataset(xr_arrays) return GSP(xr_dataset) @@ -58,7 +58,7 @@ def metadata_fake(batch_size): xr_arrays = [create_metadata_dataset() for _ in range(batch_size)] # make dataset - xr_dataset = join_dataset_to_batch_dataset(xr_arrays) + xr_dataset = join_list_dataset_to_batch_dataset(xr_arrays) return Metadata(xr_dataset) @@ -102,7 +102,7 @@ def pv_fake(batch_size, seq_length_5, n_pv_systems_per_batch): ] # make dataset - xr_dataset = join_dataset_to_batch_dataset(xr_arrays) + xr_dataset = join_list_dataset_to_batch_dataset(xr_arrays) return PV(xr_dataset) @@ -142,7 +142,7 @@ def sun_fake(batch_size, seq_length_5): ] # make dataset - xr_dataset = join_dataset_to_batch_dataset(xr_arrays) + xr_dataset = join_list_dataset_to_batch_dataset(xr_arrays) return Sun(xr_dataset) diff --git a/nowcasting_dataset/data_sources/gsp/eso.py b/nowcasting_dataset/data_sources/gsp/eso.py index c76168ef..8d1f7150 100644 --- a/nowcasting_dataset/data_sources/gsp/eso.py +++ b/nowcasting_dataset/data_sources/gsp/eso.py @@ -1,5 +1,7 @@ """ -This file has a few functions that are used to get GSP (Grid Supply Point) information from National Grid ESO. +This file has a few functions that are used to get GSP (Grid Supply Point) information + +The info comes from National Grid ESO. ESO - Electricity System Operator. General information can be found here - https://data.nationalgrideso.com/system/gis-boundaries-for-gb-grid-supply-points @@ -50,8 +52,8 @@ def get_gsp_metadata_from_eso(calculate_centroid: bool = True) -> pd.DataFrame: """ logger.debug("Getting GSP shape file") - # call ESO website. There is a possibility that this API will be replaced and its unclear if this original API will - # will stay operational + # call ESO website. There is a possibility that this API will be replaced and its unclear if + # this original API will will stay operational url = ( "https://data.nationalgrideso.com/api/3/action/datastore_search?" "resource_id=bbe2cc72-a6c6-46e6-8f4e-48b879467368&limit=400" @@ -95,7 +97,8 @@ def get_gsp_shape_from_eso( Get the the gsp shape file from ESO (or a local file) Args: - join_duplicates: If True, any RegionIDs which have multiple entries, will be joined together to give one entry + join_duplicates: If True, any RegionIDs which have multiple entries, will be joined + together to give one entry. load_local_file: Load from a local file, not from ESO save_local_file: Save to a local file, only need to do this is Data is updated. @@ -117,11 +120,11 @@ def get_gsp_shape_from_eso( shape_gpd.rename(columns=rename_load_columns, inplace=True) logger.debug("loading local file for GSP shape data:done") else: - # call ESO website. There is a possibility that this API will be replaced and its unclear if this original API will - # will stay operational + # call ESO website. There is a possibility that this API will be replaced and its unclear + # if this original API will stay operational. url = ( - "https://data.nationalgrideso.com/backend/dataset/2810092e-d4b2-472f-b955-d8bea01f9ec0/resource/" - "a3ed5711-407a-42a9-a63a-011615eea7e0/download/gsp_regions_20181031.geojson" + "https://data.nationalgrideso.com/backend/dataset/2810092e-d4b2-472f-b955-d8bea01f9ec0/" + "resource/a3ed5711-407a-42a9-a63a-011615eea7e0/download/gsp_regions_20181031.geojson" ) with urlopen(url) as response: diff --git a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py index e1338cbc..8868eb1f 100644 --- a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py +++ b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py @@ -14,6 +14,7 @@ import torch import xarray as xr +import nowcasting_dataset.filesystem.utils as nd_fs_utils from nowcasting_dataset.consts import DEFAULT_N_GSP_PER_EXAMPLE from nowcasting_dataset.data_sources.data_source import ImageDataSource from nowcasting_dataset.data_sources.gsp.eso import get_gsp_metadata_from_eso @@ -65,6 +66,10 @@ def __post_init__(self, image_size_pixels: int, meters_per_pixel: int): self.rng = np.random.default_rng(seed=seed) self.load() + def check_input_paths_exist(self) -> None: + """Check input paths exist. If not, raise a FileNotFoundError.""" + nd_fs_utils.check_path_exists(self.zarr_path) + @property def sample_period_minutes(self) -> int: """Override the default sample minutes""" @@ -118,13 +123,12 @@ def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], L Returns: list of x and y locations """ - logger.debug("Getting locations for the batch") - # Pick a random GSP for each t0_datetime, and then grab # their geographical location. x_locations = [] y_locations = [] + # TODO: Issue 305: Speed up this function by removing this for loop? for t0_dt in t0_datetimes: # Choose start and end times @@ -138,21 +142,15 @@ def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], L random_gsp_id = self.rng.choice(gsp_power.columns) meta_data = self.metadata[(self.metadata["gsp_id"] == random_gsp_id)] - # Make sure there is only one. Sometimes there are multiple gsp_ids at one location - # e.g. 'SELL_1'. Further investigation on this may be needed, - # but going to ignore this for now. See this issue: - # https://github.com/openclimatefix/nowcasting_dataset/issues/272 + # Make sure there is only one GSP. + # Sometimes there are multiple gsp_ids at one location e.g. 'SELL_1'. + # TODO: Issue #272: Further investigation on multiple GSPs may be needed. metadata_for_gsp = meta_data.iloc[0] # Get metadata for GSP x_locations.append(metadata_for_gsp.location_x) y_locations.append(metadata_for_gsp.location_y) - logger.debug( - f"Found locations for GSP id {random_gsp_id} of {metadata_for_gsp.location_x} and " - f"{metadata_for_gsp.location_y}" - ) - return x_locations, y_locations def get_example( @@ -398,7 +396,7 @@ def load_solar_gsp_data( Returns: dataframe of pv data """ - logger.debug(f"Loading Solar GSP Data from GCS {zarr_path} from {start_dt} to {end_dt}") + logger.debug(f"Loading Solar GSP Data from {zarr_path} from {start_dt} to {end_dt}") # Open data - it may be quicker to open byte file first, but decided just to keep it # like this at the moment. gsp_power = xr.open_dataset(zarr_path, engine="zarr") diff --git a/nowcasting_dataset/data_sources/gsp/gsp_model.py b/nowcasting_dataset/data_sources/gsp/gsp_model.py index 753d5da9..2e25b612 100644 --- a/nowcasting_dataset/data_sources/gsp/gsp_model.py +++ b/nowcasting_dataset/data_sources/gsp/gsp_model.py @@ -4,7 +4,6 @@ from xarray.ufuncs import isinf, isnan from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput -from nowcasting_dataset.time import make_random_time_vectors logger = logging.getLogger(__name__) @@ -18,8 +17,8 @@ class GSP(DataSourceOutput): @classmethod def model_validation(cls, v): """ Check that all values are non NaNs """ - assert (~isnan(v.data)).all(), f"Some gsp data values are NaNs" - assert (~isinf(v.data)).all(), f"Some gsp data values are Infinite" + assert (~isnan(v.data)).all(), "Some gsp data values are NaNs" + assert (~isinf(v.data)).all(), "Some gsp data values are Infinite" assert (v.data >= 0).all(), f"Some gsp data values are below 0 {v.data.min()}" return v diff --git a/nowcasting_dataset/data_sources/gsp/pvlive.py b/nowcasting_dataset/data_sources/gsp/pvlive.py index 1f0a7912..3e6199ce 100644 --- a/nowcasting_dataset/data_sources/gsp/pvlive.py +++ b/nowcasting_dataset/data_sources/gsp/pvlive.py @@ -20,7 +20,9 @@ def load_pv_gsp_raw_data_from_pvlive( start: datetime, end: datetime, number_of_gsp: int = None, normalize_data: bool = True ) -> pd.DataFrame: """ - Load raw pv gsp data from pvlive. Note that each gsp is loaded separately. Also the data is loaded in 30 day chunks. + Load raw pv gsp data from pvlive. + + Note that each gsp is loaded separately. Also the data is loaded in 30 day chunks. Args: start: the start date for gsp data to load @@ -37,14 +39,16 @@ def load_pv_gsp_raw_data_from_pvlive( # setup pv Live class, although here we are getting historic data pvl = PVLive() - # set the first chunk of data, note that 30 day chunks are used except if the end time is smaller than that + # set the first chunk of data, note that 30 day chunks are used except if the end time is + # smaller than that first_start_chunk = start first_end_chunk = min([first_start_chunk + CHUNK_DURATION, end]) gsp_data_df = [] logger.debug(f"Will be getting data for {len(gsp_ids)} gsp ids") # loop over gsp ids - # limit the total number of concurrent tasks to be 4, so that we don't hit the pvlive api too much + # limit the total number of concurrent tasks to be 4, so that we don't hit the pvlive api + # too much future_tasks = [] with futures.ThreadPoolExecutor(max_workers=4) as executor: for gsp_id in gsp_ids: @@ -53,8 +57,8 @@ def load_pv_gsp_raw_data_from_pvlive( start_chunk = first_start_chunk end_chunk = first_end_chunk - # loop over 30 days chunks (nice to see progress instead of waiting a long time for one command - this might - # not be the fastest) + # loop over 30 days chunks (nice to see progress instead of waiting a long time for + # one command - this might not be the fastest) while start_chunk <= end: logger.debug(f"Getting data for gsp id {gsp_id} from {start_chunk} to {end_chunk}") @@ -77,7 +81,7 @@ def load_pv_gsp_raw_data_from_pvlive( if end_chunk > end: end_chunk = end - logger.debug(f"Getting results") + logger.debug("Getting results") # Collect results from each thread. for task in tqdm(future_tasks): one_chunk_one_gsp_gsp_data_df = task.result() diff --git a/nowcasting_dataset/data_sources/metadata/metadata_data_source.py b/nowcasting_dataset/data_sources/metadata/metadata_data_source.py index acb58aab..de4bdbf1 100644 --- a/nowcasting_dataset/data_sources/metadata/metadata_data_source.py +++ b/nowcasting_dataset/data_sources/metadata/metadata_data_source.py @@ -1,7 +1,6 @@ """ Datetime DataSource - add hour and year features """ from dataclasses import dataclass from numbers import Number -from typing import List, Tuple import numpy as np import pandas as pd @@ -42,7 +41,7 @@ def get_example( # TODO: data_dict is unused in this function. Is that a bug? # https://github.com/openclimatefix/nowcasting_dataset/issues/279 - data_dict = dict( + data_dict = dict( # noqa: F841 t0_dt=to_numpy(t0_dt), #: Shape: [batch_size,] x_meters_center=np.array(x_meters_center), y_meters_center=np.array(y_meters_center), diff --git a/nowcasting_dataset/data_sources/nwp/nwp_data_source.py b/nowcasting_dataset/data_sources/nwp/nwp_data_source.py index 7b75c692..79c47a06 100644 --- a/nowcasting_dataset/data_sources/nwp/nwp_data_source.py +++ b/nowcasting_dataset/data_sources/nwp/nwp_data_source.py @@ -1,8 +1,6 @@ """ NWP Data Source """ import logging -from concurrent import futures from dataclasses import InitVar, dataclass -from numbers import Number from typing import Iterable, Optional import numpy as np @@ -10,10 +8,8 @@ import xarray as xr from nowcasting_dataset import utils -from nowcasting_dataset.data_sources.data_source import ZarrDataSource -from nowcasting_dataset.data_sources.nwp.nwp_model import NWP -from nowcasting_dataset.dataset.xr_utils import join_list_data_array_to_batch_dataset from nowcasting_dataset.consts import NWP_VARIABLE_NAMES +from nowcasting_dataset.data_sources.data_source import ZarrDataSource _LOG = logging.getLogger(__name__) @@ -43,7 +39,6 @@ class NWPDataSource(ZarrDataSource): hcc : High-level cloud cover in %. """ - zarr_path: str = None channels: Optional[Iterable[str]] = NWP_VARIABLE_NAMES image_size_pixels: InitVar[int] = 2 meters_per_pixel: InitVar[int] = 2_000 @@ -78,69 +73,6 @@ def open(self) -> None: data = self._open_data() self._data = data["UKV"].sel(variable=list(self.channels)) - def get_batch( - self, - t0_datetimes: pd.DatetimeIndex, - x_locations: Iterable[Number], - y_locations: Iterable[Number], - ) -> NWP: - """ - Get batch data - - Args: - t0_datetimes: list of timstamps - x_locations: list of x locations, where the batch data is for - y_locations: list of y locations, where the batch data is for - - Returns: batch data - - """ - # Lazily select time slices. - selections = [] - for t0_dt in t0_datetimes[: self.n_timesteps_per_batch]: - selections.append(self._get_time_slice(t0_dt)) - - # Load entire time slices from disk in multiple threads. - data = [] - with futures.ThreadPoolExecutor(max_workers=self.n_timesteps_per_batch) as executor: - data_futures = [] - # Submit tasks. - for selection in selections: - future = executor.submit(selection.load) - data_futures.append(future) - - # Grab tasks - for future in data_futures: - d = future.result() - data.append(d) - - # Select squares from pre-loaded time slices. - examples = [] - for i, (x_meters_center, y_meters_center) in enumerate(zip(x_locations, y_locations)): - selected_data = data[i % self.n_timesteps_per_batch] - bounding_box = self._square.bounding_box_centered_on( - x_meters_center=x_meters_center, y_meters_center=y_meters_center - ) - selected_data = selected_data.sel( - x=slice(bounding_box.left, bounding_box.right), - y=slice(bounding_box.top, bounding_box.bottom), - ) - - # selected_sat_data is likely to have 1 too many pixels in x and y - # because sel(x=slice(a, b)) is [a, b], not [a, b). So trim: - selected_data = selected_data.isel( - x=slice(0, self._square.size_pixels), y=slice(0, self._square.size_pixels) - ) - - t0_dt = t0_datetimes[i] - selected_data = self._post_process_example(selected_data, t0_dt) - - examples.append(selected_data) - - output = join_list_data_array_to_batch_dataset(examples) - - return NWP(output) - def _open_data(self) -> xr.DataArray: return open_nwp(self.zarr_path, consolidated=self.consolidated) @@ -150,8 +82,7 @@ def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray: Note that this function does *not* resample from hourly to 5 minutely. Resampling would be very expensive if done on the whole geographical - extent of the NWP data! So resampling is done in - _post_process_example(). + extent of the NWP data! Args: t0_dt: the time slice is around t0_dt. @@ -186,9 +117,7 @@ def _post_process_example( selected_data = selected_data.resample({"target_time": "5T"}) selected_data = selected_data.interpolate() selected_data = selected_data.sel(target_time=slice(start_dt, end_dt)) - selected_data = selected_data.rename({"target_time": "time"}) - selected_data = selected_data.rename({"variable": "channels"}) - + selected_data = selected_data.rename({"target_time": "time", "variable": "channels"}) selected_data.data = selected_data.data.astype(np.float32) return selected_data diff --git a/nowcasting_dataset/data_sources/nwp/nwp_model.py b/nowcasting_dataset/data_sources/nwp/nwp_model.py index f44df98f..42714434 100644 --- a/nowcasting_dataset/data_sources/nwp/nwp_model.py +++ b/nowcasting_dataset/data_sources/nwp/nwp_model.py @@ -6,7 +6,6 @@ from xarray.ufuncs import isinf, isnan from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput -from nowcasting_dataset.time import make_random_time_vectors logger = logging.getLogger(__name__) @@ -23,5 +22,5 @@ class NWP(DataSourceOutput): def model_validation(cls, v): """ Check that all values are not NaNs """ assert (~isnan(v.data)).all(), "Some nwp data values are NaNs" - assert (~isinf(v.data)).all(), f"Some nwp data values are Infinite" + assert (~isinf(v.data)).all(), "Some nwp data values are Infinite" return v diff --git a/nowcasting_dataset/data_sources/pv/pv_data_source.py b/nowcasting_dataset/data_sources/pv/pv_data_source.py index 788b417c..2738286f 100644 --- a/nowcasting_dataset/data_sources/pv/pv_data_source.py +++ b/nowcasting_dataset/data_sources/pv/pv_data_source.py @@ -9,12 +9,13 @@ from pathlib import Path from typing import List, Optional, Tuple, Union -import gcsfs +import fsspec import numpy as np import pandas as pd import torch import xarray as xr +import nowcasting_dataset.filesystem.utils as nd_fs_utils from nowcasting_dataset import geospatial, utils from nowcasting_dataset.consts import DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE from nowcasting_dataset.data_sources.data_source import ImageDataSource @@ -52,6 +53,11 @@ def __post_init__(self, image_size_pixels: int, meters_per_pixel: int): self.rng = np.random.default_rng(seed=seed) self.load() + def check_input_paths_exist(self) -> None: + """Check input paths exist. If not, raise a FileNotFoundError.""" + for filename in [self.filename, self.metadata_filename]: + nd_fs_utils.check_path_exists(filename) + def load(self): """ Load metadata and pv power @@ -62,7 +68,7 @@ def load(self): def _load_metadata(self): - logger.debug("Loading Metadata") + logger.debug(f"Loading PV metadata from {self.metadata_filename}") pv_metadata = pd.read_csv(self.metadata_filename, index_col="system_id") pv_metadata.dropna(subset=["longitude", "latitude"], how="any", inplace=True) @@ -88,14 +94,9 @@ def _load_metadata(self): def _load_pv_power(self): - logger.debug("Loading PV Power data") - - if "gs://" not in str(self.filename): - self.load_from_gcs = False + logger.debug(f"Loading PV Power data from {self.filename}") - pv_power = load_solar_pv_data_from_gcs( - self.filename, start_dt=self.start_dt, end_dt=self.end_dt, from_gcs=self.load_from_gcs - ) + pv_power = load_solar_pv_data(self.filename, start_dt=self.start_dt, end_dt=self.end_dt) # A bit of hand-crafted cleaning if 30248 in pv_power.columns: @@ -109,7 +110,8 @@ def _load_pv_power(self): pv_power = drop_pv_systems_which_produce_overnight(pv_power) # Resample to 5-minutely and interpolate up to 15 minutes ahead. - # TODO: Cubic interpolation? + # TODO: Issue #301: Give users the option to NOT resample (because Perceiver IO + # doesn't need all the data to be perfectly aligned). pv_power = pv_power.resample("5T").interpolate(method="time", limit=3) pv_power.dropna(axis="index", how="all", inplace=True) # self.pv_power = dd.from_pandas(pv_power, npartitions=3) @@ -122,10 +124,6 @@ def _get_time_slice(self, t0_dt: pd.Timestamp) -> [pd.DataFrame]: end_dt = self._get_end_dt(t0_dt) del t0_dt # t0 is not used in the rest of this method! selected_pv_power = self.pv_power.loc[start_dt:end_dt].dropna(axis="columns", how="any") - - selected_pv_azimuth_angle = None - selected_pv_elevation_angle = None - return selected_pv_power def _get_central_pv_system_id( @@ -204,8 +202,8 @@ def get_example( Get Example data for PV data Args: - t0_dt: list of timestamps for the datetime of the batches. The batch will also include data - for historic and future depending on 'history_minutes' and 'future_minutes'. + t0_dt: list of timestamps for the datetime of the batches. The batch will also include + data for historic and future depending on 'history_minutes' and 'future_minutes'. x_meters_center: x center batch locations y_meters_center: y center batch locations @@ -232,7 +230,11 @@ def get_example( selected_pv_power = selected_pv_power[all_pv_system_ids] - pv_system_row_number = np.flatnonzero(self.pv_metadata.index.isin(all_pv_system_ids)) + # TODO: Issue #302. pv_system_row_number is assigned to but never used. + # That may indicate a bug? + pv_system_row_number = np.flatnonzero( # noqa: F841 + self.pv_metadata.index.isin(all_pv_system_ids) + ) pv_system_x_coords = self.pv_metadata.location_x[all_pv_system_ids] pv_system_y_coords = self.pv_metadata.location_y[all_pv_system_ids] @@ -270,7 +272,7 @@ def get_example( pv["x_coords"] = x_coords pv["y_coords"] = y_coords - # pad out so that there are always 32 gsp, pad with zeros + # pad out so that there are always n_pv_systems_per_example, pad with zeros pad_n = self.n_pv_systems_per_example - len(pv.id_index) pv = pv.pad(id_index=(0, pad_n), data=((0, 0), (0, pad_n)), constant_values=0) @@ -315,40 +317,30 @@ def datetime_index(self) -> pd.DatetimeIndex: return self.pv_power.index -# TODO: Enable this function to load from any compute environment. See issue #286. -def load_solar_pv_data_from_gcs( +def load_solar_pv_data( filename: Union[str, Path], start_dt: Optional[datetime.datetime] = None, end_dt: Optional[datetime.datetime] = None, - from_gcs: bool = True, ) -> pd.DataFrame: """ - Load solar pv data from gcs (although there is an option to load from local - for testing) + Load solar pv data from any compute environment. Args: filename: filename of file to be loaded start_dt: the start datetime, which to trim the data to end_dt: the end datetime, which to trim the data to - from_gcs: option to laod from gcs, or form local file Returns: Solar PV data - """ - gcs = gcsfs.GCSFileSystem(access="read_only") - - logger.debug("Loading Solar PV Data from GCS") + logger.debug(f"Loading Solar PV Data from {filename} from {start_dt} to {end_dt}.") # It is possible to simplify the code below and do # xr.open_dataset(file, engine='h5netcdf') # in the first 'with' block, and delete the second 'with' block. # But that takes 1 minute to load the data, where as loading into memory # first and then loading from memory takes 23 seconds! - if from_gcs: - with gcs.open(filename, mode="rb") as file: - file_bytes = file.read() - else: - with open(filename, mode="rb") as file: - file_bytes = file.read() + with fsspec.open(filename, mode="rb") as file: + file_bytes = file.read() with io.BytesIO(file_bytes) as file: pv_power = xr.open_dataset(file, engine="h5netcdf") diff --git a/nowcasting_dataset/data_sources/satellite/satellite_data_source.py b/nowcasting_dataset/data_sources/satellite/satellite_data_source.py index e6e823bc..b01babf5 100644 --- a/nowcasting_dataset/data_sources/satellite/satellite_data_source.py +++ b/nowcasting_dataset/data_sources/satellite/satellite_data_source.py @@ -1,34 +1,23 @@ """ Satellite Data Source """ import logging -from concurrent import futures from dataclasses import InitVar, dataclass -from numbers import Number from typing import Iterable, Optional -import numpy as np import pandas as pd import xarray as xr import nowcasting_dataset.time as nd_time +from nowcasting_dataset.consts import SAT_VARIABLE_NAMES from nowcasting_dataset.data_sources.data_source import ZarrDataSource from nowcasting_dataset.data_sources.satellite.satellite_model import Satellite -from nowcasting_dataset.dataset.xr_utils import join_list_data_array_to_batch_dataset _LOG = logging.getLogger("nowcasting_dataset") -from nowcasting_dataset.consts import SAT_VARIABLE_NAMES - - @dataclass class SatelliteDataSource(ZarrDataSource): - """ - Satellite Data Source + """Satellite Data Source.""" - zarr_path: Must start with 'gs://' if on GCP. - """ - - zarr_path: str = None channels: Optional[Iterable[str]] = SAT_VARIABLE_NAMES image_size_pixels: InitVar[int] = 128 meters_per_pixel: InitVar[int] = 2_000 @@ -36,7 +25,6 @@ class SatelliteDataSource(ZarrDataSource): def __post_init__(self, image_size_pixels: int, meters_per_pixel: int): """ Post Init """ super().__post_init__(image_size_pixels, meters_per_pixel) - self._cache = {} n_channels = len(self.channels) self._shape_of_example = ( self._total_seq_length, @@ -60,85 +48,57 @@ def open(self) -> None: def _open_data(self) -> xr.DataArray: return open_sat_data(zarr_path=self.zarr_path, consolidated=self.consolidated) - def get_batch( - self, - t0_datetimes: pd.DatetimeIndex, - x_locations: Iterable[Number], - y_locations: Iterable[Number], - ) -> Satellite: - """ - Get batch data - - Load the first _n_timesteps_per_batch concurrently. This - loads the timesteps from disk concurrently, and fills the - cache. If we try loading all examples - concurrently, then SatelliteDataSource will try reading from - empty caches, and things are much slower! - - Args: - t0_datetimes: list of timestamps for the datetime of the batches. The batch will also - include data for historic and future depending on `history_minutes` and - `future_minutes`. - x_locations: x center batch locations - y_locations: y center batch locations - - Returns: Batch data - - """ - # Load the first _n_timesteps_per_batch concurrently. This - # loads the timesteps from disk concurrently, and fills the - # cache. If we try loading all examples - # concurrently, then SatelliteDataSource will try reading from - # empty caches, and things are much slower! - zipped = list(zip(t0_datetimes, x_locations, y_locations)) - batch_size = len(t0_datetimes) - - with futures.ThreadPoolExecutor(max_workers=batch_size) as executor: - future_examples = [] - for coords in zipped[: self.n_timesteps_per_batch]: - t0_datetime, x_location, y_location = coords - future_example = executor.submit( - self.get_example, t0_datetime, x_location, y_location - ) - future_examples.append(future_example) - examples = [future_example.result() for future_example in future_examples] - - # Load the remaining examples. This should hit the DataSource caches. - for coords in zipped[self.n_timesteps_per_batch :]: - t0_datetime, x_location, y_location = coords - example = self.get_example(t0_datetime, x_location, y_location) - examples.append(example) - - output = join_list_data_array_to_batch_dataset(examples) - - self._cache = {} - + def _dataset_to_data_source_output(output: xr.Dataset) -> Satellite: return Satellite(output) def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray: - try: - return self._cache[t0_dt] - except KeyError: - start_dt = self._get_start_dt(t0_dt) - end_dt = self._get_end_dt(t0_dt) - data = self.data.sel(time=slice(start_dt, end_dt)) - data = data.load() - self._cache[t0_dt] = data - return data - - def _post_process_example( - self, selected_data: xr.DataArray, t0_dt: pd.Timestamp - ) -> xr.DataArray: - - selected_data.data = selected_data.data.astype(np.float32) - - return selected_data + start_dt = self._get_start_dt(t0_dt) + end_dt = self._get_end_dt(t0_dt) + data = self.data.sel(time=slice(start_dt, end_dt)) + return data def datetime_index(self, remove_night: bool = True) -> pd.DatetimeIndex: """Returns a complete list of all available datetimes Args: remove_night: If True then remove datetimes at night. + We're interested in forecasting solar power generation, so we + don't care about nighttime data :) + + In the UK in summer, the sun rises first in the north east, and + sets last in the north west [1]. In summer, the north gets more + hours of sunshine per day. + + In the UK in winter, the sun rises first in the south east, and + sets last in the south west [2]. In winter, the south gets more + hours of sunshine per day. + + | | Summer | Winter | + | ---: | :---: | :---: | + | Sun rises first in | N.E. | S.E. | + | Sun sets last in | N.W. | S.W. | + | Most hours of sunlight | North | South | + + Before training, we select timesteps which have at least some + sunlight. We do this by computing the clearsky global horizontal + irradiance (GHI) for the four corners of the satellite imagery, + and for all the timesteps in the dataset. We only use timesteps + where the maximum global horizontal irradiance across all four + corners is above some threshold. + + The 'clearsky solar irradiance' is the amount of sunlight we'd + expect on a clear day at a specific time and location. The SI unit + of irradiance is watt per square meter. The 'global horizontal + irradiance' (GHI) is the total sunlight that would hit a + horizontal surface on the surface of the Earth. The GHI is the + sum of the direct irradiance (sunlight which takes a direct path + from the Sun to the Earth's surface) and the diffuse horizontal + irradiance (the sunlight scattered from the atmosphere). For more + info, see: https://en.wikipedia.org/wiki/Solar_irradiance + + References: + 1. [Video of June 2019](https://www.youtube.com/watch?v=IOp-tj-IJpk) + 2. [Video of Jan 2019](https://www.youtube.com/watch?v=CJ4prUVa2nQ) """ if self._data is None: sat_data = self._open_data() diff --git a/nowcasting_dataset/data_sources/sun/raw_data_load_save.py b/nowcasting_dataset/data_sources/sun/raw_data_load_save.py index aafed030..d90e806f 100644 --- a/nowcasting_dataset/data_sources/sun/raw_data_load_save.py +++ b/nowcasting_dataset/data_sources/sun/raw_data_load_save.py @@ -72,7 +72,7 @@ def get_azimuth_and_elevation( ) names.append(name) - logger.debug(f"Getting results") + logger.debug("Getting results") # Collect results from each thread. for future_azimuth_and_elevation, name in tqdm(future_azimuth_and_elevation_per_location): diff --git a/nowcasting_dataset/data_sources/sun/sun_data_source.py b/nowcasting_dataset/data_sources/sun/sun_data_source.py index 55c5c16c..3704f0ca 100644 --- a/nowcasting_dataset/data_sources/sun/sun_data_source.py +++ b/nowcasting_dataset/data_sources/sun/sun_data_source.py @@ -7,8 +7,8 @@ import numpy as np import pandas as pd -import xarray as xr +import nowcasting_dataset.filesystem.utils as nd_fs_utils from nowcasting_dataset.data_sources.data_source import DataSource from nowcasting_dataset.data_sources.sun.raw_data_load_save import load_from_zarr, x_y_to_name from nowcasting_dataset.data_sources.sun.sun_model import Sun @@ -28,6 +28,10 @@ def __post_init__(self): super().__post_init__() self._load() + def check_input_paths_exist(self) -> None: + """Check input paths exist. If not, raise a FileNotFoundError.""" + nd_fs_utils.check_path_exists(self.zarr_path) + def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number ) -> Sun: @@ -42,7 +46,8 @@ def get_example( Returns: Dictionary of azimuth and elevation data """ # all sun data is from 2019, analaysis showed over the timescale we are interested in the - # elevation and azimuth angles change by < 1 degree, so to save data, we just use data form 2019 + # elevation and azimuth angles change by < 1 degree, so to save data, we just use data + # from 2019. t0_dt = t0_dt.replace(year=2019) start_dt = self._get_start_dt(t0_dt) @@ -59,9 +64,11 @@ def get_example( ] # lets make sure there is atleast one assert len(location) > 0 - # Take the first location, and x and y coordinates are the first and center entries in this array + # Take the first location, and x and y coordinates are the first and center entries in + # this array. location = location[0] - # make name of column to pull data from. The columns name will be about something like '22222.555,3333.6666' + # make name of column to pull data from. The columns name will be about + # something like '22222.555,3333.6666' name = x_y_to_name(x=location[0], y=location[1]) del x_meters_center, y_meters_center diff --git a/nowcasting_dataset/data_sources/topographic/topographic_data_source.py b/nowcasting_dataset/data_sources/topographic/topographic_data_source.py index 87f54832..acc065ea 100644 --- a/nowcasting_dataset/data_sources/topographic/topographic_data_source.py +++ b/nowcasting_dataset/data_sources/topographic/topographic_data_source.py @@ -2,41 +2,18 @@ from dataclasses import dataclass from numbers import Number -import numpy as np import pandas as pd import rioxarray import xarray as xr from rasterio.warp import Resampling -from nowcasting_dataset.consts import TOPOGRAPHIC_DATA +import nowcasting_dataset.filesystem.utils as nd_fs_utils from nowcasting_dataset.data_sources.data_source import ImageDataSource from nowcasting_dataset.data_sources.topographic.topographic_model import Topographic from nowcasting_dataset.dataset.xr_utils import convert_data_array_to_dataset from nowcasting_dataset.geospatial import OSGB from nowcasting_dataset.utils import OpenData -# Means computed with -# out_fp = "europe_dem_1km.tif" -# out = rasterio.open(out_fp) -# data = out.read(masked=True) -# print(np.mean(data)) -# print(np.std(data)) -TOPO_MEAN = xr.DataArray( - data=[ - 365.486887, - ], - dims=["variable"], - coords={"variable": [TOPOGRAPHIC_DATA]}, -).astype(np.float32) - -TOPO_STD = xr.DataArray( - data=[ - 478.841369, - ], - dims=["variable"], - coords={"variable": [TOPOGRAPHIC_DATA]}, -).astype(np.float32) - @dataclass class TopographicDataSource(ImageDataSource): @@ -64,6 +41,10 @@ def __post_init__(self, image_size_pixels: int, meters_per_pixel: int): self._stored_pixel_size_meters = abs(self._data.coords["x"][1] - self._data.coords["x"][0]) self._meters_per_pixel = meters_per_pixel + def check_input_paths_exist(self) -> None: + """Check input paths exist. If not, raise a FileNotFoundError.""" + nd_fs_utils.check_path_exists(self.filename) + def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number ) -> Topographic: @@ -111,6 +92,7 @@ def get_example( f"actual shape {selected_data.shape}" ) + # TODO: Issue #318: Coordinates should be changed just before creating a batch. topo_xd = convert_data_array_to_dataset(selected_data) return Topographic(topo_xd) diff --git a/nowcasting_dataset/dataset/README.md b/nowcasting_dataset/dataset/README.md index 040a47f2..73e05a7e 100644 --- a/nowcasting_dataset/dataset/README.md +++ b/nowcasting_dataset/dataset/README.md @@ -7,27 +7,6 @@ This folder contains the following files 'Batch' pydantic class, to hold batch data in. An 'Example' is one item in the batch. 'BatchML' pydantic class, holds data for a batch, ready for ML models. -## datamodule.py +## xr_utils.py -Contains a class NowcastingDataModule - pl.LightningDataModule -This handles the - - amalgamation of all different data sources, - - making valid datetimes across all the sources, - - splitting into train and validation datasets - - -## datasets.py - -This file contains the following classes - -NetCDFDataset - torch.utils.data.Dataset: Use for loading pre-made batches -NowcastingDataset - torch.utils.data.IterableDataset: Dataset for making batches - - -## subset.py - -Function to subset the 'Batch' - -## fake.py - -A fake dataset, perhaps useful outside this repo. +Utilities for manipulating xarray DataArrays and Datasets. diff --git a/nowcasting_dataset/dataset/batch.py b/nowcasting_dataset/dataset/batch.py index a46258bf..16d14872 100644 --- a/nowcasting_dataset/dataset/batch.py +++ b/nowcasting_dataset/dataset/batch.py @@ -33,6 +33,7 @@ register_xr_data_array_to_tensor, register_xr_data_set_to_tensor, ) +from nowcasting_dataset.utils import get_netcdf_filename _LOG = logging.getLogger(__name__) @@ -165,7 +166,7 @@ def load_netcdf(local_netcdf_path: Union[Path, str], batch_idx: int): for data_source_name in data_sources_names: local_netcdf_filename = os.path.join( - local_netcdf_path, data_source_name, f"{batch_idx}.nc" + local_netcdf_path, data_source_name, get_netcdf_filename(batch_idx) ) # submit task diff --git a/nowcasting_dataset/dataset/datamodule.py b/nowcasting_dataset/dataset/datamodule.py deleted file mode 100644 index 33a7c0b5..00000000 --- a/nowcasting_dataset/dataset/datamodule.py +++ /dev/null @@ -1,372 +0,0 @@ -""" Data Modules """ -import logging -import warnings -from dataclasses import dataclass -from pathlib import Path -from typing import Callable, Dict, Iterable, Optional, Union - -import pandas as pd -import torch - -from nowcasting_dataset import consts, data_sources -from nowcasting_dataset.data_sources.data_source_list import DataSourceList -from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource -from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource -from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource -from nowcasting_dataset.dataset import datasets -from nowcasting_dataset.dataset.split.split import SplitMethod, split_data - -with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - import pytorch_lightning as pl - -logger = logging.getLogger(__name__) - - -@dataclass -class NowcastingDataModule(pl.LightningDataModule): - """ - Nowcasting Data Module, used to make batches - - Attributes (additional to the dataclass attributes): - pv_data_source: PVDataSource - sat_data_source: SatelliteDataSource - data_sources: List[DataSource] - train_t0_datetimes: pd.DatetimeIndex - val_t0_datetimes: pd.DatetimeIndex - """ - - pv_power_filename: Optional[Union[str, Path]] = None - pv_metadata_filename: Optional[Union[str, Path]] = None - batch_size: int = 8 - n_training_batches_per_epoch: int = 25_000 - n_validation_batches_per_epoch: int = 1_000 - n_test_batches_per_epoch: int = 1_000 - history_minutes: int = 30 #: Number of minutes of history, not including t0. - forecast_minutes: int = 60 #: Number of minutes of forecast, not including t0. - sat_filename: Union[str, Path] = consts.SAT_FILENAME - sat_channels: Iterable[str] = ("HRV",) - nwp_base_path: Optional[str] = None - nwp_channels: Optional[Iterable[str]] = ( - "t", - "dswrf", - "prate", - "r", - "sde", - "si10", - "vis", - "lcc", - "mcc", - "hcc", - ) - satellite_image_size_pixels: int = 128 #: Passed to Data Sources. - topographic_filename: Optional[Union[str, Path]] = None - sun_filename: Optional[Union[str, Path]] = None - nwp_image_size_pixels: int = 2 #: Passed to Data Sources. - meters_per_pixel: int = 2000 #: Passed to Data Sources. - pin_memory: bool = True #: Passed to DataLoader. - num_workers: int = 16 #: Passed to DataLoader. - prefetch_factor: int = 64 #: Passed to DataLoader. - n_samples_per_timestep: int = 2 #: Passed to NowcastingDataset - collate_fn: Callable = ( - torch.utils.data._utils.collate.default_collate - ) #: Passed to NowcastingDataset - gsp_filename: Optional[Union[str, Path]] = None - train_validation_percentage_split: float = 20 - pv_load_azimuth_and_elevation: bool = False - split_method: SplitMethod = SplitMethod.DAY # which split method should be used - seed: Optional[int] = None # seed used to make quasi random split data - t0_datetime_freq: str = "30T" # Frequency of the t0 datetimes. For example, if set to "30T" - # then create examples with T0 datetimes at thirty minute intervals, at 00 and 30 minutes - # past the hour. - - skip_n_train_batches: int = 0 # number of train batches to skip - skip_n_validation_batches: int = 0 # number of validation batches to skip - skip_n_test_batches: int = 0 # number of test batches to skip - - def __post_init__(self): - """ Post Init """ - super().__init__() - - self.history_length_30_minutes = self.history_minutes // 30 - self.forecast_length_30_minutes = self.forecast_minutes // 30 - - self.history_length_5_minutes = self.history_minutes // 5 - self.forecast_length_5_minutes = self.forecast_minutes // 5 - - # Plus 1 because neither history_length nor forecast_length include t0. - self._total_seq_length_5_minutes = ( - self.history_length_5_minutes + self.forecast_length_5_minutes + 1 - ) - self._total_seq_length_30_minutes = ( - self.history_length_30_minutes + self.forecast_length_30_minutes + 1 - ) - self.contiguous_dataset = None - if self.num_workers == 0: - self.prefetch_factor = 2 # Set to default when not using multiprocessing. - - def prepare_data(self) -> None: - """ Prepare all datasources """ - n_timesteps_per_batch = self.batch_size // self.n_samples_per_timestep - - self.sat_data_source = data_sources.SatelliteDataSource( - zarr_path=self.sat_filename, - image_size_pixels=self.satellite_image_size_pixels, - meters_per_pixel=self.meters_per_pixel, - history_minutes=self.history_minutes, - forecast_minutes=self.forecast_minutes, - channels=self.sat_channels, - n_timesteps_per_batch=n_timesteps_per_batch, - ) - - self.data_sources = [self.sat_data_source] - sat_datetimes = self.sat_data_source.datetime_index() - - # PV - if self.pv_power_filename is not None: - - self.pv_data_source = data_sources.PVDataSource( - filename=self.pv_power_filename, - metadata_filename=self.pv_metadata_filename, - start_dt=sat_datetimes[0], - end_dt=sat_datetimes[-1], - history_minutes=self.history_minutes, - forecast_minutes=self.forecast_minutes, - image_size_pixels=self.satellite_image_size_pixels, - meters_per_pixel=self.meters_per_pixel, - get_center=False, - load_azimuth_and_elevation=self.pv_load_azimuth_and_elevation, - ) - - self.data_sources = [self.pv_data_source, self.sat_data_source] - - if self.gsp_filename is not None: - self.gsp_data_source = GSPDataSource( - zarr_path=self.gsp_filename, - start_dt=sat_datetimes[0], - end_dt=sat_datetimes[-1], - history_minutes=self.history_minutes, - forecast_minutes=self.forecast_minutes, - image_size_pixels=self.satellite_image_size_pixels, - meters_per_pixel=self.meters_per_pixel, - get_center=True, - ) - - # put gsp data source at the start, so data is centered around GSP. This is the current - # approach, but in the future we may take a mix of GSP and PV data as the centroid. - self.data_sources = [self.gsp_data_source] + self.data_sources - - # NWP data - if self.nwp_base_path is not None: - self.nwp_data_source = data_sources.NWPDataSource( - zarr_path=self.nwp_base_path, - image_size_pixels=self.nwp_image_size_pixels, - meters_per_pixel=self.meters_per_pixel, - history_minutes=self.history_minutes, - forecast_minutes=self.forecast_minutes, - channels=self.nwp_channels, - n_timesteps_per_batch=n_timesteps_per_batch, - ) - - self.data_sources.append(self.nwp_data_source) - - # Topographic data - if self.topographic_filename is not None: - self.topo_data_source = data_sources.TopographicDataSource( - filename=self.topographic_filename, - image_size_pixels=self.satellite_image_size_pixels, - meters_per_pixel=self.meters_per_pixel, - history_minutes=self.history_minutes, - forecast_minutes=self.forecast_minutes, - ) - - self.data_sources.append(self.topo_data_source) - - # Sun data - if self.sun_filename is not None: - self.sun_data_source = SunDataSource( - zarr_path=self.sun_filename, - history_minutes=self.history_minutes, - forecast_minutes=self.forecast_minutes, - ) - self.data_sources.append(self.sun_data_source) - - self.datetime_data_source = data_sources.DatetimeDataSource( - history_minutes=self.history_minutes, - forecast_minutes=self.forecast_minutes, - ) - self.data_sources.append(self.datetime_data_source) - - self.data_sources.append( - MetadataDataSource( - history_minutes=self.history_minutes, - forecast_minutes=self.forecast_minutes, - object_at_center="GSP", - ) - ) - - self.data_sources = DataSourceList(self.data_sources) - - def setup(self, stage="fit"): - """Split data, etc. - - Args: - stage: {'fit', 'predict', 'test', 'validate'} This code ignores this. - - ## Selecting daytime data. - - We're interested in forecasting solar power generation, so we - don't care about nighttime data :) - - In the UK in summer, the sun rises first in the north east, and - sets last in the north west [1]. In summer, the north gets more - hours of sunshine per day. - - In the UK in winter, the sun rises first in the south east, and - sets last in the south west [2]. In winter, the south gets more - hours of sunshine per day. - - | | Summer | Winter | - | ---: | :---: | :---: | - | Sun rises first in | N.E. | S.E. | - | Sun sets last in | N.W. | S.W. | - | Most hours of sunlight | North | South | - - Before training, we select timesteps which have at least some - sunlight. We do this by computing the clearsky global horizontal - irradiance (GHI) for the four corners of the satellite imagery, - and for all the timesteps in the dataset. We only use timesteps - where the maximum global horizontal irradiance across all four - corners is above some threshold. - - The 'clearsky solar irradiance' is the amount of sunlight we'd - expect on a clear day at a specific time and location. The SI unit - of irradiance is watt per square meter. The 'global horizontal - irradiance' (GHI) is the total sunlight that would hit a - horizontal surface on the surface of the Earth. The GHI is the - sum of the direct irradiance (sunlight which takes a direct path - from the Sun to the Earth's surface) and the diffuse horizontal - irradiance (the sunlight scattered from the atmosphere). For more - info, see: https://en.wikipedia.org/wiki/Solar_irradiance - - References: - 1. [Video of June 2019](https://www.youtube.com/watch?v=IOp-tj-IJpk) - 2. [Video of Jan 2019](https://www.youtube.com/watch?v=CJ4prUVa2nQ) - """ - del stage # Not used in this method! - self._split_data() - - # Create datasets - logger.debug("Making train dataset") - self.train_dataset = datasets.NowcastingDataset( - t0_datetimes=self.train_t0_datetimes, - data_sources=self.data_sources, - skip_batch_index=self.skip_n_train_batches, - n_batches_per_epoch_per_worker=( - self._n_batches_per_epoch_per_worker(self.n_training_batches_per_epoch) - ), - **self._common_dataset_params(), - ) - logger.debug("Making validation dataset") - self.val_dataset = datasets.NowcastingDataset( - t0_datetimes=self.val_t0_datetimes, - data_sources=self.data_sources, - skip_batch_index=self.skip_n_validation_batches, - n_batches_per_epoch_per_worker=( - self._n_batches_per_epoch_per_worker(self.n_validation_batches_per_epoch) - ), - **self._common_dataset_params(), - ) - logger.debug("Making validation dataset: done") - - logger.debug("Making test dataset") - self.test_dataset = datasets.NowcastingDataset( - t0_datetimes=self.test_t0_datetimes, - data_sources=self.data_sources, - skip_batch_index=self.skip_n_test_batches, - n_batches_per_epoch_per_worker=( - self._n_batches_per_epoch_per_worker(self.n_test_batches_per_epoch) - ), - **self._common_dataset_params(), - ) - logger.debug("Making test dataset: done") - - if self.num_workers == 0: - self.train_dataset.per_worker_init(worker_id=0) - self.val_dataset.per_worker_init(worker_id=0) - self.test_dataset.per_worker_init(worker_id=0) - - logger.debug("Setup: done") - - def _n_batches_per_epoch_per_worker(self, n_batches_per_epoch: int) -> int: - if self.num_workers > 0: - return n_batches_per_epoch // self.num_workers - else: - return n_batches_per_epoch - - def _split_data(self): - """Sets self.train_t0_datetimes and self.val_t0_datetimes.""" - logger.debug("Going to split data") - - self._check_has_prepared_data() - self.t0_datetimes = self._get_t0_datetimes_across_all_data_sources() - - logger.debug(f"Got all start times, there are {len(self.t0_datetimes):,d}") - - data_after_splitting = split_data( - datetimes=self.t0_datetimes, method=self.split_method, seed=self.seed - ) - - self.train_t0_datetimes = data_after_splitting.train - self.val_t0_datetimes = data_after_splitting.validation - self.test_t0_datetimes = data_after_splitting.test - - logger.debug( - f"Split data done, train has {len(self.train_t0_datetimes):,d}, " - f"validation has {len(self.val_t0_datetimes):,d}, " - f"test has {len(self.test_t0_datetimes):,d} t0 datetimes." - ) - - def train_dataloader(self) -> torch.utils.data.DataLoader: - """ Train dataloader """ - return torch.utils.data.DataLoader(self.train_dataset, **self._common_dataloader_params()) - - def val_dataloader(self) -> torch.utils.data.DataLoader: - """ Validation dataloader """ - return torch.utils.data.DataLoader(self.val_dataset, **self._common_dataloader_params()) - - def test_dataloader(self) -> torch.utils.data.DataLoader: - """ Test dataloader """ - return torch.utils.data.DataLoader(self.test_dataset, **self._common_dataloader_params()) - - def _common_dataset_params(self) -> Dict: - return dict( - batch_size=self.batch_size, - n_samples_per_timestep=self.n_samples_per_timestep, - collate_fn=self.collate_fn, - ) - - def _common_dataloader_params(self) -> Dict: - return dict( - pin_memory=self.pin_memory, - num_workers=self.num_workers, - worker_init_fn=datasets.worker_init_fn, - prefetch_factor=self.prefetch_factor, - # Disable automatic batching because NowcastingDataset.__iter__ - # returns complete batches - batch_size=None, - batch_sampler=None, - ) - - def _get_t0_datetimes_across_all_data_sources(self) -> pd.DatetimeIndex: - """See DataSourceList.get_t0_datetimes_across_all_data_sources. - - This method will be deleted as part of implementing #213. - """ - return self.data_sources.get_t0_datetimes_across_all_data_sources( - freq=self.t0_datetime_freq - ) - - def _check_has_prepared_data(self): - if not self.has_prepared_data: - raise RuntimeError("Must run prepare_data() first!") diff --git a/nowcasting_dataset/dataset/datasets.py b/nowcasting_dataset/dataset/datasets.py deleted file mode 100644 index 5b03f63c..00000000 --- a/nowcasting_dataset/dataset/datasets.py +++ /dev/null @@ -1,203 +0,0 @@ -""" Dataset and functions""" -import logging -from concurrent import futures -from dataclasses import dataclass -from numbers import Number -from typing import Callable, List, Tuple - -import numpy as np -import pandas as pd -import torch -import xarray as xr - -from nowcasting_dataset import data_sources -from nowcasting_dataset.data_sources.satellite.satellite_data_source import SAT_VARIABLE_NAMES -from nowcasting_dataset.dataset.batch import Batch -from nowcasting_dataset.utils import set_fsspec_for_multiprocess - -logger = logging.getLogger(__name__) - -""" -This file contains the following classes -NetCDFDataset- torch.utils.data.Dataset: Use for loading pre-made batches -NowcastingDataset - torch.utils.data.IterableDataset: Dataset for making batches -""" - -# TODO: Can we get rid of SAT_MEAN and SAT_STD? See issue #231 -SAT_MEAN = xr.DataArray( - data=[ - 93.23458, - 131.71373, - 843.7779, - 736.6148, - 771.1189, - 589.66034, - 862.29816, - 927.69586, - 90.70885, - 107.58985, - 618.4583, - 532.47394, - ], - dims=["sat_variable"], - coords={"sat_variable": list(SAT_VARIABLE_NAMES)}, -).astype(np.float32) - -SAT_STD = xr.DataArray( - data=[ - 115.34247, - 139.92636, - 36.99538, - 57.366386, - 30.346825, - 149.68007, - 51.70631, - 35.872967, - 115.77212, - 120.997154, - 98.57828, - 99.76469, - ], - dims=["sat_variable"], - coords={"sat_variable": list(SAT_VARIABLE_NAMES)}, -).astype(np.float32) - -_LOG = logging.getLogger(__name__) - - -@dataclass -class NowcastingDataset(torch.utils.data.IterableDataset): - """ - The first data_source will be used to select the geo locations each batch. - """ - - batch_size: int - n_batches_per_epoch_per_worker: int - #: Number of times to re-use each timestep. Must exactly divide batch_size. - n_samples_per_timestep: int - data_sources: List[data_sources.DataSource] - t0_datetimes: pd.DatetimeIndex #: Valid t0 datetimes. - collate_fn: Callable = torch.utils.data._utils.collate.default_collate - - # useful way to skip batches if creating dataset fails halfway through. - # This might not be that useful, as re-running creation of datasets may cause off issues like duplicate data. - skip_batch_index: int = 0 - batch_index: int = 0 - - def __post_init__(self): - """ Post Init """ - super().__init__() - self._per_worker_init_has_run = False - self._n_timesteps_per_batch = self.batch_size // self.n_samples_per_timestep - - # Sanity checks. - if self.batch_size % self.n_samples_per_timestep != 0: - raise ValueError("n_crops_per_timestep must exactly divide batch_size!") - if len(self.t0_datetimes) < self._n_timesteps_per_batch: - raise ValueError( - f"start_dt_index only has {len(self.start_dt_index)}" - " timestamps." - f" Must have at least {self._n_timesteps_per_batch}!" - ) - - if self.skip_batch_index > 0: - _LOG.warning(f"Will be skipping {self.skip_batch_index}, is this correct?") - - def per_worker_init(self, worker_id: int) -> None: - """ - Called by worker_init_fn on each copy of NowcastingDataset - - This happens after the worker process has been spawned. - """ - # Each worker must have a different seed for its random number gen. - # Otherwise all the workers will output exactly the same data! - self.worker_id = worker_id - seed = torch.initial_seed() - self.rng = np.random.default_rng(seed=seed) - - # Initialise each data_source. - for data_source in self.data_sources: - _LOG.debug(f"Opening {type(data_source).__name__}") - data_source.open() - - # fix for fsspecs - set_fsspec_for_multiprocess() - - self._per_worker_init_has_run = True - - def __iter__(self): - """Yields a complete batch at a time.""" - if not self._per_worker_init_has_run: - raise RuntimeError("per_worker_init() must be run!") - for _ in range(self.n_batches_per_epoch_per_worker): - yield self._get_batch() - - def _get_batch(self) -> Batch: - - _LOG.debug(f"Getting batch {self.batch_index}") - - self.batch_index += 1 - if self.batch_index < self.skip_batch_index: - _LOG.debug(f"Skipping batch {self.batch_index}") - return [] - - t0_datetimes = self._get_t0_datetimes_for_batch() - x_locations, y_locations = self._get_locations(t0_datetimes) - - examples = {} - n_threads = len(self.data_sources) - with futures.ThreadPoolExecutor(max_workers=n_threads) as executor: - # Submit tasks to the executor. - future_examples_per_source = [] - for data_source in self.data_sources: - - future_examples = executor.submit( - data_source.get_batch, - t0_datetimes=t0_datetimes, - x_locations=x_locations, - y_locations=y_locations, - ) - future_examples_per_source.append(future_examples) - - # Collect results from each thread. - for future_examples in future_examples_per_source: - examples_from_source = future_examples.result() - - # print(type(examples_from_source)) - name = type(examples_from_source).__name__.lower() - examples[name] = examples_from_source - - examples["batch_size"] = len(t0_datetimes) - - return Batch(**examples) - - def _get_t0_datetimes_for_batch(self) -> pd.DatetimeIndex: - # Pick random datetimes. - t0_datetimes = self.rng.choice( - self.t0_datetimes, size=self._n_timesteps_per_batch, replace=False - ) - # Duplicate these random datetimes. - t0_datetimes = np.tile(t0_datetimes, reps=self.n_samples_per_timestep) - return pd.DatetimeIndex(t0_datetimes) - - def _get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]: - return self.data_sources[0].get_locations(t0_datetimes) - - -def worker_init_fn(worker_id): - """Configures each dataset worker process. - - 1. Get fsspec ready for multi process - 2. To call NowcastingDataset.per_worker_init(). - """ - # fix for fsspec when using multprocess - set_fsspec_for_multiprocess() - - # get_worker_info() returns information specific to each worker process. - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: - print("worker_info is None!") - else: - # The NowcastingDataset copy in this worker process. - dataset_obj = worker_info.dataset - dataset_obj.per_worker_init(worker_info.id) diff --git a/nowcasting_dataset/dataset/split/method.py b/nowcasting_dataset/dataset/split/method.py index dfae20c2..eeb64179 100644 --- a/nowcasting_dataset/dataset/split/method.py +++ b/nowcasting_dataset/dataset/split/method.py @@ -13,7 +13,9 @@ def split_method( datetimes: pd.DatetimeIndex, train_test_validation_split: Tuple[int] = (3, 1, 1), - train_test_validation_specific: TrainValidationTestSpecific = default_train_test_validation_specific, + train_test_validation_specific: TrainValidationTestSpecific = ( + default_train_test_validation_specific + ), method: str = "modulo", freq: str = "D", seed: int = 1234, @@ -38,10 +40,11 @@ def split_method( datetimes: list of datetimes train_test_validation_split: how the split is made method: which method to use. Can be modulo or random - freq: This can be D=day, W=week, M=month and Y=year. This means the data is divided up by different periods + freq: This can be D=day, W=week, M=month and Y=year. This means the data is divided up by + different periods seed: random seed used to permutate the data for the 'random' method - train_test_validation_specific: pydandic class of 'train', 'validation' and 'test'. These specifies - which data goes into which datasets + train_test_validation_specific: pydandic class of 'train', 'validation' and 'test'. + These specify which data goes into which datasets Returns: train, validation and test datetimes @@ -56,7 +59,8 @@ def split_method( if method == "modulo": # Method to split by module. - # I.e 1st, 2nd, 3rd periods goes to train, 4th goes to validation, 5th goes to test and repeat. + # I.e 1st, 2nd, 3rd periods goes to train, 4th goes to validation, 5th goes to test and + # repeat. # make which day indexes go i.e if the split is [3,1,1] then the # - train_ indexes = [0,1,2] @@ -133,14 +137,17 @@ def split_by_dates( """ Split datetimes into train, validation and test by two specific datetime splits - Note that the 'train_validation_datetime_split' should be less than the 'validation_test_datetime_split' + Note that the 'train_validation_datetime_split' should be less than the + 'validation_test_datetime_split' Args: datetimes: list of datetimes - train_validation_datetime_split: the datetime which will split the train and validation datetimes. - For example if this is '2021-01-01' then the train datetimes will end by '2021-01-01' and the - validation datetimes will start at '2021-01-01'. - validation_test_datetime_split: the datetime which will split the validation and test datetimes + train_validation_datetime_split: the datetime which will split the train and validation + datetimes. + For example if this is '2021-01-01' then the train datetimes will end by '2021-01-01' and + the validation datetimes will start at '2021-01-01'. + validation_test_datetime_split: the datetime which will split the validation and + test datetimes Returns: train, validation and test datetimes diff --git a/nowcasting_dataset/dataset/split/split.py b/nowcasting_dataset/dataset/split/split.py index ce62873b..4f1e134b 100644 --- a/nowcasting_dataset/dataset/split/split.py +++ b/nowcasting_dataset/dataset/split/split.py @@ -39,8 +39,9 @@ class SplitName(Enum): TEST = "test" -SplitData = namedtuple( - typename="SplitData", +# Create a namedtuple for storing split t0 datetimes. +SplitDateTimes = namedtuple( + typename="SplitDateTimes", field_names=[SplitName.TRAIN.value, SplitName.VALIDATION.value, SplitName.TEST.value], ) @@ -54,7 +55,7 @@ def split_data( ), train_validation_test_datetime_split: Optional[List[pd.Timestamp]] = None, seed: int = 1234, -) -> SplitData: +) -> SplitDateTimes: """ Split the date using various different methods @@ -133,9 +134,11 @@ def split_data( if method == SplitMethod.DAY_RANDOM_TEST_YEAR: # This method splits # 1. test set to be in one year, using 'train_test_validation_specific' - # 2. train and validation by random day, using 'train_test_validation_split' on ratio how to split it + # 2. train and validation by random day, using 'train_test_validation_split' on ratio + # how to split it. # - # This allows us to create a test set for 2021, and train and validation for random days not in 2021 + # This allows us to create a test set for 2021, and train and validation for + # random days not in 2021. # create test set train_datetimes, validation_datetimes, test_datetimes = split_method( @@ -149,10 +152,11 @@ def split_data( elif method == SplitMethod.DAY_RANDOM_TEST_DATE: # This method splits # 1. test set from one date onwards - # 2. train and validation by random day, using 'train_test_validation_split' on ratio how to split it + # 2. train and validation by random day, using 'train_test_validation_split' on ratio + # how to split it. # - # This allows us to create a test set from a specfic date e.g. 2020-07-01, and train and validation - # for random days before that date + # This allows us to create a test set from a specfic date e.g. 2020-07-01, and train + # and validation for random days before that date. # create test set train_datetimes, validation_datetimes, test_datetimes = split_by_dates( @@ -180,10 +184,22 @@ def split_data( else: raise ValueError(f"{method} for splitting day is not implemented") - logger.debug( - f"Split data done, train has {len(train_datetimes):,d}, " - f"validation has {len(validation_datetimes):,d}, " - f"test has {len(test_datetimes):,d} t0 datetimes." + # Sanity check! + if method != SplitMethod.SAME: + assert len(train_datetimes.intersection(validation_datetimes)) == 0 + assert len(train_datetimes.intersection(test_datetimes)) == 0 + assert len(test_datetimes.intersection(validation_datetimes)) == 0 + + assert train_datetimes.unique + assert validation_datetimes.unique + assert test_datetimes.unique + + split_datetimes = SplitDateTimes( + train=train_datetimes, validation=validation_datetimes, test=test_datetimes ) - return SplitData(train=train_datetimes, validation=validation_datetimes, test=test_datetimes) + logger.debug("Split data done!") + for split_name, dt in split_datetimes._asdict().items(): + logger.debug(f"{split_name} has {len(dt):,d} datetimes, from {dt[0]} to {dt[-1]}") + + return split_datetimes diff --git a/nowcasting_dataset/dataset/xr_utils.py b/nowcasting_dataset/dataset/xr_utils.py index e4f301a7..d3f85ed6 100644 --- a/nowcasting_dataset/dataset/xr_utils.py +++ b/nowcasting_dataset/dataset/xr_utils.py @@ -11,48 +11,64 @@ import xarray as xr -def join_list_data_array_to_batch_dataset(image_data_arrays: List[xr.DataArray]) -> xr.Dataset: - """ Join a list of data arrays to a dataset byt expanding dims """ - image_data_arrays = [ - convert_data_array_to_dataset(image_data_arrays[i]) for i in range(len(image_data_arrays)) - ] +# TODO: This function is only used in fake.py for testing. +# Maybe we should move this function to fake.py? +def join_list_data_array_to_batch_dataset(data_arrays: List[xr.DataArray]) -> xr.Dataset: + """Join a list of xr.DataArrays into an xr.Dataset by concatenating on the example dim.""" + datasets = [convert_data_array_to_dataset(data_arrays[i]) for i in range(len(data_arrays))] - return join_dataset_to_batch_dataset(image_data_arrays) + return join_list_dataset_to_batch_dataset(datasets) -def join_dataset_to_batch_dataset(image_data_arrays: List[xr.Dataset]) -> xr.Dataset: - """ Join a list of data arrays to a dataset byt expanding dims """ - image_data_arrays = [ - image_data_arrays[i].expand_dims(dim="example").assign_coords(example=("example", [i])) - for i in range(len(image_data_arrays)) - ] +def join_list_dataset_to_batch_dataset(datasets: list[xr.Dataset]) -> xr.Dataset: + """ Join a list of data sets to a dataset by expanding dims """ - return xr.concat(image_data_arrays, dim="example") + new_datasets = [] + for i, dataset in enumerate(datasets): + new_dataset = dataset.expand_dims(dim="example").assign_coords(example=("example", [i])) + new_datasets.append(new_dataset) + return xr.concat(new_datasets, dim="example") -def convert_data_array_to_dataset(data_xarray): + +# TODO: Issue #318: Maybe remove this function and, in calling code, do data_array.to_dataset() +# followed by make_dim_index, to make it more explicit what's happening? At the moment, +# in the calling code, it's not clear that the coordinates are being changed. +def convert_data_array_to_dataset(data_xarray: xr.DataArray) -> xr.Dataset: """ Convert data array to dataset. Reindex dim so that it can be merged with batch""" data = xr.Dataset({"data": data_xarray}) - - return make_dim_index(data_xarray_dataset=data) + return make_dim_index(dataset=data) -def make_dim_index(data_xarray_dataset: xr.Dataset) -> xr.Dataset: - """ Reindex dataset dims so that it can be merged with batch""" +# TODO: Issue #318: Maybe rename this function... maybe to coord_to_range()? +# Not sure what's best right now! :) +def make_dim_index(dataset: xr.Dataset) -> xr.Dataset: + """Reindex dims so that it can be merged with batch. - dims = data_xarray_dataset.dims + For each dimension in dataset, change the coords to 0.. len(original_coords), + and append "_index" to the dimension name. + And save the original coordinates in `original_dim_name`. - for dim in dims: - coord = data_xarray_dataset[dim] - data_xarray_dataset[dim] = np.arange(len(coord)) - - data_xarray_dataset = data_xarray_dataset.rename({dim: f"{dim}_index"}) + This is useful to align multiple examples into a single batch. + """ - data_xarray_dataset[dim] = xr.DataArray( - coord, coords=data_xarray_dataset[f"{dim}_index"].coords, dims=[f"{dim}_index"] + original_dim_names = dataset.dims + + for original_dim_name in original_dim_names: + original_coords = dataset[original_dim_name] + new_index_coords = np.arange(len(original_coords)) + new_index_dim_name = f"{original_dim_name}_index" + dataset[original_dim_name] = new_index_coords + dataset = dataset.rename({original_dim_name: new_index_dim_name}) + # Save the original_coords back into dataset, but this time it won't be used as + # coords for the variables payload in the dataset. + dataset[original_dim_name] = xr.DataArray( + original_coords, + coords=[new_index_coords], + dims=[new_index_dim_name], ) - return data_xarray_dataset + return dataset class PydanticXArrayDataSet(xr.Dataset): @@ -64,7 +80,7 @@ class PydanticXArrayDataSet(xr.Dataset): _expected_dimensions = () # Subclasses should set this. - # xarray doesnt support sub classing at the moment - https://github.com/pydata/xarray/issues/3980 + # xarray doesnt support sub classing at the moment: https://github.com/pydata/xarray/issues/3980 __slots__ = () @classmethod @@ -95,7 +111,8 @@ def validate_dims(cls, v: Any) -> Any: ), ( f"{cls.__name__}.dims is wrong! " f"{cls.__name__}.dims is {v.dims}. " - f"But we expected {cls._expected_dimensions}. Note that '_index' is removed, and 'example' is ignored" + f"But we expected {cls._expected_dimensions}." + " Note that '_index' is removed, and 'example' is ignored" ) return v diff --git a/nowcasting_dataset/filesystem/utils.py b/nowcasting_dataset/filesystem/utils.py index b88aa869..24eca738 100644 --- a/nowcasting_dataset/filesystem/utils.py +++ b/nowcasting_dataset/filesystem/utils.py @@ -4,6 +4,8 @@ from typing import List, Union import fsspec +import numpy as np +from pathy import Pathy _LOG = logging.getLogger("nowcasting_dataset") @@ -13,45 +15,57 @@ def upload_and_delete_local_files(dst_path: str, local_path: Path): Upload an entire folder and delete local files to either AWS or GCP """ _LOG.info("Uploading!") - filesystem = fsspec.open(dst_path).fs + filesystem = get_filesystem(dst_path) filesystem.put(str(local_path), dst_path, recursive=True) delete_all_files_in_temp_path(local_path) -def get_maximum_batch_id(path: str): +def get_filesystem(path: Union[str, Path]) -> fsspec.AbstractFileSystem: + r"""Get the fsspect FileSystem from a path. + + For example, if `path` starts with `gs:\\` then return a fsspec.GCSFileSystem. + + It is safe for `path` to include wildcards in the final filename. """ - Get the last batch ID. Works with GCS, AWS, and local. + path = Pathy(path) + return fsspec.open(path.parent).fs - Args: - path: the path folder to look in. Begin with 'gs://' for GCS. Begin with 's3://' for AWS S3. - Returns: the maximum batch id of data in `path`. +def get_maximum_batch_id(path: Pathy) -> int: """ - _LOG.debug(f"Looking for maximum batch id in {path}") + Get the last batch ID. Works with GCS, AWS, and local. - filesystem = fsspec.open(path).fs - if not filesystem.exists(path): - _LOG.debug(f"{path} does not exists") - return None + Args: + path: The path folder to look in. + Begin with 'gs://' for GCS. Begin with 's3://' for AWS S3. + Supports wildcards *, **, ?, and [..]. - filenames = get_all_filenames_in_path(path=path) + Returns: The maximum batch id of data in `path`. - # just take filename - filenames = [filename.split("/")[-1] for filename in filenames] + Raises FileNotFoundError if `path` does not exist. + """ + _LOG.debug(f"Looking for maximum batch id in {path}") - # remove suffix - filenames = [filename.split(".")[0] for filename in filenames] + filesystem = get_filesystem(path) + if not filesystem.exists(path.parent): + msg = f"{path.parent} does not exist" + _LOG.warning(msg) + raise FileNotFoundError(msg) - # change to integer - batch_indexes = [int(filename) for filename in filenames if len(filename) > 0] + filenames = filesystem.glob(path) - # if there is no files, return None - if len(batch_indexes) == 0: + # if there is no files, return 0 + if len(filenames) == 0: _LOG.debug(f"Did not find any files in {path}") - return None - - # get the maximum batch id - maximum_batch_id = max(batch_indexes) + return 0 + + # Now that filenames have leading zeros (like 000001.nc), we can use lexographical sorting + # to find the last filename, instead of having to convert all filenames to int. + filenames = np.sort(filenames) + last_filename = filenames[-1] + last_filename = Pathy(last_filename) + last_filename_stem = last_filename.stem + maximum_batch_id = int(last_filename_stem) _LOG.debug(f"Found maximum of batch it of {maximum_batch_id} in {path}") return maximum_batch_id @@ -61,14 +75,14 @@ def delete_all_files_in_temp_path(path: Union[Path, str], delete_dirs: bool = Fa """ Delete all the files in a temporary path. Option to delete the folders or not """ - filesystem = fsspec.open(path).fs + filesystem = get_filesystem(path) filenames = get_all_filenames_in_path(path=path) _LOG.info(f"Deleting {len(filenames)} files from {path}.") if delete_dirs: - for file in filenames: - filesystem.rm(file, recursive=True) + for filename in filenames: + filesystem.rm(str(filename), recursive=True) else: # loop over folder structure, but only delete files for root, dirs, files in filesystem.walk(path): @@ -78,10 +92,10 @@ def delete_all_files_in_temp_path(path: Union[Path, str], delete_dirs: bool = Fa def check_path_exists(path: Union[str, Path]): - """Raises a RuntimeError if `path` does not exist in the local filesystem.""" - filesystem = fsspec.open(path).fs + """Raises a FileNotFoundError if `path` does not exist.""" + filesystem = get_filesystem(path) if not filesystem.exists(path): - raise RuntimeError(f"{path} does not exist!") + raise FileNotFoundError(f"{path} does not exist!") def rename_file(remote_file: str, new_filename: str): @@ -93,20 +107,20 @@ def rename_file(remote_file: str, new_filename: str): new_filename: What the file should be renamed too """ - filesystem = fsspec.open(remote_file).fs + filesystem = get_filesystem(remote_file) filesystem.mv(remote_file, new_filename) def get_all_filenames_in_path(path: Union[str, Path]) -> List[str]: """ - Get all the files names from one folder in gcp + Get all the files names from one folder. Args: - path: the path that we should look in + path: The path that we should look in. - Returns: a list of files names represented as strings. + Returns: A list of filenames represented as strings. """ - filesystem = fsspec.open(path).fs + filesystem = get_filesystem(path) return filesystem.ls(path) @@ -120,7 +134,11 @@ def download_to_local(remote_filename: str, local_filename: str): """ _LOG.debug(f"Downloading from GCP {remote_filename} to {local_filename}") - filesystem = fsspec.open(remote_filename).fs + # Check the inputs are strings + remote_filename = str(remote_filename) + local_filename = str(local_filename) + + filesystem = get_filesystem(remote_filename) filesystem.get(remote_filename, local_filename) @@ -136,12 +154,20 @@ def upload_one_file( local_filename: the local file name """ - filesystem = fsspec.open(remote_filename).fs + filesystem = get_filesystem(remote_filename) filesystem.put(local_filename, remote_filename) -def make_folder(path: Union[str, Path]): - """ Make folder """ - filesystem = fsspec.open(path).fs - if not filesystem.exists(path): - filesystem.mkdir(path) +def makedirs(path: Union[str, Path], exist_ok: bool = True) -> None: + """Recursively make directories + + Creates directory at path and any intervening required directories. + + Raises exception if, for instance, the path already exists but is a file. + + Args: + path: The path to create. + exist_ok: If False then raise an exception if `path` already exists. + """ + filesystem = get_filesystem(path) + filesystem.makedirs(path, exist_ok=exist_ok) diff --git a/nowcasting_dataset/geospatial.py b/nowcasting_dataset/geospatial.py index e31fae76..41713e48 100644 --- a/nowcasting_dataset/geospatial.py +++ b/nowcasting_dataset/geospatial.py @@ -98,17 +98,21 @@ def calculate_azimuth_and_elevation_angle( latitude: float, longitude: float, datestamps: [datetime] ) -> pd.DataFrame: """ - Calculation the azimuth angle, and the elevation angle for several datetamps, but for one specific osgb location + Calculation the azimuth angle, and the elevation angle for several datetamps. - More details see: https://www.celestis.com/resources/faq/what-are-the-azimuth-and-elevation-of-a-satellite/ + But for one specific osgb location + + More details see: + https://www.celestis.com/resources/faq/what-are-the-azimuth-and-elevation-of-a-satellite/ Args: latitude: latitude of the pv site longitude: longitude of the pv site - datestamps: list of datestamps to calculate the sun angles. i.e the sun moves from east to west in the day. + datestamps: list of datestamps to calculate the sun angles. i.e the sun moves from east to + west in the day. - Returns: Pandas data frame with the index the same as 'datestamps', with columns of "elevation" and "azimuth" that - have been calculate. + Returns: Pandas data frame with the index the same as 'datestamps', with columns of + "elevation" and "azimuth" that have been calculate. """ # get the solor position diff --git a/nowcasting_dataset/manager.py b/nowcasting_dataset/manager.py new file mode 100644 index 00000000..f49edd55 --- /dev/null +++ b/nowcasting_dataset/manager.py @@ -0,0 +1,410 @@ +"""Manager class.""" + +import logging +from concurrent import futures +from pathlib import Path +from typing import Optional, Union + +import numpy as np +import pandas as pd + +# nowcasting_dataset imports +import nowcasting_dataset.time as nd_time +import nowcasting_dataset.utils as nd_utils +from nowcasting_dataset import config +from nowcasting_dataset.consts import ( + SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES, + SPATIAL_AND_TEMPORAL_LOCATIONS_OF_EACH_EXAMPLE_FILENAME, +) +from nowcasting_dataset.data_sources import ALL_DATA_SOURCE_NAMES, MAP_DATA_SOURCE_NAME_TO_CLASS +from nowcasting_dataset.dataset.split import split +from nowcasting_dataset.filesystem import utils as nd_fs_utils + +logger = logging.getLogger(__name__) + + +class Manager: + """The Manager initialises and manage a dict of DataSource objects. + + Attrs: + config: Configuration object. + data_sources: dict[str, DataSource] + data_source_which_defines_geospatial_locations: DataSource: The DataSource used to compute the + geospatial locations of each example. + save_batches_locally_and_upload: bool: Set to True by `load_yaml_configuration()` if + `config.process.upload_every_n_batches > 0`. + local_temp_path: Path: `config.process.local_temp_path` with `~` expanded. + """ + + def __init__(self) -> None: # noqa: D107 + self.config = None + self.data_sources = {} + self.data_source_which_defines_geospatial_locations = None + + def load_yaml_configuration(self, filename: str) -> None: + """Load YAML config from `filename`.""" + logger.debug(f"Loading YAML configuration file {filename}") + self.config = config.load_yaml_configuration(filename) + self.config = config.set_git_commit(self.config) + self.save_batches_locally_and_upload = self.config.process.upload_every_n_batches > 0 + + # TODO: Issue #320: This could be done in the Pydantic model? + self.local_temp_path = Path(self.config.process.local_temp_path).expanduser() + logger.debug(f"config={self.config}") + + def initialise_data_sources( + self, names_of_selected_data_sources: Optional[list[str]] = ALL_DATA_SOURCE_NAMES + ) -> None: + """Initialise DataSources specified in the InputData configuration. + + For each key in each DataSource's configuration object, the string `_` + is removed from the key before passing to the DataSource constructor. This allows us to + have verbose field names in the configuration YAML files, whilst also using standard + constructor arguments for DataSources. + """ + for data_source_name in names_of_selected_data_sources: + logger.debug(f"Creating {data_source_name} DataSource object.") + config_for_data_source = getattr(self.config.input_data, data_source_name) + if config_for_data_source is None: + logger.info(f"No configuration found for {data_source_name}.") + continue + config_for_data_source = config_for_data_source.dict() + + # Strip `_` from the config option field names. + config_for_data_source = nd_utils.remove_regex_pattern_from_keys( + config_for_data_source, pattern_to_remove=f"^{data_source_name}_" + ) + + data_source_class = MAP_DATA_SOURCE_NAME_TO_CLASS[data_source_name] + try: + data_source = data_source_class(**config_for_data_source) + except Exception: + logger.exception(f"Exception whilst instantiating {data_source_name}!") + raise + self.data_sources[data_source_name] = data_source + + # Set data_source_which_defines_geospatial_locations: + try: + self.data_source_which_defines_geospatial_locations = self.data_sources[ + self.config.input_data.data_source_which_defines_geospatial_locations + ] + except KeyError: + msg = ( + "input_data.data_source_which_defines_geospatial_locations=" + f"{self.config.input_data.data_source_which_defines_geospatial_locations}" + " is not a member of the DataSources, so cannot set" + " self.data_source_which_defines_geospatial_locations!" + f" The available DataSources are: {list(self.data_sources.keys())}" + ) + logger.error(msg) + raise RuntimeError(msg) + else: + logger.info( + f"DataSource `{data_source_name}` set as" + " data_source_which_defines_geospatial_locations." + ) + + def create_files_specifying_spatial_and_temporal_locations_of_each_example_if_necessary( + self, + ) -> None: + """Creates CSV files specifying the locations of each example if those files don't exist yet. + + Creates one file per split, in this location: + + ` / / spatial_and_temporal_locations_of_each_example.csv` + + Creates the output directory if it does not exist. + + Works on any compute environment. + """ + if self._locations_csv_file_exists(): + logger.info( + f"{SPATIAL_AND_TEMPORAL_LOCATIONS_OF_EACH_EXAMPLE_FILENAME} already exists!" + ) + return + logger.info( + f"{SPATIAL_AND_TEMPORAL_LOCATIONS_OF_EACH_EXAMPLE_FILENAME} does not exist so" + " will create..." + ) + t0_datetimes = self.get_t0_datetimes_across_all_data_sources( + freq=self.config.process.t0_datetime_frequency + ) + split_t0_datetimes = split.split_data( + datetimes=t0_datetimes, method=self.config.process.split_method + ) + for split_name, datetimes_for_split in split_t0_datetimes._asdict().items(): + n_batches = self._get_n_batches_for_split_name(split_name) + n_examples = n_batches * self.config.process.batch_size + logger.debug( + f"Creating {n_batches:,d} batches x {self.config.process.batch_size:,d} examples" + f" per batch = {n_examples:,d} examples for {split_name}." + ) + df_of_locations = self.sample_spatial_and_temporal_locations_for_examples( + t0_datetimes=datetimes_for_split, n_examples=n_examples + ) + output_filename = self._filename_of_locations_csv_file(split_name) + path_for_csv = self.config.output_data.filepath / split_name + logger.info(f"Making {path_for_csv} if it does not exist.") + nd_fs_utils.makedirs(path_for_csv, exist_ok=True) + logger.debug(f"Writing {output_filename}") + df_of_locations.to_csv(output_filename) + + def _get_n_batches_for_split_name(self, split_name: str) -> int: + return getattr(self.config.process, f"n_{split_name}_batches") + + def _filename_of_locations_csv_file(self, split_name: str) -> Path: + return ( + self.config.output_data.filepath + / split_name + / SPATIAL_AND_TEMPORAL_LOCATIONS_OF_EACH_EXAMPLE_FILENAME + ) + + def _locations_csv_file_exists(self) -> bool: + """Check if filepath/train/spatial_and_temporal_locations_of_each_example.csv exists.""" + filename = self._filename_of_locations_csv_file(split_name=split.SplitName.TRAIN.value) + try: + nd_fs_utils.check_path_exists(filename) + except FileNotFoundError: + logging.info(f"{filename} does not exist!") + return False + else: + logger.info(f"{filename} exists!") + return True + + def get_t0_datetimes_across_all_data_sources( + self, freq: Union[str, pd.Timedelta] + ) -> pd.DatetimeIndex: + """ + Compute the intersection of the t0 datetimes available across all DataSources. + + Args: + freq: The Pandas frequency string. The returned DatetimeIndex will be at this frequency, + and every datetime will be aligned to this frequency. For example, if + freq='5 minutes' then every datetime will be at 00, 05, ..., 55 minutes + past the hour. + + Returns: Valid t0 datetimes, taking into consideration all DataSources, + filtered by daylight hours (SatelliteDataSource.datetime_index() removes the night + datetimes). + """ + logger.debug( + f"Getting the intersection of time periods across all DataSources at freq={freq}..." + ) + if set(self.data_sources.keys()) != set(ALL_DATA_SOURCE_NAMES): + logger.warning( + "Computing available t0 datetimes using less than all available DataSources!" + " Are you sure you mean to do this?!" + ) + + # Get the intersection of t0 time periods from all data sources. + t0_time_periods_for_all_data_sources = [] + for data_source_name, data_source in self.data_sources.items(): + logger.debug(f"Getting t0 time periods for {data_source_name}") + try: + t0_time_periods = data_source.get_contiguous_t0_time_periods() + except NotImplementedError: + # Skip data_sources with no concept of time. + logger.debug(f"Skipping {data_source_name} because it has not concept of datetime.") + else: + t0_time_periods_for_all_data_sources.append(t0_time_periods) + + intersection_of_t0_time_periods = nd_time.intersection_of_multiple_dataframes_of_periods( + t0_time_periods_for_all_data_sources + ) + + t0_datetimes = nd_time.time_periods_to_datetime_index( + time_periods=intersection_of_t0_time_periods, freq=freq + ) + logger.debug( + f"Found {len(t0_datetimes):,d} datetimes at freq=`{freq}` across" + f" DataSources={list(self.data_sources.keys())}." + f" From {t0_datetimes[0]} to {t0_datetimes[-1]}." + ) + return t0_datetimes + + def sample_spatial_and_temporal_locations_for_examples( + self, t0_datetimes: pd.DatetimeIndex, n_examples: int + ) -> pd.DataFrame: + """ + Computes the geospatial and temporal locations for each training example. + + The first data_source in this DataSourceList defines the geospatial locations of + each example. + + Args: + t0_datetimes: All available t0 datetimes. Can be computed with + `DataSourceList.get_t0_datetimes_across_all_data_sources()` + n_examples: The number of examples requested. + + Returns: + Each row of each the DataFrame specifies the position of each example, using + columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'. + """ + shuffled_t0_datetimes = np.random.choice(t0_datetimes, size=n_examples) + # TODO: Issue #304. Speed this up by splitting the shuffled_t0_datetimes across + # multiple processors. Currently takes about half an hour for 25,000 batches. + # But wait until we've implemented issue #305, as that is likely to be sufficient! + ( + x_locations, + y_locations, + ) = self.data_source_which_defines_geospatial_locations.get_locations(shuffled_t0_datetimes) + return pd.DataFrame( + { + "t0_datetime_UTC": shuffled_t0_datetimes, + "x_center_OSGB": x_locations, + "y_center_OSGB": y_locations, + } + ) + + def _get_first_batches_to_create( + self, overwrite_batches: bool + ) -> dict[split.SplitName, dict[str, int]]: + """For each SplitName & for each DataSource name, return the first batch ID to create. + + For example, the returned_dict[SplitName.TRAIN]['gsp'] tells us the first batch_idx to + create for the training set for the GSPDataSource. + """ + # Initialise to zero: + first_batches_to_create: dict[split.SplitName, dict[str, int]] = {} + for split_name in split.SplitName: + first_batches_to_create[split_name] = { + data_source_name: 0 for data_source_name in self.data_sources + } + + if overwrite_batches: + return first_batches_to_create + + # If we're not overwriting batches then find the last batch on disk. + for split_name in split.SplitName: + for data_source_name in self.data_sources: + path = ( + self.config.output_data.filepath / split_name.value / data_source_name / "*.nc" + ) + try: + max_batch_id_on_disk = nd_fs_utils.get_maximum_batch_id(path) + except FileNotFoundError: + max_batch_id_on_disk = -1 + first_batches_to_create[split_name][data_source_name] = max_batch_id_on_disk + 1 + + return first_batches_to_create + + def _check_if_more_batches_are_required_for_split( + self, + split_name: split.SplitName, + first_batches_to_create: dict[split.SplitName, dict[str, int]], + ) -> bool: + """Returns True if batches still need to be created for any DataSource.""" + n_batches_requested = self._get_n_batches_for_split_name(split_name.value) + for data_source_name in self.data_sources: + if first_batches_to_create[split_name][data_source_name] < n_batches_requested: + return True + return False + + def _find_splits_which_need_more_batches( + self, first_batches_to_create: dict[split.SplitName, dict[str, int]] + ) -> list[split.SplitName]: + """Returns list of SplitNames which need more batches to be produced.""" + splits_which_need_more_batches = [] + for split_name in split.SplitName: + if self._check_if_more_batches_are_required_for_split( + split_name, first_batches_to_create + ): + splits_which_need_more_batches.append(split_name) + return splits_which_need_more_batches + + def create_batches(self, overwrite_batches: bool) -> None: + """Create batches (if necessary). + + Make dirs: ` / / `. + + Also make `local_temp_path` if necessary. + + Args: + overwrite_batches: If True then start from batch 0, regardless of which batches have + previously been written to disk. If False then check which batches have previously been + written to disk, and only create any batches which have not yet been written to disk. + """ + first_batches_to_create = self._get_first_batches_to_create(overwrite_batches) + + # Check if there's any work to do. + if overwrite_batches: + splits_which_need_more_batches = [split_name for split_name in split.SplitName] + else: + splits_which_need_more_batches = self._find_splits_which_need_more_batches( + first_batches_to_create + ) + if len(splits_which_need_more_batches) == 0: + logger.info("All batches have already been created! No work to do!") + return + + # Load locations for each example off disk. + locations_for_each_example_of_each_split: dict[split.SplitName, pd.DataFrame] = {} + for split_name in splits_which_need_more_batches: + filename = self._filename_of_locations_csv_file(split_name.value) + logger.info(f"Loading {filename}.") + locations_for_each_example = pd.read_csv(filename, index_col=0) + assert locations_for_each_example.columns.to_list() == list( + SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES + ) + # Converting to datetimes is much faster using `pd.to_datetime()` than + # passing `parse_datetimes` into `pd.read_csv()`. + locations_for_each_example["t0_datetime_UTC"] = pd.to_datetime( + locations_for_each_example["t0_datetime_UTC"] + ) + locations_for_each_example_of_each_split[split_name] = locations_for_each_example + + # Fire up a separate process for each DataSource, and pass it a list of batches to + # create, and whether to utils.upload_and_delete_local_files(). + # TODO: Issue 321: Split this up into separate functions!!! + n_data_sources = len(self.data_sources) + nd_utils.set_fsspec_for_multiprocess() + for split_name in splits_which_need_more_batches: + locations_for_split = locations_for_each_example_of_each_split[split_name] + with futures.ProcessPoolExecutor(max_workers=n_data_sources) as executor: + future_create_batches_jobs = [] + for worker_id, (data_source_name, data_source) in enumerate( + self.data_sources.items() + ): + # Get indexes of first batch and example. And subset locations_for_split. + idx_of_first_batch = first_batches_to_create[split_name][data_source_name] + idx_of_first_example = idx_of_first_batch * self.config.process.batch_size + locations = locations_for_split.loc[idx_of_first_example:] + + # Get paths. + dst_path = ( + self.config.output_data.filepath / split_name.value / data_source_name + ) + local_temp_path = ( + self.local_temp_path + / split_name.value + / data_source_name + / f"worker_{worker_id}" + ) + + # Make folders. + nd_fs_utils.makedirs(dst_path, exist_ok=True) + if self.save_batches_locally_and_upload: + nd_fs_utils.makedirs(local_temp_path, exist_ok=True) + + # Submit data_source.create_batches task to the worker process. + future = executor.submit( + data_source.create_batches, + spatial_and_temporal_locations_of_each_example=locations, + idx_of_first_batch=idx_of_first_batch, + batch_size=self.config.process.batch_size, + dst_path=dst_path, + local_temp_path=local_temp_path, + upload_every_n_batches=self.config.process.upload_every_n_batches, + ) + future_create_batches_jobs.append(future) + + # Wait for all futures to finish: + for future, data_source_name in zip( + future_create_batches_jobs, self.data_sources.keys() + ): + # Call exception() to propagate any exceptions raised by the worker process into + # the main process, and to wait for the worker to finish. + exception = future.exception() + if exception is not None: + logger.exception(f"Worker process {data_source_name} raised exception!") + raise exception diff --git a/nowcasting_dataset/time.py b/nowcasting_dataset/time.py index 08cbaef8..0adc95f3 100644 --- a/nowcasting_dataset/time.py +++ b/nowcasting_dataset/time.py @@ -263,9 +263,11 @@ def make_random_time_vectors(batch_size, seq_length_5_minutes, seq_length_30_min Returns: - t0_dt: [batch_size] random init datetimes - - time_5: [batch_size, seq_length_5_minutes] random sequence of datetimes, with 5 mins deltas. - t0_dt is in the middle of the sequence - - time_30: [batch_size, seq_length_30_minutes] random sequence of datetimes, with 30 mins deltas. + - time_5: [batch_size, seq_length_5_minutes] random sequence of datetimes, with + 5 mins deltas. + t0_dt is in the middle of the sequence + - time_30: [batch_size, seq_length_30_minutes] random sequence of datetimes, with + 30 mins deltas. t0_dt is in the middle of the sequence """ delta_5 = pd.Timedelta(minutes=5) diff --git a/nowcasting_dataset/utils.py b/nowcasting_dataset/utils.py index 34ea2512..a60fd08d 100644 --- a/nowcasting_dataset/utils.py +++ b/nowcasting_dataset/utils.py @@ -1,12 +1,10 @@ """ utils functions """ -import hashlib import logging +import os +import re import tempfile -from pathlib import Path -from typing import Optional +from functools import wraps -import re -import os import fsspec.asyn import gcsfs import numpy as np @@ -15,8 +13,9 @@ import xarray as xr import nowcasting_dataset -from nowcasting_dataset.consts import Array +import nowcasting_dataset.filesystem.utils as nd_fs_utils from nowcasting_dataset.config import load, model +from nowcasting_dataset.consts import Array logger = logging.getLogger(__name__) @@ -34,6 +33,7 @@ def set_fsspec_for_multiprocess() -> None: fsspec.asyn.loop[0] = None +# TODO: Issue #170. Is this this function still used? def is_monotonically_increasing(a: Array) -> bool: """ Check the array is monotonically increasing """ # TODO: Can probably replace with pd.Index.is_monotonic_increasing() @@ -45,12 +45,14 @@ def is_monotonically_increasing(a: Array) -> bool: return np.all(np.diff(a) > 0) +# TODO: Issue #170. Is this this function still used? def is_unique(a: Array) -> bool: """ Check array has unique values """ # TODO: Can probably replace with pd.Index.is_unique() return len(a) == len(np.unique(a)) +# TODO: Issue #170. Is this this function still used? def scale_to_0_to_1(a: Array) -> Array: """Scale to the range [0, 1].""" a = a - a.min() @@ -60,6 +62,7 @@ def scale_to_0_to_1(a: Array) -> Array: return a +# TODO: Issue #170. Is this this function still used? def sin_and_cos(df: pd.DataFrame) -> pd.DataFrame: """ For every column in df, creates cols for sin and cos of that col. @@ -93,26 +96,13 @@ def sin_and_cos(df: pd.DataFrame) -> pd.DataFrame: return output_df -def get_netcdf_filename(batch_idx: int, add_hash: bool = False) -> Path: - """Generate full filename, excluding path. - - Filename includes the first 6 digits of the MD5 hash of the filename, - as recommended by Google Cloud in order to distribute data across - multiple back-end servers. - - Add option to turn on and off hashing - - """ - filename = f"{batch_idx}.nc" - # Remove 'hash' at the moment. In the future could has the configuration file, and use this to make sure we are - # saving and loading the same thing - if add_hash: - hash_of_filename = hashlib.md5(filename.encode()).hexdigest() - filename = f"{hash_of_filename[0:6]}_{filename}" - - return filename +def get_netcdf_filename(batch_idx: int) -> str: + """Generate full filename, excluding path.""" + assert 0 <= batch_idx < 1e6 + return f"{batch_idx:06d}.nc" +# TODO: Issue #170. Is this this function still used? def to_numpy(value): """ Change generic data to numpy""" if isinstance(value, xr.DataArray): @@ -133,31 +123,14 @@ def to_numpy(value): return value -def coord_to_range( - da: xr.DataArray, dim: str, prefix: Optional[str], dtype=np.int32 -) -> xr.DataArray: - """ - TODO - - TODO: Actually, I think this is over-complicated? I think we can - just strip off the 'coord' from the dimension. - - """ - coord = da[dim] - da[dim] = np.arange(len(coord), dtype=dtype) - if prefix is not None: - da[f"{prefix}_{dim}_coords"] = xr.DataArray(coord, coords=[da[dim]], dims=[dim]) - return da - - class OpenData: - """ General method to open a file, but if from GCS, the file is downloaded to a temp file first """ + """Open a file, but if from GCS, the file is downloaded to a temp file first.""" def __init__(self, file_name): """ Check file is there, and create temporary file """ self.file_name = file_name - filesystem = fsspec.open(file_name).fs + filesystem = nd_fs_utils.get_filesystem(file_name) if not filesystem.exists(file_name): raise RuntimeError(f"{file_name} does not exist!") @@ -169,7 +142,7 @@ def __enter__(self): 1. if from gcs, download the file to temporary file, and return the temporary file name 2. if local, return local file name """ - fs = fsspec.open(self.file_name).fs + fs = nd_fs_utils.get_filesystem(self.file_name) if type(fs) == gcsfs.GCSFileSystem: fs.get_file(self.file_name, self.temp_file.name) filename = self.temp_file.name @@ -206,3 +179,16 @@ def get_config_with_test_paths(config_filename: str) -> model.Configuration: config = load.load_yaml_configuration(filename) config.set_base_path(local_path) return config + + +def arg_logger(func): + """A function decorator to log all the args and kwargs passed into a function.""" + # Adapted from https://stackoverflow.com/a/23983263/732596 + @wraps(func) + def inner_func(*args, **kwargs): + logger.debug( + f"Arguments passed into function `{func.__name__}`:" f" args={args}; kwargs={kwargs}" + ) + return func(*args, **kwargs) + + return inner_func diff --git a/scripts/generate_data_for_tests/get_test_data.py b/scripts/generate_data_for_tests/get_test_data.py index 4a7a51fc..f2542075 100644 --- a/scripts/generate_data_for_tests/get_test_data.py +++ b/scripts/generate_data_for_tests/get_test_data.py @@ -1,3 +1,4 @@ +"""Get test data.""" import io import os import time @@ -10,9 +11,7 @@ import xarray as xr import nowcasting_dataset -from nowcasting_dataset.config.model import Configuration from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWP_VARIABLE_NAMES, open_nwp -from nowcasting_dataset.dataset.batch import Batch # set up BUCKET = Path("solar-pv-nowcasting-data") @@ -100,13 +99,13 @@ # ### satellite -# s = SatelliteDataSource(filename="gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/" +# s = SatelliteDataSource(filename="gs://solar-pv-nowcasting-data/satellite/" +# "EUMETSAT/SEVIRI_RSS/OSGB36/" # "all_zarr_int16_single_timestep.zarr", # history_length=6, # forecast_length=12, # image_size_pixels=64, -# meters_per_pixel=2000, -# n_timesteps_per_batch=32) +# meters_per_pixel=2000) # # s.open() # start_dt = datetime.fromisoformat("2019-01-01 00:00:00.000+00:00") diff --git a/scripts/prepare_ml_data.py b/scripts/prepare_ml_data.py index 0d648530..cc3445c0 100755 --- a/scripts/prepare_ml_data.py +++ b/scripts/prepare_ml_data.py @@ -2,207 +2,73 @@ """Pre-prepares batches of data. -Usage: - -First, manually create the directories given by the constants -DST_TRAIN_PATH and DST_VALIDATION_PATH, and create the -LOCAL_TEMP_PATH. Note that all files will be deleted from -LOCAL_TEMP_PATH when this script starts up. - -Currently caluclating azimuth and elevation angles, takes about 15 mins for 2548 PV systems, -for about 1 year. - +Please run `./prepare_ml_data.py --help` for full details! """ import logging -import os -from pathlib import Path -from typing import Union -import neptune.new as neptune -import numpy as np -import torch -from neptune.new.integrations.python_logger import NeptuneHandler +import click from pathy import Pathy +# nowcasting_dataset imports import nowcasting_dataset -from nowcasting_dataset.config.load import load_yaml_configuration -from nowcasting_dataset.config.model import set_git_commit -from nowcasting_dataset.config.save import save_yaml_configuration -from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWP_VARIABLE_NAMES -from nowcasting_dataset.data_sources.satellite.satellite_data_source import SAT_VARIABLE_NAMES -from nowcasting_dataset.dataset.datamodule import NowcastingDataModule -from nowcasting_dataset.filesystem import utils -from nowcasting_dataset.filesystem.utils import check_path_exists - -logging.basicConfig(format="%(asctime)s %(levelname)s %(pathname)s %(lineno)d %(message)s") -_LOG = logging.getLogger("nowcasting_dataset") -_LOG.setLevel(logging.INFO) +from nowcasting_dataset import utils +from nowcasting_dataset.data_sources import ALL_DATA_SOURCE_NAMES +from nowcasting_dataset.manager import Manager +# Set up logging. +logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s at %(pathname)s#L%(lineno)d") logging.getLogger("nowcasting_dataset.data_source").setLevel(logging.WARNING) - -ENABLE_NEPTUNE_LOGGING = False - -# load configuration, this can be changed to a different filename as needed. -# TODO: Pass this in as a command-line argument. -# See https://github.com/openclimatefix/nowcasting_dataset/issues/171 -filename = os.path.join(os.path.dirname(nowcasting_dataset.__file__), "config", "on_premises.yaml") -config = load_yaml_configuration(filename) -config = set_git_commit(config) - -# Solar PV data -PV_DATA_FILENAME = config.input_data.pv.pv_filename -PV_METADATA_FILENAME = config.input_data.pv.pv_metadata_filename - -# Satellite data -SAT_ZARR_PATH = config.input_data.satellite.satellite_zarr_path - -# Numerical weather predictions -NWP_ZARR_PATH = config.input_data.nwp.nwp_zarr_path - -# GSP data -GSP_ZARR_PATH = config.input_data.gsp.gsp_zarr_path - -# Topographic data -TOPO_TIFF_PATH = config.input_data.topographic.topographic_filename - -# Paths for output data. -DST_NETCDF4_PATH = Pathy(config.output_data.filepath) -DST_TRAIN_PATH = DST_NETCDF4_PATH / "train" -DST_VALIDATION_PATH = DST_NETCDF4_PATH / "validation" -DST_TEST_PATH = DST_NETCDF4_PATH / "test" -LOCAL_TEMP_PATH = Path(config.process.local_temp_path).expanduser() - -UPLOAD_EVERY_N_BATCHES = config.process.upload_every_n_batches - -# Necessary to avoid "RuntimeError: receieved 0 items of ancdata". See: -# https://discuss.pytorch.org/t/runtimeerror-received-0-items-of-ancdata/4999/2 -torch.multiprocessing.set_sharing_strategy("file_system") - -np.random.seed(config.process.seed) - - -def check_directories_exist(): - _LOG.info("Checking if all paths exist...") - for path in [ - PV_DATA_FILENAME, - PV_METADATA_FILENAME, - SAT_ZARR_PATH, - NWP_ZARR_PATH, - GSP_ZARR_PATH, - TOPO_TIFF_PATH, - DST_TRAIN_PATH, - DST_VALIDATION_PATH, - DST_TEST_PATH, - ]: - check_path_exists(path) - - if UPLOAD_EVERY_N_BATCHES > 0: - check_path_exists(LOCAL_TEMP_PATH) - _LOG.info("Success! All paths exist!") - - -def get_data_module(): - num_workers = 4 - - # get the batch id already made - maximum_batch_id_train = utils.get_maximum_batch_id(os.path.join(DST_TRAIN_PATH, "metadata")) - maximum_batch_id_validation = utils.get_maximum_batch_id( - os.path.join(DST_VALIDATION_PATH, "metadata") - ) - maximum_batch_id_test = utils.get_maximum_batch_id(os.path.join(DST_TEST_PATH, "metadata")) - - if maximum_batch_id_train is None: - maximum_batch_id_train = 0 - - if maximum_batch_id_validation is None: - maximum_batch_id_validation = 0 - - if maximum_batch_id_test is None: - maximum_batch_id_test = 0 - - data_module = NowcastingDataModule( - batch_size=config.process.batch_size, - history_minutes=config.input_data.default_history_minutes, #: Number of minutes of history, not including t0. - forecast_minutes=config.input_data.default_forecast_minutes, #: Number of minutes of forecast. - satellite_image_size_pixels=config.input_data.satellite.satellite_image_size_pixels, - nwp_image_size_pixels=config.input_data.nwp.nwp_image_size_pixels, - nwp_channels=NWP_VARIABLE_NAMES, - sat_channels=SAT_VARIABLE_NAMES, - pv_power_filename=PV_DATA_FILENAME, - pv_metadata_filename=PV_METADATA_FILENAME, - sat_filename=SAT_ZARR_PATH, - nwp_base_path=NWP_ZARR_PATH, - gsp_filename=GSP_ZARR_PATH, - topographic_filename=TOPO_TIFF_PATH, - sun_filename=config.input_data.sun.sun_zarr_path, - pin_memory=False, #: Passed to DataLoader. - num_workers=num_workers, #: Passed to DataLoader. - prefetch_factor=8, #: Passed to DataLoader. - n_samples_per_timestep=8, #: Passed to NowcastingDataset - n_training_batches_per_epoch=25_008, # Add pre-fetch factor! - n_validation_batches_per_epoch=1_008, - n_test_batches_per_epoch=1_008, - collate_fn=lambda x: x, - skip_n_train_batches=maximum_batch_id_train // num_workers, - skip_n_validation_batches=maximum_batch_id_validation // num_workers, - skip_n_test_batches=maximum_batch_id_test // num_workers, - seed=config.process.seed, - ) - _LOG.info("prepare_data()") - data_module.prepare_data() - _LOG.info("setup()") - data_module.setup() - return data_module - - -def iterate_over_dataloader_and_write_to_disk( - dataloader: torch.utils.data.DataLoader, dst_path: Union[Pathy, Path] -): - _LOG.info("Getting first batch") - if UPLOAD_EVERY_N_BATCHES > 0: - local_output_path = LOCAL_TEMP_PATH - else: - local_output_path = dst_path - - for batch_i, batch in enumerate(dataloader): - _LOG.info(f"Got batch {batch_i}") - - batch.save_netcdf(batch_i=batch_i, path=local_output_path) - - if UPLOAD_EVERY_N_BATCHES > 0 and batch_i > 0 and batch_i % UPLOAD_EVERY_N_BATCHES == 0: - utils.upload_and_delete_local_files(dst_path, LOCAL_TEMP_PATH) - - # Make sure we upload the last few batches, if necessary. - if UPLOAD_EVERY_N_BATCHES > 0: - utils.upload_and_delete_local_files(dst_path, LOCAL_TEMP_PATH) - - -def main(): - if ENABLE_NEPTUNE_LOGGING: - run = neptune.init( - project="OpenClimateFix/nowcasting-data", - capture_stdout=True, - capture_stderr=True, - capture_hardware_metrics=False, - ) - _LOG.addHandler(NeptuneHandler(run=run)) - - check_directories_exist() - if UPLOAD_EVERY_N_BATCHES > 0: - utils.delete_all_files_in_temp_path(path=LOCAL_TEMP_PATH) - - datamodule = get_data_module() - - _LOG.info("Finished preparing datamodule!") - _LOG.info("Preparing training data...") - iterate_over_dataloader_and_write_to_disk(datamodule.train_dataloader(), DST_TRAIN_PATH) - _LOG.info("Preparing validation data...") - iterate_over_dataloader_and_write_to_disk(datamodule.val_dataloader(), DST_VALIDATION_PATH) - _LOG.info("Preparing test data...") - iterate_over_dataloader_and_write_to_disk(datamodule.test_dataloader(), DST_TEST_PATH) - _LOG.info("Done!") - - save_yaml_configuration(config) +logger = logging.getLogger("nowcasting_dataset") +logger.setLevel(logging.DEBUG) + +default_config_filename = Pathy(nowcasting_dataset.__file__).parent / "config" / "on_premises.yaml" + + +@click.command() +@click.option( + "--config_filename", + default=default_config_filename, + help="The filename of the YAML configuration file.", +) +@click.option( + "--data_source", + multiple=True, + default=ALL_DATA_SOURCE_NAMES, + type=click.Choice(ALL_DATA_SOURCE_NAMES), + help=( + "If you want to process just a subset of the DataSources in the config file, then enter" + " the names of those DataSources using the --data_source option. Enter one name per" + " --data_source option. You can use --data_source multiple times. For example:" + " --data_source nwp --data_source satellite. Note that only these DataSources" + " always be used when computing the available datetime periods across all the" + " DataSources, so be very careful about setting --data_source when creating the" + " spatial_and_temporal_locations_of_each_example.csv files!" + ), +) +@click.option( + "--overwrite_batches", + default=False, + help=( + "Overwrite any existing batches in the destination directory, for the selected" + " DataSource(s). If this flag is not set, and if there are existing batches," + " then this script will start generating new batches (if necessary) after the" + " existing batches." + ), +) +@utils.arg_logger +def main(config_filename: str, data_source: list[str], overwrite_batches: bool): + """Generate pre-prepared batches of data.""" + manager = Manager() + manager.load_yaml_configuration(config_filename) + manager.initialise_data_sources(names_of_selected_data_sources=data_source) + # TODO: Issue 323: maybe don't allow + # create_files_specifying_spatial_and_temporal_locations_of_each_example to be run if a subset + # of data_sources is passed in at the command line. + manager.create_files_specifying_spatial_and_temporal_locations_of_each_example_if_necessary() + manager.create_batches(overwrite_batches) + # TODO: Issue #316: save_yaml_configuration(config) + # TODO: Issue #317: Validate ML data. + logger.info("Done!") if __name__ == "__main__": diff --git a/setup.py b/setup.py index 825266e6..e37c1bad 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ author_email="info@openclimatefix.org", company="Open Climate Fix Ltd", install_requires=install_requires, - extras_require={"torch": ["torch", "pytorch_lightning"]}, + extras_require={"torch": ["torch"]}, long_description=long_description, long_description_content_type="text/markdown", package_data={"config": ["nowcasting_dataset/config/*.yaml"]}, diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 15571b93..eaa7032a 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -1,3 +1,4 @@ +"""Test config.""" import os import tempfile from datetime import datetime @@ -30,6 +31,8 @@ def test_yaml_load(): assert isinstance(config, Configuration) +# TODO: Issue #316: Remove save_yaml_configuration() and this test. +@pytest.mark.skip("This test will be removed when issue #316 is implemented") def test_yaml_save(): """ Check a configuration can be saved to a yaml file @@ -88,7 +91,7 @@ def test_load_to_gcs(): Check that configuration can be loaded to gcs """ config = load_yaml_configuration( - filename="gs://solar-pv-nowcasting-data/prepared_ML_training_data/v-default/configuration.yaml" + filename="gs://solar-pv-nowcasting-data/prepared_ML_training_data/v-default/configuration.yaml" # noqa: E501 ) assert isinstance(config, Configuration) diff --git a/tests/data_sources/gsp/test_gsp_model.py b/tests/data_sources/gsp/test_gsp_model.py index fd308ba4..b0d5241c 100644 --- a/tests/data_sources/gsp/test_gsp_model.py +++ b/tests/data_sources/gsp/test_gsp_model.py @@ -1,3 +1,4 @@ +"""Test GSP.""" import os import tempfile @@ -8,11 +9,11 @@ from nowcasting_dataset.data_sources.gsp.gsp_model import GSP -def test_gsp_init(): +def test_gsp_init(): # noqa: D103 _ = gsp_fake(batch_size=4, seq_length_30=5, n_gsp_per_batch=6) -def test_gsp_validation(): +def test_gsp_validation(): # noqa: D103 gsp = gsp_fake(batch_size=4, seq_length_30=5, n_gsp_per_batch=6) GSP.model_validation(gsp) @@ -22,10 +23,10 @@ def test_gsp_validation(): GSP.model_validation(gsp) -def test_gsp_save(): +def test_gsp_save(): # noqa: D103 with tempfile.TemporaryDirectory() as dirpath: gsp = gsp_fake(batch_size=4, seq_length_30=5, n_gsp_per_batch=6) gsp.save_netcdf(path=dirpath, batch_i=0) - assert os.path.exists(f"{dirpath}/gsp/0.nc") + assert os.path.exists(f"{dirpath}/gsp/000000.nc") diff --git a/tests/data_sources/satellite/test_satellite_data_source.py b/tests/data_sources/satellite/test_satellite_data_source.py index 86e20798..58d5bf2f 100644 --- a/tests/data_sources/satellite/test_satellite_data_source.py +++ b/tests/data_sources/satellite/test_satellite_data_source.py @@ -1,18 +1,19 @@ +"""Test SatelliteDataSource.""" import numpy as np import pandas as pd import pytest -def test_satellite_data_source_init(sat_data_source): +def test_satellite_data_source_init(sat_data_source): # noqa: D103 pass -def test_open(sat_data_source): +def test_open(sat_data_source): # noqa: D103 sat_data_source.open() assert sat_data_source.data is not None -def test_datetime_index(sat_data_source): +def test_datetime_index(sat_data_source): # noqa: D103 datetimes = sat_data_source.datetime_index() assert isinstance(datetimes, pd.DatetimeIndex) assert len(datetimes) > 0 @@ -34,7 +35,7 @@ def test_datetime_index(sat_data_source): (2001, 2001, -124_000, 130_000, 130_000, -124_000), ], ) -def test_get_example(sat_data_source, x, y, left, right, top, bottom): +def test_get_example(sat_data_source, x, y, left, right, top, bottom): # noqa: D103 sat_data_source.open() t0_dt = pd.Timestamp("2019-01-01T13:00") sat_data = sat_data_source.get_example(t0_dt=t0_dt, x_meters_center=x, y_meters_center=y) @@ -48,7 +49,7 @@ def test_get_example(sat_data_source, x, y, left, right, top, bottom): assert len(sat_data.y) == pytest.IMAGE_SIZE_PIXELS -def test_geospatial_border(sat_data_source): +def test_geospatial_border(sat_data_source): # noqa: D103 border = sat_data_source.geospatial_border() correct_border = [(-110000, 1094000), (-110000, -58000), (730000, 1094000), (730000, -58000)] np.testing.assert_array_equal(border, correct_border) diff --git a/tests/data_sources/satellite/test_satellite_model.py b/tests/data_sources/satellite/test_satellite_model.py index df2e9e88..83953f1d 100644 --- a/tests/data_sources/satellite/test_satellite_model.py +++ b/tests/data_sources/satellite/test_satellite_model.py @@ -1,3 +1,4 @@ +"""Test Satellite model.""" import os import tempfile @@ -8,11 +9,11 @@ from nowcasting_dataset.data_sources.satellite.satellite_model import Satellite -def test_satellite_init(): +def test_satellite_init(): # noqa: D103 _ = satellite_fake() -def test_satellite_validation(): +def test_satellite_validation(): # noqa: D103 sat = satellite_fake() Satellite.model_validation(sat) @@ -22,9 +23,9 @@ def test_satellite_validation(): Satellite.model_validation(sat) -def test_satellite_save(): +def test_satellite_save(): # noqa: D103 with tempfile.TemporaryDirectory() as dirpath: satellite_fake().save_netcdf(path=dirpath, batch_i=0) - assert os.path.exists(f"{dirpath}/satellite/0.nc") + assert os.path.exists(f"{dirpath}/satellite/000000.nc") diff --git a/tests/data_sources/sun/test_sun_model.py b/tests/data_sources/sun/test_sun_model.py index bcd60444..a6a0bc94 100644 --- a/tests/data_sources/sun/test_sun_model.py +++ b/tests/data_sources/sun/test_sun_model.py @@ -1,3 +1,4 @@ +# noqa: D100 import os import tempfile @@ -8,11 +9,11 @@ from nowcasting_dataset.data_sources.sun.sun_model import Sun -def test_sun_init(): +def test_sun_init(): # noqa: D103 _ = sun_fake(batch_size=4, seq_length_5=17) -def test_sun_validation(): +def test_sun_validation(): # noqa: D103 sun = sun_fake(batch_size=4, seq_length_5=17) Sun.model_validation(sun) @@ -22,7 +23,7 @@ def test_sun_validation(): Sun.model_validation(sun) -def test_sun_validation_elevation(): +def test_sun_validation_elevation(): # noqa: D103 sun = sun_fake(batch_size=4, seq_length_5=17) Sun.model_validation(sun) @@ -32,7 +33,7 @@ def test_sun_validation_elevation(): Sun.model_validation(sun) -def test_sun_validation_azimuth(): +def test_sun_validation_azimuth(): # noqa: D103 sun = sun_fake(batch_size=4, seq_length_5=17) Sun.model_validation(sun) @@ -42,10 +43,9 @@ def test_sun_validation_azimuth(): Sun.model_validation(sun) -def test_sun_save(): - +def test_sun_save(): # noqa: D103 with tempfile.TemporaryDirectory() as dirpath: sun = sun_fake(batch_size=4, seq_length_5=17) sun.save_netcdf(path=dirpath, batch_i=0) - assert os.path.exists(f"{dirpath}/sun/0.nc") + assert os.path.exists(f"{dirpath}/sun/000000.nc") diff --git a/tests/data_sources/test_data_source_list.py b/tests/data_sources/test_data_source_list.py deleted file mode 100644 index 1ed6d35a..00000000 --- a/tests/data_sources/test_data_source_list.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -from datetime import datetime - -import nowcasting_dataset -from nowcasting_dataset.data_sources.data_source_list import DataSourceList -from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource -import nowcasting_dataset.utils as nd_utils - - -def test_sample_spatial_and_temporal_locations_for_examples(): - local_path = os.path.dirname(nowcasting_dataset.__file__) + "/.." - - gsp = GSPDataSource( - zarr_path=f"{local_path}/tests/data/gsp/test.zarr", - start_dt=datetime(2019, 1, 1), - end_dt=datetime(2019, 1, 2), - history_minutes=30, - forecast_minutes=60, - image_size_pixels=64, - meters_per_pixel=2000, - ) - - data_source_list = DataSourceList([gsp]) - t0_datetimes = data_source_list.get_t0_datetimes_across_all_data_sources(freq="30T") - locations = data_source_list.sample_spatial_and_temporal_locations_for_examples( - t0_datetimes=t0_datetimes, n_examples=10 - ) - - assert locations.columns.to_list() == ["t0_datetime_UTC", "x_center_OSGB", "y_center_OSGB"] - assert len(locations) == 10 - - -def test_from_config(): - config = nd_utils.get_config_with_test_paths("test.yaml") - data_source_list = DataSourceList.from_config(config.input_data) - assert len(data_source_list) == 6 - assert isinstance( - data_source_list.data_source_which_defines_geospatial_locations, GSPDataSource - ) diff --git a/tests/data_sources/test_nwp_data_source.py b/tests/data_sources/test_nwp_data_source.py index 7ed8b695..fe7d785d 100644 --- a/tests/data_sources/test_nwp_data_source.py +++ b/tests/data_sources/test_nwp_data_source.py @@ -1,3 +1,4 @@ +# noqa: D100 import os import pandas as pd @@ -11,39 +12,36 @@ NWP_ZARR_PATH = f"{PATH}/../tests/data/nwp_data/test.zarr" -def test_nwp_data_source_init(): +def test_nwp_data_source_init(): # noqa: D103 _ = NWPDataSource( zarr_path=NWP_ZARR_PATH, history_minutes=30, forecast_minutes=60, - n_timesteps_per_batch=8, ) -def test_nwp_data_source_open(): +def test_nwp_data_source_open(): # noqa: D103 nwp = NWPDataSource( zarr_path=NWP_ZARR_PATH, history_minutes=30, forecast_minutes=60, - n_timesteps_per_batch=8, channels=["t"], ) nwp.open() -def test_nwp_data_source_batch(): +def test_nwp_data_source_batch(): # noqa: D103 nwp = NWPDataSource( zarr_path=NWP_ZARR_PATH, history_minutes=30, forecast_minutes=60, - n_timesteps_per_batch=8, channels=["t"], ) nwp.open() - t0_datetimes = nwp._data.init_time[2:10].values + t0_datetimes = nwp._data.init_time[2:6].values x = nwp._data.x[0:4].values y = nwp._data.y[0:4].values @@ -52,12 +50,11 @@ def test_nwp_data_source_batch(): assert batch.data.shape == (4, 1, 19, 2, 2) -def test_nwp_get_contiguous_time_periods(): +def test_nwp_get_contiguous_time_periods(): # noqa: D103 nwp = NWPDataSource( zarr_path=NWP_ZARR_PATH, history_minutes=30, forecast_minutes=60, - n_timesteps_per_batch=8, channels=["t"], ) @@ -68,12 +65,11 @@ def test_nwp_get_contiguous_time_periods(): pd.testing.assert_frame_equal(contiguous_time_periods, correct_time_periods) -def test_nwp_get_contiguous_t0_time_periods(): +def test_nwp_get_contiguous_t0_time_periods(): # noqa: D103 nwp = NWPDataSource( zarr_path=NWP_ZARR_PATH, history_minutes=30, forecast_minutes=60, - n_timesteps_per_batch=8, channels=["t"], ) diff --git a/tests/data_sources/test_pv_data_source.py b/tests/data_sources/test_pv_data_source.py index 55904999..b2c818e2 100644 --- a/tests/data_sources/test_pv_data_source.py +++ b/tests/data_sources/test_pv_data_source.py @@ -1,3 +1,4 @@ +"""Test PVDataSource.""" import logging import os from datetime import datetime @@ -13,7 +14,7 @@ logger = logging.getLogger(__name__) -def test_get_example_and_batch(): +def test_get_example_and_batch(): # noqa: D103 path = os.path.dirname(nowcasting_dataset.__file__) @@ -39,12 +40,12 @@ def test_get_example_and_batch(): _ = pv_data_source.get_example(pv_data_source.pv_power.index[0], x_locations[0], y_locations[0]) batch = pv_data_source.get_batch( - pv_data_source.pv_power.index[6:11], x_locations[0:10], y_locations[0:10] + pv_data_source.pv_power.index[6:16], x_locations[0:10], y_locations[0:10] ) - assert batch.data.shape == (5, 19, 128) + assert batch.data.shape == (10, 19, 128) -def test_drop_pv_systems_which_produce_overnight(): +def test_drop_pv_systems_which_produce_overnight(): # noqa: D103 pv_power = pd.DataFrame(index=pd.date_range("2010-01-01", "2010-01-02", freq="5 min")) _ = drop_pv_systems_which_produce_overnight(pv_power=pv_power) diff --git a/tests/dataset/test_batch.py b/tests/dataset/test_batch.py index 5f771f22..dcd9e38d 100644 --- a/tests/dataset/test_batch.py +++ b/tests/dataset/test_batch.py @@ -1,3 +1,4 @@ +"""Test Batch.""" import os import tempfile @@ -8,25 +9,25 @@ @pytest.fixture -def configuration(): +def configuration(): # noqa: D103 con = Configuration() con.input_data = InputData.set_all_to_defaults() con.process.batch_size = 4 return con -def test_model(configuration): +def test_model(configuration): # noqa: D103 _ = Batch.fake(configuration=configuration) -def test_model_save_to_netcdf(configuration): +def test_model_save_to_netcdf(configuration): # noqa: D103 with tempfile.TemporaryDirectory() as dirpath: Batch.fake(configuration=configuration).save_netcdf(path=dirpath, batch_i=0) - assert os.path.exists(f"{dirpath}/satellite/0.nc") + assert os.path.exists(f"{dirpath}/satellite/000000.nc") -def test_model_load_from_netcdf(configuration): +def test_model_load_from_netcdf(configuration): # noqa: D103 with tempfile.TemporaryDirectory() as dirpath: Batch.fake(configuration=configuration).save_netcdf(path=dirpath, batch_i=0) diff --git a/tests/filesystem/test_local.py b/tests/filesystem/test_local.py index 0f514cc3..642a5655 100644 --- a/tests/filesystem/test_local.py +++ b/tests/filesystem/test_local.py @@ -1,3 +1,4 @@ +# noqa: D100 import os import tempfile from pathlib import Path @@ -7,12 +8,12 @@ delete_all_files_in_temp_path, download_to_local, get_all_filenames_in_path, - make_folder, + makedirs, upload_one_file, ) -def test_check_file_exists(): +def test_check_file_exists(): # noqa: D103 file1 = "test_file1.txt" file2 = "test_dir/test_file2.txt" @@ -27,7 +28,7 @@ def test_check_file_exists(): # add fake file to dir os.mkdir(f"{tmpdirname}/test_dir") - path_and_filename_2 = os.path.join(local_path, file2) + _ = os.path.join(local_path, file2) with open(os.path.join(local_path, file2), "w"): pass @@ -35,7 +36,7 @@ def test_check_file_exists(): check_path_exists(path=f"{tmpdirname}/test_dir") -def test_make_folder(): +def test_makedirs(): # noqa: D103 folder_1 = "test_dir_1" folder_2 = "test_dir_2" @@ -48,7 +49,7 @@ def test_make_folder(): folder_2 = os.path.join(local_path, folder_2) # use the make folder function - make_folder(folder_1) + makedirs(folder_1) check_path_exists(path=folder_1) # make a folder @@ -58,7 +59,7 @@ def test_make_folder(): check_path_exists(path=folder_2) -def test_delete_local_files(): +def test_delete_local_files(): # noqa: D103 file1 = "test_file1.txt" folder1 = "test_dir" @@ -88,7 +89,7 @@ def test_delete_local_files(): assert os.path.exists(path_and_folder_1) -def test_delete_local_files_and_folder(): +def test_delete_local_files_and_folder(): # noqa: D103 file1 = "test_file1.txt" folder1 = "test_dir" @@ -118,7 +119,7 @@ def test_delete_local_files_and_folder(): assert not os.path.exists(path_and_folder_1) -def test_download(): +def test_download(): # noqa: D103 file1 = "test_file1.txt" file2 = "test_dir/test_file2.txt" @@ -128,18 +129,18 @@ def test_download(): local_path = Path(tmpdirname) # add fake file to dir - path_and_filename_1 = os.path.join(local_path, file1) + path_and_filename_1 = local_path / file1 with open(path_and_filename_1, "w"): pass # add fake file to dir - os.mkdir(f"{tmpdirname}/test_dir") - path_and_filename_2 = os.path.join(local_path, file2) - with open(os.path.join(local_path, file2), "w"): + os.mkdir(local_path / "test_dir") + path_and_filename_2 = local_path / file2 + with open(path_and_filename_2, "w"): pass # run function - path_and_filename_3 = os.path.join(local_path, file3) + path_and_filename_3 = local_path / file3 download_to_local(remote_filename=path_and_filename_1, local_filename=path_and_filename_3) # check the object are not there @@ -147,7 +148,7 @@ def test_download(): assert len(filenames) == 3 -def test_upload(): +def test_upload(): # noqa: D103 file1 = "test_file1.txt" file2 = "test_dir/test_file2.txt" @@ -163,7 +164,7 @@ def test_upload(): # add fake file to dir os.mkdir(f"{tmpdirname}/test_dir") - path_and_filename_2 = os.path.join(local_path, file2) + _ = os.path.join(local_path, file2) with open(os.path.join(local_path, file2), "w"): pass diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py deleted file mode 100644 index 5e8637f1..00000000 --- a/tests/test_datamodule.py +++ /dev/null @@ -1,155 +0,0 @@ -import logging -import os -from pathlib import Path - -import numpy as np -import pandas as pd -import pytest - -import nowcasting_dataset -from nowcasting_dataset.config.load import load_yaml_configuration -from nowcasting_dataset.dataset import datamodule -from nowcasting_dataset.dataset.batch import Batch -from nowcasting_dataset.dataset.datamodule import NowcastingDataModule -from nowcasting_dataset.dataset.split.split import SplitMethod -import nowcasting_dataset.utils as nd_utils - -logging.basicConfig(format="%(asctime)s %(levelname)s %(pathname)s %(lineno)d %(message)s") -_LOG = logging.getLogger("nowcasting_dataset") -_LOG.setLevel(logging.DEBUG) - - -@pytest.fixture -def nowcasting_datamodule(sat_filename: Path): - return datamodule.NowcastingDataModule(sat_filename=sat_filename) - - -def test_prepare_data(nowcasting_datamodule: datamodule.NowcastingDataModule): - nowcasting_datamodule.prepare_data() - - -def test_get_daylight_datetime_index( - nowcasting_datamodule: datamodule.NowcastingDataModule, use_cloud_data: bool -): - nowcasting_datamodule.prepare_data() - nowcasting_datamodule.t0_datetime_freq = "5T" - t0_datetimes = nowcasting_datamodule._get_t0_datetimes_across_all_data_sources() - assert isinstance(t0_datetimes, pd.DatetimeIndex) - if not use_cloud_data: - # The testing sat_data.zarr has contiguous data from 12:05 to 18:00. - # nowcasting_datamodule.history_minutes = 30 - # nowcasting_datamodule.forecast_minutes = 60 - # Daylight ends at 16:20. - # So the expected t0_datetimes start at 12:35 (12:05 + 30 minutes) - # and end at 15:20 (16:20 - 60 minutes) - print(t0_datetimes) - correct_t0_datetimes = pd.date_range("2019-01-01 12:35", "2019-01-01 15:20", freq="5 min") - np.testing.assert_array_equal(t0_datetimes, correct_t0_datetimes) - - -def test_setup(nowcasting_datamodule: datamodule.NowcastingDataModule): - # Check it throws RuntimeError if we try running - # setup() before running prepare_data(): - with pytest.raises(RuntimeError): - nowcasting_datamodule.setup() - nowcasting_datamodule.prepare_data() - nowcasting_datamodule.setup() - - -@pytest.mark.parametrize("config_filename", ["test.yaml", "nwp_size_test.yaml"]) -def test_data_module(config_filename): - - # load configuration, this can be changed to a different filename as needed - config = nd_utils.get_config_with_test_paths(config_filename) - - data_module = NowcastingDataModule( - batch_size=config.process.batch_size, - history_minutes=30, #: Number of timesteps of history, not including t0. - forecast_minutes=60, #: Number of timesteps of forecast. - satellite_image_size_pixels=config.input_data.satellite.satellite_image_size_pixels, - nwp_image_size_pixels=config.input_data.nwp.nwp_image_size_pixels, - nwp_channels=config.input_data.nwp.nwp_channels[0:1], - sat_channels=config.input_data.satellite.satellite_channels, # reduced for test data - pv_power_filename=config.input_data.pv.pv_filename, - pv_metadata_filename=config.input_data.pv.pv_metadata_filename, - sat_filename=config.input_data.satellite.satellite_zarr_path, - nwp_base_path=config.input_data.nwp.nwp_zarr_path, - gsp_filename=config.input_data.gsp.gsp_zarr_path, - topographic_filename=config.input_data.topographic.topographic_filename, - sun_filename=config.input_data.sun.sun_zarr_path, - pin_memory=True, #: Passed to DataLoader. - num_workers=0, #: Passed to DataLoader. - prefetch_factor=8, #: Passed to DataLoader. - n_samples_per_timestep=16, #: Passed to NowcastingDataset - n_training_batches_per_epoch=200, # Add pre-fetch factor! - n_validation_batches_per_epoch=200, - collate_fn=lambda x: x, - skip_n_train_batches=0, - skip_n_validation_batches=0, - train_validation_percentage_split=50, - pv_load_azimuth_and_elevation=True, - split_method=SplitMethod.SAME, - ) - - _LOG.info("prepare_data()") - data_module.prepare_data() - _LOG.info("setup()") - data_module.setup() - - data_generator = iter(data_module.train_dataset) - batch = next(data_generator) - - assert batch.batch_size == config.process.batch_size - assert type(batch) == Batch - - assert batch.satellite is not None - assert batch.nwp is not None - assert batch.sun is not None - assert batch.topographic is not None - assert batch.pv is not None - assert batch.gsp is not None - assert batch.metadata is not None - assert batch.datetime is not None - - -def test_batch_to_batch_to_dataset(): - config = nd_utils.get_config_with_test_paths("test.yaml") - - data_module = NowcastingDataModule( - batch_size=config.process.batch_size, - history_minutes=30, #: Number of timesteps of history, not including t0. - forecast_minutes=60, #: Number of timesteps of forecast. - satellite_image_size_pixels=config.input_data.satellite.satellite_image_size_pixels, - nwp_image_size_pixels=config.input_data.nwp.nwp_image_size_pixels, - nwp_channels=config.input_data.nwp.nwp_channels[0:1], - sat_channels=config.input_data.satellite.satellite_channels, # reduced for test data - pv_power_filename=config.input_data.pv.pv_filename, - pv_metadata_filename=config.input_data.pv.pv_metadata_filename, - sat_filename=config.input_data.satellite.satellite_zarr_path, - nwp_base_path=config.input_data.nwp.nwp_zarr_path, - gsp_filename=config.input_data.gsp.gsp_zarr_path, - topographic_filename=config.input_data.topographic.topographic_filename, - sun_filename=config.input_data.sun.sun_zarr_path, - pin_memory=True, #: Passed to DataLoader. - num_workers=0, #: Passed to DataLoader. - prefetch_factor=8, #: Passed to DataLoader. - n_samples_per_timestep=16, #: Passed to NowcastingDataset - n_training_batches_per_epoch=200, # Add pre-fetch factor! - n_validation_batches_per_epoch=200, - collate_fn=lambda x: x, - skip_n_train_batches=0, - skip_n_validation_batches=0, - train_validation_percentage_split=50, - pv_load_azimuth_and_elevation=False, - split_method=SplitMethod.SAME, - ) - - _LOG.info("prepare_data()") - data_module.prepare_data() - _LOG.info("setup()") - data_module.setup() - - data_generator = iter(data_module.train_dataset) - batch = next(data_generator) - - assert type(batch) == Batch diff --git a/tests/test_dataset.py b/tests/test_dataset.py deleted file mode 100644 index 31e56041..00000000 --- a/tests/test_dataset.py +++ /dev/null @@ -1,65 +0,0 @@ -import numpy as np -import pandas as pd -import pytest - -import nowcasting_dataset.time as nd_time -from nowcasting_dataset.dataset.batch import Batch -from nowcasting_dataset.dataset.datasets import NowcastingDataset - - -def _get_t0_datetimes(data_source, freq) -> pd.DatetimeIndex: - t0_periods = data_source.get_contiguous_t0_time_periods() - t0_datetimes = nd_time.time_periods_to_datetime_index(t0_periods, freq=freq) - return t0_datetimes - - -@pytest.fixture -def dataset(sat_data_source, general_data_source): - t0_datetimes = _get_t0_datetimes(sat_data_source, freq="5T") - - return NowcastingDataset( - batch_size=8, - n_batches_per_epoch_per_worker=64, - n_samples_per_timestep=2, - data_sources=[sat_data_source, general_data_source], - t0_datetimes=t0_datetimes, - ) - - -@pytest.fixture -def dataset_gsp(gsp_data_source, general_data_source): - t0_datetimes = _get_t0_datetimes(gsp_data_source, freq="30T") - - return NowcastingDataset( - batch_size=8, - n_batches_per_epoch_per_worker=64, - n_samples_per_timestep=2, - data_sources=[gsp_data_source, general_data_source], - t0_datetimes=t0_datetimes, - ) - - -def test_post_init(dataset: NowcastingDataset): - assert dataset._n_timesteps_per_batch == 4 - assert not dataset._per_worker_init_has_run - - -def test_per_worker_init(dataset: NowcastingDataset): - WORKER_ID = 1 - dataset.per_worker_init(worker_id=WORKER_ID) - assert isinstance(dataset.rng, np.random.Generator) - assert dataset.worker_id == WORKER_ID - - -def test_get_batch(dataset: NowcastingDataset): - dataset.per_worker_init(worker_id=1) - with pytest.raises(NotImplementedError): - _ = dataset._get_batch() - - -def test_get_batch_gsp(dataset_gsp: NowcastingDataset): - dataset_gsp.per_worker_init(worker_id=1) - batch = dataset_gsp._get_batch() - assert isinstance(batch, Batch) - - assert batch.gsp is not None diff --git a/tests/test_manager.py b/tests/test_manager.py new file mode 100644 index 00000000..751bbeb0 --- /dev/null +++ b/tests/test_manager.py @@ -0,0 +1,47 @@ +"""Test Manager.""" +from datetime import datetime +from pathlib import Path + +import nowcasting_dataset +from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource +from nowcasting_dataset.manager import Manager + + +def test_sample_spatial_and_temporal_locations_for_examples(): # noqa: D103 + local_path = Path(nowcasting_dataset.__file__).parent.parent + + gsp = GSPDataSource( + zarr_path=f"{local_path}/tests/data/gsp/test.zarr", + start_dt=datetime(2019, 1, 1), + end_dt=datetime(2019, 1, 2), + history_minutes=30, + forecast_minutes=60, + image_size_pixels=64, + meters_per_pixel=2000, + ) + + manager = Manager() + manager.data_sources = {"gsp": gsp} + manager.data_source_which_defines_geospatial_locations = gsp + t0_datetimes = manager.get_t0_datetimes_across_all_data_sources(freq="30T") + locations = manager.sample_spatial_and_temporal_locations_for_examples( + t0_datetimes=t0_datetimes, n_examples=10 + ) + + assert locations.columns.to_list() == ["t0_datetime_UTC", "x_center_OSGB", "y_center_OSGB"] + assert len(locations) == 10 + assert (t0_datetimes[0] <= locations["t0_datetime_UTC"]).all() + assert (t0_datetimes[-1] >= locations["t0_datetime_UTC"]).all() + + +def test_load_yaml_configuration(): # noqa: D103 + manager = Manager() + local_path = Path(nowcasting_dataset.__file__).parent.parent + filename = local_path / "tests" / "config" / "test.yaml" + manager.load_yaml_configuration(filename=filename) + manager.initialise_data_sources() + assert len(manager.data_sources) == 6 + assert isinstance(manager.data_source_which_defines_geospatial_locations, GSPDataSource) + + +# TODO: Issue #322: Test the other Manager methods! diff --git a/tests/test_utils.py b/tests/test_utils.py index eb3a9564..e86510b0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,4 @@ +# noqa: D100 import numpy as np import pandas as pd import pytest @@ -5,7 +6,7 @@ from nowcasting_dataset import utils -def test_is_monotically_increasing(): +def test_is_monotically_increasing(): # noqa: D103 assert utils.is_monotonically_increasing([1, 2, 3, 4]) assert not utils.is_monotonically_increasing([1, 2, 3, 3]) assert not utils.is_monotonically_increasing([1, 2, 3, 0]) @@ -15,7 +16,7 @@ def test_is_monotically_increasing(): assert not utils.is_monotonically_increasing(index[::-1]) -def test_sin_and_cos(): +def test_sin_and_cos(): # noqa: D103 df = pd.DataFrame({"a": range(30), "b": np.arange(30) - 30}) with pytest.raises(ValueError) as _: utils.sin_and_cos(pd.DataFrame({"a": [-1, 0, 1]})) @@ -29,12 +30,11 @@ def test_sin_and_cos(): np.testing.assert_array_equal(sin_and_cos.columns, ["a_sin", "a_cos", "b_sin", "b_cos"]) -def test_get_netcdf_filename(): - assert utils.get_netcdf_filename(10) == "10.nc" - assert utils.get_netcdf_filename(10, add_hash=True) == "77eb6f_10.nc" +def test_get_netcdf_filename(): # noqa: D103 + assert utils.get_netcdf_filename(10) == "000010.nc" -def test_remove_regex_pattern_from_keys(): +def test_remove_regex_pattern_from_keys(): # noqa: D103 d = { "satellite_zarr_path": "/a/b/c/foo.zarr", "bar": "baz",