Skip to content

Commit 702cd5e

Browse files
committed
updated readme and docs
1 parent 999271c commit 702cd5e

File tree

4 files changed

+66
-58
lines changed

4 files changed

+66
-58
lines changed

README.md

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,65 +2,61 @@
22

33
# torchzero
44

5-
This is a work-in-progress optimizers library for pytorch with composable zeroth, first, second order and quasi newton methods, gradient approximation, line searches and a whole lot of other stuff.
6-
7-
Most optimizers are modular, meaning you can chain them like this:
5+
`torchzero` implements a large number of chainable optimization modules that can be chained together to create custom optimizers:
86

97
```py
10-
optimizer = torchzero.optim.Modular(model.parameters(), [*list of modules*])`
8+
import torchzero as tz
9+
10+
optimizer = tz.Modular(
11+
model.parameters(),
12+
tz.m.Adam(),
13+
tz.m.Cautious(),
14+
tz.m.LR(1e-3),
15+
tz.m.WeightDecay(1e-4)
16+
)
17+
18+
# standard training loop
19+
for batch in dataset:
20+
preds = model(batch)
21+
loss = criterion(preds)
22+
optimizer.zero_grad()
23+
optimizer.step()
1124
```
1225

13-
For example you might use `[ClipNorm(4), LR(1e-3), NesterovMomentum(0.9)]` for standard SGD with gradient clipping and nesterov momentum. Move `ClipNorm` to the end to clip the update instead of the gradients. If you don't have access to gradients, add a `RandomizedFDM()` at the beginning to approximate them via randomized finite differences. Add `Cautious()` to make the optimizer cautious.
26+
Each module takes the output of the previous module and applies a further transformation. This modular design avoids redundant code, such as reimplementing cautioning, orthogonalization, laplacian smoothing, etc for every optimizer. It is also easy to experiment with grafting, interpolation between different optimizers, and perhaps some weirder combinations like nested momentum.
1427

15-
Each new module takes previous module update and works on it. That way there is no need to reimplement stuff like laplacian smoothing for all optimizers, and it is easy to experiment with grafting, interpolation between different optimizers, and perhaps some weirder combinations like nested momentum.
28+
Modules are not limited to gradient transformations. They can perform other operations like line searches, exponential moving average (EMA) and stochastic weight averaging (SWA), gradient accumulation, gradient approximation, and more.
1629

17-
# How to use
30+
There are over 100 modules, all accessible within the `tz.m` namespace. For example, the Adam update rule is available as `tz.m.Adam`. Complete list of modules is available in [documentation](https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html).
1831

19-
All modules are defined in `torchzero.modules`. You can generally mix and match them however you want. Some pre-made optimizers are available in `torchzero.optim`.
32+
## Closure
2033

21-
Some optimizers require closure, which should look like this:
34+
Some modules and optimizers in torchzero, particularly line-search methods and gradient approximation modules, require a closure function. This is similar to how `torch.optim.LBFGS` works in PyTorch. In torchzero, closure needs to accept a boolean backward argument (though the argument can have any name). When `backward=True`, the closure should zero out old gradients using `opt.zero_grad()`, and compute new gradients using `loss.backward()`.
2235

2336
```py
2437
def closure(backward = True):
25-
preds = model(inputs)
26-
loss = loss_fn(preds, targets)
38+
preds = model(inputs)
39+
loss = loss_fn(preds, targets)
2740

28-
# if you can't call loss.backward(), and instead use gradient-free methods,
29-
# they always call closure with backward=False.
30-
# so you can remove the part below, but keep the unused backward argument.
31-
if backward:
32-
optimizer.zero_grad()
33-
loss.backward()
34-
return loss
41+
if backward:
42+
optimizer.zero_grad()
43+
loss.backward()
44+
return loss
3545

3646
optimizer.step(closure)
3747
```
3848

39-
This closure will also work with all built in pytorch optimizers, including LBFGS, all optimizers in this library, as well as most custom ones.
49+
If you intend to use gradient-free methods, `backward` argument is still required in the closure. Simply leave it unused. Gradient-free and gradient approximation methods always call closure with `backward=False`.
4050

41-
# Contents
51+
All built-in pytorch optimizers, as well as most custom ones, support closure too. So the code above will work with all other optimizers out of the box, and you can switch between different optimizers without rewriting your training loop.
4252

43-
Docs are available at [torchzero.readthedocs.io](https://torchzero.readthedocs.io/en/latest/). A preliminary list of all modules is available here <https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html#classes>. Some of the implemented algorithms:
53+
# Documentation
4454

45-
- SGD/Rprop/RMSProp/AdaGrad/Adam as composable modules. They are also tested to exactly match built in pytorch versions.
46-
- Cautious Optimizers (<https://huggingface.co/papers/2411.16085>)
47-
- Optimizer grafting (<https://openreview.net/forum?id=FpKgG31Z_i9>)
48-
- Laplacian smoothing (<https://arxiv.org/abs/1806.06317>)
49-
- Polyak momentum, nesterov momentum
50-
- Gradient norm and value clipping, gradient normalization
51-
- Gradient centralization (<https://arxiv.org/abs/2004.01461>)
52-
- Learning rate droput (<https://pubmed.ncbi.nlm.nih.gov/35286266/>).
53-
- Forward gradient (<https://arxiv.org/abs/2202.08587>)
54-
- Gradient approximation via finite difference or randomized finite difference, which includes SPSA, RDSA, FDSA and Gaussian smoothing (<https://arxiv.org/abs/2211.13566v3>)
55-
- Various line searches
56-
- Exact Newton's method (with Levenberg-Marquardt regularization), newton with hessian approximation via finite difference, subspace finite differences newton.
57-
- Directional newton via one additional forward pass
55+
For more information on how to create, use and extend torchzero modules, please refer to the documentation at [torchzero.readthedocs.io](https://torchzero.readthedocs.io/en/latest/index.html).
5856

59-
All modules should be quite fast, especially on models with many different parameters, due to `_foreach` operations.
57+
# Extra
6058

61-
I am getting to the point where I can start focusing on good docs and tests. As of now, the code should be considered experimental, untested and subject to change, so feel free but be careful if using this for actual project.
62-
63-
# Wrappers
59+
Some other optimization related things in torchzero:
6460

6561
### scipy.optimize.minimize wrapper
6662

@@ -71,12 +67,26 @@ from torchzero.optim.wrappers.scipy import ScipyMinimize
7167
opt = ScipyMinimize(model.parameters(), method = 'trust-krylov')
7268
```
7369

74-
Use as any other optimizer (make sure closure accepts `backward` argument like one from **How to use**). Note that it performs full minimization on each step.
70+
Use as any other closure-based optimizer, but make sure closure accepts `backward` argument. Note that it performs full minimization on each step.
7571

7672
### Nevergrad wrapper
7773

74+
[Nevergrad](https://github.com/facebookresearch/nevergrad) is an optimization library by facebook with an insane number of gradient free methods.
75+
7876
```py
77+
from torchzero.optim.wrappers.nevergrad import NevergradOptimizer
7978
opt = NevergradOptimizer(bench.parameters(), ng.optimizers.NGOptBase, budget = 1000)
8079
```
8180

82-
Use as any other optimizer (make sure closure accepts `backward` argument like one from **How to use**).
81+
Use as any other closure-based optimizer, but make sure closure accepts `backward` argument.
82+
83+
### NLopt wrapper
84+
85+
[NLopt](https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/) is another optimization library similar to scipy.optimize.minimize, with a large number of both gradient based and gradient free methods.
86+
87+
```py
88+
from torchzero.optim.wrappers.nlopt import NLOptOptimizer
89+
opt = NLOptOptimizer(bench.parameters(), 'LD_TNEWTON_PRECOND_RESTART', maxeval = 1000)
90+
```
91+
92+
Use as any other closure-based optimizer, but make sure closure accepts `backward` argument. Note that it performs full minimization on each step.

docs/source/FAQ.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ Using torchzero optimizers is generally similar to using built-in PyTorch optimi
5858
opt.zero_grad()
5959
6060
61-
Some modules and optimizers in torchzero, particularly line-search methods and gradient approximation modules, require a closure function. This is similar to how :code:`torch.optim.LBFGS` works in PyTorch. In torchzero, the closure function for these optimizers needs to accept an argument (we'll call it backward, though the argument can have any name). When :code:`backward=True`, the closure should zero out gradients using :code:`opt.zero_grad()`, and compute gradients using :code:`loss.backward()`.
61+
Some modules and optimizers in torchzero, particularly line-search methods and gradient approximation modules, require a closure function. This is similar to how :code:`torch.optim.LBFGS` works in PyTorch. In torchzero, closure needs to accept a boolean backward argument (though the argument can have any name). When :code:`backward=True`, the closure should zero out gradients using :code:`opt.zero_grad()`, and compute gradients using :code:`loss.backward()`.
6262

6363
Here's how a training loop with a closure looks:
6464

docs/source/implementing.rst

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@ Like in pytorch, putting all settings into :code:`defaults` dictionary allows to
1919
super().__init__(defaults)
2020
2121
22-
Note:, please don't add :code:`lr` setting to your modules. When learning rate is part of the update rule, like in Adam, I rename it to :code:`alpha` and set to 1 by default. Learning rate should be controlled by a separate :py:class:`tz.m.LR<torchzero.modules.LR>` module, this avoids unintended compounding of learning rate modifications when using learning rate schedulers and per-parameter lr settings (see :ref:`How do we handle learning rates?`).
22+
Note: please don't use :code:`lr` setting in your modules. When learning rate is part of the update rule, like in Adam, I rename it to :code:`alpha` and set to 1 by default. Learning rate should be controlled by a separate :py:class:`tz.m.LR<torchzero.modules.LR>` module, this avoids unintended compounding of learning rate modifications when using learning rate schedulers and per-parameter lr settings (see :ref:`How do we handle learning rates?`).
2323

24-
Implementing update rule
24+
Implementing the update rule
2525
=============================
26-
Now we can implement the update rule.
27-
2826
Update logic in :code:`OptimizerModule` is defined in the :code:`step` method. By default it calls :code:`_update`, which in turn calls :code:`_single_tensor_update`. You can overwrite one of those three methods depending on how much control you need.
2927

3028
Method 1. Overwriting _single_tensor_update
@@ -35,12 +33,12 @@ For most update rules overwriting `_single_tensor_update` is the most convenient
3533
:code:`_single_tensor_update` accepts the following arguments:
3634

3735
* :code:`vars`: :py:mod:`tz.core.OptimizationVars<torchzero.core.OptimizationVars>` object with various useful attributes, such as closure, list of current update tensors, loss, list of gradient tensors. For now we don't need this.
38-
* :code:`ascent`: torch.Tensor of the ascent - gradient or another modules update. This is what the update rule modifies.
39-
* :code:`param`: torch.Tensor of the parameter, useful for implementing weight decay and accessing per-parameter states.
40-
* :code:`grad`: torch.Tensor of the initial gradient, not transformed by previous modules. Sometimes gradient is never evaluated, like in gradient free methods, so this may be None.
36+
* :code:`ascent`: torch.Tensor with the ascent direction (update), which is the gradient if this module is first, or an update generated by previous module. This is what the update rule should modify and return.
37+
* :code:`param`: torch.Tensor with the parameter, useful for implementing weight decay and accessing per-parameter states.
38+
* :code:`grad`: torch.Tensor with the initial gradient, not transformed by previous modules. Useful for things like cautious optimizers that compare update sign with gradient sign. Sometimes gradient is never evaluated, like in gradient free methods, so this may be None.
4139
* per-parameter settings in any order, in the Adam example below :code:`beta1, beta2, eps, alpha`. Everything passed to :code:`defaults` will be accessible there.
4240

43-
The method should return the updated ascent tensor. Please do not update :code:`param` directly.
41+
The method should return the updated ascent direction tensor. Please do not update :code:`param` directly.
4442

4543
Here is a ready to use Adam implementation through overwriting :code:`_single_tensor_update`:
4644

@@ -94,20 +92,20 @@ Here is a ready to use Adam implementation through overwriting :code:`_single_te
9492
9593
Method 2. Overwriting _update
9694
+++++++++++++++++++++++++++++++++++++++++++++
97-
:code:`_update` is similar to :code:`_single_tensor_update`, however you get access to all ascent tensors in a single list, as opposed to looping through each element. That way you can use pytorch `_foreach_xxx <https://pytorch.org/docs/stable/torch.html#foreach-operations>`_ operations for better performance. Most modules in torchzero are implemented through overwriting `_update` and with _foreach operations.
95+
:code:`_update` is similar to :code:`_single_tensor_update`, however you get access to all ascent tensors in a single list, as opposed to looping through each element. That way you can use pytorch `_foreach_xxx <https://pytorch.org/docs/stable/torch.html#foreach-operations>`_ operations for better performance. Most modules in torchzero are implemented through overwriting :code:`_update` and with :code:`_foreach` operations.
9896

9997
:code:`update` accepts the following arguments:
10098

10199
* :code:`vars`: :py:mod:`tz.core.OptimizationVars<torchzero.core.OptimizationVars>` object with various useful attributes, such as closure, list of current update tensors, loss, list of gradient tensors. For now we don't need this.
102-
* :code:`ascent`: :py:mod:`tz.TensorList<torchzero.TensorList>` - list of tensors of the ascent direction (gradient or update) for each parameter with :code:`requires_grad = True`. :code:`TensorList` is a subclass of python list with some additional methods, but we won't use those methods for now.
100+
* :code:`ascent`: :py:mod:`tz.TensorList<torchzero.TensorList>` - list of tensors of the ascent direction (gradient or update) for each parameter with :code:`requires_grad = True`. :code:`TensorList` is a subclass of python list with some additional methods, but we won't use those methods for now. As it is a subclass of list, you can pass it directly to :code:`torch._foreach_xxx` methods.
103101

104102
The method should return the updated ascent :code:`TensorList`.
105103

106104
To make working with lists of tensors more convenient, :code:`OptimizerModule` also has some helper methods.
107105

108106
* :code:`self.get_params()`: returns list of tensors of all params with :code:`requires_grad = True`.
109-
* :code:`self.get_group_key(key)`, :code:`self.get_group_keys(keys)`: return list of values of a per-parameter setting (such as beta1, beta2, eps) for each parameter with :code:`requires_grad = True`.
110-
* :code:`self.get_state_key(key)`, :code:`self.get_state_keys(keys)`: return a list of tensors of a state (e.g. exponential average) of each parameter with :code:`requires_grad = True`, initializes the state to zeroes if it doesn't exist.
107+
* :code:`self.get_group_key(key)`, :code:`self.get_group_keys(*keys)`: return list of values of a per-parameter setting (such as beta1, beta2, eps) for each parameter with :code:`requires_grad = True`.
108+
* :code:`self.get_state_key(key)`, :code:`self.get_state_keys(*keys)`: return a list of tensors of a state (e.g. exponential average) of each parameter with :code:`requires_grad = True`, initializes the state to zeroes if it doesn't exist.
111109

112110
Here is a ready to use Adam implementation through overwriting :code:`_update` using :code:`_foreach` methods. Using a lot of :code:`_foreach_xxx` methods is not very readable, but it is fast.
113111

docs/source/introduction.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
Introduction
22
==================
33

4-
`torchzero` is a library for pytorch that offers a flexible and modular way to build optimizers for various tasks. By combining smaller, reusable modules, you can easily customize and experiment with different optimization strategies.
4+
torchzero is a library for pytorch that offers a flexible and modular way to build optimizers for various tasks. By combining smaller, reusable modules, you can easily customize and experiment with different optimization strategies.
55

6-
Each module in `torchzero` takes the output of the previous module and applies a further transformation. This modular design avoids redundant code, such as reimplementing Laplacian smoothing, cautioning, orthogonalization, etc for every optimizer. It also simplifies experimenting with advanced techniques like optimizer grafting, interpolation, and complex combinations like nested momentum.
6+
Each module takes the output of the previous module and applies a further transformation. This modular design avoids redundant code, such as reimplementing Laplacian smoothing, cautioning, orthogonalization, etc for every optimizer. It also simplifies experimenting with advanced techniques like optimizer grafting, interpolation, and complex combinations like nested momentum.
77

8-
Many modules in `torchzero` perform gradient transformations. They receive an "ascent direction," which is initially the gradient, modify it, and pass it to the next module in the chain. Typically, the first module uses the raw gradient as the starting ascent direction. However, modules are not limited to gradient transformations. They can perform other operations like line searches, exponential moving average (EMA) and stochastic weight averaging (SWA), gradient accumulation, gradient approximation, and more.
8+
Many modules perform gradient transformations. They receive an "ascent direction," which is initially the gradient, modify it, and pass it to the next module in the chain. Typically, the first module uses the raw gradient as the starting ascent direction. However, modules are not limited to gradient transformations. They can perform other operations like line searches, exponential moving average (EMA) and stochastic weight averaging (SWA), gradient accumulation, gradient approximation, and more.
99

10-
`torchzero` provides over 100 modules, all accessible within the :py:mod:`tz.m<torchzero.modular>` namespace. For example, the Adam module is available as :py:class:`tz.m.Adam<torchzero.modules.Adam>`. You can find a complete list of modules in the `torchzero` documentation: https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html.
10+
torchzero provides over 100 modules, all accessible within the :py:mod:`tz.m<torchzero.modular>` namespace. For example, the Adam module is available as :py:class:`tz.m.Adam<torchzero.modules.Adam>`. You can find a complete list of modules in the torchzero documentation: https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html.
1111

1212
To combine these modules and create a custom optimizer, use tz.Modular, and then use it as any other pytorch optimizer. Here’s an example of how to define a Cautious Adam optimizer with gradient clipping and decoupled weight decay:
1313

@@ -44,4 +44,4 @@ To combine these modules and create a custom optimizer, use tz.Modular, and then
4444
print(epoch, loss.item(), end = ' \r')
4545
4646
47-
Please head over to :ref:FAQ for more examples and information.
47+
Please head over to :ref:`FAQ` for more examples and information.

0 commit comments

Comments
 (0)