Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,16 @@ dependencies = [
"pulp<2.8.0",
"rdkit",
"s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging",
"safari @ git+https://github.com/GuptaVishu2002/safari.git@fix-setup",
Comment on lines 28 to +29
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dependency safari @ git+https://github.com/GuptaVishu2002/safari.git@fix-setup pulls third-party code directly from a Git repository using a mutable ref (fix-setup), which enables supply-chain attacks if that branch/tag is ever compromised or force-moved. An attacker who gains control of that repo could silently change the code at the same ref and have malicious code executed in any environment that installs this project. To mitigate this, pin Git-based dependencies to immutable commit hashes (or published versions on a trusted index) and periodically update them intentionally, rather than tracking branches/tags.

Copilot uses AI. Check for mistakes.
"transformers==4.38.2",
"mamba-ssm @ https://github.com/state-spaces/mamba/releases/download/v2.2.6.post3/mamba_ssm-2.2.6.post3+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl",
"causal-conv1d @ https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu118torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl",
"mamba3-minimal @ https://github.com/GuptaVishu2002/mamba3-minimal/fix-packaging",
"scikit-learn",
"scipy==1.11.1",
"selfies",
"hydra-core",
"pytorch-lightning",
# snakemake->stopit needs pkg_resources, but is failing
# to specify setuptools as a dependency
"setuptools",
Expand Down
9 changes: 8 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ threadpoolctl==3.2.0
throttler==1.2.2
tomli==2.0.1
toposort==1.10
torch==2.3.0
torch==2.4.0
tqdm==4.66.4
traitlets==5.14.3
twine==5.1.0
Expand All @@ -131,3 +131,10 @@ zipp==3.19.0
einops==0.6.0
opt_einsum==3.3.0
s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging
safari @ git+https://github.com/GuptaVishu2002/safari.git@fix-setup
pytorch-lightning==2.6.0
hydra-core==1.3.2
transformers==4.38.2
mamba-ssm @ https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
causal-conv1d @ https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu118torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
mamba3-minimal @ git+https://github.com/GuptaVishu2002/mamba3-minimal.git@fix-packaging
219 changes: 143 additions & 76 deletions src/clm/commands/sample_molecules_RNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
ConditionalRNN,
Transformer,
StructuredStateSpaceSequenceModel,
) # , H3Model, H3ConvModel, HyenaModel
H3Model,
HyenaModel,
MambaModel,
Mamba2Model,
Mamba3Model,
)
from clm.functions import load_dataset, write_to_csv_file

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -41,6 +46,9 @@ def add_args(parser):
parser.add_argument(
"--n_layers", type=int, help="Number of layers in the model"
)
parser.add_argument(
"--n_blocks", type=int, help="Number of blocks for S4 model"
)
parser.add_argument(
"--state_dim", type=int, help="State dimension for S4 model"
)
Expand All @@ -50,9 +58,40 @@ def add_args(parser):
parser.add_argument(
"--n_heads", type=int, help="Number of heads for the model"
)
parser.add_argument(
"--head_dim", type=int, help="Dimension of head for the H3 model"
)
parser.add_argument(
"--n_order_heads",
type=int,
help="Number of groups input channels for filter computation for the Hyena model",
)
parser.add_argument(
"--exp_factor", type=int, help="Expansion factor for Transformer model"
)
parser.add_argument(
"--bias", action="store_true", help="Use bias in Transformer model"
)
parser.add_argument(
"--use_fast_fftconv",
action="store_true",
help="Use fast FFT convolution for H3 model",
)
parser.add_argument(
"--measure",
type=str,
help="Measure parameter for the H3 model",
)
parser.add_argument(
"--mode", type=str, help="Mode parameter for the H3 model"
)
parser.add_argument(
"--lr", type=float, help="Learning rate for the H3 model"
)
parser.add_argument("--order", type=int, help="Order for Hyena model")
parser.add_argument(
"--filter_order", type=int, help="Filter order for Hyena model"
)
parser.add_argument(
"--dropout", type=float, help="Dropout rate for the RNN"
)
Expand Down Expand Up @@ -95,6 +134,9 @@ def add_args(parser):
parser.add_argument(
"--batch_size", type=int, help="Batch size for training"
)
parser.add_argument(
"--max_len", type=int, help="Maximum length of the generated sequences"
)
parser.add_argument(
"--sample_mols", type=int, help="Number of molecules to generate"
)
Expand All @@ -121,12 +163,23 @@ def sample_molecules_RNN(
embedding_size,
hidden_size,
n_layers,
n_blocks,
state_dim,
n_ssm,
n_heads,
head_dim,
n_order_heads,
exp_factor,
bias,
use_fast_fftconv,
measure,
mode,
lr,
order,
filter_order,
dropout,
batch_size,
max_len,
sample_mols,
vocab_file,
model_file,
Expand All @@ -150,98 +203,101 @@ def sample_molecules_RNN(

heldout_dataset = None

if model_type == "S4":
if model_type in [
"S4",
"H3",
"Hyena",
"Transformer",
"Mamba",
"Mamba2",
"Mamba3",
]:
assert (
heldout_file is not None
), "heldout_file must be provided for conditional RNN Model"
heldout_dataset = load_dataset(
representation=representation,
input_file=heldout_file,
vocab_file=vocab_file,
)
not conditional
), f"Conditional mode is not implemented for {model_type} model"

if model_type == "S4":
model = StructuredStateSpaceSequenceModel(
vocabulary=vocab, # heldout_dataset.vocabulary
model_dim=embedding_size,
state_dim=state_dim,
n_layers=n_layers,
n_blocks=n_blocks,
n_ssm=n_ssm,
dropout=dropout,
max_len=max_len,
)
# elif model_type == "H3":
# assert (
# heldout_file is not None
# ), "heldout_file must be provided for conditional RNN Model"
# heldout_dataset = load_dataset(
# representation=representation,
# input_file=heldout_file,
# vocab_file=vocab_file,
# )
# model = H3Model(
# vocabulary=vocab,
# n_layers=n_layers,
# d_model=embedding_size,
# d_state=64,
# head_dim=1,
# dropout=dropout,
# max_len=250,
# use_fast_fftconv=False,
# )
# elif model_type == "H3Conv":
# assert (
# heldout_file is not None
# ), "heldout_file must be provided for conditional RNN Model"
# heldout_dataset = load_dataset(
# representation=representation,
# input_file=heldout_file,
# vocab_file=vocab_file,
# )
# model = H3ConvModel(
# vocabulary=vocab,
# n_layers=n_layers,
# d_model=embedding_size,
# head_dim=1,
# dropout=dropout,
# max_len=250,
# use_fast_fftconv=False,
# )
# elif model_type == "Hyena":
# assert (
# heldout_file is not None
# ), "heldout_file must be provided for conditional RNN Model"
# heldout_dataset = load_dataset(
# representation=representation,
# input_file=heldout_file,
# vocab_file=vocab_file,
# )
# model = HyenaModel(
# vocabulary=vocab,
# n_layers=n_layers,
# d_model=embedding_size,
# order=2,
# filter_order=64,
# num_heads=1,
# dropout=dropout,
# max_len=250,
# inner_factor=1,
# )

elif model_type == "Transformer":
assert (
heldout_file is not None
), "heldout_file must be provided for conditional RNN Model"
heldout_dataset = load_dataset(
representation=representation,
input_file=heldout_file,
vocab_file=vocab_file,
elif model_type == "H3":
model = H3Model(
vocabulary=vocab,
n_layers=n_layers,
model_dim=embedding_size,
state_dim=state_dim,
head_dim=head_dim,
dropout=dropout,
use_fast_fftconv=use_fast_fftconv,
measure=measure,
mode=mode,
lr=lr,
max_len=max_len,
)

elif model_type == "Hyena":
model = HyenaModel(
vocabulary=vocab,
n_layers=n_layers,
d_model=embedding_size,
order=order,
filter_order=filter_order,
n_order_heads=n_order_heads,
dropout=dropout,
max_len=max_len,
)
elif model_type == "Mamba":
model = MambaModel(
vocabulary=vocab,
n_layers=n_layers,
model_dim=embedding_size,
d_state=state_dim,
d_conv=4,
expand=2,
dropout=dropout,
max_len=max_len,
)
elif model_type == "Mamba2":
model = Mamba2Model(
vocabulary=vocab,
n_layers=n_layers,
model_dim=embedding_size,
d_state=state_dim,
d_conv=4,
expand=2,
dropout=dropout,
max_len=max_len,
)
elif model_type == "Mamba3":
model = Mamba3Model(
vocabulary=vocab,
n_layers=n_layers,
model_dim=embedding_size,
d_state=state_dim,
headdim=32,
expand=2,
chunk_size=64,
dropout=dropout,
max_len=max_len,
)
Comment on lines +230 to +255
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New sampling branches for H3, H3Conv, and Hyena were added, but the test suite only exercises sampling for RNN/conditional RNN. Add at least one sample_molecules_RNN test per new model type (similar to the existing RNN sampling tests) so regressions in model loading/sampling don’t go untested.

Copilot uses AI. Check for mistakes.

elif model_type == "Transformer":
model = Transformer(
vocabulary=vocab,
n_blocks=n_layers,
n_heads=n_heads,
embedding_size=embedding_size,
dropout=dropout,
exp_factor=exp_factor,
bias=True,
bias=bias,
max_len=max_len,
)

elif model_type == "RNN":
Expand Down Expand Up @@ -332,12 +388,23 @@ def main(args):
embedding_size=args.embedding_size,
hidden_size=args.hidden_size,
n_layers=args.n_layers,
n_blocks=args.n_blocks,
state_dim=args.state_dim,
n_ssm=args.n_ssm,
n_heads=args.n_heads,
head_dim=args.head_dim,
n_order_heads=args.n_order_heads,
exp_factor=args.exp_factor,
bias=args.bias,
use_fast_fftconv=args.use_fast_fftconv,
measure=args.measure,
mode=args.mode,
lr=args.lr,
order=args.order,
filter_order=args.filter_order,
dropout=args.dropout,
batch_size=args.batch_size,
max_len=args.max_len,
sample_mols=args.sample_mols,
vocab_file=args.vocab_file,
model_file=args.model_file,
Expand Down
Loading
Loading