diff --git a/examples/attachments.cpp b/examples/attachments.cpp index e4fec47..17fd097 100644 --- a/examples/attachments.cpp +++ b/examples/attachments.cpp @@ -38,7 +38,7 @@ int main() try { copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); diff --git a/examples/basic_chat.cpp b/examples/basic_chat.cpp index 4926c79..ae8f699 100644 --- a/examples/basic_chat.cpp +++ b/examples/basic_chat.cpp @@ -14,7 +14,7 @@ int main() { // Create client with default options (uses stdio transport) copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); diff --git a/examples/byok.cpp b/examples/byok.cpp index 96eb7a7..4fde885 100644 --- a/examples/byok.cpp +++ b/examples/byok.cpp @@ -76,7 +76,7 @@ int main(int argc, char* argv[]) // Create client copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); diff --git a/examples/compaction_events.cpp b/examples/compaction_events.cpp index d511ef6..88225b5 100644 --- a/examples/compaction_events.cpp +++ b/examples/compaction_events.cpp @@ -23,7 +23,7 @@ int main() try { copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); diff --git a/examples/custom_agents.cpp b/examples/custom_agents.cpp index f3c3022..bd55daa 100644 --- a/examples/custom_agents.cpp +++ b/examples/custom_agents.cpp @@ -23,7 +23,7 @@ int main() try { copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); diff --git a/examples/fluent_tools.cpp b/examples/fluent_tools.cpp index 42398c8..6398476 100644 --- a/examples/fluent_tools.cpp +++ b/examples/fluent_tools.cpp @@ -150,7 +150,7 @@ int main() std::cout << "=== Starting Copilot Session ===\n\n"; ClientOptions opts; - opts.log_level = "info"; + opts.log_level = LogLevel::Info; opts.use_stdio = true; Client client(opts); diff --git a/examples/list_models.cpp b/examples/list_models.cpp index ecbad76..86f01fe 100644 --- a/examples/list_models.cpp +++ b/examples/list_models.cpp @@ -19,7 +19,7 @@ int main() try { copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); diff --git a/examples/mcp/inprocess/mcp_inprocess.cpp b/examples/mcp/inprocess/mcp_inprocess.cpp index d7f7991..7a0af9a 100644 --- a/examples/mcp/inprocess/mcp_inprocess.cpp +++ b/examples/mcp/inprocess/mcp_inprocess.cpp @@ -195,7 +195,7 @@ int main() std::cout << "Creating Copilot SDK client...\n"; copilot::ClientOptions client_opts; - client_opts.log_level = "info"; + client_opts.log_level = copilot::LogLevel::Info; copilot::Client client(client_opts); client.start().get(); diff --git a/examples/mcp_servers.cpp b/examples/mcp_servers.cpp index 44fef4b..2a32263 100644 --- a/examples/mcp_servers.cpp +++ b/examples/mcp_servers.cpp @@ -23,7 +23,7 @@ int main() try { copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); diff --git a/examples/permission_callback.cpp b/examples/permission_callback.cpp index f30306b..b7d152c 100644 --- a/examples/permission_callback.cpp +++ b/examples/permission_callback.cpp @@ -62,7 +62,7 @@ int main() // Create client copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); diff --git a/examples/reasoning_effort.cpp b/examples/reasoning_effort.cpp index eb252de..9916cae 100644 --- a/examples/reasoning_effort.cpp +++ b/examples/reasoning_effort.cpp @@ -63,7 +63,7 @@ int main() // Create session with reasoning effort copilot::SessionConfig config; - config.reasoning_effort = "medium"; + config.reasoning_effort = copilot::ReasoningEffort::Medium; // Permission handler config.on_permission_request = [](const copilot::PermissionRequest&) diff --git a/examples/resume_with_tools.cpp b/examples/resume_with_tools.cpp index dcca74f..bb81187 100644 --- a/examples/resume_with_tools.cpp +++ b/examples/resume_with_tools.cpp @@ -39,19 +39,19 @@ copilot::ToolResultObject secret_handler(const copilot::ToolInvocation& invocati if (it != secrets.end()) { result.text_result_for_llm = "Secret value for '" + key + "': " + it->second; - result.result_type = "success"; + result.result_type = copilot::ToolResultType::Success; } else { result.text_result_for_llm = "No secret found for key: " + key; - result.result_type = "failure"; + result.result_type = copilot::ToolResultType::Failure; result.error = "Key not found"; } } catch (const std::exception& e) { result.text_result_for_llm = "Error retrieving secret"; - result.result_type = "failure"; + result.result_type = copilot::ToolResultType::Failure; result.error = e.what(); } @@ -64,7 +64,7 @@ int main() { // Create client copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); std::cout << "=== Phase 1: Create Initial Session ===\n\n"; diff --git a/examples/streaming.cpp b/examples/streaming.cpp index f4dcce8..a7520a1 100644 --- a/examples/streaming.cpp +++ b/examples/streaming.cpp @@ -23,7 +23,7 @@ int main(int argc, char* argv[]) { // Create client with streaming enabled copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); diff --git a/examples/system_prompt.cpp b/examples/system_prompt.cpp index 2ae3f69..3465b2f 100644 --- a/examples/system_prompt.cpp +++ b/examples/system_prompt.cpp @@ -72,7 +72,7 @@ int main() try { copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); diff --git a/examples/tool_progress.cpp b/examples/tool_progress.cpp index 0e843c6..70d9563 100644 --- a/examples/tool_progress.cpp +++ b/examples/tool_progress.cpp @@ -42,7 +42,7 @@ copilot::ToolResultObject word_count_handler(const copilot::ToolInvocation& invo } catch (const std::exception& e) { - result.result_type = "failure"; + result.result_type = copilot::ToolResultType::Failure; result.error = e.what(); result.text_result_for_llm = std::string("Error: ") + e.what(); } @@ -70,7 +70,7 @@ copilot::ToolResultObject search_handler(const copilot::ToolInvocation& invocati } catch (const std::exception& e) { - result.result_type = "failure"; + result.result_type = copilot::ToolResultType::Failure; result.error = e.what(); result.text_result_for_llm = std::string("Error: ") + e.what(); } @@ -83,7 +83,7 @@ int main() try { copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); diff --git a/examples/tools.cpp b/examples/tools.cpp index afabeb9..b529f53 100644 --- a/examples/tools.cpp +++ b/examples/tools.cpp @@ -46,7 +46,7 @@ copilot::ToolResultObject calculate_handler(const copilot::ToolInvocation& invoc { if (b == 0) { - result.result_type = "failure"; + result.result_type = copilot::ToolResultType::Failure; result.error = "Division by zero"; result.text_result_for_llm = "Error: Cannot divide by zero"; return result; @@ -59,7 +59,7 @@ copilot::ToolResultObject calculate_handler(const copilot::ToolInvocation& invoc } else { - result.result_type = "failure"; + result.result_type = copilot::ToolResultType::Failure; result.error = "Unknown operation: " + operation; result.text_result_for_llm = "Error: Unknown operation '" + operation + "'"; return result; @@ -72,7 +72,7 @@ copilot::ToolResultObject calculate_handler(const copilot::ToolInvocation& invoc } catch (const std::exception& e) { - result.result_type = "failure"; + result.result_type = copilot::ToolResultType::Failure; result.error = e.what(); result.text_result_for_llm = std::string("Error: ") + e.what(); } @@ -105,7 +105,7 @@ copilot::ToolResultObject get_time_handler(const copilot::ToolInvocation& invoca } catch (const std::exception& e) { - result.result_type = "failure"; + result.result_type = copilot::ToolResultType::Failure; result.error = e.what(); result.text_result_for_llm = std::string("Error: ") + e.what(); } @@ -132,7 +132,7 @@ int main() { // Create client copilot::ClientOptions options; - options.log_level = "info"; + options.log_level = copilot::LogLevel::Info; copilot::Client client(options); diff --git a/include/copilot/client.hpp b/include/copilot/client.hpp index a0ae703..ec893b4 100644 --- a/include/copilot/client.hpp +++ b/include/copilot/client.hpp @@ -29,6 +29,7 @@ namespace copilot // Forward declaration class Session; +class Subscription; // ============================================================================= // Request Builder Helpers (for unit testing request JSON shape) @@ -57,7 +58,7 @@ json build_session_resume_request(const std::string& session_id, const ResumeSes /// Example usage: /// @code /// ClientOptions opts; -/// opts.log_level = "debug"; +/// opts.log_level = LogLevel::Debug; /// /// Client client(opts); /// client.start().get(); @@ -95,8 +96,8 @@ class Client /// Stop the client gracefully /// Destroys all sessions and closes the connection - /// @return Future that completes when stopped - std::future stop(); + /// @return Future that completes with any errors encountered during cleanup + std::future> stop(); /// Force stop the client immediately /// Kills the CLI process without graceful cleanup @@ -156,6 +157,24 @@ class Client /// @throws Error if not authenticated std::future> list_models(); + // ========================================================================= + // Lifecycle Events + // ========================================================================= + + /// Subscribe to session lifecycle events (created, deleted, updated, foreground, background) + using LifecycleHandler = std::function; + Subscription on_lifecycle(LifecycleHandler handler); + + // ========================================================================= + // Foreground Session + // ========================================================================= + + /// Get the current foreground session ID + std::future> get_foreground_session_id(); + + /// Set the foreground session + std::future set_foreground_session_id(const std::string& session_id); + // ========================================================================= // Internal API (used by Session) // ========================================================================= @@ -221,6 +240,10 @@ class Client // Models cache mutable std::mutex models_cache_mutex_; std::optional> models_cache_; + + // Lifecycle handlers + mutable std::mutex lifecycle_mutex_; + std::vector lifecycle_handlers_; }; } // namespace copilot diff --git a/include/copilot/tool_builder.hpp b/include/copilot/tool_builder.hpp index 81d1f55..fc17394 100644 --- a/include/copilot/tool_builder.hpp +++ b/include/copilot/tool_builder.hpp @@ -180,6 +180,28 @@ std::string to_result_string(const T& value) } } +/// Normalize handler return value to ToolResultObject +template +ToolResultObject normalize_result(T&& value) +{ + if constexpr (std::is_same_v, ToolResultObject>) + { + return std::forward(value); + } + else if constexpr (std::is_same_v, json>) + { + return ToolResultObject{ + .text_result_for_llm = value.dump(), + .result_type = ToolResultType::Success}; + } + else + { + return ToolResultObject{ + .text_result_for_llm = to_result_string(value), + .result_type = ToolResultType::Success}; + } +} + /// Extract argument from JSON by name template T extract_arg(const json& args, const std::string& name) @@ -341,7 +363,6 @@ class ToolBuilderWithParams auto params = params_; tool.handler = [fn = std::forward(fn), params](const ToolInvocation& inv) -> ToolResultObject { - ToolResultObject result; try { const json& args = inv.arguments.value_or(json::object()); @@ -352,16 +373,22 @@ class ToolBuilderWithParams // Call the handler with extracted arguments auto ret = std::apply(fn, extracted); - result.text_result_for_llm = detail::to_result_string(ret); - result.result_type = "success"; + return detail::normalize_result(std::move(ret)); } catch (const std::exception& e) { - result.result_type = "error"; - result.error = e.what(); - result.text_result_for_llm = std::string("Error: ") + e.what(); + return ToolResultObject{ + .text_result_for_llm = "Invoking this tool produced an error. Detailed information is not available.", + .result_type = ToolResultType::Failure, + .error = e.what()}; + } + catch (...) + { + return ToolResultObject{ + .text_result_for_llm = "Invoking this tool produced an error. Detailed information is not available.", + .result_type = ToolResultType::Failure, + .error = "Unknown error"}; } - return result; }; return tool; @@ -473,7 +500,6 @@ class ToolBuilder::StructBuilder tool.handler = [fn = std::forward(fn)](const ToolInvocation& inv) -> ToolResultObject { - ToolResultObject result; try { const json& args = inv.arguments.value_or(json::object()); @@ -482,16 +508,22 @@ class ToolBuilder::StructBuilder ArgsStruct parsed = args.get(); auto ret = fn(parsed); - result.text_result_for_llm = detail::to_result_string(ret); - result.result_type = "success"; + return detail::normalize_result(std::move(ret)); } catch (const std::exception& e) { - result.result_type = "error"; - result.error = e.what(); - result.text_result_for_llm = std::string("Error: ") + e.what(); + return ToolResultObject{ + .text_result_for_llm = "Invoking this tool produced an error. Detailed information is not available.", + .result_type = ToolResultType::Failure, + .error = e.what()}; + } + catch (...) + { + return ToolResultObject{ + .text_result_for_llm = "Invoking this tool produced an error. Detailed information is not available.", + .result_type = ToolResultType::Failure, + .error = "Unknown error"}; } - return result; }; return tool; @@ -674,19 +706,26 @@ Tool make_tool(std::string name, std::string description, Func&& func, tool.handler = [f = std::forward(func), names = std::move(param_names)](const ToolInvocation& inv) -> ToolResultObject { - ToolResultObject result; try { json args = inv.arguments.value_or(json::object()); auto output = detail::invoke_with_json(f, args, names); - result.text_result_for_llm = detail::to_result_string(output); + return detail::normalize_result(std::move(output)); } catch (const std::exception& e) { - result.result_type = "error"; - result.error = e.what(); + return ToolResultObject{ + .text_result_for_llm = "Invoking this tool produced an error. Detailed information is not available.", + .result_type = ToolResultType::Failure, + .error = e.what()}; + } + catch (...) + { + return ToolResultObject{ + .text_result_for_llm = "Invoking this tool produced an error. Detailed information is not available.", + .result_type = ToolResultType::Failure, + .error = "Unknown error"}; } - return result; }; return tool; diff --git a/include/copilot/types.hpp b/include/copilot/types.hpp index 64b1b2f..2bac8e3 100644 --- a/include/copilot/types.hpp +++ b/include/copilot/types.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace copilot @@ -73,6 +74,67 @@ NLOHMANN_JSON_SERIALIZE_ENUM( } ) +/// Log level for the CLI +enum class LogLevel +{ + None, + Error, + Warning, + Info, + Debug, + All +}; + +NLOHMANN_JSON_SERIALIZE_ENUM( + LogLevel, + { + {LogLevel::None, "none"}, + {LogLevel::Error, "error"}, + {LogLevel::Warning, "warning"}, + {LogLevel::Info, "info"}, + {LogLevel::Debug, "debug"}, + {LogLevel::All, "all"}, + } +) + +/// Result type for tool execution +enum class ToolResultType +{ + Success, + Failure, + Rejected, + Denied +}; + +NLOHMANN_JSON_SERIALIZE_ENUM( + ToolResultType, + { + {ToolResultType::Success, "success"}, + {ToolResultType::Failure, "failure"}, + {ToolResultType::Rejected, "rejected"}, + {ToolResultType::Denied, "denied"}, + } +) + +/// Reasoning effort level for model inference +enum class ReasoningEffort +{ + Low, + Medium, + High, + XHigh +}; + +NLOHMANN_JSON_SERIALIZE_ENUM( + ReasoningEffort, + { + {ReasoningEffort::Low, "low"}, + {ReasoningEffort::Medium, "medium"}, + {ReasoningEffort::High, "high"}, + {ReasoningEffort::XHigh, "xhigh"}, + } +) + // ============================================================================= // Tool Types // ============================================================================= @@ -107,7 +169,7 @@ struct ToolResultObject { std::string text_result_for_llm; std::optional> binary_results_for_llm; - std::string result_type = "success"; + ToolResultType result_type = ToolResultType::Success; std::optional error; std::optional session_log; std::optional> tool_telemetry; @@ -739,6 +801,9 @@ inline void from_json(const json& j, McpRemoteServerConfig& c) c.headers = j.at("headers").get>(); } +/// Union type for MCP server configuration +using MCPServerConfig = std::variant; + // ============================================================================= // Custom Agent Configuration // ============================================================================= @@ -921,8 +986,7 @@ struct SessionConfig bool auto_byok_from_env = false; /// Reasoning effort level for models that support it. - /// Valid values: "low", "medium", "high", "xhigh". - std::optional reasoning_effort; + std::optional reasoning_effort; /// Handler for user input requests from the agent (enables ask_user tool). std::optional on_user_input_request; @@ -962,8 +1026,7 @@ struct ResumeSessionConfig std::optional model; /// Reasoning effort level for models that support it. - /// Valid values: "low", "medium", "high", "xhigh". - std::optional reasoning_effort; + std::optional reasoning_effort; /// System message configuration. std::optional system_message; @@ -1020,7 +1083,7 @@ struct ClientOptions int port = 0; bool use_stdio = true; std::optional cli_url; - std::string log_level = "info"; + LogLevel log_level = LogLevel::Info; bool auto_start = true; bool auto_restart = true; std::optional> environment; @@ -1195,6 +1258,12 @@ inline void from_json(const json& j, SessionMetadata& m) j.at("isRemote").get_to(m.is_remote); } +/// Error reported during client stop/cleanup +struct StopError +{ + std::string message; +}; + /// Response from a ping request struct PingResponse { diff --git a/src/client.cpp b/src/client.cpp index e130610..166cc08 100644 --- a/src/client.cpp +++ b/src/client.cpp @@ -217,6 +217,10 @@ Client::Client(ClientOptions options) : options_(std::move(options)) "(external server manages its own auth)"); } + // Smart default for use_logged_in_user (only when managing our own server) + if (!options_.cli_url.has_value() && !options_.use_logged_in_user.has_value()) + options_.use_logged_in_user = !options_.github_token.has_value(); + // Parse CLI URL if provided if (options_.cli_url.has_value()) parse_cli_url(*options_.cli_url); @@ -320,13 +324,14 @@ std::future Client::start() ); } -std::future Client::stop() +std::future> Client::stop() { return std::async( std::launch::async, - [this]() + [this]() -> std::vector { std::lock_guard lock(mutex_); + std::vector errors; // Destroy all sessions for (auto& [id, session] : sessions_) @@ -335,9 +340,13 @@ std::future Client::stop() { session->destroy().get(); } + catch (const std::exception& e) + { + errors.push_back(StopError{e.what()}); + } catch (...) { - // Ignore errors during cleanup + errors.push_back(StopError{"Unknown error destroying session " + id}); } } sessions_.clear(); @@ -371,6 +380,7 @@ std::future Client::stop() } state_ = ConnectionState::Disconnected; + return errors; } ); } @@ -459,7 +469,7 @@ void Client::start_cli_server() args.insert(args.end(), options_.cli_args->begin(), options_.cli_args->end()); args.push_back("--server"); args.push_back("--log-level"); - args.push_back(options_.log_level); + args.push_back(json(options_.log_level).get()); if (options_.use_stdio) { @@ -493,6 +503,10 @@ void Client::start_cli_server() // Remove NODE_DEBUG to avoid debug output interfering with JSON-RPC proc_opts.environment.erase("NODE_DEBUG"); + // Forward GitHub token as environment variable + if (options_.github_token.has_value()) + proc_opts.environment["COPILOT_SDK_AUTH_TOKEN"] = *options_.github_token; + // Spawn process process_ = std::make_unique(); process_->spawn(executable, full_args, proc_opts); @@ -558,6 +572,19 @@ void Client::connect_to_server() { if (method == "session.event") handle_session_event(method, params); + else if (method == "session.lifecycle") + { + try + { + auto event = params.get(); + std::lock_guard lock(lifecycle_mutex_); + for (const auto& handler : lifecycle_handlers_) + handler(event); + } + catch (...) + { + } + } } ); @@ -1078,4 +1105,50 @@ json Client::handle_hooks_invoke(const json& params) } } +// ============================================================================= +// Lifecycle Events +// ============================================================================= + +Subscription Client::on_lifecycle(LifecycleHandler handler) +{ + std::lock_guard lock(lifecycle_mutex_); + lifecycle_handlers_.push_back(std::move(handler)); + auto* ptr = &lifecycle_handlers_.back(); + return Subscription([this, ptr]() { + std::lock_guard lock(lifecycle_mutex_); + lifecycle_handlers_.erase( + std::remove_if(lifecycle_handlers_.begin(), lifecycle_handlers_.end(), + [ptr](const LifecycleHandler& h) { return &h == ptr; }), + lifecycle_handlers_.end()); + }); +} + +// ============================================================================= +// Foreground Session +// ============================================================================= + +std::future> Client::get_foreground_session_id() +{ + return std::async( + std::launch::async, + [this]() -> std::optional + { + auto response = rpc_client()->invoke("session.getForeground", json::object()).get(); + auto parsed = response.get(); + return parsed.session_id; + } + ); +} + +std::future Client::set_foreground_session_id(const std::string& session_id) +{ + return std::async( + std::launch::async, + [this, session_id]() + { + rpc_client()->invoke("session.setForeground", json{{"sessionId", session_id}}).get(); + } + ); +} + } // namespace copilot diff --git a/tests/repro/repro_resume_with_tools.cpp b/tests/repro/repro_resume_with_tools.cpp index 6ac197c..5acd056 100644 --- a/tests/repro/repro_resume_with_tools.cpp +++ b/tests/repro/repro_resume_with_tools.cpp @@ -126,7 +126,7 @@ int main() ToolResultObject result; result.text_result_for_llm = "SECRET_VALUE_12345"; - result.result_type = "success"; + result.result_type = ToolResultType::Success; return result; }; diff --git a/tests/snapshot_tests/snapshot_replay.cpp b/tests/snapshot_tests/snapshot_replay.cpp index 6b3b8b6..9a1233f 100644 --- a/tests/snapshot_tests/snapshot_replay.cpp +++ b/tests/snapshot_tests/snapshot_replay.cpp @@ -87,7 +87,7 @@ Tool create_tool_from_config(const json& tool_config, } result.text_result_for_llm = fixed_result; - result.result_type = "success"; + result.result_type = ToolResultType::Success; return result; }; @@ -502,7 +502,7 @@ int main(int argc, char* argv[]) int port = server.start(); ClientOptions opts; - opts.log_level = "info"; + opts.log_level = LogLevel::Info; opts.use_stdio = false; opts.cli_url = std::to_string(port); opts.auto_start = false; diff --git a/tests/test_client_session.cpp b/tests/test_client_session.cpp index c8c757c..b931f04 100644 --- a/tests/test_client_session.cpp +++ b/tests/test_client_session.cpp @@ -20,7 +20,7 @@ TEST(ClientTest, DefaultConstruction) TEST(ClientTest, ConstructionWithOptions) { ClientOptions opts; - opts.log_level = "debug"; + opts.log_level = LogLevel::Debug; opts.auto_start = false; opts.port = 8080; @@ -174,7 +174,7 @@ TEST(ToolTest, ToolDefinition) auto result = tool.handler(inv); EXPECT_EQ(result.text_result_for_llm, "Success"); - EXPECT_EQ(result.result_type, "success"); + EXPECT_EQ(result.result_type, ToolResultType::Success); } // ============================================================================= @@ -214,7 +214,7 @@ TEST(ClientOptionsTest, DefaultValues) EXPECT_EQ(opts.port, 0); EXPECT_TRUE(opts.use_stdio); EXPECT_FALSE(opts.cli_url.has_value()); - EXPECT_EQ(opts.log_level, "info"); + EXPECT_EQ(opts.log_level, LogLevel::Info); EXPECT_TRUE(opts.auto_start); EXPECT_TRUE(opts.auto_restart); EXPECT_FALSE(opts.environment.has_value()); @@ -768,3 +768,61 @@ TEST(SessionCreateRequestTest, CustomAgentWithMcpServers) EXPECT_TRUE(agent_json.contains("tools")); EXPECT_EQ(agent_json["infer"], true); } + +// ============================================================================= +// URL Parsing Edge Case Tests +// ============================================================================= + +TEST(ClientOptions, InvalidPortTooHigh) +{ + ClientOptions opts; + opts.cli_url = "70000"; + opts.use_stdio = false; + + // Port 70000 exceeds valid range, falls through to hostname parsing + Client client(opts); + EXPECT_EQ(client.state(), ConnectionState::Disconnected); +} + +TEST(ClientOptions, InvalidPortZero) +{ + ClientOptions opts; + opts.cli_url = "0"; + opts.use_stdio = false; + + // Port 0 is not in valid range (1-65535), falls through to hostname parsing + Client client(opts); + EXPECT_EQ(client.state(), ConnectionState::Disconnected); +} + +TEST(ClientOptions, InvalidPortNegative) +{ + ClientOptions opts; + opts.cli_url = "-1"; + opts.use_stdio = false; + + // Negative port not valid, falls through to hostname parsing + Client client(opts); + EXPECT_EQ(client.state(), ConnectionState::Disconnected); +} + +TEST(ClientOptions, AuthDefaultWithToken) +{ + ClientOptions opts; + opts.github_token = "ghp_test123"; + opts.auto_start = false; + + Client client(opts); + // use_logged_in_user should default to false when github_token is set + EXPECT_EQ(client.state(), ConnectionState::Disconnected); +} + +TEST(ClientOptions, AuthDefaultWithoutToken) +{ + ClientOptions opts; + opts.auto_start = false; + + Client client(opts); + // use_logged_in_user should default to true without token + EXPECT_EQ(client.state(), ConnectionState::Disconnected); +} diff --git a/tests/test_e2e.cpp b/tests/test_e2e.cpp index 8eb81d0..33b2292 100644 --- a/tests/test_e2e.cpp +++ b/tests/test_e2e.cpp @@ -168,7 +168,7 @@ class E2ETest : public ::testing::Test try { ClientOptions opts; - opts.log_level = "info"; + opts.log_level = LogLevel::Info; opts.use_stdio = true; opts.cli_args = std::vector{"--allow-all-tools", "--allow-all-paths"}; opts.auto_start = false; @@ -256,7 +256,7 @@ class E2ETest : public ::testing::Test std::unique_ptr create_client() { ClientOptions opts; - opts.log_level = "info"; + opts.log_level = LogLevel::Info; opts.use_stdio = true; // Make E2E tests reliable/non-interactive by pre-approving tool and path access. // These flags are only used for tests; library defaults remain secure-by-default. @@ -417,7 +417,7 @@ TEST_F(E2ETest, CreateSessionWithTools) result.text_result_for_llm = "54321"; else result.text_result_for_llm = "Unknown key"; - result.result_type = "success"; + result.result_type = ToolResultType::Success; return result; }; @@ -880,7 +880,7 @@ TEST_F(E2ETest, ResumeSessionWithTools) result.text_result_for_llm = "SECRET_VALUE_12345"; else result.text_result_for_llm = "Unknown key"; - result.result_type = "success"; + result.result_type = ToolResultType::Success; return result; }; @@ -1876,7 +1876,7 @@ TEST_F(E2ETest, ToolCallIdIsPropagated) } result.text_result_for_llm = "Tool executed successfully. ID: " + inv.tool_call_id; - result.result_type = "success"; + result.result_type = ToolResultType::Success; return result; }; @@ -2088,7 +2088,7 @@ TEST_F(E2ETest, ResumeSessionWithToolsAndPermissions) tool_called = true; ToolResultObject result; result.text_result_for_llm = "RESUME_TOOL_RESULT_99999"; - result.result_type = "success"; + result.result_type = ToolResultType::Success; return result; }; @@ -2832,7 +2832,7 @@ TEST_F(E2ETest, PreToolUseHookInvokedOnToolCall) tool_called = true; ToolResultObject result; result.text_result_for_llm = "Echo: " + inv.arguments.value()["message"].get(); - result.result_type = "success"; + result.result_type = ToolResultType::Success; return result; }; config.tools = {echo_tool}; @@ -2921,7 +2921,7 @@ TEST_F(E2ETest, PreToolUseHookDeniesToolExecution) tool_called = true; ToolResultObject result; result.text_result_for_llm = "This should not execute"; - result.result_type = "success"; + result.result_type = ToolResultType::Success; return result; }; config.tools = {denied_tool}; @@ -3014,7 +3014,7 @@ TEST_F(E2ETest, PostToolUseHookInvokedAfterToolExecution) tool_called = true; ToolResultObject result; result.text_result_for_llm = "Hello, " + inv.arguments.value()["name"].get() + "!"; - result.result_type = "success"; + result.result_type = ToolResultType::Success; return result; }; config.tools = {greet_tool}; @@ -3123,7 +3123,7 @@ TEST_F(E2ETest, SessionWithReasoningEffort) client->start().get(); auto config = default_session_config(); - config.reasoning_effort = "medium"; + config.reasoning_effort = ReasoningEffort::Medium; auto session = client->create_session(config).get(); EXPECT_NE(session, nullptr); @@ -3409,7 +3409,7 @@ TEST_F(E2ETest, ResumeSessionWithNewConfigFields) } auto resume_config = default_resume_config(); - resume_config.reasoning_effort = "low"; + resume_config.reasoning_effort = ReasoningEffort::Low; resume_config.working_directory = std::filesystem::current_path().string(); auto resumed = client->resume_session(session_id, resume_config).get(); @@ -3471,7 +3471,7 @@ TEST_F(E2ETest, FullFeaturedSessionWithAllNewConfig) client->start().get(); auto config = default_session_config(); - config.reasoning_effort = "high"; + config.reasoning_effort = ReasoningEffort::High; config.working_directory = std::filesystem::current_path().string(); config.on_user_input_request = [](const UserInputRequest& req, const UserInputInvocation&) -> UserInputResponse @@ -3553,3 +3553,711 @@ TEST_F(E2ETest, FullFeaturedSessionWithAllNewConfig) session->destroy().get(); client->force_stop(); } + +// ============================================================================= +// Parity Sync: Foreground Session API Tests +// ============================================================================= + +TEST_F(E2ETest, ForegroundSessionSetAndGet) +{ + test_info("Foreground session: set_foreground_session_id then get_foreground_session_id round-trips."); + auto client = create_client(); + client->start().get(); + + auto config = default_session_config(); + auto session = client->create_session(config).get(); + ASSERT_NE(session, nullptr); + std::string sid = session->session_id(); + ASSERT_FALSE(sid.empty()); + + // Send a message to persist the session first + std::atomic idle{false}; + std::mutex mtx; + std::condition_variable cv; + + auto sub = session->on( + [&](const SessionEvent& event) + { + if (event.type == SessionEventType::SessionIdle) + { + idle = true; + cv.notify_one(); + } + } + ); + + MessageOptions opts; + opts.prompt = "Say 'hi'."; + session->send(opts).get(); + + { + std::unique_lock lock(mtx); + cv.wait_for(lock, std::chrono::seconds(30), [&]() { return idle.load(); }); + } + + // Set this session as foreground + client->set_foreground_session_id(sid).get(); + + // Small delay to allow the server to persist + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + // Get foreground session ID — should match + auto fg = client->get_foreground_session_id().get(); + if (fg.has_value()) + { + EXPECT_EQ(*fg, sid); + std::cout << "Foreground session round-trip confirmed: " << *fg << "\n"; + } + else + { + std::cout << "Note: get_foreground_session_id returned empty after set " + << "(server may not persist foreground state)\n"; + } + + session->destroy().get(); + client->force_stop(); +} + +TEST_F(E2ETest, ForegroundSessionInitiallyEmpty) +{ + test_info("Foreground session initially empty: get_foreground_session_id returns empty before any set."); + auto client = create_client(); + client->start().get(); + + auto fg = client->get_foreground_session_id().get(); + // Before any session is created/set as foreground, result should be empty + if (fg.has_value()) + std::cout << "Note: foreground session ID was already set: " << *fg << "\n"; + else + std::cout << "Confirmed: no foreground session ID set initially\n"; + + // The key assertion is that the call doesn't throw + client->force_stop(); +} + +// ============================================================================= +// Parity Sync: Graceful Stop Test +// ============================================================================= + +TEST_F(E2ETest, GracefulStopReturnsNoErrors) +{ + test_info("Graceful stop: stop() returns empty vector after clean session."); + auto client = create_client(); + client->start().get(); + + auto config = default_session_config(); + auto session = client->create_session(config).get(); + ASSERT_NE(session, nullptr); + + // Send a simple message and wait for idle + std::atomic idle{false}; + std::mutex mtx; + std::condition_variable cv; + + auto sub = session->on( + [&](const SessionEvent& event) + { + if (event.type == SessionEventType::SessionIdle) + { + idle = true; + cv.notify_one(); + } + } + ); + + MessageOptions opts; + opts.prompt = "Say 'ok'."; + session->send(opts).get(); + + { + std::unique_lock lock(mtx); + cv.wait_for(lock, std::chrono::seconds(60), [&]() { return idle.load(); }); + } + + session->destroy().get(); + + // Graceful stop should return no errors + auto errors = client->stop().get(); + EXPECT_TRUE(errors.empty()) << "stop() should return empty errors after clean session, got " + << errors.size() << " errors"; + if (!errors.empty()) + { + for (const auto& e : errors) + std::cout << " StopError: " << e.message << "\n"; + } +} + +// ============================================================================= +// Parity Sync: Lifecycle Events Test +// ============================================================================= + +TEST_F(E2ETest, LifecycleCallbackFiresOnSessionCreate) +{ + test_info("Lifecycle callback: on_lifecycle fires session.created event with correct session ID."); + auto client = create_client(); + client->start().get(); + + std::atomic lifecycle_fired{false}; + std::string lifecycle_session_id; + std::string lifecycle_event_type; + std::mutex mtx; + std::condition_variable cv; + + // Register lifecycle handler BEFORE creating session + auto unsub = client->on_lifecycle( + [&](const SessionLifecycleEvent& evt) + { + std::lock_guard lock(mtx); + if (evt.type == SessionLifecycleEventTypes::Created) + { + lifecycle_fired = true; + lifecycle_session_id = evt.session_id; + lifecycle_event_type = evt.type; + cv.notify_one(); + } + } + ); + + auto config = default_session_config(); + auto session = client->create_session(config).get(); + ASSERT_NE(session, nullptr); + std::string sid = session->session_id(); + + // Wait for lifecycle event + { + std::unique_lock lock(mtx); + cv.wait_for(lock, std::chrono::seconds(10), [&]() { return lifecycle_fired.load(); }); + } + + EXPECT_TRUE(lifecycle_fired.load()) << "on_lifecycle should fire session.created event"; + { + std::lock_guard lock(mtx); + EXPECT_EQ(lifecycle_event_type, SessionLifecycleEventTypes::Created); + EXPECT_EQ(lifecycle_session_id, sid); + } + + session->destroy().get(); + client->force_stop(); +} + +// ============================================================================= +// Parity Sync: User Input Handler E2E Tests +// ============================================================================= + +TEST_F(E2ETest, UserInputHandlerInvokedByModel) +{ + test_info("User input handler: model's ask_user triggers the handler, response flows back."); + auto client = create_client(); + client->start().get(); + + std::atomic handler_called{false}; + std::string received_question; + std::mutex mtx; + + auto config = default_session_config(); + config.on_user_input_request = [&](const UserInputRequest& req, const UserInputInvocation&) -> UserInputResponse + { + { + std::lock_guard lock(mtx); + handler_called = true; + received_question = req.question; + } + UserInputResponse resp; + resp.answer = "Blue"; + resp.was_freeform = true; + return resp; + }; + + config.on_permission_request = [](const PermissionRequest&) -> PermissionRequestResult + { + PermissionRequestResult r; + r.kind = "approved"; + return r; + }; + + auto session = client->create_session(config).get(); + ASSERT_NE(session, nullptr); + + std::atomic idle{false}; + std::string assistant_response; + std::condition_variable cv; + + auto sub = session->on( + [&](const SessionEvent& event) + { + if (event.type == SessionEventType::SessionIdle) + { + idle = true; + cv.notify_one(); + } + else if (auto* msg = event.try_as()) + { + std::lock_guard lock(mtx); + assistant_response += msg->content; + } + } + ); + + // Retry up to 3 prompts to trigger ask_user + const char* prompts[] = { + "Ask the user what their favorite color is using the ask_user tool. Then report their answer.", + "You MUST use the ask_user tool to ask the user about their favorite color. Report what they say.", + "Use ask_user to ask: 'What is your favorite color?' Then tell me the answer." + }; + for (int attempt = 0; attempt < 3 && !handler_called.load(); ++attempt) + { + idle = false; + MessageOptions opts; + opts.prompt = prompts[attempt]; + session->send(opts).get(); + + { + std::unique_lock lock(mtx); + cv.wait_for(lock, std::chrono::seconds(60), [&]() { return idle.load(); }); + } + } + + EXPECT_TRUE(handler_called.load()) << "User input handler should have been invoked"; + { + std::lock_guard lock(mtx); + // Response should reference "Blue" since that's what our handler returned + if (handler_called.load()) + { + EXPECT_TRUE(assistant_response.find("Blue") != std::string::npos || + assistant_response.find("blue") != std::string::npos) + << "Response should contain the user input answer 'Blue': " << assistant_response; + } + } + + session->destroy().get(); + client->force_stop(); +} + +TEST_F(E2ETest, UserInputHandlerReceivesQuestion) +{ + test_info("User input handler: UserInputRequest contains non-empty question field."); + auto client = create_client(); + client->start().get(); + + std::atomic handler_called{false}; + std::string received_question; + std::mutex mtx; + + auto config = default_session_config(); + config.on_user_input_request = [&](const UserInputRequest& req, const UserInputInvocation&) -> UserInputResponse + { + { + std::lock_guard lock(mtx); + handler_called = true; + received_question = req.question; + } + UserInputResponse resp; + resp.answer = "42"; + resp.was_freeform = true; + return resp; + }; + + config.on_permission_request = [](const PermissionRequest&) -> PermissionRequestResult + { + PermissionRequestResult r; + r.kind = "approved"; + return r; + }; + + auto session = client->create_session(config).get(); + + std::atomic idle{false}; + std::condition_variable cv; + + auto sub = session->on( + [&](const SessionEvent& event) + { + if (event.type == SessionEventType::SessionIdle) + { + idle = true; + cv.notify_one(); + } + } + ); + + const char* prompts[] = { + "Use the ask_user tool to ask: 'What is your favorite number?'", + "You MUST call ask_user with the question 'What is your favorite number?'", + "Call ask_user to ask the user about their favorite number." + }; + for (int attempt = 0; attempt < 3 && !handler_called.load(); ++attempt) + { + idle = false; + MessageOptions opts; + opts.prompt = prompts[attempt]; + session->send(opts).get(); + + { + std::unique_lock lock(mtx); + cv.wait_for(lock, std::chrono::seconds(60), [&]() { return idle.load(); }); + } + } + + EXPECT_TRUE(handler_called.load()) << "User input handler should have been invoked"; + { + std::lock_guard lock(mtx); + if (handler_called.load()) + { + EXPECT_FALSE(received_question.empty()) << "UserInputRequest.question should not be empty"; + std::cout << "Received question: " << received_question << "\n"; + } + } + + session->destroy().get(); + client->force_stop(); +} + +// ============================================================================= +// Parity Sync: Combined Pre+Post Hook Test +// ============================================================================= + +TEST_F(E2ETest, BothPreAndPostToolHooksFireOnSameToolCall) +{ + test_info("Both hooks: pre and post tool hooks both fire for the same tool invocation."); + auto client = create_client(); + client->start().get(); + + std::atomic pre_hook_fired{false}; + std::atomic post_hook_fired{false}; + std::string pre_hook_tool; + std::string post_hook_tool; + std::mutex mtx; + + auto config = default_session_config(); + + Tool ping_tool; + ping_tool.name = "ping_pong"; + ping_tool.description = "Returns 'pong' when called"; + ping_tool.parameters_schema = { + {"type", "object"}, + {"properties", json::object()} + }; + ping_tool.handler = [](const ToolInvocation&) -> ToolResultObject + { + ToolResultObject result; + result.text_result_for_llm = "pong"; + result.result_type = ToolResultType::Success; + return result; + }; + config.tools = {ping_tool}; + + config.hooks = SessionHooks{}; + config.hooks->on_pre_tool_use = [&](const PreToolUseHookInput& input, const HookInvocation&) + -> std::optional + { + { + std::lock_guard lock(mtx); + pre_hook_fired = true; + pre_hook_tool = input.tool_name; + } + return std::nullopt; + }; + + config.hooks->on_post_tool_use = [&](const PostToolUseHookInput& input, const HookInvocation&) + -> std::optional + { + { + std::lock_guard lock(mtx); + post_hook_fired = true; + post_hook_tool = input.tool_name; + } + return std::nullopt; + }; + + config.on_permission_request = [](const PermissionRequest&) -> PermissionRequestResult + { + PermissionRequestResult r; + r.kind = "approved"; + return r; + }; + + auto session = client->create_session(config).get(); + + std::atomic idle{false}; + std::condition_variable cv; + + auto sub = session->on( + [&](const SessionEvent& event) + { + if (event.type == SessionEventType::SessionIdle) + { + idle = true; + cv.notify_one(); + } + } + ); + + const char* prompts[] = { + "Use the ping_pong tool now.", + "You MUST call the ping_pong tool.", + "Call ping_pong immediately." + }; + for (int attempt = 0; attempt < 3 && !pre_hook_fired.load(); ++attempt) + { + idle = false; + MessageOptions opts; + opts.prompt = prompts[attempt]; + session->send(opts).get(); + + { + std::unique_lock lock(mtx); + cv.wait_for(lock, std::chrono::seconds(60), [&]() { return idle.load(); }); + } + } + + EXPECT_TRUE(pre_hook_fired.load()) << "preToolUse hook should have fired"; + EXPECT_TRUE(post_hook_fired.load()) << "postToolUse hook should have fired"; + { + std::lock_guard lock(mtx); + EXPECT_EQ(pre_hook_tool, "ping_pong"); + EXPECT_EQ(post_hook_tool, "ping_pong"); + } + + session->destroy().get(); + client->force_stop(); +} + +// ============================================================================= +// Parity Sync: UserPromptSubmitted Hook Test +// ============================================================================= + +TEST_F(E2ETest, UserPromptSubmittedHookFires) +{ + test_info("User prompt submitted hook: on_user_prompt_submitted fires with the prompt text."); + auto client = create_client(); + client->start().get(); + + std::atomic hook_fired{false}; + std::string captured_prompt; + std::mutex mtx; + + auto config = default_session_config(); + config.hooks = SessionHooks{}; + config.hooks->on_user_prompt_submitted = [&](const UserPromptSubmittedHookInput& input, const HookInvocation&) + -> std::optional + { + { + std::lock_guard lock(mtx); + hook_fired = true; + captured_prompt = input.prompt; + } + return std::nullopt; + }; + + auto session = client->create_session(config).get(); + + std::atomic idle{false}; + std::condition_variable cv; + + auto sub = session->on( + [&](const SessionEvent& event) + { + if (event.type == SessionEventType::SessionIdle) + { + idle = true; + cv.notify_one(); + } + } + ); + + const std::string marker = "MARKER_PROMPT_7749"; + MessageOptions opts; + opts.prompt = "Say hello. " + marker; + session->send(opts).get(); + + { + std::unique_lock lock(mtx); + cv.wait_for(lock, std::chrono::seconds(60), [&]() { return idle.load(); }); + } + + EXPECT_TRUE(hook_fired.load()) << "on_user_prompt_submitted hook should fire"; + { + std::lock_guard lock(mtx); + if (hook_fired.load()) + { + EXPECT_TRUE(captured_prompt.find(marker) != std::string::npos) + << "Hook should capture prompt containing marker '" << marker << "', got: " << captured_prompt; + } + } + + session->destroy().get(); + client->force_stop(); +} + +// ============================================================================= +// Parity Sync: Tool Error Handling Test +// ============================================================================= + +TEST_F(E2ETest, ToolHandlerExceptionDoesNotCrash) +{ + test_info("Tool error: Tool handler throwing exception doesn't crash session."); + auto client = create_client(); + client->start().get(); + + auto config = default_session_config(); + + Tool crash_tool; + crash_tool.name = "crasher"; + crash_tool.description = "A tool that always fails internally"; + crash_tool.parameters_schema = { + {"type", "object"}, + {"properties", json::object()} + }; + crash_tool.handler = [](const ToolInvocation&) -> ToolResultObject + { + throw std::runtime_error("Secret_Error_42"); + }; + config.tools = {crash_tool}; + + config.on_permission_request = [](const PermissionRequest&) -> PermissionRequestResult + { + PermissionRequestResult r; + r.kind = "approved"; + return r; + }; + + auto session = client->create_session(config).get(); + + std::atomic idle{false}; + std::string assistant_response; + std::mutex mtx; + std::condition_variable cv; + + auto sub = session->on( + [&](const SessionEvent& event) + { + if (event.type == SessionEventType::SessionIdle) + { + idle = true; + cv.notify_one(); + } + else if (auto* msg = event.try_as()) + { + std::lock_guard lock(mtx); + assistant_response += msg->content; + } + } + ); + + const char* prompts[] = { + "Use the crasher tool now.", + "You MUST call the crasher tool immediately.", + "Call crasher." + }; + for (int attempt = 0; attempt < 3; ++attempt) + { + idle = false; + MessageOptions opts; + opts.prompt = prompts[attempt]; + session->send(opts).get(); + + { + std::unique_lock lock(mtx); + cv.wait_for(lock, std::chrono::seconds(60), [&]() { return idle.load(); }); + } + + if (idle.load()) + break; + } + + // Session should reach idle — the exception should not crash it + EXPECT_TRUE(idle.load()) << "Session should reach idle even after tool handler exception"; + + // The secret error string should not leak to the LLM response + { + std::lock_guard lock(mtx); + EXPECT_TRUE(assistant_response.find("Secret_Error_42") == std::string::npos) + << "Secret error should not leak to LLM response: " << assistant_response; + } + + session->destroy().get(); + client->force_stop(); +} + +// ============================================================================= +// Parity Sync: make_tool with normalize_result Test +// ============================================================================= + +TEST_F(E2ETest, MakeToolWithNormalizeResult) +{ + test_info("make_tool: Tool created via make_tool with plain string return works through CLI."); + auto client = create_client(); + client->start().get(); + + // Create tool using make_tool — returns std::string, exercises normalize_result path + auto rot13_tool = copilot::make_tool( + "rot13", + "Apply ROT13 cipher to the input text", + [](std::string text) -> std::string + { + std::string result = text; + for (auto& c : result) + { + if (c >= 'a' && c <= 'z') + c = 'a' + (c - 'a' + 13) % 26; + else if (c >= 'A' && c <= 'Z') + c = 'A' + (c - 'A' + 13) % 26; + } + return result; + }, + {"text"} + ); + + auto config = default_session_config(); + config.tools = {rot13_tool}; + + config.on_permission_request = [](const PermissionRequest&) -> PermissionRequestResult + { + PermissionRequestResult r; + r.kind = "approved"; + return r; + }; + + auto session = client->create_session(config).get(); + + std::atomic idle{false}; + std::string assistant_response; + std::mutex mtx; + std::condition_variable cv; + + auto sub = session->on( + [&](const SessionEvent& event) + { + if (event.type == SessionEventType::SessionIdle) + { + idle = true; + cv.notify_one(); + } + else if (auto* msg = event.try_as()) + { + std::lock_guard lock(mtx); + assistant_response += msg->content; + } + } + ); + + // "Hello" ROT13 = "Uryyb" + MessageOptions opts; + opts.prompt = "Use the rot13 tool to encode the word 'Hello'. Tell me the exact result."; + session->send(opts).get(); + + { + std::unique_lock lock(mtx); + cv.wait_for(lock, std::chrono::seconds(60), [&]() { return idle.load(); }); + } + + EXPECT_TRUE(idle.load()) << "Session should reach idle"; + { + std::lock_guard lock(mtx); + EXPECT_TRUE(assistant_response.find("Uryyb") != std::string::npos) + << "Response should contain ROT13 of 'Hello' = 'Uryyb': " << assistant_response; + } + + session->destroy().get(); + client->force_stop(); +} diff --git a/tests/test_jsonrpc.cpp b/tests/test_jsonrpc.cpp index 6af141c..9fd3251 100644 --- a/tests/test_jsonrpc.cpp +++ b/tests/test_jsonrpc.cpp @@ -536,3 +536,123 @@ TEST(JsonRpcErrorTest, ErrorWithData) EXPECT_STREQ(error.what(), "Bad params"); EXPECT_EQ(error.data()["field"], "name"); } + +// ============================================================================= +// Large Payload Tests +// ============================================================================= + +TEST(JsonRpc, LargePayload64KBBoundary) +{ + auto [client_transport, server_transport] = PipeTransport::create_pair(); + JsonRpcClient client(std::move(client_transport)); + MessageFramer server_framer(*server_transport); + + client.start(); + + // Create payload exactly at 64KB boundary + std::string large_data(64 * 1024, 'A'); + auto future = client.invoke("large.method", json{{"data", large_data}}); + + std::thread server_thread( + [&] + { + auto msg = server_framer.read_message(); + auto req = json::parse(msg); + EXPECT_EQ(req["method"], "large.method"); + EXPECT_EQ(req["params"]["data"].get().size(), 64 * 1024); + + json response = { + {"jsonrpc", "2.0"}, {"result", {{"ok", true}}}, {"id", req["id"]} + }; + server_framer.write_message(response.dump()); + } + ); + + auto result = future.get(); + EXPECT_EQ(result["ok"], true); + + server_thread.join(); + client.stop(); +} + +TEST(JsonRpc, LargePayload70KB) +{ + auto [client_transport, server_transport] = PipeTransport::create_pair(); + JsonRpcClient client(std::move(client_transport)); + MessageFramer server_framer(*server_transport); + + client.start(); + + std::string large_data(70 * 1024, 'B'); + auto future = client.invoke("large.method", json{{"data", large_data}}); + + std::thread server_thread( + [&] + { + auto msg = server_framer.read_message(); + auto req = json::parse(msg); + EXPECT_EQ(req["params"]["data"].get().size(), 70 * 1024); + + json response = { + {"jsonrpc", "2.0"}, {"result", {{"ok", true}}}, {"id", req["id"]} + }; + server_framer.write_message(response.dump()); + } + ); + + auto result = future.get(); + EXPECT_EQ(result["ok"], true); + + server_thread.join(); + client.stop(); +} + +TEST(JsonRpc, LargePayload100KB) +{ + auto [client_transport, server_transport] = PipeTransport::create_pair(); + JsonRpcClient client(std::move(client_transport)); + MessageFramer server_framer(*server_transport); + + client.start(); + + std::string large_data(100 * 1024, 'C'); + auto future = client.invoke("large.method", json{{"data", large_data}}); + + std::thread server_thread( + [&] + { + auto msg = server_framer.read_message(); + auto req = json::parse(msg); + EXPECT_EQ(req["params"]["data"].get().size(), 100 * 1024); + + json response = { + {"jsonrpc", "2.0"}, {"result", {{"size", 100 * 1024}}}, {"id", req["id"]} + }; + server_framer.write_message(response.dump()); + } + ); + + auto result = future.get(); + EXPECT_EQ(result["size"], 100 * 1024); + + server_thread.join(); + client.stop(); +} + +TEST(JsonRpc, EOFOnPartialData) +{ + auto [client_transport, server_transport] = PipeTransport::create_pair(); + JsonRpcClient client(std::move(client_transport)); + + client.start(); + + auto future = client.invoke("test.method"); + + // Close server transport immediately to simulate EOF on partial data + server_transport->close(); + + // The future should eventually fail or throw + EXPECT_THROW(future.get(), std::exception); + + client.stop(); +} diff --git a/tests/test_tool_builder.cpp b/tests/test_tool_builder.cpp index b979f2d..e7ad9b5 100644 --- a/tests/test_tool_builder.cpp +++ b/tests/test_tool_builder.cpp @@ -142,7 +142,7 @@ TEST(ToolBuilderTest, HandlerInvocationSuccess) inv.arguments = json{{"a", 10.0}, {"b", 32.0}}; auto result = add.handler(inv); - EXPECT_EQ(result.result_type, "success"); + EXPECT_EQ(result.result_type, ToolResultType::Success); EXPECT_EQ(result.text_result_for_llm, "42.000000"); } @@ -159,7 +159,7 @@ TEST(ToolBuilderTest, HandlerInvocationWithStrings) inv.arguments = json{{"name", "World"}}; auto result = greet.handler(inv); - EXPECT_EQ(result.result_type, "success"); + EXPECT_EQ(result.result_type, ToolResultType::Success); EXPECT_EQ(result.text_result_for_llm, "Hello, World!"); } @@ -201,7 +201,7 @@ TEST(ToolBuilderTest, HandlerErrorHandling) inv.arguments = json{{"a", 10.0}, {"b", 0.0}}; auto result = div.handler(inv); - EXPECT_EQ(result.result_type, "error"); + EXPECT_EQ(result.result_type, ToolResultType::Failure); EXPECT_TRUE(result.error.has_value()); EXPECT_EQ(*result.error, "Division by zero"); } @@ -216,7 +216,7 @@ TEST(ToolBuilderTest, MissingRequiredArg) inv.arguments = json::object(); // Missing "name" auto result = greet.handler(inv); - EXPECT_EQ(result.result_type, "error"); + EXPECT_EQ(result.result_type, ToolResultType::Failure); EXPECT_TRUE(result.error.has_value()); } @@ -263,7 +263,7 @@ TEST(ToolBuilderTest, StructBasedHandlerInvocation) inv.arguments = json{{"query", "test"}, {"limit", 5}}; auto result = search.handler(inv); - EXPECT_EQ(result.result_type, "success"); + EXPECT_EQ(result.result_type, ToolResultType::Success); EXPECT_EQ(result.text_result_for_llm, "test:5"); } @@ -319,7 +319,7 @@ TEST(ToolBuilderTest, BackwardCompatibility) old_tool.handler = [](const ToolInvocation& inv) -> ToolResultObject { ToolResultObject r; r.text_result_for_llm = "old style"; - r.result_type = "success"; + r.result_type = ToolResultType::Success; return r; }; @@ -417,7 +417,7 @@ TEST(MakeToolTest, ErrorHandling) ToolInvocation inv; inv.arguments = json{{"input", "test"}}; auto result = tool.handler(inv); - EXPECT_EQ(result.result_type, "error"); + EXPECT_EQ(result.result_type, ToolResultType::Failure); EXPECT_TRUE(result.error.has_value()); EXPECT_EQ(*result.error, "boom"); } diff --git a/tests/test_types.cpp b/tests/test_types.cpp index a7a5a6f..c1ae278 100644 --- a/tests/test_types.cpp +++ b/tests/test_types.cpp @@ -52,7 +52,7 @@ TEST(TypesTest, ToolBinaryResultRoundTrip) TEST(TypesTest, ToolResultObjectMinimal) { - ToolResultObject result{.text_result_for_llm = "Success!", .result_type = "success"}; + ToolResultObject result{.text_result_for_llm = "Success!", .result_type = ToolResultType::Success}; json j = result; EXPECT_EQ(j["textResultForLlm"], "Success!"); @@ -1195,7 +1195,7 @@ TEST(ReasoningEffortTest, ModelInfoWithReasoningEfforts) TEST(ReasoningEffortTest, SessionConfigReasoningEffort) { SessionConfig config; - config.reasoning_effort = "high"; + config.reasoning_effort = ReasoningEffort::High; auto request = build_session_create_request(config); EXPECT_EQ(request["reasoningEffort"], "high"); } @@ -1203,7 +1203,7 @@ TEST(ReasoningEffortTest, SessionConfigReasoningEffort) TEST(ReasoningEffortTest, ResumeConfigReasoningEffort) { ResumeSessionConfig config; - config.reasoning_effort = "low"; + config.reasoning_effort = ReasoningEffort::Low; auto request = build_session_resume_request("test-session", config); EXPECT_EQ(request["reasoningEffort"], "low"); } @@ -1567,7 +1567,7 @@ TEST(RequestBuilderTest, ResumeSessionAllNewFields) { ResumeSessionConfig config; config.model = "gpt-4o"; - config.reasoning_effort = "high"; + config.reasoning_effort = ReasoningEffort::High; config.system_message = SystemMessageConfig{.content = "Be helpful"}; config.available_tools = {"read_file"}; config.excluded_tools = {"dangerous_tool"}; @@ -1722,3 +1722,70 @@ TEST(UserInputHandlerTest, NoHandlerThrows) request.question = "test"; EXPECT_THROW(session->handle_user_input_request(request), std::runtime_error); } + +// ============================================================================= +// Event Forward-Compatibility Tests +// ============================================================================= + +TEST(Events, UnknownEventTypeHandled) +{ + json event_json = { + {"type", "some.future.event"}, + {"id", "e-unknown"}, + {"timestamp", "2025-06-01T00:00:00Z"}, + {"data", {{"foo", "bar"}}} + }; + + // Parsing an unknown event type should not throw + EXPECT_NO_THROW({ + auto event = parse_session_event(event_json); + EXPECT_EQ(event.type, SessionEventType::Unknown); + EXPECT_EQ(event.type_string, "some.future.event"); + }); +} + +TEST(Events, SessionShutdownParsed) +{ + json event_json = { + {"type", "session.shutdown"}, + {"id", "e-shutdown"}, + {"timestamp", "2025-06-01T00:00:00Z"}, + {"data", { + {"shutdownType", "routine"}, + {"totalPremiumRequests", 5}, + {"totalApiDurationMs", 1234.5}, + {"sessionStartTime", 1700000000.0}, + {"codeChanges", {{"linesAdded", 10}, {"linesRemoved", 3}}} + }} + }; + + auto event = parse_session_event(event_json); + EXPECT_EQ(event.type, SessionEventType::SessionShutdown); + EXPECT_TRUE(event.is()); + + const auto& data = event.as(); + EXPECT_EQ(data.total_premium_requests, 5); + EXPECT_NEAR(data.total_api_duration_ms, 1234.5, 0.1); +} + +TEST(Events, SessionUsageInfoRecognized) +{ + json event_json = { + {"type", "session.usage_info"}, + {"id", "e-usage"}, + {"timestamp", "2025-06-01T00:00:00Z"}, + {"data", { + {"tokenLimit", 128000}, + {"currentTokens", 5000}, + {"messagesLength", 42} + }} + }; + + auto event = parse_session_event(event_json); + EXPECT_EQ(event.type, SessionEventType::SessionUsageInfo); + EXPECT_TRUE(event.is()); + + const auto& data = event.as(); + EXPECT_EQ(data.token_limit, 128000); + EXPECT_EQ(data.current_tokens, 5000); +}