Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 0082a8d

Browse files
authored
Add Local Filesystem Support (#142)
* Add local option * Add testing local option * Bump version * Update test to check for example data
1 parent 16c3ad6 commit 0082a8d

File tree

4 files changed

+51
-6
lines changed

4 files changed

+51
-6
lines changed

nowcasting_dataset/dataset/datasets.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,15 @@ def __init__(
159159
self.gcs = None
160160
self.s3_resource = None
161161

162-
assert cloud in ["gcp", "aws"]
162+
assert cloud in ["gcp", "aws", "local"]
163163

164164
if not os.path.isdir(self.tmp_path):
165165
os.mkdir(self.tmp_path)
166166

167167
def per_worker_init(self, worker_id: int):
168168
if self.cloud == "gcp":
169169
self.gcs = gcsfs.GCSFileSystem()
170-
else:
170+
elif self.cloud == "aws":
171171
self.s3_resource = boto3.resource("s3")
172172

173173
def __len__(self):
@@ -198,15 +198,18 @@ def __getitem__(self, batch_idx: int) -> example.Example:
198198
local_filename=local_netcdf_filename,
199199
gcs=self.gcs,
200200
)
201-
else:
201+
elif self.cloud == "aws":
202202
aws_download_to_local(
203203
remote_filename=remote_netcdf_filename,
204204
local_filename=local_netcdf_filename,
205205
s3_resource=self.s3_resource,
206206
)
207+
else:
208+
local_netcdf_filename = remote_netcdf_filename
207209

208210
netcdf_batch = xr.load_dataset(local_netcdf_filename)
209-
os.remove(local_netcdf_filename)
211+
if self.cloud != "local":
212+
os.remove(local_netcdf_filename)
210213

211214
batch = example.Example(
212215
sat_datetime_index=netcdf_batch.sat_time_coords,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
setup(
1111
name="nowcasting_dataset",
12-
version="0.1.4",
12+
version="0.1.5",
1313
license="MIT",
1414
description="Nowcasting Dataset",
1515
author="Jack Kelly, Peter Dudfield, Jacob Bieker",
File renamed without changes.

tests/test_netcdf_dataset.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
def test_subselect_date():
27-
dataset = xr.open_dataset("tests/data/test.nc")
27+
dataset = xr.open_dataset("tests/data/0.nc")
2828
x = example.Example(
2929
sat_data=dataset["sat_data"],
3030
nwp=dataset["nwp"],
@@ -44,6 +44,48 @@ def test_subselect_date():
4444
assert batch[NWP_DATA].shape[2] == 5
4545

4646

47+
def test_netcdf_dataset_local():
48+
DATA_PATH = "tests/data"
49+
TEMP_PATH = "tests/data/temp/"
50+
51+
train_dataset = NetCDFDataset(
52+
1,
53+
DATA_PATH,
54+
TEMP_PATH,
55+
cloud="local",
56+
history_minutes=10,
57+
forecast_minutes=10,
58+
current_timestep_index=7,
59+
required_keys=(NWP_DATA, NWP_TARGET_TIME, SATELLITE_DATA, SATELLITE_DATETIME_INDEX),
60+
)
61+
62+
dataloader_config = dict(
63+
pin_memory=True,
64+
num_workers=1,
65+
prefetch_factor=1,
66+
worker_init_fn=worker_init_fn,
67+
persistent_workers=True,
68+
# Disable automatic batching because dataset
69+
# returns complete batches.
70+
batch_size=None,
71+
)
72+
73+
_ = torch.utils.data.DataLoader(train_dataset, **dataloader_config)
74+
75+
train_dataset.per_worker_init(1)
76+
t = iter(train_dataset)
77+
data = next(t)
78+
79+
sat_data = data[SATELLITE_DATA]
80+
81+
# Sat is in 5min increments, so should have 2 history + current + 2 future
82+
assert sat_data.shape[1] == 5
83+
assert data[NWP_DATA].shape[2] == 5
84+
85+
# Make sure file isn't deleted!
86+
assert os.path.exists("tests/data/0.nc")
87+
88+
4789
@pytest.mark.skip("CD does not have access to GCS")
4890
def test_get_dataloaders_gcp():
4991
DATA_PATH = "gs://solar-pv-nowcasting-data/prepared_ML_training_data/v5/"

0 commit comments

Comments
 (0)