Skip to content

Conversation

@AlexandreAdam
Copy link
Owner

Major updates of the project.

  • New architectures with the EDM framework. This includes a better parametrization of the score, improvement on the layers of the neural network and training dynamics (improved EMA, adaptive weighting of the loss, and the possibility to provide a custom proposal for the time-index when training to focus on the spin-glass phase transition).
  • PostHocEMA to synthesize models after training from a checkpoint directory (requires at least 2 ema lengths)
  • Additional SDEs, including the Cosine VP, EDM VE, and the Trig SDE
  • Major refactorization of the code to make it more readable
  • New higher-order ODE and SDE solvers.
  • Computing log probabilities is now much improved and tested. Sampling with higher order SDE solver is now supported out of the box.
  • Hessian Diagonal Model for log probability estimation.
  • A suit of analytical score model is now included.
  • We now have a documentation hosted on ReadTheDocs

AlexandreAdam and others added 20 commits August 1, 2024 22:11
* Major refactoring of the class structure. Still need to finish the fit and Trainer methods, then implement HessianTrace class and method to finetune SBM with LORA weights

* Finished refactoring Trainer, made new tests for save load which are more extensive

* Added and tested HessianDiagonal class, including the second order loss function needed to train it

* Made sure weights of SBM were frozen in HessianDiagonal, moved the test in their own file and made them faster. Also added a parameter to switch between the two loss. Defaults to canonical, though potentially Meng version is better.

* Fixed bug in backward compatibility methods

* Fixed bug in backward compatibility methods v2

* Added and tested save/load methods of LoRA SBM class

* Refactored conditional branch, though still need some cleanup. Added back in the torch.no_grad() decorator around the sample method which was forgotten

* Removed getargspec() since it's deprecated in later version of python

* beta_primitive for linear schedule needed a factor 1/2

* Moved the factor of 1/2 in the mu function and explicitly in the drift function as in Song et al.

* Exposed stopping factor to the API for the Euler Maruyama method

* Forgot np

* Added Tweedie formula at the last step of sampling for denoising

* Implemented conditional branch in MLP DDPM, added tests to cover training of every models under a variety of settings

* Added tests, some error catchers and centralized the conditional branch stuff.

* Added tests for training LoRA SBM, modified the cleanup function to handle the directory for this model

* Added TC2 to flake8 to allow forward reference

* Needed both TC and TC2 for flake8 to work with forward references

* Minor fixes to the tests

* Added __all__ tag to layers, debugged DDPM for 1D and 3D, tests passing except LoRA models

* Added posterior score model with likelihood score function saved with dill

* Fixed import

* Added flake8 type checkning

* .

* fixed some imports

* Added type hint blocks for forward referencing and flake8

* Removed type EnergyModel which is instance of ScoreModel now

* removed lora posterior

* Fixed an initialization problem for LoRA when using base SBM that has a different checkpoint

* Improved test, put back the loading of optimizer checkpoints in the trainer, also made sure all parameters of a ScoreModel are tracked during training for the Posterior fine-tuning tasks

* Removed a print left in the code and added weights_only to torch.load function to remove deprecation warning

* Removed the vv from coverage
…wrapping optional in the fit method, so we can use dataset that output a full batch. Updated the tests for the layers and added test to catch new behaviors.
…dules based on the global step of the training
* Adding solvers and simple score models

* Adding simple model descriptions

* Integrate solvers into score_model sample

* Fix init for simple models

* Add unit tests for simple models

* Solver now handles log_p, I think

* sde step now doesnt include x

* Address review comments

* same update for sde solver

* fix import

* kwargs now pass through score models

* lower case names

* ode now can get P(xt)

* rename ode solver to avoid conflict

* avoid snake case

* housekeeping simple_models to analytic_models and clean sde ode solvers

* clean up solver step size

* refactor solver to propogate args

* clean up solver creation in ScoreModel

* adding solver docstrings

* update analytic model docstrings

* more on passing args. getting tests to run

* solver class can now construct its subclasses

* new unit tests for solvers

* remove comment

* add hook function to solvers

* update docstring

* propogate kwargs in sde

* Housekeeping, add type hints

* merge ananlytic_models into sbm module

* conv likelihood A may be tensor or callable

* return tweedie to score_model

* analytic models handle mu_t

* rename t_scale to sigma_t for consistency

* rename methods

* return denoise method to score_model. solver handle tensor t input

* add check that t input is uniform

* minor updates from discussion, tweedie, joint, and delta logp

* clean up conv likelihood inputs

* remove unecessary pary of conv like test after update

* MVG may now also be diagonal

* Fix MVG bug

* add MVG score model which computes score analytically instead of using autograd

* allow users to set time_steps for solvers

* update docstring

* nice progress bar

* remove unused import
* Removed dollar signs in README

* Updated actions

* Update CI action to run on any branch push

* Fixed bug in CI action

* Modved package to src to avoid conflict

* Reduced memory reqs on lora sbm test for macOS backend

* Added restriction for sending the coverage

* Removed macOS backend test and updated flag for coverage

* Fixed CI action bug

* Removed flag nonsense in CI action

* Removed pip show in CI

* Adding docs folder, currently empty

* Started docs, drafted a style and created intro page

* Modified docs action and added utils for plotting and distributions

* Updated book

* Improved introduction

* Improved intro

* Renamed intro to the score

* Added readthedocs conf file

* Updated actions

* Updated permission of github action

* updated readthedocs conf

* Reset readthedocs file

* Added documentation badge

* Added documentation section

* Drafted some structure for the docs

* Started the score matching section

* Added part on annealing

* Finished annealing score matching section

* Edited score matching part

* .

* Almost finished score matching section

* Added visualization of score learned with DSM

* .

* Finished the score matching section

* Working on overview

* Revision of some of the symbols, added some stuff in getting started

* Found a neat logo

* Added logo in README

* Cleaned up 02 SM notebook a little bit

* Added logo to book

* Updated front page

* Worked on the getting started page

* Added outline for diffusion section
#12)

* Fixed issue with Hessian Diagonal model log likelihood not working properly with the sampler. Now all the logic is built in properly. Also refactored log_prob to log_likelihood

* Changed return_logp to return_dlogp, removed unused arguments in docstrings

* Added docstring for dlogp in ODE Solver class
* Removed dollar signs in README

* Updated actions

* Update CI action to run on any branch push

* Fixed bug in CI action

* Modved package to src to avoid conflict

* Reduced memory reqs on lora sbm test for macOS backend

* Added restriction for sending the coverage

* Removed macOS backend test and updated flag for coverage

* Fixed CI action bug

* Removed flag nonsense in CI action

* Removed pip show in CI

* Adding docs folder, currently empty

* Started docs, drafted a style and created intro page

* Modified docs action and added utils for plotting and distributions

* Updated book

* Improved introduction

* Improved intro

* Renamed intro to the score

* Added readthedocs conf file

* Updated actions

* Updated permission of github action

* updated readthedocs conf

* Reset readthedocs file

* Added documentation badge

* Added documentation section

* Drafted some structure for the docs

* Started the score matching section

* Added part on annealing

* Finished annealing score matching section

* Edited score matching part

* .

* Almost finished score matching section

* Added visualization of score learned with DSM

* .

* Finished the score matching section

* Working on overview

* Revision of some of the symbols, added some stuff in getting started

* Found a neat logo

* Added logo in README

* Cleaned up 02 SM notebook a little bit

* Added logo to book

* Updated front page

* Worked on the getting started page

* Added outline for diffusion section

* Fixed issue with Hessian Diagonal model log likelihood not working properly with the sampler. Now all the logic is built in properly. Also refactored log_prob to log_likelihood

* Changed return_logp to return_dlogp, removed unused arguments in docstrings

* Added docstring for dlogp in ODE Solver class

* Joint training second order

* Started working on EDM parametrization

* Drafted new edm parametrization

* Drafted edm stuff

* Updated the overview section

* Renamed solvers, updated SDE to have the skip connections needed for the EDM setup

* Fixed a bug in Solver constructor

* Updated hyperparameters of some SDE to make sure they work nicely with EDM

* Removed Tweedie from solvers, brought it back in Score Model class

* Added the sampling distribution for EDM

* Added functionality to automatically load EDM model from path

* Updated code for EDMv2Net and tested it in the API

* Added inverse square root learning rate scheduler, tested the saving/loading logic so that global step can be reloaded from a checkpoint

* Made sure analytical score models would throw error upon trying to save or load

* Working on logic of loading checkpoints with different ema_lengths

* Implemented the PostHocEMA wrapper, modified the loading logic so it could handle the case where multiple ema_lengths are present

* Minor changes, preparing for futur updates

* Added my own EMA to control better the soft reset and the various parametrization, now Karras EMA is back a default for the trainer as the preferred method.

* Debugged soft reset

* Cleaned up some files and imports

* Modified extensions in github actions

* Updated and fixed EDM SDE

* .

* Fixed bug in Trainer catched by linter
@AlexandreAdam AlexandreAdam added documentation Improvements or additions to documentation enhancement New feature or request labels Nov 23, 2024
@AlexandreAdam AlexandreAdam self-assigned this Nov 23, 2024
@AlexandreAdam AlexandreAdam changed the title Dev major: Dev Nov 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants