diff --git a/pyproject.toml b/pyproject.toml index 86aa9acf..0da7e041 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,9 +26,16 @@ dependencies = [ "pulp<2.8.0", "rdkit", "s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging", + "safari @ git+https://github.com/GuptaVishu2002/safari.git@fix-setup", + "transformers==4.38.2", + "mamba-ssm @ https://github.com/state-spaces/mamba/releases/download/v2.2.6.post3/mamba_ssm-2.2.6.post3+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl", + "causal-conv1d @ https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu118torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl", + "mamba3-minimal @ https://github.com/GuptaVishu2002/mamba3-minimal/fix-packaging", "scikit-learn", "scipy==1.11.1", "selfies", + "hydra-core", + "pytorch-lightning", # snakemake->stopit needs pkg_resources, but is failing # to specify setuptools as a dependency "setuptools", diff --git a/requirements.txt b/requirements.txt index e5502ffb..b4266a54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -117,7 +117,7 @@ threadpoolctl==3.2.0 throttler==1.2.2 tomli==2.0.1 toposort==1.10 -torch==2.3.0 +torch==2.4.0 tqdm==4.66.4 traitlets==5.14.3 twine==5.1.0 @@ -131,3 +131,10 @@ zipp==3.19.0 einops==0.6.0 opt_einsum==3.3.0 s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging +safari @ git+https://github.com/GuptaVishu2002/safari.git@fix-setup +pytorch-lightning==2.6.0 +hydra-core==1.3.2 +transformers==4.38.2 +mamba-ssm @ https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl +causal-conv1d @ https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu118torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl +mamba3-minimal @ git+https://github.com/GuptaVishu2002/mamba3-minimal.git@fix-packaging diff --git a/src/clm/commands/sample_molecules_RNN.py b/src/clm/commands/sample_molecules_RNN.py index fd9685df..03622bb1 100644 --- a/src/clm/commands/sample_molecules_RNN.py +++ b/src/clm/commands/sample_molecules_RNN.py @@ -11,7 +11,12 @@ ConditionalRNN, Transformer, StructuredStateSpaceSequenceModel, -) # , H3Model, H3ConvModel, HyenaModel + H3Model, + HyenaModel, + MambaModel, + Mamba2Model, + Mamba3Model, +) from clm.functions import load_dataset, write_to_csv_file logger = logging.getLogger(__name__) @@ -41,6 +46,9 @@ def add_args(parser): parser.add_argument( "--n_layers", type=int, help="Number of layers in the model" ) + parser.add_argument( + "--n_blocks", type=int, help="Number of blocks for S4 model" + ) parser.add_argument( "--state_dim", type=int, help="State dimension for S4 model" ) @@ -50,9 +58,40 @@ def add_args(parser): parser.add_argument( "--n_heads", type=int, help="Number of heads for the model" ) + parser.add_argument( + "--head_dim", type=int, help="Dimension of head for the H3 model" + ) + parser.add_argument( + "--n_order_heads", + type=int, + help="Number of groups input channels for filter computation for the Hyena model", + ) parser.add_argument( "--exp_factor", type=int, help="Expansion factor for Transformer model" ) + parser.add_argument( + "--bias", action="store_true", help="Use bias in Transformer model" + ) + parser.add_argument( + "--use_fast_fftconv", + action="store_true", + help="Use fast FFT convolution for H3 model", + ) + parser.add_argument( + "--measure", + type=str, + help="Measure parameter for the H3 model", + ) + parser.add_argument( + "--mode", type=str, help="Mode parameter for the H3 model" + ) + parser.add_argument( + "--lr", type=float, help="Learning rate for the H3 model" + ) + parser.add_argument("--order", type=int, help="Order for Hyena model") + parser.add_argument( + "--filter_order", type=int, help="Filter order for Hyena model" + ) parser.add_argument( "--dropout", type=float, help="Dropout rate for the RNN" ) @@ -95,6 +134,9 @@ def add_args(parser): parser.add_argument( "--batch_size", type=int, help="Batch size for training" ) + parser.add_argument( + "--max_len", type=int, help="Maximum length of the generated sequences" + ) parser.add_argument( "--sample_mols", type=int, help="Number of molecules to generate" ) @@ -121,12 +163,23 @@ def sample_molecules_RNN( embedding_size, hidden_size, n_layers, + n_blocks, state_dim, n_ssm, n_heads, + head_dim, + n_order_heads, exp_factor, + bias, + use_fast_fftconv, + measure, + mode, + lr, + order, + filter_order, dropout, batch_size, + max_len, sample_mols, vocab_file, model_file, @@ -150,90 +203,92 @@ def sample_molecules_RNN( heldout_dataset = None - if model_type == "S4": + if model_type in [ + "S4", + "H3", + "Hyena", + "Transformer", + "Mamba", + "Mamba2", + "Mamba3", + ]: assert ( - heldout_file is not None - ), "heldout_file must be provided for conditional RNN Model" - heldout_dataset = load_dataset( - representation=representation, - input_file=heldout_file, - vocab_file=vocab_file, - ) + not conditional + ), f"Conditional mode is not implemented for {model_type} model" + + if model_type == "S4": model = StructuredStateSpaceSequenceModel( vocabulary=vocab, # heldout_dataset.vocabulary model_dim=embedding_size, state_dim=state_dim, - n_layers=n_layers, + n_blocks=n_blocks, n_ssm=n_ssm, dropout=dropout, + max_len=max_len, ) - # elif model_type == "H3": - # assert ( - # heldout_file is not None - # ), "heldout_file must be provided for conditional RNN Model" - # heldout_dataset = load_dataset( - # representation=representation, - # input_file=heldout_file, - # vocab_file=vocab_file, - # ) - # model = H3Model( - # vocabulary=vocab, - # n_layers=n_layers, - # d_model=embedding_size, - # d_state=64, - # head_dim=1, - # dropout=dropout, - # max_len=250, - # use_fast_fftconv=False, - # ) - # elif model_type == "H3Conv": - # assert ( - # heldout_file is not None - # ), "heldout_file must be provided for conditional RNN Model" - # heldout_dataset = load_dataset( - # representation=representation, - # input_file=heldout_file, - # vocab_file=vocab_file, - # ) - # model = H3ConvModel( - # vocabulary=vocab, - # n_layers=n_layers, - # d_model=embedding_size, - # head_dim=1, - # dropout=dropout, - # max_len=250, - # use_fast_fftconv=False, - # ) - # elif model_type == "Hyena": - # assert ( - # heldout_file is not None - # ), "heldout_file must be provided for conditional RNN Model" - # heldout_dataset = load_dataset( - # representation=representation, - # input_file=heldout_file, - # vocab_file=vocab_file, - # ) - # model = HyenaModel( - # vocabulary=vocab, - # n_layers=n_layers, - # d_model=embedding_size, - # order=2, - # filter_order=64, - # num_heads=1, - # dropout=dropout, - # max_len=250, - # inner_factor=1, - # ) - elif model_type == "Transformer": - assert ( - heldout_file is not None - ), "heldout_file must be provided for conditional RNN Model" - heldout_dataset = load_dataset( - representation=representation, - input_file=heldout_file, - vocab_file=vocab_file, + elif model_type == "H3": + model = H3Model( + vocabulary=vocab, + n_layers=n_layers, + model_dim=embedding_size, + state_dim=state_dim, + head_dim=head_dim, + dropout=dropout, + use_fast_fftconv=use_fast_fftconv, + measure=measure, + mode=mode, + lr=lr, + max_len=max_len, + ) + + elif model_type == "Hyena": + model = HyenaModel( + vocabulary=vocab, + n_layers=n_layers, + d_model=embedding_size, + order=order, + filter_order=filter_order, + n_order_heads=n_order_heads, + dropout=dropout, + max_len=max_len, + ) + elif model_type == "Mamba": + model = MambaModel( + vocabulary=vocab, + n_layers=n_layers, + model_dim=embedding_size, + d_state=state_dim, + d_conv=4, + expand=2, + dropout=dropout, + max_len=max_len, + ) + elif model_type == "Mamba2": + model = Mamba2Model( + vocabulary=vocab, + n_layers=n_layers, + model_dim=embedding_size, + d_state=state_dim, + d_conv=4, + expand=2, + dropout=dropout, + max_len=max_len, ) + elif model_type == "Mamba3": + model = Mamba3Model( + vocabulary=vocab, + n_layers=n_layers, + model_dim=embedding_size, + d_state=state_dim, + headdim=32, + expand=2, + chunk_size=64, + dropout=dropout, + max_len=max_len, + ) + + elif model_type == "Transformer": model = Transformer( vocabulary=vocab, n_blocks=n_layers, @@ -241,7 +296,8 @@ def sample_molecules_RNN( embedding_size=embedding_size, dropout=dropout, exp_factor=exp_factor, - bias=True, + bias=bias, + max_len=max_len, ) elif model_type == "RNN": @@ -332,12 +388,23 @@ def main(args): embedding_size=args.embedding_size, hidden_size=args.hidden_size, n_layers=args.n_layers, + n_blocks=args.n_blocks, state_dim=args.state_dim, n_ssm=args.n_ssm, n_heads=args.n_heads, + head_dim=args.head_dim, + n_order_heads=args.n_order_heads, exp_factor=args.exp_factor, + bias=args.bias, + use_fast_fftconv=args.use_fast_fftconv, + measure=args.measure, + mode=args.mode, + lr=args.lr, + order=args.order, + filter_order=args.filter_order, dropout=args.dropout, batch_size=args.batch_size, + max_len=args.max_len, sample_mols=args.sample_mols, vocab_file=args.vocab_file, model_file=args.model_file, diff --git a/src/clm/commands/train_models_RNN.py b/src/clm/commands/train_models_RNN.py index 84ace720..ededcbc0 100644 --- a/src/clm/commands/train_models_RNN.py +++ b/src/clm/commands/train_models_RNN.py @@ -11,7 +11,12 @@ ConditionalRNN, Transformer, StructuredStateSpaceSequenceModel, -) # , H3Model, H3ConvModel, HyenaModel + H3Model, + HyenaModel, + MambaModel, + Mamba2Model, + Mamba3Model, +) from clm.loggers import EarlyStopping, track_loss, print_update from clm.functions import write_smiles, load_dataset @@ -31,108 +36,122 @@ def add_args(parser): default="SMILES", help="Molecular representation format (one of: SMILES/SELFIES)", ) - parser.add_argument( "--model_type", type=str, help="Type of model used (e.g., S4, Transformer)", ) - parser.add_argument( "--rnn_type", type=str, help="Type of RNN used (e.g., LSTM, GRU)" ) - parser.add_argument( "--embedding_size", type=int, help="Size of the embedding layer" ) - parser.add_argument( "--hidden_size", type=int, help="Size of the hidden layers" ) - parser.add_argument( "--n_layers", type=int, help="Number of layers in the model" ) - + parser.add_argument( + "--n_blocks", type=int, help="Number of blocks for S4 model" + ) parser.add_argument( "--state_dim", type=int, help="State dimension for S4 model" ) - parser.add_argument( "--n_ssm", type=int, help="Number of SSM layers for S4 model" ) - parser.add_argument( "--n_heads", type=int, help="Number of heads for the model" ) - + parser.add_argument( + "--head_dim", type=int, help="Dimension of head for the H3 model" + ) + parser.add_argument( + "--n_order_heads", + type=int, + help="Number of groups input channels for filter computation for the Hyena model", + ) parser.add_argument( "--exp_factor", type=int, help="Expansion factor for Transformer model" ) - + parser.add_argument( + "--bias", action="store_true", help="Use bias in Transformer model" + ) + parser.add_argument( + "--use_fast_fftconv", + action="store_true", + help="Use fast FFT convolution for H3 model", + ) + parser.add_argument( + "--measure", + type=str, + help="Measure parameter for the H3 model", + ) + parser.add_argument( + "--mode", type=str, help="Mode parameter for the H3 model" + ) + parser.add_argument( + "--lr", type=float, help="Learning rate for the H3 model" + ) + parser.add_argument("--order", type=int, help="Order for Hyena model") + parser.add_argument( + "--filter_order", type=int, help="Filter order for Hyena model" + ) parser.add_argument( "--dropout", type=float, help="Dropout rate for the RNN" ) - parser.add_argument( "--batch_size", type=int, help="Batch size for training" ) - + parser.add_argument( + "--max_len", type=int, help="Maximum length of the generated sequences" + ) parser.add_argument( "--learning_rate", type=float, help="Learning rate for the optimizer" ) - parser.add_argument( "--max_epochs", type=int, help="Maximum number of epochs for training" ) - parser.add_argument( "--patience", type=int, help="Patience for early stopping" ) - parser.add_argument( "--log_every_steps", type=int, help="Logging frequency in steps" ) - parser.add_argument( "--log_every_epochs", type=int, help="Logging frequency in epochs" ) - parser.add_argument( "--sample_mols", type=int, help="Number of molecules to sample for evaluation", ) - parser.add_argument( "--input_file", type=str, required=True, help="Input file path for training data", ) - parser.add_argument( "--vocab_file", type=str, required=True, help="Output path for the vocabulary file ({fold} is populated automatically)", ) - parser.add_argument( "--smiles_file", type=str, default=None, help="File path for additional SMILES data (optional)", ) - parser.add_argument( "--model_file", type=str, help="File path to save the trained model" ) - parser.add_argument( "--loss_file", type=str, help="File path to save the training loss data" ) - parser.add_argument( "--conditional", action="store_true", @@ -203,12 +222,23 @@ def train_models_RNN( embedding_size, hidden_size, n_layers, + n_blocks, state_dim, n_ssm, n_heads, + head_dim, + n_order_heads, exp_factor, + bias, + use_fast_fftconv, + measure, + mode, + lr, + order, + filter_order, dropout, batch_size, + max_len, learning_rate, max_epochs, patience, @@ -238,46 +268,73 @@ def train_models_RNN( vocabulary=dataset.vocabulary, model_dim=embedding_size, state_dim=state_dim, - n_layers=n_layers, + n_blocks=n_blocks, n_ssm=n_ssm, dropout=dropout, + max_len=max_len, ) - # elif model_type == "H3": - # model = H3Model( - # vocabulary=dataset.vocabulary, - # n_layers=n_layers, - # d_model=embedding_size, - # d_state=64, - # head_dim=1, - # dropout=dropout, - # max_len=250, - # use_fast_fftconv=False, - # ) + elif model_type == "H3": + model = H3Model( + vocabulary=dataset.vocabulary, + n_layers=n_layers, + model_dim=embedding_size, + state_dim=state_dim, + head_dim=head_dim, + dropout=dropout, + use_fast_fftconv=use_fast_fftconv, + measure=measure, + mode=mode, + lr=lr, + max_len=max_len, + ) - # elif model_type == "H3Conv": - # model = H3ConvModel( - # vocabulary=dataset.vocabulary, - # n_layers=n_layers, - # d_model=embedding_size, - # head_dim=1, - # dropout=dropout, - # max_len=250, - # use_fast_fftconv=False, - # ) + elif model_type == "Hyena": + model = HyenaModel( + vocabulary=dataset.vocabulary, + n_layers=n_layers, + d_model=embedding_size, + order=order, + filter_order=filter_order, + n_order_heads=n_order_heads, + dropout=dropout, + max_len=max_len, + ) - # elif model_type == "Hyena": - # model = HyenaModel( - # vocabulary=dataset.vocabulary, - # n_layers=n_layers, - # d_model=embedding_size, - # order=2, - # filter_order=64, - # num_heads=1, - # dropout=dropout, - # max_len=250, - # inner_factor=1, - # ) + elif model_type == "Mamba": + model = MambaModel( + vocabulary=dataset.vocabulary, + n_layers=n_layers, + model_dim=embedding_size, + d_state=state_dim, + d_conv=4, + expand=2, + dropout=dropout, + max_len=max_len, + ) + elif model_type == "Mamba2": + model = Mamba2Model( + vocabulary=dataset.vocabulary, + n_layers=n_layers, + model_dim=embedding_size, + d_state=state_dim, + d_conv=4, + expand=2, + dropout=dropout, + max_len=max_len, + ) + elif model_type == "Mamba3": + model = Mamba3Model( + vocabulary=dataset.vocabulary, + n_layers=n_layers, + model_dim=embedding_size, + d_state=state_dim, + headdim=32, + expand=2, + chunk_size=64, + dropout=dropout, + max_len=max_len, + ) elif model_type == "Transformer": model = Transformer( @@ -287,7 +344,8 @@ def train_models_RNN( embedding_size=embedding_size, dropout=dropout, exp_factor=exp_factor, - bias=True, + bias=bias, + max_len=max_len, ) elif model_type == "RNN": @@ -388,12 +446,23 @@ def main(args): embedding_size=args.embedding_size, hidden_size=args.hidden_size, n_layers=args.n_layers, + n_blocks=args.n_blocks, state_dim=args.state_dim, n_ssm=args.n_ssm, n_heads=args.n_heads, + head_dim=args.head_dim, + n_order_heads=args.n_order_heads, exp_factor=args.exp_factor, + bias=args.bias, + use_fast_fftconv=args.use_fast_fftconv, + measure=args.measure, + mode=args.mode, + lr=args.lr, + order=args.order, + filter_order=args.filter_order, dropout=args.dropout, batch_size=args.batch_size, + max_len=args.max_len, learning_rate=args.learning_rate, max_epochs=args.max_epochs, patience=args.patience, diff --git a/src/clm/models.py b/src/clm/models.py index ad365a8b..b936a357 100644 --- a/src/clm/models.py +++ b/src/clm/models.py @@ -4,578 +4,1174 @@ import torch.nn.functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence -# from clm.src.models.sequence.h3 import H3 -# from clm.src.models.sequence.h3_conv import H3Conv -# from clm.src.models.sequence.hyena_components import HyenaOperator +from safari.models.sequence.h3 import H3 + +from safari.models.sequence.hyena_components import HyenaOperator + +try: + from safari.ops.fftconv import fftconv_func + + HAS_FFTCONV = True +except ImportError: + print( + "Warning: fftconv CUDA extension not available, using reference implementation" + ) + HAS_FFTCONV = False + fftconv_func = None from s4dd.module_library.sequence_model import SequenceModel +from mamba_ssm.modules.mamba_simple import Mamba +from mamba_ssm.modules.mamba2 import Mamba2 +from mamba3 import Mamba3Config, Mamba3LMHeadModel + + +class Mamba3Model(nn.Module): + """CLM wrapper around the Mamba-3 SSM architecture. + + Mamba-3 improves over Mamba-2 with: + - Trapezoidal discretization (second-order accurate state update) + - Complex-valued SSM via data-dependent RoPE (enables state-tracking) + - MIMO formulation (better hardware utilisation during decode) + - QK-Normalisation on B, C projections + - Learnable BC bias (head-specific, channel-wise, init to ones) + - No short convolution (trapezoidal + bias removes the need for conv1d) + + The backbone is imported from the mamba3-minimal package: + https://github.com/GuptaVishu2002/mamba3-minimal/tree/fix-packaging + + Architecture follows Llama design: + Embedding → N × [RMSNorm → Mamba3 → RMSNorm → SwiGLU] → RMSNorm → LM Head + + Interface mirrors MambaModel / Mamba2Model so it is a drop-in replacement + inside train_models_RNN.py. + """ + + def __init__( + self, + vocabulary, + n_layers: int = 4, + model_dim: int = 256, + d_state: int = 128, + headdim: int = 64, + expand: int = 2, + chunk_size: int = 64, + dropout: float = 0.1, + max_len: int = 250, + use_mimo: bool = False, + mimo_rank: int = 4, + **kwargs, + ): + super(Mamba3Model, self).__init__() + + # Device + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + # Vocabulary + self.vocabulary = vocabulary + self.vocabulary_size = len(self.vocabulary) + self.padding_idx = self.vocabulary.dictionary[""] + + # Hyperparameters (stored for repr / checkpointing) + self.model_dim = model_dim + self.d_state = d_state + self.headdim = headdim + self.expand = expand + self.chunk_size = chunk_size + self.n_layers = n_layers + self.dropout_p = dropout + self.max_len = max_len + self.use_mimo = use_mimo + self.mimo_rank = mimo_rank + + # ── Validate headdim ────────────────────────────────────────────────── + d_inner = expand * model_dim + assert ( + d_inner % headdim == 0 + ), f"d_inner (expand*model_dim = {d_inner}) must be divisible by headdim ({headdim})" + assert ( + d_state % 2 == 0 + ), f"d_state ({d_state}) must be even for complex SSM / RoPE pairing" + + # ── Build Mamba-3 backbone ──────────────────────────────────────────── + # Mamba3LMHeadModel has its own embedding + LM head, but we replace the + # embedding with one that uses our vocabulary's padding index, and we + # repurpose the LM head as our output projection. + cfg = Mamba3Config( + d_model=model_dim, + n_layer=n_layers, + d_state=d_state, + expand=expand, + headdim=headdim, + chunk_size=chunk_size, + vocab_size=self.vocabulary_size, + pad_vocab_size_multiple=1, # exact size; no padding of vocab + use_mimo=use_mimo, + mimo_rank=mimo_rank, + ) + self.backbone = Mamba3LMHeadModel(cfg, device=str(self.device)) + + # Replace the embedding so we can honour padding_idx. + # (Mamba3LMHeadModel does not expose padding_idx in its Embedding.) + self.backbone.backbone.embedding = nn.Embedding( + self.vocabulary_size, + model_dim, + padding_idx=self.padding_idx, + ).to(self.device) + + # Re-tie lm_head weights to the new embedding. + self.backbone.lm_head.weight = self.backbone.backbone.embedding.weight + + # ── Dropout (applied after each residual block output) ─────────────── + self.dropout_layer = nn.Dropout(dropout) + + # ── Loss function ───────────────────────────────────────────────────── + self.loss_fn = nn.CrossEntropyLoss( + ignore_index=self.padding_idx, reduction="none" + ) + + # Move everything to the target device + if torch.cuda.is_available(): + self.cuda() + + # ────────────────────────────────────────────────────────────────────────── + # forward + # ────────────────────────────────────────────────────────────────────────── + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x : (batch_size, seq_len) — token indices + + Returns + ------- + logits : (batch_size, seq_len, vocab_size) + """ + # Mamba3LMHeadModel.forward returns (logits, inference_caches). + # We only need logits during training / teacher-forced evaluation. + logits, _ = self.backbone(x, h=None) + return logits + + # ────────────────────────────────────────────────────────────────────────── + # loss + # ────────────────────────────────────────────────────────────────────────── + + def loss(self, batch) -> torch.Tensor: + """Compute mean cross-entropy loss over non-padding positions. + + The collate function returns tensors in (seq_len, batch_size) layout; + we transpose to (batch_size, seq_len) and call .contiguous() to avoid + stride errors inside the Mamba-3 SSD kernels. + """ + padded, lengths, _ = batch + + padded = padded.to(self.device) + + # Collate returns (seq_len, batch_size) → (batch_size, seq_len). + # .contiguous() is required: transpose() only swaps strides without + # copying memory, which can cause stride-alignment errors in SSD. + padded = padded.transpose(0, 1).contiguous() + + # Mamba-3's chunked SSD requires seqlen to be a multiple of chunk_size. + # Pad the sequence dimension if necessary. + seq_len = padded.shape[1] + remainder = seq_len % self.chunk_size + if remainder != 0: + pad_len = self.chunk_size - remainder + padded = F.pad(padded, (0, pad_len), value=self.padding_idx) + + logits = self(padded) # (batch_size, padded_seq_len, vocab_size) + + # Teacher-forced targets: shift right by one position. + # Trim both to the original (unpadded) seq_len - 1 so padding tokens + # never contribute to the loss (CrossEntropyLoss ignores padding_idx). + targets = padded[:, 1:seq_len] # (batch_size, seq_len - 1) + logits = logits[ + :, : seq_len - 1, : + ] # (batch_size, seq_len - 1, vocab_size) + + loss = 0.0 + for char_idx in range(targets.shape[1]): + loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + + return loss.mean() + + # ────────────────────────────────────────────────────────────────────────── + # sample + # ────────────────────────────────────────────────────────────────────────── + + def sample( + self, + *, + n_sequences: int, + max_len: int = None, + return_smiles: bool = True, + return_losses: bool = False, + descriptors=None, + ): + """Auto-regressively sample sequences from the model. + + Uses the Mamba-3 inference cache (constant-time per step) so generation + is efficient: the full prefix is processed in one chunked forward pass, + then each new token is decoded in O(1) via the recurrent step path. + + Parameters + ---------- + n_sequences : int + Number of sequences to generate in parallel. + max_len : int, optional + Maximum generation length (defaults to self.max_len). + return_smiles : bool + Decode token sequences to SMILES strings if True. + return_losses : bool + Also return per-sequence NLL losses if True. + descriptors : ignored + Accepted for API compatibility with ConditionalRNN; has no effect. + """ + if max_len is None: + max_len = self.max_len + + was_training = self.training + self.eval() + + start_token = self.vocabulary.dictionary["SOS"] + stop_token = self.vocabulary.dictionary["EOS"] + pad_token = self.vocabulary.dictionary[""] + + loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + finished = torch.zeros( + n_sequences, dtype=torch.uint8, device=self.device + ) + log_probs = torch.zeros(n_sequences, device=self.device) + sequences: list[torch.Tensor] = [] + + with torch.no_grad(): + # ── Initialise with SOS token; build inference cache ────────────── + # Shape: (n_sequences, 1) + current_ids = torch.full( + (n_sequences, 1), + start_token, + dtype=torch.long, + device=self.device, + ) + + # Process the SOS token through the chunked (non-inference) path to + # initialise the per-layer caches h. chunk_size=1 is allowed because + # seqlen=1 is divisible by 1; alternatively we pad to chunk_size. + # We pad to chunk_size and then use only the last logit. + pad_len = self.chunk_size - 1 + if pad_len > 0: + padded_ids = F.pad(current_ids, (pad_len, 0), value=pad_token) + else: + padded_ids = current_ids + + logits_init, h = self.backbone(padded_ids, h=None) + # h is now a list of InferenceCache, one per layer. + + # Get the logit for the position corresponding to SOS (last position). + logits_step = logits_init[:, -1, :] # (n_sequences, vocab_size) + logits_step = torch.clamp(logits_step, min=-1e4, max=1e4) + prob = F.softmax(logits_step, dim=-1) + + if not (torch.isnan(prob).any() or torch.isinf(prob).any()): + outputs = torch.multinomial(prob, num_samples=1).squeeze(1) + sequences.append(outputs.view(-1, 1)) + + log_prob = F.log_softmax(logits_step, dim=-1) + losses = loss_fn(log_prob, outputs) + losses[finished.bool()] = 0 + log_probs += losses + + finished = torch.ge(finished + (outputs == stop_token), 1) + else: + # Fallback: emit SOS again (will be cleaned up by vocabulary.decode) + outputs = current_ids.squeeze(1) + + # ── Auto-regressive generation using the recurrent step ─────────── + for _ in range(max_len - 1): + if torch.prod(finished) == 1: + break + + # Shape: (n_sequences, 1) — the previously sampled token + next_ids = outputs.unsqueeze(1) + + # Constant-time step via inference cache + logits_step, h = self.backbone(next_ids, h=h) + logits_step = logits_step[:, -1, :] # (n_sequences, vocab_size) + logits_step = torch.clamp(logits_step, min=-1e4, max=1e4) + prob = F.softmax(logits_step, dim=-1) + + if torch.isnan(prob).any() or torch.isinf(prob).any(): + break + + outputs = torch.multinomial(prob, num_samples=1).squeeze(1) + sequences.append(outputs.view(-1, 1)) + + log_prob = F.log_softmax(logits_step, dim=-1) + losses = loss_fn(log_prob, outputs) + losses[finished.bool()] = 0 + log_probs += losses + + finished = torch.ge(finished + (outputs == stop_token), 1) + + seqs = ( + torch.cat(sequences, dim=1) + if sequences + else torch.full( + (n_sequences, 1), + start_token, + dtype=torch.long, + device=self.device, + ) + ) + + if return_smiles: + smiles = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] + else: + smiles = sequences + + if was_training: + self.train() + + if return_losses: + return smiles, log_probs.detach().cpu().numpy() + return smiles + + +class MambaModel(nn.Module): + def __init__( + self, + vocabulary, + n_layers=4, + model_dim=256, + d_state=16, + d_conv=4, + expand=2, + dropout=0.1, + max_len=250, + **kwargs, + ): + super(MambaModel, self).__init__() + + # Device + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + # Vocabulary + self.vocabulary = vocabulary + self.vocabulary_size = len(self.vocabulary) + self.padding_idx = self.vocabulary.dictionary[""] + + # Hyperparameters + self.model_dim = model_dim + self.d_state = d_state + self.d_conv = d_conv + self.n_layers = n_layers + self.expand = expand + self.dropout = dropout + self.max_len = max_len + + # Model components + padding_t = torch.tensor(self.padding_idx).to(self.device) + self.embedding = nn.Embedding( + self.vocabulary_size, self.model_dim, padding_idx=padding_t + ) + + # Stack of Mamba layers + # This module uses roughly 3 * expand * d_model^2 parameters + self.mamba_layers = nn.ModuleList( + [ + Mamba( + d_model=self.model_dim, # Model dimension d_model + d_state=self.d_state, # SSM state expansion factor + d_conv=self.d_conv, # Local convolution width + expand=self.expand, # Block expansion factor + ) + for _ in range(n_layers) + ] + ) + + # Layer norm for each layer + self.layer_norms = nn.ModuleList( + [nn.LayerNorm(self.model_dim) for _ in range(n_layers)] + ) + + # Dropout + self.dropout_layer = nn.Dropout(dropout) + + # Output projection + self.output_embedding = nn.Linear(self.model_dim, self.vocabulary_size) + + # Loss function + self.loss_fn = nn.CrossEntropyLoss( + ignore_index=self.padding_idx, reduction="none" + ) + + # Final layer norm applied after all Mamba layers, before output projection + self.final_norm = nn.LayerNorm(self.model_dim) + + # Move to GPU + if torch.cuda.is_available(): + self.cuda() + + def forward(self, x): + """ + x: (batch_size, seq_len) + Returns: (batch_size, seq_len, vocab_size) + """ + batch_size, seq_len = x.size() + + # Embed + x = self.embedding(x) # (batch_size, seq_len, model_dim) + + # Apply Mamba layers with residual connections + for mamba_layer, layer_norm in zip(self.mamba_layers, self.layer_norms): + residual = x + x = layer_norm(x) + x = mamba_layer(x) # Mamba expects (B, L, D) + x = self.dropout_layer(x) + x = x + residual + + x = self.final_norm(x) + # Project to vocabulary + logits = self.output_embedding(x) # (batch_size, seq_len, vocab_size) + + return logits + + def loss(self, batch): + """Compute loss for a batch.""" + # Collate always returns (padded, lengths, descriptors); descriptor ignored here + padded, lengths, _ = batch + + padded = padded.to(self.device) + + # Collate always returns (seq_len, batch_size); transpose to (batch_size, seq_len) + padded = padded.transpose(0, 1) + + logits = self(padded) + + targets = padded[:, 1:] + logits = logits[:, :-1, :] + + loss = 0.0 + actual_len = min(logits.shape[1], targets.shape[1]) + for char_idx in range(actual_len): + loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + + return loss.mean() + + def sample( + self, + *, + n_sequences, + max_len=None, + return_smiles=True, + return_losses=False, + descriptors=None, + ): + if max_len is None: + max_len = self.max_len + + was_training = self.training + self.eval() + + start_token = self.vocabulary.dictionary["SOS"] + stop_token = self.vocabulary.dictionary["EOS"] + pad_token = self.vocabulary.dictionary[""] + + inputs = ( + torch.empty(n_sequences).fill_(start_token).long().to(self.device) + ) + loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + finished = torch.zeros(n_sequences).byte().to(self.device) + log_probs = torch.zeros(n_sequences).to(self.device) + sequences = [] + + with torch.no_grad(): + for step in range(max_len): + if step == 0: + current_seq = inputs.unsqueeze(1) + else: + seq_list = [inputs.unsqueeze(1)] + sequences + current_seq = torch.cat(seq_list, dim=1) + + logits = self(current_seq)[:, -1, :] + logits = torch.clamp(logits, min=-1e4, max=1e4) + prob = F.softmax(logits, dim=-1) + + if torch.isnan(prob).any() or torch.isinf(prob).any(): + break + + outputs = torch.multinomial(prob, num_samples=1).squeeze(1) + sequences.append(outputs.view(-1, 1)) + + log_prob = F.log_softmax(logits, dim=-1) + losses = loss_fn(log_prob, outputs) + losses[finished.bool()] = 0 + log_probs += losses + + finished = torch.ge(finished + (outputs == stop_token), 1) + if torch.prod(finished) == 1: + break + + # ← added empty-sequence fallback, matches H3/Hyena/S4 + seqs = ( + torch.cat(sequences, 1) + if sequences + else torch.empty(n_sequences, 1, dtype=torch.long) + .fill_(start_token) + .to(self.device) + ) + + if return_smiles: + smiles = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] + else: + smiles = sequences + + if was_training: + self.train() + + if return_losses: + return smiles, log_probs.detach().cpu().numpy() + else: + return smiles + + +class Mamba2Model(nn.Module): + def __init__( + self, + vocabulary, + n_layers=4, + model_dim=256, + d_state=64, + d_conv=4, + expand=2, + dropout=0.1, + max_len=250, + **kwargs, + ): + super(Mamba2Model, self).__init__() + + # Device + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + # Vocabulary + self.vocabulary = vocabulary + self.vocabulary_size = len(self.vocabulary) + self.padding_idx = self.vocabulary.dictionary[""] + + # Hyperparameters + self.model_dim = model_dim + self.d_state = d_state + self.d_conv = d_conv + self.n_layers = n_layers + self.expand = expand + self.dropout = dropout + self.max_len = max_len + + # Model components + padding_t = torch.tensor(self.padding_idx).to(self.device) + self.embedding = nn.Embedding( + self.vocabulary_size, self.model_dim, padding_idx=padding_t + ) + + # Stack of Mamba2 layers + # causal_conv1d_cuda requires d_in_proj % 8 == 0, where: + # d_in_proj = 2*expand*d_model + 2*d_state + nheads + # nheads = (expand*d_model) // headdim + # Find the largest headdim (power-of-2) that satisfies this. + d_inner = int(self.expand * self.model_dim) + _headdim = None + for hd in [64, 32, 16, 8]: + if d_inner % hd != 0: + continue + nheads_candidate = d_inner // hd + if (2 * d_inner + 2 * self.d_state + nheads_candidate) % 8 == 0: + _headdim = hd + break + if _headdim is None: + raise ValueError( + f"No valid headdim found for model_dim={self.model_dim}, " + f"expand={self.expand}, d_state={self.d_state}. " + f"Ensure (2*expand*d_model + 2*d_state + nheads) is divisible by 8." + ) + self.mamba2_layers = nn.ModuleList( + [ + Mamba2( + d_model=self.model_dim, # Model dimension d_model + d_state=self.d_state, # SSM state expansion factor, typically 64 or 128 + d_conv=self.d_conv, # Local convolution width + expand=self.expand, # Block expansion factor + headdim=_headdim, + ) + for _ in range(n_layers) + ] + ) + + # Layer norm for each layer + self.layer_norms = nn.ModuleList( + [nn.LayerNorm(self.model_dim) for _ in range(n_layers)] + ) + + # Dropout + self.dropout_layer = nn.Dropout(dropout) + + # Output projection + self.output_embedding = nn.Linear(self.model_dim, self.vocabulary_size) + + # Loss function + self.loss_fn = nn.CrossEntropyLoss( + ignore_index=self.padding_idx, reduction="none" + ) + + # Final layer norm applied after all Mamba2 layers, before output projection + self.final_norm = nn.LayerNorm(self.model_dim) + + # Move to GPU + if torch.cuda.is_available(): + self.cuda() + + def forward(self, x): + """ + x: (batch_size, seq_len) + Returns: (batch_size, seq_len, vocab_size) + """ + batch_size, seq_len = x.size() + + # Embed + x = self.embedding(x) # (batch_size, seq_len, model_dim) + + # Apply Mamba2 layers with residual connections + for mamba2_layer, layer_norm in zip( + self.mamba2_layers, self.layer_norms + ): + residual = x + x = layer_norm(x) + x = mamba2_layer(x) # Mamba2 expects (B, L, D) + x = self.dropout_layer(x) + x = x + residual + + x = self.final_norm(x) + # Project to vocabulary + logits = self.output_embedding(x) # (batch_size, seq_len, vocab_size) + + return logits + + def loss(self, batch): + # Collate always returns (padded, lengths, descriptors); descriptor ignored here + padded, lengths, _ = batch + + padded = padded.to(self.device) + + # Collate returns (seq_len, batch_size); transpose to (batch_size, seq_len). + # .contiguous() is critical: transpose() only swaps strides without copying + # memory. Non-standard strides propagate through nn.Embedding into + # causal_conv1d, causing: RuntimeError: strides must be multiples of 8 + padded = padded.transpose(0, 1).contiguous() + + logits = self(padded) + + targets = padded[:, 1:] + logits = logits[:, :-1, :] + + loss = 0.0 + actual_len = min(logits.shape[1], targets.shape[1]) + for char_idx in range(actual_len): + loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + + return loss.mean() + + def sample( + self, + *, + n_sequences, + max_len=None, + return_smiles=True, + return_losses=False, + descriptors=None, + ): + if max_len is None: + max_len = self.max_len + + was_training = self.training + self.eval() + + start_token = self.vocabulary.dictionary["SOS"] + stop_token = self.vocabulary.dictionary["EOS"] + pad_token = self.vocabulary.dictionary[""] + + inputs = ( + torch.empty(n_sequences).fill_(start_token).long().to(self.device) + ) + loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + finished = torch.zeros(n_sequences).byte().to(self.device) + log_probs = torch.zeros(n_sequences).to(self.device) + sequences = [] + + with torch.no_grad(): + for step in range(max_len): + if step == 0: + current_seq = inputs.unsqueeze(1) + else: + seq_list = [inputs.unsqueeze(1)] + sequences + current_seq = torch.cat(seq_list, dim=1) + + logits = self(current_seq)[:, -1, :] + logits = torch.clamp(logits, min=-1e4, max=1e4) + prob = F.softmax(logits, dim=-1) + + if torch.isnan(prob).any() or torch.isinf(prob).any(): + break + + outputs = torch.multinomial(prob, num_samples=1).squeeze(1) + sequences.append(outputs.view(-1, 1)) + + log_prob = F.log_softmax(logits, dim=-1) + losses = loss_fn(log_prob, outputs) + losses[finished.bool()] = 0 + log_probs += losses + + finished = torch.ge(finished + (outputs == stop_token), 1) + if torch.prod(finished) == 1: + break + + seqs = ( + torch.cat(sequences, 1) + if sequences + else torch.empty(n_sequences, 1, dtype=torch.long) + .fill_(start_token) + .to(self.device) + ) + + if return_smiles: + smiles = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] + else: + smiles = sequences + + if was_training: + self.train() + + if return_losses: + return smiles, log_probs.detach().cpu().numpy() + else: + return smiles + + +class H3Model(nn.Module): + def __init__( + self, + vocabulary, + n_layers=4, + model_dim=256, + state_dim=64, + head_dim=1, + dropout=0.1, + max_len=250, + use_fast_fftconv=False, + # SSM kernel parameters + measure="diag-lin", + mode="diag", + lr=None, + **kernel_args, + ): + super(H3Model, self).__init__() + + # Device + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + # Vocabulary + self.vocabulary = vocabulary + self.vocabulary_size = len(self.vocabulary) + self.padding_idx = self.vocabulary.dictionary[""] + + # Hyperparameters + self.model_dim = model_dim + self.state_dim = state_dim + self.n_layers = n_layers + self.head_dim = head_dim + self.dropout = dropout + self.max_len = max_len + self.use_fast_fftconv = use_fast_fftconv and HAS_FFTCONV + self.measure = measure + self.mode = mode + + # Model components + self.embedding = nn.Embedding( + self.vocabulary_size, self.model_dim, padding_idx=self.padding_idx + ) + + # Stack of H3 layers using actual Safari implementation + self.h3_layers = nn.ModuleList( + [ + H3( + d_model=self.model_dim, + d_state=self.state_dim, + l_max=max_len, + head_dim=self.head_dim, + use_fast_fftconv=self.use_fast_fftconv, + dropout=self.dropout, + layer_idx=i, + mode=self.mode, # Use S4D variant + measure=self.measure, + lr=None if lr == 0.0 else lr, + **kernel_args, + ) + for i in range(n_layers) + ] + ) + + # Layer norm for each layer + self.layer_norms = nn.ModuleList( + [nn.LayerNorm(self.model_dim) for _ in range(n_layers)] + ) + + # Dropout + self.dropout_layer = nn.Dropout(dropout) + + # Final layer norm applied after all H3 layers, before output projection. + self.final_norm = nn.LayerNorm(self.model_dim) + + # Output projection + self.output_embedding = nn.Linear(self.model_dim, self.vocabulary_size) + + # Loss function + self.loss_fn = nn.CrossEntropyLoss( + ignore_index=self.padding_idx, reduction="none" + ) + + # Move to GPU + if torch.cuda.is_available(): + self.cuda() + + def forward(self, x): + """ + x: (batch_size, seq_len) + Returns: (batch_size, seq_len, vocab_size) + """ + batch_size, seq_len = x.size() + + # Embed + x = self.embedding(x) # (batch_size, seq_len, model_dim) + + # Apply H3 layers with residual connections + for h3_layer, layer_norm in zip(self.h3_layers, self.layer_norms): + residual = x + x = layer_norm(x) + x = h3_layer(x) # H3 expects (B, L, H) + x = self.dropout_layer(x) + x = x + residual + + # Final layer norm before output projection + x = self.final_norm(x) + + # Project to vocabulary + logits = self.output_embedding(x) # (batch_size, seq_len, vocab_size) + + return logits + + def loss(self, batch): + """Compute loss for a batch.""" + # Collate always returns (padded, lengths, descriptors); descriptor ignored here + padded, lengths, _ = batch + + padded = padded.to(self.device) + + # Collate always returns (seq_len, batch_size); transpose to (batch_size, seq_len) + padded = padded.transpose(0, 1) + + # Forward pass + logits = self(padded) # (batch_size, seq_len, vocab_size) + + # Calculate loss (predict next token) + targets = padded[:, 1:] # (batch_size, seq_len-1) + logits = logits[:, :-1, :] # (batch_size, seq_len-1, vocab_size) + + # Compute loss + loss = 0.0 + actual_len = min(logits.shape[1], targets.shape[1]) + + for char_idx in range(actual_len): + loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + + return loss.mean() + + def sample( + self, + *, + n_sequences, + max_len=None, + return_smiles=True, + return_losses=False, + descriptors=None, + ): + """Sample sequences from the model.""" + if max_len is None: + max_len = self.max_len + + was_training = self.training + self.eval() + + # Get tokens + start_token = self.vocabulary.dictionary["SOS"] + stop_token = self.vocabulary.dictionary["EOS"] + pad_token = self.vocabulary.dictionary[""] + + # Initialize + inputs = ( + torch.empty(n_sequences).fill_(start_token).long().to(self.device) + ) + + # Loss function + loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + + # Sampling loop + finished = torch.zeros(n_sequences).byte().to(self.device) + log_probs = torch.zeros(n_sequences).to(self.device) + sequences = [] + + with torch.no_grad(): + for step in range(max_len): + # Get logits for all sequences so far + if step == 0: + current_seq = inputs.unsqueeze(1) # (n_sequences, 1) + else: + # Build full sequence so far + seq_list = [inputs.unsqueeze(1)] + sequences + current_seq = torch.cat( + seq_list, dim=1 + ) # (n_sequences, step+1) + + # Forward pass + logits = self(current_seq) # (n_sequences, step+1, vocab_size) + logits = logits[ + :, -1, : + ] # Get last position (n_sequences, vocab_size) + + # Clamp and sample + logits = torch.clamp(logits, min=-1e4, max=1e4) + prob = F.softmax(logits, dim=-1) + + if torch.isnan(prob).any() or torch.isinf(prob).any(): + break + + outputs = torch.multinomial(prob, num_samples=1).squeeze(1) + sequences.append(outputs.view(-1, 1)) + + # Calculate NLL + log_prob = F.log_softmax(logits, dim=-1) + losses = loss_fn(log_prob, outputs) + + # Zero losses if finished + losses[finished.bool()] = 0 + log_probs += losses + + # Check if finished + finished = torch.ge(finished + (outputs == stop_token), 1) + if torch.prod(finished) == 1: + break + + # Concatenate sequences and decode + seqs = ( + torch.cat(sequences, 1) + if sequences + else torch.empty(n_sequences, 1, dtype=torch.long) + .fill_(start_token) + .to(self.device) + ) + if return_smiles: + smiles = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] + else: + smiles = sequences + + if was_training: + self.train() + + if return_losses: + return smiles, log_probs.detach().cpu().numpy() + else: + return smiles + + +class HyenaModel(nn.Module): + def __init__( + self, + vocabulary, + n_layers=4, + d_model=256, + order=2, + filter_order=64, + n_order_heads=1, + dropout=0.25, + max_len=250, + **hyena_args, + ): + super(HyenaModel, self).__init__() + + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self.vocabulary = vocabulary + self.vocabulary_size = len(vocabulary) + self.padding_idx = vocabulary.dictionary[""] + self.d_model = d_model + self.n_layers = n_layers + self.dropout = dropout + self.max_len = max_len + + self.embedding = nn.Embedding( + self.vocabulary_size, d_model, padding_idx=self.padding_idx + ) + + self.hyena_layers = nn.ModuleList( + [ + HyenaOperator( + d_model=d_model, + l_max=max_len, + order=order, + filter_order=filter_order, + num_heads=n_order_heads, + dropout=dropout, + **hyena_args, + ) + for _ in range(n_layers) + ] + ) + + self.layer_norms = nn.ModuleList( + [nn.LayerNorm(d_model) for _ in range(n_layers)] + ) + + self.dropout_layer = nn.Dropout(dropout) + # Final layer norm applied after all Hyena layers, before output projection. + self.final_norm = nn.LayerNorm(d_model) + self.output_embedding = nn.Linear(d_model, self.vocabulary_size) + self.loss_fn = nn.CrossEntropyLoss( + ignore_index=self.padding_idx, reduction="none" + ) + + if torch.cuda.is_available(): + self.cuda() + + def forward(self, x): + x = self.embedding(x) + + for hyena_layer, layer_norm in zip(self.hyena_layers, self.layer_norms): + residual = x + x = layer_norm(x) + x = hyena_layer(x) + x = self.dropout_layer(x) + x = x + residual + + # Final layer norm before output projection + x = self.final_norm(x) + return self.output_embedding(x) + + def loss(self, batch): + # Collate always returns (padded, lengths, descriptors); descriptor ignored here + padded, lengths, _ = batch + + padded = padded.to(self.device) + # Collate always returns (seq_len, batch_size); transpose to (batch_size, seq_len) + padded = padded.transpose(0, 1) + + logits = self(padded) + targets = padded[:, 1:] + logits = logits[:, :-1, :] + + loss = 0.0 + actual_len = min(logits.shape[1], targets.shape[1]) + for char_idx in range(actual_len): + loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + + return loss.mean() + + def sample( + self, + *, + n_sequences, + max_len=None, + return_smiles=True, + return_losses=False, + descriptors=None, + ): + if max_len is None: + max_len = self.max_len + + was_training = self.training + self.eval() + + start_token = self.vocabulary.dictionary["SOS"] + stop_token = self.vocabulary.dictionary["EOS"] + pad_token = self.vocabulary.dictionary[""] + + inputs = ( + torch.empty(n_sequences).fill_(start_token).long().to(self.device) + ) + loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + + finished = torch.zeros(n_sequences).byte().to(self.device) + log_probs = torch.zeros(n_sequences).to(self.device) + sequences = [] -# class H3Model(nn.Module): -# def __init__( -# self, -# vocabulary, -# n_layers=4, -# d_model=256, -# d_state=64, -# head_dim=1, -# dropout=0.1, -# max_len=250, -# use_fast_fftconv=False, -# ): -# super(H3Model, self).__init__() - -# if H3 is None: -# raise ImportError( -# "H3 modules not found. Make sure src.models.sequence.h3 is available." -# ) - -# # detect device -# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -# # vocabulary -# self.vocabulary = vocabulary -# self.vocabulary_size = len(self.vocabulary) -# self.padding_idx = self.vocabulary.dictionary[""] -# padding_t = torch.tensor(self.padding_idx).to(self.device) - -# # hyperparams -# self.n_layers = n_layers -# self.d_model = d_model -# self.d_state = d_state -# self.head_dim = head_dim -# self.dropout = dropout -# self.max_len = max_len -# self.use_fast_fftconv = use_fast_fftconv - -# # model components -# self.embedding = nn.Embedding( -# self.vocabulary_size, self.d_model, padding_idx=padding_t -# ) - -# # H3 layers -# self.layers = nn.ModuleList([ -# H3( -# d_model=self.d_model, -# d_state=self.d_state, -# l_max=self.max_len, -# head_dim=self.head_dim, -# use_fast_fftconv=self.use_fast_fftconv, -# dropout=self.dropout, -# layer_idx=i, -# ) -# for i in range(self.n_layers) -# ]) - -# # dropout and output -# self.norm = nn.LayerNorm(self.d_model) -# self.dropout_layer = nn.Dropout(dropout) -# self.output_projection = nn.Linear(self.d_model, self.vocabulary_size) - -# # loss function (ignoring padding) -# self.loss_fn = nn.CrossEntropyLoss( -# ignore_index=self.padding_idx, reduction="none" -# ) - -# # move to GPU -# if torch.cuda.is_available(): -# self.cuda() - -# def forward(self, x, inference_params=None): - -# batch_size, seq_len = x.size() - -# # Embed the input -# x = self.embedding(x) # (batch_size, seq_len, d_model) - -# # Pass through H3 layers -# for layer in self.layers: -# x = layer(x, inference_params=inference_params) -# if self.dropout > 0: -# x = self.dropout_layer(x) - -# # Normalize and project to vocabulary -# x = self.norm(x) -# logits = self.output_projection(x) # (batch_size, seq_len, vocab_size) - -# return logits - -# def loss(self, batch): -# if len(batch) == 3: -# padded, lengths, _ = batch -# else: -# padded, lengths = batch - -# padded = padded.to(self.device) - -# # Handle different input formats -# if padded.dim() == 2: -# if padded.shape[0] > padded.shape[1]: -# padded = padded.transpose(0, 1) - -# # Forward pass -# logits = self(padded) - -# # Calculate loss -# targets = padded[:, 1:] -# logits = logits[:, :-1, :] - -# loss = 0.0 -# actual_len = min(logits.shape[1], targets.shape[1]) - -# for char_idx in range(actual_len): -# loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) - -# return loss.mean() - -# def sample( -# self, -# *, -# n_sequences, -# max_len=None, -# return_smiles=True, -# return_losses=False, -# descriptors=None, -# ): -# if max_len is None: -# max_len = self.max_len - -# self.eval() - -# # Get start/stop tokens -# start_token = self.vocabulary.dictionary["SOS"] -# stop_token = self.vocabulary.dictionary["EOS"] -# pad_token = self.vocabulary.dictionary[""] - -# # Create inference params -# class InferenceParams: -# def __init__(self, max_seqlen, batch_size): -# self.max_seqlen = max_seqlen -# self.max_batch_size = batch_size -# self.sequence_len_offset = 0 -# self.key_value_memory_dict = {} - -# inference_params = InferenceParams(max_len, n_sequences) - -# # Initialize with start tokens - keep only current token for recurrent stepping -# current_token = torch.full( -# (n_sequences, 1), start_token, dtype=torch.long, device=self.device -# ) - -# loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) - -# finished = torch.zeros(n_sequences, dtype=torch.bool, device=self.device) -# log_probs = torch.zeros(n_sequences, device=self.device) -# sequences = [] - -# with torch.no_grad(): -# for step in range(max_len): -# # Process only the current token in recurrent mode -# logits = self(current_token, inference_params=inference_params) -# logits = logits[:, -1, :] # Get last (and only) position - -# logits = torch.clamp(logits, min=-1e4, max=1e4) -# prob = F.softmax(logits, dim=-1) - -# if torch.isnan(prob).any() or torch.isinf(prob).any(): -# break - -# outputs = torch.multinomial(prob, num_samples=1) -# sequences.append(outputs) - -# log_prob = F.log_softmax(logits, dim=-1) -# losses = loss_fn(log_prob, outputs.squeeze(1)) -# losses[finished] = 0 -# log_probs += losses - -# # Update current token for next step (don't accumulate) -# current_token = outputs -# inference_params.sequence_len_offset += 1 - -# finished = finished | (outputs.squeeze(1) == stop_token) -# if finished.all(): -# break - -# seqs = torch.cat(sequences, 1) if sequences else torch.full( -# (n_sequences, 1), start_token, dtype=torch.long, device=self.device -# ) - -# if return_smiles: -# outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] -# else: -# outputs = sequences - -# if return_losses: -# return outputs, log_probs.detach().cpu().numpy() -# else: -# return outputs - - -# class H3ConvModel(nn.Module): -# def __init__( -# self, -# vocabulary, -# n_layers=4, -# d_model=256, -# head_dim=1, -# dropout=0.1, -# max_len=250, -# use_fast_fftconv=False, -# ): -# super(H3ConvModel, self).__init__() - -# if H3Conv is None: -# raise ImportError( -# "H3Conv modules not found. Make sure src.models.sequence.h3_conv is available." -# ) - -# # detect device -# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -# # vocabulary -# self.vocabulary = vocabulary -# self.vocabulary_size = len(self.vocabulary) -# self.padding_idx = self.vocabulary.dictionary[""] -# padding_t = torch.tensor(self.padding_idx).to(self.device) - -# # hyperparams -# self.n_layers = n_layers -# self.d_model = d_model -# self.head_dim = head_dim -# self.dropout = dropout -# self.max_len = max_len -# self.use_fast_fftconv = use_fast_fftconv - -# # model components -# self.embedding = nn.Embedding( -# self.vocabulary_size, self.d_model, padding_idx=padding_t -# ) - -# # H3Conv layers -# self.layers = nn.ModuleList([ -# H3Conv( -# d_model=self.d_model, -# l_max=self.max_len, -# head_dim=self.head_dim, -# use_fast_fftconv=self.use_fast_fftconv, -# dropout=self.dropout, -# layer_idx=i, -# ) -# for i in range(self.n_layers) -# ]) - -# # dropout and output -# self.norm = nn.LayerNorm(self.d_model) -# self.dropout_layer = nn.Dropout(dropout) -# self.output_projection = nn.Linear(self.d_model, self.vocabulary_size) - -# # loss function (ignoring padding) -# self.loss_fn = nn.CrossEntropyLoss( -# ignore_index=self.padding_idx, reduction="none" -# ) - -# # move to GPU -# if torch.cuda.is_available(): -# self.cuda() - -# def forward(self, x, inference_params=None): -# batch_size, seq_len = x.size() - -# # Embed the input -# x = self.embedding(x) # (batch_size, seq_len, d_model) - -# # Pass through H3Conv layers -# for layer in self.layers: -# x = layer(x, inference_params=inference_params) -# if self.dropout > 0: -# x = self.dropout_layer(x) - -# # Normalize and project to vocabulary -# x = self.norm(x) -# logits = self.output_projection(x) - -# return logits - -# def loss(self, batch): -# if len(batch) == 3: -# padded, lengths, _ = batch -# else: -# padded, lengths = batch - -# padded = padded.to(self.device) - -# # Handle different input formats -# if padded.dim() == 2: -# if padded.shape[0] > padded.shape[1]: -# padded = padded.transpose(0, 1) - -# # Forward pass -# logits = self(padded) - -# # Calculate loss -# targets = padded[:, 1:] -# logits = logits[:, :-1, :] - -# loss = 0.0 -# actual_len = min(logits.shape[1], targets.shape[1]) - -# for char_idx in range(actual_len): -# loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) - -# return loss.mean() - -# def sample( -# self, -# *, -# n_sequences, -# max_len=None, -# return_smiles=True, -# return_losses=False, -# descriptors=None, -# ): -# if max_len is None: -# max_len = self.max_len - -# self.eval() - -# start_token = self.vocabulary.dictionary["SOS"] -# stop_token = self.vocabulary.dictionary["EOS"] -# pad_token = self.vocabulary.dictionary[""] - -# # H3Conv doesn't use stateful inference, process full sequence each time -# inputs = torch.full( -# (n_sequences, 1), start_token, dtype=torch.long, device=self.device -# ) - -# loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) - -# finished = torch.zeros(n_sequences, dtype=torch.bool, device=self.device) -# log_probs = torch.zeros(n_sequences, device=self.device) -# sequences = [] - -# with torch.no_grad(): -# for step in range(max_len): -# logits = self(inputs) -# logits = logits[:, -1, :] - -# logits = torch.clamp(logits, min=-1e4, max=1e4) -# prob = F.softmax(logits, dim=-1) - -# if torch.isnan(prob).any() or torch.isinf(prob).any(): -# break - -# outputs = torch.multinomial(prob, num_samples=1) -# sequences.append(outputs) - -# log_prob = F.log_softmax(logits, dim=-1) -# losses = loss_fn(log_prob, outputs.squeeze(1)) -# losses[finished] = 0 -# log_probs += losses - -# inputs = torch.cat([inputs, outputs], dim=1) - -# finished = finished | (outputs.squeeze(1) == stop_token) -# if finished.all(): -# break - -# seqs = torch.cat(sequences, 1) if sequences else torch.full( -# (n_sequences, 1), start_token, dtype=torch.long, device=self.device -# ) - -# if return_smiles: -# outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] -# else: -# outputs = sequences - -# if return_losses: -# return outputs, log_probs.detach().cpu().numpy() -# else: -# return outputs - - -# class HyenaModel(nn.Module): -# def __init__( -# self, -# vocabulary, -# n_layers=4, -# d_model=256, -# order=2, -# filter_order=64, -# num_heads=1, -# dropout=0.1, -# max_len=250, -# inner_factor=1, -# ): -# super(HyenaModel, self).__init__() - -# # detect device -# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -# # vocabulary -# self.vocabulary = vocabulary -# self.vocabulary_size = len(self.vocabulary) -# self.padding_idx = self.vocabulary.dictionary[""] -# padding_t = torch.tensor(self.padding_idx).to(self.device) - -# # hyperparams -# self.n_layers = n_layers -# self.d_model = d_model -# self.order = order -# self.filter_order = filter_order -# self.num_heads = num_heads -# self.dropout = dropout -# self.max_len = max_len -# self.inner_factor = inner_factor - -# # model components -# self.embedding = nn.Embedding( -# self.vocabulary_size, self.d_model, padding_idx=padding_t -# ) - -# # Hyena layers -# self.layers = nn.ModuleList([ -# HyenaOperator( -# d_model=self.d_model, -# l_max=self.max_len, -# order=self.order, -# filter_order=self.filter_order, -# num_heads=self.num_heads, -# inner_factor=self.inner_factor, -# dropout=self.dropout, -# ) -# for i in range(self.n_layers) -# ]) - -# # dropout and output -# self.norm = nn.LayerNorm(self.d_model) -# self.dropout_layer = nn.Dropout(dropout) -# self.output_projection = nn.Linear(self.d_model, self.vocabulary_size) - -# # loss function (ignoring padding) -# self.loss_fn = nn.CrossEntropyLoss( -# ignore_index=self.padding_idx, reduction="none" -# ) - -# # move to GPU -# if torch.cuda.is_available(): -# self.cuda() - -# def forward(self, x): -# batch_size, seq_len = x.size() - -# # Embed the input -# x = self.embedding(x) # (batch_size, seq_len, d_model) - -# # Pass through Hyena layers -# for layer in self.layers: -# residual = x -# x = layer(x) -# x = x + residual # Residual connection -# if self.dropout > 0: -# x = self.dropout_layer(x) - -# # Normalize and project to vocabulary -# x = self.norm(x) -# logits = self.output_projection(x) - -# return logits - -# def loss(self, batch): -# if len(batch) == 3: -# padded, lengths, _ = batch -# else: -# padded, lengths = batch - -# padded = padded.to(self.device) - -# # Handle different input formats -# if padded.dim() == 2: -# if padded.shape[0] > padded.shape[1]: -# padded = padded.transpose(0, 1) - -# # Forward pass -# logits = self(padded) - -# # Calculate loss -# targets = padded[:, 1:] -# logits = logits[:, :-1, :] - -# loss = 0.0 -# actual_len = min(logits.shape[1], targets.shape[1]) - -# for char_idx in range(actual_len): -# loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) - -# return loss.mean() - -# def sample( -# self, -# *, -# n_sequences, -# max_len=None, -# return_smiles=True, -# return_losses=False, -# descriptors=None, -# ): -# if max_len is None: -# max_len = self.max_len - -# self.eval() - -# start_token = self.vocabulary.dictionary["SOS"] -# stop_token = self.vocabulary.dictionary["EOS"] -# pad_token = self.vocabulary.dictionary[""] - -# # Initialize with start tokens -# inputs = torch.full( -# (n_sequences, 1), start_token, dtype=torch.long, device=self.device -# ) - -# loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) - -# finished = torch.zeros(n_sequences, dtype=torch.bool, device=self.device) -# log_probs = torch.zeros(n_sequences, device=self.device) -# sequences = [] - -# with torch.no_grad(): -# for step in range(max_len): -# # Hyena processes full sequence each time (stateless) -# logits = self(inputs) -# logits = logits[:, -1, :] + with torch.no_grad(): + for step in range(max_len): + if step == 0: + current_seq = inputs.unsqueeze(1) + else: + seq_list = [inputs.unsqueeze(1)] + sequences + current_seq = torch.cat(seq_list, dim=1) -# logits = torch.clamp(logits, min=-1e4, max=1e4) -# prob = F.softmax(logits, dim=-1) + logits = self(current_seq) + logits = logits[:, -1, :] -# if torch.isnan(prob).any() or torch.isinf(prob).any(): -# break + logits = torch.clamp(logits, min=-1e4, max=1e4) + prob = F.softmax(logits, dim=-1) -# outputs = torch.multinomial(prob, num_samples=1) -# sequences.append(outputs) + if torch.isnan(prob).any() or torch.isinf(prob).any(): + break -# log_prob = F.log_softmax(logits, dim=-1) -# losses = loss_fn(log_prob, outputs.squeeze(1)) -# losses[finished] = 0 -# log_probs += losses + outputs = torch.multinomial(prob, num_samples=1).squeeze(1) + sequences.append(outputs.view(-1, 1)) -# inputs = torch.cat([inputs, outputs], dim=1) + log_prob = F.log_softmax(logits, dim=-1) + losses = loss_fn(log_prob, outputs) + losses[finished.bool()] = 0 + log_probs += losses -# finished = finished | (outputs.squeeze(1) == stop_token) -# if finished.all(): -# break + finished = torch.ge(finished + (outputs == stop_token), 1) + if torch.prod(finished) == 1: + break -# seqs = torch.cat(sequences, 1) if sequences else torch.full( -# (n_sequences, 1), start_token, dtype=torch.long, device=self.device -# ) + # Concatenate sequences and decode + seqs = ( + torch.cat(sequences, 1) + if sequences + else torch.empty(n_sequences, 1, dtype=torch.long) + .fill_(start_token) + .to(self.device) + ) + if return_smiles: + outputs = [ + self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs + ] + else: + outputs = sequences -# if return_smiles: -# outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] -# else: -# outputs = sequences + if was_training: + self.train() -# if return_losses: -# return outputs, log_probs.detach().cpu().numpy() -# else: -# return outputs + if return_losses: + return outputs, log_probs.detach().cpu().numpy() + else: + return outputs class StructuredStateSpaceSequenceModel(nn.Module): @@ -584,7 +1180,7 @@ def __init__( vocabulary, model_dim=256, state_dim=64, - n_layers=4, + n_blocks=2, n_ssm=1, dropout=0.25, max_len=250, @@ -600,12 +1196,11 @@ def __init__( self.vocabulary = vocabulary self.vocabulary_size = len(self.vocabulary) self.padding_idx = self.vocabulary.dictionary[""] - padding_t = torch.tensor(self.padding_idx).to(self.device) # hyperparams self.model_dim = model_dim self.state_dim = state_dim - self.n_layers = n_layers + self.n_blocks = n_blocks self.n_ssm = n_ssm self.dropout = dropout self.max_len = max_len @@ -628,12 +1223,12 @@ def __init__( # model components self.embedding = nn.Embedding( - self.vocabulary_size, self.model_dim, padding_idx=padding_t + self.vocabulary_size, self.model_dim, padding_idx=self.padding_idx ) self.model = SequenceModel( d_model=self.model_dim, - n_layers=self.n_layers, + n_layers=self.n_blocks, transposed=False, # Changed to False - expect (batch, length, dim) dropout=self.dropout, layer=self.layer_config, @@ -685,19 +1280,13 @@ def recurrent_step(self, x_t): return x_t def loss(self, batch): - if len(batch) == 3: - padded, lengths, _ = batch - else: - padded, lengths = batch + # Collate always returns (padded, lengths, descriptors); descriptor ignored here + padded, lengths, _ = batch padded = padded.to(self.device) - # Handle different input formats - # RNN collate returns (seq_len, batch_size) format - # S4 model expects (batch_size, seq_len) format - # Always transpose since collate always uses (seq_len, batch_size) - if padded.dim() == 2: - padded = padded.transpose(0, 1) + # Collate always returns (seq_len, batch_size); transpose to (batch_size, seq_len) + padded = padded.transpose(0, 1) # batch_size = padded.shape[0] # seq_len = padded.shape[1] @@ -734,6 +1323,7 @@ def sample( if max_len is None: max_len = self.max_len + was_training = self.training # IMPORTANT: Set model to eval mode before sampling self.eval() @@ -814,6 +1404,9 @@ def sample( else: outputs = sequences + if was_training: + self.train() + # Optionally return losses if return_losses: return outputs, log_probs.detach().cpu().numpy() @@ -844,10 +1437,11 @@ def __init__( # embedding layer self.padding_idx = self.vocabulary.dictionary[""] - padding_t = torch.tensor(self.padding_idx).to(self.device) self.embedding_size = embedding_size self.embedding = nn.Embedding( - self.vocabulary_size, self.embedding_size, padding_idx=padding_t + self.vocabulary_size, + self.embedding_size, + padding_idx=self.padding_idx, ) # RNN architecture @@ -922,6 +1516,9 @@ def sample( return_losses=False, descriptors=None, ): + was_training = self.training + self.eval() + # get start/stop tokens start_token = self.vocabulary.dictionary["SOS"] stop_token = self.vocabulary.dictionary["EOS"] @@ -956,25 +1553,26 @@ def sample( finished = torch.zeros(n_sequences).byte().to(self.device) log_probs = torch.zeros(n_sequences).to(self.device) sequences = [] - for step in range(max_len): - embedded = self.embedding(inputs) - output, hidden = self.rnn(embedded, hidden) - logits = self.decoder(output) - prob = F.softmax(logits, dim=2) - inputs = torch.multinomial(prob.squeeze(0), num_samples=1).view( - 1, -1 - ) - sequences.append(inputs.view(-1, 1)) - # calculate NLL too - log_prob = F.log_softmax(logits.squeeze(0), dim=1) - losses = loss(log_prob, inputs.squeeze(0)) - # zero losses if we are finished sampling - losses[finished.squeeze(0).bool()] = 0 - log_probs += losses - # track whether sampling is done for all molecules - finished = torch.ge(finished + (inputs == stop_token), 1) - if torch.prod(finished) == 1: - break + with torch.no_grad(): + for step in range(max_len): + embedded = self.embedding(inputs) + output, hidden = self.rnn(embedded, hidden) + logits = self.decoder(output) + prob = F.softmax(logits, dim=2) + inputs = torch.multinomial(prob.squeeze(0), num_samples=1).view( + 1, -1 + ) + sequences.append(inputs.view(-1, 1)) + # calculate NLL too + log_prob = F.log_softmax(logits.squeeze(0), dim=1) + losses = loss(log_prob, inputs.squeeze(0)) + # zero losses if we are finished sampling + losses[finished.squeeze(0).bool()] = 0 + log_probs += losses + # track whether sampling is done for all molecules + finished = torch.ge(finished + (inputs == stop_token), 1) + if torch.prod(finished) == 1: + break # concatenate sequences and decode seqs = torch.cat(sequences, 1) @@ -985,6 +1583,9 @@ def sample( else: outputs = sequences + if was_training: + self.train() + # optionally return losses if return_losses: return outputs, log_probs.detach().cpu().numpy() @@ -1172,7 +1773,6 @@ def __init__( self.vocabulary = vocabulary self.vocabulary_size = len(self.vocabulary) self.padding_idx = self.vocabulary.dictionary[""] - padding_t = torch.tensor(self.padding_idx).to(self.device) # hyperparams self.n_blocks = n_blocks @@ -1188,7 +1788,7 @@ def __init__( wte=nn.Embedding( self.vocabulary_size, self.embedding_size, - padding_idx=padding_t, + padding_idx=self.padding_idx, ), wpe=nn.Embedding(self.max_len, self.embedding_size), drop=nn.Dropout(self.dropout), @@ -1255,17 +1855,13 @@ def forward(self, x): return logits def loss(self, batch): - if len(batch) == 3: - padded, lengths, _ = batch - else: - padded, lengths = batch + # Collate always returns (padded, lengths, descriptors); descriptor ignored here + padded, lengths, _ = batch padded = padded.to(self.device) - # RNN collate returns (seq_len, batch_size) format - # Transformer expects (batch_size, seq_len) format - if padded.dim() == 2: - padded = padded.transpose(0, 1) + # Collate always returns (seq_len, batch_size); transpose to (batch_size, seq_len) + padded = padded.transpose(0, 1) # Get actual sequence length from batch actual_seq_len = padded.shape[1] @@ -1294,6 +1890,7 @@ def sample( # Reset recurrent state before sampling # self.reset_state(n_sequences, device=self.device) + was_training = self.training self.eval() torch.cuda.empty_cache() @@ -1346,7 +1943,7 @@ def sample( if torch.prod(finished) == 1: break - # concatenate sequences and decode + # Concatenate sequences and decode seqs = ( torch.cat(sequences, 1) if sequences @@ -1363,6 +1960,9 @@ def sample( torch.cuda.empty_cache() + if was_training: + self.train() + # optionally return losses if return_losses: return outputs, log_probs.detach().cpu().numpy() @@ -1454,9 +2054,10 @@ def __init__( ) # Add num_descriptors to self.hidden_size self.padding_idx = self.vocabulary.dictionary[""] - padding_t = torch.tensor(self.padding_idx).to(self.device) self.embedding = nn.Embedding( - self.vocabulary_size, self.embedding_size, padding_idx=padding_t + self.vocabulary_size, + self.embedding_size, + padding_idx=self.padding_idx, ) self.n_layers = n_layers self.rnn_type = rnn_type @@ -1596,6 +2197,9 @@ def sample( n_sequences is None or len(descriptors) == n_sequences ), "When providing descriptor values, either omit n_sequences or make them conform to the number of descriptors" + was_training = self.training + self.eval() + # get start/stop tokens start_token = self.vocabulary.dictionary["SOS"] stop_token = self.vocabulary.dictionary["EOS"] @@ -1633,42 +2237,47 @@ def sample( finished = torch.zeros(n_sequences).byte().to(self.device) log_probs = torch.zeros(n_sequences).to(self.device) sequences = [] - for step in range(max_len): - embedded = self.embedding(inputs) - if self.conditional_emb_l: - combined_embedding = self.conditional_to_emb( - descriptors.float() - ) - embedded = torch.cat( - [embedded, combined_embedding], axis=2 - ).float() - elif self.conditional_emb: - embedded = torch.cat([embedded, descriptors], axis=2).float() - - output, hidden = self.rnn(embedded, hidden) - if self.conditional_dec_l: - combined_embedding = self.conditional_to_dec( - descriptors.float() + with torch.no_grad(): + for step in range(max_len): + embedded = self.embedding(inputs) + if self.conditional_emb_l: + combined_embedding = self.conditional_to_emb( + descriptors.float() + ) + embedded = torch.cat( + [embedded, combined_embedding], axis=2 + ).float() + elif self.conditional_emb: + embedded = torch.cat( + [embedded, descriptors], axis=2 + ).float() + + output, hidden = self.rnn(embedded, hidden) + if self.conditional_dec_l: + combined_embedding = self.conditional_to_dec( + descriptors.float() + ) + output = torch.cat( + [output, combined_embedding], axis=2 + ).float() + elif self.conditional_dec: + output = torch.cat([output, descriptors], axis=2).float() + + logits = self.decoder(output) + prob = F.softmax(logits, dim=2) + inputs = torch.multinomial(prob.squeeze(0), num_samples=1).view( + 1, -1 ) - output = torch.cat([output, combined_embedding], axis=2).float() - elif self.conditional_dec: - output = torch.cat([output, descriptors], axis=2).float() - - logits = self.decoder(output) - prob = F.softmax(logits, dim=2) - inputs = torch.multinomial(prob.squeeze(0), num_samples=1).view( - 1, -1 - ) - sequences.append(inputs.view(-1, 1)) - log_prob = F.log_softmax(logits.squeeze(0), dim=1) - losses = loss(log_prob, inputs.squeeze(0)) - # zero losses if we are finished sampling - losses[finished.squeeze(0).bool()] = 0 - log_probs += losses - # track whether sampling is done for all molecules - finished = torch.ge(finished + (inputs == stop_token), 1) - if torch.prod(finished) == 1: - break + sequences.append(inputs.view(-1, 1)) + log_prob = F.log_softmax(logits.squeeze(0), dim=1) + losses = loss(log_prob, inputs.squeeze(0)) + # zero losses if we are finished sampling + losses[finished.squeeze(0).bool()] = 0 + log_probs += losses + # track whether sampling is done for all molecules + finished = torch.ge(finished + (inputs == stop_token), 1) + if torch.prod(finished) == 1: + break # concatenate sequences and decode seqs = torch.cat(sequences, 1) @@ -1677,6 +2286,9 @@ def sample( else: smiles = sequences + if was_training: + self.train() + if return_losses: return smiles, log_probs.detach().cpu().numpy() else: diff --git a/tests/test_snakemake_steps.py b/tests/test_snakemake_steps.py index c8170dd1..eb952df8 100644 --- a/tests/test_snakemake_steps.py +++ b/tests/test_snakemake_steps.py @@ -195,12 +195,23 @@ def test_02_train_models_RNN(tmp_path): embedding_size=32, hidden_size=256, n_layers=3, + n_blocks=1, state_dim=64, n_ssm=1, n_heads=4, + head_dim=1, + n_order_heads=1, exp_factor=4, + bias=True, + use_fast_fftconv=False, + order=2, + filter_order=64, + measure="diag-lin", + mode="diag", + lr=0.0, dropout=0, batch_size=64, + max_len=250, learning_rate=0.001, max_epochs=3, patience=5000, @@ -226,13 +237,24 @@ def test_02_train_models_conditional_RNN(tmp_path): rnn_type="LSTM", embedding_size=32, hidden_size=256, - n_layers=3, + n_layers=2, + n_blocks=1, state_dim=64, n_ssm=1, n_heads=4, + head_dim=1, + n_order_heads=1, exp_factor=4, + bias=True, + use_fast_fftconv=False, + order=2, + filter_order=64, + measure="diag-lin", + mode="diag", + lr=0.0, dropout=0, batch_size=64, + max_len=250, learning_rate=0.001, max_epochs=3, patience=5000, @@ -259,13 +281,24 @@ def test_02_train_models_S4(tmp_path): rnn_type="S4", embedding_size=32, hidden_size=256, - n_layers=3, + n_layers=2, + n_blocks=1, state_dim=64, n_ssm=1, n_heads=4, + head_dim=1, + n_order_heads=1, exp_factor=4, + bias=True, + use_fast_fftconv=False, + order=2, + filter_order=64, + measure="diag-lin", + mode="diag", + lr=0.0, dropout=0, batch_size=64, + max_len=250, learning_rate=0.001, max_epochs=3, patience=5000, @@ -284,6 +317,92 @@ def test_02_train_models_S4(tmp_path): # so we simply ensure that this step runs without errors. +def test_02_train_models_H3(tmp_path): + train_models_RNN.train_models_RNN( + representation="SMILES", + model_type="H3", + rnn_type="H3", + embedding_size=32, + hidden_size=256, + n_layers=2, + n_blocks=1, + state_dim=64, + n_ssm=2, + n_heads=2, + head_dim=1, + n_order_heads=1, + exp_factor=4, + bias=True, + use_fast_fftconv=False, + order=2, + filter_order=64, + measure="diag-lin", + mode="diag", + lr=0.0, + dropout=0, + batch_size=64, + max_len=250, + learning_rate=0.001, + max_epochs=3, + patience=5000, + log_every_steps=100, + log_every_epochs=1, + sample_mols=100, + input_file=test_dir + / "0/prior/inputs/train_LOTUS_truncated_SMILES_0.smi", + vocab_file=test_dir + / "0/prior/inputs/train_LOTUS_truncated_SMILES_0.vocabulary", + model_file=tmp_path / "LOTUS_truncated_SMILES_0_0_model_H3.pt", + loss_file=tmp_path / "LOTUS_truncated_SMILES_0_0_loss.csv", + smiles_file=None, + ) + # Model loss values can vary between platforms and architectures, + # so we simply ensure that this step runs without errors. + + +def test_02_train_models_Hyena(tmp_path): + train_models_RNN.train_models_RNN( + representation="SMILES", + model_type="Hyena", + rnn_type="Hyena", + embedding_size=32, + hidden_size=256, + n_layers=2, + n_blocks=1, + state_dim=64, + n_ssm=2, + n_heads=2, + head_dim=1, + n_order_heads=1, + exp_factor=4, + bias=True, + use_fast_fftconv=False, + order=2, + filter_order=64, + measure="diag-lin", + mode="diag", + lr=0.0, + dropout=0, + batch_size=64, + max_len=250, + learning_rate=0.001, + max_epochs=3, + patience=5000, + log_every_steps=100, + log_every_epochs=1, + sample_mols=100, + input_file=test_dir + / "0/prior/inputs/train_LOTUS_truncated_SMILES_0.smi", + vocab_file=test_dir + / "0/prior/inputs/train_LOTUS_truncated_SMILES_0.vocabulary", + model_file=tmp_path / "LOTUS_truncated_SMILES_0_0_model_Hyena.pt", + loss_file=tmp_path / "LOTUS_truncated_SMILES_0_0_loss.csv", + smiles_file=None, + ) + # Model loss values can vary between platforms and architectures, + # so we simply ensure that this step runs without errors. + + def test_02_train_models_Transformer(tmp_path): train_models_RNN.train_models_RNN( representation="SMILES", @@ -292,12 +411,23 @@ def test_02_train_models_Transformer(tmp_path): embedding_size=32, hidden_size=256, n_layers=3, + n_blocks=1, state_dim=64, n_ssm=1, n_heads=4, + head_dim=1, + n_order_heads=1, exp_factor=4, + bias=True, + use_fast_fftconv=False, + order=2, + filter_order=64, + measure="diag-lin", + mode="diag", + lr=0.0, dropout=0, batch_size=64, + max_len=250, learning_rate=0.001, max_epochs=3, patience=5000, @@ -327,12 +457,23 @@ def test_03_sample_molecules_RNN(tmp_path): embedding_size=32, hidden_size=256, n_layers=3, + n_blocks=1, state_dim=64, n_ssm=1, n_heads=4, + head_dim=1, + n_order_heads=1, exp_factor=4, + bias=True, + use_fast_fftconv=False, + order=2, + filter_order=64, + measure="diag-lin", + mode="diag", + lr=0.0, dropout=0, batch_size=64, + max_len=250, sample_mols=100, vocab_file=test_dir / "0/prior/inputs/train_LOTUS_truncated_SMILES_0.vocabulary", @@ -357,12 +498,23 @@ def test_03_sample_molecules_conditional_RNN(tmp_path): embedding_size=32, hidden_size=256, n_layers=3, + n_blocks=1, state_dim=64, n_ssm=1, n_heads=4, + head_dim=1, + n_order_heads=1, exp_factor=4, + bias=True, + use_fast_fftconv=False, + order=2, + filter_order=64, + measure="diag-lin", + mode="diag", + lr=0.0, dropout=0, batch_size=64, + max_len=250, sample_mols=100, vocab_file=test_dir / "0/prior/inputs/train_LOTUS_truncated_SMILES_0.vocabulary", diff --git a/workflow/Snakefile_data b/workflow/Snakefile_data index e8083d02..6e6adad4 100644 --- a/workflow/Snakefile_data +++ b/workflow/Snakefile_data @@ -175,12 +175,23 @@ rule train_models_RNN: '--embedding_size {MODEL_PARAMS[embedding_size]} ' '--hidden_size {MODEL_PARAMS[hidden_size]} ' '--n_layers {MODEL_PARAMS[n_layers]} ' + '--n_blocks {MODEL_PARAMS[n_blocks]} ' '--state_dim {MODEL_PARAMS[state_dim]} ' '--n_ssm {MODEL_PARAMS[n_ssm]} ' '--n_heads {MODEL_PARAMS[n_heads]} ' + '--head_dim {MODEL_PARAMS[head_dim]} ' + '--n_order_heads {MODEL_PARAMS[n_order_heads]} ' '--exp_factor {MODEL_PARAMS[exp_factor]} ' + f'{"--bias" if MODEL_PARAMS["bias"] else ""} ' + f'{"--use_fast_fftconv" if MODEL_PARAMS["use_fast_fftconv"] else ""} ' + '--measure {MODEL_PARAMS[measure]} ' + '--mode {MODEL_PARAMS[mode]} ' + '--lr {MODEL_PARAMS[lr]} ' + '--order {MODEL_PARAMS[order]} ' + '--filter_order {MODEL_PARAMS[filter_order]} ' '--dropout {MODEL_PARAMS[dropout]} ' '--batch_size {MODEL_PARAMS[batch_size]} ' + '--max_len {MODEL_PARAMS[max_len]} ' '--learning_rate {MODEL_PARAMS[learning_rate]} ' '--max_epochs {MODEL_PARAMS[max_epochs]} ' '--patience {MODEL_PARAMS[patience]} ' @@ -211,7 +222,7 @@ rule sample_molecules_RNN: output_file = PATHS['input_file'] resources: mem_mb=12000, - runtime=120+MODEL_PARAMS["sample_mols"]//10000, + runtime=240+MODEL_PARAMS["sample_mols"]//10000, slurm_extra="--gres=gpu:1" shell: 'clm sample_molecules_RNN ' @@ -222,12 +233,23 @@ rule sample_molecules_RNN: '--embedding_size {MODEL_PARAMS[embedding_size]} ' '--hidden_size {MODEL_PARAMS[hidden_size]} ' '--n_layers {MODEL_PARAMS[n_layers]} ' + '--n_blocks {MODEL_PARAMS[n_blocks]} ' '--state_dim {MODEL_PARAMS[state_dim]} ' '--n_ssm {MODEL_PARAMS[n_ssm]} ' '--n_heads {MODEL_PARAMS[n_heads]} ' + '--head_dim {MODEL_PARAMS[head_dim]} ' + '--n_order_heads {MODEL_PARAMS[n_order_heads]} ' '--exp_factor {MODEL_PARAMS[exp_factor]} ' + f'{"--bias" if MODEL_PARAMS["bias"] else ""} ' + f'{"--use_fast_fftconv" if MODEL_PARAMS["use_fast_fftconv"] else ""} ' + '--measure {MODEL_PARAMS[measure]} ' + '--mode {MODEL_PARAMS[mode]} ' + '--lr {MODEL_PARAMS[lr]} ' + '--order {MODEL_PARAMS[order]} ' + '--filter_order {MODEL_PARAMS[filter_order]} ' '--dropout {MODEL_PARAMS[dropout]} ' '--batch_size {MODEL_PARAMS[batch_size]} ' + '--max_len {MODEL_PARAMS[max_len]} ' '--sample_mols {MODEL_PARAMS[sample_mols]} ' '--vocab_file {input.vocab_file} ' '--model_file {input.model_file} ' diff --git a/workflow/config/config.yaml b/workflow/config/config.yaml index 05be6577..75e15ea4 100644 --- a/workflow/config/config.yaml +++ b/workflow/config/config.yaml @@ -62,17 +62,28 @@ preprocess: # Parameters that define the neural network model and training process. model_params: - model_type: RNN # Type of model to be used. Available options are 'RNN', 'Transformer', 'S4'. + model_type: RNN # Type of model to be used. Available options are 'RNN', 'Transformer', 'S4', 'H3', 'Hyena'. rnn_type: LSTM # Applicable to model_type 'RNN' only. Available options are 'LSTM' and 'GRU' embedding_size: 128 # Size of the embedding vectors that represent each token in the input sequence. hidden_size: 1024 # Size of the hidden state of the model. n_layers: 3 # Number of stacked RNN layers in the model. - state_dim: 64 # State dimension for S4 model. Applicable if model_type is 'S4'. + n_blocks: 2 # Number of blocks for S4 model. Applicable if model_type is 'S4'. + state_dim: 64 # State dimension for S4 model. Applicable if model_type is 'S4', 'H3'. n_ssm: 1 # Number of SSM layers for S4 model. Applicable if model_type is 'S4'. n_heads: 4 # Number of attention heads for the model. Applicable if model_type is 'Transformer'. + head_dim: 1 # Dimension of head for the model. Applicable if model_type is 'H3'. + n_order_heads: 1 # Number of groups input channels for filter computation for the Hyena model. Applicable if model_type is 'Hyena'. exp_factor: 4 # Expansion factor for Transformer model. Applicable if model_type is 'Transformer'. + bias: true # Whether to include bias terms in the model's layers. Applicable if model_type is 'Transformer'. + use_fast_fftconv: false # Whether to use fast FFT convolution for H3 model. Applicable if model_type is 'H3'. + order: 2 # Order of the convolution for Hyena model. Applicable if model_type is 'Hyena'. + filter_order: 64 # Filter order for Hyena model. Applicable if model_type is 'Hyena'. + measure: "diag-lin" # Measure for H3 layers. Applicable if model_type is 'H3'. + mode: "diag" # Mode for H3 layers. Applicable if model_type is 'H3'. + lr: 0.0 # Learning rate for the H3 mlayers. Applicable if model_type is 'H3'. dropout: 0 # Dropout rate applied to the RNN layer for regularization. batch_size: 64 # Number of samples processed before the models internal parameters are updated. + max_len: 250 # Maximum length of input sequences. Sequences longer than this will be truncated, and shorter sequences will be padded. learning_rate: 0.001 # Used by the optimizer to update model parameters. max_epochs: 999999 # Maximum number of training epochs (complete passes through the training dataset). patience: 50000 # Number of steps with no improvement in the validation loss after which early stopping is triggered. diff --git a/workflow/config/config_fast.yaml b/workflow/config/config_fast.yaml index 3502b043..dcba4948 100644 --- a/workflow/config/config_fast.yaml +++ b/workflow/config/config_fast.yaml @@ -51,17 +51,28 @@ preprocess: # Parameters that define the neural network model and training process. model_params: - model_type: RNN # Type of model to be used. Available options are 'RNN', 'Transformer', 'S4'. + model_type: RNN # Type of model to be used. Available options are 'RNN', 'Transformer', 'S4', 'H3', 'Hyena'. rnn_type: LSTM # Applicable to model_type 'RNN' only. Available options are 'LSTM' and 'GRU' embedding_size: 32 # Size of the embedding vectors that represent each token in the input sequence. hidden_size: 256 # Size of the hidden state of the model. n_layers: 3 # Number of stacked RNN layers in the model. - state_dim: 64 # State dimension for S4 model. Applicable if model_type is 'S4'. + n_blocks: 2 # Number of blocks for S4 model. Applicable if model_type is 'S4'. + state_dim: 64 # State dimension for S4 model. Applicable if model_type is 'S4', 'H3'. n_ssm: 1 # Number of SSM layers for S4 model. Applicable if model_type is 'S4'. n_heads: 4 # Number of attention heads for the model. Applicable if model_type is 'Transformer'. + head_dim: 1 # Dimension of head for the model. Applicable if model_type is 'H3'. + n_order_heads: 1 # Number of groups input channels for filter computation for the Hyena model. Applicable if model_type is 'Hyena'. exp_factor: 4 # Expansion factor for Transformer model. Applicable if model_type is 'Transformer'. + bias: true # Whether to include bias terms in the model's layers. Applicable if model_type is 'Transformer'. + use_fast_fftconv: false # Whether to use fast FFT convolution for H3 model. Applicable if model_type is 'H3'. + order: 2 # Order of the convolution for Hyena model. Applicable if model_type is 'Hyena'. + filter_order: 64 # Filter order for Hyena model. Applicable if model_type is 'Hyena'. + measure: "diag-lin" # Measure for H3 layers. Applicable if model_type is 'H3'. + mode: "diag" # Mode for H3 layers. Applicable if model_type is 'H3'. + lr: 0.0 # Learning rate for the H3 mlayers. Applicable if model_type is 'H3'. dropout: 0 # Dropout rate applied to the RNN layer for regularization. batch_size: 64 # Number of samples processed before the models internal parameters are updated. + max_len: 250 # Maximum length of input sequences. Sequences longer than this will be truncated, and shorter sequences will be padded. learning_rate: 0.001 # Used by the optimizer to update model parameters. max_epochs: 3 # Maximum number of training epochs (complete passes through the training dataset). patience: 50000 # Number of steps with no improvement in the validation loss after which early stopping is triggered.