From 0370d5463488764123953e97ab59e2484552e1c6 Mon Sep 17 00:00:00 2001 From: Keith Bennett Date: Mon, 9 Jun 2025 14:00:32 +0100 Subject: [PATCH] Fix/improve ray filtering flags in plot_rays() --- sdf_helper/sdf_helper.py | 96 ++++++++++++++++++++++------------------ 1 file changed, 52 insertions(+), 44 deletions(-) diff --git a/sdf_helper/sdf_helper.py b/sdf_helper/sdf_helper.py index 7311f7b..4ea2c2d 100644 --- a/sdf_helper/sdf_helper.py +++ b/sdf_helper/sdf_helper.py @@ -1262,19 +1262,31 @@ def plot_rays(var, skip=1, rays=None, **kwargs): ---------- var : sdf.Block The SDF variable for the rays to plot + rays : list + A list of ray numbers to plot + ray_start : integer + The first ray number to plot + ray_stop : integer + The last ray number to plot skip : integer Number of rays to skip before selecting the next one to plot """ - ray_start = -1 - l = "ray_start" - if l in kwargs: - ray_start = kwargs[l] + if rays: + if isinstance(rays, int): + rays = [rays] + else: + ray_start, ray_stop = None, None + + l = "ray_start" + if l in kwargs: + ray_start = kwargs[l] + + l = "ray_stop" + if l in kwargs: + ray_stop = kwargs[l] - ray_stop = 1e9 - l = "ray_stop" - if l in kwargs: - ray_stop = kwargs[l] + ray_slice = slice(ray_start, ray_stop, skip) if isinstance(var, sdf.BlockStitchedPath): v = var.data[0] @@ -1303,27 +1315,26 @@ def plot_rays(var, skip=1, rays=None, **kwargs): v = var.data[0] vmin = v.data.min() vmax = v.data.max() - for iray, v in enumerate(var.data): - if iray < ray_start: - continue - if iray > ray_stop: - break - if iray % skip == 0: - vmin = min(vmin, v.data.min()) - vmax = max(vmax, v.data.max()) + if rays: + ray_list = [var.data[i] for i in rays] + else: + ray_list = var.data[ray_slice] + for v in ray_list: + vmin = min(vmin, v.data.min()) + vmax = max(vmax, v.data.max()) if k0 not in kwargs: kwargs[k0] = vmin if k1 not in kwargs: kwargs[k1] = vmax - for iray, v in enumerate(var.data): - if iray < ray_start: - continue - if iray > ray_stop: - break - if iray % skip == 0: - plot_auto(v, update=False, **kwargs) - kwargs["hold"] = True + if rays: + ray_list = [var.data[i] for i in rays] + else: + ray_list = var.data[ray_slice] + + for v in ray_list: + plot_auto(v, update=False, **kwargs) + kwargs["hold"] = True plot_auto(var.data[0], axis_only=True, **kwargs) kwargs["hold"] = True @@ -1341,17 +1352,16 @@ def plot_rays(var, skip=1, rays=None, **kwargs): if k not in kwargs and not (k0 in kwargs and k1 in kwargs): vmin = var.data.min() vmax = var.data.max() - iray = -1 - for k in data.keys(): + + if rays: + ray_list = [data[i] for i in rays] + else: + ray_list = data[ray_slice] + + for k in ray_list.keys(): if k.startswith(start) and k.endswith(end): - iray += 1 - if iray < ray_start: - continue - if iray > ray_stop: - break - if iray % skip == 0: - vmin = min(vmin, data[k].data.min()) - vmax = max(vmax, data[k].data.max()) + vmin = min(vmin, data[k].data.min()) + vmax = max(vmax, data[k].data.max()) if k0 not in kwargs: kwargs[k0] = vmin if k1 not in kwargs: @@ -1367,17 +1377,15 @@ def plot_rays(var, skip=1, rays=None, **kwargs): + ")$" ) - iray = -1 - for k in data.keys(): + if rays: + ray_list = [data[i] for i in rays] + else: + ray_list = data[ray_slice] + + for k in ray_list.keys(): if k.startswith(start) and k.endswith(end): - iray += 1 - if iray < ray_start: - continue - if iray > ray_stop: - break - if iray % skip == 0: - plot_auto(data[k], hold=True, update=False, **kwargs) - kwargs["hold"] = True + plot_auto(data[k], hold=True, update=False, **kwargs) + kwargs["hold"] = True plot_auto(var, axis_only=True, **kwargs) kwargs["hold"] = True