Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions configs/experiment/sample_uncapped_2AA_diffusion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# @package _global_

defaults:
- override /sampler: "diffusion.yaml"
- override /callbacks:
- sampler/save_sample.yaml
- _self_

callbacks:
sampler:
save_sample:
save_trajectory: false

init_datasets:
_target_: jamun.data.parse_datasets_from_directory
root: "${paths.data_path}/timewarp/2AA-1-large/test/"
traj_pattern: "^(.*)-traj-arrays.npz"
pdb_pattern: "^(.*)-traj-state0.pdb"
subsample: 1
num_workers: 16

sampler:
sigma_min: 1e-4
sigma_max: 1.0
rho: 7
num_steps: 64
use_second_order_correction: true

finetune_on_init: false

num_batches: 2048
repeat_init_samples: 16
num_init_samples_per_dataset: 1
continue_chain: false


wandb_train_run_path: ???
checkpoint_type: "best_so_far"

logger:
wandb:
group: sample_uncapped_2AA
68 changes: 68 additions & 0 deletions configs/experiment/train_uncapped_2AA_diffusion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# @package _global_

defaults:
- override /callbacks:
- visualize_denoise.yaml
- timing.yaml
- ema.yaml
- ema_model_checkpoint.yaml
- _self_

compute_average_squared_distance_from_data: true

model:
sigma_distribution:
_target_: jamun.distributions.ClippedLogNormalSigma
log_sigma_mean: -3.2188758248682006 # log(0.04)
log_sigma_std: 1.0
max_radius: 1.0
optim:
lr: 2e-3
arch:
irreps_hidden: "120x0e + 32x1e"
hidden_layer_factory:
_target_: "e3tools.nn.SeparableConvBlock"


callbacks:
viz:
sigma_list: [0.04]

data:
datamodule:
batch_size: 32
datasets:
train:
_target_: jamun.data.parse_datasets_from_directory
root: "${paths.data_path}/timewarp/2AA-1-large/train/"
traj_pattern: "^(.*)-traj-arrays.npz"
pdb_pattern: "^(.*)-traj-state0.pdb"
num_workers: 16

val:
_target_: jamun.data.parse_datasets_from_directory
root: "${paths.data_path}/timewarp/2AA-1-large/val/"
traj_pattern: "^(.*)-traj-arrays.npz"
pdb_pattern: "^(.*)-traj-state0.pdb"
subsample: 100
max_datasets: 20
num_workers: 16

test:
_target_: jamun.data.parse_datasets_from_directory
root: "${paths.data_path}/timewarp/2AA-1-large/test/"
traj_pattern: "^(.*)-traj-arrays.npz"
pdb_pattern: "^(.*)-traj-state0.pdb"
subsample: 100
max_datasets: 20
num_workers: 16

trainer:
val_check_interval: 50000
check_val_every_n_epoch: null
max_epochs: 10
num_sanity_val_steps: 0

logger:
wandb:
group: train_uncapped_2AA
1 change: 1 addition & 0 deletions profiling/profile.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ nsys profile \
-t cuda,nvtx,osrt,cudnn,cublas \
-s cpu \
-x true \
--pytorch=autograd-nvtx \
-o nsys.profile \
--force-overwrite true \
--capture-range=cudaProfilerApi \
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ dependencies = [
jamun_train = "jamun.cmdline.train:main"
jamun_sample = "jamun.cmdline.sample:main"

[project.optional-dependencies]
analysis = [
"polars>=1.32.0",
"pyarrow>=21.0.0",
"seaborn>=0.13.2",
]

[build-system]
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"
Expand Down
2 changes: 1 addition & 1 deletion src/jamun/callbacks/sampler/_chemical_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
datasets=datasets,
metric_fn=lambda dataset: ChemicalValidityMetrics(*args, dataset=dataset, **kwargs),
)
py_logger = logging.getLogger("jamun")
py_logger = logging.getLogger(__name__)
py_logger.info(
f"Initialized ChemicalValidityMetricsCallback with datasets of labels: {list(self.meters.keys())}."
)
4 changes: 2 additions & 2 deletions src/jamun/callbacks/sampler/_measure_sampling_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def on_after_sample_batch(self, sample, fabric, batch_idx):
fabric.log("sampler/avg_time_per_graph", self.total_sampling_time / self.total_num_graphs, step=batch_idx)

# Log to console
py_logger = logging.getLogger("jamun")
py_logger = logging.getLogger(__name__)
py_logger.info(
f"Sampled batch {batch_idx} with {num_graphs} samples in {time_elapsed:.4f} seconds "
f"({time_elapsed / num_graphs:.4f} seconds per sample)."
Expand Down Expand Up @@ -102,7 +102,7 @@ def on_sample_end(self, fabric):
fabric.log("sampler/std_batch_time", torch.std(torch.tensor(self.batch_times)).item())

# Log to console.
py_logger = logging.getLogger("jamun")
py_logger = logging.getLogger(__name__)
py_logger.info(
f"Total sampling time: {self.total_sampling_time:.4f} seconds "
f"for {self.total_num_graphs} samples "
Expand Down
2 changes: 1 addition & 1 deletion src/jamun/callbacks/sampler/_posebusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ def __init__(
datasets=datasets,
metric_fn=lambda dataset: PoseBustersMetrics(*args, dataset=dataset, **kwargs),
)
py_logger = logging.getLogger("jamun")
py_logger = logging.getLogger(__name__)
py_logger.info(f"Initialized PoseBustersCallback with datasets of labels: {list(self.meters.keys())}.")
2 changes: 1 addition & 1 deletion src/jamun/callbacks/sampler/_ramachandran.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
datasets=datasets,
metric_fn=lambda dataset: RamachandranPlotMetrics(dataset=dataset),
)
py_logger = logging.getLogger("jamun")
py_logger = logging.getLogger(__name__)
py_logger.info(
f"Initialized RamachandranPlotMetricsCallback with datasets of labels: {list(self.meters.keys())}."
)
2 changes: 1 addition & 1 deletion src/jamun/callbacks/sampler/_save_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ def __init__(
datasets=datasets,
metric_fn=lambda dataset: SaveTrajectory(*args, dataset=dataset, **kwargs),
)
py_logger = logging.getLogger("jamun")
py_logger = logging.getLogger(__name__)
py_logger.info(f"Initialized SaveTrajectoryCallback with datasets of labels: {list(self.meters.keys())}.")
2 changes: 1 addition & 1 deletion src/jamun/callbacks/sampler/_score_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ def __init__(
datasets=datasets,
metric_fn=lambda dataset: ScoreDistributionMetrics(*args, dataset=dataset, **kwargs),
)
py_logger = logging.getLogger("jamun")
py_logger = logging.getLogger(__name__)
py_logger.info(f"Initialized ScoreDistributionCallback with datasets of labels: {list(self.meters.keys())}.")
2 changes: 1 addition & 1 deletion src/jamun/callbacks/sampler/_trajectory_animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ def __init__(
datasets=datasets,
metric_fn=lambda dataset: TrajectoryVisualizer(*args, dataset=dataset, **kwargs),
)
py_logger = logging.getLogger("jamun")
py_logger = logging.getLogger(__name__)
py_logger.info(f"Initialized TrajectoryVisualizerCallback with datasets of labels: {list(self.meters.keys())}.")
2 changes: 1 addition & 1 deletion src/jamun/callbacks/sampler/_visualize_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ def __init__(
datasets=datasets,
metric_fn=lambda dataset: SampleVisualizer(*args, dataset=dataset, **kwargs),
)
py_logger = logging.getLogger("jamun")
py_logger = logging.getLogger(__name__)
py_logger.info(f"Initialized SampleVisualizerCallback with datasets of labels: {list(self.meters.keys())}.")
46 changes: 30 additions & 16 deletions src/jamun/cmdline/sample.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import sys
import traceback
Expand Down Expand Up @@ -25,11 +26,13 @@
dotenv.load_dotenv(".env", verbose=True)
OmegaConf.register_new_resolver("format", format_resolver)

py_logger = logging.getLogger("jamun")


def sample_loop(
fabric,
model,
batch_sampler,
sampler,
num_batches: int,
init_graphs: torch_geometric.data.Data,
continue_chain: bool = False,
Expand All @@ -38,7 +41,7 @@ def sample_loop(
model_wrapped = jamun.utils.ModelSamplingWrapper(
model=model,
init_graphs=init_graphs,
sigma=batch_sampler.sigma,
sigma=sampler.sigma,
)

y_init = model_wrapped.sample_initial_noisy_positions()
Expand All @@ -57,7 +60,7 @@ def sample_loop(
for batch_idx in iterable:
fabric.call("on_before_sample_batch", fabric=fabric, batch_idx=batch_idx)

out = batch_sampler.sample(model=model_wrapped, y_init=y_init, v_init=v_init)
out = sampler.sample(model=model_wrapped, y_init=y_init, v_init=v_init)
samples = model_wrapped.unbatch_samples(out)

# Start next chain from the end state of the previous chain?
Expand Down Expand Up @@ -90,6 +93,27 @@ def get_initial_graphs(
def run(cfg):
log_cfg = OmegaConf.to_container(cfg, throw_on_missing=True, resolve=True)

rank_zero_logging_level = cfg.get("rank_zero_logging_level", "INFO")
non_rank_zero_logging_level = cfg.get("non_rank_zero_logging_level", "ERROR")

if rank_zero_only.rank == 0:
level = logging.getLevelNamesMapping()[rank_zero_logging_level]
else:
level = logging.getLevelNamesMapping()[non_rank_zero_logging_level]

py_logger.setLevel(level)

loggers = instantiate_dict_cfg(cfg.get("logger"), verbose=(rank_zero_only.rank == 0))
wandb_logger = None
for logger in loggers:
if isinstance(logger, pl.loggers.WandbLogger):
wandb_logger = logger

callbacks = instantiate_dict_cfg(cfg.get("callbacks"), verbose=(rank_zero_only.rank == 0))
fabric = hydra.utils.instantiate(cfg.fabric, callbacks=callbacks, loggers=loggers)

fabric.launch()

dist_log(f"{OmegaConf.to_yaml(log_cfg)}")
dist_log(f"{os.getcwd()=}")
dist_log(f"{torch.__config__.parallel_info()}")
Expand All @@ -100,14 +124,8 @@ def run(cfg):
dist_log(f"Setting float_32_matmul_precision to {matmul_prec}")
torch.set_float32_matmul_precision(matmul_prec)

loggers = instantiate_dict_cfg(cfg.get("logger"), verbose=(rank_zero_only.rank == 0))
wandb_logger = None
for logger in loggers:
if isinstance(logger, pl.loggers.WandbLogger):
wandb_logger = logger

if rank_zero_only.rank == 0 and wandb_logger:
dist_log(f"{wandb_logger.experiment.name=}")
py_logger.info(f"{wandb_logger.experiment.name=}")
wandb_logger.experiment.config.update({"cfg": log_cfg, "version": jamun.__version__, "cwd": os.getcwd()})

# Load the checkpoint either given the wandb run path or the checkpoint path.
Expand All @@ -128,14 +146,10 @@ def run(cfg):
repeat=cfg.repeat_init_samples,
)

callbacks = instantiate_dict_cfg(cfg.get("callbacks"), verbose=(rank_zero_only.rank == 0))
fabric = hydra.utils.instantiate(cfg.fabric, callbacks=callbacks, loggers=loggers)

fabric.launch()
fabric.setup(model)
model.eval()

batch_sampler = hydra.utils.instantiate(cfg.batch_sampler)
sampler = hydra.utils.instantiate(cfg.sampler)

if seed := cfg.get("seed"):
# During sampling, we want ranks to generate different chains.
Expand Down Expand Up @@ -172,7 +186,7 @@ def run(cfg):
sample_loop(
fabric=fabric,
model=model,
batch_sampler=batch_sampler,
sampler=sampler,
init_graphs=init_graphs,
num_batches=cfg.num_batches,
continue_chain=cfg.continue_chain,
Expand Down
13 changes: 13 additions & 0 deletions src/jamun/cmdline/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import pathlib
import sys
Expand All @@ -22,6 +23,8 @@
dotenv.load_dotenv(".env", verbose=True)
OmegaConf.register_new_resolver("format", format_resolver)

py_logger = logging.getLogger("jamun")


def compute_average_squared_distance_from_config(cfg: OmegaConf) -> float:
"""Computes the average squared distance for normalization from the data."""
Expand All @@ -36,6 +39,16 @@ def compute_average_squared_distance_from_config(cfg: OmegaConf) -> float:
def run(cfg):
log_cfg = OmegaConf.to_container(cfg, throw_on_missing=True, resolve=True)

rank_zero_logging_level = cfg.get("rank_zero_logging_level", "INFO")
non_rank_zero_logging_level = cfg.get("non_rank_zero_logging_level", "ERROR")

if rank_zero_only.rank == 0:
level = logging.getLevelNamesMapping()[rank_zero_logging_level]
else:
level = logging.getLevelNamesMapping()[non_rank_zero_logging_level]

py_logger.setLevel(level)

dist_log(f"{OmegaConf.to_yaml(log_cfg)}")
dist_log(f"{os.getcwd()=}")
dist_log(f"{torch.__config__.parallel_info()}")
Expand Down
4 changes: 2 additions & 2 deletions src/jamun/data/_mdtraj.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(
self.original_topology.atom(atom_indices[i]), self.original_topology.atom(atom_indices[i + 1])
)

py_logger = logging.getLogger("jamun")
py_logger = logging.getLogger(__name__)
py_logger.warning(
f"Dataset {self.label()}: No bonds found in topology. Assuming a coarse-grained model and creating bonds between consecutive residues."
)
Expand Down Expand Up @@ -250,7 +250,7 @@ def __init__(
self.traj.topology.atom(atom_indices[i]), self.traj.topology.atom(atom_indices[i + 1])
)

py_logger = logging.getLogger("jamun")
py_logger = logging.getLogger(__name__)
py_logger.warning(
f"Dataset {self.label()}: No bonds found in topology. Assuming a coarse-grained model and creating bonds between consecutive residues."
)
Expand Down
1 change: 1 addition & 0 deletions src/jamun/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def download_file(url: str, path: str, verbose: bool = False, block_size: int |
pbar.update(len(data))


# FIXME num_workers>0 breaks singleton cacheing of datasets
def parse_datasets_from_directory(
root: str,
traj_pattern: str,
Expand Down
2 changes: 1 addition & 1 deletion src/jamun/hydra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def instantiate_dict_cfg(cfg: DictConfig | None, verbose: bool = False):
raise TypeError("cfg must be a DictConfig")

if verbose:
py_logger = logging.getLogger("jamun")
py_logger = logging.getLogger(__name__)

for k, v in cfg.items():
if isinstance(v, DictConfig):
Expand Down
2 changes: 1 addition & 1 deletion src/jamun/hydra_config/sample.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- _self_
- model: denoiser_pretrained
- batch_sampler: walkjump.yaml
- sampler: walkjump.yaml
- logger: default
- paths: default
- hydra: default
Expand Down
Loading