Skip to content

Commit

Permalink
Support Writing Contact Maps (#1)
Browse files Browse the repository at this point in the history
* spec out parquet conversion

* start stubbing out parquet options.

* turn on the get_contact_map

* add Readme
  • Loading branch information
zachcp authored Dec 9, 2024
1 parent 69526ca commit cd7bdfc
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 17 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[package]
name = "plm-local"
version = "0.1.1"
edition = "2021"
authors = ["Zach Charlop-Powers<[email protected]>"]
description = "Local LLMs for Proteins"
Expand All @@ -12,12 +13,13 @@ metal = ["candle-core/metal", "candle-nn/metal", "candle-metal-kernels"]
[dependencies]
anyhow = "1.0.94"
candle-core = "0.8.1"
candle-examples = "0.8.1"
candle-hf-hub = "0.3.3"
candle-metal-kernels = { version = "0.8.1", optional = true }
candle-nn = "0.8.1"
chrono = "0.4.39"
clap = { version = "4.5.23", features = ["derive"] }
ferritin-amplify = { git = "https://github.com/zachcp/ferritin", version = "*", package = "ferritin-amplify" }
polars = { version = "0.45.0", features = ["polars-io", "parquet"] }
serde_json = "1.0.133"
tokenizers = { version = "0.21.0" }

Expand Down
22 changes: 22 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# plm-local


Local-first protein languge models. WIP.


```shell
# run AMPLIFY350M
cargo run --release \
--features metal -- \
--model-id 350M \
--protein-string \
MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL

# run AMPLIFY120M
cargo run --release \
--features metal -- \
--model-id 120M \
--protein-string \
MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL

```
2 changes: 1 addition & 1 deletion recipe.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
context:
version: "0.1.0"
version: "0.1.1"

package:
name: plm-local
Expand Down
169 changes: 169 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use anyhow::Result;
use candle_core::utils::{cuda_is_available, metal_is_available};
use candle_core::{Device, D};
use ferritin_amplify::ModelOutput;
use polars::prelude::*;
use polars::prelude::{df, CsvWriter, DataFrame, ParquetWriter};
use tokenizers::Tokenizer;

pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(Device::Cpu)
}
}

pub enum OutputType {
CSV,
PARQUET,
}

pub struct OutputConfig {
pub contact_output: OutputType,
pub top_k_output: OutputType,
pub sequence: String,
pub outdir: String,
pub tokenizer: Tokenizer,
}

pub trait ModelIO {
fn generate_contacts(&self, config: &OutputConfig) -> Result<DataFrame>;
fn top_hits(&self, config: &OutputConfig) -> Result<DataFrame>;
fn to_disk(&self, config: &OutputConfig) -> Result<()>;
}

impl ModelIO for ModelOutput {
fn top_hits(&self, config: &OutputConfig) -> Result<DataFrame> {
// let predictions = self.logits.argmax(D::Minus1)?;
todo!("Need to think through the API a bit");
}
fn generate_contacts(&self, config: &OutputConfig) -> Result<DataFrame> {
let apc = self.get_contact_map()?;
if apc.is_none() {
Ok(DataFrame::empty())
} else {
let restensor = apc.unwrap();
let (seqlen, _seqlen2, _) = restensor.dims3()?;
let contact_probs = candle_nn::ops::softmax(&restensor, D::Minus1)?;
let max_probs = contact_probs.max(D::Minus1)?;
let flattened = max_probs.flatten_all()?;
let values: Vec<f32> = flattened.to_vec1()?;
let indices_1: Vec<i32> = (1..=seqlen)
.map(|x| x as i32)
.cycle()
.take(seqlen * seqlen)
.collect();
let indices_2: Vec<i32> = (1..=seqlen)
.map(|x| x as i32)
.flat_map(|x| std::iter::repeat(x).take(seqlen))
.collect();
let df = df! [
"index_1" => &indices_1,
"index_2" => &indices_2,
"value" => &values,
]?;
Ok(df)
}
}
fn to_disk(&self, config: &OutputConfig) -> Result<()> {
// Validated the pytorch/python AMPLIFY model has the same dims...
// 350M: Contact Map: Ok(Some(Tensor[dims 254, 254, 480; f32, metal:4294969344]))
// 120M: Contact Map: Ok(Some(Tensor[dims 254, 254, 240; f32, metal:4294969344]))
// Lets take the max() of the Softmax values....

let mut contacts = self.generate_contacts(config)?;

println!("Writing Contact Parquet File");
std::fs::create_dir_all(&config.outdir)?;
let outdir = std::path::PathBuf::from(&config.outdir);
match &config.contact_output {
OutputType::CSV => {
let contact_map_file = outdir.join("contact_map.csv");
let mut file = std::fs::File::create(&contact_map_file)?;
CsvWriter::new(&mut file).finish(&mut contacts)?;
}
OutputType::PARQUET => {
let contact_map_file = outdir.join("contact_map.parquet");
let mut file = std::fs::File::create(&contact_map_file)?;
ParquetWriter::new(&mut file).finish(&mut contacts)?;
}
}

println!("Writing Top Output...");
let predictions = self.logits.argmax(D::Minus1)?;
let indices: Vec<u32> = predictions.to_vec2()?[0].to_vec();
let decoded = config
.tokenizer
.decode(indices.as_slice(), true)
.map_err(|e| anyhow::anyhow!("{}", e))?;
// std::fs::write(&decoded_path, &decoded)?;
// let decoded = &config.tokenizer.decode(indices.as_slice(), true)?;
// println!("Decoded: {:?}", decoded);

let decoded_path = outdir.join("decoded.txt");
std::fs::write(&decoded_path, decoded)?;

println!("Writing Sequence...");
let sequence_path = outdir.join("sequence.txt");
std::fs::write(&sequence_path, &config.sequence)?;

Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use candle_core::{Device, Tensor};
use polars::prelude::*;
use std::fs::File;

#[test]
fn test_parquet_conversion() -> anyhow::Result<()> {
let tensor = Tensor::new(&[[0f32, 1., 3.], [2., 3., 4.], [4., 5., 6.]], &Device::Cpu)?;
let (length, width) = tensor.dims2()?;
println!("Tensor Dims: {:?}. {}, {}", tensor.dims(), length, width);
let flattened = tensor.flatten_all()?;

let values: Vec<f32> = flattened.to_vec1()?;
let indices_01: Vec<i32> = (1..=width)
.map(|x| x as i32)
.cycle()
.take(width * width)
.collect();

let indices_02: Vec<i32> = (1..=width)
.map(|x| x as i32)
.flat_map(|x| std::iter::repeat(x).take(width))
.take(width * width)
.collect();

let mut df = df! [
"index_1" => &indices_01,
"index_2" => &indices_02,
"values" => &values,
]?;

let path = "output.parquet";
ParquetWriter::new(File::create(path)?).finish(&mut df)?;

let csv_path = "output.csv";
CsvWriter::new(File::create(csv_path)?).finish(&mut df)?;
Ok(())
}
}
41 changes: 26 additions & 15 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use anyhow::{Error as E, Result};
use candle_core::{DType, Tensor, D};
use candle_examples::device;
use candle_core::{DType, Tensor};
use candle_hf_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder;
use clap::Parser;
use ferritin_amplify::{AMPLIFYConfig as Config, AMPLIFY};
use plm_local::{device, ModelIO, OutputConfig, OutputType};
use tokenizers::Tokenizer;

pub const DTYPE: DType = DType::F32;
Expand All @@ -18,7 +18,7 @@ pub const DTYPE: DType = DType::F32;
)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
#[arg(long, default_value_t = false)]
cpu: bool,

/// Which AMPLIFY Model to use, either '120M' or '350M'.
Expand All @@ -31,7 +31,11 @@ struct Args {

/// Path to a protein FASTA file
#[arg(long)]
protein_fasta: Option<std::path::PathBuf>,
protein_fasta: Option<String>,

/// Output directory for files
#[arg(long)]
output_dir: Option<String>,
}

impl Args {
Expand Down Expand Up @@ -82,27 +86,34 @@ fn main() -> Result<()> {
));
};

// default is datetime-model
let output_dir = args.output_dir.unwrap_or_else(|| {
let now = chrono::Local::now();
format!("{}_{}", now.format("%Y%m%d_%H%M%S"), args.model_id)
});

//
for prot in protein_sequences.iter() {
// let sprot_01 = "MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL";
let config = OutputConfig {
contact_output: OutputType::CSV,
top_k_output: OutputType::CSV,
sequence: prot.clone(),
outdir: output_dir.clone(),
tokenizer: tokenizer.clone(),
};

let tokens = tokenizer
.encode(prot.to_string(), false)
.map_err(E::msg)?
.get_ids()
.to_vec();

let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
println!("Encoding.......");
let encoded = model.forward(&token_ids, None, false, false)?;

println!("Predicting.......");
let predictions = encoded.logits.argmax(D::Minus1)?;

println!("Decoding.......");
let indices: Vec<u32> = predictions.to_vec2()?[0].to_vec();
let decoded = tokenizer.decode(indices.as_slice(), true);
println!("Encoding.......");
let encoded = model.forward(&token_ids, None, false, true)?;

println!("Decoded: {:?}, ", decoded);
println!("Writing Outputs... ");
let _ = encoded.to_disk(&config)?;
}

Ok(())
Expand Down

0 comments on commit cd7bdfc

Please sign in to comment.