diff --git a/configs/experiment/sample_custom.yaml b/configs/experiment/sample_custom.yaml index 3e7e581..967a201 100644 --- a/configs/experiment/sample_custom.yaml +++ b/configs/experiment/sample_custom.yaml @@ -1,6 +1,6 @@ # @package _global_ -num_sampling_steps_per_batch: 100000 +num_sampling_steps_per_batch: 10000 num_batches: 1 num_init_samples_per_dataset: 1 repeat_init_samples: 1 @@ -18,11 +18,21 @@ continue_chain: true # wandb_train_run_path: prescient-design/jamun/x6rwt91k # trpcage -wandb_train_run_path: prescient-design/jamun/30ud8v9x +# wandb_train_run_path: prescient-design/jamun/30ud8v9x + +# proteing +# wandb_train_run_path: prescient-design/jamun/uikn2rrg + +# ATLAS +wandb_train_run_path: prescient-design/jamun/psxyw586 init_pdbs: - chignolin: /data/bucket/kleinhej/fast-folding/processed/chignolin/filtered.pdb - trpcage: /data/bucket/kleinhej/fast-folding/processed/trpcage/filtered.pdb + # chignolin: /data/bucket/kleinhej/fast-folding/processed/chignolin/filtered.pdb + # trpcage: /data/bucket/kleinhej/fast-folding/processed/trpcage/filtered.pdb + # proteing: /data/bucket/kleinhej/fast-folding/processed/proteing/filtered.pdb + # 3g5k_D: /data/bucket/kleinhej/atlas/3g5k_D/protein/3g5k_D.pdb + # 6hr2_B: /data/bucket/kleinhej/atlas/6hr2_B/protein/6hr2_B.pdb + 2vfx_C: /data/bucket/kleinhej/atlas/2vfx_C/protein/2vfx_C.pdb checkpoint_type: best_so_far sigma: 0.04 diff --git a/configs/experiment/train_atlas.yaml b/configs/experiment/train_atlas.yaml new file mode 100644 index 0000000..299c009 --- /dev/null +++ b/configs/experiment/train_atlas.yaml @@ -0,0 +1,65 @@ +# @package _global_ + +defaults: + - override /callbacks: + - timing.yaml + - lr_monitor.yaml + - model_checkpoint.yaml + + +model: + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.04 + max_radius: 1.0 + optim: + lr: 0.002 + +callbacks: + viz: + sigma_list: ["${model.sigma_distribution.sigma}"] + +data: + datamodule: + batch_size: 4 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory_new + root: "${paths.data_path}/atlas" + traj_pattern: "^(.*)/protein/(.*).xtc" + topology_pattern: "^(.*)/protein/(.*).pdb" + filter_codes_csv: "${paths.data_path}/atlas-splits/train.csv" + filter_codes_csv_header: "code" + as_iterable: true + + val: + _target_: jamun.data.parse_datasets_from_directory_new + root: "${paths.data_path}/atlas" + traj_pattern: "^(.*)/protein/(.*).xtc" + topology_pattern: "^(.*)/protein/(.*).pdb" + filter_codes_csv: "${paths.data_path}/atlas-splits/val.csv" + filter_codes_csv_header: "code" + max_datasets: 10 + subsample: 100 + as_iterable: true + + test: + _target_: jamun.data.parse_datasets_from_directory_new + root: "${paths.data_path}/atlas" + traj_pattern: "^(.*)/protein/(.*).xtc" + topology_pattern: "^(.*)/protein/(.*).pdb" + filter_codes_csv: "${paths.data_path}/atlas-splits/test.csv" + filter_codes_csv_header: "code" + max_datasets: 10 + subsample: 100 + as_iterable: true + + +trainer: + val_check_interval: 6000 + limit_val_batches: 1000 + max_epochs: 100 + +logger: + wandb: + group: train_atlas diff --git a/configs/experiment/train_trpcage.yaml b/configs/experiment/train_trpcage.yaml index 01d577e..aa7cafb 100644 --- a/configs/experiment/train_trpcage.yaml +++ b/configs/experiment/train_trpcage.yaml @@ -21,7 +21,6 @@ data: root: "${paths.data_path}/fast-folding/processed/trpcage" traj_pattern: "train/^(.*).xtc" pdb_file: "filtered.pdb" - max_datasets: 10 val: _target_: jamun.data.parse_datasets_from_directory diff --git a/configs/experiment/train_uncapped_2AA_alignment.yaml b/configs/experiment/train_uncapped_2AA_alignment.yaml deleted file mode 100644 index 25cc4bb..0000000 --- a/configs/experiment/train_uncapped_2AA_alignment.yaml +++ /dev/null @@ -1,50 +0,0 @@ -# @package _global_ - -defaults: - - override /model: energy.yaml - -model: - sigma_distribution: - _target_: jamun.distributions.ConstantSigma - sigma: 0.04 - max_radius: 1.0 - optim: - lr: 0.0001 - -callbacks: - viz: - sigma_list: ["${model.sigma_distribution.sigma}"] - -data: - datamodule: - batch_size: 32 - datasets: - train: - _target_: jamun.data.parse_datasets_from_directory - root: "${paths.data_path}/timewarp/2AA-1-large/train/" - traj_pattern: "^(.*)-traj-arrays.npz" - pdb_pattern: "^(.*)-traj-state0.pdb" - - val: - _target_: jamun.data.parse_datasets_from_directory - root: "${paths.data_path}/timewarp/2AA-1-large/val/" - traj_pattern: "^(.*)-traj-arrays.npz" - pdb_pattern: "^(.*)-traj-state0.pdb" - subsample: 100 - max_datasets: 20 - - test: - _target_: jamun.data.parse_datasets_from_directory - root: "${paths.data_path}/timewarp/2AA-1-large/test/" - traj_pattern: "^(.*)-traj-arrays.npz" - pdb_pattern: "^(.*)-traj-state0.pdb" - subsample: 100 - max_datasets: 20 - -trainer: - val_check_interval: 0.1 - max_epochs: 10 - -logger: - wandb: - group: train_uncapped_2AA diff --git a/configs/experiment/train_uncapped_4AA_alignment.yaml b/configs/experiment/train_uncapped_4AA_alignment.yaml new file mode 100644 index 0000000..06f3f63 --- /dev/null +++ b/configs/experiment/train_uncapped_4AA_alignment.yaml @@ -0,0 +1,58 @@ +# @package _global_ + +defaults: + - override /model/arch: mlp.yaml + +model: + arch: + num_nodes: 32 + sigma_distribution: + _target_: jamun.distributions.ConstantSigma + sigma: 0.05 + max_radius: 1.0 + optim: + lr: 0.001 + normalization_type: null + use_alignment_estimators: true + alignment_correction_order: 0 + use_torch_compile: false + rotational_augmentation: true + +callbacks: + viz: + sigma_list: ["${model.sigma_distribution.sigma}"] + +data: + datamodule: + batch_size: 32 + datasets: + train: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/timewarp/4AA-large/train/" + traj_pattern: "^(AEQN*)-traj-arrays.npz" + pdb_pattern: "^(AEQN*)-traj-state0.pdb" + max_datasets: 10 + + val: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/timewarp/4AA-large/train/" + traj_pattern: "^(AEQN*)-traj-arrays.npz" + pdb_pattern: "^(AEQN*)-traj-state0.pdb" + max_datasets: 10 + + test: + _target_: jamun.data.parse_datasets_from_directory + root: "${paths.data_path}/timewarp/4AA-large/train/" + traj_pattern: "^(AEQN*)-traj-arrays.npz" + pdb_pattern: "^(AEQN*)-traj-state0.pdb" + max_datasets: 10 + +trainer: + val_check_interval: 1000 + check_val_every_n_epoch: null + max_epochs: 10000 + log_every_n_steps: 1 + +logger: + wandb: + group: train_uncapped_4AA_alignment diff --git a/scripts/slurm/README.md b/scripts/slurm/README.md index a428010..3a9c22c 100644 --- a/scripts/slurm/README.md +++ b/scripts/slurm/README.md @@ -2,8 +2,11 @@ These are some example [SLURM](https://slurm.schedmd.com/documentation.html) launcher scripts that have worked for us, but might need some modifications for your cluster and SLURM version. -Launch with: +For example, we would launch a training run with: ```bash -sbatch train.sh -sbatch sample.sh +sbatch train_uncapped_2AA.sh +``` +and then a corresponding sampling run with: +```bash +sbatch sample_uncapped_2AA.sh ``` diff --git a/scripts/slurm/train_atlas.sh b/scripts/slurm/train_atlas.sh new file mode 100644 index 0000000..2a10e07 --- /dev/null +++ b/scripts/slurm/train_atlas.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +#SBATCH --partition g6e +#SBATCH --nodes 1 +#SBATCH --ntasks-per-node 1 +#SBATCH --gpus-per-node 1 +#SBATCH --cpus-per-task 8 +#SBATCH --time 3-0 +#SBATCH --mem=32G + +eval "$(conda shell.bash hook)" +conda activate jamun + +set -eux + +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "hostname = $(hostname)" + +export HYDRA_FULL_ERROR=1 +export CUDA_HOME=/usr/local/cuda-12.8/ + +# export TORCH_COMPILE_DEBUG=1 +# export TORCH_LOGS="+dynamo" +# export TORCHDYNAMO_VERBOSE=1 + +# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks. +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +nvidia-smi + +srun --cpus-per-task 8 --cpu-bind=cores,verbose \ + jamun_train --config-dir=/homefs/home/daigavaa/jamun/configs \ + experiment=train_atlas.yaml \ + ++trainer.devices=$SLURM_GPUS_PER_NODE \ + ++trainer.num_nodes=$SLURM_JOB_NUM_NODES \ + ++logger.wandb.tags=["'${SLURM_JOB_ID}'","'${RUN_KEY}'","train","atlas"] \ + ++run_key=$RUN_KEY diff --git a/scripts/slurm/train_uncapped_4AA.sh b/scripts/slurm/train_uncapped_4AA.sh index 9d39e17..dee7a2f 100644 --- a/scripts/slurm/train_uncapped_4AA.sh +++ b/scripts/slurm/train_uncapped_4AA.sh @@ -30,7 +30,6 @@ nvidia-smi srun --cpus-per-task 8 --cpu-bind=cores,verbose \ jamun_train --config-dir=/homefs/home/daigavaa/jamun/configs \ experiment=train_uncapped_4AA.yaml \ - ++model.sigma_distribution.sigma=0.04 \ ++trainer.devices=$SLURM_GPUS_PER_NODE \ ++trainer.num_nodes=$SLURM_JOB_NUM_NODES \ ++logger.wandb.tags=["'${SLURM_JOB_ID}'","'${RUN_KEY}'","train","uncapped_4AA"] \ diff --git a/scripts/slurm/train_uncapped_4AA_alignment.sh b/scripts/slurm/train_uncapped_4AA_alignment.sh new file mode 100644 index 0000000..b69775f --- /dev/null +++ b/scripts/slurm/train_uncapped_4AA_alignment.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +#SBATCH --partition g6e +#SBATCH --nodes 1 +#SBATCH --ntasks-per-node 2 +#SBATCH --gpus-per-node 1 +#SBATCH --cpus-per-task 4 +#SBATCH --time 1-0 +#SBATCH --array 0-2 +#SBATCH --mem=60G + +eval "$(conda shell.bash hook)" +conda activate jamun + +set -eux + +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "hostname = $(hostname)" + +export HYDRA_FULL_ERROR=1 +# export TORCH_COMPILE_DEBUG=1 +# export TORCH_LOGS="+dynamo" +# export TORCHDYNAMO_VERBOSE=1 + +# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks. +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +nvidia-smi + +srun --cpus-per-task 4 --cpu-bind=cores,verbose \ + jamun_train --config-dir=/homefs/home/daigavaa/jamun/configs \ + experiment=train_uncapped_4AA_alignment.yaml \ + ++trainer.devices=$SLURM_GPUS_PER_NODE \ + ++trainer.num_nodes=$SLURM_JOB_NUM_NODES \ + ++model.use_alignment_estimators=true \ + ++model.alignment_correction_order=$SLURM_ARRAY_TASK_ID \ + ++logger.wandb.tags=["'${SLURM_JOB_ID}'","'${RUN_KEY}'","train","align-4AA","'${SLURM_ARRAY_TASK_ID}'"] \ + ++run_key=$RUN_KEY diff --git a/scripts/slurm/train_uncapped_4AA_no_alignment.sh b/scripts/slurm/train_uncapped_4AA_no_alignment.sh new file mode 100644 index 0000000..d00c18c --- /dev/null +++ b/scripts/slurm/train_uncapped_4AA_no_alignment.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash + +#SBATCH --partition g6e +#SBATCH --nodes 1 +#SBATCH --ntasks-per-node 2 +#SBATCH --gpus-per-node 1 +#SBATCH --cpus-per-task 4 +#SBATCH --time 1-0 +#SBATCH --mem=60G + +eval "$(conda shell.bash hook)" +conda activate jamun + +set -eux + +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "hostname = $(hostname)" + +export HYDRA_FULL_ERROR=1 +# export TORCH_COMPILE_DEBUG=1 +# export TORCH_LOGS="+dynamo" +# export TORCHDYNAMO_VERBOSE=1 + +# NOTE: We generate this in submit script instead of using time-based default to ensure consistency across ranks. +RUN_KEY=$(openssl rand -hex 12) +echo "RUN_KEY = ${RUN_KEY}" + +nvidia-smi + +# Define the array of sigma values +# Bash arrays are 0-indexed +# declare -a SIGMAS=(0.1 0.2 0.5 1.0) +# SIGMA=${SIGMAS[$SLURM_ARRAY_TASK_ID]} + +srun --cpus-per-task 4 --cpu-bind=cores,verbose \ + jamun_train --config-dir=/homefs/home/daigavaa/jamun/configs \ + experiment=train_uncapped_4AA_alignment.yaml \ + ++trainer.devices=$SLURM_GPUS_PER_NODE \ + ++trainer.num_nodes=$SLURM_JOB_NUM_NODES \ + ++model.use_alignment_estimators=false \ + ++model.alignment_correction_order=null \ + ++logger.wandb.tags=["'${SLURM_JOB_ID}'","'${RUN_KEY}'","train","align-4AA"] \ + ++run_key=$RUN_KEY diff --git a/src/jamun/callbacks/_visualize_denoise.py b/src/jamun/callbacks/_visualize_denoise.py index bbafecc..075851e 100644 --- a/src/jamun/callbacks/_visualize_denoise.py +++ b/src/jamun/callbacks/_visualize_denoise.py @@ -48,9 +48,8 @@ def on_validation_batch_start(self, trainer, pl_module, data, data_idx, dataload del topology.pos, topology.batch, topology.num_graphs x, batch, num_graphs = data.pos, data.batch, data.num_graphs - for sigma in self.sigma_list: - xhat, _, y = pl_module.noise_and_denoise( + xhat, aux = pl_module.noise_and_denoise( x, topology, batch, num_graphs, sigma, use_alignment_estimators=pl_module.use_alignment_estimators ) xhat_graphs = topology.clone() @@ -60,7 +59,7 @@ def on_validation_batch_start(self, trainer, pl_module, data, data_idx, dataload x_graphs.pos = x y_graphs = topology.clone() - y_graphs.pos = y + y_graphs.pos = aux["y"] for xhat_graph, y_graph, x_graph in zip( torch_geometric.data.Batch.to_data_list(xhat_graphs), diff --git a/src/jamun/data/_dloader.py b/src/jamun/data/_dloader.py index 34f93d5..ae0af69 100644 --- a/src/jamun/data/_dloader.py +++ b/src/jamun/data/_dloader.py @@ -4,7 +4,7 @@ import lightning.pytorch as pl import numpy as np import torch_geometric.loader -from torch.utils.data import ConcatDataset, Dataset, IterableDataset +from torch.utils.data import Dataset, IterableDataset, ConcatDataset from jamun import utils diff --git a/src/jamun/data/_mdtraj.py b/src/jamun/data/_mdtraj.py index 0776ba9..14519ba 100644 --- a/src/jamun/data/_mdtraj.py +++ b/src/jamun/data/_mdtraj.py @@ -325,6 +325,8 @@ def save_topology_pdb(self, filename: str | None = None): utils.save_pdb(self.traj[0], filename) def __getitem__(self, idx): + if idx >= self.traj.n_frames: + idx = self.traj.n_frames - 1 graph = self.graph.clone("pos") graph.pos = torch.tensor(self.traj.xyz[idx]) if self.transform: @@ -332,7 +334,7 @@ def __getitem__(self, idx): return graph def __len__(self): - return self.traj.n_frames + return max(32, self.traj.n_frames) @functools.cached_property def topology(self) -> md.Topology: diff --git a/src/jamun/hydra_config/model/arch/e3conv_separable.yaml b/src/jamun/hydra_config/model/arch/e3conv_separable.yaml index cf54ffc..c7f8bae 100644 --- a/src/jamun/hydra_config/model/arch/e3conv_separable.yaml +++ b/src/jamun/hydra_config/model/arch/e3conv_separable.yaml @@ -28,7 +28,7 @@ atom_embedder_factory: num_atom_codes: 10 num_residue_types: 25 hidden_layer_factory: - _target_: e3tools.nn.SeparableConvBlock + _target_: e3tools.nn.FusedSeparableConvBlock _partial_: true output_head_factory: _target_: e3tools.nn.EquivariantMLP diff --git a/src/jamun/hydra_config/model/arch/mlp.yaml b/src/jamun/hydra_config/model/arch/mlp.yaml new file mode 100644 index 0000000..8c4bd2b --- /dev/null +++ b/src/jamun/hydra_config/model/arch/mlp.yaml @@ -0,0 +1,19 @@ +_target_: jamun.model.arch.MLP +_partial_: true +input_dims: 3 +embed_dims: 32 +noise_input_dims: 1 +n_layers: 2 +num_nodes: ??? +atom_embedder_factory: + _target_: jamun.model.embedding.ResidueAtomEmbedder + _partial_: true + atom_type_embedding_dim: 8 + atom_code_embedding_dim: 8 + residue_code_embedding_dim: 8 + residue_index_embedding_dim: 8 + use_residue_sequence_index: false + num_atom_types: 20 + max_sequence_length: 10 + num_atom_codes: 10 + num_residue_types: 25 diff --git a/src/jamun/hydra_config/model/arch/transformer.yaml b/src/jamun/hydra_config/model/arch/transformer.yaml new file mode 100644 index 0000000..d51ec6d --- /dev/null +++ b/src/jamun/hydra_config/model/arch/transformer.yaml @@ -0,0 +1,19 @@ +_target_: jamun.model.arch.Transformer +_partial_: true +input_dims: 3 +embed_dims: 32 +noise_input_dims: 1 +n_layers: 6 +n_heads: 4 +atom_embedder_factory: + _target_: jamun.model.embedding.ResidueAtomEmbedder + _partial_: true + atom_type_embedding_dim: 8 + atom_code_embedding_dim: 8 + residue_code_embedding_dim: 8 + residue_index_embedding_dim: 8 + use_residue_sequence_index: false + num_atom_types: 20 + max_sequence_length: 10 + num_atom_codes: 10 + num_residue_types: 25 diff --git a/src/jamun/hydra_config/model/denoiser.yaml b/src/jamun/hydra_config/model/denoiser.yaml index 999cf6c..710e4bc 100644 --- a/src/jamun/hydra_config/model/denoiser.yaml +++ b/src/jamun/hydra_config/model/denoiser.yaml @@ -1,5 +1,5 @@ defaults: - - arch: e3conv.yaml + - arch: e3conv_separable.yaml - optim: adam.yaml - lr_scheduler_config: null - _self_ diff --git a/src/jamun/hydra_config/model/energy.yaml b/src/jamun/hydra_config/model/energy.yaml index ba45680..49ee7d5 100644 --- a/src/jamun/hydra_config/model/energy.yaml +++ b/src/jamun/hydra_config/model/energy.yaml @@ -1,5 +1,5 @@ defaults: - - arch: e3conv.yaml + - arch: e3conv_separable.yaml - optim: adam.yaml - lr_scheduler_config: null - _self_ diff --git a/src/jamun/model/arch/__init__.py b/src/jamun/model/arch/__init__.py index 4b9edfd..b12f93c 100644 --- a/src/jamun/model/arch/__init__.py +++ b/src/jamun/model/arch/__init__.py @@ -1,3 +1,5 @@ from .e3conv import E3Conv +from .mlp import MLP from .ophiuchus import Ophiuchus from .orb import MoleculeGNSWrapper +from .transformer import Transformer diff --git a/src/jamun/model/arch/mlp.py b/src/jamun/model/arch/mlp.py new file mode 100644 index 0000000..0547ca6 --- /dev/null +++ b/src/jamun/model/arch/mlp.py @@ -0,0 +1,72 @@ +from collections.abc import Callable + +import torch +import torch.nn as nn +import torch_geometric.data + + +class MLP(nn.Module): + """A simple MLP architecture with noise conditioning.""" + + def __init__( + self, + atom_embedder_factory: Callable[..., torch.nn.Module], + n_layers: int, + input_dims: int, + embed_dims: int, + noise_input_dims: int, + num_nodes: int, + ): + super().__init__() + self.atom_embedder = atom_embedder_factory() + self.input_dims = input_dims + self.embed_dims = embed_dims + self.num_nodes = num_nodes + + # Since we concatenate atom embeddings to input features, adjust input dimensions. + input_dims += self.atom_embedder.irreps_out.dim + + self.input_projection = nn.Linear(input_dims, embed_dims) + self.layers = nn.ModuleList() + self.layer_norms = nn.ModuleList() + for _ in range(n_layers): + self.layers.append( + nn.Sequential( + nn.Linear(num_nodes * (embed_dims + noise_input_dims), num_nodes * embed_dims), + nn.GELU(), + ) + ) + self.layer_norms.append(nn.LayerNorm(num_nodes * embed_dims)) + self.output_projection = nn.Linear(num_nodes * embed_dims, num_nodes * self.input_dims) + + def forward( + self, + pos: torch.Tensor, + topology: torch_geometric.data.Batch, + batch: torch.Tensor, + num_graphs: int, + c_noise: torch.Tensor, + c_in: torch.Tensor, + ) -> torch.Tensor: + del c_in + c_noise = c_noise.unsqueeze(0).expand(pos.shape[0], -1) + + # Concatenate position and atom embeddings + x = torch.cat([pos, self.atom_embedder(topology)], dim=-1) + + # Project to embedding dimension and flatten per graph + x = self.input_projection(x) + x = x.view(num_graphs, self.num_nodes * self.embed_dims) + c_noise = c_noise.view(num_graphs, self.num_nodes * c_noise.shape[-1]) + + # Process through MLP layers + for layer, layer_norm in zip(self.layers, self.layer_norms): + x_with_noise = torch.cat([x, c_noise], dim=-1) + x = layer(x_with_noise) + x + x = layer_norm(x) + + # Output projection + x = self.output_projection(x) + x = x.view(num_graphs * self.num_nodes, self.input_dims) + + return x diff --git a/src/jamun/model/arch/transformer.py b/src/jamun/model/arch/transformer.py new file mode 100644 index 0000000..bd06b11 --- /dev/null +++ b/src/jamun/model/arch/transformer.py @@ -0,0 +1,145 @@ +from collections.abc import Callable + +import torch +import torch.nn as nn +import torch_geometric.data + + +class NoiseConditionalSelfAttention(nn.Module): + """Self-attention layer with noise conditioning.""" + + def __init__(self, input_dims: int, embed_dims: int, n_heads: int, noise_input_dims: int): + super().__init__() + self.embed_dims = embed_dims + self.n_heads = n_heads + self.head_dim = embed_dims // n_heads + + assert embed_dims % n_heads == 0, "embed_dims must be divisible by n_heads" + + self.Q = nn.Linear(input_dims + noise_input_dims, embed_dims) + self.K = nn.Linear(input_dims + noise_input_dims, embed_dims) + self.V = nn.Linear(input_dims + noise_input_dims, embed_dims) + self.output_projection = nn.Linear(embed_dims, embed_dims) + + def forward(self, x: torch.Tensor, c_noise: torch.Tensor, batch: torch.Tensor, num_graphs: int) -> torch.Tensor: + num_nodes, _ = x.shape + + # Concatenate input with noise conditioning + x_with_noise = torch.cat([x, c_noise], dim=-1) + + # Compute Q, K, V + Q = self.Q(x_with_noise) # [num_nodes, embed_dims] + K = self.K(x_with_noise) # [num_nodes, embed_dims] + V = self.V(x_with_noise) # [num_nodes, embed_dims] + + # Reshape for multi-head attention + Q = Q.view(num_nodes, self.n_heads, self.head_dim).transpose(-3, -2) + K = K.view(num_nodes, self.n_heads, self.head_dim).transpose(-3, -2) + V = V.view(num_nodes, self.n_heads, self.head_dim).transpose(-3, -2) + + # Scaled dot-product attention + scale = self.head_dim**-0.5 + logits = (Q @ K.transpose(-2, -1)) * scale + + # Mask attention scores for different graphs in the batch + mask = batch.unsqueeze(0) != batch.unsqueeze(1) + mask = mask.unsqueeze(0).expand(self.n_heads, -1, -1) # [n_heads, num_nodes, num_nodes] + # print("Mask", mask.shape) + # print("Logits before mask", logits.shape) + logits = logits.masked_fill(mask, float("-inf")) + attn_weights = torch.softmax(logits, dim=-1) + attn_output = attn_weights @ V + + # Concatenate heads + attn_output = attn_output.transpose(-3, -2).contiguous().view(num_nodes, self.embed_dims) + + # Output projection + output = self.output_projection(attn_output) + return output + + +class Transformer(nn.Module): + """A simple Transformer architecture with noise conditioning.""" + + def __init__( + self, + atom_embedder_factory: Callable[..., torch.nn.Module], + n_layers: int, + input_dims: int, + embed_dims: int, + noise_input_dims: int, + n_heads: int, + num_nodes: int, + ): + super().__init__() + self.atom_embedder = atom_embedder_factory() + self.input_dims = input_dims + self.embed_dims = embed_dims + self.num_nodes = num_nodes + + # Since we concatenate atom embeddings to input features, adjust input dimensions. + input_dims += self.atom_embedder.irreps_out.dim + + # Input projection to match embedding dimension + self.input_projection = nn.Linear(input_dims, embed_dims) + + self.self_attention_layers = nn.ModuleList() + self.self_attention_layer_norms = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.ffn_layer_norms = nn.ModuleList() + + for _ in range(n_layers): + self.self_attention_layers.append( + NoiseConditionalSelfAttention( + input_dims=embed_dims, + embed_dims=embed_dims, + n_heads=n_heads, + noise_input_dims=noise_input_dims, + ) + ) + self.self_attention_layer_norms.append(nn.LayerNorm(embed_dims)) + self.ffn_layer_norms.append(nn.LayerNorm(embed_dims)) + + # Feed-forward network + self.ffn_layers.append( + nn.Sequential(nn.Linear(embed_dims, embed_dims * 4), nn.ReLU(), nn.Linear(embed_dims * 4, embed_dims)) + ) + + self.output_projection = nn.Linear(embed_dims, self.input_dims) + + def forward( + self, + pos: torch.Tensor, + topology: torch_geometric.data.Batch, + batch: torch.Tensor, + num_graphs: int, + c_noise: torch.Tensor, + c_in: torch.Tensor, + ) -> torch.Tensor: + del c_in + c_noise = c_noise.unsqueeze(0).expand(pos.shape[0], -1) + + # c_noise: [num_nodes, noise_input_dims] + # pos: [num_nodes, input_dims] + x = torch.cat([pos, self.atom_embedder(topology)], dim=-1) + + # Project input to embedding dimension + x = self.input_projection(x) + + for self_attention, self_attention_layer_norm, ffn, ffn_layer_norm in zip( + self.self_attention_layers, self.self_attention_layer_norms, self.ffn_layers, self.ffn_layer_norms + ): + # Self-attention with residual connection and layer norm + attn_output = self_attention(x, c_noise, batch, num_graphs) + x = x + attn_output + x = self_attention_layer_norm(x) + + # Feed-forward with residual connection and layer norm + ffn_output = ffn(x) + x = x + ffn_output + x = ffn_layer_norm(x) + + # Final projection back to input dimensions + x = self.output_projection(x) + + return x diff --git a/src/jamun/model/denoiser.py b/src/jamun/model/denoiser.py index 03fd0bd..66eb940 100644 --- a/src/jamun/model/denoiser.py +++ b/src/jamun/model/denoiser.py @@ -244,16 +244,30 @@ def noise_and_denoise( else: xtarget = x + xtarget_second_order = align_A_to_B_batched_f( + x, + y, + batch, + num_graphs, + sigma=sigma, + correction_order=2, + ) + with torch.cuda.nvtx.range("xhat"): xhat = self.xhat(y, topology, batch, num_graphs, sigma) - return xhat, xtarget, y + return xhat, { + "xtarget": xtarget, + "xtarget_second_order": xtarget_second_order, + "y": y, + } def compute_loss( self, *, x: torch.Tensor, xtarget: torch.Tensor, + xtarget_second_order: torch.Tensor, xhat: torch.Tensor, topology: torch_geometric.data.Batch, batch: torch.Tensor, @@ -268,6 +282,17 @@ def compute_loss( for key in aux_xtarget: aux[f"xtarget/{key}"] = aux_xtarget[key] + aux_xtarget_second_order = compute_rmsd_metrics( + x=xtarget_second_order, + xhat=xhat, + batch=batch, + num_graphs=num_graphs, + sigma=sigma, + mean_center=self.mean_center, + ) + for key in aux_xtarget_second_order: + aux[f"xtarget_second_order/{key}"] = aux_xtarget_second_order[key] + aux_x = compute_rmsd_metrics( x=x, xhat=xhat, batch=batch, num_graphs=num_graphs, sigma=sigma, mean_center=self.mean_center ) @@ -303,7 +328,7 @@ def noise_and_compute_loss( use_alignment_estimators: bool, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """Add noise to the input and compute the loss.""" - xhat, xtarget, _ = self.noise_and_denoise( + xhat, aux = self.noise_and_denoise( x=x, topology=topology, batch=batch, @@ -313,7 +338,8 @@ def noise_and_compute_loss( ) return self.compute_loss( x=x, - xtarget=xtarget, + xtarget=aux["xtarget"], + xtarget_second_order=aux["xtarget_second_order"], xhat=xhat, topology=topology, batch=batch, @@ -336,7 +362,7 @@ def training_step(self, data: torch_geometric.data.Batch, data_idx: int): if self.rotational_augmentation: with torch.cuda.nvtx.range("rotational_augmentation"): R = e3nn.o3.rand_matrix(device=self.device, dtype=x.dtype) - x = torch.einsum("ni,ij->nj", x, R.T) + x = torch.einsum("Ni,ij->Nj", x, R.T) loss, aux = self.noise_and_compute_loss( x=x, @@ -370,7 +396,7 @@ def validation_step(self, data: torch_geometric.data.Batch, data_idx: int): x, batch, num_graphs = data.pos, data.batch, data.num_graphs if self.rotational_augmentation: R = e3nn.o3.rand_matrix(device=self.device, dtype=x.dtype) - x = torch.einsum("ni,ij->nj", x, R.T) + x = torch.einsum("Ni,ij->Nj", x, R.T) loss, aux = self.noise_and_compute_loss( x=x, diff --git a/src/jamun/model/utils.py b/src/jamun/model/utils.py index 0330ac5..cf97e5c 100644 --- a/src/jamun/model/utils.py +++ b/src/jamun/model/utils.py @@ -86,6 +86,7 @@ def add_edges( def compute_rmsd_metrics( + *, x: torch.Tensor, xhat: torch.Tensor, batch: torch.Tensor, @@ -126,6 +127,7 @@ def compute_rmsd_metrics( return { "mse": mse, + "mse_aligned": mse_aligned, "rmsd": rmsd, "rmsd_aligned": rmsd_aligned, "scaled_rmsd": scaled_rmsd, diff --git a/src/jamun/utils/align.py b/src/jamun/utils/align.py index 552457e..692a94c 100644 --- a/src/jamun/utils/align.py +++ b/src/jamun/utils/align.py @@ -118,16 +118,33 @@ def kabsch_algorithm( # SVD to get rotation. U, S, VH = torch.linalg.svd(H) + verbose = False + + if verbose: + print("sigma:", sigma) + print("S before sign correction:", S) # Compute corrected S. R_check = torch.einsum("Gki,Gjk->Gij", VH, U) # V U^T + if verbose: + print("det U:", torch.linalg.det(U)) + print("det V:", torch.linalg.det(VH)) + print("U[0]:", U[0]) + print("VH[0]:", VH[0]) dets = torch.linalg.det(R_check) signs = torch.ones(num_graphs, 3, device=dets.device) - signs[:, 2] = dets + signs[:, -1] = (dets >= 0).float() * 2 - 1 S = torch.einsum("Gk,Gk->Gk", signs, S) + if verbose: + print("det R:", dets) + print("signs:", signs) + print("S after sign correction:", S) # Remove reflections. S = alignment_correction_upto_order(S, sigma=sigma, correction_order=correction_order) + if verbose: + print("S after correction:", S) + print() R = torch.einsum("Gki,Gk,Gk,Gjk->Gij", VH, signs, S, U) # V S U^T # Align y to x.