Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/python/plotly/plotly/express/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
set_mapbox_access_token,
defaults,
get_trendline_results,
IdentityMap,
Constant,
)

from . import data, colors # noqa: F401
Expand Down Expand Up @@ -95,4 +97,6 @@
"colors",
"set_mapbox_access_token",
"get_trendline_results",
"IdentityMap",
"Constant",
]
45 changes: 41 additions & 4 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,30 @@ def __init__(self):
defaults = PxDefaults()
del PxDefaults


class IdentityMap(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could these two classes be moved to a _special_inputs.py file to shorten a bit _core.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, if needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as you wish, I just feel that the length of this file makes it a bit overwhelming.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

"""
`dict`-like object which can be passed in to arguments like `color_discrete_map` to
use the provided data values as colors, rather than mapping them to colors cycled
from `color_discrete_sequence`.
"""

def __getitem__(self, key):
return key

def __contains__(self, key):
return True

def copy(self):
return self


class Constant(object):
def __init__(self, value, label=None):
self.value = value
self.label = label


MAPBOX_TOKEN = None


Expand Down Expand Up @@ -919,6 +943,8 @@ def build_dataframe(args, attrables, array_attrables):
else:
df_output[df_input.columns] = df_input[df_input.columns]

constants = dict()

# Loop over possible arguments
for field_name in attrables:
# Massaging variables
Expand Down Expand Up @@ -950,8 +976,15 @@ def build_dataframe(args, attrables, array_attrables):
"pandas MultiIndex is not supported by plotly express "
"at the moment." % field
)
# ----------------- argument is a constant ----------------------
if isinstance(argument, Constant):
col_name = _check_name_not_reserved(
str(argument.label) if argument.label is not None else field,
reserved_names,
)
constants[col_name] = argument.value
# ----------------- argument is a col name ----------------------
if isinstance(argument, str) or isinstance(
elif isinstance(argument, str) or isinstance(
argument, int
): # just a column name given as str or int
if not df_provided:
Expand Down Expand Up @@ -1032,6 +1065,9 @@ def build_dataframe(args, attrables, array_attrables):
else:
args[field_name][i] = str(col_name)

for col_name in constants:
df_output[col_name] = constants[col_name]

args["data_frame"] = df_output
return args

Expand Down Expand Up @@ -1402,9 +1438,10 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
for col, val, m in zip(grouper, group_name, grouped_mappings):
if col != one_group:
key = get_label(args, col)
mapping_labels[key] = str(val)
if m.show_in_trace_name:
trace_name_labels[key] = str(val)
if not isinstance(m.val_map, IdentityMap):
mapping_labels[key] = str(val)
if m.show_in_trace_name:
trace_name_labels[key] = str(val)
if m.variable == "animation_frame":
frame_name = val
trace_name = ", ".join(trace_name_labels.values())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,48 @@ def test_size_column():
df = px.data.tips()
fig = px.scatter(df, x=df["size"], y=df.tip)
assert fig.data[0].hovertemplate == "size=%{x}<br>tip=%{y}<extra></extra>"


def test_identity_map():
fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=["red", "blue"],
color_discrete_map=px.IdentityMap(),
)
assert fig.data[0].marker.color == "red"
assert fig.data[1].marker.color == "blue"
assert "color=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"


def test_constants():
fig = px.scatter(x=px.Constant(1), y=[1, 2])
assert fig.data[0].x[0] == 1
assert fig.data[0].x[1] == 1
assert "x=" in fig.data[0].hovertemplate

fig = px.scatter(x=px.Constant(1, label="time"), y=[1, 2])
assert fig.data[0].x[0] == 1
assert fig.data[0].x[1] == 1
assert "x=" not in fig.data[0].hovertemplate
assert "time=" in fig.data[0].hovertemplate

fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=px.Constant("red", label="the_identity_label"),
hover_data=[px.Constant("data", label="the_data")],
color_discrete_map=px.IdentityMap(),
)
assert fig.data[0].marker.color == "red"
assert fig.data[0].customdata[0][0] == "data"
assert fig.data[1].marker.color == "red"
assert "color=" not in fig.data[0].hovertemplate
assert "the_identity_label=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert "the_data=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"