diff --git a/files/shinkai_welcome.md b/files/shinkai_welcome.md index a299ecefe..dcc730e88 100644 --- a/files/shinkai_welcome.md +++ b/files/shinkai_welcome.md @@ -16,9 +16,9 @@ At its core, an AI agent in Shinkai starts with a base AI model (like those from 2. **Give Instructions:** Write a "System Prompt" detailing how you want your agent to act, what knowledge it should focus on, or the persona it should adopt. 3. **Equip with Tools:** Grant your agent specific skills by enabling "Tools" (you can build your own with our specialized AI, download from the AI Store or manually create them). -┌──────────────────┐ ┌────────────────────────┐ ┌──────────────────┐ -│ 1. Pick a Model │→ │ 2. Prompt + Add Tools │→ │ 3. Launch Agent │ -└──────────────────┘ └────────────────────────┘ └──────────────────┘ +* ┌──────────────────┐ ┌────────────────────────┐ ┌──────────────────┐ +* │ 1. Pick a Model │→ │ 2. Prompt + Add Tools │→ │ 3. Launch Agent │ +* └──────────────────┘ └────────────────────────┘ └──────────────────┘ ## Tools: Supercharging Your Agents diff --git a/shinkai-bin/shinkai-node/src/llm_provider/job_manager.rs b/shinkai-bin/shinkai-node/src/llm_provider/job_manager.rs index 442e05aac..07072ee03 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/job_manager.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/job_manager.rs @@ -777,37 +777,20 @@ impl JobManager { // Remove from the jobs map self.jobs.lock().await.remove(&job_id); - // Remove from the database - if let Some(db_arc) = self.db.upgrade() { - // Remove job from database - if let Err(e) = db_arc.remove_job(&job_id) { - shinkai_log( - ShinkaiLogOption::JobExecution, - ShinkaiLogLevel::Error, - &format!("Failed to delete job {} from database: {}", job_id, e), - ); - return Err(LLMProviderError::ShinkaiDB(e)); - } - - // Remove from both job queues - let _ = self.job_queue_manager_normal.lock().await.dequeue(&job_id).await; - let _ = self.job_queue_manager_immediate.lock().await.dequeue(&job_id).await; - - shinkai_log( - ShinkaiLogOption::JobExecution, - ShinkaiLogLevel::Info, - &format!( - "Successfully killed job with conversation inbox: {}", - conversation_inbox_name - ), - ); - - Ok(job_id) - } else { - Err(LLMProviderError::DatabaseError( - "Failed to upgrade database reference".to_string(), - )) - } + // Remove from both job queues + let _ = self.job_queue_manager_normal.lock().await.dequeue(&job_id).await; + let _ = self.job_queue_manager_immediate.lock().await.dequeue(&job_id).await; + + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Info, + &format!( + "Successfully killed job with conversation inbox: {}", + conversation_inbox_name + ), + ); + + Ok(job_id) } } diff --git a/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs b/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs index 97e5b60d4..3486a8466 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs @@ -18,7 +18,7 @@ use shinkai_message_primitives::schemas::job_config::JobConfig; use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::{LLMProviderInterface, OpenAI}; use shinkai_message_primitives::schemas::prompts::Prompt; use shinkai_message_primitives::schemas::ws_types::{ - ToolMetadata, ToolStatus, ToolStatusType, WSMessageType, WSMetadata, WSUpdateHandler, WidgetMetadata, + ToolMetadata, ToolStatus, ToolStatusType, WSMessageType, WSMetadata, WSUpdateHandler, WidgetMetadata }; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::WSTopic; use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}; @@ -151,21 +151,6 @@ impl LLMService for OpenAI { } } - if let Some(ref msg_id) = tracing_message_id { - let network_info = json!({ - "url": url, - "payload": payload_log - }); - if let Err(e) = db.add_tracing( - msg_id, - inbox_name.as_ref().map(|i| i.get_value()).as_deref(), - "llm_network_request", - &network_info, - ) { - eprintln!("failed to add network request trace: {:?}", e); - } - } - if is_stream { handle_streaming_response( client, diff --git a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs index 261901472..881467d65 100644 --- a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs +++ b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs @@ -17,11 +17,7 @@ impl Node { let listen_address_clone = self.listen_address; let libp2p_manager_clone = self.libp2p_manager.clone(); tokio::spawn(async move { - let _ = Self::ping_all( - listen_address_clone, - libp2p_manager_clone, - ) - .await; + let _ = Self::ping_all(listen_address_clone, libp2p_manager_clone).await; }); } NodeCommand::GetPublicKeys(sender) => { @@ -797,6 +793,28 @@ impl Node { let _ = Node::v2_remove_job(db_clone, bearer, job_id, res).await; }); } + NodeCommand::V2ApiKillJob { + bearer, + conversation_inbox_name, + res, + } => { + let db_clone = self.db.clone(); + let job_manager_clone = self.job_manager.clone().unwrap(); + let ws_manager_clone = self.ws_manager.clone(); + let llm_stopper_clone = self.llm_stopper.clone(); + tokio::spawn(async move { + let _ = Node::v2_api_kill_job( + db_clone, + job_manager_clone, + ws_manager_clone, + llm_stopper_clone, + bearer, + conversation_inbox_name, + res, + ) + .await; + }); + } NodeCommand::V2ApiVecFSRetrievePathSimplifiedJson { bearer, payload, res } => { let db_clone = Arc::clone(&self.db); @@ -1245,11 +1263,16 @@ impl Node { let _ = Node::v2_api_get_shinkai_tool_metadata(db_clone, bearer, tool_router_key, res).await; }); } - NodeCommand::V2ApiGetToolWithOffering { bearer, tool_key_name, res } => { + NodeCommand::V2ApiGetToolWithOffering { + bearer, + tool_key_name, + res, + } => { let db_clone = Arc::clone(&self.db); let node_name_clone = self.node_name.clone(); tokio::spawn(async move { - let _ = Node::v2_api_get_tool_with_offering(db_clone, node_name_clone, bearer, tool_key_name, res).await; + let _ = Node::v2_api_get_tool_with_offering(db_clone, node_name_clone, bearer, tool_key_name, res) + .await; }); } NodeCommand::V2ApiGetToolsWithOfferings { bearer, res } => { @@ -1669,7 +1692,11 @@ impl Node { let _ = Node::v2_api_get_job_scope(db_clone, bearer, job_id, res).await; }); } - NodeCommand::V2ApiGetMessageTraces { bearer, message_id, res } => { + NodeCommand::V2ApiGetMessageTraces { + bearer, + message_id, + res, + } => { let db_clone = Arc::clone(&self.db); tokio::spawn(async move { let _ = Node::v2_api_get_message_traces(db_clone, bearer, message_id, res).await; diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs index d44bc73c7..76cfa52e9 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs @@ -25,7 +25,7 @@ use tokio::sync::Mutex; use x25519_dalek::PublicKey as EncryptionPublicKey; use crate::{ - llm_provider::job_manager::JobManager, managers::IdentityManager, network::{node_error::NodeError, Node} + llm_provider::{job_manager::JobManager, llm_stopper::LLMStopper}, managers::IdentityManager, network::{node_error::NodeError, ws_manager::WebSocketManager, Node} }; use x25519_dalek::StaticSecret as EncryptionStaticKey; @@ -1605,6 +1605,98 @@ impl Node { Ok(()) } + pub async fn v2_api_kill_job( + db: Arc, + job_manager: Arc>, + ws_manager: Option>>, + llm_stopper: Arc, + bearer: String, + conversation_inbox_name: String, + res: Sender>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Kill the job and capture necessary info + let (job_id, identity_sk, node_name) = { + let mut jm = job_manager.lock().await; + match jm.kill_job_by_conversation_inbox_name(&conversation_inbox_name).await { + Ok(job_id) => { + let id_sk = clone_signature_secret_key(&jm.identity_secret_key); + let node_name = jm.node_profile_name.clone(); + (job_id, id_sk, node_name) + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: err.to_string(), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + } + }; + + // Obtain partial assistant message from the WebSocket manager + let partial_text = if let Some(manager) = ws_manager.as_ref() { + manager + .lock() + .await + .get_fragment(&conversation_inbox_name) + .await + .unwrap_or_default() + } else { + String::new() + }; + + // Signal the LLM to stop processing + llm_stopper.stop(&conversation_inbox_name); + + // Insert an assistant message with the partial text + let ai_message = ShinkaiMessageBuilder::job_message_from_llm_provider( + job_id.clone(), + partial_text, + Vec::new(), + None, + identity_sk, + node_name.node_name.clone(), + node_name.node_name.clone(), + ) + .map_err(|_| NodeError { + message: "Failed to build message".to_string(), + })?; + + if let Err(err) = db.add_message_to_job_inbox(&job_id, &ai_message, None, None).await { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to add message: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + if let Some(manager) = ws_manager { + manager.lock().await.clear_fragment(&conversation_inbox_name).await; + } + + // Clear any stop signal set for this job + llm_stopper.reset(&conversation_inbox_name); + + let _ = res + .send(Ok(SendResponseBody { + status: "success".to_string(), + message: "Job killed successfully".to_string(), + data: None, + })) + .await; + + Ok(()) + } + pub async fn v2_export_messages_from_inbox( db: Arc, bearer: String, @@ -1761,11 +1853,7 @@ impl Node { for messages in v2_chat_messages { for message in messages { - let role = if message - .sender_subidentity - .to_lowercase() - .contains("/agent/") - { + let role = if message.sender_subidentity.to_lowercase().contains("/agent/") { "assistant" } else { "user" diff --git a/shinkai-bin/shinkai-node/src/network/ws_manager.rs b/shinkai-bin/shinkai-node/src/network/ws_manager.rs index 9b3384c5a..ea15ca94b 100644 --- a/shinkai-bin/shinkai-node/src/network/ws_manager.rs +++ b/shinkai-bin/shinkai-node/src/network/ws_manager.rs @@ -42,6 +42,7 @@ pub struct WebSocketManager { identity_manager_trait: Arc>, encryption_secret_key: EncryptionStaticKey, message_queue: MessageQueue, + message_fragments: Arc>>, } impl Clone for WebSocketManager { @@ -55,6 +56,7 @@ impl Clone for WebSocketManager { identity_manager_trait: Arc::clone(&self.identity_manager_trait), encryption_secret_key: self.encryption_secret_key.clone(), message_queue: Arc::clone(&self.message_queue), + message_fragments: Arc::clone(&self.message_fragments), } } } @@ -87,6 +89,7 @@ impl WebSocketManager { identity_manager_trait, encryption_secret_key, message_queue: Arc::new(Mutex::new(VecDeque::new())), + message_fragments: Arc::new(Mutex::new(HashMap::new())), })); let manager_clone = Arc::clone(&manager); @@ -493,6 +496,16 @@ impl WebSocketManager { } } } + + pub async fn get_fragment(&self, inbox: &str) -> Option { + let fragments = self.message_fragments.lock().await; + fragments.get(inbox).cloned() + } + + pub async fn clear_fragment(&self, inbox: &str) { + let mut fragments = self.message_fragments.lock().await; + fragments.remove(inbox); + } } #[async_trait] @@ -505,6 +518,18 @@ impl WSUpdateHandler for WebSocketManager { metadata: WSMessageType, is_stream: bool, ) { + if is_stream && matches!(topic, WSTopic::Inbox) { + let mut fragments = self.message_fragments.lock().await; + let entry = fragments.entry(subtopic.clone()).or_default(); + entry.push_str(&update); + + if let WSMessageType::Metadata(meta) = &metadata { + if meta.is_done { + fragments.remove(&subtopic); + } + } + } + let mut queue = self.message_queue.lock().await; queue.push_back((topic, subtopic, update, metadata, is_stream)); } diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_jobs.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_jobs.rs index 63df6cb01..419a2ee75 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_jobs.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_jobs.rs @@ -6,21 +6,29 @@ use serde::Deserialize; use serde_json::json; use shinkai_message_primitives::{ schemas::{ - job_config::JobConfig, llm_providers::serialized_llm_provider::{ - Exo, Gemini, Groq, LLMProviderInterface, Ollama, OpenAI, SerializedLLMProvider, ShinkaiBackend - }, shinkai_name::{ShinkaiName, ShinkaiSubidentityType}, smart_inbox::{LLMProviderSubset, V2SmartInbox} - }, shinkai_message::{ - shinkai_message::NodeApiData, shinkai_message_schemas::{ - APIChangeJobAgentRequest, AssociatedUI, CallbackAction, ExportInboxMessagesFormat, JobCreationInfo, JobMessage, V2ChatMessage - } - }, shinkai_utils::job_scope::MinimalJobScope + job_config::JobConfig, + llm_providers::serialized_llm_provider::{ + Exo, Gemini, Groq, LLMProviderInterface, Ollama, OpenAI, SerializedLLMProvider, ShinkaiBackend, + }, + shinkai_name::{ShinkaiName, ShinkaiSubidentityType}, + smart_inbox::{LLMProviderSubset, V2SmartInbox}, + }, + shinkai_message::{ + shinkai_message::NodeApiData, + shinkai_message_schemas::{ + APIChangeJobAgentRequest, AssociatedUI, CallbackAction, ExportInboxMessagesFormat, JobCreationInfo, + JobMessage, V2ChatMessage, + }, + }, + shinkai_utils::job_scope::MinimalJobScope, }; use utoipa::{OpenApi, ToSchema}; use warp::multipart::FormData; use warp::Filter; use crate::{ - node_api_router::{APIError, SendResponseBody, SendResponseBodyData}, node_commands::NodeCommand + node_api_router::{APIError, SendResponseBody, SendResponseBodyData}, + node_commands::NodeCommand, }; use super::api_v2_router::{create_success_response, with_sender}; @@ -167,6 +175,13 @@ pub fn job_routes( .and(warp::body::json()) .and_then(remove_job_handler); + let kill_job_route = warp::path("kill_job") + .and(warp::post()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::body::json()) + .and_then(kill_job_handler); + let export_messages_from_inbox_route = warp::path("export_messages_from_inbox") .and(warp::post()) .and(with_sender(node_commands_sender.clone())) @@ -208,6 +223,7 @@ pub fn job_routes( .or(get_message_traces_route) .or(fork_job_messages_route) .or(remove_job_route) + .or(kill_job_route) .or(export_messages_from_inbox_route) .or(add_messages_god_mode_route) .or(get_job_provider_route) @@ -277,6 +293,11 @@ pub struct RemoveJobRequest { pub job_id: String, } +#[derive(Deserialize, ToSchema)] +pub struct KillJobRequest { + pub conversation_inbox_name: String, +} + #[derive(Deserialize, ToSchema)] pub struct ExportInboxMessagesRequest { pub inbox_name: String, @@ -1241,6 +1262,45 @@ pub async fn remove_job_handler( } } +#[utoipa::path( + post, + path = "/v2/kill_job", + request_body = KillJobRequest, + responses( + (status = 200, description = "Successfully killed job", body = SendResponseBody), + (status = 400, description = "Bad request", body = APIError), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn kill_job_handler( + node_commands_sender: Sender, + authorization: String, + payload: KillJobRequest, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let (res_sender, res_receiver) = async_channel::bounded(1); + node_commands_sender + .send(NodeCommand::V2ApiKillJob { + bearer, + conversation_inbox_name: payload.conversation_inbox_name, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(response) => { + let response = create_success_response(response); + Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK)) + } + Err(error) => Ok(warp::reply::with_status( + warp::reply::json(&error), + StatusCode::from_u16(error.code).unwrap(), + )), + } +} + #[utoipa::path( post, path = "/v2/export_messages_from_inbox", @@ -1397,7 +1457,7 @@ pub async fn get_job_provider_handler( UpdateJobConfigRequest, UpdateSmartInboxNameRequest, SerializedLLMProvider, JobCreationInfo, JobMessage, NodeApiData, LLMProviderSubset, AssociatedUI, MinimalJobScope, CallbackAction, ShinkaiName, LLMProviderInterface, RetryMessageRequest, UpdateJobScopeRequest, ExportInboxMessagesFormat, ExportInboxMessagesRequest, - ShinkaiSubidentityType, OpenAI, Ollama, Groq, Gemini, Exo, ShinkaiBackend, SendResponseBody, SendResponseBodyData, APIError, GetToolingLogsRequest, GetMessageTracesRequest, ForkJobMessagesRequest, RemoveJobRequest) + ShinkaiSubidentityType, OpenAI, Ollama, Groq, Gemini, Exo, ShinkaiBackend, SendResponseBody, SendResponseBodyData, APIError, GetToolingLogsRequest, GetMessageTracesRequest, ForkJobMessagesRequest, RemoveJobRequest, KillJobRequest) ), tags( (name = "jobs", description = "Job API endpoints") diff --git a/shinkai-libs/shinkai-http-api/src/node_commands.rs b/shinkai-libs/shinkai-http-api/src/node_commands.rs index 2ef034238..86f480f69 100644 --- a/shinkai-libs/shinkai-http-api/src/node_commands.rs +++ b/shinkai-libs/shinkai-http-api/src/node_commands.rs @@ -283,6 +283,11 @@ pub enum NodeCommand { job_id: String, res: Sender>, }, + V2ApiKillJob { + bearer: String, + conversation_inbox_name: String, + res: Sender>, + }, V2ApiVecFSRetrievePathSimplifiedJson { bearer: String, payload: APIVecFsRetrievePathSimplifiedJson,