Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,10 @@ impl GenericInferenceChain {
&format!("start_generic_inference_chain> message: {:?}", user_message),
);
let start_time = Instant::now();

let is_agent_with_tools = match &llm_provider {
ProviderOrAgent::Agent(agent) => !agent.tools.is_empty(),
ProviderOrAgent::LLMProvider(_) => false,
};
/*
How it (should) work:

Expand Down Expand Up @@ -367,13 +370,17 @@ impl GenericInferenceChain {
// If tool not found directly, try FTS and vector search
let sanitized_query = tool_name.replace(|c: char| !c.is_alphanumeric() && c != ' ', " ");

// include_simulated should be false by default and only turned on if used by agent.
// Perform FTS search
let fts_results = tool_router.sqlite_manager.search_tools_fts(&sanitized_query);
let fts_results = tool_router
.sqlite_manager
.search_tools_fts(&sanitized_query, is_agent_with_tools);

// Perform vector search
let vector_results = tool_router
.sqlite_manager
.tool_vector_search(&sanitized_query, 5, false, true)
// include_simulated should be false by default and only turned on if used by agent.
.tool_vector_search(&sanitized_query, 5, false, true, is_agent_with_tools)
.await;

match (fts_results, vector_results) {
Expand Down Expand Up @@ -445,10 +452,7 @@ impl GenericInferenceChain {
let tools_allowed = job_config.as_ref().and_then(|config| config.use_tools).unwrap_or(false);

// 2c. Check if the LLM provider is an agent with tools
let is_agent_with_tools = match &llm_provider {
ProviderOrAgent::Agent(agent) => !agent.tools.is_empty(),
ProviderOrAgent::LLMProvider(_) => false,
};
// is_agent_with_tools

// 2d. Check if the LLM provider/agent has tool capabilities
let can_use_tools = ModelCapabilitiesManager::has_tool_capabilities_for_provider_or_agent(
Expand Down Expand Up @@ -491,7 +495,8 @@ impl GenericInferenceChain {
// to find the most relevant tools for the user's message
if let Some(tool_router) = &tool_router {
let results = tool_router
.combined_tool_search(&user_message.clone(), 7, false, true)
// include_simulated should be false by default and only turned on if used by agent.
.combined_tool_search(&user_message.clone(), 7, false, true, is_agent_with_tools)
.await;

match results {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ impl JobManager {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Debug,
&format!("Retrieved {} image files", image_files.len()),
&format!("Retrieved: {} image files", image_files.len()),
);

let start = Instant::now();
Expand Down
2 changes: 1 addition & 1 deletion shinkai-bin/shinkai-node/src/managers/galxe_quests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ pub async fn compute_download_store_quest(db: Arc<SqliteManager>) -> Result<bool

// Get all installed tools
let installed_tools = db
.get_all_tool_headers()
.get_all_tool_headers(false)
.map_err(|e| format!("Failed to get installed tools: {}", e))?;

// Count tools that were downloaded (exist in default tools but don't have playground)
Expand Down
87 changes: 73 additions & 14 deletions shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ use crate::llm_provider::execution::chains::inference_chain_trait::{FunctionCall
use crate::llm_provider::job_manager::JobManager;
use crate::network::node_shareable_logic::ZipFileContents;
use crate::network::zip_export_import::zip_export_import::{
get_agent_from_zip, get_tool_from_zip, import_agent, import_tool
get_agent_from_zip, get_tool_from_zip, import_agent, import_tool,
};
use crate::network::Node;
use crate::tools::tool_definitions::definition_generation::{generate_tool_definitions, get_rust_tools};
use crate::tools::tool_execution::{
execute_agent_dynamic::execute_agent_tool, execution_coordinator::override_tool_config, execution_custom::try_to_execute_rust_tool, execution_header_generator::{check_tool, generate_execution_environment}
execute_agent_dynamic::execute_agent_tool,
execution_coordinator::override_tool_config,
execution_custom::try_to_execute_rust_tool,
execution_header_generator::{check_tool, generate_execution_environment},
};
use crate::utils::environment::{fetch_node_environment, NodeEnvironment};
use async_std::path::PathBuf;
Expand All @@ -19,18 +22,33 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use shinkai_embedding::embedding_generator::EmbeddingGenerator;
use shinkai_fs::shinkai_file_manager::ShinkaiFileManager;
use shinkai_message_primitives::schemas::indexable_version::IndexableVersion;
use shinkai_message_primitives::schemas::invoices::{Invoice, InvoiceStatusEnum};
use shinkai_message_primitives::schemas::job::JobLike;
use shinkai_message_primitives::schemas::llm_providers::agent::Agent;
use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent;
use shinkai_message_primitives::schemas::shinkai_tools::CodeLanguage;
use shinkai_message_primitives::schemas::{
indexable_version::IndexableVersion, invoices::{Invoice, InvoiceStatusEnum}, job::JobLike, llm_providers::common_agent_llm_provider::ProviderOrAgent, shinkai_name::ShinkaiName, shinkai_preferences::ShinkaiInternalComms, shinkai_tool_offering::{AssetPayment, ToolPrice, UsageType, UsageTypeInquiry}, tool_router_key::ToolRouterKey, wallet_mixed::{Asset, NetworkIdentifier}, ws_types::{PaymentMetadata, WSMessageType, WidgetMetadata}
shinkai_name::ShinkaiName,
shinkai_preferences::ShinkaiInternalComms,
shinkai_tool_offering::{AssetPayment, ToolPrice, UsageType, UsageTypeInquiry},
tool_router_key::ToolRouterKey,
wallet_mixed::{Asset, NetworkIdentifier},
ws_types::{PaymentMetadata, WSMessageType, WidgetMetadata},
};
use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::{AssociatedUI, WSTopic};
use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption};
use shinkai_sqlite::errors::SqliteManagerError;
use shinkai_sqlite::files::prompts_data;
use shinkai_sqlite::SqliteManager;
use shinkai_tools_primitives::tools::{
error::ToolError, network_tool::NetworkTool, parameters::Parameters, rust_tools::RustTool, shinkai_tool::{ShinkaiTool, ShinkaiToolHeader}, tool_config::ToolConfig, tool_output_arg::ToolOutputArg
error::ToolError,
network_tool::NetworkTool,
parameters::Parameters,
rust_tools::RustTool,
shinkai_tool::{ShinkaiTool, ShinkaiToolHeader},
tool_config::ToolConfig,
tool_output_arg::ToolOutputArg,
};
use std::env;
use std::sync::Arc;
Expand Down Expand Up @@ -711,10 +729,11 @@ impl ToolRouter {
&self,
query: &str,
num_of_results: u64,
include_simulated: bool,
) -> Result<Vec<ShinkaiToolHeader>, ToolError> {
let tool_headers = self
.sqlite_manager
.tool_vector_search(query, num_of_results, false, false)
.tool_vector_search(query, num_of_results, false, false, include_simulated)
.await
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;
// Note: we can add more code here to filter out low confidence results
Expand All @@ -726,10 +745,11 @@ impl ToolRouter {
&self,
query: &str,
num_of_results: u64,
include_simulated: bool,
) -> Result<Vec<ShinkaiToolHeader>, ToolError> {
let tool_headers = self
.sqlite_manager
.tool_vector_search(query, num_of_results, false, true)
.tool_vector_search(query, num_of_results, false, true, include_simulated)
.await
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;
// Note: we can add more code here to filter out low confidence results
Expand All @@ -741,10 +761,11 @@ impl ToolRouter {
&self,
query: &str,
num_of_results: u64,
include_simulated: bool,
) -> Result<Vec<ShinkaiToolHeader>, ToolError> {
let tool_headers = self
.sqlite_manager
.tool_vector_search(query, num_of_results, true, true)
.tool_vector_search(query, num_of_results, true, true, include_simulated)
.await
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;
// Note: we can add more code here to filter out low confidence results
Expand Down Expand Up @@ -819,8 +840,37 @@ impl ToolRouter {
}

match shinkai_tool {
ShinkaiTool::Simulated(simulated_tool, _is_enabled) => {
let node_env: crate::utils::environment::NodeEnvironment = fetch_node_environment();
let bearer = context.db().read_api_v2_key().unwrap_or_default().unwrap_or_default();
let app_id = match context.full_job().associated_ui().as_ref() {
Some(AssociatedUI::Cron(cron_id)) => cron_id.clone(),
_ => context.full_job().job_id().to_string(),
};

let tool_id = shinkai_tool.tool_router_key().to_string_without_version().clone();
let result = simulated_tool
.run(
bearer,
node_env.api_listen_address.ip().to_string(),
node_env.api_listen_address.port(),
app_id,
tool_id,
// TODO This should resolve the LLM provider, and not the agent_id
context.agent().clone().get_llm_provider_id().to_string(),
function_args,
function_config_vec,
)
.await?;
let result_str = serde_json::to_string(&result)
.map_err(|e| LLMProviderError::FunctionExecutionError(e.to_string()))?;
return Ok(ToolCallFunctionResponse {
response: result_str,
function_call,
});
}
ShinkaiTool::Python(python_tool, _is_enabled) => {
let node_env = fetch_node_environment();
let node_env: crate::utils::environment::NodeEnvironment = fetch_node_environment();
let node_storage_path = node_env
.node_storage_path
.clone()
Expand All @@ -836,7 +886,7 @@ impl ToolRouter {
let tools: Vec<ToolRouterKey> = context
.db()
.clone()
.get_all_tool_headers()?
.get_all_tool_headers(false)?
.into_iter()
.filter_map(|tool| match ToolRouterKey::from_string(&tool.tool_router_key) {
Ok(tool_router_key) => Some(tool_router_key),
Expand All @@ -853,8 +903,8 @@ impl ToolRouter {
let envs = generate_execution_environment(
context.db(),
context.agent().clone().get_id().to_string(),
tool_id.clone(),
app_id.clone(),
tool_id.clone(),
agent_id,
shinkai_tool.tool_router_key().to_string_without_version().clone(),
app_id.clone(),
Expand Down Expand Up @@ -1016,7 +1066,7 @@ impl ToolRouter {
let tools: Vec<ToolRouterKey> = context
.db()
.clone()
.get_all_tool_headers()?
.get_all_tool_headers(false)?
.into_iter()
.filter_map(|tool| match ToolRouterKey::from_string(&tool.tool_router_key) {
Ok(tool_router_key) => Some(tool_router_key),
Expand Down Expand Up @@ -1375,7 +1425,7 @@ impl ToolRouter {
let tools: Vec<ToolRouterKey> = self
.sqlite_manager
.clone()
.get_all_tool_headers()?
.get_all_tool_headers(false)?
.into_iter()
.filter_map(|tool| match ToolRouterKey::from_string(&tool.tool_router_key) {
Ok(tool_router_key) => Some(tool_router_key),
Expand Down Expand Up @@ -1447,6 +1497,7 @@ impl ToolRouter {
num_of_results: u64,
include_disabled: bool,
include_network: bool,
include_simulated: bool,
) -> Result<Vec<ShinkaiToolHeader>, ToolError> {
// Sanitize the query to handle special characters
let sanitized_query = query.replace(|c: char| !c.is_alphanumeric() && c != ' ', " ");
Expand All @@ -1455,14 +1506,22 @@ impl ToolRouter {
let vector_start_time = Instant::now();
let vector_search_result = self
.sqlite_manager
.tool_vector_search(&sanitized_query, num_of_results, include_disabled, include_network)
.tool_vector_search(
&sanitized_query,
num_of_results,
include_disabled,
include_network,
include_simulated,
)
.await;
let vector_elapsed_time = vector_start_time.elapsed();
println!("Time taken for vector search: {:?}", vector_elapsed_time);

// Start the timer for FTS search
let fts_start_time = Instant::now();
let fts_search_result = self.sqlite_manager.search_tools_fts(&sanitized_query);
let fts_search_result = self
.sqlite_manager
.search_tools_fts(&sanitized_query, include_simulated);
let fts_elapsed_time = fts_start_time.elapsed();
println!("Time taken for FTS search: {:?}", fts_elapsed_time);

Expand Down
74 changes: 73 additions & 1 deletion shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1766,7 +1766,12 @@ impl Node {
let _ = Node::v2_api_scan_ollama_models(db_clone, bearer, res).await;
});
}
NodeCommand::V2ApiListAllShinkaiTools { bearer, category, res } => {
NodeCommand::V2ApiListAllShinkaiTools {
bearer,
category,
include_simulated,
res,
} => {
let db_clone = Arc::clone(&self.db);
let tool_router_clone = self.tool_router.clone();
let node_name_clone = self.node_name.clone();
Expand All @@ -1776,6 +1781,7 @@ impl Node {
bearer,
node_name_clone,
category,
include_simulated,
tool_router_clone,
res,
)
Expand Down Expand Up @@ -1888,6 +1894,36 @@ impl Node {
// let _ = Node::v2_api_list_files_in_inbox(db_clone, bearer, inbox_name, res).await;
// });
// }
NodeCommand::V2ApiGenerateAgentFromPrompt {
bearer,
prompt,
llm_provider,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_name = self.node_name.clone();
let identity_manager_clone = self.identity_manager.clone();
let job_manager_clone = self.job_manager.clone().unwrap();
let encryption_secret_key_clone = self.encryption_secret_key.clone();
let encryption_public_key_clone = self.encryption_public_key.clone();
let signing_secret_key_clone = self.identity_secret_key.clone();
tokio::spawn(async move {
let _ = Node::v2_api_generate_agent_from_prompt(
db_clone,
bearer,
prompt,
llm_provider,
node_name,
identity_manager_clone,
job_manager_clone,
encryption_secret_key_clone,
encryption_public_key_clone,
signing_secret_key_clone,
res,
)
.await;
});
}
NodeCommand::V2ApiGetToolOffering {
bearer,
tool_key_name,
Expand Down Expand Up @@ -3073,6 +3109,42 @@ impl Node {
let _ = Node::v2_api_get_preferences(db_clone, bearer, res).await;
});
}

NodeCommand::V2ApiCreateSimulatedTool {
bearer,
name,
prompt,
agent_id,
llm_provider,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_name_clone = self.node_name.clone();
let identity_manager_clone = self.identity_manager.clone();
let job_manager_clone = self.job_manager.clone().unwrap();
let encryption_secret_key_clone = self.encryption_secret_key.clone();
let encryption_public_key_clone = self.encryption_public_key;
let signing_secret_key_clone = self.identity_secret_key.clone();

tokio::spawn(async move {
let _ = Node::v2_api_create_simulated_tool(
db_clone,
bearer,
name,
prompt,
agent_id,
llm_provider,
node_name_clone,
identity_manager_clone,
job_manager_clone,
encryption_secret_key_clone,
encryption_public_key_clone,
signing_secret_key_clone,
res,
)
.await;
});
}
NodeCommand::V2ApiGetLastUsedAgentsAndLLMs { bearer, last, res } => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
Expand Down
Loading