-
Notifications
You must be signed in to change notification settings - Fork 10
add H3, H3Conv and Hyena as model architecture #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
c19a72a
5205237
fbd52a3
f5f355b
4a3bb02
d7db753
89943ac
9bc177c
cb14b50
5ce37bd
e4eb00d
ea3be5a
d5c003b
1d437cc
7160355
85e30bd
c137e9f
584bde2
a930573
a58fdf2
81f3204
2752d83
28696ca
80c1962
1c74176
7a18b8e
aaa6125
740c08d
d7120ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,12 @@ | |
| ConditionalRNN, | ||
| Transformer, | ||
| StructuredStateSpaceSequenceModel, | ||
| ) # , H3Model, H3ConvModel, HyenaModel | ||
| H3Model, | ||
| HyenaModel, | ||
| MambaModel, | ||
| Mamba2Model, | ||
| Mamba3Model, | ||
| ) | ||
| from clm.functions import load_dataset, write_to_csv_file | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -41,6 +46,9 @@ def add_args(parser): | |
| parser.add_argument( | ||
| "--n_layers", type=int, help="Number of layers in the model" | ||
| ) | ||
| parser.add_argument( | ||
| "--n_blocks", type=int, help="Number of blocks for S4 model" | ||
| ) | ||
| parser.add_argument( | ||
| "--state_dim", type=int, help="State dimension for S4 model" | ||
| ) | ||
|
|
@@ -50,9 +58,40 @@ def add_args(parser): | |
| parser.add_argument( | ||
| "--n_heads", type=int, help="Number of heads for the model" | ||
| ) | ||
| parser.add_argument( | ||
| "--head_dim", type=int, help="Dimension of head for the H3 model" | ||
| ) | ||
| parser.add_argument( | ||
| "--n_order_heads", | ||
| type=int, | ||
| help="Number of groups input channels for filter computation for the Hyena model", | ||
| ) | ||
| parser.add_argument( | ||
| "--exp_factor", type=int, help="Expansion factor for Transformer model" | ||
| ) | ||
| parser.add_argument( | ||
| "--bias", action="store_true", help="Use bias in Transformer model" | ||
| ) | ||
| parser.add_argument( | ||
| "--use_fast_fftconv", | ||
| action="store_true", | ||
| help="Use fast FFT convolution for H3 model", | ||
| ) | ||
| parser.add_argument( | ||
| "--measure", | ||
| type=str, | ||
| help="Measure parameter for the H3 model", | ||
| ) | ||
| parser.add_argument( | ||
| "--mode", type=str, help="Mode parameter for the H3 model" | ||
| ) | ||
| parser.add_argument( | ||
| "--lr", type=float, help="Learning rate for the H3 model" | ||
| ) | ||
| parser.add_argument("--order", type=int, help="Order for Hyena model") | ||
| parser.add_argument( | ||
| "--filter_order", type=int, help="Filter order for Hyena model" | ||
| ) | ||
| parser.add_argument( | ||
| "--dropout", type=float, help="Dropout rate for the RNN" | ||
| ) | ||
|
|
@@ -95,6 +134,9 @@ def add_args(parser): | |
| parser.add_argument( | ||
| "--batch_size", type=int, help="Batch size for training" | ||
| ) | ||
| parser.add_argument( | ||
| "--max_len", type=int, help="Maximum length of the generated sequences" | ||
| ) | ||
| parser.add_argument( | ||
| "--sample_mols", type=int, help="Number of molecules to generate" | ||
| ) | ||
|
|
@@ -121,12 +163,23 @@ def sample_molecules_RNN( | |
| embedding_size, | ||
| hidden_size, | ||
| n_layers, | ||
| n_blocks, | ||
| state_dim, | ||
| n_ssm, | ||
| n_heads, | ||
| head_dim, | ||
| n_order_heads, | ||
| exp_factor, | ||
| bias, | ||
| use_fast_fftconv, | ||
| measure, | ||
| mode, | ||
| lr, | ||
| order, | ||
| filter_order, | ||
| dropout, | ||
| batch_size, | ||
| max_len, | ||
| sample_mols, | ||
| vocab_file, | ||
| model_file, | ||
|
|
@@ -150,98 +203,101 @@ def sample_molecules_RNN( | |
|
|
||
| heldout_dataset = None | ||
|
|
||
| if model_type == "S4": | ||
| if model_type in [ | ||
| "S4", | ||
| "H3", | ||
| "Hyena", | ||
| "Transformer", | ||
| "Mamba", | ||
| "Mamba2", | ||
| "Mamba3", | ||
| ]: | ||
| assert ( | ||
| heldout_file is not None | ||
| ), "heldout_file must be provided for conditional RNN Model" | ||
| heldout_dataset = load_dataset( | ||
| representation=representation, | ||
| input_file=heldout_file, | ||
| vocab_file=vocab_file, | ||
| ) | ||
| not conditional | ||
| ), f"Conditional mode is not implemented for {model_type} model" | ||
|
|
||
| if model_type == "S4": | ||
| model = StructuredStateSpaceSequenceModel( | ||
| vocabulary=vocab, # heldout_dataset.vocabulary | ||
| model_dim=embedding_size, | ||
| state_dim=state_dim, | ||
| n_layers=n_layers, | ||
| n_blocks=n_blocks, | ||
| n_ssm=n_ssm, | ||
| dropout=dropout, | ||
| max_len=max_len, | ||
| ) | ||
| # elif model_type == "H3": | ||
| # assert ( | ||
| # heldout_file is not None | ||
| # ), "heldout_file must be provided for conditional RNN Model" | ||
| # heldout_dataset = load_dataset( | ||
| # representation=representation, | ||
| # input_file=heldout_file, | ||
| # vocab_file=vocab_file, | ||
| # ) | ||
| # model = H3Model( | ||
| # vocabulary=vocab, | ||
| # n_layers=n_layers, | ||
| # d_model=embedding_size, | ||
| # d_state=64, | ||
| # head_dim=1, | ||
| # dropout=dropout, | ||
| # max_len=250, | ||
| # use_fast_fftconv=False, | ||
| # ) | ||
| # elif model_type == "H3Conv": | ||
| # assert ( | ||
| # heldout_file is not None | ||
| # ), "heldout_file must be provided for conditional RNN Model" | ||
| # heldout_dataset = load_dataset( | ||
| # representation=representation, | ||
| # input_file=heldout_file, | ||
| # vocab_file=vocab_file, | ||
| # ) | ||
| # model = H3ConvModel( | ||
| # vocabulary=vocab, | ||
| # n_layers=n_layers, | ||
| # d_model=embedding_size, | ||
| # head_dim=1, | ||
| # dropout=dropout, | ||
| # max_len=250, | ||
| # use_fast_fftconv=False, | ||
| # ) | ||
| # elif model_type == "Hyena": | ||
| # assert ( | ||
| # heldout_file is not None | ||
| # ), "heldout_file must be provided for conditional RNN Model" | ||
| # heldout_dataset = load_dataset( | ||
| # representation=representation, | ||
| # input_file=heldout_file, | ||
| # vocab_file=vocab_file, | ||
| # ) | ||
| # model = HyenaModel( | ||
| # vocabulary=vocab, | ||
| # n_layers=n_layers, | ||
| # d_model=embedding_size, | ||
| # order=2, | ||
| # filter_order=64, | ||
| # num_heads=1, | ||
| # dropout=dropout, | ||
| # max_len=250, | ||
| # inner_factor=1, | ||
| # ) | ||
|
|
||
| elif model_type == "Transformer": | ||
| assert ( | ||
| heldout_file is not None | ||
| ), "heldout_file must be provided for conditional RNN Model" | ||
| heldout_dataset = load_dataset( | ||
| representation=representation, | ||
| input_file=heldout_file, | ||
| vocab_file=vocab_file, | ||
| elif model_type == "H3": | ||
| model = H3Model( | ||
| vocabulary=vocab, | ||
| n_layers=n_layers, | ||
| model_dim=embedding_size, | ||
| state_dim=state_dim, | ||
| head_dim=head_dim, | ||
| dropout=dropout, | ||
| use_fast_fftconv=use_fast_fftconv, | ||
| measure=measure, | ||
| mode=mode, | ||
| lr=lr, | ||
| max_len=max_len, | ||
| ) | ||
|
|
||
| elif model_type == "Hyena": | ||
| model = HyenaModel( | ||
| vocabulary=vocab, | ||
| n_layers=n_layers, | ||
| d_model=embedding_size, | ||
| order=order, | ||
| filter_order=filter_order, | ||
| n_order_heads=n_order_heads, | ||
| dropout=dropout, | ||
| max_len=max_len, | ||
| ) | ||
| elif model_type == "Mamba": | ||
| model = MambaModel( | ||
| vocabulary=vocab, | ||
| n_layers=n_layers, | ||
| model_dim=embedding_size, | ||
| d_state=state_dim, | ||
| d_conv=4, | ||
| expand=2, | ||
| dropout=dropout, | ||
| max_len=max_len, | ||
| ) | ||
| elif model_type == "Mamba2": | ||
| model = Mamba2Model( | ||
| vocabulary=vocab, | ||
| n_layers=n_layers, | ||
| model_dim=embedding_size, | ||
| d_state=state_dim, | ||
| d_conv=4, | ||
| expand=2, | ||
| dropout=dropout, | ||
| max_len=max_len, | ||
| ) | ||
| elif model_type == "Mamba3": | ||
| model = Mamba3Model( | ||
| vocabulary=vocab, | ||
| n_layers=n_layers, | ||
| model_dim=embedding_size, | ||
| d_state=state_dim, | ||
| headdim=32, | ||
| expand=2, | ||
| chunk_size=64, | ||
| dropout=dropout, | ||
| max_len=max_len, | ||
| ) | ||
|
Comment on lines
+230
to
+255
|
||
|
|
||
| elif model_type == "Transformer": | ||
| model = Transformer( | ||
| vocabulary=vocab, | ||
| n_blocks=n_layers, | ||
| n_heads=n_heads, | ||
| embedding_size=embedding_size, | ||
| dropout=dropout, | ||
| exp_factor=exp_factor, | ||
| bias=True, | ||
| bias=bias, | ||
| max_len=max_len, | ||
| ) | ||
|
|
||
| elif model_type == "RNN": | ||
|
|
@@ -332,12 +388,23 @@ def main(args): | |
| embedding_size=args.embedding_size, | ||
| hidden_size=args.hidden_size, | ||
| n_layers=args.n_layers, | ||
| n_blocks=args.n_blocks, | ||
| state_dim=args.state_dim, | ||
| n_ssm=args.n_ssm, | ||
| n_heads=args.n_heads, | ||
| head_dim=args.head_dim, | ||
| n_order_heads=args.n_order_heads, | ||
| exp_factor=args.exp_factor, | ||
| bias=args.bias, | ||
| use_fast_fftconv=args.use_fast_fftconv, | ||
| measure=args.measure, | ||
| mode=args.mode, | ||
| lr=args.lr, | ||
| order=args.order, | ||
| filter_order=args.filter_order, | ||
| dropout=args.dropout, | ||
| batch_size=args.batch_size, | ||
| max_len=args.max_len, | ||
| sample_mols=args.sample_mols, | ||
| vocab_file=args.vocab_file, | ||
| model_file=args.model_file, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dependency
safari @ git+https://github.com/GuptaVishu2002/safari.git@fix-setuppulls third-party code directly from a Git repository using a mutable ref (fix-setup), which enables supply-chain attacks if that branch/tag is ever compromised or force-moved. An attacker who gains control of that repo could silently change the code at the same ref and have malicious code executed in any environment that installs this project. To mitigate this, pin Git-based dependencies to immutable commit hashes (or published versions on a trusted index) and periodically update them intentionally, rather than tracking branches/tags.