Skip to content

Commit 066dd42

Browse files
Circle CICircle CI
authored andcommitted
CircleCI update of dev docs (3269).
1 parent 51d92a1 commit 066dd42

File tree

516 files changed

+104193
-102678
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

516 files changed

+104193
-102678
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# Partial Wasserstein in 1D\n\nThis script demonstrates how to compute and visualize the Partial Wasserstein distance between two 1D discrete distributions using `ot.partial.partial_wasserstein_1d`.\n\nWe illustrate the intermediate transport plans for all `k = 1...n`, where `n = min(len(x_a), len(x_b))`.\n"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {
14+
"collapsed": false
15+
},
16+
"outputs": [],
17+
"source": [
18+
"# sphinx_gallery_thumbnail_number = 5\n\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom ot.partial import partial_wasserstein_1d\n\n\ndef plot_partial_transport(\n ax, x_a, x_b, indices_a=None, indices_b=None, marginal_costs=None\n):\n y_a = np.ones_like(x_a)\n y_b = -np.ones_like(x_b)\n min_min = min(x_a.min(), x_b.min())\n max_max = max(x_a.max(), x_b.max())\n\n ax.plot([min_min - 1, max_max + 1], [1, 1], \"k-\", lw=0.5, alpha=0.5)\n ax.plot([min_min - 1, max_max + 1], [-1, -1], \"k-\", lw=0.5, alpha=0.5)\n\n # Plot transport lines\n if indices_a is not None and indices_b is not None:\n subset_a = np.sort(x_a[indices_a])\n subset_b = np.sort(x_b[indices_b])\n\n for x_a_i, x_b_j in zip(subset_a, subset_b):\n ax.plot([x_a_i, x_b_j], [1, -1], \"k--\", alpha=0.7)\n\n # Plot all points\n ax.plot(x_a, y_a, \"o\", color=\"C0\", label=\"x_a\", markersize=8)\n ax.plot(x_b, y_b, \"o\", color=\"C1\", label=\"x_b\", markersize=8)\n\n if marginal_costs is not None:\n k = len(marginal_costs)\n ax.set_title(\n f\"Partial Transport - k = {k}, Cumulative Cost = {sum(marginal_costs):.2f}\",\n fontsize=16,\n )\n else:\n ax.set_title(\"Original 1D Discrete Distributions\", fontsize=16)\n ax.legend(loc=\"upper right\", fontsize=14)\n ax.set_yticks([])\n ax.set_xticks([])\n ax.set_ylim(-2, 2)\n ax.set_xlim(min(x_a.min(), x_b.min()) - 1, max(x_a.max(), x_b.max()) + 1)\n ax.axis(\"off\")\n\n\n# Simulate two 1D discrete distributions\nnp.random.seed(0)\nn = 6\nx_a = np.sort(np.random.uniform(0, 10, size=n))\nx_b = np.sort(np.random.uniform(0, 10, size=n))\n\n# Plot original distributions\nplt.figure(figsize=(6, 2))\nplot_partial_transport(plt.gca(), x_a, x_b)\nplt.show()"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"metadata": {
25+
"collapsed": false
26+
},
27+
"outputs": [],
28+
"source": [
29+
"indices_a, indices_b, marginal_costs = partial_wasserstein_1d(x_a, x_b)\n\n# Compute cumulative cost\ncumulative_costs = np.cumsum(marginal_costs)\n\n# Visualize all partial transport plans\nfor k in range(n):\n plt.figure(figsize=(6, 2))\n plot_partial_transport(\n plt.gca(),\n x_a,\n x_b,\n indices_a[: k + 1],\n indices_b[: k + 1],\n marginal_costs[: k + 1],\n )\n plt.show()"
30+
]
31+
}
32+
],
33+
"metadata": {
34+
"kernelspec": {
35+
"display_name": "Python 3",
36+
"language": "python",
37+
"name": "python3"
38+
},
39+
"language_info": {
40+
"codemirror_mode": {
41+
"name": "ipython",
42+
"version": 3
43+
},
44+
"file_extension": ".py",
45+
"mimetype": "text/x-python",
46+
"name": "python",
47+
"nbconvert_exporter": "python",
48+
"pygments_lexer": "ipython3",
49+
"version": "3.10.18"
50+
}
51+
},
52+
"nbformat": 4,
53+
"nbformat_minor": 0
54+
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)