Skip to content

Commit 93d3a98

Browse files
authored
Merge pull request #219 from lucasimi/ui-to-plotly
UI to plotly
2 parents 2207be8 + a7bc234 commit 93d3a98

18 files changed

+500
-300
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ and in the
3939

4040
- **Flexible visualization**
4141

42-
Multiple visualization backends supported (e.g., Plotly, Matplotlib) for
42+
Multiple visualization backends supported (Plotly, Matplotlib, PyVis) for
4343
generating high-quality Mapper graph representations with adjustable
4444
layouts and styling.
4545

app/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ numpy>=1.25.2,<2.0.0
33
scikit-learn>=1.5.0,<1.6.0
44
umap-learn>=0.5.7,<0.6.0
55
pandas>=2.1.0,<3.0.0
6-
tda-mapper>=0.10.0,<0.11.0
6+
tda-mapper>=0.11.0,<0.12.0
77
plotly>=6.0.0,<7.0.0

app/streamlit_app.py

Lines changed: 46 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from sklearn.decomposition import PCA
2323
from umap import UMAP
2424

25+
from tdamapper._plot_plotly import _marker_size
26+
from tdamapper.core import aggregate_graph
2527
from tdamapper.cover import BallCover, CubicalCover
2628
from tdamapper.learn import MapperAlgorithm
2729
from tdamapper.plot import MapperPlot
@@ -100,6 +102,16 @@
100102

101103
V_CMAP_TWILIGHT = "Twilight (Cyclic)"
102104

105+
V_CMAPS = {
106+
V_CMAP_JET: "Jet",
107+
V_CMAP_VIRIDIS: "Viridis",
108+
V_CMAP_CIVIDIS: "Cividis",
109+
V_CMAP_SPECTRAL: "Spectral",
110+
V_CMAP_PORTLAND: "Portland",
111+
V_CMAP_HSV: "HSV",
112+
V_CMAP_TWILIGHT: "Twilight",
113+
}
114+
103115
GIT_REPO_URL = "https://github.com/lucasimi/tda-mapper-python"
104116

105117
ICON_URL = f"{GIT_REPO_URL}/raw/main/docs/source/logos/tda-mapper-logo-icon.png"
@@ -165,14 +177,6 @@ def _fix_data(data):
165177
return df
166178

167179

168-
def _get_dim(fig):
169-
dim = 2
170-
for trace in fig.data:
171-
if "3d" in trace.type:
172-
dim = 3
173-
return dim
174-
175-
176180
def _get_graph_no_attribs(graph):
177181
graph_no_attribs = nx.Graph()
178182
graph_no_attribs.add_nodes_from(graph.nodes())
@@ -561,50 +565,6 @@ def plot_agg_input_section():
561565
return agg, agg_name
562566

563567

564-
def plot_cmap_input_section():
565-
cmap_type = st.selectbox(
566-
"Colormap",
567-
options=[
568-
V_CMAP_JET,
569-
V_CMAP_VIRIDIS,
570-
V_CMAP_CIVIDIS,
571-
V_CMAP_PORTLAND,
572-
V_CMAP_SPECTRAL,
573-
V_CMAP_HSV,
574-
V_CMAP_TWILIGHT,
575-
],
576-
)
577-
cmap = None
578-
if cmap_type == V_CMAP_JET:
579-
cmap = "Jet"
580-
elif cmap_type == V_CMAP_VIRIDIS:
581-
cmap = "Viridis"
582-
elif cmap_type == V_CMAP_CIVIDIS:
583-
cmap = "Cividis"
584-
elif cmap_type == V_CMAP_PORTLAND:
585-
cmap = "Portland"
586-
elif cmap_type == V_CMAP_SPECTRAL:
587-
cmap = "Spectral"
588-
elif cmap_type == V_CMAP_HSV:
589-
cmap = "HSV"
590-
elif cmap_type == V_CMAP_TWILIGHT:
591-
cmap = "Twilight"
592-
return cmap
593-
594-
595-
def plot_color_input_section(df_X, df_y):
596-
X_cols = list(df_X.columns)
597-
y_cols = list(df_y.columns)
598-
col_feat = st.selectbox(
599-
"Color",
600-
options=y_cols + X_cols,
601-
)
602-
if col_feat in X_cols:
603-
return df_X[col_feat].to_numpy(), col_feat
604-
elif col_feat in y_cols:
605-
return df_y[col_feat].to_numpy(), col_feat
606-
607-
608568
@st.cache_data(
609569
hash_funcs={
610570
"networkx.classes.graph.Graph": lambda g: _encode_graph(
@@ -654,15 +614,13 @@ def mapper_plot_section(mapper_graph):
654614
hash_funcs={"tdamapper.plot.MapperPlot": lambda mp: mp.positions},
655615
show_spinner="Rendering Mapper",
656616
)
657-
def compute_mapper_fig(
658-
mapper_plot, colors, node_size, cmap, _agg, agg_name, colors_feat
659-
):
617+
def compute_mapper_fig(mapper_plot, colors, node_size, cmap, _agg, agg_name):
660618
logger.info("Generating Mapper figure")
661619
mapper_fig = mapper_plot.plot_plotly(
662620
colors,
663621
node_size=node_size,
664622
agg=_agg,
665-
title=f"{agg_name} of {colors_feat}",
623+
title=[f"{agg_name} of {c}" for c in colors.columns],
666624
cmap=cmap,
667625
width=600,
668626
height=600,
@@ -673,23 +631,17 @@ def compute_mapper_fig(
673631
def mapper_figure_section(df_X, df_y, mapper_plot):
674632
st.header("🎨 Plot")
675633
agg, agg_name = plot_agg_input_section()
676-
cmap = plot_cmap_input_section()
677-
colors, colors_feat = plot_color_input_section(df_X, df_y)
678-
node_size = st.slider("Node size", min_value=0.1, max_value=10.0, value=1.0)
634+
colors = pd.concat([df_y, df_X], axis=1)
679635
mapper_fig = compute_mapper_fig(
680636
mapper_plot,
681637
colors=colors,
682-
node_size=node_size,
638+
node_size=1.0,
683639
_agg=agg,
684-
cmap=cmap,
640+
cmap=["Jet", "Viridis", "Cividis"],
685641
agg_name=agg_name,
686-
colors_feat=colors_feat,
687642
)
688-
dim = _get_dim(mapper_fig)
689643
mapper_fig.update_layout(
690-
dragmode="orbit" if dim == 3 else "pan",
691-
uirevision="constant",
692-
margin=dict(b=0, l=0, r=0, t=0),
644+
margin=dict(b=5, l=5, r=5, t=5),
693645
)
694646
mapper_fig.update_xaxes(
695647
showline=False,
@@ -699,16 +651,41 @@ def mapper_figure_section(df_X, df_y, mapper_plot):
699651
scaleanchor="x",
700652
scaleratio=1,
701653
)
654+
702655
return mapper_fig
703656

704657

705-
def mapper_rendering_section(mapper_graph, mapper_fig):
658+
def _compute_colors_agg(mapper_plot, df_X, df_y, col_feat, agg):
659+
X_cols = list(df_X.columns)
660+
y_cols = list(df_y.columns)
661+
colors = np.array([])
662+
if col_feat in X_cols:
663+
colors = df_X[col_feat].to_numpy()
664+
elif col_feat in y_cols:
665+
colors = df_y[col_feat].to_numpy()
666+
return aggregate_graph(colors, mapper_plot.graph, agg)
667+
668+
669+
def _edge_colors(mapper_plot, df_X, df_y, col_feat, agg):
670+
colors_avg = []
671+
colors_agg = _compute_colors_agg(mapper_plot, df_X, df_y, col_feat, agg)
672+
for edge in mapper_plot.graph.edges():
673+
c0, c1 = colors_agg[edge[0]], colors_agg[edge[1]]
674+
colors_avg.append(c0)
675+
colors_avg.append(c1)
676+
colors_avg.append(c1)
677+
return colors_avg
678+
679+
680+
def mapper_rendering_section(mapper_fig):
706681
config = {
707682
"scrollZoom": True,
708683
"displaylogo": False,
709684
"modeBarButtonsToRemove": ["zoom", "pan"],
710685
}
711-
st.plotly_chart(mapper_fig, use_container_width=True, config=config)
686+
st.plotly_chart(
687+
mapper_fig, use_container_width=True, config=config, key="mapper_plot"
688+
)
712689

713690

714691
def data_summary_section(df_X, df_y, mapper_graph):
@@ -826,7 +803,7 @@ def main():
826803
with col_0:
827804
data_summary_section(df_X, df_y, mapper_graph)
828805
with col_1:
829-
mapper_rendering_section(mapper_graph, mapper_fig)
806+
mapper_rendering_section(mapper_fig)
830807
col_2, col_3 = st.columns([1, 3])
831808
with col_2:
832809
data_download_button(df_X, df_y)

benchmarks/benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.decomposition import PCA
1010

1111
import tdamapper as tm
12-
from tdamapper.clustering import TrivialClustering
12+
from tdamapper.core import TrivialClustering
1313

1414

1515
def _segment(cardinality, dimension, noise=0.1, start=None, end=None):
@@ -81,7 +81,7 @@ def run_gm(X, n, p):
8181

8282
def run_tm(X, n, p):
8383
t0 = time.time()
84-
tm.core.MapperAlgorithm(
84+
tm.learn.MapperAlgorithm(
8585
cover=tm.cover.CubicalCover(
8686
n_intervals=n,
8787
overlap_frac=p,

docs/source/index.rst

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,27 @@ Core features
5050

5151
- **Efficient construction**
5252

53-
Leverages optimized spatial search techniques and parallelization to accelerate the construction of Mapper graphs, supporting the analysis of high-dimensional datasets.
53+
Leverages optimized spatial search techniques and parallelization to
54+
accelerate the construction of Mapper graphs, supporting the analysis of
55+
high-dimensional datasets.
5456

5557
- **Scikit-learn integration**
5658

57-
Provides custom estimators that are fully compatible with scikit-learn's API, enabling seamless integration into scikit-learn pipelines for tasks such as dimensionality reduction, clustering, and feature extraction.
59+
Provides custom estimators that are fully compatible with scikit-learn's
60+
API, enabling seamless integration into scikit-learn pipelines for tasks
61+
such as dimensionality reduction, clustering, and feature extraction.
5862

5963
- **Flexible visualization**
6064

61-
Multiple visualization backends supported (e.g., Plotly, Matplotlib) for generating high-quality Mapper graph representations with adjustable layouts and styling.
65+
Multiple visualization backends supported (Plotly, Matplotlib, PyVis) for
66+
generating high-quality Mapper graph representations with adjustable
67+
layouts and styling.
6268

6369
- **Interactive app**
6470

65-
Provides an interactive web-based interface (via Streamlit) for dynamic exploration of Mapper graph structures, offering real-time adjustments to parameters and visualizations.
71+
Provides an interactive web-based interface (via Streamlit) for dynamic
72+
exploration of Mapper graph structures, offering real-time adjustments to
73+
parameters and visualizations.
6674

6775

6876
Background

docs/source/notebooks/circles.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
cover=CubicalCover(n_intervals=10, overlap_frac=0.3), clustering=DBSCAN()
7878
)
7979
graph = mapper.fit_transform(X, y)
80+
print(f"nodes: {len(graph.nodes())}, edges: {len(graph.edges())}")
8081

8182
# %% [markdown]
8283
# ### Visualization
@@ -92,9 +93,15 @@
9293
# %%
9394
plot = MapperPlot(graph, dim=2, iterations=60, seed=42)
9495

95-
fig = plot.plot_plotly(colors=labels, cmap="jet", agg=np.nanmean, width=600, height=600)
96+
fig = plot.plot_plotly(
97+
colors=labels,
98+
cmap=["jet", "viridis", "cividis"],
99+
agg=np.nanmean,
100+
width=600,
101+
height=600,
102+
)
96103

97-
fig.show(config={"scrollZoom": True})
104+
fig.show(config={"scrollZoom": True}, renderer="notebook_connected")
98105
# fig.write_image("circles_mean.png", width=500, height=500)
99106

100107
# %% [markdown]
@@ -107,14 +114,16 @@
107114
# data where such ambiguity is common.
108115

109116
# %%
110-
plot.plot_plotly_update(
111-
fig,
117+
118+
fig = plot.plot_plotly(
112119
colors=labels,
113-
cmap="viridis",
120+
cmap=["jet", "viridis", "cividis"],
114121
agg=np.nanstd,
122+
width=600,
123+
height=600,
115124
)
116125

117-
fig.show(config={"scrollZoom": True})
126+
fig.show(config={"scrollZoom": True}, renderer="notebook_connected")
118127
# fig.write_image("circles_std.png", width=500, height=500)
119128

120129
# %% [markdown]

docs/source/notebooks/digits.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
)
7373

7474
graph = mapper.fit_transform(X, y)
75+
print(f"nodes: {len(graph.nodes())}, edges: {len(graph.edges())}")
7576

7677
# %% [markdown]
7778
# ### Visualization
@@ -100,15 +101,15 @@ def mode(arr):
100101

101102
fig = plot.plot_plotly(
102103
colors=labels,
103-
cmap="jet",
104+
cmap=["jet", "viridis", "cividis"],
104105
agg=mode,
105106
title="mode of digits",
106107
width=600,
107108
height=600,
108109
node_size=0.5,
109110
)
110111

111-
fig.show(config={"scrollZoom": True})
112+
fig.show(config={"scrollZoom": True}, renderer="notebook_connected")
112113

113114
# %% [markdown]
114115
# We also color the nodes by the **entropy** of their digit labels, which
@@ -131,15 +132,15 @@ def entropy(arr):
131132

132133
fig = plot.plot_plotly(
133134
colors=labels,
134-
cmap="viridis",
135+
cmap=["jet", "viridis", "cividis"],
135136
agg=entropy,
136137
title="entropy of digits",
137138
width=600,
138139
height=600,
139140
node_size=0.5,
140141
)
141142

142-
fig.show(config={"scrollZoom": True})
143+
fig.show(config={"scrollZoom": True}, renderer="notebook_connected")
143144

144145
# %% [markdown]
145146
# ### Identifying high-entropy

docs/source/quickstart.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ Development
2727
2828
pip install git+https://github.com/lucasimi/tda-mapper-python
2929
30-
- To install from the latest commit of develop branch
30+
- To install from the latest commit of a branch
3131

3232
.. code:: bash
3333
34-
pip install git+https://github.com/lucasimi/tda-mapper-python@develop
34+
pip install git+https://github.com/lucasimi/tda-mapper-python@[name-of-the-branch]
3535
3636
3737
How To Use

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "tda-mapper"
7-
version = "0.10.0"
7+
version = "0.11.0"
88
description = "A simple and efficient Python implementation of Mapper algorithm for Topological Data Analysis"
99
readme = "README.md"
1010
authors = [{ name = "Luca Simi", email = "[email protected]" }]
@@ -49,6 +49,7 @@ dev = [
4949
"black[jupyter]",
5050
"isort",
5151
"flake8",
52+
"nbformat>=4.2.0",
5253
]
5354

5455
[project.urls]

0 commit comments

Comments
 (0)