Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,5 @@ shinkai-libs/shinkai-non-rust-code/internal_tools_storage
shinkai-bin/shinkai-node/storage_test_node
storage_debug_node2_test
storage_debug_node1_test
storage_debug_nico_testnet1
storage_debug_nico_testnet2
187 changes: 172 additions & 15 deletions shinkai-bin/shinkai-node/src/llm_provider/job_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::managers::tool_router::ToolRouter;
use crate::managers::IdentityManager;
use crate::network::agent_payments_manager::external_agent_offerings_manager::ExtAgentOfferingsManager;
use crate::network::agent_payments_manager::my_agent_offerings_manager::MyAgentOfferingsManager;
use dashmap::DashMap;
use ed25519_dalek::SigningKey;
use futures::Future;
use shinkai_embedding::embedding_generator::RemoteEmbeddingGenerator;
Expand Down Expand Up @@ -57,6 +58,8 @@ pub struct JobManager {
pub job_processing_task: Option<tokio::task::JoinHandle<()>>,
// Websocket manager for sending updates to the frontend
pub ws_manager: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>>,
// Track individual job processing tasks by conversation_inbox_name
pub job_processing_tasks: Arc<DashMap<String, tokio::task::JoinHandle<()>>>,
}

impl JobManager {
Expand Down Expand Up @@ -107,6 +110,8 @@ impl JobManager {
.parse::<usize>()
.unwrap_or(NUM_THREADS);

let job_processing_tasks = Arc::new(DashMap::new());

// Start processing both queues
let job_queue_handler = JobManager::process_job_queue(
job_queue_normal.clone(),
Expand All @@ -122,6 +127,7 @@ impl JobManager {
Some(my_agent_payments_manager.clone()),
Some(ext_agent_payments_manager.clone()),
llm_stopper.clone(),
job_processing_tasks.clone(),
|job,
db,
node_profile_name,
Expand Down Expand Up @@ -162,6 +168,7 @@ impl JobManager {
job_queue_manager_immediate: job_queue_immediate,
job_processing_task: Some(job_queue_handler),
ws_manager,
job_processing_tasks,
}
}

Expand All @@ -179,6 +186,7 @@ impl JobManager {
my_agent_payments_manager: Option<Arc<Mutex<MyAgentOfferingsManager>>>,
ext_agent_payments_manager: Option<Arc<Mutex<ExtAgentOfferingsManager>>>,
llm_stopper: Arc<LLMStopper>,
job_processing_tasks: Arc<DashMap<String, tokio::task::JoinHandle<()>>>,
job_processing_fn: impl Fn(
JobForProcessing,
Weak<SqliteManager>,
Expand Down Expand Up @@ -259,8 +267,29 @@ impl JobManager {
let ext_agent_payments_manager = ext_agent_payments_manager.clone();
let llm_stopper = llm_stopper.clone();
let in_progress = processing_jobs.clone();

tokio::spawn(async move {
let tasks_map_for_cleanup = job_processing_tasks.clone();
let tasks_map_for_insert = job_processing_tasks.clone();
let job_id_clone = job_id.clone();

// Get the conversation inbox name for this job
let conversation_inbox_name = if let Some(db_arc) = db_clone.upgrade() {
if let Ok(job) = db_arc.get_job(&job_id) {
job.conversation_inbox_name().to_string()
} else {
// Fallback to job_id if we can't get the job
InboxName::get_job_inbox_name_from_params(job_id.clone())
.unwrap_or_else(|_| InboxName::new(job_id.clone()).unwrap())
.to_string()
}
} else {
// Fallback to job_id if we can't get the database
InboxName::get_job_inbox_name_from_params(job_id.clone())
.unwrap_or_else(|_| InboxName::new(job_id.clone()).unwrap())
.to_string()
};

let conversation_inbox_name_clone = conversation_inbox_name.clone();
let task_handle = tokio::spawn(async move {
let _ = (job_processing_fn)(
job,
db_clone,
Expand All @@ -277,11 +306,18 @@ impl JobManager {
)
.await;

let _ = queue_immediate.lock().await.dequeue(&job_id).await;
let _ = queue_immediate.lock().await.dequeue(&job_id_clone).await;
let mut inprog = in_progress.lock().await;
inprog.remove(&job_id);
inprog.remove(&job_id_clone);

// Remove task handle from the map when completed
tasks_map_for_cleanup.remove(&conversation_inbox_name_clone);

drop(permit);
});

// Store the task handle
tasks_map_for_insert.insert(conversation_inbox_name, task_handle);
}
// Done immediate; continue so we can re-check normal vs. immediate again
continue;
Expand Down Expand Up @@ -356,11 +392,32 @@ impl JobManager {
let queue_normal = queue_normal.clone();
let my_agent_payments_manager = my_agent_payments_manager.clone();
let ext_agent_payments_manager = ext_agent_payments_manager.clone();
let llm_stopper = llm_stopper.clone();
let llm_stopper = llm_stopper.clone();
let in_progress = processing_jobs.clone();

tokio::spawn(async move {
(job_processing_fn)(
let tasks_map_for_cleanup = job_processing_tasks.clone();
let tasks_map_for_insert = job_processing_tasks.clone();
let job_id_clone = job_id.clone();

// Get the conversation inbox name for this job
let conversation_inbox_name = if let Some(db_arc) = db_clone.upgrade() {
if let Ok(job) = db_arc.get_job(&job_id) {
job.conversation_inbox_name().to_string()
} else {
// Fallback to job_id if we can't get the job
InboxName::get_job_inbox_name_from_params(job_id.clone())
.unwrap_or_else(|_| InboxName::new(job_id.clone()).unwrap())
.to_string()
}
} else {
// Fallback to job_id if we can't get the database
InboxName::get_job_inbox_name_from_params(job_id.clone())
.unwrap_or_else(|_| InboxName::new(job_id.clone()).unwrap())
.to_string()
};

let conversation_inbox_name_clone = conversation_inbox_name.clone();
let task_handle = tokio::spawn(async move {
let _ = (job_processing_fn)(
job,
db_clone,
node_profile_name,
Expand All @@ -376,12 +433,19 @@ impl JobManager {
)
.await;

let _ = queue_normal.lock().await.dequeue(&job_id).await;
let _ = queue_normal.lock().await.dequeue(&job_id_clone).await;
let mut inprog = in_progress.lock().await;
inprog.remove(&job_id);
inprog.remove(&job_id_clone);

// Remove task handle from the map when completed
tasks_map_for_cleanup.remove(&conversation_inbox_name_clone);

drop(permit);
});

// Store the task handle
tasks_map_for_insert.insert(conversation_inbox_name, task_handle);

// Done with this normal job
break;
}
Expand All @@ -408,10 +472,31 @@ impl JobManager {
let queue_immediate = queue_immediate.clone();
let my_agent_payments_manager = my_agent_payments_manager.clone();
let ext_agent_payments_manager = ext_agent_payments_manager.clone();
let llm_stopper = llm_stopper.clone();
let llm_stopper = llm_stopper.clone();
let in_progress = processing_jobs.clone();

tokio::spawn(async move {
let tasks_map_for_cleanup = job_processing_tasks.clone();
let tasks_map_for_insert = job_processing_tasks.clone();
let imm_id_clone = imm_id.clone();

// Get the conversation inbox name for this job
let conversation_inbox_name = if let Some(db_arc) = db_clone.upgrade() {
if let Ok(job) = db_arc.get_job(&imm_id) {
job.conversation_inbox_name().to_string()
} else {
// Fallback to job_id if we can't get the job
InboxName::get_job_inbox_name_from_params(imm_id.clone())
.unwrap_or_else(|_| InboxName::new(imm_id.clone()).unwrap())
.to_string()
}
} else {
// Fallback to job_id if we can't get the database
InboxName::get_job_inbox_name_from_params(imm_id.clone())
.unwrap_or_else(|_| InboxName::new(imm_id.clone()).unwrap())
.to_string()
};

let conversation_inbox_name_clone = conversation_inbox_name.clone();
let task_handle = tokio::spawn(async move {
(job_processing_fn)(
imm_job,
db_clone,
Expand All @@ -428,11 +513,18 @@ impl JobManager {
)
.await;

let _ = queue_immediate.lock().await.dequeue(&imm_id).await;
let _ = queue_immediate.lock().await.dequeue(&imm_id_clone).await;
let mut inprog = in_progress.lock().await;
inprog.remove(&imm_id);
inprog.remove(&imm_id_clone);

// Remove task handle from the map when completed
tasks_map_for_cleanup.remove(&conversation_inbox_name_clone);

drop(permit);
});

// Store the task handle
tasks_map_for_insert.insert(conversation_inbox_name, task_handle);
} else {
eprintln!("rx_immediate closed, shutting down...");
return;
Expand Down Expand Up @@ -652,6 +744,71 @@ impl JobManager {

Ok(job_message.job_id.clone().to_string())
}

/// Kills a job by its conversation inbox name
/// This will abort the running task and remove the job from the database
pub async fn kill_job_by_conversation_inbox_name(
&mut self,
conversation_inbox_name: &str,
) -> Result<String, LLMProviderError> {
// First, try to get the job_id from the conversation inbox name
let job_id = if let Ok(inbox_name) = InboxName::new(conversation_inbox_name.to_string()) {
if let Some(job_id) = inbox_name.get_job_id() {
job_id
} else {
return Err(LLMProviderError::SomeError("Not a job inbox".to_string()));
}
} else {
return Err(LLMProviderError::SomeError(
"Invalid conversation inbox name".to_string(),
));
};

// Remove and abort the task if it exists
if let Some((_, task_handle)) = self.job_processing_tasks.remove(conversation_inbox_name) {
task_handle.abort();
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Info,
&format!("Aborted task for conversation inbox: {}", conversation_inbox_name),
);
}

// Remove from the jobs map
self.jobs.lock().await.remove(&job_id);

// Remove from the database
if let Some(db_arc) = self.db.upgrade() {
// Remove job from database
if let Err(e) = db_arc.remove_job(&job_id) {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
&format!("Failed to delete job {} from database: {}", job_id, e),
);
return Err(LLMProviderError::ShinkaiDB(e));
}

// Remove from both job queues
let _ = self.job_queue_manager_normal.lock().await.dequeue(&job_id).await;
let _ = self.job_queue_manager_immediate.lock().await.dequeue(&job_id).await;

shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Info,
&format!(
"Successfully killed job with conversation inbox: {}",
conversation_inbox_name
),
);

Ok(job_id)
} else {
Err(LLMProviderError::DatabaseError(
"Failed to upgrade database reference".to_string(),
))
}
}
}

impl JobManagerTrait for JobManager {
Expand Down
Loading