Skip to content

Commit

Permalink
deploy: 721fdee
Browse files Browse the repository at this point in the history
  • Loading branch information
fracape committed May 3, 2024
1 parent 8e9be4c commit 3fe0a29
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
32 changes: 19 additions & 13 deletions _modules/compressai/entropy_models/entropy_models.html
Original file line number Diff line number Diff line change
Expand Up @@ -650,20 +650,24 @@ <h1>Source code for compressai.entropy_models.entropy_models</h1><div class="hig
<span class="n">scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_scale</span> <span class="o">**</span> <span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">channels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span>

<span class="bp">self</span><span class="o">.</span><span class="n">matrices</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ParameterList</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">biases</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ParameterList</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">factors</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ParameterList</span><span class="p">()</span>

<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
<span class="n">init</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">expm1</span><span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="n">scale</span> <span class="o">/</span> <span class="n">filters</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]))</span>
<span class="n">matrix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">filters</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span> <span class="n">filters</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
<span class="n">matrix</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">fill_</span><span class="p">(</span><span class="n">init</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_parameter</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;_matrix</span><span class="si">{</span><span class="n">i</span><span class="si">:</span><span class="s2">d</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">matrix</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">matrices</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">matrix</span><span class="p">))</span>

<span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">filters</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">uniform_</span><span class="p">(</span><span class="n">bias</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_parameter</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;_bias</span><span class="si">{</span><span class="n">i</span><span class="si">:</span><span class="s2">d</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">bias</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">biases</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">bias</span><span class="p">))</span>

<span class="k">if</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">):</span>
<span class="n">factor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">filters</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">zeros_</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_parameter</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;_factor</span><span class="si">{</span><span class="n">i</span><span class="si">:</span><span class="s2">d</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">factor</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">factors</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">factor</span><span class="p">))</span>

<span class="bp">self</span><span class="o">.</span><span class="n">quantiles</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">init</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">init_scale</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_scale</span><span class="p">])</span>
Expand Down Expand Up @@ -723,24 +727,23 @@ <h1>Source code for compressai.entropy_models.entropy_models</h1><div class="hig
<span class="c1"># TorchScript not yet working (nn.Mmodule indexing not supported)</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">inputs</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
<span class="n">matrix</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;_matrix</span><span class="si">{</span><span class="n">i</span><span class="si">:</span><span class="s2">d</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">matrix</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">matrices</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="k">if</span> <span class="n">stop_gradient</span><span class="p">:</span>
<span class="n">matrix</span> <span class="o">=</span> <span class="n">matrix</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">softplus</span><span class="p">(</span><span class="n">matrix</span><span class="p">),</span> <span class="n">logits</span><span class="p">)</span>

<span class="n">bias</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;_bias</span><span class="si">{</span><span class="n">i</span><span class="si">:</span><span class="s2">d</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">biases</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="k">if</span> <span class="n">stop_gradient</span><span class="p">:</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
<span class="n">logits</span> <span class="o">+=</span> <span class="n">bias</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span> <span class="o">+</span> <span class="n">bias</span>

<span class="k">if</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">):</span>
<span class="n">factor</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;_factor</span><span class="si">{</span><span class="n">i</span><span class="si">:</span><span class="s2">d</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">factor</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">factors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="k">if</span> <span class="n">stop_gradient</span><span class="p">:</span>
<span class="n">factor</span> <span class="o">=</span> <span class="n">factor</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
<span class="n">logits</span> <span class="o">+=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span>
<span class="k">return</span> <span class="n">logits</span>

<span class="nd">@torch</span><span class="o">.</span><span class="n">jit</span><span class="o">.</span><span class="n">unused</span>
<span class="k">def</span> <span class="nf">_likelihood</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">stop_gradient</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
Expand All @@ -758,10 +761,13 @@ <h1>Source code for compressai.entropy_models.entropy_models</h1><div class="hig

<span class="k">if</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">jit</span><span class="o">.</span><span class="n">is_scripting</span><span class="p">():</span>
<span class="c1"># x from B x C x ... to C x B x ...</span>
<span class="n">perm</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
<span class="n">perm</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">perm</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">perm</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">perm</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># Compute inverse permutation</span>
<span class="n">inv_perm</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">))[</span><span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">perm</span><span class="p">)]</span>
<span class="n">perm</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
<span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="n">inv_perm</span> <span class="o">=</span> <span class="n">perm</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span>
<span class="c1"># TorchScript in 2D for static inference</span>
Expand Down
3 changes: 2 additions & 1 deletion _modules/compressai/models/base.html
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ <h1>Source code for compressai.models.base</h1><div class="highlight"><pre>

<span class="kn">from</span> <span class="nn">compressai.entropy_models</span> <span class="kn">import</span> <span class="n">EntropyBottleneck</span><span class="p">,</span> <span class="n">GaussianConditional</span>
<span class="kn">from</span> <span class="nn">compressai.latent_codecs</span> <span class="kn">import</span> <span class="n">LatentCodec</span>
<span class="kn">from</span> <span class="nn">compressai.models.utils</span> <span class="kn">import</span> <span class="n">update_registered_buffers</span>
<span class="kn">from</span> <span class="nn">compressai.models.utils</span> <span class="kn">import</span> <span class="n">remap_old_keys</span><span class="p">,</span> <span class="n">update_registered_buffers</span>

<span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span>
<span class="s2">&quot;CompressionModel&quot;</span><span class="p">,</span>
Expand Down Expand Up @@ -395,6 +395,7 @@ <h1>Source code for compressai.models.base</h1><div class="highlight"><pre>
<span class="p">[</span><span class="s2">&quot;_quantized_cdf&quot;</span><span class="p">,</span> <span class="s2">&quot;_offset&quot;</span><span class="p">,</span> <span class="s2">&quot;_cdf_length&quot;</span><span class="p">],</span>
<span class="n">state_dict</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">state_dict</span> <span class="o">=</span> <span class="n">remap_old_keys</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">)</span>

<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">GaussianConditional</span><span class="p">):</span>
<span class="n">update_registered_buffers</span><span class="p">(</span>
Expand Down

0 comments on commit 3fe0a29

Please sign in to comment.