diff --git a/docs/tutorials/visualization_tutorial.ipynb b/docs/tutorials/visualization_tutorial.ipynb index 14d9ba4d1ea..ba8d7a922d2 100644 --- a/docs/tutorials/visualization_tutorial.ipynb +++ b/docs/tutorials/visualization_tutorial.ipynb @@ -142,7 +142,8 @@ "source": [ "#### Changing the agents\n", "\n", - "In the visualization above, all we could see is the agents moving around -- but not how much money they had, or anything else of interest. Let's change it so that agents who are broke (wealth 0) are drawn in red, smaller. (TODO: currently, we can't predict the drawing order of the circles, so a broke agent may be overshadowed by a wealthy agent. We should fix this by doing a hollow circle instead)\n", + "In the visualization above, all we could see is the agents moving around -- but not how much money they had, or anything else of interest. Let's change it so that agents who are broke (wealth 0) are drawn in red, smaller. (TODO: Currently, we can't predict the drawing order of the circles, so a broke agent may be overshadowed by a wealthy agent. We should fix this by doing a hollow circle instead)\n", + "In addition to size and color, an agent's shape can also be customized when using the default drawer. The allowed values for shapes can be found [here](https://matplotlib.org/stable/api/markers_api.html).\n", "\n", "To do this, we go back to our `agent_portrayal` code and add some code to change the portrayal based on the agent properties and launch the server again." ] diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index aadfa206472..83d0e3d8eaf 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -1,3 +1,5 @@ +from collections import defaultdict + import networkx as nx import solara from matplotlib.figure import Figure @@ -23,12 +25,44 @@ def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = Non solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies) +# matplotlib scatter does not allow for multiple shapes in one call +def _split_and_scatter(portray_data, space_ax): + grouped_data = defaultdict(lambda: {"x": [], "y": [], "s": [], "c": []}) + + # Extract data from the dictionary + x = portray_data["x"] + y = portray_data["y"] + s = portray_data["s"] + c = portray_data["c"] + m = portray_data["m"] + + if not (len(x) == len(y) == len(s) == len(c) == len(m)): + raise ValueError( + "Length mismatch in portrayal data lists: " + f"x: {len(x)}, y: {len(y)}, size: {len(s)}, " + f"color: {len(c)}, marker: {len(m)}" + ) + + # Group the data by marker + for i in range(len(x)): + marker = m[i] + grouped_data[marker]["x"].append(x[i]) + grouped_data[marker]["y"].append(y[i]) + grouped_data[marker]["s"].append(s[i]) + grouped_data[marker]["c"].append(c[i]) + + # Plot each group with the same marker + for marker, data in grouped_data.items(): + space_ax.scatter(data["x"], data["y"], s=data["s"], c=data["c"], marker=marker) + + def _draw_grid(space, space_ax, agent_portrayal): def portray(g): x = [] y = [] s = [] # size c = [] # color + m = [] # shape for i in range(g.width): for j in range(g.height): content = g._grid[i][j] @@ -41,23 +75,23 @@ def portray(g): data = agent_portrayal(agent) x.append(i) y.append(j) - if "size" in data: - s.append(data["size"]) - if "color" in data: - c.append(data["color"]) - out = {"x": x, "y": y} - # This is the default value for the marker size, which auto-scales - # according to the grid area. - out["s"] = (180 / max(g.width, g.height)) ** 2 - if len(s) > 0: - out["s"] = s - if len(c) > 0: - out["c"] = c + + # This is the default value for the marker size, which auto-scales + # according to the grid area. + default_size = (180 / max(g.width, g.height)) ** 2 + # establishing a default prevents misalignment if some agents are not given size, color, etc. + size = data.get("size", default_size) + s.append(size) + color = data.get("color", "tab:blue") + c.append(color) + mark = data.get("shape", "o") + m.append(mark) + out = {"x": x, "y": y, "s": s, "c": c, "m": m} return out space_ax.set_xlim(-1, space.width) space_ax.set_ylim(-1, space.height) - space_ax.scatter(**portray(space)) + _split_and_scatter(portray(space), space_ax) def _draw_network_grid(space, space_ax, agent_portrayal): @@ -77,20 +111,23 @@ def portray(space): y = [] s = [] # size c = [] # color + m = [] # shape for agent in space._agent_to_index: data = agent_portrayal(agent) _x, _y = agent.pos x.append(_x) y.append(_y) - if "size" in data: - s.append(data["size"]) - if "color" in data: - c.append(data["color"]) - out = {"x": x, "y": y} - if len(s) > 0: - out["s"] = s - if len(c) > 0: - out["c"] = c + + # This is matplotlib's default marker size + default_size = 20 + # establishing a default prevents misalignment if some agents are not given size, color, etc. + size = data.get("size", default_size) + s.append(size) + color = data.get("color", "tab:blue") + c.append(color) + mark = data.get("shape", "o") + m.append(mark) + out = {"x": x, "y": y, "s": s, "c": c, "m": m} return out # Determine border style based on space.torus @@ -110,7 +147,7 @@ def portray(space): space_ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding) # Portray and scatter the agents in the space - space_ax.scatter(**portray(space)) + _split_and_scatter(portray(space), space_ax) @solara.component diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index 4f1598846e5..375cbce8fb8 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -103,7 +103,8 @@ def SolaraViz( model_params: Parameters for initializing the model measures: List of callables or data attributes to plot name: Name for display - agent_portrayal: Options for rendering agents (dictionary) + agent_portrayal: Options for rendering agents (dictionary); + Default drawer supports custom `"size"`, `"color"`, and `"shape"`. space_drawer: Method to render the agent space for the model; default implementation is the `SpaceMatplotlib` component; simulations with no space to visualize should