Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions configs/experiment/sample_custom.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# @package _global_

num_sampling_steps_per_batch: 100000
num_sampling_steps_per_batch: 10000
num_batches: 1
num_init_samples_per_dataset: 1
repeat_init_samples: 1
Expand All @@ -18,11 +18,21 @@ continue_chain: true
# wandb_train_run_path: prescient-design/jamun/x6rwt91k

# trpcage
wandb_train_run_path: prescient-design/jamun/30ud8v9x
# wandb_train_run_path: prescient-design/jamun/30ud8v9x

# proteing
# wandb_train_run_path: prescient-design/jamun/uikn2rrg

# ATLAS
wandb_train_run_path: prescient-design/jamun/psxyw586

init_pdbs:
chignolin: /data/bucket/kleinhej/fast-folding/processed/chignolin/filtered.pdb
trpcage: /data/bucket/kleinhej/fast-folding/processed/trpcage/filtered.pdb
# chignolin: /data/bucket/kleinhej/fast-folding/processed/chignolin/filtered.pdb
# trpcage: /data/bucket/kleinhej/fast-folding/processed/trpcage/filtered.pdb
# proteing: /data/bucket/kleinhej/fast-folding/processed/proteing/filtered.pdb
# 3g5k_D: /data/bucket/kleinhej/atlas/3g5k_D/protein/3g5k_D.pdb
# 6hr2_B: /data/bucket/kleinhej/atlas/6hr2_B/protein/6hr2_B.pdb
2vfx_C: /data/bucket/kleinhej/atlas/2vfx_C/protein/2vfx_C.pdb

checkpoint_type: best_so_far
sigma: 0.04
Expand Down
65 changes: 65 additions & 0 deletions configs/experiment/train_atlas.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# @package _global_

defaults:
- override /callbacks:
- timing.yaml
- lr_monitor.yaml
- model_checkpoint.yaml


model:
sigma_distribution:
_target_: jamun.distributions.ConstantSigma
sigma: 0.04
max_radius: 1.0
optim:
lr: 0.002

callbacks:
viz:
sigma_list: ["${model.sigma_distribution.sigma}"]

data:
datamodule:
batch_size: 4
datasets:
train:
_target_: jamun.data.parse_datasets_from_directory_new
root: "${paths.data_path}/atlas"
traj_pattern: "^(.*)/protein/(.*).xtc"
topology_pattern: "^(.*)/protein/(.*).pdb"
filter_codes_csv: "${paths.data_path}/atlas-splits/train.csv"
filter_codes_csv_header: "code"
as_iterable: true

val:
_target_: jamun.data.parse_datasets_from_directory_new
root: "${paths.data_path}/atlas"
traj_pattern: "^(.*)/protein/(.*).xtc"
topology_pattern: "^(.*)/protein/(.*).pdb"
filter_codes_csv: "${paths.data_path}/atlas-splits/val.csv"
filter_codes_csv_header: "code"
max_datasets: 10
subsample: 100
as_iterable: true

test:
_target_: jamun.data.parse_datasets_from_directory_new
root: "${paths.data_path}/atlas"
traj_pattern: "^(.*)/protein/(.*).xtc"
topology_pattern: "^(.*)/protein/(.*).pdb"
filter_codes_csv: "${paths.data_path}/atlas-splits/test.csv"
filter_codes_csv_header: "code"
max_datasets: 10
subsample: 100
as_iterable: true


trainer:
val_check_interval: 6000
limit_val_batches: 1000
max_epochs: 100

logger:
wandb:
group: train_atlas
1 change: 0 additions & 1 deletion configs/experiment/train_trpcage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ data:
root: "${paths.data_path}/fast-folding/processed/trpcage"
traj_pattern: "train/^(.*).xtc"
pdb_file: "filtered.pdb"
max_datasets: 10

val:
_target_: jamun.data.parse_datasets_from_directory
Expand Down
50 changes: 0 additions & 50 deletions configs/experiment/train_uncapped_2AA_alignment.yaml

This file was deleted.

58 changes: 58 additions & 0 deletions configs/experiment/train_uncapped_4AA_alignment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# @package _global_

defaults:
- override /model/arch: mlp.yaml

model:
arch:
num_nodes: 32
sigma_distribution:
_target_: jamun.distributions.ConstantSigma
sigma: 0.05
max_radius: 1.0
optim:
lr: 0.001
normalization_type: null
use_alignment_estimators: true
alignment_correction_order: 0
use_torch_compile: false
rotational_augmentation: true

callbacks:
viz:
sigma_list: ["${model.sigma_distribution.sigma}"]

data:
datamodule:
batch_size: 32
datasets:
train:
_target_: jamun.data.parse_datasets_from_directory
root: "${paths.data_path}/timewarp/4AA-large/train/"
traj_pattern: "^(AEQN*)-traj-arrays.npz"
pdb_pattern: "^(AEQN*)-traj-state0.pdb"
max_datasets: 10

val:
_target_: jamun.data.parse_datasets_from_directory
root: "${paths.data_path}/timewarp/4AA-large/train/"
traj_pattern: "^(AEQN*)-traj-arrays.npz"
pdb_pattern: "^(AEQN*)-traj-state0.pdb"
max_datasets: 10

test:
_target_: jamun.data.parse_datasets_from_directory
root: "${paths.data_path}/timewarp/4AA-large/train/"
traj_pattern: "^(AEQN*)-traj-arrays.npz"
pdb_pattern: "^(AEQN*)-traj-state0.pdb"
max_datasets: 10

trainer:
val_check_interval: 1000
check_val_every_n_epoch: null
max_epochs: 10000
log_every_n_steps: 1

logger:
wandb:
group: train_uncapped_4AA_alignment
9 changes: 6 additions & 3 deletions scripts/slurm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

These are some example [SLURM](https://slurm.schedmd.com/documentation.html) launcher scripts that have worked for us, but might need some modifications for your cluster and SLURM version.

Launch with:
For example, we would launch a training run with:
```bash
sbatch train.sh
sbatch sample.sh
sbatch train_uncapped_2AA.sh
```
and then a corresponding sampling run with:
```bash
sbatch sample_uncapped_2AA.sh
```
38 changes: 38 additions & 0 deletions scripts/slurm/train_atlas.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env bash

#SBATCH --partition g6e
#SBATCH --nodes 1
#SBATCH --ntasks-per-node 1
#SBATCH --gpus-per-node 1
#SBATCH --cpus-per-task 8
#SBATCH --time 3-0
#SBATCH --mem=32G

eval "$(conda shell.bash hook)"
conda activate jamun

set -eux

echo "SLURM_JOB_ID = ${SLURM_JOB_ID}"
echo "hostname = $(hostname)"

export HYDRA_FULL_ERROR=1
export CUDA_HOME=/usr/local/cuda-12.8/

# export TORCH_COMPILE_DEBUG=1
# export TORCH_LOGS="+dynamo"
# export TORCHDYNAMO_VERBOSE=1

# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks.
RUN_KEY=$(openssl rand -hex 12)
echo "RUN_KEY = ${RUN_KEY}"

nvidia-smi

srun --cpus-per-task 8 --cpu-bind=cores,verbose \
jamun_train --config-dir=/homefs/home/daigavaa/jamun/configs \
experiment=train_atlas.yaml \
++trainer.devices=$SLURM_GPUS_PER_NODE \
++trainer.num_nodes=$SLURM_JOB_NUM_NODES \
++logger.wandb.tags=["'${SLURM_JOB_ID}'","'${RUN_KEY}'","train","atlas"] \
++run_key=$RUN_KEY
1 change: 0 additions & 1 deletion scripts/slurm/train_uncapped_4AA.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ nvidia-smi
srun --cpus-per-task 8 --cpu-bind=cores,verbose \
jamun_train --config-dir=/homefs/home/daigavaa/jamun/configs \
experiment=train_uncapped_4AA.yaml \
++model.sigma_distribution.sigma=0.04 \
++trainer.devices=$SLURM_GPUS_PER_NODE \
++trainer.num_nodes=$SLURM_JOB_NUM_NODES \
++logger.wandb.tags=["'${SLURM_JOB_ID}'","'${RUN_KEY}'","train","uncapped_4AA"] \
Expand Down
39 changes: 39 additions & 0 deletions scripts/slurm/train_uncapped_4AA_alignment.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env bash

#SBATCH --partition g6e
#SBATCH --nodes 1
#SBATCH --ntasks-per-node 2
#SBATCH --gpus-per-node 1
#SBATCH --cpus-per-task 4
#SBATCH --time 1-0
#SBATCH --array 0-2
#SBATCH --mem=60G

eval "$(conda shell.bash hook)"
conda activate jamun

set -eux

echo "SLURM_JOB_ID = ${SLURM_JOB_ID}"
echo "hostname = $(hostname)"

export HYDRA_FULL_ERROR=1
# export TORCH_COMPILE_DEBUG=1
# export TORCH_LOGS="+dynamo"
# export TORCHDYNAMO_VERBOSE=1

# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks.
RUN_KEY=$(openssl rand -hex 12)
echo "RUN_KEY = ${RUN_KEY}"

nvidia-smi

srun --cpus-per-task 4 --cpu-bind=cores,verbose \
jamun_train --config-dir=/homefs/home/daigavaa/jamun/configs \
experiment=train_uncapped_4AA_alignment.yaml \
++trainer.devices=$SLURM_GPUS_PER_NODE \
++trainer.num_nodes=$SLURM_JOB_NUM_NODES \
++model.use_alignment_estimators=true \
++model.alignment_correction_order=$SLURM_ARRAY_TASK_ID \
++logger.wandb.tags=["'${SLURM_JOB_ID}'","'${RUN_KEY}'","train","align-4AA","'${SLURM_ARRAY_TASK_ID}'"] \
++run_key=$RUN_KEY
43 changes: 43 additions & 0 deletions scripts/slurm/train_uncapped_4AA_no_alignment.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env bash

#SBATCH --partition g6e
#SBATCH --nodes 1
#SBATCH --ntasks-per-node 2
#SBATCH --gpus-per-node 1
#SBATCH --cpus-per-task 4
#SBATCH --time 1-0
#SBATCH --mem=60G

eval "$(conda shell.bash hook)"
conda activate jamun

set -eux

echo "SLURM_JOB_ID = ${SLURM_JOB_ID}"
echo "hostname = $(hostname)"

export HYDRA_FULL_ERROR=1
# export TORCH_COMPILE_DEBUG=1
# export TORCH_LOGS="+dynamo"
# export TORCHDYNAMO_VERBOSE=1

# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks.
RUN_KEY=$(openssl rand -hex 12)
echo "RUN_KEY = ${RUN_KEY}"

nvidia-smi

# Define the array of sigma values
# Bash arrays are 0-indexed
# declare -a SIGMAS=(0.1 0.2 0.5 1.0)
# SIGMA=${SIGMAS[$SLURM_ARRAY_TASK_ID]}

srun --cpus-per-task 4 --cpu-bind=cores,verbose \
jamun_train --config-dir=/homefs/home/daigavaa/jamun/configs \
experiment=train_uncapped_4AA_alignment.yaml \
++trainer.devices=$SLURM_GPUS_PER_NODE \
++trainer.num_nodes=$SLURM_JOB_NUM_NODES \
++model.use_alignment_estimators=false \
++model.alignment_correction_order=null \
++logger.wandb.tags=["'${SLURM_JOB_ID}'","'${RUN_KEY}'","train","align-4AA"] \
++run_key=$RUN_KEY
5 changes: 2 additions & 3 deletions src/jamun/callbacks/_visualize_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ def on_validation_batch_start(self, trainer, pl_module, data, data_idx, dataload
del topology.pos, topology.batch, topology.num_graphs

x, batch, num_graphs = data.pos, data.batch, data.num_graphs

for sigma in self.sigma_list:
xhat, _, y = pl_module.noise_and_denoise(
xhat, aux = pl_module.noise_and_denoise(
x, topology, batch, num_graphs, sigma, use_alignment_estimators=pl_module.use_alignment_estimators
)
xhat_graphs = topology.clone()
Expand All @@ -60,7 +59,7 @@ def on_validation_batch_start(self, trainer, pl_module, data, data_idx, dataload
x_graphs.pos = x

y_graphs = topology.clone()
y_graphs.pos = y
y_graphs.pos = aux["y"]

for xhat_graph, y_graph, x_graph in zip(
torch_geometric.data.Batch.to_data_list(xhat_graphs),
Expand Down
2 changes: 1 addition & 1 deletion src/jamun/data/_dloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import lightning.pytorch as pl
import numpy as np
import torch_geometric.loader
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
from torch.utils.data import Dataset, IterableDataset, ConcatDataset

from jamun import utils

Expand Down
Loading
Loading