Skip to content

Commit

Permalink
update the lib and main
Browse files Browse the repository at this point in the history
  • Loading branch information
zachcp committed Dec 9, 2024
1 parent 575bd50 commit 46e7602
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 45 deletions.
53 changes: 45 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use anyhow::Result;
use candle_core::utils::{cuda_is_available, metal_is_available};
use candle_core::{Device, D};
use ferritin_amplify::ModelOutput;
use polars::prelude::{df, DataFrame};
use polars::io::parquet::*;
use polars::prelude::*;
use polars::prelude::{df, CsvWriter, DataFrame, ParquetWriter};

pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Expand All @@ -27,7 +29,7 @@ pub fn device(cpu: bool) -> Result<Device> {
}

#[derive(Default)]
enum OutputType {
pub enum OutputType {
#[default]
CSV,
PARQUET,
Expand All @@ -37,20 +39,22 @@ enum OutputType {
pub struct OutputConfig {
pub contact_output: OutputType,
pub top_k_output: OutputType,
pub sequence: String,
pub outdir: String,
}

pub trait ModelIO {
fn generate_contacts(&self, config: OutputConfig) -> Result<DataFrame>;
fn top_hits(&self,OutputConfig) -> Result<DataFrame>;
fn to_disk(&self,OutputConfig);
fn generate_contacts(&self, config: &OutputConfig) -> Result<DataFrame>;
fn top_hits(&self, config: &OutputConfig) -> Result<DataFrame>;
fn to_disk(&self, config: &OutputConfig) -> Result<()>;
}

impl ModelIO for ModelOutput {
fn top_hits(&self, OutputConfig) -> Result<DataFrame> {
fn top_hits(&self, config: &OutputConfig) -> Result<DataFrame> {
// let predictions = self.logits.argmax(D::Minus1)?;
todo!("Need to think through the API a bit");
}
fn generate_contacts(&self, OutputConfig) -> Result<DataFrame> {
fn generate_contacts(&self, config: &OutputConfig) -> Result<DataFrame> {
let apc = self.get_contact_map()?;
if apc.is_none() {
Ok(DataFrame::empty())
Expand Down Expand Up @@ -78,7 +82,40 @@ impl ModelIO for ModelOutput {
Ok(df)
}
}
fn to_disk(&self, OutputConfig) {}
fn to_disk(&self, config: &OutputConfig) -> Result<()> {
// Validated the pytorch/python AMPLIFY model has the same dims...
// 350M: Contact Map: Ok(Some(Tensor[dims 254, 254, 480; f32, metal:4294969344]))
// 120M: Contact Map: Ok(Some(Tensor[dims 254, 254, 240; f32, metal:4294969344]))
// Lets take the max() of the Softmax values....

let mut contacts = self.generate_contacts(config)?;

println!("Writing Contact Parquet File");
std::fs::create_dir_all(&config.outdir)?;
let outdir = std::path::PathBuf::from(&config.outdir);
match &config.contact_output {
OutputType::CSV => {
let contact_map_file = outdir.join("contact_map.csv");
let mut file = std::fs::File::create(&contact_map_file)?;
CsvWriter::new(&mut file).finish(&mut contacts)?;
}
OutputType::PARQUET => {
let contact_map_file = outdir.join("contact_map.parquet");
let mut file = std::fs::File::create(&contact_map_file)?;
ParquetWriter::new(&mut file).finish(&mut contacts)?;
}
}

// println!("Writing Contact Parquet File");
// println!("Predicting.......");
// let predictions = encoded.logits.argmax(D::Minus1)?;
// println!("Decoding.......");
// let indices: Vec<u32> = predictions.to_vec2()?[0].to_vec();
// let decoded = tokenizer.decode(indices.as_slice(), true);
// println!("Decoded: {:?}, ", decoded);

Ok(())
}
}

#[cfg(test)]
Expand Down
51 changes: 14 additions & 37 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ use candle_hf_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder;
use clap::Parser;
use ferritin_amplify::{AMPLIFYConfig as Config, AMPLIFY};
use plm_local::{device, ModelIO, OutputConfig};
use polars::io::parquet::*;
use polars::prelude::*;
use plm_local::{device, ModelIO, OutputConfig, OutputType};
use tokenizers::Tokenizer;

pub const DTYPE: DType = DType::F32;
Expand All @@ -33,11 +31,11 @@ struct Args {

/// Path to a protein FASTA file
#[arg(long)]
protein_fasta: Option<std::path::PathBuf>,
protein_fasta: Option<String>,

/// Output directory for files
#[arg(long)]
output_dir: Option<std::path::PathBuf>,
output_dir: Option<String>,
}

impl Args {
Expand Down Expand Up @@ -91,51 +89,30 @@ fn main() -> Result<()> {
// default is datetime-model
let output_dir = args.output_dir.unwrap_or_else(|| {
let now = chrono::Local::now();
let dirname = format!("{}_{}", now.format("%Y%m%d_%H%M%S"), args.model_id);
std::path::PathBuf::from(dirname)
format!("{}_{}", now.format("%Y%m%d_%H%M%S"), args.model_id)
});

//
for prot in protein_sequences.iter() {
// let sprot_01 = "MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL";
let config = OutputConfig {
contact_output: OutputType::CSV,
top_k_output: OutputType::CSV,
sequence: prot.clone(),
outdir: output_dir.clone(),
};

let tokens = tokenizer
.encode(prot.to_string(), false)
.map_err(E::msg)?
.get_ids()
.to_vec();

let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;

println!("Encoding.......");
let encoded = model.forward(&token_ids, None, false, true)?;

println!("Writing Contact Map (todo).......");
let cmap = encoded.get_contact_map();
println!("Contact Map: {:?}", cmap);
// Validated the pytorch/python AMPLIFY model has the same dims...
// 350M: Contact Map: Ok(Some(Tensor[dims 254, 254, 480; f32, metal:4294969344]))
// 120M: Contact Map: Ok(Some(Tensor[dims 254, 254, 240; f32, metal:4294969344]))
// Lets take the max() of the Softmax values....
let mut cmap2 = encoded.contacts()?;
std::fs::create_dir_all(&output_dir)?;
let contact_map_file = output_dir.join("contact_map.parquet");
let contact_map_csv = output_dir.join("contact_map.csv");
let mut file = std::fs::File::create(contact_map_file).unwrap();
ParquetWriter::new(&mut file).finish(&mut cmap2).unwrap();
let mut file = std::fs::File::create(contact_map_csv).unwrap();
CsvWriter::new(&mut file).finish(&mut cmap2).unwrap();

println!("DataFrame: {:?}", cmap2);

println!("Writing Logits as Parquet.......");

println!("Predicting.......");
let predictions = encoded.logits.argmax(D::Minus1)?;

println!("Decoding.......");
let indices: Vec<u32> = predictions.to_vec2()?[0].to_vec();
let decoded = tokenizer.decode(indices.as_slice(), true);

println!("Decoded: {:?}, ", decoded);
println!("Writing Outputs... ");
let _ = encoded.to_disk(&config)?;
}

Ok(())
Expand Down

0 comments on commit 46e7602

Please sign in to comment.