DART is a novel optimization algorithm designed for deep learning, combining the strengths of Bayesian inference, adaptive learning rates, and momentum-based methods. It advances beyond static or cyclic schedules by integrating a probabilistic learning rate framework, sampling directly from a Dirichlet distribution to promote dynamic, multimodal learning behavior.
- Dirichlet-distributed learning rates enabling multimodal, adaptive exploration.
- Implicit reparameterization gradients for backpropagation through sampling.
- Adaptive moment estimation inspired by Adam, using gradients w.r.t. Dirichlet concentration parameters.
- Bias correction to ensure stable updates early in training.
- Improved convergence and interpretability in stochastic environments.
This work extends findings by:
- Loschitov & Hutter on warm restarts in SGD,
- An et al. and Yu et al. on cyclic learning rate schedules,
- Kingma & Welling on variational inference,
- Joo et al. on the Dirichlet Variational Autoencoder (DirVAE).
Core Concept: DART replaces fixed or hand-tuned learning rate schedules with learnable probabilistic distributions. At each step, candidate learning rates are sampled from a Dirichlet distribution, and adjusted based on backpropagated gradient information.
For each parameter tensor, we sample candidate learning rates from a Dirichlet distribution using the reparameterization trick:
Step 1: Gamma Sampling
Step 2: Dirichlet Construction
Step 3: Learning Rate Scaling $$\text{lr}i = \text{lr}{\min} + (\text{lr}{\max} - \text{lr}{\min}) \times \pi_i$$
Where:
-
$\alpha_i$ are the concentration parameters (learnable) -
$K$ is the number of parameters - $\text{lr}{\min}$ and $\text{lr}{\max}$ define the learning rate bounds
The concentration parameters are updated using gradient information through a momentum-based approach inspired by Adam:
Moment Estimates:
Bias Correction:
Parameter Update:
The key innovation is maintaining gradient flow through the sampling process:
This enables the optimizer to learn which learning rate distributions work best for different parameters.
The Dirichlet distribution provides several key advantages:
Probability Density Function:
where
Expected Value:
Variance:
To enable gradient flow through the stochastic sampling, we use the reparameterization trick:
This allows us to compute gradients with respect to the concentration parameters
The computational complexity of DART is:
-
Sampling:
$O(K)$ where$K$ is the number of parameters -
Gradient Computation:
$O(K)$ for concentration parameter updates -
Memory:
$O(K)$ for storing concentration parameters and moment estimates
The overall complexity is comparable to Adam while providing enhanced exploration capabilities.
DART offers several theoretical advantages over traditional optimizers:
Unlike fixed learning rates, DART's Dirichlet sampling enables exploration across multiple learning rate modes simultaneously:
The variance of learning rates adapts based on concentration parameters:
Concentration parameters are updated using gradient information:
After only 60 epochs on the MNIST dataset using a basic MLP, DART achieved:
-
Final Loss:
$\mathcal{L} = 0.2776$ - Training Stability: Reduced variance in loss trajectories
- Convergence Speed: Faster convergence compared to fixed learning rates
- Parameter Efficiency: Better utilization of different learning rates across layers
The results demonstrate the potential of probabilistic learning rate adaptation in deep learning optimization.
Clone the repo:
git clone https://github.com/maticos-dev/dart-optimizer.git
cd dart-optimizerInstall Dependencies:
pip install -r requirements.txtOr install directly:
pip install -e .import torch
import torch.nn as nn
from dartopt import Dart
from dartopt.utils import MLP, Trainer, DartDataBuilder
# Create model and data
model = MLP(input_size=784, output_size=10)
X, y = torch.randn(1000, 784), torch.randint(0, 10, (1000,))
dataset = DartDataBuilder(X, y, device='cpu')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)
# Initialize DART optimizer
optimizer = Dart(
model.parameters(),
lr=1e-3, # Base learning rate
alpha=1.0, # Dirichlet concentration parameter
lr_min=1e-6, # Minimum learning rate
lr_max=1e-1, # Maximum learning rate
betas=(0.9, 0.999) # Adam-style momentum parameters
)
# Training
trainer = Trainer(dataloader, num_epochs=60)
criterion = nn.CrossEntropyLoss()
history = trainer.train(model, optimizer, criterion)# Custom parameter groups with different settings
optimizer = Dart([
{'params': model.fc1.parameters(), 'lr': 1e-3, 'alpha': 2.0},
{'params': model.fc2.parameters(), 'lr': 5e-4, 'alpha': 1.5},
{'params': model.fc3.parameters(), 'lr': 1e-4, 'alpha': 1.0}
])
# Monitor learning rate samples
lr_samples = optimizer.get_lr_samples()
concentration_params = optimizer.get_concentration_params()lr: Base learning rate for scaling the Dirichlet samplesalpha: Initial concentration parameter for the Dirichlet distributionlr_min/lr_max: Bounds for the sampled learning ratesbetas: Momentum parameters for concentration updates (β₁, β₂)eps: Numerical stability termweight_decay: L2 regularization coefficient
If you use DART in academic work, please cite the following papers, whose insights played an outsize role in the development of this probabilistic optimizer:
- Kingma & Welling, Auto-Encoding Variational Bayes (2014)
- Loshchilov & Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts (2017)
- Joo et al., Dirichlet Variational Autoencoder (2019)
While traditional optimizers use deterministic rules, DART introduces informed randomness backed by Bayesian theory. It enhances:
- Exploration via sampling
- Adaptation via gradients through distribution parameters
- Interpretability by representing biases as learnable distributions