Skip to content
This repository was archived by the owner on Sep 28, 2023. It is now read-only.

Commit 2de25e2

Browse files
jacobbiekerflowirtzJackKelly
authored
Add Absolute Position Encoder (#2)
#minor * Move position_encoding from nowcasting_dataset PR * Start on encoding modalities checks * Add getting needed information for relative encoding * Add absolute encoding, more steps for relative encoding * Make stub for subselecting position encoding * Add first bit of subselecting position encoding tensor * Remove relative encoding, split into different branch * Go through the batch of datetimes * Simplify code * Add datetime feature creation test * Add test for normalizing geospatial coordinates * Add test for absolute position encoding * Run black on tests * Add test for encode position * Add not implemented test * Add test for encoding modalities * Add test for multi-modality encoding * Add PV to multi-modality test * Remove stub tests * Add requirement * Bump version * Reduce repeats of time features Instead of matching the number of channels in the spatial dimension, just have it be single channel, as its a single value per timestep * Add check for identical features in different modalities * Make check more robust * Add spatial checks as well * Add docstring on kwargs * Remove relative and both options for encoding * Update nowcasting_dataloader/utils/position_encoding.py Co-authored-by: Flo <[email protected]> * Update tests/test_position_encoding.py Co-authored-by: Flo <[email protected]> * Add skipping tests * Update nowcasting_dataloader/utils/position_encoding.py Co-authored-by: Jack Kelly <[email protected]> * Update nowcasting_dataloader/utils/position_encoding.py Co-authored-by: Jack Kelly <[email protected]> * Update nowcasting_dataloader/utils/position_encoding.py Co-authored-by: Jack Kelly <[email protected]> * Update nowcasting_dataloader/utils/position_encoding.py Co-authored-by: Jack Kelly <[email protected]> * Address some PR comments * Update nowcasting_dataloader/utils/position_encoding.py Co-authored-by: Jack Kelly <[email protected]> * Address more PR comments * Remove encode_position * Switch to using constants * Remove method * Fix import * Fixes * Fix test * Change docstring Co-authored-by: Flo <[email protected]> Co-authored-by: Jack Kelly <[email protected]>
1 parent 7cad5c3 commit 2de25e2

File tree

8 files changed

+425
-18
lines changed

8 files changed

+425
-18
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
default_language_version:
22
python: python3.9
33

4-
files: ^nowcasting_dataloader/
54
repos:
65
- repo: https://github.com/pre-commit/pre-commit-hooks
76
rev: v3.4.0
@@ -26,17 +25,3 @@ repos:
2625
hooks:
2726
- id: prettier
2827
types: [yaml]
29-
30-
# Google docstring
31-
- repo: https://github.com/PyCQA/pydocstyle
32-
rev: 6.1.1
33-
hooks:
34-
- id: pydocstyle
35-
args:
36-
[
37-
--convention,
38-
"google",
39-
--add-ignore,
40-
"D200,D210,D212,D415",
41-
"nowcasting_dataloader",
42-
]

nowcasting_dataloader/data_sources/pv/pv_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
PV_SYSTEM_ROW_NUMBER,
1414
PV_SYSTEM_ID,
1515
)
16-
from nowcasting_dataset.data_sources.datasource_output import (
16+
from nowcasting_dataloader.data_sources.datasource_output import (
1717
DataSourceOutputML,
1818
)
1919
from nowcasting_dataset.time import make_random_time_vectors

nowcasting_dataloader/data_sources/satellite/satellite_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pydantic import Field
99

1010
from nowcasting_dataset.consts import Array
11-
from nowcasting_dataset.data_sources.datasource_output import (
11+
from nowcasting_dataloader.data_sources.datasource_output import (
1212
DataSourceOutputML,
1313
)
1414
from nowcasting_dataset.time import make_random_time_vectors

nowcasting_dataloader/data_sources/sun/sun_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import Field, validator
66

77
from nowcasting_dataset.consts import Array, SUN_AZIMUTH_ANGLE, SUN_ELEVATION_ANGLE
8-
from nowcasting_dataset.data_sources.datasource_output import (
8+
from nowcasting_dataloader.data_sources.datasource_output import (
99
DataSourceOutputML,
1010
)
1111
from nowcasting_dataset.time import make_random_time_vectors
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Various utilities for loading and transforming data"""
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
"""
2+
This file contains various ways of performing positional encoding.
3+
4+
These encodings can be:
5+
- Absolute positioning (i.e. this pixel is at this latitude/longitude, and is at 16:00)
6+
7+
These encodings can also be performed with:
8+
- Fourier Features, based off what is done in PerceiverIO
9+
"""
10+
import numpy as np
11+
import torch
12+
import einops
13+
from math import pi
14+
from typing import Union, Optional, Dict, List, Tuple, Any
15+
import datetime
16+
17+
TIME_DIM = 2
18+
HEIGHT_DIM = 3
19+
WIDTH_DIM = 4
20+
21+
22+
def encode_modalities(
23+
modalities_to_encode: Dict[str, torch.Tensor],
24+
datetimes: Dict[str, List[datetime.datetime]],
25+
geospatial_coordinates: Dict[str, Tuple[np.ndarray, np.ndarray]],
26+
geospatial_bounds: Dict[str, float],
27+
**kwargs,
28+
) -> dict[str, torch.Tensor]:
29+
"""
30+
Create a consistent position encoding and encode the positions of the different modalities in time and space
31+
32+
This position encoding is added as new keys to the dictionary containing the modalities to encode. This is done
33+
instead of appending the position encoding in case the position encoding needs to be used for the query to the
34+
Perceiver IO model
35+
36+
This code assumes that there is at least 2 timesteps of at least one modality to be encoded
37+
38+
Args:
39+
modalities_to_encode: Dict of input modalities, i.e. NWP, Satellite, PV, GSP, etc as torch.Tensors in [B, C, T, H, W] ordering
40+
datetimes: Dict of datetimes for each modality, giving the actual date for each timestep in the modality
41+
geospatial_coordinates: Dict of x, y coordinates for each modality with pixels, used to determine smallest spatial step needed, in OSGB coordinates
42+
geospatial_bounds: Max extant of the area where examples could be drawn from, used for normalizing coordinates within an area of interest
43+
in the format of a dictionary with the keys {'x_min', 'x_max', 'y_min', 'y_max'}
44+
kwargs: Passed to fourier_encode
45+
46+
Returns:
47+
Input modality dictionary where for every 'key' in modalities_to_encode, a new key called 'key+'_position_encoding' will be added
48+
containing the absolute position encoding of the examples
49+
"""
50+
position_encodings = {}
51+
for key in modalities_to_encode.keys():
52+
position_encodings[key + "_position_encoding"] = encode_absolute_position(
53+
shape=modalities_to_encode[key].shape,
54+
geospatial_coordinates=geospatial_coordinates[key],
55+
datetimes=datetimes[key],
56+
geospatial_bounds=geospatial_bounds,
57+
**kwargs,
58+
)
59+
# Update original dictionary
60+
modalities_to_encode.update(position_encodings)
61+
return modalities_to_encode
62+
63+
64+
def encode_absolute_position(
65+
shape: List[int],
66+
geospatial_coordinates: List[np.ndarray],
67+
geospatial_bounds: Dict[str, float],
68+
datetimes: List[datetime.datetime],
69+
**kwargs,
70+
) -> torch.Tensor:
71+
"""
72+
Encodes the absolute position of the pixels/voxels in time and space
73+
74+
This should be done per-modality and can be thought of as the relative position of the input modalities across a
75+
given year and the area of the Earth covered by all the examples.
76+
77+
Args:
78+
shape: Shape to encode positions for
79+
geospatial_coordinates: The geospatial coordinates, in OSGB format
80+
datetimes: Time of day and date as a list of datetimes, one for each timestep
81+
geospatial_bounds: The geospatial bounds of the area where the examples come from, e.g. the coordinates of the area covered by the SEVIRI RSS image
82+
**kwargs:
83+
84+
Returns:
85+
The absolute position encoding for the given shape
86+
"""
87+
datetime_features = create_datetime_features(datetimes)
88+
89+
# Fourier Features of absolute position
90+
encoded_geo_position = normalize_geospatial_coordinates(
91+
geospatial_coordinates, geospatial_bounds, **kwargs
92+
)
93+
94+
# Combine time and space features
95+
to_concat = [einops.repeat(encoded_geo_position, "b h w c -> b c t h w", t=shape[TIME_DIM])]
96+
for date_feature in datetime_features:
97+
to_concat.append(
98+
einops.repeat(
99+
date_feature, "b t -> b c t h w", h=shape[HEIGHT_DIM], w=shape[WIDTH_DIM], c=1
100+
)
101+
)
102+
103+
# Now combined into one large encoding
104+
absolute_position_encoding = torch.cat(to_concat, dim=1)
105+
106+
return absolute_position_encoding
107+
108+
109+
def normalize_geospatial_coordinates(
110+
geospatial_coordinates: List[np.ndarray], geospatial_bounds: Dict[str, float], **kwargs
111+
) -> torch.Tensor:
112+
"""
113+
Normalize the geospatial coordinates by the max extant to keep everything between -1 and 1, in sin and cos
114+
115+
This normalization should be against a set geospatial area, so that the same place has the same spatial encoding
116+
every time.
117+
118+
Args:
119+
geospatial_coordinates: The coordinates for the pixels in the image
120+
geospatial_bounds: The maximum extant
121+
122+
Returns:
123+
The normalized geospatial coordinates, rescaled to between -1 and 1 for the whole extant of the training area
124+
125+
"""
126+
# Normalize the X first
127+
geospatial_coordinates[0] = (geospatial_coordinates[0] - geospatial_bounds["x_min"]) / (
128+
geospatial_bounds["x_max"] - geospatial_bounds["x_min"]
129+
)
130+
# Normalize the Y second
131+
geospatial_coordinates[1] = (geospatial_coordinates[1] - geospatial_bounds["y_min"]) / (
132+
geospatial_bounds["y_max"] - geospatial_bounds["y_min"]
133+
)
134+
135+
# Now those are between 0 and 1, want between -1 and 1
136+
geospatial_coordinates[0] = geospatial_coordinates[0] * 2 - 1
137+
geospatial_coordinates[1] = geospatial_coordinates[1] * 2 - 1
138+
# Now create a grid of the coordinates
139+
# Have to do it for each individual example in the batch, and zip together x and y for it
140+
to_concat = []
141+
for idx in range(len(geospatial_coordinates[0])):
142+
x = geospatial_coordinates[0][idx]
143+
y = geospatial_coordinates[1][idx]
144+
grid = torch.meshgrid(x, y)
145+
pos = torch.stack(grid, dim=-1)
146+
encoded_position = fourier_encode(pos, **kwargs)
147+
encoded_position = einops.rearrange(encoded_position, "... n d -> ... (n d)")
148+
to_concat.append(encoded_position)
149+
150+
encoded_position = torch.stack(to_concat, dim=0)
151+
return encoded_position
152+
153+
154+
def create_datetime_features(
155+
datetimes: List[List[datetime.datetime]],
156+
) -> List[torch.Tensor]:
157+
"""
158+
Converts a list of datetimes to day of year, hour of day sin and cos representation
159+
160+
Args:
161+
datetimes: List of list of datetimes for the examples in a batch
162+
163+
Returns:
164+
Tuple of torch Tensors containing the hour of day sin,cos, and day of year sin,cos
165+
"""
166+
hour_of_day = []
167+
day_of_year = []
168+
for batch_idx in range(len(datetimes)):
169+
hours = []
170+
days = []
171+
for index in datetimes[batch_idx]:
172+
hours.append((index.hour + (index.minute / 60) / 24))
173+
days.append((index.timetuple().tm_yday / 365))
174+
hour_of_day.append(hours)
175+
day_of_year.append(days)
176+
177+
outputs = []
178+
for index in [hour_of_day, day_of_year]:
179+
index = torch.as_tensor(index)
180+
radians = index * 2 * np.pi
181+
index_sin = torch.sin(radians)
182+
index_cos = torch.cos(radians)
183+
outputs.append(index_sin)
184+
outputs.append(index_cos)
185+
186+
return outputs
187+
188+
189+
def fourier_encode(
190+
x: torch.Tensor,
191+
max_freq: float,
192+
num_bands: int = 4,
193+
sine_only: bool = False,
194+
) -> torch.Tensor:
195+
"""
196+
Create Fourier Encoding
197+
198+
Args:
199+
x: Input Torch Tensor
200+
max_freq: Maximum frequency for the Fourier features
201+
num_bands: Number of frequency bands
202+
sine_only: Whether to only use sine or both sine and cosine features
203+
204+
Returns:
205+
Torch Tensor with the fourier position encoded concatenated
206+
"""
207+
x = x.unsqueeze(-1)
208+
device, dtype, orig_x = x.device, x.dtype, x
209+
210+
scales = torch.linspace(
211+
1.0,
212+
max_freq / 2,
213+
num_bands,
214+
device=device,
215+
dtype=dtype,
216+
)
217+
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
218+
219+
x = x * scales * pi
220+
x = x.sin() if sine_only else torch.cat([x.sin(), x.cos()], dim=-1)
221+
x = torch.cat((x, orig_x), dim=-1)
222+
return x

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
nowcasting_dataset
22
torch
33
pytorch-lightning
4+
einops

0 commit comments

Comments
 (0)