This repository was archived by the owner on Nov 27, 2023. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 0
Add Positional Encoders #33
Closed
Closed
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
841f718
Copy over Perceiver encoding
jacobbieker c36f34a
Add top-level docstring
jacobbieker 5242c9f
Add stub tests
jacobbieker 6ca30d2
Fix docstring
jacobbieker 62d7114
Add more to docstring
jacobbieker a3562dc
Add assert
jacobbieker a1e3590
Add more into absolute positioning
jacobbieker 65722bd
Add encoding geospatial coordinates
jacobbieker 14af5d8
Fill out absolute position encoding
jacobbieker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
""" | ||
This file contains various ways of performing positional encoding. | ||
|
||
These encodings can be: | ||
- Relative positioning (i.e. this pixel is this far from the top left, and this many timesteps in the future) | ||
- Absolute positioning (i.e. this pixel is at this latitude/longitude, and is at 16:00) | ||
|
||
These encodings can also be performed with: | ||
- Fourier Features, based off what is done in PerceiverIO | ||
- Coordinates, based off the idea from Coordinate Convolutions | ||
""" | ||
import numpy as np | ||
import torch | ||
import einops | ||
from math import pi | ||
from typing import Union, Optional, Dict, List, Tuple | ||
import datetime | ||
|
||
|
||
def encode_position( | ||
shape: List[..., int], | ||
geospatial_coordinates: Optional[Tuple[List[int, ...], List[int, ...]]], | ||
datetimes: Optional[List[datetime.datetime]], | ||
method: str, | ||
positioning: str, | ||
geospatial_bounds: Optional[List[int, int, int, int]], | ||
**kwargs, | ||
) -> torch.Tensor: | ||
""" | ||
This function wraps a variety of different methods for generating position features for given inputs. | ||
|
||
Args: | ||
shape: The shape of the input to be encoded, should be the largest or finest-grained input | ||
For example, if the inputs are shapes (12, 6, 128, 128) and (1, 6), (12, 6, 128, 128) should be passed in as | ||
shape, as it has the most elements and the input (1, 6) can just subselect the position encoding | ||
geospatial_coordinates: The latitude/longitude of the inputs for shape, in OSGB coordinates, unused if using relative positioning only | ||
datetimes: time of day and date for each of the timesteps in the shape, unused if using relative positioning only | ||
method: Method of the encoding, either 'fourier' for Fourier Features | ||
positioning: The type of positioning used, either 'relative' for relative positioning, or 'absolute', or 'both' | ||
geospatial_bounds: The bounds of the geospatial area covered, in x_min, y_min, x_max, y_max ordering, only used for absolute coordinates | ||
|
||
Returns: | ||
The position encodings for all items in the batch | ||
""" | ||
assert method in [ | ||
"fourier", | ||
], ValueError(f"method must be one of 'fourier', not {method}") | ||
assert positioning in ["relative", "absolute", "both"], ValueError( | ||
f"positioning must be one of 'relative', 'absolute' or 'both', not {positioning}" | ||
) | ||
|
||
if positioning == "relative": | ||
position_encoding = encode_relative_position(shape, **kwargs) | ||
elif positioning == "absolute": | ||
position_encoding = encode_absolute_position( | ||
shape, geospatial_coordinates, geospatial_bounds, datetimes | ||
) | ||
else: | ||
# Both position encodings | ||
position_encoding = torch.cat( | ||
[ | ||
encode_relative_position(shape), | ||
encode_absolute_position( | ||
shape, geospatial_coordinates, geospatial_bounds, datetimes | ||
), | ||
], | ||
dim=-1, | ||
) | ||
return position_encoding | ||
|
||
|
||
def encode_relative_position(shape: List[..., int], **kwargs) -> torch.Tensor: | ||
""" | ||
Encode the relative position of the pixels/voxels | ||
|
||
Args: | ||
shape: | ||
|
||
Returns: | ||
The relative position encoding as a torch Tensor | ||
|
||
""" | ||
position_encoding = encode_fouier_position(1, shape, **kwargs) | ||
return position_encoding | ||
|
||
|
||
def encode_absolute_position( | ||
shape: List[..., int], geospatial_coordinates, geospatial_bounds, datetimes, **kwargs | ||
) -> torch.Tensor: | ||
""" | ||
Encodes the absolute position of the pixels/voxels in time and space | ||
|
||
Args: | ||
shape: Shape to encode positions for | ||
geospatial_coordinates: The geospatial coordinates, in OSGB format | ||
datetimes: Time of day and date as a list of datetimes, one for each timestep | ||
**kwargs: | ||
|
||
Returns: | ||
The absolute position encoding for the given shape | ||
""" | ||
hour_of_day_sin, hour_of_day_cos, day_of_year_sin, day_of_year_cos = create_datetime_features( | ||
datetimes | ||
) | ||
|
||
# Fourier Features of absolute position | ||
encoded_latlon = normalize_geospatial_coordinates( | ||
geospatial_coordinates, geospatial_bounds, **kwargs | ||
) | ||
|
||
# Combine time and space features | ||
# Time features should be in shape [Channels,Timestep] | ||
# Space features should be in [Channels, Height, Width] | ||
# So can just concat along channels, after expanding time features tto Height, Width, and Space along Time | ||
hour_of_day_sin = einops.repeat(hour_of_day_sin, "b c t -> b c t h w", h=shape[-2], w=shape[-1]) | ||
hour_of_day_cos = einops.repeat(hour_of_day_cos, "b c t -> b c t h w", h=shape[-2], w=shape[-1]) | ||
day_of_year_sin = einops.repeat(day_of_year_sin, "b c t -> b c t h w", h=shape[-2], w=shape[-1]) | ||
day_of_year_cos = einops.repeat(day_of_year_cos, "b c t -> b c t h w", h=shape[-2], w=shape[-1]) | ||
# Now do for latlon encoding | ||
encoded_latlon = einops.repeat(encoded_latlon, "b c h w -> b c t h w", t=shape[1]) | ||
|
||
# Now combined into one large encoding | ||
absolute_position_encoding = torch.cat( | ||
[encoded_latlon, hour_of_day_sin, hour_of_day_cos, day_of_year_sin, day_of_year_cos], dim=1 | ||
) | ||
|
||
return absolute_position_encoding | ||
|
||
|
||
def normalize_geospatial_coordinates( | ||
geospatial_coordinates, geospatial_bounds, **kwargs | ||
) -> torch.Tensor: | ||
""" | ||
Normalize the geospatial coordinates by the max extant to keep everything between -1 and 1, in sin and cos | ||
|
||
Args: | ||
geospatial_coordinates: The coordinates for the pixels in the image | ||
geospatial_bounds: The maximum extant | ||
|
||
Returns: | ||
The normalized geospatial coordinates, rescaled to between -1 and 1 | ||
|
||
""" | ||
# Normalize the X first | ||
geospatial_coordinates[0] = (geospatial_coordinates[0] - geospatial_bounds[0]) / ( | ||
geospatial_bounds[2] - geospatial_bounds[0] | ||
) | ||
# Normalize the Y second | ||
geospatial_coordinates[1] = (geospatial_coordinates[1] - geospatial_bounds[1]) / ( | ||
geospatial_bounds[3] - geospatial_bounds[1] | ||
) | ||
|
||
# Now those are between 0 and 1, want between -1 and 1 | ||
geospatial_coordinates[0] = geospatial_coordinates[0] * 2 - 1 | ||
geospatial_coordinates[1] = geospatial_coordinates[1] * 2 - 1 | ||
|
||
# Now create a grid of the coordinates | ||
pos = torch.stack(torch.meshgrid(*geospatial_coordinates), dim=-1) | ||
|
||
# And now convert to Fourier features, based off the absolute positions of the coordinates | ||
encoded_position = fourier_encode(pos, **kwargs) | ||
return encoded_position | ||
|
||
|
||
def create_datetime_features( | ||
datetimes: List[datetime.datetime], | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
""" | ||
Converts a list of datetimes to day of year, hour of day sin and cos representation | ||
|
||
Args: | ||
datetimes: List of datetimes | ||
|
||
Returns: | ||
Tuple of torch Tensors containing the hour of day sin,cos, and day of year sin,cos | ||
""" | ||
hour_of_day = [] | ||
day_of_year = [] | ||
for index in datetimes: | ||
hour_of_day.append((index.hour + (index.minute / 60) / 24)) | ||
day_of_year.append((index.timetuple().tm_yday / 365)) # Get the day of the year | ||
hour_of_day = torch.as_tensor(hour_of_day) | ||
day_of_year = torch.as_tensor(day_of_year) | ||
hour_radians = hour_of_day * 2 * np.pi | ||
day_radians = day_of_year * 2 * np.pi | ||
hour_of_day_sin = torch.sin(hour_radians) | ||
hour_of_day_cos = torch.cos(hour_radians) | ||
day_of_year_sin = torch.sin(day_radians) | ||
day_of_year_cos = torch.cos(day_radians) | ||
|
||
return hour_of_day_sin, hour_of_day_cos, day_of_year_sin, day_of_year_cos | ||
|
||
|
||
def encode_fouier_position( | ||
batch_size: int, | ||
axis: list, | ||
max_frequency: float, | ||
num_frequency_bands: int, | ||
sine_only: bool = False, | ||
) -> torch.Tensor: | ||
""" | ||
Encode the Fourier Features and return them | ||
|
||
Args: | ||
batch_size: Batch size | ||
axis: List containing the size of each axis | ||
max_frequency: Max frequency | ||
num_frequency_bands: Number of frequency bands to use | ||
sine_only: (bool) Whether to only use Sine features or both Sine and Cosine, defaults to both | ||
|
||
Returns: | ||
Torch tensor containing the Fourier Features of shape [Batch, *axis] | ||
""" | ||
axis_pos = list( | ||
map( | ||
lambda size: torch.linspace(-1.0, 1.0, steps=size), | ||
axis, | ||
) | ||
) | ||
pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1) | ||
enc_pos = fourier_encode( | ||
pos, | ||
max_frequency, | ||
num_frequency_bands, | ||
sine_only=sine_only, | ||
) | ||
enc_pos = einops.rearrange(enc_pos, "... n d -> ... (n d)") | ||
enc_pos = einops.repeat(enc_pos, "... -> b ...", b=batch_size) | ||
return enc_pos | ||
|
||
|
||
def fourier_encode( | ||
x: torch.Tensor, | ||
max_freq: float, | ||
num_bands: int = 4, | ||
sine_only: bool = False, | ||
) -> torch.Tensor: | ||
""" | ||
Create Fourier Encoding | ||
|
||
Args: | ||
x: Input Torch Tensor | ||
max_freq: Maximum frequency for the Fourier features | ||
num_bands: Number of frequency bands | ||
sine_only: Whether to only use sine or both sine and cosine features | ||
|
||
Returns: | ||
Torch Tensor with the fourier position encoded concatenated | ||
""" | ||
x = x.unsqueeze(-1) | ||
device, dtype, orig_x = x.device, x.dtype, x | ||
|
||
scales = torch.linspace( | ||
1.0, | ||
max_freq / 2, | ||
num_bands, | ||
device=device, | ||
dtype=dtype, | ||
) | ||
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] | ||
|
||
x = x * scales * pi | ||
x = x.sin() if sine_only else torch.cat([x.sin(), x.cos()], dim=-1) | ||
x = torch.cat((x, orig_x), dim=-1) | ||
return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from nowcasting_utils.models.position_encoding import encode_position | ||
import pytest | ||
|
||
def test_fourier_encoding(): | ||
pass | ||
|
||
def test_coordinate_encoding(): | ||
pass | ||
|
||
def test_multi_modality_encoding(): | ||
pass | ||
|
||
def test_5min_30min_encoding(): | ||
pass | ||
|
||
def test_satellite_pv_encoding(): | ||
pass |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
todo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this PR is very much not done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wanted to get more thoughts on the design before I actually finished this, incase we want to move it elsewhere, I can make it more simplified, etc.