Skip to content
This repository was archived by the owner on Sep 28, 2023. It is now read-only.

Commit 03b5cb2

Browse files
committed
Add stub version of encoding for Batch
1 parent e322488 commit 03b5cb2

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

nowcasting_dataloader/utils/position_encoding.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,96 @@
1313
from math import pi
1414
from typing import Union, Optional, Dict, List, Tuple, Any
1515
import datetime
16+
from nowcasting_dataset.dataset.batch import Batch
1617

1718
TIME_DIM = 2
1819
HEIGHT_DIM = 3
1920
WIDTH_DIM = 4
2021

22+
SEVIRI_RSS_BOUNDS = {"x_min": 0, "y_min": 0, "x_max": 1, "y_max": 1}
23+
24+
25+
def generate_position_encodings_for_batch(batch: Batch, **kwargs) -> dict[str, torch.Tensor]:
26+
"""
27+
Generates positional encodings and returns them as a dictionary
28+
29+
This is not returned with the Batch, as that would require more keys, etc. in Batch
30+
31+
Args:
32+
batch: Batch object holding the data
33+
34+
Returns:
35+
Dictionary containing the keys of the modalities in the Batch + '_position_encoding'
36+
"""
37+
38+
assert batch.datetime is not None, "Datetime must be set for position encoding to work"
39+
40+
position_encodings = {}
41+
# Go for each modality where a position encoding makes sense
42+
if batch.satellite is not None:
43+
position_encodings[batch.satellite.key + "_position_encoding"] = encode_absolute_position(
44+
[
45+
batch.batch_size,
46+
len(batch.satellite.time),
47+
len(batch.satellite.x),
48+
len(batch.satellite.y),
49+
], # We want to remove the channel dimension, as that's not relevant here
50+
geospatial_coordinates=[batch.satellite.x, batch.satellite.y],
51+
datetimes=batch.satellite.time,
52+
geospatial_bounds=SEVIRI_RSS_BOUNDS,
53+
**kwargs,
54+
)
55+
56+
if batch.nwp is not None:
57+
position_encodings[batch.satellite.key + "_position_encoding"] = encode_absolute_position(
58+
[
59+
batch.batch_size,
60+
len(batch.satellite.time),
61+
len(batch.satellite.x),
62+
len(batch.satellite.y),
63+
], # We want to remove the channel dimension, as that's not relevant here
64+
geospatial_coordinates=[batch.satellite.x, batch.satellite.y],
65+
datetimes=batch.satellite.time,
66+
geospatial_bounds=SEVIRI_RSS_BOUNDS,
67+
**kwargs,
68+
)
69+
70+
if batch.gsp is not None:
71+
position_encodings[batch.satellite.key + "_position_encoding"] = encode_absolute_position(
72+
[
73+
batch.batch_size,
74+
len(batch.gsp.time),
75+
len(batch.gsp.x),
76+
len(batch.gsp.y),
77+
], # We want to remove the channel dimension, as that's not relevant here
78+
geospatial_coordinates=[batch.gsp.x, batch.gsp.y],
79+
datetimes=batch.gsp.time,
80+
geospatial_bounds=SEVIRI_RSS_BOUNDS,
81+
**kwargs,
82+
)
83+
84+
if batch.pv is not None:
85+
position_encodings[batch.satellite.key + "_position_encoding"] = encode_absolute_position(
86+
[
87+
batch.batch_size,
88+
len(batch.satellite.time),
89+
len(batch.satellite.x),
90+
len(batch.satellite.y),
91+
], # We want to remove the channel dimension, as that's not relevant here
92+
geospatial_coordinates=[batch.satellite.x, batch.satellite.y],
93+
datetimes=batch.satellite.time,
94+
geospatial_bounds=SEVIRI_RSS_BOUNDS,
95+
**kwargs,
96+
)
97+
98+
if batch.sun is not None:
99+
pass
100+
101+
if batch.topographic is not None:
102+
pass
103+
104+
return NotImplementedError
105+
21106

22107
def encode_modalities(
23108
modalities_to_encode: Dict[str, torch.Tensor],

0 commit comments

Comments
 (0)