@@ -1257,16 +1257,17 @@ <h1>Source code for ot.backend</h1><div class="highlight"><pre>
1257
1257
1258
1258
<div class="viewcode-block" id="Backend.kl_div">
1259
1259
<a class="viewcode-back" href="../../gen_modules/ot.backend.html#ot.backend.Backend.kl_div">[docs]</a>
1260
- <span class="k">def</span> <span class="nf">kl_div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">):</span>
1260
+ <span class="k">def</span> <span class="nf">kl_div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">mass</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n"> eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">):</span>
1261
1261
<span class="w"> </span><span class="sa">r</span><span class="sd">"""</span>
1262
- <span class="sd"> Computes the Kullback-Leibler divergence.</span>
1262
+ <span class="sd"> Computes the (Generalized) Kullback-Leibler divergence.</span>
1263
1263
1264
1264
<span class="sd"> This function follows the api from :any:`scipy.stats.entropy`.</span>
1265
1265
1266
1266
<span class="sd"> Parameter eps is used to avoid numerical errors and is added in the log.</span>
1267
1267
1268
1268
<span class="sd"> .. math::</span>
1269
- <span class="sd"> KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)</span>
1269
+ <span class="sd"> KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle</span>
1270
+ <span class="sd"> + \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle</span>
1270
1271
1271
1272
<span class="sd"> See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html</span>
1272
1273
<span class="sd"> """</span>
@@ -1908,8 +1909,11 @@ <h1>Source code for ot.backend</h1><div class="highlight"><pre>
1908
1909
1909
1910
<div class="viewcode-block" id="NumpyBackend.kl_div">
1910
1911
<a class="viewcode-back" href="../../gen_modules/ot.backend.html#ot.backend.NumpyBackend.kl_div">[docs]</a>
1911
- <span class="k">def</span> <span class="nf">kl_div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">):</span>
1912
- <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">p</span> <span class="o">/</span> <span class="n">q</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span></div>
1912
+ <span class="k">def</span> <span class="nf">kl_div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">mass</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">):</span>
1913
+ <span class="n">value</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">p</span> <span class="o">/</span> <span class="n">q</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span>
1914
+ <span class="k">if</span> <span class="n">mass</span><span class="p">:</span>
1915
+ <span class="n">value</span> <span class="o">=</span> <span class="n">value</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">q</span> <span class="o">-</span> <span class="n">p</span><span class="p">)</span>
1916
+ <span class="k">return</span> <span class="n">value</span></div>
1913
1917
1914
1918
1915
1919
<div class="viewcode-block" id="NumpyBackend.isfinite">
@@ -2550,8 +2554,11 @@ <h1>Source code for ot.backend</h1><div class="highlight"><pre>
2550
2554
2551
2555
<div class="viewcode-block" id="JaxBackend.kl_div">
2552
2556
<a class="viewcode-back" href="../../gen_modules/ot.backend.html#ot.backend.JaxBackend.kl_div">[docs]</a>
2553
- <span class="k">def</span> <span class="nf">kl_div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">):</span>
2554
- <span class="k">return</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">p</span> <span class="o">/</span> <span class="n">q</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span></div>
2557
+ <span class="k">def</span> <span class="nf">kl_div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">mass</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">):</span>
2558
+ <span class="n">value</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">p</span> <span class="o">/</span> <span class="n">q</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span>
2559
+ <span class="k">if</span> <span class="n">mass</span><span class="p">:</span>
2560
+ <span class="n">value</span> <span class="o">=</span> <span class="n">value</span> <span class="o">+</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">q</span> <span class="o">-</span> <span class="n">p</span><span class="p">)</span>
2561
+ <span class="k">return</span> <span class="n">value</span></div>
2555
2562
2556
2563
2557
2564
<div class="viewcode-block" id="JaxBackend.isfinite">
@@ -3280,8 +3287,11 @@ <h1>Source code for ot.backend</h1><div class="highlight"><pre>
3280
3287
3281
3288
<div class="viewcode-block" id="TorchBackend.kl_div">
3282
3289
<a class="viewcode-back" href="../../gen_modules/ot.backend.html#ot.backend.TorchBackend.kl_div">[docs]</a>
3283
- <span class="k">def</span> <span class="nf">kl_div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">):</span>
3284
- <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">p</span> <span class="o">/</span> <span class="n">q</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span></div>
3290
+ <span class="k">def</span> <span class="nf">kl_div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">mass</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">):</span>
3291
+ <span class="n">value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">p</span> <span class="o">/</span> <span class="n">q</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span>
3292
+ <span class="k">if</span> <span class="n">mass</span><span class="p">:</span>
3293
+ <span class="n">value</span> <span class="o">=</span> <span class="n">value</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">q</span> <span class="o">-</span> <span class="n">p</span><span class="p">)</span>
3294
+ <span class="k">return</span> <span class="n">value</span></div>
3285
3295
3286
3296
3287
3297
<div class="viewcode-block" id="TorchBackend.isfinite">
@@ -3924,8 +3934,11 @@ <h1>Source code for ot.backend</h1><div class="highlight"><pre>
3924
3934
3925
3935
<div class="viewcode-block" id="CupyBackend.kl_div">
3926
3936
<a class="viewcode-back" href="../../gen_modules/ot.backend.html#ot.backend.CupyBackend.kl_div">[docs]</a>
3927
- <span class="k">def</span> <span class="nf">kl_div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">):</span>
3928
- <span class="k">return</span> <span class="n">cp</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span> <span class="o">*</span> <span class="n">cp</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">p</span> <span class="o">/</span> <span class="n">q</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span></div>
3937
+ <span class="k">def</span> <span class="nf">kl_div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">mass</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">):</span>
3938
+ <span class="n">value</span> <span class="o">=</span> <span class="n">cp</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span> <span class="o">*</span> <span class="n">cp</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">p</span> <span class="o">/</span> <span class="n">q</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span>
3939
+ <span class="k">if</span> <span class="n">mass</span><span class="p">:</span>
3940
+ <span class="n">value</span> <span class="o">=</span> <span class="n">value</span> <span class="o">+</span> <span class="n">cp</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">q</span> <span class="o">-</span> <span class="n">p</span><span class="p">)</span>
3941
+ <span class="k">return</span> <span class="n">value</span></div>
3929
3942
3930
3943
3931
3944
<div class="viewcode-block" id="CupyBackend.isfinite">
@@ -4591,8 +4604,11 @@ <h1>Source code for ot.backend</h1><div class="highlight"><pre>
4591
4604
4592
4605
<div class="viewcode-block" id="TensorflowBackend.kl_div">
4593
4606
<a class="viewcode-back" href="../../gen_modules/ot.backend.html#ot.backend.TensorflowBackend.kl_div">[docs]</a>
4594
- <span class="k">def</span> <span class="nf">kl_div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">):</span>
4595
- <span class="k">return</span> <span class="n">tnp</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span> <span class="o">*</span> <span class="n">tnp</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">p</span> <span class="o">/</span> <span class="n">q</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span></div>
4607
+ <span class="k">def</span> <span class="nf">kl_div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">mass</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">):</span>
4608
+ <span class="n">value</span> <span class="o">=</span> <span class="n">tnp</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p</span> <span class="o">*</span> <span class="n">tnp</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">p</span> <span class="o">/</span> <span class="n">q</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span>
4609
+ <span class="k">if</span> <span class="n">mass</span><span class="p">:</span>
4610
+ <span class="n">value</span> <span class="o">=</span> <span class="n">value</span> <span class="o">+</span> <span class="n">tnp</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">q</span> <span class="o">-</span> <span class="n">p</span><span class="p">)</span>
4611
+ <span class="k">return</span> <span class="n">value</span></div>
4596
4612
4597
4613
4598
4614
<div class="viewcode-block" id="TensorflowBackend.isfinite">
0 commit comments