Skip to content

Commit

Permalink
HySpecNet-11k: add additional metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Feb 9, 2025
1 parent b9404d1 commit b03aabe
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 44 deletions.
6 changes: 3 additions & 3 deletions tests/data/hyspecnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

np.random.seed(0)

# Tile name purposefully shortened to avoid Windows git filename length limit.
tiles = ['ENMAP01_20221103T162438Z']
tiles = ['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z']
patches = ['Y01460273_X05670694', 'Y01460273_X06950822']

profile = {
Expand All @@ -41,7 +40,8 @@
for tile in tiles:
for patch in patches:
# Split CSV
path = os.path.join(tile, f'{tile}-{patch}', f'{tile}-{patch}-DATA.npy')
# Directory purposefully shortened to avoid Windows git filename length limit.
path = os.path.join(f'{tile}-{patch}-DATA.npy')
for split in ['train', 'val', 'test']:
with open(os.path.join(root, 'splits', 'easy', f'{split}.csv'), 'a+') as f:
f.write(f'{path}\n')
Expand Down
Binary file modified tests/data/hyspecnet/hyspecnet-11k-01.tar.gz
Binary file not shown.
Binary file modified tests/data/hyspecnet/hyspecnet-11k-splits.tar.gz
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/data/hyspecnet/hyspecnet-11k/splits/easy/test.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy
ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X05670694-DATA.npy
ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X06950822-DATA.npy
4 changes: 2 additions & 2 deletions tests/data/hyspecnet/hyspecnet-11k/splits/easy/train.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy
ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X05670694-DATA.npy
ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X06950822-DATA.npy
4 changes: 2 additions & 2 deletions tests/data/hyspecnet/hyspecnet-11k/splits/easy/val.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy
ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X05670694-DATA.npy
ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X06950822-DATA.npy
2 changes: 1 addition & 1 deletion tests/datasets/test_hyspecnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_plot(self, dataset: HySpecNet11k) -> None:
plt.close()

def test_plot_rgb(self, dataset: HySpecNet11k) -> None:
dataset = HySpecNet11k(root=dataset.root, bands=(1, 2, 3))
dataset = HySpecNet11k(root=dataset.root, bands=('B1', 'B2', 'B3'))
match = 'Dataset does not contain some of the RGB bands'
with pytest.raises(RGBBandsMissingError, match=match):
dataset.plot(dataset[0])
67 changes: 33 additions & 34 deletions torchgeo/datasets/hyspecnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""HySpecNet dataset."""

import os
import re
from collections.abc import Callable, Sequence
from typing import ClassVar

Expand All @@ -14,36 +15,16 @@
from matplotlib.figure import Figure
from torch import Tensor

from .enmap import EnMAP
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
from .utils import Path, download_url, extract_archive, percentile_normalization

# https://git.tu-berlin.de/rsim/hyspecnet-tools/-/blob/main/tif_to_npy.ipynb
invalid_channels = [
126,
127,
128,
129,
130,
131,
132,
133,
134,
135,
136,
137,
138,
139,
140,
160,
161,
162,
163,
164,
165,
166,
]
valid_channels_ids = [c + 1 for c in range(224) if c not in invalid_channels]
from .utils import (
Path,
disambiguate_timestamp,
download_url,
extract_archive,
percentile_normalization,
)


class HySpecNet11k(NonGeoDataset):
Expand Down Expand Up @@ -99,15 +80,16 @@ class HySpecNet11k(NonGeoDataset):
'hyspecnet-11k-splits.tar.gz': '94fad9e3c979c612c29a045406247d6c',
}

all_bands = valid_channels_ids
rgb_bands = (43, 28, 10)
all_bands = EnMAP.all_bands
default_bands = EnMAP.default_bands
rgb_bands = EnMAP.rgb_bands

def __init__(
self,
root: Path = 'data',
split: str = 'train',
strategy: str = 'easy',
bands: Sequence[int] = all_bands,
bands: Sequence[str] | None = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
Expand All @@ -131,11 +113,14 @@ def __init__(
self.root = root
self.split = split
self.strategy = strategy
self.bands = bands
self.bands = bands or self.default_bands
self.transforms = transforms
self.download = download
self.checksum = checksum

self.wavelengths = torch.tensor([EnMAP.wavelengths[b] for b in self.bands])
self.band_indices = [self.all_bands.index(b) + 1 for b in self.bands]

self._verify()

path = os.path.join(root, 'hyspecnet-11k', 'splits', strategy, f'{split}.csv')
Expand All @@ -159,9 +144,23 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
Returns:
Data and label at that index.
"""
file = self.files[index].replace('DATA.npy', 'SPECTRAL_IMAGE.TIF')
path = self.files[index].replace('DATA.npy', 'SPECTRAL_IMAGE.TIF')
file = os.path.basename(path)
match = re.match(EnMAP.filename_regex, file, re.VERBOSE)
assert match
mint, maxt = disambiguate_timestamp(match.group('date'), EnMAP.date_format)

with rio.open(os.path.join(self.root, 'hyspecnet-11k', 'patches', file)) as src:
sample = {'image': torch.tensor(src.read(self.bands).astype('float32'))}
minx, maxx = src.bounds.left, src.bounds.right
miny, maxy = src.bounds.bottom, src.bounds.top
sample = {
'image': torch.tensor(src.read(self.band_indices).astype('float32')),
'x': torch.tensor((minx + maxx) / 2),
'y': torch.tensor((miny + maxy) / 2),
't': torch.tensor((mint + maxt) / 2),
'wavelength': self.wavelengths,
'res': torch.tensor(30),
}

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down

0 comments on commit b03aabe

Please sign in to comment.