Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit fd6a1e8

Browse files
Merge pull request #363 from openclimatefix/issue/342-manger-test
Issue/342 manger test
2 parents 642a110 + acb298b commit fd6a1e8

File tree

5 files changed

+94
-4
lines changed

5 files changed

+94
-4
lines changed

nowcasting_dataset/dataset/split/split.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ def split_data(
200200

201201
logger.debug("Split data done!")
202202
for split_name, dt in split_datetimes._asdict().items():
203-
logger.debug(f"{split_name} has {len(dt):,d} datetimes, from {dt[0]} to {dt[-1]}")
203+
if len(dt) == 0:
204+
# only a warning is made as this may happen during unittests
205+
logger.warning(f"{split_name} has {len(dt):,d} datetimes")
206+
else:
207+
logger.debug(f"{split_name} has {len(dt):,d} datetimes, from {dt[0]} to {dt[-1]}")
204208

205209
return split_datetimes

nowcasting_dataset/filesystem/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
_LOG = logging.getLogger("nowcasting_dataset")
1111

1212

13-
def upload_and_delete_local_files(dst_path: str, local_path: Path):
13+
def upload_and_delete_local_files(dst_path: Union[str, Path], local_path: Union[str, Path]):
1414
"""
1515
Upload an entire folder and delete local files to either AWS or GCP
1616
"""
1717
_LOG.info("Uploading!")
1818
filesystem = get_filesystem(dst_path)
19-
filesystem.put(str(local_path), dst_path, recursive=True)
19+
filesystem.copy(str(local_path), str(dst_path), recursive=True)
2020
delete_all_files_in_temp_path(local_path)
2121

2222

nowcasting_dataset/manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ def create_batches(self, overwrite_batches: bool) -> None:
369369
for worker_id, (data_source_name, data_source) in enumerate(
370370
self.data_sources.items()
371371
):
372+
373+
if len(locations_for_split) == 0:
374+
break
375+
372376
# Get indexes of first batch and example. And subset locations_for_split.
373377
idx_of_first_batch = first_batches_to_create[split_name][data_source_name]
374378
idx_of_first_example = idx_of_first_batch * self.config.process.batch_size

tests/config/test.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ process:
3030
local_temp_path: ~/temp/
3131
seed: 1234
3232
upload_every_n_batches: 16
33+
n_train_batches: 2
34+
n_validation_batches: 0
35+
n_test_batches: 0

tests/test_manager.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Test Manager."""
2+
import os
3+
import tempfile
24
from datetime import datetime
35
from pathlib import Path
46

@@ -76,4 +78,81 @@ def test_get_daylight_datetime_index():
7678
np.testing.assert_array_equal(t0_datetimes, correct_t0_datetimes)
7779

7880

79-
# TODO: Issue #322: Test the other Manager methods!
81+
def test_batches():
82+
"""Test that batches can be made"""
83+
filename = Path(nowcasting_dataset.__file__).parent.parent / "tests" / "data" / "sat_data.zarr"
84+
85+
sat = SatelliteDataSource(
86+
zarr_path=filename,
87+
history_minutes=30,
88+
forecast_minutes=60,
89+
image_size_pixels=64,
90+
meters_per_pixel=2000,
91+
channels=("HRV",),
92+
)
93+
94+
filename = (
95+
Path(nowcasting_dataset.__file__).parent.parent / "tests" / "data" / "gsp" / "test.zarr"
96+
)
97+
98+
gsp = GSPDataSource(
99+
zarr_path=filename,
100+
start_dt=datetime(2019, 1, 1),
101+
end_dt=datetime(2019, 1, 2),
102+
history_minutes=30,
103+
forecast_minutes=60,
104+
image_size_pixels=64,
105+
meters_per_pixel=2000,
106+
)
107+
108+
manager = Manager()
109+
110+
# load config
111+
local_path = Path(nowcasting_dataset.__file__).parent.parent
112+
filename = local_path / "tests" / "config" / "test.yaml"
113+
manager.load_yaml_configuration(filename=filename)
114+
115+
with tempfile.TemporaryDirectory() as local_temp_path, tempfile.TemporaryDirectory() as dst_path: # noqa 101
116+
117+
# set local temp path, and dst path
118+
manager.config.output_data.filepath = Path(dst_path)
119+
manager.local_temp_path = Path(local_temp_path)
120+
121+
# just set satellite as data source
122+
manager.data_sources = {"gsp": gsp, "sat": sat}
123+
manager.data_source_which_defines_geospatial_locations = gsp
124+
125+
# make file for locations
126+
manager.create_files_specifying_spatial_and_temporal_locations_of_each_example_if_necessary() # noqa 101
127+
128+
# make batches
129+
manager.create_batches(overwrite_batches=True)
130+
131+
assert os.path.exists(f"{dst_path}/train")
132+
assert os.path.exists(f"{dst_path}/train/gsp")
133+
assert os.path.exists(f"{dst_path}/train/gsp/000000.nc")
134+
assert os.path.exists(f"{dst_path}/train/sat/000000.nc")
135+
assert os.path.exists(f"{dst_path}/train/gsp/000001.nc")
136+
assert os.path.exists(f"{dst_path}/train/sat/000001.nc")
137+
138+
139+
def test_save_config():
140+
"""Test that configuration file is saved"""
141+
142+
manager = Manager()
143+
144+
# load config
145+
local_path = Path(nowcasting_dataset.__file__).parent.parent
146+
filename = local_path / "tests" / "config" / "test.yaml"
147+
manager.load_yaml_configuration(filename=filename)
148+
149+
with tempfile.TemporaryDirectory() as local_temp_path, tempfile.TemporaryDirectory() as dst_path: # noqa 101
150+
151+
# set local temp path, and dst path
152+
manager.config.output_data.filepath = Path(dst_path)
153+
manager.local_temp_path = Path(local_temp_path)
154+
155+
# save config
156+
manager.save_yaml_configuration()
157+
158+
assert os.path.exists(f"{dst_path}/configuration.yaml")

0 commit comments

Comments
 (0)