Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 4d01e8a

Browse files
authored
Merge pull request #256 from openclimatefix/jack/get_contiguous_time_periods
Implement `DataSource.get_contiguous_time_periods()`
2 parents 7de1a35 + 0abff22 commit 4d01e8a

File tree

6 files changed

+62
-57
lines changed

6 files changed

+62
-57
lines changed

nowcasting_dataset/data_sources/data_source.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def __post_init__(self):
6666

6767
self._history_duration = pd.Timedelta(self.history_minutes, unit="minutes")
6868
self._forecast_duration = pd.Timedelta(self.forecast_minutes, unit="minutes")
69-
# Add sample_period_duration because neither history_duration not forecast_duration include t0.
69+
# Add sample_period_duration because neither history_duration not forecast_duration
70+
# include t0.
7071
self._total_seq_duration = (
7172
self._history_duration + self._forecast_duration + self.sample_period_duration
7273
)
@@ -112,13 +113,13 @@ def get_batch(
112113
Get Batch Data
113114
114115
Args:
115-
t0_datetimes: list of timestamps for the datetime of the batches. The batch will also include data
116-
for historic and future depending on 'history_minutes' and 'future_minutes'.
116+
t0_datetimes: list of timestamps for the datetime of the batches. The batch will also
117+
include data for historic and future depending on `history_minutes` and
118+
`future_minutes`.
117119
x_locations: x center batch locations
118120
y_locations: y center batch locations
119121
120-
Returns: Batch data
121-
122+
Returns: Batch data.
122123
"""
123124
examples = []
124125
zipped = zip(t0_datetimes, x_locations, y_locations)
@@ -176,31 +177,34 @@ def get_contiguous_time_periods(self) -> pd.DataFrame:
176177
Returns:
177178
pd.DataFrame where each row represents a single time period. The pd.DataFrame
178179
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
179-
"""
180180
181-
# TODO: Use nd_time.get_contiguous_time_periods()
182-
# See https://github.com/openclimatefix/nowcasting_dataset/issues/223
183-
raise NotImplementedError()
184-
185-
def _get_time_slice(self, t0_dt: pd.Timestamp):
186-
"""Get a single timestep of data. Must be overridden."""
187-
raise NotImplementedError()
181+
Raises:
182+
NotImplementedError if this DataSource has no concept of a datetime index.
183+
"""
184+
datetimes = self.datetime_index()
185+
return nd_time.get_contiguous_time_periods(
186+
datetimes=datetimes,
187+
min_seq_length=self._total_seq_length,
188+
max_gap_duration=self.sample_period_duration,
189+
)
188190

189-
# ****************** METHODS THAT MUST BE OVERRIDDEN **********************
190191
def get_locations_for_batch(
191192
self, t0_datetimes: pd.DatetimeIndex
192193
) -> Tuple[List[Number], List[Number]]:
193-
"""Find a valid geographical location for each t0_datetime.
194+
"""Find a valid geographical locations for each t0_datetime.
195+
196+
Should be overridden by DataSources which may be used to define the locations
197+
for each batch.
194198
195199
Returns: x_locations, y_locations. Each has one entry per t0_datetime.
196200
Locations are in OSGB coordinates.
197201
"""
198-
# TODO: Do this properly, using PV locations!
199-
locations = [20_000, 40_000, 500_000, 600_000, 100_000, 100_000, 250_000, 250_000]
200-
201-
location = np.random.choice(locations, size=(len(t0_datetimes), 2))
202+
raise NotImplementedError()
202203

203-
return location[:, 0], location[:, 1]
204+
# ****************** METHODS THAT MUST BE OVERRIDDEN **********************
205+
def _get_time_slice(self, t0_dt: pd.Timestamp):
206+
"""Get a single timestep of data. Must be overridden."""
207+
raise NotImplementedError()
204208

205209
def get_example(
206210
self,
@@ -273,8 +277,8 @@ def get_example(
273277
Get Example data
274278
275279
Args:
276-
t0_dt: list of timestamps for the datetime of the batches. The batch will also include data
277-
for historic and future depending on 'history_minutes' and 'future_minutes'.
280+
t0_dt: list of timestamps for the datetime of the batches. The batch will also include
281+
data for historic and future depending on `history_minutes` and `future_minutes`.
278282
x_meters_center: x center batch locations
279283
y_meters_center: y center batch locations
280284

nowcasting_dataset/time.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,20 @@ def get_start_datetimes(
172172

173173

174174
def get_contiguous_time_periods(
175-
datetimes: pd.DatetimeIndex, min_seq_length: int, max_gap: pd.Timedelta = THIRTY_MINUTES
175+
datetimes: pd.DatetimeIndex,
176+
min_seq_length: int,
177+
max_gap_duration: pd.Timedelta = THIRTY_MINUTES,
176178
) -> pd.DataFrame:
177179
"""Returns a pd.DataFrame where each row records the boundary of a contiguous time periods.
178180
179181
Args:
180182
datetimes: The pd.DatetimeIndex of the timeseries. Must be sorted.
181-
min_seq_length: Sequences of min_seq_length or shorter will be discarded.
182-
max_gap: If any pair of consecutive `datetimes` is more than `max_gap` apart, then this pair
183-
of `datetimes` will be considered a "gap" between two contiguous sequences.
183+
min_seq_length: Sequences of min_seq_length or shorter will be discarded. Typically, this
184+
would be set to the `total_seq_length` of each machine learning example.
185+
max_gap_duration: If any pair of consecutive `datetimes` is more than `max_gap_duration`
186+
apart, then this pair of `datetimes` will be considered a "gap" between two contiguous
187+
sequences. Typically, `max_gap_duration` would be set to the sample period of
188+
the timeseries.
184189
185190
Returns:
186191
pd.DataFrame where each row represents a single time period. The pd.DataFrame
@@ -193,7 +198,7 @@ def get_contiguous_time_periods(
193198
assert datetimes.is_unique
194199

195200
# Find indices of gaps larger than max_gap:
196-
gap_mask = np.diff(datetimes) > max_gap
201+
gap_mask = np.diff(datetimes) > max_gap_duration
197202
gap_indices = np.argwhere(gap_mask)[:, 0]
198203

199204
# gap_indicies are the indices into dt_index for the timestep immediately

tests/data_sources/test_data_source.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33

44
def test_image_data_source():
5-
65
_ = ImageDataSource(
76
image_size_pixels=64,
87
meters_per_pixel=2000,

tests/data_sources/test_nwp_data_source.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import os
2+
import pandas as pd
23

34
import nowcasting_dataset
45
from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWPDataSource
56

67

7-
def test_nwp_data_source_init():
8+
PATH = os.path.dirname(nowcasting_dataset.__file__)
89

9-
path = os.path.dirname(nowcasting_dataset.__file__)
10+
# Solar PV data (test data)
11+
NWP_FILENAME = f"{PATH}/../tests/data/nwp_data/test.zarr"
1012

11-
# Solar PV data (test data)
12-
NWP_FILENAME = f"{path}/../tests/data/nwp_data/test.zarr"
1313

14+
def test_nwp_data_source_init():
1415
_ = NWPDataSource(
1516
filename=NWP_FILENAME,
1617
history_minutes=30,
@@ -21,12 +22,6 @@ def test_nwp_data_source_init():
2122

2223

2324
def test_nwp_data_source_open():
24-
25-
path = os.path.dirname(nowcasting_dataset.__file__)
26-
27-
# Solar PV data (test data)
28-
NWP_FILENAME = f"{path}/../tests/data/nwp_data/test.zarr"
29-
3025
nwp = NWPDataSource(
3126
filename=NWP_FILENAME,
3227
history_minutes=30,
@@ -40,12 +35,6 @@ def test_nwp_data_source_open():
4035

4136

4237
def test_nwp_data_source_batch():
43-
44-
path = os.path.dirname(nowcasting_dataset.__file__)
45-
46-
# Solar PV data (test data)
47-
NWP_FILENAME = f"{path}/../tests/data/nwp_data/test.zarr"
48-
4938
nwp = NWPDataSource(
5039
filename=NWP_FILENAME,
5140
history_minutes=30,
@@ -64,3 +53,20 @@ def test_nwp_data_source_batch():
6453
batch = nwp.get_batch(t0_datetimes=t0_datetimes, x_locations=x, y_locations=y)
6554

6655
assert batch.data.shape == (4, 1, 19, 2, 2)
56+
57+
58+
def test_nwp_get_contiguous_time_periods():
59+
nwp = NWPDataSource(
60+
filename=NWP_FILENAME,
61+
history_minutes=30,
62+
forecast_minutes=60,
63+
convert_to_numpy=True,
64+
n_timesteps_per_batch=8,
65+
channels=["t"],
66+
)
67+
68+
contiguous_time_periods = nwp.get_contiguous_time_periods()
69+
correct_time_periods = pd.DataFrame(
70+
[{"start_dt": pd.Timestamp("2019-01-01 00:00"), "end_dt": pd.Timestamp("2019-01-02 02:00")}]
71+
)
72+
pd.testing.assert_frame_equal(contiguous_time_periods, correct_time_periods)

tests/test_dataset.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import pytest
44

55
import nowcasting_dataset.time as nd_time
6-
from nowcasting_dataset.consts import GSP_DATETIME_INDEX
76
from nowcasting_dataset.dataset.datasets import NowcastingDataset
87
from nowcasting_dataset.dataset.batch import Batch
98

@@ -56,16 +55,8 @@ def test_per_worker_init(dataset: NowcastingDataset):
5655

5756
def test_get_batch(dataset: NowcastingDataset):
5857
dataset.per_worker_init(worker_id=1)
59-
batch = dataset._get_batch()
60-
assert isinstance(batch, Batch)
61-
assert batch.satellite is not None
62-
assert batch.satellite.data.shape == (
63-
8,
64-
2,
65-
pytest.IMAGE_SIZE_PIXELS,
66-
pytest.IMAGE_SIZE_PIXELS,
67-
1,
68-
)
58+
with pytest.raises(NotImplementedError):
59+
_ = dataset._get_batch()
6960

7061

7162
def test_get_batch_gsp(dataset_gsp: NowcastingDataset):

tests/test_time.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_get_contiguous_time_periods_1_with_1_chunk(min_seq_length):
6868
freq = pd.Timedelta(5, unit="minutes")
6969
dt_index = pd.date_range("2010-01-01", "2010-01-02", freq=freq)
7070
periods: pd.DataFrame = nd_time.get_contiguous_time_periods(
71-
dt_index, min_seq_length=min_seq_length, max_gap=freq
71+
dt_index, min_seq_length=min_seq_length, max_gap_duration=freq
7272
)
7373
correct_periods = pd.DataFrame([{"start_dt": dt_index[0], "end_dt": dt_index[-1]}])
7474
pd.testing.assert_frame_equal(periods, correct_periods)
@@ -81,7 +81,7 @@ def test_get_contiguous_time_periods_2_with_2_chunks(min_seq_length):
8181
dt_index2 = pd.date_range("2010-02-01", "2010-02-02", freq=freq)
8282
dt_index = dt_index1.union(dt_index2)
8383
periods: pd.DataFrame = nd_time.get_contiguous_time_periods(
84-
dt_index, min_seq_length=min_seq_length, max_gap=freq
84+
dt_index, min_seq_length=min_seq_length, max_gap_duration=freq
8585
)
8686
correct_periods = pd.DataFrame(
8787
[

0 commit comments

Comments
 (0)