diff --git a/src/topotoolbox/flow_object.py b/src/topotoolbox/flow_object.py index 23538e7..83e1642 100644 --- a/src/topotoolbox/flow_object.py +++ b/src/topotoolbox/flow_object.py @@ -143,58 +143,18 @@ def _d8_carve(self, process, which could lead to issues when using very large DEMs. """ dims = grid.dims - dem = np.asarray(grid, dtype=np.float32) - filled_dem = np.zeros_like(dem, dtype=np.float32) - restore_nans = False - if bc is None: - bc = np.ones_like(dem, dtype=np.uint8) - bc[1:-1, 1:-1] = 0 # Set interior pixels to 0 + (aux, filled_dem, flats) = grid.auxiliary_topography(bc, hybrid) - nans = np.isnan(dem) - dem[nans] = -np.inf - bc[nans] = 1 - restore_nans = True - - if not validate_alignment(grid, bc): - err = ("The shape of the provided boundary conditions does not " - f"match the shape of the DEM. {dims}") - raise ValueError(err)from None - - bc = np.asarray(bc, dtype=np.uint8) - - queue = np.zeros(np.prod(dem.shape), dtype=np.int64) - if hybrid: - _grid.fillsinks_hybrid(filled_dem, queue, dem, bc, dims) - else: - _grid.fillsinks(filled_dem, dem, bc, dims) - - if restore_nans: - dem[nans] = np.nan - filled_dem[nans] = np.nan - - flats = np.zeros_like(dem, dtype=np.int32) - _grid.identifyflats(flats, filled_dem, dims) - - costs = np.zeros_like(dem, dtype=np.float32) - conncomps = np.zeros_like(dem, dtype=np.int64) - _grid.gwdt_computecosts(costs, conncomps, flats, dem, filled_dem, dims) - - dist = np.zeros_like(flats, dtype=np.float32) - prev = conncomps # prev: dtype=np.int64 - heap = queue # heap: dtype=np.int64 - back = np.zeros_like(flats, dtype=np.int64) - _grid.gwdt(dist, prev, costs, flats, heap, back, dims) - - node = heap # node: dtype=np.int64 - direction = np.zeros_like(dem, dtype=np.uint8) + node = np.zeros_like(aux, dtype=np.int64) # node: dtype=np.int64 + direction = np.zeros_like(aux, dtype=np.uint8) _grid.flow_routing_d8_carve( - node, direction, filled_dem, dist, flats, dims) + node, direction, filled_dem, aux, flats, dims) # ravel is used here to flatten the arrays. The memory order should not matter # because we only need a block of contiguous memory interpreted as a 1D array. - source = np.ravel(conncomps) # source: dtype=int64 - target = np.ravel(back) # target: dtype=int64 + source = np.zeros(aux.size, dtype=np.int64) # source: dtype=int64 + target = np.zeros(aux.size, dtype=np.int64) # target: dtype=int64 edge_count = _grid.flow_routing_d8_edgelist( source, target, node, direction, dims) diff --git a/src/topotoolbox/grid_object.py b/src/topotoolbox/grid_object.py index cc22eb9..270641c 100755 --- a/src/topotoolbox/grid_object.py +++ b/src/topotoolbox/grid_object.py @@ -311,6 +311,58 @@ def fillsinks(self, return result + def auxiliary_topography(self, bc=None, hybrid=True): + """Compute auxiliary topography + + This is used for least-cost auxiliary topography flow routing. + + Returns + ------- + (aux: np.ndarray, filled_dem: np.ndarray, flats: np.ndarray) + The auxiliary topography, sink-filled DEM and an array + identifying flats in the sink-filled DEM. + """ + dims = self.dims + dem = np.asarray(self, dtype=np.float32) + filled_dem = np.zeros_like(dem) + + restore_nans = False + if bc is None: + bc = np.ones_like(dem, dtype=np.uint8) + bc[1:-1, 1:-1] = 0 # Set interior pixels to 0 + + nans = np.isnan(dem) + dem[nans] = -np.inf + bc[nans] = 1 + restore_nans = True + + bc = np.asarray(bc, dtype=np.uint8) + + queue = np.zeros(np.prod(dem.shape), dtype=np.int64) + if hybrid: + _grid.fillsinks_hybrid(filled_dem, queue, dem, bc, dims) + else: + _grid.fillsinks(filled_dem, dem, bc, dims) + + if restore_nans: + dem[nans] = np.nan + filled_dem[nans] = np.nan + + flats = np.zeros_like(dem, dtype=np.int32) + _grid.identifyflats(flats, filled_dem, dims) + + costs = np.zeros_like(dem, dtype=np.float32) + conncomps = np.zeros_like(dem, dtype=np.int64) + _grid.gwdt_computecosts(costs, conncomps, flats, dem, filled_dem, dims) + + aux = np.zeros_like(flats, dtype=np.float32) + prev = conncomps # prev: dtype=np.int64 + heap = queue # heap: dtype=np.int64 + back = np.zeros_like(flats, dtype=np.int64) + _grid.gwdt(aux, prev, costs, flats, heap, back, dims) + + return (aux, filled_dem, flats) + def identifyflats( self, raw: bool = False, output: list[str] | None = None) -> tuple: """Identifies flats and sills in a digital elevation model (DEM).