diff --git a/nowcasting_dataset/dataset/split/split.py b/nowcasting_dataset/dataset/split/split.py index 4f1e134b..f873728e 100644 --- a/nowcasting_dataset/dataset/split/split.py +++ b/nowcasting_dataset/dataset/split/split.py @@ -200,6 +200,10 @@ def split_data( logger.debug("Split data done!") for split_name, dt in split_datetimes._asdict().items(): - logger.debug(f"{split_name} has {len(dt):,d} datetimes, from {dt[0]} to {dt[-1]}") + if len(dt) == 0: + # only a warning is made as this may happen during unittests + logger.warning(f"{split_name} has {len(dt):,d} datetimes") + else: + logger.debug(f"{split_name} has {len(dt):,d} datetimes, from {dt[0]} to {dt[-1]}") return split_datetimes diff --git a/nowcasting_dataset/filesystem/utils.py b/nowcasting_dataset/filesystem/utils.py index 4f9feed9..25599a8c 100644 --- a/nowcasting_dataset/filesystem/utils.py +++ b/nowcasting_dataset/filesystem/utils.py @@ -10,13 +10,13 @@ _LOG = logging.getLogger("nowcasting_dataset") -def upload_and_delete_local_files(dst_path: str, local_path: Path): +def upload_and_delete_local_files(dst_path: Union[str, Path], local_path: Union[str, Path]): """ Upload an entire folder and delete local files to either AWS or GCP """ _LOG.info("Uploading!") filesystem = get_filesystem(dst_path) - filesystem.put(str(local_path), dst_path, recursive=True) + filesystem.copy(str(local_path), str(dst_path), recursive=True) delete_all_files_in_temp_path(local_path) diff --git a/nowcasting_dataset/manager.py b/nowcasting_dataset/manager.py index 9383a73e..5a180fed 100644 --- a/nowcasting_dataset/manager.py +++ b/nowcasting_dataset/manager.py @@ -369,6 +369,10 @@ def create_batches(self, overwrite_batches: bool) -> None: for worker_id, (data_source_name, data_source) in enumerate( self.data_sources.items() ): + + if len(locations_for_split) == 0: + break + # Get indexes of first batch and example. And subset locations_for_split. idx_of_first_batch = first_batches_to_create[split_name][data_source_name] idx_of_first_example = idx_of_first_batch * self.config.process.batch_size diff --git a/tests/config/test.yaml b/tests/config/test.yaml index 7cfc3153..37f846cc 100644 --- a/tests/config/test.yaml +++ b/tests/config/test.yaml @@ -30,3 +30,6 @@ process: local_temp_path: ~/temp/ seed: 1234 upload_every_n_batches: 16 + n_train_batches: 2 + n_validation_batches: 0 + n_test_batches: 0 diff --git a/tests/test_manager.py b/tests/test_manager.py index 81daf75e..0056ee2e 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,4 +1,6 @@ """Test Manager.""" +import os +import tempfile from datetime import datetime from pathlib import Path @@ -76,4 +78,81 @@ def test_get_daylight_datetime_index(): np.testing.assert_array_equal(t0_datetimes, correct_t0_datetimes) -# TODO: Issue #322: Test the other Manager methods! +def test_batches(): + """Test that batches can be made""" + filename = Path(nowcasting_dataset.__file__).parent.parent / "tests" / "data" / "sat_data.zarr" + + sat = SatelliteDataSource( + zarr_path=filename, + history_minutes=30, + forecast_minutes=60, + image_size_pixels=64, + meters_per_pixel=2000, + channels=("HRV",), + ) + + filename = ( + Path(nowcasting_dataset.__file__).parent.parent / "tests" / "data" / "gsp" / "test.zarr" + ) + + gsp = GSPDataSource( + zarr_path=filename, + start_dt=datetime(2019, 1, 1), + end_dt=datetime(2019, 1, 2), + history_minutes=30, + forecast_minutes=60, + image_size_pixels=64, + meters_per_pixel=2000, + ) + + manager = Manager() + + # load config + local_path = Path(nowcasting_dataset.__file__).parent.parent + filename = local_path / "tests" / "config" / "test.yaml" + manager.load_yaml_configuration(filename=filename) + + with tempfile.TemporaryDirectory() as local_temp_path, tempfile.TemporaryDirectory() as dst_path: # noqa 101 + + # set local temp path, and dst path + manager.config.output_data.filepath = Path(dst_path) + manager.local_temp_path = Path(local_temp_path) + + # just set satellite as data source + manager.data_sources = {"gsp": gsp, "sat": sat} + manager.data_source_which_defines_geospatial_locations = gsp + + # make file for locations + manager.create_files_specifying_spatial_and_temporal_locations_of_each_example_if_necessary() # noqa 101 + + # make batches + manager.create_batches(overwrite_batches=True) + + assert os.path.exists(f"{dst_path}/train") + assert os.path.exists(f"{dst_path}/train/gsp") + assert os.path.exists(f"{dst_path}/train/gsp/000000.nc") + assert os.path.exists(f"{dst_path}/train/sat/000000.nc") + assert os.path.exists(f"{dst_path}/train/gsp/000001.nc") + assert os.path.exists(f"{dst_path}/train/sat/000001.nc") + + +def test_save_config(): + """Test that configuration file is saved""" + + manager = Manager() + + # load config + local_path = Path(nowcasting_dataset.__file__).parent.parent + filename = local_path / "tests" / "config" / "test.yaml" + manager.load_yaml_configuration(filename=filename) + + with tempfile.TemporaryDirectory() as local_temp_path, tempfile.TemporaryDirectory() as dst_path: # noqa 101 + + # set local temp path, and dst path + manager.config.output_data.filepath = Path(dst_path) + manager.local_temp_path = Path(local_temp_path) + + # save config + manager.save_yaml_configuration() + + assert os.path.exists(f"{dst_path}/configuration.yaml")