1- # noqa: D100
21import warnings
32from collections .abc import Callable
43from dataclasses import fields
@@ -75,6 +74,7 @@ def collect_agent_data(
7574 "stroke" : [], # Stroke color
7675 "strokeWidth" : [],
7776 "filled" : [],
77+ "tooltip" : [],
7878 }
7979
8080 # Import here to avoid circular import issues
@@ -129,6 +129,7 @@ def collect_agent_data(
129129 linewidths = dict_data .pop (
130130 "linewidths" , style_fields .get ("linewidths" )
131131 ),
132+ tooltip = dict_data .pop ("tooltip" , None ),
132133 )
133134 if dict_data :
134135 ignored_keys = list (dict_data .keys ())
@@ -184,6 +185,7 @@ def collect_agent_data(
184185 # FIXME: Make filled user-controllable
185186 filled_value = True
186187 arguments ["filled" ].append (filled_value )
188+ arguments ["tooltip" ].append (aps .tooltip )
187189
188190 final_data = {}
189191 for k , v in arguments .items ():
@@ -199,87 +201,84 @@ def collect_agent_data(
199201
200202 return final_data
201203
204+
205+
202206 def draw_agents (
203207 self , arguments , chart_width : int = 450 , chart_height : int = 350 , ** kwargs
204208 ):
205- """Draw agents using Altair backend.
206-
207- Args:
208- arguments: Dictionary containing agent data arrays.
209- chart_width: Width of the chart.
210- chart_height: Height of the chart.
211- **kwargs: Additional keyword arguments for customization.
212- Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
213-
214- Returns:
215- alt.Chart: The Altair chart representing the agents, or None if no agents.
216- """
209+ """Draw agents using Altair backend."""
217210 if arguments ["loc" ].size == 0 :
218211 return None
219212
220- # To get a continuous scale for color the domain should be between [0, 1]
221- # that's why changing the the domain of strokeWidth beforehand.
222- stroke_width = [data / 10 for data in arguments ["strokeWidth" ]]
223-
224- # Agent data preparation
225- df_data = {
226- "x" : arguments ["loc" ][:, 0 ],
227- "y" : arguments ["loc" ][:, 1 ],
228- "size" : arguments ["size" ],
229- "shape" : arguments ["shape" ],
230- "opacity" : arguments ["opacity" ],
231- "strokeWidth" : stroke_width ,
232- "original_color" : arguments ["color" ],
233- "is_filled" : arguments ["filled" ],
234- "original_stroke" : arguments ["stroke" ],
235- }
236- df = pd .DataFrame (df_data )
237-
238- # To ensure distinct shapes according to agent portrayal
239- unique_shape_names_in_data = df ["shape" ].unique ().tolist ()
240-
241- fill_colors = []
242- stroke_colors = []
243- for i in range (len (df )):
244- filled = df ["is_filled" ][i ]
245- main_color = df ["original_color" ][i ]
246- stroke_spec = (
247- df ["original_stroke" ][i ]
248- if isinstance (df ["original_stroke" ][i ], str )
249- else None
250- )
251- if filled :
252- fill_colors .append (main_color )
253- stroke_colors .append (stroke_spec )
213+ # Prepare a list of dictionaries, which is a robust way to create a DataFrame
214+ records = []
215+ for i in range (len (arguments ["loc" ])):
216+ record = {
217+ "x" : arguments ["loc" ][i ][0 ],
218+ "y" : arguments ["loc" ][i ][1 ],
219+ "size" : arguments ["size" ][i ],
220+ "shape" : arguments ["shape" ][i ],
221+ "opacity" : arguments ["opacity" ][i ],
222+ "strokeWidth" : arguments ["strokeWidth" ][i ] / 10 , # Scale for continuous domain
223+ "original_color" : arguments ["color" ][i ],
224+ }
225+ # Add tooltip data if available
226+ tooltip = arguments ["tooltip" ][i ]
227+ if tooltip :
228+ record .update (tooltip )
229+
230+ # Determine fill and stroke colors
231+ if arguments ["filled" ][i ]:
232+ record ["viz_fill_color" ] = arguments ["color" ][i ]
233+ record ["viz_stroke_color" ] = arguments ["stroke" ][i ] if isinstance (arguments ["stroke" ][i ], str ) else None
254234 else :
255- fill_colors .append (None )
256- stroke_colors .append (main_color )
257- df ["viz_fill_color" ] = fill_colors
258- df ["viz_stroke_color" ] = stroke_colors
235+ record ["viz_fill_color" ] = None
236+ record ["viz_stroke_color" ] = arguments ["color" ][i ]
237+
238+ records .append (record )
239+
240+ df = pd .DataFrame (records )
241+
242+ # Ensure all columns that should be numeric are, handling potential Nones
243+ numeric_cols = ['x' , 'y' , 'size' , 'opacity' , 'strokeWidth' , 'original_color' ]
244+ for col in numeric_cols :
245+ if col in df .columns :
246+ df [col ] = pd .to_numeric (df [col ], errors = 'coerce' )
247+
248+
249+ # Get tooltip keys from the first valid record
250+ tooltip_list = ["x" , "y" ]
251+ # This is the corrected line:
252+ if any (t is not None for t in arguments ["tooltip" ]):
253+ first_valid_tooltip = next ((t for t in arguments ["tooltip" ] if t ), None )
254+ if first_valid_tooltip :
255+ tooltip_list .extend (first_valid_tooltip .keys ())
259256
260257 # Extract additional parameters from kwargs
261- # FIXME: Add more parameters to kwargs
262258 title = kwargs .pop ("title" , "" )
263259 xlabel = kwargs .pop ("xlabel" , "" )
264260 ylabel = kwargs .pop ("ylabel" , "" )
265-
266- # Tooltip list for interactivity
267- # FIXME: Add more fields to tooltip (preferably from agent_portrayal)
268- tooltip_list = ["x" , "y" ]
261+ legend_title = kwargs .pop ("legend_title" , "Color" )
269262
270263 # Handle custom colormapping
271264 cmap = kwargs .pop ("cmap" , "viridis" )
272265 vmin = kwargs .pop ("vmin" , None )
273266 vmax = kwargs .pop ("vmax" , None )
274267
275- color_is_numeric = np . issubdtype (df ["original_color" ]. dtype , np . number )
268+ color_is_numeric = pd . api . types . is_numeric_dtype (df ["original_color" ])
276269 if color_is_numeric :
277270 color_min = vmin if vmin is not None else df ["original_color" ].min ()
278271 color_max = vmax if vmax is not None else df ["original_color" ].max ()
279272
280273 fill_encoding = alt .Fill (
281274 "original_color:Q" ,
282275 scale = alt .Scale (scheme = cmap , domain = [color_min , color_max ]),
276+ legend = alt .Legend (
277+ title = legend_title ,
278+ orient = "right" ,
279+ type = "gradient" ,
280+ gradientLength = 200 ,
281+ ),
283282 )
284283 else :
285284 fill_encoding = alt .Fill (
@@ -290,6 +289,7 @@ def draw_agents(
290289
291290 # Determine space dimensions
292291 xmin , xmax , ymin , ymax = self .space_drawer .get_viz_limits ()
292+ unique_shape_names_in_data = df ["shape" ].dropna ().unique ().tolist ()
293293
294294 chart = (
295295 alt .Chart (df )
@@ -316,16 +316,10 @@ def draw_agents(
316316 ),
317317 title = "Shape" ,
318318 ),
319- opacity = alt .Opacity (
320- "opacity:Q" ,
321- title = "Opacity" ,
322- scale = alt .Scale (domain = [0 , 1 ], range = [0 , 1 ]),
323- ),
319+ opacity = alt .Opacity ("opacity:Q" , title = "Opacity" , scale = alt .Scale (domain = [0 , 1 ], range = [0 , 1 ])),
324320 fill = fill_encoding ,
325321 stroke = alt .Stroke ("viz_stroke_color:N" , scale = None ),
326- strokeWidth = alt .StrokeWidth (
327- "strokeWidth:Q" , scale = alt .Scale (domain = [0 , 1 ])
328- ),
322+ strokeWidth = alt .StrokeWidth ("strokeWidth:Q" , scale = alt .Scale (domain = [0 , 1 ])),
329323 tooltip = tooltip_list ,
330324 )
331325 .properties (title = title , width = chart_width , height = chart_height )
@@ -437,4 +431,4 @@ def draw_propertylayer(
437431 main_charts .append (current_chart )
438432
439433 base = alt .layer (* main_charts ).resolve_scale (color = "independent" )
440- return base
434+ return base
0 commit comments