Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 4fa6173

Browse files
Merge pull request #195 from openclimatefix/issue/166-batch-pydantic
Issue/166 batch pydantic
2 parents 73bbe20 + ee09b9a commit 4fa6173

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+2511
-998
lines changed

conftest.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from nowcasting_dataset.config.load import load_yaml_configuration
1010
from nowcasting_dataset.data_sources import SatelliteDataSource
1111
from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource
12+
from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource
1213

1314
pytest.IMAGE_SIZE_PIXELS = 128
1415

@@ -50,6 +51,14 @@ def sat_data_source(sat_filename: Path):
5051
)
5152

5253

54+
@pytest.fixture
55+
def general_data_source():
56+
57+
return MetadataDataSource(
58+
history_minutes=0, forecast_minutes=5, object_at_center="GSP", convert_to_numpy=True
59+
)
60+
61+
5362
@pytest.fixture
5463
def gsp_data_source():
5564
return GSPDataSource(
@@ -65,9 +74,9 @@ def gsp_data_source():
6574
@pytest.fixture
6675
def configuration():
6776
filename = os.path.join(os.path.dirname(nowcasting_dataset.__file__), "config", "gcp.yaml")
68-
config = load_yaml_configuration(filename)
77+
configuration = load_yaml_configuration(filename)
6978

70-
return config
79+
return configuration
7180

7281

7382
@pytest.fixture

notebooks/2021-09/2021-09-07/sat_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from datetime import datetime
22

3-
from nowcasting_dataset.data_sources.satellite_data_source import SatelliteDataSource
3+
from nowcasting_dataset.data_sources.satellite.satellite_data_source import SatelliteDataSource
44

55
s = SatelliteDataSource(
66
filename="gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/"
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from pydantic import BaseModel, Field, validator
2+
from typing import Union
3+
import numpy as np
4+
import xarray as xr
5+
import torch
6+
from nowcasting_dataset.config.model import Configuration
7+
8+
9+
Array = Union[xr.DataArray, np.ndarray, torch.Tensor]
10+
11+
12+
class Satellite(BaseModel):
13+
14+
# width: int = Field(..., g=0, description="The width of the satellite image")
15+
# height: int = Field(..., g=0, description="The width of the satellite image")
16+
# num_channels: int = Field(..., g=0, description="The width of the satellite image")
17+
18+
# Shape: [batch_size,] seq_length, width, height, channel
19+
image_data: Array = Field(
20+
...,
21+
description="Satellites images. Shape: [batch_size,] seq_length, width, height, channel",
22+
)
23+
x_coords: Array = Field(
24+
...,
25+
description="The x (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] width",
26+
)
27+
y_coords: Array = Field(
28+
...,
29+
description="The y (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] width",
30+
)
31+
32+
# @validator("sat_data")
33+
# def image_shape(cls, v):
34+
# assert v.shape[-1] == cls.num_channels
35+
# assert v.shape[-2] == cls.height
36+
# assert v.shape[-3] == cls.width
37+
#
38+
# @validator("x_coords")
39+
# def x_coords_shape(cls, v):
40+
# assert v.shape[-1] == cls.width
41+
#
42+
# @validator("y_coords")
43+
# def y_coords_shape(cls, v):
44+
# assert v.shape[-1] == cls.height
45+
#
46+
class Config:
47+
arbitrary_types_allowed = True
48+
49+
50+
class Batch(BaseModel):
51+
52+
batch_size: int = Field(
53+
...,
54+
g=0,
55+
description="The size of this batch. If the batch size is 0, "
56+
"then this item stores one data item",
57+
)
58+
59+
satellite: Satellite
60+
61+
62+
class FakeDataset(torch.utils.data.Dataset):
63+
"""Fake dataset."""
64+
65+
def __init__(self, configuration: Configuration = Configuration(), length: int = 10):
66+
"""
67+
Init
68+
69+
Args:
70+
configuration: configuration object
71+
length: length of dataset
72+
"""
73+
self.batch_size = configuration.process.batch_size
74+
self.seq_length_5 = (
75+
configuration.process.seq_len_5_minutes
76+
) # the sequence data in 5 minute steps
77+
self.seq_length_30 = (
78+
configuration.process.seq_len_30_minutes
79+
) # the sequence data in 30 minute steps
80+
self.satellite_image_size_pixels = configuration.process.satellite_image_size_pixels
81+
self.nwp_image_size_pixels = configuration.process.nwp_image_size_pixels
82+
self.number_sat_channels = len(configuration.process.sat_channels)
83+
self.number_nwp_channels = len(configuration.process.nwp_channels)
84+
self.length = length
85+
86+
def __len__(self):
87+
"""Number of pieces of data"""
88+
return self.length
89+
90+
def per_worker_init(self, worker_id: int):
91+
"""Not needed"""
92+
pass
93+
94+
def __getitem__(self, idx):
95+
"""
96+
Get item, use for iter and next method
97+
98+
Args:
99+
idx: batch index
100+
101+
Returns: Dictionary of random data
102+
103+
"""
104+
105+
sat = Satellite(
106+
image_data=np.random.randn(
107+
self.batch_size,
108+
self.seq_length_5,
109+
self.satellite_image_size_pixels,
110+
self.satellite_image_size_pixels,
111+
self.number_sat_channels,
112+
),
113+
x_coords=torch.sort(torch.randn(self.batch_size, self.satellite_image_size_pixels))[0],
114+
y_coords=torch.sort(
115+
torch.randn(self.batch_size, self.satellite_image_size_pixels), descending=True
116+
)[0],
117+
)
118+
119+
# Note need to return as nested dict
120+
return Batch(satellite=sat, batch_size=self.batch_size).dict()
121+
122+
123+
train = torch.utils.data.DataLoader(FakeDataset())
124+
i = iter(train)
125+
x = next(i)
126+
127+
x = Batch(**x)
128+
# IT WORKS
129+
assert type(x.satellite.image_data) == torch.Tensor

notebooks/2021-10/2021-10-08/no_validation.py

Whitespace-only changes.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import os
2+
3+
import numpy as np
4+
import xarray as xr
5+
from nowcasting_dataset.utils import coord_to_range
6+
7+
8+
def get_satellite_xrarray_data_array(
9+
batch_size, seq_length_5, satellite_image_size_pixels, number_sat_channels=10
10+
):
11+
12+
r = np.random.randn(
13+
# self.batch_size,
14+
seq_length_5,
15+
satellite_image_size_pixels,
16+
satellite_image_size_pixels,
17+
number_sat_channels,
18+
)
19+
20+
time = np.sort(np.random.randn(seq_length_5))
21+
22+
x_coords = np.sort(np.random.randint(0, 1000, (satellite_image_size_pixels)))
23+
y_coords = np.sort(np.random.randint(0, 1000, (satellite_image_size_pixels)))[::-1].copy()
24+
25+
sat_xr = xr.DataArray(
26+
data=r,
27+
dims=["time", "x", "y", "channels"],
28+
coords=dict(
29+
# batch=range(0,self.batch_size),
30+
x=list(x_coords),
31+
y=list(y_coords),
32+
time=list(time),
33+
channels=range(0, number_sat_channels),
34+
),
35+
attrs=dict(
36+
description="Ambient temperature.",
37+
units="degC",
38+
),
39+
name="sata_data",
40+
)
41+
42+
return sat_xr
43+
44+
45+
def sat_data_array_to_dataset(sat_xr):
46+
ds = sat_xr.to_dataset(name="sat_data")
47+
# ds["sat_data"] = ds["sat_data"].astype(np.int16)
48+
49+
for dim in ["time", "x", "y"]:
50+
# This does seem like the right way to do it
51+
# https://ecco-v4-python-tutorial.readthedocs.io/ECCO_v4_Saving_Datasets_and_DataArrays_to_NetCDF.html
52+
ds = coord_to_range(ds, dim, prefix="sat")
53+
ds = ds.rename(
54+
{
55+
"channels": f"sat_channels",
56+
"x": f"sat_x",
57+
"y": f"sat_y",
58+
}
59+
)
60+
61+
# ds["sat_x_coords"] = ds["sat_x_coords"].astype(np.int32)
62+
# ds["sat_y_coords"] = ds["sat_y_coords"].astype(np.int32)
63+
64+
return ds
65+
66+
67+
def to_netcdf(batch_xr, local_filename):
68+
encoding = {name: {"compression": "lzf"} for name in batch_xr.data_vars}
69+
batch_xr.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding)
70+
71+
72+
# 1. try to save netcdf files not using coord to range function
73+
sat_xrs = [get_satellite_xrarray_data_array(4, 19, 32) for _ in range(0, 10)]
74+
75+
### error ###
76+
# cant do this step as x/y index has duplicate values
77+
sat_dataset = xr.merge(sat_xrs)
78+
to_netcdf(sat_dataset, "test_no_alignment.nc")
79+
###
80+
81+
# but can save it as separate files
82+
os.mkdir("test_no_alignment")
83+
[sat_xrs[i].to_netcdf(f"test_no_alignment/{i}.nc", engine="h5netcdf") for i in range(0, 10)]
84+
# 10 files about 1.5MB
85+
86+
# 2.
87+
sat_xrs = [get_satellite_xrarray_data_array(4, 19, 32) for _ in range(0, 10)]
88+
sat_xrs = [sat_data_array_to_dataset(sat_xr) for sat_xr in sat_xrs]
89+
90+
sat_dataset = xr.concat(sat_xrs, dim="example")
91+
to_netcdf(sat_dataset, "test_alignment.nc")
92+
# this 15 MB
93+
94+
95+
# conclusion
96+
# no major improvement in compression by joining datasets together, buts by joining array together,
97+
# it does make it easier to get array ready ML
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from pydantic import BaseModel, Field, validator
2+
from typing import Union, List
3+
import numpy as np
4+
import xarray as xr
5+
import torch
6+
from nowcasting_dataset.config.model import Configuration
7+
8+
9+
Array = Union[xr.DataArray, np.ndarray, torch.Tensor]
10+
11+
12+
class Satellite(BaseModel):
13+
# Shape: [batch_size,] seq_length, width, height, channel
14+
image_data: xr.DataArray = Field(
15+
...,
16+
description="Satellites images. Shape: [batch_size,] seq_length, width, height, channel",
17+
)
18+
19+
class Config:
20+
arbitrary_types_allowed = True
21+
22+
@validator("image_data")
23+
def v_image_data(cls, v):
24+
print("validating image data")
25+
return v
26+
27+
28+
class Batch(BaseModel):
29+
30+
batch_size: int = 0
31+
satellite: Satellite
32+
33+
@validator("batch_size")
34+
def v_image_data(cls, v):
35+
print("validating batch size")
36+
return v
37+
38+
39+
s = Satellite(image_data=xr.DataArray())
40+
s_dict = s.dict()
41+
42+
x = Satellite(**s_dict)
43+
x = Satellite.construct(Satellite.__fields_set__, **s_dict)
44+
45+
46+
batch = Batch(batch_size=5, satellite=s)
47+
48+
b_dict = batch.dict()
49+
50+
x = Batch(**b_dict)
51+
x = Batch.construct(Batch.__fields_set__, **b_dict)
52+
53+
54+
# class Satellite(BaseModel):
55+
#
56+
# image_data: xr.DataArray
57+
#
58+
# # validate
59+
#
60+
# def to_dataset(self):
61+
# pass
62+
#
63+
# def from_dateset(self):
64+
# pass
65+
#
66+
# def to_numpy(self) -> SatelliteNumpy:
67+
# pass
68+
#
69+
#
70+
# class SatelliteNumpy(BaseModel):
71+
#
72+
# image_data: np.ndarray
73+
# x: np.ndarray
74+
# # more
75+
#
76+
#
77+
# class Example(BaseModel):
78+
#
79+
# satelllite: Satellite
80+
# # more
81+
#
82+
#
83+
# class Batch(BaseModel):
84+
#
85+
# batch_size: int = 0
86+
# examples: List[Example]
87+
#
88+
# def to/from_netcdf():
89+
# pass
90+
#
91+
#
92+
# class BatchNumpy(BaseModel):
93+
#
94+
# batch_size: int = 0
95+
# satellite: SatellliteNumpy
96+
# # more
97+
#
98+
# def from_batch(self) -> BatchNumpy:
99+
# """ change to Batch numpy structure """

nowcasting_dataset/config/gcp.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ input_data:
66
satellite_zarr_path: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
77
solar_pv_data_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_timeseries_batch.nc
88
solar_pv_metadata_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv
9-
gsp_zarr_path: gs://solar-pv-nowcasting-data/PV/PVOutput.org/PV/GSP/v1/pv_gsp.zarr
9+
gsp_zarr_path: gs://solar-pv-nowcasting-data/PV/GSP/v1/pv_gsp.zarr
1010
topographic_filename: gs://solar-pv-nowcasting-data/Topographic/europe_dem_1km_osgb.tif
1111
output_data:
12-
filepath: gs://solar-pv-nowcasting-data/prepared_ML_training_data/v6/
12+
filepath: gs://solar-pv-nowcasting-data/prepared_ML_training_data/v7/
1313
process:
1414
local_temp_path: ~/temp/
1515
seed: 1234

0 commit comments

Comments
 (0)