Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Issue/342 manger test #363

Merged
merged 6 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion nowcasting_dataset/dataset/split/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def split_data(

logger.debug("Split data done!")
for split_name, dt in split_datetimes._asdict().items():
logger.debug(f"{split_name} has {len(dt):,d} datetimes, from {dt[0]} to {dt[-1]}")
if len(dt) == 0:
# only a warning is made as this may happen during unittests
logger.warning(f"{split_name} has {len(dt):,d} datetimes")
else:
logger.debug(f"{split_name} has {len(dt):,d} datetimes, from {dt[0]} to {dt[-1]}")

return split_datetimes
4 changes: 2 additions & 2 deletions nowcasting_dataset/filesystem/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
_LOG = logging.getLogger("nowcasting_dataset")


def upload_and_delete_local_files(dst_path: str, local_path: Path):
def upload_and_delete_local_files(dst_path: Union[str, Path], local_path: Union[str, Path]):
"""
Upload an entire folder and delete local files to either AWS or GCP
"""
_LOG.info("Uploading!")
filesystem = get_filesystem(dst_path)
filesystem.put(str(local_path), dst_path, recursive=True)
filesystem.copy(str(local_path), str(dst_path), recursive=True)
delete_all_files_in_temp_path(local_path)


Expand Down
4 changes: 4 additions & 0 deletions nowcasting_dataset/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,10 @@ def create_batches(self, overwrite_batches: bool) -> None:
for worker_id, (data_source_name, data_source) in enumerate(
self.data_sources.items()
):

if len(locations_for_split) == 0:
break

# Get indexes of first batch and example. And subset locations_for_split.
idx_of_first_batch = first_batches_to_create[split_name][data_source_name]
idx_of_first_example = idx_of_first_batch * self.config.process.batch_size
Expand Down
3 changes: 3 additions & 0 deletions tests/config/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ process:
local_temp_path: ~/temp/
seed: 1234
upload_every_n_batches: 16
n_train_batches: 2
n_validation_batches: 0
n_test_batches: 0
81 changes: 80 additions & 1 deletion tests/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Test Manager."""
import os
import tempfile
from datetime import datetime
from pathlib import Path

Expand Down Expand Up @@ -76,4 +78,81 @@ def test_get_daylight_datetime_index():
np.testing.assert_array_equal(t0_datetimes, correct_t0_datetimes)


# TODO: Issue #322: Test the other Manager methods!
def test_batches():
"""Test that batches can be made"""
filename = Path(nowcasting_dataset.__file__).parent.parent / "tests" / "data" / "sat_data.zarr"

sat = SatelliteDataSource(
zarr_path=filename,
history_minutes=30,
forecast_minutes=60,
image_size_pixels=64,
meters_per_pixel=2000,
channels=("HRV",),
)

filename = (
Path(nowcasting_dataset.__file__).parent.parent / "tests" / "data" / "gsp" / "test.zarr"
)

gsp = GSPDataSource(
zarr_path=filename,
start_dt=datetime(2019, 1, 1),
end_dt=datetime(2019, 1, 2),
history_minutes=30,
forecast_minutes=60,
image_size_pixels=64,
meters_per_pixel=2000,
)

manager = Manager()

# load config
local_path = Path(nowcasting_dataset.__file__).parent.parent
filename = local_path / "tests" / "config" / "test.yaml"
manager.load_yaml_configuration(filename=filename)

with tempfile.TemporaryDirectory() as local_temp_path, tempfile.TemporaryDirectory() as dst_path: # noqa 101

# set local temp path, and dst path
manager.config.output_data.filepath = Path(dst_path)
manager.local_temp_path = Path(local_temp_path)

# just set satellite as data source
manager.data_sources = {"gsp": gsp, "sat": sat}
manager.data_source_which_defines_geospatial_locations = gsp

# make file for locations
manager.create_files_specifying_spatial_and_temporal_locations_of_each_example_if_necessary() # noqa 101

# make batches
manager.create_batches(overwrite_batches=True)

assert os.path.exists(f"{dst_path}/train")
assert os.path.exists(f"{dst_path}/train/gsp")
assert os.path.exists(f"{dst_path}/train/gsp/000000.nc")
assert os.path.exists(f"{dst_path}/train/sat/000000.nc")
assert os.path.exists(f"{dst_path}/train/gsp/000001.nc")
assert os.path.exists(f"{dst_path}/train/sat/000001.nc")


def test_save_config():
"""Test that configuration file is saved"""

manager = Manager()

# load config
local_path = Path(nowcasting_dataset.__file__).parent.parent
filename = local_path / "tests" / "config" / "test.yaml"
manager.load_yaml_configuration(filename=filename)

with tempfile.TemporaryDirectory() as local_temp_path, tempfile.TemporaryDirectory() as dst_path: # noqa 101

# set local temp path, and dst path
manager.config.output_data.filepath = Path(dst_path)
manager.local_temp_path = Path(local_temp_path)

# save config
manager.save_yaml_configuration()

assert os.path.exists(f"{dst_path}/configuration.yaml")