Skip to content

Commit

Permalink
fix: use correct locking strategy inside baml-cli serve (#943)
Browse files Browse the repository at this point in the history
  • Loading branch information
sxlijin authored Sep 11, 2024
1 parent 8873fe7 commit fcb694d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
10 changes: 7 additions & 3 deletions engine/baml-runtime/src/cli/dev.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use anyhow::{Result};
use anyhow::Result;
use notify_debouncer_full::{new_debouncer, notify::*};
use std::ops::DerefMut;
use std::path::PathBuf;
use std::time::{Duration, Instant};
use std::{path::PathBuf};

use crate::{cli::generate::GenerateArgs, BamlRuntime};

Expand Down Expand Up @@ -118,7 +119,10 @@ Thanks for trying out BAML!
}
.run(defaults);

std::mem::swap(&mut *server.b.lock().await, &mut new_runtime);
std::mem::swap(
server.b.write().await.deref_mut(),
&mut new_runtime,
);
log::info!(
"Reloaded runtime in {}ms ({})",
elapsed.as_millis(),
Expand Down
14 changes: 6 additions & 8 deletions engine/baml-runtime/src/cli/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use core::pin::Pin;
use futures::Stream;
use serde_json::json;
use std::{path::PathBuf, sync::Arc, task::Poll};
use tokio::{net::TcpListener, sync::Mutex};
use tokio::{net::TcpListener, sync::RwLock};
use tokio_stream::StreamExt;

use crate::{
Expand Down Expand Up @@ -115,7 +115,7 @@ Thanks for trying out BAML!
pub(super) struct Server {
src_dir: PathBuf,
port: u16,
pub(super) b: Arc<Mutex<BamlRuntime>>,
pub(super) b: Arc<RwLock<BamlRuntime>>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -201,7 +201,7 @@ impl Server {
Arc::new(Self {
src_dir: src_dir.clone(),
port,
b: Arc::new(Mutex::new(BamlRuntime::from_directory(
b: Arc::new(RwLock::new(BamlRuntime::from_directory(
&src_dir,
std::env::vars().collect(),
)?)),
Expand Down Expand Up @@ -377,10 +377,8 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping`

let ctx_mgr = RuntimeContextManager::new_from_env_vars(std::env::vars().collect(), None);

let (result, _trace_id) = self
.b
.lock()
.await
let locked = self.b.read().await;
let (result, _trace_id) = locked
.call_function(b_fn, &args, &ctx_mgr, None, None)
.await;

Expand Down Expand Up @@ -439,7 +437,7 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping`

let mut result_stream = self
.b
.lock()
.read()
.await
.stream_function(b_fn, &args, &ctx_mgr, None, None)?;

Expand Down

0 comments on commit fcb694d

Please sign in to comment.