Skip to content

Commit f577673

Browse files
authored
Merge pull request #14 from shenoynikhil/checkpoints
Automatic checkpoint downloading
2 parents 8924983 + 4e1c5aa commit f577673

File tree

18 files changed

+671
-10
lines changed

18 files changed

+671
-10
lines changed

.github/workflows/release.yml

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
name: release
2+
3+
on:
4+
workflow_dispatch:
5+
inputs:
6+
release-version:
7+
description: "A valid Semver version string"
8+
required: true
9+
10+
permissions:
11+
contents: write
12+
pull-requests: write
13+
14+
jobs:
15+
release:
16+
# Do not release if not triggered from the default branch
17+
if: github.ref == format('refs/heads/{0}', github.event.repository.default_branch)
18+
19+
runs-on: ubuntu-latest
20+
timeout-minutes: 30
21+
22+
defaults:
23+
run:
24+
shell: bash -l {0}
25+
26+
steps:
27+
- name: Checkout the code
28+
uses: actions/checkout@v3
29+
30+
- name: Setup mamba
31+
uses: mamba-org/setup-micromamba@v1
32+
with:
33+
environment-file: env.yml
34+
environment-name: my_env
35+
cache-environment: true
36+
cache-downloads: true
37+
create-args: >-
38+
pip
39+
semver
40+
python-build
41+
setuptools_scm
42+
43+
- name: Check the version is valid semver
44+
run: |
45+
RELEASE_VERSION="${{ inputs.release-version }}"
46+
47+
{
48+
pysemver check $RELEASE_VERSION
49+
} || {
50+
echo "The version '$RELEASE_VERSION' is not a valid Semver version string."
51+
echo "Please use a valid semver version string. More details at https://semver.org/"
52+
echo "The release process is aborted."
53+
exit 1
54+
}
55+
56+
- name: Check the version is higher than the latest one
57+
run: |
58+
# Retrieve the git tags first
59+
git fetch --prune --unshallow --tags &> /dev/null
60+
61+
RELEASE_VERSION="${{ inputs.release-version }}"
62+
LATEST_VERSION=$(git describe --abbrev=0 --tags)
63+
64+
IS_HIGHER_VERSION=$(pysemver compare $RELEASE_VERSION $LATEST_VERSION)
65+
66+
if [ "$IS_HIGHER_VERSION" != "1" ]; then
67+
echo "The version '$RELEASE_VERSION' is not higher than the latest version '$LATEST_VERSION'."
68+
echo "The release process is aborted."
69+
exit 1
70+
fi
71+
72+
- name: Build Changelog
73+
id: github_release
74+
uses: mikepenz/release-changelog-builder-action@v4
75+
env:
76+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
77+
with:
78+
toTag: "main"
79+
80+
- name: Configure git
81+
run: |
82+
git config --global user.name "${GITHUB_ACTOR}"
83+
git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com"
84+
85+
- name: Create and push git tag
86+
env:
87+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
88+
run: |
89+
# Tag the release
90+
git tag -a "${{ inputs.release-version }}" -m "Release version ${{ inputs.release-version }}"
91+
92+
# Checkout the git tag
93+
git checkout "${{ inputs.release-version }}"
94+
95+
# Push the modified changelogs
96+
git push origin main
97+
98+
# Push the tags
99+
git push origin "${{ inputs.release-version }}"
100+
101+
- name: Install library
102+
run: python -m pip install --no-deps .
103+
104+
- name: Build the wheel and sdist
105+
run: python -m build --no-isolation
106+
107+
- name: Publish package to PyPI
108+
uses: pypa/gh-action-pypi-publish@release/v1
109+
with:
110+
password: ${{ secrets.PYPI_API_TOKEN }}
111+
packages-dir: dist/
112+
113+
- name: Create GitHub Release
114+
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844
115+
with:
116+
tag_name: ${{ inputs.release-version }}
117+
body: ${{steps.github_release.outputs.changelog}}

README.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@ Implementation of [Equivariant Flow Matching for Molecule Conformer Generation](
33

44
ET-Flow is a state-of-the-art generative model for generating small molecule conformations using equivariant transformers and flow matching.
55

6-
### Setup Environment
6+
### Install Etflow
7+
We are now available on PyPI. Easily install the package using the following command:
8+
```bash
9+
pip install etflow
10+
```
11+
12+
### Setup dev Environment
713
Run the following commands to setup the environment:
814
```bash
915
conda env create -n etflow -f env.yml
@@ -15,7 +21,18 @@ python3 -m pip install -e .
1521
### Generating Conformations for Custom Smiles
1622
We have a sample notebook ([generate_confs.ipynb](generate_confs.ipynb)) to generate conformations for custom smiles input. One needs to pass the config and corresponding checkpoint path in order as additional inputs.
1723

18-
[WIP] We are currently adding support to load the model config and checkpoint without custom downloading.
24+
We have added support to load the model config and checkpoint with automatic download and caching. See ([tutorial.ipynb](tutorial.ipynb)) or use the following snippet to load the model and generate conformations for custom smiles input.
25+
26+
```python
27+
from etflow import BaseFlow
28+
model=BaseFlow.from_default(model="drugs-o3")
29+
model.predict(['CN1C=NC2=C1C(=O)N(C(=O)N2C)C'], num_samples=3, as_mol=True)
30+
```
31+
32+
We currently support the following configurations and checkpoint:
33+
- `drugs-o3`
34+
- `qm9-o3`
35+
- `drugs-so3`
1936

2037
### Preprocessing Data
2138
To pre-process the data, perform the following steps,

env.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ dependencies:
99
- python >=3.8
1010
- pip
1111
- tqdm
12+
- pydantic
13+
- fsspec
1214

1315
# Scientific
1416
- pandas

etflow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from etflow.models.model import BaseFlow as BaseFlow

etflow/_version.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
try:
2+
from importlib.metadata import PackageNotFoundError, version
3+
except ModuleNotFoundError:
4+
# Try backported to PY<38 `importlib_metadata`.
5+
from importlib_metadata import PackageNotFoundError, version
6+
7+
8+
try:
9+
__version__ = version("etflow")
10+
except PackageNotFoundError:
11+
# package is not installed
12+
__version__ = "dev"

etflow/commons/configs.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import os
2+
from typing import Optional
3+
4+
try:
5+
from pydantic.v1 import BaseModel, Field, validator
6+
except ImportError:
7+
from pydantic import BaseModel, validator, Field
8+
9+
import fsspec
10+
from tqdm import tqdm
11+
from typing_extensions import Literal
12+
13+
CACHE = "~/.cache/etflow"
14+
15+
16+
def download_with_progress(url, destination, chunk_size=2**20): # 1MB chunks
17+
"""
18+
Download a file with progress bar using fsspec and tqdm.
19+
20+
Args:
21+
url: Source URL
22+
destination: Local destination path
23+
chunk_size: Size of chunks to read/write in bytes
24+
"""
25+
with fsspec.open(url, "rb") as source:
26+
# Get file size if available
27+
try:
28+
file_size = source.size
29+
except AttributeError:
30+
file_size = None
31+
32+
# Create progress bar
33+
with tqdm(
34+
total=file_size,
35+
unit="B",
36+
unit_scale=True,
37+
desc=f"Downloading to {destination}",
38+
) as pbar:
39+
# Open destination file
40+
with open(destination, "wb") as target:
41+
while True:
42+
chunk = source.read(chunk_size)
43+
if not chunk:
44+
break
45+
target.write(chunk)
46+
pbar.update(len(chunk))
47+
48+
49+
class BaseConfigSchema(BaseModel):
50+
class Config(BaseModel.Config):
51+
case_insensitive = True
52+
extra = "forbid"
53+
54+
55+
class CheckpointConfigSchema(BaseConfigSchema):
56+
type: str
57+
checkpoint_path: str
58+
cache: Optional[str] = CACHE
59+
_format: Literal[str] = ".ckpt"
60+
61+
@validator("cache")
62+
def validate_cache(cls, value):
63+
if not value:
64+
value = "cache"
65+
if not os.path.exists(value):
66+
os.makedirs(value)
67+
return value
68+
69+
def fetch_checkpoint(self) -> str:
70+
self.create_cache()
71+
if not self.checkpoint_exists():
72+
download_with_progress(self.checkpoint_path, self.local_path)
73+
else:
74+
print(f"Checkpoint found at {self.local_path}")
75+
return self
76+
77+
@property
78+
def local_path(self) -> str:
79+
return os.path.join(self.cache, self.type + self._format)
80+
81+
def checkpoint_exists(self) -> bool:
82+
return os.path.exists(self.local_path)
83+
84+
def cache_exists(self) -> bool:
85+
return os.path.exists(self.cache)
86+
87+
def create_cache(self):
88+
if not self.cache_exists():
89+
os.makedirs(self.cache)
90+
91+
def set_cache(self, cache: str):
92+
if not cache:
93+
return
94+
self.cache = cache
95+
96+
97+
class ModelArgsSchema(BaseConfigSchema):
98+
network_type: Literal["TorchMDDynamics"] = "TorchMDDynamics"
99+
hidden_channels: int = 160
100+
num_layers: int = 20
101+
num_rbf: int = 64
102+
rbf_type: Literal["expnorm"] = "expnorm"
103+
trainable_rbf: bool = True
104+
activation: Literal["silu"] = "silu"
105+
neighbor_embedding: bool = True
106+
cutoff_lower: float = 0.0
107+
cutoff_upper: float = 10.0
108+
max_z: int = 100
109+
node_attr_dim: int = 10
110+
edge_attr_dim: int = 1
111+
attn_activation: Literal["silu"] = "silu"
112+
num_heads: int = 8
113+
distance_influence: Literal["both"] = "both"
114+
reduce_op: Literal["sum"] = "sum"
115+
qk_norm: bool = True
116+
so3_equivariant: bool = False
117+
clip_during_norm: bool = True
118+
parity_switch: Literal["post_hoc"] = "post_hoc"
119+
output_layer_norm: bool = False
120+
121+
# flow matching specific
122+
sigma: float = 0.1
123+
prior_type: Literal["harmonic"] = "harmonic"
124+
interpolation_type: Literal["linear"] = "linear"
125+
126+
# optimizer args
127+
optimizer_type: Literal["AdamW"] = "AdamW"
128+
lr: float = 8.0e-4
129+
weight_decay: float = 1.0e-8
130+
131+
# lr scheduler args
132+
lr_scheduler_type: Literal[
133+
"CosineAnnealingWarmupRestarts"
134+
] = "CosineAnnealingWarmupRestarts"
135+
first_cycle_steps: int = 375_000
136+
cycle_mult: float = 1.0
137+
max_lr: float = 5.0e-4
138+
min_lr: float = 1.0e-8
139+
warmup_steps: int = 0
140+
gamma: float = 0.05
141+
last_epoch: int = -1
142+
lr_scheduler_monitor: Literal["val/loss"] = "val/loss"
143+
lr_scheduler_interval: Literal["step"] = "step"
144+
lr_scheduler_frequency: int = 1
145+
146+
147+
class ModelConfigSchema(BaseConfigSchema):
148+
model: Literal["BaseFlow"] = "BaseFlow"
149+
model_args: ModelArgsSchema = ModelArgsSchema()
150+
checkpoint_config: CheckpointConfigSchema
151+
152+
def model_dict(self):
153+
return self.dict(exclude={"checkpoint_config"})
154+
155+
156+
class DRUGS_O3_CHECKPOINT(CheckpointConfigSchema):
157+
type: Literal["drugs-o3"] = "drugs-o3"
158+
checkpoint_path: str = (
159+
"https://zenodo.org/records/14226681/files/drugs-o3.ckpt?download=1"
160+
)
161+
162+
163+
class DRUGS_SO3_CHECKPOINT(CheckpointConfigSchema):
164+
type: Literal["drugs-so3"] = "drugs-so3"
165+
checkpoint_path: str = (
166+
"https://zenodo.org/records/14226681/files/drugs-so3.ckpt?download=1"
167+
)
168+
169+
170+
class QM9_O3_CHECKPOINT(CheckpointConfigSchema):
171+
type: Literal["qm9-o3"] = "qm9-o3"
172+
checkpoint_path: str = (
173+
"https://zenodo.org/records/14226681/files/qm9-o3.ckpt?download=1"
174+
)
175+
176+
177+
class DRUGS_O3(ModelConfigSchema):
178+
checkpoint_config: CheckpointConfigSchema = Field(
179+
default_factory=DRUGS_O3_CHECKPOINT
180+
)
181+
182+
183+
class DRUGS_SO3(ModelConfigSchema):
184+
checkpoint_config: CheckpointConfigSchema = Field(
185+
default_factory=DRUGS_SO3_CHECKPOINT
186+
)
187+
model_args: ModelConfigSchema = ModelArgsSchema(so3_equivariant=True)
188+
189+
190+
class QM9_O3(ModelConfigSchema):
191+
checkpoint_config: CheckpointConfigSchema = Field(default_factory=QM9_O3_CHECKPOINT)
192+
model_args: ModelConfigSchema = ModelArgsSchema(
193+
output_layer_norm=True,
194+
lr=7.0e-4,
195+
first_cycle_steps=250_000,
196+
max_lr=7.0e-4,
197+
)
198+
199+
200+
CONFIG_DICT = {
201+
"drugs-o3": DRUGS_O3,
202+
"drugs-so3": DRUGS_SO3,
203+
"qm9-o3": QM9_O3,
204+
}
205+
206+
if __name__ == "__main__":
207+
print(DRUGS_O3().checkpoint_config)

etflow/commons/sample.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
from pytorch_lightning import seed_everything
44
from torch_geometric.data import Batch, Data
55

6-
from etflow.models.model import BaseFlow
7-
86

97
@torch.no_grad()
108
def batched_sampling(
11-
model: BaseFlow,
9+
model,
1210
data: Data,
1311
max_batch_size: int = 1,
1412
num_samples: int = 1,

0 commit comments

Comments
 (0)