From d534b822c9bd502277d3867caf3539975820dca5 Mon Sep 17 00:00:00 2001 From: Eddie Date: Wed, 30 Apr 2025 18:05:26 -0400 Subject: [PATCH 1/9] Simulated Tool --- Cargo.lock | 28 +- .../generic_chain/generic_inference_chain.rs | 16 +- .../sheet_ui_inference_chain.rs | 4 +- .../shinkai-node/src/managers/galxe_quests.rs | 2 +- .../shinkai-node/src/managers/tool_router.rs | 63 ++- .../src/network/v1_api/api_v1_commands.rs | 54 +- .../network/v2_api/api_v2_commands_tools.rs | 79 ++- .../tool_definitions/definition_generation.rs | 2 +- .../tool_execution/execution_coordinator.rs | 25 +- .../native_tools/config_setup.rs | 12 +- shinkai-libs/shinkai-sqlite/src/errors.rs | 2 + .../src/shinkai_tool_manager.rs | 472 ++++++++++++++---- .../src/tools/deno_tools.rs | 2 +- .../shinkai-tools-primitives/src/tools/mod.rs | 1 + .../src/tools/shinkai_tool.rs | 40 +- .../src/tools/simulated_tool.rs | 245 +++++++++ 16 files changed, 843 insertions(+), 204 deletions(-) create mode 100644 shinkai-libs/shinkai-tools-primitives/src/tools/simulated_tool.rs diff --git a/Cargo.lock b/Cargo.lock index a9f5a42a1..c17e00f44 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6869,7 +6869,7 @@ checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" [[package]] name = "shinkai-spreadsheet-llm" -version = "0.9.20" +version = "0.9.21" dependencies = [ "async-trait", "chrono", @@ -6884,7 +6884,7 @@ dependencies = [ [[package]] name = "shinkai_crypto_identities" -version = "0.9.20" +version = "0.9.21" dependencies = [ "chrono", "dashmap", @@ -6900,7 +6900,7 @@ dependencies = [ [[package]] name = "shinkai_embedding" -version = "0.9.20" +version = "0.9.21" dependencies = [ "async-trait", "bincode", @@ -6921,7 +6921,7 @@ dependencies = [ [[package]] name = "shinkai_fs" -version = "0.9.20" +version = "0.9.21" dependencies = [ "async-trait", "bincode", @@ -6956,7 +6956,7 @@ dependencies = [ [[package]] name = "shinkai_http_api" -version = "0.9.20" +version = "0.9.21" dependencies = [ "anyhow", "async-channel 1.9.0", @@ -6993,7 +6993,7 @@ dependencies = [ [[package]] name = "shinkai_job_queue_manager" -version = "0.9.20" +version = "0.9.21" dependencies = [ "chrono", "serde", @@ -7007,7 +7007,7 @@ dependencies = [ [[package]] name = "shinkai_message_primitives" -version = "0.9.20" +version = "0.9.21" dependencies = [ "aes-gcm", "async-trait", @@ -7035,14 +7035,14 @@ dependencies = [ [[package]] name = "shinkai_mini_libs" -version = "0.9.20" +version = "0.9.21" dependencies = [ "base64 0.22.1", ] [[package]] name = "shinkai_node" -version = "0.9.20" +version = "0.9.21" dependencies = [ "aes-gcm", "anyhow", @@ -7116,7 +7116,7 @@ dependencies = [ [[package]] name = "shinkai_ocr" -version = "0.9.20" +version = "0.9.21" dependencies = [ "anyhow", "image 0.25.5", @@ -7131,7 +7131,7 @@ dependencies = [ [[package]] name = "shinkai_sheet" -version = "0.9.20" +version = "0.9.21" dependencies = [ "async-channel 1.9.0", "chrono", @@ -7145,7 +7145,7 @@ dependencies = [ [[package]] name = "shinkai_sqlite" -version = "0.9.20" +version = "0.9.21" dependencies = [ "bincode", "blake3", @@ -7177,7 +7177,7 @@ dependencies = [ [[package]] name = "shinkai_tcp_relayer" -version = "0.9.20" +version = "0.9.21" dependencies = [ "chrono", "clap 3.2.25", @@ -7196,7 +7196,7 @@ dependencies = [ [[package]] name = "shinkai_tools_primitives" -version = "0.9.20" +version = "0.9.21" dependencies = [ "anyhow", "base64 0.22.1", diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs index 4f40c25c3..1acd4ec0c 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs @@ -1,6 +1,6 @@ use crate::llm_provider::error::LLMProviderError; use crate::llm_provider::execution::chains::inference_chain_trait::{ - InferenceChain, InferenceChainContext, InferenceChainContextTrait, InferenceChainResult + InferenceChain, InferenceChainContext, InferenceChainContextTrait, InferenceChainResult, }; use crate::llm_provider::execution::prompts::general_prompts::JobPromptGenerator; use crate::llm_provider::execution::user_message_parser::ParsedUserMessage; @@ -24,7 +24,7 @@ use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provide use shinkai_message_primitives::schemas::shinkai_fs::ShinkaiFileChunkCollection; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::schemas::ws_types::{ - ToolMetadata, ToolStatus, ToolStatusType, WSMessageType, WSUpdateHandler, WidgetMetadata + ToolMetadata, ToolStatus, ToolStatusType, WSMessageType, WSUpdateHandler, WidgetMetadata, }; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::WSTopic; use shinkai_message_primitives::shinkai_utils::job_scope::MinimalJobScope; @@ -369,13 +369,17 @@ impl GenericInferenceChain { // If tool not found directly, try FTS and vector search let sanitized_query = tool_name.replace(|c: char| !c.is_alphanumeric() && c != ' ', " "); + // TODO [SIMULATED] + // Simulated tools should not be included in the FTS search. We have to detect if this is Agent Test Screen. // Perform FTS search - let fts_results = tool_router.sqlite_manager.search_tools_fts(&sanitized_query); + let fts_results = tool_router.sqlite_manager.search_tools_fts(&sanitized_query, true); // Perform vector search let vector_results = tool_router .sqlite_manager - .tool_vector_search(&sanitized_query, 5, false, true) + // TODO [SIMULATED] + // include_simulated should be false by default and only turned on if the user is the agent-test screen. + .tool_vector_search(&sanitized_query, 5, false, true, true) .await; match (fts_results, vector_results) { @@ -493,7 +497,9 @@ impl GenericInferenceChain { // to find the most relevant tools for the user's message if let Some(tool_router) = &tool_router { let results = tool_router - .combined_tool_search(&user_message.clone(), 7, false, true) + // TODO [SIMULATED] + // include_simulated should be false by default and only turned on if the user is the agent-test screen. + .combined_tool_search(&user_message.clone(), 7, false, true, true) .await; match results { diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/sheet_ui_chain/sheet_ui_inference_chain.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/sheet_ui_chain/sheet_ui_inference_chain.rs index 9b11d0d75..6ab534966 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/sheet_ui_chain/sheet_ui_inference_chain.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/sheet_ui_chain/sheet_ui_inference_chain.rs @@ -1,6 +1,6 @@ use crate::llm_provider::error::LLMProviderError; use crate::llm_provider::execution::chains::inference_chain_trait::{ - InferenceChain, InferenceChainContext, InferenceChainContextTrait, InferenceChainResult + InferenceChain, InferenceChainContext, InferenceChainContextTrait, InferenceChainResult, }; use crate::llm_provider::execution::chains::sheet_ui_chain::sheet_rust_functions::SheetRustFunctions; use crate::llm_provider::execution::prompts::general_prompts::JobPromptGenerator; @@ -240,7 +240,7 @@ impl SheetUIInferenceChain { // Search in JS Tools let results = tool_router - .vector_search_enabled_tools(&user_message.clone(), 2) + .vector_search_enabled_tools(&user_message.clone(), 2, false) .await .unwrap(); for result in results { diff --git a/shinkai-bin/shinkai-node/src/managers/galxe_quests.rs b/shinkai-bin/shinkai-node/src/managers/galxe_quests.rs index bcdd5fe13..e0dfc1cca 100644 --- a/shinkai-bin/shinkai-node/src/managers/galxe_quests.rs +++ b/shinkai-bin/shinkai-node/src/managers/galxe_quests.rs @@ -391,7 +391,7 @@ pub async fn compute_download_store_quest(db: Arc) -> Result Result, ToolError> { let tool_headers = self .sqlite_manager - .tool_vector_search(query, num_of_results, false, false) + .tool_vector_search(query, num_of_results, false, false, include_simulated) .await .map_err(|e| ToolError::DatabaseError(e.to_string()))?; // Note: we can add more code here to filter out low confidence results @@ -563,10 +564,11 @@ impl ToolRouter { &self, query: &str, num_of_results: u64, + include_simulated: bool, ) -> Result, ToolError> { let tool_headers = self .sqlite_manager - .tool_vector_search(query, num_of_results, false, true) + .tool_vector_search(query, num_of_results, false, true, include_simulated) .await .map_err(|e| ToolError::DatabaseError(e.to_string()))?; // Note: we can add more code here to filter out low confidence results @@ -578,10 +580,11 @@ impl ToolRouter { &self, query: &str, num_of_results: u64, + include_simulated: bool, ) -> Result, ToolError> { let tool_headers = self .sqlite_manager - .tool_vector_search(query, num_of_results, true, true) + .tool_vector_search(query, num_of_results, true, true, include_simulated) .await .map_err(|e| ToolError::DatabaseError(e.to_string()))?; // Note: we can add more code here to filter out low confidence results @@ -656,8 +659,37 @@ impl ToolRouter { } match shinkai_tool { + ShinkaiTool::Simulated(simulated_tool, _is_enabled) => { + let node_env: crate::utils::environment::NodeEnvironment = fetch_node_environment(); + let bearer = context.db().read_api_v2_key().unwrap_or_default().unwrap_or_default(); + let app_id = match context.full_job().associated_ui().as_ref() { + Some(AssociatedUI::Cron(cron_id)) => cron_id.clone(), + _ => context.full_job().job_id().to_string(), + }; + + let tool_id = shinkai_tool.tool_router_key().to_string_without_version().clone(); + let result = simulated_tool + .run( + bearer, + node_env.api_listen_address.ip().to_string(), + node_env.api_listen_address.port(), + app_id, + tool_id, + // TODO This should resolve the LLM provider, and not the agent_id + context.agent().clone().get_llm_provider_id().to_string(), + function_args, + function_config_vec, + ) + .await?; + let result_str = serde_json::to_string(&result) + .map_err(|e| LLMProviderError::FunctionExecutionError(e.to_string()))?; + return Ok(ToolCallFunctionResponse { + response: result_str, + function_call, + }); + } ShinkaiTool::Python(python_tool, _is_enabled) => { - let node_env = fetch_node_environment(); + let node_env: crate::utils::environment::NodeEnvironment = fetch_node_environment(); let node_storage_path = node_env .node_storage_path .clone() @@ -673,7 +705,7 @@ impl ToolRouter { let tools: Vec = context .db() .clone() - .get_all_tool_headers()? + .get_all_tool_headers(false)? .into_iter() .filter_map(|tool| match ToolRouterKey::from_string(&tool.tool_router_key) { Ok(tool_router_key) => Some(tool_router_key), @@ -690,8 +722,8 @@ impl ToolRouter { let envs = generate_execution_environment( context.db(), context.agent().clone().get_id().to_string(), - tool_id.clone(), app_id.clone(), + tool_id.clone(), agent_id, shinkai_tool.tool_router_key().to_string_without_version().clone(), app_id.clone(), @@ -853,7 +885,7 @@ impl ToolRouter { let tools: Vec = context .db() .clone() - .get_all_tool_headers()? + .get_all_tool_headers(false)? .into_iter() .filter_map(|tool| match ToolRouterKey::from_string(&tool.tool_router_key) { Ok(tool_router_key) => Some(tool_router_key), @@ -1212,7 +1244,7 @@ impl ToolRouter { let tools: Vec = self .sqlite_manager .clone() - .get_all_tool_headers()? + .get_all_tool_headers(false)? .into_iter() .filter_map(|tool| match ToolRouterKey::from_string(&tool.tool_router_key) { Ok(tool_router_key) => Some(tool_router_key), @@ -1284,6 +1316,7 @@ impl ToolRouter { num_of_results: u64, include_disabled: bool, include_network: bool, + include_simulated: bool, ) -> Result, ToolError> { // Sanitize the query to handle special characters let sanitized_query = query.replace(|c: char| !c.is_alphanumeric() && c != ' ', " "); @@ -1292,14 +1325,22 @@ impl ToolRouter { let vector_start_time = Instant::now(); let vector_search_result = self .sqlite_manager - .tool_vector_search(&sanitized_query, num_of_results, include_disabled, include_network) + .tool_vector_search( + &sanitized_query, + num_of_results, + include_disabled, + include_network, + include_simulated, + ) .await; let vector_elapsed_time = vector_start_time.elapsed(); println!("Time taken for vector search: {:?}", vector_elapsed_time); // Start the timer for FTS search let fts_start_time = Instant::now(); - let fts_search_result = self.sqlite_manager.search_tools_fts(&sanitized_query); + let fts_search_result = self + .sqlite_manager + .search_tools_fts(&sanitized_query, include_simulated); let fts_elapsed_time = fts_start_time.elapsed(); println!("Time taken for FTS search: {:?}", fts_elapsed_time); diff --git a/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_commands.rs b/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_commands.rs index 4ad13747d..06161bda7 100644 --- a/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_commands.rs +++ b/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_commands.rs @@ -1,42 +1,42 @@ use crate::managers::identity_manager::IdentityManagerTrait; use crate::managers::tool_router::ToolRouter; use crate::{ - llm_provider::job_manager::JobManager, managers::IdentityManager, network::{ - node::ProxyConnectionInfo, node_error::NodeError, node_shareable_logic::validate_message_main_logic, Node - }, utils::update_global_identity::update_global_identity_name + llm_provider::job_manager::JobManager, + managers::IdentityManager, + network::{ + node::ProxyConnectionInfo, node_error::NodeError, node_shareable_logic::validate_message_main_logic, Node, + }, + utils::update_global_identity::update_global_identity_name, }; -use aes_gcm::aead::{generic_array::GenericArray, Aead}; -use aes_gcm::Aes256Gcm; -use aes_gcm::KeyInit; -use async_channel::Sender; -use blake3::Hasher; -use ed25519_dalek::{SigningKey, VerifyingKey}; -use log::error; -use reqwest::StatusCode; -use serde_json::{json, Value as JsonValue}; - -use shinkai_embedding::embedding_generator::RemoteEmbeddingGenerator; -use shinkai_embedding::model_type::EmbeddingModelType; use shinkai_http_api::api_v1::api_v1_handlers::APIUseRegistrationCodeSuccessResponse; use shinkai_http_api::node_api_router::{APIError, SendResponseBodyData}; use shinkai_message_primitives::schemas::identity::{ - DeviceIdentity, Identity, IdentityType, RegistrationCode, StandardIdentity, StandardIdentityType + DeviceIdentity, Identity, IdentityType, RegistrationCode, StandardIdentity, StandardIdentityType, }; use shinkai_message_primitives::schemas::inbox_permission::InboxPermission; use shinkai_message_primitives::schemas::smart_inbox::SmartInbox; use shinkai_message_primitives::schemas::ws_types::WSUpdateHandler; use shinkai_message_primitives::{ schemas::{ - inbox_name::InboxName, llm_providers::serialized_llm_provider::SerializedLLMProvider, shinkai_name::{ShinkaiName, ShinkaiSubidentityType} - }, shinkai_message::{ - shinkai_message::{MessageBody, MessageData, ShinkaiMessage}, shinkai_message_schemas::{ - APIAddAgentRequest, APIAddOllamaModels, APIChangeJobAgentRequest, APIGetMessagesFromInboxRequest, APIReadUpToTimeRequest, IdentityPermissions, MessageSchemaType, RegistrationCodeRequest, RegistrationCodeType - } - }, shinkai_utils::{ + inbox_name::InboxName, + llm_providers::serialized_llm_provider::SerializedLLMProvider, + shinkai_name::{ShinkaiName, ShinkaiSubidentityType}, + }, + shinkai_message::{ + shinkai_message::{MessageBody, MessageData, ShinkaiMessage}, + shinkai_message_schemas::{ + APIAddAgentRequest, APIAddOllamaModels, APIChangeJobAgentRequest, APIGetMessagesFromInboxRequest, + APIReadUpToTimeRequest, IdentityPermissions, MessageSchemaType, RegistrationCodeRequest, + RegistrationCodeType, + }, + }, + shinkai_utils::{ encryption::{ - clone_static_secret_key, encryption_public_key_to_string, string_to_encryption_public_key, EncryptionMethod - }, shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}, signatures::{clone_signature_secret_key, signature_public_key_to_string, string_to_signature_public_key} - } + clone_static_secret_key, encryption_public_key_to_string, string_to_encryption_public_key, EncryptionMethod, + }, + shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}, + signatures::{clone_signature_secret_key, signature_public_key_to_string, string_to_signature_public_key}, + }, }; use shinkai_sqlite::errors::SqliteManagerError; use shinkai_sqlite::SqliteManager; @@ -2596,7 +2596,7 @@ impl Node { // Perform the internal search using tool_router if let Some(tool_router) = tool_router { - match tool_router.vector_search_all_tools(&search_query, 5).await { + match tool_router.vector_search_all_tools(&search_query, 5, false).await { Ok(tools) => { let tools_json = serde_json::to_value(tools).map_err(|err| NodeError { message: format!("Failed to serialize tools: {}", err), @@ -2669,7 +2669,7 @@ impl Node { // List all Shinkai tools let tools = { - match sqlite_manager.get_all_tool_headers() { + match sqlite_manager.get_all_tool_headers(false) { Ok(tools) => tools, Err(err) => { let api_error = APIError { diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs index da32f83ba..b9b1e4e17 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs @@ -1,7 +1,14 @@ use crate::{ - llm_provider::job_manager::JobManager, managers::{tool_router::ToolRouter, IdentityManager}, network::{node_error::NodeError, node_shareable_logic::download_zip_file, Node}, tools::{ - tool_definitions::definition_generation::{generate_tool_definitions, get_all_tools}, tool_execution::execution_coordinator::{execute_code, execute_mcp_tool_cmd, execute_tool_cmd}, tool_generation::v2_create_and_send_job_message, tool_prompts::{generate_code_prompt, tool_metadata_implementation_prompt} - }, utils::environment::NodeEnvironment + llm_provider::job_manager::JobManager, + managers::{tool_router::ToolRouter, IdentityManager}, + network::{node_error::NodeError, node_shareable_logic::download_zip_file, Node}, + tools::{ + tool_definitions::definition_generation::{generate_tool_definitions, get_all_tools}, + tool_execution::execution_coordinator::{execute_code, execute_mcp_tool_cmd, execute_tool_cmd}, + tool_generation::v2_create_and_send_job_message, + tool_prompts::{generate_code_prompt, tool_metadata_implementation_prompt}, + }, + utils::environment::NodeEnvironment, }; use async_channel::Sender; use chrono::Utc; @@ -11,26 +18,44 @@ use serde_json::{json, Map, Value}; use shinkai_http_api::node_api_router::{APIError, SendResponseBodyData}; use shinkai_message_primitives::{ schemas::{ - inbox_name::InboxName, indexable_version::IndexableVersion, job::JobLike, job_config::JobConfig, shinkai_name::ShinkaiSubidentityType, tool_router_key::ToolRouterKey - }, shinkai_message::shinkai_message_schemas::{CallbackAction, JobCreationInfo, MessageSchemaType}, shinkai_utils::{ - job_scope::MinimalJobScope, shinkai_message_builder::ShinkaiMessageBuilder, signatures::clone_signature_secret_key - } + inbox_name::InboxName, indexable_version::IndexableVersion, job::JobLike, job_config::JobConfig, + shinkai_name::ShinkaiSubidentityType, tool_router_key::ToolRouterKey, + }, + shinkai_message::shinkai_message_schemas::{CallbackAction, JobCreationInfo, MessageSchemaType}, + shinkai_utils::{ + job_scope::MinimalJobScope, shinkai_message_builder::ShinkaiMessageBuilder, + signatures::clone_signature_secret_key, + }, }; use shinkai_message_primitives::{ schemas::{ - shinkai_name::ShinkaiName, shinkai_tools::{CodeLanguage, DynamicToolType} - }, shinkai_message::shinkai_message_schemas::JobMessage + shinkai_name::ShinkaiName, + shinkai_tools::{CodeLanguage, DynamicToolType}, + }, + shinkai_message::shinkai_message_schemas::JobMessage, }; use shinkai_sqlite::{errors::SqliteManagerError, SqliteManager}; use shinkai_tools_primitives::tools::{ - deno_tools::DenoTool, error::ToolError, parameters::Parameters, python_tools::PythonTool, shinkai_tool::{ShinkaiTool, ShinkaiToolWithAssets}, tool_config::{OAuth, ToolConfig}, tool_output_arg::ToolOutputArg, tool_playground::{ToolPlayground, ToolPlaygroundMetadata} + deno_tools::DenoTool, + error::ToolError, + parameters::Parameters, + python_tools::PythonTool, + shinkai_tool::{ShinkaiTool, ShinkaiToolWithAssets}, + tool_config::{OAuth, ToolConfig}, + tool_output_arg::ToolOutputArg, + tool_playground::{ToolPlayground, ToolPlaygroundMetadata}, }; use shinkai_tools_primitives::tools::{ - shinkai_tool::ShinkaiToolHeader, tool_types::{OperatingSystem, RunnerType, ToolResult} + shinkai_tool::ShinkaiToolHeader, + tool_types::{OperatingSystem, RunnerType, ToolResult}, }; use std::{collections::HashMap, path::PathBuf}; use std::{ - env, fs::File, io::{Read, Write}, sync::Arc, time::Instant + env, + fs::File, + io::{Read, Write}, + sync::Arc, + time::Instant, }; use tokio::fs; use tokio::{process::Command, sync::Mutex}; @@ -148,9 +173,9 @@ impl Node { .iter() .map(|tool| tool.to_string_without_version()) .collect::>(); - db.tool_vector_search_with_vector_limited(embedding, 5, tool_names) + db.tool_vector_search_with_vector_limited(embedding, 5, tool_names, false) } else { - db.tool_vector_search(&sanitized_query, 5, false, true).await + db.tool_vector_search(&sanitized_query, 5, false, true, false).await }; let vector_elapsed_time = vector_start_time.elapsed(); @@ -158,7 +183,7 @@ impl Node { // Start the timer for FTS search let fts_start_time = Instant::now(); - let fts_search_result = db.search_tools_fts(&sanitized_query); + let fts_search_result = db.search_tools_fts(&sanitized_query, false); let fts_elapsed_time = fts_start_time.elapsed(); println!("Time taken for FTS search: {:?}", fts_elapsed_time); @@ -238,7 +263,7 @@ impl Node { } // List all tools - match db.get_all_tool_headers() { + match db.get_all_tool_headers(false) { Ok(tools) => { // Group tools by their base key (without version) use std::collections::HashMap; @@ -371,7 +396,7 @@ impl Node { let _bearer = Self::get_bearer_token(db.clone(), &res).await?; // List all tools - match db.get_all_tool_headers() { + match db.get_all_tool_headers(false) { Ok(tools) => { // Group tools by their base key (without version) use std::collections::HashMap; @@ -1483,7 +1508,7 @@ impl Node { let all_tools: Vec = db .clone() - .get_all_tool_headers()? + .get_all_tool_headers(false)? .into_iter() .filter_map(|tool| match ToolRouterKey::from_string(&tool.tool_router_key) { Ok(tool_router_key) => Some(tool_router_key), @@ -1614,7 +1639,7 @@ impl Node { .map(|t| t.to_string()) .collect(); let user_tools: Vec = tools.iter().map(|tools| tools.to_string_with_version()).collect(); - let all_tool_headers = db.clone().get_all_tool_headers()?; + let all_tool_headers = db.clone().get_all_tool_headers(false)?; all_tool_headers .into_iter() .map(|tool| ToolRouterKey::from_string(&tool.tool_router_key)) @@ -2850,7 +2875,7 @@ impl Node { } // Get all tools - match db.get_all_tool_headers() { + match db.get_all_tool_headers(false) { Ok(tools) => { let mut tool_statuses: Vec<(String, bool)> = Vec::new(); @@ -2910,7 +2935,7 @@ impl Node { } // Get all tools - match db.get_all_tool_headers() { + match db.get_all_tool_headers(false) { Ok(tools) => { let mut tool_statuses: Vec<(String, bool)> = Vec::new(); @@ -3113,7 +3138,7 @@ impl Node { } Err(_) => { // Create a new playground from the tool data - let output = new_tool.output_arg(); + let output = ToolOutputArg::empty(); let output_json = output.json; // Attempt to parse the output_json into a meaningful result let result: ToolResult = if !output_json.is_empty() { @@ -3609,7 +3634,7 @@ impl Node { // Get all tool-key-paths let tool_list: Vec = db - .get_all_tool_headers() + .get_all_tool_headers(false) .map_err(|e| APIError { code: 500, error: "Failed to get tool headers".to_string(), @@ -3826,7 +3851,7 @@ LANGUAGE={env_language} } // List all tools - match db.get_all_tool_headers() { + match db.get_all_tool_headers(false) { Ok(tools) => { // Group tools by their base key (without version) use std::collections::HashMap; @@ -3915,7 +3940,7 @@ LANGUAGE={env_language} tool.disable(); tool.disable_mcp(); } - + if let Err(e) = db.update_tool(tool).await { let err = APIError { code: 500, @@ -3925,7 +3950,7 @@ LANGUAGE={env_language} let _ = res.send(Err(err)).await; return Ok(()); } - + let response = json!({ "tool_router_key": tool_router_key, "enabled": enabled, @@ -4132,7 +4157,7 @@ LANGUAGE={env_language} } let tools: Vec = db - .get_all_tool_headers() + .get_all_tool_headers(false) .map_err(|_| ToolError::ExecutionError("Failed to get tool headers".to_string()))? .iter() .filter_map(|tool| match ToolRouterKey::from_string(&tool.tool_router_key) { diff --git a/shinkai-bin/shinkai-node/src/tools/tool_definitions/definition_generation.rs b/shinkai-bin/shinkai-node/src/tools/tool_definitions/definition_generation.rs index 7bbbb00e6..fa41d49b9 100644 --- a/shinkai-bin/shinkai-node/src/tools/tool_definitions/definition_generation.rs +++ b/shinkai-bin/shinkai-node/src/tools/tool_definitions/definition_generation.rs @@ -30,7 +30,7 @@ pub fn get_rust_tools() -> Vec { } pub async fn get_all_tools(sqlite_manager: Arc) -> Vec { - let mut all_tools = match sqlite_manager.get_all_tool_headers() { + let mut all_tools = match sqlite_manager.get_all_tool_headers(false) { Ok(data) => data, Err(_) => Vec::new(), }; diff --git a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_coordinator.rs b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_coordinator.rs index 88efa2122..391856d9c 100644 --- a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_coordinator.rs +++ b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_coordinator.rs @@ -325,6 +325,23 @@ pub async fn execute_tool_cmd( } match tool { + ShinkaiTool::Simulated(simulated_tool, _) => { + let node_env = fetch_node_environment(); + + return simulated_tool + .run( + bearer, + node_env.api_listen_address.ip().to_string(), + node_env.api_listen_address.port(), + app_id, + tool_id, + llm_provider, + parameters, + extra_config, + ) + .await + .map(|result| json!(result.data)); + } ShinkaiTool::Rust(_, _) => { try_to_execute_rust_tool( &tool_router_key, @@ -394,7 +411,7 @@ pub async fn execute_tool_cmd( .ok_or_else(|| ToolError::ExecutionError("Node storage path is not set".to_string()))?; let tools: Vec = db .clone() - .get_all_tool_headers() + .get_all_tool_headers(false) .map_err(|_| ToolError::ExecutionError("Failed to get tool headers".to_string()))? .into_iter() .filter_map(|tool| match ToolRouterKey::from_string(&tool.tool_router_key) { @@ -452,7 +469,7 @@ pub async fn execute_tool_cmd( .ok_or_else(|| ToolError::ExecutionError("Node storage path is not set".to_string()))?; let tools: Vec = db .clone() - .get_all_tool_headers() + .get_all_tool_headers(false) .map_err(|_| ToolError::ExecutionError("Failed to get tool headers".to_string()))? .into_iter() .filter_map(|tool| match ToolRouterKey::from_string(&tool.tool_router_key) { @@ -604,7 +621,7 @@ pub async fn execute_code( // Route based on the prefix let tools: Vec = db .clone() - .get_all_tool_headers() + .get_all_tool_headers(false) .map_err(|_| ToolError::ExecutionError("Failed to get tool headers".to_string()))? .into_iter() .filter_map(|tool| match ToolRouterKey::from_string(&tool.tool_router_key) { @@ -698,7 +715,7 @@ pub async fn check_code( eprintln!("[check_code] code_extracted: {}", code_extracted); let tools: Vec = sqlite_manager .clone() - .get_all_tool_headers() + .get_all_tool_headers(false) .map_err(|_| ToolError::ExecutionError("Failed to get tool headers".to_string()))? .into_iter() .filter_map(|tool| match ToolRouterKey::from_string(&tool.tool_router_key) { diff --git a/shinkai-bin/shinkai-node/src/tools/tool_implementation/native_tools/config_setup.rs b/shinkai-bin/shinkai-node/src/tools/tool_implementation/native_tools/config_setup.rs index 025d3af2e..d788ba106 100644 --- a/shinkai-bin/shinkai-node/src/tools/tool_implementation/native_tools/config_setup.rs +++ b/shinkai-bin/shinkai-node/src/tools/tool_implementation/native_tools/config_setup.rs @@ -145,7 +145,7 @@ async fn select_tool_router_key_from_intent( llm_provider: String, ) -> Result { let all_tool_headers = db - .get_all_tool_headers() + .get_all_tool_headers(false) .map_err(|e| ToolError::ExecutionError(format!("Failed to get all tool headers: {}", e)))?; let mut list_of_tools: String = "".to_string(); @@ -418,12 +418,18 @@ mod tests { assert_eq!(config1.description, "API Key"); assert!(config1.required); assert_eq!(config1.type_name, Some("string".to_string())); - assert_eq!(config1.key_value, Some(serde_json::Value::String("new_key".to_string()))); + assert_eq!( + config1.key_value, + Some(serde_json::Value::String("new_key".to_string())) + ); assert_eq!(config2.key_name, "secret"); assert_eq!(config2.description, "Secret"); assert!(!config2.required); assert_eq!(config2.type_name, Some("string".to_string())); - assert_eq!(config2.key_value, Some(serde_json::Value::String("old_secret".to_string()))); + assert_eq!( + config2.key_value, + Some(serde_json::Value::String("old_secret".to_string())) + ); } _ => panic!("Expected Deno tool"), } diff --git a/shinkai-libs/shinkai-sqlite/src/errors.rs b/shinkai-libs/shinkai-sqlite/src/errors.rs index bbf88cf5e..99056a43b 100644 --- a/shinkai-libs/shinkai-sqlite/src/errors.rs +++ b/shinkai-libs/shinkai-sqlite/src/errors.rs @@ -90,6 +90,8 @@ pub enum SqliteManagerError { ValidationError(String), #[error("Tool type mismatch")] ToolTypeMismatch, + #[error("Tool is simulated: {0}")] + ToolSimulated(String), // Add other error variants as needed } diff --git a/shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs b/shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs index 009858d5f..20e51850b 100644 --- a/shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs +++ b/shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs @@ -259,6 +259,7 @@ impl SqliteManager { num_results: u64, include_disabled: bool, include_network: bool, + include_simulated: bool, ) -> Result, SqliteManagerError> { // TODO: implement an LRU cache for the vector search // so we are not searching the database every time @@ -272,42 +273,21 @@ impl SqliteManager { // Perform the vector search to get tool_keys and distances let conn = self.get_connection()?; - let query = match (include_disabled, include_network) { - (true, true) => { - "SELECT v.tool_key, v.distance + let mut query = "SELECT v.tool_key, v.distance FROM shinkai_tools_vec_items v - WHERE v.embedding MATCH json(?1) - ORDER BY distance - LIMIT ?2" - } - (true, false) => { - "SELECT v.tool_key, v.distance - FROM shinkai_tools_vec_items v - WHERE v.embedding MATCH json(?1) - AND v.is_network = 0 - ORDER BY distance - LIMIT ?2" - } - (false, true) => { - "SELECT v.tool_key, v.distance - FROM shinkai_tools_vec_items v - WHERE v.embedding MATCH json(?1) - AND v.is_enabled = 1 - ORDER BY distance - LIMIT ?2" - } - (false, false) => { - "SELECT v.tool_key, v.distance - FROM shinkai_tools_vec_items v - WHERE v.embedding MATCH json(?1) - AND v.is_enabled = 1 - AND v.is_network = 0 - ORDER BY distance - LIMIT ?2" - } - }; + WHERE v.embedding MATCH json(?1)" + .to_string(); - let mut stmt = conn.prepare(query)?; + if !include_disabled { + query = format!("{} AND v.is_enabled = 1", query); + } + if !include_network { + query = format!("{} AND v.is_network = 0", query); + } + + query = format!("{} ORDER BY distance LIMIT ?2", query); + + let mut stmt = conn.prepare(&query)?; // Retrieve tool_keys and distances let tool_keys_and_distances: Vec<(String, f64)> = stmt @@ -317,7 +297,7 @@ impl SqliteManager { // Retrieve the corresponding ShinkaiToolHeaders and pair with distances let mut tools_with_distances = Vec::new(); for (tool_key, distance) in tool_keys_and_distances { - if let Ok(tool_header) = self.get_tool_header_by_key(&tool_key) { + if let Ok(tool_header) = self.get_tool_header_by_key(&tool_key, include_simulated) { tools_with_distances.push((tool_header, distance)); } } @@ -332,6 +312,7 @@ impl SqliteManager { num_results: u64, include_disabled: bool, include_network: bool, + include_simulated: bool, ) -> Result, SqliteManagerError> { if query.is_empty() { return Ok(Vec::new()); @@ -344,17 +325,32 @@ impl SqliteManager { })?; // Use the new function to perform the search - self.tool_vector_search_with_vector(embedding, num_results, include_disabled, include_network) + self.tool_vector_search_with_vector( + embedding, + num_results, + include_disabled, + include_network, + include_simulated, + ) } /// Retrieves a ShinkaiToolHeader based on its tool_key - pub fn get_tool_header_by_key(&self, tool_key: &str) -> Result { + pub fn get_tool_header_by_key( + &self, + tool_key: &str, + include_simulated: bool, + ) -> Result { let conn = self.get_connection()?; - let mut stmt = - conn.prepare("SELECT tool_header FROM shinkai_tools WHERE tool_key = ?1 ORDER BY version DESC LIMIT 1")?; + let mut stmt = conn.prepare( + "SELECT tool_header, tool_type FROM shinkai_tools WHERE tool_key = ?1 ORDER BY version DESC LIMIT 1", + )?; - let tool_header_data: Vec = stmt - .query_row(params![tool_key.to_lowercase()], |row| row.get(0)) + let tool_header_data: (Vec, String) = stmt + .query_row(params![tool_key.to_lowercase()], |row| { + let tool_header_data: Vec = row.get(0)?; + let tool_type = row.get(1)?; + Ok((tool_header_data, tool_type)) + }) .map_err(|e| { if e == rusqlite::Error::QueryReturnedNoRows { eprintln!("Tool not found with key: {}", tool_key); @@ -365,7 +361,11 @@ impl SqliteManager { } })?; - let tool_header: ShinkaiToolHeader = serde_json::from_slice(&tool_header_data).map_err(|e| { + if !include_simulated && tool_header_data.1 == "Simulated" { + return Err(SqliteManagerError::ToolSimulated(tool_key.to_string())); + } + + let tool_header: ShinkaiToolHeader = serde_json::from_slice(&tool_header_data.0).map_err(|e| { eprintln!("Deserialization error: {}", e); SqliteManagerError::SerializationError(e.to_string()) })?; @@ -510,9 +510,13 @@ impl SqliteManager { } /// Retrieves all ShinkaiToolHeader entries from the shinkai_tools table - pub fn get_all_tool_headers(&self) -> Result, SqliteManagerError> { + pub fn get_all_tool_headers(&self, include_simulated: bool) -> Result, SqliteManagerError> { let conn = self.get_connection()?; - let mut stmt = conn.prepare("SELECT tool_header FROM shinkai_tools")?; + let mut query = "SELECT tool_header FROM shinkai_tools".to_string(); + if !include_simulated { + query = format!("{} WHERE tool_type != 'Simulated'", query); + } + let mut stmt = conn.prepare(&query)?; let header_iter = stmt.query_map([], |row| { let tool_header_data: Vec = row.get(0)?; @@ -714,7 +718,11 @@ impl SqliteManager { } // Search the FTS table - pub fn search_tools_fts(&self, query: &str) -> Result, SqliteManagerError> { + pub fn search_tools_fts( + &self, + query: &str, + include_simulated: bool, + ) -> Result, SqliteManagerError> { // Get a connection from the in-memory pool for FTS operations let fts_conn = self .fts_pool @@ -748,23 +756,41 @@ impl SqliteManager { SqliteManagerError::DatabaseError(e) })?; - // Only fetch tool header if we haven't seen this one already - if seen.insert(name.clone()) { - let mut stmt = - conn.prepare("SELECT tool_header FROM shinkai_tools WHERE name = ?1 ORDER BY version DESC")?; - let tool_header_data: Vec = - stmt.query_row(rusqlite::params![name], |row| row.get(0)).map_err(|e| { - eprintln!("Persistent DB query error: {}", e); - SqliteManagerError::DatabaseError(e) - })?; - - let tool_header: ShinkaiToolHeader = serde_json::from_slice(&tool_header_data).map_err(|e| { + // Check if the tool is in the list. + if seen.contains(&name) { + continue; + } + + let mut query = "SELECT tool_header FROM shinkai_tools WHERE name = ?1".to_string(); + if !include_simulated { + query = format!("{} AND tool_type != 'Simulated'", query); + } + query = format!("{} ORDER BY version DESC", query); + + let mut stmt = conn.prepare(&query)?; + let tool_header_data: Result, rusqlite::Error> = + stmt.query_row(rusqlite::params![name], |row| row.get(0)); + + // If tool is not found, it was a simulated tool. + if tool_header_data.is_err() { + let err = tool_header_data.err().unwrap(); + if err == rusqlite::Error::QueryReturnedNoRows { + eprintln!("Tool not found: {}", name); + continue; + } else { + eprintln!("Persistent DB query error: {}", err); + return Err(SqliteManagerError::DatabaseError(err)); + } + } + + let tool_header: ShinkaiToolHeader = + serde_json::from_slice(&tool_header_data.unwrap()).map_err(|e| { eprintln!("Deserialization error: {}", e); SqliteManagerError::SerializationError(e.to_string()) })?; - tool_headers.push(tool_header); - } + seen.insert(name.clone()); + tool_headers.push(tool_header); } } @@ -835,6 +861,7 @@ impl SqliteManager { vector: Vec, num_results: u64, tool_keys: Vec, + include_simulated: bool, ) -> Result, SqliteManagerError> { // Serialize the vector to a JSON array string for the database query let vector_json = serde_json::to_string(&vector).map_err(|e| { @@ -849,11 +876,12 @@ impl SqliteManager { let mut current_limit = num_results * 2; // Adjust this multiplier as needed // SQL query to perform the vector search - let query = "SELECT v.tool_key, v.distance + let mut query = "SELECT v.tool_key, v.distance FROM shinkai_tools_vec_items v - WHERE v.embedding MATCH json(?1) - ORDER BY v.distance - LIMIT ?2"; + WHERE v.embedding MATCH json(?1)" + .to_string(); + + query = format!("{} ORDER BY v.distance LIMIT ?2", query); let mut tools_with_distances = Vec::new(); @@ -870,7 +898,7 @@ impl SqliteManager { // Filter results based on the provided tool keys for (tool_key, distance) in &tool_keys_and_distances { if tool_keys.contains(tool_key) { - if let Ok(tool_header) = self.get_tool_header_by_key(tool_key) { + if let Ok(tool_header) = self.get_tool_header_by_key(tool_key, include_simulated) { tools_with_distances.push((tool_header, *distance)); } } @@ -945,7 +973,9 @@ mod tests { use shinkai_tools_primitives::tools::deno_tools::DenoTool; use shinkai_tools_primitives::tools::network_tool::NetworkTool; use shinkai_tools_primitives::tools::parameters::Parameters; + use shinkai_tools_primitives::tools::parameters::Property; use shinkai_tools_primitives::tools::python_tools::PythonTool; + use shinkai_tools_primitives::tools::simulated_tool::SimulatedTool; use shinkai_tools_primitives::tools::tool_config::BasicConfig; use shinkai_tools_primitives::tools::tool_config::ToolConfig; use shinkai_tools_primitives::tools::tool_output_arg::ToolOutputArg; @@ -965,6 +995,131 @@ mod tests { SqliteManager::new(db_path, api_url, model_type).unwrap() } + #[tokio::test] + async fn add_simulated_random_number_generator_tool() { + let manager = setup_test_db().await; + let simulated_tool_example = SimulatedTool { + name: "Random Number Generator".to_string(), + description: "returns a random number, with a optional seed".to_string(), + keywords: vec!["random".to_string(), "number".to_string()], + config: vec![], + input_args: Parameters { + schema_type: "object".to_string(), + properties: { + let mut props = std::collections::HashMap::new(); + props.insert( + "seed".to_string(), + Property::new("number".to_string(), "seed for the random number".to_string()), + ); + props + }, + required: vec![], + }, + result: ToolResult { + r#type: "object".to_string(), + properties: serde_json::json!({ + "random_number": { + "type": "number", + "description": "The random number" + } + }), + required: vec![], + }, + embedding: None, + }; + + let vector = SqliteManager::generate_vector_for_testing(0.1); + let _ = manager.add_tool_with_vector(ShinkaiTool::Simulated(simulated_tool_example.clone(), true), vector); + let t = manager.get_tool_by_key(&simulated_tool_example.get_tool_router_key()); + assert!(t.is_ok()); + let tool = t.unwrap(); + assert_eq!(tool.name(), simulated_tool_example.clone().name); + } + + #[tokio::test] + async fn add_simulated_crypto_price_tool() { + let manager = setup_test_db().await; + let simulated_tool_example = SimulatedTool { + name: "Historical Crypto Prices".to_string(), + description: "returns the price of crypto over a given time period".to_string(), + keywords: vec!["crypto".to_string(), "price".to_string(), "historical".to_string()], + config: vec![], + input_args: Parameters { + schema_type: "object".to_string(), + properties: { + let mut props = std::collections::HashMap::new(); + props.insert( + "crypto_symbols".to_string(), + Property::with_array_items( + "string".to_string(), + Property::new( + "string".to_string(), + "the crypto symbols to get the price of".to_string(), + ), + ), + ); + props.insert( + "start_date".to_string(), + Property::new( + "string".to_string(), + "the start date of the price. Default 10 days ago.".to_string(), + ), + ); + props.insert( + "end_date".to_string(), + Property::new( + "string".to_string(), + "the end date of the price. Default is now.".to_string(), + ), + ); + props.insert( + "interval".to_string(), + Property::new( + "string".to_string(), + "the interval of the price. e.g., 5m, 15m, 1h, 4h, 12h, 1D, 1W, 1M, 1Y. Default is 1D." + .to_string(), + ), + ); + props + }, + required: vec!["crypto_symbols".to_string()], + }, + result: ToolResult { + r#type: "object".to_string(), + properties: serde_json::json!({ + "prices": { + "type": "array", + "description": "The price of the stock", + "items": { + "type": "object", + "properties": { + "date": { + "type": "string", + "description": "The date of the price" + }, + "price": { + "type": "number", + "description": "The price of the stock" + } + } + } + } + }), + required: vec![], + }, + embedding: None, + }; + let vector = SqliteManager::generate_vector_for_testing(0.1); + + let _ = manager.add_tool_with_vector(ShinkaiTool::Simulated(simulated_tool_example.clone(), true), vector); + + let t = manager.get_tool_by_key(&simulated_tool_example.get_tool_router_key()); + + assert!(t.is_ok()); + let tool = t.unwrap(); + assert_eq!(tool.name(), simulated_tool_example.clone().name); + } + #[tokio::test] async fn test_add_deno_tool() { let manager = setup_test_db().await; @@ -1124,9 +1279,21 @@ mod tests { tool_set: None, }; + // Create a SimulatedTool instance + let simulated_tool = SimulatedTool { + name: "Simulated Test Tool".to_string(), + description: "A simulated tool for testing".to_string(), + keywords: vec!["simulated".to_string(), "test".to_string()], + config: vec![], + input_args: Parameters::new(), + result: ToolResult::new("object".to_string(), serde_json::Value::Null, vec![]), + embedding: None, + }; + let shinkai_tool_1 = ShinkaiTool::Deno(deno_tool_1, true); let shinkai_tool_2 = ShinkaiTool::Deno(deno_tool_2, true); let shinkai_tool_3 = ShinkaiTool::Deno(deno_tool_3, true); + let shinkai_simulated = ShinkaiTool::Simulated(simulated_tool, true); // Add the tools to the database with different vectors manager @@ -1138,22 +1305,42 @@ mod tests { manager .add_tool_with_vector(shinkai_tool_3.clone(), SqliteManager::generate_vector_for_testing(0.9)) .unwrap(); + manager + .add_tool_with_vector( + shinkai_simulated.clone(), + SqliteManager::generate_vector_for_testing(0.3), + ) + .unwrap(); // Generate an embedding vector for the query that is close to the first tool let embedding_query = SqliteManager::generate_vector_for_testing(0.09); - // Perform a vector search using the generated embedding + // Perform a vector search using the generated embedding, excluding simulated tools let num_results = 1; let search_results: Vec = manager - .tool_vector_search_with_vector(embedding_query, num_results, true, true) + .tool_vector_search_with_vector(embedding_query.clone(), num_results, true, true, false) .unwrap() .iter() .map(|(tool, _distance)| tool.clone()) .collect(); - // Assert that the search results contain the first tool + // Assert that the search results contain the first tool and not the simulated tool assert_eq!(search_results.len(), 1); assert_eq!(search_results[0].name, "Deno Test Tool 1"); + assert!(!search_results.iter().any(|t| t.name == "Simulated Test Tool")); + + // Now perform a search including simulated tools + let search_results_with_simulated: Vec = manager + .tool_vector_search_with_vector(embedding_query, 10, true, true, true) + .unwrap() + .iter() + .map(|(tool, _distance)| tool.clone()) + .collect(); + + // Assert that the simulated tool is included in the results + assert!(search_results_with_simulated + .iter() + .any(|t| t.name == "Simulated Test Tool")); } #[tokio::test] @@ -1256,7 +1443,7 @@ mod tests { .unwrap(); // Print out the name and key for each tool in the database - let all_tools = manager.get_all_tool_headers().unwrap(); + let all_tools = manager.get_all_tool_headers(false).unwrap(); for tool in &all_tools { eprintln!("Tool name: {}, Tool key: {}", tool.name, tool.tool_router_key); } @@ -1433,6 +1620,17 @@ mod tests { }, ]; + // Add a simulated tool + let simulated_tool = SimulatedTool { + name: "Simulated Analysis Tool".to_string(), + description: "A simulated tool for testing".to_string(), + keywords: vec!["simulated".to_string(), "analysis".to_string()], + config: vec![], + input_args: Parameters::new(), + result: ToolResult::new("object".to_string(), serde_json::Value::Null, vec![]), + embedding: None, + }; + // Add all tools to the database for (i, tool) in tools.into_iter().enumerate() { let shinkai_tool = ShinkaiTool::Deno(tool, true); @@ -1444,29 +1642,43 @@ mod tests { } } - // Test exact match - match manager.search_tools_fts("Text Analysis") { + // Add the simulated tool + let shinkai_simulated = ShinkaiTool::Simulated(simulated_tool, true); + let vector = SqliteManager::generate_vector_for_testing(0.4); + manager.add_tool_with_vector(shinkai_simulated, vector).unwrap(); + + // Test exact match - should not include simulated tool + match manager.search_tools_fts("Text Analysis", false) { Ok(results) => { eprintln!("Search results: {:?}", results); assert_eq!(results.len(), 1); assert_eq!(results[0].name, "Text Analysis Helper"); + assert!(!results.iter().any(|t| t.name == "Simulated Analysis Tool")); } Err(e) => eprintln!("Search failed: {:?}", e), } - // Test partial match - let results = manager.search_tools_fts("visualization").unwrap(); + // Test partial match - should not include simulated tool + let results = manager.search_tools_fts("visualization", false).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].name, "Data Visualization Tool"); + assert!(!results.iter().any(|t| t.name == "Simulated Analysis Tool")); - // Test case insensitive match - let results = manager.search_tools_fts("IMAGE").unwrap(); + // Test case insensitive match - should not include simulated tool + let results = manager.search_tools_fts("IMAGE", false).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].name, "Image Processing Tool"); + assert!(!results.iter().any(|t| t.name == "Simulated Analysis Tool")); // Test no match - let results = manager.search_tools_fts("nonexistent").unwrap(); + let results = manager.search_tools_fts("nonexistent", false).unwrap(); assert_eq!(results.len(), 0); + + // Test with include_simulated=true to verify simulated tool is included + let results = manager.search_tools_fts("analysis", true).unwrap(); + assert_eq!(results.len(), 2); + assert!(results.iter().any(|t| t.name == "Text Analysis Helper")); + assert!(results.iter().any(|t| t.name == "Simulated Analysis Tool")); } #[tokio::test] @@ -1526,9 +1738,21 @@ mod tests { tool_set: None, }; + // Create a SimulatedTool instance + let simulated_tool = SimulatedTool { + name: "Simulated Test Tool".to_string(), + description: "A simulated tool for testing".to_string(), + keywords: vec!["simulated".to_string(), "test".to_string()], + config: vec![], + input_args: Parameters::new(), + result: ToolResult::new("object".to_string(), serde_json::Value::Null, vec![]), + embedding: None, + }; + // Add both tools to the database let shinkai_enabled = ShinkaiTool::Deno(enabled_tool, true); let shinkai_disabled = ShinkaiTool::Deno(disabled_tool, false); + let shinkai_simulated = ShinkaiTool::Simulated(simulated_tool, true); manager .add_tool_with_vector(shinkai_enabled.clone(), SqliteManager::generate_vector_for_testing(0.1)) @@ -1539,32 +1763,41 @@ mod tests { SqliteManager::generate_vector_for_testing(0.2), ) .unwrap(); + manager + .add_tool_with_vector( + shinkai_simulated.clone(), + SqliteManager::generate_vector_for_testing(0.3), + ) + .unwrap(); - // Test search excluding disabled tools + // Test search excluding disabled tools and simulated tools let embedding_query = SqliteManager::generate_vector_for_testing(0.15); let search_results: Vec = manager - .tool_vector_search_with_vector(embedding_query.clone(), 10, false, true) + .tool_vector_search_with_vector(embedding_query.clone(), 10, false, true, false) .unwrap() .iter() .map(|(tool, _distance)| tool.clone()) .collect(); - // Should only find the enabled tool + // Should only find the enabled tool, not the disabled or simulated tools assert_eq!(search_results.len(), 1); assert_eq!(search_results[0].name, "Enabled Test Tool"); + assert!(!search_results.iter().any(|t| t.name == "Disabled Test Tool")); + assert!(!search_results.iter().any(|t| t.name == "Simulated Test Tool")); - // Test search including disabled tools + // Test search including disabled tools but excluding simulated tools let search_results: Vec = manager - .tool_vector_search_with_vector(embedding_query.clone(), 10, true, true) + .tool_vector_search_with_vector(embedding_query.clone(), 10, true, true, false) .unwrap() .iter() .map(|(tool, _distance)| tool.clone()) .collect(); - // Should find both tools + // Should find both enabled and disabled tools, but not simulated tools assert_eq!(search_results.len(), 2); assert!(search_results.iter().any(|t| t.name == "Enabled Test Tool")); assert!(search_results.iter().any(|t| t.name == "Disabled Test Tool")); + assert!(!search_results.iter().any(|t| t.name == "Simulated Test Tool")); // Now disable the previously enabled tool if let ShinkaiTool::Deno(mut deno_tool, _is_enabled) = shinkai_enabled { @@ -1576,15 +1809,15 @@ mod tests { .unwrap(); } - // Search again excluding disabled tools - should now return empty results + // Search again excluding disabled tools and simulated tools - should now return empty results let search_results: Vec = manager - .tool_vector_search_with_vector(embedding_query, 10, false, true) + .tool_vector_search_with_vector(embedding_query, 10, false, true, false) .unwrap() .iter() .map(|(tool, _distance)| tool.clone()) .collect(); - // Should find no tools as both are now disabled + // Should find no tools as both are now disabled and simulated tools are excluded assert_eq!(search_results.len(), 0); } @@ -1673,10 +1906,22 @@ mod tests { restrictions: None, }; + // Create a SimulatedTool instance + let simulated_tool = SimulatedTool { + name: "Simulated Test Tool".to_string(), + description: "A simulated tool for testing".to_string(), + keywords: vec!["simulated".to_string(), "test".to_string()], + config: vec![], + input_args: Parameters::new(), + result: ToolResult::new("object".to_string(), serde_json::Value::Null, vec![]), + embedding: None, + }; + // Wrap the tools in ShinkaiTool variants let shinkai_enabled_non_network = ShinkaiTool::Deno(enabled_non_network_tool, true); let shinkai_disabled_non_network = ShinkaiTool::Deno(disabled_non_network_tool, false); let shinkai_enabled_network = ShinkaiTool::Network(enabled_network_tool, true); + let shinkai_simulated = ShinkaiTool::Simulated(simulated_tool, true); // Add the tools to the database manager @@ -1697,12 +1942,24 @@ mod tests { SqliteManager::generate_vector_for_testing(0.3), ) .unwrap(); + manager + .add_tool_with_vector( + shinkai_simulated.clone(), + SqliteManager::generate_vector_for_testing(0.4), + ) + .unwrap(); // Perform searches and verify results // Search including only enabled non-network tools let search_results: Vec = manager - .tool_vector_search_with_vector(SqliteManager::generate_vector_for_testing(0.15), 10, false, false) + .tool_vector_search_with_vector( + SqliteManager::generate_vector_for_testing(0.15), + 10, + false, + false, + false, + ) .unwrap() .iter() .map(|(tool, _distance)| tool.clone()) @@ -1710,10 +1967,11 @@ mod tests { assert_eq!(search_results.len(), 1); assert_eq!(search_results[0].name, "Enabled Non-Network Tool"); + assert!(!search_results.iter().any(|t| t.name == "Simulated Test Tool")); // Search including only enabled tools (both network and non-network) let search_results: Vec = manager - .tool_vector_search_with_vector(SqliteManager::generate_vector_for_testing(0.25), 10, false, true) + .tool_vector_search_with_vector(SqliteManager::generate_vector_for_testing(0.25), 10, false, true, false) .unwrap() .iter() .map(|(tool, _distance)| tool.clone()) @@ -1722,10 +1980,11 @@ mod tests { assert_eq!(search_results.len(), 2); assert!(search_results.iter().any(|t| t.name == "Enabled Non-Network Tool")); assert!(search_results.iter().any(|t| t.name == "Enabled Network Tool")); + assert!(!search_results.iter().any(|t| t.name == "Simulated Test Tool")); // Search including all non-network tools (enabled and disabled) let search_results: Vec = manager - .tool_vector_search_with_vector(SqliteManager::generate_vector_for_testing(0.15), 10, true, false) + .tool_vector_search_with_vector(SqliteManager::generate_vector_for_testing(0.15), 10, true, false, false) .unwrap() .iter() .map(|(tool, _distance)| tool.clone()) @@ -1734,10 +1993,11 @@ mod tests { assert_eq!(search_results.len(), 2); assert!(search_results.iter().any(|t| t.name == "Enabled Non-Network Tool")); assert!(search_results.iter().any(|t| t.name == "Disabled Non-Network Tool")); + assert!(!search_results.iter().any(|t| t.name == "Simulated Test Tool")); // Search including all tools (enabled, disabled, network, and non-network) let search_results: Vec = manager - .tool_vector_search_with_vector(SqliteManager::generate_vector_for_testing(0.25), 10, true, true) + .tool_vector_search_with_vector(SqliteManager::generate_vector_for_testing(0.25), 10, true, true, false) .unwrap() .iter() .map(|(tool, _distance)| tool.clone()) @@ -1747,6 +2007,7 @@ mod tests { assert!(search_results.iter().any(|t| t.name == "Enabled Non-Network Tool")); assert!(search_results.iter().any(|t| t.name == "Disabled Non-Network Tool")); assert!(search_results.iter().any(|t| t.name == "Enabled Network Tool")); + assert!(!search_results.iter().any(|t| t.name == "Simulated Test Tool")); } #[tokio::test] @@ -1859,7 +2120,7 @@ mod tests { // Perform the limited search let results = manager - .tool_vector_search_with_vector_limited(search_vector.clone(), 2, limited_tool_keys.clone()) + .tool_vector_search_with_vector_limited(search_vector.clone(), 2, limited_tool_keys.clone(), false) .unwrap(); // Verify results @@ -1867,7 +2128,7 @@ mod tests { // Perform the limited search let results = manager - .tool_vector_search_with_vector_limited(search_vector, 10, limited_tool_keys) + .tool_vector_search_with_vector_limited(search_vector, 10, limited_tool_keys, false) .unwrap(); // Verify results @@ -1944,17 +2205,35 @@ mod tests { tool_set: None, }; + // Add a simulated tool + let simulated_tool = SimulatedTool { + name: "Simulated Versioned Tool".to_string(), + description: "A simulated versioned tool".to_string(), + keywords: vec!["simulated".to_string(), "version".to_string()], + config: vec![], + input_args: Parameters::new(), + result: ToolResult::new("object".to_string(), serde_json::Value::Null, vec![]), + embedding: None, + }; + // Wrap the DenoTools in ShinkaiTool::Deno variants let shinkai_tool_v1 = ShinkaiTool::Deno(deno_tool_v1, true); let shinkai_tool_v2 = ShinkaiTool::Deno(deno_tool_v2, true); + let shinkai_simulated = ShinkaiTool::Simulated(simulated_tool, true); - // Add both tools to the database + // Add all tools to the database manager .add_tool_with_vector(shinkai_tool_v1.clone(), SqliteManager::generate_vector_for_testing(0.1)) .unwrap(); manager .add_tool_with_vector(shinkai_tool_v2.clone(), SqliteManager::generate_vector_for_testing(0.2)) .unwrap(); + manager + .add_tool_with_vector( + shinkai_simulated.clone(), + SqliteManager::generate_vector_for_testing(0.3), + ) + .unwrap(); // Retrieve and verify both tools are added let retrieved_tool_v1 = manager @@ -1990,7 +2269,7 @@ mod tests { // Perform a vector search and ensure it only returns one result let search_vector = SqliteManager::generate_vector_for_testing(0.2); let search_results = manager - .tool_vector_search_with_vector(search_vector, 1, true, true) + .tool_vector_search_with_vector(search_vector, 1, true, true, false) .unwrap(); // Verify that only one result is returned @@ -1998,11 +2277,18 @@ mod tests { assert_eq!(search_results[0].0.name, "Versioned Tool"); assert_eq!(search_results[0].0.version, "2.0"); - // Perform an FTS search and ensure it only returns one result (version 2.0) - let fts_results = manager.search_tools_fts("Versioned Tool").unwrap(); + // Perform an FTS search and ensure it only returns one result (version 2.0) and not the simulated tool + let fts_results = manager.search_tools_fts("Versioned Tool", false).unwrap(); assert_eq!(fts_results.len(), 1); assert_eq!(fts_results[0].name, "Versioned Tool"); assert_eq!(fts_results[0].version, "2.0"); + assert!(!fts_results.iter().any(|t| t.name == "Simulated Versioned Tool")); + + // Test with include_simulated=true to verify simulated tool is included + let fts_results = manager.search_tools_fts("Versioned Tool", true).unwrap(); + assert_eq!(fts_results.len(), 2); + assert!(fts_results.iter().any(|t| t.name == "Versioned Tool")); + assert!(fts_results.iter().any(|t| t.name == "Simulated Versioned Tool")); } #[tokio::test] diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs index 4146581a1..a6d0c4fc0 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs @@ -38,7 +38,7 @@ pub struct DenoTool { pub description: String, pub keywords: Vec, pub input_args: Parameters, - pub output_arg: ToolOutputArg, + pub output_arg: ToolOutputArg, // DEPRICATED. Use "Result" Instance instead. pub activated: bool, pub embedding: Option>, pub result: ToolResult, diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/mod.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/mod.rs index 5d58183f6..927e5385b 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/mod.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/mod.rs @@ -8,6 +8,7 @@ pub mod python_tools; pub mod rust_tools; pub mod shared_execution; pub mod shinkai_tool; +pub mod simulated_tool; pub mod tool_config; pub mod tool_output_arg; pub mod tool_playground; diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/shinkai_tool.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/shinkai_tool.rs index 60c601bcd..259b7c5d1 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/shinkai_tool.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/shinkai_tool.rs @@ -6,15 +6,18 @@ use serde_json::{self, Value}; use shinkai_message_primitives::schemas::tool_router_key::ToolRouterKey; use shinkai_message_primitives::schemas::{ - indexable_version::IndexableVersion, shinkai_tool_offering::{ShinkaiToolOffering, UsageType} + indexable_version::IndexableVersion, + shinkai_tool_offering::{ShinkaiToolOffering, UsageType}, }; use super::agent_tool_wrapper::AgentToolWrapper; +use super::simulated_tool::SimulatedTool; use super::tool_config::OAuth; use super::tool_playground::{SqlQuery, SqlTable}; use super::tool_types::{OperatingSystem, RunnerType}; use super::{ - deno_tools::DenoTool, network_tool::NetworkTool, parameters::Parameters, python_tools::PythonTool, tool_config::ToolConfig, tool_output_arg::ToolOutputArg + deno_tools::DenoTool, network_tool::NetworkTool, parameters::Parameters, python_tools::PythonTool, + tool_config::ToolConfig, tool_output_arg::ToolOutputArg, }; pub type IsEnabled = bool; @@ -27,6 +30,7 @@ pub enum ShinkaiTool { Deno(DenoTool, IsEnabled), Python(PythonTool, IsEnabled), Agent(AgentToolWrapper, IsEnabled), + Simulated(SimulatedTool, IsEnabled), } #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] @@ -83,7 +87,7 @@ impl ShinkaiTool { enabled: self.is_enabled(), mcp_enabled: Some(self.is_mcp_enabled()), input_args: self.input_args(), - output_arg: self.output_arg(), + output_arg: ToolOutputArg::empty(), config: self.get_js_tool_config().cloned(), usage_type: self.get_usage_type(), tool_offering: None, @@ -98,7 +102,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(d, _) => ("local".to_string(), d.author.clone(), d.name.clone()), ShinkaiTool::Python(p, _) => ("local".to_string(), p.author.clone(), p.name.clone()), ShinkaiTool::Agent(a, _) => ("local".to_string(), a.author.clone(), a.agent_id.clone()), - _ => unreachable!(), + ShinkaiTool::Simulated(s, _) => (s.get_source(), s.get_author(), s.name.clone()), }; ToolRouterKey::new(provider, author, name, None) } @@ -130,6 +134,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(d, _) => d.name.clone(), ShinkaiTool::Python(p, _) => p.name.clone(), ShinkaiTool::Agent(a, _) => a.name.clone(), + ShinkaiTool::Simulated(s, _) => s.name.clone(), } } /// Tool description @@ -140,6 +145,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(d, _) => d.description.clone(), ShinkaiTool::Python(p, _) => p.description.clone(), ShinkaiTool::Agent(a, _) => a.description.clone(), + ShinkaiTool::Simulated(s, _) => s.description.clone(), } } @@ -151,17 +157,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(d, _) => d.input_args.clone(), ShinkaiTool::Python(p, _) => p.input_args.clone(), ShinkaiTool::Agent(a, _) => a.input_args.clone(), - } - } - - /// Returns the input arguments of the tool - pub fn output_arg(&self) -> ToolOutputArg { - match self { - ShinkaiTool::Rust(r, _) => r.output_arg.clone(), - ShinkaiTool::Network(n, _) => n.output_arg.clone(), - ShinkaiTool::Deno(d, _) => d.output_arg.clone(), - ShinkaiTool::Python(p, _) => p.output_arg.clone(), - ShinkaiTool::Agent(a, _) => a.output_arg.clone(), + ShinkaiTool::Simulated(s, _) => s.input_args.clone(), } } @@ -173,6 +169,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(_, _) => "Deno", ShinkaiTool::Python(_, _) => "Python", ShinkaiTool::Agent(_, _) => "Agent", + ShinkaiTool::Simulated(_, _) => "Simulated", } } @@ -292,6 +289,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(d, _) => d.embedding = Some(embedding), ShinkaiTool::Python(p, _) => p.embedding = Some(embedding), ShinkaiTool::Agent(a, _) => a.embedding = Some(embedding), + ShinkaiTool::Simulated(s, _) => s.embedding = Some(embedding), } } @@ -335,6 +333,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(d, _) => d.embedding.clone(), ShinkaiTool::Python(p, _) => p.embedding.clone(), ShinkaiTool::Agent(a, _) => a.embedding.clone(), + ShinkaiTool::Simulated(s, _) => s.embedding.clone(), } } @@ -363,6 +362,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(d, _) => d.author.clone(), ShinkaiTool::Python(p, _) => p.author.clone(), ShinkaiTool::Agent(a, _) => a.author.clone(), + ShinkaiTool::Simulated(s, _) => "@@local.shinkai".to_string(), } } @@ -374,6 +374,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(d, _) => d.version.clone(), ShinkaiTool::Python(p, _) => p.version.clone(), ShinkaiTool::Agent(_a, _) => "1.0.0".to_string(), + ShinkaiTool::Simulated(s, _) => "1.0.0".to_string(), } } @@ -394,6 +395,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(_, enabled) => *enabled, ShinkaiTool::Python(_, enabled) => *enabled, ShinkaiTool::Agent(_a, enabled) => *enabled, + ShinkaiTool::Simulated(_, enabled) => *enabled, } } @@ -405,6 +407,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(tool, is_enabled) => *is_enabled && tool.mcp_enabled.unwrap_or(false), ShinkaiTool::Python(tool, is_enabled) => *is_enabled && tool.mcp_enabled.unwrap_or(false), ShinkaiTool::Agent(a, is_enabled) => *is_enabled && a.mcp_enabled.unwrap_or(false), + ShinkaiTool::Simulated(_, is_enabled) => *is_enabled, } } @@ -416,6 +419,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(_, enabled) => *enabled = true, ShinkaiTool::Python(_, enabled) => *enabled = true, ShinkaiTool::Agent(_, enabled) => *enabled = true, + ShinkaiTool::Simulated(_, enabled) => *enabled = true, } } @@ -426,6 +430,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(tool, _) => tool.mcp_enabled = Some(true), ShinkaiTool::Python(tool, _) => tool.mcp_enabled = Some(true), ShinkaiTool::Agent(tool, _) => tool.mcp_enabled = Some(true), + ShinkaiTool::Simulated(_, _) => (), } } @@ -437,6 +442,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(_, enabled) => *enabled = false, ShinkaiTool::Python(_, enabled) => *enabled = false, ShinkaiTool::Agent(_, enabled) => *enabled = false, + ShinkaiTool::Simulated(_, enabled) => *enabled = false, } } @@ -447,6 +453,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(tool, _) => tool.mcp_enabled = Some(false), ShinkaiTool::Python(tool, _) => tool.mcp_enabled = Some(false), ShinkaiTool::Agent(tool, _) => tool.mcp_enabled = Some(false), + ShinkaiTool::Simulated(_, _) => (), } } @@ -466,6 +473,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(js_tool, _) => js_tool.config.clone(), ShinkaiTool::Python(python_tool, _) => python_tool.config.clone(), ShinkaiTool::Agent(_a, _) => vec![], + ShinkaiTool::Simulated(_, _) => vec![], } } @@ -477,6 +485,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(deno_tool, _) => deno_tool.check_required_config_fields(), ShinkaiTool::Python(_, _) => true, ShinkaiTool::Agent(_, _) => true, + ShinkaiTool::Simulated(_, _) => true, } } @@ -550,6 +559,7 @@ impl ShinkaiTool { ShinkaiTool::Deno(d, _) => d.keywords.clone(), ShinkaiTool::Python(p, _) => p.keywords.clone(), ShinkaiTool::Agent(_a, _) => vec![], + ShinkaiTool::Simulated(s, _) => s.keywords.clone(), } } } diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/simulated_tool.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/simulated_tool.rs new file mode 100644 index 000000000..40d71d30e --- /dev/null +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/simulated_tool.rs @@ -0,0 +1,245 @@ +use super::parameters::Parameters; +use super::tool_config::ToolConfig; +use super::tool_playground::ToolPlaygroundMetadata; +use super::tool_types::{RunnerType, ToolResult}; +use crate::tools::error::ToolError; +use serde_json::{json, Map}; +use shinkai_message_primitives::schemas::tool_router_key::ToolRouterKey; +use shinkai_tools_runner::tools::run_result::RunResult; +use std::collections::HashMap; +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct SimulatedTool { + pub name: String, + pub description: String, + pub keywords: Vec, + pub config: Vec, + pub input_args: Parameters, + pub result: ToolResult, + pub embedding: Option>, +} + +impl SimulatedTool { + pub fn to_json(&self) -> Result { + serde_json::to_string(self).map_err(|_| ToolError::FailedJSONParsing) + } + + pub fn from_json(json: &str) -> Result { + let deserialized: Self = serde_json::from_str(json)?; + Ok(deserialized) + } + + pub fn get_author(&self) -> String { + "@@localhost.shinkai".to_string() + } + + pub fn get_version(&self) -> Option { + None + } + + pub fn get_source(&self) -> String { + "@@simulated.local".to_string() + } + + pub fn get_tool_router_key(&self) -> String { + let trk = ToolRouterKey { + source: self.get_source(), + author: self.get_author(), + name: self.name.clone(), + version: None, + }; + trk.to_string_without_version() + } + + pub async fn build_example_json( + key: &String, + hash_map: &Map, + ) -> Result<(String, serde_json::Value), ToolError> { + let unknown = serde_json::Value::String("unknown".to_string()); + let r#type = hash_map.get("type").unwrap_or(&unknown).as_str().unwrap_or_default(); + + if r#type == "array" { + let items = hash_map.get("items").unwrap().as_object().unwrap(); + let _items_type = items + .get("type") + .unwrap_or(&unknown.clone()) + .as_str() + .unwrap_or_default(); + } + + match r#type { + "string" => { + return Ok((key.to_string(), serde_json::Value::String("EXAMPLE_VALUE".to_string()))); + } + "number" => { + return Ok(( + key.to_string(), + serde_json::Value::Number(serde_json::Number::from(100)), + )); + } + "boolean" => { + return Ok((key.to_string(), serde_json::Value::Bool(true))); + } + "array" => { + return Ok((key.to_string(), serde_json::Value::Array(vec![]))); + } + "object" => { + // Create recursive call to build the object + let properties = hash_map.get("properties").unwrap().as_object().unwrap(); + let mut object = HashMap::new(); + for (property_key, property_value) in properties { + let property_values = property_value.as_object().unwrap(); + let property_value_json = + Box::pin(SimulatedTool::build_example_json(&property_key, property_values)).await?; + object.insert(property_key.to_string(), property_value_json); + } + let o = serde_json::to_value(&object).unwrap(); + return Ok((key.to_string(), o)); + } + _ => { + return Err(ToolError::FailedJSONParsing); + } + } + } + + pub async fn run( + &self, + bearer_token: String, + api_ip: String, + api_port: u16, + app_id: String, + tool_id: String, + llm_provider: String, + parameters: serde_json::Map, + extra_config: Vec, + ) -> Result { + let metadata = ToolPlaygroundMetadata { + name: self.name.clone(), + homepage: None, + version: "1.0.0".to_string(), + description: self.description.clone(), + author: "@@local.shinkai".to_string(), + keywords: self.keywords.clone(), + configurations: self.config.clone(), + parameters: self.input_args.clone(), + result: self.result.clone(), + sql_tables: vec![], + sql_queries: vec![], + tools: None, + oauth: None, + runner: RunnerType::OnlyHost, + operating_system: vec![], + tool_set: None, + }; + + let mut example_result: HashMap = HashMap::new(); + let properties = self.result.properties.as_object().unwrap(); + println!("result: {:?}", self.result.properties); + println!("properties: {:?}", properties); + for (property_name, object) in properties.iter() { + println!("property_name: {}", property_name); + println!("object: {:?}", object); + let property_values = object.as_object().unwrap(); + let (key, value) = SimulatedTool::build_example_json(&property_name, property_values).await?; + example_result.insert(key, value); + } + + let prompt = format!( + r#" + + * You are a TOOL simulator. + * The TOOL description is given in the metadata tag. + * The TOOL inputs: "parameters" and "configuration" are given in the inputs tag. + * Given the description, parameters and configuration, generate a mock response. + * Simulate the expected successful output of the tool. + * Do not let the user know this is a mock response. + * use the output_example tag as example for the response. + + + +{} + + + +parameters: {} + +extra_config: {} + + + + * Write a valid JSON Object + * Follow the output_example tag as base example, you may add more key/values. + * Do not output any other comments, ideas, planning, thoughts or comments. + + + +```json +{} +``` + + "#, + serde_json::to_value(&metadata).unwrap(), + serde_json::to_value(¶meters).unwrap(), + serde_json::to_value(&extra_config).unwrap(), + serde_json::to_value(&example_result).unwrap(), + ); + + // TODO Check if HTTP or HTTPS is used + let url = format!("http://{}:{}/v2/tool_execution", api_ip, api_port); + let client = reqwest::Client::new(); + let response = client + .post(url) + .header("Authorization", format!("Bearer {}", bearer_token)) + .header("x-shinkai-tool-id", tool_id) + .header("x-shinkai-app-id", app_id) + .header("x-shinkai-llm-provider", llm_provider.clone()) + .header("Content-Type", "application/json; charset=utf-8") + .json(&json!({ + "tool_router_key": "local:::__official_shinkai:::shinkai_llm_prompt_processor", + "llm_provider": llm_provider.clone(), + "parameters": { + "prompt": prompt + } + })) + .send() + .await?; + + let body = response.json::().await?; + // We expect the response to have this format, but cannot guarantee it + // So we try to parse the response as JSON and if it fails, we return the entire object + // { + // "message": "```json\n{\"random_number\":57}\n```" + // } + if body.get("message").is_none() { + return Err(ToolError::ExecutionError(format!( + "[SimulatedTool] No message found in response: {}", + body + ))); + } + + let message_value = body.get("message").unwrap(); + let message = message_value.as_str().unwrap_or_default(); + let mut message_split = message.split("\n").collect::>(); + let len = message_split.clone().len(); + + if message_split[0] == "```json" { + message_split[0] = ""; + } + if message_split[len - 1] == "```" { + message_split[len - 1] = ""; + } + let cleaned_json = message_split.join(" "); + + match serde_json::from_str::(&cleaned_json) { + Ok(data) => return Ok(RunResult { data }), + Err(e) => { + println!( + "[SimulatedTool] Could not parse as JSON: {} - so returning entire object", + body + ); + return Ok(RunResult { + data: message_value.clone(), + }); + } + } + } +} From 0097a5494972f7b2157ce004c6b7018c792e108e Mon Sep 17 00:00:00 2001 From: Eddie Date: Wed, 30 Apr 2025 19:05:32 -0400 Subject: [PATCH 2/9] added placeholder endpoints --- .../src/network/handle_commands_list.rs | 23 ++++++ .../src/network/v1_api/api_v1_commands.rs | 11 ++- .../v2_api/api_v2_commands_my_agent_offers.rs | 31 +++++++ .../network/v2_api/api_v2_commands_tools.rs | 60 ++++++++++++++ .../api_v2/api_v2_handlers_my_agent_offers.rs | 82 +++++++++++++++++++ .../src/api_v2/api_v2_handlers_tools.rs | 60 ++++++++++++++ .../src/api_v2/api_v2_router.rs | 3 + .../shinkai-http-api/src/node_commands.rs | 12 +++ 8 files changed, 281 insertions(+), 1 deletion(-) diff --git a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs index 41fbff301..a52a8d74e 100644 --- a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs +++ b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs @@ -2025,6 +2025,17 @@ impl Node { // let _ = Node::v2_api_list_files_in_inbox(db_clone, bearer, inbox_name, res).await; // }); // } + NodeCommand::V2ApiGenerateAgentFromPrompt { + bearer, + prompt, + res, + } => { + let db_clone = Arc::clone(&self.db); + let node_name = self.node_name.clone(); + tokio::spawn(async move { + let _ = Node::v2_api_generate_agent_from_prompt(db_clone, bearer, prompt, node_name, res).await; + }); + } NodeCommand::V2ApiGetToolOffering { bearer, tool_key_name, @@ -3222,6 +3233,18 @@ impl Node { let _ = Node::v2_api_get_preferences(db_clone, bearer, res).await; }); } + NodeCommand::V2ApiCreateSimulatedTool { + bearer, + name, + prompt, + agent_id, + res, + } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_create_simulated_tool(db_clone, bearer, name, prompt, agent_id, res).await; + }); + } _ => (), } } diff --git a/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_commands.rs b/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_commands.rs index 06161bda7..f4b9278e9 100644 --- a/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_commands.rs +++ b/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_commands.rs @@ -8,6 +8,16 @@ use crate::{ }, utils::update_global_identity::update_global_identity_name, }; +use async_channel::Sender; +use blake3::Hasher; +use ed25519_dalek::SigningKey; +use ed25519_dalek::VerifyingKey; +use log::error; +use reqwest::StatusCode; +use serde_json::json; +use serde_json::Value as JsonValue; +use shinkai_embedding::embedding_generator::RemoteEmbeddingGenerator; +use shinkai_embedding::model_type::EmbeddingModelType; use shinkai_http_api::api_v1::api_v1_handlers::APIUseRegistrationCodeSuccessResponse; use shinkai_http_api::node_api_router::{APIError, SendResponseBodyData}; use shinkai_message_primitives::schemas::identity::{ @@ -41,7 +51,6 @@ use shinkai_message_primitives::{ use shinkai_sqlite::errors::SqliteManagerError; use shinkai_sqlite::SqliteManager; use shinkai_tools_primitives::tools::shinkai_tool::ShinkaiTool; - use std::{convert::TryInto, env, sync::Arc, time::Instant}; use tokio::sync::Mutex; use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey}; diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_my_agent_offers.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_my_agent_offers.rs index 0bfa774ba..c05c4985d 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_my_agent_offers.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_my_agent_offers.rs @@ -298,4 +298,35 @@ impl Node { Ok(()) } + + pub async fn v2_api_generate_agent_from_prompt( + db: Arc, + bearer: String, + prompt: String, + _: ShinkaiName, + res: Sender>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + let json_response = serde_json::json!(format!( + r#"{{ + "name": "Agent 1", + "description": "Agent 1 is a helpful agent that can help with tasks", + "system_prompt": {prompt}, + "simulated_tools": [{{ + "name": "Get Crypto Token Price", + "prompt": "Get the price of a token" + }}, {{ + "name": "Swap Crypto Tokens", + "prompt": "Swap tokens" + }}] + }}"# + )); + + let _ = res.send(Ok(json_response)).await; + Ok(()) + } } diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs index b9b1e4e17..c82a22367 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs @@ -41,6 +41,7 @@ use shinkai_tools_primitives::tools::{ parameters::Parameters, python_tools::PythonTool, shinkai_tool::{ShinkaiTool, ShinkaiToolWithAssets}, + simulated_tool::SimulatedTool, tool_config::{OAuth, ToolConfig}, tool_output_arg::ToolOutputArg, tool_playground::{ToolPlayground, ToolPlaygroundMetadata}, @@ -4272,6 +4273,65 @@ LANGUAGE={env_language} Ok(()) } + + pub async fn v2_api_create_simulated_tool( + db: Arc, + bearer: String, + name: String, + prompt: String, + agent_id: String, + res: Sender>, + ) -> Result<(), NodeError> { + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Create a simulated tool + let simulated_tool = SimulatedTool { + name: format!("{agent_id}-{name}"), + description: prompt, + keywords: vec![], + config: vec![], + input_args: Parameters::new(), + result: ToolResult { + r#type: "object".to_string(), + properties: json!({ + "return": { + "type": "string", + "description": "The return value of the tool" + } + }), + required: vec!["return".to_string()], + }, + embedding: None, + }; + + // Save the tool + let save_result = db.add_tool(ShinkaiTool::Simulated(simulated_tool.clone(), true)).await; + + match save_result { + Ok(_) => { + let _ = res + .send(Ok(json!({ + "status": "success", + "message": "Simulated tool created successfully", + "tool_router_key": simulated_tool.get_tool_router_key(), + }))) + .await; + Ok(()) + } + Err(err) => { + let _ = res + .send(Err(APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to create simulated tool: {}", err), + })) + .await; + Ok(()) + } + } + } } #[cfg(test)] diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_my_agent_offers.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_my_agent_offers.rs index e69de29bb..51a023f73 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_my_agent_offers.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_my_agent_offers.rs @@ -0,0 +1,82 @@ +use async_channel::Sender; +use serde::Deserialize; +use serde_json::json; +use utoipa::{OpenApi, ToSchema}; +use warp::http::StatusCode; +use warp::Filter; + +use crate::{node_api_router::APIError, node_commands::NodeCommand}; + +use super::api_v2_router::{create_success_response, with_sender}; + +pub fn my_agent_offers_routes( + node_commands_sender: Sender, +) -> impl Filter + Clone { + let generate_agent_from_prompt_route = warp::path("generate_agent_from_prompt") + .and(warp::post()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::body::json()) + .and_then(generate_agent_from_prompt_handler); + + generate_agent_from_prompt_route +} + +#[derive(Deserialize, ToSchema)] +pub struct GenerateAgentFromPromptRequest { + pub prompt: String, +} + +#[utoipa::path( + post, + path = "/v2/generate_agent_from_prompt", + request_body = GenerateAgentFromPromptRequest, + responses( + (status = 200, description = "Successfully set tool offering", body = Value), + (status = 400, description = "Bad request", body = APIError), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn generate_agent_from_prompt_handler( + node_commands_sender: Sender, + authorization: String, + payload: GenerateAgentFromPromptRequest, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let (res_sender, res_receiver) = async_channel::bounded(1); + node_commands_sender + .send(NodeCommand::V2ApiGenerateAgentFromPrompt { + bearer, + prompt: payload.prompt, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(response) => { + let response = create_success_response(json!({ "result": response })); + Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK)) + } + Err(error) => Ok(warp::reply::with_status( + warp::reply::json(&error), + StatusCode::from_u16(error.code).unwrap(), + )), + } +} + +#[derive(OpenApi)] +#[openapi( + paths( + generate_agent_from_prompt_handler, + ), + components( + schemas(GenerateAgentFromPromptRequest, APIError) + ), + tags( + (name = "my_agent", description = "My Agents") + ) +)] +pub struct ToolOfferingsApiDoc; diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs index 8c2c25a68..abfc8afb8 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs @@ -283,6 +283,13 @@ pub fn tool_routes( .and(warp::body::json()) .and_then(tool_check_handler); + let create_simulated_tool_route = warp::path("create_simulated_tool") + .and(warp::post()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::body::json()) + .and_then(create_simulated_tool_handler); + tool_execution_route .or(code_execution_route) .or(tool_definitions_route) @@ -319,6 +326,7 @@ pub fn tool_routes( .or(set_tool_mcp_enabled_route) .or(copy_tool_asset_route) .or(tool_check_route) + .or(create_simulated_tool_route) } pub fn safe_folder_name(tool_router_key: &str) -> String { @@ -2362,6 +2370,57 @@ pub async fn tool_check_handler( } } + + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct SimulatedShinkaiToolRequest { + pub agent_id: String, + pub name: String, + pub prompt: String +} + +#[utoipa::path( + post, + path = "/v2/create_simulated_tool", + request_body = SimulatedShinkaiToolRequest, + responses( + (status = 200, description = "Successfully created simulated tool", body = Value), + (status = 400, description = "Bad request", body = APIError), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn create_simulated_tool_handler( + sender: Sender, + authorization: String, + payload: SimulatedShinkaiToolRequest, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let (res_sender, res_receiver) = async_channel::bounded(1); + sender + .send(NodeCommand::V2ApiCreateSimulatedTool { + bearer, + agent_id: payload.agent_id.clone(), + name: payload.name.clone(), + prompt: payload.prompt.clone(), + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(response) => { + let response = create_success_response(response); + Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK)) + } + Err(error) => Ok(warp::reply::with_status( + warp::reply::json(&error), + StatusCode::from_u16(error.code).unwrap(), + )), + } +} + #[derive(OpenApi)] #[openapi( paths( @@ -2397,6 +2456,7 @@ pub async fn tool_check_handler( set_tool_mcp_enabled_handler, copy_tool_assets_handler, tool_check_handler, + create_simulated_tool_handler, ), components( schemas( diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_router.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_router.rs index e157d7e2a..f15810f90 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_router.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_router.rs @@ -2,6 +2,7 @@ use crate::node_commands::NodeCommand; use super::api_v2_handlers_cron::cron_routes; use super::api_v2_handlers_ext_agent_offers::ext_agent_offers_routes; +use super::api_v2_handlers_my_agent_offers::my_agent_offers_routes; use super::api_v2_handlers_general::general_routes; use super::api_v2_handlers_jobs::job_routes; use super::api_v2_handlers_oauth::oauth_routes; @@ -25,6 +26,7 @@ pub fn v2_routes( let vecfs_routes = vecfs_routes(node_commands_sender.clone(), node_name.clone()); let job_routes = job_routes(node_commands_sender.clone(), node_name.clone()); let ext_agent_offers = ext_agent_offers_routes(node_commands_sender.clone()); + let my_agent_offers = my_agent_offers_routes(node_commands_sender.clone()); let wallet_routes = wallet_routes(node_commands_sender.clone()); let custom_prompt = prompt_routes(node_commands_sender.clone()); let swagger_ui_routes = swagger_ui_routes(); @@ -37,6 +39,7 @@ pub fn v2_routes( .or(vecfs_routes) .or(job_routes) .or(ext_agent_offers) + .or(my_agent_offers) .or(wallet_routes) .or(custom_prompt) .or(swagger_ui_routes) diff --git a/shinkai-libs/shinkai-http-api/src/node_commands.rs b/shinkai-libs/shinkai-http-api/src/node_commands.rs index 42ed23622..8e7a9b23a 100644 --- a/shinkai-libs/shinkai-http-api/src/node_commands.rs +++ b/shinkai-libs/shinkai-http-api/src/node_commands.rs @@ -816,6 +816,11 @@ pub enum NodeCommand { payload: APIAddOllamaModels, res: Sender>, }, + V2ApiGenerateAgentFromPrompt { + bearer: String, + prompt: String, + res: Sender>, + }, V2ApiGetToolOffering { bearer: String, tool_key_name: String, @@ -1342,4 +1347,11 @@ pub enum NodeCommand { bearer: String, res: Sender>, }, + V2ApiCreateSimulatedTool { + bearer: String, + name: String, + prompt: String, + agent_id: String, + res: Sender>, + }, } From 913df2e46299990b3ca61d65ce2af541d6177df7 Mon Sep 17 00:00:00 2001 From: Eddie Date: Thu, 1 May 2025 14:56:42 -0400 Subject: [PATCH 3/9] Agent & Tool Simulation Creation from LLM --- .../src/network/handle_commands_list.rs | 53 ++++- .../v2_api/api_v2_commands_my_agent_offers.rs | 184 +++++++++++++++--- .../network/v2_api/api_v2_commands_tools.rs | 139 +++++++++++-- .../api_v2/api_v2_handlers_my_agent_offers.rs | 2 + .../src/api_v2/api_v2_handlers_tools.rs | 4 +- .../shinkai-http-api/src/node_commands.rs | 2 + .../src/tools/simulated_tool.rs | 21 +- .../src/tools/tool_playground.rs | 6 +- 8 files changed, 354 insertions(+), 57 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs index a52a8d74e..794fbb76f 100644 --- a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs +++ b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs @@ -1977,7 +1977,12 @@ impl Node { let _ = Node::v2_api_add_shinkai_tool(db_clone, bearer, node_env, shinkai_tool, res).await; }); } - NodeCommand::V2ApiGetShinkaiTool { bearer, payload, serialize_config, res } => { + NodeCommand::V2ApiGetShinkaiTool { + bearer, + payload, + serialize_config, + res, + } => { let db_clone = Arc::clone(&self.db); tokio::spawn(async move { let _ = Node::v2_api_get_shinkai_tool(db_clone, bearer, payload, serialize_config, res).await; @@ -2028,12 +2033,31 @@ impl Node { NodeCommand::V2ApiGenerateAgentFromPrompt { bearer, prompt, + llm_provider, res, } => { let db_clone = Arc::clone(&self.db); let node_name = self.node_name.clone(); + let identity_manager_clone = self.identity_manager.clone(); + let job_manager_clone = self.job_manager.clone().unwrap(); + let encryption_secret_key_clone = self.encryption_secret_key.clone(); + let encryption_public_key_clone = self.encryption_public_key.clone(); + let signing_secret_key_clone = self.identity_secret_key.clone(); tokio::spawn(async move { - let _ = Node::v2_api_generate_agent_from_prompt(db_clone, bearer, prompt, node_name, res).await; + let _ = Node::v2_api_generate_agent_from_prompt( + db_clone, + bearer, + prompt, + llm_provider, + node_name, + identity_manager_clone, + job_manager_clone, + encryption_secret_key_clone, + encryption_public_key_clone, + signing_secret_key_clone, + res, + ) + .await; }); } NodeCommand::V2ApiGetToolOffering { @@ -3238,11 +3262,34 @@ impl Node { name, prompt, agent_id, + llm_provider, res, } => { let db_clone = Arc::clone(&self.db); + let node_name_clone = self.node_name.clone(); + let identity_manager_clone = self.identity_manager.clone(); + let job_manager_clone = self.job_manager.clone().unwrap(); + let encryption_secret_key_clone = self.encryption_secret_key.clone(); + let encryption_public_key_clone = self.encryption_public_key; + let signing_secret_key_clone = self.identity_secret_key.clone(); + tokio::spawn(async move { - let _ = Node::v2_api_create_simulated_tool(db_clone, bearer, name, prompt, agent_id, res).await; + let _ = Node::v2_api_create_simulated_tool( + db_clone, + bearer, + name, + prompt, + agent_id, + llm_provider, + node_name_clone, + identity_manager_clone, + job_manager_clone, + encryption_secret_key_clone, + encryption_public_key_clone, + signing_secret_key_clone, + res, + ) + .await; }); } _ => (), diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_my_agent_offers.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_my_agent_offers.rs index c05c4985d..819e81f71 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_my_agent_offers.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_my_agent_offers.rs @@ -1,19 +1,26 @@ -use std::sync::Arc; - +use crate::{ + llm_provider::job_manager::JobManager, + managers::IdentityManager, + network::{ + agent_payments_manager::my_agent_offerings_manager::MyAgentOfferingsManager, node_error::NodeError, Node, + }, + tools::tool_implementation::{ + native_tools::llm_prompt_processor::LlmPromptProcessorTool, tool_traits::ToolExecutor, + }, +}; use async_channel::Sender; +use ed25519_dalek::SigningKey; use reqwest::StatusCode; -use serde_json::Value; - +use serde_json::{Map, Value}; use shinkai_http_api::node_api_router::APIError; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::schemas::shinkai_tool_offering::UsageTypeInquiry; use shinkai_sqlite::{errors::SqliteManagerError, SqliteManager}; -use shinkai_tools_primitives::tools::shinkai_tool::ShinkaiTool; -use tokio::sync::{Mutex, RwLock}; - -use crate::network::{ - agent_payments_manager::my_agent_offerings_manager::MyAgentOfferingsManager, node_error::NodeError, Node, -}; +use shinkai_tools_primitives::tools::{error::ToolError, shinkai_tool::ShinkaiTool}; +use std::sync::Arc; +use tokio::sync::Mutex; +use x25519_dalek::PublicKey as EncryptionPublicKey; +use x25519_dalek::StaticSecret as EncryptionStaticKey; impl Node { pub async fn v2_api_request_invoice( @@ -303,7 +310,13 @@ impl Node { db: Arc, bearer: String, prompt: String, - _: ShinkaiName, + llm_provider: String, + node_name: ShinkaiName, + identity_manager: Arc>, + job_manager: Arc>, + encryption_secret_key: EncryptionStaticKey, + encryption_public_key: EncryptionPublicKey, + signing_secret_key: SigningKey, res: Sender>, ) -> Result<(), NodeError> { // Validate the bearer token @@ -311,22 +324,139 @@ impl Node { return Ok(()); } - let json_response = serde_json::json!(format!( - r#"{{ - "name": "Agent 1", - "description": "Agent 1 is a helpful agent that can help with tasks", - "system_prompt": {prompt}, - "simulated_tools": [{{ - "name": "Get Crypto Token Price", - "prompt": "Get the price of a token" - }}, {{ - "name": "Swap Crypto Tokens", - "prompt": "Swap tokens" - }}] - }}"# - )); - - let _ = res.send(Ok(json_response)).await; + let mut retries = 3; + + let mut agent; + // We are using a LLM to generate the agent. + // So the return type might be flacky. + // We will retry 3 times. + loop { + // Create a new agent from the prompt + agent = Self::create_agent_from_prompt( + prompt.clone(), + llm_provider.clone(), + db.clone(), + bearer.clone(), + node_name.clone(), + identity_manager.clone(), + job_manager.clone(), + encryption_secret_key.clone(), + encryption_public_key.clone(), + signing_secret_key.clone(), + ) + .await; + + if let Err(e) = agent { + if retries > 0 { + retries -= 1; + continue; + } + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to generate agent: {}", e), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + break; + } + + let _ = res.send(Ok(agent.unwrap())).await; Ok(()) } + + async fn create_agent_from_prompt( + agent_prompt: String, + llm_provider: String, + db: Arc, + bearer: String, + node_name: ShinkaiName, + identity_manager: Arc>, + job_manager: Arc>, + encryption_secret_key: EncryptionStaticKey, + encryption_public_key: EncryptionPublicKey, + signing_secret_key: SigningKey, + ) -> Result { + let static_prompt = r#" + + * You are a generator of AGENT definitions. + * An AGENT contains a name, indications, instructions and a list of tools to achieve a specific goal. + * An AGENT is specified in JSON format. + * An AGENT will later be called by the user with a prompt. + * The "rules" tag has the instructions you must follow to generate the AGENT. + * The "command" tag has the definition of the AGENT. + * The "output" tag has an example of the output you must generate. + * Do not include any other text than the JSON. + + + + * The name must be concise. + * The indications must be a short description of the AGENT. + * The instructions, is the system prompt, the mind of the AGENT. It must be a markdown of steps, actions to achieve the goal or goals of the AGENT. + * The tools must be a list of tools that the AGENT will need to achieve the goal. + * A tool must description must contain: it's action, the expected inputs and return object. + * The output must be a valid JSON. + + + +```json + { + "name": "AGENT_NAME", + "indications": "AGENT_INDICATIONS", + "instructions": "AGENT_INSTRUCTIONS", + "tools": [{ + "name": "TOOL_NAME", + "description": "TOOL_DESCRIPTION", + }] + } +``` + +"#; + + let prompt = format!( + "{} + + +{} + +", + static_prompt, agent_prompt + ); + let mut parameters = Map::new(); + parameters.insert("prompt".to_string(), prompt.clone().into()); + parameters.insert("llm_provider".to_string(), llm_provider.clone().into()); + + let body = LlmPromptProcessorTool::execute( + bearer, + "tool_id".to_string(), + "app_id".to_string(), + db, + node_name, + identity_manager, + job_manager, + encryption_secret_key, + encryption_public_key, + signing_secret_key, + ¶meters, + llm_provider, + ) + .await?; + + let message_value = body.get("message").unwrap(); + let message = message_value.as_str().unwrap_or_default(); + let mut message_split = message.split("\n").collect::>(); + let len = message_split.clone().len(); + + if message_split[0] == "```json" { + message_split[0] = ""; + } + if message_split[len - 1] == "```" { + message_split[len - 1] = ""; + } + let cleaned_json = message_split.join(" "); + + serde_json::from_str::(&cleaned_json) + .map_err(|e| ToolError::ExecutionError(format!("Failed to parse JSON: {}", e))) + } } diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs index c82a22367..ff7caab4b 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs @@ -2,6 +2,9 @@ use crate::{ llm_provider::job_manager::JobManager, managers::{tool_router::ToolRouter, IdentityManager}, network::{node_error::NodeError, node_shareable_logic::download_zip_file, Node}, + tools::tool_implementation::{ + native_tools::llm_prompt_processor::LlmPromptProcessorTool, tool_traits::ToolExecutor, + }, tools::{ tool_definitions::definition_generation::{generate_tool_definitions, get_all_tools}, tool_execution::execution_coordinator::{execute_code, execute_mcp_tool_cmd, execute_tool_cmd}, @@ -12,7 +15,8 @@ use crate::{ }; use async_channel::Sender; use chrono::Utc; -use ed25519_dalek::{ed25519::signature::SignerMut, SigningKey}; +use ed25519_dalek::ed25519::signature::SignerMut; +use ed25519_dalek::SigningKey; use reqwest::StatusCode; use serde_json::{json, Map, Value}; use shinkai_http_api::node_api_router::{APIError, SendResponseBodyData}; @@ -50,16 +54,18 @@ use shinkai_tools_primitives::tools::{ shinkai_tool::ShinkaiToolHeader, tool_types::{OperatingSystem, RunnerType, ToolResult}, }; +use std::sync::Arc; use std::{collections::HashMap, path::PathBuf}; use std::{ env, fs::File, io::{Read, Write}, - sync::Arc, time::Instant, }; use tokio::fs; -use tokio::{process::Command, sync::Mutex}; +use tokio::process::Command; +use tokio::sync::Mutex; +use uuid::Uuid; use x25519_dalek::PublicKey as EncryptionPublicKey; use x25519_dalek::StaticSecret as EncryptionStaticKey; use zip::{write::FileOptions, ZipWriter}; @@ -4280,29 +4286,63 @@ LANGUAGE={env_language} name: String, prompt: String, agent_id: String, + llm_provider: String, + node_name: ShinkaiName, + identity_manager: Arc>, + job_manager: Arc>, + encryption_secret_key: EncryptionStaticKey, + encryption_public_key: EncryptionPublicKey, + signing_secret_key: SigningKey, res: Sender>, ) -> Result<(), NodeError> { if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { return Ok(()); } + let mut retries = 3; + let mut tool_metadata_implementation; + loop { + tool_metadata_implementation = Self::create_tool_metadata_from_prompt( + prompt.clone(), + llm_provider.clone(), + db.clone(), + bearer.clone(), + node_name.clone(), + identity_manager.clone(), + job_manager.clone(), + encryption_secret_key.clone(), + encryption_public_key.clone(), + signing_secret_key.clone(), + ) + .await; + + if let Err(e) = tool_metadata_implementation { + if retries > 0 { + retries -= 1; + continue; + } + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to create simulated tool: {:?}", e), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + break; + } + let tool_metadata_implementation = tool_metadata_implementation.unwrap(); + let random = Uuid::new_v4().to_string(); + let metadata: ToolPlaygroundMetadata = serde_json::from_value(tool_metadata_implementation).unwrap(); + // Create a simulated tool let simulated_tool = SimulatedTool { - name: format!("{agent_id}-{name}"), - description: prompt, - keywords: vec![], - config: vec![], - input_args: Parameters::new(), - result: ToolResult { - r#type: "object".to_string(), - properties: json!({ - "return": { - "type": "string", - "description": "The return value of the tool" - } - }), - required: vec!["return".to_string()], - }, + name: format!("{}-{}-{}", agent_id, metadata.name, random), + description: metadata.description.to_string(), + keywords: metadata.keywords, + config: metadata.configurations, + input_args: metadata.parameters, + result: metadata.result, embedding: None, }; @@ -4332,6 +4372,69 @@ LANGUAGE={env_language} } } } + + async fn create_tool_metadata_from_prompt( + tool_prompt: String, + llm_provider: String, + db: Arc, + bearer: String, + node_name: ShinkaiName, + identity_manager: Arc>, + job_manager: Arc>, + encryption_secret_key: EncryptionStaticKey, + encryption_public_key: EncryptionPublicKey, + signing_secret_key: SigningKey, + ) -> Result { + let tool_metadata_implementation = tool_metadata_implementation_prompt( + CodeLanguage::Typescript, + tool_prompt, + vec![], + identity_manager.clone(), + ) + .await + .map_err(|e| { + ToolError::ExecutionError(format!( + "Failed to generate tool metadata implementation: {}", + e.message + )) + })?; + + let mut parameters = Map::new(); + parameters.insert("prompt".to_string(), tool_metadata_implementation.clone().into()); + parameters.insert("llm_provider".to_string(), llm_provider.clone().into()); + + let body = LlmPromptProcessorTool::execute( + bearer, + "tool_id".to_string(), + "app_id".to_string(), + db, + node_name, + identity_manager, + job_manager, + encryption_secret_key, + encryption_public_key, + signing_secret_key, + ¶meters, + llm_provider, + ) + .await?; + + let message_value = body.get("message").unwrap(); + let message = message_value.as_str().unwrap_or_default(); + let mut message_split = message.split("\n").collect::>(); + let len = message_split.clone().len(); + + if message_split[0] == "```json" { + message_split[0] = ""; + } + if message_split[len - 1] == "```" { + message_split[len - 1] = ""; + } + let cleaned_json = message_split.join(" "); + + serde_json::from_str::(&cleaned_json) + .map_err(|e| ToolError::ExecutionError(format!("Failed to parse JSON: {}", e))) + } } #[cfg(test)] diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_my_agent_offers.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_my_agent_offers.rs index 51a023f73..6ef069c15 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_my_agent_offers.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_my_agent_offers.rs @@ -25,6 +25,7 @@ pub fn my_agent_offers_routes( #[derive(Deserialize, ToSchema)] pub struct GenerateAgentFromPromptRequest { pub prompt: String, + pub llm_provider: String, } #[utoipa::path( @@ -48,6 +49,7 @@ pub async fn generate_agent_from_prompt_handler( .send(NodeCommand::V2ApiGenerateAgentFromPrompt { bearer, prompt: payload.prompt, + llm_provider: payload.llm_provider, res: res_sender, }) .await diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs index abfc8afb8..3381b6eed 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs @@ -2376,7 +2376,8 @@ pub async fn tool_check_handler( pub struct SimulatedShinkaiToolRequest { pub agent_id: String, pub name: String, - pub prompt: String + pub prompt: String, + pub llm_provider: String, } #[utoipa::path( @@ -2402,6 +2403,7 @@ pub async fn create_simulated_tool_handler( agent_id: payload.agent_id.clone(), name: payload.name.clone(), prompt: payload.prompt.clone(), + llm_provider: payload.llm_provider.clone(), res: res_sender, }) .await diff --git a/shinkai-libs/shinkai-http-api/src/node_commands.rs b/shinkai-libs/shinkai-http-api/src/node_commands.rs index 8e7a9b23a..3d2040e08 100644 --- a/shinkai-libs/shinkai-http-api/src/node_commands.rs +++ b/shinkai-libs/shinkai-http-api/src/node_commands.rs @@ -819,6 +819,7 @@ pub enum NodeCommand { V2ApiGenerateAgentFromPrompt { bearer: String, prompt: String, + llm_provider: String, res: Sender>, }, V2ApiGetToolOffering { @@ -1352,6 +1353,7 @@ pub enum NodeCommand { name: String, prompt: String, agent_id: String, + llm_provider: String, res: Sender>, }, } diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/simulated_tool.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/simulated_tool.rs index 40d71d30e..c5bd73f90 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/simulated_tool.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/simulated_tool.rs @@ -58,12 +58,15 @@ impl SimulatedTool { let r#type = hash_map.get("type").unwrap_or(&unknown).as_str().unwrap_or_default(); if r#type == "array" { - let items = hash_map.get("items").unwrap().as_object().unwrap(); - let _items_type = items - .get("type") - .unwrap_or(&unknown.clone()) - .as_str() - .unwrap_or_default(); + let items = hash_map.get("items"); + if items.is_some() { + let items = items.unwrap().as_object().unwrap(); + let _items_type = items + .get("type") + .unwrap_or(&unknown.clone()) + .as_str() + .unwrap_or_default(); + } } match r#type { @@ -84,7 +87,11 @@ impl SimulatedTool { } "object" => { // Create recursive call to build the object - let properties = hash_map.get("properties").unwrap().as_object().unwrap(); + let properties = hash_map.get("properties"); + if properties.is_none() { + return Ok((key.to_string(), json!({}))); + } + let properties = properties.unwrap().as_object().unwrap(); let mut object = HashMap::new(); for (property_key, property_value) in properties { let property_values = property_value.as_object().unwrap(); diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/tool_playground.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/tool_playground.rs index 14f207376..e93e3e7dc 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/tool_playground.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/tool_playground.rs @@ -61,6 +61,7 @@ pub struct ToolPlaygroundMetadata { pub tools: Option>, pub oauth: Option>, pub runner: RunnerType, + #[serde(rename = "operatingSystem")] pub operating_system: Vec, pub tool_set: Option, } @@ -137,7 +138,10 @@ where let mut property = HashMap::new(); property.insert("description", basic.description.clone()); // If type_name is None, default to "string" - let type_value = basic.type_name.as_ref().map_or_else(|| "string".to_string(), |t| t.clone()); + let type_value = basic + .type_name + .as_ref() + .map_or_else(|| "string".to_string(), |t| t.clone()); property.insert("type", type_value); properties.insert(basic.key_name.clone(), property); From 4825928bf40aced67cb2add3e47bf5747bfff043 Mon Sep 17 00:00:00 2001 From: Eddie Date: Thu, 1 May 2025 16:41:02 -0400 Subject: [PATCH 4/9] Updated endpoint for listing tools --- .../src/network/handle_commands_list.rs | 8 +- .../v2_api/api_v2_commands_my_agent_offers.rs | 10 +++ .../network/v2_api/api_v2_commands_tools.rs | 80 +++++++++++++------ .../tests/it/a3_micropayment_flow_tests.rs | 9 ++- .../src/api_v2/api_v2_handlers_tools.rs | 3 +- .../shinkai-http-api/src/node_commands.rs | 1 + .../src/shinkai_tool_manager.rs | 4 +- .../src/tools/simulated_tool.rs | 11 ++- 8 files changed, 92 insertions(+), 34 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs index 794fbb76f..f5b1735a5 100644 --- a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs +++ b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs @@ -1918,7 +1918,12 @@ impl Node { let _ = Node::v2_api_scan_ollama_models(db_clone, bearer, res).await; }); } - NodeCommand::V2ApiListAllShinkaiTools { bearer, category, res } => { + NodeCommand::V2ApiListAllShinkaiTools { + bearer, + category, + include_simulated, + res, + } => { let db_clone = Arc::clone(&self.db); let tool_router_clone = self.tool_router.clone(); let node_name_clone = self.node_name.clone(); @@ -1928,6 +1933,7 @@ impl Node { bearer, node_name_clone, category, + include_simulated, tool_router_clone, res, ) diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_my_agent_offers.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_my_agent_offers.rs index 819e81f71..c9cad6fcd 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_my_agent_offers.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_my_agent_offers.rs @@ -362,6 +362,16 @@ impl Node { break; } + // { + // "name": "AGENT_NAME", + // "indications": "AGENT_INDICATIONS", + // "instructions": "AGENT_INSTRUCTIONS", + // "tools": [{ + // "name": "TOOL_NAME", + // "description": "TOOL_DESCRIPTION", + // }] + // } + let _ = res.send(Ok(agent.unwrap())).await; Ok(()) } diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs index ff7caab4b..3d66c8b00 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs @@ -261,6 +261,7 @@ impl Node { bearer: String, node_name: ShinkaiName, category: Option, + include_simulated: bool, tool_router: Option>, res: Sender>, ) -> Result<(), NodeError> { @@ -270,7 +271,7 @@ impl Node { } // List all tools - match db.get_all_tool_headers(false) { + match db.get_all_tool_headers(include_simulated) { Ok(tools) => { // Group tools by their base key (without version) use std::collections::HashMap; @@ -4303,7 +4304,7 @@ LANGUAGE={env_language} let mut tool_metadata_implementation; loop { tool_metadata_implementation = Self::create_tool_metadata_from_prompt( - prompt.clone(), + format!("{}: {}", name, prompt), llm_provider.clone(), db.clone(), bearer.clone(), @@ -4348,29 +4349,62 @@ LANGUAGE={env_language} // Save the tool let save_result = db.add_tool(ShinkaiTool::Simulated(simulated_tool.clone(), true)).await; + if let Err(e) = save_result { + let _ = res + .send(Err(APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to create simulated tool: {}", e), + })) + .await; + return Ok(()); + } - match save_result { - Ok(_) => { - let _ = res - .send(Ok(json!({ - "status": "success", - "message": "Simulated tool created successfully", - "tool_router_key": simulated_tool.get_tool_router_key(), - }))) - .await; - Ok(()) - } - Err(err) => { - let _ = res - .send(Err(APIError { - code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), - error: "Internal Server Error".to_string(), - message: format!("Failed to create simulated tool: {}", err), - })) - .await; - Ok(()) - } + let agent = db.clone().get_agent(&agent_id); + if let Err(e) = agent { + let _ = res + .send(Err(APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to get agent: {}", e), + })) + .await; + return Ok(()); + } + let agent = agent.unwrap(); + if agent.is_none() { + let _ = res + .send(Err(APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: "Failed to get agent".to_string(), + })) + .await; + return Ok(()); + } + + let mut agent = agent.unwrap(); + agent.tools.push(simulated_tool.get_tool_router_key()); + let status = db.update_agent(agent); + if let Err(e) = status { + let _ = res + .send(Err(APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to update agent: {}", e), + })) + .await; + return Ok(()); } + + let _ = res + .send(Ok(json!({ + "status": "success", + "message": "Simulated tool created successfully", + "tool_router_key": simulated_tool.get_tool_router_key_string(), + }))) + .await; + Ok(()) } async fn create_tool_metadata_from_prompt( diff --git a/shinkai-bin/shinkai-node/tests/it/a3_micropayment_flow_tests.rs b/shinkai-bin/shinkai-node/tests/it/a3_micropayment_flow_tests.rs index 3c8993edf..a7a870140 100644 --- a/shinkai-bin/shinkai-node/tests/it/a3_micropayment_flow_tests.rs +++ b/shinkai-bin/shinkai-node/tests/it/a3_micropayment_flow_tests.rs @@ -3,15 +3,16 @@ use shinkai_http_api::node_commands::NodeCommand; use shinkai_message_primitives::schemas::invoices::{Invoice, InvoiceStatusEnum}; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::schemas::shinkai_tool_offering::{ - AssetPayment, ShinkaiToolOffering, ToolPrice, UsageType, UsageTypeInquiry + AssetPayment, ShinkaiToolOffering, ToolPrice, UsageType, UsageTypeInquiry, }; use shinkai_message_primitives::schemas::wallet_complementary::{WalletRole, WalletSource}; use shinkai_message_primitives::schemas::wallet_mixed::{Asset, NetworkIdentifier}; use shinkai_message_primitives::shinkai_utils::encryption::{ - encryption_public_key_to_string, encryption_secret_key_to_string, unsafe_deterministic_encryption_keypair + encryption_public_key_to_string, encryption_secret_key_to_string, unsafe_deterministic_encryption_keypair, }; use shinkai_message_primitives::shinkai_utils::signatures::{ - clone_signature_secret_key, signature_public_key_to_string, signature_secret_key_to_string, unsafe_deterministic_signature_keypair + clone_signature_secret_key, signature_public_key_to_string, signature_secret_key_to_string, + unsafe_deterministic_signature_keypair, }; use shinkai_message_primitives::shinkai_utils::utils::hash_string; use shinkai_node::network::Node; @@ -348,6 +349,7 @@ fn micropayment_flow_test() { .send(NodeCommand::V2ApiListAllShinkaiTools { bearer: api_v2_key.to_string(), category: None, + include_simulated: false, res: sender, }) .await @@ -534,6 +536,7 @@ fn micropayment_flow_test() { .send(NodeCommand::V2ApiListAllShinkaiTools { bearer: api_v2_key.to_string(), category: None, + include_simulated: false, res: sender, }) .await diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs index 3381b6eed..5c63f5c70 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs @@ -695,12 +695,13 @@ pub async fn list_all_shinkai_tools_handler( ) -> Result { let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); let category = query_params.get("category").cloned(); - + let include_simulated = query_params.get("include_simulated").cloned().unwrap_or("false".to_string()).parse::().unwrap_or(false); let (res_sender, res_receiver) = async_channel::bounded(1); sender .send(NodeCommand::V2ApiListAllShinkaiTools { bearer, category, + include_simulated, res: res_sender, }) .await diff --git a/shinkai-libs/shinkai-http-api/src/node_commands.rs b/shinkai-libs/shinkai-http-api/src/node_commands.rs index 3d2040e08..9ae7fe284 100644 --- a/shinkai-libs/shinkai-http-api/src/node_commands.rs +++ b/shinkai-libs/shinkai-http-api/src/node_commands.rs @@ -705,6 +705,7 @@ pub enum NodeCommand { V2ApiListAllShinkaiTools { bearer: String, category: Option, + include_simulated: bool, res: Sender>, }, V2ApiListAllMcpShinkaiTools { diff --git a/shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs b/shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs index 20e51850b..3ba7e1809 100644 --- a/shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs +++ b/shinkai-libs/shinkai-sqlite/src/shinkai_tool_manager.rs @@ -1030,7 +1030,7 @@ mod tests { let vector = SqliteManager::generate_vector_for_testing(0.1); let _ = manager.add_tool_with_vector(ShinkaiTool::Simulated(simulated_tool_example.clone(), true), vector); - let t = manager.get_tool_by_key(&simulated_tool_example.get_tool_router_key()); + let t = manager.get_tool_by_key(&simulated_tool_example.get_tool_router_key_string()); assert!(t.is_ok()); let tool = t.unwrap(); assert_eq!(tool.name(), simulated_tool_example.clone().name); @@ -1113,7 +1113,7 @@ mod tests { let _ = manager.add_tool_with_vector(ShinkaiTool::Simulated(simulated_tool_example.clone(), true), vector); - let t = manager.get_tool_by_key(&simulated_tool_example.get_tool_router_key()); + let t = manager.get_tool_by_key(&simulated_tool_example.get_tool_router_key_string()); assert!(t.is_ok()); let tool = t.unwrap(); diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/simulated_tool.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/simulated_tool.rs index c5bd73f90..98c835ec8 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/simulated_tool.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/simulated_tool.rs @@ -40,14 +40,17 @@ impl SimulatedTool { "@@simulated.local".to_string() } - pub fn get_tool_router_key(&self) -> String { - let trk = ToolRouterKey { + pub fn get_tool_router_key(&self) -> ToolRouterKey { + ToolRouterKey { source: self.get_source(), author: self.get_author(), name: self.name.clone(), version: None, - }; - trk.to_string_without_version() + } + } + + pub fn get_tool_router_key_string(&self) -> String { + self.get_tool_router_key().to_string_without_version() } pub async fn build_example_json( From d18ba2e29f014c38cbd3cdcbcf2564418dd359dd Mon Sep 17 00:00:00 2001 From: Eddie Date: Thu, 1 May 2025 16:50:20 -0400 Subject: [PATCH 5/9] Only find in tools & embeddings if Agent is calling --- .../generic_chain/generic_inference_chain.rs | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs index 1acd4ec0c..70f00e26a 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs @@ -226,7 +226,10 @@ impl GenericInferenceChain { &format!("start_generic_inference_chain> message: {:?}", user_message), ); let start_time = Instant::now(); - + let is_agent_with_tools = match &llm_provider { + ProviderOrAgent::Agent(agent) => !agent.tools.is_empty(), + ProviderOrAgent::LLMProvider(_) => false, + }; /* How it (should) work: @@ -369,17 +372,17 @@ impl GenericInferenceChain { // If tool not found directly, try FTS and vector search let sanitized_query = tool_name.replace(|c: char| !c.is_alphanumeric() && c != ' ', " "); - // TODO [SIMULATED] - // Simulated tools should not be included in the FTS search. We have to detect if this is Agent Test Screen. + // include_simulated should be false by default and only turned on if used by agent. // Perform FTS search - let fts_results = tool_router.sqlite_manager.search_tools_fts(&sanitized_query, true); + let fts_results = tool_router + .sqlite_manager + .search_tools_fts(&sanitized_query, is_agent_with_tools); // Perform vector search let vector_results = tool_router .sqlite_manager - // TODO [SIMULATED] - // include_simulated should be false by default and only turned on if the user is the agent-test screen. - .tool_vector_search(&sanitized_query, 5, false, true, true) + // include_simulated should be false by default and only turned on if used by agent. + .tool_vector_search(&sanitized_query, 5, false, true, is_agent_with_tools) .await; match (fts_results, vector_results) { @@ -451,10 +454,7 @@ impl GenericInferenceChain { let tools_allowed = job_config.as_ref().and_then(|config| config.use_tools).unwrap_or(false); // 2c. Check if the LLM provider is an agent with tools - let is_agent_with_tools = match &llm_provider { - ProviderOrAgent::Agent(agent) => !agent.tools.is_empty(), - ProviderOrAgent::LLMProvider(_) => false, - }; + // is_agent_with_tools // 2d. Check if the LLM provider/agent has tool capabilities let can_use_tools = ModelCapabilitiesManager::has_tool_capabilities_for_provider_or_agent( @@ -497,9 +497,8 @@ impl GenericInferenceChain { // to find the most relevant tools for the user's message if let Some(tool_router) = &tool_router { let results = tool_router - // TODO [SIMULATED] - // include_simulated should be false by default and only turned on if the user is the agent-test screen. - .combined_tool_search(&user_message.clone(), 7, false, true, true) + // include_simulated should be false by default and only turned on if used by agent. + .combined_tool_search(&user_message.clone(), 7, false, true, is_agent_with_tools) .await; match results { From 1fec883c3c7e106d41ad8f1c5b4fa08e1160d3b5 Mon Sep 17 00:00:00 2001 From: Eddie Date: Fri, 2 May 2025 09:57:39 -0400 Subject: [PATCH 6/9] operating system alternative names --- .../shinkai-tools-primitives/src/tools/tool_playground.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/tool_playground.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/tool_playground.rs index e93e3e7dc..f7f007a59 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/tool_playground.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/tool_playground.rs @@ -61,7 +61,7 @@ pub struct ToolPlaygroundMetadata { pub tools: Option>, pub oauth: Option>, pub runner: RunnerType, - #[serde(rename = "operatingSystem")] + #[serde(alias = "operatingSystem", alias = "operating_system")] pub operating_system: Vec, pub tool_set: Option, } From 5d406ec71be6cdd1431f656f0427ebc31e85a2d4 Mon Sep 17 00:00:00 2001 From: Eddie Date: Fri, 2 May 2025 10:34:47 -0400 Subject: [PATCH 7/9] tool headers --- .../src/tools/agent_tool_wrapper.rs | 3 ++- .../src/tools/deno_tools.rs | 2 +- .../src/tools/shinkai_tool.rs | 19 ++++++++++++++++++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/agent_tool_wrapper.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/agent_tool_wrapper.rs index 875d13045..448818353 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/agent_tool_wrapper.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/agent_tool_wrapper.rs @@ -106,6 +106,7 @@ fn default_input_args() -> Parameters { fn default_output_arg() -> ToolOutputArg { ToolOutputArg { - json: r#"{"type":"string","description":"Agent response"}"#.to_string(), + json: r#"{"type": "object", "properties": {"message": {"type": "string", "description":"Agent response"}}}"# + .to_string(), } } diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs index a6d0c4fc0..3c602ea01 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs @@ -38,7 +38,7 @@ pub struct DenoTool { pub description: String, pub keywords: Vec, pub input_args: Parameters, - pub output_arg: ToolOutputArg, // DEPRICATED. Use "Result" Instance instead. + pub output_arg: ToolOutputArg, // DEPRICATED. Use "result" instead. pub activated: bool, pub embedding: Option>, pub result: ToolResult, diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/shinkai_tool.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/shinkai_tool.rs index 259b7c5d1..bca661a41 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/shinkai_tool.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/shinkai_tool.rs @@ -87,7 +87,7 @@ impl ShinkaiTool { enabled: self.is_enabled(), mcp_enabled: Some(self.is_mcp_enabled()), input_args: self.input_args(), - output_arg: ToolOutputArg::empty(), + output_arg: self.get_output_args(), config: self.get_js_tool_config().cloned(), usage_type: self.get_usage_type(), tool_offering: None, @@ -161,6 +161,23 @@ impl ShinkaiTool { } } + pub fn get_output_args(&self) -> ToolOutputArg { + match self { + ShinkaiTool::Rust(r, _) => r.output_arg.clone(), + ShinkaiTool::Network(n, _) => n.output_arg.clone(), + ShinkaiTool::Deno(d, _) => ToolOutputArg { + json: serde_json::to_string(&d.result).unwrap_or_default(), + }, + ShinkaiTool::Python(p, _) => ToolOutputArg { + json: serde_json::to_string(&p.result).unwrap_or_default(), + }, + ShinkaiTool::Agent(a, _) => a.output_arg.clone(), + ShinkaiTool::Simulated(s, _) => ToolOutputArg { + json: serde_json::to_string(&s.result).unwrap_or_default(), + }, + } + } + /// Returns the output arguments of the tool pub fn tool_type(&self) -> &'static str { match self { From 44fb15f34c599396a99e60ddf5693b5c55bbce8c Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Tue, 6 May 2025 21:56:16 -0500 Subject: [PATCH 8/9] nbsp --- .../src/llm_provider/execution/job_execution_core.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_core.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_core.rs index 94395bce4..6e0f9520f 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_core.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_core.rs @@ -203,7 +203,7 @@ impl JobManager { shinkai_log( ShinkaiLogOption::JobExecution, ShinkaiLogLevel::Debug, - &format!("Retrieved {} image files", image_files.len()), + &format!("Retrieved: {} image files", image_files.len()), ); let start = Instant::now(); From 00849828205f8fe6af14ac03be171ec114ce3c61 Mon Sep 17 00:00:00 2001 From: Eddie Date: Wed, 14 May 2025 11:07:22 -0400 Subject: [PATCH 9/9] fix use imports --- shinkai-bin/shinkai-node/src/managers/tool_router.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/managers/tool_router.rs b/shinkai-bin/shinkai-node/src/managers/tool_router.rs index 9e65fd90c..d5181148c 100644 --- a/shinkai-bin/shinkai-node/src/managers/tool_router.rs +++ b/shinkai-bin/shinkai-node/src/managers/tool_router.rs @@ -29,10 +29,6 @@ use shinkai_message_primitives::schemas::llm_providers::agent::Agent; use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent; use shinkai_message_primitives::schemas::shinkai_tools::CodeLanguage; use shinkai_message_primitives::schemas::{ - indexable_version::IndexableVersion, - invoices::{Invoice, InvoiceStatusEnum}, - job::JobLike, - llm_providers::common_agent_llm_provider::ProviderOrAgent, shinkai_name::ShinkaiName, shinkai_preferences::ShinkaiInternalComms, shinkai_tool_offering::{AssetPayment, ToolPrice, UsageType, UsageTypeInquiry},