Skip to content
This repository was archived by the owner on Sep 28, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
7b677d5
Move position_encoding from nowcasting_dataset PR
jacobbieker Oct 12, 2021
4337828
Start on encoding modalities checks
jacobbieker Oct 13, 2021
46b7f5f
Add getting needed information for relative encoding
jacobbieker Oct 13, 2021
e81cc10
Add absolute encoding, more steps for relative encoding
jacobbieker Oct 13, 2021
990b22b
Make stub for subselecting position encoding
jacobbieker Oct 13, 2021
53b34e8
Add first bit of subselecting position encoding tensor
jacobbieker Oct 13, 2021
62eb6d9
Remove relative encoding, split into different branch
jacobbieker Oct 13, 2021
04e8de5
Go through the batch of datetimes
jacobbieker Oct 13, 2021
7da54df
Simplify code
jacobbieker Oct 13, 2021
0d27504
Add datetime feature creation test
jacobbieker Oct 13, 2021
f635872
Add test for normalizing geospatial coordinates
jacobbieker Oct 13, 2021
d88f826
Add test for absolute position encoding
jacobbieker Oct 13, 2021
01246ad
Run black on tests
jacobbieker Oct 13, 2021
33791a7
Add test for encode position
jacobbieker Oct 13, 2021
bcf6a5d
Add not implemented test
jacobbieker Oct 13, 2021
fe01d2f
Add test for encoding modalities
jacobbieker Oct 13, 2021
c283baa
Add test for multi-modality encoding
jacobbieker Oct 14, 2021
0d6612b
Add PV to multi-modality test
jacobbieker Oct 14, 2021
61a5347
Remove stub tests
jacobbieker Oct 14, 2021
79b9e21
Add requirement
jacobbieker Oct 14, 2021
d6c5643
Bump version
jacobbieker Oct 14, 2021
c617046
Reduce repeats of time features
jacobbieker Oct 14, 2021
690c828
Add check for identical features in different modalities
jacobbieker Oct 14, 2021
1dc65cc
Make check more robust
jacobbieker Oct 14, 2021
02181b9
Add spatial checks as well
jacobbieker Oct 14, 2021
9529943
Add docstring on kwargs
jacobbieker Oct 14, 2021
0a5ba2e
Remove relative and both options for encoding
jacobbieker Oct 14, 2021
0f45431
Update nowcasting_dataloader/utils/position_encoding.py
jacobbieker Oct 14, 2021
9f40ac8
Update tests/test_position_encoding.py
jacobbieker Oct 14, 2021
ba7b484
Add skipping tests
jacobbieker Oct 14, 2021
33c2903
Merge branch 'main' into jacob/position-encoding
jacobbieker Oct 20, 2021
f1c1f6e
Update nowcasting_dataloader/utils/position_encoding.py
jacobbieker Oct 20, 2021
950f7cd
Update nowcasting_dataloader/utils/position_encoding.py
jacobbieker Oct 20, 2021
6bf2f3b
Update nowcasting_dataloader/utils/position_encoding.py
jacobbieker Oct 20, 2021
d14ee1a
Update nowcasting_dataloader/utils/position_encoding.py
jacobbieker Oct 20, 2021
fcf551e
Address some PR comments
jacobbieker Oct 20, 2021
b04cdd5
Merge remote-tracking branch 'origin/jacob/position-encoding' into ja…
jacobbieker Oct 20, 2021
1de189f
Update nowcasting_dataloader/utils/position_encoding.py
jacobbieker Oct 20, 2021
2f4319f
Address more PR comments
jacobbieker Oct 20, 2021
5da3588
Merge remote-tracking branch 'origin/jacob/position-encoding' into ja…
jacobbieker Oct 20, 2021
f7b46b9
Remove encode_position
jacobbieker Oct 20, 2021
6e8036b
Switch to using constants
jacobbieker Oct 20, 2021
ac10f76
Remove method
jacobbieker Oct 20, 2021
197bc93
Fix import
jacobbieker Oct 20, 2021
6e273b5
Fixes
jacobbieker Oct 20, 2021
0326f76
Fix test
jacobbieker Oct 20, 2021
ce9a595
Change docstring
jacobbieker Oct 20, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
default_language_version:
python: python3.9

files: ^nowcasting_dataloader/
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0
Expand All @@ -26,17 +25,3 @@ repos:
hooks:
- id: prettier
types: [yaml]

# Google docstring
- repo: https://github.com/PyCQA/pydocstyle
rev: 6.1.1
hooks:
- id: pydocstyle
args:
[
--convention,
"google",
--add-ignore,
"D200,D210,D212,D415",
"nowcasting_dataloader",
]
1 change: 1 addition & 0 deletions nowcasting_dataloader/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Various utilities for loading and transforming data"""
291 changes: 291 additions & 0 deletions nowcasting_dataloader/utils/position_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
"""
This file contains various ways of performing positional encoding.

These encodings can be:
- Relative positioning (i.e. this pixel is this far from the top left, and this many timesteps in the future)
- Absolute positioning (i.e. this pixel is at this latitude/longitude, and is at 16:00)

These encodings can also be performed with:
- Fourier Features, based off what is done in PerceiverIO
"""
import numpy as np
import torch
import einops
from math import pi
from typing import Union, Optional, Dict, List, Tuple, Any
import datetime


def encode_modalities(
modalities_to_encode: Dict[str, torch.Tensor],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should use an Enum for modality names, instead of strings? I'll start a new issue: #15

datetimes: Dict[str, List[datetime.datetime]],
geospatial_coordinates: Dict[str, Tuple[np.ndarray, np.ndarray]],
geospatial_bounds: Dict[str, float],
**kwargs,
) -> dict:
"""
Create a consistent position encoding and encode the positions of the different modalities in time and space

This position encoding is added as new keys to the dictionary containing the modalities to encode. This is done
instead of appending the position encoding in case the position encoding needs to be used for the query to the
Perceiver IO model

This code assumes that there is at least 2 timesteps of at least one modality to be encoded

Args:
positioning: The type of positioning used, either 'relative' for relative positioning, or 'absolute', or 'both'
modalities_to_encode: Dict of input modalities, i.e. NWP, Satellite, PV, GSP, etc as torch.Tensors in [B, C, T, H, W] ordering
datetimes: Dict of datetimes for each modality, giving the actual date for each timestep in the modality
geospatial_coordinates: Dict of lat/lon coordinates for each modality with pixels, used to determine smallest spatial step needed, in OSGB coordinates
geospatial_bounds: Max extant of the area where examples could be drawn from, used for normalizing coordinates within an area of interest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are the geospatial_bounds represented? Is the float the width of the square region of interest?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going with a dictionary containing the x_min, x_max, y_min, y_max, I've updated the docstring!

kwargs: Passed to fourier_encode

Returns:
Input modality dictionary with extra keys added containing the absolute position encoding of the examples
"""
position_encodings = {}
for key in modalities_to_encode.keys():
position_encodings[key + "_position_encoding"] = encode_position(
[
modalities_to_encode[key].shape[0],
*modalities_to_encode[key].shape[2:],
], # We want to remove the channel dimension, as that's not relevant here
geospatial_coordinates=geospatial_coordinates[key],
datetimes=datetimes[key],
geospatial_bounds=geospatial_bounds,
**kwargs,
)
# Update original dictionary
modalities_to_encode.update(position_encodings)
return modalities_to_encode


def encode_position(
shape: List[int],
geospatial_coordinates: List[np.ndarray],
datetimes: List[datetime.datetime],
geospatial_bounds: Dict[str, float],
method: str = "fourier",
**kwargs,
) -> torch.Tensor:
"""
This function wraps a variety of different methods for generating position features for given inputs.

Args:
shape: The shape of the input to be encoded, should be the largest or finest-grained input
For example, if the inputs are shapes (12, 6, 128, 128) and (1, 6), (12, 6, 128, 128) should be passed in as
shape, as it has the most elements and the input (1, 6) can just subselect the position encoding
geospatial_coordinates: The latitude/longitude of the inputs for shape, in OSGB coordinates
datetimes: time of day and date for each of the timesteps in the shape
method: Method of the encoding, either 'fourier' for Fourier Features
geospatial_bounds: The bounds of the geospatial area covered, in a dict with the keys 'x_min', 'y_min', 'x_max', 'y_max'
kwargs: Passed to fourier_encode

Returns:
The position encodings for all items in the batch
"""
assert method in [
"fourier",
], AssertionError(f"method must be one of 'fourier', not '{method}'")

position_encoding = encode_absolute_position(
shape, geospatial_coordinates, geospatial_bounds, datetimes, **kwargs
)
return position_encoding


def encode_absolute_position(
shape: List[int],
geospatial_coordinates: List[np.ndarray],
geospatial_bounds: Dict[str, float],
datetimes: List[datetime.datetime],
**kwargs,
) -> torch.Tensor:
"""
Encodes the absolute position of the pixels/voxels in time and space

This should be done per-modality and can be thought of as the relative position of the input modalities across a
given year and the area of the Earth covered by all the examples.

Args:
shape: Shape to encode positions for
geospatial_coordinates: The geospatial coordinates, in OSGB format
datetimes: Time of day and date as a list of datetimes, one for each timestep
geospatial_bounds: The geospatial bounds of the area where the examples come from, e.g. the coordinates of the area covered by the SEVIRI RSS image
**kwargs:

Returns:
The absolute position encoding for the given shape
"""
datetime_features = create_datetime_features(datetimes)

# Fourier Features of absolute position
encoded_latlon = normalize_geospatial_coordinates(
geospatial_coordinates, geospatial_bounds, **kwargs
)

# Combine time and space features
to_concat = [einops.repeat(encoded_latlon, "b h w c -> b c t h w", t=shape[1])]
for date_feature in datetime_features:
to_concat.append(
einops.repeat(date_feature, "b t -> b c t h w", h=shape[-2], w=shape[-1], c=1)
)

# Now combined into one large encoding
absolute_position_encoding = torch.cat(to_concat, dim=1)

return absolute_position_encoding


def normalize_geospatial_coordinates(
geospatial_coordinates: List[np.ndarray], geospatial_bounds: Dict[str, float], **kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yay, more types!

) -> torch.Tensor:
"""
Normalize the geospatial coordinates by the max extant to keep everything between -1 and 1, in sin and cos

This should work on a batch level, as the geospatial bounds should be the same for every example in the batch

Args:
geospatial_coordinates: The coordinates for the pixels in the image
geospatial_bounds: The maximum extant

Returns:
The normalized geospatial coordinates, rescaled to between -1 and 1 for the whole extant of the training area

"""
# Normalize the X first
geospatial_coordinates[0] = (geospatial_coordinates[0] - geospatial_bounds["x_min"]) / (
geospatial_bounds["x_max"] - geospatial_bounds["x_min"]
)
# Normalize the Y second
geospatial_coordinates[1] = (geospatial_coordinates[1] - geospatial_bounds["y_min"]) / (
geospatial_bounds["y_max"] - geospatial_bounds["y_min"]
)

# Now those are between 0 and 1, want between -1 and 1
geospatial_coordinates[0] = geospatial_coordinates[0] * 2 - 1
geospatial_coordinates[1] = geospatial_coordinates[1] * 2 - 1
# Now create a grid of the coordinates
# Have to do it for each individual example in the batch, and zip together x and y for it
to_concat = []
for idx in range(len(geospatial_coordinates[0])):
x = geospatial_coordinates[0][idx]
y = geospatial_coordinates[1][idx]
grid = torch.meshgrid(x, y)
pos = torch.stack(grid, dim=-1)
encoded_position = fourier_encode(pos, **kwargs)
encoded_position = einops.rearrange(encoded_position, "... n d -> ... (n d)")
to_concat.append(encoded_position)

# And now convert to Fourier features, based off the absolute positions of the coordinates
encoded_position = torch.stack(to_concat, dim=0)
return encoded_position


def create_datetime_features(
datetimes: List[List[datetime.datetime]],
) -> List[torch.Tensor]:
"""
Converts a list of datetimes to day of year, hour of day sin and cos representation

Args:
datetimes: List of list of datetimes for the examples in a batch

Returns:
Tuple of torch Tensors containing the hour of day sin,cos, and day of year sin,cos
"""
hour_of_day = []
day_of_year = []
for batch_idx in range(len(datetimes)):
hours = []
days = []
for index in datetimes[batch_idx]:
hours.append((index.hour + (index.minute / 60) / 24))
days.append((index.timetuple().tm_yday / 365))
hour_of_day.append(hours)
day_of_year.append(days)

outputs = []
for index in [hour_of_day, day_of_year]:
index = torch.as_tensor(index)
radians = index * 2 * np.pi
index_sin = torch.sin(radians)
index_cos = torch.cos(radians)
outputs.append(index_sin)
outputs.append(index_cos)

return outputs


def encode_fouier_position(
batch_size: int,
axis: list,
max_frequency: float,
num_frequency_bands: int,
sine_only: bool = False,
) -> torch.Tensor:
"""
Encode the Fourier Features and return them

Args:
batch_size: Batch size
axis: List containing the size of each axis
max_frequency: Max frequency
num_frequency_bands: Number of frequency bands to use
sine_only: (bool) Whether to only use Sine features or both Sine and Cosine, defaults to both

Returns:
Torch tensor containing the Fourier Features of shape [Batch, *axis]
"""
axis_pos = list(
map(
lambda size: torch.linspace(-1.0, 1.0, steps=size),
axis,
)
)
pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1)
enc_pos = fourier_encode(
pos,
max_frequency,
num_frequency_bands,
sine_only=sine_only,
)
enc_pos = einops.rearrange(enc_pos, "... n d -> ... (n d)")
enc_pos = einops.repeat(enc_pos, "... -> b ...", b=batch_size)
return enc_pos


def fourier_encode(
x: torch.Tensor,
max_freq: float,
num_bands: int = 4,
sine_only: bool = False,
) -> torch.Tensor:
"""
Create Fourier Encoding

Args:
x: Input Torch Tensor
max_freq: Maximum frequency for the Fourier features
num_bands: Number of frequency bands
sine_only: Whether to only use sine or both sine and cosine features

Returns:
Torch Tensor with the fourier position encoded concatenated
"""
x = x.unsqueeze(-1)
device, dtype, orig_x = x.device, x.dtype, x

scales = torch.linspace(
1.0,
max_freq / 2,
num_bands,
device=device,
dtype=dtype,
)
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]

x = x * scales * pi
x = x.sin() if sine_only else torch.cat([x.sin(), x.cos()], dim=-1)
x = torch.cat((x, orig_x), dim=-1)
return x
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
nowcasting_dataset
torch
pytorch-lightning
einops
Loading