Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HySpecNet-11k: add additional metadata #2569

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/data/hyspecnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

np.random.seed(0)

# Tile name purposefully shortened to avoid Windows git filename length limit.
tiles = ['ENMAP01_20221103T162438Z']
tiles = ['ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z']
patches = ['Y01460273_X05670694', 'Y01460273_X06950822']

profile = {
Expand All @@ -41,7 +40,8 @@
for tile in tiles:
for patch in patches:
# Split CSV
path = os.path.join(tile, f'{tile}-{patch}', f'{tile}-{patch}-DATA.npy')
# Directory purposefully shortened to avoid Windows git filename length limit.
path = os.path.join(f'{tile}-{patch}-DATA.npy')
for split in ['train', 'val', 'test']:
with open(os.path.join(root, 'splits', 'easy', f'{split}.csv'), 'a+') as f:
f.write(f'{path}\n')
Expand Down
Binary file modified tests/data/hyspecnet/hyspecnet-11k-01.tar.gz
Binary file not shown.
Binary file modified tests/data/hyspecnet/hyspecnet-11k-splits.tar.gz
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/data/hyspecnet/hyspecnet-11k/splits/easy/test.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy
ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X05670694-DATA.npy
ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X06950822-DATA.npy
4 changes: 2 additions & 2 deletions tests/data/hyspecnet/hyspecnet-11k/splits/easy/train.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy
ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X05670694-DATA.npy
ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X06950822-DATA.npy
4 changes: 2 additions & 2 deletions tests/data/hyspecnet/hyspecnet-11k/splits/easy/val.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy
ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy
ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X05670694-DATA.npy
ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X06950822-DATA.npy
2 changes: 1 addition & 1 deletion tests/datasets/test_hyspecnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_plot(self, dataset: HySpecNet11k) -> None:
plt.close()

def test_plot_rgb(self, dataset: HySpecNet11k) -> None:
dataset = HySpecNet11k(root=dataset.root, bands=(1, 2, 3))
dataset = HySpecNet11k(root=dataset.root, bands=('B1', 'B2', 'B3'))
match = 'Dataset does not contain some of the RGB bands'
with pytest.raises(RGBBandsMissingError, match=match):
dataset.plot(dataset[0])
67 changes: 33 additions & 34 deletions torchgeo/datasets/hyspecnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""HySpecNet dataset."""

import os
import re
from collections.abc import Callable, Sequence
from typing import ClassVar

Expand All @@ -14,36 +15,16 @@
from matplotlib.figure import Figure
from torch import Tensor

from .enmap import EnMAP
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
from .utils import Path, download_url, extract_archive, percentile_normalization

# https://git.tu-berlin.de/rsim/hyspecnet-tools/-/blob/main/tif_to_npy.ipynb
invalid_channels = [
126,
127,
128,
129,
130,
131,
132,
133,
134,
135,
136,
137,
138,
139,
140,
160,
161,
162,
163,
164,
165,
166,
]
valid_channels_ids = [c + 1 for c in range(224) if c not in invalid_channels]
from .utils import (
Path,
disambiguate_timestamp,
download_url,
extract_archive,
percentile_normalization,
)


class HySpecNet11k(NonGeoDataset):
Expand Down Expand Up @@ -99,15 +80,16 @@ class HySpecNet11k(NonGeoDataset):
'hyspecnet-11k-splits.tar.gz': '94fad9e3c979c612c29a045406247d6c',
}

all_bands = valid_channels_ids
rgb_bands = (43, 28, 10)
all_bands = EnMAP.all_bands
default_bands = EnMAP.default_bands
rgb_bands = EnMAP.rgb_bands

def __init__(
self,
root: Path = 'data',
split: str = 'train',
strategy: str = 'easy',
bands: Sequence[int] = all_bands,
bands: Sequence[str] | None = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
Expand All @@ -131,11 +113,14 @@ def __init__(
self.root = root
self.split = split
self.strategy = strategy
self.bands = bands
self.bands = bands or self.default_bands
self.transforms = transforms
self.download = download
self.checksum = checksum

self.wavelengths = torch.tensor([EnMAP.wavelengths[b] for b in self.bands])
self.band_indices = [self.all_bands.index(b) + 1 for b in self.bands]

self._verify()

path = os.path.join(root, 'hyspecnet-11k', 'splits', strategy, f'{split}.csv')
Expand All @@ -159,9 +144,23 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
Returns:
Data and label at that index.
"""
file = self.files[index].replace('DATA.npy', 'SPECTRAL_IMAGE.TIF')
path = self.files[index].replace('DATA.npy', 'SPECTRAL_IMAGE.TIF')
file = os.path.basename(path)
match = re.match(EnMAP.filename_regex, file, re.VERBOSE)
assert match
mint, maxt = disambiguate_timestamp(match.group('date'), EnMAP.date_format)

with rio.open(os.path.join(self.root, 'hyspecnet-11k', 'patches', file)) as src:
sample = {'image': torch.tensor(src.read(self.bands).astype('float32'))}
minx, maxx = src.bounds.left, src.bounds.right
miny, maxy = src.bounds.bottom, src.bounds.top
sample = {
'image': torch.tensor(src.read(self.band_indices).astype('float32')),
'x': torch.tensor((minx + maxx) / 2),
'y': torch.tensor((miny + maxy) / 2),
't': torch.tensor((mint + maxt) / 2),
'wavelength': self.wavelengths,
'res': torch.tensor(30),
}

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down