From bab475893842fdd7cd92432ca772301cacd6823d Mon Sep 17 00:00:00 2001 From: Wang Boyu Date: Tue, 26 Jul 2022 23:14:38 +0800 Subject: [PATCH] create RasterLayer from file --- mesa_geo/raster_layers.py | 33 ++++++++++++++++++++++++++------- tests/test_RasterLayer.py | 8 ++++---- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/mesa_geo/raster_layers.py b/mesa_geo/raster_layers.py index 618416a1..8b9ae3e6 100644 --- a/mesa_geo/raster_layers.py +++ b/mesa_geo/raster_layers.py @@ -15,6 +15,7 @@ Sequence, Iterator, Iterable, + Type, ) import numpy as np @@ -154,7 +155,7 @@ class RasterLayer(RasterBase): cells: List[List[Cell]] _neighborhood_cache: Dict[Any, List[Coordinate]] - def __init__(self, width, height, crs, total_bounds, cell_cls=Cell): + def __init__(self, width, height, crs, total_bounds, cell_cls: Type[Cell] = Cell): super().__init__(width, height, crs, total_bounds) self.cell_cls = cell_cls self.cells = [] @@ -232,13 +233,13 @@ def coord_iter(self) -> Iterator[Tuple[Cell, int, int]]: for col in range(self.height): yield self.cells[row][col], row, col # cell, x, y - def apply_raster(self, data: np.ndarray, name: str = None) -> None: + def apply_raster(self, data: np.ndarray, attr_name: str = None) -> None: assert data.shape == (1, self.height, self.width) - if name is None: - name = f"attribute_{len(self.cell_cls.__dict__)}" + if attr_name is None: + attr_name = f"attribute_{len(self.cell_cls.__dict__)}" for x in range(self.width): for y in range(self.height): - setattr(self.cells[x][y], name, data[0, self.height - y - 1, x]) + setattr(self.cells[x][y], attr_name, data[0, self.height - y - 1, x]) def iter_neighborhood( self, @@ -396,6 +397,24 @@ def to_image(self, colormap) -> ImageLayer: values[:, row, col] = colormap(cell) return ImageLayer(values=values, crs=self.crs, total_bounds=self.total_bounds) + @classmethod + def from_file( + cls, raster_file, cell_cls: Type[Cell] = Cell, attr_name: str = None + ) -> RasterLayer: + with rio.open(raster_file, "r") as dataset: + values = dataset.read() + _, height, width = values.shape + total_bounds = [ + dataset.bounds.left, + dataset.bounds.bottom, + dataset.bounds.right, + dataset.bounds.top, + ] + obj = cls(width, height, dataset.crs, total_bounds, cell_cls) + obj._transform = dataset.transform + obj.apply_raster(values, attr_name=attr_name) + return obj + class ImageLayer(RasterBase): _values: np.ndarray @@ -458,8 +477,8 @@ def to_crs(self, crs, inplace=False) -> ImageLayer | None: return layer @classmethod - def from_file(cls, raster_file: str) -> ImageLayer: - with rio.open(raster_file, "r") as dataset: + def from_file(cls, image_file) -> ImageLayer: + with rio.open(image_file, "r") as dataset: values = dataset.read() total_bounds = [ dataset.bounds.left, diff --git a/tests/test_RasterLayer.py b/tests/test_RasterLayer.py index 6a33650a..4e8c116a 100644 --- a/tests/test_RasterLayer.py +++ b/tests/test_RasterLayer.py @@ -38,12 +38,12 @@ def test_apple_raster(self): """ self.assertEqual(self.raster_layer.cells[0][1].attribute_5, 3) - self.raster_layer.apply_raster(raster_data, name="elevation") + self.raster_layer.apply_raster(raster_data, attr_name="elevation") self.assertEqual(self.raster_layer.cells[0][1].elevation, 3) def test_get_min_cell(self): self.raster_layer.apply_raster( - np.array([[[1, 2], [3, 4], [5, 6]]]), name="elevation" + np.array([[[1, 2], [3, 4], [5, 6]]]), attr_name="elevation" ) min_cell = min( @@ -63,7 +63,7 @@ def test_get_min_cell(self): self.assertEqual(min_cell.elevation, 1) self.raster_layer.apply_raster( - np.array([[[1, 2], [3, 4], [5, 6]]]), name="water_level" + np.array([[[1, 2], [3, 4], [5, 6]]]), attr_name="water_level" ) min_cell = min( self.raster_layer.get_neighboring_cells( @@ -77,7 +77,7 @@ def test_get_min_cell(self): def test_get_max_cell(self): self.raster_layer.apply_raster( - np.array([[[1, 2], [3, 4], [5, 6]]]), name="elevation" + np.array([[[1, 2], [3, 4], [5, 6]]]), attr_name="elevation" ) max_cell = max(