diff --git a/Cargo.toml b/Cargo.toml index ef86912..fe9f578 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [package] name = "plm-local" +version = "0.1.1" edition = "2021" authors = ["Zach Charlop-Powers"] description = "Local LLMs for Proteins" @@ -12,12 +13,13 @@ metal = ["candle-core/metal", "candle-nn/metal", "candle-metal-kernels"] [dependencies] anyhow = "1.0.94" candle-core = "0.8.1" -candle-examples = "0.8.1" candle-hf-hub = "0.3.3" candle-metal-kernels = { version = "0.8.1", optional = true } candle-nn = "0.8.1" +chrono = "0.4.39" clap = { version = "4.5.23", features = ["derive"] } ferritin-amplify = { git = "https://github.com/zachcp/ferritin", version = "*", package = "ferritin-amplify" } +polars = { version = "0.45.0", features = ["polars-io", "parquet"] } serde_json = "1.0.133" tokenizers = { version = "0.21.0" } diff --git a/Readme.md b/Readme.md new file mode 100644 index 0000000..dd08589 --- /dev/null +++ b/Readme.md @@ -0,0 +1,22 @@ +# plm-local + + +Local-first protein languge models. WIP. + + +```shell +# run AMPLIFY350M +cargo run --release \ + --features metal -- \ + --model-id 350M \ + --protein-string \ + MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL + +# run AMPLIFY120M +cargo run --release \ + --features metal -- \ + --model-id 120M \ + --protein-string \ + MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL + +``` diff --git a/recipe.yaml b/recipe.yaml index c674de4..d7c1a8b 100644 --- a/recipe.yaml +++ b/recipe.yaml @@ -1,5 +1,5 @@ context: - version: "0.1.0" + version: "0.1.1" package: name: plm-local diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..4487018 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,169 @@ +use anyhow::Result; +use candle_core::utils::{cuda_is_available, metal_is_available}; +use candle_core::{Device, D}; +use ferritin_amplify::ModelOutput; +use polars::prelude::*; +use polars::prelude::{df, CsvWriter, DataFrame, ParquetWriter}; +use tokenizers::Tokenizer; + +pub fn device(cpu: bool) -> Result { + if cpu { + Ok(Device::Cpu) + } else if cuda_is_available() { + Ok(Device::new_cuda(0)?) + } else if metal_is_available() { + Ok(Device::new_metal(0)?) + } else { + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + { + println!( + "Running on CPU, to run on GPU(metal), build this example with `--features metal`" + ); + } + #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] + { + println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); + } + Ok(Device::Cpu) + } +} + +pub enum OutputType { + CSV, + PARQUET, +} + +pub struct OutputConfig { + pub contact_output: OutputType, + pub top_k_output: OutputType, + pub sequence: String, + pub outdir: String, + pub tokenizer: Tokenizer, +} + +pub trait ModelIO { + fn generate_contacts(&self, config: &OutputConfig) -> Result; + fn top_hits(&self, config: &OutputConfig) -> Result; + fn to_disk(&self, config: &OutputConfig) -> Result<()>; +} + +impl ModelIO for ModelOutput { + fn top_hits(&self, config: &OutputConfig) -> Result { + // let predictions = self.logits.argmax(D::Minus1)?; + todo!("Need to think through the API a bit"); + } + fn generate_contacts(&self, config: &OutputConfig) -> Result { + let apc = self.get_contact_map()?; + if apc.is_none() { + Ok(DataFrame::empty()) + } else { + let restensor = apc.unwrap(); + let (seqlen, _seqlen2, _) = restensor.dims3()?; + let contact_probs = candle_nn::ops::softmax(&restensor, D::Minus1)?; + let max_probs = contact_probs.max(D::Minus1)?; + let flattened = max_probs.flatten_all()?; + let values: Vec = flattened.to_vec1()?; + let indices_1: Vec = (1..=seqlen) + .map(|x| x as i32) + .cycle() + .take(seqlen * seqlen) + .collect(); + let indices_2: Vec = (1..=seqlen) + .map(|x| x as i32) + .flat_map(|x| std::iter::repeat(x).take(seqlen)) + .collect(); + let df = df! [ + "index_1" => &indices_1, + "index_2" => &indices_2, + "value" => &values, + ]?; + Ok(df) + } + } + fn to_disk(&self, config: &OutputConfig) -> Result<()> { + // Validated the pytorch/python AMPLIFY model has the same dims... + // 350M: Contact Map: Ok(Some(Tensor[dims 254, 254, 480; f32, metal:4294969344])) + // 120M: Contact Map: Ok(Some(Tensor[dims 254, 254, 240; f32, metal:4294969344])) + // Lets take the max() of the Softmax values.... + + let mut contacts = self.generate_contacts(config)?; + + println!("Writing Contact Parquet File"); + std::fs::create_dir_all(&config.outdir)?; + let outdir = std::path::PathBuf::from(&config.outdir); + match &config.contact_output { + OutputType::CSV => { + let contact_map_file = outdir.join("contact_map.csv"); + let mut file = std::fs::File::create(&contact_map_file)?; + CsvWriter::new(&mut file).finish(&mut contacts)?; + } + OutputType::PARQUET => { + let contact_map_file = outdir.join("contact_map.parquet"); + let mut file = std::fs::File::create(&contact_map_file)?; + ParquetWriter::new(&mut file).finish(&mut contacts)?; + } + } + + println!("Writing Top Output..."); + let predictions = self.logits.argmax(D::Minus1)?; + let indices: Vec = predictions.to_vec2()?[0].to_vec(); + let decoded = config + .tokenizer + .decode(indices.as_slice(), true) + .map_err(|e| anyhow::anyhow!("{}", e))?; + // std::fs::write(&decoded_path, &decoded)?; + // let decoded = &config.tokenizer.decode(indices.as_slice(), true)?; + // println!("Decoded: {:?}", decoded); + + let decoded_path = outdir.join("decoded.txt"); + std::fs::write(&decoded_path, decoded)?; + + println!("Writing Sequence..."); + let sequence_path = outdir.join("sequence.txt"); + std::fs::write(&sequence_path, &config.sequence)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use candle_core::{Device, Tensor}; + use polars::prelude::*; + use std::fs::File; + + #[test] + fn test_parquet_conversion() -> anyhow::Result<()> { + let tensor = Tensor::new(&[[0f32, 1., 3.], [2., 3., 4.], [4., 5., 6.]], &Device::Cpu)?; + let (length, width) = tensor.dims2()?; + println!("Tensor Dims: {:?}. {}, {}", tensor.dims(), length, width); + let flattened = tensor.flatten_all()?; + + let values: Vec = flattened.to_vec1()?; + let indices_01: Vec = (1..=width) + .map(|x| x as i32) + .cycle() + .take(width * width) + .collect(); + + let indices_02: Vec = (1..=width) + .map(|x| x as i32) + .flat_map(|x| std::iter::repeat(x).take(width)) + .take(width * width) + .collect(); + + let mut df = df! [ + "index_1" => &indices_01, + "index_2" => &indices_02, + "values" => &values, + ]?; + + let path = "output.parquet"; + ParquetWriter::new(File::create(path)?).finish(&mut df)?; + + let csv_path = "output.csv"; + CsvWriter::new(File::create(csv_path)?).finish(&mut df)?; + Ok(()) + } +} diff --git a/src/main.rs b/src/main.rs index 6c0f650..115b74a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,10 @@ use anyhow::{Error as E, Result}; -use candle_core::{DType, Tensor, D}; -use candle_examples::device; +use candle_core::{DType, Tensor}; use candle_hf_hub::{api::sync::Api, Repo, RepoType}; use candle_nn::VarBuilder; use clap::Parser; use ferritin_amplify::{AMPLIFYConfig as Config, AMPLIFY}; +use plm_local::{device, ModelIO, OutputConfig, OutputType}; use tokenizers::Tokenizer; pub const DTYPE: DType = DType::F32; @@ -18,7 +18,7 @@ pub const DTYPE: DType = DType::F32; )] struct Args { /// Run on CPU rather than on GPU. - #[arg(long)] + #[arg(long, default_value_t = false)] cpu: bool, /// Which AMPLIFY Model to use, either '120M' or '350M'. @@ -31,7 +31,11 @@ struct Args { /// Path to a protein FASTA file #[arg(long)] - protein_fasta: Option, + protein_fasta: Option, + + /// Output directory for files + #[arg(long)] + output_dir: Option, } impl Args { @@ -82,27 +86,34 @@ fn main() -> Result<()> { )); }; + // default is datetime-model + let output_dir = args.output_dir.unwrap_or_else(|| { + let now = chrono::Local::now(); + format!("{}_{}", now.format("%Y%m%d_%H%M%S"), args.model_id) + }); + + // for prot in protein_sequences.iter() { - // let sprot_01 = "MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL"; + let config = OutputConfig { + contact_output: OutputType::CSV, + top_k_output: OutputType::CSV, + sequence: prot.clone(), + outdir: output_dir.clone(), + tokenizer: tokenizer.clone(), + }; let tokens = tokenizer .encode(prot.to_string(), false) .map_err(E::msg)? .get_ids() .to_vec(); - let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; - println!("Encoding......."); - let encoded = model.forward(&token_ids, None, false, false)?; - println!("Predicting......."); - let predictions = encoded.logits.argmax(D::Minus1)?; - - println!("Decoding......."); - let indices: Vec = predictions.to_vec2()?[0].to_vec(); - let decoded = tokenizer.decode(indices.as_slice(), true); + println!("Encoding......."); + let encoded = model.forward(&token_ids, None, false, true)?; - println!("Decoded: {:?}, ", decoded); + println!("Writing Outputs... "); + let _ = encoded.to_disk(&config)?; } Ok(())