Skip to content

Commit

Permalink
Merge pull request #14 from shenoynikhil/checkpoints
Browse files Browse the repository at this point in the history
Automatic checkpoint downloading
  • Loading branch information
shenoynikhil authored Dec 12, 2024
2 parents 8924983 + 4e1c5aa commit f577673
Show file tree
Hide file tree
Showing 18 changed files with 671 additions and 10 deletions.
117 changes: 117 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
name: release

on:
workflow_dispatch:
inputs:
release-version:
description: "A valid Semver version string"
required: true

permissions:
contents: write
pull-requests: write

jobs:
release:
# Do not release if not triggered from the default branch
if: github.ref == format('refs/heads/{0}', github.event.repository.default_branch)

runs-on: ubuntu-latest
timeout-minutes: 30

defaults:
run:
shell: bash -l {0}

steps:
- name: Checkout the code
uses: actions/checkout@v3

- name: Setup mamba
uses: mamba-org/setup-micromamba@v1
with:
environment-file: env.yml
environment-name: my_env
cache-environment: true
cache-downloads: true
create-args: >-
pip
semver
python-build
setuptools_scm
- name: Check the version is valid semver
run: |
RELEASE_VERSION="${{ inputs.release-version }}"
{
pysemver check $RELEASE_VERSION
} || {
echo "The version '$RELEASE_VERSION' is not a valid Semver version string."
echo "Please use a valid semver version string. More details at https://semver.org/"
echo "The release process is aborted."
exit 1
}
- name: Check the version is higher than the latest one
run: |
# Retrieve the git tags first
git fetch --prune --unshallow --tags &> /dev/null
RELEASE_VERSION="${{ inputs.release-version }}"
LATEST_VERSION=$(git describe --abbrev=0 --tags)
IS_HIGHER_VERSION=$(pysemver compare $RELEASE_VERSION $LATEST_VERSION)
if [ "$IS_HIGHER_VERSION" != "1" ]; then
echo "The version '$RELEASE_VERSION' is not higher than the latest version '$LATEST_VERSION'."
echo "The release process is aborted."
exit 1
fi
- name: Build Changelog
id: github_release
uses: mikepenz/release-changelog-builder-action@v4
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
toTag: "main"

- name: Configure git
run: |
git config --global user.name "${GITHUB_ACTOR}"
git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com"
- name: Create and push git tag
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Tag the release
git tag -a "${{ inputs.release-version }}" -m "Release version ${{ inputs.release-version }}"
# Checkout the git tag
git checkout "${{ inputs.release-version }}"
# Push the modified changelogs
git push origin main
# Push the tags
git push origin "${{ inputs.release-version }}"
- name: Install library
run: python -m pip install --no-deps .

- name: Build the wheel and sdist
run: python -m build --no-isolation

- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
packages-dir: dist/

- name: Create GitHub Release
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844
with:
tag_name: ${{ inputs.release-version }}
body: ${{steps.github_release.outputs.changelog}}
21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ Implementation of [Equivariant Flow Matching for Molecule Conformer Generation](

ET-Flow is a state-of-the-art generative model for generating small molecule conformations using equivariant transformers and flow matching.

### Setup Environment
### Install Etflow
We are now available on PyPI. Easily install the package using the following command:
```bash
pip install etflow
```

### Setup dev Environment
Run the following commands to setup the environment:
```bash
conda env create -n etflow -f env.yml
Expand All @@ -15,7 +21,18 @@ python3 -m pip install -e .
### Generating Conformations for Custom Smiles
We have a sample notebook ([generate_confs.ipynb](generate_confs.ipynb)) to generate conformations for custom smiles input. One needs to pass the config and corresponding checkpoint path in order as additional inputs.

[WIP] We are currently adding support to load the model config and checkpoint without custom downloading.
We have added support to load the model config and checkpoint with automatic download and caching. See ([tutorial.ipynb](tutorial.ipynb)) or use the following snippet to load the model and generate conformations for custom smiles input.

```python
from etflow import BaseFlow
model=BaseFlow.from_default(model="drugs-o3")
model.predict(['CN1C=NC2=C1C(=O)N(C(=O)N2C)C'], num_samples=3, as_mol=True)
```

We currently support the following configurations and checkpoint:
- `drugs-o3`
- `qm9-o3`
- `drugs-so3`

### Preprocessing Data
To pre-process the data, perform the following steps,
Expand Down
2 changes: 2 additions & 0 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ dependencies:
- python >=3.8
- pip
- tqdm
- pydantic
- fsspec

# Scientific
- pandas
Expand Down
1 change: 1 addition & 0 deletions etflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from etflow.models.model import BaseFlow as BaseFlow
12 changes: 12 additions & 0 deletions etflow/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
try:
from importlib.metadata import PackageNotFoundError, version
except ModuleNotFoundError:
# Try backported to PY<38 `importlib_metadata`.
from importlib_metadata import PackageNotFoundError, version


try:
__version__ = version("etflow")
except PackageNotFoundError:
# package is not installed
__version__ = "dev"
207 changes: 207 additions & 0 deletions etflow/commons/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import os
from typing import Optional

try:
from pydantic.v1 import BaseModel, Field, validator
except ImportError:
from pydantic import BaseModel, validator, Field

import fsspec
from tqdm import tqdm
from typing_extensions import Literal

CACHE = "~/.cache/etflow"


def download_with_progress(url, destination, chunk_size=2**20): # 1MB chunks
"""
Download a file with progress bar using fsspec and tqdm.
Args:
url: Source URL
destination: Local destination path
chunk_size: Size of chunks to read/write in bytes
"""
with fsspec.open(url, "rb") as source:
# Get file size if available
try:
file_size = source.size
except AttributeError:
file_size = None

# Create progress bar
with tqdm(
total=file_size,
unit="B",
unit_scale=True,
desc=f"Downloading to {destination}",
) as pbar:
# Open destination file
with open(destination, "wb") as target:
while True:
chunk = source.read(chunk_size)
if not chunk:
break
target.write(chunk)
pbar.update(len(chunk))


class BaseConfigSchema(BaseModel):
class Config(BaseModel.Config):
case_insensitive = True
extra = "forbid"


class CheckpointConfigSchema(BaseConfigSchema):
type: str
checkpoint_path: str
cache: Optional[str] = CACHE
_format: Literal[str] = ".ckpt"

@validator("cache")
def validate_cache(cls, value):
if not value:
value = "cache"
if not os.path.exists(value):
os.makedirs(value)
return value

def fetch_checkpoint(self) -> str:
self.create_cache()
if not self.checkpoint_exists():
download_with_progress(self.checkpoint_path, self.local_path)
else:
print(f"Checkpoint found at {self.local_path}")
return self

@property
def local_path(self) -> str:
return os.path.join(self.cache, self.type + self._format)

def checkpoint_exists(self) -> bool:
return os.path.exists(self.local_path)

def cache_exists(self) -> bool:
return os.path.exists(self.cache)

def create_cache(self):
if not self.cache_exists():
os.makedirs(self.cache)

def set_cache(self, cache: str):
if not cache:
return
self.cache = cache


class ModelArgsSchema(BaseConfigSchema):
network_type: Literal["TorchMDDynamics"] = "TorchMDDynamics"
hidden_channels: int = 160
num_layers: int = 20
num_rbf: int = 64
rbf_type: Literal["expnorm"] = "expnorm"
trainable_rbf: bool = True
activation: Literal["silu"] = "silu"
neighbor_embedding: bool = True
cutoff_lower: float = 0.0
cutoff_upper: float = 10.0
max_z: int = 100
node_attr_dim: int = 10
edge_attr_dim: int = 1
attn_activation: Literal["silu"] = "silu"
num_heads: int = 8
distance_influence: Literal["both"] = "both"
reduce_op: Literal["sum"] = "sum"
qk_norm: bool = True
so3_equivariant: bool = False
clip_during_norm: bool = True
parity_switch: Literal["post_hoc"] = "post_hoc"
output_layer_norm: bool = False

# flow matching specific
sigma: float = 0.1
prior_type: Literal["harmonic"] = "harmonic"
interpolation_type: Literal["linear"] = "linear"

# optimizer args
optimizer_type: Literal["AdamW"] = "AdamW"
lr: float = 8.0e-4
weight_decay: float = 1.0e-8

# lr scheduler args
lr_scheduler_type: Literal[
"CosineAnnealingWarmupRestarts"
] = "CosineAnnealingWarmupRestarts"
first_cycle_steps: int = 375_000
cycle_mult: float = 1.0
max_lr: float = 5.0e-4
min_lr: float = 1.0e-8
warmup_steps: int = 0
gamma: float = 0.05
last_epoch: int = -1
lr_scheduler_monitor: Literal["val/loss"] = "val/loss"
lr_scheduler_interval: Literal["step"] = "step"
lr_scheduler_frequency: int = 1


class ModelConfigSchema(BaseConfigSchema):
model: Literal["BaseFlow"] = "BaseFlow"
model_args: ModelArgsSchema = ModelArgsSchema()
checkpoint_config: CheckpointConfigSchema

def model_dict(self):
return self.dict(exclude={"checkpoint_config"})


class DRUGS_O3_CHECKPOINT(CheckpointConfigSchema):
type: Literal["drugs-o3"] = "drugs-o3"
checkpoint_path: str = (
"https://zenodo.org/records/14226681/files/drugs-o3.ckpt?download=1"
)


class DRUGS_SO3_CHECKPOINT(CheckpointConfigSchema):
type: Literal["drugs-so3"] = "drugs-so3"
checkpoint_path: str = (
"https://zenodo.org/records/14226681/files/drugs-so3.ckpt?download=1"
)


class QM9_O3_CHECKPOINT(CheckpointConfigSchema):
type: Literal["qm9-o3"] = "qm9-o3"
checkpoint_path: str = (
"https://zenodo.org/records/14226681/files/qm9-o3.ckpt?download=1"
)


class DRUGS_O3(ModelConfigSchema):
checkpoint_config: CheckpointConfigSchema = Field(
default_factory=DRUGS_O3_CHECKPOINT
)


class DRUGS_SO3(ModelConfigSchema):
checkpoint_config: CheckpointConfigSchema = Field(
default_factory=DRUGS_SO3_CHECKPOINT
)
model_args: ModelConfigSchema = ModelArgsSchema(so3_equivariant=True)


class QM9_O3(ModelConfigSchema):
checkpoint_config: CheckpointConfigSchema = Field(default_factory=QM9_O3_CHECKPOINT)
model_args: ModelConfigSchema = ModelArgsSchema(
output_layer_norm=True,
lr=7.0e-4,
first_cycle_steps=250_000,
max_lr=7.0e-4,
)


CONFIG_DICT = {
"drugs-o3": DRUGS_O3,
"drugs-so3": DRUGS_SO3,
"qm9-o3": QM9_O3,
}

if __name__ == "__main__":
print(DRUGS_O3().checkpoint_config)
4 changes: 1 addition & 3 deletions etflow/commons/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from pytorch_lightning import seed_everything
from torch_geometric.data import Batch, Data

from etflow.models.model import BaseFlow


@torch.no_grad()
def batched_sampling(
model: BaseFlow,
model,
data: Data,
max_batch_size: int = 1,
num_samples: int = 1,
Expand Down
Loading

0 comments on commit f577673

Please sign in to comment.