Skip to content

add swarm plot to the scatter documentation #5149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: doc-prod
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions doc/python/line-and-scatter.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,126 @@ fig.update_traces(textposition="bottom right")
fig.show()
```

### Swarm (or Beeswarm) Plots

Swarm plots show the distribution of values in a column by giving each entry one dot and adjusting the y-value so that dots do not overlap and appear symmetrically around the y=0 line. They complement histograms, box plots, and violin plots. This example could be generalized to implement a swarm plot for multiple categories by adjusting the y-coordinate for each category.

```python
import pandas as pd
import plotly.express as px
import collections

def negative_1_if_count_is_odd(count):
# if this is an odd numbered entry in its bin, make its y coordinate negative
# the y coordinate of the first entry is 0, so entries 3, 5, and 7 get negative y coordinates
if count%2 == 1:
return -1
else:
return 1




def swarm(
X_series,
point_size=16,
fig_width = 800,
gap_multiplier=1.2,
bin_fraction=0.95, #bin fraction slightly undersizes the bins to avoid collisions
):
#sorting will align columns in attractive arcs rather than having columns the vary unpredicatbly in the x-dimension
X_series=X_series.copy().sort_values()


# we need to reason in terms of the marker size that is measured in px
# so we need to think about each x-coordinate as being a fraction of the way from the
# minimum X value to the maximum X value
min_x = min(X_series)
max_x = max(X_series)

list_of_rows = []
# we will count the number of points in each "bin" / vertical strip of the graph
# to be able to assign a y-coordinate that avoids overlapping
bin_counter = collections.Counter()

for x_val in X_series:
# assign this x_value to bin number
# each bin is a vertical strip slightly narrower than one marker

bin=(((fig_width*bin_fraction*(x_val-min_x))/(max_x-min_x)) // point_size)

#update the count of dots in that strip
bin_counter.update([bin])


# the collision free y coordinate gives the items in a vertical bin
# coordinates: 0, 1, -1, 2, -2, 3, -3 ... and so on to evenly spread
# their locations above and below the y-axis (we'll make a correction below to deal with even numbers of entries)
# we then scale this by the point_size*gap_multiplier to get a y coordinate in px

collision_free_y_coordinate=(bin_counter[bin]//2)*negative_1_if_count_is_odd(bin_counter[bin])*point_size*gap_multiplier
list_of_rows.append({"x":x_val,"y":collision_free_y_coordinate,"bin":bin})



for row in list_of_rows:
bin = row["bin"]
#see if we need to "look left" to avoid a possible collision
for other_row in list_of_rows:
if (other_row["bin"]==bin-1 ):
#"bubble" the entry up until we find a slot that avoids a collision
while ((other_row["y"]==row["y"])
and (((fig_width*(row["x"]-other_row["x"]))/(max_x-min_x) // point_size) < 1)):
print(row)
print(other_row)
print(((fig_width*(row["x"]-other_row["x"] ))/(max_x-min_x) // point_size))

print("updating to fix collision")
bin_counter.update([bin])
print(bin_counter[bin])
row["y"]=(bin_counter[bin]//2)*negative_1_if_count_is_odd(bin_counter[bin])*point_size*gap_multiplier
print(row["y"])

# if the number of points is even,
# move y-coordinates down to put an equal number of entries above and below the axis
for row in list_of_rows:
if bin_counter[row["bin"]]%2==0:
row["y"]-=point_size*gap_multiplier/2


df = pd.DataFrame(list_of_rows)
# one way to make this code more flexible to e.g. handle multiple categories would be to return a list of "swarmified" y coordinates here
# you could then generate "swarmified" y coordinates for each category and add category specific offsets before scatterplotting them

fig = px.scatter(
df,
x="x",
y="y",
)
#we want to suppress the y coordinate in the hover value because the y-coordinate is irrelevant/misleading
fig.update_traces(
marker_size=point_size,
#suppress the y coordinate because the y-coordinate is irrelevant
hovertemplate="<b>value</b>: %{x}",
)
# we have to set the width and height because we aim to avoid icon collisions and we specify the icon size
# in the same units as the width and height
fig.update_layout(width=fig_width, height=(point_size*max(bin_counter.values())+200))
fig.update_yaxes(
showticklabels=False, # Turn off y-axis labels
ticks='', # Remove the ticks
title=""
)
return fig



df_iris = px.data.iris() # iris is a pandas DataFrame
x = df_iris["sepal_length"]
fig = swarm(x)
fig.show()
```

## Scatter and line plots with go.Scatter

If Plotly Express does not provide a good starting point, it is possible to use [the more generic `go.Scatter` class from `plotly.graph_objects`](/python/graph-objects/). Whereas `plotly.express` has two functions `scatter` and `line`, `go.Scatter` can be used both for plotting points (makers) or lines, depending on the value of `mode`. The different options of `go.Scatter` are documented in its [reference page](https://plotly.com/python/reference/scatter/).
Expand Down