diff --git a/backend/rust/backend-burn/src/main.rs b/backend/rust/backend-burn/src/main.rs index 6aadfaba69e6..530e7a3b088e 100644 --- a/backend/rust/backend-burn/src/main.rs +++ b/backend/rust/backend-burn/src/main.rs @@ -1,5 +1,7 @@ use std::collections::HashMap; use std::net::SocketAddr; +use std::process::{id, Command}; +use std::sync::{Arc, Mutex}; use bunker::pb::Result as PbResult; use bunker::pb::{ @@ -15,16 +17,12 @@ use tonic::{Request, Response, Status}; use async_trait::async_trait; use tracing::{event, span, Level}; -use tracing_subscriber::filter::LevelParseError; - -use std::fs; -use std::process::{Command,id}; use models::*; // implement BackendService trait in bunker #[derive(Default, Debug)] -struct BurnBackend; +pub struct BurnBackend; #[async_trait] impl BackendService for BurnBackend { @@ -42,25 +40,7 @@ impl BackendService for BurnBackend { #[tracing::instrument] async fn predict(&self, request: Request) -> Result, Status> { - let mut models: Vec> = vec![Box::new(models::MNINST::new())]; - let result = models[0].predict(request.into_inner()); - - match result { - Ok(res) => { - let reply = Reply { - message: res.into(), - }; - let res = Response::new(reply); - Ok(res) - } - Err(e) => { - let reply = Reply { - message: e.to_string().into(), - }; - let res = Response::new(reply); - Ok(res) - } - } + todo!("predict") } #[tracing::instrument] @@ -68,7 +48,7 @@ impl BackendService for BurnBackend { &self, request: Request, ) -> Result, Status> { - todo!() + todo!("load_model") } #[tracing::instrument] @@ -121,35 +101,34 @@ impl BackendService for BurnBackend { &self, request: Request, ) -> Result, Status> { - // Here we do not need to cover the windows platform let mut breakdown = HashMap::new(); - let mut memory_usage: u64=0; + let mut memory_usage: u64 = 0; #[cfg(target_os = "linux")] { - let pid =id(); - let stat = fs::read_to_string(format!("/proc/{}/stat", pid)).expect("Failed to read stat file"); + let pid = id(); + let stat = fs::read_to_string(format!("/proc/{}/stat", pid)) + .expect("Failed to read stat file"); let stats: Vec<&str> = stat.split_whitespace().collect(); memory_usage = stats[23].parse::().expect("Failed to parse RSS"); } - #[cfg(target_os="macos")] + #[cfg(target_os = "macos")] { - let output=Command::new("ps") - .arg("-p") - .arg(id().to_string()) - .arg("-o") - .arg("rss=") - .output() - .expect("failed to execute process"); - - memory_usage = String::from_utf8_lossy(&output.stdout) - .trim() - .parse::() - .expect("Failed to parse memory usage"); + let output = Command::new("ps") + .arg("-p") + .arg(id().to_string()) + .arg("-o") + .arg("rss=") + .output() + .expect("failed to execute process"); + memory_usage = String::from_utf8_lossy(&output.stdout) + .trim() + .parse::() + .expect("Failed to parse memory usage"); } breakdown.insert("RSS".to_string(), memory_usage); @@ -167,6 +146,87 @@ impl BackendService for BurnBackend { } } +#[cfg(test)] +mod tests { + use super::*; + use tonic::Request; + + #[tokio::test] + async fn test_health() { + let backend = BurnBackend::default(); + let request = Request::new(HealthMessage {}); + let response = backend.health(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + let message_str = String::from_utf8(response.get_ref().message.clone()).unwrap(); + assert_eq!(message_str, "OK"); + } + #[tokio::test] + async fn test_status() { + let backend = BurnBackend::default(); + let request = Request::new(HealthMessage {}); + let response = backend.status(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + let state = response.get_ref().state; + assert_eq!(state, 0); + } + + #[tokio::test] + async fn test_load_model() { + let backend = BurnBackend::default(); + let request = Request::new(ModelOptions { + model: "test".to_string(), + context_size: 0, + seed: 0, + n_batch: 0, + f16_memory: false, + m_lock: false, + m_map: false, + vocab_only: false, + low_vram: false, + embeddings: false, + numa: false, + ngpu_layers: 0, + main_gpu: "".to_string(), + tensor_split: "".to_string(), + threads: 1, + library_search_path: "".to_string(), + rope_freq_base: 0.0, + rope_freq_scale: 0.0, + rms_norm_eps: 0.0, + ngqa: 0, + model_file: "".to_string(), + device: "".to_string(), + use_triton: false, + model_base_name: "".to_string(), + use_fast_tokenizer: false, + pipeline_type: "".to_string(), + scheduler_type: "".to_string(), + cuda: false, + cfg_scale: 0.0, + img2img: false, + clip_model: "".to_string(), + clip_subfolder: "".to_string(), + clip_skip: 0, + tokenizer: "".to_string(), + lora_base: "".to_string(), + lora_adapter: "".to_string(), + no_mul_mat_q: false, + draft_model: "".to_string(), + audio_path: "".to_string(), + quantization: "".to_string(), + }); + let response = backend.load_model(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + //TO_DO: add test for response + } +} + #[tokio::main] async fn main() -> Result<(), Box> { let subscriber = tracing_subscriber::fmt() diff --git a/backend/rust/models/src/lib.rs b/backend/rust/models/src/lib.rs index f3302e83ef73..b739bf0b0d51 100644 --- a/backend/rust/models/src/lib.rs +++ b/backend/rust/models/src/lib.rs @@ -1,9 +1,16 @@ +use bunker::pb::{ModelOptions, PredictOptions}; + pub(crate) mod mnist; pub use mnist::mnist::MNINST; -use bunker::pb::{ModelOptions, PredictOptions}; - +/// Trait for implementing a Language Model. pub trait LLM { + /// Loads the model from the given options. fn load_model(&mut self, request: ModelOptions) -> Result>; + /// Predicts the output for the given input options. fn predict(&mut self, request: PredictOptions) -> Result>; } + +pub struct LLModel { + model: Box, +} diff --git a/backend/rust/models/src/mnist/mnist.rs b/backend/rust/models/src/mnist/mnist.rs index 995b2706ed05..7a727bbbf441 100644 --- a/backend/rust/models/src/mnist/mnist.rs +++ b/backend/rust/models/src/mnist/mnist.rs @@ -4,7 +4,7 @@ //! Adapter by Aisuko use burn::{ - backend::wgpu::{compute::init_async, AutoGraphicsApi, WgpuDevice}, + backend::wgpu::{AutoGraphicsApi, WgpuDevice}, module::Module, nn::{self, BatchNorm, PaddingConfig2d}, record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, @@ -12,7 +12,6 @@ use burn::{ }; // https://github.com/burn-rs/burn/blob/main/examples/mnist-inference-web/model.bin -static STATE_ENCODED: &[u8] = include_bytes!("model.bin"); const NUM_CLASSES: usize = 10; @@ -36,7 +35,7 @@ pub struct MNINST { } impl MNINST { - pub fn new() -> Self { + pub fn new(model_name: &str) -> Self { let conv1 = ConvBlock::new([1, 8], [3, 3]); // 1 input channel, 8 output channels, 3x3 kernel size let conv2 = ConvBlock::new([8, 16], [3, 3]); // 8 input channels, 16 output channels, 3x3 kernel size let conv3 = ConvBlock::new([16, 24], [3, 3]); // 16 input channels, 24 output channels, 3x3 kernel size @@ -59,8 +58,9 @@ impl MNINST { fc2: fc2, activation: nn::GELU::new(), }; + let state_encoded: &[u8] = &std::fs::read(model_name).expect("Failed to load model"); let record = BinBytesRecorder::::default() - .load(STATE_ENCODED.to_vec()) + .load(state_encoded.to_vec()) .expect("Failed to decode state"); instance.load_record(record) @@ -178,7 +178,7 @@ mod tests { pub type Backend = burn::backend::NdArrayBackend; #[test] fn test_inference() { - let mut model = MNINST::::new(); + let mut model = MNINST::::new("model.bin"); let output = model.inference(&[0.0; 28 * 28]).unwrap(); assert_eq!(output.len(), 10); } diff --git a/backend/rust/models/src/mnist/mod.rs b/backend/rust/models/src/mnist/mod.rs index d53b76c6c7a4..8cc85ce0e520 100644 --- a/backend/rust/models/src/mnist/mod.rs +++ b/backend/rust/models/src/mnist/mod.rs @@ -1,14 +1,20 @@ use crate::LLM; -use bunker::pb::{ModelOptions, PredictOptions}; pub(crate) mod mnist; +use mnist::MNINST; + +use bunker::pb::{ModelOptions, PredictOptions}; + #[cfg(feature = "ndarray")] pub type Backend = burn::backend::NdArrayBackend; -impl LLM for mnist::MNINST { +impl LLM for MNINST { fn load_model(&mut self, request: ModelOptions) -> Result> { - todo!("load model") + let model = request.model_file; + let instance = MNINST::::new(&model); + *self = instance; + Ok("".to_string()) } fn predict(&mut self, pre_ops: PredictOptions) -> Result> {