From af5dcc7fdf097bf31550c6018a1e1d3d7262b7e8 Mon Sep 17 00:00:00 2001 From: afourney Date: Fri, 7 Feb 2025 10:01:44 -0800 Subject: [PATCH 01/36] Significant updates to agbench. (#5313) - Updated HumanEval template to use AgentChat - Update templates to use config.yaml for model and other configuration - Read environment from ENV.yaml (ENV.json still supported but deprecated) - Temporarily removed WebArena and AssistantBench. Neither had viable Templates after `autogen_magentic_one` was removed. Templates need to be update to AgentChat (in a future PR, but this PR is getting big enough already) --- .../benchmarks/AssistantBench/.gitignore | 1 - .../benchmarks/AssistantBench/ENV.json.sample | 5 - .../benchmarks/AssistantBench/README.md | 78 ---- .../Scripts/assistantbench_evaluator.py | 127 ------ .../AssistantBench/Scripts/custom_tabulate.py | 232 ----------- .../Scripts/evaluate_utils/LICENSE | 201 --------- .../Scripts/evaluate_utils/__init__.py | 0 .../Scripts/evaluate_utils/evaluate_dicts.py | 71 ---- .../evaluate_utils/evaluate_factory.py | 30 -- .../evaluate_utils/evaluate_numbers.py | 35 -- .../evaluate_utils/evaluate_strings.py | 180 -------- .../Scripts/evaluate_utils/readme.md | 1 - .../Scripts/evaluate_utils/utils.py | 24 -- .../AssistantBench/Scripts/init_tasks.py | 93 ----- .../agbench/benchmarks/GAIA/ENV.json.sample | 5 - .../packages/agbench/benchmarks/GAIA/ENV.yaml | 30 ++ .../agbench/benchmarks/GAIA/README.md | 26 +- .../GAIA/Scripts/custom_tabulate.py | 245 +++++------ .../Templates/MagenticOne/expected_answer.txt | 1 + .../GAIA/Templates/MagenticOne/prompt.txt | 1 + .../Templates/MagenticOne/requirements.txt | 5 + .../GAIA/Templates/MagenticOne/scenario.py | 89 ++++ .../SelectorGroupChat/expected_answer.txt | 1 + .../Templates/SelectorGroupChat/prompt.txt | 1 + .../SelectorGroupChat/requirements.txt | 5 + .../Templates/SelectorGroupChat/scenario.py | 176 ++++++++ .../agbench/benchmarks/GAIA/config.yaml | 38 ++ .../agbench/benchmarks/HumanEval/ENV.yaml | 30 ++ .../agbench/benchmarks/HumanEval/README.md | 24 +- .../HumanEval/Scripts/init_tasks.py | 4 +- .../AgentChat/custom_code_executor.py | 54 +++ .../HumanEval/Templates/AgentChat/prompt.txt | 1 + .../Templates/AgentChat/requirements.txt | 4 + .../HumanEval/Templates/AgentChat/scenario.py | 54 +++ .../HumanEval/Templates/AgentChat/test.txt | 1 + .../agbench/benchmarks/HumanEval/config.yaml | 29 ++ .../agbench/benchmarks/WebArena/ENV.sample | 0 .../agbench/benchmarks/WebArena/README.md | 8 - .../WebArena/Scripts/custom_tabulate.py | 35 -- .../benchmarks/WebArena/Scripts/init_tasks.py | 122 ------ .../Common/evaluation_harness/ATTRIBUTION | 28 -- .../Common/evaluation_harness/LICENSE | 201 --------- .../Common/evaluation_harness/__init__.py | 6 - .../Common/evaluation_harness/env_config.py | 90 ---- .../Common/evaluation_harness/evaluators.py | 387 ------------------ .../evaluation_harness/helper_functions.py | 233 ----------- .../Common/evaluation_harness/openai_utils.py | 275 ------------- .../agbench/src/agbench/res/Dockerfile | 19 +- .../packages/agbench/src/agbench/run_cmd.py | 155 ++++++- .../agbench/src/agbench/tabulate_cmd.py | 19 +- 50 files changed, 793 insertions(+), 2687 deletions(-) delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/.gitignore delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/ENV.json.sample delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/README.md delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/Scripts/assistantbench_evaluator.py delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/Scripts/custom_tabulate.py delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/LICENSE delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/__init__.py delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_dicts.py delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_factory.py delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_numbers.py delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_strings.py delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/readme.md delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/utils.py delete mode 100644 python/packages/agbench/benchmarks/AssistantBench/Scripts/init_tasks.py delete mode 100644 python/packages/agbench/benchmarks/GAIA/ENV.json.sample create mode 100644 python/packages/agbench/benchmarks/GAIA/ENV.yaml create mode 100644 python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/expected_answer.txt create mode 100644 python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/prompt.txt create mode 100644 python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/requirements.txt create mode 100644 python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/scenario.py create mode 100644 python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/expected_answer.txt create mode 100644 python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/prompt.txt create mode 100644 python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/requirements.txt create mode 100644 python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/scenario.py create mode 100644 python/packages/agbench/benchmarks/GAIA/config.yaml create mode 100644 python/packages/agbench/benchmarks/HumanEval/ENV.yaml create mode 100644 python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/custom_code_executor.py create mode 100644 python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/prompt.txt create mode 100644 python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/requirements.txt create mode 100644 python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/scenario.py create mode 100644 python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/test.txt create mode 100644 python/packages/agbench/benchmarks/HumanEval/config.yaml delete mode 100644 python/packages/agbench/benchmarks/WebArena/ENV.sample delete mode 100644 python/packages/agbench/benchmarks/WebArena/README.md delete mode 100644 python/packages/agbench/benchmarks/WebArena/Scripts/custom_tabulate.py delete mode 100644 python/packages/agbench/benchmarks/WebArena/Scripts/init_tasks.py delete mode 100644 python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/ATTRIBUTION delete mode 100644 python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/LICENSE delete mode 100644 python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/__init__.py delete mode 100644 python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/env_config.py delete mode 100644 python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/evaluators.py delete mode 100644 python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/helper_functions.py delete mode 100644 python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/openai_utils.py diff --git a/python/packages/agbench/benchmarks/AssistantBench/.gitignore b/python/packages/agbench/benchmarks/AssistantBench/.gitignore deleted file mode 100644 index f6c9d117b084..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/.gitignore +++ /dev/null @@ -1 +0,0 @@ -ENV.json \ No newline at end of file diff --git a/python/packages/agbench/benchmarks/AssistantBench/ENV.json.sample b/python/packages/agbench/benchmarks/AssistantBench/ENV.json.sample deleted file mode 100644 index 1f2c4915e3c7..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/ENV.json.sample +++ /dev/null @@ -1,5 +0,0 @@ -{ - "BING_API_KEY": "YOUR_KEY_KEY", - "HOMEPAGE": "https://www.bing.com/", - "WEB_SURFER_DEBUG_DIR": "/autogen/debug" -} diff --git a/python/packages/agbench/benchmarks/AssistantBench/README.md b/python/packages/agbench/benchmarks/AssistantBench/README.md deleted file mode 100644 index 30bcf881fb00..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/README.md +++ /dev/null @@ -1,78 +0,0 @@ -# AssistantBench Benchmark - -This scenario implements the [AssistantBench](https://assistantbench.github.io/) agent benchmark. Before you begin, make sure you have followed the instructions in `../README.md` to prepare your environment. We modify the evaluation code from AssistantBench in [Scripts](Scripts) and retain the license including it here [LICENSE](Scripts/evaluate_utils/LICENSE). Please find the original AssistantBench evaluation code here [https://huggingface.co/spaces/AssistantBench/leaderboard/tree/main/evaluation](https://huggingface.co/spaces/AssistantBench/leaderboard/tree/main/evaluation). - -### Setup Environment Variables for AgBench - -Navigate to AssistantBench - -```bash -cd benchmarks/AssistantBench -``` - -Create a file called ENV.json with the following (required) contents (If you're using MagenticOne) - -```json -{ - "BING_API_KEY": "REPLACE_WITH_YOUR_BING_API_KEY", - "HOMEPAGE": "https://www.bing.com/", - "WEB_SURFER_DEBUG_DIR": "/autogen/debug", - "CHAT_COMPLETION_KWARGS_JSON": "{\"api_version\": \"2024-02-15-preview\", \"azure_endpoint\": \"YOUR_ENDPOINT/\", \"model_capabilities\": {\"function_calling\": true, \"json_output\": true, \"vision\": true}, \"azure_ad_token_provider\": \"DEFAULT\", \"model\": \"gpt-4o-2024-05-13\"}", - "CHAT_COMPLETION_PROVIDER": "azure" -} -``` - -You can also use the openai client by replacing the last two entries in the ENV file by: - -- `CHAT_COMPLETION_PROVIDER='openai'` -- `CHAT_COMPLETION_KWARGS_JSON` with the following JSON structure: - -```json -{ - "api_key": "REPLACE_WITH_YOUR_API", - "model": "gpt-4o-2024-05-13" -} -``` - -Now initialize the tasks. - -```bash -python Scripts/init_tasks.py -``` - -Note: This will attempt to download AssistantBench from Huggingface, but this requires authentication. - -After running the script, you should see the new following folders and files: - -``` -. -./Downloads -./Downloads/AssistantBench -./Downloads/AssistantBench/assistant_bench_v1.0_dev.jsonl -./Downloads/AssistantBench/assistant_bench_v1.0_dev.jsonl -./Tasks -./Tasks/assistant_bench_v1.0_dev.jsonl -./Tasks/assistant_bench_v1.0_dev.jsonl -``` - -Then run `Scripts/init_tasks.py` again. - -Once the script completes, you should now see a folder in your current directory called `Tasks` that contains one JSONL file per template in `Templates`. - -### Running AssistantBench - -Now to run a specific subset of AssistantBench use: - -```bash -agbench run Tasks/assistant_bench_v1.0_dev__MagenticOne.jsonl -``` - -You should see the command line print the raw logs that shows the agents in action To see a summary of the results (e.g., task completion rates), in a new terminal run the following: - -```bash -agbench tabulate Results/assistant_bench_v1.0_dev__MagenticOne -``` - -## References - -Yoran, Ori, Samuel Joseph Amouyal, Chaitanya Malaviya, Ben Bogin, Ofir Press, and Jonathan Berant. "AssistantBench: Can Web Agents Solve Realistic and Time-Consuming Tasks?." arXiv preprint arXiv:2407.15711 (2024). https://arxiv.org/abs/2407.15711 diff --git a/python/packages/agbench/benchmarks/AssistantBench/Scripts/assistantbench_evaluator.py b/python/packages/agbench/benchmarks/AssistantBench/Scripts/assistantbench_evaluator.py deleted file mode 100644 index 56d1a04faa67..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/Scripts/assistantbench_evaluator.py +++ /dev/null @@ -1,127 +0,0 @@ -# This Script is slightly modified from the creators of the AssistantBench dataset https://huggingface.co/spaces/AssistantBench/leaderboard/blob/main/evaluation/evaluator.py -import json -from evaluate_utils.evaluate_factory import get_evaluator -import numpy as np - - -def find_isnan(samp): - try: - if np.isnan(samp): - return True - else: - return False - except: - return False - - -def fix_ans(answer): - try: - answer = answer.replace("{'", '{"').replace("', '", '", "').replace("': '", '": "').replace("'}", '"}') - answer = answer.replace("': ", '": ') - return answer - except: - return answer - - -def parse_answer(answer): - if len(answer) == 1: - ans, is_num = fix_number(answer[0]) - if is_num: - return ans, "number" - try: - ans = json.loads(fix_ans(answer[0])) - return [ans], "json" - except: - ans, is_num = fix_number(answer[0]) - if is_num: - return ans, "number" - else: - return answer[0], "string" - else: - try: - ans = [json.loads(fix_ans(ex)) for ex in answer] - return ans, "json" - except: - return answer, "string list" - - -def fix_number(number): - if type(number) == str: - copy_ans = number - copy_ans = " ".join(" ".join(" ".join(copy_ans.split("$")).split("%")).split("sqft")).strip() - copy_ans = copy_ans.strip() - copy_ans = copy_ans.replace(",", ".").replace(" square kilometers", "") - try: - return float(copy_ans), True - except: - return number, False - elif type(number) == int: - return float(number), True - else: - return number, True - - -def fix_prediction(prediction, gold_answer, evaluator): - if ( - type(prediction) == list - and len(prediction) == 1 - and (type(prediction[0]) == int or ((type(prediction[0]) == str) and prediction[0].isnumeric())) - ): - prediction = fix_number(prediction[0]) - - if type(prediction) != list: - prediction, is_num = fix_number(prediction) - if evaluator == "json": - try: - prediction = [json.loads(pred) for pred in prediction.split("\n")] - except: - prediction = [prediction] - - if (hasattr(type(prediction), "__len__")) and (len(prediction) == 0): - return prediction, False - - if (type(prediction) == list and len(prediction) > 1) and type(gold_answer) == float: - return prediction, False - - return prediction, True - - -def question_scorer(prediction, gold_answer): - """ - prediction: str or list of str - gold_answer: str or list of str - - returns a float between 0 and 1 - """ - try: - try: - prediction = json.loads(prediction) - except: - prediction = prediction - - answer_list = ( - [x for x in gold_answer.split("\n") if len(x.strip()) > 0] if type(gold_answer) != list else gold_answer - ) - gold_answer, evaluator = parse_answer(answer_list) - prediction, run_eval = fix_prediction(prediction, gold_answer, evaluator) - - has_ans = 1.0 - if (type(prediction) != float and len(prediction) == 0) or find_isnan(prediction): - has_ans = 0.0 - - if not run_eval: - return 0.0 - - metric_eval = get_evaluator(evaluator) - accuracy = metric_eval(prediction, gold_answer) - # double check if the accuracy is a number between 0 and 1 - if 0 <= accuracy <= 1: - return accuracy - else: - # throw exception - raise ValueError(f"Accuracy should be a float between 0 and 1, but got {accuracy}") - except Exception as e: - print( - f"Something went wrong while evaluating prediction {prediction} vs gold answer {gold_answer} with error {e}" - ) - return 0.0 diff --git a/python/packages/agbench/benchmarks/AssistantBench/Scripts/custom_tabulate.py b/python/packages/agbench/benchmarks/AssistantBench/Scripts/custom_tabulate.py deleted file mode 100644 index 61c40acc72f8..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/Scripts/custom_tabulate.py +++ /dev/null @@ -1,232 +0,0 @@ -import os -import sys -import re -from agbench.tabulate_cmd import default_tabulate -import json -import pandas as pd -import sqlite3 -import glob -import numpy as np -sys.path.append(os.path.dirname(__file__)) - -from assistantbench_evaluator import question_scorer - -EXCLUDE_DIR_NAMES = ["__pycache__"] - - -def normalize_answer(a): - # Lower case - # Trim (left and right) - # standardize comma separated values - # Replace multiple spaces with one space - # Remove trailing punctuation - norm_answer = ", ".join(a.strip().lower().split(",")) - norm_answer = re.sub(r"[\.\!\?]+$", "", re.sub(r"\s+", " ", norm_answer)) - return norm_answer - - -def scorer(instance_dir): - # Read the expected answer - expected_answer_file = os.path.join(instance_dir, "expected_answer.txt") - if not os.path.isfile(expected_answer_file): - return None - - expected_answer = None - with open(expected_answer_file, "rt") as fh: - expected_answer = fh.read().strip() - - # Read the console - console_log_file = os.path.join(instance_dir, "console_log.txt") - if not os.path.isfile(console_log_file): - return None - - console_log = "" - with open(console_log_file, "rt") as fh: - console_log = fh.read() - - final_answer = None - m = re.search(r"FINAL ANSWER:(.*?)\n", console_log, re.DOTALL) - if m: - final_answer = m.group(1).strip() - - # Missing the final answer line - if final_answer is None: - return None - # get accuracy from assistantbench util, no normalization done for accuracy - accuracy = question_scorer(final_answer, expected_answer) - n_ex = normalize_answer(expected_answer) - n_final = normalize_answer(final_answer) - return (accuracy, n_ex, n_final) - - -def get_number_of_chat_messages(chat_messages_dir): - result = 0 - for file in glob.glob(f"{chat_messages_dir}/*_messages.json"): - with open(file, "r") as f: - content = json.load(f) - for agent, messages in content.items(): - result += len(messages) - return result - - -def main(args): - parsed_args, all_results = default_tabulate(args, scorer=scorer) - excel_path = parsed_args.excel - - if excel_path: - excel_dir = os.path.dirname(excel_path) or "." - if not os.path.exists(excel_dir): - os.makedirs(excel_dir, exist_ok=True) - - if not excel_path.endswith((".xlsx", ".xls")): - excel_path += ".xlsx" - - runlogs = ( - parsed_args.runlogs - if parsed_args.runlogs.endswith("/") - else parsed_args.runlogs + "/" - ) - - if os.path.isdir(runlogs): - task_ids = sorted( - [ - task_id - for task_id in os.listdir(runlogs) - if task_id not in EXCLUDE_DIR_NAMES - ], - key=lambda s: os.path.getmtime(os.path.join(parsed_args.runlogs, s)), - ) - else: - raise ValueError("please input a valid directory to tabulate result") - - trials = ( - sorted(os.listdir(f"{runlogs}{task_ids[0]}"), key=lambda x: int(x)) - if len(task_ids) > 0 - else [] - ) - dbnames = [ - [f"{runlogs}{task_id}/{trial}/telemetry.db" for task_id in task_ids] - for trial in trials - ] - - query = """ - SELECT cost, session_id, response, start_time, end_time - FROM ( - SELECT invocation_id, cost, session_id, response, start_time, end_time, - ROW_NUMBER() OVER (PARTITION BY invocation_id ORDER BY start_time) as rn - FROM chat_completions - ) - WHERE rn = 1; - """ - - with pd.ExcelWriter(excel_path, engine="openpyxl") as writer: - for trial_index, each_trial in enumerate(dbnames): - result_df = pd.DataFrame( - columns=[ - "id", - "status", - "expected_answer", - "final_answer", - "cost", - "latency", - "num_of_llm_requests", - "num_of_chat_messages", - "prompt_tokens", - "completion_tokens", - "total_tokens", - "model", - ] - ) - - result_df_type_mapping = { - "id": str, - "status": bool, - "expected_answer": str, - "final_answer": str, - "cost": float, - "latency": float, - "num_of_llm_requests": int, - "num_of_chat_messages": int, - "prompt_tokens": int, - "completion_tokens": int, - "total_tokens": int, - } - - for dbname, scorer_results in zip(each_trial, all_results): - task_id = scorer_results[0] - scorer_result = scorer_results[trial_index + 1] - - status, expected_answer, final_answer = ( - scorer_result if scorer_result else (False, "", "") - ) - - con = sqlite3.connect(dbname) - - # TODO: if large amount of data, add chunksize - telemetry_df = pd.read_sql_query(query, con) - - earliest_starttime = pd.to_datetime( - telemetry_df["start_time"], format="%Y-%m-%d %H:%M:%S.%f" - ).min() - latest_endtime = pd.to_datetime( - telemetry_df["end_time"], format="%Y-%m-%d %H:%M:%S.%f" - ).max() - - num_of_chat_messages = get_number_of_chat_messages( - chat_messages_dir=os.path.dirname(dbname) - ) - result = { - "id": task_id, - "status": status, - "expected_answer": expected_answer, - "final_answer": final_answer, - "cost": telemetry_df["cost"].sum(), - "latency": ( - latest_endtime - earliest_starttime - ).total_seconds(), - "num_of_llm_requests": len(telemetry_df), - "num_of_chat_messages": num_of_chat_messages, - "prompt_tokens": telemetry_df["response"] - .apply( - lambda x: json.loads(x)["usage"]["prompt_tokens"] - if "usage" in json.loads(x) - and "prompt_tokens" in json.loads(x)["usage"] - else 0 - ) - .sum(), - "completion_tokens": telemetry_df["response"] - .apply( - lambda x: json.loads(x)["usage"]["completion_tokens"] - if "usage" in json.loads(x) - and "completion_tokens" in json.loads(x)["usage"] - else 0 - ) - .sum(), - "total_tokens": telemetry_df["response"] - .apply( - lambda x: json.loads(x)["usage"]["total_tokens"] - if "usage" in json.loads(x) - and "total_tokens" in json.loads(x)["usage"] - else 0 - ) - .sum(), - "model": telemetry_df["response"] - .apply( - lambda x: json.loads(x)["model"] - if "model" in json.loads(x) - else "" - ) - .unique(), - } - - result_df = result_df.astype(result_df_type_mapping) - result_df = pd.concat( - [result_df, pd.DataFrame([result])], ignore_index=True - ) - result_df.to_excel( - writer, sheet_name=f"trial_{trial_index}", index=False - ) - - -if __name__ == "__main__" and __package__ is None: - main(sys.argv) diff --git a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/LICENSE b/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/LICENSE deleted file mode 100644 index f49a4e16e68b..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file diff --git a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/__init__.py b/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_dicts.py b/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_dicts.py deleted file mode 100644 index 9ce61c8cea99..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_dicts.py +++ /dev/null @@ -1,71 +0,0 @@ -# From AssistantBench modified slightly. -from typing import Dict, List -import numpy as np - -from .utils import _align_bags - - -def calculate_f1_score(precision, recall): - if precision + recall == 0: - return 0 # Handle the case to avoid division by zero - return 2 * (precision * recall) / (precision + recall) - - -def calc_recall(pred: Dict, gold: Dict, use_gold_for_eval: bool): - from .evaluate_factory import get_evaluator_from_gold_answer - - recall = [] - for gold_key, gold_value in gold.items(): - pred_value = pred.get(gold_key) - gold_value = fix_number(gold_value) - pred_value = fix_number(pred_value) - if gold_key not in pred: - recall.append(0) - else: - evaluator = ( - get_evaluator_from_gold_answer(type(gold_value)) - if use_gold_for_eval - else get_evaluator_from_gold_answer(type(pred_value)) - ) - if type(pred_value) != type(gold_value): - recall.append(0) - continue - recall.append(evaluator(pred_value, gold_value)) - avg_recall = np.average(recall) - return avg_recall - - -def fix_number(number): - if type(number) == str: - copy_ans = number - copy_ans = " ".join( - " ".join(" ".join(copy_ans.split("$")).split("%")).split("sqft") - ).strip() - copy_ans = copy_ans.strip() - copy_ans = copy_ans.replace(",", ".") - try: - return float(copy_ans) - except: - return number - elif type(number) == int: - return float(number) - else: - return number - - -def evaluate_pair_of_dicts(pred: Dict, gold: Dict): - recall = calc_recall(pred, gold, True) - precision = calc_recall(gold, pred, False) - f1 = calculate_f1_score(precision, recall) - return f1 - - -def evaluate_dicts(pred: List[Dict], gold: List[Dict]): - if not ( - type(pred) == dict - or len(pred) == 0 - or (type(pred) == list and type(pred[0]) == dict) - ): - return 0 - max_alignment_scores = _align_bags(pred, gold, evaluate_pair_of_dicts) - return np.average(max_alignment_scores) diff --git a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_factory.py b/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_factory.py deleted file mode 100644 index 6a63c0a26eeb..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_factory.py +++ /dev/null @@ -1,30 +0,0 @@ -#From AssistantBench modified slightly. - -from typing import Union, Dict - -from .evaluate_dicts import evaluate_dicts -from .evaluate_numbers import evaluate_numbers -from .evaluate_strings import evaluate_strings - -EvaluatorFactory = { - "string": evaluate_strings, - "number": evaluate_numbers, - "json": evaluate_dicts, - "string list": evaluate_strings, -} - -EvaluatorFactoryFromType = { - str: evaluate_strings, - int: evaluate_numbers, - float: evaluate_numbers, - bool: evaluate_strings, - list: evaluate_strings, -} - - -def get_evaluator(evaluator: str): - return EvaluatorFactory[evaluator] - - -def get_evaluator_from_gold_answer(gold_answer: Union[str, int, float]): - return EvaluatorFactoryFromType[gold_answer] diff --git a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_numbers.py b/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_numbers.py deleted file mode 100644 index 74a51b512653..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_numbers.py +++ /dev/null @@ -1,35 +0,0 @@ -#From AssistantBench modified slightly. - -from typing import Union -import numpy as np - - -# Renamed calc_z function to distance_function_log -def distance_function_log(pred: float, gold: float): - if pred == gold == 0: - return 1 - if pred == 0: - pred = 1e-4 - if gold == 0: - gold = 1e-4 - if pred > gold: - return max(0, 1 - np.log(pred / gold)) - else: - return max(0, 1 - np.log(gold / pred)) - - -def evaluate_numbers(pred: Union[float, str], gold: float): - res = None - if type(pred) != float and type(pred) != int: - try: - pred = float(pred) - except ValueError: - res = 0 - if type(gold) != float and type(gold) != int: - try: - gold = float(gold) - except ValueError: - res = 0 - if res is None: - res = distance_function_log(pred, gold) - return res diff --git a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_strings.py b/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_strings.py deleted file mode 100644 index 301eff3b7764..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/evaluate_strings.py +++ /dev/null @@ -1,180 +0,0 @@ -""" -From AssistantBench modified slightly. -Evaluation for two strings or list of strings. - -Code taken from the DROP benchmark - https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py -""" - -from collections import defaultdict -from typing import List, Set, Tuple, Union -import string -import re -import numpy as np -from scipy.optimize import linear_sum_assignment - - -# From here through _normalize_answer was originally copied from: -# https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ -# Then cleaned up and modified a bit. -def _remove_articles(text: str) -> str: - regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) - return re.sub(regex, " ", text) - - -def _white_space_fix(text: str) -> str: - return " ".join(text.split()) - - -EXCLUDE = set(string.punctuation) - - -def _remove_punc(text: str) -> str: - if not _is_number(text): - return "".join(ch for ch in text if ch not in EXCLUDE) - else: - return text - - -def _lower(text: str) -> str: - return text.lower() - - -def _tokenize(text: str) -> List[str]: - return re.split(" |-", text) - - -def _normalize_answer(text: str) -> str: - """Lower text and remove punctuation, articles and extra whitespace.""" - - parts = [ - _white_space_fix( - _remove_articles(_normalize_number(_remove_punc(_lower(token)))) - ) - for token in _tokenize(text) - ] - parts = [part for part in parts if part.strip()] - normalized = " ".join(parts).strip() - return normalized - - -def _is_number(text: str) -> bool: - try: - float(text) - return True - except ValueError: - return False - - -def _normalize_number(text: str) -> str: - if _is_number(text): - return str(float(text)) - else: - return text - - -def _answer_to_bags( - answer: Union[str, List[str], Tuple[str, ...]], -) -> Tuple[List[str], List[Set[str]]]: - if isinstance(answer, (list, tuple)): - raw_spans = answer - else: - raw_spans = [answer] - normalized_spans: List[str] = [] - token_bags = [] - for raw_span in raw_spans: - normalized_span = _normalize_answer(raw_span) - normalized_spans.append(normalized_span) - token_bags.append(set(normalized_span.split())) - return normalized_spans, token_bags - - -def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]: - """ - Takes gold and predicted answer sets and first finds the optimal 1-1 alignment - between them and gets maximum metric values over all the answers. - """ - scores = np.zeros([len(gold), len(predicted)]) - for gold_index, gold_item in enumerate(gold): - for pred_index, pred_item in enumerate(predicted): - if _match_numbers_if_present(gold_item, pred_item): - scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item) - row_ind, col_ind = linear_sum_assignment(-scores) - - max_scores = np.zeros([max(len(gold), len(predicted))]) - for row, column in zip(row_ind, col_ind): - max_scores[row] = max(max_scores[row], scores[row, column]) - return max_scores - - -def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float: - intersection = len(gold_bag.intersection(predicted_bag)) - if not predicted_bag: - precision = 1.0 - else: - precision = intersection / float(len(predicted_bag)) - if not gold_bag: - recall = 1.0 - else: - recall = intersection / float(len(gold_bag)) - f1 = ( - (2 * precision * recall) / (precision + recall) - if not (precision == 0.0 and recall == 0.0) - else 0.0 - ) - return f1 - - -def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool: - gold_numbers = set() - predicted_numbers = set() - for word in gold_bag: - if _is_number(word): - gold_numbers.add(word) - for word in predicted_bag: - if _is_number(word): - predicted_numbers.add(word) - if (not gold_numbers) or gold_numbers.intersection(predicted_numbers): - return True - return False - - -def get_metrics( - predicted: Union[str, List[str], Tuple[str, ...]], - gold: Union[str, List[str], Tuple[str, ...]], -) -> Tuple[float, float]: - """ - Takes a predicted answer and a gold answer (that are both either a string or a list of - strings), and returns exact match and the DROP F1 metric for the prediction. If you are - writing a script for evaluating objects in memory (say, the output of predictions during - validation, or while training), this is the function you want to call, after using - :func:`answer_json_to_strings` when reading the gold answer from the released data file. - """ - predicted_bags = _answer_to_bags(predicted) - gold_bags = _answer_to_bags(gold) - - if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len( - gold_bags[0] - ): - exact_match = 1.0 - else: - exact_match = 0.0 - - f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1]) - f1 = np.mean(f1_per_bag) - f1 = round(f1, 2) - return exact_match, f1 - - -def evaluate_strings(prediction, gold): - if type(prediction) != list and type(prediction) != str: - prediction = str(prediction) - if type(gold) != list and type(gold) != str: - gold = str(gold) - try: - predicted_bags = _answer_to_bags(prediction) - gold_bags = _answer_to_bags(gold) - f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1]) - f1 = np.mean(f1_per_bag) - except Exception: - f1 = 0.0 - return f1 diff --git a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/readme.md b/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/readme.md deleted file mode 100644 index 733706ff4eeb..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/readme.md +++ /dev/null @@ -1 +0,0 @@ -These files were obtained from the creators of the AssistantBench benchmark and modified slightly. You can find the latest version at [https://huggingface.co/spaces/AssistantBench/leaderboard/tree/main/evaluation](https://huggingface.co/spaces/AssistantBench/leaderboard/tree/main/evaluation) diff --git a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/utils.py b/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/utils.py deleted file mode 100644 index ea55f392a55a..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/Scripts/evaluate_utils/utils.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import List, Set, Tuple, Union, Callable -import numpy as np -from scipy.optimize import linear_sum_assignment - - -def _align_bags( - predicted: List[Set[str]], - gold: List[Set[str]], - method: Callable[[object, object], float], -) -> List[float]: - """ - Takes gold and predicted answer sets and first finds the optimal 1-1 alignment - between them and gets maximum metric values over all the answers. - """ - scores = np.zeros([len(gold), len(predicted)]) - for gold_index, gold_item in enumerate(gold): - for pred_index, pred_item in enumerate(predicted): - scores[gold_index, pred_index] = method(pred_item, gold_item) - row_ind, col_ind = linear_sum_assignment(-scores) - - max_scores = np.zeros([max(len(gold), len(predicted))]) - for row, column in zip(row_ind, col_ind): - max_scores[row] = max(max_scores[row], scores[row, column]) - return max_scores diff --git a/python/packages/agbench/benchmarks/AssistantBench/Scripts/init_tasks.py b/python/packages/agbench/benchmarks/AssistantBench/Scripts/init_tasks.py deleted file mode 100644 index 752739cb6d9f..000000000000 --- a/python/packages/agbench/benchmarks/AssistantBench/Scripts/init_tasks.py +++ /dev/null @@ -1,93 +0,0 @@ -import json -import os -import re -import sys - -from huggingface_hub import snapshot_download - -SCRIPT_PATH = os.path.realpath(__file__) -SCRIPT_NAME = os.path.basename(SCRIPT_PATH) -SCRIPT_DIR = os.path.dirname(SCRIPT_PATH) - -SCENARIO_DIR = os.path.realpath(os.path.join(SCRIPT_DIR, os.path.pardir)) -TEMPLATES_DIR = os.path.join(SCENARIO_DIR, "Templates") -TASKS_DIR = os.path.join(SCENARIO_DIR, "Tasks") -DOWNLOADS_DIR = os.path.join(SCENARIO_DIR, "Downloads") -REPO_DIR = os.path.join(DOWNLOADS_DIR, "AssistantBench") - - -def download_assistantbench(): - """Download the AssistantBench benchmark from Hugging Face.""" - - if not os.path.isdir(DOWNLOADS_DIR): - os.mkdir(DOWNLOADS_DIR) - - """Download the AssistantBench dataset from Hugging Face Hub""" - snapshot_download( - repo_id="AssistantBench/AssistantBench", - repo_type="dataset", - local_dir=REPO_DIR, - local_dir_use_symlinks=True, - ) - - -def create_jsonl(data_file_path, file_name, template): - """Creates a JSONL scenario file with a given name, and template path.""" - tasks = [] - with open(data_file_path) as fh: - for line in fh: - data = json.loads(line) - tasks.append(data) - file_name = os.path.basename(file_name) - if not os.path.isdir(TASKS_DIR): - os.mkdir(TASKS_DIR) - - with open(os.path.join(TASKS_DIR, file_name), "wt") as fh: - for task in tasks: - if "answer" not in task or task["answer"] is None: - task["answer"] = "" - print(f"Converting: [{file_name}] {task['id']}") - template_cp_list = [template] - record = { - "id": task["id"], - "template": template_cp_list, - "substitutions": { - "scenario.py": { - "__FILE_NAME__": "", - }, - "expected_answer.txt": {"__EXPECTED_ANSWER__": task["answer"]}, - "prompt.txt": {"__PROMPT__": task["task"]}, - }, - "difficulty": task["difficulty"], - "explanation": task["explanation"], - "metadata": task["metadata"], - "gold_url": task["gold_url"], - "set": task["set"], - } - fh.write(json.dumps(record).strip() + "\n") - - -############################################################################### -def main(): - ab_validation_files = os.path.join(REPO_DIR, "assistant_bench_v1.0_dev.jsonl") - ab_test_files = os.path.join(REPO_DIR, "assistant_bench_v1.0_test.jsonl") - - if not os.path.isfile(ab_validation_files) or not os.path.isfile(ab_test_files): - download_assistantbench() - - if not os.path.isfile(ab_validation_files) or not os.path.isfile(ab_test_files): - sys.exit(f"Error: '{REPO_DIR}' does not appear to be a copy of the AssistantBench repository.") - - templates = {} - for entry in os.scandir(TEMPLATES_DIR): - if entry.is_dir(): - templates[re.sub(r"\s", "", entry.name)] = entry.path - print(templates) - # make a copy of the data in the Tasks directory - for t in templates.items(): - create_jsonl(ab_validation_files, f"assistant_bench_v1.0_dev__{t[0]}.jsonl", t[1]) - create_jsonl(ab_test_files, f"assistant_bench_v1.0_test__{t[0]}.jsonl", t[1]) - - -if __name__ == "__main__" and __package__ is None: - main() diff --git a/python/packages/agbench/benchmarks/GAIA/ENV.json.sample b/python/packages/agbench/benchmarks/GAIA/ENV.json.sample deleted file mode 100644 index 1f2c4915e3c7..000000000000 --- a/python/packages/agbench/benchmarks/GAIA/ENV.json.sample +++ /dev/null @@ -1,5 +0,0 @@ -{ - "BING_API_KEY": "YOUR_KEY_KEY", - "HOMEPAGE": "https://www.bing.com/", - "WEB_SURFER_DEBUG_DIR": "/autogen/debug" -} diff --git a/python/packages/agbench/benchmarks/GAIA/ENV.yaml b/python/packages/agbench/benchmarks/GAIA/ENV.yaml new file mode 100644 index 000000000000..e2778d51e07d --- /dev/null +++ b/python/packages/agbench/benchmarks/GAIA/ENV.yaml @@ -0,0 +1,30 @@ +# ENV.yaml +# +# This file specifies environment variables to be passed to the Docker task +# instances or virtual environments. These values are ephemeral, and are +# discarded when the task concludes. This is useful for passing API keys, etc. +# since they will not be saved in logs or to any task output. +# +# String values can reference environment variable on the host machine. +# For example: +# +# OPENAI_API_KEY: ${OPENAI_API_KEY} +# +# Will copy the host's OPENAI_API_KEY environment variable to the corresponding +# variable in the task environment. +# +# Complex values will be converte to JSON, and then passed as a string to the +# task environment. For example: +# +# MODEL_CONFIG: +# provider: autogen_ext.models.openai.OpenAIChatCompletionClient +# config: +# model: gpt-4o +# +# Will be converted to: +# +# MODEL_CONFIG: >- +# {"provider": "autogen_ext.models.openai.OpenAIChatCompletionClient", "config": {"model": "gpt-4o"}} +# + +OPENAI_API_KEY: ${OPENAI_API_KEY} diff --git a/python/packages/agbench/benchmarks/GAIA/README.md b/python/packages/agbench/benchmarks/GAIA/README.md index 753d8e4ed51a..ef98a24e4b4e 100644 --- a/python/packages/agbench/benchmarks/GAIA/README.md +++ b/python/packages/agbench/benchmarks/GAIA/README.md @@ -10,31 +10,7 @@ Navigate to GAIA cd benchmarks/GAIA ``` -Create a file called ENV.json with the following (required) contents (If you're using MagenticOne) - -```json -{ - "BING_API_KEY": "REPLACE_WITH_YOUR_BING_API_KEY", - "HOMEPAGE": "https://www.bing.com/", - "WEB_SURFER_DEBUG_DIR": "/autogen/debug", - "CHAT_COMPLETION_KWARGS_JSON": "{\"api_version\": \"2024-02-15-preview\", \"azure_endpoint\": \"YOUR_ENDPOINT/\", \"model_capabilities\": {\"function_calling\": true, \"json_output\": true, \"vision\": true}, \"azure_ad_token_provider\": \"DEFAULT\", \"model\": \"gpt-4o-2024-05-13\"}", - "CHAT_COMPLETION_PROVIDER": "azure" -} -``` - -You can also use the openai client by replacing the last two entries in the ENV file by: - -- `CHAT_COMPLETION_PROVIDER='openai'` -- `CHAT_COMPLETION_KWARGS_JSON` with the following JSON structure: - -```json -{ - "api_key": "REPLACE_WITH_YOUR_API", - "model": "gpt-4o-2024-05-13" -} -``` - -You might need to add additional packages to the requirements.txt file inside the Templates/MagenticOne folder. +Update `config.yaml` to point to your model host, as appropriate. The default configuration points to 'gpt-4o'. Now initialize the tasks. diff --git a/python/packages/agbench/benchmarks/GAIA/Scripts/custom_tabulate.py b/python/packages/agbench/benchmarks/GAIA/Scripts/custom_tabulate.py index ec51863e9c7b..1b23ee219f7f 100644 --- a/python/packages/agbench/benchmarks/GAIA/Scripts/custom_tabulate.py +++ b/python/packages/agbench/benchmarks/GAIA/Scripts/custom_tabulate.py @@ -6,12 +6,14 @@ import pandas as pd import sqlite3 import glob +import string +import warnings import numpy as np EXCLUDE_DIR_NAMES = ["__pycache__"] -def normalize_answer(a): +def in_house_normalize_answer(a): # Lower case # Trim (left and right) # standardize comma separated values @@ -22,6 +24,106 @@ def normalize_answer(a): return norm_answer +def in_house_question_scorer( + model_answer: str, + ground_truth: str, +) -> bool: + n_ma = in_house_normalize_answer(model_answer) + n_gt = in_house_normalize_answer(ground_truth) + return (n_gt != "" and n_gt == n_ma) + + +def gaia_question_scorer( + model_answer: str, + ground_truth: str, +) -> bool: + #FROM: https://huggingface.co/spaces/gaia-benchmark/leaderboard/blob/main/scorer.py + + def normalize_number_str(number_str: str) -> float: + # we replace these common units and commas to allow + # conversion to float + for char in ["$", "%", ","]: + number_str = number_str.replace(char, "") + try: + return float(number_str) + except ValueError: + print(f"String {number_str} cannot be normalized to number str.") + return float("inf") + + def split_string(s: str, char_list: list[str] = [",", ";"],) -> list[str]: + pattern = f"[{''.join(char_list)}]" + return re.split(pattern, s) + + def normalize_str(input_str, remove_punct=True) -> str: + """ + Normalize a string by: + - Removing all white spaces + - Optionally removing punctuation (if remove_punct is True) + - Converting to lowercase + Parameters: + - input_str: str, the string to normalize + - remove_punct: bool, whether to remove punctuation (default: True) + Returns: + - str, the normalized string + """ + # Remove all white spaces. Required e.g for seagull vs. sea gull + no_spaces = re.sub(r"\s", "", input_str) + + # Remove punctuation, if specified. + if remove_punct: + translator = str.maketrans("", "", string.punctuation) + return no_spaces.lower().translate(translator) + else: + return no_spaces.lower() + + + def is_float(element: any) -> bool: + try: + float(element) + return True + except ValueError: + return False + + # if gt is a number + if is_float(ground_truth): + normalized_answer = normalize_number_str(model_answer) + return normalized_answer == float(ground_truth) + + # if gt is a list + elif any(char in ground_truth for char in [",", ";"]): + # question with the fish: normalization removes punct + + gt_elems = split_string(ground_truth) + ma_elems = split_string(model_answer) + + # check length is the same + if len(gt_elems) != len(ma_elems): + #warnings.warn( + # "Answer lists have different lengths, returning False.", UserWarning + #) + return False + + # compare each element as float or str + comparisons = [] + for ma_elem, gt_elem in zip(ma_elems, gt_elems): + if is_float(gt_elem): + normalized_ma_elem = normalize_number_str(ma_elem) + comparisons.append(normalized_ma_elem == float(gt_elem)) + else: + # we do not remove punct since comparisons can include punct + comparisons.append( + normalize_str(ma_elem, remove_punct=False) + == normalize_str(gt_elem, remove_punct=False) + ) + return all(comparisons) + + # if gt is a str + else: + return normalize_str(model_answer) == normalize_str(ground_truth) + + +############## + def scorer(instance_dir): # Read the expected answer expected_answer_file = os.path.join(instance_dir, "expected_answer.txt") @@ -51,147 +153,12 @@ def scorer(instance_dir): return None # Return true if they are equal after normalization - n_ex = normalize_answer(expected_answer) - n_final = normalize_answer(final_answer) - return ( - (n_ex != "" and n_ex == n_final), - n_ex, - n_final - ) - - -def get_number_of_chat_messages(chat_messages_dir): - result = 0 - for file in glob.glob(f"{chat_messages_dir}/*_messages.json"): - with open(file, "r") as f: - content = json.load(f) - for agent, messages in content.items(): - result += len(messages) - return result + # return in_house_question_scorer(final_answer, expected_answer) + return gaia_question_scorer(final_answer, expected_answer) def main(args): - parsed_args, all_results = default_tabulate(args, scorer=scorer) - excel_path = parsed_args.excel - - if excel_path: - excel_dir = os.path.dirname(excel_path) or "." - if not os.path.exists(excel_dir): - os.makedirs(excel_dir, exist_ok=True) - - if not excel_path.endswith((".xlsx", ".xls")): - excel_path += ".xlsx" - - runlogs = parsed_args.runlogs if parsed_args.runlogs.endswith("/") else parsed_args.runlogs + "/" - - if os.path.isdir(runlogs): - task_ids = sorted( - [task_id for task_id in os.listdir(runlogs) if task_id not in EXCLUDE_DIR_NAMES], - key=lambda s: os.path.getmtime(os.path.join(parsed_args.runlogs, s)), - ) - else: - raise ValueError("please input a valid directory to tabulate result") - - trials = sorted(os.listdir(f"{runlogs}{task_ids[0]}"), key=lambda x: int(x)) if len(task_ids) > 0 else [] - dbnames = [[f"{runlogs}{task_id}/{trial}/telemetry.db" for task_id in task_ids] for trial in trials] - - query = """ - SELECT cost, session_id, response, start_time, end_time - FROM ( - SELECT invocation_id, cost, session_id, response, start_time, end_time, - ROW_NUMBER() OVER (PARTITION BY invocation_id ORDER BY start_time) as rn - FROM chat_completions - ) - WHERE rn = 1; - """ - - with pd.ExcelWriter(excel_path, engine="openpyxl") as writer: - for trial_index, each_trial in enumerate(dbnames): - result_df = pd.DataFrame( - columns=[ - "id", - "status", - "expected_answer", - "final_answer", - "cost", - "latency", - "num_of_llm_requests", - "num_of_chat_messages", - "prompt_tokens", - "completion_tokens", - "total_tokens", - "model", - ] - ) - - result_df_type_mapping = { - "id": str, - "status": bool, - "expected_answer": str, - "final_answer": str, - "cost": float, - "latency": float, - "num_of_llm_requests": int, - "num_of_chat_messages": int, - "prompt_tokens": int, - "completion_tokens": int, - "total_tokens": int, - } - - for dbname, scorer_results in zip(each_trial, all_results): - task_id = scorer_results[0] - scorer_result = scorer_results[trial_index + 1] - - status, expected_answer, final_answer = scorer_result if scorer_result else (False,"","") - - con = sqlite3.connect(dbname) - - # TODO: if large amount of data, add chunksize - telemetry_df = pd.read_sql_query(query, con) - - earliest_starttime = pd.to_datetime(telemetry_df["start_time"], format="%Y-%m-%d %H:%M:%S.%f").min() - latest_endtime = pd.to_datetime(telemetry_df["end_time"], format="%Y-%m-%d %H:%M:%S.%f").max() - - num_of_chat_messages = get_number_of_chat_messages(chat_messages_dir=os.path.dirname(dbname)) - result = { - "id": task_id, - "status": status, - "expected_answer": expected_answer, - "final_answer": final_answer, - "cost": telemetry_df["cost"].sum(), - "latency": (latest_endtime - earliest_starttime).total_seconds(), - "num_of_llm_requests": len(telemetry_df), - "num_of_chat_messages": num_of_chat_messages, - "prompt_tokens": telemetry_df["response"] - .apply( - lambda x: json.loads(x)["usage"]["prompt_tokens"] - if "usage" in json.loads(x) and "prompt_tokens" in json.loads(x)["usage"] - else 0 - ) - .sum(), - "completion_tokens": telemetry_df["response"] - .apply( - lambda x: json.loads(x)["usage"]["completion_tokens"] - if "usage" in json.loads(x) and "completion_tokens" in json.loads(x)["usage"] - else 0 - ) - .sum(), - "total_tokens": telemetry_df["response"] - .apply( - lambda x: json.loads(x)["usage"]["total_tokens"] - if "usage" in json.loads(x) and "total_tokens" in json.loads(x)["usage"] - else 0 - ) - .sum(), - "model": telemetry_df["response"] - .apply(lambda x: json.loads(x)["model"] if "model" in json.loads(x) else "") - .unique(), - } - - result_df = result_df.astype(result_df_type_mapping) - result_df = pd.concat([result_df, pd.DataFrame([result])], ignore_index=True) - result_df.to_excel(writer, sheet_name=f"trial_{trial_index}", index=False) - + default_tabulate(args, scorer=scorer) if __name__ == "__main__" and __package__ is None: main(sys.argv) diff --git a/python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/expected_answer.txt b/python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/expected_answer.txt new file mode 100644 index 000000000000..8153c2bf8242 --- /dev/null +++ b/python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/expected_answer.txt @@ -0,0 +1 @@ +__EXPECTED_ANSWER__ diff --git a/python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/prompt.txt b/python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/prompt.txt new file mode 100644 index 000000000000..482f50dca311 --- /dev/null +++ b/python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/prompt.txt @@ -0,0 +1 @@ +__PROMPT__ diff --git a/python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/requirements.txt b/python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/requirements.txt new file mode 100644 index 000000000000..3db8bfa55857 --- /dev/null +++ b/python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/requirements.txt @@ -0,0 +1,5 @@ +tiktoken +pyyaml +/autogen_python/packages/autogen-core +/autogen_python/packages/autogen-ext[openai,magentic-one] +/autogen_python/packages/autogen-agentchat diff --git a/python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/scenario.py b/python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/scenario.py new file mode 100644 index 000000000000..7f43c111e29a --- /dev/null +++ b/python/packages/agbench/benchmarks/GAIA/Templates/MagenticOne/scenario.py @@ -0,0 +1,89 @@ +import asyncio +import os +import yaml +import warnings +from autogen_ext.agents.magentic_one import MagenticOneCoderAgent +from autogen_agentchat.teams import MagenticOneGroupChat +from autogen_agentchat.ui import Console +from autogen_core.models import ModelFamily +from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor +from autogen_agentchat.conditions import TextMentionTermination +from autogen_core.models import ChatCompletionClient +from autogen_ext.agents.web_surfer import MultimodalWebSurfer +from autogen_ext.agents.file_surfer import FileSurfer +from autogen_agentchat.agents import CodeExecutorAgent +from autogen_agentchat.messages import TextMessage + +# Suppress warnings about the requests.Session() not being closed +warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning) + +async def main() -> None: + + # Load model configuration and create the model client. + with open("config.yaml", "r") as f: + config = yaml.safe_load(f) + + orchestrator_client = ChatCompletionClient.load_component(config["orchestrator_client"]) + coder_client = ChatCompletionClient.load_component(config["coder_client"]) + web_surfer_client = ChatCompletionClient.load_component(config["web_surfer_client"]) + file_surfer_client = ChatCompletionClient.load_component(config["file_surfer_client"]) + + # Read the prompt + prompt = "" + with open("prompt.txt", "rt") as fh: + prompt = fh.read().strip() + filename = "__FILE_NAME__".strip() + + # Set up the team + coder = MagenticOneCoderAgent( + "Assistant", + model_client = coder_client, + ) + + executor = CodeExecutorAgent("ComputerTerminal", code_executor=LocalCommandLineCodeExecutor()) + + file_surfer = FileSurfer( + name="FileSurfer", + model_client = file_surfer_client, + ) + + web_surfer = MultimodalWebSurfer( + name="WebSurfer", + model_client = web_surfer_client, + downloads_folder=os.getcwd(), + debug_dir="logs", + to_save_screenshots=True, + ) + + team = MagenticOneGroupChat( + [coder, executor, file_surfer, web_surfer], + model_client=orchestrator_client, + max_turns=20, + final_answer_prompt= f""", +We have completed the following task: + +{prompt} + +The above messages contain the conversation that took place to complete the task. +Read the above conversation and output a FINAL ANSWER to the question. +To output the final answer, use the following template: FINAL ANSWER: [YOUR FINAL ANSWER] +Your FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. +ADDITIONALLY, your FINAL ANSWER MUST adhere to any formatting instructions specified in the original question (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.) +If you are asked for a number, express it numerically (i.e., with digits rather than words), don't use commas, and don't include units such as $ or percent signs unless specified otherwise. +If you are asked for a string, don't use articles or abbreviations (e.g. for cities), unless specified otherwise. Don't output any final sentence punctuation such as '.', '!', or '?'. +If you are asked for a comma separated list, apply the above rules depending on whether the elements are numbers or strings. +""".strip() + ) + + # Prepare the prompt + filename_prompt = "" + if len(filename) > 0: + filename_prompt = f"The question is about a file, document or image, which can be accessed by the filename '{filename}' in the current working directory." + task = f"{prompt}\n\n{filename_prompt}" + + # Run the task + stream = team.run_stream(task=task.strip()) + await Console(stream) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/expected_answer.txt b/python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/expected_answer.txt new file mode 100644 index 000000000000..8153c2bf8242 --- /dev/null +++ b/python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/expected_answer.txt @@ -0,0 +1 @@ +__EXPECTED_ANSWER__ diff --git a/python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/prompt.txt b/python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/prompt.txt new file mode 100644 index 000000000000..482f50dca311 --- /dev/null +++ b/python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/prompt.txt @@ -0,0 +1 @@ +__PROMPT__ diff --git a/python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/requirements.txt b/python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/requirements.txt new file mode 100644 index 000000000000..3db8bfa55857 --- /dev/null +++ b/python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/requirements.txt @@ -0,0 +1,5 @@ +tiktoken +pyyaml +/autogen_python/packages/autogen-core +/autogen_python/packages/autogen-ext[openai,magentic-one] +/autogen_python/packages/autogen-agentchat diff --git a/python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/scenario.py b/python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/scenario.py new file mode 100644 index 000000000000..e2e1d8fae009 --- /dev/null +++ b/python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/scenario.py @@ -0,0 +1,176 @@ +import asyncio +import os +import yaml +import warnings +from typing import Sequence +from autogen_ext.agents.magentic_one import MagenticOneCoderAgent +from autogen_agentchat.teams import SelectorGroupChat +from autogen_agentchat.conditions import MaxMessageTermination +from autogen_agentchat.ui import Console +from autogen_agentchat.utils import content_to_str +from autogen_core.models import ModelFamily +from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor +from autogen_agentchat.conditions import TextMentionTermination +from autogen_agentchat.base import TerminationCondition, TerminatedException +from autogen_core.models import ChatCompletionClient +from autogen_ext.agents.web_surfer import MultimodalWebSurfer +from autogen_ext.agents.file_surfer import FileSurfer +from autogen_agentchat.agents import CodeExecutorAgent +from autogen_agentchat.messages import TextMessage, AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage +from autogen_core.models import LLMMessage, UserMessage, AssistantMessage + +# Suppress warnings about the requests.Session() not being closed +warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning) + +async def main() -> None: + + # Load model configuration and create the model client. + with open("config.yaml", "r") as f: + config = yaml.safe_load(f) + + orchestrator_client = ChatCompletionClient.load_component(config["orchestrator_client"]) + coder_client = ChatCompletionClient.load_component(config["coder_client"]) + web_surfer_client = ChatCompletionClient.load_component(config["web_surfer_client"]) + file_surfer_client = ChatCompletionClient.load_component(config["file_surfer_client"]) + + # Read the prompt + prompt = "" + with open("prompt.txt", "rt") as fh: + prompt = fh.read().strip() + filename = "__FILE_NAME__".strip() + + # Set up the team + coder = MagenticOneCoderAgent( + "Assistant", + model_client = coder_client, + ) + + executor = CodeExecutorAgent("ComputerTerminal", code_executor=LocalCommandLineCodeExecutor()) + + file_surfer = FileSurfer( + name="FileSurfer", + model_client = file_surfer_client, + ) + + web_surfer = MultimodalWebSurfer( + name="WebSurfer", + model_client = web_surfer_client, + downloads_folder=os.getcwd(), + debug_dir="logs", + to_save_screenshots=True, + ) + + # Prepare the prompt + filename_prompt = "" + if len(filename) > 0: + filename_prompt = f"The question is about a file, document or image, which can be accessed by the filename '{filename}' in the current working directory." + task = f"{prompt}\n\n{filename_prompt}" + + # Termination conditions + max_messages_termination = MaxMessageTermination(max_messages=20) + llm_termination = LLMTermination( + prompt=f"""Consider the following task: +{task.strip()} + +Does the above conversation suggest that the task has been solved? +If so, reply "TERMINATE", otherwise reply "CONTINUE" +""", + model_client=orchestrator_client + ) + + termination = max_messages_termination | llm_termination + + # Create the team + team = SelectorGroupChat( + [coder, executor, file_surfer, web_surfer], + model_client=orchestrator_client, + termination_condition=termination, + ) + + # Run the task + stream = team.run_stream(task=task.strip()) + result = await Console(stream) + + # Do one more inference to format the results + final_context: Sequence[LLMMessage] = [] + for message in result.messages: + if isinstance(message, TextMessage): + final_context.append(UserMessage(content=message.content, source=message.source)) + elif isinstance(message, MultiModalMessage): + if orchestrator_client.model_info["vision"]: + final_context.append(UserMessage(content=message.content, source=message.source)) + else: + final_context.append(UserMessage(content=content_to_str(message.content), source=message.source)) + final_context.append(UserMessage( + content=f"""We have completed the following task: +{prompt} + +The above messages contain the conversation that took place to complete the task. +Read the above conversation and output a FINAL ANSWER to the question. +To output the final answer, use the following template: FINAL ANSWER: [YOUR FINAL ANSWER] +Your FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. +ADDITIONALLY, your FINAL ANSWER MUST adhere to any formatting instructions specified in the original question (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.) +If you are asked for a number, express it numerically (i.e., with digits rather than words), don't use commas, and don't include units such as $ or percent signs unless specified otherwise. +If you are asked for a string, don't use articles or abbreviations (e.g. for cities), unless specified otherwise. Don't output any final sentence punctuation such as '.', '!', or '?'. +If you are asked for a comma separated list, apply the above rules depending on whether the elements are numbers or strings. +#""".strip(), + source="user")) + + # Call the model to evaluate + response = await orchestrator_client.create(final_context) + print(response.content, flush=True) + + +class LLMTermination(TerminationCondition): + """Terminate the conversation if an LLM determines the task is complete. + + Args: + prompt: The prompt to evaluate in the llm + model_client: The LLM model_client to use + termination_phrase: The phrase to look for in the LLM output to trigger termination + """ + + def __init__(self, prompt: str, model_client: ChatCompletionClient, termination_phrase: str = "TERMINATE") -> None: + self._prompt = prompt + self._model_client = model_client + self._termination_phrase = termination_phrase + self._terminated = False + self._context: Sequence[LLMMessage] = [] + + @property + def terminated(self) -> bool: + return self._terminated + + async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None: + if self._terminated: + raise TerminatedException("Termination condition has already been reached") + + # Build the context + for message in messages: + if isinstance(message, TextMessage): + self._context.append(UserMessage(content=message.content, source=message.source)) + elif isinstance(message, MultiModalMessage): + if self._model_client.model_info["vision"]: + self._context.append(UserMessage(content=message.content, source=message.source)) + else: + self._context.append(UserMessage(content=content_to_str(message.content), source=message.source)) + + if len(self._context) == 0: + return None + + # Call the model to evaluate + response = await self._model_client.create(self._context + [UserMessage(content=self._prompt, source="user")]) + + # Check for termination + if isinstance(message.content, str) and self._termination_phrase in response.content: + self._terminated = True + return StopMessage(content=message.content, source="LLMTermination") + return None + + async def reset(self) -> None: + self._terminated = False + self._context = [] + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/packages/agbench/benchmarks/GAIA/config.yaml b/python/packages/agbench/benchmarks/GAIA/config.yaml new file mode 100644 index 000000000000..a13c1b1a4598 --- /dev/null +++ b/python/packages/agbench/benchmarks/GAIA/config.yaml @@ -0,0 +1,38 @@ +# config.yaml +# +# The contents of this file will be copied into the 'config.yaml' file of +# every expanded Task, just prior to running the scenario. This provides a +# good place to store model or other configurations important for the scenario. + +############################### +# Open AI model configuration # +############################### +model_config: &client + provider: autogen_ext.models.openai.OpenAIChatCompletionClient + config: + model: gpt-4o + + +############################## +# Ollama model configuration # +############################## +#model_config: &client +# provider: autogen_ext.models.openai.OpenAIChatCompletionClient +# config: +# model: deepseek-r1:7b +# base_url: http://localhost:11434/v1/ +# api_key: ollama +# model_info: +# function_calling: false +# json_output: false +# vision: false +# family: r1 +# + +####################### +# Used by MagenticOne # +####################### +orchestrator_client: *client +coder_client: *client +web_surfer_client: *client +file_surfer_client: *client diff --git a/python/packages/agbench/benchmarks/HumanEval/ENV.yaml b/python/packages/agbench/benchmarks/HumanEval/ENV.yaml new file mode 100644 index 000000000000..e2778d51e07d --- /dev/null +++ b/python/packages/agbench/benchmarks/HumanEval/ENV.yaml @@ -0,0 +1,30 @@ +# ENV.yaml +# +# This file specifies environment variables to be passed to the Docker task +# instances or virtual environments. These values are ephemeral, and are +# discarded when the task concludes. This is useful for passing API keys, etc. +# since they will not be saved in logs or to any task output. +# +# String values can reference environment variable on the host machine. +# For example: +# +# OPENAI_API_KEY: ${OPENAI_API_KEY} +# +# Will copy the host's OPENAI_API_KEY environment variable to the corresponding +# variable in the task environment. +# +# Complex values will be converte to JSON, and then passed as a string to the +# task environment. For example: +# +# MODEL_CONFIG: +# provider: autogen_ext.models.openai.OpenAIChatCompletionClient +# config: +# model: gpt-4o +# +# Will be converted to: +# +# MODEL_CONFIG: >- +# {"provider": "autogen_ext.models.openai.OpenAIChatCompletionClient", "config": {"model": "gpt-4o"}} +# + +OPENAI_API_KEY: ${OPENAI_API_KEY} diff --git a/python/packages/agbench/benchmarks/HumanEval/README.md b/python/packages/agbench/benchmarks/HumanEval/README.md index 25acc2630523..0c045af4cc1a 100644 --- a/python/packages/agbench/benchmarks/HumanEval/README.md +++ b/python/packages/agbench/benchmarks/HumanEval/README.md @@ -15,26 +15,8 @@ Navigate to HumanEval cd benchmarks/HumanEval ``` -Create a file called ENV.json with the following (required) contents (If you're using MagenticOne) +Update `config.yaml` to point to your model host, as appropriate. The default configuration points to 'gpt-4o'. -```json -{ - "CHAT_COMPLETION_KWARGS_JSON": "{\"api_version\": \"2024-02-15-preview\", \"azure_endpoint\": \"YOUR_ENDPOINT/\", \"model_capabilities\": {\"function_calling\": true, \"json_output\": true, \"vision\": true}, \"azure_ad_token_provider\": \"DEFAULT\", \"model\": \"gpt-4o-2024-05-13\"}", - "CHAT_COMPLETION_PROVIDER": "azure" -} -``` - -You can also use the openai client by replacing the last two entries in the ENV file by: - -- `CHAT_COMPLETION_PROVIDER='openai'` -- `CHAT_COMPLETION_KWARGS_JSON` with the following JSON structure: - -```json -{ - "api_key": "REPLACE_WITH_YOUR_API", - "model": "gpt-4o-2024-05-13" -} -``` Now initialize the tasks. @@ -51,13 +33,13 @@ Once the script completes, you should now see a folder in your current directory Now to run a specific subset of HumanEval use: ```bash -agbench run Tasks/human_eval_MagenticOne.jsonl +agbench run Tasks/human_eval_AgentChat.jsonl ``` You should see the command line print the raw logs that shows the agents in action To see a summary of the results (e.g., task completion rates), in a new terminal run the following: ```bash -agbench tabulate Results/human_eval_MagenticOne +agbench tabulate Results/human_eval_AgentChat ``` diff --git a/python/packages/agbench/benchmarks/HumanEval/Scripts/init_tasks.py b/python/packages/agbench/benchmarks/HumanEval/Scripts/init_tasks.py index df4e6b194841..2dc7d4f0fb7b 100644 --- a/python/packages/agbench/benchmarks/HumanEval/Scripts/init_tasks.py +++ b/python/packages/agbench/benchmarks/HumanEval/Scripts/init_tasks.py @@ -93,9 +93,9 @@ def create_jsonl(name, tasks, template): "id": task["task_id"].replace("/", "_"), "template": template, "substitutions": { - "scenario.py": {"__ENTRY_POINT__": task["entry_point"]}, "prompt.txt": {"__PROMPT__": task["prompt"]}, - "unit_tests.py": {"__TEST__": task["test"]}, + "test.txt": {"__TEST__": task["test"]}, + "custom_code_executor.py": {"__ENTRY_POINT__": task["entry_point"]}, }, } diff --git a/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/custom_code_executor.py b/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/custom_code_executor.py new file mode 100644 index 000000000000..5d9893e057d0 --- /dev/null +++ b/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/custom_code_executor.py @@ -0,0 +1,54 @@ +import re +from typing import List, Sequence + +from autogen_core.code_executor import CodeBlock, CodeExecutor +from autogen_agentchat.agents import CodeExecutorAgent + + +class CustomCodeExecutorAgent(CodeExecutorAgent): + + def __init__( + self, + name: str, + code_executor: CodeExecutor, + *, + description: str = "A computer terminal that performs no other action than running Python scripts (provided to it quoted in ```python code blocks), or sh shell scripts (provided to it quoted in ```sh code blocks).", + sources: Sequence[str] | None = None, + ) -> None: + super().__init__(name=name, description=description, code_executor=code_executor, sources=sources) + self._test_code = "" + with open("test.txt", "rt") as fh: + self._test_code = fh.read() + + + def _extract_markdown_code_blocks(self, markdown_text: str) -> List[CodeBlock]: + code_blocks = super()._extract_markdown_code_blocks(markdown_text) + new_blocks: List[CodeBlock] = [] + for block in code_blocks: + + # Handle deepseek + code_content = block.code + #m = re.search(r"^\s*\s*(.*?)\s*\s*(.*?)\s*$", code_content, re.DOTALL) + #if m: + # code_content = m.group(2) + + # If python, wrap the extracted code in a unit testing harness + if block.language and block.language.lower() == "python": + code_content = self._test_code + """ + +def run_tests(candidate): + try: + check(candidate) + # We can search for this string in the output + print("ALL TESTS PASSED !#!#") + print("TERMINATE") + except AssertionError: + print("SOME TESTS FAILED - TRY AGAIN !#!#") + +""" + code_content + """ + +run_tests(__ENTRY_POINT__) +""" + new_blocks.append(CodeBlock(code=code_content, language=block.language)) + + return new_blocks diff --git a/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/prompt.txt b/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/prompt.txt new file mode 100644 index 000000000000..482f50dca311 --- /dev/null +++ b/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/prompt.txt @@ -0,0 +1 @@ +__PROMPT__ diff --git a/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/requirements.txt b/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/requirements.txt new file mode 100644 index 000000000000..5ba1405ce6e0 --- /dev/null +++ b/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/requirements.txt @@ -0,0 +1,4 @@ +pyyaml +/autogen_python/packages/autogen-core +/autogen_python/packages/autogen-ext[openai] +/autogen_python/packages/autogen-agentchat diff --git a/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/scenario.py b/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/scenario.py new file mode 100644 index 000000000000..96d8cd968bee --- /dev/null +++ b/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/scenario.py @@ -0,0 +1,54 @@ +import asyncio +import os +import yaml +from autogen_ext.agents.magentic_one import MagenticOneCoderAgent +from autogen_agentchat.teams import RoundRobinGroupChat +from autogen_agentchat.ui import Console +from autogen_core.models import ModelFamily +from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor +from autogen_agentchat.conditions import TextMentionTermination +from custom_code_executor import CustomCodeExecutorAgent +from autogen_core.models import ChatCompletionClient + +async def main() -> None: + + # Load model configuration and create the model client. + with open("config.yaml", "r") as f: + config = yaml.safe_load(f) + model_client = ChatCompletionClient.load_component(config["model_config"]) + + # Coder + coder_agent = MagenticOneCoderAgent( + name="coder", + model_client=model_client, + ) + + # Executor + executor = CustomCodeExecutorAgent( + name="executor", + code_executor=LocalCommandLineCodeExecutor(), + sources=["coder"], + ) + + # Termination condition + termination = TextMentionTermination(text="TERMINATE", sources=["executor"]) + + # Define a team + agent_team = RoundRobinGroupChat([coder_agent, executor], max_turns=12, termination_condition=termination) + + prompt = "" + with open("prompt.txt", "rt") as fh: + prompt = fh.read() + + task = f"""Complete the following python function. Format your output as Markdown python code block containing the entire function definition: + +```python +{prompt} +``` +""" + + # Run the team and stream messages to the console. + stream = agent_team.run_stream(task=task) + await Console(stream) + +asyncio.run(main()) diff --git a/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/test.txt b/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/test.txt new file mode 100644 index 000000000000..91318587b914 --- /dev/null +++ b/python/packages/agbench/benchmarks/HumanEval/Templates/AgentChat/test.txt @@ -0,0 +1 @@ +__TEST__ diff --git a/python/packages/agbench/benchmarks/HumanEval/config.yaml b/python/packages/agbench/benchmarks/HumanEval/config.yaml new file mode 100644 index 000000000000..9e2f22819d7a --- /dev/null +++ b/python/packages/agbench/benchmarks/HumanEval/config.yaml @@ -0,0 +1,29 @@ +# config.yaml +# +# The contents of this file will be copied into the 'config.yaml' file of +# every expanded Task, just prior to running the scenario. This provides a +# good place to store model or other configurations important for the scenario. + +############################### +# Open AI model configuration # +############################### +model_config: + provider: autogen_ext.models.openai.OpenAIChatCompletionClient + config: + model: gpt-4o + + +############################## +# Ollama model configuration # +############################## +#model_config: +# provider: autogen_ext.models.openai.OpenAIChatCompletionClient +# config: +# model: deepseek-r1:7b +# base_url: http://localhost:11434/v1/ +# api_key: ollama +# model_info: +# function_calling: false +# json_output: false +# vision: false +# family: r1 diff --git a/python/packages/agbench/benchmarks/WebArena/ENV.sample b/python/packages/agbench/benchmarks/WebArena/ENV.sample deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/packages/agbench/benchmarks/WebArena/README.md b/python/packages/agbench/benchmarks/WebArena/README.md deleted file mode 100644 index 74e17f892dad..000000000000 --- a/python/packages/agbench/benchmarks/WebArena/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# WebArena Benchmark - -This scenario implements the [WebArena](https://github.com/web-arena-x/webarena/tree/main) benchmark. The evaluation code has been modified from WebArena in [evaluation_harness](Templates/Common/evaluation_harness) we retain the License from WebArena and include it here [LICENSE](Templates/Common/evaluation_harness/LICENSE). - - -## References - -Zhou, Shuyan, Frank F. Xu, Hao Zhu, Xuhui Zhou, Robert Lo, Abishek Sridhar, Xianyi Cheng et al. "Webarena: A realistic web environment for building autonomous agents." arXiv preprint arXiv:2307.13854 (2023). \ No newline at end of file diff --git a/python/packages/agbench/benchmarks/WebArena/Scripts/custom_tabulate.py b/python/packages/agbench/benchmarks/WebArena/Scripts/custom_tabulate.py deleted file mode 100644 index 6697a08749a5..000000000000 --- a/python/packages/agbench/benchmarks/WebArena/Scripts/custom_tabulate.py +++ /dev/null @@ -1,35 +0,0 @@ -import os -import sys -import re -from agbench.tabulate_cmd import default_tabulate - - -def scorer(instance_dir): - - # Read the console - console_log_file = os.path.join(instance_dir, "console_log.txt") - if not os.path.isfile(console_log_file): - return None - - console_log = "" - with open(console_log_file, "rt") as fh: - console_log = fh.read() - - final_score = None - m = re.search(r"FINAL SCORE:(.*?)\n", console_log, re.DOTALL) - if m: - final_score = m.group(1).strip() - - # Missing the final answer line - if final_score is None: - return None - else: - return float(final_score) > 0 - - -def main(args): - default_tabulate(args, scorer=scorer) - - -if __name__ == "__main__" and __package__ is None: - main(sys.argv) diff --git a/python/packages/agbench/benchmarks/WebArena/Scripts/init_tasks.py b/python/packages/agbench/benchmarks/WebArena/Scripts/init_tasks.py deleted file mode 100644 index 5ba3fd4d08f4..000000000000 --- a/python/packages/agbench/benchmarks/WebArena/Scripts/init_tasks.py +++ /dev/null @@ -1,122 +0,0 @@ -# -# Run this file to download the human_eval dataset, and create a corresponding testbed scenario: -# (default: ../scenarios/human_eval_two_agents_gpt4.jsonl and ./scenarios/human_eval_two_agents_gpt35.jsonl) -# - -import requests -import tarfile -import hashlib -import io -import json -import os -import re -import sys - -URL = "https://raw.githubusercontent.com/web-arena-x/webarena/main/config_files/test.raw.json" - -SCRIPT_PATH = os.path.realpath(__file__) -SCRIPT_NAME = os.path.basename(SCRIPT_PATH) -SCRIPT_DIR = os.path.dirname(SCRIPT_PATH) - -SCENARIO_DIR = os.path.realpath(os.path.join(SCRIPT_DIR, os.path.pardir)) -TEMPLATES_DIR = os.path.join(SCENARIO_DIR, "Templates") -TASKS_DIR = os.path.join(SCENARIO_DIR, "Tasks") -DOWNLOADS_DIR = os.path.join(SCENARIO_DIR, "Downloads") - - -def download(): - """Download the WebArena dataset (if not already downloaded). - Return a JSON list of problem instances.""" - - if not os.path.isdir(DOWNLOADS_DIR): - os.mkdir(DOWNLOADS_DIR) - - json_file = os.path.join(DOWNLOADS_DIR, "test.raw.json") - - if not os.path.isfile(json_file): - # Send a HTTP request to the URL - response = requests.get(URL, stream=True) - response.raise_for_status() - - # If the HTTP request returns a status code 200, proceed - with open(json_file, "wb") as fh: - for chunk in response.iter_content(chunk_size=512): - fh.write(chunk) - - # Load the problems - problems = None - with open(json_file, "rb") as fh: - problems = json.load(fh) - return problems - - -def create_jsonl(name, tasks, template): - """Creates a JSONL scenario file with a given name, dictionary of MATH problems, and template path.""" - - # Create a task directory if it doesn't exist - if not os.path.isdir(TASKS_DIR): - os.mkdir(TASKS_DIR) - - # Create the jsonl file - prompt_fields = ["task_id", "intent_template_id", "sites", "require_login", "start_url", "geolocation", "intent"] - with open(os.path.join(TASKS_DIR, name + ".jsonl"), "wt") as fh: - for task in tasks: - print(f"Converting: {name}, {task['task_id']}") - - task_prompt = {} - for field in prompt_fields: - task_prompt[field] = task[field] - - record = { - "id": str(task["task_id"]), - "template": [os.path.join(TEMPLATES_DIR, "Common"), template], - "substitutions": { - "task_prompt.json.txt": {"__TASK_PROMPT__": json.dumps(task_prompt, indent=4)}, - "full_task.json.txt": {"__FULL_TASK__": json.dumps(task, indent=4)}, - }, - } - - fh.write(json.dumps(record).strip() + "\n") - - -############################################################################### -def main(): - tasks = download() - - # list all directories in the Templates directory - # and populate a dictionary with the name and path - templates = {} - for entry in os.scandir(TEMPLATES_DIR): - if entry.is_dir(): - if entry.name == "Common": # Skip the common template, which will be included in all - continue - templates[re.sub(r"\s", "", entry.name)] = entry.path - - # Divide the tasks by their websites and if they are validation or test - page_groups = dict() - for task in tasks: - - # We don't know how the intent ids are distributed, so hash them to get a uniform distribution - template_hash = hashlib.md5(str(task["intent_template_id"]).encode("utf-8")).hexdigest() - - # The full hash will consist of 32 hexadecimal digits. We can get a 50/50 split by checking if the first digit is in the range (0-7) vs (8-F) - task_set = "validation" if template_hash[0] in "01234567" else "test" - - key = task["sites"][0] - if len(task["sites"]) > 1: - key = "several_sites" - key = task_set + "_" + key - - # key = "__".join(sorted([s for s in task["sites"]])) - if key not in page_groups: - page_groups[key] = list() - page_groups[key].append(task) - - # Create the json files - for t in templates.items(): - for pg in page_groups: - create_jsonl(f"webarena__{pg}_{t[0]}", page_groups[pg], t[1]) - - -if __name__ == "__main__" and __package__ is None: - main() diff --git a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/ATTRIBUTION b/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/ATTRIBUTION deleted file mode 100644 index 0713904fb45b..000000000000 --- a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/ATTRIBUTION +++ /dev/null @@ -1,28 +0,0 @@ -The contents of this `evaluation_harness` folder are adapted from: - - https://github.com/web-arena-x/webarena - -under the following license: - -========================================================================================================= - -Copyright (c) 2024 Jing Yu Koh, Robert Lo, Lawrence Jang, Vikram Duvvur, Ming Chong Lim, and Po-Yu Huang - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/LICENSE b/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/LICENSE deleted file mode 100644 index f49a4e16e68b..000000000000 --- a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file diff --git a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/__init__.py b/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/__init__.py deleted file mode 100644 index e942c1066769..000000000000 --- a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .evaluators import * -from .helper_functions import ( - shopping_get_latest_order_url, - shopping_get_sku_latest_review_author, - shopping_get_sku_latest_review_rating, -) diff --git a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/env_config.py b/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/env_config.py deleted file mode 100644 index ed84ae4735ef..000000000000 --- a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/env_config.py +++ /dev/null @@ -1,90 +0,0 @@ -# websites domain -import os - -REDDIT = os.environ.get("REDDIT", "") -SHOPPING = os.environ.get("SHOPPING", "") -SHOPPING_ADMIN = os.environ.get("SHOPPING_ADMIN", "") -GITLAB = os.environ.get("GITLAB", "") -WIKIPEDIA = os.environ.get("WIKIPEDIA", "") -MAP = os.environ.get("MAP", "") -HOMEPAGE = os.environ.get("HOMEPAGE", "") - -REDDIT_USERNAME = os.environ.get("REDDIT_USERNAME", "") -REDDIT_PASSWORD = os.environ.get("REDDIT_PASSWORD", "") - -GITLAB_USERNAME = os.environ.get("GITLAB_USERNAME", "") -GITLAB_PASSWORD = os.environ.get("GITLAB_PASSWORD", "") - -SHOPPING_USERNAME = os.environ.get("SHOPPING_USERNAME", "") -SHOPPING_PASSWORD = os.environ.get("SHOPPING_PASSWORD", "") - -SHOPPING_ADMIN_USERNAME = os.environ.get("SHOPPING_ADMIN_USERNAME", "") -SHOPPING_ADMIN_PASSWORD = os.environ.get("SHOPPING_ADMIN_PASSWORD", "") - -assert REDDIT and SHOPPING and SHOPPING_ADMIN and GITLAB and WIKIPEDIA and MAP and HOMEPAGE, ( - "Please setup the URLs to each site. Current: \n" - + f"Reddit: {REDDIT}\n" - + f"Shopping: {SHOPPING}\n" - + f"Shopping Admin: {SHOPPING_ADMIN}\n" - + f"Gitlab: {GITLAB}\n" - + f"Wikipedia: {WIKIPEDIA}\n" - + f"Map: {MAP}\n" - + f"Homepage: {HOMEPAGE}\n" -) - -ACCOUNTS = { - "reddit": {"username": REDDIT_USERNAME, "password": REDDIT_PASSWORD}, - "gitlab": {"username": GITLAB_USERNAME, "password": GITLAB_PASSWORD}, - "shopping": {"username": SHOPPING_USERNAME, "password": SHOPPING_PASSWORD}, - "shopping_admin": {"username": SHOPPING_ADMIN_USERNAME, "password": SHOPPING_ADMIN_PASSWORD}, - "shopping_site_admin": {"username": SHOPPING_ADMIN_USERNAME, "password": SHOPPING_ADMIN_PASSWORD}, -} - -URL_MAPPINGS = { - REDDIT: "http://reddit.com", - SHOPPING: "http://onestopmarket.com", - SHOPPING_ADMIN: "http://luma.com/admin", - GITLAB: "http://gitlab.com", - WIKIPEDIA: "http://wikipedia.org", - MAP: "http://openstreetmap.org", - HOMEPAGE: "http://homepage.com", -} - -# ADDED BY MSR Frontiers -######################### -SITE_URLS = { - "reddit": REDDIT, - "gitlab": GITLAB, - "shopping": SHOPPING, - "shopping_admin": SHOPPING_ADMIN, - "shopping_site_admin": SHOPPING_ADMIN, - "map": MAP, - "wikipedia": WIKIPEDIA, -} - -LOGIN_PROMPTS = { - "reddit": f"Type '{REDDIT}' into the address bar to navigate to the site. Click 'Log in', type the username '{ACCOUNTS['reddit']['username']}', and password is '{ACCOUNTS['reddit']['password']}'. Finally click the login button.", - "gitlab": f"Type '{GITLAB}' into the address bar to navigate to the site. At the log in prompt, type the username '{ACCOUNTS['gitlab']['username']}', and the password '{ACCOUNTS['gitlab']['password']}'. Finally click the 'Sign in' button.", - "shopping": f"Type '{SHOPPING}' into the address bar to navigate to the site. Click 'Sign In' at the top of the page. Enter the Email '{ACCOUNTS['shopping']['username']}', and password '{ACCOUNTS['shopping']['password']}'. Finally click the 'Sign In' button.", - "shopping_admin": f"Type '{SHOPPING_ADMIN}' into the address bar to navigate to the site. At the log in prompt, enter the username '{ACCOUNTS['shopping_admin']['username']}', and the password '{ACCOUNTS['shopping_admin']['password']}'. Finally click the 'Sign In' button.", -} - -SITE_DESCRIPTIONS = { - "reddit": "a Postmill forum populated with a large sample of data crawled from Reddit. Postmill is similar to Reddit, but the UI is distinct, and 'subreddits' begin with /f/ rather than /r/", - "gitlab": "a Gitlab site populated with various programming projects. Gitlab is similar to GitHub, though the UIs are slightly different", - "shopping": "an online store built with the Magento open source eCommerce platform", - "shopping_admin": "the content management admin portal for an online store running the Magento open source eCommerce software", -} - - -def url_to_sitename(url): - if url.startswith(REDDIT): - return "reddit" - elif url.startswith(GITLAB): - return "gitlab" - elif url.startswith(SHOPPING): - return "shopping" - elif url.startswith(SHOPPING_ADMIN): - return "shopping_admin" - else: - return None diff --git a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/evaluators.py b/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/evaluators.py deleted file mode 100644 index 05c9a3bc15a1..000000000000 --- a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/evaluators.py +++ /dev/null @@ -1,387 +0,0 @@ -"""From WebArena. base class for evaluation""" - -# answer string match -import collections -import html -import importlib -import json -import time -import urllib -import inspect -from pathlib import Path -from typing import Any, Tuple, Union, TypedDict, Dict - -from beartype import beartype -from nltk.tokenize import word_tokenize # type: ignore -from playwright.async_api import CDPSession, Page - -import numpy as np -import numpy.typing as npt - -from .helper_functions import ( - PseudoPage, - gitlab_get_project_memeber_role, - llm_fuzzy_match, - llm_ua_match, - reddit_get_post_url, - shopping_get_latest_order_url, - shopping_get_sku_latest_review_author, - shopping_get_sku_latest_review_rating, -) - - -# Subset used for evaluation (added by: adamfo) -##################################################################### -class Action(TypedDict): - answer: str - - -Observation = str | npt.NDArray[np.uint8] - - -class StateInfo(TypedDict): - observation: dict[str, Observation] - info: Dict[str, Any] - - -Trajectory = list[Union[Action, StateInfo]] - - -def make_answer_trajecotry(answer: str) -> Trajectory: - ans = Action() - ans["answer"] = answer - return [ans] - - -##################################################################### -class Evaluator(object): - def __init__(self, eval_tag: str = "") -> None: - self.eval_tag = eval_tag - - @beartype - async def __call__( - self, - trajectory: Trajectory, - config_file: Path | str, - page: Page | PseudoPage, - client: CDPSession, - azure_config: dict[str, Any] | None = None, - ) -> float: - raise NotImplementedError - - @staticmethod - def get_last_action(trajectory: Trajectory) -> Action: - try: - # is_bearable(trajectory[-1], Action) - last_action = trajectory[-1] - except Exception: - raise ValueError("The last element of trajectory should be an action, add a fake stop action if needed") - - return last_action # type: ignore[return-value] - - @staticmethod - def get_last_state(trajectory: Trajectory) -> StateInfo: - try: - # is_bearable(trajectory[-2], StateInfo) - last_state = trajectory[-2] - except Exception: - raise ValueError( - "The second last element of trajectory should be a state, add a fake stop action if needed" - ) - - return last_state # type: ignore[return-value] - - -class StringEvaluator(Evaluator): - """Check whether the answer is correct with: - exact match: the answer is exactly the same as the reference answer - must include: each phrase in the reference answer must be included in the answer - fuzzy match: the answer is similar to the reference answer, using LLM judge - """ - - @staticmethod - @beartype - def clean_answer(answer: str) -> str: - answer = answer.strip() - if answer.startswith("'") and answer.endswith("'"): - answer = answer[1:-1] - elif answer.startswith('"') and answer.endswith('"'): - answer = answer[1:-1] - return answer.lower() - - @staticmethod - @beartype - def exact_match(ref: str, pred: str) -> float: - return float(StringEvaluator.clean_answer(pred) == StringEvaluator.clean_answer(ref)) - - @staticmethod - @beartype - def must_include(ref: str, pred: str, tokenize: bool = False) -> float: - clean_ref = StringEvaluator.clean_answer(ref) - clean_pred = StringEvaluator.clean_answer(pred) - # tokenize the answer if the ref is a single word - # prevent false positive (e.g, 0) - if tokenize and len(clean_ref) == 1 and len(word_tokenize(clean_ref)) == 1: - tok_pred = word_tokenize(clean_pred) - return float(clean_ref in tok_pred) - else: - return float(clean_ref in clean_pred) - - @staticmethod - @beartype - def fuzzy_match(ref: str, pred: str, intent: str, azure_config: dict[str, Any] | None) -> float: - return llm_fuzzy_match(pred, ref, intent, azure_config) - - @staticmethod - @beartype - def ua_match(ref: str, pred: str, intent: str, azure_config: dict[str, Any] | None) -> float: - return llm_ua_match(pred, ref, intent, azure_config) - - async def __call__( - self, - trajectory: Trajectory, - config_file: Path | str, - page: Page | PseudoPage | None = None, - client: CDPSession | None = None, - azure_config: dict[str, Any] | None = None, - ) -> float: - with open(config_file, "r") as f: - configs = json.load(f) - - last_action = self.get_last_action(trajectory) - pred = self.clean_answer(last_action["answer"]) - - score = 1.0 - for approach, value in configs["eval"]["reference_answers"].items(): - match approach: - case "exact_match": - score *= self.exact_match(ref=value, pred=pred) - - case "must_include": - assert isinstance(value, list) - for must_value in value: - score *= self.must_include( - ref=must_value, - pred=pred, - tokenize=(len(value) == 1), - ) - case "fuzzy_match": - intent = configs["intent"] - if value == "N/A": - # if the instruction only asks the model to generate N/A when encountering an unachievable task - # without more concrete reasons - score *= self.exact_match(ref=value, pred=pred) - # if the instruction also asks the model to generate the reason why the task is unachievable - # this should be the default as it will prevent false positive N/A` - if score != 1: - score = 1.0 * self.ua_match( - intent=configs["intent"], - ref=configs["eval"]["string_note"], - pred=pred, - azure_config=azure_config, - ) - else: - assert isinstance(value, list) - for reference in value: - score *= self.fuzzy_match( - ref=reference, pred=pred, intent=intent, azure_config=azure_config - ) - return score - - -class URLEvaluator(Evaluator): - """Check URL matching""" - - @beartype - async def __call__( - self, - trajectory: Trajectory, - config_file: Path | str, - page: Page | PseudoPage, - client: CDPSession | None = None, - azure_config: dict[str, Any] | None = None, - ) -> float: - with open(config_file, "r") as f: - configs = json.load(f) - - def clean_url(url: str) -> str: - url = str(url) - url = url.rstrip("/") - return url - - def parse_url(url: str) -> tuple[str, dict[str, list[str]]]: - """Parse a URL into its base, path, and query components.""" - parsed_url = urllib.parse.urlparse(url) - base_path = parsed_url.netloc + parsed_url.path - query = urllib.parse.parse_qs(parsed_url.query) - return base_path, query - - def parse_urls( - urls: list[str], - ) -> tuple[list[str], dict[str, set[str]]]: - """Parse a list of URLs.""" - base_paths = [] - queries = collections.defaultdict(set) - for url in urls: - base_path, query = parse_url(url) - base_paths.append(base_path) - for k, v in query.items(): - queries[k].update(v) - return base_paths, queries - - pred = clean_url(page.url) - ref_urls = configs["eval"]["reference_url"].split(" |OR| ") - ref_urls = [clean_url(url) for url in ref_urls] - matching_rule = configs["eval"].get("url_note", "GOLD in PRED") - if matching_rule == "GOLD in PRED": - print(f"Pred: {pred}") - print(f"Ref: {ref_urls}") - ref_base_paths, ref_queries = parse_urls(ref_urls) - pred_base_paths, pred_query = parse_url(pred) - - base_score = float(any([ref_base_path in pred_base_paths for ref_base_path in ref_base_paths])) - query_score = 1.0 - for k, possible_values in ref_queries.items(): - query_score *= float( - any(possible_ref_value in pred_query.get(k, []) for possible_ref_value in possible_values) - ) - score = base_score * query_score - - else: - raise ValueError(f"Unknown matching rule: {matching_rule}") - - return score - - -class HTMLContentEvaluator(Evaluator): - """Check whether the contents appear in the page""" - - @beartype - async def __call__( - self, - trajectory: Trajectory, - config_file: Path | str, - page: Page | PseudoPage, - client: CDPSession | None = None, - azure_config: dict[str, Any] | None = None, - ) -> float: - with open(config_file, "r") as f: - configs = json.load(f) - - targets = configs["eval"]["program_html"] - - score = 1.0 - for target in targets: - target_url: str = target["url"] # which url to check - if target_url.startswith("func"): - func = target_url.split("func:")[1] - func = func.replace("__last_url__", page.url) - target_url = eval(func) - if inspect.isawaitable(target_url): - target_url = await target_url - - locator: str = target["locator"] # js element locator - - # navigate to that url - if target_url != "last": - await page.goto(target_url) - time.sleep(3) # TODO [shuyanzh]: fix this hard-coded sleep - - # empty, use the full page - if not locator.strip(): - selected_element = await page.content() - # use JS to select the element - elif locator.startswith("document.") or locator.startswith("[...document."): - if "prep_actions" in target: - try: - for prep_action in target["prep_actions"]: - await page.evaluate(f"() => {prep_action}") - except Exception: - pass - try: - selected_element = await page.evaluate(f"() => {locator}") - selected_element = str(selected_element) - if not selected_element: - selected_element = "" - except Exception: - # the page is wrong, return empty - selected_element = "" - # run program to call API - elif locator.startswith("func:"): # a helper function - func = locator.split("func:")[1] - func = func.replace("__page__", "page") - selected_element = eval(func) - if inspect.isawaitable(selected_element): - selected_element = await selected_element - else: - raise ValueError(f"Unknown locator: {locator}") - - selected_element = html.unescape(selected_element) - - if "exact_match" in target["required_contents"]: - required_contents = target["required_contents"]["exact_match"] - cur_score = StringEvaluator.exact_match(ref=required_contents, pred=selected_element) - score *= float(cur_score) - # print(f"[exact match] {cur_score}, selected element: {selected_element}, required contents: {required_contents}") - elif "must_include" in target["required_contents"]: - required_contents = target["required_contents"]["must_include"] - assert isinstance(required_contents, list) - for content in required_contents: - content_or = content.split(" |OR| ") - cur_score = any( - [ - StringEvaluator.must_include( - ref=content, - pred=selected_element, - tokenize=False, - ) - for content in content_or - ] - ) - score *= float(cur_score) - # print(f"[must include] {cur_score}, selected element: {selected_element}, required contents: {content_or}") - else: - raise ValueError(f"Unknown required_contents: {target['required_contents'].keys()}") - return score - - -class EvaluatorComb: - def __init__(self, evaluators: list[Evaluator]) -> None: - self.evaluators = evaluators - - @beartype - async def __call__( - self, - trajectory: Trajectory, - config_file: Path | str, - page: Page | PseudoPage, - client: CDPSession, - azure_config: dict[str, Any] | None = None, - ) -> float: - score = 1.0 - for evaluator in self.evaluators: - cur_score = await evaluator(trajectory, config_file, page, client, azure_config) - score *= cur_score - return score - - -@beartype -def evaluator_router(config_file: Path | str) -> EvaluatorComb: - """Router to get the evaluator class""" - with open(config_file, "r") as f: - configs = json.load(f) - - eval_types = configs["eval"]["eval_types"] - evaluators: list[Evaluator] = [] - for eval_type in eval_types: - match eval_type: - case "string_match": - evaluators.append(StringEvaluator()) - case "url_match": - evaluators.append(URLEvaluator()) - case "program_html": - evaluators.append(HTMLContentEvaluator()) - case _: - raise ValueError(f"eval_type {eval_type} is not supported") - - return EvaluatorComb(evaluators) diff --git a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/helper_functions.py b/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/helper_functions.py deleted file mode 100644 index eff8520b5ab4..000000000000 --- a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/helper_functions.py +++ /dev/null @@ -1,233 +0,0 @@ -"""From WebArena with minor modifications. Implements helper functions to assist evaluation cases where other evaluators are not suitable.""" - -import json -from typing import Any -from urllib.parse import urlparse - -import requests -from playwright.async_api import Page - -from .env_config import ( - ACCOUNTS, - GITLAB, - MAP, - REDDIT, - SHOPPING, - SHOPPING_ADMIN, - WIKIPEDIA, -) - -from .openai_utils import ( - generate_from_openai_chat_completion, -) - -import autogen - - -def shopping_get_auth_token() -> str: - response = requests.post( - url=f"{SHOPPING}/rest/default/V1/integration/admin/token", - headers={"content-type": "application/json"}, - data=json.dumps( - { - "username": ACCOUNTS["shopping_site_admin"]["username"], - "password": ACCOUNTS["shopping_site_admin"]["password"], - } - ), - ) - token: str = response.json() - return token - - -def shopping_get_latest_order_url() -> str: - """Get the latest order url from the shopping website.""" - - header = { - "Authorization": f"Bearer {shopping_get_auth_token()}", - "Content-Type": "application/json", - } - - params = { - "searchCriteria[sortOrders][0][field]": "created_at", - "searchCriteria[sortOrders][0][direction]": "DESC", - "searchCriteria[pageSize]": "1", - } - - response = requests.get(f"{SHOPPING}/rest/V1/orders", params=params, headers=header) - assert response.status_code == 200 - response_obj = response.json()["items"][0] - order_id = int(response_obj["increment_id"]) - order_url = f"{SHOPPING}/sales/order/view/order_id/{order_id}/" - return order_url - - -def shopping_get_sku_latest_review_author(sku: str) -> str: - """Get the latest review for shopping admin.""" - header = { - "Authorization": f"Bearer {shopping_get_auth_token()}", - "Content-Type": "application/json", - } - response = requests.get(f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header) - assert response.status_code == 200 - response_obj = response.json() - if len(response_obj) == 0: - return "" - author: str = response_obj[-1]["nickname"] - return author - - -def shopping_get_sku_latest_review_rating(sku: str) -> str: - """Get the latest review for shopping admin.""" - header = { - "Authorization": f"Bearer {shopping_get_auth_token()}", - "Content-Type": "application/json", - } - response = requests.get(f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header) - assert response.status_code == 200 - response_obj = response.json() - if len(response_obj) == 0: - return "" - assert response_obj[0]["ratings"][0]["rating_name"] == "Rating" - rating: str = str(response_obj[-1]["ratings"][0]["percent"]) - return rating - - -def reddit_get_post_url(url: str) -> str: - """Get the post url""" - # Url is http://domain/f/subreddit/post_id/... - # get domain, subreddit, post_id - domain = urlparse(url).netloc - tok_url = urlparse(url).path.split("/") - # not a valid post/comment url, return the url as is - if len(tok_url) < 4: - return url - if tok_url[1] != "f": - return url - subreddit = urlparse(url).path.split("/")[2] - post_id = urlparse(url).path.split("/")[3] - scheme = urlparse(url).scheme - post_url = f"{scheme}://{domain}/f/{subreddit}/{post_id}/" - return post_url - - -async def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str: - # get the account index - try: - account_idx = await page.evaluate( - f"""(() => {{ - const elements = document.querySelectorAll("td[data-label='Account'] span.gl-avatar-labeled-sublabel"); - let index = -1; // Default value if not found - - for(let i = 0; i < elements.length; i++) {{ - if(elements[i].outerText === '@{account_name}') {{ - index = i; - break; - }} - }} - - return index; - }})()""" - ) - - # get the role - role: str = await page.evaluate( - f"""(() => {{ - return document.querySelectorAll("td.col-max-role span")[{account_idx}].outerText; - }})()""" - ) - except Exception: - role = "" - - return role - - -def llm_fuzzy_match(pred: str, reference: str, question: str, azure_config: dict[str, Any] | None) -> float: - """Check whether the prediction matches the reference with GPT4-turbo""" - messages: list[dict[str, Any]] = [] - # construct the question to ask - message = "Help a teacher to grade the answer of a student given a question. Keep in mind that the student may use different phrasing or wording to answer the question. The goal is to evaluate whether the answer is semantically equivalent to the reference answer.\n" - message += f"question: {question}\n" - message += f"reference answer: {reference}\n" - message += "all the string 'N/A' that you see is a special sequence that means 'not achievable'\n" - message += f"student answer: {pred}\n" - message += "Conclude the judgement by correct/incorrect/partially correct." - messages = [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": message}, - ] - - response = None - if azure_config is None: - response = generate_from_openai_chat_completion( - model="gpt-4-1106-preview", - messages=messages, - temperature=0, - max_tokens=768, - top_p=1.0, - context_length=0, - ).lower() - else: - client = autogen.OpenAIWrapper(**azure_config) - raw_response = client.create(context=None, messages=messages) - response = client.extract_text_or_completion_object(raw_response)[0].lower() - - if "partially correct" in response or "incorrect" in response: - return 0.0 - else: - assert "correct" in response - return 1.0 - - -def llm_ua_match(pred: str, reference: str, question: str, azure_config: dict[str, Any] | None) -> float: - """Check whether the prediction matches the reference with GPT-turbo""" - messages: list[dict[str, Any]] = [] - # construct the question to ask - message = "" - message += f"task: {question}\n" - message += f"actual unachievable reason: {reference}\n" - message += f"reported unachievable reason: {pred}\n" - message += ( - "The task described above is inherently unachievable due to the reason specified under 'actual unachievable reason'. " - "An individual previously attempted this task and was unable to complete it. They provided a reason for their failure, " - "which is listed under 'reported unachievable reason'. Your role is to review both the actual and reported reasons. " - "Determine if the reported reason aligns with the actual reason, even if implicitly. " - "If the stated reason is in line with the actual reason, respond with 'same'. Otherwise, respond with 'different'." - ) - messages = [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": message}, - ] - - response = None - if azure_config is None: - response = generate_from_openai_chat_completion( - model="gpt-4-1106-preview", - messages=messages, - temperature=0, - max_tokens=768, - top_p=1.0, - context_length=0, - ).lower() - else: - client = autogen.OpenAIWrapper(**azure_config) - raw_response = client.create(context=None, messages=messages) - response = client.extract_text_or_completion_object(raw_response)[0].lower() - - if "different" in response: - return 0.0 - else: - assert "same" in response - return 1.0 - - -class PseudoPage: - def __init__(self, original_page: Page, url: str): - self.url = url - self.original_page = original_page - - def __getattr__(self, attr: str) -> Any: - # Delegate attribute access to the original page object - if attr not in ["url"]: - return getattr(self.original_page, attr) - else: - return getattr(self, attr) diff --git a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/openai_utils.py b/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/openai_utils.py deleted file mode 100644 index 1381f392cdf2..000000000000 --- a/python/packages/agbench/benchmarks/WebArena/Templates/Common/evaluation_harness/openai_utils.py +++ /dev/null @@ -1,275 +0,0 @@ -"""Tools to generate from OpenAI prompts. -Adopted from https://github.com/zeno-ml/zeno-build/""" - -import asyncio -import logging -import os -import random -import time -from typing import Any - -import aiolimiter -import openai -from openai import AsyncOpenAI, OpenAI - -client = None -aclient = None -if "OPENAI_API_KEY" not in os.environ and "OAI_CONFIG_LIST" not in os.environ: - raise ValueError("Neither OPENAI_API_KEY nor OAI_CONFIG_LIST is defined in the environment.") - -if "OPENAI_API_KEY" in os.environ: - client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) - aclient = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"]) -from tqdm.asyncio import tqdm_asyncio - - -def retry_with_exponential_backoff( # type: ignore - func, - initial_delay: float = 1, - exponential_base: float = 2, - jitter: bool = True, - max_retries: int = 3, - errors: tuple[Any] = ( - openai.RateLimitError, - openai.BadRequestError, - openai.InternalServerError, - ), -): - """Retry a function with exponential backoff.""" - - def wrapper(*args, **kwargs): # type: ignore - # Initialize variables - num_retries = 0 - delay = initial_delay - - # Loop until a successful response or max_retries is hit or an exception is raised - while True: - try: - - return func(*args, **kwargs) - - # Retry on specified errors - except errors: - # Increment retries - num_retries += 1 - - # Check if max retries has been reached - if num_retries > max_retries: - raise Exception(f"Maximum number of retries ({max_retries}) exceeded.") - - # Increment the delay - delay *= exponential_base * (1 + jitter * random.random()) - - # Sleep for the delay - time.sleep(delay) - - # Raise exceptions for any errors not specified - except Exception as e: - raise e - - return wrapper - - -async def _throttled_openai_completion_acreate( - engine: str, - prompt: str, - temperature: float, - max_tokens: int, - top_p: float, - limiter: aiolimiter.AsyncLimiter, -) -> dict[str, Any]: - async with limiter: - for _ in range(3): - try: - return await aclient.completions.create( - engine=engine, - prompt=prompt, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - ) - except openai.RateLimitError: - logging.warning("OpenAI API rate limit exceeded. Sleeping for 10 seconds.") - await asyncio.sleep(10) - except openai.APIError as e: - logging.warning(f"OpenAI API error: {e}") - break - return {"choices": [{"message": {"content": ""}}]} - - -async def agenerate_from_openai_completion( - prompts: list[str], - engine: str, - temperature: float, - max_tokens: int, - top_p: float, - context_length: int, - requests_per_minute: int = 300, -) -> list[str]: - """Generate from OpenAI Completion API. - - Args: - prompts: list of prompts - temperature: Temperature to use. - max_tokens: Maximum number of tokens to generate. - top_p: Top p to use. - context_length: Length of context to use. - requests_per_minute: Number of requests per minute to allow. - - Returns: - List of generated responses. - """ - if "OPENAI_API_KEY" not in os.environ: - raise ValueError("OPENAI_API_KEY environment variable must be set when using OpenAI API.") - - limiter = aiolimiter.AsyncLimiter(requests_per_minute) - async_responses = [ - _throttled_openai_completion_acreate( - engine=engine, - prompt=prompt, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - limiter=limiter, - ) - for prompt in prompts - ] - responses = await tqdm_asyncio.gather(*async_responses) - return [x["choices"][0]["text"] for x in responses] - - -@retry_with_exponential_backoff -def generate_from_openai_completion( - prompt: str, - engine: str, - temperature: float, - max_tokens: int, - top_p: float, - context_length: int, - stop_token: str | None = None, -) -> str: - if "OPENAI_API_KEY" not in os.environ: - raise ValueError("OPENAI_API_KEY environment variable must be set when using OpenAI API.") - - response = client.completions.create( - prompt=prompt, - engine=engine, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - stop=[stop_token], - ) - answer: str = response["choices"][0]["text"] - return answer - - -async def _throttled_openai_chat_completion_acreate( - model: str, - messages: list[dict[str, str]], - temperature: float, - max_tokens: int, - top_p: float, - limiter: aiolimiter.AsyncLimiter, -) -> dict[str, Any]: - async with limiter: - for _ in range(3): - try: - return await aclient.chat.completions.create( - model=model, - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - ) - except openai.RateLimitError: - logging.warning("OpenAI API rate limit exceeded. Sleeping for 10 seconds.") - await asyncio.sleep(10) - except asyncio.exceptions.TimeoutError: - logging.warning("OpenAI API timeout. Sleeping for 10 seconds.") - await asyncio.sleep(10) - except openai.APIError as e: - logging.warning(f"OpenAI API error: {e}") - break - return {"choices": [{"message": {"content": ""}}]} - - -async def agenerate_from_openai_chat_completion( - messages_list: list[list[dict[str, str]]], - engine: str, - temperature: float, - max_tokens: int, - top_p: float, - context_length: int, - requests_per_minute: int = 300, -) -> list[str]: - """Generate from OpenAI Chat Completion API. - - Args: - messages_list: list of message list - temperature: Temperature to use. - max_tokens: Maximum number of tokens to generate. - top_p: Top p to use. - context_length: Length of context to use. - requests_per_minute: Number of requests per minute to allow. - - Returns: - List of generated responses. - """ - if "OPENAI_API_KEY" not in os.environ: - raise ValueError("OPENAI_API_KEY environment variable must be set when using OpenAI API.") - - limiter = aiolimiter.AsyncLimiter(requests_per_minute) - async_responses = [ - _throttled_openai_chat_completion_acreate( - model=engine, - messages=message, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - limiter=limiter, - ) - for message in messages_list - ] - responses = await tqdm_asyncio.gather(*async_responses) - return [x["choices"][0]["message"]["content"] for x in responses] - - -@retry_with_exponential_backoff -def generate_from_openai_chat_completion( - messages: list[dict[str, str]], - model: str, - temperature: float, - max_tokens: int, - top_p: float, - context_length: int, - stop_token: str | None = None, -) -> str: - if "OPENAI_API_KEY" not in os.environ: - raise ValueError("OPENAI_API_KEY environment variable must be set when using OpenAI API.") - response = client.chat.completions.create( - model=model, - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - ) - answer: str = response.choices[0].message.content - return answer - - -@retry_with_exponential_backoff -# debug only -def fake_generate_from_openai_chat_completion( - messages: list[dict[str, str]], - model: str, - temperature: float, - max_tokens: int, - top_p: float, - context_length: int, - stop_token: str | None = None, -) -> str: - if "OPENAI_API_KEY" not in os.environ: - raise ValueError("OPENAI_API_KEY environment variable must be set when using OpenAI API.") - - answer = "Let's think step-by-step. This page shows a list of links and buttons. There is a search box with the label 'Search query'. I will click on the search box to type the query. So the action I will perform is \"click [60]\"." - return answer diff --git a/python/packages/agbench/src/agbench/res/Dockerfile b/python/packages/agbench/src/agbench/res/Dockerfile index a7da943f343d..033c92162658 100644 --- a/python/packages/agbench/src/agbench/res/Dockerfile +++ b/python/packages/agbench/src/agbench/res/Dockerfile @@ -11,23 +11,20 @@ RUN ln -snf /usr/share/zoneinfo/US/Pacific /etc/localtime && echo "US/Pacific" > # Upgrade pip RUN pip install --upgrade pip -# Pre-load autogen_core dependencies, but not autogen_core itself since we'll often want to install the latest from source -RUN pip install openai pillow aiohttp typing-extensions pydantic types-aiofiles grpcio protobuf +# Pre-load autogen to get the dependencies, but then uninstall them (leaving dependencies in place) +RUN pip install autogen-core autogen-agentchat autogen-ext pyyaml +RUN pip uninstall --yes autogen-core autogen-agentchat autogen-ext + +# Optional markitdown dependencies +RUN pip install markitdown SpeechRecognition pydub youtube_transcript_api==0.6.0 # Pre-load popular packages as per https://learnpython.com/blog/most-popular-python-packages/ RUN pip install numpy pandas matplotlib seaborn scikit-learn requests urllib3 nltk pytest -# Pre-load packages needed for mdconvert file utils -RUN pip install python-docx pdfminer.six python-pptx SpeechRecognition openpyxl pydub mammoth puremagic youtube_transcript_api==0.6.0 - # Pre-load Playwright RUN pip install playwright RUN playwright install --with-deps chromium -# Fix an incompatibility with numpy -RUN pip uninstall --yes numpy -RUN pip install "numpy<2.0" - # Webarena (evaluation code) -RUN pip install beartype aiolimiter -RUN /usr/bin/echo -e "import nltk\nnltk.download('punkt')" | python +#RUN pip install beartype aiolimiter +#RUN /usr/bin/echo -e "import nltk\nnltk.download('punkt')" | python diff --git a/python/packages/agbench/src/agbench/run_cmd.py b/python/packages/agbench/src/agbench/run_cmd.py index 9d94d79170fe..181088d44be8 100644 --- a/python/packages/agbench/src/agbench/run_cmd.py +++ b/python/packages/agbench/src/agbench/run_cmd.py @@ -5,6 +5,7 @@ import os import pathlib import random +import re import shutil import subprocess import sys @@ -14,6 +15,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast import docker +import yaml from azure.core.exceptions import ClientAuthenticationError from azure.identity import DefaultAzureCredential, get_bearer_token_provider from docker.errors import APIError, DockerException, ImageNotFound @@ -38,7 +40,9 @@ # Do not use this field to specify the name of an existing image (e.g., on Dockerhub) DEFAULT_DOCKER_IMAGE_TAG = "agbench" -DEFAULT_ENV_FILE = "ENV.json" +DEFAULT_ENV_FILE_JSON = "ENV.json" +DEFAULT_ENV_FILE_YAML = "ENV.yaml" +DEFAULT_CONFIG_YAML = "config.yaml" # Get a random number generator for subsampling subsample_rng = random.Random(425) @@ -55,10 +59,12 @@ def run_scenarios( scenario: str, n_repeats: int, is_native: bool, + config_file: Union[None, str], token_provider: Optional[Callable[[], str]], docker_image: Optional[str] = None, results_dir: str = "Results", subsample: Union[None, int, float] = None, + env_file: Union[None, str] = None, ) -> None: """ Run a set agbench scenarios a given number of times. @@ -151,10 +157,10 @@ def run_scenarios( print(f"Running scenario {results_repetition}") # Expand the scenario - expand_scenario(scenario_dir, instance, results_repetition) + expand_scenario(scenario_dir, instance, results_repetition, config_file) # Prepare the environment (keys/values that need to be added) - env = get_scenario_env(token_provider) + env = get_scenario_env(token_provider=token_provider, env_file=env_file) # Run the scenario if is_native: @@ -171,7 +177,9 @@ def run_scenarios( file_handle.close() -def expand_scenario(scenario_dir: str, scenario: ScenarioInstance, output_dir: str) -> None: +def expand_scenario( + scenario_dir: str, scenario: ScenarioInstance, output_dir: str, config_file: Union[str, None] +) -> None: """ Expand a scenario into a folder. Despite some awkwardness created by backwards compatibility and notational conveniences, expansion is conceptually simple. @@ -244,16 +252,26 @@ def expand_scenario(scenario_dir: str, scenario: ScenarioInstance, output_dir: s line = line.replace(k, v) fh.write(line) + # Copy the config + if config_file is None: + if os.path.isfile(DEFAULT_CONFIG_YAML): + config_file = DEFAULT_CONFIG_YAML + + if config_file is not None: + src_path = pathlib.Path(config_file).absolute() + dest_path = pathlib.Path(os.path.join(output_dir, "config.yaml")).absolute() + shutil.copyfile(src_path, dest_path) + else: + logging.warning(f"No {DEFAULT_CONFIG_YAML} file found.") -def get_scenario_env( - token_provider: Optional[Callable[[], str]] = None, env_file: str = DEFAULT_ENV_FILE -) -> Dict[str, str]: + +def get_scenario_env(token_provider: Optional[Callable[[], str]] = None, env_file: str | None = None) -> Dict[str, str]: """ Return a dictionary of environment variables needed to run a scenario. Args: config_list (list): An AutoGen OAI_CONFIG_LIST to be used when running scenarios. - env_file (str): The path to the env_file to read. (default: DEFAULT_ENV_FILE) + env_file (str): The path to the env_file to read. (if None, default to DEFAULT_ENV_FILE) Returns: A dictionary of keys and values that need to be added to the system environment. """ @@ -264,10 +282,6 @@ def get_scenario_env( if openai_api_key is not None and len(openai_api_key.strip()) > 0: env["OPENAI_API_KEY"] = openai_api_key - bing_api_key = os.environ.get("BING_API_KEY") - if bing_api_key is not None and len(bing_api_key.strip()) > 0: - env["BING_API_KEY"] = bing_api_key - ## Support Azure auth tokens azure_openai_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN") if not azure_openai_ad_token and token_provider: @@ -282,13 +296,91 @@ def get_scenario_env( env["AZURE_OPENAI_AD_TOKEN"] = azure_openai_ad_token # Update with any values from the ENV.json file - if os.path.isfile(env_file): + env_file_contents: Dict[str, Any] = {} + if env_file is None: + # Env file was not specified, so read the default, or warn if the default file is missing. + if os.path.isfile(DEFAULT_ENV_FILE_YAML): + with open(DEFAULT_ENV_FILE_YAML, "r") as fh: + env_file_contents = yaml.safe_load(fh) + elif os.path.isfile(DEFAULT_ENV_FILE_JSON): + with open(DEFAULT_ENV_FILE_JSON, "rt") as fh: + env_file_contents = json.loads(fh.read()) + logging.warning(f"JSON environment files are deprecated. Migrate to '{DEFAULT_ENV_FILE_YAML}'") + else: + logging.warning( + f"The environment file '{DEFAULT_ENV_FILE_YAML}' was not found. A default environment will be provided, containing the keys: {env.keys()}" + ) + else: + # Env file was specified. Throw an error if the file can't be read. with open(env_file, "rt") as fh: - env.update(json.loads(fh.read())) + if env_file.endswith(".json"): + logging.warning("JSON environment files are deprecated. Migrate to YAML") + env_file_contents = json.loads(fh.read()) + else: + env_file_contents = yaml.safe_load(fh) + + # Apply substitutions in-place + substitute_env_variables(env_file_contents) + + # Flatten any structures + for key, value in env_file_contents.items(): + if isinstance(value, dict) or isinstance(value, list): + env_file_contents[key] = json.dumps(value) + + # Warn about carrying env variables + if "OPENAI_API_KEY" in env and "OPENAI_API_KEY" not in env_file_contents: + logging.warning( + f"Implicit inclusion of OPENAI_API_KEY in the task environment is deprecated. Add it to {DEFAULT_ENV_FILE_YAML} instead. E.g.,\n" + + """ + +OPENAI_API_KEY: ${OPENAI_API_KEY} + +""" + ) + + # Apply the loaded variables + env.update(cast(Dict[str, str], env_file_contents)) return env +def substitute_env_variables(json_data: Any) -> None: + """ + Recursively replaces any instance of "${ENV_VARIABLE}" with os.environ("ENV_VARIABLE") in a structure returned from json.loads() + """ + + def replace_env_var(match: Any) -> str: + var_name = match.group(1) + return os.environ.get(var_name, "") + + pattern = re.compile(r"\$\{(\w+)\}") + + def replace_in_dict(d: Dict[str, Any]) -> None: + for key, value in d.items(): + if isinstance(value, str): + d[key] = pattern.sub(replace_env_var, value) + elif isinstance(value, dict): + replace_in_dict(cast(Dict[str, Any], value)) + elif isinstance(value, list): + # Note: with the task mypy complains of a redundant cast + # without the cast, pyright complains the type is unknown + replace_in_list(cast(List[Any], value)) # type: ignore + + def replace_in_list(lst: List[Any]) -> None: + for i, item in enumerate(lst): + if isinstance(item, str): + lst[i] = pattern.sub(replace_env_var, item) + elif isinstance(item, dict): + replace_in_dict(cast(Dict[str, Any], item)) + elif isinstance(item, list): + replace_in_list(cast(List[Any], item)) # type: ignore + + if isinstance(json_data, dict): + replace_in_dict(cast(Dict[str, Any], json_data)) + elif isinstance(json_data, list): + replace_in_list(cast(List[Any], json_data)) # type: ignore + + def run_scenario_natively(work_dir: str, env: Mapping[str, str], timeout: int = TASK_TIMEOUT) -> None: """ Run a scenario in the native environment. @@ -647,9 +739,11 @@ def run_scenarios_subset( scenarios: List[Dict[str, Any]], n_repeats: int, is_native: bool, + config_file: Union[None, str], docker_image: Optional[str] = None, results_dir: str = "Results", subsample: Union[None, int, float] = None, + env_file: Union[None, str] = None, ) -> None: """ Run a subset of agbench scenarios a given number of times. @@ -680,10 +774,10 @@ def run_scenarios_subset( print(f"Running scenario {results_repetition}") # Expand the scenario - expand_scenario(".", instance, results_repetition) # type: ignore + expand_scenario(".", instance, results_repetition, config_file) # type: ignore # Prepare the environment (keys/values that need to be added) - env = get_scenario_env() + env = get_scenario_env(env_file=env_file) # Run the scenario if is_native: @@ -715,9 +809,11 @@ def run_parallel(args: argparse.Namespace) -> None: scenario_subset, args.repeat, args.native, + args.config, args.docker_image, "Results", args.subsample, + args.env, ) for scenario_subset in scenarios ] @@ -742,7 +838,7 @@ def get_azure_token_provider() -> Optional[Callable[[], str]]: except ClientAuthenticationError: error_message = traceback.format_exc() print( - f"Azure token provider failed loading. Try using 'az login --use-device-code':\n{error_message}\n\nContinuing without Azure token provider..." + f"Azure token provider failed loading. Try using 'az login --use-device-code'\n\nError details:\n{error_message}\n\nContinuing without Azure token provider..." ) logging.disable(logging.NOTSET) return None @@ -776,7 +872,6 @@ def run_cli(args: Sequence[str]) -> None: help='Run on a subsample of the tasks in the JSONL file(s). If a decimal value is specified, then run on the given proportion of tasks in each file. For example "0.7" would run on 70%% of tasks, and "1.0" would run on 100%% of tasks. If an integer value is specified, then randomly select *that* number of tasks from each specified JSONL file. For example "7" would run tasks, while "1" would run only 1 task from each specified JSONL file. (default: 1.0; which is 100%%)', default=None, ) - parser.add_argument( "-p", "--parallel", @@ -784,7 +879,22 @@ def run_cli(args: Sequence[str]) -> None: help="The number of parallel processes to run (default: 1).", default=1, ) - + parser.add_argument( + "-e", + "--env", + type=str, + help="The environment file to load into Docker, or into the native task context (default: '" + + DEFAULT_ENV_FILE_YAML + + "').", + default=None, + ) + parser.add_argument( + "-c", + "--config", + type=str, + help="The config file to copy into the Task (default: '" + DEFAULT_CONFIG_YAML + "').", + default=None, + ) parser.add_argument( "-d", "--docker-image", @@ -802,6 +912,11 @@ def run_cli(args: Sequence[str]) -> None: parsed_args = parser.parse_args(args) + if parsed_args.config is not None: + # Make sure the config file is readable, so that we fail early + with open(parsed_args.config, "r"): + pass + # don't support parallel and subsample together if parsed_args.parallel > 1 and parsed_args.subsample is not None: sys.exit("The options --parallel and --subsample can not be used together currently. Exiting.") @@ -861,7 +976,9 @@ def run_cli(args: Sequence[str]) -> None: scenario=parsed_args.scenario, n_repeats=parsed_args.repeat, is_native=True if parsed_args.native else False, + config_file=parsed_args.config, token_provider=azure_token_provider, docker_image=parsed_args.docker_image, subsample=subsample, + env_file=parsed_args.env, ) diff --git a/python/packages/agbench/src/agbench/tabulate_cmd.py b/python/packages/agbench/src/agbench/tabulate_cmd.py index 4aac0e2d03d0..2b1eb37b933c 100644 --- a/python/packages/agbench/src/agbench/tabulate_cmd.py +++ b/python/packages/agbench/src/agbench/tabulate_cmd.py @@ -19,6 +19,10 @@ "ALL TESTS PASSED !#!#", ] +COMPLETED_STRINGS = [ + "SCENARIO.PY COMPLETE !#!#", +] + EXCLUDE_DIR_NAMES = ["__pycache__"] @@ -63,10 +67,19 @@ def default_scorer(instance_dir: str, success_strings: List[str] = SUCCESS_STRIN if os.path.isfile(console_log): with open(console_log, "rt") as fh: content = fh.read() + + # It succeeded for s in success_strings: if s in content: return True - return False + + # It completed without succeeding + for s in COMPLETED_STRINGS: + if s in content: + return False + + # Has not, or did not, complete + return None else: return None @@ -187,7 +200,9 @@ def _count_equals(value: Optional[bool], trial: int) -> int: failures = 0 for row in all_results: if isinstance(row[i + 1], tuple): - failures += row[i + 1][0] != 1 + failures += row[i + 1][0] not in [1, None] + else: + failures += row[i + 1] not in [1, None] footer_row.append(failures) footer.append(footer_row) From 901ab1276d76d1af987d080fd1dad694dd8c8db5 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Fri, 7 Feb 2025 10:32:14 -0800 Subject: [PATCH 02/36] feat: enhance AzureAIChatCompletionClient validation and add unit tests (#5417) Resolves #5414 --- .../models/azure/_azure_ai_client.py | 13 ++-- .../models/test_azure_ai_model_client.py | 60 +++++++++++++++++++ 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py index 4d8d5eb50630..317bf1db8bb1 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py @@ -246,6 +246,10 @@ def _validate_config(config: Dict[str, Any]) -> AzureAIChatCompletionClientConfi raise ValueError("credential is required for AzureAIChatCompletionClient") if "model_info" not in config: raise ValueError("model_info is required for AzureAIChatCompletionClient") + if "family" not in config["model_info"]: + raise ValueError( + "family is required for model_info in AzureAIChatCompletionClient. See autogen_core.models.ModelFamily for options." + ) if _is_github_model(config["endpoint"]) and "model" not in config: raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient") return cast(AzureAIChatCompletionClientConfig, config) @@ -512,7 +516,8 @@ def capabilities(self) -> ModelInfo: def __del__(self) -> None: # TODO: This is a hack to close the open client - try: - asyncio.get_running_loop().create_task(self._client.close()) - except RuntimeError: - asyncio.run(self._client.close()) + if hasattr(self, "_client"): + try: + asyncio.get_running_loop().create_task(self._client.close()) + except RuntimeError: + asyncio.run(self._client.close()) diff --git a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py index d2662a0a270b..d21c249b9571 100644 --- a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py @@ -7,6 +7,7 @@ from autogen_core import CancellationToken, FunctionCall, Image from autogen_core.models import CreateResult, ModelFamily, UserMessage from autogen_ext.models.azure import AzureAIChatCompletionClient +from autogen_ext.models.azure.config import GITHUB_MODELS_ENDPOINT from azure.ai.inference.aio import ( ChatCompletionsClient, ) @@ -104,6 +105,65 @@ def azure_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient ) +@pytest.mark.asyncio +async def test_azure_ai_chat_completion_client_validation() -> None: + with pytest.raises(ValueError, match="endpoint is required"): + AzureAIChatCompletionClient( + model="model", + credential=AzureKeyCredential("api_key"), + model_info={ + "json_output": False, + "function_calling": False, + "vision": False, + "family": "unknown", + }, + ) + + with pytest.raises(ValueError, match="credential is required"): + AzureAIChatCompletionClient( + model="model", + endpoint="endpoint", + model_info={ + "json_output": False, + "function_calling": False, + "vision": False, + "family": "unknown", + }, + ) + + with pytest.raises(ValueError, match="model is required"): + AzureAIChatCompletionClient( + endpoint=GITHUB_MODELS_ENDPOINT, + credential=AzureKeyCredential("api_key"), + model_info={ + "json_output": False, + "function_calling": False, + "vision": False, + "family": "unknown", + }, + ) + + with pytest.raises(ValueError, match="model_info is required"): + AzureAIChatCompletionClient( + model="model", + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + ) + + with pytest.raises(ValueError, match="family is required"): + AzureAIChatCompletionClient( + model="model", + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_info={ + "json_output": False, + "function_calling": False, + "vision": False, + # Missing family. + }, # type: ignore + ) + + @pytest.mark.asyncio async def test_azure_ai_chat_completion_client(azure_client: AzureAIChatCompletionClient) -> None: assert azure_client From 07c5dc751485d49bf93d4c57e841dc6ea6a24757 Mon Sep 17 00:00:00 2001 From: so2liu Date: Sat, 8 Feb 2025 02:42:27 +0800 Subject: [PATCH 03/36] =?UTF-8?q?fix:=20streaming=20token=20mode=20cannot?= =?UTF-8?q?=20work=20in=20function=20calls=20and=20will=20infi=E2=80=A6=20?= =?UTF-8?q?(#5396)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix: Prevent empty messages accumulation in streaming mode Co-authored-by: Eric Zhu --- .../autogen-ext/src/autogen_ext/models/openai/_openai_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index e6f8da279d96..b1f23e1040dd 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -767,7 +767,7 @@ async def create_stream( stop_reason = choice.finish_reason if chunk.usage is None and stop_reason is None else stop_reason maybe_model = chunk.model # First try get content - if choice.delta.content is not None: + if choice.delta.content: content_deltas.append(choice.delta.content) if len(choice.delta.content) > 0: yield choice.delta.content From 5fcb3b8061c79c61cbc5ecf8f07c868ed5500de2 Mon Sep 17 00:00:00 2001 From: Wei Jen Lu Date: Fri, 7 Feb 2025 19:58:56 +0000 Subject: [PATCH 04/36] Fix typo in Swarm doc (#5435) Fix typo in Swarm doc --- .../docs/src/user-guide/agentchat-user-guide/swarm.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/swarm.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/swarm.ipynb index 41ba103b5d3c..df9a6cfe4a3c 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/swarm.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/swarm.ipynb @@ -46,7 +46,7 @@ "\n", "1. Each agent has the ability to generate {py:class}`~autogen_agentchat.messages.HandoffMessage`\n", " to signal which other agents it can hand off to. For {py:class}`~autogen_agentchat.agents.AssistantAgent`, this means setting the `handoffs` argument.\n", - "2. When the team starts on a task, the first speaker agents operate on the task and make locallized decision about whether to hand off and to whom.\n", + "2. When the team starts on a task, the first speaker agents operate on the task and make localized decision about whether to hand off and to whom.\n", "3. When an agent generates a {py:class}`~autogen_agentchat.messages.HandoffMessage`, the receiving agent takes over the task with the same message context.\n", "4. The process continues until a termination condition is met.\n", "\n", From f20ba9127d1b53f4f21991b9953c848c96405c61 Mon Sep 17 00:00:00 2001 From: afourney Date: Fri, 7 Feb 2025 12:08:28 -0800 Subject: [PATCH 05/36] M1 docker (#5437) Presently MagenticOne and the m1 CLI use the LocalCommandLineExecutor (presumably copied from the agbench code, which already runs in Docker). This pr defaults m1 to Docker, and adds a code_executor parameter to MagenticOne, which defaults to local for now to maintain backward compatibility -- but this behavior is immediately deprecated. --- .../src/autogen_ext/teams/magentic_one.py | 13 ++++++++++++- .../packages/magentic-one-cli/pyproject.toml | 2 +- .../src/magentic_one_cli/_m1.py | 18 +++++++++++++----- python/uv.lock | 4 ++-- 4 files changed, 28 insertions(+), 9 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py b/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py index 0b4bf8aea932..55f068d169d2 100644 --- a/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py +++ b/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py @@ -5,6 +5,7 @@ from autogen_agentchat.base import ChatAgent from autogen_agentchat.teams import MagenticOneGroupChat from autogen_core import CancellationToken +from autogen_core.code_executor import CodeExecutor from autogen_core.models import ChatCompletionClient from autogen_ext.agents.file_surfer import FileSurfer @@ -126,14 +127,24 @@ def __init__( client: ChatCompletionClient, hil_mode: bool = False, input_func: InputFuncType | None = None, + code_executor: CodeExecutor | None = None, ): self.client = client self._validate_client_capabilities(client) + if code_executor is None: + warnings.warn( + "Instantiating MagenticOne without a code_executor is deprecated. Provide a code_executor to clear this warning (e.g., code_executor=LocalCommandLineCodeExecutor() ).", + DeprecationWarning, + stacklevel=2, + ) + code_executor = LocalCommandLineCodeExecutor() + fs = FileSurfer("FileSurfer", model_client=client) ws = MultimodalWebSurfer("WebSurfer", model_client=client) coder = MagenticOneCoderAgent("Coder", model_client=client) - executor = CodeExecutorAgent("Executor", code_executor=LocalCommandLineCodeExecutor()) + executor = CodeExecutorAgent("ComputerTerminal", code_executor=code_executor) + agents: List[ChatAgent] = [fs, ws, coder, executor] if hil_mode: user_proxy = UserProxyAgent("User", input_func=input_func) diff --git a/python/packages/magentic-one-cli/pyproject.toml b/python/packages/magentic-one-cli/pyproject.toml index 82bb75ce95d0..880c57d9036d 100644 --- a/python/packages/magentic-one-cli/pyproject.toml +++ b/python/packages/magentic-one-cli/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ dependencies = [ "pyyaml>=5.1", "autogen-agentchat>=0.4.4,<0.5", - "autogen-ext[openai,magentic-one,rich]>=0.4.4,<0.5", + "autogen-ext[docker,openai,magentic-one,rich]>=0.4.4,<0.5", ] [project.scripts] diff --git a/python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py b/python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py index c3936c050fa2..f9191e35ee2a 100644 --- a/python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py +++ b/python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py @@ -9,6 +9,7 @@ from autogen_agentchat.ui import Console, UserInputManager from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from autogen_ext.code_executors.docker import DockerCommandLineCodeExecutor from autogen_ext.teams.magentic_one import MagenticOne from autogen_ext.ui import RichConsole @@ -113,12 +114,19 @@ def main() -> None: # Run the task async def run_task(task: str, hil_mode: bool, use_rich_console: bool) -> None: input_manager = UserInputManager(callback=cancellable_input) - m1 = MagenticOne(client=client, hil_mode=hil_mode, input_func=input_manager.get_wrapped_callback()) - if use_rich_console: - await RichConsole(m1.run_stream(task=task), output_stats=False, user_input_manager=input_manager) - else: - await Console(m1.run_stream(task=task), output_stats=False, user_input_manager=input_manager) + async with DockerCommandLineCodeExecutor(work_dir=os.getcwd()) as code_executor: + m1 = MagenticOne( + client=client, + hil_mode=hil_mode, + input_func=input_manager.get_wrapped_callback(), + code_executor=code_executor, + ) + + if use_rich_console: + await RichConsole(m1.run_stream(task=task), output_stats=False, user_input_manager=input_manager) + else: + await Console(m1.run_stream(task=task), output_stats=False, user_input_manager=input_manager) task = args.task if isinstance(args.task, str) else args.task[0] asyncio.run(run_task(task, not args.no_hil, args.rich)) diff --git a/python/uv.lock b/python/uv.lock index 87c5320f0343..ceb0e23ada8b 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -3647,7 +3647,7 @@ version = "0.2.3" source = { editable = "packages/magentic-one-cli" } dependencies = [ { name = "autogen-agentchat" }, - { name = "autogen-ext", extra = ["magentic-one", "openai", "rich"] }, + { name = "autogen-ext", extra = ["docker", "magentic-one", "openai", "rich"] }, { name = "pyyaml" }, ] @@ -3659,7 +3659,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "autogen-agentchat", editable = "packages/autogen-agentchat" }, - { name = "autogen-ext", extras = ["openai", "magentic-one", "rich"], editable = "packages/autogen-ext" }, + { name = "autogen-ext", extras = ["docker", "openai", "magentic-one", "rich"], editable = "packages/autogen-ext" }, { name = "pyyaml", specifier = ">=5.1" }, ] From abdc0da4f1954c37e81e90fddcb1f2796a41356e Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Fri, 7 Feb 2025 12:17:56 -0800 Subject: [PATCH 06/36] Add sample chat application with FastAPI (#5433) Introduce a sample chat application using AgentChat and FastAPI, demonstrating single-agent and team chat functionalities, along with state persistence and conversation history management. Resolves #5423 --------- Co-authored-by: Victor Dibia Co-authored-by: Victor Dibia --- python/samples/agentchat_fastapi/.gitignore | 5 + python/samples/agentchat_fastapi/README.md | 70 ++++++ .../samples/agentchat_fastapi/app_agent.html | 195 ++++++++++++++++ python/samples/agentchat_fastapi/app_agent.py | 111 +++++++++ .../samples/agentchat_fastapi/app_team.html | 217 ++++++++++++++++++ python/samples/agentchat_fastapi/app_team.py | 166 ++++++++++++++ .../model_config_template.yaml | 26 +++ 7 files changed, 790 insertions(+) create mode 100644 python/samples/agentchat_fastapi/.gitignore create mode 100644 python/samples/agentchat_fastapi/README.md create mode 100644 python/samples/agentchat_fastapi/app_agent.html create mode 100644 python/samples/agentchat_fastapi/app_agent.py create mode 100644 python/samples/agentchat_fastapi/app_team.html create mode 100644 python/samples/agentchat_fastapi/app_team.py create mode 100644 python/samples/agentchat_fastapi/model_config_template.yaml diff --git a/python/samples/agentchat_fastapi/.gitignore b/python/samples/agentchat_fastapi/.gitignore new file mode 100644 index 000000000000..77066c6b1263 --- /dev/null +++ b/python/samples/agentchat_fastapi/.gitignore @@ -0,0 +1,5 @@ +model_config.yaml +agent_state.json +agent_history.json +team_state.json +team_history.json diff --git a/python/samples/agentchat_fastapi/README.md b/python/samples/agentchat_fastapi/README.md new file mode 100644 index 000000000000..b7ba4012ea85 --- /dev/null +++ b/python/samples/agentchat_fastapi/README.md @@ -0,0 +1,70 @@ +# AgentChat App with FastAPI + +This sample demonstrates how to create a simple chat application using +[AgentChat](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/index.html) +and [FastAPI](https://fastapi.tiangolo.com/). + +You will be using the following features of AgentChat: + +1. Agent: + - `AssistantAgent` + - `UserProxyAgent` with a custom websocket input function +2. Team: `RoundRobinGroupChat` +3. State persistence: `save_state` and `load_state` methods of both agent and team. + +## Setup + +Install the required packages with OpenAI support: + +```bash +pip install -U "autogen-ext[openai]" "fastapi" "uvicorn" "PyYAML" +``` + +To use models other than OpenAI, see the [Models](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/models.html) documentation. + +Create a new file named `model_config.yaml` in the same directory as this README file to configure your model settings. +See `model_config_template.yaml` for an example. + +## Chat with a single agent + +To start the FastAPI server for single-agent chat, run: + +```bash +python app_agent.py +``` + +Visit http://localhost:8001 in your browser to start chatting. + +## Chat with a team of agents + +To start the FastAPI server for team chat, run: + +```bash +python app_team.py +``` + +Visit http://localhost:8002 in your browser to start chatting. + +The team also includes a `UserProxyAgent` agent with a custom websocket input function +that allows the user to send messages to the team from the browser. + +The team follows a round-robin strategy so each agent will take turns to respond. +When it is the user's turn, the input box will be enabled. +Once the user sends a message, the input box will be disabled and the agents +will take turns to respond. + +## State persistence + +The agents and team use the `load_state` and `save_state` methods to load and save +their state from and to files on each turn. +For the agent, the state is saved to and loaded from `agent_state.json`. +For the team, the state is saved to and loaded from `team_state.json`. +You can inspect the state files to see the state of the agents and team +once you have chatted with them. + +When the server restarts, the agents and team will load their state from the state files +to maintain their state across restarts. + +Additionally, the apps uses separate JSON files, +`agent_history.json` and `team_history.json`, to store the conversation history +for display in the browser. diff --git a/python/samples/agentchat_fastapi/app_agent.html b/python/samples/agentchat_fastapi/app_agent.html new file mode 100644 index 000000000000..36a017d54da3 --- /dev/null +++ b/python/samples/agentchat_fastapi/app_agent.html @@ -0,0 +1,195 @@ + + + + + + + AutoGen FastAPI Sample: Agent + + + + +
+
+
+ + +
+
+ + + + + diff --git a/python/samples/agentchat_fastapi/app_agent.py b/python/samples/agentchat_fastapi/app_agent.py new file mode 100644 index 000000000000..9f114a651655 --- /dev/null +++ b/python/samples/agentchat_fastapi/app_agent.py @@ -0,0 +1,111 @@ +import json +import os +from typing import Any + +import aiofiles +import yaml +from autogen_agentchat.agents import AssistantAgent +from autogen_agentchat.messages import TextMessage +from autogen_core import CancellationToken +from autogen_core.models import ChatCompletionClient +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from fastapi.staticfiles import StaticFiles + +app = FastAPI() + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Allows all origins + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers +) + +# Serve static files +app.mount("/static", StaticFiles(directory="."), name="static") + +@app.get("/") +async def root(): + """Serve the chat interface HTML file.""" + return FileResponse("app_agent.html") + +model_config_path = "model_config.yaml" +state_path = "agent_state.json" +history_path = "agent_history.json" + + +async def get_agent() -> AssistantAgent: + """Get the assistant agent, load state from file.""" + # Get model client from config. + async with aiofiles.open(model_config_path, "r") as file: + model_config = yaml.safe_load(await file.read()) + model_client = ChatCompletionClient.load_component(model_config) + # Create the assistant agent. + agent = AssistantAgent( + name="assistant", + model_client=model_client, + system_message="You are a helpful assistant.", + ) + # Load state from file. + if not os.path.exists(state_path): + return agent # Return agent without loading state. + async with aiofiles.open(state_path, "r") as file: + state = json.loads(await file.read()) + await agent.load_state(state) + return agent + + +async def get_history() -> list[dict[str, Any]]: + """Get chat history from file.""" + if not os.path.exists(history_path): + return [] + async with aiofiles.open(history_path, "r") as file: + return json.loads(await file.read()) + + +@app.get("/history") +async def history() -> list[dict[str, Any]]: + try: + return await get_history() + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@app.post("/chat", response_model=TextMessage) +async def chat(request: TextMessage) -> TextMessage: + try: + # Get the agent and respond to the message. + agent = await get_agent() + response = await agent.on_messages(messages=[request], cancellation_token=CancellationToken()) + + # Save agent state to file. + state = await agent.save_state() + async with aiofiles.open(state_path, "w") as file: + await file.write(json.dumps(state)) + + # Save chat history to file. + history = await get_history() + history.append(request.model_dump()) + history.append(response.chat_message.model_dump()) + async with aiofiles.open(history_path, "w") as file: + await file.write(json.dumps(history)) + + assert isinstance(response.chat_message, TextMessage) + return response.chat_message + except Exception as e: + error_message = { + "type": "error", + "content": f"Error: {str(e)}", + "source": "system" + } + raise HTTPException(status_code=500, detail=error_message) from e + + +# Example usage +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/python/samples/agentchat_fastapi/app_team.html b/python/samples/agentchat_fastapi/app_team.html new file mode 100644 index 000000000000..94a823518c16 --- /dev/null +++ b/python/samples/agentchat_fastapi/app_team.html @@ -0,0 +1,217 @@ + + + + + + + AutoGen FastAPI Sample: Team + + + + +
+
+
+ + +
+
+ + + + + diff --git a/python/samples/agentchat_fastapi/app_team.py b/python/samples/agentchat_fastapi/app_team.py new file mode 100644 index 000000000000..123f9770f7ed --- /dev/null +++ b/python/samples/agentchat_fastapi/app_team.py @@ -0,0 +1,166 @@ +import json +import logging +import os +from typing import Any, Awaitable, Callable, Optional + +import aiofiles +import yaml +from autogen_agentchat.agents import AssistantAgent, UserProxyAgent +from autogen_agentchat.base import TaskResult +from autogen_agentchat.messages import TextMessage, UserInputRequestedEvent +from autogen_agentchat.teams import RoundRobinGroupChat +from autogen_core import CancellationToken +from autogen_core.models import ChatCompletionClient +from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from fastapi.staticfiles import StaticFiles + +logger = logging.getLogger(__name__) + +app = FastAPI() + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Allows all origins + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers +) + +model_config_path = "model_config.yaml" +state_path = "team_state.json" +history_path = "team_history.json" + +# Serve static files +app.mount("/static", StaticFiles(directory="."), name="static") + +@app.get("/") +async def root(): + """Serve the chat interface HTML file.""" + return FileResponse("app_team.html") + + +async def get_team( + user_input_func: Callable[[str, Optional[CancellationToken]], Awaitable[str]], +) -> RoundRobinGroupChat: + # Get model client from config. + async with aiofiles.open(model_config_path, "r") as file: + model_config = yaml.safe_load(await file.read()) + model_client = ChatCompletionClient.load_component(model_config) + # Create the team. + agent = AssistantAgent( + name="assistant", + model_client=model_client, + system_message="You are a helpful assistant.", + ) + yoda = AssistantAgent( + name="yoda", + model_client=model_client, + system_message="Repeat the same message in the tone of Yoda.", + ) + user_proxy = UserProxyAgent( + name="user", + input_func=user_input_func, # Use the user input function. + ) + team = RoundRobinGroupChat( + [agent, yoda, user_proxy], + ) + # Load state from file. + if not os.path.exists(state_path): + return team + async with aiofiles.open(state_path, "r") as file: + state = json.loads(await file.read()) + await team.load_state(state) + return team + + +async def get_history() -> list[dict[str, Any]]: + """Get chat history from file.""" + if not os.path.exists(history_path): + return [] + async with aiofiles.open(history_path, "r") as file: + return json.loads(await file.read()) + + +@app.get("/history") +async def history() -> list[dict[str, Any]]: + try: + return await get_history() + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@app.websocket("/ws/chat") +async def chat(websocket: WebSocket): + await websocket.accept() + + # User input function used by the team. + async def _user_input(prompt: str, cancellation_token: CancellationToken | None) -> str: + data = await websocket.receive_json() + message = TextMessage.model_validate(data) + return message.content + + try: + while True: + # Get user message. + data = await websocket.receive_json() + request = TextMessage.model_validate(data) + + try: + # Get the team and respond to the message. + team = await get_team(_user_input) + history = await get_history() + stream = team.run_stream(task=request) + async for message in stream: + if isinstance(message, TaskResult): + continue + await websocket.send_json(message.model_dump()) + if not isinstance(message, UserInputRequestedEvent): + # Don't save user input events to history. + history.append(message.model_dump()) + + # Save team state to file. + async with aiofiles.open(state_path, "w") as file: + state = await team.save_state() + await file.write(json.dumps(state)) + + # Save chat history to file. + async with aiofiles.open(history_path, "w") as file: + await file.write(json.dumps(history)) + + except Exception as e: + # Send error message to client + error_message = { + "type": "error", + "content": f"Error: {str(e)}", + "source": "system" + } + await websocket.send_json(error_message) + # Re-enable input after error + await websocket.send_json({ + "type": "UserInputRequestedEvent", + "content": "An error occurred. Please try again.", + "source": "system" + }) + + except WebSocketDisconnect: + logger.info("Client disconnected") + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + try: + await websocket.send_json({ + "type": "error", + "content": f"Unexpected error: {str(e)}", + "source": "system" + }) + except: + pass + + +# Example usage +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8002) diff --git a/python/samples/agentchat_fastapi/model_config_template.yaml b/python/samples/agentchat_fastapi/model_config_template.yaml new file mode 100644 index 000000000000..9768f5df0fe1 --- /dev/null +++ b/python/samples/agentchat_fastapi/model_config_template.yaml @@ -0,0 +1,26 @@ +# Use Open AI with key +provider: autogen_ext.models.openai.OpenAIChatCompletionClient +config: + model: gpt-4o + api_key: REPLACE_WITH_YOUR_API_KEY +# Use Azure Open AI with key +# provider: autogen_ext.models.openai.AzureOpenAIChatCompletionClient +# config: +# model: gpt-4o +# azure_endpoint: https://{your-custom-endpoint}.openai.azure.com/ +# azure_deployment: {your-azure-deployment} +# api_version: {your-api-version} +# api_key: REPLACE_WITH_YOUR_API_KEY +# Use Azure OpenAI with AD token provider. +# provider: autogen_ext.models.openai.AzureOpenAIChatCompletionClient +# config: +# model: gpt-4o +# azure_endpoint: https://{your-custom-endpoint}.openai.azure.com/ +# azure_deployment: {your-azure-deployment} +# api_version: {your-api-version} +# azure_ad_token_provider: +# provider: autogen_ext.auth.azure.AzureTokenProvider +# config: +# provider_kind: DefaultAzureCredential +# scopes: +# - https://cognitiveservices.azure.com/.default From 707c3cf6559ddda7c7b153ec6ab2fd5e11d8483a Mon Sep 17 00:00:00 2001 From: abhijeethaval Date: Sat, 8 Feb 2025 02:11:01 +0530 Subject: [PATCH 07/36] Update teams.ipynb : In the sample code the termination condition is set to the text "APPROVE" but the documentation mentions "TERMINATE" (#5426) Co-authored-by: Eric Zhu --- .../src/user-guide/agentchat-user-guide/tutorial/teams.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/teams.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/teams.ipynb index d12a273edbda..0ddaf05399e8 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/teams.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/teams.ipynb @@ -112,7 +112,7 @@ "source": [ "The team runs the agents until the termination condition was met.\n", "In this case, the team ran agents following a round-robin order until the the\n", - "termination condition was met when the word \"TERMINATE\" was detected in the\n", + "termination condition was met when the word \"APPROVE\" was detected in the\n", "agent's response.\n", "When the team stops, it returns a {py:class}`~autogen_agentchat.base.TaskResult` object with all the messages produced by the agents in the team." ] From 0008c9cb449258c113d397c917b9ccb6effedbbb Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Fri, 7 Feb 2025 12:52:08 -0800 Subject: [PATCH 08/36] fix: do not count agent event in MaxMessageTermination condition (#5436) Resolves #5425 --- .../conditions/_terminations.py | 16 +++++++++++----- .../tests/test_termination_condition.py | 14 +++++++++++++- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/conditions/_terminations.py b/python/packages/autogen-agentchat/src/autogen_agentchat/conditions/_terminations.py index 7c46545e8c16..7caf8e0d6e71 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/conditions/_terminations.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/conditions/_terminations.py @@ -6,7 +6,7 @@ from typing_extensions import Self from ..base import TerminatedException, TerminationCondition -from ..messages import AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage +from ..messages import AgentEvent, BaseChatMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage class StopMessageTerminationConfig(BaseModel): @@ -48,6 +48,7 @@ def _from_config(cls, config: StopMessageTerminationConfig) -> Self: class MaxMessageTerminationConfig(BaseModel): max_messages: int + include_agent_event: bool = False class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminationConfig]): @@ -55,14 +56,17 @@ class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminatio Args: max_messages: The maximum number of messages allowed in the conversation. + include_agent_event: If True, include :class:`~autogen_agentchat.messages.AgentEvent` in the message count. + Otherwise, only include :class:`~autogen_agentchat.messages.ChatMessage`. Defaults to False. """ component_config_schema = MaxMessageTerminationConfig component_provider_override = "autogen_agentchat.conditions.MaxMessageTermination" - def __init__(self, max_messages: int) -> None: + def __init__(self, max_messages: int, include_agent_event: bool = False) -> None: self._max_messages = max_messages self._message_count = 0 + self._include_agent_event = include_agent_event @property def terminated(self) -> bool: @@ -71,7 +75,7 @@ def terminated(self) -> bool: async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None: if self.terminated: raise TerminatedException("Termination condition has already been reached") - self._message_count += len(messages) + self._message_count += len([m for m in messages if self._include_agent_event or isinstance(m, BaseChatMessage)]) if self._message_count >= self._max_messages: return StopMessage( content=f"Maximum number of messages {self._max_messages} reached, current message count: {self._message_count}", @@ -83,11 +87,13 @@ async def reset(self) -> None: self._message_count = 0 def _to_config(self) -> MaxMessageTerminationConfig: - return MaxMessageTerminationConfig(max_messages=self._max_messages) + return MaxMessageTerminationConfig( + max_messages=self._max_messages, include_agent_event=self._include_agent_event + ) @classmethod def _from_config(cls, config: MaxMessageTerminationConfig) -> Self: - return cls(max_messages=config.max_messages) + return cls(max_messages=config.max_messages, include_agent_event=config.include_agent_event) class TextMentionTerminationConfig(BaseModel): diff --git a/python/packages/autogen-agentchat/tests/test_termination_condition.py b/python/packages/autogen-agentchat/tests/test_termination_condition.py index 68c782e0cb81..54171e69c87e 100644 --- a/python/packages/autogen-agentchat/tests/test_termination_condition.py +++ b/python/packages/autogen-agentchat/tests/test_termination_condition.py @@ -12,7 +12,7 @@ TimeoutTermination, TokenUsageTermination, ) -from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage +from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage, UserInputRequestedEvent from autogen_core.models import RequestUsage @@ -74,6 +74,18 @@ async def test_max_message_termination() -> None: is not None ) + termination = MaxMessageTermination(2, include_agent_event=True) + assert await termination([]) is None + await termination.reset() + assert await termination([TextMessage(content="Hello", source="user")]) is None + await termination.reset() + assert ( + await termination( + [TextMessage(content="Hello", source="user"), UserInputRequestedEvent(request_id="1", source="agent")] + ) + is not None + ) + @pytest.mark.asyncio async def test_mention_termination() -> None: From 9f1c4c924f672c565d15ce7dde33c8501d170a9a Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Fri, 7 Feb 2025 17:57:08 -0500 Subject: [PATCH 09/36] .NET AgentChat Part 1: Abstractions, Base Classes, RoundRobin (#5434) Stands up an initial implementation of the AgentChat project, including abstractions, base classes, and the `RoundRobinGroupChat` implementation. --- dotnet/AutoGen.sln | 26 +- .../AgentChat/Abstractions/ChatAgent.cs | 193 ++++++ .../AgentChat/Abstractions/Handoff.cs | 50 ++ .../AgentChat/Abstractions/ITeam.cs | 17 + .../AgentChat/Abstractions/MessageHandling.cs | 38 ++ .../AgentChat/Abstractions/Messages.cs | 616 ++++++++++++++++++ .../AgentChat/Abstractions/Tasks.cs | 117 ++++ .../AgentChat/Abstractions/Termination.cs | 204 ++++++ .../AgentChat/Abstractions/Tools.cs | 186 ++++++ .../AgentChat/Agents/ChatAgentBase.cs | 50 ++ .../AgentChat/GroupChat/ChatAgentRouter.cs | 98 +++ .../AgentChat/GroupChat/Events.cs | 84 +++ .../AgentChat/GroupChat/GroupChatBase.cs | 235 +++++++ .../GroupChat/GroupChatHandlerRouter.cs | 48 ++ .../GroupChat/GroupChatManagerBase.cs | 187 ++++++ .../AgentChat/GroupChat/GroupChatOptions.cs | 38 ++ .../GroupChat/HostableAgentAdapter.cs | 26 + .../GroupChat/OutputCollectorAgent.cs | 133 ++++ .../GroupChat/RoundRobinGroupChat.cs | 57 ++ .../Microsoft.AutoGen.AgentChat.csproj | 23 + .../Terminations/StopMessageTermination.cs | 40 ++ .../Microsoft.AutoGen/Contracts/AgentId.cs | 2 +- .../Contracts/MessageContext.cs | 3 + .../AgentChatSmokeTest.cs | 106 +++ .../Microsoft.AutoGen.AgentChat.Tests.csproj | 15 + 25 files changed, 2584 insertions(+), 8 deletions(-) create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/ChatAgent.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Handoff.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/ITeam.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/MessageHandling.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Messages.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Tasks.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Termination.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Tools.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Agents/ChatAgentBase.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/ChatAgentRouter.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/Events.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatBase.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatHandlerRouter.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatManagerBase.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatOptions.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/HostableAgentAdapter.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/OutputCollectorAgent.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/RoundRobinGroupChat.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Microsoft.AutoGen.AgentChat.csproj create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/StopMessageTermination.cs create mode 100644 dotnet/test/Microsoft.AutoGen.AgentChat.Tests/AgentChatSmokeTest.cs create mode 100644 dotnet/test/Microsoft.AutoGen.AgentChat.Tests/Microsoft.AutoGen.AgentChat.Tests.csproj diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index 74b7ac965592..a267a11e6fa9 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -122,6 +122,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "GettingStartedGrpc", "sampl EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Core.Grpc.Tests", "test\Microsoft.AutoGen.Core.Grpc.Tests\Microsoft.AutoGen.Core.Grpc.Tests.csproj", "{23A028D3-5EB1-4FA0-9CD1-A1340B830579}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.AgentChat", "src\Microsoft.AutoGen\AgentChat\Microsoft.AutoGen.AgentChat.csproj", "{7F828599-56E8-4597-8F68-EE26FD631417}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.AgentChat.Tests", "test\Microsoft.AutoGen.AgentChat.Tests\Microsoft.AutoGen.AgentChat.Tests.csproj", "{217A4F86-8ADD-4998-90BA-880092A019F5}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -300,16 +304,14 @@ Global {70A8D4B5-D0A6-4098-A6F3-6ED274B65E7D}.Debug|Any CPU.Build.0 = Debug|Any CPU {70A8D4B5-D0A6-4098-A6F3-6ED274B65E7D}.Release|Any CPU.ActiveCfg = Release|Any CPU {70A8D4B5-D0A6-4098-A6F3-6ED274B65E7D}.Release|Any CPU.Build.0 = Release|Any CPU - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.CoreOnly|Any CPU.ActiveCfg = Debug|Any CPU - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.CoreOnly|Any CPU.Build.0 = Debug|Any CPU - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Debug|Any CPU.Build.0 = Debug|Any CPU - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Release|Any CPU.ActiveCfg = Release|Any CPU - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Release|Any CPU.Build.0 = Release|Any CPU {AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Debug|Any CPU.Build.0 = Debug|Any CPU {AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Release|Any CPU.ActiveCfg = Release|Any CPU {AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Release|Any CPU.Build.0 = Release|Any CPU + {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Release|Any CPU.Build.0 = Release|Any CPU {C3740DF1-18B1-4607-81E4-302F0308C848}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {C3740DF1-18B1-4607-81E4-302F0308C848}.Debug|Any CPU.Build.0 = Debug|Any CPU {C3740DF1-18B1-4607-81E4-302F0308C848}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -318,6 +320,14 @@ Global {23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Debug|Any CPU.Build.0 = Debug|Any CPU {23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Release|Any CPU.ActiveCfg = Release|Any CPU {23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Release|Any CPU.Build.0 = Release|Any CPU + {7F828599-56E8-4597-8F68-EE26FD631417}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7F828599-56E8-4597-8F68-EE26FD631417}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7F828599-56E8-4597-8F68-EE26FD631417}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7F828599-56E8-4597-8F68-EE26FD631417}.Release|Any CPU.Build.0 = Release|Any CPU + {217A4F86-8ADD-4998-90BA-880092A019F5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {217A4F86-8ADD-4998-90BA-880092A019F5}.Debug|Any CPU.Build.0 = Debug|Any CPU + {217A4F86-8ADD-4998-90BA-880092A019F5}.Release|Any CPU.ActiveCfg = Release|Any CPU + {217A4F86-8ADD-4998-90BA-880092A019F5}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -368,11 +378,13 @@ Global {EAFFE339-26CB-4019-991D-BCCE8E7D33A1} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {58AD8E1D-83BD-4950-A324-1A20677D78D9} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {70A8D4B5-D0A6-4098-A6F3-6ED274B65E7D} = {CE0AA8D5-12B8-4628-9589-DAD8CB0DDCF6} - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {AAD593FE-A49B-425E-A9FE-A0022CD25E3D} = {F42F9C8E-7BD9-4687-9B63-AFFA461AF5C1} {F42F9C8E-7BD9-4687-9B63-AFFA461AF5C1} = {CE0AA8D5-12B8-4628-9589-DAD8CB0DDCF6} + {3D83C6DB-ACEA-48F3-959F-145CCD2EE135} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {C3740DF1-18B1-4607-81E4-302F0308C848} = {CE0AA8D5-12B8-4628-9589-DAD8CB0DDCF6} {23A028D3-5EB1-4FA0-9CD1-A1340B830579} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {7F828599-56E8-4597-8F68-EE26FD631417} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {217A4F86-8ADD-4998-90BA-880092A019F5} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/ChatAgent.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/ChatAgent.cs new file mode 100644 index 000000000000..6cad4f88dfe2 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/ChatAgent.cs @@ -0,0 +1,193 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatAgent.cs + +using System.Text.RegularExpressions; + +namespace Microsoft.AutoGen.AgentChat.Abstractions; + +/// +/// A valid name for an agent. +/// +/// +/// To ensure parity with Python, we require agent names to be Python identifiers. +/// +public struct AgentName +{ + // + // TODO: Ensure that only valid C# identifiers can pass the validation on Python? + + /* + From https://docs.python.org/3/reference/lexical_analysis.html#identifiers: + ``` + identifier ::= xid_start xid_continue* + id_start ::= + id_continue ::= + xid_start ::= + xid_continue ::= + ``` + + Note: we are not going to deal with normalization; it would require a lot of effort for likely little gain + (this will mean that, strictly speaking, .NET will support a subset of the identifiers that Python does) + + The Unicode category codes mentioned above stand for: + + * Lu - uppercase letters + * Ll - lowercase letters + * Lt - titlecase letters + * Lm - modifier letters + * Lo - other letters + * Nl - letter numbers* + * Mn - nonspacing marks + * Mc - spacing combining marks* + * Nd - decimal numbers + * Pc - connector punctuations + + Of these, most are captured by "word characters" in .NET, \w, only needing \p{Nl} and \p{Mc} to be added. + While Copilot /thinks/ that \p{Pc} is needed, it is not, as it is part of \w in .NET. + + * Other_ID_Start - explicit list of characters in PropList.txt to support backwards compatibility + * Other_ID_Continue - likewise + + # ================================================ + + 1885..1886 ; Other_ID_Start # Mn [2] MONGOLIAN LETTER ALI GALI BALUDA..MONGOLIAN LETTER ALI GALI THREE BALUDA + 2118 ; Other_ID_Start # Sm SCRIPT CAPITAL P + 212E ; Other_ID_Start # So ESTIMATED SYMBOL + 309B..309C ; Other_ID_Start # Sk [2] KATAKANA-HIRAGANA VOICED SOUND MARK..KATAKANA-HIRAGANA SEMI-VOICED SOUND MARK + + # Total code points: 6 + + The pattern for this in .NET is [\u1185-\u1186\u2118\u212E\u309B-\u309C] + + # ================================================ + + 00B7 ; Other_ID_Continue # Po MIDDLE DOT + 0387 ; Other_ID_Continue # Po GREEK ANO TELEIA + 1369..1371 ; Other_ID_Continue # No [9] ETHIOPIC DIGIT ONE..ETHIOPIC DIGIT NINE + 19DA ; Other_ID_Continue # No NEW TAI LUE THAM DIGIT ONE + 200C..200D ; Other_ID_Continue # Cf [2] ZERO WIDTH NON-JOINER..ZERO WIDTH JOINER + 30FB ; Other_ID_Continue # Po KATAKANA MIDDLE DOT + FF65 ; Other_ID_Continue # Po HALFWIDTH KATAKANA MIDDLE DOT + + # Total code points: 16 + + The pattern for this in .NET is [\u00B7\u0387\u1369-\u1371\u19DA\u200C\u200D\u30FB\uFF65] + + # ================================================ + + Classes for "IdStart": {Lu, Ll, Lt, Lm, Lo, Nl, '_', Other_ID_Start} + pattern: [\p{Lu}\p{Ll}\p{Lt}\p{Lm}\p{Lo}\p{Nl}_\u1185-\u1186\u2118\u212E\u309B-\u309C] + + Classes for "IdContinue": {\w, Nl, Mc, Other_ID_Start, Other_ID_Continue} + pattern: [\w\p{Nl}\p{Mc}_\u1185-\u1186\u2118\u212E\u309B-\u309C\u00B7\u0387\u1369-\u1371\u19DA\u200C\u200D\u30FB\uFF65] + + Match group for identifiers: + (?(?:[\p{Lu}\p{Ll}\p{Lt}\p{Lm}\p{Lo}\p{Nl}_\u1185-\u1186\u2118\u212E\u309B-\u309C])(?:[\w\p{Nl}\p{Mc}_\u1185-\u1186\u2118\u212E\u309B-\u309C\u00B7\u0387\u1369-\u1371\u19DA\u200C\u200D\u30FB\uFF65])*) + */ + + private const string IdStartClass = @"[\p{Lu}\p{Ll}\p{Lt}\p{Lm}\p{Lo}\p{Nl}_\u1185-\u1186\u2118\u212E\u309B-\u309C]"; + private const string IdContinueClass = @"[\w\p{Nl}\p{Mc}_\u1185-\u1186\u2118\u212E\u309B-\u309C\u00B7\u0387\u1369-\u1371\u19DA\u200C\u200D\u30FB\uFF65]"; + + private static readonly Regex AgentNameRegex = new Regex($"^{IdStartClass}{IdContinueClass}*$", RegexOptions.Compiled | RegexOptions.Singleline); + + public string Value { get; } + + public AgentName(string name) + { + AgentName.CheckValid(name); + + this.Value = name; + } + + public static bool IsValid(string name) => AgentNameRegex.IsMatch(name); + + public static void CheckValid(string name) + { + if (!AgentName.IsValid(name)) + { + throw new ArgumentException($"Agent name '{name}' is not a valid identifier."); + } + } + + // Implicit cast to string + public static implicit operator string(AgentName agentName) => agentName.Value; +} + +/// +/// A response from calling 's ."/> +/// +public class Response +{ + /// + /// A chat message produced by the agent as a response. + /// + public required ChatMessage Message { get; set; } + + /// + /// Inner messages produced by the agent. + /// + public List? InnerMessages { get; set; } +} + +/// +/// Base class for representing a stream of messages interspacing responses () and +/// internal processing messages (). This functions as a discriminated union. +/// +/// The response type. Usually . +/// The ineternal message type. Usually . +public class StreamingFrame() where TInternalMessage : AgentMessage +{ + public enum FrameType + { + InternalMessage, + Response + } + + public FrameType Type { get; set; } + + public TInternalMessage? InternalMessage { get; set; } + public TResponse? Response { get; set; } +} + +/// +/// Base class for representing a stream of messages with internal messages of any subtype. +/// +/// The response type. Usually . +public class StreamingFrame : StreamingFrame; + +/// +/// The stream frame for 's +/// +public class ChatStreamFrame : StreamingFrame; + +/// +/// An agent that can participate in a chat. +/// +public interface IChatAgent : + IHandleChat, Response>, + IHandleStream, ChatStreamFrame> +{ + /// + /// The name of the agent. This is used by team to uniquely identify the agent.It should be unique within the team. + /// + AgentName Name { get; } + + /// + /// The description of the agent. This is used by team to make decisions about which agents to use.The description + /// should describe the agent's capabilities and how to interact with it. + /// + string Description { get; } + + /// + /// The types of messages that the agent produces. + /// + IEnumerable ProducedMessageTypes { get; } // TODO: Is there a way to make this part of the type somehow? + // Annotations, or IProduce<>? Do we ever actually access this? + + /// + /// Reset the agent to its initialization state. + /// + /// + /// + ValueTask ResetAsync(CancellationToken cancellationToken); +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Handoff.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Handoff.cs new file mode 100644 index 000000000000..0153f111c101 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Handoff.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Handoff.cs + +namespace Microsoft.AutoGen.AgentChat.Abstractions; + +/// +/// Handoff configuration. +/// +/// The name of the target agent receiving the handoff. +/// The description of the handoff such as the condition under which it should happen and the target +/// agent's ability. If not provided, it is generated from the target agent's name. +/// The name of this handoff configuration. If not provided, it is generated from the target agent's name. +/// The message to the target agent. If not provided, it is generated from the target agent's name. +public class Handoff(string target, string? description = null, string? name = null, string? message = null) +{ + private static string? CheckName(string? name) + { + if (name != null && !AgentName.IsValid(name)) + { + throw new ArgumentException($"Handoff name '{name}' is not a valid identifier."); + } + + return name; + } + + /// + /// The name of the target agent receiving the handoff. + /// + public AgentName Target { get; } = new AgentName(target); + + /// + /// The description of the handoff such as the condition under which it should happen and the target. + /// + public string Description { get; } = description ?? $"Handoff to {target}"; + + /// + /// The name of this handoff configuration. + /// + public string Name { get; } = CheckName(name) ?? $"transfer_to_{target.ToLowerInvariant()}"; + + /// + /// The content of the HandoffMessage that will be sent. + /// + public string Message { get; } = message ?? $"Transferred to {target}, adopting the role of {target} immediately."; + + /// + /// Handoff Tool to execute the handoff. + /// + public ITool HandoffTool => new CallableTool(this.Name, this.Description, () => { return this.Message; }); +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/ITeam.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/ITeam.cs new file mode 100644 index 000000000000..e6c16cba430b --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/ITeam.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ITeam.cs + +namespace Microsoft.AutoGen.AgentChat.Abstractions; + +/// +/// A team of agents. +/// +public interface ITeam : ITaskRunner +{ + /// + /// Reset the team and all its participants to its initial state. + /// + /// + /// A representing the asynchronous operation. + ValueTask ResetAsync(CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/MessageHandling.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/MessageHandling.cs new file mode 100644 index 000000000000..4183723e8c81 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/MessageHandling.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// MessageHandling.cs + +namespace Microsoft.AutoGen.AgentChat.Abstractions; + +public interface IHandleChat +{ + ValueTask HandleAsync(TIn item) + { + return this.HandleAsync(item, CancellationToken.None); + } + + ValueTask HandleAsync(TIn item, CancellationToken cancellationToken); +} + +public interface IHandleChat // TODO: Map this to IHandle<> somehow? +{ + ValueTask HandleAsync(TIn item) + { + return this.HandleAsync(item, CancellationToken.None); + } + + ValueTask HandleAsync(TIn item, CancellationToken cancellationToken); +} + +public interface IHandleDefault : IHandleChat +{ +} + +public interface IHandleStream +{ + IAsyncEnumerable StreamAsync(TIn item) + { + return this.StreamAsync(item, CancellationToken.None); + } + + IAsyncEnumerable StreamAsync(TIn item, CancellationToken cancellationToken); +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Messages.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Messages.cs new file mode 100644 index 000000000000..83a3781e66ca --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Messages.cs @@ -0,0 +1,616 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Messages.cs + +using System.Collections; +using System.Diagnostics; +using System.Text; + +using Microsoft.Extensions.AI; + +namespace Microsoft.AutoGen.AgentChat.Abstractions; + +/// +/// The base class for all messages that can be sent between agents. +/// +/// +/// This functions as a combination of both BaseMessage and AgentMessage on the Python side. +/// +public abstract class AgentMessage +{ + /// + /// The name of the agent that sent this message. + /// + public required string Source { get; set; } + + // IMPORTANT NOTE: Unlike the ITypeMarshal implementation in ProtobufTypeMarshal, + // the .ToWire() call on this is intended to be used for directly converting a concrete message type to its leaf representation. + // In the context of Protobuf these may not be the same due to discriminated union types being real types, as opposed to + // a runtime union restriction. + //public IMessage ToWire() + //{ + // return this switch + // { + // ChatMessage chatMessage => ProtobufTypeMarshal.Convert(chatMessage), + // AgentEvent agentEvent => ProtobufTypeMarshal.Convert(agentEvent), + // _ => throw new InvalidOperationException($"Unknown type {this.GetType().Name}"), + // }; + //} +} + +/// +/// Events emitted by agents and teams when they work, not used for agent-to-agent communication. +/// +public abstract class AgentEvent : AgentMessage +{ + public Microsoft.Extensions.AI.ChatMessage ToCompletionClientMessage() + => ToCompletionClientMessage(role: ChatRole.Assistant); + + /// + /// Converts the to a . + /// + /// + /// This should usually be + /// + /// The role of the agent that is sending the message. + /// + /// A that represents the . + /// + public abstract Microsoft.Extensions.AI.ChatMessage ToCompletionClientMessage(ChatRole role); +} + +/// +/// Messages for agent-to-agent communication. +/// +public abstract class ChatMessage : AgentMessage +{ + /// + /// Converts the to a . + /// + /// The role of the agent that is sending the message. + /// + /// A that represents the . + /// + public abstract Microsoft.Extensions.AI.ChatMessage ToCompletionClientMessage(ChatRole role); +} + +// Leaf Classes + +/// +/// A text message. +/// +public class TextMessage : ChatMessage +{ + /// + /// The content of the message. + /// + public required string Content { get; set; } + + /// /> + public override Microsoft.Extensions.AI.ChatMessage ToCompletionClientMessage(ChatRole role) + { + return new Microsoft.Extensions.AI.ChatMessage(role, this.Content) { AuthorName = this.Source }; + } +} + +/// +/// The data inside of a multi-modal message. Can be either a string or an Image. +/// +/// +/// This presents an API surface around the types that are supported by AgentChat, rather +/// than allowing any . +/// +public struct MultiModalData +{ + /// + /// Supported Types of . + /// + public enum Type + { + String, Image + } + + /// + /// Checks the type of the and wraps it in a instance if + /// it is a supported type. + /// + /// The to wrap. + /// A instance wrapping the . + /// + /// Thrown if the is not a or . + /// + public static MultiModalData CheckTypeAndCreate(AIContent item) + { + if (item is TextContent text) + { + return new MultiModalData(text); + } + else if (item is ImageContent image) + { + return new MultiModalData(image); + } + else + { + throw new ArgumentException("Only TextContent and ImageContent are allowed in MultiModalMessage"); + } + } + + /// + /// Initializes a new instance of the with a . + /// + /// The text to wrap. + public MultiModalData(string text) + { + ContentType = Type.String; + AIContent = new TextContent(text); + } + + /// + /// Initializes a new instance of the with a . + /// + /// The to wrap. + public MultiModalData(TextContent textContent) + { + ContentType = Type.String; + AIContent = textContent; + } + + /// + /// Initializes a new instance of the with an . + /// + /// The image to wrap. + public MultiModalData(ImageContent image) + { + ContentType = Type.Image; + AIContent = image; + } + + /// + /// Gets the wrapped by this instance. + /// + public Type ContentType { get; } + + /// + /// Gets the wrapped by this instance. + /// + public AIContent AIContent { get; } +} + +/// +/// A multi-modal message. +/// +public class MultiModalMessage : ChatMessage, IList +{ + /// " + public AIContent this[int index] + { + get => this.Content[index].AIContent; + set => this.Content[index] = MultiModalData.CheckTypeAndCreate(value); + } + + /// + /// The contents of the message. + /// + public List Content { get; private set; } = new List(); + + /// + public int Count => this.Content.Count; + + /// + public bool IsReadOnly => false; + + /// + /// Adds a range of to the message. The type does not need + /// to be checked because it was already validated when the + /// was created. + /// + /// The items to add. + internal void AddRangeUnchecked(IEnumerable items) + { + this.Content.AddRange(items); + } + + /// + /// Checks and adds a range of to the message. + /// + /// The items to add. + public void AddRange(IEnumerable items) + { + foreach (AIContent item in items) + { + this.Content.Add(MultiModalData.CheckTypeAndCreate(item)); + } + } + + /// + /// Adds a range of to the message. + /// + /// The items to add. + public void AddRange(IEnumerable textItems) + { + foreach (TextContent item in textItems) + { + this.Add(item); + } + } + + /// + /// Adds a range of to the message. + /// + /// The items to add. + public void AddRange(IEnumerable textItems) + { + foreach (string item in textItems) + { + this.Add(item); + } + } + + /// + /// Adds a range of to the message. + /// + /// The items to add. + public void AddRange(IEnumerable images) + { + foreach (ImageContent image in images) + { + this.Add(image); + } + } + + /// + /// Checks and adds an to the message. + /// + /// The item to add. + public void Add(AIContent item) + { + this.Content.Add(MultiModalData.CheckTypeAndCreate(item)); + } + + /// + /// Adds a to the message. + /// + /// The text to add. + public void Add(string text) + { + this.Content.Add(new(text)); + } + + /// + /// Adds a to the message. + /// + /// The image to add. + public void Add(ImageContent image) + { + this.Content.Add(new(image)); + } + + /// + /// Adds a to the message. + /// + /// The to add. + public void Add(TextContent text) + { + this.Content.Add(new(text)); + } + + /// + public void Clear() + { + this.Content.Clear(); + } + + /// + public bool Contains(AIContent item) + { + return this.Content.Any(x => x.AIContent == item); + } + + /// + public void CopyTo(AIContent[] array, int arrayIndex) + { + if (array == null) + { + throw new ArgumentNullException(nameof(array)); + } + + if (arrayIndex < 0 || arrayIndex >= array.Length) + { + throw new ArgumentOutOfRangeException(nameof(arrayIndex)); + } + + if (array.Length - arrayIndex < this.Content.Count) + { + throw new ArgumentException("The number of elements in the source is greater than the available space from arrayIndex to the end of the destination array."); + } + + for (var i = 0; i < this.Content.Count; i++) + { + array[arrayIndex + i] = this.Content[i].AIContent; + } + } + + /// + public IEnumerator GetEnumerator() + { + return this.Content.Select(x => x.AIContent).GetEnumerator(); + } + + /// + public int IndexOf(AIContent item) + { + return this.Content.FindIndex(x => x.AIContent == item); + } + + /// + public int IndexOf(string text) + { + return this.Content.FindIndex(x => x.ContentType == MultiModalData.Type.String && ((TextContent)x.AIContent).Text == text); + } + + /// /> + public void Insert(int index, AIContent item) + { + this.Content.Insert(index, MultiModalData.CheckTypeAndCreate(item)); + } + + /// + public void Insert(int index, string text) + { + this.Content.Insert(index, new(text)); + } + + /// + public void Insert(int index, TextContent text) + { + this.Content.Insert(index, new(text)); + } + + /// + public void Insert(int index, ImageContent image) + { + this.Content.Insert(index, new(image)); + } + + /// + public bool Remove(AIContent item) + { + int targetIndex = Content.FindIndex(x => x.AIContent == item); + if (targetIndex == -1) + { + return false; + } + + this.Content.RemoveAt(targetIndex); + return true; + } + + /// + public void RemoveAt(int index) + { + this.Content.RemoveAt(index); + } + + /// + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + /// /> + public override Microsoft.Extensions.AI.ChatMessage ToCompletionClientMessage(ChatRole role) + { + StringBuilder contentBuilder = new StringBuilder(); + foreach (MultiModalData item in this.Content) + { + if (item.ContentType == MultiModalData.Type.String) + { + contentBuilder.AppendLine(item.AIContent.RawRepresentation as string ?? ""); + } + else if (item.ContentType == MultiModalData.Type.Image) + { + contentBuilder.AppendLine("[Image]"); + } + } + + return new Microsoft.Extensions.AI.ChatMessage(role, contentBuilder.ToString()) { AuthorName = this.Source }; + } +} + +/// +/// A message requesting stop of a conversation. +/// +public class StopMessage : ChatMessage +{ + public required string Content { get; set; } + + public override Microsoft.Extensions.AI.ChatMessage ToCompletionClientMessage(ChatRole role) + { + Debug.Assert(role == ChatRole.Assistant, "StopMessage can only come from agents in the Assistant Role"); + return new Microsoft.Extensions.AI.ChatMessage(ChatRole.Assistant, this.Content) { AuthorName = this.Source }; + } +} + +/// +/// A message requesting handoff of a conversation to another agent. +/// +public class HandoffMessage : ChatMessage +{ + /// + /// The name of the target agent to handoff to. + /// + public required string Target { get; set; } + + /// + /// The handoff message to the target agent. + /// + public required string Context { get; set; } + + /// /> + public override Microsoft.Extensions.AI.ChatMessage ToCompletionClientMessage(ChatRole role) + { + Debug.Assert(role == ChatRole.Assistant, "HandoffMessage can only come from agents in the Assistant Role"); + return new Microsoft.Extensions.AI.ChatMessage(ChatRole.Assistant, this.Context) { AuthorName = this.Source }; + } +} + +/// +/// A request to call a function. +/// +public class FunctionCall +{ + // TODO: Should this be part of the Autogen "Core" (and what does that even mean on the .NET side?) + // It is unfortuante that we have to duplicate this type, but in order to be compatible with Python, it is necessary for + // us to be able to process incoming FunctionCalls with parameters in the form of a JSON string. This means that without + // knowing the target function, and unless the types are specified inline in the JSON, we cannot deserialize them in a + // generic manner (or we need to have a central registry of function calls, which is undesirable). + // The solution, for now, is to keep the representation as JSON and provide a helper that binds the JSON to a candidate + // schema. + + /// + /// An identifier representing this specific request. Responses will include this identifier. + /// + public required string Id { get; set; } + + /// + /// The arguments to pass to the function in JSON format. + /// + public string? Arguments { get; set; } + + /// + /// The name of the function to call. + /// + public required string Name { get; set; } +} + +/// +/// The result of a function call. +/// +public class FunctionExecutionResult +{ + /// + /// The identifier of the request that this result is for. + /// + public required string Id { get; set; } + + /// + /// The result of calling the function. + /// + public required string Content { get; set; } +} + +/// +/// An event signaling a request to use tools. +/// +public class ToolCallRequestEvent : AgentEvent +{ + /// + /// The tool calls. + /// + public List Content { get; private set; } = new List(); + + /// /> + public override Microsoft.Extensions.AI.ChatMessage ToCompletionClientMessage(ChatRole role) + { + Debug.Assert(role == ChatRole.Assistant, "ToolCallMessage can only come from agents in the Assistant Role"); + return new Microsoft.Extensions.AI.ChatMessage(ChatRole.Assistant, (IList)this.Content) { AuthorName = this.Source }; + } +} + +/// +/// An event signaling the execution of tool calls. +/// +public class ToolCallExecutionEvent : AgentEvent +{ + /// + /// The tool call results. + /// + public List Content { get; private set; } = new List(); + + /// /> + public override Microsoft.Extensions.AI.ChatMessage ToCompletionClientMessage(ChatRole role) + { + Debug.Assert(role == ChatRole.Tool, "ToolCallResultMessage can only come from agents in the Tool Role"); + return new Microsoft.Extensions.AI.ChatMessage(ChatRole.Tool, (IList)this.Content) { AuthorName = this.Source }; + } +} + +/// +/// A message summarizing the results of tool calls. +/// +public class ToolCallSummaryMessage : ChatMessage +{ + /// + /// Summary of the tool call results. + /// + public required string Content { get; set; } + + public override Extensions.AI.ChatMessage ToCompletionClientMessage(ChatRole role) + { + Debug.Assert(role == ChatRole.Assistant, "ToolCallSummaryMessage can only come from agents in the Assistant Role"); + return new Microsoft.Extensions.AI.ChatMessage(ChatRole.Assistant, this.Content) { AuthorName = this.Source }; + } +} + +/// +/// An event signaling that the user proxy has requested user input. Published prior to invoking the +/// input callback. +/// +public class UserInputRequestedEvent : AgentEvent +{ + /// + /// Identifier for the user input request. + /// + public required string RequestId { get; set; } + + /// /> + public override Extensions.AI.ChatMessage ToCompletionClientMessage(ChatRole role) + { + throw new Exception("UserInputRequestedEvent should not be sent to the completion client"); + } +} + +public static class CompletionChatMessageExtensions +{ + /// + /// Flattens a into a single + /// containing all of the content in the original message as a single string. + /// + /// + /// + /// + /// The to flatten. + /// + /// + /// A new that is a flattened version of the input. + /// + public static Microsoft.Extensions.AI.ChatMessage Flatten(this Microsoft.Extensions.AI.ChatMessage msg) + { + if (msg.Contents.Count == 1 && msg.Contents[0] is TextContent) + { + return msg; + } + + StringBuilder contentBuilder = new StringBuilder(); + foreach (AIContent content in msg.Contents) + { + if (content is TextContent textContent) + { + contentBuilder.AppendLine(textContent.Text); + } + else if (content is ImageContent) + { + contentBuilder.AppendLine("[Image]"); + } + else + { + contentBuilder.AppendLine($"[{content.GetType().Name}]"); + } + } + + return new Microsoft.Extensions.AI.ChatMessage(msg.Role, contentBuilder.ToString()) + { + AuthorName = msg.AuthorName, + AdditionalProperties = msg.AdditionalProperties + }; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Tasks.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Tasks.cs new file mode 100644 index 000000000000..2b4d66d93bdd --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Tasks.cs @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Tasks.cs + +namespace Microsoft.AutoGen.AgentChat.Abstractions; + +/// +/// Result of running a task. +/// +/// Messages produced by the task. +public struct TaskResult(List messages) +{ + /// + /// Message produced by the task. + /// + public List Messages { get; } = messages; + + /// + /// The reason the task stopped. + /// + public string? StopReason = null; +} + +/// +/// The stream frame for . +/// +public class TaskFrame : StreamingFrame +{ + /// + /// Create a new with a response. + /// + /// Result of running a task. + public TaskFrame(TaskResult response) + { + this.Response = response; + this.Type = TaskFrame.FrameType.Response; + } + + /// + /// Create a new with an internal message. + /// + /// The internal message. + public TaskFrame(AgentMessage message) + { + this.InternalMessage = message; + this.Type = TaskFrame.FrameType.InternalMessage; + } +} + +/// +/// A task runner. +/// +public interface ITaskRunner +{ + private static ChatMessage ToMessage(string task) => new TextMessage { Content = task, Source = "user" }; + + /// + /// Run the task and return the result. + /// + /// The task definition in text form. + /// + /// The result of running the task. + async ValueTask RunAsync(string task, CancellationToken cancellationToken = default) => + await this.RunAsync(ToMessage(task)!, cancellationToken); + + /// + /// Run the task and return the result. + /// + /// + /// The runner is stateful and a subsequent call to this method will continue from where the previous + /// call left off.If the task is not specified,the runner will continue with the current task. + /// + /// The task definition as a message. + /// + /// The result of running the task. + /// If no response is generated. + async ValueTask RunAsync(ChatMessage task, CancellationToken cancellationToken = default) + { + await foreach (TaskFrame frame in this.StreamAsync(task, cancellationToken)) + { + if (frame.Type == TaskFrame.FrameType.Response) + { + return frame.Response!; + } + } + + throw new InvalidOperationException("The stream should have returned the final result."); + } + + /// + /// Run the task and produce a stream of and the final + /// is the last frame in the stream. + /// + /// + /// The runner is stateful and a subsequent call to this method will continue from where the previous + /// call left off.If the task is not specified,the runner will continue with the current task. + /// + /// The task definition as a string. + /// + /// A stream of containing internal messages and intermediate results followed by + /// the final + IAsyncEnumerable StreamAsync(string task, CancellationToken cancellationToken = default) => + this.StreamAsync(ToMessage(task), cancellationToken); + + /// + /// Run the task and produce a stream of and the final + /// is the last frame in the stream. + /// + /// + /// The runner is stateful and a subsequent call to this method will continue from where the previous + /// call left off.If the task is not specified,the runner will continue with the current task. + /// + /// The task definition as a message. + /// + /// A stream of containing internal messages and intermediate results followed by + /// the final + IAsyncEnumerable StreamAsync(ChatMessage? task, CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Termination.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Termination.cs new file mode 100644 index 000000000000..f5cb00cfe979 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Termination.cs @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Termination.cs + +namespace Microsoft.AutoGen.AgentChat.Abstractions; + +/// +/// A stateful condition that determines when a conversation should be terminated. +/// +/// A termination condition takes a sequences of objects since the last time the +/// condition was checked, and returns a if the conversation should be terminated, +/// or null otherwise. +/// +/// Once a termination condition has been reached, it must be before it can be used again. +/// +/// Termination conditions can be combined using the and methods. +/// +public interface ITerminationCondition +{ + /// + /// Checks if the termination condition has been reached + /// + bool IsTerminated { get; } + + /// + /// Check if the conversation should be terminated based on the messages received + /// since the last time the condition was called. + /// Return a if the conversation should be terminated, or null otherwise. + /// + /// The messages received since the last time the condition was called. + /// A if the conversation should be terminated, or null + /// otherwise. + /// If the termination condition has already been reached. + ValueTask CheckAndUpdateAsync(IList messages); + + /// + /// Resets the termination condition. + /// + void Reset(); + + /// + /// Combine this termination condition with another using a logical OR. + /// + /// Another termination condition. + /// The combined termination condition, with appropriate short-circuiting. + ITerminationCondition Or(ITerminationCondition other) + { + return new CombinerCondition(CombinerCondition.Or, this, other); + } + + /// + /// Combine this termination condition with another using a logical AND. + /// + /// Another termination condition. + /// The combined termination condition, with appropriate short-circuiting. + ITerminationCondition And(ITerminationCondition other) + { + return new CombinerCondition(CombinerCondition.And, this, other); + } +} + +/// +/// Exception thrown when a termination condition has already been reached. +/// +public sealed class TerminatedException : Exception +{ + public TerminatedException() : base("The termination condition has already been reached.") + { + } +} + +/// +/// A termination condition that combines multiple termination conditions using a logical AND or OR. +/// +internal sealed class CombinerCondition : ITerminationCondition +{ + public const bool Conjunction = true; + public const bool Disjunction = false; + + public const bool And = Conjunction; + public const bool Or = Disjunction; + + private List stopMessages = new List(); + private List clauses; + private readonly bool conjunction; + + /// + /// Create a new with the given conjunction and clauses. + /// + /// The conjunction to use when combining the clauses. + /// The termination conditions to combine. + public CombinerCondition(bool conjunction, params IEnumerable clauses) + { + // Flatten the list of clauses by unwrapping included CombinerConditions if their + // conjunctions match (since combiners with associative conjunctions can be hoisted). + IEnumerable flattened = + clauses.SelectMany(c => + (c is CombinerCondition combiner && combiner.conjunction == conjunction) + ? (IEnumerable)combiner.clauses + : new[] { c }); + + this.conjunction = conjunction; + + this.clauses = flattened.ToList(); + } + + /// + public bool IsTerminated { get; private set; } + + /// + public void Reset() + { + this.stopMessages.Clear(); + this.clauses.ForEach(c => c.Reset()); + + this.IsTerminated = false; + } + + /// + public async ValueTask CheckAndUpdateAsync(IList messages) + { + if (this.IsTerminated) + { + throw new TerminatedException(); + } + + // When operating as a conjunction, we may be accumulated terminating conditions, but we will not fire until + // all of them are complete. In this case, we need to avoid continuing to interact with already terminated + // clauses, because trying to update them will throw + var candidateTerminations = this.conjunction ? this.clauses.Where(clause => !clause.IsTerminated) : clauses; + + // TODO: Do we really need these to be ValueTasks? (Alternatively: Do we really need to run them explicitly + // on every invocation, or is a Worker pattern more appropriate?) + List> tasks = candidateTerminations.Select(c => c.CheckAndUpdateAsync(messages).AsTask()).ToList(); + StopMessage?[] results = await Task.WhenAll(tasks); + + bool raiseTermination = this.conjunction; // if or, we start with false until we observe a true + // if and, we start with true until we observe a false + + foreach (StopMessage? maybeStop in results) + { + if (maybeStop != null) + { + this.stopMessages.Add(maybeStop); + if (!this.conjunction) + { + // If any clause terminates, the disjunction terminates + raiseTermination = true; + } + } + else if (this.conjunction) + { + // If any clause does not terminate, the conjunction does not terminate + raiseTermination = false; + } + } + + if (raiseTermination) + { + this.IsTerminated = true; + + return new StopMessage + { + Content = string.Join("; ", stopMessages.Select(sm => sm.Content)), + Source = string.Join(", ", stopMessages.Select(sm => sm.Source)) + }; + } + + return null; + } + + /// + /// + /// If this condition is already a disjunction, the new condition is added to the list of clauses. + /// + ITerminationCondition ITerminationCondition.Or(ITerminationCondition other) + { + if (this.conjunction == Or) + { + this.clauses.Add(other); + return this; + } + else + { + return new CombinerCondition(Or, this, new CombinerCondition(Or, other)); + } + } + + /// + /// + /// If this condition is already a conjunction, the new condition is added to the list of clauses. + /// + ITerminationCondition ITerminationCondition.And(ITerminationCondition other) + { + if (this.conjunction == And) + { + this.clauses.Add(other); + return this; + } + else + { + return new CombinerCondition(And, this, new CombinerCondition(And, other)); + } + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Tools.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Tools.cs new file mode 100644 index 000000000000..fe10182d11bf --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Tools.cs @@ -0,0 +1,186 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Tools.cs + +using System.ComponentModel; +using System.Reflection; +using Microsoft.Extensions.AI; + +namespace Microsoft.AutoGen.AgentChat.Abstractions; + +// TODO: This likely should live as a "Component" in an Agent-building ClassLib? +// It seems like it could have applicability beyond AgentChat. + +public static class ReflectionExtensions +{ + public static AIFunctionParameterMetadata ToAIFunctionMetadata(this ParameterInfo pi) + { + return new AIFunctionParameterMetadata(pi.Name!) + { + Description = pi.GetCustomAttribute()?.Description, + + ParameterType = pi.ParameterType, + + HasDefaultValue = pi.HasDefaultValue, + IsRequired = !pi.HasDefaultValue, + DefaultValue = pi.DefaultValue, + + // Schema = JSONSchema of type + }; + } + + public static AIFunctionReturnParameterMetadata ToAIFunctionReturnMetadata(this ParameterInfo rpi) + { + return new AIFunctionReturnParameterMetadata + { + Description = rpi.GetCustomAttribute()?.Description, + + ParameterType = rpi.ParameterType + + //Schema = JSONSchema of type + }; + } +} + +public class ParameterSchema(string name, Type type, bool isRequired = false, object? defaultValue = default) +{ + public string Name { get; } = name; + public Type Type { get; } = type; + public bool IsRequired { get; } = isRequired; + + public object? DefaultValue { get; } = defaultValue; + + public static implicit operator ParameterSchema(ParameterInfo parameterInfo) + { + Type parameterType = parameterInfo.ParameterType; + return ParameterSchema.Create(parameterType, parameterInfo.Name!, parameterInfo.HasDefaultValue, parameterInfo.DefaultValue); + } + + public static implicit operator ParameterSchema(AIFunctionParameterMetadata parameterMetadata) + { + Type parameterType = parameterMetadata.ParameterType!; // TODO: Deal with missing ParameterTypes + return ParameterSchema.Create(parameterType, + parameterMetadata.Name, + parameterMetadata.IsRequired, + parameterMetadata.DefaultValue); + } +} + +// TODO: Can this be obviated by AIFunctionParameter? +public class ParameterSchema(string name, bool isRequired = false, T? defaultValue = default) + : ParameterSchema(name, typeof(T), isRequired, defaultValue) +{ + public static ParameterSchema Create(Type type, string name, bool isRequired = false, object? defaultValue = default) + { + Type parameterSchemaType = typeof(ParameterSchema<>).MakeGenericType(type); + ParameterSchema? maybeResult = Activator.CreateInstance(parameterSchemaType, name, isRequired, defaultValue) as ParameterSchema; + return maybeResult!; + } +} + +/// +/// A tool that can be executed by agents. +/// +public interface ITool +{ + string Name { get; } + string Description { get; } + + public IEnumerable Parameters { get; } + public Type ReturnType { get; } + + // TODO: State serialization + + // TODO: Can we somehow make this a ValueTask? + public Task ExecuteAsync(IEnumerable parameters, CancellationToken cancellationToken = default); + + /// + /// This tool represented as an . + /// + public AIFunction AIFunction + { + get + { + return CallableTool.CreateAIFunction(this.Name, this.Description, this.ExecuteAsync); + } + } +} + +public static class TypeExtensions +{ + private static ISet TaskTypes = new HashSet([typeof(Task<>), typeof(ValueTask<>)]); + + public static Type UnwrapReturnIfAsync(this Type type) + { + if (type.IsGenericType && TaskTypes.Contains(type.GetGenericTypeDefinition())) + { + return type.GetGenericArguments()[0]; + } + else if (type == typeof(Task) || type == typeof(ValueTask)) + { + return typeof(void); + } + else + { + return type; + } + } +} + +/// +/// Projects a as an . +/// +/// The to wrap. +public class AIFunctionTool(AIFunction aiFunction) : ITool +{ + /// + public AIFunction AIFunction { get; } = aiFunction; + + /// + public string Name => this.AIFunction.Metadata.Name; + + /// + public string Description => this.AIFunction.Metadata.Description; + + /// + public IEnumerable Parameters => from rawParameter in this.AIFunction.Metadata.Parameters + select (ParameterSchema)rawParameter; + + /// + public Type ReturnType => this.AIFunction.Metadata.ReturnParameter.ParameterType!; // TODO: Deal with missing return types + + /// + public Task ExecuteAsync(IEnumerable parameters, CancellationToken cancellationToken = default) + => this.ExecuteAsync(parameters, cancellationToken); +} + +/// +/// Projects a delegate as a by wrapping it in . +/// +/// The name of the tool. +/// The description of the tool. +/// The delegate to wrap. +public class CallableTool(string name, string description, Delegate callable) + : AIFunctionTool(CreateAIFunction(name, description, callable)) +{ + internal static AIFunction CreateAIFunction(string name, string description, Delegate callable) + { + MethodInfo methodInfo = callable.Method; + + IEnumerable parameters = + from parameterInfo in methodInfo.GetParameters() + select parameterInfo.ToAIFunctionMetadata(); + + AIFunctionReturnParameterMetadata returnParameter = methodInfo.ReturnParameter.ToAIFunctionReturnMetadata(); + + AIFunctionFactoryCreateOptions createOptions = new() + { + Name = name, + Description = description, + Parameters = parameters.ToList(), + ReturnParameter = returnParameter, + // SerializerOptions = TODO: How do we maintain consistency with Python? + }; + + return AIFunctionFactory.Create(callable, createOptions); + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Agents/ChatAgentBase.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Agents/ChatAgentBase.cs new file mode 100644 index 000000000000..1ddffad4aa70 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Agents/ChatAgentBase.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatAgentBase.cs + +using System.Runtime.CompilerServices; +using Microsoft.AutoGen.AgentChat.Abstractions; + +namespace Microsoft.AutoGen.AgentChat.Agents; + +/// +/// Base class for a chat agent. +/// +public abstract class ChatAgentBase : IChatAgent +{ + public ChatAgentBase(string name, string description) + { + Name = new AgentName(name); + Description = description; + } + + /// + public AgentName Name { get; } + + /// + public string Description { get; } + + /// + public virtual async IAsyncEnumerable StreamAsync(IEnumerable item, [EnumeratorCancellation] CancellationToken cancellationToken) + { + Response response = await (this).HandleAsync(item, cancellationToken); + if (response.InnerMessages != null) + { + foreach (var message in response.InnerMessages) + { + // It would be really nice to have type unions in .NET; need to think about how to make this interface nicer. + yield return new ChatStreamFrame { Type = ChatStreamFrame.FrameType.InternalMessage, InternalMessage = message }; + } + } + + yield return new ChatStreamFrame { Type = ChatStreamFrame.FrameType.Response, Response = response }; + } + + /// + public abstract IEnumerable ProducedMessageTypes { get; } + + /// + public abstract ValueTask HandleAsync(IEnumerable item, CancellationToken cancellationToken); + + /// + public abstract ValueTask ResetAsync(CancellationToken cancellationToken); +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/ChatAgentRouter.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/ChatAgentRouter.cs new file mode 100644 index 000000000000..a195cd4d6743 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/ChatAgentRouter.cs @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatAgentRouter.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AutoGen.AgentChat.GroupChat; + +public struct AgentChatConfig(IChatAgent chatAgent, string parentTopicType, string outputTopicType) +{ + public string ParticipantTopicType => this.Name; + public string ParentTopicType { get; } = parentTopicType; + public string OutputTopicType { get; } = outputTopicType; + + public IChatAgent ChatAgent { get; } = chatAgent; + + public string Name => this.ChatAgent.Name; + public string Description => this.ChatAgent.Description; +} + +internal sealed class ChatAgentRouter : HostableAgentAdapter, + IHandle, + IHandle, + IHandle, + IHandle +{ + private readonly TopicId parentTopic; + private readonly TopicId outputTopic; + private readonly IChatAgent agent; + + public ChatAgentRouter(AgentInstantiationContext agentCtx, AgentChatConfig config, ILogger? logger = null) : base(agentCtx, config.Description, logger) + { + this.parentTopic = new TopicId(config.ParentTopicType, this.Id.Key); + this.outputTopic = new TopicId(config.OutputTopicType, this.Id.Key); + + this.agent = config.ChatAgent; + } + + public List MessageBuffer { get; private set; } = new(); + + public ValueTask HandleAsync(GroupChatStart item, MessageContext messageContext) + { + if (item.Messages != null) + { + this.MessageBuffer.AddRange(item.Messages); + } + + return ValueTask.CompletedTask; + } + + public ValueTask HandleAsync(GroupChatAgentResponse item, MessageContext messageContext) + { + this.MessageBuffer.Add(item.AgentResponse.Message); + + return ValueTask.CompletedTask; + } + + public async ValueTask HandleAsync(GroupChatRequestPublish item, MessageContext messageContext) + { + Response? response = null; + + // TODO: Is there a better abstraction here than IAsyncEnumerable? Though the akwardness mainly comes from + // the lack of real type unions in C#, which is why we need to create the StreamingFrame type in the first + // place. + await foreach (ChatStreamFrame frame in this.agent.StreamAsync(this.MessageBuffer, messageContext.CancellationToken)) + { + // TODO: call publish message + switch (frame.Type) + { + case ChatStreamFrame.FrameType.Response: + await this.PublishMessageAsync(new GroupChatMessage { Message = frame.Response!.Message }, this.outputTopic); + response = frame.Response; + break; + case ChatStreamFrame.FrameType.InternalMessage: + await this.PublishMessageAsync(new GroupChatMessage { Message = frame.InternalMessage! }, this.outputTopic); + break; + } + } + + if (response == null) + { + throw new InvalidOperationException("The agent did not produce a final response. Check the agent's on_messages_stream method."); + } + + this.MessageBuffer.Clear(); + + await this.PublishMessageAsync(new GroupChatAgentResponse { AgentResponse = response }, this.parentTopic); + } + + public ValueTask HandleAsync(GroupChatReset item, MessageContext messageContext) + { + this.MessageBuffer.Clear(); + return this.agent.ResetAsync(messageContext.CancellationToken); + } +} + diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/Events.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/Events.cs new file mode 100644 index 000000000000..8eb88ddb3720 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/Events.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Events.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; + +// using ProtobufTypeMarshal = Microsoft.AutoGen.AgentChat.WireProtocol.ProtobufTypeMarshal; + +namespace Microsoft.AutoGen.AgentChat.GroupChat; + +public class GroupChatEventBase /*: IWireable*/ +{ + // public IMessage ToWire() + // { + // return this switch + // { + // GroupChatStart groupChatStart => ProtobufTypeMarshal.Convert(groupChatStart), + // GroupChatAgentResponse groupChatAgentResponse => ProtobufTypeMarshal.Convert(groupChatAgentResponse), + // GroupChatRequestPublish groupChatRequestPublish => ProtobufTypeMarshal.Convert(groupChatRequestPublish), + // GroupChatMessage groupChatMessage => ProtobufTypeMarshal.Convert(groupChatMessage), + // GroupChatTermination groupChatTermination => ProtobufTypeMarshal.Convert(groupChatTermination), + // GroupChatReset groupChatReset => ProtobufTypeMarshal.Convert(groupChatReset), + // _ => throw new InvalidOperationException($"Unknown type {this.GetType().Name}"), + // }; + // } +} + +/// +/// A request to start a group chat. +/// +public class GroupChatStart : GroupChatEventBase +{ + /// + /// An optional list of messages to start the group chat. + /// + public List? Messages { get; set; } +} + +/// +/// A response published to a group chat. +/// +public class GroupChatAgentResponse : GroupChatEventBase +{ + /// + /// The response from a agent. + /// + public required Response AgentResponse { get; set; } +} + +/// +/// A request to publish a message to a group chat. +/// +public class GroupChatRequestPublish : GroupChatEventBase +{ +} + +/// +/// A message from a group chat. +/// +public class GroupChatMessage : GroupChatEventBase +{ + /// + /// The message that was published. + /// + public required AgentMessage Message { get; set; } +} + +/// +/// A message indicating that group chat was terminated. +/// +public class GroupChatTermination : GroupChatEventBase +{ + /// + /// The stop message that indicates the reason of termination. + /// + public required StopMessage Message { get; set; } +} + +/// +/// A request to reset the agents in the group chat. +/// +public class GroupChatReset : GroupChatEventBase +{ +} + diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatBase.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatBase.cs new file mode 100644 index 000000000000..5192cac6d086 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatBase.cs @@ -0,0 +1,235 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GroupChatBase.cs + +using System.Diagnostics; +using System.Reflection; +using System.Runtime.CompilerServices; +using Microsoft.AutoGen.AgentChat.Abstractions; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core; + +namespace Microsoft.AutoGen.AgentChat.GroupChat; + +internal static class AgentsRuntimeExtensions +{ + public static async ValueTask RegisterChatAgentAsync(this IAgentRuntime runtime, AgentChatConfig config) + { + AgentType type = config.Name; + + AgentType resultType = await runtime.RegisterAgentFactoryAsync(type, + (id, runtime) => + { + AgentInstantiationContext agentContext = new AgentInstantiationContext(id, runtime); + return ValueTask.FromResult(new ChatAgentRouter(agentContext, config)); + }); + + await runtime.AddSubscriptionAsync(new TypeSubscription(config.ParticipantTopicType, type)); + await runtime.AddSubscriptionAsync(new TypeSubscription(config.ParentTopicType, type)); + + return resultType; + } + + public static async ValueTask RegisterGroupChatManagerAsync(this IAgentRuntime runtime, GroupChatOptions options, string teamId, Func factory) + where TManager : GroupChatManagerBase + { + AgentType type = GroupChatBase.GroupChatManagerTopicType; + AgentId expectedId = new AgentId(type, teamId); + + AgentType resultType = await runtime.RegisterAgentFactoryAsync(type, + (id, runtime) => + { + Debug.Assert(expectedId == id, $"Expecting the AgentId {expectedId} to be the teamId {id}"); + + AgentInstantiationContext agentContext = new AgentInstantiationContext(id, runtime); + TManager gcm = factory(options); // TODO: Should we allow this to be async? + + return ValueTask.FromResult(new GroupChatHandlerRouter(agentContext, gcm)); + }); + + await runtime.AddSubscriptionAsync(new TypeSubscription(GroupChatBase.GroupChatManagerTopicType, resultType)); + await runtime.AddSubscriptionAsync(new TypeSubscription(options.GroupChatTopicType, resultType)); + + return resultType; + } + + public static async ValueTask RegisterOutputCollectorAsync(this IAgentRuntime runtime, IOutputCollectionSink sink, string outputTopicType) + { + AgentType type = GroupChatBase.CollectorAgentType; + AgentType resultType = await runtime.RegisterAgentFactoryAsync(type, + (id, runtime) => + { + AgentInstantiationContext agentContext = new AgentInstantiationContext(id, runtime); + return ValueTask.FromResult(new OutputCollectorAgent(agentContext, sink)); + }); + + await runtime.AddSubscriptionAsync(new TypeSubscription(outputTopicType, type)); + + return resultType; + } +} + +public abstract class GroupChatBase : ITeam where TManager : GroupChatManagerBase +{ + // TODO: Where do these come from? + internal const string GroupTopicType = "group_topic"; + internal const string OutputTopicType = "output_topic"; + internal const string GroupChatManagerTopicType = "group_chat_manager"; + internal const string CollectorAgentType = "collect_output_messages"; + + private GroupChatOptions GroupChatOptions { get; } + + private readonly List messageThread = new(); + private Dictionary Participants { get; } = new(); + + protected GroupChatBase(List participants, ITerminationCondition? terminationCondition = null, int? maxTurns = null) + { + this.GroupChatOptions = new GroupChatOptions(GroupTopicType, OutputTopicType) + { + TerminationCondition = terminationCondition, + MaxTurns = maxTurns, + }; + + foreach (var participant in participants) + { + AgentChatConfig config = new AgentChatConfig(participant, GroupTopicType, OutputTopicType); + this.Participants[participant.Name] = config; + this.GroupChatOptions.Participants[participant.Name] = new GroupParticipant(config.ParticipantTopicType, participant.Description); + } + + this.messageThread = new List(); // TODO: Allow injecting this + + this.TeamId = Guid.NewGuid().ToString().ToLowerInvariant(); + } + + public string TeamId + { + get; + private set; + } + + public virtual TManager CreateChatManager(GroupChatOptions options) + { + try + { + if (Activator.CreateInstance(typeof(TManager), options) is TManager result) + { + return result; + }; + } + catch (TargetInvocationException tie) + { + throw new Exception("Could not create chat manager", tie.InnerException); + } + catch (Exception ex) + { + throw new Exception("Could not create chat manager", ex); + } + + throw new Exception("Could not create chat manager; make sure that it contains a ctor() or ctor(GroupChatOptions), or override the CreateChatManager method"); + } + + // TODO: Turn this into an IDisposable-based utility + private int running; // = 0 + private bool EnsureSingleRun() + { + return Interlocked.CompareExchange(ref running, 1, 0) == 0; + } + + private void EndRun() + { + this.running = 0; + } + + public IAsyncEnumerable StreamAsync(string task, CancellationToken cancellationToken) + { + if (String.IsNullOrEmpty(task)) + { + throw new ArgumentNullException(nameof(task)); + } + + // TODO: Send this on + TextMessage taskStart = new() + { + Content = task, + Source = "user" + }; + + return this.StreamAsync(taskStart, cancellationToken); + } + + public ValueTask ResetAsync(CancellationToken cancel) + { + return ValueTask.CompletedTask; + } + + public async IAsyncEnumerable StreamAsync(ChatMessage? task, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (task == null) + { + throw new ArgumentNullException(nameof(task)); + } + + if (!this.EnsureSingleRun()) + { + throw new InvalidOperationException("The task is already running."); + } + + // TODO: How do we allow the user to configure this? + //AgentsAppBuilder builder = new AgentsAppBuilder().UseInProcessRuntime(); + InProcessRuntime runtime = new InProcessRuntime(); + + foreach (AgentChatConfig config in this.Participants.Values) + { + await runtime.RegisterChatAgentAsync(config); + } + + await runtime.RegisterGroupChatManagerAsync(this.GroupChatOptions, this.TeamId, this.CreateChatManager); + + OutputSink outputSink = new OutputSink(); + await runtime.RegisterOutputCollectorAsync(outputSink, this.GroupChatOptions.OutputTopicType); + + await runtime.StartAsync(); + + Task shutdownTask = Task.CompletedTask; + + try + { + // TODO: Protos + GroupChatStart taskMessage = new GroupChatStart + { + Messages = [task] + }; + + List runMessages = new(); + + AgentId chatManagerId = new AgentId(GroupChatManagerTopicType, this.TeamId); + await runtime.SendMessageAsync(taskMessage, chatManagerId, cancellationToken: cancellationToken); + + shutdownTask = Task.Run(runtime.RunUntilIdleAsync); + + while (true) + { + OutputSink.SinkFrame frame = await outputSink.WaitForDataAsync(cancellationToken); + runMessages.AddRange(frame.Messages); + + foreach (AgentMessage message in frame.Messages) + { + yield return new TaskFrame(message); + } + + if (frame.IsTerminal) + { + TaskResult result = new TaskResult(runMessages); + yield return new TaskFrame(result); + break; + } + } + } + finally + { + this.EndRun(); + + await shutdownTask; + } + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatHandlerRouter.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatHandlerRouter.cs new file mode 100644 index 000000000000..9855dcafb7ea --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatHandlerRouter.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GroupChatHandlerRouter.cs + +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AutoGen.AgentChat.GroupChat; + +internal delegate ValueTask MessagePublishServicer(GroupChatEventBase event_, string topicType, CancellationToken cancellation = default); + +internal interface IGroupChatHandler : IHandle, IHandle, IHandle +{ + void AttachMessagePublishServicer(MessagePublishServicer? servicer = null); + void DetachMessagePublishServicer() => this.AttachMessagePublishServicer(null); +} + +internal sealed class GroupChatHandlerRouter : HostableAgentAdapter, + IHandle, + IHandle, + IHandle + + where TManager : GroupChatManagerBase, IGroupChatHandler +{ + public const string DefaultDescription = "Group chat manager"; + + private TManager ChatManager { get; } + + public GroupChatHandlerRouter(AgentInstantiationContext agentCtx, TManager chatManager, ILogger? logger = null) : base(agentCtx, DefaultDescription, logger) + { + this.ChatManager = chatManager; + this.ChatManager.AttachMessagePublishServicer(PublishMessageServicer); + } + + private ValueTask PublishMessageServicer(GroupChatEventBase event_, string topicType, CancellationToken cancellation = default) + { + return this.PublishMessageAsync(event_, new TopicId(topicType, this.Id.Key), cancellationToken: cancellation); + } + + public ValueTask HandleAsync(GroupChatStart item, MessageContext messageContext) + => this.ChatManager.HandleAsync(item, messageContext); + + public ValueTask HandleAsync(GroupChatAgentResponse item, MessageContext messageContext) + => this.ChatManager.HandleAsync(item, messageContext); + + public ValueTask HandleAsync(object item, MessageContext messageContext) + => this.ChatManager.HandleAsync(item, messageContext); +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatManagerBase.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatManagerBase.cs new file mode 100644 index 000000000000..69c57e344d89 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatManagerBase.cs @@ -0,0 +1,187 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GroupChatManagerBase.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; +using Microsoft.AutoGen.Contracts; + +namespace Microsoft.AutoGen.AgentChat.GroupChat; + +public abstract class GroupChatManagerBase : IGroupChatHandler +{ + private GroupChatOptions options; + + // TODO: We may be able to abstract this out at the Core level + private MessagePublishServicer? PublishServicer { get; set; } + + // It would be so awesome if we could avoid passing GroupChatOptions to the constructor + // and use something like Python's context manager mechanism to pick up the options from + // the logical stack. But that's very difficult in C#, because the user code could do all + // sorts of weird things like shunt the execution to a different thread. We cannot even + // assume that we are in an async context (much less that we are in the same async context) + public GroupChatManagerBase(GroupChatOptions options) : base() + { + this.options = options; + + this.MessageThread = new List(); + } + + protected string GroupChatTopicType => this.options.GroupChatTopicType; + protected string OutputTopicType => this.options.OutputTopicType; + + protected Dictionary Participants => this.options.Participants; + + protected ITerminationCondition? TerminationCondition => this.options.TerminationCondition; + protected int? MaxTurns => this.options.MaxTurns; + + private int CurrentTurn { get; set; } + + protected List MessageThread; + + void IGroupChatHandler.AttachMessagePublishServicer(MessagePublishServicer? servicer) + { + this.PublishServicer = servicer; + } + + private ValueTask PublishMessageAsync(GroupChatEventBase message, string topicType, CancellationToken cancellation = default) + { + return this.PublishServicer?.Invoke(message, topicType, cancellation) ?? ValueTask.CompletedTask; + } + + protected ValueTask PublishMessageAsync(ChatMessage message, string topicType, CancellationToken cancellation = default) + { + return this.PublishMessageAsync(new GroupChatMessage { Message = message }, topicType, cancellation); + } + + protected virtual async ValueTask ValidateGroupState(List? messages) + { + } + + public abstract ValueTask SelectSpeakerAsync(List thread); + + public async ValueTask HandleAsync(GroupChatStart item, MessageContext messageContext) + { + if (this.TerminationCondition != null && this.TerminationCondition.IsTerminated) + { + // skipReset is used here to match the Python code + await this.TerminateAsync("The chat has already terminated", skipReset: true); + + StopMessage earlyStop = new StopMessage + { + Content = "The chat has already terminated", + Source = GroupChatBase.GroupChatManagerTopicType + }; + + await this.PublishMessageAsync(new GroupChatTermination { Message = earlyStop }, this.OutputTopicType); + + return; + } + + if (item.Messages != null) + { + this.MessageThread.AddRange(item.Messages); + } + + await this.ValidateGroupState(item.Messages); + + if (item.Messages != null) + { + await this.PublishMessageAsync(item, this.OutputTopicType); + await this.PublishMessageAsync(item, this.GroupChatTopicType); + + // item.Messages is IList but we need IList + // Unfortunately, IList does not support type variance, so we have to do this rather ugly thing + // TODO: Check if we really need to have AgentMessage on the interface of ITerminationCondition + List converted = [.. item.Messages.Cast()]; + + if (await this.TerminateIfNeededAsync(converted)) + { + return; + } + } + + await this.ProcessNextSpeakerAsync(); + } + + public async ValueTask HandleAsync(GroupChatAgentResponse item, MessageContext messageContext) + { + List delta = new List(); + + if (item.AgentResponse.InnerMessages != null) + { + this.MessageThread.AddRange(item.AgentResponse.InnerMessages); + delta.AddRange(item.AgentResponse.InnerMessages); + } + + this.MessageThread.Add(item.AgentResponse.Message); + delta.Add(item.AgentResponse.Message); + + if (await this.TerminateIfNeededAsync(delta)) + { + return; + } + + this.CurrentTurn++; + if (this.MaxTurns.HasValue && this.MaxTurns.Value <= this.CurrentTurn) + { + await this.TerminateAsync($"Maximum number of turns ({this.MaxTurns.Value}) reached."); + return; + } + + await this.ProcessNextSpeakerAsync(); + } + + private ValueTask TerminateAsync(string message, bool skipReset = false) + { + StopMessage stopMessage = new StopMessage + { + Content = message, + Source = GroupChatBase.GroupChatManagerTopicType + }; + + return this.TerminateAsync(stopMessage, skipReset); + } + + private async ValueTask TerminateAsync(StopMessage stopMessage, bool skipReset = false) + { + await this.PublishMessageAsync(new GroupChatTermination { Message = stopMessage }, this.OutputTopicType); + + if (!skipReset) + { + this.TerminationCondition?.Reset(); + this.CurrentTurn = 0; + } + } + + private async ValueTask TerminateIfNeededAsync(params IList incomingMessages) + { + if (this.TerminationCondition == null) + { + return false; + } + + StopMessage? stopMessage = await this.TerminationCondition.CheckAndUpdateAsync(incomingMessages); + if (stopMessage != null) + { + await this.TerminateAsync(stopMessage); + + return true; + } + + return false; + } + + // TODO: Figure out how to route this to the right method + //private ValueTask ProcessNextSpeakerAsync(params IList incomingMessages) + // => this.ProcessNextSpeakerAsync(incomingMessages); + + private async ValueTask ProcessNextSpeakerAsync() + { + string nextSpeakerTopic = await this.SelectSpeakerAsync(this.MessageThread); + await this.PublishMessageAsync(new GroupChatRequestPublish { }, nextSpeakerTopic); + } + + public ValueTask HandleAsync(object item, MessageContext messageContext) + { + throw new InvalidOperationException($"Unhandled message in group chat manager: {item.GetType()}"); + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatOptions.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatOptions.cs new file mode 100644 index 000000000000..e0907e5231c7 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/GroupChatOptions.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GroupChatOptions.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; + +namespace Microsoft.AutoGen.AgentChat.GroupChat; + +public struct GroupParticipant(string topicType, string description) +{ + public string TopicType { get; } = topicType; + public string Description { get; } = description; + + // Destructuring from a tuple + public GroupParticipant((string topicType, string description) tuple) : this(tuple.topicType, tuple.description) + { + } + + // Destructuring to a tuple + public void Deconstruct(out string topicType, out string description) + { + topicType = this.TopicType; + description = this.Description; + } + + public static implicit operator GroupParticipant((string topicType, string description) tuple) => new GroupParticipant(tuple); + public static implicit operator (string topicType, string description)(GroupParticipant participant) => (participant.TopicType, participant.Description); +} + +public class GroupChatOptions(string groupTopicType, string outputTopicType) +{ + public string GroupChatTopicType { get; } = groupTopicType; + public string OutputTopicType { get; } = outputTopicType; + + public ITerminationCondition? TerminationCondition { get; set; } + public int? MaxTurns { get; set; } + + public Dictionary Participants { get; } = new Dictionary(); +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/HostableAgentAdapter.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/HostableAgentAdapter.cs new file mode 100644 index 000000000000..86f20607cc9d --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/HostableAgentAdapter.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// HostableAgentAdapter.cs + +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AutoGen.AgentChat.GroupChat; + +public class AgentInstantiationContext(AgentId id, IAgentRuntime runtime) +{ + public AgentId Id { get; } = id; + public IAgentRuntime Runtime { get; } = runtime; +} + +internal class HostableAgentAdapter : BaseAgent +{ + public HostableAgentAdapter(AgentId id, IAgentRuntime runtime, string description, ILogger? logger = null) : base(id, runtime, description, logger) + { + } + + public HostableAgentAdapter(AgentInstantiationContext agentCtx, string description, ILogger? logger = null) : base(agentCtx.Id, agentCtx.Runtime, description, logger) + { + } +} + diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/OutputCollectorAgent.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/OutputCollectorAgent.cs new file mode 100644 index 000000000000..d17fa266b6aa --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/OutputCollectorAgent.cs @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OutputCollectorAgent.cs + +using System.Diagnostics; +using Microsoft.AutoGen.AgentChat.GroupChat; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AutoGen.AgentChat.Abstractions; + +internal interface IOutputCollectionSink +{ + void CollectMessage(AgentMessage message); + void Terminate(StopMessage message); +} + +internal sealed class OutputSink : IOutputCollectionSink +{ + public sealed class SinkFrame + { + public StopMessage? Termination { get; set; } + public List Messages { get; } = new(); + + public bool IsTerminal => this.Termination != null; + } + + private readonly object sync = new(); + private SemaphoreSlim semapohre = new SemaphoreSlim(1, 1); + + private SinkFrame? receivingSinkFrame; + + private void RunSynchronized(Action frameAction) + { + // Make sure we do not overlap with Terminate + lock (this.sync) + { + if (this.receivingSinkFrame == null) + { + this.receivingSinkFrame = new SinkFrame(); + } + + frameAction(this.receivingSinkFrame); + } + + semapohre.Release(); + } + + public void CollectMessage(AgentMessage message) + { + this.RunSynchronized( + frame => + { + frame.Messages.Add(message); + }); + } + + public void Terminate(StopMessage message) + { + this.RunSynchronized( + frame => + { + frame.Termination = message; + }); + } + + public async Task WaitForDataAsync(CancellationToken cancellation) + { + while (true) + { + SinkFrame? lastFrame; + lock (this.sync) + { + lastFrame = Interlocked.Exchange(ref this.receivingSinkFrame, null); + + if (lastFrame != null) + { + return lastFrame; + } + } + + await this.semapohre.WaitAsync(cancellation); + } + } +} + +// TODO: Abstract the core logic of this out into the equivalent of ClosureAgent, because that seems like a +// useful facility to have in Core +internal sealed class OutputCollectorAgent : BaseAgent, + IHandle, + IHandle, + IHandle +{ + private IOutputCollectionSink Sink { get; } + + public OutputCollectorAgent(AgentInstantiationContext ctx, IOutputCollectionSink sink, ILogger? logger = null) : base(ctx.Id, ctx.Runtime, string.Empty, logger) + { + this.Sink = sink; + } + + private void ForwardMessageInternal(ChatMessage message, CancellationToken cancel = default) + { + if (!cancel.IsCancellationRequested) + { + this.Sink.CollectMessage(message); + } + } + + public ValueTask HandleAsync(GroupChatStart item, MessageContext context) + { + item.Messages?.ForEach(item => this.ForwardMessageInternal(item, context.CancellationToken)); + + return ValueTask.CompletedTask; + } + + public ValueTask HandleAsync(GroupChatMessage item, MessageContext context) + { + Debug.Assert(item.Message is ChatMessage, "We should never receive internal messages into the output queue?"); + if (item.Message is ChatMessage chatMessage) + { + this.ForwardMessageInternal(chatMessage, context.CancellationToken); + } + + return ValueTask.CompletedTask; + } + + public ValueTask HandleAsync(GroupChatTermination item, MessageContext context) + { + this.Sink.Terminate(item.Message); + + return ValueTask.CompletedTask; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/RoundRobinGroupChat.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/RoundRobinGroupChat.cs new file mode 100644 index 000000000000..daa715063096 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/GroupChat/RoundRobinGroupChat.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// RoundRobinGroupChat.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; + +namespace Microsoft.AutoGen.AgentChat.GroupChat; + +/// +/// A group chat manager that selects the next speaker in a round-robin fashion. +/// +public class RoundRobinGroupChatManager : GroupChatManagerBase +{ + private readonly List participantTopicTypes; + private int nextSpeakerIndex; + + public RoundRobinGroupChatManager(GroupChatOptions options) : base(options) + { + this.participantTopicTypes = [.. from candidateTopic in options.Participants.Values + select candidateTopic.TopicType]; + this.nextSpeakerIndex = 0; + } + + public override ValueTask SelectSpeakerAsync(List thread) + { + string result = this.participantTopicTypes[this.nextSpeakerIndex].ToString(); + + this.nextSpeakerIndex = (this.nextSpeakerIndex + 1) % this.participantTopicTypes.Count; + + return ValueTask.FromResult(result); + } +} + +/// +/// A team that runs a group chat with a participants taking turns in a round-robin fashion to publish +/// a message to all. +/// +/// If a single participant is in the team, the participant will be the only speaker. +/// +public class RoundRobinGroupChat : GroupChatBase +{ + /// + /// Initializes a new round-robin group chat. + /// + /// The participants in the group chat. + /// + /// The termination condition for the group chat. Defaults to null. Without a termination + /// condition, the group chat will run indefinitely. + /// + /// + /// The maximum number of turns for the group chat. Defaults to null, meaning no limit. + /// Note that the gets first priority for checking the termination + /// if both are provided. + /// + public RoundRobinGroupChat(List participants, ITerminationCondition? terminationCondition = null, int? maxTurns = null) : base(participants, terminationCondition, maxTurns) + { + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Microsoft.AutoGen.AgentChat.csproj b/dotnet/src/Microsoft.AutoGen/AgentChat/Microsoft.AutoGen.AgentChat.csproj new file mode 100644 index 000000000000..6d2cb3fbe14e --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Microsoft.AutoGen.AgentChat.csproj @@ -0,0 +1,23 @@ + + + + net8.0 + enable + enable + + + + + + + + + + + + + + + + + diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/StopMessageTermination.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/StopMessageTermination.cs new file mode 100644 index 000000000000..194fe1afe289 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/StopMessageTermination.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// StopMessageTermination.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; + +namespace Microsoft.AutoGen.AgentChat.Terminations; + +/// +/// Terminate a conversation if a is received. +/// +public class StopMessageTermination : ITerminationCondition +{ + public bool IsTerminated { get; private set; } + + public ValueTask CheckAndUpdateAsync(IList messages) + { + if (this.IsTerminated) + { + throw new TerminatedException(); + } + + foreach (AgentMessage item in messages) + { + if (item is StopMessage) + { + this.IsTerminated = true; + + StopMessage result = new() { Content = "Stop message received", Source = nameof(StopMessageTermination) }; + return ValueTask.FromResult(result); + } + } + + return ValueTask.FromResult(null); + } + + public void Reset() + { + this.IsTerminated = false; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/AgentId.cs b/dotnet/src/Microsoft.AutoGen/Contracts/AgentId.cs index 3098547231ac..da5c5aa65ef2 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/AgentId.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/AgentId.cs @@ -14,7 +14,7 @@ namespace Microsoft.AutoGen.Contracts; /// See the Python equivalent: /// AgentId in AutoGen (Python). /// -[DebuggerDisplay($"AgentId(type=\"{nameof(Type)}\", key=\"{nameof(Key)}\")")] +[DebuggerDisplay($"AgentId(type=\"{{{nameof(Type)}}}\", key=\"{{{nameof(Key)}}}\")")] public struct AgentId { private static readonly Regex TypeRegex = new(@"^[a-zA-Z_][a-zA-Z0-9_]*$", RegexOptions.Compiled); diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/MessageContext.cs b/dotnet/src/Microsoft.AutoGen/Contracts/MessageContext.cs index e99f0c42d941..58b580fee506 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/MessageContext.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/MessageContext.cs @@ -9,6 +9,9 @@ namespace Microsoft.AutoGen.Contracts; /// public class MessageContext(string messageId, CancellationToken cancellationToken) { + public MessageContext(CancellationToken cancellation) : this(Guid.NewGuid().ToString(), cancellation) + { } + /// /// Gets or sets the unique identifier for this message. /// diff --git a/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/AgentChatSmokeTest.cs b/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/AgentChatSmokeTest.cs new file mode 100644 index 000000000000..018f80cae38a --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/AgentChatSmokeTest.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AgentChatSmokeTest.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; +using Microsoft.AutoGen.AgentChat.Agents; +using Microsoft.AutoGen.AgentChat.GroupChat; +using Microsoft.AutoGen.AgentChat.Terminations; +using Xunit; + +namespace Microsoft.AutoGen.AgentChat.Tests; + +public class AgentChatSmokeTest +{ + public class SpeakMessageAgent : ChatAgentBase + { + public SpeakMessageAgent(string name, string description, string content) : base(name, description) + { + this.Content = content; + } + + public string Content { get; private set; } + + public override IEnumerable ProducedMessageTypes => [typeof(HandoffMessage)]; + + public override ValueTask HandleAsync(IEnumerable item, CancellationToken cancellationToken) + { + Response result = new() + { + Message = new TextMessage { Content = this.Content, Source = this.Name } + }; + + return ValueTask.FromResult(result); + } + + public override ValueTask ResetAsync(CancellationToken cancellationToken) + { + return ValueTask.CompletedTask; + } + } + + public class TerminatingAgent : ChatAgentBase + { + public List? IncomingMessages { get; private set; } + + public TerminatingAgent(string name, string description) : base(name, description) + { + } + + public override IEnumerable ProducedMessageTypes => [typeof(StopMessage)]; + + public override ValueTask HandleAsync(IEnumerable item, CancellationToken cancellationToken) + { + this.IncomingMessages = item.ToList(); + + string content = "Terminating"; + if (item.Any()) + { + ChatMessage lastMessage = item.Last(); + + switch (lastMessage) + { + case TextMessage textMessage: + content = $"Terminating; got: {textMessage.Content}"; + break; + case HandoffMessage handoffMessage: + content = $"Terminating; got handoff: {handoffMessage.Context}"; + break; + } + } + + Response result = new() + { + Message = new StopMessage { Content = content, Source = this.Name } + }; + + return ValueTask.FromResult(result); + } + + public override ValueTask ResetAsync(CancellationToken cancellationToken) + { + this.IncomingMessages = null; + + return ValueTask.CompletedTask; + } + } + + [Fact] + public async Task Test_RoundRobin_SpeakAndTerminating() + { + TerminatingAgent terminatingAgent = new("Terminate", "Terminate"); + + ITeam chat = new RoundRobinGroupChat( + [ + new SpeakMessageAgent("Speak", "Speak", "Hello"), + terminatingAgent + ], + terminationCondition: new StopMessageTermination()); + + TaskResult result = await chat.RunAsync(""); + + Assert.Equal(3, result.Messages.Count); + Assert.Equal("", Assert.IsType(result.Messages[0]).Content); + Assert.Equal("Hello", Assert.IsType(result.Messages[1]).Content); + Assert.Equal("Terminating; got: Hello", Assert.IsType(result.Messages[2]).Content); + } +} diff --git a/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/Microsoft.AutoGen.AgentChat.Tests.csproj b/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/Microsoft.AutoGen.AgentChat.Tests.csproj new file mode 100644 index 000000000000..e505771e69e2 --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/Microsoft.AutoGen.AgentChat.Tests.csproj @@ -0,0 +1,15 @@ + + + + $(TestTargetFrameworks) + enable + enable + True + + + + + + + + From 73a7ba5764d2c3b25247b19ec6339887eb853130 Mon Sep 17 00:00:00 2001 From: Rohan Thacker Date: Sat, 8 Feb 2025 04:57:14 +0530 Subject: [PATCH 10/36] Added the Claude family of models to ModelFamily (#5443) Added the Claude family of models to the `ModelFamily` class. Co-authored-by: Eric Zhu --- .../src/autogen_core/models/_model_client.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/packages/autogen-core/src/autogen_core/models/_model_client.py b/python/packages/autogen-core/src/autogen_core/models/_model_client.py index e3ff0403fb88..4a046a3cf5f9 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_model_client.py +++ b/python/packages/autogen-core/src/autogen_core/models/_model_client.py @@ -27,6 +27,11 @@ class ModelFamily: GEMINI_1_5_FLASH = "gemini-1.5-flash" GEMINI_1_5_PRO = "gemini-1.5-pro" GEMINI_2_0_FLASH = "gemini-2.0-flash" + CLAUDE_3_HAIKU = "claude-3-haiku" + CLAUDE_3_SONNET = "claude-3-sonnet" + CLAUDE_3_OPUS = "claude-3-opus" + CLAUDE_3_5_HAIKU = "claude-3.5-haiku" + CLAUDE_3_5_SONNET = "claude-3.5-sonnet" UNKNOWN = "unknown" ANY: TypeAlias = Literal[ @@ -39,6 +44,11 @@ class ModelFamily: "gemini-1.5-flash", "gemini-1.5-pro", "gemini-2.0-flash", + "claude-3-haiku", + "claude-3-sonnet", + "claude-3-opus", + "claude-3.5-haiku", + "claude-3.5-sonnet", "unknown", ] From 9494ac97a0792b6207aa15a31ec02c67fbdff9db Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Fri, 7 Feb 2025 15:37:44 -0800 Subject: [PATCH 11/36] AGS Improvements (Add Test Button in Team Builder View + Others) (#5416) ## Why are these changes needed? - Add ability to test teams in Team Builder view - Update Gallery (add deep research default team, fix bug with gallery serialization) - UI fixes - improve drag drop component experience - improve new session experience (single click rather than 3 clicks to create a session) - fix bug with stop reason not being shown in some cases Image Image ## Related issue number Closes #5392 ## Checks - [ ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --- .../autogenstudio-user-guide/faq.md | 88 +++-- .../autogenstudio-user-guide/index.md | 4 +- .../autogenstudio-user-guide/usage.md | 8 +- .../autogenstudio/gallery/builder.py | 84 ++++- .../autogen-studio/autogenstudio/version.py | 2 +- .../frontend/src/components/sidebar.tsx | 2 +- .../views/gallery/default_gallery.json | 346 +++++++++++++++--- .../src/components/views/gallery/store.tsx | 2 +- .../components/views/session/chat/chat.tsx | 3 - .../components/views/session/chat/runview.tsx | 6 - .../src/components/views/session/manager.tsx | 1 - .../components/views/session/newsession.tsx | 16 +- .../src/components/views/session/sidebar.tsx | 6 +- .../components/views/team/builder/builder.css | 52 +-- .../components/views/team/builder/builder.tsx | 76 +++- .../components/views/team/builder/library.tsx | 6 +- .../views/team/builder/testdrawer.tsx | 98 +++++ .../src/components/views/team/sidebar.tsx | 4 +- python/packages/autogen-studio/pyproject.toml | 1 + python/uv.lock | 4 +- 20 files changed, 668 insertions(+), 141 deletions(-) create mode 100644 python/packages/autogen-studio/frontend/src/components/views/team/builder/testdrawer.tsx diff --git a/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/faq.md b/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/faq.md index 26c52628d380..66034743af35 100644 --- a/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/faq.md +++ b/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/faq.md @@ -15,29 +15,67 @@ A: You can specify the directory where files are stored by setting the `--appdir Yes. AutoGen standardizes on the openai model api format, and you can use any api server that offers an openai compliant endpoint. -AutoGen Studio is based on declaritive specifications which applies to models as well. Agents can include a model_client field which specifies the model endpoint details including `model`, `api_key`, `base_url`, `model type`. +AutoGen Studio is based on declaritive specifications which applies to models as well. Agents can include a model_client field which specifies the model endpoint details including `model`, `api_key`, `base_url`, `model type`. Note, you can define your [model client](https://microsoft.github.io/autogen/dev/user-guide/core-user-guide/components/model-clients.html) in python and dump it to a json file for use in AutoGen Studio. -An example of the openai model client is shown below: +In the following sample, we will define an OpenAI, AzureOpenAI and a local model client in python and dump them to a json file. + +```python +from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient +from autogen_core.models import ModelInfo + +model_client=OpenAIChatCompletionClient( + model="gpt-4o-mini", + ) +print(model_client.dump_component().model_dump_json()) + +az_model_client = AzureOpenAIChatCompletionClient( + azure_deployment="{your-azure-deployment}", + model="gpt-4o", + api_version="2024-06-01", + azure_endpoint="https://{your-custom-endpoint}.openai.azure.com/", + api_key="sk-...", +) +print(az_model_client.dump_component().model_dump_json()) + +mistral_vllm_model = OpenAIChatCompletionClient( + model="TheBloke/Mistral-7B-Instruct-v0.2-GGUF", + base_url="http://localhost:1234/v1", + model_info=ModelInfo(vision=False, function_calling=True, json_output=False, family="unknown"), + ) +print(mistral_vllm_model.dump_component().model_dump_json()) +``` + +OpenAI ```json { - "model": "gpt-4o-mini", - "model_type": "OpenAIChatCompletionClient", - "api_key": "your-api-key" + "provider": "autogen_ext.models.openai.OpenAIChatCompletionClient", + "component_type": "model", + "version": 1, + "component_version": 1, + "description": "Chat completion client for OpenAI hosted models.", + "label": "OpenAIChatCompletionClient", + "config": { "model": "gpt-4o-mini" } } ``` -An example of the azure openai model client is shown below: +Azure OpenAI ```json { - "model": "gpt-4o-mini", - "model_type": "AzureOpenAIChatCompletionClient", - "azure_deployment": "gpt-4o-mini", - "api_version": "2024-02-15-preview", - "azure_endpoint": "https://your-endpoint.openai.azure.com/", - "api_key": "your-api-key", - "component_type": "model" + "provider": "autogen_ext.models.openai.AzureOpenAIChatCompletionClient", + "component_type": "model", + "version": 1, + "component_version": 1, + "description": "Chat completion client for Azure OpenAI hosted models.", + "label": "AzureOpenAIChatCompletionClient", + "config": { + "model": "gpt-4o", + "api_key": "sk-...", + "azure_endpoint": "https://{your-custom-endpoint}.openai.azure.com/", + "azure_deployment": "{your-azure-deployment}", + "api_version": "2024-06-01" + } } ``` @@ -45,21 +83,27 @@ Have a local model server like Ollama, vLLM or LMStudio that provide an OpenAI c ```json { - "model": "TheBloke/Mistral-7B-Instruct-v0.2-GGUF", - "model_type": "OpenAIChatCompletionClient", - "base_url": "http://localhost:1234/v1", - "api_version": "1.0", + "provider": "autogen_ext.models.openai.OpenAIChatCompletionClient", "component_type": "model", - "model_capabilities": { - "vision": false, - "function_calling": true, - "json_output": false + "version": 1, + "component_version": 1, + "description": "Chat completion client for OpenAI hosted models.", + "label": "OpenAIChatCompletionClient", + "config": { + "model": "TheBloke/Mistral-7B-Instruct-v0.2-GGUF", + "model_info": { + "vision": false, + "function_calling": true, + "json_output": false, + "family": "unknown" + }, + "base_url": "http://localhost:1234/v1" } } ``` ```{caution} -It is important that you add the `model_capabilities` field to the model client specification for custom models. This is used by the framework instantiate and use the model correctly. Also, the `AssistantAgent` and many other agents in AgentChat require the model to have the `function_calling` capability. +It is important that you add the `model_info` field to the model client specification for custom models. This is used by the framework instantiate and use the model correctly. Also, the `AssistantAgent` and many other agents in AgentChat require the model to have the `function_calling` capability. ``` ## Q: The server starts but I can't access the UI diff --git a/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/index.md b/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/index.md index 09de3f9ac14f..608adf78a093 100644 --- a/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/index.md +++ b/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/index.md @@ -12,7 +12,9 @@ myst: AutoGen Studio is a low-code interface built to help you rapidly prototype AI agents, enhance them with tools, compose them into teams and interact with them to accomplish tasks. It is built on [AutoGen AgentChat](https://microsoft.github.io/autogen) - a high-level API for building multi-agent applications. -![AutoGen Studio](https://media.githubusercontent.com/media/microsoft/autogen/refs/heads/main/python/packages/autogen-studio/docs/ags_screen.png) +> See a video tutorial on AutoGen Studio v0.4 (02/25) - [https://youtu.be/oum6EI7wohM](https://youtu.be/oum6EI7wohM) + +[![A Friendly Introduction to AutoGen Studio v0.4](https://img.youtube.com/vi/oum6EI7wohM/maxresdefault.jpg)](https://www.youtube.com/watch?v=oum6EI7wohM) Code for AutoGen Studio is on GitHub at [microsoft/autogen](https://github.com/microsoft/autogen/tree/main/python/packages/autogen-studio) diff --git a/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/usage.md b/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/usage.md index e185a9128566..dd2cc6d1e727 100644 --- a/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/usage.md +++ b/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/usage.md @@ -10,7 +10,9 @@ myst: AutoGen Studio provides a Team Builder interface where developers can define multiple components and behaviors. Users can create teams, add agents to teams, attach tools and models to agents, and define team termination conditions. After defining a team, users can test it in the Playground view to accomplish various tasks through direct interaction. -![AutoGen Studio](https://media.githubusercontent.com/media/microsoft/autogen/refs/heads/main/python/packages/autogen-studio/docs/ags_screen.png) +> See a video tutorial on AutoGen Studio v0.4 (02/25) - [https://youtu.be/oum6EI7wohM](https://youtu.be/oum6EI7wohM) + +[![A Friendly Introduction to AutoGen Studio v0.4](https://img.youtube.com/vi/oum6EI7wohM/maxresdefault.jpg)](https://www.youtube.com/watch?v=oum6EI7wohM) ## Declarative Specification of Componenents @@ -100,8 +102,6 @@ This example shows a team with a single agent, using the `RoundRobinGroupChat` t ## Building an Agent Team -
-
AutoGen Studio integrates closely with all component abstractions provided by AutoGen AgentChat, including {py:class}`~autogen_agentchat.teams`, {py:class}`~autogen_agentchat.agents`, {py:class}`~autogen_core.models`, {py:class}`~autogen_core.tools`, and termination {py:class}`~autogen_agentchat.conditions`. @@ -117,6 +117,8 @@ Team Builder Operations: - Agents: Add models and tools - Save team configurations +Note: For each node in the visual builder, you can click on the edit icon (top right) to view and edit the JSON configuration. + ## Gallery - Sharing and Reusing Components A Gallery is a collection of components - teams, agents, models, tools, and terminations - that can be shared and reused across projects. diff --git a/python/packages/autogen-studio/autogenstudio/gallery/builder.py b/python/packages/autogen-studio/autogenstudio/gallery/builder.py index 7ef14c2061b6..0d9a36f83471 100644 --- a/python/packages/autogen-studio/autogenstudio/gallery/builder.py +++ b/python/packages/autogen-studio/autogenstudio/gallery/builder.py @@ -148,7 +148,7 @@ def create_default_gallery() -> Gallery: mistral_vllm_model = OpenAIChatCompletionClient( model="TheBloke/Mistral-7B-Instruct-v0.2-GGUF", base_url="http://localhost:1234/v1", - model_info=ModelInfo(vision=False, function_calling=True, json_output=False), + model_info=ModelInfo(vision=False, function_calling=True, json_output=False, family="unknown"), ) builder.add_model( mistral_vllm_model.dump_component(), @@ -236,6 +236,7 @@ def create_default_gallery() -> Gallery: model_client=base_model, termination_condition=web_termination, ) + builder.add_team( websurfer_team.dump_component(), label="Web Agent Team (Operator)", @@ -256,8 +257,8 @@ def create_default_gallery() -> Gallery: builder.add_tool( tools.fetch_webpage_tool.dump_component(), - label="Webpage Generation Tool", - description="A tool that generates a webpage from a list of images. Requires beautifulsoup4 html2text library to function.", + label="Fetch Webpage Tool", + description="A tool that fetches the content of a webpage and converts it to markdown. Requires the requests and beautifulsoup4 library to function.", ) builder.add_tool( @@ -272,6 +273,83 @@ def create_default_gallery() -> Gallery: description="A tool that performs Google searches using the Google Custom Search API. Requires the requests library, [GOOGLE_API_KEY, GOOGLE_CSE_ID] to be set, env variable to function.", ) + # Create deep research agent + model_client = OpenAIChatCompletionClient(model="gpt-4o", temperature=0.7) + + research_assistant = AssistantAgent( + name="research_assistant", + description="A research assistant that performs web searches and analyzes information", + model_client=model_client, + tools=[tools.google_search_tool, tools.fetch_webpage_tool], + system_message="""You are a research assistant focused on finding accurate information. + Use the google_search tool to find relevant information. + Break down complex queries into specific search terms. + Always verify information across multiple sources when possible. + When you find relevant information, explain why it's relevant and how it connects to the query. When you get feedback from the a verifier agent, use your tools to act on the feedback and make progress.""", + ) + + verifier = AssistantAgent( + name="verifier", + description="A verification specialist who ensures research quality and completeness", + model_client=model_client, + system_message="""You are a research verification specialist. + Your role is to: + 1. Verify that search queries are effective and suggest improvements if needed + 2. Explore drill downs where needed e.g, if the answer is likely in a link in the returned search results, suggest clicking on the link + 3. Suggest additional angles or perspectives to explore. Be judicious in suggesting new paths to avoid scope creep or wasting resources, if the task appears to be addressed and we can provide a report, do this and respond with "TERMINATE". + 4. Track progress toward answering the original question + 5. When the research is complete, provide a detailed summary in markdown format. For incomplete research, end your message with "CONTINUE RESEARCH". For complete research, end your message with APPROVED. + Your responses should be structured as: + - Progress Assessment + - Gaps/Issues (if any) + - Suggestions (if needed) + - Next Steps or Final Summary""", + ) + + summary_agent = AssistantAgent( + name="summary_agent", + description="A summary agent that provides a detailed markdown summary of the research as a report to the user.", + model_client=model_client, + system_message="""You are a summary agent. Your role is to provide a detailed markdown summary of the research as a report to the user. Your report should have a reasonable title that matches the research question and should summarize the key details in the results found in natural an actionable manner. The main results/answer should be in the first paragraph. + Your report should end with the word "TERMINATE" to signal the end of the conversation.""", + ) + + termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(max_messages=30) + + selector_prompt = """You are coordinating a research team by selecting the team member to speak/act next. The following team member roles are available: + {roles}. + The research_assistant performs searches and analyzes information. + The verifier evaluates progress and ensures completeness. + The summary_agent provides a detailed markdown summary of the research as a report to the user. + + Given the current context, select the most appropriate next speaker. + The research_assistant should search and analyze. + The verifier should evaluate progress and guide the research (select this role is there is a need to verify/evaluate progress). You should ONLY select the summary_agent role if the research is complete and it is time to generate a report. + + Base your selection on: + 1. Current stage of research + 2. Last speaker's findings or suggestions + 3. Need for verification vs need for new information + Read the following conversation. Then select the next role from {participants} to play. Only return the role. + + {history} + + Read the above conversation. Then select the next role from {participants} to play. ONLY RETURN THE ROLE.""" + + deep_research_team = SelectorGroupChat( + participants=[research_assistant, verifier, summary_agent], + model_client=model_client, + termination_condition=termination, + selector_prompt=selector_prompt, + allow_repeated_speaker=True, + ) + + builder.add_team( + deep_research_team.dump_component(), + label="Deep Research Team", + description="A team that performs deep research using web searches, verification, and summarization.", + ) + return builder.build() diff --git a/python/packages/autogen-studio/autogenstudio/version.py b/python/packages/autogen-studio/autogenstudio/version.py index 171811d227da..f1a865c952a5 100644 --- a/python/packages/autogen-studio/autogenstudio/version.py +++ b/python/packages/autogen-studio/autogenstudio/version.py @@ -1,3 +1,3 @@ -VERSION = "0.4.0" +VERSION = "0.4.1" __version__ = VERSION APP_NAME = "autogenstudio" diff --git a/python/packages/autogen-studio/frontend/src/components/sidebar.tsx b/python/packages/autogen-studio/frontend/src/components/sidebar.tsx index 1aa1ba4abdb4..b01672d81829 100644 --- a/python/packages/autogen-studio/frontend/src/components/sidebar.tsx +++ b/python/packages/autogen-studio/frontend/src/components/sidebar.tsx @@ -220,7 +220,7 @@ const Sidebar = ({ link, meta, isMobile }: SidebarProps) => { ], }) } - className="group flex gap-x-3 rounded-md p-2 text-sm font-medium text-primary hover:text-accent hover:bg-secondary justify-center" + className="group hidden flex gap-x-3 rounded-md p-2 text-sm font-medium text-primary hover:text-accent hover:bg-secondary justify-center" > diff --git a/python/packages/autogen-studio/frontend/src/components/views/gallery/default_gallery.json b/python/packages/autogen-studio/frontend/src/components/views/gallery/default_gallery.json index 6572c57048e2..1a8b4e555435 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/gallery/default_gallery.json +++ b/python/packages/autogen-studio/frontend/src/components/views/gallery/default_gallery.json @@ -1,10 +1,11 @@ { "id": "gallery_default", "name": "Default Component Gallery", + "url": null, "metadata": { "author": "AutoGen Team", - "created_at": "2025-02-01T17:12:38.179093", - "updated_at": "2025-02-01T17:12:38.701658", + "created_at": "2025-02-06T11:52:58.243877", + "updated_at": "2025-02-06T11:52:58.497032", "version": "1.0.0", "description": "A default gallery containing basic components for human-in-loop conversations", "tags": ["human-in-loop", "assistant", "web agents"], @@ -53,7 +54,7 @@ "description": "Create custom tools by wrapping standard Python functions.", "label": "FunctionTool", "config": { - "source_code": "def calculator(a: float, b: float, operator: str) -> str:\n try:\n if operator == \"+\":\n return str(a + b)\n elif operator == \"-\":\n return str(a - b)\n elif operator == \"*\":\n return str(a * b)\n elif operator == \"/\":\n if b == 0:\n return \"Error: Division by zero\"\n return str(a / b)\n else:\n return \"Error: Invalid operator. Please use +, -, *, or /\"\n except Exception as e:\n return f\"Error: {str(e)}\"\n", + "source_code": "def calculator(a: float, b: float, operator: str) -> str:\n try:\n if operator == \"+\":\n return str(a + b)\n elif operator == \"-\":\n return str(a - b)\n elif operator == \"*\":\n return str(a * b)\n elif operator == \"/\":\n if b == 0:\n return \"Error: Division by zero\"\n return str(a / b)\n else:\n return \"Error: Invalid operator. Please use +, -, *, or /\"\n except Exception as e:\n return f\"Error: {str(e)}\"\n", "name": "calculator", "description": "A simple calculator that performs basic arithmetic operations", "global_imports": [], @@ -250,6 +251,249 @@ "selector_prompt": "You are the cordinator of role play game. The following roles are available:\n{roles}. Given a task, the websurfer_agent will be tasked to address it by browsing the web and providing information. The assistant_agent will be tasked with verifying the information provided by the websurfer_agent and summarizing the information to present a final answer to the user. If the task needs assistance from a human user (e.g., providing feedback, preferences, or the task is stalled), you should select the user_proxy role to provide the necessary information.\n\nRead the following conversation. Then select the next role from {participants} to play. Only return the role.\n\n{history}\n\nRead the above conversation. Then select the next role from {participants} to play. Only return the role.", "allow_repeated_speaker": false } + }, + { + "provider": "autogen_agentchat.teams.SelectorGroupChat", + "component_type": "team", + "version": 1, + "component_version": 1, + "description": "A team that performs deep research using web searches, verification, and summarization.", + "label": "Deep Research Team", + "config": { + "participants": [ + { + "provider": "autogen_agentchat.agents.AssistantAgent", + "component_type": "agent", + "version": 1, + "component_version": 1, + "description": "An agent that provides assistance with tool use.", + "label": "AssistantAgent", + "config": { + "name": "research_assistant", + "model_client": { + "provider": "autogen_ext.models.openai.OpenAIChatCompletionClient", + "component_type": "model", + "version": 1, + "component_version": 1, + "description": "Chat completion client for OpenAI hosted models.", + "label": "OpenAIChatCompletionClient", + "config": { + "temperature": 0.7, + "model": "gpt-4o" + } + }, + "tools": [ + { + "provider": "autogen_core.tools.FunctionTool", + "component_type": "tool", + "version": 1, + "component_version": 1, + "description": "Create custom tools by wrapping standard Python functions.", + "label": "FunctionTool", + "config": { + "source_code": "async def google_search(\n query: str,\n num_results: int = 5,\n include_snippets: bool = True,\n include_content: bool = True,\n content_max_length: Optional[int] = 15000,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: bool = True,\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Google search using the Custom Search API and optionally fetch webpage content.\n\n Args:\n query: Search query string\n num_results: Number of results to return (max 10)\n include_snippets: Include result snippets in output\n include_content: Include full webpage content in markdown format\n content_max_length: Maximum length of webpage content (if included)\n language: Language code for search results (e.g., en, es, fr)\n country: Optional country code for search results (e.g., us, uk)\n safe_search: Enable safe search filtering\n\n Returns:\n List[Dict[str, str]]: List of search results, each containing:\n - title: Result title\n - link: Result URL\n - snippet: Result description (if include_snippets=True)\n - content: Webpage content in markdown (if include_content=True)\n \"\"\"\n api_key = os.getenv(\"GOOGLE_API_KEY\")\n cse_id = os.getenv(\"GOOGLE_CSE_ID\")\n\n if not api_key or not cse_id:\n raise ValueError(\"Missing required environment variables. Please set GOOGLE_API_KEY and GOOGLE_CSE_ID.\")\n\n num_results = min(max(1, num_results), 10)\n\n async def fetch_page_content(url: str, max_length: Optional[int] = 50000) -> str:\n \"\"\"Helper function to fetch and convert webpage content to markdown\"\"\"\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n h2t = html2text.HTML2Text()\n h2t.body_width = 0\n h2t.ignore_images = False\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n markdown = h2t.handle(str(soup))\n\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except Exception as e:\n return f\"Error fetching content: {str(e)}\"\n\n params = {\n \"key\": api_key,\n \"cx\": cse_id,\n \"q\": query,\n \"num\": num_results,\n \"hl\": language,\n \"safe\": \"active\" if safe_search else \"off\",\n }\n\n if country:\n params[\"gl\"] = country\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(\"https://www.googleapis.com/customsearch/v1\", params=params, timeout=10)\n response.raise_for_status()\n data = response.json()\n\n results = []\n if \"items\" in data:\n for item in data[\"items\"]:\n result = {\"title\": item.get(\"title\", \"\"), \"link\": item.get(\"link\", \"\")}\n if include_snippets:\n result[\"snippet\"] = item.get(\"snippet\", \"\")\n\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n results.append(result)\n\n return results\n\n except httpx.RequestError as e:\n raise ValueError(f\"Failed to perform search: {str(e)}\") from e\n except KeyError as e:\n raise ValueError(f\"Invalid API response format: {str(e)}\") from e\n except Exception as e:\n raise ValueError(f\"Error during search: {str(e)}\") from e\n", + "name": "google_search", + "description": "\n Perform Google searches using the Custom Search API with optional webpage content fetching.\n Requires GOOGLE_API_KEY and GOOGLE_CSE_ID environment variables to be set.\n ", + "global_imports": [ + { + "module": "typing", + "imports": ["List", "Dict", "Optional"] + }, + "os", + "httpx", + "html2text", + { + "module": "bs4", + "imports": ["BeautifulSoup"] + }, + { + "module": "urllib.parse", + "imports": ["urljoin"] + } + ], + "has_cancellation_support": false + } + }, + { + "provider": "autogen_core.tools.FunctionTool", + "component_type": "tool", + "version": 1, + "component_version": 1, + "description": "Create custom tools by wrapping standard Python functions.", + "label": "FunctionTool", + "config": { + "source_code": "async def fetch_webpage(\n url: str, include_images: bool = True, max_length: Optional[int] = None, headers: Optional[Dict[str, str]] = None\n) -> str:\n \"\"\"Fetch a webpage and convert it to markdown format.\n\n Args:\n url: The URL of the webpage to fetch\n include_images: Whether to include image references in the markdown\n max_length: Maximum length of the output markdown (if None, no limit)\n headers: Optional HTTP headers for the request\n\n Returns:\n str: Markdown version of the webpage content\n\n Raises:\n ValueError: If the URL is invalid or the page can't be fetched\n \"\"\"\n # Use default headers if none provided\n if headers is None:\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n # Fetch the webpage\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n # Parse HTML\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n # Configure HTML to Markdown converter\n h2t = html2text.HTML2Text()\n h2t.body_width = 0 # No line wrapping\n h2t.ignore_images = not include_images\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n # Convert to markdown\n markdown = h2t.handle(str(soup))\n\n # Trim if max_length is specified\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except httpx.RequestError as e:\n raise ValueError(f\"Failed to fetch webpage: {str(e)}\") from e\n except Exception as e:\n raise ValueError(f\"Error processing webpage: {str(e)}\") from e\n", + "name": "fetch_webpage", + "description": "Fetch a webpage and convert it to markdown format, with options for including images and limiting length", + "global_imports": [ + "os", + "html2text", + { + "module": "typing", + "imports": ["Optional", "Dict"] + }, + "httpx", + { + "module": "bs4", + "imports": ["BeautifulSoup"] + }, + { + "module": "html2text", + "imports": ["HTML2Text"] + }, + { + "module": "urllib.parse", + "imports": ["urljoin"] + } + ], + "has_cancellation_support": false + } + } + ], + "handoffs": [], + "model_context": { + "provider": "autogen_core.model_context.UnboundedChatCompletionContext", + "component_type": "chat_completion_context", + "version": 1, + "component_version": 1, + "description": "An unbounded chat completion context that keeps a view of the all the messages.", + "label": "UnboundedChatCompletionContext", + "config": {} + }, + "description": "A research assistant that performs web searches and analyzes information", + "system_message": "You are a research assistant focused on finding accurate information.\n Use the google_search tool to find relevant information.\n Break down complex queries into specific search terms.\n Always verify information across multiple sources when possible.\n When you find relevant information, explain why it's relevant and how it connects to the query. When you get feedback from the a verifier agent, use your tools to act on the feedback and make progress.", + "model_client_stream": false, + "reflect_on_tool_use": false, + "tool_call_summary_format": "{result}" + } + }, + { + "provider": "autogen_agentchat.agents.AssistantAgent", + "component_type": "agent", + "version": 1, + "component_version": 1, + "description": "An agent that provides assistance with tool use.", + "label": "AssistantAgent", + "config": { + "name": "verifier", + "model_client": { + "provider": "autogen_ext.models.openai.OpenAIChatCompletionClient", + "component_type": "model", + "version": 1, + "component_version": 1, + "description": "Chat completion client for OpenAI hosted models.", + "label": "OpenAIChatCompletionClient", + "config": { + "temperature": 0.7, + "model": "gpt-4o" + } + }, + "tools": [], + "handoffs": [], + "model_context": { + "provider": "autogen_core.model_context.UnboundedChatCompletionContext", + "component_type": "chat_completion_context", + "version": 1, + "component_version": 1, + "description": "An unbounded chat completion context that keeps a view of the all the messages.", + "label": "UnboundedChatCompletionContext", + "config": {} + }, + "description": "A verification specialist who ensures research quality and completeness", + "system_message": "You are a research verification specialist.\n Your role is to:\n 1. Verify that search queries are effective and suggest improvements if needed\n 2. Explore drill downs where needed e.g, if the answer is likely in a link in the returned search results, suggest clicking on the link\n 3. Suggest additional angles or perspectives to explore. Be judicious in suggesting new paths to avoid scope creep or wasting resources, if the task appears to be addressed and we can provide a report, do this and respond with \"TERMINATE\".\n 4. Track progress toward answering the original question\n 5. When the research is complete, provide a detailed summary in markdown format\n \n For incomplete research, end your message with \"CONTINUE RESEARCH\". \n For complete research, end your message with APPROVED.\n \n Your responses should be structured as:\n - Progress Assessment\n - Gaps/Issues (if any)\n - Suggestions (if needed)\n - Next Steps or Final Summary", + "model_client_stream": false, + "reflect_on_tool_use": false, + "tool_call_summary_format": "{result}" + } + }, + { + "provider": "autogen_agentchat.agents.AssistantAgent", + "component_type": "agent", + "version": 1, + "component_version": 1, + "description": "An agent that provides assistance with tool use.", + "label": "AssistantAgent", + "config": { + "name": "summary_agent", + "model_client": { + "provider": "autogen_ext.models.openai.OpenAIChatCompletionClient", + "component_type": "model", + "version": 1, + "component_version": 1, + "description": "Chat completion client for OpenAI hosted models.", + "label": "OpenAIChatCompletionClient", + "config": { + "temperature": 0.7, + "model": "gpt-4o" + } + }, + "tools": [], + "handoffs": [], + "model_context": { + "provider": "autogen_core.model_context.UnboundedChatCompletionContext", + "component_type": "chat_completion_context", + "version": 1, + "component_version": 1, + "description": "An unbounded chat completion context that keeps a view of the all the messages.", + "label": "UnboundedChatCompletionContext", + "config": {} + }, + "description": "A summary agent that provides a detailed markdown summary of the research as a report to the user.", + "system_message": "You are a summary agent. Your role is to provide a detailed markdown summary of the research as a report to the user. Your report should have a reasonable title that matches the research question and should summarize the key details in the results found in natural an actionable manner. The main results/answer should be in the first paragraph.\n Your report should end with the word \"TERMINATE\" to signal the end of the conversation.", + "model_client_stream": false, + "reflect_on_tool_use": false, + "tool_call_summary_format": "{result}" + } + } + ], + "model_client": { + "provider": "autogen_ext.models.openai.OpenAIChatCompletionClient", + "component_type": "model", + "version": 1, + "component_version": 1, + "description": "Chat completion client for OpenAI hosted models.", + "label": "OpenAIChatCompletionClient", + "config": { + "temperature": 0.7, + "model": "gpt-4o" + } + }, + "termination_condition": { + "provider": "autogen_agentchat.base.OrTerminationCondition", + "component_type": "termination", + "version": 1, + "component_version": 1, + "label": "OrTerminationCondition", + "config": { + "conditions": [ + { + "provider": "autogen_agentchat.conditions.TextMentionTermination", + "component_type": "termination", + "version": 1, + "component_version": 1, + "description": "Terminate the conversation if a specific text is mentioned.", + "label": "TextMentionTermination", + "config": { + "text": "TERMINATE" + } + }, + { + "provider": "autogen_agentchat.conditions.MaxMessageTermination", + "component_type": "termination", + "version": 1, + "component_version": 1, + "description": "Terminate the conversation after a maximum number of messages have been exchanged.", + "label": "MaxMessageTermination", + "config": { + "max_messages": 30 + } + } + ] + } + }, + "selector_prompt": "You are coordinating a research team by selecting the team member to speak/act next. The following team member roles are available:\n {roles}.\n The research_assistant performs searches and analyzes information.\n The verifier evaluates progress and ensures completeness.\n The summary_agent provides a detailed markdown summary of the research as a report to the user.\n\n Given the current context, select the most appropriate next speaker.\n The research_assistant should search and analyze.\n The verifier should evaluate progress and guide the research (select this role is there is a need to verify/evaluate progress). You should ONLY select the summary_agent role if the research is complete and it is time to generate a report.\n\n Base your selection on:\n 1. Current stage of research\n 2. Last speaker's findings or suggestions\n 3. Need for verification vs need for new information\n \n Read the following conversation. Then select the next role from {participants} to play. Only return the role.\n\n {history}\n\n Read the above conversation. Then select the next role from {participants} to play. ONLY RETURN THE ROLE.", + "allow_repeated_speaker": true + } } ], "components": { @@ -283,7 +527,7 @@ "description": "Create custom tools by wrapping standard Python functions.", "label": "FunctionTool", "config": { - "source_code": "def calculator(a: float, b: float, operator: str) -> str:\n try:\n if operator == \"+\":\n return str(a + b)\n elif operator == \"-\":\n return str(a - b)\n elif operator == \"*\":\n return str(a * b)\n elif operator == \"/\":\n if b == 0:\n return \"Error: Division by zero\"\n return str(a / b)\n else:\n return \"Error: Invalid operator. Please use +, -, *, or /\"\n except Exception as e:\n return f\"Error: {str(e)}\"\n", + "source_code": "def calculator(a: float, b: float, operator: str) -> str:\n try:\n if operator == \"+\":\n return str(a + b)\n elif operator == \"-\":\n return str(a - b)\n elif operator == \"*\":\n return str(a * b)\n elif operator == \"/\":\n if b == 0:\n return \"Error: Division by zero\"\n return str(a / b)\n else:\n return \"Error: Invalid operator. Please use +, -, *, or /\"\n except Exception as e:\n return f\"Error: {str(e)}\"\n", "name": "calculator", "description": "A simple calculator that performs basic arithmetic operations", "global_imports": [], @@ -409,6 +653,12 @@ "label": "Mistral-7B vllm", "config": { "model": "TheBloke/Mistral-7B-Instruct-v0.2-GGUF", + "model_info": { + "vision": false, + "function_calling": true, + "json_output": false, + "family": "unknown" + }, "base_url": "http://localhost:1234/v1" } } @@ -419,10 +669,10 @@ "component_type": "tool", "version": 1, "component_version": 1, - "description": "Create custom tools by wrapping standard Python functions.", - "label": "FunctionTool", + "description": "A tool that performs basic arithmetic operations (addition, subtraction, multiplication, division).", + "label": "Calculator Tool", "config": { - "source_code": "def calculator(a: float, b: float, operator: str) -> str:\n try:\n if operator == \"+\":\n return str(a + b)\n elif operator == \"-\":\n return str(a - b)\n elif operator == \"*\":\n return str(a * b)\n elif operator == \"/\":\n if b == 0:\n return \"Error: Division by zero\"\n return str(a / b)\n else:\n return \"Error: Invalid operator. Please use +, -, *, or /\"\n except Exception as e:\n return f\"Error: {str(e)}\"\n", + "source_code": "def calculator(a: float, b: float, operator: str) -> str:\n try:\n if operator == \"+\":\n return str(a + b)\n elif operator == \"-\":\n return str(a - b)\n elif operator == \"*\":\n return str(a * b)\n elif operator == \"/\":\n if b == 0:\n return \"Error: Division by zero\"\n return str(a / b)\n else:\n return \"Error: Invalid operator. Please use +, -, *, or /\"\n except Exception as e:\n return f\"Error: {str(e)}\"\n", "name": "calculator", "description": "A simple calculator that performs basic arithmetic operations", "global_imports": [], @@ -437,10 +687,13 @@ "description": "A tool that generates images based on a text description using OpenAI's DALL-E model. Note: Requires OpenAI API key to function.", "label": "Image Generation Tool", "config": { - "source_code": "async def generate_image(\n query: str,\n output_dir: Optional[Path] = None,\n image_size: Literal[\"1024x1024\", \"512x512\", \"256x256\"] = \"1024x1024\"\n) -> List[str]:\n \"\"\"\n Generate images using OpenAI's DALL-E model based on a text description.\n\n Args:\n query: Natural language description of the desired image\n output_dir: Directory to save generated images (default: current directory)\n image_size: Size of generated image (1024x1024, 512x512, or 256x256)\n\n Returns:\n List[str]: Paths to the generated image files\n \"\"\"\n # Initialize the OpenAI client\n client = OpenAI()\n\n # Generate images using DALL-E 3\n response = client.images.generate(\n model=\"dall-e-3\",\n prompt=query,\n n=1,\n response_format=\"b64_json\",\n size=image_size\n )\n\n saved_files = []\n\n # Process the response\n if response.data:\n for image_data in response.data:\n # Generate a unique filename\n file_name = f\"{uuid.uuid4()}.png\"\n\n # Use output_dir if provided, otherwise use current directory\n file_path = Path(output_dir) / file_name if output_dir else Path(file_name)\n\n base64_str = image_data.b64_json \n img = Image.open(io.BytesIO(base64.decodebytes(bytes(base64_str, \"utf-8\")))) \n\n # Save the image to a file \n img.save(file_path) \n\n saved_files.append(str(file_path))\n\n return saved_files\n", + "source_code": "async def generate_image(\n query: str, output_dir: Optional[Path] = None, image_size: Literal[\"1024x1024\", \"512x512\", \"256x256\"] = \"1024x1024\"\n) -> List[str]:\n \"\"\"\n Generate images using OpenAI's DALL-E model based on a text description.\n\n Args:\n query: Natural language description of the desired image\n output_dir: Directory to save generated images (default: current directory)\n image_size: Size of generated image (1024x1024, 512x512, or 256x256)\n\n Returns:\n List[str]: Paths to the generated image files\n \"\"\"\n # Initialize the OpenAI client\n client = OpenAI()\n\n # Generate images using DALL-E 3\n response = client.images.generate(model=\"dall-e-3\", prompt=query, n=1, response_format=\"b64_json\", size=image_size)\n\n saved_files = []\n\n # Process the response\n if response.data:\n for image_data in response.data:\n # Generate a unique filename\n file_name = f\"{uuid.uuid4()}.png\"\n\n # Use output_dir if provided, otherwise use current directory\n file_path = Path(output_dir) / file_name if output_dir else Path(file_name)\n\n base64_str = image_data.b64_json\n img = Image.open(io.BytesIO(base64.decodebytes(bytes(base64_str, \"utf-8\"))))\n\n # Save the image to a file\n img.save(file_path)\n\n saved_files.append(str(file_path))\n\n return saved_files\n", "name": "generate_image", "description": "Generate images using DALL-E based on text descriptions.", "global_imports": [ + "io", + "uuid", + "base64", { "module": "typing", "imports": ["List", "Optional", "Literal"] @@ -449,10 +702,6 @@ "module": "pathlib", "imports": ["Path"] }, - { - "module": "uuid", - "imports": ["uuid4"] - }, { "module": "openai", "imports": ["OpenAI"] @@ -460,14 +709,6 @@ { "module": "PIL", "imports": ["Image"] - }, - { - "module": "io", - "imports": ["BytesIO"] - }, - { - "module": "base64", - "imports": ["decodebytes"] } ], "has_cancellation_support": false @@ -481,10 +722,13 @@ "description": "A tool that generates a PDF file from a list of images.Requires the PyFPDF and pillow library to function.", "label": "PDF Generation Tool", "config": { - "source_code": "async def generate_pdf(\n sections: List[Dict[str, Optional[str]]], \n output_file: str = \"report.pdf\",\n report_title: str = \"PDF Report\"\n) -> str:\n \"\"\"\n Generate a PDF report with formatted sections including text and images.\n\n Args:\n sections: List of dictionaries containing section details with keys:\n - title: Section title\n - level: Heading level (title, h1, h2)\n - content: Section text content\n - image: Optional image URL or file path\n output_file: Name of output PDF file\n report_title: Title shown at top of report\n\n Returns:\n str: Path to the generated PDF file\n \"\"\"\n def normalize_text(text: str) -> str:\n \"\"\"Normalize Unicode text to ASCII.\"\"\"\n return unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('ascii')\n\n def get_image(image_url_or_path):\n \"\"\"Fetch image from URL or local path.\"\"\"\n if image_url_or_path.startswith((\"http://\", \"https://\")):\n response = requests.get(image_url_or_path)\n if response.status_code == 200:\n return BytesIO(response.content)\n elif Path(image_url_or_path).is_file():\n return open(image_url_or_path, 'rb')\n return None\n\n def add_rounded_corners(img, radius=6):\n \"\"\"Add rounded corners to an image.\"\"\"\n mask = Image.new('L', img.size, 0)\n draw = ImageDraw.Draw(mask)\n draw.rounded_rectangle([(0, 0), img.size], radius, fill=255)\n img = ImageOps.fit(img, mask.size, centering=(0.5, 0.5))\n img.putalpha(mask)\n return img\n\n class PDF(FPDF):\n \"\"\"Custom PDF class with header and content formatting.\"\"\"\n def header(self):\n self.set_font(\"Arial\", \"B\", 12)\n normalized_title = normalize_text(report_title)\n self.cell(0, 10, normalized_title, 0, 1, \"C\")\n\n def chapter_title(self, txt): \n self.set_font(\"Arial\", \"B\", 12)\n normalized_txt = normalize_text(txt)\n self.cell(0, 10, normalized_txt, 0, 1, \"L\")\n self.ln(2)\n\n def chapter_body(self, body):\n self.set_font(\"Arial\", \"\", 12)\n normalized_body = normalize_text(body)\n self.multi_cell(0, 10, normalized_body)\n self.ln()\n\n def add_image(self, img_data):\n img = Image.open(img_data)\n img = add_rounded_corners(img)\n img_path = Path(f\"temp_{uuid.uuid4().hex}.png\")\n img.save(img_path, format=\"PNG\")\n self.image(str(img_path), x=None, y=None, w=190 if img.width > 190 else img.width)\n self.ln(10)\n img_path.unlink()\n\n # Initialize PDF\n pdf = PDF()\n pdf.add_page()\n font_size = {\"title\": 16, \"h1\": 14, \"h2\": 12, \"body\": 12}\n\n # Add sections\n for section in sections:\n title = section.get(\"title\", \"\")\n level = section.get(\"level\", \"h1\")\n content = section.get(\"content\", \"\")\n image = section.get(\"image\")\n\n pdf.set_font(\"Arial\", \"B\" if level in font_size else \"\", font_size.get(level, font_size[\"body\"]))\n pdf.chapter_title(title)\n\n if content:\n pdf.chapter_body(content)\n\n if image:\n img_data = get_image(image)\n if img_data:\n pdf.add_image(img_data)\n if isinstance(img_data, BytesIO):\n img_data.close()\n\n pdf.output(output_file)\n return output_file\n", + "source_code": "async def generate_pdf(\n sections: List[Dict[str, Optional[str]]], output_file: str = \"report.pdf\", report_title: str = \"PDF Report\"\n) -> str:\n \"\"\"\n Generate a PDF report with formatted sections including text and images.\n\n Args:\n sections: List of dictionaries containing section details with keys:\n - title: Section title\n - level: Heading level (title, h1, h2)\n - content: Section text content\n - image: Optional image URL or file path\n output_file: Name of output PDF file\n report_title: Title shown at top of report\n\n Returns:\n str: Path to the generated PDF file\n \"\"\"\n\n def normalize_text(text: str) -> str:\n \"\"\"Normalize Unicode text to ASCII.\"\"\"\n return unicodedata.normalize(\"NFKD\", text).encode(\"ascii\", \"ignore\").decode(\"ascii\")\n\n def get_image(image_url_or_path):\n \"\"\"Fetch image from URL or local path.\"\"\"\n if image_url_or_path.startswith((\"http://\", \"https://\")):\n response = requests.get(image_url_or_path)\n if response.status_code == 200:\n return BytesIO(response.content)\n elif Path(image_url_or_path).is_file():\n return open(image_url_or_path, \"rb\")\n return None\n\n def add_rounded_corners(img, radius=6):\n \"\"\"Add rounded corners to an image.\"\"\"\n mask = Image.new(\"L\", img.size, 0)\n draw = ImageDraw.Draw(mask)\n draw.rounded_rectangle([(0, 0), img.size], radius, fill=255)\n img = ImageOps.fit(img, mask.size, centering=(0.5, 0.5))\n img.putalpha(mask)\n return img\n\n class PDF(FPDF):\n \"\"\"Custom PDF class with header and content formatting.\"\"\"\n\n def header(self):\n self.set_font(\"Arial\", \"B\", 12)\n normalized_title = normalize_text(report_title)\n self.cell(0, 10, normalized_title, 0, 1, \"C\")\n\n def chapter_title(self, txt):\n self.set_font(\"Arial\", \"B\", 12)\n normalized_txt = normalize_text(txt)\n self.cell(0, 10, normalized_txt, 0, 1, \"L\")\n self.ln(2)\n\n def chapter_body(self, body):\n self.set_font(\"Arial\", \"\", 12)\n normalized_body = normalize_text(body)\n self.multi_cell(0, 10, normalized_body)\n self.ln()\n\n def add_image(self, img_data):\n img = Image.open(img_data)\n img = add_rounded_corners(img)\n img_path = Path(f\"temp_{uuid.uuid4().hex}.png\")\n img.save(img_path, format=\"PNG\")\n self.image(str(img_path), x=None, y=None, w=190 if img.width > 190 else img.width)\n self.ln(10)\n img_path.unlink()\n\n # Initialize PDF\n pdf = PDF()\n pdf.add_page()\n font_size = {\"title\": 16, \"h1\": 14, \"h2\": 12, \"body\": 12}\n\n # Add sections\n for section in sections:\n title = section.get(\"title\", \"\")\n level = section.get(\"level\", \"h1\")\n content = section.get(\"content\", \"\")\n image = section.get(\"image\")\n\n pdf.set_font(\"Arial\", \"B\" if level in font_size else \"\", font_size.get(level, font_size[\"body\"]))\n pdf.chapter_title(title)\n\n if content:\n pdf.chapter_body(content)\n\n if image:\n img_data = get_image(image)\n if img_data:\n pdf.add_image(img_data)\n if isinstance(img_data, BytesIO):\n img_data.close()\n\n pdf.output(output_file)\n return output_file\n", "name": "generate_pdf", "description": "Generate PDF reports with formatted sections containing text and images", "global_imports": [ + "uuid", + "requests", + "unicodedata", { "module": "typing", "imports": ["List", "Dict", "Optional"] @@ -493,14 +737,6 @@ "module": "pathlib", "imports": ["Path"] }, - { - "module": "uuid", - "imports": ["uuid4"] - }, - { - "module": "requests", - "imports": ["get"] - }, { "module": "fpdf", "imports": ["FPDF"] @@ -512,10 +748,6 @@ { "module": "io", "imports": ["BytesIO"] - }, - { - "module": "unicodedata", - "imports": ["normalize"] } ], "has_cancellation_support": false @@ -526,21 +758,20 @@ "component_type": "tool", "version": 1, "component_version": 1, - "description": "A tool that generates a webpage from a list of images. Requires beautifulsoup4 html2text library to function.", - "label": "Webpage Generation Tool", + "description": "A tool that fetches the content of a webpage and converts it to markdown. Requires the requests and beautifulsoup4 library to function.", + "label": "Fetch Webpage Tool", "config": { - "source_code": "async def fetch_webpage(\n url: str,\n include_images: bool = True,\n max_length: Optional[int] = None,\n headers: Optional[Dict[str, str]] = None\n) -> str:\n \"\"\"\n Fetch a webpage and convert it to markdown format.\n\n Args:\n url: The URL of the webpage to fetch\n include_images: Whether to include image references in the markdown\n max_length: Maximum length of the output markdown (if None, no limit)\n headers: Optional HTTP headers for the request\n\n Returns:\n str: Markdown version of the webpage content\n\n Raises:\n ValueError: If the URL is invalid or the page can't be fetched\n \"\"\"\n # Use default headers if none provided\n if headers is None:\n headers = {\n 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'\n }\n\n try:\n # Fetch the webpage\n response = requests.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n # Parse HTML\n soup = BeautifulSoup(response.text, 'html.parser')\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all(['a', 'img']):\n if tag.get('href'):\n tag['href'] = urljoin(url, tag['href'])\n if tag.get('src'):\n tag['src'] = urljoin(url, tag['src'])\n\n # Configure HTML to Markdown converter\n h2t = html2text.HTML2Text()\n h2t.body_width = 0 # No line wrapping\n h2t.ignore_images = not include_images\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n # Convert to markdown\n markdown = h2t.handle(str(soup))\n\n # Trim if max_length is specified\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except requests.RequestException as e:\n raise ValueError(f\"Failed to fetch webpage: {str(e)}\")\n except Exception as e:\n raise ValueError(f\"Error processing webpage: {str(e)}\")\n", + "source_code": "async def fetch_webpage(\n url: str, include_images: bool = True, max_length: Optional[int] = None, headers: Optional[Dict[str, str]] = None\n) -> str:\n \"\"\"Fetch a webpage and convert it to markdown format.\n\n Args:\n url: The URL of the webpage to fetch\n include_images: Whether to include image references in the markdown\n max_length: Maximum length of the output markdown (if None, no limit)\n headers: Optional HTTP headers for the request\n\n Returns:\n str: Markdown version of the webpage content\n\n Raises:\n ValueError: If the URL is invalid or the page can't be fetched\n \"\"\"\n # Use default headers if none provided\n if headers is None:\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n # Fetch the webpage\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n # Parse HTML\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n # Configure HTML to Markdown converter\n h2t = html2text.HTML2Text()\n h2t.body_width = 0 # No line wrapping\n h2t.ignore_images = not include_images\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n # Convert to markdown\n markdown = h2t.handle(str(soup))\n\n # Trim if max_length is specified\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except httpx.RequestError as e:\n raise ValueError(f\"Failed to fetch webpage: {str(e)}\") from e\n except Exception as e:\n raise ValueError(f\"Error processing webpage: {str(e)}\") from e\n", "name": "fetch_webpage", "description": "Fetch a webpage and convert it to markdown format, with options for including images and limiting length", "global_imports": [ + "os", + "html2text", { "module": "typing", "imports": ["Optional", "Dict"] }, - { - "module": "requests", - "imports": ["get", "RequestException"] - }, + "httpx", { "module": "bs4", "imports": ["BeautifulSoup"] @@ -565,25 +796,25 @@ "description": "A tool that performs Bing searches using the Bing Web Search API. Requires the requests library, BING_SEARCH_KEY env variable to function.", "label": "Bing Search Tool", "config": { - "source_code": "async def bing_search(\n query: str,\n num_results: int = 5,\n include_snippets: bool = True,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: str = \"moderate\",\n response_filter: str = \"webpages\"\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Bing search using the Bing Web Search API. \n\n Args:\n query: Search query string\n num_results: Number of results to return (max 50)\n include_snippets: Include result snippets in output\n language: Language code for search results (e.g., 'en', 'es', 'fr')\n country: Optional market code for search results (e.g., 'us', 'uk')\n safe_search: SafeSearch setting ('off', 'moderate', or 'strict')\n response_filter: Type of results ('webpages', 'news', 'images', or 'videos')\n\n Returns:\n List[Dict[str, str]]: List of search results\n\n Raises:\n ValueError: If API credentials are invalid or request fails\n \"\"\"\n # Get and validate API key\n api_key = os.getenv('BING_SEARCH_KEY', '').strip()\n\n if not api_key:\n raise ValueError(\n \"BING_SEARCH_KEY environment variable is not set. \"\n \"Please obtain an API key from Azure Portal.\"\n )\n\n # Validate safe_search parameter\n valid_safe_search = ['off', 'moderate', 'strict']\n if safe_search.lower() not in valid_safe_search:\n raise ValueError(\n f\"Invalid safe_search value. Must be one of: {', '.join(valid_safe_search)}\"\n )\n\n # Validate response_filter parameter\n valid_filters = ['webpages', 'news', 'images', 'videos']\n if response_filter.lower() not in valid_filters:\n raise ValueError(\n f\"Invalid response_filter value. Must be one of: {', '.join(valid_filters)}\"\n )\n\n # Build request headers and parameters\n headers = {\n 'Ocp-Apim-Subscription-Key': api_key,\n 'Accept': 'application/json'\n }\n\n params = {\n 'q': query,\n 'count': min(max(1, num_results), 50),\n 'mkt': f'{language}-{country.upper()}' if country else language,\n 'safeSearch': safe_search.capitalize(),\n 'responseFilter': response_filter,\n 'textFormat': 'raw'\n }\n\n # Make the request\n try:\n response = requests.get(\n 'https://api.bing.microsoft.com/v7.0/search',\n headers=headers,\n params=params,\n timeout=10\n )\n\n # Handle common error cases\n if response.status_code == 401:\n raise ValueError(\n \"Authentication failed. Please verify your Bing Search API key.\"\n )\n elif response.status_code == 403:\n raise ValueError(\n \"Access forbidden. This could mean:\\n\"\n \"1. The API key is invalid\\n\"\n \"2. The API key has expired\\n\"\n \"3. You've exceeded your API quota\"\n )\n elif response.status_code == 429:\n raise ValueError(\"API quota exceeded. Please try again later.\")\n\n response.raise_for_status()\n data = response.json()\n\n # Process results based on response_filter\n results = []\n if response_filter == 'webpages' and 'webPages' in data:\n items = data['webPages']['value']\n elif response_filter == 'news' and 'news' in data:\n items = data['news']['value']\n elif response_filter == 'images' and 'images' in data:\n items = data['images']['value']\n elif response_filter == 'videos' and 'videos' in data:\n items = data['videos']['value']\n else:\n if not any(key in data for key in ['webPages', 'news', 'images', 'videos']):\n return [] # No results found\n raise ValueError(f\"No {response_filter} results found in API response\")\n\n # Extract relevant information based on result type\n for item in items:\n result = {'title': item.get('name', '')}\n\n if response_filter == 'webpages':\n result['link'] = item.get('url', '')\n if include_snippets:\n result['snippet'] = item.get('snippet', '')\n\n elif response_filter == 'news':\n result['link'] = item.get('url', '')\n if include_snippets:\n result['snippet'] = item.get('description', '')\n result['date'] = item.get('datePublished', '')\n\n elif response_filter == 'images':\n result['link'] = item.get('contentUrl', '')\n result['thumbnail'] = item.get('thumbnailUrl', '')\n if include_snippets:\n result['snippet'] = item.get('description', '')\n\n elif response_filter == 'videos':\n result['link'] = item.get('contentUrl', '')\n result['thumbnail'] = item.get('thumbnailUrl', '')\n if include_snippets:\n result['snippet'] = item.get('description', '')\n result['duration'] = item.get('duration', '')\n\n results.append(result)\n\n return results[:num_results]\n\n except requests.RequestException as e:\n error_msg = str(e)\n if \"InvalidApiKey\" in error_msg:\n raise ValueError(\n \"Invalid API key. Please check your BING_SEARCH_KEY environment variable.\"\n )\n elif \"KeyExpired\" in error_msg:\n raise ValueError(\"API key has expired. Please generate a new key.\")\n else:\n raise ValueError(f\"Search request failed: {error_msg}\")\n except json.JSONDecodeError:\n raise ValueError(\n \"Failed to parse API response. \"\n \"Please verify your API credentials and try again.\"\n )\n except Exception as e:\n raise ValueError(f\"Unexpected error during search: {str(e)}\")\n", + "source_code": "async def bing_search(\n query: str,\n num_results: int = 5,\n include_snippets: bool = True,\n include_content: bool = True,\n content_max_length: Optional[int] = 15000,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: str = \"moderate\",\n response_filter: str = \"webpages\",\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Bing search using the Bing Web Search API.\n\n Args:\n query: Search query string\n num_results: Number of results to return (max 50)\n include_snippets: Include result snippets in output\n include_content: Include full webpage content in markdown format\n content_max_length: Maximum length of webpage content (if included)\n language: Language code for search results (e.g., 'en', 'es', 'fr')\n country: Optional market code for search results (e.g., 'us', 'uk')\n safe_search: SafeSearch setting ('off', 'moderate', or 'strict')\n response_filter: Type of results ('webpages', 'news', 'images', or 'videos')\n\n Returns:\n List[Dict[str, str]]: List of search results\n\n Raises:\n ValueError: If API credentials are invalid or request fails\n \"\"\"\n # Get and validate API key\n api_key = os.getenv(\"BING_SEARCH_KEY\", \"\").strip()\n\n if not api_key:\n raise ValueError(\n \"BING_SEARCH_KEY environment variable is not set. \" \"Please obtain an API key from Azure Portal.\"\n )\n\n # Validate safe_search parameter\n valid_safe_search = [\"off\", \"moderate\", \"strict\"]\n if safe_search.lower() not in valid_safe_search:\n raise ValueError(f\"Invalid safe_search value. Must be one of: {', '.join(valid_safe_search)}\")\n\n # Validate response_filter parameter\n valid_filters = [\"webpages\", \"news\", \"images\", \"videos\"]\n if response_filter.lower() not in valid_filters:\n raise ValueError(f\"Invalid response_filter value. Must be one of: {', '.join(valid_filters)}\")\n\n async def fetch_page_content(url: str, max_length: Optional[int] = 50000) -> str:\n \"\"\"Helper function to fetch and convert webpage content to markdown\"\"\"\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n h2t = html2text.HTML2Text()\n h2t.body_width = 0\n h2t.ignore_images = False\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n markdown = h2t.handle(str(soup))\n\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except Exception as e:\n return f\"Error fetching content: {str(e)}\"\n\n # Build request headers and parameters\n headers = {\"Ocp-Apim-Subscription-Key\": api_key, \"Accept\": \"application/json\"}\n\n params = {\n \"q\": query,\n \"count\": min(max(1, num_results), 50),\n \"mkt\": f\"{language}-{country.upper()}\" if country else language,\n \"safeSearch\": safe_search.capitalize(),\n \"responseFilter\": response_filter,\n \"textFormat\": \"raw\",\n }\n\n # Make the request\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(\n \"https://api.bing.microsoft.com/v7.0/search\", headers=headers, params=params, timeout=10\n )\n\n # Handle common error cases\n if response.status_code == 401:\n raise ValueError(\"Authentication failed. Please verify your Bing Search API key.\")\n elif response.status_code == 403:\n raise ValueError(\n \"Access forbidden. This could mean:\\n\"\n \"1. The API key is invalid\\n\"\n \"2. The API key has expired\\n\"\n \"3. You've exceeded your API quota\"\n )\n elif response.status_code == 429:\n raise ValueError(\"API quota exceeded. Please try again later.\")\n\n response.raise_for_status()\n data = response.json()\n\n # Process results based on response_filter\n results = []\n if response_filter == \"webpages\" and \"webPages\" in data:\n items = data[\"webPages\"][\"value\"]\n elif response_filter == \"news\" and \"news\" in data:\n items = data[\"news\"][\"value\"]\n elif response_filter == \"images\" and \"images\" in data:\n items = data[\"images\"][\"value\"]\n elif response_filter == \"videos\" and \"videos\" in data:\n items = data[\"videos\"][\"value\"]\n else:\n if not any(key in data for key in [\"webPages\", \"news\", \"images\", \"videos\"]):\n return [] # No results found\n raise ValueError(f\"No {response_filter} results found in API response\")\n\n # Extract relevant information based on result type\n for item in items:\n result = {\"title\": item.get(\"name\", \"\")}\n\n if response_filter == \"webpages\":\n result[\"link\"] = item.get(\"url\", \"\")\n if include_snippets:\n result[\"snippet\"] = item.get(\"snippet\", \"\")\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n elif response_filter == \"news\":\n result[\"link\"] = item.get(\"url\", \"\")\n if include_snippets:\n result[\"snippet\"] = item.get(\"description\", \"\")\n result[\"date\"] = item.get(\"datePublished\", \"\")\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n elif response_filter == \"images\":\n result[\"link\"] = item.get(\"contentUrl\", \"\")\n result[\"thumbnail\"] = item.get(\"thumbnailUrl\", \"\")\n if include_snippets:\n result[\"snippet\"] = item.get(\"description\", \"\")\n\n elif response_filter == \"videos\":\n result[\"link\"] = item.get(\"contentUrl\", \"\")\n result[\"thumbnail\"] = item.get(\"thumbnailUrl\", \"\")\n if include_snippets:\n result[\"snippet\"] = item.get(\"description\", \"\")\n result[\"duration\"] = item.get(\"duration\", \"\")\n\n results.append(result)\n\n return results[:num_results]\n\n except httpx.RequestException as e:\n error_msg = str(e)\n if \"InvalidApiKey\" in error_msg:\n raise ValueError(\"Invalid API key. Please check your BING_SEARCH_KEY environment variable.\") from e\n elif \"KeyExpired\" in error_msg:\n raise ValueError(\"API key has expired. Please generate a new key.\") from e\n else:\n raise ValueError(f\"Search request failed: {error_msg}\") from e\n except json.JSONDecodeError:\n raise ValueError(\"Failed to parse API response. \" \"Please verify your API credentials and try again.\") from None\n except Exception as e:\n raise ValueError(f\"Unexpected error during search: {str(e)}\") from e\n", "name": "bing_search", - "description": "\n Perform Bing searches using the Bing Web Search API. \n Requires BING_SEARCH_KEY environment variable.\n Supports web, news, image, and video searches.\n See function documentation for detailed setup instructions.\n ", + "description": "\n Perform Bing searches using the Bing Web Search API. Requires BING_SEARCH_KEY environment variable.\n Supports web, news, image, and video searches.\n See function documentation for detailed setup instructions.\n ", "global_imports": [ { "module": "typing", "imports": ["List", "Dict", "Optional"] }, + "os", + "httpx", + "json", + "html2text", { - "module": "os", - "imports": ["getenv"] - }, - { - "module": "requests", - "imports": ["get", "RequestException"] + "module": "bs4", + "imports": ["BeautifulSoup"] }, { - "module": "json", - "imports": ["JSONDecodeError"] + "module": "urllib.parse", + "imports": ["urljoin"] } ], "has_cancellation_support": false @@ -597,21 +828,24 @@ "description": "A tool that performs Google searches using the Google Custom Search API. Requires the requests library, [GOOGLE_API_KEY, GOOGLE_CSE_ID] to be set, env variable to function.", "label": "Google Search Tool", "config": { - "source_code": "async def google_search(\n query: str,\n num_results: int = 5,\n include_snippets: bool = True,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: bool = True\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Google search using the Custom Search API. \n Args:\n query: Search query string\n num_results: Number of results to return (max 10)\n include_snippets: Include result snippets in output\n language: Language code for search results (e.g., 'en', 'es', 'fr')\n country: Optional country code for search results (e.g., 'us', 'uk')\n safe_search: Enable safe search filtering\n\n Returns:\n List[Dict[str, str]]: List of search results, each containing:\n - title: Result title\n - link: Result URL\n - snippet: Result description (if include_snippets=True)\n\n Raises:\n ValueError: If API keys are missing or invalid, or if the request fails\n \"\"\"\n # Get API credentials\n api_key = os.getenv('GOOGLE_API_KEY')\n cse_id = os.getenv('GOOGLE_CSE_ID')\n\n if not api_key or not cse_id:\n raise ValueError(\n \"Missing required environment variables. Please set GOOGLE_API_KEY and GOOGLE_CSE_ID. \"\n \"See function documentation for setup instructions.\"\n )\n\n # Ensure num_results is within allowed range\n num_results = min(max(1, num_results), 10)\n\n # Build request parameters\n params = {\n 'key': api_key,\n 'cx': cse_id,\n 'q': query,\n 'num': num_results,\n 'hl': language,\n 'safe': 'active' if safe_search else 'off',\n }\n\n if country:\n params['gl'] = country\n\n # Make the request\n try:\n response = requests.get(\n 'https://www.googleapis.com/customsearch/v1',\n params=params,\n timeout=10\n )\n response.raise_for_status()\n data = response.json()\n\n # Process results\n results = []\n if 'items' in data:\n for item in data['items']:\n result = {\n 'title': item.get('title', ''),\n 'link': item.get('link', '')\n }\n if include_snippets:\n result['snippet'] = item.get('snippet', '')\n results.append(result)\n\n return results\n\n except requests.RequestException as e:\n raise ValueError(f\"Failed to perform search: {str(e)}\")\n except KeyError as e:\n raise ValueError(f\"Invalid API response format: {str(e)}\")\n except Exception as e:\n raise ValueError(f\"Error during search: {str(e)}\")\n", + "source_code": "async def google_search(\n query: str,\n num_results: int = 5,\n include_snippets: bool = True,\n include_content: bool = True,\n content_max_length: Optional[int] = 15000,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: bool = True,\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Google search using the Custom Search API and optionally fetch webpage content.\n\n Args:\n query: Search query string\n num_results: Number of results to return (max 10)\n include_snippets: Include result snippets in output\n include_content: Include full webpage content in markdown format\n content_max_length: Maximum length of webpage content (if included)\n language: Language code for search results (e.g., en, es, fr)\n country: Optional country code for search results (e.g., us, uk)\n safe_search: Enable safe search filtering\n\n Returns:\n List[Dict[str, str]]: List of search results, each containing:\n - title: Result title\n - link: Result URL\n - snippet: Result description (if include_snippets=True)\n - content: Webpage content in markdown (if include_content=True)\n \"\"\"\n api_key = os.getenv(\"GOOGLE_API_KEY\")\n cse_id = os.getenv(\"GOOGLE_CSE_ID\")\n\n if not api_key or not cse_id:\n raise ValueError(\"Missing required environment variables. Please set GOOGLE_API_KEY and GOOGLE_CSE_ID.\")\n\n num_results = min(max(1, num_results), 10)\n\n async def fetch_page_content(url: str, max_length: Optional[int] = 50000) -> str:\n \"\"\"Helper function to fetch and convert webpage content to markdown\"\"\"\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n h2t = html2text.HTML2Text()\n h2t.body_width = 0\n h2t.ignore_images = False\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n markdown = h2t.handle(str(soup))\n\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except Exception as e:\n return f\"Error fetching content: {str(e)}\"\n\n params = {\n \"key\": api_key,\n \"cx\": cse_id,\n \"q\": query,\n \"num\": num_results,\n \"hl\": language,\n \"safe\": \"active\" if safe_search else \"off\",\n }\n\n if country:\n params[\"gl\"] = country\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(\"https://www.googleapis.com/customsearch/v1\", params=params, timeout=10)\n response.raise_for_status()\n data = response.json()\n\n results = []\n if \"items\" in data:\n for item in data[\"items\"]:\n result = {\"title\": item.get(\"title\", \"\"), \"link\": item.get(\"link\", \"\")}\n if include_snippets:\n result[\"snippet\"] = item.get(\"snippet\", \"\")\n\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n results.append(result)\n\n return results\n\n except httpx.RequestError as e:\n raise ValueError(f\"Failed to perform search: {str(e)}\") from e\n except KeyError as e:\n raise ValueError(f\"Invalid API response format: {str(e)}\") from e\n except Exception as e:\n raise ValueError(f\"Error during search: {str(e)}\") from e\n", "name": "google_search", - "description": "\n Perform Google searches using the Custom Search API. \n Requires GOOGLE_API_KEY and GOOGLE_CSE_ID environment variables to be set.\n See function documentation for setup instructions.\n ", + "description": "\n Perform Google searches using the Custom Search API with optional webpage content fetching.\n Requires GOOGLE_API_KEY and GOOGLE_CSE_ID environment variables to be set.\n ", "global_imports": [ { "module": "typing", "imports": ["List", "Dict", "Optional"] }, + "os", + "httpx", + "html2text", { - "module": "os", - "imports": ["getenv"] + "module": "bs4", + "imports": ["BeautifulSoup"] }, { - "module": "requests", - "imports": ["get", "RequestException"] + "module": "urllib.parse", + "imports": ["urljoin"] } ], "has_cancellation_support": false diff --git a/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx b/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx index 7c8edc0e2947..465554b58a74 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx @@ -151,7 +151,7 @@ export const useGalleryStore = create()( }, }), { - name: "gallery-storage-v2", + name: "gallery-storage-v3", } ) ); diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/chat.tsx b/python/packages/autogen-studio/frontend/src/components/views/session/chat/chat.tsx index 8f5ea8bcd39a..e5b3bd11d070 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/session/chat/chat.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/session/chat/chat.tsx @@ -148,8 +148,6 @@ export default function ChatView({ session }: ChatViewProps) { setCurrentRun((current) => { if (!current || !session?.id) return null; - console.log("WebSocket message:", message); - switch (message.type) { case "error": if (inputTimeoutRef.current) { @@ -305,7 +303,6 @@ export default function ChatView({ session }: ChatViewProps) { } try { - console.log("Sending input response:", response); activeSocketRef.current.send( JSON.stringify({ type: "input_response", diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/runview.tsx b/python/packages/autogen-studio/frontend/src/components/views/session/chat/runview.tsx index 005febf5042c..ddf662e7629e 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/session/chat/runview.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/session/chat/runview.tsx @@ -125,12 +125,6 @@ const RunView: React.FC = ({ const lastResultMessage = run.team_result?.task_result.messages.slice(-1)[0]; const lastMessage = run.messages.slice(-1)[0]; - console.log("lastResultMessage", lastResultMessage); - console.log( - "lastMessage", - run.messages[run.messages.length - 1]?.config?.content - ); - return (
{/* Run Header */} diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/manager.tsx b/python/packages/autogen-studio/frontend/src/components/views/session/manager.tsx index 49d72b20ab6a..97be1dd0d045 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/session/manager.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/session/manager.tsx @@ -128,7 +128,6 @@ export const SessionManager: React.FC = () => { const handleQuickStart = async (teamId: number, teamName: string) => { if (!user?.email) return; - console.log("Quick start session", teamId, teamName); try { const defaultName = `${teamName.substring( 0, diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/newsession.tsx b/python/packages/autogen-studio/frontend/src/components/views/session/newsession.tsx index f2288773d802..611b6bf57b3b 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/session/newsession.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/session/newsession.tsx @@ -71,14 +71,14 @@ const NewSessionControls = ({ // Update state first setSelectedTeamId(newTeamId); - // Save to localStorage - if (typeof window !== "undefined") { - localStorage.setItem("lastUsedTeamId", e.key); - } - - // Delay the session start to allow UI to update - await new Promise((resolve) => setTimeout(resolve, 100)); - onStartSession(newTeamId, selectedTeam.component.label || ""); + // // Save to localStorage + // if (typeof window !== "undefined") { + // localStorage.setItem("lastUsedTeamId", e.key); + // } + + // // Delay the session start to allow UI to update + // await new Promise((resolve) => setTimeout(resolve, 100)); + // onStartSession(newTeamId, selectedTeam.component.label || ""); }; const hasNoTeams = !isLoading && teams.length === 0; diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/sidebar.tsx b/python/packages/autogen-studio/frontend/src/components/views/session/sidebar.tsx index fdc77cc6f981..44530aba1c95 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/session/sidebar.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/session/sidebar.tsx @@ -103,6 +103,10 @@ export const Sidebar: React.FC = ({
Recents{" "} + + {" "} + ({sessions.length}){" "} + {" "} {isLoading && ( )} @@ -117,7 +121,7 @@ export const Sidebar: React.FC = ({
)} -
+
{sessions.map((s) => (
= ({ // const [isDirty, setIsDirty] = useState(false); const editorRef = useRef(null); const [messageApi, contextHolder] = message.useMessage(); + const [activeDragItem, setActiveDragItem] = useState( + null + ); + + const [testDrawerVisible, setTestDrawerVisible] = useState(false); const { undo, @@ -262,11 +284,23 @@ export const TeamBuilder: React.FC = ({ // Pass both new node data AND target node id addNode(position, draggedItem.config, nodeId); + setActiveDragItem(null); + }; + + const handleTestDrawerClose = () => { + console.log("TestDrawer closed"); + setTestDrawerVisible(false); }; const onDragStart = (item: DragItem) => { // We can add any drag start logic here if needed }; + const handleDragStart = (event: DragStartEvent) => { + const { active } = event; + if (active.data.current) { + setActiveDragItem(active.data.current as DragItemData); + } + }; return (
{contextHolder} @@ -304,6 +338,18 @@ export const TeamBuilder: React.FC = ({
+ + +
); }; diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx index 34118174e464..65b277b6bfe1 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx @@ -51,7 +51,7 @@ const PresetItem: React.FC = ({ const style = { transform: CSS.Transform.toString(transform), - opacity: isDragging ? 0.5 : undefined, + opacity: isDragging ? 0.8 : undefined, }; return ( @@ -60,7 +60,7 @@ const PresetItem: React.FC = ({ style={style} {...attributes} {...listeners} - className="p-2 text-primary mb-2 border border-secondary rounded cursor-move hover:bg-secondary transition-colors " + className={`p-2 text-primary mb-2 border rounded cursor-move bg-secondary transition-colors`} >
@@ -178,7 +178,7 @@ export const ComponentLibrary: React.FC = () => { return (
diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/testdrawer.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/testdrawer.tsx new file mode 100644 index 000000000000..0848abd7f5b3 --- /dev/null +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/testdrawer.tsx @@ -0,0 +1,98 @@ +import React, { useContext, useEffect, useState } from "react"; + +import { Drawer, Button, message, Checkbox } from "antd"; +import { Team, Session } from "../../../types/datamodel"; +import ChatView from "../../session/chat/chat"; +import { appContext } from "../../../../hooks/provider"; +import { sessionAPI } from "../../session/api"; + +interface TestDrawerProps { + isVisble: boolean; + team: Team; + onClose: () => void; +} + +const TestDrawer = ({ isVisble, onClose, team }: TestDrawerProps) => { + const [session, setSession] = useState(null); + const { user } = useContext(appContext); + const [loading, setLoading] = useState(false); + const [deleteOnClose, setDeleteOnClose] = useState(true); + const [messageApi, contextHolder] = message.useMessage(); + + const createSession = async (teamId: number, teamName: string) => { + if (!user?.email) return; + try { + const defaultName = `Test Session ${teamName.substring( + 0, + 20 + )} - ${new Date().toLocaleString()} `; + const created = await sessionAPI.createSession( + { + name: defaultName, + team_id: teamId, + }, + user.email + ); + setSession(created); + } catch (error) { + messageApi.error("Error creating session"); + } + }; + + const deleteSession = async (sessionId: number) => { + if (!user?.email) return; + try { + await sessionAPI.deleteSession(sessionId, user.email); + setSession(null); // Clear session state after successful deletion + } catch (error) { + messageApi.error("Error deleting session"); + } + }; + + // Single effect to handle session creation when drawer opens + useEffect(() => { + if (isVisble && team?.id && !session) { + setLoading(true); + createSession( + team.id, + team.component.label || team.component.component_type + ).finally(() => { + setLoading(false); + }); + } + }, [isVisble, team?.id]); + + // Single cleanup handler in the Drawer's onClose + const handleClose = async () => { + if (session?.id && deleteOnClose) { + // Only delete if flag is true + await deleteSession(session.id); + } + onClose(); + }; + + return ( +
+ {contextHolder} + Test Team: {team.component.label}} + size="large" + placement="right" + onClose={handleClose} + open={isVisble} + extra={ + setDeleteOnClose(e.target.checked)} + > + Delete session on close + + } + > + {loading &&

Creating a test session...

} + {session && } +
+
+ ); +}; +export default TestDrawer; diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx index b6aca411e4e1..e60c3dcff3e9 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx @@ -255,7 +255,9 @@ export const TeamSidebar: React.FC = ({ onClick={(e) => { e.stopPropagation(); galleryTeam.label = - galleryTeam.label + "_" + new Date().getTime(); + galleryTeam.label + + "_" + + (new Date().getTime() + "").substring(0, 5); onCreateTeam({ component: galleryTeam, }); diff --git a/python/packages/autogen-studio/pyproject.toml b/python/packages/autogen-studio/pyproject.toml index 63fc9a65794f..f4579354a604 100644 --- a/python/packages/autogen-studio/pyproject.toml +++ b/python/packages/autogen-studio/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "alembic", "loguru", "pyyaml", + "html2text", "autogen-core>=0.4.5,<0.5", "autogen-agentchat>=0.4.5,<0.5", "autogen-ext[magentic-one, openai, azure]>=0.4.2,<0.5", diff --git a/python/uv.lock b/python/uv.lock index ceb0e23ada8b..dbe382896347 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -749,7 +749,7 @@ requires-dist = [ [[package]] name = "autogenstudio" -version = "0.4.0" +version = "0.4.1" source = { editable = "packages/autogen-studio" } dependencies = [ { name = "aiofiles" }, @@ -759,6 +759,7 @@ dependencies = [ { name = "autogen-ext", extra = ["azure", "magentic-one", "openai"] }, { name = "azure-identity" }, { name = "fastapi", extra = ["standard"] }, + { name = "html2text" }, { name = "loguru" }, { name = "numpy" }, { name = "psycopg" }, @@ -790,6 +791,7 @@ requires-dist = [ { name = "azure-identity" }, { name = "fastapi", marker = "extra == 'web'" }, { name = "fastapi", extras = ["standard"] }, + { name = "html2text" }, { name = "loguru" }, { name = "numpy", specifier = "<2.0.0" }, { name = "psycopg" }, From edbd20167b1d8c87807ab04cfbe20add3400ac20 Mon Sep 17 00:00:00 2001 From: Ryan Sweet Date: Fri, 7 Feb 2025 16:28:55 -0800 Subject: [PATCH 12/36] bring back grpc service (#5377) Restoring the grpc + Orleans server into the project ## Why are these changes needed? This is the distributed agent runtime for .NET that can manage routing messages amongst a fleet of grpc agent runtimes. ## Related issue number ## Checks - [ ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Co-authored-by: Jack Gerrits Co-authored-by: Jacob Alber --- .github/workflows/dotnet-build.yml | 12 + dotnet/AutoGen.sln | 52 +++ .../src/Microsoft.AutoGen/AgentHost/Host.cs | 2 +- .../Microsoft.AutoGen.AgentHost.csproj | 2 +- .../Microsoft.AutoGen/AgentHost/Program.cs | 2 +- .../Agents/AIAgent/InferenceAgent.cs | 44 ++ .../IOAgent/ConsoleAgent/IHandleConsole.cs | 65 +++ .../Agents/IOAgent/FileAgent/IHandleFileIO.cs | 75 ++++ .../Agents/IOAgent/IProcessIO.cs | 23 + .../Agents/Microsoft.AutoGen.Agents.csproj | 24 + .../Agents/protos/agent_events.proto | 43 ++ .../Core.Grpc/GrpcAgentRuntime.cs | 2 + .../Core/AgentRuntimeExtensions.cs | 2 +- .../Core/TypePrefixSubscriptionAttribute.cs | 2 +- .../Runtime.Grpc/Abstractions/IAgentGrain.cs | 10 - .../Services/Grpc/GrpcGatewayService.cs | 61 --- .../Services/Orleans/AgentStateGrain.cs | 46 -- .../Services/Orleans/ISubscriptionsGrain.cs | 10 - .../Services/Orleans/SubscriptionsGrain.cs | 50 --- .../Orleans/Surrogates/AgentStateSurrogate.cs | 55 --- .../Abstractions/AgentsRegistryState.cs | 15 + .../Abstractions/IConnection.cs | 7 + .../Abstractions/IGateway.cs | 11 +- .../Abstractions/IGatewayRegistry.cs | 7 +- .../Abstractions/IRegistry.cs | 38 ++ .../Abstractions/IRegistryGrain.cs | 4 +- ...rosoft.AutoGen.RuntimeGateway.Grpc.csproj} | 6 + .../Services/AgentWorkerHostingExtensions.cs | 7 +- .../Services/Grpc/GrpcGateway.cs | 414 +++++++++--------- .../Services/Grpc/GrpcGatewayService.cs | 110 +++++ .../Services/Grpc/GrpcWorkerConnection.cs | 11 +- .../Orleans/OrleansRuntimeHostingExtenions.cs | 2 +- .../Services/Orleans/RegistryGrain.cs | 103 ++--- .../AddSubscriptionRequestSurrogate.cs | 5 +- .../AddSubscriptionResponseSurrogate.cs | 18 +- .../Orleans/Surrogates/AgentIdSurrogate.cs | 4 +- .../Orleans/Surrogates/CloudEventSurrogate.cs | 4 +- .../Surrogates/GetSubscriptionsRequest.cs | 4 +- .../RegisterAgentTypeRequestSurrogate.cs | 5 +- .../RegisterAgentTypeResponseSurrogate.cs | 18 +- .../Surrogates/RemoveSubscriptionRequest.cs | 3 +- .../Surrogates/RemoveSubscriptionResponse.cs | 15 +- .../Orleans/Surrogates/RpcRequestSurrogate.cs | 4 +- .../Surrogates/RpcResponseSurrogate.cs | 4 +- .../Surrogates/SubscriptionSurrogate.cs | 5 +- .../TypePrefixSubscriptionSurrogate.cs | 4 +- .../Surrogates/TypeSubscriptionSurrogate.cs | 5 +- .../HelloAgent.AppHost.csproj | 2 +- ...Microsoft.AutoGen.Integration.Tests.csproj | 15 +- .../GrpcGatewayServiceTests.cs | 187 -------- .../TestAgent.cs | 46 -- .../GrpcGatewayServiceTests.cs | 241 ++++++++++ .../Helpers/AgentTypes.cs | 18 + .../Helpers/AgentsMetadata.cs | 104 +++++ .../Helpers/Grpc/TestAsyncStreamReader.cs | 2 +- .../Helpers/Grpc/TestGrpcClient.cs | 9 +- .../Helpers/Grpc/TestServerCallContext.cs | 3 +- .../Helpers/Grpc/TestServerStreamWriter.cs | 2 +- .../Helpers/Orleans/ClusterCollection.cs | 2 +- .../Helpers/Orleans/ClusterFixture.cs | 2 +- .../Orleans/SiloBuilderConfigurator.cs | 2 +- .../Helpers/ReflectionHelper.cs | 70 +++ ....AutoGen.RuntimeGateway.Grpc.Tests.csproj} | 6 +- .../TestAgent.cs | 45 ++ python/pyproject.toml | 2 +- .../protos/agent_events_pb2.py | 4 +- python/uv.lock | 10 - 67 files changed, 1343 insertions(+), 844 deletions(-) create mode 100644 dotnet/src/Microsoft.AutoGen/Agents/AIAgent/InferenceAgent.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Agents/IOAgent/ConsoleAgent/IHandleConsole.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Agents/IOAgent/FileAgent/IHandleFileIO.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Agents/IOAgent/IProcessIO.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Agents/Microsoft.AutoGen.Agents.csproj create mode 100644 dotnet/src/Microsoft.AutoGen/Agents/protos/agent_events.proto delete mode 100644 dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IAgentGrain.cs delete mode 100644 dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs delete mode 100644 dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/AgentStateGrain.cs delete mode 100644 dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/ISubscriptionsGrain.cs delete mode 100644 dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/SubscriptionsGrain.cs delete mode 100644 dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentStateSurrogate.cs create mode 100644 dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/AgentsRegistryState.cs create mode 100644 dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IConnection.cs rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Abstractions/IGateway.cs (58%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Abstractions/IGatewayRegistry.cs (93%) create mode 100644 dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IRegistry.cs rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Abstractions/IRegistryGrain.cs (63%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc/Microsoft.AutoGen.Runtime.Grpc.csproj => RuntimeGateway.Grpc/Microsoft.AutoGen.RuntimeGateway.Grpc.csproj} (85%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/AgentWorkerHostingExtensions.cs (79%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Grpc/GrpcGateway.cs (58%) create mode 100644 dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Grpc/GrpcGatewayService.cs rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Grpc/GrpcWorkerConnection.cs (91%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/OrleansRuntimeHostingExtenions.cs (98%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/RegistryGrain.cs (75%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/AddSubscriptionRequestSurrogate.cs (86%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/AddSubscriptionResponseSurrogate.cs (61%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/AgentIdSurrogate.cs (88%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/CloudEventSurrogate.cs (91%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/GetSubscriptionsRequest.cs (88%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/RegisterAgentTypeRequestSurrogate.cs (89%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/RegisterAgentTypeResponseSurrogate.cs (62%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/RemoveSubscriptionRequest.cs (89%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/RemoveSubscriptionResponse.cs (67%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/RpcRequestSurrogate.cs (92%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/RpcResponseSurrogate.cs (91%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/SubscriptionSurrogate.cs (92%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/TypePrefixSubscriptionSurrogate.cs (90%) rename dotnet/src/Microsoft.AutoGen/{Runtime.Grpc => RuntimeGateway.Grpc}/Services/Orleans/Surrogates/TypeSubscriptionSurrogate.cs (89%) delete mode 100644 dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs delete mode 100644 dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/TestAgent.cs create mode 100644 dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/GrpcGatewayServiceTests.cs create mode 100644 dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/AgentTypes.cs create mode 100644 dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/AgentsMetadata.cs rename dotnet/test/{Microsoft.AutoGen.Runtime.Grpc.Tests => Microsoft.AutoGen.RuntimeGateway.Grpc.Tests}/Helpers/Grpc/TestAsyncStreamReader.cs (96%) rename dotnet/test/{Microsoft.AutoGen.Runtime.Grpc.Tests => Microsoft.AutoGen.RuntimeGateway.Grpc.Tests}/Helpers/Grpc/TestGrpcClient.cs (91%) rename dotnet/test/{Microsoft.AutoGen.Runtime.Grpc.Tests => Microsoft.AutoGen.RuntimeGateway.Grpc.Tests}/Helpers/Grpc/TestServerCallContext.cs (94%) rename dotnet/test/{Microsoft.AutoGen.Runtime.Grpc.Tests => Microsoft.AutoGen.RuntimeGateway.Grpc.Tests}/Helpers/Grpc/TestServerStreamWriter.cs (97%) rename dotnet/test/{Microsoft.AutoGen.Runtime.Grpc.Tests => Microsoft.AutoGen.RuntimeGateway.Grpc.Tests}/Helpers/Orleans/ClusterCollection.cs (78%) rename dotnet/test/{Microsoft.AutoGen.Runtime.Grpc.Tests => Microsoft.AutoGen.RuntimeGateway.Grpc.Tests}/Helpers/Orleans/ClusterFixture.cs (87%) rename dotnet/test/{Microsoft.AutoGen.Runtime.Grpc.Tests => Microsoft.AutoGen.RuntimeGateway.Grpc.Tests}/Helpers/Orleans/SiloBuilderConfigurator.cs (89%) create mode 100644 dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/ReflectionHelper.cs rename dotnet/test/{Microsoft.AutoGen.Runtime.Grpc.Tests/Microsoft.AutoGen.Runtime.Grpc.Tests.csproj => Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.csproj} (74%) create mode 100644 dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/TestAgent.cs diff --git a/.github/workflows/dotnet-build.yml b/.github/workflows/dotnet-build.yml index 8539502b9e33..ca55a6f142c7 100644 --- a/.github/workflows/dotnet-build.yml +++ b/.github/workflows/dotnet-build.yml @@ -65,6 +65,18 @@ jobs: - uses: actions/checkout@v4 with: lfs: true + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + version: "0.5.18" + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - run: uv sync --locked --all-extras + working-directory: ./python + - name: Prepare python venv + run: | + source ${{ github.workspace }}/python/.venv/bin/activate - name: Setup .NET 8.0 uses: actions/setup-dotnet@v4 with: diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index a267a11e6fa9..cb590f21dbd0 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -122,6 +122,22 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "GettingStartedGrpc", "sampl EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Core.Grpc.Tests", "test\Microsoft.AutoGen.Core.Grpc.Tests\Microsoft.AutoGen.Core.Grpc.Tests.csproj", "{23A028D3-5EB1-4FA0-9CD1-A1340B830579}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.RuntimeGateway.Grpc", "src\Microsoft.AutoGen\RuntimeGateway.Grpc\Microsoft.AutoGen.RuntimeGateway.Grpc.csproj", "{BE420A71-7615-4DFD-BE94-9409397949F1}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.RuntimeGateway.Grpc.Tests", "test\Microsoft.AutoGen.RuntimeGateway.Grpc.Tests\Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.csproj", "{CDD859F3-1B60-4ECE-8472-54DF8EFCA682}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Integration.Tests", "test\Microsoft.AutoGen.Integration.Tests\Microsoft.AutoGen.Integration.Tests.csproj", "{7A11022E-4E5D-4A4A-AADF-E715C2ECF800}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.AgentHost", "src\Microsoft.AutoGen\AgentHost\Microsoft.AutoGen.AgentHost.csproj", "{50C2E8D5-68AB-45A3-B96F-355E1F8AC039}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Hello.AppHost", "samples\Hello\Hello.AppHost\Hello.AppHost.csproj", "{B8E77E57-C983-4EEA-9589-906271486D80}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.AutoGen", "Microsoft.AutoGen", "{81BA12F2-2D2F-42C1-AF83-FBDAA1A78A45}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Agents", "src\Microsoft.AutoGen\Agents\Microsoft.AutoGen.Agents.csproj", "{EF954ED3-87D5-40F1-8557-E7179F43EA0E}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloAgent.AppHost", "test\Microsoft.AutoGen.Integration.Tests.AppHosts\HelloAgent.AppHost\HelloAgent.AppHost.csproj", "{50082F76-917F-42EE-8869-8C72630423A7}" + Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.AgentChat", "src\Microsoft.AutoGen\AgentChat\Microsoft.AutoGen.AgentChat.csproj", "{7F828599-56E8-4597-8F68-EE26FD631417}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.AgentChat.Tests", "test\Microsoft.AutoGen.AgentChat.Tests\Microsoft.AutoGen.AgentChat.Tests.csproj", "{217A4F86-8ADD-4998-90BA-880092A019F5}" @@ -320,6 +336,34 @@ Global {23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Debug|Any CPU.Build.0 = Debug|Any CPU {23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Release|Any CPU.ActiveCfg = Release|Any CPU {23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Release|Any CPU.Build.0 = Release|Any CPU + {BE420A71-7615-4DFD-BE94-9409397949F1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BE420A71-7615-4DFD-BE94-9409397949F1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BE420A71-7615-4DFD-BE94-9409397949F1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BE420A71-7615-4DFD-BE94-9409397949F1}.Release|Any CPU.Build.0 = Release|Any CPU + {CDD859F3-1B60-4ECE-8472-54DF8EFCA682}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {CDD859F3-1B60-4ECE-8472-54DF8EFCA682}.Debug|Any CPU.Build.0 = Debug|Any CPU + {CDD859F3-1B60-4ECE-8472-54DF8EFCA682}.Release|Any CPU.ActiveCfg = Release|Any CPU + {CDD859F3-1B60-4ECE-8472-54DF8EFCA682}.Release|Any CPU.Build.0 = Release|Any CPU + {7A11022E-4E5D-4A4A-AADF-E715C2ECF800}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7A11022E-4E5D-4A4A-AADF-E715C2ECF800}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7A11022E-4E5D-4A4A-AADF-E715C2ECF800}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7A11022E-4E5D-4A4A-AADF-E715C2ECF800}.Release|Any CPU.Build.0 = Release|Any CPU + {50C2E8D5-68AB-45A3-B96F-355E1F8AC039}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {50C2E8D5-68AB-45A3-B96F-355E1F8AC039}.Debug|Any CPU.Build.0 = Debug|Any CPU + {50C2E8D5-68AB-45A3-B96F-355E1F8AC039}.Release|Any CPU.ActiveCfg = Release|Any CPU + {50C2E8D5-68AB-45A3-B96F-355E1F8AC039}.Release|Any CPU.Build.0 = Release|Any CPU + {B8E77E57-C983-4EEA-9589-906271486D80}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B8E77E57-C983-4EEA-9589-906271486D80}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B8E77E57-C983-4EEA-9589-906271486D80}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B8E77E57-C983-4EEA-9589-906271486D80}.Release|Any CPU.Build.0 = Release|Any CPU + {EF954ED3-87D5-40F1-8557-E7179F43EA0E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {EF954ED3-87D5-40F1-8557-E7179F43EA0E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {EF954ED3-87D5-40F1-8557-E7179F43EA0E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {EF954ED3-87D5-40F1-8557-E7179F43EA0E}.Release|Any CPU.Build.0 = Release|Any CPU + {50082F76-917F-42EE-8869-8C72630423A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {50082F76-917F-42EE-8869-8C72630423A7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {50082F76-917F-42EE-8869-8C72630423A7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {50082F76-917F-42EE-8869-8C72630423A7}.Release|Any CPU.Build.0 = Release|Any CPU {7F828599-56E8-4597-8F68-EE26FD631417}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {7F828599-56E8-4597-8F68-EE26FD631417}.Debug|Any CPU.Build.0 = Debug|Any CPU {7F828599-56E8-4597-8F68-EE26FD631417}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -383,6 +427,14 @@ Global {3D83C6DB-ACEA-48F3-959F-145CCD2EE135} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {C3740DF1-18B1-4607-81E4-302F0308C848} = {CE0AA8D5-12B8-4628-9589-DAD8CB0DDCF6} {23A028D3-5EB1-4FA0-9CD1-A1340B830579} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {BE420A71-7615-4DFD-BE94-9409397949F1} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {CDD859F3-1B60-4ECE-8472-54DF8EFCA682} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {7A11022E-4E5D-4A4A-AADF-E715C2ECF800} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {50C2E8D5-68AB-45A3-B96F-355E1F8AC039} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {B8E77E57-C983-4EEA-9589-906271486D80} = {F42F9C8E-7BD9-4687-9B63-AFFA461AF5C1} + {81BA12F2-2D2F-42C1-AF83-FBDAA1A78A45} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {EF954ED3-87D5-40F1-8557-E7179F43EA0E} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {50082F76-917F-42EE-8869-8C72630423A7} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {7F828599-56E8-4597-8F68-EE26FD631417} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {217A4F86-8ADD-4998-90BA-880092A019F5} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} EndGlobalSection diff --git a/dotnet/src/Microsoft.AutoGen/AgentHost/Host.cs b/dotnet/src/Microsoft.AutoGen/AgentHost/Host.cs index 1ecf42c79589..0176b3faa3e3 100644 --- a/dotnet/src/Microsoft.AutoGen/AgentHost/Host.cs +++ b/dotnet/src/Microsoft.AutoGen/AgentHost/Host.cs @@ -4,7 +4,7 @@ using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.Hosting; -namespace Microsoft.AutoGen.Runtime.Grpc; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc; public static class Host { diff --git a/dotnet/src/Microsoft.AutoGen/AgentHost/Microsoft.AutoGen.AgentHost.csproj b/dotnet/src/Microsoft.AutoGen/AgentHost/Microsoft.AutoGen.AgentHost.csproj index 33b051ad917b..dea603ebc668 100644 --- a/dotnet/src/Microsoft.AutoGen/AgentHost/Microsoft.AutoGen.AgentHost.csproj +++ b/dotnet/src/Microsoft.AutoGen/AgentHost/Microsoft.AutoGen.AgentHost.csproj @@ -15,7 +15,7 @@ - + \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/AgentHost/Program.cs b/dotnet/src/Microsoft.AutoGen/AgentHost/Program.cs index 024ca0d4309f..3e32c50a3a0f 100644 --- a/dotnet/src/Microsoft.AutoGen/AgentHost/Program.cs +++ b/dotnet/src/Microsoft.AutoGen/AgentHost/Program.cs @@ -2,5 +2,5 @@ // Program.cs using Microsoft.Extensions.Hosting; -var app = await Microsoft.AutoGen.Runtime.Grpc.Host.StartAsync(local: false, useGrpc: true).ConfigureAwait(false); +var app = await Microsoft.AutoGen.RuntimeGateway.Grpc.Host.StartAsync(local: false, useGrpc: true).ConfigureAwait(false); await app.WaitForShutdownAsync(); diff --git a/dotnet/src/Microsoft.AutoGen/Agents/AIAgent/InferenceAgent.cs b/dotnet/src/Microsoft.AutoGen/Agents/AIAgent/InferenceAgent.cs new file mode 100644 index 000000000000..d3dc100012eb --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Agents/AIAgent/InferenceAgent.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// InferenceAgent.cs +using Google.Protobuf; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +namespace Microsoft.AutoGen.Agents; +/// +/// Base class for inference agents using the Microsoft.Extensions.AI library. +/// +/// +/// +/// +/// +/// +/// +public abstract class InferenceAgent( + AgentId id, + IAgentRuntime runtime, + string name, + ILogger>? logger, + IChatClient client) + : BaseAgent(id, runtime, name, logger) + where T : IMessage, new() +{ + protected IChatClient ChatClient { get; } = client; + private ILogger>? Logger => _logger as ILogger>; + private Task CompleteAsync( + IList chatMessages, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + { + return ChatClient.CompleteAsync(chatMessages, options, cancellationToken); + } + private IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + { + return ChatClient.CompleteStreamingAsync(chatMessages, options, cancellationToken); + } + +} diff --git a/dotnet/src/Microsoft.AutoGen/Agents/IOAgent/ConsoleAgent/IHandleConsole.cs b/dotnet/src/Microsoft.AutoGen/Agents/IOAgent/ConsoleAgent/IHandleConsole.cs new file mode 100644 index 000000000000..651be87314e0 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Agents/IOAgent/ConsoleAgent/IHandleConsole.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IHandleConsole.cs +using Google.Protobuf; +using Microsoft.AutoGen.Contracts; + +namespace Microsoft.AutoGen.Agents; +/// +/// Default interface methods for an event handler for Input and Output that writes or reads from the console +/// Can be used inside your agents by inheriting from this interface +/// public class MyAgent : BaseAgent, IHandleConsole +/// +public interface IHandleConsole : IHandle, IHandle, IProcessIO +{ + /// + /// Prototype for Publish Message Async method + /// + /// + /// + /// + /// + /// + /// ValueTask + ValueTask PublishMessageAsync(T message, TopicId topic, string? messageId, CancellationToken token = default) where T : IMessage; + + /// + /// Receives events of type Output and writes them to the console + /// then runs the ProcessOutputAsync method which you should implement in your agent + /// + /// + /// + /// ValueTask + async ValueTask IHandle.HandleAsync(Output item, MessageContext messageContext) + { + // Assuming item has a property `Message` that we want to write to the console + Console.WriteLine(item.Message); + await ProcessOutputAsync(item.Message); + + var evt = new OutputWritten + { + Route = "console" + }; + await PublishMessageAsync(evt, new TopicId("OutputWritten"), null, token: CancellationToken.None).ConfigureAwait(false); + } + + /// + /// Receives events of type Input and reads from the console, then runs the ProcessInputAsync method + /// which you should implement in your agent + /// + /// + /// + /// + async ValueTask IHandle.HandleAsync(Input item, MessageContext messageContext) + { + Console.WriteLine("Please enter input:"); + string content = Console.ReadLine() ?? string.Empty; + + await ProcessInputAsync(content); + + var evt = new InputProcessed + { + Route = "console" + }; + await PublishMessageAsync(evt, new TopicId("InputProcessed"), null, token: CancellationToken.None).ConfigureAwait(false); + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Agents/IOAgent/FileAgent/IHandleFileIO.cs b/dotnet/src/Microsoft.AutoGen/Agents/IOAgent/FileAgent/IHandleFileIO.cs new file mode 100644 index 000000000000..b3d670004e19 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Agents/IOAgent/FileAgent/IHandleFileIO.cs @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IHandleFileIO.cs + +using Google.Protobuf; +using Microsoft.AutoGen.Contracts; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AutoGen.Agents; +/// +/// Default interface methods for an event handler for Input and Output that writes or reads from a file +/// Can be used inside your agents by inheriting from this interface +/// public class MyAgent : BaseAgent, IHandleFileIO +/// +public interface IHandleFileIO : IHandle, IHandle, IProcessIO +{ + // A Logger instance to log messages + ILogger LogTarget { get; } + // The path to the input file + string InputPath { get; } + // The path to the output file + string OutputPath { get; } + // The route of the agent (used in the post-process events) + const string Route = "Microsoft.AutoGen.Agents.IHandleFileIO"; + + /// + /// Prototype for Publish Message Async method + /// + /// + /// + /// + /// + /// + /// ValueTask + ValueTask PublishMessageAsync(T message, TopicId topic, string? messageId, CancellationToken token = default) where T : IMessage; + async ValueTask IHandle.HandleAsync(Input item, MessageContext messageContext) + { + + // validate that the file exists + if (!File.Exists(InputPath)) + { + var errorMessage = $"File not found: {InputPath}"; + LogTarget.LogError(errorMessage); + //publish IOError event + var err = new IOError + { + Message = errorMessage + }; + await PublishMessageAsync(err, new TopicId("IOError"), null, token: CancellationToken.None).ConfigureAwait(false); + return; + } + string content; + using (var reader = new StreamReader(item.Message)) + { + content = await reader.ReadToEndAsync(CancellationToken.None); + } + await ProcessInputAsync(content); + var evt = new InputProcessed + { + Route = Route + }; + await PublishMessageAsync(evt, new TopicId("InputProcessed"), null, token: CancellationToken.None).ConfigureAwait(false); + } + async ValueTask IHandle.HandleAsync(Output item, MessageContext messageContext) + { + using (var writer = new StreamWriter(OutputPath, append: true)) + { + await writer.WriteLineAsync(item.Message); + } + var evt = new OutputWritten + { + Route = Route + }; + await PublishMessageAsync(evt, new TopicId("OutputWritten"), null, token: CancellationToken.None).ConfigureAwait(false); + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Agents/IOAgent/IProcessIO.cs b/dotnet/src/Microsoft.AutoGen/Agents/IOAgent/IProcessIO.cs new file mode 100644 index 000000000000..e348f3e1ca71 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Agents/IOAgent/IProcessIO.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IProcessIO.cs + +namespace Microsoft.AutoGen.Agents; + +/// +/// Default Interface methods for processing input and output shared by IOAgents that should be implemented in your agent +/// +public interface IProcessIO +{ + /// + /// Implement this method in your agent to process the input + /// + /// + /// Task + static Task ProcessOutputAsync(string message) { return Task.CompletedTask; } + /// + /// Implement this method in your agent to process the output + /// + /// + /// Task + static Task ProcessInputAsync(string message) { return Task.FromResult(message); } +} diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Microsoft.AutoGen.Agents.csproj b/dotnet/src/Microsoft.AutoGen/Agents/Microsoft.AutoGen.Agents.csproj new file mode 100644 index 000000000000..5032e95a12a0 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Agents/Microsoft.AutoGen.Agents.csproj @@ -0,0 +1,24 @@ + + + + net8.0 + enable + enable + + + + + + + + + + + + + + + + + + diff --git a/dotnet/src/Microsoft.AutoGen/Agents/protos/agent_events.proto b/dotnet/src/Microsoft.AutoGen/Agents/protos/agent_events.proto new file mode 100644 index 000000000000..414d79f9678c --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Agents/protos/agent_events.proto @@ -0,0 +1,43 @@ +syntax = "proto3"; + +package agents; + +option csharp_namespace = "Microsoft.AutoGen.Agents"; +message TextMessage { + string textMessage = 1; + string source = 2; +} +message Input { + string message = 1; +} +message InputProcessed { + string route = 1; +} +message Output { + string message = 1; +} +message OutputWritten { + string route = 1; +} +message IOError { + string message = 1; +} +message NewMessageReceived { + string message = 1; +} +message ResponseGenerated { + string response = 1; +} +message GoodBye { + string message = 1; +} +message MessageStored { + string message = 1; +} +message ConversationClosed { + string user_id = 1; + string user_message = 2; +} +message Shutdown { + string message = 1; +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs index f3dc619fdbef..324874a0005b 100644 --- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs @@ -155,6 +155,7 @@ private async ValueTask HandleRequest(RpcRequest request, CancellationToken canc var messageContext = new MessageContext(request.RequestId, cancellationToken) { + Sender = request.Source?.FromProtobuf() ?? null, Topic = null, IsRpc = true @@ -275,6 +276,7 @@ public Task StopAsync(CancellationToken cancellationToken) var request = new RpcRequest { RequestId = Guid.NewGuid().ToString(), + Source = sender?.ToProtobuf() ?? null, Target = recepient.ToProtobuf(), Payload = payload, diff --git a/dotnet/src/Microsoft.AutoGen/Core/AgentRuntimeExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core/AgentRuntimeExtensions.cs index fd7a479adde3..dfda552956ba 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/AgentRuntimeExtensions.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/AgentRuntimeExtensions.cs @@ -62,7 +62,7 @@ private static ISubscriptionDefinition[] BindSubscriptionsForAgentType(AgentType var classSubscriptions = runtimeType.GetCustomAttributes().Select(t => t.Bind(agentType)); subscriptions.AddRange(classSubscriptions); - var prefixSubscriptions = runtimeType.GetCustomAttributes().Select(t => t.Bind(agentType)); + var prefixSubscriptions = runtimeType.GetCustomAttributes().Select(t => t.Bind(agentType)); subscriptions.AddRange(prefixSubscriptions); } diff --git a/dotnet/src/Microsoft.AutoGen/Core/TypePrefixSubscriptionAttribute.cs b/dotnet/src/Microsoft.AutoGen/Core/TypePrefixSubscriptionAttribute.cs index 57105c8b9629..be48ab8b195f 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/TypePrefixSubscriptionAttribute.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/TypePrefixSubscriptionAttribute.cs @@ -6,7 +6,7 @@ namespace Microsoft.AutoGen.Core; [AttributeUsage(AttributeTargets.All)] -public class TopicPrefixSubscriptionAttribute(string topic) : Attribute, IUnboundSubscriptionDefinition +public class TypePrefixSubscriptionAttribute(string topic) : Attribute, IUnboundSubscriptionDefinition { public string Topic { get; } = topic; diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IAgentGrain.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IAgentGrain.cs deleted file mode 100644 index 947b6b0cbc0a..000000000000 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IAgentGrain.cs +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// IAgentGrain.cs - -namespace Microsoft.AutoGen.Runtime.Grpc.Abstractions; - -internal interface IAgentGrain : IGrainWithStringKey -{ - ValueTask ReadStateAsync(); - ValueTask WriteStateAsync(Contracts.AgentState state, string eTag); -} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs deleted file mode 100644 index 9481922943c9..000000000000 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// GrpcGatewayService.cs - -using Grpc.Core; -using Microsoft.AutoGen.Contracts; - -namespace Microsoft.AutoGen.Runtime.Grpc; - -// gRPC service which handles communication between the agent worker and the cluster. -public sealed class GrpcGatewayService(GrpcGateway gateway) : AgentRpc.AgentRpcBase -{ - private readonly GrpcGateway Gateway = (GrpcGateway)gateway; - - public override async Task OpenChannel(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) - { - try - { - await Gateway.ConnectToWorkerProcess(requestStream, responseStream, context).ConfigureAwait(true); - } - catch - { - if (context.CancellationToken.IsCancellationRequested) - { - return; - } - throw; - } - } - public override async Task GetState(AgentId request, ServerCallContext context) - { - var state = await Gateway.ReadAsync(request); - return new GetStateResponse { AgentState = state }; - } - public override async Task SaveState(AgentState request, ServerCallContext context) - { - await Gateway.StoreAsync(request); - return new SaveStateResponse - { - Success = true // TODO: Implement error handling - }; - } - public override async Task AddSubscription(AddSubscriptionRequest request, ServerCallContext context) - { - request.RequestId = context.Peer; - return await Gateway.SubscribeAsync(request).ConfigureAwait(true); - } - public override async Task RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context) - { - return await Gateway.UnsubscribeAsync(request).ConfigureAwait(true); - } - public override async Task GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context) - { - var subscriptions = await Gateway.GetSubscriptionsAsync(request); - return new GetSubscriptionsResponse { Subscriptions = { subscriptions } }; - } - public override async Task RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context) - { - request.RequestId = context.Peer; - return await Gateway.RegisterAgentTypeAsync(request).ConfigureAwait(true); - } -} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/AgentStateGrain.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/AgentStateGrain.cs deleted file mode 100644 index 97869cd91fd1..000000000000 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/AgentStateGrain.cs +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// AgentStateGrain.cs - -using Microsoft.AutoGen.Runtime.Grpc.Abstractions; - -namespace Microsoft.AutoGen.Runtime.Grpc; - -internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore")] IPersistentState state) : Grain, IAgentState, IAgentGrain -{ - /// - public async ValueTask WriteStateAsync(AgentState newState, string eTag, CancellationToken cancellationToken = default) - { - // etags for optimistic concurrency control - // if the Etag is null, its a new state - // if the passed etag is null or empty, we should not check the current state's Etag - caller doesnt care - // if both etags are set, they should match or it means that the state has changed since the last read. - if ((string.IsNullOrEmpty(state.Etag)) || (string.IsNullOrEmpty(eTag)) || (string.Equals(state.Etag, eTag, StringComparison.Ordinal))) - { - state.State = newState; - await state.WriteStateAsync().ConfigureAwait(false); - } - else - { - //TODO - this is probably not the correct behavior to just throw - I presume we want to somehow let the caller know that the state has changed and they need to re-read it - throw new ArgumentException( - "The provided ETag does not match the current ETag. The state has been modified by another request."); - } - return state.Etag; - } - - /// - public ValueTask ReadStateAsync(CancellationToken cancellationToken = default) - { - return ValueTask.FromResult(state.State); - } - - ValueTask IAgentGrain.ReadStateAsync() - { - return ReadStateAsync(); - } - - ValueTask IAgentGrain.WriteStateAsync(AgentState state, string eTag) - { - return WriteStateAsync(state, eTag); - } -} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/ISubscriptionsGrain.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/ISubscriptionsGrain.cs deleted file mode 100644 index 60c17b7c6597..000000000000 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/ISubscriptionsGrain.cs +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// ISubscriptionsGrain.cs - -namespace Microsoft.AutoGen.Runtime.Grpc; -public interface ISubscriptionsGrain : IGrainWithIntegerKey -{ - ValueTask SubscribeAsync(string agentType, string topic); - ValueTask UnsubscribeAsync(string agentType, string topic); - ValueTask>> GetSubscriptions(string agentType); -} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/SubscriptionsGrain.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/SubscriptionsGrain.cs deleted file mode 100644 index 632cc7cefde8..000000000000 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/SubscriptionsGrain.cs +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// SubscriptionsGrain.cs - -namespace Microsoft.AutoGen.Runtime.Grpc; - -internal sealed class SubscriptionsGrain([PersistentState("state", "PubSubStore")] IPersistentState state) : Grain, ISubscriptionsGrain -{ - private readonly Dictionary> _subscriptions = new(); - public ValueTask>> GetSubscriptions(string? agentType = null) - { - //if agentType is null, return all subscriptions else filter on agentType - if (agentType != null) - { - return new ValueTask>>(_subscriptions.Where(x => x.Value.Contains(agentType)).ToDictionary(x => x.Key, x => x.Value)); - } - return new ValueTask>>(_subscriptions); - } - public async ValueTask SubscribeAsync(string agentType, string topic) - { - if (!_subscriptions.TryGetValue(topic, out var subscriptions)) - { - subscriptions = _subscriptions[topic] = []; - } - if (!subscriptions.Contains(agentType)) - { - subscriptions.Add(agentType); - } - _subscriptions[topic] = subscriptions; - state.State.Subscriptions = _subscriptions; - await state.WriteStateAsync().ConfigureAwait(false); - } - public async ValueTask UnsubscribeAsync(string agentType, string topic) - { - if (!_subscriptions.TryGetValue(topic, out var subscriptions)) - { - subscriptions = _subscriptions[topic] = []; - } - if (!subscriptions.Contains(agentType)) - { - subscriptions.Remove(agentType); - } - _subscriptions[topic] = subscriptions; - state.State.Subscriptions = _subscriptions; - await state.WriteStateAsync(); - } -} -public sealed class SubscriptionsState -{ - public Dictionary> Subscriptions { get; set; } = new(); -} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentStateSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentStateSurrogate.cs deleted file mode 100644 index a5291f942155..000000000000 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentStateSurrogate.cs +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// AgentStateSurrogate.cs - -using Google.Protobuf; -using Microsoft.AutoGen.Contracts; - -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; - -[GenerateSerializer] -public struct AgentStateSurrogate -{ - [Id(0)] - public string Id; - [Id(1)] - public string TextData; - [Id(2)] - public ByteString BinaryData; - [Id(3)] - public AgentId AgentId; - [Id(4)] - public string Etag; - [Id(5)] - public ByteString ProtoData; -} - -[RegisterConverter] -public sealed class AgentStateSurrogateConverter : - IConverter -{ - public AgentState ConvertFromSurrogate( - in AgentStateSurrogate surrogate) - { - var agentState = new AgentState - { - AgentId = surrogate.AgentId, - BinaryData = surrogate.BinaryData, - TextData = surrogate.TextData, - ETag = surrogate.Etag - }; - //agentState.ProtoData = surrogate.ProtoData; - return agentState; - } - - public AgentStateSurrogate ConvertToSurrogate( - in AgentState value) => - new AgentStateSurrogate - { - AgentId = value.AgentId, - BinaryData = value.BinaryData, - TextData = value.TextData, - Etag = value.ETag, - //ProtoData = value.ProtoData.Value - }; -} - diff --git a/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/AgentsRegistryState.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/AgentsRegistryState.cs new file mode 100644 index 000000000000..2baa70b33ef8 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/AgentsRegistryState.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AgentsRegistryState.cs +using System.Collections.Concurrent; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions; +public class AgentsRegistryState +{ + public ConcurrentDictionary> AgentsToTopicsMap { get; set; } = []; + public ConcurrentDictionary> AgentsToTopicsPrefixMap { get; set; } = []; + public ConcurrentDictionary> TopicToAgentTypesMap { get; set; } = []; + public ConcurrentDictionary> TopicPrefixToAgentTypesMap { get; set; } = []; + public ConcurrentDictionary> GuidSubscriptionsMap { get; set; } = []; + public string Etag { get; set; } = Guid.NewGuid().ToString(); +} diff --git a/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IConnection.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IConnection.cs new file mode 100644 index 000000000000..a5044ff69b4f --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IConnection.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IConnection.cs + +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions; +public interface IConnection +{ +} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGateway.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IGateway.cs similarity index 58% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGateway.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IGateway.cs index 33bb94f7c49b..b8aeae041e86 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGateway.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IGateway.cs @@ -1,18 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // IGateway.cs -using Microsoft.AutoGen.Contracts; +using Grpc.Core; +using Microsoft.AutoGen.Protobuf; -namespace Microsoft.AutoGen.Runtime.Grpc.Abstractions; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions; public interface IGateway : IGrainObserver { ValueTask InvokeRequestAsync(RpcRequest request); - ValueTask BroadcastEventAsync(CloudEvent evt); - ValueTask StoreAsync(Contracts.AgentState value); - ValueTask ReadAsync(AgentId agentId); - ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request); + ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, ServerCallContext context); ValueTask SubscribeAsync(AddSubscriptionRequest request); ValueTask UnsubscribeAsync(RemoveSubscriptionRequest request); ValueTask> GetSubscriptionsAsync(GetSubscriptionsRequest request); - Task SendMessageAsync(IConnection connection, CloudEvent cloudEvent); } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGatewayRegistry.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IGatewayRegistry.cs similarity index 93% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGatewayRegistry.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IGatewayRegistry.cs index cb3778418040..3d47696da3ee 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGatewayRegistry.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IGatewayRegistry.cs @@ -1,9 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // IGatewayRegistry.cs +using Microsoft.AutoGen.Protobuf; -using Microsoft.AutoGen.Contracts; - -namespace Microsoft.AutoGen.Runtime.Grpc.Abstractions; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions; /// /// Interface for managing agent registration, placement, and subscriptions. @@ -30,7 +29,7 @@ public interface IGatewayRegistry : IRegistry /// The request containing agent type details. /// The worker to register the agent type with. /// A task representing the asynchronous operation. - ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, IGateway worker); + ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, string clientId, IGateway worker); /// /// Adds a new worker to the registry. diff --git a/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IRegistry.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IRegistry.cs new file mode 100644 index 000000000000..215ace6e5dda --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IRegistry.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IRegistry.cs +using Microsoft.AutoGen.Protobuf; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions; + +public interface IRegistry +{ + + /// + /// Gets a list of agents subscribed to and handling the specified topic and event type. + /// + /// The topic to check subscriptions for. + /// The event type to check subscriptions for. + /// A task representing the asynchronous operation, with the list of agent IDs as the result. + ValueTask> GetSubscribedAndHandlingAgentsAsync(string topic, string key); + + /// + /// Subscribes an agent to a topic. + /// + /// The subscription request. + /// A task representing the asynchronous operation. + /// removing CancellationToken from here as it is not compatible with Orleans Serialization + ValueTask SubscribeAsync(AddSubscriptionRequest request); + + /// + /// Unsubscribes an agent from a topic. + /// + /// The unsubscription request. + /// A task representing the asynchronous operation. + /// removing CancellationToken from here as it is not compatible with Orleans Serialization + ValueTask UnsubscribeAsync(RemoveSubscriptionRequest request); // TODO: This should have its own request type. + + /// + /// Gets the subscriptions for a specified agent type. + /// + /// A task representing the asynchronous operation, with the subscriptions as the result. + ValueTask> GetSubscriptionsAsync(GetSubscriptionsRequest request); +} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IRegistryGrain.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IRegistryGrain.cs similarity index 63% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IRegistryGrain.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IRegistryGrain.cs index 81b59858619c..a44da1ce5b22 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IRegistryGrain.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Abstractions/IRegistryGrain.cs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // IRegistryGrain.cs -namespace Microsoft.AutoGen.Runtime.Grpc.Abstractions; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions; /// /// Orleans specific interface, needed to mark the key /// -[Alias("Microsoft.AutoGen.Runtime.Grpc.Abstractions.IRegistryGrain")] +[Alias("Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions.IRegistryGrain")] public interface IRegistryGrain : IGatewayRegistry, IGrainWithIntegerKey { } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Microsoft.AutoGen.Runtime.Grpc.csproj b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Microsoft.AutoGen.RuntimeGateway.Grpc.csproj similarity index 85% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Microsoft.AutoGen.Runtime.Grpc.csproj rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Microsoft.AutoGen.RuntimeGateway.Grpc.csproj index b874a657d8f2..f9a568e89ade 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Microsoft.AutoGen.Runtime.Grpc.csproj +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Microsoft.AutoGen.RuntimeGateway.Grpc.csproj @@ -7,6 +7,7 @@ + @@ -29,4 +30,9 @@ + + + + + \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/AgentWorkerHostingExtensions.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/AgentWorkerHostingExtensions.cs similarity index 79% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/AgentWorkerHostingExtensions.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/AgentWorkerHostingExtensions.cs index 3b130ca4bed5..37b9e1f1b6df 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/AgentWorkerHostingExtensions.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/AgentWorkerHostingExtensions.cs @@ -3,12 +3,11 @@ using System.Diagnostics; using Microsoft.AspNetCore.Builder; -using Microsoft.AutoGen.Core; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Hosting; -namespace Microsoft.AutoGen.Runtime.Grpc; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc; public static class AgentWorkerHostingExtensions { public static WebApplicationBuilder AddAgentService(this WebApplicationBuilder builder) @@ -18,10 +17,6 @@ public static WebApplicationBuilder AddAgentService(this WebApplicationBuilder b builder.Services.TryAddSingleton(DistributedContextPropagator.Current); builder.Services.AddGrpc(); - builder.Services.AddKeyedSingleton("AgentsMetadata", (sp, key) => - { - return ReflectionHelper.GetAgentsMetadata(AppDomain.CurrentDomain.GetAssemblies()); - }); builder.Services.AddSingleton(); builder.Services.AddSingleton(sp => (IHostedService)sp.GetRequiredService()); diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Grpc/GrpcGateway.cs similarity index 58% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Grpc/GrpcGateway.cs index 26c99c894248..e657ac36fef0 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Grpc/GrpcGateway.cs @@ -4,46 +4,81 @@ using System.Collections.Concurrent; using Grpc.Core; using Microsoft.AutoGen.Contracts; -using Microsoft.AutoGen.Runtime.Grpc.Abstractions; +using Microsoft.AutoGen.Protobuf; +using Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -namespace Microsoft.AutoGen.Runtime.Grpc; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc; +/// +/// Represents the gRPC gateway service that handles communication between the agent worker and the cluster. +/// public sealed class GrpcGateway : BackgroundService, IGateway { private static readonly TimeSpan s_agentResponseTimeout = TimeSpan.FromSeconds(30); private readonly ILogger _logger; private readonly IClusterClient _clusterClient; - //private readonly ConcurrentDictionary _agentState = new(); private readonly IRegistryGrain _gatewayRegistry; private readonly IGateway _reference; - // The agents supported by each worker process. private readonly ConcurrentDictionary> _supportedAgentTypes = []; - public readonly ConcurrentDictionary _workers = new(); - internal readonly ConcurrentDictionary _workersByConnection = new(); - private readonly ConcurrentDictionary _subscriptionsByAgentType = new(); - private readonly ConcurrentDictionary> _subscriptionsByTopic = new(); - private readonly ISubscriptionsGrain _subscriptions; - - // The mapping from agent id to worker process. + public readonly ConcurrentDictionary _workers = new(); private readonly ConcurrentDictionary<(string Type, string Key), GrpcWorkerConnection> _agentDirectory = new(); - // RPC private readonly ConcurrentDictionary<(GrpcWorkerConnection, string), TaskCompletionSource> _pendingRequests = new(); + + /// + /// Initializes a new instance of the class. + /// + /// The cluster client. + /// The logger. public GrpcGateway(IClusterClient clusterClient, ILogger logger) { _logger = logger; _clusterClient = clusterClient; _reference = clusterClient.CreateObjectReference(this); _gatewayRegistry = clusterClient.GetGrain(0); - _subscriptions = clusterClient.GetGrain(0); } + + /// + /// Executes the background service. + /// + /// The cancellation token. + /// A task that represents the asynchronous operation. + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + while (!stoppingToken.IsCancellationRequested) + { + try + { + await _gatewayRegistry.AddWorkerAsync(_reference); + } + catch (Exception exception) + { + _logger.LogWarning(exception, "Error adding worker to registry."); + } + await Task.Delay(TimeSpan.FromSeconds(15), stoppingToken); + } + try + { + await _gatewayRegistry.RemoveWorkerAsync(_reference); + } + catch (Exception exception) + { + _logger.LogWarning(exception, "Error removing worker from registry."); + } + } + + /// + /// Invokes a request asynchronously. + /// + /// The RPC request. + /// The cancellation token. + /// A task that represents the asynchronous operation. The task result contains the RPC response. public async ValueTask InvokeRequestAsync(RpcRequest request, CancellationToken cancellationToken = default) { var agentId = (request.Target.Type, request.Target.Key); if (!_agentDirectory.TryGetValue(agentId, out var connection) || connection.Completion.IsCompleted == true) { - // Activate the agent on a compatible worker process. if (_supportedAgentTypes.TryGetValue(request.Target.Type, out var workers)) { connection = workers[Random.Shared.Next(workers.Count)]; @@ -54,109 +89,143 @@ public async ValueTask InvokeRequestAsync(RpcRequest request, Cance return new(new RpcResponse { Error = "Agent not found." }); } } - // Proxy the request to the agent. var originalRequestId = request.RequestId; var newRequestId = Guid.NewGuid().ToString(); var completion = _pendingRequests[(connection, newRequestId)] = new(TaskCreationOptions.RunContinuationsAsynchronously); request.RequestId = newRequestId; await connection.ResponseStream.WriteAsync(new Message { Request = request }, cancellationToken).ConfigureAwait(false); - // Wait for the response and send it back to the caller. var response = await completion.Task.WaitAsync(s_agentResponseTimeout); response.RequestId = originalRequestId; return response; } - public async ValueTask StoreAsync(AgentState value, CancellationToken cancellationToken = default) - { - _ = value.AgentId ?? throw new ArgumentNullException(nameof(value.AgentId)); - var agentState = _clusterClient.GetGrain($"{value.AgentId.Type}:{value.AgentId.Key}"); - await agentState.WriteStateAsync(value, value.ETag); - } - public async ValueTask ReadAsync(AgentId agentId, CancellationToken cancellationToken = default) - { - var agentState = _clusterClient.GetGrain($"{agentId.Type}:{agentId.Key}"); - return await agentState.ReadStateAsync(); - } - public async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken cancellationToken = default) + + /// + /// Registers an agent type asynchronously. + /// + /// The register agent type request. + /// The server call context. + /// The cancellation token. + /// A task that represents the asynchronous operation. The task result contains the register agent type response. + public async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, ServerCallContext context, CancellationToken cancellationToken = default) { try { - var connection = _workersByConnection[request.RequestId]; + var clientId = context.RequestHeaders.Get("client-id")?.Value ?? + throw new RpcException(new Status(StatusCode.InvalidArgument, "Grpc Client ID is required.")); + if (!_workers.TryGetValue(clientId, out var connection)) + { + throw new RpcException(new Status(StatusCode.InvalidArgument, $"Grpc Worker Connection not found for ClientId {clientId}.")); + } connection.AddSupportedType(request.Type); _supportedAgentTypes.GetOrAdd(request.Type, _ => []).Add(connection); - await _gatewayRegistry.RegisterAgentTypeAsync(request, _reference).ConfigureAwait(true); - return new RegisterAgentTypeResponse - { - Success = true, - RequestId = request.RequestId - }; + await _gatewayRegistry.RegisterAgentTypeAsync(request, clientId, _reference).ConfigureAwait(true); + return new RegisterAgentTypeResponse { }; } catch (Exception ex) { - return new RegisterAgentTypeResponse - { - Success = false, - RequestId = request.RequestId, - Error = ex.Message - }; + throw new RpcException(new Status(StatusCode.Internal, ex.Message)); } } + + /// + /// Subscribes to a topic asynchronously. + /// + /// The add subscription request. + /// The cancellation token. + /// A task that represents the asynchronous operation. The task result contains the add subscription response. public async ValueTask SubscribeAsync(AddSubscriptionRequest request, CancellationToken cancellationToken = default) { try { await _gatewayRegistry.SubscribeAsync(request).ConfigureAwait(true); - return new AddSubscriptionResponse - { - Success = true, - RequestId = request.RequestId - }; + return new AddSubscriptionResponse { }; } catch (Exception ex) { - return new AddSubscriptionResponse - { - Success = false, - RequestId = request.RequestId, - Error = ex.Message - }; + throw new RpcException(new Status(StatusCode.Internal, ex.Message)); } } - protected override async Task ExecuteAsync(CancellationToken stoppingToken) + + /// + /// Unsubscribes from a topic asynchronously. + /// + /// The remove subscription request. + /// The cancellation token. + /// A task that represents the asynchronous operation. The task result contains the remove subscription response. + public async ValueTask UnsubscribeAsync(RemoveSubscriptionRequest request, CancellationToken cancellationToken = default) { - while (!stoppingToken.IsCancellationRequested) - { - try - { - await _gatewayRegistry.AddWorkerAsync(_reference); - } - catch (Exception exception) - { - _logger.LogWarning(exception, "Error adding worker to registry."); - } - await Task.Delay(TimeSpan.FromSeconds(15), stoppingToken); - } try { - await _gatewayRegistry.RemoveWorkerAsync(_reference); + await _gatewayRegistry.UnsubscribeAsync(request).ConfigureAwait(true); + return new RemoveSubscriptionResponse { }; } - catch (Exception exception) + catch (Exception ex) { - _logger.LogWarning(exception, "Error removing worker from registry."); + throw new RpcException(new Status(StatusCode.Internal, ex.Message)); } } + + /// + /// Gets the subscriptions asynchronously. + /// + /// The get subscriptions request. + /// The cancellation token. + /// A task that represents the asynchronous operation. The task result contains the list of subscriptions. + public ValueTask> GetSubscriptionsAsync(GetSubscriptionsRequest request, CancellationToken cancellationToken = default) + { + return _gatewayRegistry.GetSubscriptionsAsync(request); + } + + async ValueTask IGateway.InvokeRequestAsync(RpcRequest request) + { + return await InvokeRequestAsync(request, default).ConfigureAwait(false); + } + + ValueTask IGateway.RegisterAgentTypeAsync(RegisterAgentTypeRequest request, ServerCallContext context) + { + return RegisterAgentTypeAsync(request, context, default); + } + + ValueTask IGateway.SubscribeAsync(AddSubscriptionRequest request) + { + return SubscribeAsync(request, default); + } + + ValueTask IGateway.UnsubscribeAsync(RemoveSubscriptionRequest request) + { + return UnsubscribeAsync(request, default); + } + + ValueTask> IGateway.GetSubscriptionsAsync(GetSubscriptionsRequest request) + { + return GetSubscriptionsAsync(request); + } + + /// + /// Connects to a worker process. + /// + /// The request stream. + /// The response stream. + /// The server call context. + /// A task that represents the asynchronous operation. internal async Task ConnectToWorkerProcess(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) { _logger.LogInformation("Received new connection from {Peer}.", context.Peer); + var clientId = (context.RequestHeaders.Get("client-id")?.Value) ?? + throw new RpcException(new Status(StatusCode.InvalidArgument, "Client ID is required.")); var workerProcess = new GrpcWorkerConnection(this, requestStream, responseStream, context); - _workers.GetOrAdd(workerProcess, workerProcess); - _workersByConnection.GetOrAdd(context.Peer, workerProcess); + _workers.GetOrAdd(clientId, workerProcess); await workerProcess.Connect().ConfigureAwait(false); } - internal async Task SendMessageAsync(GrpcWorkerConnection connection, CloudEvent cloudEvent, CancellationToken cancellationToken = default) - { - await connection.ResponseStream.WriteAsync(new Message { CloudEvent = cloudEvent }, cancellationToken).ConfigureAwait(false); - } + + /// + /// Handles received messages from a worker connection. + /// + /// The worker connection. + /// The received message. + /// The cancellation token. + /// A task that represents the asynchronous operation. internal async Task OnReceivedMessageAsync(GrpcWorkerConnection connection, Message message, CancellationToken cancellationToken = default) { _logger.LogInformation("Received message {Message} from connection {Connection}.", message, connection); @@ -171,18 +240,17 @@ internal async Task OnReceivedMessageAsync(GrpcWorkerConnection connection, Mess case Message.MessageOneofCase.CloudEvent: await DispatchEventAsync(message.CloudEvent, cancellationToken); break; - case Message.MessageOneofCase.RegisterAgentTypeRequest: - await RegisterAgentTypeAsync(connection, message.RegisterAgentTypeRequest); - break; - case Message.MessageOneofCase.AddSubscriptionRequest: - await AddSubscriptionAsync(connection, message.AddSubscriptionRequest); - break; default: - // if it wasn't recognized return bad request await RespondBadRequestAsync(connection, $"Unknown message type for message '{message}'."); break; }; } + + /// + /// Dispatches a response to a pending request. + /// + /// The worker connection. + /// The RPC response. private void DispatchResponse(GrpcWorkerConnection connection, RpcResponse response) { if (!_pendingRequests.TryRemove((connection, response.RequestId), out var completion)) @@ -190,26 +258,15 @@ private void DispatchResponse(GrpcWorkerConnection connection, RpcResponse respo _logger.LogWarning("Received response for unknown request id: {RequestId}.", response.RequestId); return; } - // Complete the request. completion.SetResult(response); } - private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection, RegisterAgentTypeRequest msg) - { - connection.AddSupportedType(msg.Type); - _supportedAgentTypes.GetOrAdd(msg.Type, _ => []).Add(connection); - await _gatewayRegistry.RegisterAgentTypeAsync(msg, _reference).ConfigureAwait(true); - Message response = new() - { - RegisterAgentTypeResponse = new() - { - RequestId = msg.RequestId, - Error = "", - Success = true - } - }; - await connection.ResponseStream.WriteAsync(response).ConfigureAwait(false); - } + /// + /// Dispatches an event to the appropriate agents. + /// + /// The cloud event. + /// The cancellation token. + /// A task that represents the asynchronous operation. private async ValueTask DispatchEventAsync(CloudEvent evt, CancellationToken cancellationToken = default) { var registry = _clusterClient.GetGrain(0); @@ -227,7 +284,7 @@ private async ValueTask DispatchEventAsync(CloudEvent evt, CancellationToken can var activeConnections = connections.Where(c => c.Completion?.IsCompleted == false).ToList(); foreach (var connection in activeConnections) { - tasks.Add(this.SendMessageAsync(connection, evt, cancellationToken)); + tasks.Add(this.WriteResponseAsync(connection, evt, cancellationToken)); } } } @@ -238,6 +295,13 @@ private async ValueTask DispatchEventAsync(CloudEvent evt, CancellationToken can _logger.LogWarning("No agent types found for event type {EventType}.", evt.Type); } } + + /// + /// Dispatches a request to the appropriate agent. + /// + /// The worker connection. + /// The RPC request. + /// A task that represents the asynchronous operation. private async ValueTask DispatchRequestAsync(GrpcWorkerConnection connection, RpcRequest request) { var requestId = request.RequestId; @@ -260,6 +324,14 @@ await InvokeRequestDelegate(connection, request, async request => return await gateway.InvokeRequestAsync(request).ConfigureAwait(true); }).ConfigureAwait(false); } + + /// + /// Invokes a request delegate. + /// + /// The worker connection. + /// The RPC request. + /// The function to invoke. + /// A task that represents the asynchronous operation. private static async Task InvokeRequestDelegate(GrpcWorkerConnection connection, RpcRequest request, Func> func) { try @@ -273,9 +345,16 @@ private static async Task InvokeRequestDelegate(GrpcWorkerConnection connection, await connection.ResponseStream.WriteAsync(new Message { Response = new RpcResponse { RequestId = request.RequestId, Error = ex.Message } }).ConfigureAwait(false); } } + + /// + /// Handles the removal of a worker process. + /// + /// The worker process. internal void OnRemoveWorkerProcess(GrpcWorkerConnection workerProcess) { - _workers.TryRemove(workerProcess, out _); + var clientId = workerProcess.ServerCallContext.RequestHeaders.Get("client-id")?.Value ?? + throw new RpcException(new Status(StatusCode.InvalidArgument, "Grpc Client ID is required.")); + _workers.TryRemove(clientId, out _); var types = workerProcess.GetSupportedTypes(); foreach (var type in types) { @@ -284,7 +363,6 @@ internal void OnRemoveWorkerProcess(GrpcWorkerConnection workerProcess) supported.Remove(workerProcess); } } - // Any agents activated on that worker are also gone. foreach (var pair in _agentDirectory) { if (pair.Value == workerProcess) @@ -293,39 +371,24 @@ internal void OnRemoveWorkerProcess(GrpcWorkerConnection workerProcess) } } } + + /// + /// Responds with a bad request error. + /// + /// The worker connection. + /// The error message. + /// A task that represents the asynchronous operation. private static async ValueTask RespondBadRequestAsync(GrpcWorkerConnection connection, string error) { throw new RpcException(new Status(StatusCode.InvalidArgument, error)); } - private async ValueTask AddSubscriptionAsync(GrpcWorkerConnection connection, AddSubscriptionRequest request) - { - var topic = ""; - var agentType = ""; - if (request.Subscription.TypePrefixSubscription is not null) - { - topic = request.Subscription.TypePrefixSubscription.TopicTypePrefix; - agentType = request.Subscription.TypePrefixSubscription.AgentType; - } - else if (request.Subscription.TypeSubscription is not null) - { - topic = request.Subscription.TypeSubscription.TopicType; - agentType = request.Subscription.TypeSubscription.AgentType; - } - _subscriptionsByAgentType[agentType] = request.Subscription; - _subscriptionsByTopic.GetOrAdd(topic, _ => []).Add(agentType); - await _subscriptions.SubscribeAsync(topic, agentType); - //var response = new SubscriptionResponse { RequestId = request.RequestId, Error = "", Success = true }; - Message response = new() - { - AddSubscriptionResponse = new() - { - RequestId = request.RequestId, - Error = "", - Success = true - } - }; - await connection.ResponseStream.WriteAsync(response).ConfigureAwait(false); - } + + /// + /// Dispatches an event to the specified agent types. + /// + /// The agent types. + /// The cloud event. + /// A task that represents the asynchronous operation. private async ValueTask DispatchEventToAgentsAsync(IEnumerable agentTypes, CloudEvent evt) { var tasks = new List(agentTypes.Count()); @@ -335,86 +398,33 @@ private async ValueTask DispatchEventToAgentsAsync(IEnumerable agentType { foreach (var connection in connections) { - tasks.Add(this.SendMessageAsync(connection, evt)); + tasks.Add(this.WriteResponseAsync(connection, evt)); } } } await Task.WhenAll(tasks).ConfigureAwait(false); } - public async ValueTask BroadcastEventAsync(CloudEvent evt, CancellationToken cancellationToken = default) - { - var tasks = new List(_workers.Count); - foreach (var (_, connection) in _supportedAgentTypes) - { - tasks.Add(this.SendMessageAsync((IConnection)connection[0], evt, default)); - } - await Task.WhenAll(tasks).ConfigureAwait(false); - } - Task IGateway.SendMessageAsync(IConnection connection, CloudEvent cloudEvent) - { - return this.SendMessageAsync(connection, cloudEvent, default); - } - public async Task SendMessageAsync(IConnection connection, CloudEvent cloudEvent, CancellationToken cancellationToken = default) + /// + /// Writes a response to a worker connection. + /// + /// The worker connection. + /// The cloud event. + /// The cancellation token. + /// A task that represents the asynchronous operation. + private async Task WriteResponseAsync(GrpcWorkerConnection connection, CloudEvent cloudEvent, CancellationToken cancellationToken = default) { - var queue = (GrpcWorkerConnection)connection; - await queue.ResponseStream.WriteAsync(new Message { CloudEvent = cloudEvent }, cancellationToken).ConfigureAwait(false); + await connection.ResponseStream.WriteAsync(new Message { CloudEvent = cloudEvent }, cancellationToken).ConfigureAwait(false); } - public async ValueTask UnsubscribeAsync(RemoveSubscriptionRequest request, CancellationToken cancellationToken = default) - { - try - { - await _gatewayRegistry.UnsubscribeAsync(request).ConfigureAwait(true); - return new RemoveSubscriptionResponse - - { - Success = true, - }; - } - catch (Exception ex) - { - return new RemoveSubscriptionResponse - { - Success = false, - Error = ex.Message - }; - } - } - public ValueTask> GetSubscriptionsAsync(GetSubscriptionsRequest request, CancellationToken cancellationToken = default) - { - return _gatewayRegistry.GetSubscriptionsAsync(request); - } - async ValueTask IGateway.InvokeRequestAsync(RpcRequest request) - { - return await InvokeRequestAsync(request, default).ConfigureAwait(false); - } - async ValueTask IGateway.BroadcastEventAsync(CloudEvent evt) - { - await BroadcastEventAsync(evt, default).ConfigureAwait(false); - } - ValueTask IGateway.StoreAsync(AgentState value) - { - return StoreAsync(value, default); - } - ValueTask IGateway.ReadAsync(AgentId agentId) - { - return ReadAsync(agentId, default); - } - ValueTask IGateway.RegisterAgentTypeAsync(RegisterAgentTypeRequest request) - { - return RegisterAgentTypeAsync(request, default); - } - ValueTask IGateway.SubscribeAsync(AddSubscriptionRequest request) - { - return SubscribeAsync(request, default); - } - ValueTask IGateway.UnsubscribeAsync(RemoveSubscriptionRequest request) - { - return UnsubscribeAsync(request, default); - } - ValueTask> IGateway.GetSubscriptionsAsync(GetSubscriptionsRequest request) - { - return GetSubscriptionsAsync(request); + /// + /// Writes a response to a worker connection. + /// + /// The worker connection. + /// The cloud event. + /// A task that represents the asynchronous operation. + public async Task WriteResponseAsync(IConnection connection, CloudEvent cloudEvent) + { + await WriteResponseAsync((GrpcWorkerConnection)connection, cloudEvent, default).ConfigureAwait(false); } } diff --git a/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Grpc/GrpcGatewayService.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Grpc/GrpcGatewayService.cs new file mode 100644 index 000000000000..1f04647db322 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Grpc/GrpcGatewayService.cs @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GrpcGatewayService.cs +using Grpc.Core; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.RuntimeGateway.Grpc; + +/// +/// Represents the gRPC service which handles communication between the agent worker and the cluster. +/// +public sealed class GrpcGatewayService(GrpcGateway gateway) : AgentRpc.AgentRpcBase +{ + private readonly GrpcGateway Gateway = (GrpcGateway)gateway; + + /// + /// Method run on first connect from a worker process. + /// + /// The request stream. + /// The response stream. + /// The server call context. + /// A task that represents the asynchronous operation. + public override async Task OpenChannel(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) + { + try + { + await Gateway.ConnectToWorkerProcess(requestStream, responseStream, context).ConfigureAwait(true); + } + catch + { + if (context.CancellationToken.IsCancellationRequested) + { + return; + } + throw; + } + } + + /// + /// Adds a subscription. + /// + /// The add subscription request. + /// The server call context. + /// A task that represents the asynchronous operation. The task result contains the add subscription response. + public override async Task AddSubscription(AddSubscriptionRequest request, ServerCallContext context) + { + try + { + return await Gateway.SubscribeAsync(request).ConfigureAwait(true); + } + catch (Exception e) + { + throw new RpcException(new Status(StatusCode.Internal, e.Message)); + } + } + + /// + /// Removes a subscription. + /// + /// The remove subscription request. + /// The server call context. + /// A task that represents the asynchronous operation. The task result contains the remove subscription response. + public override async Task RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context) + { + try + { + return await Gateway.UnsubscribeAsync(request).ConfigureAwait(true); + } + catch (Exception e) + { + throw new RpcException(new Status(StatusCode.Internal, e.Message)); + } + } + + /// + /// Gets the subscriptions. + /// + /// The get subscriptions request. + /// The server call context. + /// A task that represents the asynchronous operation. The task result contains the get subscriptions response. + public override async Task GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context) + { + try + { + var subscriptions = await Gateway.GetSubscriptionsAsync(request); + return new GetSubscriptionsResponse { Subscriptions = { subscriptions } }; + } + catch (Exception e) + { + throw new RpcException(new Status(StatusCode.Internal, e.Message)); + } + } + + /// + /// Registers an agent type (factory) + /// + /// The register agent type request. + /// The server call context. + /// A task that represents the asynchronous operation. The task result contains the register agent type response. + public override async Task RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context) + { + try + { + return await Gateway.RegisterAgentTypeAsync(request, context).ConfigureAwait(true); + } + catch (Exception e) + { + throw new RpcException(new Status(StatusCode.Internal, e.Message)); + } + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcWorkerConnection.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Grpc/GrpcWorkerConnection.cs similarity index 91% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcWorkerConnection.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Grpc/GrpcWorkerConnection.cs index cba0f8c4772b..0cdfc9a2898b 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcWorkerConnection.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Grpc/GrpcWorkerConnection.cs @@ -1,12 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // GrpcWorkerConnection.cs - using System.Threading.Channels; using Grpc.Core; +using Microsoft.AutoGen.Protobuf; +using Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions; -namespace Microsoft.AutoGen.Runtime.Grpc; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc; -internal sealed class GrpcWorkerConnection : IAsyncDisposable, IConnection +public sealed class GrpcWorkerConnection : IAsyncDisposable, IConnection { private static long s_nextConnectionId; private Task _readTask = Task.CompletedTask; @@ -102,9 +103,9 @@ public async Task RunWritePump() await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); try { - await foreach (var message in _outboundMessages.Reader.ReadAllAsync(_shutdownCancellationToken.Token)) + await foreach (var message in _outboundMessages.Reader.ReadAllAsync(_shutdownCancellationToken.Token).ConfigureAwait(false)) { - await ResponseStream.WriteAsync(message); + await ResponseStream.WriteAsync(message).ConfigureAwait(false); } } catch (OperationCanceledException) diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/OrleansRuntimeHostingExtenions.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/OrleansRuntimeHostingExtenions.cs similarity index 98% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/OrleansRuntimeHostingExtenions.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/OrleansRuntimeHostingExtenions.cs index e83db26ad0b7..9abf9ac048c7 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/OrleansRuntimeHostingExtenions.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/OrleansRuntimeHostingExtenions.cs @@ -9,7 +9,7 @@ using Orleans.Configuration; using Orleans.Serialization; -namespace Microsoft.AutoGen.Runtime.Grpc; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc; public static class OrleansRuntimeHostingExtenions { diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/RegistryGrain.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/RegistryGrain.cs similarity index 75% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/RegistryGrain.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/RegistryGrain.cs index 9de7065fdb62..522c20614606 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/RegistryGrain.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/RegistryGrain.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // RegistryGrain.cs -using Microsoft.AutoGen.Contracts; -using Microsoft.AutoGen.Runtime.Grpc.Abstractions; +using Microsoft.AutoGen.Protobuf; +using Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions; -namespace Microsoft.AutoGen.Runtime.Grpc; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc; internal sealed class RegistryGrain([PersistentState("state", "AgentRegistryStore")] IPersistentState state) : Grain, IRegistryGrain { private readonly Dictionary _workerStates = new(); @@ -16,7 +16,7 @@ public override Task OnActivateAsync(CancellationToken cancellationToken) this.RegisterGrainTimer(static state => state.PurgeInactiveWorkers(), this, TimeSpan.FromSeconds(30), TimeSpan.FromSeconds(30)); return base.OnActivateAsync(cancellationToken); } - public ValueTask> GetSubscribedAndHandlingAgentsAsync(string topic, string eventType) + public ValueTask> GetSubscribedAndHandlingAgentsAsync(string topic, string key) { List agents = []; // get all agent types that are subscribed to the topic @@ -29,19 +29,19 @@ public ValueTask> GetSubscribedAndHandlingAgentsAsync(string topic, }*/ agents.AddRange(subscribedAgentTypes.ToList()); } - if (state.State.TopicToAgentTypesMap.TryGetValue(eventType, out var eventHandlingAgents)) + if (state.State.TopicToAgentTypesMap.TryGetValue(key, out var eventHandlingAgents)) { agents.AddRange(eventHandlingAgents.ToList()); } - if (state.State.TopicToAgentTypesMap.TryGetValue(topic + "." + eventType, out var combo)) + if (state.State.TopicToAgentTypesMap.TryGetValue(topic + "." + key, out var combo)) { agents.AddRange(combo.ToList()); } - // instead of an exact match, we can also check for a prefix match where key starts with the eventType - if (state.State.TopicToAgentTypesMap.Keys.Any(key => key.StartsWith(eventType))) + // instead of an exact match, we can also check for a prefix match from the TopicPrefixToAgentTypesMap + if (state.State.TopicPrefixToAgentTypesMap.Keys.Any(key => key.StartsWith(topic))) { - state.State.TopicToAgentTypesMap.Where( - kvp => kvp.Key.StartsWith(eventType)) + state.State.TopicPrefixToAgentTypesMap.Where( + kvp => kvp.Key.StartsWith(topic)) .SelectMany(kvp => kvp.Value) .Distinct() .ToList() @@ -51,7 +51,6 @@ public ValueTask> GetSubscribedAndHandlingAgentsAsync(string topic, }); } agents = agents.Distinct().ToList(); - return new ValueTask>(agents); } public ValueTask<(IGateway? Worker, bool NewPlacement)> GetOrPlaceAgent(AgentId agentId) @@ -94,18 +93,8 @@ public ValueTask RemoveWorkerAsync(IGateway worker) } return ValueTask.CompletedTask; } - public async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest registration, IGateway gateway) + public async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest registration, string clientId, IGateway gateway) { - if (!_supportedAgentTypes.TryGetValue(registration.Type, out var supportedAgentTypes)) - { - supportedAgentTypes = _supportedAgentTypes[registration.Type] = []; - } - - if (!supportedAgentTypes.Contains(gateway)) - { - supportedAgentTypes.Add(gateway); - } - var workerState = GetOrAddWorker(gateway); workerState.SupportedTypes.Add(registration.Type); @@ -159,9 +148,7 @@ private WorkerState GetOrAddWorker(IGateway worker) workerState.LastSeen = DateTimeOffset.UtcNow; return workerState; } - public ValueTask GetCompatibleWorkerAsync(string type) => new(GetCompatibleWorkerCore(type)); - private IGateway? GetCompatibleWorkerCore(string type) { if (_supportedAgentTypes.TryGetValue(type, out var workers)) @@ -178,9 +165,27 @@ public async ValueTask SubscribeAsync(AddSubscriptionRequest subscription) subscription.Subscription.Id = guid; switch (subscription.Subscription.SubscriptionCase) { - //TODO: this doesnt look right case Subscription.SubscriptionOneofCase.TypePrefixSubscription: - break; + { + // add the topic to the set of topics for the agent type + state.State.AgentsToTopicsMap.TryGetValue(subscription.Subscription.TypePrefixSubscription.AgentType, out var topics); + if (topics is null) + { + topics = new HashSet(); + state.State.AgentsToTopicsPrefixMap[subscription.Subscription.TypePrefixSubscription.AgentType] = topics; + } + topics.Add(subscription.Subscription.TypePrefixSubscription.TopicTypePrefix); + + // add the agent type to the set of agent types for the topic + state.State.TopicPrefixToAgentTypesMap.TryGetValue(subscription.Subscription.TypePrefixSubscription.TopicTypePrefix, out var agents); + if (agents is null) + { + agents = new HashSet(); + state.State.TopicPrefixToAgentTypesMap[subscription.Subscription.TypePrefixSubscription.TopicTypePrefix] = agents; + } + agents.Add(subscription.Subscription.TypePrefixSubscription.AgentType); + break; + } case Subscription.SubscriptionOneofCase.TypeSubscription: { // add the topic to the set of topics for the agent type @@ -200,20 +205,19 @@ public async ValueTask SubscribeAsync(AddSubscriptionRequest subscription) state.State.TopicToAgentTypesMap[subscription.Subscription.TypeSubscription.TopicType] = agents; } agents.Add(subscription.Subscription.TypeSubscription.AgentType); - - // add the subscription by Guid - state.State.GuidSubscriptionsMap.TryGetValue(guid, out var existingSubscriptions); - if (existingSubscriptions is null) - { - existingSubscriptions = new HashSet(); - state.State.GuidSubscriptionsMap[guid] = existingSubscriptions; - } - existingSubscriptions.Add(subscription.Subscription); break; } default: throw new InvalidOperationException("Invalid subscription type"); } + // add the subscription by Guid + state.State.GuidSubscriptionsMap.TryGetValue(guid, out var existingSubscriptions); + if (existingSubscriptions is null) + { + existingSubscriptions = new HashSet(); + state.State.GuidSubscriptionsMap[guid] = existingSubscriptions; + } + existingSubscriptions.Add(subscription.Subscription); await state.WriteStateAsync().ConfigureAwait(false); } public async ValueTask UnsubscribeAsync(RemoveSubscriptionRequest request) @@ -239,17 +243,25 @@ public async ValueTask UnsubscribeAsync(RemoveSubscriptionRequest request) // remove the agent type from the set of agent types for the topic state.State.TopicToAgentTypesMap.TryGetValue(subscription.TypeSubscription.TopicType, out var agents); agents?.Remove(subscription.TypeSubscription.AgentType); - - //remove the subscription by Guid - state.State.GuidSubscriptionsMap.TryGetValue(guid, out var existingSubscriptions); - existingSubscriptions?.Remove(subscription); break; } case Subscription.SubscriptionOneofCase.TypePrefixSubscription: - break; + { + // remove the topic from the set of topics for the agent type + state.State.AgentsToTopicsPrefixMap.TryGetValue(subscription.TypePrefixSubscription.AgentType, out var topics); + topics?.Remove(subscription.TypePrefixSubscription.TopicTypePrefix); + + // remove the agent type from the set of agent types for the topic + state.State.TopicPrefixToAgentTypesMap.TryGetValue(subscription.TypePrefixSubscription.TopicTypePrefix, out var agents); + agents?.Remove(subscription.TypePrefixSubscription.AgentType); + break; + } default: throw new InvalidOperationException("Invalid subscription type"); } + //remove the subscription by Guid + state.State.GuidSubscriptionsMap.TryGetValue(guid, out var existingSubscriptions); + existingSubscriptions?.Remove(subscription); } state.State.GuidSubscriptionsMap.Remove(guid, out _); } @@ -265,17 +277,6 @@ public ValueTask> GetSubscriptionsAsync(GetSubscriptionsReque } return new(subscriptions); } - public ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, IAgentRuntime worker) - { - var (_, _) = (request, worker); - var e = "RegisterAgentTypeAsync(RegisterAgentTypeRequest request, IAgentRuntime worker) is not implemented when using the Grpc runtime."; - throw new NotImplementedException(e); - } - public ValueTask UnregisterAgentTypeAsync(string type, IAgentRuntime worker) - { - var e = "UnregisterAgentTypeAsync(string type, IAgentRuntime worker) is not implemented when using the Grpc runtime."; - throw new NotImplementedException(e); - } private sealed class WorkerState { public HashSet SupportedTypes { get; set; } = []; diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionRequestSurrogate.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/AddSubscriptionRequestSurrogate.cs similarity index 86% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionRequestSurrogate.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/AddSubscriptionRequestSurrogate.cs index 37e3af1b9d17..793301cf4d5f 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionRequestSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/AddSubscriptionRequestSurrogate.cs @@ -1,7 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AddSubscriptionRequestSurrogate.cs +using Microsoft.AutoGen.Protobuf; -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct AddSubscriptionRequestSurrogate @@ -21,7 +22,6 @@ public AddSubscriptionRequest ConvertFromSurrogate( { var request = new AddSubscriptionRequest() { - RequestId = surrogate.RequestId, Subscription = surrogate.Subscription }; return request; @@ -31,7 +31,6 @@ public AddSubscriptionRequestSurrogate ConvertToSurrogate( in AddSubscriptionRequest value) => new AddSubscriptionRequestSurrogate { - RequestId = value.RequestId, Subscription = value.Subscription }; } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionResponseSurrogate.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/AddSubscriptionResponseSurrogate.cs similarity index 61% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionResponseSurrogate.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/AddSubscriptionResponseSurrogate.cs index 4c15784e0fcc..6a07a114d3e6 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionResponseSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/AddSubscriptionResponseSurrogate.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AddSubscriptionResponseSurrogate.cs -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct AddSubscriptionResponseSurrogate @@ -20,20 +22,10 @@ public sealed class AddSubscriptionResponseSurrogateConverter : { public AddSubscriptionResponse ConvertFromSurrogate( in AddSubscriptionResponseSurrogate surrogate) => - new AddSubscriptionResponse - { - RequestId = surrogate.RequestId, - Success = surrogate.Success, - Error = surrogate.Error - }; + new AddSubscriptionResponse { }; public AddSubscriptionResponseSurrogate ConvertToSurrogate( in AddSubscriptionResponse value) => - new AddSubscriptionResponseSurrogate - { - RequestId = value.RequestId, - Success = value.Success, - Error = value.Error - }; + new AddSubscriptionResponseSurrogate { }; } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentIdSurrogate.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/AgentIdSurrogate.cs similarity index 88% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentIdSurrogate.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/AgentIdSurrogate.cs index ddef9e997575..af7728d1254c 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentIdSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/AgentIdSurrogate.cs @@ -3,9 +3,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AgentIdSurrogate.cs -using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct AgentIdSurrogate diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/CloudEventSurrogate.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/CloudEventSurrogate.cs similarity index 91% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/CloudEventSurrogate.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/CloudEventSurrogate.cs index 22359a08981c..a69f7cfc89a5 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/CloudEventSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/CloudEventSurrogate.cs @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // CloudEventSurrogate.cs - using Google.Protobuf; using Google.Protobuf.WellKnownTypes; +using Microsoft.AutoGen.Contracts; -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; // TODO: Add the rest of the properties [GenerateSerializer] diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/GetSubscriptionsRequest.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/GetSubscriptionsRequest.cs similarity index 88% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/GetSubscriptionsRequest.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/GetSubscriptionsRequest.cs index ab4722ff8c74..e53948041828 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/GetSubscriptionsRequest.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/GetSubscriptionsRequest.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // GetSubscriptionsRequest.cs -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct GetSubscriptionsRequestSurrogate diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeRequestSurrogate.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeRequestSurrogate.cs similarity index 89% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeRequestSurrogate.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeRequestSurrogate.cs index fa50e597fabe..9bcba2391bb4 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeRequestSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeRequestSurrogate.cs @@ -2,8 +2,9 @@ // RegisterAgentTypeRequestSurrogate.cs using Google.Protobuf.Collections; +using Microsoft.AutoGen.Protobuf; -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct RegisterAgentTypeRequestSurrogate @@ -27,7 +28,6 @@ public RegisterAgentTypeRequest ConvertFromSurrogate( { var request = new RegisterAgentTypeRequest() { - RequestId = surrogate.RequestId, Type = surrogate.Type }; /* future @@ -40,7 +40,6 @@ public RegisterAgentTypeRequestSurrogate ConvertToSurrogate( in RegisterAgentTypeRequest value) => new RegisterAgentTypeRequestSurrogate { - RequestId = value.RequestId, Type = value.Type, /* future Events = value.Events, diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeResponseSurrogate.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeResponseSurrogate.cs similarity index 62% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeResponseSurrogate.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeResponseSurrogate.cs index 2c7d6788a76c..c91fb3833c30 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeResponseSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeResponseSurrogate.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // RegisterAgentTypeResponseSurrogate.cs -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct RegisterAgentTypeResponseSurrogate @@ -20,20 +22,10 @@ public sealed class RegisterAgentTypeResponseSurrogateConverter : { public RegisterAgentTypeResponse ConvertFromSurrogate( in RegisterAgentTypeResponseSurrogate surrogate) => - new RegisterAgentTypeResponse - { - RequestId = surrogate.RequestId, - Success = surrogate.Success, - Error = surrogate.Error - }; + new RegisterAgentTypeResponse { }; public RegisterAgentTypeResponseSurrogate ConvertToSurrogate( in RegisterAgentTypeResponse value) => - new RegisterAgentTypeResponseSurrogate - { - RequestId = value.RequestId, - Success = value.Success, - Error = value.Error - }; + new RegisterAgentTypeResponseSurrogate { }; } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionRequest.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionRequest.cs similarity index 89% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionRequest.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionRequest.cs index 27299728baa8..9b397c2bb82d 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionRequest.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionRequest.cs @@ -1,7 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // RemoveSubscriptionRequest.cs +using Microsoft.AutoGen.Protobuf; -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct RemoveSubscriptionRequestSurrogate diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionResponse.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionResponse.cs similarity index 67% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionResponse.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionResponse.cs index 88253c99b916..eec77162942e 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionResponse.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionResponse.cs @@ -1,7 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // RemoveSubscriptionResponse.cs +using Microsoft.AutoGen.Protobuf; -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct RemoveSubscriptionResponseSurrogate @@ -20,18 +21,10 @@ public sealed class SubscriptionResponseSurrogateConverter : { public RemoveSubscriptionResponse ConvertFromSurrogate( in RemoveSubscriptionResponseSurrogate surrogate) => - new RemoveSubscriptionResponse - { - Success = surrogate.Success, - Error = surrogate.Error - }; + new RemoveSubscriptionResponse { }; public RemoveSubscriptionResponseSurrogate ConvertToSurrogate( in RemoveSubscriptionResponse value) => - new RemoveSubscriptionResponseSurrogate - { - Success = value.Success, - Error = value.Error - }; + new RemoveSubscriptionResponseSurrogate { }; } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcRequestSurrogate.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RpcRequestSurrogate.cs similarity index 92% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcRequestSurrogate.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RpcRequestSurrogate.cs index 9791a68d7952..4b9fdb2500f6 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcRequestSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RpcRequestSurrogate.cs @@ -2,9 +2,9 @@ // RpcRequestSurrogate.cs using Google.Protobuf.Collections; -using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct RpcRequestSurrogate diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcResponseSurrogate.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RpcResponseSurrogate.cs similarity index 91% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcResponseSurrogate.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RpcResponseSurrogate.cs index 5c9fac246f84..999ae2bf6502 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcResponseSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/RpcResponseSurrogate.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // RpcResponseSurrogate.cs - using Google.Protobuf.Collections; +using Microsoft.AutoGen.Protobuf; -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct RpcResponseSurrogate diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/SubscriptionSurrogate.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/SubscriptionSurrogate.cs similarity index 92% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/SubscriptionSurrogate.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/SubscriptionSurrogate.cs index 1fd56c176278..6942ada72b0b 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/SubscriptionSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/SubscriptionSurrogate.cs @@ -1,9 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SubscriptionSurrogate.cs +using Microsoft.AutoGen.Protobuf; -using Microsoft.AutoGen.Contracts; - -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct SubscriptionSurrogate diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypePrefixSubscriptionSurrogate.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/TypePrefixSubscriptionSurrogate.cs similarity index 90% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypePrefixSubscriptionSurrogate.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/TypePrefixSubscriptionSurrogate.cs index ca4d721315e8..c38d84641b11 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypePrefixSubscriptionSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/TypePrefixSubscriptionSurrogate.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // TypePrefixSubscriptionSurrogate.cs -using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct TypePrefixSubscriptionSurrogate diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypeSubscriptionSurrogate.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/TypeSubscriptionSurrogate.cs similarity index 89% rename from dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypeSubscriptionSurrogate.cs rename to dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/TypeSubscriptionSurrogate.cs index 57fa202ebfc3..958c1c4d682f 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypeSubscriptionSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/Surrogates/TypeSubscriptionSurrogate.cs @@ -1,9 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // TypeSubscriptionSurrogate.cs -using Microsoft.AutoGen.Contracts; - -namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; +using Microsoft.AutoGen.Protobuf; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Orleans.Surrogates; [GenerateSerializer] public struct TypeSubscriptionSurrogate diff --git a/dotnet/test/Microsoft.AutoGen.Integration.Tests.AppHosts/HelloAgent.AppHost/HelloAgent.AppHost.csproj b/dotnet/test/Microsoft.AutoGen.Integration.Tests.AppHosts/HelloAgent.AppHost/HelloAgent.AppHost.csproj index 1442ebe3d05d..441d48d18cb5 100644 --- a/dotnet/test/Microsoft.AutoGen.Integration.Tests.AppHosts/HelloAgent.AppHost/HelloAgent.AppHost.csproj +++ b/dotnet/test/Microsoft.AutoGen.Integration.Tests.AppHosts/HelloAgent.AppHost/HelloAgent.AppHost.csproj @@ -16,6 +16,6 @@ - + diff --git a/dotnet/test/Microsoft.AutoGen.Integration.Tests/Microsoft.AutoGen.Integration.Tests.csproj b/dotnet/test/Microsoft.AutoGen.Integration.Tests/Microsoft.AutoGen.Integration.Tests.csproj index 320aa44deccb..b3d7cbf8776a 100644 --- a/dotnet/test/Microsoft.AutoGen.Integration.Tests/Microsoft.AutoGen.Integration.Tests.csproj +++ b/dotnet/test/Microsoft.AutoGen.Integration.Tests/Microsoft.AutoGen.Integration.Tests.csproj @@ -28,15 +28,15 @@ - - + + .venv - $(RepoRoot)..\python + $(RepoRoot)..\python @@ -44,10 +44,17 @@ $(PythonVenvRoot)\$(PythonVirtualEnvironmentName)\ True + ~/.local/bin/uv + True + uv + $(Uv) - + + + + diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs deleted file mode 100644 index fcae1ec3dcdb..000000000000 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// GrpcGatewayServiceTests.cs - -using FluentAssertions; -using Microsoft.AutoGen.Contracts; -using Microsoft.AutoGen.Core; -using Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Grpc; -using Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Orleans; -using Microsoft.Extensions.Logging; -using Moq; -using NewMessageReceived = Tests.Events.NewMessageReceived; - -namespace Microsoft.AutoGen.Runtime.Grpc.Tests; -[Collection(ClusterCollection.Name)] -[Trait("Category", "GRPC")] -public class GrpcGatewayServiceTests -{ - private readonly ClusterFixture _fixture; - - public GrpcGatewayServiceTests(ClusterFixture fixture) - { - _fixture = fixture; - } - [Fact] - public async Task Test_OpenChannel() - { - var logger = Mock.Of>(); - var gateway = new GrpcGateway(_fixture.Cluster.Client, logger); - var service = new GrpcGatewayService(gateway); - var client = new TestGrpcClient(); - - gateway._workers.Count.Should().Be(0); - var task = OpenChannel(service, client); - gateway._workers.Count.Should().Be(1); - client.Dispose(); - await task; - } - - [Fact] - public async Task Test_Message_Exchange_Through_Gateway() - { - var logger = Mock.Of>(); - var gateway = new GrpcGateway(_fixture.Cluster.Client, logger); - var service = new GrpcGatewayService(gateway); - var client = new TestGrpcClient(); - var task = OpenChannel(service: service, client); - await service.RegisterAgent(await CreateRegistrationRequest(service, typeof(PBAgent), client.CallContext.Peer), client.CallContext); - await service.RegisterAgent(await CreateRegistrationRequest(service, typeof(GMAgent), client.CallContext.Peer), client.CallContext); - - var inputEvent = new NewMessageReceived { Message = $"Start-{client.CallContext.Peer}" }.ToCloudEvent("gh-gh-gh", "gh-gh-gh"); - - client.AddMessage(new Message { CloudEvent = inputEvent }); - var newMessageReceived = await client.ReadNext(); - newMessageReceived!.CloudEvent.Type.Should().Be(GetFullName(typeof(NewMessageReceived))); - newMessageReceived.CloudEvent.Source.Should().Be("gh-gh-gh"); - var secondMessage = await client.ReadNext(); - secondMessage!.CloudEvent.Type.Should().Be(GetFullName(typeof(NewMessageReceived))); - - // Simulate an agent, by publishing a new message in the request stream - var helloEvent = new Hello { Message = $"Hello test-{client.CallContext.Peer}" }.ToCloudEvent("gh-gh-gh", "gh-gh-gh"); - client.AddMessage(new Message { CloudEvent = helloEvent }); - var helloMessageReceived = await client.ReadNext(); - helloMessageReceived!.CloudEvent.Type.Should().Be(GetFullName(typeof(Hello))); - helloMessageReceived.CloudEvent.Source.Should().Be("gh-gh-gh"); - client.Dispose(); - await task; - } - - [Fact] - public async Task Test_RegisterAgent_Should_Succeed() - { - var logger = Mock.Of>(); - var gateway = new GrpcGateway(_fixture.Cluster.Client, logger); - var service = new GrpcGatewayService(gateway); - var client = new TestGrpcClient(); - var task = OpenChannel(service: service, client); - var response = await service.RegisterAgent(await CreateRegistrationRequest(service, typeof(PBAgent), client.CallContext.Peer), client.CallContext); - response.Success.Should().BeTrue(); - client.Dispose(); - await task; - } - - [Fact] - public async Task Test_RegisterAgent_Should_Fail_For_Wrong_ConnectionId() - { - var logger = Mock.Of>(); - var gateway = new GrpcGateway(_fixture.Cluster.Client, logger); - var service = new GrpcGatewayService(gateway); - var client = new TestGrpcClient(); - var response = await service.RegisterAgent(await CreateRegistrationRequest(service, typeof(PBAgent), "faulty_connection_id"), client.CallContext); - response.Success.Should().BeFalse(); - client.Dispose(); - } - - [Fact] - public async Task Test_SaveState() - { - var logger = Mock.Of>(); - var gateway = new GrpcGateway(_fixture.Cluster.Client, logger); - var service = new GrpcGatewayService(gateway); - var callContext = TestServerCallContext.Create(); - var response = await service.SaveState(new AgentState { AgentId = new AgentId { Key = "Test", Type = "test" } }, callContext); - response.Should().NotBeNull(); - } - - [Fact] - public async Task Test_GetState() - { - var logger = Mock.Of>(); - var gateway = new GrpcGateway(_fixture.Cluster.Client, logger); - var service = new GrpcGatewayService(gateway); - var callContext = TestServerCallContext.Create(); - var response = await service.GetState(new AgentId { Key = "", Type = "" }, callContext); - response.Should().NotBeNull(); - } - - private async Task CreateRegistrationRequest(GrpcGatewayService service, Type type, string requestId) - { - var registration = new RegisterAgentTypeRequest - { - Type = type.Name, - RequestId = requestId - }; - var assembly = type.Assembly; - var eventTypes = ReflectionHelper.GetAgentsMetadata(assembly); - var events = eventTypes.GetEventsForAgent(type)?.ToList(); - var topics = eventTypes.GetTopicsForAgent(type)?.ToList(); - if (events is not null && topics is not null) { events.AddRange(topics); } - var client = new TestGrpcClient(); - - if (events != null) - { - foreach (var e in events) - { - var subscriptionRequest = new Message - { - AddSubscriptionRequest = new AddSubscriptionRequest - { - RequestId = Guid.NewGuid().ToString(), - Subscription = new Subscription - { - TypeSubscription = new TypeSubscription - { - AgentType = type.Name, - TopicType = type.Name + "." + e - } - } - } - }; - await service.AddSubscription(subscriptionRequest.AddSubscriptionRequest, client.CallContext); - } - } - var topicTypes = type.GetCustomAttributes(typeof(TopicSubscriptionAttribute), true).Cast().Select(t => t.Topic).ToList(); - if (topicTypes != null) - { - foreach (var topicType in topicTypes) - { - var subscriptionRequest = new Message - { - AddSubscriptionRequest = new AddSubscriptionRequest - { - RequestId = Guid.NewGuid().ToString(), - Subscription = new Subscription - { - TypeSubscription = new TypeSubscription - { - AgentType = type.Name, - TopicType = topicType - } - } - } - }; - await service.AddSubscription(subscriptionRequest.AddSubscriptionRequest, client.CallContext); - } - } - return registration; - } - - private Task OpenChannel(GrpcGatewayService service, TestGrpcClient client) - { - return service.OpenChannel(client.RequestStream, client.ResponseStream, client.CallContext); - } - private string GetFullName(Type type) - { - return ReflectionHelper.GetMessageDescriptor(type)!.FullName; - } -} diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/TestAgent.cs b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/TestAgent.cs deleted file mode 100644 index e479a2cd7d33..000000000000 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/TestAgent.cs +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// TestAgent.cs - -using System.Collections.Concurrent; -using Microsoft.AutoGen.Contracts; -using Microsoft.AutoGen.Core; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; - -namespace Microsoft.AutoGen.Runtime.Grpc.Tests; - -[TopicSubscription("gh-gh-gh")] -public class PBAgent([FromKeyedServices("AgentsMetadata")] AgentsMetadata eventTypes, ILogger? logger = null) - : Agent(eventTypes, logger) - , IHandle - , IHandle -{ - public async Task Handle(NewMessageReceived item, CancellationToken cancellationToken = default) - { - ReceivedMessages[AgentId.Key] = item.Message; - var hello = new Hello { Message = item.Message }; - await PublishMessageAsync(hello); - } - public Task Handle(GoodBye item, CancellationToken cancellationToken) - { - _logger.LogInformation($"Received GoodBye message {item.Message}"); - return Task.CompletedTask; - } - - public static ConcurrentDictionary ReceivedMessages { get; private set; } = new(); -} - -[TopicSubscription("gh-gh-gh")] -public class GMAgent([FromKeyedServices("AgentsMetadata")] AgentsMetadata eventTypes, ILogger? logger = null) - : Agent(eventTypes, logger) - , IHandle -{ - public async Task Handle(Hello item, CancellationToken cancellationToken) - { - _logger.LogInformation($"Received Hello message {item.Message}"); - ReceivedMessages[AgentId.Key] = item.Message; - await PublishMessageAsync(new GoodBye { Message = "" }); - } - - public static ConcurrentDictionary ReceivedMessages { get; private set; } = new(); -} diff --git a/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/GrpcGatewayServiceTests.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/GrpcGatewayServiceTests.cs new file mode 100644 index 000000000000..e9009c92394e --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/GrpcGatewayServiceTests.cs @@ -0,0 +1,241 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GrpcGatewayServiceTests.cs + +using FluentAssertions; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core; +using Microsoft.AutoGen.Protobuf; +using Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.Helpers.Grpc; +using Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.Helpers.Orleans; +using Microsoft.Extensions.Logging; +using Moq; +using NewMessageReceived = Tests.Events.NewMessageReceived; + +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests; +[Collection(ClusterCollection.Name)] +[Trait("Category", "UnitV2")] +public class GrpcGatewayServiceTests +{ + private readonly ClusterFixture _fixture; + + public GrpcGatewayServiceTests(ClusterFixture fixture) + { + _fixture = fixture; + } + [Fact] + public async Task Test_OpenChannel() + { + var logger = Mock.Of>(); + var gateway = new GrpcGateway(_fixture.Cluster.Client, logger); + var service = new GrpcGatewayService(gateway); + var client = new TestGrpcClient(); + + gateway._workers.Count.Should().Be(0); + var task = OpenChannel(service, client); + gateway._workers.Count.Should().Be(1); + client.Dispose(); + await task; + } + + [Fact] + public async Task Test_Message_Exchange_Through_Gateway() + { + var logger = Mock.Of>(); + var gateway = new GrpcGateway(_fixture.Cluster.Client, logger); + var service = new GrpcGatewayService(gateway); + var client = new TestGrpcClient(); + var task = OpenChannel(service: service, client); + await service.RegisterAgent(await CreateRegistrationRequest(service, typeof(PBAgent)), client.CallContext); + await service.RegisterAgent(await CreateRegistrationRequest(service, typeof(GMAgent)), client.CallContext); + + //var inputEvent = new NewMessageReceived { Message = $"Start-{client.CallContext.Peer}" }.ToCloudEvent("gh-gh-gh", "gh-gh-gh"); + var newMessage = new NewMessageReceived { Message = $"Start-{client.CallContext.Peer}" }; + var eventType = GetFullName(typeof(NewMessageReceived)); + var inputEvent = CloudEventExtensions.CreateCloudEvent( + Google.Protobuf.WellKnownTypes.Any.Pack(newMessage), + new TopicId(eventType, "gh-gh-gh"), + eventType, + null, + Guid.NewGuid().ToString()); + + client.AddMessage(new Message { CloudEvent = inputEvent }); + var newMessageReceived = await client.ReadNext(); + newMessageReceived!.CloudEvent.Type.Should().Be(GetFullName(typeof(NewMessageReceived))); + newMessageReceived.CloudEvent.Source.Should().Be("gh-gh-gh"); + var secondMessage = await client.ReadNext(); + secondMessage!.CloudEvent.Type.Should().Be(GetFullName(typeof(NewMessageReceived))); + + // Simulate an agent, by publishing a new message in the request stream + //var helloEvent = new Hello { Message = $"Hello test-{client.CallContext.Peer}" }.ToCloudEvent("gh-gh-gh", "gh-gh-gh"); + var hello = new Hello { Message = $"Hello test-{client.CallContext.Peer}" }; + var eventTypeHello = GetFullName(typeof(Hello)); + var helloEvent = CloudEventExtensions.CreateCloudEvent( + Google.Protobuf.WellKnownTypes.Any.Pack(message: hello), + new TopicId(eventTypeHello, "gh-gh-gh"), + eventTypeHello, + null, + Guid.NewGuid().ToString() + ); + client.AddMessage(new Message { CloudEvent = helloEvent }); + var helloMessageReceived = await client.ReadNext(); + helloMessageReceived!.CloudEvent.Type.Should().Be(GetFullName(typeof(Hello))); + helloMessageReceived.CloudEvent.Source.Should().Be("gh-gh-gh"); + client.Dispose(); + await task; + } + + [Fact] + public async Task Test_RegisterAgent_Should_Succeed() + { + var logger = Mock.Of>(); + var gateway = new GrpcGateway(_fixture.Cluster.Client, logger); + var service = new GrpcGatewayService(gateway); + var client = new TestGrpcClient(); + var task = OpenChannel(service: service, client); + var response = await service.RegisterAgent(await CreateRegistrationRequest(service, typeof(PBAgent)), client.CallContext); + response.GetType().Should().Be(typeof(RegisterAgentTypeResponse)); + client.Dispose(); + await task; + } + + private async Task CreateRegistrationRequest(GrpcGatewayService service, Type type) + { + var registration = new RegisterAgentTypeRequest + { + Type = type.Name, + }; + var assembly = type.Assembly; + var eventTypes = ReflectionHelper.GetAgentsMetadata(assembly); + var events = eventTypes.GetEventsForAgent(type)?.ToList(); + var topics = eventTypes.GetTopicsForAgent(type)?.ToList(); + var topicsPrefix = eventTypes.GetTopicsPrefixForAgent(type)?.ToList(); + if (events is not null && topics is not null) { events.AddRange(topics); } + var client = new TestGrpcClient(); + + if (events != null) + { + foreach (var e in events) + { + var subscriptionRequest = new AddSubscriptionRequest + { + Subscription = new Subscription + { + Id = Guid.NewGuid().ToString(), + TypeSubscription = new Protobuf.TypeSubscription + { + AgentType = type.Name, + TopicType = type.Name + "." + e + } + } + + }; + await service.AddSubscription(subscriptionRequest, client.CallContext); + } + } + var topicTypes = type.GetCustomAttributes(typeof(TypeSubscriptionAttribute), true).Cast().Select(t => t.Topic).ToList(); + if (topicTypes != null) + { + foreach (var topicType in topicTypes) + { + var subscriptionRequest = new AddSubscriptionRequest + { + Subscription = new Subscription + { + Id = Guid.NewGuid().ToString(), + TypeSubscription = new Protobuf.TypeSubscription + { + AgentType = type.Name, + TopicType = topicType + } + } + + }; + await service.AddSubscription(subscriptionRequest, client.CallContext); + } + } + var topicPrefixTypes = type.GetCustomAttributes(typeof(TypePrefixSubscriptionAttribute), true).Cast().Select(t => t.Topic).ToList(); + if (topicPrefixTypes != null) + { + foreach (var topicType in topicPrefixTypes) + { + var subscriptionRequest = new AddSubscriptionRequest + { + Subscription = new Subscription + { + Id = Guid.NewGuid().ToString(), + TypePrefixSubscription = new Protobuf.TypePrefixSubscription + { + AgentType = type.Name, + TopicTypePrefix = topicType + } + } + + }; + await service.AddSubscription(subscriptionRequest, client.CallContext); + } + } + return registration; + } + + private Task OpenChannel(GrpcGatewayService service, TestGrpcClient client) + { + return service.OpenChannel(client.RequestStream, client.ResponseStream, client.CallContext); + } + private string GetFullName(Type type) + { + return ReflectionHelper.GetMessageDescriptor(type)!.FullName; + } + /// duplicate code here because I could not get InternalsVisibleTo to work + internal static class Constants + { + public const string DATA_CONTENT_TYPE_PROTOBUF_VALUE = "application/x-protobuf"; + public const string DATA_CONTENT_TYPE_JSON_VALUE = "application/json"; + public const string DATA_CONTENT_TYPE_TEXT_VALUE = "text/plain"; + + public const string DATA_CONTENT_TYPE_ATTR = "datacontenttype"; + public const string DATA_SCHEMA_ATTR = "dataschema"; + public const string AGENT_SENDER_TYPE_ATTR = "agagentsendertype"; + public const string AGENT_SENDER_KEY_ATTR = "agagentsenderkey"; + + public const string MESSAGE_KIND_ATTR = "agmsgkind"; + public const string MESSAGE_KIND_VALUE_PUBLISH = "publish"; + public const string MESSAGE_KIND_VALUE_RPC_REQUEST = "rpc_request"; + public const string MESSAGE_KIND_VALUE_RPC_RESPONSE = "rpc_response"; + } + internal static class CloudEventExtensions + { + // Convert an ISubscrptionDefinition to a Protobuf Subscription + internal static CloudEvent CreateCloudEvent(Google.Protobuf.WellKnownTypes.Any payload, TopicId topic, string dataType, Contracts.AgentId? sender, string messageId) + { + var attributes = new Dictionary + { + { + Constants.DATA_CONTENT_TYPE_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = Constants.DATA_CONTENT_TYPE_PROTOBUF_VALUE } + }, + { + Constants.DATA_SCHEMA_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = dataType } + }, + { + Constants.MESSAGE_KIND_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = Constants.MESSAGE_KIND_VALUE_PUBLISH } + } + }; + + if (sender != null) + { + var senderNonNull = (Contracts.AgentId)sender; + attributes.Add(Constants.AGENT_SENDER_TYPE_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = senderNonNull.Type }); + attributes.Add(Constants.AGENT_SENDER_KEY_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = senderNonNull.Key }); + } + + return new CloudEvent + { + ProtoData = payload, + Type = topic.Type, + Source = topic.Source, + Id = messageId, + Attributes = { attributes } + }; + + } + } +} diff --git a/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/AgentTypes.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/AgentTypes.cs new file mode 100644 index 000000000000..d8e286b1fbb3 --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/AgentTypes.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AgentTypes.cs +using Microsoft.AutoGen.Core; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests; +public sealed class AgentTypes(Dictionary types) +{ + public Dictionary Types { get; } = types; + public static AgentTypes? GetAgentTypesFromAssembly() + { + var agents = AppDomain.CurrentDomain.GetAssemblies() + .SelectMany(assembly => assembly.GetTypes()) + .Where(type => ReflectionHelper.IsSubclassOfGeneric(type, typeof(BaseAgent)) + && !type.IsAbstract) + .ToDictionary(type => type.Name, type => type); + + return new AgentTypes(agents); + } +} diff --git a/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/AgentsMetadata.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/AgentsMetadata.cs new file mode 100644 index 000000000000..016bfc329bfe --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/AgentsMetadata.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AgentsMetadata.cs + +using System.Collections.Concurrent; +using Google.Protobuf.Reflection; + +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests; + +/// +/// Represents a collection of event types and their associated metadata. +/// +public sealed class AgentsMetadata +{ + /// + /// Initializes a new instance of the class. + /// + /// The type registry containing protobuf type information. + /// A dictionary mapping event names to their corresponding types. + /// A dictionary mapping types to a set of event names associated with those types. + /// A dictionary mapping types to a set of topics associated with those types. + /// A dictionary mapping types to a set of topics associated with those types. + /// + public AgentsMetadata( + TypeRegistry typeRegistry, + Dictionary types, + Dictionary> eventsMap, + Dictionary> topicsMap, + Dictionary> topicsPrefixMap) + { + TypeRegistry = typeRegistry; + _types = new(types); + _eventsMap = new(eventsMap); + _topicsMap = new(topicsMap); + _topicsPrefixMap = new(topicsPrefixMap); + } + + /// + /// Gets the type registry containing protobuf type information. + /// + public TypeRegistry TypeRegistry { get; } + + private ConcurrentDictionary _types; + + private ConcurrentDictionary> _eventsMap; + private ConcurrentDictionary> _topicsMap; + private ConcurrentDictionary> _topicsPrefixMap; + + /// + /// Checks if a given type handles a specific event name. + /// + /// The type to check. + /// The event name to check. + /// true if the type handles the event name; otherwise, false. + public bool CheckIfTypeHandles(Type type, string eventName) + { + if (_eventsMap.TryGetValue(type, out var events)) + { + return events.Contains(eventName); + } + return false; + } + + /// + /// Gets the event type by its name. + /// + /// The name of the event type. + /// The event type if found; otherwise, null. + public Type? GetEventTypeByName(string type) + { + if (_types.TryGetValue(type, out var eventType)) + { + return eventType; + } + return null; + } + + public HashSet? GetEventsForAgent(Type agent) + { + if (_eventsMap.TryGetValue(agent, out var events)) + { + return events; + } + return null; + } + + public HashSet? GetTopicsForAgent(Type agent) + { + if (_topicsMap.TryGetValue(agent, out var topics)) + { + return topics; + } + return null; + } + + public HashSet? GetTopicsPrefixForAgent(Type type) + { + if (_topicsPrefixMap.TryGetValue(type, out var topics)) + { + return topics; + } + return null; + } +} + diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestAsyncStreamReader.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Grpc/TestAsyncStreamReader.cs similarity index 96% rename from dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestAsyncStreamReader.cs rename to dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Grpc/TestAsyncStreamReader.cs index 4f26711d149f..a0708a13b484 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestAsyncStreamReader.cs +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Grpc/TestAsyncStreamReader.cs @@ -16,7 +16,7 @@ using System.Threading.Channels; using Grpc.Core; -namespace Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Grpc; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.Helpers.Grpc; public class TestAsyncStreamReader : IDisposable, IAsyncStreamReader where T : class diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestGrpcClient.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Grpc/TestGrpcClient.cs similarity index 91% rename from dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestGrpcClient.cs rename to dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Grpc/TestGrpcClient.cs index e47f26eda159..8c325678cbf2 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestGrpcClient.cs +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Grpc/TestGrpcClient.cs @@ -1,15 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // TestGrpcClient.cs - -using Microsoft.AutoGen.Contracts; - -namespace Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Grpc; +using Microsoft.AutoGen.Protobuf; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.Helpers.Grpc; internal sealed class TestGrpcClient : IDisposable { public TestAsyncStreamReader RequestStream { get; } public TestServerStreamWriter ResponseStream { get; } public TestServerCallContext CallContext { get; } - private CancellationTokenSource CallContextCancellation = new(); public TestGrpcClient() { @@ -28,7 +25,6 @@ public void AddMessage(Message message) { RequestStream.AddMessage(message); } - public void Dispose() { CallContextCancellation.Cancel(); @@ -36,3 +32,4 @@ public void Dispose() ResponseStream.Dispose(); } } + diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestServerCallContext.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Grpc/TestServerCallContext.cs similarity index 94% rename from dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestServerCallContext.cs rename to dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Grpc/TestServerCallContext.cs index 47f25155602d..491eb112b4bb 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestServerCallContext.cs +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Grpc/TestServerCallContext.cs @@ -15,7 +15,7 @@ using Grpc.Core; -namespace Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Grpc; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.Helpers.Grpc; public class TestServerCallContext : ServerCallContext { @@ -68,6 +68,7 @@ protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders) public static TestServerCallContext Create(Metadata? requestHeaders = null, CancellationToken cancellationToken = default) { + requestHeaders ??= new Metadata() { { "client-id", Guid.NewGuid().ToString() } }; return new TestServerCallContext(requestHeaders ?? new Metadata(), cancellationToken); } } diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestServerStreamWriter.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Grpc/TestServerStreamWriter.cs similarity index 97% rename from dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestServerStreamWriter.cs rename to dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Grpc/TestServerStreamWriter.cs index ca2aeab2e410..92074b2fabc6 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Grpc/TestServerStreamWriter.cs +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Grpc/TestServerStreamWriter.cs @@ -16,7 +16,7 @@ using System.Threading.Channels; using Grpc.Core; -namespace Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Grpc; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.Helpers.Grpc; public class TestServerStreamWriter : IDisposable, IServerStreamWriter where T : class { diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/ClusterCollection.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Orleans/ClusterCollection.cs similarity index 78% rename from dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/ClusterCollection.cs rename to dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Orleans/ClusterCollection.cs index d61dc7b21c50..e391a47f3e6c 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/ClusterCollection.cs +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Orleans/ClusterCollection.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // ClusterCollection.cs -namespace Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Orleans; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.Helpers.Orleans; [CollectionDefinition(Name)] public sealed class ClusterCollection : ICollectionFixture diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/ClusterFixture.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Orleans/ClusterFixture.cs similarity index 87% rename from dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/ClusterFixture.cs rename to dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Orleans/ClusterFixture.cs index 9db2f7f654d4..cb1320a251c3 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/ClusterFixture.cs +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Orleans/ClusterFixture.cs @@ -3,7 +3,7 @@ using Orleans.TestingHost; -namespace Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Orleans; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.Helpers.Orleans; public sealed class ClusterFixture : IDisposable { diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/SiloBuilderConfigurator.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Orleans/SiloBuilderConfigurator.cs similarity index 89% rename from dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/SiloBuilderConfigurator.cs rename to dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Orleans/SiloBuilderConfigurator.cs index bb960f7b1107..731ab83694c8 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/SiloBuilderConfigurator.cs +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/Orleans/SiloBuilderConfigurator.cs @@ -4,7 +4,7 @@ using Orleans.Serialization; using Orleans.TestingHost; -namespace Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Orleans; +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.Helpers.Orleans; public class SiloBuilderConfigurator : ISiloConfigurator { diff --git a/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/ReflectionHelper.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/ReflectionHelper.cs new file mode 100644 index 000000000000..12e9b799b97c --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Helpers/ReflectionHelper.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ReflectionHelper.cs +using System.Reflection; +using Google.Protobuf; +using Google.Protobuf.Reflection; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core; + +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests; +public sealed class ReflectionHelper +{ + public static bool IsSubclassOfGeneric(Type type, Type genericBaseType) + { + while (type != null && type != typeof(object)) + { + if (genericBaseType == (type.IsGenericType ? type.GetGenericTypeDefinition() : type)) + { + return true; + } + if (type.BaseType == null) + { + return false; + } + type = type.BaseType; + } + return false; + } + public static AgentsMetadata GetAgentsMetadata(params Assembly[] assemblies) + { + var interfaceType = typeof(IMessage); + var pairs = assemblies + .SelectMany(assembly => assembly.GetTypes()) + .Where(type => interfaceType.IsAssignableFrom(type) && type.IsClass && !type.IsAbstract) + .Select(t => (t, GetMessageDescriptor(t))); + + var descriptors = pairs.Select(t => t.Item2); + var typeRegistry = TypeRegistry.FromMessages(descriptors); + var types = pairs.ToDictionary(item => item.Item2?.FullName ?? "", item => item.t); + + var eventsMap = assemblies + .SelectMany(assembly => assembly.GetTypes()) + .Where(type => IsSubclassOfGeneric(type, typeof(BaseAgent)) && !type.IsAbstract) + .Select(t => (t, t.GetInterfaces() + .Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>)) + .Select(i => GetMessageDescriptor(i.GetGenericArguments().First())?.FullName ?? "").ToHashSet())) + .ToDictionary(item => item.t, item => item.Item2); + var topicsMap = assemblies + .SelectMany(assembly => assembly.GetTypes()) + .Where(type => IsSubclassOfGeneric(type, typeof(BaseAgent)) && !type.IsAbstract) + .Select(t => (t, t.GetCustomAttributes().Select(a => a.Topic).ToHashSet())) + .ToDictionary(item => item.t, item => item.Item2); + var topicsPrefixMap = assemblies + .SelectMany(assembly => assembly.GetTypes()) + .Where(type => IsSubclassOfGeneric(type, typeof(BaseAgent)) && !type.IsAbstract) + .Select(t => (t, t.GetCustomAttributes().Select(a => a.Topic).ToHashSet())) + .ToDictionary(item => item.t, item => item.Item2); + return new AgentsMetadata(typeRegistry, types, eventsMap, topicsMap, topicsPrefixMap); + } + + /// + /// Gets the message descriptor for the specified type. + /// + /// The type to get the message descriptor for. + /// The message descriptor if found; otherwise, null. + public static MessageDescriptor? GetMessageDescriptor(Type type) + { + var property = type.GetProperty("Descriptor", BindingFlags.Static | BindingFlags.Public); + return property?.GetValue(null) as MessageDescriptor; + } +} diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Microsoft.AutoGen.Runtime.Grpc.Tests.csproj b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.csproj similarity index 74% rename from dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Microsoft.AutoGen.Runtime.Grpc.Tests.csproj rename to dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.csproj index c8b00ee268b0..066a49c2de68 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Microsoft.AutoGen.Runtime.Grpc.Tests.csproj +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.csproj @@ -8,17 +8,13 @@ - - runtime; build; native; contentfiles; analyzers; buildtransitive - all - - + diff --git a/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/TestAgent.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/TestAgent.cs new file mode 100644 index 000000000000..493b0370433a --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/TestAgent.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// TestAgent.cs +using System.Collections.Concurrent; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core; +using Microsoft.AutoGen.Protobuf; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests; + +[TypeSubscription("gh-gh-gh")] +public class PBAgent(Contracts.AgentId id, IAgentRuntime runtime, ILogger? logger = null) + : BaseAgent(id, runtime, "Test Agent", logger), + IHandle, + IHandle +{ + public async ValueTask HandleAsync(NewMessageReceived item, MessageContext messageContext) + { + var key = messageContext.MessageId ?? Guid.NewGuid().ToString(); + ReceivedMessages.AddOrUpdate(key, item.Message, (k, v) => item.Message); + var hello = new Hello { Message = item.Message }; + var typeFullName = typeof(Hello).FullName ?? throw new InvalidOperationException("Type full name is null"); + await PublishMessageAsync(hello, new TopicId(typeFullName), "gh-gh-gh"); + } + public async ValueTask HandleAsync(GoodBye item, MessageContext context) + { + _logger.LogInformation($"Received GoodBye message {item.Message}"); + } + public static ConcurrentDictionary ReceivedMessages { get; private set; } = new(); +} + +[TypeSubscription("gh-gh-gh")] +public class GMAgent(Contracts.AgentId id, IAgentRuntime runtime, ILogger? logger = null) + : BaseAgent(id, runtime, "Test Agent", logger), + IHandle +{ + public async ValueTask HandleAsync(Hello item, MessageContext messageContext) + { + var key = messageContext.MessageId ?? Guid.NewGuid().ToString(); + ReceivedMessages.AddOrUpdate(key, item.Message, (k, v) => item.Message); + var typeFullName = typeof(GoodBye).FullName ?? throw new InvalidOperationException("Type full name is null"); + await PublishMessageAsync(new GoodBye { Message = "" }, new TopicId(typeFullName, "gh-gh-gh")); + } + public static ConcurrentDictionary ReceivedMessages { get; private set; } = new(); +} diff --git a/python/pyproject.toml b/python/pyproject.toml index 29cf8f963f09..832f7518a870 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -103,7 +103,7 @@ gen-proto = [ ] -gen-proto-samples = "python -m grpc_tools.protoc --python_out=./samples/core_xlang_hello_python_agent/protos --grpc_python_out=./samples/core_xlang_hello_python_agent/protos --mypy_out=./samples/core_xlang_hello_python_agent/protos --mypy_grpc_out=./samples/core_xlang_hello_python_agent/protos --proto_path ../protos/ agent_events.proto" +gen-proto-samples = "python -m grpc_tools.protoc --python_out=./samples/core_xlang_hello_python_agent/protos --grpc_python_out=./samples/core_xlang_hello_python_agent/protos --mypy_out=./samples/core_xlang_hello_python_agent/protos --mypy_grpc_out=./samples/core_xlang_hello_python_agent/protos --proto_path ../dotnet/src/Microsoft.AutoGen/Agents/protos/ agent_events.proto" [[tool.poe.tasks.gen-test-proto.sequence]] cmd = "python -m grpc_tools.protoc --python_out=./packages/autogen-core/tests/protos --grpc_python_out=./packages/autogen-core/tests/protos --mypy_out=./packages/autogen-core/tests/protos --mypy_grpc_out=./packages/autogen-core/tests/protos --proto_path ./packages/autogen-core/tests/protos serialization_test.proto" diff --git a/python/samples/core_xlang_hello_python_agent/protos/agent_events_pb2.py b/python/samples/core_xlang_hello_python_agent/protos/agent_events_pb2.py index fdd42804947d..4d65bcefd3cc 100644 --- a/python/samples/core_xlang_hello_python_agent/protos/agent_events_pb2.py +++ b/python/samples/core_xlang_hello_python_agent/protos/agent_events_pb2.py @@ -24,14 +24,14 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_events.proto\x12\x06\x61gents\"2\n\x0bTextMessage\x12\x13\n\x0btextMessage\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"\x18\n\x05Input\x12\x0f\n\x07message\x18\x01 \x01(\t\"\x1f\n\x0eInputProcessed\x12\r\n\x05route\x18\x01 \x01(\t\"\x19\n\x06Output\x12\x0f\n\x07message\x18\x01 \x01(\t\"\x1e\n\rOutputWritten\x12\r\n\x05route\x18\x01 \x01(\t\"\x1a\n\x07IOError\x12\x0f\n\x07message\x18\x01 \x01(\t\"%\n\x12NewMessageReceived\x12\x0f\n\x07message\x18\x01 \x01(\t\"%\n\x11ResponseGenerated\x12\x10\n\x08response\x18\x01 \x01(\t\"\x1a\n\x07GoodBye\x12\x0f\n\x07message\x18\x01 \x01(\t\" \n\rMessageStored\x12\x0f\n\x07message\x18\x01 \x01(\t\";\n\x12\x43onversationClosed\x12\x0f\n\x07user_id\x18\x01 \x01(\t\x12\x14\n\x0cuser_message\x18\x02 \x01(\t\"\x1b\n\x08Shutdown\x12\x0f\n\x07message\x18\x01 \x01(\tB\x1e\xaa\x02\x1bMicrosoft.AutoGen.Contractsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_events.proto\x12\x06\x61gents\"2\n\x0bTextMessage\x12\x13\n\x0btextMessage\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"\x18\n\x05Input\x12\x0f\n\x07message\x18\x01 \x01(\t\"\x1f\n\x0eInputProcessed\x12\r\n\x05route\x18\x01 \x01(\t\"\x19\n\x06Output\x12\x0f\n\x07message\x18\x01 \x01(\t\"\x1e\n\rOutputWritten\x12\r\n\x05route\x18\x01 \x01(\t\"\x1a\n\x07IOError\x12\x0f\n\x07message\x18\x01 \x01(\t\"%\n\x12NewMessageReceived\x12\x0f\n\x07message\x18\x01 \x01(\t\"%\n\x11ResponseGenerated\x12\x10\n\x08response\x18\x01 \x01(\t\"\x1a\n\x07GoodBye\x12\x0f\n\x07message\x18\x01 \x01(\t\" \n\rMessageStored\x12\x0f\n\x07message\x18\x01 \x01(\t\";\n\x12\x43onversationClosed\x12\x0f\n\x07user_id\x18\x01 \x01(\t\x12\x14\n\x0cuser_message\x18\x02 \x01(\t\"\x1b\n\x08Shutdown\x12\x0f\n\x07message\x18\x01 \x01(\tB\x1b\xaa\x02\x18Microsoft.AutoGen.Agentsb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'agent_events_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None - _globals['DESCRIPTOR']._serialized_options = b'\252\002\033Microsoft.AutoGen.Contracts' + _globals['DESCRIPTOR']._serialized_options = b'\252\002\030Microsoft.AutoGen.Agents' _globals['_TEXTMESSAGE']._serialized_start=30 _globals['_TEXTMESSAGE']._serialized_end=80 _globals['_INPUT']._serialized_start=82 diff --git a/python/uv.lock b/python/uv.lock index dbe382896347..eab21f4ab3cd 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -4308,7 +4308,6 @@ name = "nvidia-cublas-cu12" version = "12.4.5.8" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771 }, { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 }, ] @@ -4317,7 +4316,6 @@ name = "nvidia-cuda-cupti-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556 }, { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 }, ] @@ -4326,7 +4324,6 @@ name = "nvidia-cuda-nvrtc-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372 }, { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 }, ] @@ -4335,7 +4332,6 @@ name = "nvidia-cuda-runtime-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177 }, { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 }, ] @@ -4358,7 +4354,6 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 }, { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, ] @@ -4367,7 +4362,6 @@ name = "nvidia-curand-cu12" version = "10.3.5.147" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811 }, { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 }, ] @@ -4381,7 +4375,6 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 }, { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, ] @@ -4393,7 +4386,6 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 }, { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, ] @@ -4410,7 +4402,6 @@ name = "nvidia-nvjitlink-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510 }, { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 }, ] @@ -4419,7 +4410,6 @@ name = "nvidia-nvtx-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417 }, { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 }, ] From 51b601b6d844df7ae6963fd93481038844696e03 Mon Sep 17 00:00:00 2001 From: Ryan Sweet Date: Fri, 7 Feb 2025 19:59:38 -0800 Subject: [PATCH 13/36] re-add hello app host (#5445) --- dotnet/AutoGen.sln | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index cb590f21dbd0..014433663b60 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -136,12 +136,12 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.AutoGen", "Micros EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Agents", "src\Microsoft.AutoGen\Agents\Microsoft.AutoGen.Agents.csproj", "{EF954ED3-87D5-40F1-8557-E7179F43EA0E}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloAgent.AppHost", "test\Microsoft.AutoGen.Integration.Tests.AppHosts\HelloAgent.AppHost\HelloAgent.AppHost.csproj", "{50082F76-917F-42EE-8869-8C72630423A7}" - Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.AgentChat", "src\Microsoft.AutoGen\AgentChat\Microsoft.AutoGen.AgentChat.csproj", "{7F828599-56E8-4597-8F68-EE26FD631417}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.AgentChat.Tests", "test\Microsoft.AutoGen.AgentChat.Tests\Microsoft.AutoGen.AgentChat.Tests.csproj", "{217A4F86-8ADD-4998-90BA-880092A019F5}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloAgent.AppHost", "test\Microsoft.AutoGen.Integration.Tests.AppHosts\HelloAgent.AppHost\HelloAgent.AppHost.csproj", "{0C371D65-7EF9-44EA-8128-A105DA82A80E}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -360,10 +360,6 @@ Global {EF954ED3-87D5-40F1-8557-E7179F43EA0E}.Debug|Any CPU.Build.0 = Debug|Any CPU {EF954ED3-87D5-40F1-8557-E7179F43EA0E}.Release|Any CPU.ActiveCfg = Release|Any CPU {EF954ED3-87D5-40F1-8557-E7179F43EA0E}.Release|Any CPU.Build.0 = Release|Any CPU - {50082F76-917F-42EE-8869-8C72630423A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {50082F76-917F-42EE-8869-8C72630423A7}.Debug|Any CPU.Build.0 = Debug|Any CPU - {50082F76-917F-42EE-8869-8C72630423A7}.Release|Any CPU.ActiveCfg = Release|Any CPU - {50082F76-917F-42EE-8869-8C72630423A7}.Release|Any CPU.Build.0 = Release|Any CPU {7F828599-56E8-4597-8F68-EE26FD631417}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {7F828599-56E8-4597-8F68-EE26FD631417}.Debug|Any CPU.Build.0 = Debug|Any CPU {7F828599-56E8-4597-8F68-EE26FD631417}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -372,6 +368,10 @@ Global {217A4F86-8ADD-4998-90BA-880092A019F5}.Debug|Any CPU.Build.0 = Debug|Any CPU {217A4F86-8ADD-4998-90BA-880092A019F5}.Release|Any CPU.ActiveCfg = Release|Any CPU {217A4F86-8ADD-4998-90BA-880092A019F5}.Release|Any CPU.Build.0 = Release|Any CPU + {0C371D65-7EF9-44EA-8128-A105DA82A80E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {0C371D65-7EF9-44EA-8128-A105DA82A80E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {0C371D65-7EF9-44EA-8128-A105DA82A80E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {0C371D65-7EF9-44EA-8128-A105DA82A80E}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -434,9 +434,9 @@ Global {B8E77E57-C983-4EEA-9589-906271486D80} = {F42F9C8E-7BD9-4687-9B63-AFFA461AF5C1} {81BA12F2-2D2F-42C1-AF83-FBDAA1A78A45} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {EF954ED3-87D5-40F1-8557-E7179F43EA0E} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} - {50082F76-917F-42EE-8869-8C72630423A7} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {7F828599-56E8-4597-8F68-EE26FD631417} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {217A4F86-8ADD-4998-90BA-880092A019F5} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {0C371D65-7EF9-44EA-8128-A105DA82A80E} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B} From be085567ea5b43ea24f729d06cf8cec2442cd636 Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Sat, 8 Feb 2025 14:54:05 +1000 Subject: [PATCH 14/36] fix: remove sk tool adapter plugin name (#5444) ## Why are these changes needed? Semantic kernel prepends the plugin name to the tool name when passing the tools to model clients and this is causing a mismatch between tool names in SK and the AssistantAgent. Since plugin names are optional, we have opted to remove it. ## Related issue number Closes #5420 ## Checks - [ ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --------- Co-authored-by: Leonardo Pinheiro --- .../models/semantic_kernel/_sk_chat_completion_adapter.py | 2 +- .../tools/semantic_kernel/_kernel_function_from_tool.py | 2 +- .../tests/models/test_sk_chat_completion_adapter.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py index d0cf64c41e96..6cea28d0f2f5 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py @@ -313,7 +313,7 @@ def _sync_tools_with_kernel(self, kernel: Kernel, tools: Sequence[Tool | ToolSch for tool in tools: if isinstance(tool, BaseTool): # Convert Tool to KernelFunction using KernelFunctionFromTool - kernel_function = KernelFunctionFromTool(tool, plugin_name="autogen_tools") # type: ignore + kernel_function = KernelFunctionFromTool(tool) # type: ignore self._tools_plugin.functions[tool.schema["name"]] = kernel_function def _process_tool_calls(self, result: ChatMessageContent) -> list[FunctionCall]: diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py b/python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py index a11646c876ad..1c2fec0e373d 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py @@ -106,7 +106,7 @@ def __init__(self, tool: BaseTool[InputT, OutputT], plugin_name: str | None = No return_parameter=return_param, is_prompt=False, is_asynchronous=True, - plugin_name=plugin_name or "", + plugin_name=plugin_name, ) super().__init__(metadata=metadata) self._tool = tool diff --git a/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py b/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py index 71b74a9a0cb6..1b5a6ea03fa5 100644 --- a/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py +++ b/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py @@ -71,7 +71,7 @@ async def mock_get_chat_message_contents( FunctionCallContent( id="call_UwVVI0iGEmcPwmKUigJcuuuF", function_name="calculator", - plugin_name="autogen_tools", + plugin_name=None, arguments="{}", ) ], @@ -113,7 +113,7 @@ async def mock_get_streaming_chat_message_contents( FunctionCallContent( id="call_n8135GXc2kbiaaDdpImsB1VW", function_name="calculator", - plugin_name="autogen_tools", + plugin_name=None, arguments="", content_type="function_call", # type: ignore ) From 0b659de36d0308fe7d56919ee24e9e7fadae577c Mon Sep 17 00:00:00 2001 From: afourney Date: Fri, 7 Feb 2025 23:04:24 -0800 Subject: [PATCH 15/36] Mitigates #5401 by optionally prepending names to messages. (#5448) Mitigates #5401 by optionally prepending names to messages. Co-authored-by: Eric Zhu --- .../models/openai/_openai_client.py | 57 ++++++++++++++----- .../models/openai/config/__init__.py | 2 + .../tests/models/test_openai_model_client.py | 54 +++++++++++++++++- 3 files changed, 99 insertions(+), 14 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index b1f23e1040dd..23b46cdb4891 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -137,11 +137,11 @@ def type_to_role(message: LLMMessage) -> ChatCompletionRole: return "tool" -def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam: +def user_message_to_oai(message: UserMessage, prepend_name: bool = False) -> ChatCompletionUserMessageParam: assert_valid_name(message.source) if isinstance(message.content, str): return ChatCompletionUserMessageParam( - content=message.content, + content=(f"{message.source} said:\n" if prepend_name else "") + message.content, role="user", name=message.source, ) @@ -149,10 +149,18 @@ def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam: parts: List[ChatCompletionContentPartParam] = [] for part in message.content: if isinstance(part, str): - oai_part = ChatCompletionContentPartTextParam( - text=part, - type="text", - ) + if prepend_name: + # Append the name to the first text part + oai_part = ChatCompletionContentPartTextParam( + text=f"{message.source} said:\n" + part, + type="text", + ) + prepend_name = False + else: + oai_part = ChatCompletionContentPartTextParam( + text=part, + type="text", + ) parts.append(oai_part) elif isinstance(part, Image): # TODO: support url based images @@ -211,11 +219,11 @@ def assistant_message_to_oai( ) -def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]: +def to_oai_type(message: LLMMessage, prepend_name: bool = False) -> Sequence[ChatCompletionMessageParam]: if isinstance(message, SystemMessage): return [system_message_to_oai(message)] elif isinstance(message, UserMessage): - return [user_message_to_oai(message)] + return [user_message_to_oai(message, prepend_name)] elif isinstance(message, AssistantMessage): return [assistant_message_to_oai(message)] else: @@ -356,8 +364,10 @@ def __init__( create_args: Dict[str, Any], model_capabilities: Optional[ModelCapabilities] = None, # type: ignore model_info: Optional[ModelInfo] = None, + add_name_prefixes: bool = False, ): self._client = client + self._add_name_prefixes = add_name_prefixes if model_capabilities is None and model_info is None: try: self._model_info = _model_info.get_info(create_args["model"]) @@ -451,7 +461,7 @@ async def create( if self.model_info["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output.") - oai_messages_nested = [to_oai_type(m) for m in messages] + oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages] oai_messages = [item for sublist in oai_messages_nested for item in sublist] if self.model_info["function_calling"] is False and len(tools) > 0: @@ -672,7 +682,7 @@ async def create_stream( create_args = self._create_args.copy() create_args.update(extra_create_args) - oai_messages_nested = [to_oai_type(m) for m in messages] + oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages] oai_messages = [item for sublist in oai_messages_nested for item in sublist] # TODO: allow custom handling. @@ -874,7 +884,7 @@ def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | # Message tokens. for message in messages: num_tokens += tokens_per_message - oai_message = to_oai_type(message) + oai_message = to_oai_type(message, prepend_name=self._add_name_prefixes) for oai_message_part in oai_message: for key, value in oai_message_part.items(): if value is None: @@ -992,6 +1002,11 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA top_p (optional, float): user (optional, str): default_headers (optional, dict[str, str]): Custom headers; useful for authentication or other custom requirements. + add_name_prefixes (optional, bool): Whether to prepend the `source` value + to each :class:`~autogen_core.models.UserMessage` content. E.g., + "this is content" becomes "Reviewer said: this is content." + This can be useful for models that do not support the `name` field in + message. Defaults to False. To use this client, you must install the `openai` extension: @@ -1074,11 +1089,19 @@ def __init__(self, **kwargs: Unpack[OpenAIClientConfiguration]): model_info = kwargs["model_info"] del copied_args["model_info"] + add_name_prefixes: bool = False + if "add_name_prefixes" in kwargs: + add_name_prefixes = kwargs["add_name_prefixes"] + client = _openai_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) super().__init__( - client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info + client=client, + create_args=create_args, + model_capabilities=model_capabilities, + model_info=model_info, + add_name_prefixes=add_name_prefixes, ) def __getstate__(self) -> Dict[str, Any]: @@ -1215,11 +1238,19 @@ def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]): model_info = kwargs["model_info"] del copied_args["model_info"] + add_name_prefixes: bool = False + if "add_name_prefixes" in kwargs: + add_name_prefixes = kwargs["add_name_prefixes"] + client = _azure_openai_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) self._raw_config: Dict[str, Any] = copied_args super().__init__( - client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info + client=client, + create_args=create_args, + model_capabilities=model_capabilities, + model_info=model_info, + add_name_prefixes=add_name_prefixes, ) def __getstate__(self) -> Dict[str, Any]: diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py index 367564187990..b85e7c22ccdc 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py @@ -34,6 +34,7 @@ class BaseOpenAIClientConfiguration(CreateArguments, total=False): max_retries: int model_capabilities: ModelCapabilities # type: ignore model_info: ModelInfo + add_name_prefixes: bool """What functionality the model supports, determined by default from model name but is overriden if value passed.""" default_headers: Dict[str, str] | None @@ -75,6 +76,7 @@ class BaseOpenAIClientConfigurationConfigModel(CreateArgumentsConfigModel): max_retries: int | None = None model_capabilities: ModelCapabilities | None = None # type: ignore model_info: ModelInfo | None = None + add_name_prefixes: bool | None = None default_headers: Dict[str, str] | None = None diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index f2a8ff94319f..69d9281042c0 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -22,7 +22,7 @@ from autogen_core.tools import BaseTool, FunctionTool from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient from autogen_ext.models.openai._model_info import resolve_model -from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools +from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions from openai.resources.chat.completions import AsyncCompletions from openai.types.chat.chat_completion import ChatCompletion, Choice @@ -1050,4 +1050,56 @@ async def test_ollama() -> None: assert chunks[-1].thought is not None +@pytest.mark.asyncio +async def test_add_name_prefixes(monkeypatch: pytest.MonkeyPatch) -> None: + sys_message = SystemMessage(content="You are a helpful AI agent, and you answer questions in a friendly way.") + assistant_message = AssistantMessage(content="Hello, how can I help you?", source="Assistant") + user_text_message = UserMessage(content="Hello, I am from Seattle.", source="Adam") + user_mm_message = UserMessage( + content=[ + "Here is a postcard from Seattle:", + Image.from_base64( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4z8AAAAMBAQDJ/pLvAAAAAElFTkSuQmCC" + ), + ], + source="Adam", + ) + + # Default conversion + oai_sys = to_oai_type(sys_message)[0] + oai_asst = to_oai_type(assistant_message)[0] + oai_text = to_oai_type(user_text_message)[0] + oai_mm = to_oai_type(user_mm_message)[0] + + converted_sys = to_oai_type(sys_message, prepend_name=True)[0] + converted_asst = to_oai_type(assistant_message, prepend_name=True)[0] + converted_text = to_oai_type(user_text_message, prepend_name=True)[0] + converted_mm = to_oai_type(user_mm_message, prepend_name=True)[0] + + # Invariants + assert "content" in oai_sys + assert "content" in oai_asst + assert "content" in oai_text + assert "content" in oai_mm + assert "content" in converted_sys + assert "content" in converted_asst + assert "content" in converted_text + assert "content" in converted_mm + assert oai_sys["role"] == converted_sys["role"] + assert oai_sys["content"] == converted_sys["content"] + assert oai_asst["role"] == converted_asst["role"] + assert oai_asst["content"] == converted_asst["content"] + assert oai_text["role"] == converted_text["role"] + assert oai_mm["role"] == converted_mm["role"] + assert isinstance(oai_mm["content"], list) + assert isinstance(converted_mm["content"], list) + assert len(oai_mm["content"]) == len(converted_mm["content"]) + assert "text" in converted_mm["content"][0] + assert "text" in oai_mm["content"][0] + + # Name prepended + assert str(converted_text["content"]) == "Adam said:\n" + str(oai_text["content"]) + assert str(converted_mm["content"][0]["text"]) == "Adam said:\n" + str(oai_mm["content"][0]["text"]) + + # TODO: add integration tests for Azure OpenAI using AAD token. From 426b8984854f0c5284656a14ae01f1272891364a Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Sat, 8 Feb 2025 15:13:46 -0800 Subject: [PATCH 16/36] fix: improve speaker selection in SelectorGroupChat for weaker models (#5454) Don't throw an exception when model makes a mistake. Use retries, and if not succeeding after a fixed attempts, fall back to the previous sepaker if available, or the first participant. Resolves #5453 --- .../teams/_group_chat/_selector_group_chat.py | 107 ++++++++++++------ .../tests/test_group_chat.py | 62 ++++++++++ .../tests/test_group_chat_endpoint.py | 47 +++++--- 3 files changed, 169 insertions(+), 47 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index de0ef3247c69..70de5713814d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Mapping, Sequence from autogen_core import Component, ComponentModel -from autogen_core.models import ChatCompletionClient, ModelFamily, SystemMessage, UserMessage +from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage from pydantic import BaseModel from typing_extensions import Self @@ -39,6 +39,7 @@ def __init__( selector_prompt: str, allow_repeated_speaker: bool, selector_func: Callable[[Sequence[AgentEvent | ChatMessage]], str | None] | None, + max_selector_attempts: int, ) -> None: super().__init__( group_topic_type, @@ -53,6 +54,7 @@ def __init__( self._previous_speaker: str | None = None self._allow_repeated_speaker = allow_repeated_speaker self._selector_func = selector_func + self._max_selector_attempts = max_selector_attempts async def validate_group_state(self, messages: List[ChatMessage] | None) -> None: pass @@ -131,41 +133,72 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str: # Select the next speaker. if len(participants) > 1: - select_speaker_prompt = self._selector_prompt.format( - roles=roles, participants=str(participants), history=history - ) - select_speaker_messages: List[SystemMessage | UserMessage] - if self._model_client.model_info["family"] in [ - ModelFamily.GPT_4, - ModelFamily.GPT_4O, - ModelFamily.GPT_35, - ModelFamily.O1, - ModelFamily.O3, - ]: - select_speaker_messages = [SystemMessage(content=select_speaker_prompt)] - else: - # Many other models need a UserMessage to respond to - select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")] - - response = await self._model_client.create(messages=select_speaker_messages) - - assert isinstance(response.content, str) - mentions = self._mentioned_agents(response.content, self._participant_topic_types) - if len(mentions) != 1: - raise ValueError(f"Expected exactly one agent to be mentioned, but got {mentions}") - agent_name = list(mentions.keys())[0] - if ( - not self._allow_repeated_speaker - and self._previous_speaker is not None - and agent_name == self._previous_speaker - ): - trace_logger.warning(f"Selector selected the previous speaker: {agent_name}") + agent_name = await self._select_speaker(roles, participants, history, self._max_selector_attempts) else: agent_name = participants[0] self._previous_speaker = agent_name trace_logger.debug(f"Selected speaker: {agent_name}") return agent_name + async def _select_speaker(self, roles: str, participants: List[str], history: str, max_attempts: int) -> str: + select_speaker_prompt = self._selector_prompt.format( + roles=roles, participants=str(participants), history=history + ) + select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage] + if self._model_client.model_info["family"] in [ + ModelFamily.GPT_4, + ModelFamily.GPT_4O, + ModelFamily.GPT_35, + ModelFamily.O1, + ModelFamily.O3, + ]: + select_speaker_messages = [SystemMessage(content=select_speaker_prompt)] + else: + # Many other models need a UserMessage to respond to + select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="user")] + + num_attempts = 0 + while num_attempts < max_attempts: + num_attempts += 1 + response = await self._model_client.create(messages=select_speaker_messages) + assert isinstance(response.content, str) + select_speaker_messages.append(AssistantMessage(content=response.content, source="selector")) + mentions = self._mentioned_agents(response.content, self._participant_topic_types) + if len(mentions) == 0: + trace_logger.debug(f"Model failed to select a valid name: {response.content} (attempt {num_attempts})") + feedback = f"No valid name was mentioned. Please select from: {str(participants)}." + select_speaker_messages.append(UserMessage(content=feedback, source="user")) + elif len(mentions) > 1: + trace_logger.debug(f"Model selected multiple names: {str(mentions)} (attempt {num_attempts})") + feedback = ( + f"Expected exactly one name to be mentioned. Please select only one from: {str(participants)}." + ) + select_speaker_messages.append(UserMessage(content=feedback, source="user")) + else: + agent_name = list(mentions.keys())[0] + if ( + not self._allow_repeated_speaker + and self._previous_speaker is not None + and agent_name == self._previous_speaker + ): + trace_logger.debug(f"Model selected the previous speaker: {agent_name} (attempt {num_attempts})") + feedback = ( + f"Repeated speaker is not allowed, please select a different name from: {str(participants)}." + ) + select_speaker_messages.append(UserMessage(content=feedback, source="user")) + else: + # Valid selection + trace_logger.debug(f"Model selected a valid name: {agent_name} (attempt {num_attempts})") + return agent_name + + if self._previous_speaker is not None: + trace_logger.warning(f"Model failed to select a speaker after {max_attempts}, using the previous speaker.") + return self._previous_speaker + trace_logger.warning( + f"Model failed to select a speaker after {max_attempts} and there was no previous speaker, using the first participant." + ) + return participants[0] + def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]: """Counts the number of times each agent is mentioned in the provided message content. Agent names will match under any of the following conditions (all case-sensitive): @@ -210,6 +243,7 @@ class SelectorGroupChatConfig(BaseModel): selector_prompt: str allow_repeated_speaker: bool # selector_func: ComponentModel | None + max_selector_attempts: int = 3 class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): @@ -226,13 +260,17 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit. selector_prompt (str, optional): The prompt template to use for selecting the next speaker. Must contain '{roles}', '{participants}', and '{history}' to be filled in. - allow_repeated_speaker (bool, optional): Whether to allow the same speaker to be selected - consecutively. Defaults to False. + allow_repeated_speaker (bool, optional): Whether to include the previous speaker in the list of candidates to be selected for the next turn. + Defaults to False. The model may still select the previous speaker -- a warning will be logged if this happens. + max_selector_attempts (int, optional): The maximum number of attempts to select a speaker using the model. Defaults to 3. + If the model fails to select a speaker after the maximum number of attempts, the previous speaker will be used if available, + otherwise the first participant will be used. selector_func (Callable[[Sequence[AgentEvent | ChatMessage]], str | None], optional): A custom selector function that takes the conversation history and returns the name of the next speaker. If provided, this function will be used to override the model to select the next speaker. If the function returns None, the model will be used to select the next speaker. + Raises: ValueError: If the number of participants is less than two or if the selector prompt is invalid. @@ -368,6 +406,7 @@ def __init__( Read the above conversation. Then select the next role from {participants} to play. Only return the role. """, allow_repeated_speaker: bool = False, + max_selector_attempts: int = 3, selector_func: Callable[[Sequence[AgentEvent | ChatMessage]], str | None] | None = None, ): super().__init__( @@ -390,6 +429,7 @@ def __init__( self._model_client = model_client self._allow_repeated_speaker = allow_repeated_speaker self._selector_func = selector_func + self._max_selector_attempts = max_selector_attempts def _create_group_chat_manager_factory( self, @@ -411,6 +451,7 @@ def _create_group_chat_manager_factory( self._selector_prompt, self._allow_repeated_speaker, self._selector_func, + self._max_selector_attempts, ) def _to_config(self) -> SelectorGroupChatConfig: @@ -421,6 +462,7 @@ def _to_config(self) -> SelectorGroupChatConfig: max_turns=self._max_turns, selector_prompt=self._selector_prompt, allow_repeated_speaker=self._allow_repeated_speaker, + max_selector_attempts=self._max_selector_attempts, # selector_func=self._selector_func.dump_component() if self._selector_func else None, ) @@ -435,6 +477,7 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self: max_turns=config.max_turns, selector_prompt=config.selector_prompt, allow_repeated_speaker=config.allow_repeated_speaker, + max_selector_attempts=config.max_selector_attempts, # selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[AgentEvent | ChatMessage]], str | None]) # if config.selector_func # else None, diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index c04d7344029f..098d68349ca1 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -740,6 +740,68 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte assert result2 == result +@pytest.mark.asyncio +async def test_selector_group_chat_succcess_after_2_attempts() -> None: + model_client = ReplayChatCompletionClient( + ["agent2, agent3", "agent2"], + ) + agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1) + agent2 = _EchoAgent("agent2", description="echo agent 2") + agent3 = _EchoAgent("agent3", description="echo agent 3") + team = SelectorGroupChat( + participants=[agent1, agent2, agent3], + model_client=model_client, + max_turns=1, + ) + result = await team.run(task="Write a program that prints 'Hello, world!'") + assert len(result.messages) == 2 + assert result.messages[0].content == "Write a program that prints 'Hello, world!'" + assert result.messages[1].source == "agent2" + + +@pytest.mark.asyncio +async def test_selector_group_chat_fall_back_to_first_after_3_attempts() -> None: + model_client = ReplayChatCompletionClient( + [ + "agent2, agent3", # Multiple speakers + "agent5", # Non-existent speaker + "agent3, agent1", # Multiple speakers + ] + ) + agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1) + agent2 = _EchoAgent("agent2", description="echo agent 2") + agent3 = _EchoAgent("agent3", description="echo agent 3") + team = SelectorGroupChat( + participants=[agent1, agent2, agent3], + model_client=model_client, + max_turns=1, + ) + result = await team.run(task="Write a program that prints 'Hello, world!'") + assert len(result.messages) == 2 + assert result.messages[0].content == "Write a program that prints 'Hello, world!'" + assert result.messages[1].source == "agent1" + + +@pytest.mark.asyncio +async def test_selector_group_chat_fall_back_to_previous_after_3_attempts() -> None: + model_client = ReplayChatCompletionClient( + ["agent2", "agent2", "agent2", "agent2"], + ) + agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1) + agent2 = _EchoAgent("agent2", description="echo agent 2") + agent3 = _EchoAgent("agent3", description="echo agent 3") + team = SelectorGroupChat( + participants=[agent1, agent2, agent3], + model_client=model_client, + max_turns=2, + ) + result = await team.run(task="Write a program that prints 'Hello, world!'") + assert len(result.messages) == 3 + assert result.messages[0].content == "Write a program that prints 'Hello, world!'" + assert result.messages[1].source == "agent2" + assert result.messages[2].source == "agent2" + + @pytest.mark.asyncio async def test_selector_group_chat_custom_selector(monkeypatch: pytest.MonkeyPatch) -> None: model = "gpt-4o-2024-05-13" diff --git a/python/packages/autogen-agentchat/tests/test_group_chat_endpoint.py b/python/packages/autogen-agentchat/tests/test_group_chat_endpoint.py index fcd3b8839575..301c2f21d777 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat_endpoint.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat_endpoint.py @@ -4,10 +4,29 @@ from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import SelectorGroupChat from autogen_agentchat.ui import Console -from autogen_core.models import ModelFamily +from autogen_core.models import ChatCompletionClient, ModelFamily from autogen_ext.models.openai import OpenAIChatCompletionClient +async def _test_selector_group_chat(model_client: ChatCompletionClient) -> None: + assistant = AssistantAgent( + "assistant", + description="A helpful assistant agent.", + model_client=model_client, + system_message="You are a helpful assistant.", + ) + + critic = AssistantAgent( + "critic", + description="A critic agent to provide feedback.", + model_client=model_client, + system_message="Provide feedback.", + ) + + team = SelectorGroupChat([assistant, critic], model_client=model_client, max_turns=2) + await Console(team.run_stream(task="Draft a short email about organizing a holiday party for new year.")) + + @pytest.mark.asyncio async def test_selector_group_chat_gemini() -> None: try: @@ -26,20 +45,18 @@ async def test_selector_group_chat_gemini() -> None: "family": ModelFamily.GEMINI_1_5_FLASH, }, ) + await _test_selector_group_chat(model_client) - assistant = AssistantAgent( - "assistant", - description="A helpful assistant agent.", - model_client=model_client, - system_message="You are a helpful assistant.", - ) - critic = AssistantAgent( - "critic", - description="A critic agent to provide feedback.", - model_client=model_client, - system_message="Provide feedback.", - ) +@pytest.mark.asyncio +async def test_selector_group_chat_openai() -> None: + try: + api_key = os.environ["OPENAI_API_KEY"] + except KeyError: + pytest.skip("OPENAI_API_KEY not set in environment variables.") - team = SelectorGroupChat([assistant, critic], model_client=model_client, max_turns=2) - await Console(team.run_stream(task="Draft a short email about organizing a holiday party for new year.")) + model_client = OpenAIChatCompletionClient( + model="gpt-4o-mini", + api_key=api_key, + ) + await _test_selector_group_chat(model_client) From 15891e8cef188b7cb640883fad173e055cca861c Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Sat, 8 Feb 2025 15:53:02 -0800 Subject: [PATCH 17/36] docs: enhance human-in-the-loop tutorial with FastAPI websocket example (#5455) Added a websocket example and link to the sample directory. --- .../tutorial/human-in-the-loop.ipynb | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/human-in-the-loop.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/human-in-the-loop.ipynb index a7c0121d20c9..bcabe6958cb8 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/human-in-the-loop.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/human-in-the-loop.ipynb @@ -129,7 +129,28 @@ "through `user_proxy` to approve the generated poem.\n", "\n", "You can provide your own input function to the {py:class}`~autogen_agentchat.agents.UserProxyAgent`\n", - "to customize the feedback process." + "to customize the feedback process.\n", + "For example, when the team is running as a web service, you can use a custom\n", + "input function to wait for message from a web socket connection.\n", + "The following code snippet shows an example of custom input function\n", + "when using the [FastAPI](https://fastapi.tiangolo.com/) web framework:\n", + "\n", + "```python\n", + "@app.websocket(\"/ws/chat\")\n", + "async def chat(websocket: WebSocket):\n", + " await websocket.accept()\n", + "\n", + " async def _user_input(prompt: str, cancellation_token: CancellationToken | None) -> str:\n", + " data = await websocket.receive_json() # Wait for user message from websocket.\n", + " message = TextMessage.model_validate(data) # Assume user message is a TextMessage.\n", + " return message.content\n", + " \n", + " # Create user proxy with custom input function\n", + " # Run the team with the user proxy\n", + " # ...\n", + "```\n", + "\n", + "See the [AgentChat FastAPI sample](https://github.com/microsoft/autogen/blob/main/python/samples/agentchat_fastapi) for a complete example." ] }, { From b5eaab85015e2434c13717c77f11397da1e52dad Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Sat, 8 Feb 2025 18:08:14 -0800 Subject: [PATCH 18/36] fix & doc: update selector prompt documentation and remove validation checks (#5456) --- .../teams/_group_chat/_selector_group_chat.py | 9 +- .../selector-group-chat.ipynb | 1207 ++++++++++------- 2 files changed, 693 insertions(+), 523 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 70de5713814d..a69934a64dab 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -259,7 +259,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): Without a termination condition, the group chat will run indefinitely. max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit. selector_prompt (str, optional): The prompt template to use for selecting the next speaker. - Must contain '{roles}', '{participants}', and '{history}' to be filled in. + Available fields: '{roles}', '{participants}', and '{history}'. allow_repeated_speaker (bool, optional): Whether to include the previous speaker in the list of candidates to be selected for the next turn. Defaults to False. The model may still select the previous speaker -- a warning will be logged if this happens. max_selector_attempts (int, optional): The maximum number of attempts to select a speaker using the model. Defaults to 3. @@ -418,13 +418,6 @@ def __init__( # Validate the participants. if len(participants) < 2: raise ValueError("At least two participants are required for SelectorGroupChat.") - # Validate the selector prompt. - if "{roles}" not in selector_prompt: - raise ValueError("The selector prompt must contain '{roles}'") - if "{participants}" not in selector_prompt: - raise ValueError("The selector prompt must contain '{participants}'") - if "{history}" not in selector_prompt: - raise ValueError("The selector prompt must contain '{history}'") self._selector_prompt = selector_prompt self._model_client = model_client self._allow_repeated_speaker = allow_repeated_speaker diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/selector-group-chat.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/selector-group-chat.ipynb index efb7c24190fb..6fcfa6dc50a1 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/selector-group-chat.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/selector-group-chat.ipynb @@ -1,518 +1,695 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Selector Group Chat" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "{py:class}`~autogen_agentchat.teams.SelectorGroupChat` implements a team where participants take turns broadcasting messages to all other members. A generative model (e.g., an LLM) selects the next speaker based on the shared context, enabling dynamic, context-aware collaboration.\n", - "\n", - "Key features include:\n", - "\n", - "- Model-based speaker selection\n", - "- Configurable participant roles and descriptions\n", - "- Prevention of consecutive turns by the same speaker (optional)\n", - "- Customizable selection prompting\n", - "- Customizable selection function to override the default model-based selection\n", - "\n", - "```{note}\n", - "{py:class}`~autogen_agentchat.teams.SelectorGroupChat` is a high-level API. For more control and customization, refer to the [Group Chat Pattern](../core-user-guide/design-patterns/group-chat.ipynb) in the Core API documentation to implement your own group chat logic.\n", - "```\n", - "\n", - "## How Does it Work?\n", - "\n", - "{py:class}`~autogen_agentchat.teams.SelectorGroupChat` is a group chat similar to {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat`,\n", - "but with a model-based next speaker selection mechanism.\n", - "When the team receives a task through {py:meth}`~autogen_agentchat.teams.BaseGroupChat.run` or {py:meth}`~autogen_agentchat.teams.BaseGroupChat.run_stream`,\n", - "the following steps are executed:\n", - "\n", - "1. The team analyzes the current conversation context, including the conversation history and participants' {py:attr}`~autogen_agentchat.base.ChatAgent.name` and {py:attr}`~autogen_agentchat.base.ChatAgent.description` attributes, to determine the next speaker using a model. By default, the team will not select the same speak consecutively unless it is the only agent available. This can be changed by setting `allow_repeated_speaker=True`. You can also override the model by providing a custom selection function.\n", - "2. The team prompts the selected speaker agent to provide a response, which is then **broadcasted** to all other participants.\n", - "3. The termination condition is checked to determine if the conversation should end, if not, the process repeats from step 1.\n", - "4. When the conversation ends, the team returns the {py:class}`~autogen_agentchat.base.TaskResult` containing the conversation history from this task.\n", - "\n", - "Once the team finishes the task, the conversation context is kept within the team and all participants, so the next task can continue from the previous conversation context.\n", - "You can reset the conversation context by calling {py:meth}`~autogen_agentchat.teams.BaseGroupChat.reset`.\n", - "\n", - "In this section, we will demonstrate how to use {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with a simple example for a web search and data analysis task." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example: Web Search/Analysis" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Sequence\n", - "\n", - "from autogen_agentchat.agents import AssistantAgent\n", - "from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination\n", - "from autogen_agentchat.messages import AgentEvent, ChatMessage\n", - "from autogen_agentchat.teams import SelectorGroupChat\n", - "from autogen_agentchat.ui import Console\n", - "from autogen_ext.models.openai import OpenAIChatCompletionClient" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Agents\n", - "\n", - "![Selector Group Chat](selector-group-chat.svg)\n", - "\n", - "This system uses three specialized agents:\n", - "\n", - "- **Planning Agent**: The strategic coordinator that breaks down complex tasks into manageable subtasks. \n", - "- **Web Search Agent**: An information retrieval specialist that interfaces with the `search_web_tool`.\n", - "- **Data Analyst Agent**: An agent specialist in performing calculations equipped with `percentage_change_tool`. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The tools `search_web_tool` and `percentage_change_tool` are external tools that the agents can use to perform their tasks." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# Note: This example uses mock tools instead of real APIs for demonstration purposes\n", - "def search_web_tool(query: str) -> str:\n", - " if \"2006-2007\" in query:\n", - " return \"\"\"Here are the total points scored by Miami Heat players in the 2006-2007 season:\n", - " Udonis Haslem: 844 points\n", - " Dwayne Wade: 1397 points\n", - " James Posey: 550 points\n", - " ...\n", - " \"\"\"\n", - " elif \"2007-2008\" in query:\n", - " return \"The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\"\n", - " elif \"2008-2009\" in query:\n", - " return \"The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.\"\n", - " return \"No data found.\"\n", - "\n", - "\n", - "def percentage_change_tool(start: float, end: float) -> float:\n", - " return ((end - start) / start) * 100" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's create the specialized agents using the {py:class}`~autogen_agentchat.agents.AssistantAgent` class.\n", - "It is important to note that the agents' {py:attr}`~autogen_agentchat.base.ChatAgent.name` and {py:attr}`~autogen_agentchat.base.ChatAgent.description` attributes are used by the model to determine the next speaker,\n", - "so it is recommended to provide meaningful names and descriptions." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "model_client = OpenAIChatCompletionClient(model=\"gpt-4o\")\n", - "\n", - "planning_agent = AssistantAgent(\n", - " \"PlanningAgent\",\n", - " description=\"An agent for planning tasks, this agent should be the first to engage when given a new task.\",\n", - " model_client=model_client,\n", - " system_message=\"\"\"\n", - " You are a planning agent.\n", - " Your job is to break down complex tasks into smaller, manageable subtasks.\n", - " Your team members are:\n", - " Web search agent: Searches for information\n", - " Data analyst: Performs calculations\n", - "\n", - " You only plan and delegate tasks - you do not execute them yourself.\n", - "\n", - " When assigning tasks, use this format:\n", - " 1. : \n", - "\n", - " After all tasks are complete, summarize the findings and end with \"TERMINATE\".\n", - " \"\"\",\n", - ")\n", - "\n", - "web_search_agent = AssistantAgent(\n", - " \"WebSearchAgent\",\n", - " description=\"A web search agent.\",\n", - " tools=[search_web_tool],\n", - " model_client=model_client,\n", - " system_message=\"\"\"\n", - " You are a web search agent.\n", - " Your only tool is search_tool - use it to find information.\n", - " You make only one search call at a time.\n", - " Once you have the results, you never do calculations based on them.\n", - " \"\"\",\n", - ")\n", - "\n", - "data_analyst_agent = AssistantAgent(\n", - " \"DataAnalystAgent\",\n", - " description=\"A data analyst agent. Useful for performing calculations.\",\n", - " model_client=model_client,\n", - " tools=[percentage_change_tool],\n", - " system_message=\"\"\"\n", - " You are a data analyst.\n", - " Given the tasks you have been assigned, you should analyze the data and provide results using the tools provided.\n", - " \"\"\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```{note}\n", - "By default, {py:class}`~autogen_agentchat.agents.AssistantAgent` returns the\n", - "tool output as the response. If your tool does not return a well-formed\n", - "string in natural language format, you may want to add a reflection step\n", - "within the agent by setting `reflect_on_tool_use=True` when creating the agent.\n", - "This will allow the agent to reflect on the tool output and provide a natural\n", - "language response.\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Workflow\n", - "\n", - "1. The task is received by the {py:class}`~autogen_agentchat.teams.SelectorGroupChat` which, based on agent descriptions, selects the most appropriate agent to handle the initial task (typically the Planning Agent).\n", - "\n", - "2. The **Planning Agent** analyzes the task and breaks it down into subtasks, assigning each to the most appropriate agent using the format:\n", - " ` : `\n", - "\n", - "3. Based on the conversation context and agent descriptions, the {py:class}`~autogen_agent.teams.SelectorGroupChat` manager dynamically selects the next agent to handle their assigned subtask.\n", - "\n", - "4. The **Web Search Agent** performs searches one at a time, storing results in the shared conversation history.\n", - "\n", - "5. The **Data Analyst** processes the gathered information using available calculation tools when selected.\n", - "\n", - "6. The workflow continues with agents being dynamically selected until either:\n", - " - The Planning Agent determines all subtasks are complete and sends \"TERMINATE\"\n", - " - An alternative termination condition is met (e.g., a maximum number of messages)\n", - "\n", - "When defining your agents, make sure to include a helpful {py:attr}`~autogen_agentchat.base.ChatAgent.description` since this is used to decide which agent to select next." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's create the team with two termination conditions:\n", - "{py:class}`~autogen_agentchat.conditions.TextMentionTermination` to end the conversation when the Planning Agent sends \"TERMINATE\",\n", - "and {py:class}`~autogen_agentchat.conditions.MaxMessageTermination` to limit the conversation to 25 messages to avoid infinite loop." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "text_mention_termination = TextMentionTermination(\"TERMINATE\")\n", - "max_messages_termination = MaxMessageTermination(max_messages=25)\n", - "termination = text_mention_termination | max_messages_termination\n", - "\n", - "team = SelectorGroupChat(\n", - " [planning_agent, web_search_agent, data_analyst_agent],\n", - " model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n", - " termination_condition=termination,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we run the team with a task to find information about an NBA player." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "---------- user ----------\n", - "Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?\n", - "---------- PlanningAgent ----------\n", - "To address this request, we will divide the task into manageable subtasks. \n", - "\n", - "1. Web search agent: Identify the Miami Heat player with the highest points in the 2006-2007 season.\n", - "2. Web search agent: Gather the total rebounds for the identified player during the 2007-2008 season.\n", - "3. Web search agent: Gather the total rebounds for the identified player during the 2008-2009 season.\n", - "4. Data analyst: Calculate the percentage change in total rebounds for the identified player between the 2007-2008 and 2008-2009 seasons.\n", - "[Prompt tokens: 159, Completion tokens: 122]\n", - "---------- WebSearchAgent ----------\n", - "[FunctionCall(id='call_xdYlGP2lsqDeWdSiOlwOBNiO', arguments='{\"query\":\"Miami Heat highest points player 2006-2007 season\"}', name='search_web_tool')]\n", - "[Prompt tokens: 271, Completion tokens: 26]\n", - "---------- WebSearchAgent ----------\n", - "[FunctionExecutionResult(content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', call_id='call_xdYlGP2lsqDeWdSiOlwOBNiO')]\n", - "---------- WebSearchAgent ----------\n", - "Tool calls:\n", - "search_web_tool({\"query\":\"Miami Heat highest points player 2006-2007 season\"}) = Here are the total points scored by Miami Heat players in the 2006-2007 season:\n", - " Udonis Haslem: 844 points\n", - " Dwayne Wade: 1397 points\n", - " James Posey: 550 points\n", - " ...\n", - " \n", - "---------- DataAnalystAgent ----------\n", - "[FunctionCall(id='call_asjxKtAVGfqrYl0jMpvwCrvV', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='WebSearchAgent'), FunctionCall(id='call_8L91Kizt0KU6RNwUgvNx7S0s', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='WebSearchAgent')]\n", - "[Prompt tokens: 345, Completion tokens: 68]\n", - "---------- DataAnalystAgent ----------\n", - "[FunctionExecutionResult(content=\"Error: The tool 'WebSearchAgent' is not available.\", call_id='call_asjxKtAVGfqrYl0jMpvwCrvV'), FunctionExecutionResult(content=\"Error: The tool 'WebSearchAgent' is not available.\", call_id='call_8L91Kizt0KU6RNwUgvNx7S0s')]\n", - "---------- DataAnalystAgent ----------\n", - "Tool calls:\n", - "WebSearchAgent({\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}) = Error: The tool 'WebSearchAgent' is not available.\n", - "WebSearchAgent({\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}) = Error: The tool 'WebSearchAgent' is not available.\n", - "---------- WebSearchAgent ----------\n", - "[FunctionCall(id='call_imvRJ2jhpPdovBbx8MFjlFVS', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='search_web_tool'), FunctionCall(id='call_U30KVmFG1aeXPbqGJjDmJ6iJ', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='search_web_tool')]\n", - "[Prompt tokens: 445, Completion tokens: 70]\n", - "---------- WebSearchAgent ----------\n", - "[FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.', call_id='call_imvRJ2jhpPdovBbx8MFjlFVS'), FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', call_id='call_U30KVmFG1aeXPbqGJjDmJ6iJ')]\n", - "---------- WebSearchAgent ----------\n", - "Tool calls:\n", - "search_web_tool({\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}) = The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\n", - "search_web_tool({\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}) = The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.\n", - "---------- DataAnalystAgent ----------\n", - "[FunctionCall(id='call_CtAnvcbitN0JiwBfiLVzb5Do', arguments='{\"start\":214,\"end\":398}', name='percentage_change_tool')]\n", - "[Prompt tokens: 562, Completion tokens: 20]\n", - "---------- DataAnalystAgent ----------\n", - "[FunctionExecutionResult(content='85.98130841121495', call_id='call_CtAnvcbitN0JiwBfiLVzb5Do')]\n", - "---------- DataAnalystAgent ----------\n", - "Tool calls:\n", - "percentage_change_tool({\"start\":214,\"end\":398}) = 85.98130841121495\n", - "---------- PlanningAgent ----------\n", - "Summary of Findings:\n", - "\n", - "1. Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 season, scoring a total of 1,397 points.\n", - "2. Dwyane Wade's total rebounds during the 2007-2008 season were 214.\n", - "3. Dwyane Wade's total rebounds during the 2008-2009 season were 398.\n", - "4. The percentage change in Dwyane Wade's total rebounds between the 2007-2008 and 2008-2009 seasons was approximately 85.98%.\n", - "\n", - "TERMINATE\n", - "[Prompt tokens: 590, Completion tokens: 122]\n", - "---------- Summary ----------\n", - "Number of messages: 15\n", - "Finish reason: Text 'TERMINATE' mentioned\n", - "Total prompt tokens: 2372\n", - "Total completion tokens: 428\n", - "Duration: 9.21 seconds\n" - ] - }, - { - "data": { - "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?', type='TextMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=159, completion_tokens=122), content='To address this request, we will divide the task into manageable subtasks. \\n\\n1. Web search agent: Identify the Miami Heat player with the highest points in the 2006-2007 season.\\n2. Web search agent: Gather the total rebounds for the identified player during the 2007-2008 season.\\n3. Web search agent: Gather the total rebounds for the identified player during the 2008-2009 season.\\n4. Data analyst: Calculate the percentage change in total rebounds for the identified player between the 2007-2008 and 2008-2009 seasons.', type='TextMessage'), ToolCallRequestEvent(source='WebSearchAgent', models_usage=RequestUsage(prompt_tokens=271, completion_tokens=26), content=[FunctionCall(id='call_xdYlGP2lsqDeWdSiOlwOBNiO', arguments='{\"query\":\"Miami Heat highest points player 2006-2007 season\"}', name='search_web_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='WebSearchAgent', models_usage=None, content=[FunctionExecutionResult(content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', call_id='call_xdYlGP2lsqDeWdSiOlwOBNiO')], type='ToolCallExecutionEvent'), TextMessage(source='WebSearchAgent', models_usage=None, content='Tool calls:\\nsearch_web_tool({\"query\":\"Miami Heat highest points player 2006-2007 season\"}) = Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', type='TextMessage'), ToolCallRequestEvent(source='DataAnalystAgent', models_usage=RequestUsage(prompt_tokens=345, completion_tokens=68), content=[FunctionCall(id='call_asjxKtAVGfqrYl0jMpvwCrvV', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='WebSearchAgent'), FunctionCall(id='call_8L91Kizt0KU6RNwUgvNx7S0s', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='WebSearchAgent')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='DataAnalystAgent', models_usage=None, content=[FunctionExecutionResult(content=\"Error: The tool 'WebSearchAgent' is not available.\", call_id='call_asjxKtAVGfqrYl0jMpvwCrvV'), FunctionExecutionResult(content=\"Error: The tool 'WebSearchAgent' is not available.\", call_id='call_8L91Kizt0KU6RNwUgvNx7S0s')], type='ToolCallExecutionEvent'), TextMessage(source='DataAnalystAgent', models_usage=None, content='Tool calls:\\nWebSearchAgent({\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}) = Error: The tool \\'WebSearchAgent\\' is not available.\\nWebSearchAgent({\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}) = Error: The tool \\'WebSearchAgent\\' is not available.', type='TextMessage'), ToolCallRequestEvent(source='WebSearchAgent', models_usage=RequestUsage(prompt_tokens=445, completion_tokens=70), content=[FunctionCall(id='call_imvRJ2jhpPdovBbx8MFjlFVS', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='search_web_tool'), FunctionCall(id='call_U30KVmFG1aeXPbqGJjDmJ6iJ', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='search_web_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='WebSearchAgent', models_usage=None, content=[FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.', call_id='call_imvRJ2jhpPdovBbx8MFjlFVS'), FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', call_id='call_U30KVmFG1aeXPbqGJjDmJ6iJ')], type='ToolCallExecutionEvent'), TextMessage(source='WebSearchAgent', models_usage=None, content='Tool calls:\\nsearch_web_tool({\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}) = The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\\nsearch_web_tool({\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}) = The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', type='TextMessage'), ToolCallRequestEvent(source='DataAnalystAgent', models_usage=RequestUsage(prompt_tokens=562, completion_tokens=20), content=[FunctionCall(id='call_CtAnvcbitN0JiwBfiLVzb5Do', arguments='{\"start\":214,\"end\":398}', name='percentage_change_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='DataAnalystAgent', models_usage=None, content=[FunctionExecutionResult(content='85.98130841121495', call_id='call_CtAnvcbitN0JiwBfiLVzb5Do')], type='ToolCallExecutionEvent'), TextMessage(source='DataAnalystAgent', models_usage=None, content='Tool calls:\\npercentage_change_tool({\"start\":214,\"end\":398}) = 85.98130841121495', type='TextMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=590, completion_tokens=122), content=\"Summary of Findings:\\n\\n1. Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 season, scoring a total of 1,397 points.\\n2. Dwyane Wade's total rebounds during the 2007-2008 season were 214.\\n3. Dwyane Wade's total rebounds during the 2008-2009 season were 398.\\n4. The percentage change in Dwyane Wade's total rebounds between the 2007-2008 and 2008-2009 seasons was approximately 85.98%.\\n\\nTERMINATE\", type='TextMessage')], stop_reason=\"Text 'TERMINATE' mentioned\")" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "task = \"Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?\"\n", - "\n", - "# Use asyncio.run(...) if you are running this in a script.\n", - "await Console(team.run_stream(task=task))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As we can see, after the Web Search Agent conducts the necessary searches and the Data Analyst Agent completes the necessary calculations, we find that Dwayne Wade was the Miami Heat player with the highest points in the 2006-2007 season, and the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons is 85.98%!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Custom Selector Function" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Often times we want better control over the selection process. \n", - "To this end, we can set the `selector_func` argument with a custom selector function to override the default model-based selection.\n", - "For instance, we want the Planning Agent to speak immediately after any specialized agent to check the progress.\n", - "\n", - "```{note}\n", - "Returning `None` from the custom selector function will use the default model-based selection.\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "---------- user ----------\n", - "Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?\n", - "---------- PlanningAgent ----------\n", - "To address this query, we'll need to break it down into a few specific tasks:\n", - "\n", - "1. Web search agent: Identify the Miami Heat player with the highest points in the 2006-2007 NBA season.\n", - "2. Web search agent: Find the total number of rebounds by this player in the 2007-2008 NBA season.\n", - "3. Web search agent: Find the total number of rebounds by this player in the 2008-2009 NBA season.\n", - "4. Data analyst: Calculate the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons.\n", - "\n", - "Let's get started with these tasks.\n", - "[Prompt tokens: 159, Completion tokens: 132]\n", - "---------- WebSearchAgent ----------\n", - "[FunctionCall(id='call_TSUHOBKhpHmTNoYeJzwSP5V4', arguments='{\"query\":\"Miami Heat highest points player 2006-2007 season\"}', name='search_web_tool')]\n", - "[Prompt tokens: 281, Completion tokens: 26]\n", - "---------- WebSearchAgent ----------\n", - "[FunctionExecutionResult(content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', call_id='call_TSUHOBKhpHmTNoYeJzwSP5V4')]\n", - "---------- WebSearchAgent ----------\n", - "Tool calls:\n", - "search_web_tool({\"query\":\"Miami Heat highest points player 2006-2007 season\"}) = Here are the total points scored by Miami Heat players in the 2006-2007 season:\n", - " Udonis Haslem: 844 points\n", - " Dwayne Wade: 1397 points\n", - " James Posey: 550 points\n", - " ...\n", - " \n", - "---------- PlanningAgent ----------\n", - "1. Web search agent: Find the total number of rebounds by Dwayne Wade in the 2007-2008 NBA season.\n", - "2. Web search agent: Find the total number of rebounds by Dwayne Wade in the 2008-2009 NBA season.\n", - "[Prompt tokens: 382, Completion tokens: 54]\n", - "---------- DataAnalystAgent ----------\n", - "[FunctionCall(id='call_BkPBFkpuTG6c3eeoACrrRX7V', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='search_web_tool'), FunctionCall(id='call_5LQquT7ZUAAQRf7gvckeTVdQ', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='search_web_tool')]\n", - "[Prompt tokens: 416, Completion tokens: 68]\n", - "---------- DataAnalystAgent ----------\n", - "[FunctionExecutionResult(content=\"Error: The tool 'search_web_tool' is not available.\", call_id='call_BkPBFkpuTG6c3eeoACrrRX7V'), FunctionExecutionResult(content=\"Error: The tool 'search_web_tool' is not available.\", call_id='call_5LQquT7ZUAAQRf7gvckeTVdQ')]\n", - "---------- DataAnalystAgent ----------\n", - "Tool calls:\n", - "search_web_tool({\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}) = Error: The tool 'search_web_tool' is not available.\n", - "search_web_tool({\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}) = Error: The tool 'search_web_tool' is not available.\n", - "---------- PlanningAgent ----------\n", - "It seems there was a miscommunication in task assignment. Let me reassess and reassign the tasks correctly.\n", - "\n", - "1. Web search agent: Find the total number of rebounds by Dwayne Wade in the 2007-2008 NBA season.\n", - "2. Web search agent: Find the total number of rebounds by Dwayne Wade in the 2008-2009 NBA season.\n", - "[Prompt tokens: 525, Completion tokens: 76]\n", - "---------- WebSearchAgent ----------\n", - "[FunctionCall(id='call_buIWOtu1dJqPaxJmqMyuRkpj', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='search_web_tool'), FunctionCall(id='call_qcnHKdoPsNAzMlPvoBvqmt8n', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='search_web_tool')]\n", - "[Prompt tokens: 599, Completion tokens: 70]\n", - "---------- WebSearchAgent ----------\n", - "[FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.', call_id='call_buIWOtu1dJqPaxJmqMyuRkpj'), FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', call_id='call_qcnHKdoPsNAzMlPvoBvqmt8n')]\n", - "---------- WebSearchAgent ----------\n", - "Tool calls:\n", - "search_web_tool({\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}) = The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\n", - "search_web_tool({\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}) = The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.\n", - "---------- PlanningAgent ----------\n", - "With this information, we can proceed to calculate the percentage change in Dwyane Wade's total rebounds from the 2007-2008 season to the 2008-2009 season.\n", - "\n", - "1. Data analyst: Calculate the percentage change in Dwyane Wade's total rebounds between the 2007-2008 (214 rebounds) and the 2008-2009 (398 rebounds) NBA seasons.\n", - "[Prompt tokens: 711, Completion tokens: 83]\n", - "---------- DataAnalystAgent ----------\n", - "[FunctionCall(id='call_RjbFpLCehz1Nlk5kYmyMUenB', arguments='{\"start\":214,\"end\":398}', name='percentage_change_tool')]\n", - "[Prompt tokens: 806, Completion tokens: 20]\n", - "---------- DataAnalystAgent ----------\n", - "[FunctionExecutionResult(content='85.98130841121495', call_id='call_RjbFpLCehz1Nlk5kYmyMUenB')]\n", - "---------- DataAnalystAgent ----------\n", - "Tool calls:\n", - "percentage_change_tool({\"start\":214,\"end\":398}) = 85.98130841121495\n", - "---------- PlanningAgent ----------\n", - "Based on the data collected, Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 NBA season, scoring a total of 1,397 points. Between the 2007-2008 and 2008-2009 seasons, Dwyane Wade's total rebounds increased from 214 to 398. This represents an approximate 85.98% increase in his total rebounds.\n", - "\n", - "TERMINATE\n", - "[Prompt tokens: 834, Completion tokens: 90]\n", - "---------- Summary ----------\n", - "Number of messages: 18\n", - "Finish reason: Text 'TERMINATE' mentioned\n", - "Total prompt tokens: 4713\n", - "Total completion tokens: 619\n", - "Duration: 11.72 seconds\n" - ] - }, - { - "data": { - "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?', type='TextMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=159, completion_tokens=132), content=\"To address this query, we'll need to break it down into a few specific tasks:\\n\\n1. Web search agent: Identify the Miami Heat player with the highest points in the 2006-2007 NBA season.\\n2. Web search agent: Find the total number of rebounds by this player in the 2007-2008 NBA season.\\n3. Web search agent: Find the total number of rebounds by this player in the 2008-2009 NBA season.\\n4. Data analyst: Calculate the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons.\\n\\nLet's get started with these tasks.\", type='TextMessage'), ToolCallRequestEvent(source='WebSearchAgent', models_usage=RequestUsage(prompt_tokens=281, completion_tokens=26), content=[FunctionCall(id='call_TSUHOBKhpHmTNoYeJzwSP5V4', arguments='{\"query\":\"Miami Heat highest points player 2006-2007 season\"}', name='search_web_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='WebSearchAgent', models_usage=None, content=[FunctionExecutionResult(content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', call_id='call_TSUHOBKhpHmTNoYeJzwSP5V4')], type='ToolCallExecutionEvent'), TextMessage(source='WebSearchAgent', models_usage=None, content='Tool calls:\\nsearch_web_tool({\"query\":\"Miami Heat highest points player 2006-2007 season\"}) = Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', type='TextMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=382, completion_tokens=54), content='1. Web search agent: Find the total number of rebounds by Dwayne Wade in the 2007-2008 NBA season.\\n2. Web search agent: Find the total number of rebounds by Dwayne Wade in the 2008-2009 NBA season.', type='TextMessage'), ToolCallRequestEvent(source='DataAnalystAgent', models_usage=RequestUsage(prompt_tokens=416, completion_tokens=68), content=[FunctionCall(id='call_BkPBFkpuTG6c3eeoACrrRX7V', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='search_web_tool'), FunctionCall(id='call_5LQquT7ZUAAQRf7gvckeTVdQ', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='search_web_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='DataAnalystAgent', models_usage=None, content=[FunctionExecutionResult(content=\"Error: The tool 'search_web_tool' is not available.\", call_id='call_BkPBFkpuTG6c3eeoACrrRX7V'), FunctionExecutionResult(content=\"Error: The tool 'search_web_tool' is not available.\", call_id='call_5LQquT7ZUAAQRf7gvckeTVdQ')], type='ToolCallExecutionEvent'), TextMessage(source='DataAnalystAgent', models_usage=None, content='Tool calls:\\nsearch_web_tool({\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}) = Error: The tool \\'search_web_tool\\' is not available.\\nsearch_web_tool({\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}) = Error: The tool \\'search_web_tool\\' is not available.', type='TextMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=525, completion_tokens=76), content='It seems there was a miscommunication in task assignment. Let me reassess and reassign the tasks correctly.\\n\\n1. Web search agent: Find the total number of rebounds by Dwayne Wade in the 2007-2008 NBA season.\\n2. Web search agent: Find the total number of rebounds by Dwayne Wade in the 2008-2009 NBA season.', type='TextMessage'), ToolCallRequestEvent(source='WebSearchAgent', models_usage=RequestUsage(prompt_tokens=599, completion_tokens=70), content=[FunctionCall(id='call_buIWOtu1dJqPaxJmqMyuRkpj', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='search_web_tool'), FunctionCall(id='call_qcnHKdoPsNAzMlPvoBvqmt8n', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='search_web_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='WebSearchAgent', models_usage=None, content=[FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.', call_id='call_buIWOtu1dJqPaxJmqMyuRkpj'), FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', call_id='call_qcnHKdoPsNAzMlPvoBvqmt8n')], type='ToolCallExecutionEvent'), TextMessage(source='WebSearchAgent', models_usage=None, content='Tool calls:\\nsearch_web_tool({\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}) = The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\\nsearch_web_tool({\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}) = The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', type='TextMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=711, completion_tokens=83), content=\"With this information, we can proceed to calculate the percentage change in Dwyane Wade's total rebounds from the 2007-2008 season to the 2008-2009 season.\\n\\n1. Data analyst: Calculate the percentage change in Dwyane Wade's total rebounds between the 2007-2008 (214 rebounds) and the 2008-2009 (398 rebounds) NBA seasons.\", type='TextMessage'), ToolCallRequestEvent(source='DataAnalystAgent', models_usage=RequestUsage(prompt_tokens=806, completion_tokens=20), content=[FunctionCall(id='call_RjbFpLCehz1Nlk5kYmyMUenB', arguments='{\"start\":214,\"end\":398}', name='percentage_change_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='DataAnalystAgent', models_usage=None, content=[FunctionExecutionResult(content='85.98130841121495', call_id='call_RjbFpLCehz1Nlk5kYmyMUenB')], type='ToolCallExecutionEvent'), TextMessage(source='DataAnalystAgent', models_usage=None, content='Tool calls:\\npercentage_change_tool({\"start\":214,\"end\":398}) = 85.98130841121495', type='TextMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=834, completion_tokens=90), content=\"Based on the data collected, Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 NBA season, scoring a total of 1,397 points. Between the 2007-2008 and 2008-2009 seasons, Dwyane Wade's total rebounds increased from 214 to 398. This represents an approximate 85.98% increase in his total rebounds.\\n\\nTERMINATE\", type='TextMessage')], stop_reason=\"Text 'TERMINATE' mentioned\")" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:\n", - " if messages[-1].source != planning_agent.name:\n", - " return planning_agent.name\n", - " return None\n", - "\n", - "\n", - "# Reset the previous team and run the chat again with the selector function.\n", - "await team.reset()\n", - "team = SelectorGroupChat(\n", - " [planning_agent, web_search_agent, data_analyst_agent],\n", - " model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n", - " termination_condition=termination,\n", - " selector_func=selector_func,\n", - ")\n", - "\n", - "await Console(team.run_stream(task=task))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can see from the conversation log that the Planning Agent always speaks immediately after the specialized agents." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.7" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Selector Group Chat" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "{py:class}`~autogen_agentchat.teams.SelectorGroupChat` implements a team where participants take turns broadcasting messages to all other members. A generative model (e.g., an LLM) selects the next speaker based on the shared context, enabling dynamic, context-aware collaboration.\n", + "\n", + "Key features include:\n", + "\n", + "- Model-based speaker selection\n", + "- Configurable participant roles and descriptions\n", + "- Prevention of consecutive turns by the same speaker (optional)\n", + "- Customizable selection prompting\n", + "- Customizable selection function to override the default model-based selection\n", + "\n", + "```{note}\n", + "{py:class}`~autogen_agentchat.teams.SelectorGroupChat` is a high-level API. For more control and customization, refer to the [Group Chat Pattern](../core-user-guide/design-patterns/group-chat.ipynb) in the Core API documentation to implement your own group chat logic.\n", + "```\n", + "\n", + "## How Does it Work?\n", + "\n", + "{py:class}`~autogen_agentchat.teams.SelectorGroupChat` is a group chat similar to {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat`,\n", + "but with a model-based next speaker selection mechanism.\n", + "When the team receives a task through {py:meth}`~autogen_agentchat.teams.BaseGroupChat.run` or {py:meth}`~autogen_agentchat.teams.BaseGroupChat.run_stream`,\n", + "the following steps are executed:\n", + "\n", + "1. The team analyzes the current conversation context, including the conversation history and participants' {py:attr}`~autogen_agentchat.base.ChatAgent.name` and {py:attr}`~autogen_agentchat.base.ChatAgent.description` attributes, to determine the next speaker using a model. By default, the team will not select the same speak consecutively unless it is the only agent available. This can be changed by setting `allow_repeated_speaker=True`. You can also override the model by providing a custom selection function.\n", + "2. The team prompts the selected speaker agent to provide a response, which is then **broadcasted** to all other participants.\n", + "3. The termination condition is checked to determine if the conversation should end, if not, the process repeats from step 1.\n", + "4. When the conversation ends, the team returns the {py:class}`~autogen_agentchat.base.TaskResult` containing the conversation history from this task.\n", + "\n", + "Once the team finishes the task, the conversation context is kept within the team and all participants, so the next task can continue from the previous conversation context.\n", + "You can reset the conversation context by calling {py:meth}`~autogen_agentchat.teams.BaseGroupChat.reset`.\n", + "\n", + "In this section, we will demonstrate how to use {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with a simple example for a web search and data analysis task." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example: Web Search/Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Sequence\n", + "\n", + "from autogen_agentchat.agents import AssistantAgent, UserProxyAgent\n", + "from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination\n", + "from autogen_agentchat.messages import AgentEvent, ChatMessage\n", + "from autogen_agentchat.teams import SelectorGroupChat\n", + "from autogen_agentchat.ui import Console\n", + "from autogen_ext.models.openai import OpenAIChatCompletionClient" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Agents\n", + "\n", + "![Selector Group Chat](selector-group-chat.svg)\n", + "\n", + "This system uses three specialized agents:\n", + "\n", + "- **Planning Agent**: The strategic coordinator that breaks down complex tasks into manageable subtasks. \n", + "- **Web Search Agent**: An information retrieval specialist that interfaces with the `search_web_tool`.\n", + "- **Data Analyst Agent**: An agent specialist in performing calculations equipped with `percentage_change_tool`. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The tools `search_web_tool` and `percentage_change_tool` are external tools that the agents can use to perform their tasks." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Note: This example uses mock tools instead of real APIs for demonstration purposes\n", + "def search_web_tool(query: str) -> str:\n", + " if \"2006-2007\" in query:\n", + " return \"\"\"Here are the total points scored by Miami Heat players in the 2006-2007 season:\n", + " Udonis Haslem: 844 points\n", + " Dwayne Wade: 1397 points\n", + " James Posey: 550 points\n", + " ...\n", + " \"\"\"\n", + " elif \"2007-2008\" in query:\n", + " return \"The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\"\n", + " elif \"2008-2009\" in query:\n", + " return \"The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.\"\n", + " return \"No data found.\"\n", + "\n", + "\n", + "def percentage_change_tool(start: float, end: float) -> float:\n", + " return ((end - start) / start) * 100" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create the specialized agents using the {py:class}`~autogen_agentchat.agents.AssistantAgent` class.\n", + "It is important to note that the agents' {py:attr}`~autogen_agentchat.base.ChatAgent.name` and {py:attr}`~autogen_agentchat.base.ChatAgent.description` attributes are used by the model to determine the next speaker,\n", + "so it is recommended to provide meaningful names and descriptions." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "model_client = OpenAIChatCompletionClient(model=\"gpt-4o\")\n", + "\n", + "planning_agent = AssistantAgent(\n", + " \"PlanningAgent\",\n", + " description=\"An agent for planning tasks, this agent should be the first to engage when given a new task.\",\n", + " model_client=model_client,\n", + " system_message=\"\"\"\n", + " You are a planning agent.\n", + " Your job is to break down complex tasks into smaller, manageable subtasks.\n", + " Your team members are:\n", + " WebSearchAgent: Searches for information\n", + " DataAnalystAgent: Performs calculations\n", + "\n", + " You only plan and delegate tasks - you do not execute them yourself.\n", + "\n", + " When assigning tasks, use this format:\n", + " 1. : \n", + "\n", + " After all tasks are complete, summarize the findings and end with \"TERMINATE\".\n", + " \"\"\",\n", + ")\n", + "\n", + "web_search_agent = AssistantAgent(\n", + " \"WebSearchAgent\",\n", + " description=\"An agent for searching information on the web.\",\n", + " tools=[search_web_tool],\n", + " model_client=model_client,\n", + " system_message=\"\"\"\n", + " You are a web search agent.\n", + " Your only tool is search_tool - use it to find information.\n", + " You make only one search call at a time.\n", + " Once you have the results, you never do calculations based on them.\n", + " \"\"\",\n", + ")\n", + "\n", + "data_analyst_agent = AssistantAgent(\n", + " \"DataAnalystAgent\",\n", + " description=\"An agent for performing calculations.\",\n", + " model_client=model_client,\n", + " tools=[percentage_change_tool],\n", + " system_message=\"\"\"\n", + " You are a data analyst.\n", + " Given the tasks you have been assigned, you should analyze the data and provide results using the tools provided.\n", + " If you have not seen the data, ask for it.\n", + " \"\"\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{note}\n", + "By default, {py:class}`~autogen_agentchat.agents.AssistantAgent` returns the\n", + "tool output as the response. If your tool does not return a well-formed\n", + "string in natural language format, you may want to add a reflection step\n", + "within the agent by setting `reflect_on_tool_use=True` when creating the agent.\n", + "This will allow the agent to reflect on the tool output and provide a natural\n", + "language response.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Workflow\n", + "\n", + "1. The task is received by the {py:class}`~autogen_agentchat.teams.SelectorGroupChat` which, based on agent descriptions, selects the most appropriate agent to handle the initial task (typically the Planning Agent).\n", + "\n", + "2. The **Planning Agent** analyzes the task and breaks it down into subtasks, assigning each to the most appropriate agent using the format:\n", + " ` : `\n", + "\n", + "3. Based on the conversation context and agent descriptions, the {py:class}`~autogen_agent.teams.SelectorGroupChat` manager dynamically selects the next agent to handle their assigned subtask.\n", + "\n", + "4. The **Web Search Agent** performs searches one at a time, storing results in the shared conversation history.\n", + "\n", + "5. The **Data Analyst** processes the gathered information using available calculation tools when selected.\n", + "\n", + "6. The workflow continues with agents being dynamically selected until either:\n", + " - The Planning Agent determines all subtasks are complete and sends \"TERMINATE\"\n", + " - An alternative termination condition is met (e.g., a maximum number of messages)\n", + "\n", + "When defining your agents, make sure to include a helpful {py:attr}`~autogen_agentchat.base.ChatAgent.description` since this is used to decide which agent to select next." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Termination Conditions\n", + "\n", + "Let's use two termination conditions:\n", + "{py:class}`~autogen_agentchat.conditions.TextMentionTermination` to end the conversation when the Planning Agent sends \"TERMINATE\",\n", + "and {py:class}`~autogen_agentchat.conditions.MaxMessageTermination` to limit the conversation to 25 messages to avoid infinite loop." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "text_mention_termination = TextMentionTermination(\"TERMINATE\")\n", + "max_messages_termination = MaxMessageTermination(max_messages=25)\n", + "termination = text_mention_termination | max_messages_termination" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Selector Prompt\n", + "\n", + "{py:class}`~autogen_agentchat.teams.SelectorGroupChat` uses a model to select\n", + "the next speaker based on the conversation context.\n", + "We will use a custom selector prompt to properly align with the workflow." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "selector_prompt = \"\"\"Select an agent to perform task.\n", + "\n", + "{roles}\n", + "\n", + "Current conversation context:\n", + "{history}\n", + "\n", + "Read the above conversation, then select an agent from {participants} to perform the next task.\n", + "Make sure the planner agent has assigned tasks before other agents start working.\n", + "Only select one agent.\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Running the Team\n", + "\n", + "Let's create the team with the agents, termination conditions, and custom selector prompt." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "team = SelectorGroupChat(\n", + " [planning_agent, web_search_agent, data_analyst_agent],\n", + " model_client=model_client,\n", + " termination_condition=termination,\n", + " selector_prompt=selector_prompt,\n", + " allow_repeated_speaker=True, # Allow an agent to speak multiple turns in a row.\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we run the team with a task to find information about an NBA player." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "task = \"Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?\n", + "---------- PlanningAgent ----------\n", + "To answer this question, we'll need to break it down into the following tasks:\n", + "\n", + "1. WebSearchAgent: Investigate who was the Miami Heat player with the highest points scored in the 2006-2007 NBA season.\n", + "2. WebSearchAgent: Gather statistics for the player's total rebounds during the 2007-2008 NBA season.\n", + "3. WebSearchAgent: Collect statistics for the same player's total rebounds during the 2008-2009 NBA season.\n", + "4. DataAnalystAgent: Calculate the percentage change in total rebounds for the player between the 2007-2008 and 2008-2009 seasons.\n", + "\n", + "Let's begin by gathering the necessary data. \n", + "\n", + "1. WebSearchAgent: Investigate who was the Miami Heat player with the highest points scored in the 2006-2007 NBA season.\n", + "---------- WebSearchAgent ----------\n", + "[FunctionCall(id='call_knPrF0edlVal8opo7AMMU8Z0', arguments='{\"query\":\"Miami Heat highest points scorer 2006-2007 NBA season\"}', name='search_web_tool')]\n", + "---------- WebSearchAgent ----------\n", + "[FunctionExecutionResult(content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', call_id='call_knPrF0edlVal8opo7AMMU8Z0')]\n", + "---------- WebSearchAgent ----------\n", + "Here are the total points scored by Miami Heat players in the 2006-2007 season:\n", + " Udonis Haslem: 844 points\n", + " Dwayne Wade: 1397 points\n", + " James Posey: 550 points\n", + " ...\n", + " \n", + "---------- WebSearchAgent ----------\n", + "[FunctionCall(id='call_KOP1aR8O2kLGIeAkaUa8dE8u', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='search_web_tool'), FunctionCall(id='call_Hk9j3wvryEhlKpe0YOUgXO46', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='search_web_tool')]\n", + "---------- WebSearchAgent ----------\n", + "[FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.', call_id='call_KOP1aR8O2kLGIeAkaUa8dE8u'), FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', call_id='call_Hk9j3wvryEhlKpe0YOUgXO46')]\n", + "---------- WebSearchAgent ----------\n", + "The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\n", + "The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.\n" + ] }, - "nbformat": 4, - "nbformat_minor": 2 + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ekzhu/autogen/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py:415: UserWarning: Both tool_calls and content are present in the message. This is unexpected. content will be ignored, tool_calls will be used.\n", + " model_result = await self._model_client.create(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- DataAnalystAgent ----------\n", + "[FunctionCall(id='call_By2X8HJjd7K0B3qKSOti1vJ8', arguments='{\"start\":214,\"end\":398}', name='percentage_change_tool')]\n", + "---------- DataAnalystAgent ----------\n", + "[FunctionExecutionResult(content='85.98130841121495', call_id='call_By2X8HJjd7K0B3qKSOti1vJ8')]\n", + "---------- DataAnalystAgent ----------\n", + "85.98130841121495\n", + "---------- PlanningAgent ----------\n", + "To summarize: \n", + "\n", + "- The Miami Heat player with the highest points in the 2006-2007 season was Dwyane Wade, with a total of 1,397 points.\n", + "- The percentage change in Dwyane Wade's total rebounds from the 2007-2008 season to the 2008-2009 season is approximately 85.98%.\n", + "\n", + "TERMINATE\n" + ] + }, + { + "data": { + "text/plain": [ + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?', type='TextMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=161, completion_tokens=169), content=\"To answer this question, we'll need to break it down into the following tasks:\\n\\n1. WebSearchAgent: Investigate who was the Miami Heat player with the highest points scored in the 2006-2007 NBA season.\\n2. WebSearchAgent: Gather statistics for the player's total rebounds during the 2007-2008 NBA season.\\n3. WebSearchAgent: Collect statistics for the same player's total rebounds during the 2008-2009 NBA season.\\n4. DataAnalystAgent: Calculate the percentage change in total rebounds for the player between the 2007-2008 and 2008-2009 seasons.\\n\\nLet's begin by gathering the necessary data. \\n\\n1. WebSearchAgent: Investigate who was the Miami Heat player with the highest points scored in the 2006-2007 NBA season.\", type='TextMessage'), ToolCallRequestEvent(source='WebSearchAgent', models_usage=RequestUsage(prompt_tokens=317, completion_tokens=28), content=[FunctionCall(id='call_knPrF0edlVal8opo7AMMU8Z0', arguments='{\"query\":\"Miami Heat highest points scorer 2006-2007 NBA season\"}', name='search_web_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='WebSearchAgent', models_usage=None, content=[FunctionExecutionResult(content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', call_id='call_knPrF0edlVal8opo7AMMU8Z0')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='WebSearchAgent', models_usage=None, content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', type='ToolCallSummaryMessage'), ToolCallRequestEvent(source='WebSearchAgent', models_usage=RequestUsage(prompt_tokens=410, completion_tokens=142), content=[FunctionCall(id='call_KOP1aR8O2kLGIeAkaUa8dE8u', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='search_web_tool'), FunctionCall(id='call_Hk9j3wvryEhlKpe0YOUgXO46', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='search_web_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='WebSearchAgent', models_usage=None, content=[FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.', call_id='call_KOP1aR8O2kLGIeAkaUa8dE8u'), FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', call_id='call_Hk9j3wvryEhlKpe0YOUgXO46')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='WebSearchAgent', models_usage=None, content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\\nThe number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', type='ToolCallSummaryMessage'), ToolCallRequestEvent(source='DataAnalystAgent', models_usage=RequestUsage(prompt_tokens=437, completion_tokens=21), content=[FunctionCall(id='call_By2X8HJjd7K0B3qKSOti1vJ8', arguments='{\"start\":214,\"end\":398}', name='percentage_change_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='DataAnalystAgent', models_usage=None, content=[FunctionExecutionResult(content='85.98130841121495', call_id='call_By2X8HJjd7K0B3qKSOti1vJ8')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='DataAnalystAgent', models_usage=None, content='85.98130841121495', type='ToolCallSummaryMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=469, completion_tokens=79), content=\"To summarize: \\n\\n- The Miami Heat player with the highest points in the 2006-2007 season was Dwyane Wade, with a total of 1,397 points.\\n- The percentage change in Dwyane Wade's total rebounds from the 2007-2008 season to the 2008-2009 season is approximately 85.98%.\\n\\nTERMINATE\", type='TextMessage')], stop_reason=\"Text 'TERMINATE' mentioned\")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Use asyncio.run(...) if you are running this in a script.\n", + "await Console(team.run_stream(task=task))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we can see, after the Web Search Agent conducts the necessary searches and the Data Analyst Agent completes the necessary calculations, we find that Dwayne Wade was the Miami Heat player with the highest points in the 2006-2007 season, and the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons is 85.98%!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom Selector Function" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Often times we want better control over the selection process. \n", + "To this end, we can set the `selector_func` argument with a custom selector function to override the default model-based selection.\n", + "For instance, we want the Planning Agent to speak immediately after any specialized agent to check the progress.\n", + "\n", + "```{note}\n", + "Returning `None` from the custom selector function will use the default model-based selection.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?\n", + "---------- PlanningAgent ----------\n", + "To answer this question, we need to follow these steps: \n", + "\n", + "1. Identify the Miami Heat player with the highest points in the 2006-2007 season.\n", + "2. Retrieve the total rebounds of that player for the 2007-2008 and 2008-2009 seasons.\n", + "3. Calculate the percentage change in his total rebounds between the two seasons.\n", + "\n", + "Let's delegate these tasks:\n", + "\n", + "1. WebSearchAgent: Find the Miami Heat player with the highest points in the 2006-2007 NBA season.\n", + "2. WebSearchAgent: Retrieve the total rebounds for the identified player during the 2007-2008 NBA season.\n", + "3. WebSearchAgent: Retrieve the total rebounds for the identified player during the 2008-2009 NBA season.\n", + "4. DataAnalystAgent: Calculate the percentage change in total rebounds between the 2007-2008 and 2008-2009 seasons for the player found.\n", + "---------- WebSearchAgent ----------\n", + "[FunctionCall(id='call_Pz82ndNLSV4cH0Sg6g7ArP4L', arguments='{\"query\":\"Miami Heat player highest points 2006-2007 season\"}', name='search_web_tool')]\n", + "---------- WebSearchAgent ----------\n", + "[FunctionExecutionResult(content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', call_id='call_Pz82ndNLSV4cH0Sg6g7ArP4L')]\n", + "---------- WebSearchAgent ----------\n", + "Here are the total points scored by Miami Heat players in the 2006-2007 season:\n", + " Udonis Haslem: 844 points\n", + " Dwayne Wade: 1397 points\n", + " James Posey: 550 points\n", + " ...\n", + " \n", + "---------- PlanningAgent ----------\n", + "Great! Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 season. Now, let's continue with the next tasks:\n", + "\n", + "2. WebSearchAgent: Retrieve the total rebounds for Dwyane Wade during the 2007-2008 NBA season.\n", + "3. WebSearchAgent: Retrieve the total rebounds for Dwyane Wade during the 2008-2009 NBA season.\n", + "---------- WebSearchAgent ----------\n", + "[FunctionCall(id='call_3qv9so2DXFZIHtzqDIfXoFID', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='search_web_tool'), FunctionCall(id='call_Vh7zzzWUeiUAvaYjP0If0k1k', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='search_web_tool')]\n", + "---------- WebSearchAgent ----------\n", + "[FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.', call_id='call_3qv9so2DXFZIHtzqDIfXoFID'), FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', call_id='call_Vh7zzzWUeiUAvaYjP0If0k1k')]\n", + "---------- WebSearchAgent ----------\n", + "The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\n", + "The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.\n", + "---------- PlanningAgent ----------\n", + "Now let's calculate the percentage change in total rebounds between the 2007-2008 and 2008-2009 seasons for Dwyane Wade.\n", + "\n", + "4. DataAnalystAgent: Calculate the percentage change in total rebounds for Dwyane Wade between the 2007-2008 and 2008-2009 seasons.\n", + "---------- DataAnalystAgent ----------\n", + "[FunctionCall(id='call_FXnPSr6JVGfAWs3StIizbt2V', arguments='{\"start\":214,\"end\":398}', name='percentage_change_tool')]\n", + "---------- DataAnalystAgent ----------\n", + "[FunctionExecutionResult(content='85.98130841121495', call_id='call_FXnPSr6JVGfAWs3StIizbt2V')]\n", + "---------- DataAnalystAgent ----------\n", + "85.98130841121495\n", + "---------- PlanningAgent ----------\n", + "Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 season, scoring a total of 1397 points. The percentage change in his total rebounds from the 2007-2008 season (214 rebounds) to the 2008-2009 season (398 rebounds) is approximately 86.0%.\n", + "\n", + "TERMINATE\n" + ] + }, + { + "data": { + "text/plain": [ + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?', type='TextMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=161, completion_tokens=192), content=\"To answer this question, we need to follow these steps: \\n\\n1. Identify the Miami Heat player with the highest points in the 2006-2007 season.\\n2. Retrieve the total rebounds of that player for the 2007-2008 and 2008-2009 seasons.\\n3. Calculate the percentage change in his total rebounds between the two seasons.\\n\\nLet's delegate these tasks:\\n\\n1. WebSearchAgent: Find the Miami Heat player with the highest points in the 2006-2007 NBA season.\\n2. WebSearchAgent: Retrieve the total rebounds for the identified player during the 2007-2008 NBA season.\\n3. WebSearchAgent: Retrieve the total rebounds for the identified player during the 2008-2009 NBA season.\\n4. DataAnalystAgent: Calculate the percentage change in total rebounds between the 2007-2008 and 2008-2009 seasons for the player found.\", type='TextMessage'), ToolCallRequestEvent(source='WebSearchAgent', models_usage=RequestUsage(prompt_tokens=340, completion_tokens=27), content=[FunctionCall(id='call_Pz82ndNLSV4cH0Sg6g7ArP4L', arguments='{\"query\":\"Miami Heat player highest points 2006-2007 season\"}', name='search_web_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='WebSearchAgent', models_usage=None, content=[FunctionExecutionResult(content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', call_id='call_Pz82ndNLSV4cH0Sg6g7ArP4L')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='WebSearchAgent', models_usage=None, content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', type='ToolCallSummaryMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=420, completion_tokens=87), content=\"Great! Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 season. Now, let's continue with the next tasks:\\n\\n2. WebSearchAgent: Retrieve the total rebounds for Dwyane Wade during the 2007-2008 NBA season.\\n3. WebSearchAgent: Retrieve the total rebounds for Dwyane Wade during the 2008-2009 NBA season.\", type='TextMessage'), ToolCallRequestEvent(source='WebSearchAgent', models_usage=RequestUsage(prompt_tokens=525, completion_tokens=71), content=[FunctionCall(id='call_3qv9so2DXFZIHtzqDIfXoFID', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 season\"}', name='search_web_tool'), FunctionCall(id='call_Vh7zzzWUeiUAvaYjP0If0k1k', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 season\"}', name='search_web_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='WebSearchAgent', models_usage=None, content=[FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.', call_id='call_3qv9so2DXFZIHtzqDIfXoFID'), FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', call_id='call_Vh7zzzWUeiUAvaYjP0If0k1k')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='WebSearchAgent', models_usage=None, content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\\nThe number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', type='ToolCallSummaryMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=569, completion_tokens=68), content=\"Now let's calculate the percentage change in total rebounds between the 2007-2008 and 2008-2009 seasons for Dwyane Wade.\\n\\n4. DataAnalystAgent: Calculate the percentage change in total rebounds for Dwyane Wade between the 2007-2008 and 2008-2009 seasons.\", type='TextMessage'), ToolCallRequestEvent(source='DataAnalystAgent', models_usage=RequestUsage(prompt_tokens=627, completion_tokens=21), content=[FunctionCall(id='call_FXnPSr6JVGfAWs3StIizbt2V', arguments='{\"start\":214,\"end\":398}', name='percentage_change_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='DataAnalystAgent', models_usage=None, content=[FunctionExecutionResult(content='85.98130841121495', call_id='call_FXnPSr6JVGfAWs3StIizbt2V')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='DataAnalystAgent', models_usage=None, content='85.98130841121495', type='ToolCallSummaryMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=659, completion_tokens=76), content='Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 season, scoring a total of 1397 points. The percentage change in his total rebounds from the 2007-2008 season (214 rebounds) to the 2008-2009 season (398 rebounds) is approximately 86.0%.\\n\\nTERMINATE', type='TextMessage')], stop_reason=\"Text 'TERMINATE' mentioned\")" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:\n", + " if messages[-1].source != planning_agent.name:\n", + " return planning_agent.name\n", + " return None\n", + "\n", + "\n", + "# Reset the previous team and run the chat again with the selector function.\n", + "await team.reset()\n", + "team = SelectorGroupChat(\n", + " [planning_agent, web_search_agent, data_analyst_agent],\n", + " model_client=model_client,\n", + " termination_condition=termination,\n", + " selector_prompt=selector_prompt,\n", + " allow_repeated_speaker=True,\n", + " selector_func=selector_func,\n", + ")\n", + "\n", + "await Console(team.run_stream(task=task))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can see from the conversation log that the Planning Agent always speaks immediately after the specialized agents." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## User Feedback\n", + "\n", + "We can add {py:class}`~autogen_agentchat.agents.UserProxyAgent` to the team to\n", + "provide user feedback during a run.\n", + "See [Human-in-the-Loop](./tutorial/human-in-the-loop.ipynb) for more details\n", + "about {py:class}`~autogen_agentchat.agents.UserProxyAgent`.\n", + "\n", + "To use the {py:class}`~autogen_agentchat.agents.UserProxyAgent` in the \n", + "web search example, we simply add it to the team and update the selector function\n", + "to always check for user feedback after the planning agent speaks.\n", + "If the user responds with `\"APPROVE\"`, the conversation continues, otherwise,\n", + "the planning agent tries again, until the user approves." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- PlanningAgent ----------\n", + "To address the user's query, we will need to perform the following tasks:\n", + "\n", + "1. Identify the Miami Heat player with the highest points in the 2006-2007 season.\n", + "2. Find the total rebounds for that player in the 2007-2008 season.\n", + "3. Find the total rebounds for that player in the 2008-2009 season.\n", + "4. Calculate the percentage change in the total rebounds between the 2007-2008 and 2008-2009 seasons.\n", + "\n", + "Let's assign these tasks:\n", + "\n", + "1. **WebSearchAgent**: Identify the Miami Heat player with the highest points in the 2006-2007 season.\n", + " \n", + "(Task 2 and 3 depend on the result of Task 1. We'll proceed with Tasks 2 and 3 once Task 1 is complete.)\n", + "---------- UserProxyAgent ----------\n", + "approve\n", + "---------- WebSearchAgent ----------\n", + "[FunctionCall(id='call_0prr3fUnG5CtisUG7QeygW0w', arguments='{\"query\":\"Miami Heat highest points scorer 2006-2007 NBA season\"}', name='search_web_tool')]\n", + "---------- WebSearchAgent ----------\n", + "[FunctionExecutionResult(content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', call_id='call_0prr3fUnG5CtisUG7QeygW0w')]\n", + "---------- WebSearchAgent ----------\n", + "Here are the total points scored by Miami Heat players in the 2006-2007 season:\n", + " Udonis Haslem: 844 points\n", + " Dwayne Wade: 1397 points\n", + " James Posey: 550 points\n", + " ...\n", + " \n", + "---------- PlanningAgent ----------\n", + "Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 season, scoring 1397 points.\n", + "\n", + "Next, we need to find Dwyane Wade's total rebounds for the 2007-2008 and 2008-2009 seasons:\n", + "\n", + "2. **WebSearchAgent**: Find Dwyane Wade's total rebounds for the 2007-2008 season.\n", + "3. **WebSearchAgent**: Find Dwyane Wade's total rebounds for the 2008-2009 season.\n", + "---------- UserProxyAgent ----------\n", + "approve\n", + "---------- WebSearchAgent ----------\n", + "[FunctionCall(id='call_fBZe80NaBfruOVGwRWbhXyRm', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 NBA season\"}', name='search_web_tool'), FunctionCall(id='call_cURYibna4fGxySiL7IYt0c3s', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 NBA season\"}', name='search_web_tool')]\n", + "---------- WebSearchAgent ----------\n", + "[FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.', call_id='call_fBZe80NaBfruOVGwRWbhXyRm'), FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', call_id='call_cURYibna4fGxySiL7IYt0c3s')]\n", + "---------- WebSearchAgent ----------\n", + "The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\n", + "The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.\n", + "---------- PlanningAgent ----------\n", + "Now that we have Dwyane Wade's total rebounds for both seasons, we can calculate the percentage change:\n", + "\n", + "4. **DataAnalystAgent**: Calculate the percentage change in Dwyane Wade's total rebounds from the 2007-2008 season (214 rebounds) to the 2008-2009 season (398 rebounds).\n", + "---------- UserProxyAgent ----------\n", + "approve\n", + "---------- DataAnalystAgent ----------\n", + "[FunctionCall(id='call_z3uog7t2x0z1Suzl5hACF9hY', arguments='{\"start\":214,\"end\":398}', name='percentage_change_tool')]\n", + "---------- DataAnalystAgent ----------\n", + "[FunctionExecutionResult(content='85.98130841121495', call_id='call_z3uog7t2x0z1Suzl5hACF9hY')]\n", + "---------- DataAnalystAgent ----------\n", + "85.98130841121495\n", + "---------- PlanningAgent ----------\n", + "Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 season, scoring 1397 points. His total rebounds increased from 214 in the 2007-2008 season to 398 in the 2008-2009 season, which is a percentage change of approximately 85.98%.\n", + "\n", + "TERMINATE\n" + ] + }, + { + "data": { + "text/plain": [ + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Who was the Miami Heat player with the highest points in the 2006-2007 season, and what was the percentage change in his total rebounds between the 2007-2008 and 2008-2009 seasons?', type='TextMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=161, completion_tokens=166), content=\"To address the user's query, we will need to perform the following tasks:\\n\\n1. Identify the Miami Heat player with the highest points in the 2006-2007 season.\\n2. Find the total rebounds for that player in the 2007-2008 season.\\n3. Find the total rebounds for that player in the 2008-2009 season.\\n4. Calculate the percentage change in the total rebounds between the 2007-2008 and 2008-2009 seasons.\\n\\nLet's assign these tasks:\\n\\n1. **WebSearchAgent**: Identify the Miami Heat player with the highest points in the 2006-2007 season.\\n \\n(Task 2 and 3 depend on the result of Task 1. We'll proceed with Tasks 2 and 3 once Task 1 is complete.)\", type='TextMessage'), UserInputRequestedEvent(source='UserProxyAgent', models_usage=None, request_id='2a433f88-f886-4b39-a078-ea1acdcb2f9d', content='', type='UserInputRequestedEvent'), TextMessage(source='UserProxyAgent', models_usage=None, content='approve', type='TextMessage'), ToolCallRequestEvent(source='WebSearchAgent', models_usage=RequestUsage(prompt_tokens=323, completion_tokens=28), content=[FunctionCall(id='call_0prr3fUnG5CtisUG7QeygW0w', arguments='{\"query\":\"Miami Heat highest points scorer 2006-2007 NBA season\"}', name='search_web_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='WebSearchAgent', models_usage=None, content=[FunctionExecutionResult(content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', call_id='call_0prr3fUnG5CtisUG7QeygW0w')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='WebSearchAgent', models_usage=None, content='Here are the total points scored by Miami Heat players in the 2006-2007 season:\\n Udonis Haslem: 844 points\\n Dwayne Wade: 1397 points\\n James Posey: 550 points\\n ...\\n ', type='ToolCallSummaryMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=403, completion_tokens=112), content=\"Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 season, scoring 1397 points.\\n\\nNext, we need to find Dwyane Wade's total rebounds for the 2007-2008 and 2008-2009 seasons:\\n\\n2. **WebSearchAgent**: Find Dwyane Wade's total rebounds for the 2007-2008 season.\\n3. **WebSearchAgent**: Find Dwyane Wade's total rebounds for the 2008-2009 season.\", type='TextMessage'), UserInputRequestedEvent(source='UserProxyAgent', models_usage=None, request_id='23dd4570-2391-41e9-aeea-86598499792c', content='', type='UserInputRequestedEvent'), TextMessage(source='UserProxyAgent', models_usage=None, content='approve', type='TextMessage'), ToolCallRequestEvent(source='WebSearchAgent', models_usage=RequestUsage(prompt_tokens=543, completion_tokens=73), content=[FunctionCall(id='call_fBZe80NaBfruOVGwRWbhXyRm', arguments='{\"query\": \"Dwyane Wade total rebounds 2007-2008 NBA season\"}', name='search_web_tool'), FunctionCall(id='call_cURYibna4fGxySiL7IYt0c3s', arguments='{\"query\": \"Dwyane Wade total rebounds 2008-2009 NBA season\"}', name='search_web_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='WebSearchAgent', models_usage=None, content=[FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.', call_id='call_fBZe80NaBfruOVGwRWbhXyRm'), FunctionExecutionResult(content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', call_id='call_cURYibna4fGxySiL7IYt0c3s')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='WebSearchAgent', models_usage=None, content='The number of total rebounds for Dwayne Wade in the Miami Heat season 2007-2008 is 214.\\nThe number of total rebounds for Dwayne Wade in the Miami Heat season 2008-2009 is 398.', type='ToolCallSummaryMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=586, completion_tokens=70), content=\"Now that we have Dwyane Wade's total rebounds for both seasons, we can calculate the percentage change:\\n\\n4. **DataAnalystAgent**: Calculate the percentage change in Dwyane Wade's total rebounds from the 2007-2008 season (214 rebounds) to the 2008-2009 season (398 rebounds).\", type='TextMessage'), UserInputRequestedEvent(source='UserProxyAgent', models_usage=None, request_id='e849d193-4ab3-4558-8560-7dbc062a0aee', content='', type='UserInputRequestedEvent'), TextMessage(source='UserProxyAgent', models_usage=None, content='approve', type='TextMessage'), ToolCallRequestEvent(source='DataAnalystAgent', models_usage=RequestUsage(prompt_tokens=655, completion_tokens=21), content=[FunctionCall(id='call_z3uog7t2x0z1Suzl5hACF9hY', arguments='{\"start\":214,\"end\":398}', name='percentage_change_tool')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='DataAnalystAgent', models_usage=None, content=[FunctionExecutionResult(content='85.98130841121495', call_id='call_z3uog7t2x0z1Suzl5hACF9hY')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='DataAnalystAgent', models_usage=None, content='85.98130841121495', type='ToolCallSummaryMessage'), TextMessage(source='PlanningAgent', models_usage=RequestUsage(prompt_tokens=687, completion_tokens=74), content='Dwyane Wade was the Miami Heat player with the highest points in the 2006-2007 season, scoring 1397 points. His total rebounds increased from 214 in the 2007-2008 season to 398 in the 2008-2009 season, which is a percentage change of approximately 85.98%.\\n\\nTERMINATE', type='TextMessage')], stop_reason=\"Text 'TERMINATE' mentioned\")" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "user_proxy_agent = UserProxyAgent(\"UserProxyAgent\", description=\"A proxy for the user to approve or disapprove tasks.\")\n", + "\n", + "\n", + "def selector_func_with_user_proxy(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:\n", + " if messages[-1].source != planning_agent.name and messages[-1].source != user_proxy_agent.name:\n", + " # Planning agent should be the first to engage when given a new task, or check progress.\n", + " return planning_agent.name\n", + " if messages[-1].source == planning_agent.name:\n", + " if messages[-2].source == user_proxy_agent.name and \"APPROVE\" in messages[-1].content.upper(): # type: ignore\n", + " # User has approved the plan, proceed to the next agent.\n", + " return None\n", + " # Use the user proxy agent to get the user's approval to proceed.\n", + " return user_proxy_agent.name\n", + " if messages[-1].source == user_proxy_agent.name:\n", + " # If the user does not approve, return to the planning agent.\n", + " if \"APPROVE\" not in messages[-1].content.upper(): # type: ignore\n", + " return planning_agent.name\n", + " return None\n", + "\n", + "\n", + "# Reset the previous agents and run the chat again with the user proxy agent and selector function.\n", + "await team.reset()\n", + "team = SelectorGroupChat(\n", + " [planning_agent, web_search_agent, data_analyst_agent, user_proxy_agent],\n", + " model_client=model_client,\n", + " termination_condition=termination,\n", + " selector_prompt=selector_prompt,\n", + " selector_func=selector_func_with_user_proxy,\n", + " allow_repeated_speaker=True,\n", + ")\n", + "\n", + "await Console(team.run_stream(task=task))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, the user's feedback is incorporated into the conversation flow,\n", + "and the user can approve or reject the planning agent's decisions." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } From b868e32b05791ef4c7693339473414ebc6f89d24 Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Sun, 9 Feb 2025 14:39:19 +1000 Subject: [PATCH 19/36] fix: update SK adapter stream tool call processing. (#5449) ## Why are these changes needed? The current stream processing of SK model adapter returns on the first function call chunk but this behavior is incorrect end ends up returning with an incomplete function call. The observed behavior is that the function name and arguments are split into different chunks and this update correctly processes the chunks in this way. ## Related issue number Fixes the reply in #5420 ## Checks - [ ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --------- Co-authored-by: Leonardo Pinheiro --- .../_sk_chat_completion_adapter.py | 130 ++++++-- .../models/test_sk_chat_completion_adapter.py | 315 +++++++++++++++++- 2 files changed, 399 insertions(+), 46 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py index 6cea28d0f2f5..07f3fe74802c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py @@ -1,5 +1,6 @@ import json from typing import Any, Literal, Mapping, Optional, Sequence +import warnings from autogen_core import FunctionCall from autogen_core._cancellation_token import CancellationToken @@ -18,7 +19,6 @@ from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent -from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.functions.kernel_plugin import KernelPlugin from semantic_kernel.kernel import Kernel from typing_extensions import AsyncGenerator, Union @@ -427,6 +427,28 @@ async def create( thought=thought, ) + @staticmethod + def _merge_function_call_content(existing_call: FunctionCallContent, new_chunk: FunctionCallContent) -> None: + """Helper to merge partial argument chunks from new_chunk into existing_call.""" + if isinstance(existing_call.arguments, str) and isinstance(new_chunk.arguments, str): + existing_call.arguments += new_chunk.arguments + elif isinstance(existing_call.arguments, dict) and isinstance(new_chunk.arguments, dict): + existing_call.arguments.update(new_chunk.arguments) + elif not existing_call.arguments or existing_call.arguments in ("{}", ""): + # If existing had no arguments yet, just take the new one + existing_call.arguments = new_chunk.arguments + else: + # If there's a mismatch (str vs dict), handle as needed + warnings.warn("Mismatch in argument types during merge. Existing arguments retained.", stacklevel=2) + + # Optionally update name/function_name if newly provided + if new_chunk.name: + existing_call.name = new_chunk.name + if new_chunk.plugin_name: + existing_call.plugin_name = new_chunk.plugin_name + if new_chunk.function_name: + existing_call.function_name = new_chunk.function_name + async def create_stream( self, messages: Sequence[LLMMessage], @@ -460,6 +482,7 @@ async def create_stream( Yields: Union[str, CreateResult]: Either a string chunk of the response or a CreateResult containing function calls. """ + kernel = self._get_kernel(extra_create_args) chat_history = self._convert_to_chat_history(messages) user_settings = self._get_prompt_settings(extra_create_args) @@ -468,54 +491,105 @@ async def create_stream( prompt_tokens = 0 completion_tokens = 0 - accumulated_content = "" + accumulated_text = "" + + # Keep track of in-progress function calls. Keyed by ID + # because partial chunks for the same function call might arrive separately. + function_calls_in_progress: dict[str, FunctionCallContent] = {} + + # Track the ID of the last function call we saw so we can continue + # accumulating chunk arguments for that call if new items have id=None + last_function_call_id: Optional[str] = None async for streaming_messages in self._sk_client.get_streaming_chat_message_contents( chat_history, settings=settings, kernel=kernel ): for msg in streaming_messages: - if not isinstance(msg, StreamingChatMessageContent): - continue - # Track token usage if msg.metadata and "usage" in msg.metadata: usage = msg.metadata["usage"] prompt_tokens = getattr(usage, "prompt_tokens", 0) completion_tokens = getattr(usage, "completion_tokens", 0) - # Check for function calls - if any(isinstance(item, FunctionCallContent) for item in msg.items): - function_calls = self._process_tool_calls(msg) + # Process function call deltas + for item in msg.items: + if isinstance(item, FunctionCallContent): + # If the chunk has a valid ID, we start or continue that ID explicitly + if item.id: + last_function_call_id = item.id + if last_function_call_id not in function_calls_in_progress: + function_calls_in_progress[last_function_call_id] = item + else: + # Merge partial arguments into existing call + existing_call = function_calls_in_progress[last_function_call_id] + self._merge_function_call_content(existing_call, item) + else: + # item.id is None, so we assume it belongs to the last known ID + if not last_function_call_id: + # No call in progress means we can't merge + # You could either skip or raise an error here + warnings.warn( + "Received function call chunk with no ID and no call in progress.", stacklevel=2 + ) + continue + + existing_call = function_calls_in_progress[last_function_call_id] + # Merge partial chunk + self._merge_function_call_content(existing_call, item) + + # Check if the model signaled tool_calls finished + if msg.finish_reason == "tool_calls" and function_calls_in_progress: + calls_to_yield: list[FunctionCall] = [] + for _, call_content in function_calls_in_progress.items(): + plugin_name = call_content.plugin_name or "" + function_name = call_content.function_name + if plugin_name: + full_name = f"{plugin_name}-{function_name}" + else: + full_name = function_name + + if isinstance(call_content.arguments, dict): + arguments = json.dumps(call_content.arguments) + else: + assert isinstance(call_content.arguments, str) + arguments = call_content.arguments or "{}" + + calls_to_yield.append( + FunctionCall( + id=call_content.id or "unknown_id", + name=full_name, + arguments=arguments, + ) + ) + # Yield all function calls in progress yield CreateResult( - content=function_calls, + content=calls_to_yield, finish_reason="function_calls", usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), cached=False, ) return - # Handle text content + # Handle any plain text in the message if msg.content: - accumulated_content += msg.content + accumulated_text += msg.content yield msg.content - # Final yield if there was text content - if accumulated_content: - self._total_prompt_tokens += prompt_tokens - self._total_completion_tokens += completion_tokens - - if isinstance(accumulated_content, str) and self._model_info["family"] == ModelFamily.R1: - thought, accumulated_content = parse_r1_content(accumulated_content) - else: - thought = None - - yield CreateResult( - content=accumulated_content, - finish_reason="stop", - usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), - cached=False, - thought=thought, - ) + # If we exit the loop without tool calls finishing, yield whatever text was accumulated + self._total_prompt_tokens += prompt_tokens + self._total_completion_tokens += completion_tokens + + thought = None + if isinstance(accumulated_text, str) and self._model_info["family"] == ModelFamily.R1: + thought, accumulated_text = parse_r1_content(accumulated_text) + + yield CreateResult( + content=accumulated_text, + finish_reason="stop", + usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), + cached=False, + thought=thought, + ) def actual_usage(self) -> RequestUsage: return RequestUsage(prompt_tokens=self._total_prompt_tokens, completion_tokens=self._total_completion_tokens) diff --git a/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py b/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py index 1b5a6ea03fa5..ce602d8fbad2 100644 --- a/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py +++ b/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py @@ -7,7 +7,13 @@ from autogen_core.models import CreateResult, LLMMessage, ModelFamily, ModelInfo, SystemMessage, UserMessage from autogen_core.tools import BaseTool from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion @@ -72,7 +78,7 @@ async def mock_get_chat_message_contents( id="call_UwVVI0iGEmcPwmKUigJcuuuF", function_name="calculator", plugin_name=None, - arguments="{}", + arguments='{"a": 2, "b": 2}', ) ], finish_reason=FinishReason.TOOL_CALLS, @@ -96,30 +102,89 @@ async def mock_get_streaming_chat_message_contents( **kwargs: Any, ) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]: if "What is 2 + 2?" in str(chat_history): - # Mock response for calculator tool test - single message with function call + # Initial chunk with function call setup yield [ StreamingChatMessageContent( choice_index=0, - inner_content=None, + inner_content=ChatCompletionChunk( + id="chatcmpl-123", + choices=[ + Choice( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="call_UwVVI0iGEmcPwmKUigJcuuuF", + function=ChoiceDeltaToolCallFunction(name="calculator", arguments=""), + type="function", + ) + ], + ), + finish_reason=None, + index=0, + ) + ], + created=1736673679, + model="gpt-4o-mini", + object="chat.completion.chunk", + ), ai_model_id="gpt-4o-mini", - metadata={ - "logprobs": None, - "id": "chatcmpl-AooRjGxKtdTke46keWkBQBKg033XW", - "created": 1736673679, - "usage": {"prompt_tokens": 53, "completion_tokens": 13}, - }, role=AuthorRole.ASSISTANT, - items=[ # type: ignore + items=[ FunctionCallContent( - id="call_n8135GXc2kbiaaDdpImsB1VW", - function_name="calculator", - plugin_name=None, - arguments="", - content_type="function_call", # type: ignore + id="call_UwVVI0iGEmcPwmKUigJcuuuF", function_name="calculator", arguments="" ) ], - finish_reason=None, - function_invoke_attempt=0, + ) + ] + + # Arguments chunks + for arg_chunk in ["{", '"a"', ":", " ", "2", ",", " ", '"b"', ":", " ", "2", "}"]: + yield [ + StreamingChatMessageContent( + choice_index=0, + inner_content=ChatCompletionChunk( + id="chatcmpl-123", + choices=[ + Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, function=ChoiceDeltaToolCallFunction(arguments=arg_chunk) + ) + ] + ), + finish_reason=None, + index=0, + ) + ], + created=1736673679, + model="gpt-4o-mini", + object="chat.completion.chunk", + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + items=[FunctionCallContent(function_name="calculator", arguments=arg_chunk)], + ) + ] + + # Final chunk with finish reason + yield [ + StreamingChatMessageContent( # type: ignore + choice_index=0, + inner_content=ChatCompletionChunk( + id="chatcmpl-123", + choices=[Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)], + created=1736673679, + model="gpt-4o-mini", + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=53, completion_tokens=13, total_tokens=66), + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + finish_reason=FinishReason.TOOL_CALLS, + metadata={"usage": {"prompt_tokens": 53, "completion_tokens": 13}}, ) ] else: @@ -449,3 +514,217 @@ async def mock_get_streaming_chat_message_contents( assert response_chunks[-1].finish_reason == "stop" assert response_chunks[-1].content == "Hello!" assert response_chunks[-1].thought == "Reasoning..." + + +@pytest.mark.asyncio +async def test_sk_chat_completion_stream_with_multiple_function_calls() -> None: + """ + This test returns two distinct function calls via streaming, each one arriving in pieces. + We intentionally set name, plugin_name, and function_name in the later partial chunks so + that _merge_function_call_content is triggered to update them. + """ + + async def mock_get_streaming_chat_message_contents( + chat_history: ChatHistory, + settings: PromptExecutionSettings, + **kwargs: Any, + ) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]: + # First partial chunk for call_1 + yield [ + StreamingChatMessageContent( + choice_index=0, + inner_content=ChatCompletionChunk( + id="chunk-id-1", + choices=[ + Choice( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="call_1", + function=ChoiceDeltaToolCallFunction(name=None, arguments='{"arg1":'), + type="function", + ) + ], + ), + finish_reason=None, + index=0, + ) + ], + created=1736679999, + model="gpt-4o-mini", + object="chat.completion.chunk", + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + items=[ + FunctionCallContent( + id="call_1", + # no plugin_name/function_name yet + name=None, + arguments='{"arg1":', + ) + ], + ) + ] + # Second partial chunk for call_1 (updates plugin_name/function_name) + yield [ + StreamingChatMessageContent( + choice_index=0, + inner_content=ChatCompletionChunk( + id="chunk-id-2", + choices=[ + Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction( + # Provide the rest of the arguments + arguments='"value1"}', + name="firstFunction", + ), + ) + ] + ), + finish_reason=None, + index=0, + ) + ], + created=1736679999, + model="gpt-4o-mini", + object="chat.completion.chunk", + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + items=[ + FunctionCallContent( + id="call_1", plugin_name="myPlugin", function_name="firstFunction", arguments='"value1"}' + ) + ], + ) + ] + # Now partial chunk for a second call, call_2 + yield [ + StreamingChatMessageContent( + choice_index=0, + inner_content=ChatCompletionChunk( + id="chunk-id-3", + choices=[ + Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="call_2", + function=ChoiceDeltaToolCallFunction(name=None, arguments='{"arg2":"another"}'), + type="function", + ) + ], + ), + finish_reason=None, + index=0, + ) + ], + created=1736679999, + model="gpt-4o-mini", + object="chat.completion.chunk", + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + items=[FunctionCallContent(id="call_2", arguments='{"arg2":"another"}')], + ) + ] + # Next partial chunk updates name, plugin_name, function_name for call_2 + yield [ + StreamingChatMessageContent( + choice_index=0, + inner_content=ChatCompletionChunk( + id="chunk-id-4", + choices=[ + Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, function=ChoiceDeltaToolCallFunction(name="secondFunction") + ) + ] + ), + finish_reason=None, + index=0, + ) + ], + created=1736679999, + model="gpt-4o-mini", + object="chat.completion.chunk", + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + items=[ + FunctionCallContent( + id="call_2", + name="someFancyName", + plugin_name="anotherPlugin", + function_name="secondFunction", + ) + ], + ) + ] + # Final chunk signals finish with tool_calls + yield [ + StreamingChatMessageContent( # type: ignore + choice_index=0, + inner_content=ChatCompletionChunk( + id="chunk-id-5", + choices=[Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)], + created=1736679999, + model="gpt-4o-mini", + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + finish_reason=FinishReason.TOOL_CALLS, + metadata={"usage": {"prompt_tokens": 10, "completion_tokens": 5}}, + ) + ] + + # Mock SK client + mock_client = AsyncMock(spec=AzureChatCompletion) + mock_client.get_streaming_chat_message_contents = mock_get_streaming_chat_message_contents + + # Create adapter and kernel + kernel = Kernel(memory=NullMemory()) + adapter = SKChatCompletionAdapter(mock_client, kernel=kernel) + + # Call create_stream with no actual tools (we just test the multiple calls) + messages: list[LLMMessage] = [ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Call two different plugin functions", source="user"), + ] + + # Collect streaming outputs + response_chunks: list[CreateResult | str] = [] + async for chunk in adapter.create_stream(messages=messages): + response_chunks.append(chunk) + + # The final chunk should be a CreateResult with function_calls + assert len(response_chunks) > 0 + final_chunk = response_chunks[-1] + assert isinstance(final_chunk, CreateResult) + assert final_chunk.finish_reason == "function_calls" + assert isinstance(final_chunk.content, list) + assert len(final_chunk.content) == 2 # We expect 2 calls + + # Verify first call merged name + arguments + first_call = final_chunk.content[0] + assert first_call.id == "call_1" + assert first_call.name == "myPlugin-firstFunction" # pluginName-functionName + assert '{"arg1":"value1"}' in first_call.arguments + + # Verify second call also merged everything + second_call = final_chunk.content[1] + assert second_call.id == "call_2" + assert second_call.name == "anotherPlugin-secondFunction" + assert '{"arg2":"another"}' in second_call.arguments From 7fc7f383f0b9acc8b3593c7a8d06cfea54246bf9 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sat, 8 Feb 2025 20:50:12 -0800 Subject: [PATCH 20/36] Enable LLM Call Observability in AGS (#5457) It is often helpful to inspect the raw request and response to/from an LLM as agents act. This PR, does the following: - Updates TeamManager to yield LLMCallEvents from core library - Run in an async background task to listen for LLMCallEvent JIT style when a team is run - Add events to an async queue and - yield those events in addition to whatever actual agentchat team.run_stream yields. - Update the AGS UI to show those LLMCallEvents in the messages section as a team runs - Add settings panel to show/hide llm call events in messages. - Minor updates to default team image image image ## Why are these changes needed? ## Related issue number Closes #5440 ## Checks - [ ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --- .../autogenstudio/datamodel/__init__.py | 2 + .../autogenstudio/datamodel/types.py | 6 + .../autogenstudio/gallery/builder.py | 4 +- .../gallery/tools/bing_search.py | 4 +- .../gallery/tools/google_search.py | 4 +- .../autogenstudio/teammanager/teammanager.py | 47 +++-- .../autogen-studio/autogenstudio/web/app.py | 5 - .../autogenstudio/web/managers/connection.py | 13 +- .../frontend/src/components/sidebar.tsx | 4 +- .../src/components/types/datamodel.ts | 8 +- .../frontend/src/components/views/atoms.tsx | 100 ++++++++-- .../views/gallery/default_gallery.json | 2 +- .../src/components/views/gallery/store.tsx | 2 +- .../views/{session => playground}/api.ts | 0 .../chat/agentflow/agentflow.tsx | 19 +- .../chat/agentflow/agentnode.tsx | 0 .../chat/agentflow/edge.tsx | 0 .../chat/agentflow/edgemessagemodal.tsx | 0 .../chat/agentflow/toolbar.tsx | 0 .../{session => playground}/chat/chat.tsx | 0 .../chat/chatinput.tsx | 0 .../chat/inputrequest.tsx | 0 .../views/playground/chat/logrenderer.tsx | 187 ++++++++++++++++++ .../chat/rendermessage.tsx | 27 ++- .../{session => playground}/chat/runview.tsx | 78 +++++--- .../{session => playground}/chat/types.ts | 0 .../views/{session => playground}/editor.tsx | 0 .../views/{session => playground}/manager.tsx | 0 .../{session => playground}/newsession.tsx | 0 .../views/{session => playground}/sidebar.tsx | 0 .../views/{session => playground}/types.ts | 0 .../src/components/views/settings/manager.tsx | 163 +++++++++++++++ .../src/components/views/settings/sidebar.tsx | 87 ++++++++ .../src/components/views/settings/store.tsx | 121 ++++++++++++ .../src/components/views/settings/types.tsx | 8 + .../components/views/team/builder/builder.tsx | 5 +- .../views/team/builder/testdrawer.tsx | 4 +- .../src/components/views/team/sidebar.tsx | 11 +- .../src/components/views/team/types.ts | 82 -------- .../frontend/src/pages/index.tsx | 4 +- .../frontend/src/pages/settings.tsx | 13 +- 41 files changed, 801 insertions(+), 209 deletions(-) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/api.ts (100%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/chat/agentflow/agentflow.tsx (97%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/chat/agentflow/agentnode.tsx (100%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/chat/agentflow/edge.tsx (100%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/chat/agentflow/edgemessagemodal.tsx (100%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/chat/agentflow/toolbar.tsx (100%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/chat/chat.tsx (100%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/chat/chatinput.tsx (100%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/chat/inputrequest.tsx (100%) create mode 100644 python/packages/autogen-studio/frontend/src/components/views/playground/chat/logrenderer.tsx rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/chat/rendermessage.tsx (89%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/chat/runview.tsx (85%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/chat/types.ts (100%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/editor.tsx (100%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/manager.tsx (100%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/newsession.tsx (100%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/sidebar.tsx (100%) rename python/packages/autogen-studio/frontend/src/components/views/{session => playground}/types.ts (100%) create mode 100644 python/packages/autogen-studio/frontend/src/components/views/settings/manager.tsx create mode 100644 python/packages/autogen-studio/frontend/src/components/views/settings/sidebar.tsx create mode 100644 python/packages/autogen-studio/frontend/src/components/views/settings/store.tsx create mode 100644 python/packages/autogen-studio/frontend/src/components/views/settings/types.tsx diff --git a/python/packages/autogen-studio/autogenstudio/datamodel/__init__.py b/python/packages/autogen-studio/autogenstudio/datamodel/__init__.py index b47788ed7273..b7f6f8da56e0 100644 --- a/python/packages/autogen-studio/autogenstudio/datamodel/__init__.py +++ b/python/packages/autogen-studio/autogenstudio/datamodel/__init__.py @@ -4,6 +4,7 @@ GalleryComponents, GalleryItems, GalleryMetadata, + LLMCallEventMessage, MessageConfig, MessageMeta, Response, @@ -22,4 +23,5 @@ "TeamResult", "Response", "SocketMessage", + "LLMCallEventMessage", ] diff --git a/python/packages/autogen-studio/autogenstudio/datamodel/types.py b/python/packages/autogen-studio/autogenstudio/datamodel/types.py index b35b7a054e67..3e3bdbbe43ad 100644 --- a/python/packages/autogen-studio/autogenstudio/datamodel/types.py +++ b/python/packages/autogen-studio/autogenstudio/datamodel/types.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional from autogen_agentchat.base import TaskResult +from autogen_agentchat.messages import BaseChatMessage from autogen_core import ComponentModel from pydantic import BaseModel @@ -18,6 +19,11 @@ class TeamResult(BaseModel): duration: float +class LLMCallEventMessage(BaseChatMessage): + source: str = "llm_call_event" + content: str + + class MessageMeta(BaseModel): task: Optional[str] = None task_result: Optional[TaskResult] = None diff --git a/python/packages/autogen-studio/autogenstudio/gallery/builder.py b/python/packages/autogen-studio/autogenstudio/gallery/builder.py index 0d9a36f83471..a3d4b4523bba 100644 --- a/python/packages/autogen-studio/autogenstudio/gallery/builder.py +++ b/python/packages/autogen-studio/autogenstudio/gallery/builder.py @@ -240,7 +240,7 @@ def create_default_gallery() -> Gallery: builder.add_team( websurfer_team.dump_component(), label="Web Agent Team (Operator)", - description="A group chat team that have participants takes turn to publish a message\n to all, using a ChatCompletion model to select the next speaker after each message.", + description="A team with 3 agents - a Web Surfer agent that can browse the web, a Verification Assistant that verifies and summarizes information, and a User Proxy that provides human feedback when needed.", ) builder.add_tool( @@ -347,7 +347,7 @@ def create_default_gallery() -> Gallery: builder.add_team( deep_research_team.dump_component(), label="Deep Research Team", - description="A team that performs deep research using web searches, verification, and summarization.", + description="A team with 3 agents - a Research Assistant that performs web searches and analyzes information, a Verifier that ensures research quality and completeness, and a Summary Agent that provides a detailed markdown summary of the research as a report to the user.", ) return builder.build() diff --git a/python/packages/autogen-studio/autogenstudio/gallery/tools/bing_search.py b/python/packages/autogen-studio/autogenstudio/gallery/tools/bing_search.py index c3e3c6e6a18d..73b59f0f0bac 100644 --- a/python/packages/autogen-studio/autogenstudio/gallery/tools/bing_search.py +++ b/python/packages/autogen-studio/autogenstudio/gallery/tools/bing_search.py @@ -12,10 +12,10 @@ async def bing_search( query: str, - num_results: int = 5, + num_results: int = 3, include_snippets: bool = True, include_content: bool = True, - content_max_length: Optional[int] = 15000, + content_max_length: Optional[int] = 10000, language: str = "en", country: Optional[str] = None, safe_search: str = "moderate", diff --git a/python/packages/autogen-studio/autogenstudio/gallery/tools/google_search.py b/python/packages/autogen-studio/autogenstudio/gallery/tools/google_search.py index b8f6800b39dc..c1fc8f0f2c20 100644 --- a/python/packages/autogen-studio/autogenstudio/gallery/tools/google_search.py +++ b/python/packages/autogen-studio/autogenstudio/gallery/tools/google_search.py @@ -11,10 +11,10 @@ async def google_search( query: str, - num_results: int = 5, + num_results: int = 3, include_snippets: bool = True, include_content: bool = True, - content_max_length: Optional[int] = 15000, + content_max_length: Optional[int] = 10000, language: str = "en", country: Optional[str] = None, safe_search: bool = True, diff --git a/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py b/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py index 886c228fab10..d194e351af59 100644 --- a/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py +++ b/python/packages/autogen-studio/autogenstudio/teammanager/teammanager.py @@ -1,3 +1,4 @@ +import asyncio import json import logging import time @@ -8,13 +9,26 @@ import yaml from autogen_agentchat.base import TaskResult, Team from autogen_agentchat.messages import AgentEvent, ChatMessage -from autogen_core import CancellationToken, Component, ComponentModel +from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel +from autogen_core.logging import LLMCallEvent -from ..datamodel.types import TeamResult +from ..datamodel.types import LLMCallEventMessage, TeamResult logger = logging.getLogger(__name__) +class RunEventLogger(logging.Handler): + """Event logger that queues LLMCallEvents for streaming""" + + def __init__(self): + super().__init__() + self.events = asyncio.Queue() + + def emit(self, record: logging.LogRecord): + if isinstance(record.msg, LLMCallEvent): + self.events.put_nowait(LLMCallEventMessage(content=str(record.msg))) + + class TeamManager: """Manages team operations including loading configs and running teams""" @@ -35,14 +49,7 @@ async def load_from_file(path: Union[str, Path]) -> dict: @staticmethod async def load_from_directory(directory: Union[str, Path]) -> List[dict]: - """Load all team configurations from a directory - - Args: - directory (Union[str, Path]): Path to directory containing config files - - Returns: - List[dict]: List of loaded team configurations - """ + """Load all team configurations from a directory""" directory = Path(directory) configs = [] valid_extensions = {".json", ".yaml", ".yml"} @@ -61,7 +68,6 @@ async def _create_team( self, team_config: Union[str, Path, dict, ComponentModel], input_func: Optional[Callable] = None ) -> Component: """Create team instance from config""" - # Handle different input types if isinstance(team_config, (str, Path)): config = await self.load_from_file(team_config) elif isinstance(team_config, dict): @@ -69,14 +75,12 @@ async def _create_team( else: config = team_config.model_dump() - # Use Component.load_component directly team = Team.load_component(config) for agent in team._participants: if hasattr(agent, "input_func"): agent.input_func = input_func - # TBD - set input function return team async def run_stream( @@ -85,11 +89,17 @@ async def run_stream( team_config: Union[str, Path, dict, ComponentModel], input_func: Optional[Callable] = None, cancellation_token: Optional[CancellationToken] = None, - ) -> AsyncGenerator[Union[AgentEvent | ChatMessage, ChatMessage, TaskResult], None]: + ) -> AsyncGenerator[Union[AgentEvent | ChatMessage | LLMCallEvent, ChatMessage, TeamResult], None]: """Stream team execution results""" start_time = time.time() team = None + # Setup logger correctly + logger = logging.getLogger(EVENT_LOGGER_NAME) + logger.setLevel(logging.INFO) + llm_event_logger = RunEventLogger() + logger.handlers = [llm_event_logger] # Replace all handlers + try: team = await self._create_team(team_config, input_func) @@ -102,7 +112,15 @@ async def run_stream( else: yield message + # Check for any LLM events + while not llm_event_logger.events.empty(): + event = await llm_event_logger.events.get() + yield event + finally: + # Cleanup - remove our handler + logger.handlers.remove(llm_event_logger) + # Ensure cleanup happens if team and hasattr(team, "_participants"): for agent in team._participants: @@ -127,7 +145,6 @@ async def run( return TeamResult(task_result=result, usage="", duration=time.time() - start_time) finally: - # Ensure cleanup happens if team and hasattr(team, "_participants"): for agent in team._participants: if hasattr(agent, "close"): diff --git a/python/packages/autogen-studio/autogenstudio/web/app.py b/python/packages/autogen-studio/autogenstudio/web/app.py index 35dd8fc2c96c..a9fbb0fcc543 100644 --- a/python/packages/autogen-studio/autogenstudio/web/app.py +++ b/python/packages/autogen-studio/autogenstudio/web/app.py @@ -15,11 +15,6 @@ from .initialization import AppInitializer from .routes import runs, sessions, teams, ws -# Configure logging -# logger = logging.getLogger(__name__) -# logging.basicConfig(level=logging.INFO) - - # Initialize application app_file_path = os.path.dirname(os.path.abspath(__file__)) initializer = AppInitializer(settings, app_file_path) diff --git a/python/packages/autogen-studio/autogenstudio/web/managers/connection.py b/python/packages/autogen-studio/autogenstudio/web/managers/connection.py index 67855e997404..8ea851918317 100644 --- a/python/packages/autogen-studio/autogenstudio/web/managers/connection.py +++ b/python/packages/autogen-studio/autogenstudio/web/managers/connection.py @@ -21,7 +21,7 @@ from fastapi import WebSocket, WebSocketDisconnect from ...database import DatabaseManager -from ...datamodel import Message, MessageConfig, Run, RunStatus, TeamResult +from ...datamodel import LLMCallEventMessage, Message, MessageConfig, Run, RunStatus, TeamResult from ...teammanager import TeamManager logger = logging.getLogger(__name__) @@ -111,6 +111,7 @@ async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None HandoffMessage, ToolCallRequestEvent, ToolCallExecutionEvent, + LLMCallEventMessage, ), ): await self._save_message(run_id, message) @@ -328,7 +329,15 @@ def _format_message(self, message: Any) -> Optional[dict]: } elif isinstance( - message, (TextMessage, StopMessage, HandoffMessage, ToolCallRequestEvent, ToolCallExecutionEvent) + message, + ( + TextMessage, + StopMessage, + HandoffMessage, + ToolCallRequestEvent, + ToolCallExecutionEvent, + LLMCallEventMessage, + ), ): return {"type": "message", "data": message.model_dump()} diff --git a/python/packages/autogen-studio/frontend/src/components/sidebar.tsx b/python/packages/autogen-studio/frontend/src/components/sidebar.tsx index b01672d81829..0b4bd3c38c9c 100644 --- a/python/packages/autogen-studio/frontend/src/components/sidebar.tsx +++ b/python/packages/autogen-studio/frontend/src/components/sidebar.tsx @@ -220,7 +220,7 @@ const Sidebar = ({ link, meta, isMobile }: SidebarProps) => { ], }) } - className="group hidden flex gap-x-3 rounded-md p-2 text-sm font-medium text-primary hover:text-accent hover:bg-secondary justify-center" + className="group flex gap-x-3 rounded-md p-2 text-sm font-medium text-primary hover:text-accent hover:bg-secondary justify-center" > @@ -248,7 +248,7 @@ const Sidebar = ({ link, meta, isMobile }: SidebarProps) => { ) : (
-
+
{" "} (
@@ -41,6 +52,11 @@ export const LoadingDots = ({ size = 8 }) => { ); }; +// import { memo, useState } from 'react'; +// import ReactMarkdown from 'react-markdown'; +// import { Minimize2, Maximize2, ArrowsMaximize, X } from 'lucide-react'; +// import { Tooltip } from 'antd'; + export const TruncatableText = memo( ({ content, @@ -48,14 +64,17 @@ export const TruncatableText = memo( className = "", jsonThreshold = 1000, textThreshold = 500, + showFullscreen = true, }: { content: string; isJson?: boolean; className?: string; jsonThreshold?: number; textThreshold?: number; + showFullscreen?: boolean; }) => { const [isExpanded, setIsExpanded] = useState(false); + const [isFullscreen, setIsFullscreen] = useState(false); const threshold = isJson ? jsonThreshold : textThreshold; const shouldTruncate = content.length > threshold; @@ -72,7 +91,7 @@ export const TruncatableText = memo(
- {/* {displayContent} */} {displayContent} {shouldTruncate && !isExpanded && ( -
+
)}
{shouldTruncate && ( -
- + + + {showFullscreen && ( + + + + )} +
+ )} + + {isFullscreen && ( +
setIsFullscreen(false)} + > +
e.stopPropagation()} > - {isExpanded ? "Show less" : "Show more"} - {isExpanded ? : } - + + + +
+ {isJson ? ( +
{content}
+ ) : ( + {content} + )} +
+
)}
diff --git a/python/packages/autogen-studio/frontend/src/components/views/gallery/default_gallery.json b/python/packages/autogen-studio/frontend/src/components/views/gallery/default_gallery.json index 1a8b4e555435..305cf1eecc0f 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/gallery/default_gallery.json +++ b/python/packages/autogen-studio/frontend/src/components/views/gallery/default_gallery.json @@ -828,7 +828,7 @@ "description": "A tool that performs Google searches using the Google Custom Search API. Requires the requests library, [GOOGLE_API_KEY, GOOGLE_CSE_ID] to be set, env variable to function.", "label": "Google Search Tool", "config": { - "source_code": "async def google_search(\n query: str,\n num_results: int = 5,\n include_snippets: bool = True,\n include_content: bool = True,\n content_max_length: Optional[int] = 15000,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: bool = True,\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Google search using the Custom Search API and optionally fetch webpage content.\n\n Args:\n query: Search query string\n num_results: Number of results to return (max 10)\n include_snippets: Include result snippets in output\n include_content: Include full webpage content in markdown format\n content_max_length: Maximum length of webpage content (if included)\n language: Language code for search results (e.g., en, es, fr)\n country: Optional country code for search results (e.g., us, uk)\n safe_search: Enable safe search filtering\n\n Returns:\n List[Dict[str, str]]: List of search results, each containing:\n - title: Result title\n - link: Result URL\n - snippet: Result description (if include_snippets=True)\n - content: Webpage content in markdown (if include_content=True)\n \"\"\"\n api_key = os.getenv(\"GOOGLE_API_KEY\")\n cse_id = os.getenv(\"GOOGLE_CSE_ID\")\n\n if not api_key or not cse_id:\n raise ValueError(\"Missing required environment variables. Please set GOOGLE_API_KEY and GOOGLE_CSE_ID.\")\n\n num_results = min(max(1, num_results), 10)\n\n async def fetch_page_content(url: str, max_length: Optional[int] = 50000) -> str:\n \"\"\"Helper function to fetch and convert webpage content to markdown\"\"\"\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n h2t = html2text.HTML2Text()\n h2t.body_width = 0\n h2t.ignore_images = False\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n markdown = h2t.handle(str(soup))\n\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except Exception as e:\n return f\"Error fetching content: {str(e)}\"\n\n params = {\n \"key\": api_key,\n \"cx\": cse_id,\n \"q\": query,\n \"num\": num_results,\n \"hl\": language,\n \"safe\": \"active\" if safe_search else \"off\",\n }\n\n if country:\n params[\"gl\"] = country\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(\"https://www.googleapis.com/customsearch/v1\", params=params, timeout=10)\n response.raise_for_status()\n data = response.json()\n\n results = []\n if \"items\" in data:\n for item in data[\"items\"]:\n result = {\"title\": item.get(\"title\", \"\"), \"link\": item.get(\"link\", \"\")}\n if include_snippets:\n result[\"snippet\"] = item.get(\"snippet\", \"\")\n\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n results.append(result)\n\n return results\n\n except httpx.RequestError as e:\n raise ValueError(f\"Failed to perform search: {str(e)}\") from e\n except KeyError as e:\n raise ValueError(f\"Invalid API response format: {str(e)}\") from e\n except Exception as e:\n raise ValueError(f\"Error during search: {str(e)}\") from e\n", + "source_code": "async def google_search(\n query: str,\n num_results: int = 3,\n include_snippets: bool = True,\n include_content: bool = True,\n content_max_length: Optional[int] = 15000,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: bool = True,\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Google search using the Custom Search API and optionally fetch webpage content.\n\n Args:\n query: Search query string\n num_results: Number of results to return (max 10)\n include_snippets: Include result snippets in output\n include_content: Include full webpage content in markdown format\n content_max_length: Maximum length of webpage content (if included)\n language: Language code for search results (e.g., en, es, fr)\n country: Optional country code for search results (e.g., us, uk)\n safe_search: Enable safe search filtering\n\n Returns:\n List[Dict[str, str]]: List of search results, each containing:\n - title: Result title\n - link: Result URL\n - snippet: Result description (if include_snippets=True)\n - content: Webpage content in markdown (if include_content=True)\n \"\"\"\n api_key = os.getenv(\"GOOGLE_API_KEY\")\n cse_id = os.getenv(\"GOOGLE_CSE_ID\")\n\n if not api_key or not cse_id:\n raise ValueError(\"Missing required environment variables. Please set GOOGLE_API_KEY and GOOGLE_CSE_ID.\")\n\n num_results = min(max(1, num_results), 10)\n\n async def fetch_page_content(url: str, max_length: Optional[int] = 50000) -> str:\n \"\"\"Helper function to fetch and convert webpage content to markdown\"\"\"\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n h2t = html2text.HTML2Text()\n h2t.body_width = 0\n h2t.ignore_images = False\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n markdown = h2t.handle(str(soup))\n\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except Exception as e:\n return f\"Error fetching content: {str(e)}\"\n\n params = {\n \"key\": api_key,\n \"cx\": cse_id,\n \"q\": query,\n \"num\": num_results,\n \"hl\": language,\n \"safe\": \"active\" if safe_search else \"off\",\n }\n\n if country:\n params[\"gl\"] = country\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(\"https://www.googleapis.com/customsearch/v1\", params=params, timeout=10)\n response.raise_for_status()\n data = response.json()\n\n results = []\n if \"items\" in data:\n for item in data[\"items\"]:\n result = {\"title\": item.get(\"title\", \"\"), \"link\": item.get(\"link\", \"\")}\n if include_snippets:\n result[\"snippet\"] = item.get(\"snippet\", \"\")\n\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n results.append(result)\n\n return results\n\n except httpx.RequestError as e:\n raise ValueError(f\"Failed to perform search: {str(e)}\") from e\n except KeyError as e:\n raise ValueError(f\"Invalid API response format: {str(e)}\") from e\n except Exception as e:\n raise ValueError(f\"Error during search: {str(e)}\") from e\n", "name": "google_search", "description": "\n Perform Google searches using the Custom Search API with optional webpage content fetching.\n Requires GOOGLE_API_KEY and GOOGLE_CSE_ID environment variables to be set.\n ", "global_imports": [ diff --git a/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx b/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx index 465554b58a74..1f251958801e 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx @@ -151,7 +151,7 @@ export const useGalleryStore = create()( }, }), { - name: "gallery-storage-v3", + name: "gallery-storage-v4", } ) ); diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/api.ts b/python/packages/autogen-studio/frontend/src/components/views/playground/api.ts similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/api.ts rename to python/packages/autogen-studio/frontend/src/components/views/playground/api.ts diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/agentflow.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/agentflow/agentflow.tsx similarity index 97% rename from python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/agentflow.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/chat/agentflow/agentflow.tsx index 3e2c1f4244e3..5dc1888cc1db 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/agentflow.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/agentflow/agentflow.tsx @@ -195,23 +195,6 @@ const createNode = ( }; } - // if (type === "task") { - // return { - // id, - // type: "agentNode", - // position: { x: 0, y: 0 }, - // data: { - // type: "task", - // label: "Task", - // description: run?.task.content || "", - // isActive: false, - // status: null, - // reason: null, - // draggable: false, - // }, - // }; - // } - return { id, type: "agentNode", @@ -567,7 +550,7 @@ const AgentFlow: React.FC = ({ teamConfig, run }) => { {settings.showGrid && } {settings.showMiniMap && } -
+
diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/agentnode.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/agentflow/agentnode.tsx similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/agentnode.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/chat/agentflow/agentnode.tsx diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/edge.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/agentflow/edge.tsx similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/edge.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/chat/agentflow/edge.tsx diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/edgemessagemodal.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/agentflow/edgemessagemodal.tsx similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/edgemessagemodal.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/chat/agentflow/edgemessagemodal.tsx diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/toolbar.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/agentflow/toolbar.tsx similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/chat/agentflow/toolbar.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/chat/agentflow/toolbar.tsx diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/chat.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/chat.tsx similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/chat/chat.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/chat/chat.tsx diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/chatinput.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/chatinput.tsx similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/chat/chatinput.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/chat/chatinput.tsx diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/inputrequest.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/inputrequest.tsx similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/chat/inputrequest.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/chat/inputrequest.tsx diff --git a/python/packages/autogen-studio/frontend/src/components/views/playground/chat/logrenderer.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/logrenderer.tsx new file mode 100644 index 000000000000..6a34f8a8e249 --- /dev/null +++ b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/logrenderer.tsx @@ -0,0 +1,187 @@ +import React, { useMemo, useState } from "react"; +import { Terminal, Maximize2, X } from "lucide-react"; +import { TruncatableText } from "../../atoms"; +import { Tooltip } from "antd"; + +interface LLMLogEvent { + type: "LLMCall"; + messages: { + content: string; + role: string; + name?: string; + }[]; + response: { + id: string; + choices: { + message: { + content: string; + role: string; + }; + }[]; + usage: { + completion_tokens: number; + prompt_tokens: number; + total_tokens: number; + }; + model: string; + }; + prompt_tokens: number; + completion_tokens: number; + agent_id: string; +} + +interface LLMLogRendererProps { + content: string; +} + +const formatTokens = (tokens: number) => { + return tokens >= 1000 ? `${(tokens / 1000).toFixed(1)}k` : tokens; +}; + +const FullLogView = ({ + event, + onClose, +}: { + event: LLMLogEvent; + onClose: () => void; +}) => ( +
+
e.stopPropagation()} + > + + + + +
+
+ +

LLM Call Details

+

+ {event.agent_id.split("/")[0]} • {event.response.model} •{" "} + {formatTokens(event.response.usage.total_tokens)} tokens +

+
+ +
+

Messages

+ {event.messages.map((msg, idx) => ( +
+
+ + {msg.role} {msg.name && `(${msg.name})`} + +
+ +
+ ))} +
+ +
+

Response

+
+ +
+
+ +
+
+
Model
+
{event.response.model}
+
+
+
Prompt Tokens
+
+ {event.response.usage.prompt_tokens} +
+
+
+
Completion Tokens
+
+ {event.response.usage.completion_tokens} +
+
+
+
Total Tokens
+
+ {event.response.usage.total_tokens} +
+
+
+
+
+
+); + +const LLMLogRenderer: React.FC = ({ content }) => { + const [showFullLog, setShowFullLog] = useState(false); + + const parsedContent = useMemo(() => { + try { + return JSON.parse(content) as LLMLogEvent; + } catch (e) { + console.error("Failed to parse LLM log content:", e); + return null; + } + }, [content]); + + if (!parsedContent) { + return ( +
+ + Invalid log format +
+ ); + } + + const { messages, response, agent_id } = parsedContent; + const agentName = messages[0]?.name || "Agent"; + const totalTokens = response.usage.total_tokens; + const shortAgentId = agent_id ? `${agent_id.split("/")[0]}` : ""; + + return ( + <> +
+ + + {shortAgentId ? `${shortAgentId}` : ""} • {response.model} •{" "} + {formatTokens(totalTokens)} tokens + + + + +
+ + {showFullLog && ( + setShowFullLog(false)} + /> + )} + + ); +}; + +export default LLMLogRenderer; diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/rendermessage.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/rendermessage.tsx similarity index 89% rename from python/packages/autogen-studio/frontend/src/components/views/session/chat/rendermessage.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/chat/rendermessage.tsx index e1fe720dec5a..ec1eedbaa17c 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/session/chat/rendermessage.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/rendermessage.tsx @@ -1,5 +1,12 @@ import React, { useState, memo } from "react"; -import { User, Bot, Maximize2, Minimize2, DraftingCompass } from "lucide-react"; +import { + User, + Bot, + Maximize2, + Minimize2, + DraftingCompass, + Brain, +} from "lucide-react"; import { AgentMessageConfig, FunctionCall, @@ -7,6 +14,7 @@ import { ImageContent, } from "../../../types/datamodel"; import { ClickableImage, TruncatableText } from "../../atoms"; +import LLMLogRenderer from "./logrenderer"; const TEXT_THRESHOLD = 400; const JSON_THRESHOLD = 800; @@ -143,9 +151,14 @@ export const RenderMessage: React.FC = ({ if (!message) return null; const isUser = messageUtils.isUser(message.source); const content = message.content; + const isLLMEventMessage = message.source === "llm_call_event"; return ( -
+
= ({ ${isUser ? "text-accent" : "text-primary"} `} > - {isUser ? : } + {isUser ? ( + + ) : message.source == "llm_call_event" ? ( + + ) : ( + + )}
@@ -177,6 +196,8 @@ export const RenderMessage: React.FC = ({ ) : messageUtils.isFunctionExecutionResult(content) ? ( + ) : message.source === "llm_call_event" ? ( + ) : ( { + return messages.filter((msg) => msg.config.source !== "llm_call_event"); +}; + +export const getLastMeaningfulMessage = ( + messages: Message[] +): Message | undefined => { + return messages + .filter((msg) => msg.config.source !== "llm_call_event") + .slice(-1)[0]; +}; + +// Type guard for message arrays +export const isAgentMessage = (message: Message): boolean => { + return message.config.source !== "llm_call_event"; +}; + const RunView: React.FC = ({ run, onInputResponse, @@ -45,6 +57,18 @@ const RunView: React.FC = ({ const isActive = run.status === "active" || run.status === "awaiting_input"; const [isFlowVisible, setIsFlowVisible] = useState(true); + const showLLMEvents = useSettingsStore( + (state) => state.playground.showLLMEvents + ); + console.log("showLLMEvents", showLLMEvents); + + const visibleMessages = useMemo(() => { + if (showLLMEvents) { + return run.messages; + } + return run.messages.filter((msg) => msg.config.source !== "llm_call_event"); + }, [run.messages, showLLMEvents]); + // Replace existing scroll effect with this simpler one useEffect(() => { setTimeout(() => { @@ -56,7 +80,7 @@ const RunView: React.FC = ({ } }, 450); }, [run.messages]); // Only depend on messages changing - + // console.log("run", run); const calculateThreadTokens = (messages: Message[]) => { // console.log("messages", messages); return messages.reduce((total, msg) => { @@ -123,7 +147,7 @@ const RunView: React.FC = ({ }; const lastResultMessage = run.team_result?.task_result.messages.slice(-1)[0]; - const lastMessage = run.messages.slice(-1)[0]; + const lastMessage = getLastMeaningfulMessage(visibleMessages); return (
@@ -202,19 +226,7 @@ const RunView: React.FC = ({
{lastMessage ? ( - // - + ) : ( <> {lastResultMessage && ( @@ -228,7 +240,7 @@ const RunView: React.FC = ({ {/* Thread Section */}
- {run.messages.length > 0 && ( + {visibleMessages.length > 0 && (
@@ -262,8 +274,8 @@ const RunView: React.FC = ({
- {calculateThreadTokens(run.messages)} tokens |{" "} - {run.messages.length} messages + {calculateThreadTokens(visibleMessages)} tokens |{" "} + {visibleMessages.length} messages
@@ -290,14 +302,14 @@ const RunView: React.FC = ({ {" "} {" "}
- {run.messages.map((msg, idx) => ( + {visibleMessages.map((msg, idx) => (
))} @@ -322,7 +334,7 @@ const RunView: React.FC = ({ {/* Agent Flow Visualization */} {isFlowVisible && (
-
+
{teamConfig && ( - + )}
)} diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/chat/types.ts b/python/packages/autogen-studio/frontend/src/components/views/playground/chat/types.ts similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/chat/types.ts rename to python/packages/autogen-studio/frontend/src/components/views/playground/chat/types.ts diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/editor.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/editor.tsx similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/editor.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/editor.tsx diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/manager.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/manager.tsx similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/manager.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/manager.tsx diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/newsession.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/newsession.tsx similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/newsession.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/newsession.tsx diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/sidebar.tsx b/python/packages/autogen-studio/frontend/src/components/views/playground/sidebar.tsx similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/sidebar.tsx rename to python/packages/autogen-studio/frontend/src/components/views/playground/sidebar.tsx diff --git a/python/packages/autogen-studio/frontend/src/components/views/session/types.ts b/python/packages/autogen-studio/frontend/src/components/views/playground/types.ts similarity index 100% rename from python/packages/autogen-studio/frontend/src/components/views/session/types.ts rename to python/packages/autogen-studio/frontend/src/components/views/playground/types.ts diff --git a/python/packages/autogen-studio/frontend/src/components/views/settings/manager.tsx b/python/packages/autogen-studio/frontend/src/components/views/settings/manager.tsx new file mode 100644 index 000000000000..fbae94e50960 --- /dev/null +++ b/python/packages/autogen-studio/frontend/src/components/views/settings/manager.tsx @@ -0,0 +1,163 @@ +import React, { useState, useEffect } from "react"; +import { ChevronRight, RotateCcw } from "lucide-react"; +import { Switch, Button, Tooltip } from "antd"; +import { MessagesSquare } from "lucide-react"; +import { useSettingsStore } from "./store"; +import { SettingsSidebar } from "./sidebar"; +import { SettingsSection } from "./types"; +import { LucideIcon } from "lucide-react"; + +interface SettingToggleProps { + checked: boolean; + onChange: (checked: boolean) => void; + label: string; + description?: string; +} + +interface SectionHeaderProps { + title: string; + icon: LucideIcon; + onReset: () => void; +} + +const SettingToggle: React.FC = ({ + checked, + onChange, + label, + description, +}) => ( +
+
+ + {description && ( + {description} + )} +
+ +
+); + +const SectionHeader: React.FC = ({ + title, + icon: Icon, + onReset, +}) => ( +
+
+ +

{title}

+
+ +
+); + +export const SettingsManager: React.FC = () => { + const [isSidebarOpen, setIsSidebarOpen] = useState(() => { + if (typeof window !== "undefined") { + const stored = localStorage.getItem("settingsSidebar"); + return stored !== null ? JSON.parse(stored) : true; + } + return true; + }); + + const { + playground, + updatePlaygroundSettings, + resetPlaygroundSettings, + resetAllSettings, + } = useSettingsStore(); + + const sections: SettingsSection[] = [ + { + id: "playground", + title: "Playground", + icon: MessagesSquare, + content: () => ( + <> + +
+ + updatePlaygroundSettings({ showLLMEvents: checked }) + } + label={"Show LLM Events"} + description="Display detailed LLM call logs in the message thread" + /> +
+ + ), + }, + ]; + + const [currentSection, setCurrentSection] = useState( + sections[0] + ); + + useEffect(() => { + if (typeof window !== "undefined") { + localStorage.setItem("settingsSidebar", JSON.stringify(isSidebarOpen)); + } + }, [isSidebarOpen]); + + return ( +
+
+ setIsSidebarOpen(!isSidebarOpen)} + onSelectSection={setCurrentSection} + /> +
+ +
+
+
+ Settings + + {currentSection.title} +
+ + + +
+

+ Settings are automatically saved and synced across browser + sessions +

+ +
+
+
+
+ ); +}; + +export default SettingsManager; diff --git a/python/packages/autogen-studio/frontend/src/components/views/settings/sidebar.tsx b/python/packages/autogen-studio/frontend/src/components/views/settings/sidebar.tsx new file mode 100644 index 000000000000..46a1a1fc2325 --- /dev/null +++ b/python/packages/autogen-studio/frontend/src/components/views/settings/sidebar.tsx @@ -0,0 +1,87 @@ +import React from "react"; +import { Button, Tooltip } from "antd"; +import { PanelLeftClose, PanelLeftOpen } from "lucide-react"; +import { SettingsSection } from "./types"; + +interface SettingsSidebarProps { + isOpen: boolean; + sections: SettingsSection[]; + currentSection: SettingsSection; + onToggle: () => void; + onSelectSection: (section: SettingsSection) => void; +} + +export const SettingsSidebar: React.FC = ({ + isOpen, + sections, + currentSection, + onToggle, + onSelectSection, +}) => { + // Render collapsed state + if (!isOpen) { + return ( +
+
+ + + +
+
+ ); + } + + return ( +
+ {/* Header */} +
+
+ Settings + + {sections.length} + +
+ + + +
+ +
+ {sections.map((section) => ( +
+
+
onSelectSection(section)} + > +
+ + {section.title} +
+
+
+ ))} +
+
+ ); +}; diff --git a/python/packages/autogen-studio/frontend/src/components/views/settings/store.tsx b/python/packages/autogen-studio/frontend/src/components/views/settings/store.tsx new file mode 100644 index 000000000000..9ffffd1efc54 --- /dev/null +++ b/python/packages/autogen-studio/frontend/src/components/views/settings/store.tsx @@ -0,0 +1,121 @@ +//settings/store.tsx +import { create } from "zustand"; +import { persist } from "zustand/middleware"; + +interface PlaygroundSettings { + showLLMEvents: boolean; + // Future playground settings + expandedMessagesByDefault?: boolean; + showAgentFlowByDefault?: boolean; +} + +interface TeamBuilderSettings { + // Future teambuilder settings + showAdvancedOptions?: boolean; + defaultAgentLayout?: "grid" | "list"; +} + +interface GallerySettings { + // Future gallery settings + viewMode?: "grid" | "list"; + sortBy?: "date" | "popularity"; +} + +interface SettingsState { + playground: PlaygroundSettings; + teamBuilder: TeamBuilderSettings; + gallery: GallerySettings; + // Actions to update settings + updatePlaygroundSettings: (settings: Partial) => void; + updateTeamBuilderSettings: (settings: Partial) => void; + updateGallerySettings: (settings: Partial) => void; + // Reset functions + resetPlaygroundSettings: () => void; + resetTeamBuilderSettings: () => void; + resetGallerySettings: () => void; + resetAllSettings: () => void; +} + +const DEFAULT_PLAYGROUND_SETTINGS: PlaygroundSettings = { + showLLMEvents: true, // Default to hiding LLM events +}; + +const DEFAULT_TEAMBUILDER_SETTINGS: TeamBuilderSettings = { + showAdvancedOptions: false, + defaultAgentLayout: "grid", +}; + +const DEFAULT_GALLERY_SETTINGS: GallerySettings = { + viewMode: "grid", + sortBy: "date", +}; + +export const useSettingsStore = create()( + persist( + (set) => ({ + // Initial state + playground: DEFAULT_PLAYGROUND_SETTINGS, + teamBuilder: DEFAULT_TEAMBUILDER_SETTINGS, + gallery: DEFAULT_GALLERY_SETTINGS, + + // Update functions + updatePlaygroundSettings: (settings) => + set((state) => ({ + playground: { ...state.playground, ...settings }, + })), + + updateTeamBuilderSettings: (settings) => + set((state) => ({ + teamBuilder: { ...state.teamBuilder, ...settings }, + })), + + updateGallerySettings: (settings) => + set((state) => ({ + gallery: { ...state.gallery, ...settings }, + })), + + // Reset functions + resetPlaygroundSettings: () => + set((state) => ({ + playground: DEFAULT_PLAYGROUND_SETTINGS, + })), + + resetTeamBuilderSettings: () => + set((state) => ({ + teamBuilder: DEFAULT_TEAMBUILDER_SETTINGS, + })), + + resetGallerySettings: () => + set((state) => ({ + gallery: DEFAULT_GALLERY_SETTINGS, + })), + + resetAllSettings: () => + set({ + playground: DEFAULT_PLAYGROUND_SETTINGS, + teamBuilder: DEFAULT_TEAMBUILDER_SETTINGS, + gallery: DEFAULT_GALLERY_SETTINGS, + }), + }), + { + name: "ags-app-settings", + partialize: (state) => ({ + playground: state.playground, + teamBuilder: state.teamBuilder, + gallery: state.gallery, + }), + } + ) +); + +// Example usage: +/* +import { useSettingsStore } from './stores/settings'; + +// In a component: +const { showLLMEvents } = useSettingsStore((state) => state.playground); +const updatePlaygroundSettings = useSettingsStore((state) => state.updatePlaygroundSettings); + +// Toggle LLM events +updatePlaygroundSettings({ showLLMEvents: !showLLMEvents }); +*/ diff --git a/python/packages/autogen-studio/frontend/src/components/views/settings/types.tsx b/python/packages/autogen-studio/frontend/src/components/views/settings/types.tsx new file mode 100644 index 000000000000..5a24b775dafc --- /dev/null +++ b/python/packages/autogen-studio/frontend/src/components/views/settings/types.tsx @@ -0,0 +1,8 @@ +import { LucideIcon } from "lucide-react"; + +export interface SettingsSection { + id: string; + title: string; + icon: LucideIcon; + content: () => JSX.Element; +} diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx index a48c998a9c76..bbaf1bbeff5d 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/builder.tsx @@ -41,7 +41,7 @@ import { MonacoEditor } from "../../monaco"; import { NodeEditor } from "./node-editor/node-editor"; import debounce from "lodash.debounce"; import { appContext } from "../../../../hooks/provider"; -import { sessionAPI } from "../../session/api"; +import { sessionAPI } from "../../playground/api"; import TestDrawer from "./testdrawer"; const { Sider, Content } = Layout; @@ -179,7 +179,6 @@ export const TeamBuilder: React.FC = ({ } if (onChange) { - console.log("Saving team configuration", component); const teamData: Partial = team ? { ...team, @@ -288,7 +287,7 @@ export const TeamBuilder: React.FC = ({ }; const handleTestDrawerClose = () => { - console.log("TestDrawer closed"); + // console.log("TestDrawer closed"); setTestDrawerVisible(false); }; diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/testdrawer.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/testdrawer.tsx index 0848abd7f5b3..134fa58406b7 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/testdrawer.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/testdrawer.tsx @@ -2,9 +2,9 @@ import React, { useContext, useEffect, useState } from "react"; import { Drawer, Button, message, Checkbox } from "antd"; import { Team, Session } from "../../../types/datamodel"; -import ChatView from "../../session/chat/chat"; +import ChatView from "../../playground/chat/chat"; import { appContext } from "../../../../hooks/provider"; -import { sessionAPI } from "../../session/api"; +import { sessionAPI } from "../../playground/api"; interface TestDrawerProps { isVisble: boolean; diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx index e60c3dcff3e9..952621012e83 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/sidebar.tsx @@ -3,11 +3,9 @@ import { Button, Tooltip } from "antd"; import { Bot, Plus, - Edit, Trash2, PanelLeftClose, PanelLeftOpen, - Calendar, Copy, GalleryHorizontalEnd, InfoIcon, @@ -15,7 +13,6 @@ import { } from "lucide-react"; import type { Team } from "../../types/datamodel"; import { getRelativeTimeString } from "../atoms"; -import { defaultTeam } from "./types"; import { useGalleryStore } from "../gallery/store"; interface TeamSidebarProps { @@ -43,8 +40,12 @@ export const TeamSidebar: React.FC = ({ }) => { const defaultGallery = useGalleryStore((state) => state.getDefaultGallery()); const createTeam = () => { - const newTeam = Object.assign({}, defaultTeam); - newTeam.component.label = "new_team_" + new Date().getTime(); + const newTeam = Object.assign( + {}, + { component: defaultGallery?.items.teams[0] } + ); + newTeam.component.label = + "default_team" + new Date().getTime().toString().slice(0, 2); onCreateTeam(newTeam); }; diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/types.ts b/python/packages/autogen-studio/frontend/src/components/views/team/types.ts index 768744eb228b..7a962a74e097 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/types.ts +++ b/python/packages/autogen-studio/frontend/src/components/views/team/types.ts @@ -15,85 +15,3 @@ export interface TeamListProps { onDelete: (teamId: number) => void; isLoading?: boolean; } - -export const defaultTeamConfig: Component = { - provider: "autogen_agentchat.teams.RoundRobinGroupChat", - component_type: "team", - version: 1, - component_version: 1, - description: - "A team of agents that chat with users in a round-robin fashion.", - label: "General Team", - config: { - participants: [ - { - provider: "autogen_agentchat.agents.AssistantAgent", - component_type: "agent", - version: 1, - component_version: 1, - config: { - name: "weather_agent", - model_client: { - provider: "autogen_ext.models.openai.OpenAIChatCompletionClient", - component_type: "model", - version: 1, - component_version: 1, - config: { model: "gpt-4o-mini" }, - }, - tools: [ - { - provider: "autogen_core.tools.FunctionTool", - component_type: "tool", - version: 1, - component_version: 1, - config: { - source_code: - 'async def get_weather(city: str) -> str:\n return f"The weather in {city} is 73 degrees and Sunny."\n', - name: "get_weather", - description: "", - global_imports: [], - has_cancellation_support: false, - }, - }, - ], - handoffs: [], - description: - "An agent that provides assistance with ability to use tools.", - system_message: - "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", - reflect_on_tool_use: false, - tool_call_summary_format: "{result}", - }, - }, - ], - termination_condition: { - provider: "autogen_agentchat.base.OrTerminationCondition", - component_type: "termination", - version: 1, - component_version: 1, - config: { - conditions: [ - { - provider: "autogen_agentchat.conditions.MaxMessageTermination", - component_type: "termination", - version: 1, - component_version: 1, - config: { max_messages: 10 }, - }, - { - provider: "autogen_agentchat.conditions.TextMentionTermination", - component_type: "termination", - version: 1, - component_version: 1, - config: { text: "TERMINATE" }, - }, - ], - }, - }, - max_turns: 1, - }, -}; - -export const defaultTeam: Team = { - component: defaultTeamConfig, -}; diff --git a/python/packages/autogen-studio/frontend/src/pages/index.tsx b/python/packages/autogen-studio/frontend/src/pages/index.tsx index 87ff4e342d14..22f0118acb68 100644 --- a/python/packages/autogen-studio/frontend/src/pages/index.tsx +++ b/python/packages/autogen-studio/frontend/src/pages/index.tsx @@ -1,8 +1,8 @@ import * as React from "react"; import Layout from "../components/layout"; import { graphql } from "gatsby"; -import ChatView from "../components/views/session/chat/chat"; -import { SessionManager } from "../components/views/session/manager"; +import ChatView from "../components/views/playground/chat/chat"; +import { SessionManager } from "../components/views/playground/manager"; // markup const IndexPage = ({ data }: any) => { diff --git a/python/packages/autogen-studio/frontend/src/pages/settings.tsx b/python/packages/autogen-studio/frontend/src/pages/settings.tsx index 5f9310f755ae..b6472894b682 100644 --- a/python/packages/autogen-studio/frontend/src/pages/settings.tsx +++ b/python/packages/autogen-studio/frontend/src/pages/settings.tsx @@ -1,21 +1,14 @@ import * as React from "react"; import Layout from "../components/layout"; import { graphql } from "gatsby"; -import { TriangleAlertIcon } from "lucide-react"; +import { SettingsManager } from "../components/views/settings/manager"; // markup const SettingsPage = ({ data }: any) => { return ( - +
-
Settings
-
- {" "} - Work in progress .. -
+
); From 5308b76d5f8d8b102756831a20000849975534e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rich=C3=A1rd=20Gyik=C3=B3?= Date: Sun, 9 Feb 2025 06:20:00 +0100 Subject: [PATCH 21/36] Add MCP adapters to autogen-ext (#5251) Co-authored-by: Eric Zhu --- .../autogen-core/docs/src/reference/index.md | 1 + .../python/autogen_ext.tools.mcp.rst | 8 + python/packages/autogen-ext/pyproject.toml | 5 + .../_sk_chat_completion_adapter.py | 2 +- .../src/autogen_ext/tools/mcp/__init__.py | 13 + .../src/autogen_ext/tools/mcp/_base.py | 101 +++++++ .../src/autogen_ext/tools/mcp/_config.py | 22 ++ .../src/autogen_ext/tools/mcp/_factory.py | 134 +++++++++ .../src/autogen_ext/tools/mcp/_session.py | 23 ++ .../src/autogen_ext/tools/mcp/_sse.py | 104 +++++++ .../src/autogen_ext/tools/mcp/_stdio.py | 61 ++++ .../autogen-ext/tests/tools/test_mcp_tools.py | 271 ++++++++++++++++++ python/uv.lock | 59 ++++ 13 files changed, 803 insertions(+), 1 deletion(-) create mode 100644 python/packages/autogen-core/docs/src/reference/python/autogen_ext.tools.mcp.rst create mode 100644 python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/tools/mcp/_config.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/tools/mcp/_session.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py create mode 100644 python/packages/autogen-ext/tests/tools/test_mcp_tools.py diff --git a/python/packages/autogen-core/docs/src/reference/index.md b/python/packages/autogen-core/docs/src/reference/index.md index fdaf598c0029..762e9f07ba35 100644 --- a/python/packages/autogen-core/docs/src/reference/index.md +++ b/python/packages/autogen-core/docs/src/reference/index.md @@ -54,6 +54,7 @@ python/autogen_ext.models.replay python/autogen_ext.models.azure python/autogen_ext.models.semantic_kernel python/autogen_ext.tools.langchain +python/autogen_ext.tools.mcp python/autogen_ext.tools.graphrag python/autogen_ext.tools.code_execution python/autogen_ext.tools.semantic_kernel diff --git a/python/packages/autogen-core/docs/src/reference/python/autogen_ext.tools.mcp.rst b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.tools.mcp.rst new file mode 100644 index 000000000000..3c7ffd9085a5 --- /dev/null +++ b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.tools.mcp.rst @@ -0,0 +1,8 @@ +autogen\_ext.tools.mcp +====================== + + +.. automodule:: autogen_ext.tools.mcp + :members: + :undoc-members: + :show-inheritance: diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index a2ffda98a019..f3e49b6ce992 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -112,6 +112,11 @@ semantic-kernel-all = [ rich = ["rich>=13.9.4"] +mcp = [ + "mcp>=1.1.3", + "json-schema-to-pydantic>=0.2.2" +] + [tool.hatch.build.targets.wheel] packages = ["src/autogen_ext"] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py index 07f3fe74802c..26e2092dc430 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py @@ -1,6 +1,6 @@ import json -from typing import Any, Literal, Mapping, Optional, Sequence import warnings +from typing import Any, Literal, Mapping, Optional, Sequence from autogen_core import FunctionCall from autogen_core._cancellation_token import CancellationToken diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py new file mode 100644 index 000000000000..83d76fcad502 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py @@ -0,0 +1,13 @@ +from ._config import McpServerParams, SseServerParams, StdioServerParams +from ._factory import mcp_server_tools +from ._sse import SseMcpToolAdapter +from ._stdio import StdioMcpToolAdapter + +__all__ = [ + "StdioMcpToolAdapter", + "StdioServerParams", + "SseMcpToolAdapter", + "SseServerParams", + "McpServerParams", + "mcp_server_tools", +] diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py new file mode 100644 index 000000000000..cf905800d6b2 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py @@ -0,0 +1,101 @@ +from abc import ABC +from typing import Any, Generic, Type, TypeVar + +from autogen_core import CancellationToken +from autogen_core.tools import BaseTool +from json_schema_to_pydantic import create_model +from mcp import Tool +from pydantic import BaseModel + +from ._config import McpServerParams +from ._session import create_mcp_server_session + +TServerParams = TypeVar("TServerParams", bound=McpServerParams) + + +class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]): + """ + Base adapter class for MCP tools to make them compatible with AutoGen. + + Args: + server_params (TServerParams): Parameters for the MCP server connection. + tool (Tool): The MCP tool to wrap. + """ + + component_type = "tool" + + def __init__(self, server_params: TServerParams, tool: Tool) -> None: + self._tool = tool + self._server_params = server_params + + # Extract name and description + name = tool.name + description = tool.description or "" + + # Create the input model from the tool's schema + input_model = create_model(tool.inputSchema) + + # Use Any as return type since MCP tool returns can vary + return_type: Type[Any] = object + + super().__init__(input_model, return_type, name, description) + + async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any: + """ + Run the MCP tool with the provided arguments. + + Args: + args (BaseModel): The arguments to pass to the tool. + cancellation_token (CancellationToken): Token to signal cancellation. + + Returns: + Any: The result of the tool execution. + + Raises: + Exception: If the operation is cancelled or the tool execution fails. + """ + kwargs = args.model_dump() + + try: + async with create_mcp_server_session(self._server_params) as session: + await session.initialize() + + if cancellation_token.is_cancelled(): + raise Exception("Operation cancelled") + + result = await session.call_tool(self._tool.name, kwargs) # type: ignore + + if result.isError: + raise Exception(f"MCP tool execution failed: {result.content}") + + return result.content + except Exception as e: + raise Exception(str(e)) from e + + @classmethod + async def from_server_params(cls, server_params: TServerParams, tool_name: str) -> "McpToolAdapter[TServerParams]": + """ + Create an instance of McpToolAdapter from server parameters and tool name. + + Args: + server_params (TServerParams): Parameters for the MCP server connection. + tool_name (str): The name of the tool to wrap. + + Returns: + McpToolAdapter[TServerParams]: An instance of McpToolAdapter. + + Raises: + ValueError: If the tool with the specified name is not found. + """ + async with create_mcp_server_session(server_params) as session: + await session.initialize() + + tools_response = await session.list_tools() + matching_tool = next((t for t in tools_response.tools if t.name == tool_name), None) + + if matching_tool is None: + raise ValueError( + f"Tool '{tool_name}' not found, available tools: {', '.join([t.name for t in tools_response.tools])}" + ) + + return cls(server_params=server_params, tool=matching_tool) diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_config.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_config.py new file mode 100644 index 000000000000..3f140587d30e --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_config.py @@ -0,0 +1,22 @@ +from typing import Any, TypeAlias + +from mcp import StdioServerParameters +from pydantic import BaseModel + + +class StdioServerParams(StdioServerParameters): + """Parameters for connecting to an MCP server over STDIO.""" + + pass + + +class SseServerParams(BaseModel): + """Parameters for connecting to an MCP server over SSE.""" + + url: str + headers: dict[str, Any] | None = None + timeout: float = 5 + sse_read_timeout: float = 60 * 5 + + +McpServerParams: TypeAlias = StdioServerParams | SseServerParams diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py new file mode 100644 index 000000000000..fcb29e3c4a99 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py @@ -0,0 +1,134 @@ +from ._config import McpServerParams, SseServerParams, StdioServerParams +from ._session import create_mcp_server_session +from ._sse import SseMcpToolAdapter +from ._stdio import StdioMcpToolAdapter + + +async def mcp_server_tools( + server_params: McpServerParams, +) -> list[StdioMcpToolAdapter | SseMcpToolAdapter]: + """Creates a list of MCP tool adapters that can be used with AutoGen agents. + + This factory function connects to an MCP server and returns adapters for all available tools. + The adapters can be directly assigned to an AutoGen agent's tools list. + + Args: + server_params (McpServerParams): Connection parameters for the MCP server. + Can be either StdioServerParams for command-line tools or + SseServerParams for HTTP/SSE services. + + Returns: + list[StdioMcpToolAdapter | SseMcpToolAdapter]: A list of tool adapters ready to use + with AutoGen agents. + + Examples: + + **Local file system MCP service over standard I/O example:** + + Install the filesystem server package from npm (requires Node.js 16+ and npm). + + .. code-block:: bash + + npm install -g @modelcontextprotocol/server-filesystem + + Create an agent that can use all tools from the local filesystem MCP server. + + .. code-block:: python + + import asyncio + from pathlib import Path + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_ext.tools.mcp import StdioServerParams, mcp_server_tools + from autogen_agentchat.agents import AssistantAgent + from autogen_core import CancellationToken + + + async def main() -> None: + # Setup server params for local filesystem access + desktop = str(Path.home() / "Desktop") + server_params = StdioServerParams( + command="npx.cmd", args=["-y", "@modelcontextprotocol/server-filesystem", desktop] + ) + + # Get all available tools from the server + tools = await mcp_server_tools(server_params) + + # Create an agent that can use all the tools + agent = AssistantAgent( + name="file_manager", + model_client=OpenAIChatCompletionClient(model="gpt-4"), + tools=tools, # type: ignore + ) + + # The agent can now use any of the filesystem tools + await agent.run(task="Create a file called test.txt with some content", cancellation_token=CancellationToken()) + + + if __name__ == "__main__": + asyncio.run(main()) + + **Local fetch MCP service over standard I/O example:** + + Install the `mcp-server-fetch` package. + + .. code-block:: bash + + pip install mcp-server-fetch + + Create an agent that can use the `fetch` tool from the local MCP server. + + .. code-block:: python + + import asyncio + + from autogen_agentchat.agents import AssistantAgent + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_ext.tools.mcp import StdioServerParams, mcp_server_tools + + + async def main() -> None: + # Get the fetch tool from mcp-server-fetch. + fetch_mcp_server = StdioServerParams(command="uvx", args=["mcp-server-fetch"]) + tools = await mcp_server_tools(fetch_mcp_server) + + # Create an agent that can use the fetch tool. + model_client = OpenAIChatCompletionClient(model="gpt-4o") + agent = AssistantAgent(name="fetcher", model_client=model_client, tools=tools, reflect_on_tool_use=True) # type: ignore + + # Let the agent fetch the content of a URL and summarize it. + result = await agent.run(task="Summarize the content of https://en.wikipedia.org/wiki/Seattle") + print(result.messages[-1].content) + + + asyncio.run(main()) + + + **Remote MCP service over SSE example:** + + .. code-block:: python + + from autogen_ext.tools.mcp import SseServerParams, mcp_server_tools + + + async def main() -> None: + # Setup server params for remote service + server_params = SseServerParams(url="https://api.example.com/mcp", headers={"Authorization": "Bearer token"}) + + # Get all available tools + tools = await mcp_server_tools(server_params) + + # Create an agent with all tools + agent = AssistantAgent(name="tool_user", model_client=OpenAIChatCompletionClient(model="gpt-4"), tools=tools) # type: ignore + + For more examples and detailed usage, see the samples directory in the package repository. + """ + async with create_mcp_server_session(server_params) as session: + await session.initialize() + + tools = await session.list_tools() + + if isinstance(server_params, StdioServerParams): + return [StdioMcpToolAdapter(server_params=server_params, tool=tool) for tool in tools.tools] + elif isinstance(server_params, SseServerParams): + return [SseMcpToolAdapter(server_params=server_params, tool=tool) for tool in tools.tools] + raise ValueError(f"Unsupported server params type: {type(server_params)}") diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_session.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_session.py new file mode 100644 index 000000000000..964a5c9f77cf --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_session.py @@ -0,0 +1,23 @@ +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from mcp import ClientSession +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client + +from ._config import McpServerParams, SseServerParams, StdioServerParams + + +@asynccontextmanager +async def create_mcp_server_session( + server_params: McpServerParams, +) -> AsyncGenerator[ClientSession, None]: + """Create an MCP client session for the given server parameters.""" + if isinstance(server_params, StdioServerParams): + async with stdio_client(server_params) as (read, write): + async with ClientSession(read_stream=read, write_stream=write) as session: + yield session + elif isinstance(server_params, SseServerParams): + async with sse_client(**server_params.model_dump()) as (read, write): + async with ClientSession(read_stream=read, write_stream=write) as session: + yield session diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py new file mode 100644 index 000000000000..7b875649a4f4 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py @@ -0,0 +1,104 @@ +from autogen_core import Component +from mcp import Tool +from pydantic import BaseModel +from typing_extensions import Self + +from ._base import McpToolAdapter +from ._config import SseServerParams + + +class SseMcpToolAdapterConfig(BaseModel): + """Configuration for the MCP tool adapter.""" + + server_params: SseServerParams + tool: Tool + + +class SseMcpToolAdapter( + McpToolAdapter[SseServerParams], + Component[SseMcpToolAdapterConfig], +): + """ + Allows you to wrap an MCP tool running over Server-Sent Events (SSE) and make it available to AutoGen. + + This adapter enables using MCP-compatible tools that communicate over HTTP with SSE + with AutoGen agents. Common use cases include integrating with remote MCP services, + cloud-based tools, and web APIs that implement the Model Context Protocol (MCP). + + Args: + server_params (SseServerParameters): Parameters for the MCP server connection, + including URL, headers, and timeouts + tool (Tool): The MCP tool to wrap + + Examples: + Use a remote translation service that implements MCP over SSE to create tools + that allow AutoGen agents to perform translations: + + .. code-block:: python + + import asyncio + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_ext.tools.mcp import SseMcpToolAdapter, SseServerParams + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.ui import Console + from autogen_core import CancellationToken + + + async def main() -> None: + # Create server params for the remote MCP service + server_params = SseServerParams( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer your-api-key", "Content-Type": "application/json"}, + timeout=30, # Connection timeout in seconds + ) + + # Get the translation tool from the server + adapter = await SseMcpToolAdapter.from_server_params(server_params, "translate") + + # Create an agent that can use the translation tool + model_client = OpenAIChatCompletionClient(model="gpt-4") + agent = AssistantAgent( + name="translator", + model_client=model_client, + tools=[adapter], + system_message="You are a helpful translation assistant.", + ) + + # Let the agent translate some text + await Console( + agent.run_stream(task="Translate 'Hello, how are you?' to Spanish", cancellation_token=CancellationToken()) + ) + + + if __name__ == "__main__": + asyncio.run(main()) + + """ + + component_config_schema = SseMcpToolAdapterConfig + component_provider_override = "autogen_ext.tools.mcp.SseMcpToolAdapter" + + def __init__(self, server_params: SseServerParams, tool: Tool) -> None: + super().__init__(server_params=server_params, tool=tool) + + def _to_config(self) -> SseMcpToolAdapterConfig: + """ + Convert the adapter to its configuration representation. + + Returns: + SseMcpToolAdapterConfig: The configuration of the adapter. + """ + return SseMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool) + + @classmethod + def _from_config(cls, config: SseMcpToolAdapterConfig) -> Self: + """ + Create an instance of SseMcpToolAdapter from its configuration. + + Args: + config (SseMcpToolAdapterConfig): The configuration of the adapter. + + Returns: + SseMcpToolAdapter: An instance of SseMcpToolAdapter. + """ + return cls(server_params=config.server_params, tool=config.tool) diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py new file mode 100644 index 000000000000..629b1a46ec1a --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py @@ -0,0 +1,61 @@ +from autogen_core import Component +from mcp import Tool +from pydantic import BaseModel +from typing_extensions import Self + +from ._base import McpToolAdapter +from ._config import StdioServerParams + + +class StdioMcpToolAdapterConfig(BaseModel): + """Configuration for the MCP tool adapter.""" + + server_params: StdioServerParams + tool: Tool + + +class StdioMcpToolAdapter( + McpToolAdapter[StdioServerParams], + Component[StdioMcpToolAdapterConfig], +): + """Allows you to wrap an MCP tool running over STDIO and make it available to AutoGen. + + This adapter enables using MCP-compatible tools that communicate over standard input/output + with AutoGen agents. Common use cases include wrapping command-line tools and local services + that implement the Model Context Protocol (MCP). + + Args: + server_params (StdioServerParams): Parameters for the MCP server connection, + including command to run and its arguments + tool (Tool): The MCP tool to wrap + + See :func:`~autogen_ext.tools.mcp.mcp_server_tools` for examples. + """ + + component_config_schema = StdioMcpToolAdapterConfig + component_provider_override = "autogen_ext.tools.mcp.StdioMcpToolAdapter" + + def __init__(self, server_params: StdioServerParams, tool: Tool) -> None: + super().__init__(server_params=server_params, tool=tool) + + def _to_config(self) -> StdioMcpToolAdapterConfig: + """ + Convert the adapter to its configuration representation. + + Returns: + StdioMcpToolAdapterConfig: The configuration of the adapter. + """ + return StdioMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool) + + @classmethod + def _from_config(cls, config: StdioMcpToolAdapterConfig) -> Self: + """ + Create an instance of StdioMcpToolAdapter from its configuration. + + Args: + config (StdioMcpToolAdapterConfig): The configuration of the adapter. + + Returns: + StdioMcpToolAdapter: An instance of StdioMcpToolAdapter. + """ + return cls(server_params=config.server_params, tool=config.tool) diff --git a/python/packages/autogen-ext/tests/tools/test_mcp_tools.py b/python/packages/autogen-ext/tests/tools/test_mcp_tools.py new file mode 100644 index 000000000000..7af9933d3a94 --- /dev/null +++ b/python/packages/autogen-ext/tests/tools/test_mcp_tools.py @@ -0,0 +1,271 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +from autogen_core import CancellationToken +from autogen_ext.tools.mcp import ( + SseMcpToolAdapter, + SseServerParams, + StdioMcpToolAdapter, + StdioServerParams, +) +from json_schema_to_pydantic import create_model +from mcp import ClientSession, Tool + + +@pytest.fixture +def sample_tool() -> Tool: + return Tool( + name="test_tool", + description="A test tool", + inputSchema={ + "type": "object", + "properties": {"test_param": {"type": "string"}}, + "required": ["test_param"], + }, + ) + + +@pytest.fixture +def sample_server_params() -> StdioServerParams: + return StdioServerParams(command="echo", args=["test"]) + + +@pytest.fixture +def sample_sse_tool() -> Tool: + return Tool( + name="test_sse_tool", + description="A test SSE tool", + inputSchema={ + "type": "object", + "properties": {"test_param": {"type": "string"}}, + "required": ["test_param"], + }, + ) + + +@pytest.fixture +def mock_sse_session() -> AsyncMock: + session = AsyncMock(spec=ClientSession) + session.initialize = AsyncMock() + session.call_tool = AsyncMock() + session.list_tools = AsyncMock() + return session + + +@pytest.fixture +def mock_session() -> AsyncMock: + session = AsyncMock(spec=ClientSession) + session.initialize = AsyncMock() + session.call_tool = AsyncMock() + session.list_tools = AsyncMock() + return session + + +@pytest.fixture +def mock_tool_response() -> MagicMock: + response = MagicMock() + response.isError = False + response.content = {"result": "test_output"} + return response + + +@pytest.fixture +def cancellation_token() -> CancellationToken: + return CancellationToken() + + +def test_adapter_config_serialization(sample_tool: Tool, sample_server_params: StdioServerParams) -> None: + """Test that adapter can be saved to and loaded from config.""" + original_adapter = StdioMcpToolAdapter(server_params=sample_server_params, tool=sample_tool) + config = original_adapter.dump_component() + loaded_adapter = StdioMcpToolAdapter.load_component(config) + + # Test that the loaded adapter has the same properties + assert loaded_adapter.name == "test_tool" + assert loaded_adapter.description == "A test tool" + + # Verify schema structure + schema = loaded_adapter.schema + assert "parameters" in schema, "Schema must have parameters" + params_schema = schema["parameters"] + assert isinstance(params_schema, dict), "Parameters must be a dict" + assert "type" in params_schema, "Parameters must have type" + assert "required" in params_schema, "Parameters must have required fields" + assert "properties" in params_schema, "Parameters must have properties" + + # Compare schema content + assert params_schema["type"] == sample_tool.inputSchema["type"] + assert params_schema["required"] == sample_tool.inputSchema["required"] + assert ( + params_schema["properties"]["test_param"]["type"] == sample_tool.inputSchema["properties"]["test_param"]["type"] + ) + + +@pytest.mark.asyncio +async def test_mcp_tool_execution( + sample_tool: Tool, + sample_server_params: StdioServerParams, + mock_session: AsyncMock, + mock_tool_response: MagicMock, + cancellation_token: CancellationToken, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that adapter properly executes tools through ClientSession.""" + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + monkeypatch.setattr( + "autogen_ext.tools.mcp._base.create_mcp_server_session", + lambda *args, **kwargs: mock_context, # type: ignore + ) + + mock_session.call_tool.return_value = mock_tool_response + + adapter = StdioMcpToolAdapter(server_params=sample_server_params, tool=sample_tool) + result = await adapter.run( + args=create_model(sample_tool.inputSchema)(**{"test_param": "test"}), + cancellation_token=cancellation_token, + ) + + assert result == mock_tool_response.content + mock_session.initialize.assert_called_once() + mock_session.call_tool.assert_called_once() + + +@pytest.mark.asyncio +async def test_adapter_from_server_params( + sample_tool: Tool, + sample_server_params: StdioServerParams, + mock_session: AsyncMock, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that adapter can be created from server parameters.""" + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + monkeypatch.setattr( + "autogen_ext.tools.mcp._base.create_mcp_server_session", + lambda *args, **kwargs: mock_context, # type: ignore + ) + + mock_session.list_tools.return_value.tools = [sample_tool] + + adapter = await StdioMcpToolAdapter.from_server_params(sample_server_params, "test_tool") + + assert isinstance(adapter, StdioMcpToolAdapter) + assert adapter.name == "test_tool" + assert adapter.description == "A test tool" + + # Verify schema structure + schema = adapter.schema + assert "parameters" in schema, "Schema must have parameters" + params_schema = schema["parameters"] + assert isinstance(params_schema, dict), "Parameters must be a dict" + assert "type" in params_schema, "Parameters must have type" + assert "required" in params_schema, "Parameters must have required fields" + assert "properties" in params_schema, "Parameters must have properties" + + # Compare schema content + assert params_schema["type"] == sample_tool.inputSchema["type"] + assert params_schema["required"] == sample_tool.inputSchema["required"] + assert ( + params_schema["properties"]["test_param"]["type"] == sample_tool.inputSchema["properties"]["test_param"]["type"] + ) + + +@pytest.mark.asyncio +async def test_sse_adapter_config_serialization(sample_sse_tool: Tool) -> None: + """Test that SSE adapter can be saved to and loaded from config.""" + params = SseServerParams(url="http://test-url") + original_adapter = SseMcpToolAdapter(server_params=params, tool=sample_sse_tool) + config = original_adapter.dump_component() + loaded_adapter = SseMcpToolAdapter.load_component(config) + + # Test that the loaded adapter has the same properties + assert loaded_adapter.name == "test_sse_tool" + assert loaded_adapter.description == "A test SSE tool" + + # Verify schema structure + schema = loaded_adapter.schema + assert "parameters" in schema, "Schema must have parameters" + params_schema = schema["parameters"] + assert isinstance(params_schema, dict), "Parameters must be a dict" + assert "type" in params_schema, "Parameters must have type" + assert "required" in params_schema, "Parameters must have required fields" + assert "properties" in params_schema, "Parameters must have properties" + + # Compare schema content + assert params_schema["type"] == sample_sse_tool.inputSchema["type"] + assert params_schema["required"] == sample_sse_tool.inputSchema["required"] + assert ( + params_schema["properties"]["test_param"]["type"] + == sample_sse_tool.inputSchema["properties"]["test_param"]["type"] + ) + + +@pytest.mark.asyncio +async def test_sse_tool_execution( + sample_sse_tool: Tool, + mock_sse_session: AsyncMock, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that SSE adapter properly executes tools through ClientSession.""" + params = SseServerParams(url="http://test-url") + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_sse_session + + mock_sse_session.call_tool.return_value = MagicMock(isError=False, content={"result": "test_output"}) + + monkeypatch.setattr( + "autogen_ext.tools.mcp._base.create_mcp_server_session", + lambda *args, **kwargs: mock_context, # type: ignore + ) + + adapter = SseMcpToolAdapter(server_params=params, tool=sample_sse_tool) + result = await adapter.run( + args=create_model(sample_sse_tool.inputSchema)(**{"test_param": "test"}), + cancellation_token=CancellationToken(), + ) + + assert result == mock_sse_session.call_tool.return_value.content + mock_sse_session.initialize.assert_called_once() + mock_sse_session.call_tool.assert_called_once() + + +@pytest.mark.asyncio +async def test_sse_adapter_from_server_params( + sample_sse_tool: Tool, + mock_sse_session: AsyncMock, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that SSE adapter can be created from server parameters.""" + params = SseServerParams(url="http://test-url") + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_sse_session + monkeypatch.setattr( + "autogen_ext.tools.mcp._base.create_mcp_server_session", + lambda *args, **kwargs: mock_context, # type: ignore + ) + + mock_sse_session.list_tools.return_value.tools = [sample_sse_tool] + + adapter = await SseMcpToolAdapter.from_server_params(params, "test_sse_tool") + + assert isinstance(adapter, SseMcpToolAdapter) + assert adapter.name == "test_sse_tool" + assert adapter.description == "A test SSE tool" + + # Verify schema structure + schema = adapter.schema + assert "parameters" in schema, "Schema must have parameters" + params_schema = schema["parameters"] + assert isinstance(params_schema, dict), "Parameters must be a dict" + assert "type" in params_schema, "Parameters must have type" + assert "required" in params_schema, "Parameters must have required fields" + assert "properties" in params_schema, "Parameters must have properties" + + # Compare schema content + assert params_schema["type"] == sample_sse_tool.inputSchema["type"] + assert params_schema["required"] == sample_sse_tool.inputSchema["required"] + assert ( + params_schema["properties"]["test_param"]["type"] + == sample_sse_tool.inputSchema["properties"]["test_param"]["type"] + ) diff --git a/python/uv.lock b/python/uv.lock index eab21f4ab3cd..b4b1e170df60 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -2,6 +2,7 @@ version = 1 requires-python = ">=3.10, <3.13" resolution-markers = [ "python_full_version >= '3.12.4' and sys_platform == 'darwin'", + "python_version < '0'", "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'", "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'", @@ -611,6 +612,10 @@ magentic-one = [ { name = "pillow" }, { name = "playwright" }, ] +mcp = [ + { name = "json-schema-to-pydantic" }, + { name = "mcp" }, +] openai = [ { name = "aiofiles" }, { name = "openai" }, @@ -694,10 +699,12 @@ requires-dist = [ { name = "graphrag", marker = "extra == 'graphrag'", specifier = ">=1.0.1" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = "~=1.70.0" }, { name = "ipykernel", marker = "extra == 'jupyter-executor'", specifier = ">=6.29.5" }, + { name = "json-schema-to-pydantic", marker = "extra == 'mcp'", specifier = ">=0.2.2" }, { name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" }, { name = "markitdown", marker = "extra == 'file-surfer'", specifier = ">=0.0.1a2" }, { name = "markitdown", marker = "extra == 'magentic-one'", specifier = ">=0.0.1a2" }, { name = "markitdown", marker = "extra == 'web-surfer'", specifier = ">=0.0.1a2" }, + { name = "mcp", marker = "extra == 'mcp'", specifier = ">=1.1.3" }, { name = "nbclient", marker = "extra == 'jupyter-executor'", specifier = ">=0.10.2" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.52.2" }, { name = "openai-whisper", marker = "extra == 'video-surfer'" }, @@ -2813,6 +2820,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fd/2d/79a46330c4b97ee90dd403fb0d267da7b25b24d7db604c5294e5c57d5f7c/json_repair-0.30.3-py3-none-any.whl", hash = "sha256:63bb588162b0958ae93d85356ecbe54c06b8c33f8a4834f93fa2719ea669804e", size = 18951 }, ] +[[package]] +name = "json-schema-to-pydantic" +version = "0.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/87/af1870beea329744a759349b972b309f8c95ae21e986e387e19733b85cc9/json_schema_to_pydantic-0.2.2.tar.gz", hash = "sha256:685db8d93aa29ccd257b2803fcd9a956c527e5fb108a523cbfe8cac1239b3785", size = 34158 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/8d/3c738e4b4b041269c4a506544b90e9ba924bbd800c8d496ed3e5a6da0265/json_schema_to_pydantic-0.2.2-py3-none-any.whl", hash = "sha256:01b82d234f2b482a273e117e29d063b6b86021a250035873d6eec4b85b70e64d", size = 11396 }, +] + [[package]] name = "jsonpatch" version = "1.33" @@ -3854,6 +3873,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899 }, ] +[[package]] +name = "mcp" +version = "1.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "pydantic" }, + { name = "sse-starlette" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f7/60/66ebfd280b197f9a9d074c9e46cb1ac3186a32d12e6bd0425c24fe7cf7e8/mcp-1.1.3.tar.gz", hash = "sha256:af11018b8e9153cdd25f3722ec639fe7a462c00213a330fd6f593968341a9883", size = 57903 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/08/cfcfa13e41f8d27503c51a8cbf1939d720073ace92469d08655bb5de1b24/mcp-1.1.3-py3-none-any.whl", hash = "sha256:71462d6cd7c06c14689dfcf110ff22286ba1b608cfc3515c0a5cbe33d131731a", size = 36997 }, +] + [[package]] name = "mdit-py-plugins" version = "0.4.2" @@ -4308,6 +4344,7 @@ name = "nvidia-cublas-cu12" version = "12.4.5.8" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771 }, { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 }, ] @@ -4316,6 +4353,7 @@ name = "nvidia-cuda-cupti-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556 }, { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 }, ] @@ -4324,6 +4362,7 @@ name = "nvidia-cuda-nvrtc-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372 }, { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 }, ] @@ -4332,6 +4371,7 @@ name = "nvidia-cuda-runtime-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177 }, { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 }, ] @@ -4354,6 +4394,7 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 }, { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, ] @@ -4362,6 +4403,7 @@ name = "nvidia-curand-cu12" version = "10.3.5.147" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811 }, { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 }, ] @@ -4375,6 +4417,7 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 }, { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, ] @@ -4386,6 +4429,7 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 }, { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, ] @@ -4402,6 +4446,7 @@ name = "nvidia-nvjitlink-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510 }, { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 }, ] @@ -4410,6 +4455,7 @@ name = "nvidia-nvtx-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417 }, { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 }, ] @@ -6785,6 +6831,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dd/b1/3af5104b716c420e40a6ea1b09886cae3a1b9f4538343875f637755cae5b/sqlmodel-0.0.22-py3-none-any.whl", hash = "sha256:a1ed13e28a1f4057cbf4ff6cdb4fc09e85702621d3259ba17b3c230bfb2f941b", size = 28276 }, ] +[[package]] +name = "sse-starlette" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/a4/80d2a11af59fe75b48230846989e93979c892d3a20016b42bb44edb9e398/sse_starlette-2.2.1.tar.gz", hash = "sha256:54470d5f19274aeed6b2d473430b08b4b379ea851d953b11d7f1c4a2c118b419", size = 17376 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/e0/5b8bd393f27f4a62461c5cf2479c75a2cc2ffa330976f9f00f5f6e4f50eb/sse_starlette-2.2.1-py3-none-any.whl", hash = "sha256:6410a3d3ba0c89e7675d4c273a301d64649c03a5ef1ca101f10b47f895fd0e99", size = 10120 }, +] + [[package]] name = "stack-data" version = "0.6.3" From 9a028acf9ff88ecf0183513e1060937d8a8c18a1 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Sun, 9 Feb 2025 10:12:59 -0800 Subject: [PATCH 22/36] feat: enhance Gemini model support in OpenAI client and tests (#5461) --- .../teams/_group_chat/_selector_group_chat.py | 8 +- .../tests/test_group_chat_endpoint.py | 9 +-- .../agentchat-user-guide/quickstart.ipynb | 79 +++++-------------- .../tutorial/models.ipynb | 19 ++--- .../components/model-clients.ipynb | 19 ++--- .../src/autogen_core/models/_model_client.py | 29 +++++++ .../autogen_ext/models/openai/_model_info.py | 36 +++++++++ .../models/openai/_openai_client.py | 9 +++ .../tests/models/test_openai_model_client.py | 20 +++-- 9 files changed, 116 insertions(+), 112 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index a69934a64dab..547910ce1898 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -145,13 +145,7 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st roles=roles, participants=str(participants), history=history ) select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage] - if self._model_client.model_info["family"] in [ - ModelFamily.GPT_4, - ModelFamily.GPT_4O, - ModelFamily.GPT_35, - ModelFamily.O1, - ModelFamily.O3, - ]: + if ModelFamily.is_openai(self._model_client.model_info["family"]): select_speaker_messages = [SystemMessage(content=select_speaker_prompt)] else: # Many other models need a UserMessage to respond to diff --git a/python/packages/autogen-agentchat/tests/test_group_chat_endpoint.py b/python/packages/autogen-agentchat/tests/test_group_chat_endpoint.py index 301c2f21d777..e32d8fd97feb 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat_endpoint.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat_endpoint.py @@ -4,7 +4,7 @@ from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import SelectorGroupChat from autogen_agentchat.ui import Console -from autogen_core.models import ChatCompletionClient, ModelFamily +from autogen_core.models import ChatCompletionClient from autogen_ext.models.openai import OpenAIChatCompletionClient @@ -36,14 +36,7 @@ async def test_selector_group_chat_gemini() -> None: model_client = OpenAIChatCompletionClient( model="gemini-1.5-flash", - base_url="https://generativelanguage.googleapis.com/v1beta/openai/", api_key=api_key, - model_info={ - "vision": True, - "function_calling": True, - "json_output": True, - "family": ModelFamily.GEMINI_1_5_FLASH, - }, ) await _test_selector_group_chat(model_client) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/quickstart.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/quickstart.ipynb index 06e265b1429d..4f1a11d0464f 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/quickstart.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/quickstart.ipynb @@ -12,11 +12,7 @@ "metadata": {}, "source": [ "Via AgentChat, you can build applications quickly using preset agents.\n", - "To illustrate this, we will begin with creating a team of a single tool-use\n", - "agent that you can chat with.\n", - "\n", - "The following code uses the OpenAI model. If you haven't already, you need to\n", - "install the following package and extension:" + "To illustrate this, we will begin with creating a single tool-use agent." ] }, { @@ -29,7 +25,7 @@ }, "outputs": [], "source": [ - "pip install -U \"autogen-agentchat\" \"autogen-ext[openai,azure]\"" + "pip install -U \"autogen-ext[openai,azure]\"" ] }, { @@ -37,12 +33,14 @@ "metadata": {}, "source": [ "To use Azure OpenAI models and AAD authentication,\n", - "you can follow the instructions [here](./tutorial/models.ipynb#azure-openai)." + "you can follow the instructions [here](./tutorial/models.ipynb#azure-openai).\n", + "\n", + "To use other models, see [Models](./tutorial/models.ipynb)." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -50,94 +48,53 @@ "output_type": "stream", "text": [ "---------- user ----------\n", - "What is the weather in NYC?\n", + "What is the weather in New York?\n", "---------- weather_agent ----------\n", - "[FunctionCall(id='call_vN04UiNJgqSz6g3MHt7Renig', arguments='{\"city\":\"New York City\"}', name='get_weather')]\n", - "[Prompt tokens: 75, Completion tokens: 16]\n", + "[FunctionCall(id='call_ciy1Ecys9LH201cyim10xlnQ', arguments='{\"city\":\"New York\"}', name='get_weather')]\n", "---------- weather_agent ----------\n", - "[FunctionExecutionResult(content='The weather in New York City is 73 degrees and Sunny.', call_id='call_vN04UiNJgqSz6g3MHt7Renig')]\n", - "---------- weather_agent ----------\n", - "The weather in New York City is 73 degrees and Sunny.\n", - "---------- Summary ----------\n", - "Number of messages: 4\n", - "Finish reason: Maximum number of turns 1 reached.\n", - "Total prompt tokens: 75\n", - "Total completion tokens: 16\n", - "Duration: 1.15 seconds\n", - "---------- user ----------\n", - "What is the weather in Seattle?\n", + "[FunctionExecutionResult(content='The weather in New York is 73 degrees and Sunny.', call_id='call_ciy1Ecys9LH201cyim10xlnQ')]\n", "---------- weather_agent ----------\n", - "[FunctionCall(id='call_BesYutZXJIMfu2TlDZgodIEj', arguments='{\"city\":\"Seattle\"}', name='get_weather')]\n", - "[Prompt tokens: 127, Completion tokens: 14]\n", - "---------- weather_agent ----------\n", - "[FunctionExecutionResult(content='The weather in Seattle is 73 degrees and Sunny.', call_id='call_BesYutZXJIMfu2TlDZgodIEj')]\n", - "---------- weather_agent ----------\n", - "The weather in Seattle is 73 degrees and Sunny.\n", - "---------- Summary ----------\n", - "Number of messages: 4\n", - "Finish reason: Maximum number of turns 1 reached.\n", - "Total prompt tokens: 127\n", - "Total completion tokens: 14\n", - "Duration: 2.38 seconds\n" + "The weather in New York is currently 73 degrees and sunny.\n" ] } ], "source": [ "from autogen_agentchat.agents import AssistantAgent\n", - "from autogen_agentchat.teams import RoundRobinGroupChat\n", "from autogen_agentchat.ui import Console\n", "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", "\n", "\n", "# Define a tool\n", "async def get_weather(city: str) -> str:\n", + " \"\"\"Get the weather for a given city.\"\"\"\n", " return f\"The weather in {city} is 73 degrees and Sunny.\"\n", "\n", "\n", "async def main() -> None:\n", - " # Define an agent\n", - " weather_agent = AssistantAgent(\n", + " agent = AssistantAgent(\n", " name=\"weather_agent\",\n", " model_client=OpenAIChatCompletionClient(\n", - " model=\"gpt-4o-2024-08-06\",\n", + " model=\"gpt-4o\",\n", " # api_key=\"YOUR_API_KEY\",\n", " ),\n", " tools=[get_weather],\n", + " system_message=\"You are a helpful assistant.\",\n", + " reflect_on_tool_use=True,\n", " )\n", - "\n", - " # Define a team with a single agent and maximum auto-gen turns of 1.\n", - " agent_team = RoundRobinGroupChat([weather_agent], max_turns=1)\n", - "\n", - " while True:\n", - " # Get user input from the console.\n", - " user_input = input(\"Enter a message (type 'exit' to leave): \")\n", - " if user_input.strip().lower() == \"exit\":\n", - " break\n", - " # Run the team and stream messages to the console.\n", - " stream = agent_team.run_stream(task=user_input)\n", - " await Console(stream)\n", + " await Console(agent.run_stream(task=\"What is the weather in New York?\"))\n", "\n", "\n", "# NOTE: if running this inside a Python script you'll need to use asyncio.run(main()).\n", "await main()" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The code snippet above introduces two high level concepts in AgentChat: *Agent* and *Team*. An Agent helps us define what actions are taken when a message is received. Specifically, we use the {py:class}`~autogen_agentchat.agents.AssistantAgent` preset - an agent that can be given access to a model (e.g., LLM) and tools (functions) that it can then use to address tasks. A Team helps us define the rules for how agents interact with each other. In the {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat` team, agents respond in a sequential round-robin fashion.\n", - "In this case, we have a single agent, so the same agent is used for each round." - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What's Next?\n", "\n", - "Now that you have a basic understanding of how to define an agent and a team, consider following the [tutorial](./tutorial/models) for a walkthrough on other features of AgentChat.\n", - "\n" + "Now that you have a basic understanding of how to define an agent, consider following the [tutorial](./tutorial/models) for a walkthrough on other features of AgentChat." ] } ], @@ -157,7 +114,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/models.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/models.ipynb index f131894473c2..cd8427839650 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/models.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/models.ipynb @@ -295,33 +295,24 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "finish_reason='unknown' content='Paris\\n' usage=RequestUsage(prompt_tokens=8, completion_tokens=2) cached=False logprobs=None\n" + "finish_reason='stop' content='Paris\\n' usage=RequestUsage(prompt_tokens=7, completion_tokens=2) cached=False logprobs=None thought=None\n" ] } ], "source": [ - "import os\n", - "\n", - "from autogen_core.models import ModelFamily, UserMessage\n", + "from autogen_core.models import UserMessage\n", "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", "\n", "model_client = OpenAIChatCompletionClient(\n", - " model=\"gemini-1.5-flash\",\n", - " base_url=\"https://generativelanguage.googleapis.com/v1beta/openai/\",\n", - " api_key=os.environ[\"GEMINI_API_KEY\"],\n", - " model_info={\n", - " \"vision\": True,\n", - " \"function_calling\": True,\n", - " \"json_output\": True,\n", - " \"family\": ModelFamily.GEMINI_1_5_FLASH,\n", - " },\n", + " model=\"gemini-1.5-flash-8b\",\n", + " # api_key=\"GEMINI_API_KEY\",\n", ")\n", "\n", "response = await model_client.create([UserMessage(content=\"What is the capital of France?\", source=\"user\")])\n", diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/components/model-clients.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/components/model-clients.ipynb index c566a62c82f5..df6bb2d1ac77 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/components/model-clients.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/components/model-clients.ipynb @@ -303,33 +303,24 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "finish_reason='unknown' content='Paris\\n' usage=RequestUsage(prompt_tokens=8, completion_tokens=2) cached=False logprobs=None\n" + "finish_reason='stop' content='Paris\\n' usage=RequestUsage(prompt_tokens=7, completion_tokens=2) cached=False logprobs=None thought=None\n" ] } ], "source": [ - "import os\n", - "\n", - "from autogen_core.models import ModelFamily, UserMessage\n", + "from autogen_core.models import UserMessage\n", "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", "\n", "model_client = OpenAIChatCompletionClient(\n", - " model=\"gemini-1.5-flash\",\n", - " base_url=\"https://generativelanguage.googleapis.com/v1beta/openai/\",\n", - " api_key=os.environ[\"GEMINI_API_KEY\"],\n", - " model_info={\n", - " \"vision\": True,\n", - " \"function_calling\": True,\n", - " \"json_output\": True,\n", - " \"family\": ModelFamily.GEMINI_1_5_FLASH,\n", - " },\n", + " model=\"gemini-1.5-flash-8b\",\n", + " # api_key=\"GEMINI_API_KEY\",\n", ")\n", "\n", "response = await model_client.create([UserMessage(content=\"What is the capital of France?\", source=\"user\")])\n", diff --git a/python/packages/autogen-core/src/autogen_core/models/_model_client.py b/python/packages/autogen-core/src/autogen_core/models/_model_client.py index 4a046a3cf5f9..421cecf8cc06 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_model_client.py +++ b/python/packages/autogen-core/src/autogen_core/models/_model_client.py @@ -55,6 +55,35 @@ class ModelFamily: def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily: raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.") + @staticmethod + def is_claude(family: str) -> bool: + return family in ( + ModelFamily.CLAUDE_3_HAIKU, + ModelFamily.CLAUDE_3_SONNET, + ModelFamily.CLAUDE_3_OPUS, + ModelFamily.CLAUDE_3_5_HAIKU, + ModelFamily.CLAUDE_3_5_SONNET, + ) + + @staticmethod + def is_gemini(family: str) -> bool: + return family in ( + ModelFamily.GEMINI_1_5_FLASH, + ModelFamily.GEMINI_1_5_PRO, + ModelFamily.GEMINI_2_0_FLASH, + ) + + @staticmethod + def is_openai(family: str) -> bool: + return family in ( + ModelFamily.GPT_4O, + ModelFamily.O1, + ModelFamily.O3, + ModelFamily.GPT_4, + ModelFamily.GPT_35, + ModelFamily.R1, + ) + @deprecated("Use the ModelInfo class instead ModelCapabilities.") class ModelCapabilities(TypedDict, total=False): diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_model_info.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_model_info.py index 116ef539f448..c838df9cd554 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_model_info.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_model_info.py @@ -134,6 +134,36 @@ "json_output": True, "family": ModelFamily.GPT_35, }, + "gemini-1.5-flash": { + "vision": True, + "function_calling": True, + "json_output": True, + "family": ModelFamily.GEMINI_1_5_FLASH, + }, + "gemini-1.5-flash-8b": { + "vision": True, + "function_calling": True, + "json_output": True, + "family": ModelFamily.GEMINI_1_5_FLASH, + }, + "gemini-1.5-pro": { + "vision": True, + "function_calling": True, + "json_output": True, + "family": ModelFamily.GEMINI_1_5_PRO, + }, + "gemini-2.0-flash": { + "vision": True, + "function_calling": True, + "json_output": True, + "family": ModelFamily.GEMINI_2_0_FLASH, + }, + "gemini-2.0-flash-lite-preview-02-05": { + "vision": True, + "function_calling": True, + "json_output": True, + "family": ModelFamily.GEMINI_2_0_FLASH, + }, } _MODEL_TOKEN_LIMITS: Dict[str, int] = { @@ -156,8 +186,14 @@ "gpt-3.5-turbo-instruct": 4096, "gpt-3.5-turbo-0613": 4096, "gpt-3.5-turbo-16k-0613": 16385, + "gemini-1.5-flash": 1048576, + "gemini-1.5-flash-8b": 1048576, + "gemini-1.5-pro": 2097152, + "gemini-2.0-flash": 1048576, } +GEMINI_OPENAI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/" + def resolve_model(model: str) -> str: if model in _MODEL_POINTERS: diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index 23b46cdb4891..8469f921ba73 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -3,6 +3,7 @@ import json import logging import math +import os import re import warnings from asyncio import Task @@ -1093,6 +1094,14 @@ def __init__(self, **kwargs: Unpack[OpenAIClientConfiguration]): if "add_name_prefixes" in kwargs: add_name_prefixes = kwargs["add_name_prefixes"] + # Special handling for Gemini model. + assert "model" in copied_args and isinstance(copied_args["model"], str) + if copied_args["model"].startswith("gemini-"): + if "base_url" not in copied_args: + copied_args["base_url"] = _model_info.GEMINI_OPENAI_BASE_URL + if "api_key" not in copied_args and "GEMINI_API_KEY" in os.environ: + copied_args["api_key"] = os.environ["GEMINI_API_KEY"] + client = _openai_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index 69d9281042c0..edb7598961fb 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -197,6 +197,18 @@ async def test_openai_chat_completion_client() -> None: assert client +@pytest.mark.asyncio +async def test_openai_chat_completion_client_with_gemini_model() -> None: + client = OpenAIChatCompletionClient(model="gemini-1.5-flash", api_key="api_key") + assert client + + +@pytest.mark.asyncio +async def test_openai_chat_completion_client_raise_on_unknown_model() -> None: + with pytest.raises(ValueError, match="model_info is required"): + _ = OpenAIChatCompletionClient(model="unknown", api_key="api_key") + + @pytest.mark.asyncio async def test_custom_model_with_capabilities() -> None: with pytest.raises(ValueError, match="model_info is required"): @@ -952,14 +964,6 @@ async def test_gemini() -> None: model_client = OpenAIChatCompletionClient( model="gemini-1.5-flash", - api_key=api_key, - base_url="https://generativelanguage.googleapis.com/v1beta/openai/", - model_info={ - "function_calling": True, - "json_output": True, - "vision": True, - "family": ModelFamily.GEMINI_1_5_FLASH, - }, ) await _test_model_client_basic_completion(model_client) await _test_model_client_with_function_calling(model_client) From 340a8e85872697f1aa524ef900c651eaf427d547 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sun, 9 Feb 2025 14:46:07 -0800 Subject: [PATCH 23/36] Add notes on api key and modifying specifications in AGS (#5466) Add clear notes on how to specify api key and modifying specifications in AGS. Add diagrams explaining how to switch between visual builder and JSON mode ## Why are these changes needed? ![teambuilder](https://github.com/user-attachments/assets/9eede334-7f60-4c87-bec6-cf41839ba231) ## Related issue number @nour-bouzid ## Checks - [ ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --- .../autogenstudio-user-guide/jsoneditor.jpg | 3 + .../autogenstudio-user-guide/teambuilder.jpg | 3 + .../autogenstudio-user-guide/usage.md | 67 ++++++++++++------- .../autogenstudio/gallery/builder.py | 22 ++++-- .../views/gallery/default_gallery.json | 62 +++++++++++------ .../src/components/views/gallery/store.tsx | 2 +- .../components/views/team/builder/library.tsx | 2 +- 7 files changed, 112 insertions(+), 49 deletions(-) create mode 100644 python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/jsoneditor.jpg create mode 100644 python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/teambuilder.jpg diff --git a/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/jsoneditor.jpg b/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/jsoneditor.jpg new file mode 100644 index 000000000000..ec22f9dbb778 --- /dev/null +++ b/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/jsoneditor.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afa594c040c6f8e342bb4ebd4a140e4e8cdd7291f01025f467c44d75acc52dad +size 758268 diff --git a/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/teambuilder.jpg b/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/teambuilder.jpg new file mode 100644 index 000000000000..bf419d191be8 --- /dev/null +++ b/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/teambuilder.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d3c49bf0e8931375ab20cb1bf8eb5d8682c1a12b9609b8c0b993d1ae76f70ec +size 734326 diff --git a/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/usage.md b/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/usage.md index dd2cc6d1e727..9e19f804960d 100644 --- a/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/usage.md +++ b/python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/usage.md @@ -7,16 +7,52 @@ myst: # Usage -AutoGen Studio provides a Team Builder interface where developers can define multiple components and behaviors. Users can create teams, add agents to teams, attach tools and models to agents, and define team termination conditions. -After defining a team, users can test it in the Playground view to accomplish various tasks through direct interaction. +AutoGen Studio (AGS) provides a Team Builder interface where developers can define multiple components and behaviors. Users can create teams, add agents to teams, attach tools and models to agents, and define team termination conditions. +After defining a team, users can test directly in the team builder view or attach it to a session for use in the Playground view. > See a video tutorial on AutoGen Studio v0.4 (02/25) - [https://youtu.be/oum6EI7wohM](https://youtu.be/oum6EI7wohM) [![A Friendly Introduction to AutoGen Studio v0.4](https://img.youtube.com/vi/oum6EI7wohM/maxresdefault.jpg)](https://www.youtube.com/watch?v=oum6EI7wohM) +## Setting Up an API Key + +Most of your agents will require an API key. You can set up an environment variable `OPENAI_API_KEY` (assuming you are using OpenAI models) and AutoGen will automatically use this for any OpenAI model clients you specify for your agents or teams. Alternatively you can specify the api key as part of the team or agent configuration. + +See the section below on how to build an agent team either using the visual builder or by directly editing the JSON configuration. + +## Building an Agent Team + +
+ +AutoGen Studio integrates closely with all component abstractions provided by AutoGen AgentChat, including {py:class}`~autogen_agentchat.teams`, {py:class}`~autogen_agentchat.agents`, {py:class}`~autogen_core.models`, {py:class}`~autogen_core.tools`, and termination {py:class}`~autogen_agentchat.conditions`. + +The Team Builder view in AGS provides a visual team builder that allows users to define components through either drag-and-drop functionality or by editing a JSON configuration of the team directly. + +### Using the Visual Builder + +The visual builder is enabled by default and allows users to drag-and-drop components from the provided Component library to the Team Builder canvas. The team builder canvas represents a team and consists of a main team node and a set of a connected agent nodes. It includes a Component Library that has a selection of components that can be added to the team or agent nodes in the canvas. + +![Team Builder](teambuilder.jpg) + +The core supported behaviours include: + +- Create a new team. This can be done by clicking on the "New Team" button in the Team Builder view or by selecting any of the existing default teams that ship with the default AGS Gallery. Once you do this, a new team node and agent node(s) will be created in the canvas. +- Drag and drop components from the library to the team or agent nodes in the canvas. + - Teams: drag in agents and termination conditions to the team node (there are specific drop zones for these components) + - Agents: drag in models and tools to the agent node (there are specific drop zones for these components) +- Editing Team/Agent Nodes: Click on the edit icon (top right) of the node to view and edit its properties. This pops up a panel that allows you to edit the fields of the node. In some cases you will need to scroll down and click into specific sections e.g., for an agent with a model client, you will need to click into the model client section to edit the model client properties. Once done with editing, click on the save button to save the changes. + +### Using the JSON Editor + +![JSON Editor](jsoneditor.jpg) + +AGS also lets you directly modify the JSON configuration of the team. This can be done by toggling the visual builder mode off. Once you do this, you will see the JSON configuration of the team. You can then edit the JSON configuration directly. + +> Did you know that you define your agents in Python, export them to JSON and then paste them in the JSON editor? The section below shows how to accomplish this. + ## Declarative Specification of Componenents -AutoGen Studio is built on the declarative specification behaviors of AutoGen AgentChat. This allows users to define teams, agents, models, tools, and termination conditions in python and then dump them into a JSON file for use in AutoGen Studio. +AutoGen Studio is built on the declarative specification behaviors of AutoGen AgentChat. This allows users to define teams, agents, models, tools, and termination conditions in Python and then dump them into a JSON file for use in AutoGen Studio. Here's an example of an agent team and how it is converted to a JSON file: @@ -98,30 +134,15 @@ print(config.model_dump_json()) } ``` -This example shows a team with a single agent, using the `RoundRobinGroupChat` type and a `TextMentionTermination` condition. - -## Building an Agent Team - -
- -AutoGen Studio integrates closely with all component abstractions provided by AutoGen AgentChat, including {py:class}`~autogen_agentchat.teams`, {py:class}`~autogen_agentchat.agents`, {py:class}`~autogen_core.models`, {py:class}`~autogen_core.tools`, and termination {py:class}`~autogen_agentchat.conditions`. - -The Team Builder interface allows users to define components through either declarative specification or drag-and-drop functionality: - -Team Builder Operations: +This example shows a team with a single agent, using the `RoundRobinGroupChat` type and a `TextMentionTermination` condition. You will also notice that the model client is an `OpenAIChatCompletionClient` model client where only the model name is specified. In this case, the API key is assumed to be set as an environment variable `OPENAI_API_KEY`. You can also specify the API key as part of the model client configuration. -- Create a new team - - Edit Team JSON directly (toggle visual builder mode off) or - - Use the visual builder, drag-and-drop components from the library: - - Teams: Add agents and termination conditions - - Agents: Add models and tools -- Save team configurations +To understand the full configuration of an model clients, you can refer to the [AutoGen Model Clients documentation](https://microsoft.github.io/autogen/dev/user-guide/core-user-guide/components/model-clients.html). -Note: For each node in the visual builder, you can click on the edit icon (top right) to view and edit the JSON configuration. +Note that you can similarly define your model client in Python and call `dump_component()` on it to get the JSON configuration and use it to update the model client section of your team or agent configuration. ## Gallery - Sharing and Reusing Components -A Gallery is a collection of components - teams, agents, models, tools, and terminations - that can be shared and reused across projects. +AGS provides a Gallery view, where a gallery is a collection of components - teams, agents, models, tools, and terminations - that can be shared and reused across projects. Users can create a local gallery or import a gallery (from a URL, a JSON file import or simply by copying and pasting the JSON). At any given time, users can select any of the current Gallery items as a **default gallery**. This **default gallery** will be used to populate the Team Builder sidebar with components. @@ -155,7 +176,7 @@ Team configurations can be integrated into Python applications using the `TeamMa from autogenstudio.teammanager import TeamManager tm = TeamManager() -result_stream = tm.run(task="What is the weather in New York?", team_config="team.json") # or wm.run_stream(..) +result_stream = tm.run(task="What is the weather in New York?", team_config="team.json") # or tm.run_stream(..) ``` To export team configurations, use the export button in Team Builder to generate a JSON file for Python application use. diff --git a/python/packages/autogen-studio/autogenstudio/gallery/builder.py b/python/packages/autogen-studio/autogenstudio/gallery/builder.py index a3d4b4523bba..6582481ae9bf 100644 --- a/python/packages/autogen-studio/autogenstudio/gallery/builder.py +++ b/python/packages/autogen-studio/autogenstudio/gallery/builder.py @@ -8,6 +8,7 @@ from autogen_core.models import ModelInfo from autogen_ext.agents.web_surfer import MultimodalWebSurfer from autogen_ext.models.openai import OpenAIChatCompletionClient +from autogen_ext.models.openai._openai_client import AzureOpenAIChatCompletionClient from autogenstudio.datamodel import Gallery, GalleryComponents, GalleryItems, GalleryMetadata @@ -142,8 +143,7 @@ def create_default_gallery() -> Gallery: # Create base model client base_model = OpenAIChatCompletionClient(model="gpt-4o-mini") - builder.add_model(base_model.dump_component()) - + builder.add_model(base_model.dump_component(), label="OpenAI GPT-4o Mini", description="OpenAI GPT-4o-mini") # Create Mistral vllm model mistral_vllm_model = OpenAIChatCompletionClient( model="TheBloke/Mistral-7B-Instruct-v0.2-GGUF", @@ -152,8 +152,22 @@ def create_default_gallery() -> Gallery: ) builder.add_model( mistral_vllm_model.dump_component(), - label="Mistral-7B vllm", - description="Example on how to use the OpenAIChatCopletionClient with local models (Ollama, vllm etc).", + label="Mistral-7B Local", + description="Local Mistral-7B model client for instruction-based generation (Ollama, LMStudio).", + ) + + # create an azure mode + az_model_client = AzureOpenAIChatCompletionClient( + azure_deployment="{your-azure-deployment}", + model="gpt-4o-mini", + api_version="2024-06-01", + azure_endpoint="https://{your-custom-endpoint}.openai.azure.com/", + api_key="sk-...", # For key-based authentication. + ) + builder.add_model( + az_model_client.dump_component(), + label="AzureOpenAI GPT-4o-mini", + description="GPT-4o Mini Azure OpenAI model client.", ) builder.add_tool( diff --git a/python/packages/autogen-studio/frontend/src/components/views/gallery/default_gallery.json b/python/packages/autogen-studio/frontend/src/components/views/gallery/default_gallery.json index 305cf1eecc0f..e95142b52d05 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/gallery/default_gallery.json +++ b/python/packages/autogen-studio/frontend/src/components/views/gallery/default_gallery.json @@ -4,8 +4,8 @@ "url": null, "metadata": { "author": "AutoGen Team", - "created_at": "2025-02-06T11:52:58.243877", - "updated_at": "2025-02-06T11:52:58.497032", + "created_at": "2025-02-09T09:43:30.164372", + "updated_at": "2025-02-09T09:43:30.486369", "version": "1.0.0", "description": "A default gallery containing basic components for human-in-loop conversations", "tags": ["human-in-loop", "assistant", "web agents"], @@ -107,7 +107,8 @@ "description": "Terminate the conversation after a maximum number of messages have been exchanged.", "label": "MaxMessageTermination", "config": { - "max_messages": 10 + "max_messages": 10, + "include_agent_event": false } } ] @@ -120,7 +121,7 @@ "component_type": "team", "version": 1, "component_version": 1, - "description": "A group chat team that have participants takes turn to publish a message\n to all, using a ChatCompletion model to select the next speaker after each message.", + "description": "A team with 3 agents - a Web Surfer agent that can browse the web, a Verification Assistant that verifies and summarizes information, and a User Proxy that provides human feedback when needed.", "label": "Web Agent Team (Operator)", "config": { "participants": [ @@ -231,7 +232,8 @@ "description": "Terminate the conversation after a maximum number of messages have been exchanged.", "label": "MaxMessageTermination", "config": { - "max_messages": 20 + "max_messages": 20, + "include_agent_event": false } }, { @@ -249,7 +251,8 @@ } }, "selector_prompt": "You are the cordinator of role play game. The following roles are available:\n{roles}. Given a task, the websurfer_agent will be tasked to address it by browsing the web and providing information. The assistant_agent will be tasked with verifying the information provided by the websurfer_agent and summarizing the information to present a final answer to the user. If the task needs assistance from a human user (e.g., providing feedback, preferences, or the task is stalled), you should select the user_proxy role to provide the necessary information.\n\nRead the following conversation. Then select the next role from {participants} to play. Only return the role.\n\n{history}\n\nRead the above conversation. Then select the next role from {participants} to play. Only return the role.", - "allow_repeated_speaker": false + "allow_repeated_speaker": false, + "max_selector_attempts": 3 } }, { @@ -257,7 +260,7 @@ "component_type": "team", "version": 1, "component_version": 1, - "description": "A team that performs deep research using web searches, verification, and summarization.", + "description": "A team with 3 agents - a Research Assistant that performs web searches and analyzes information, a Verifier that ensures research quality and completeness, and a Summary Agent that provides a detailed markdown summary of the research as a report to the user.", "label": "Deep Research Team", "config": { "participants": [ @@ -291,7 +294,7 @@ "description": "Create custom tools by wrapping standard Python functions.", "label": "FunctionTool", "config": { - "source_code": "async def google_search(\n query: str,\n num_results: int = 5,\n include_snippets: bool = True,\n include_content: bool = True,\n content_max_length: Optional[int] = 15000,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: bool = True,\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Google search using the Custom Search API and optionally fetch webpage content.\n\n Args:\n query: Search query string\n num_results: Number of results to return (max 10)\n include_snippets: Include result snippets in output\n include_content: Include full webpage content in markdown format\n content_max_length: Maximum length of webpage content (if included)\n language: Language code for search results (e.g., en, es, fr)\n country: Optional country code for search results (e.g., us, uk)\n safe_search: Enable safe search filtering\n\n Returns:\n List[Dict[str, str]]: List of search results, each containing:\n - title: Result title\n - link: Result URL\n - snippet: Result description (if include_snippets=True)\n - content: Webpage content in markdown (if include_content=True)\n \"\"\"\n api_key = os.getenv(\"GOOGLE_API_KEY\")\n cse_id = os.getenv(\"GOOGLE_CSE_ID\")\n\n if not api_key or not cse_id:\n raise ValueError(\"Missing required environment variables. Please set GOOGLE_API_KEY and GOOGLE_CSE_ID.\")\n\n num_results = min(max(1, num_results), 10)\n\n async def fetch_page_content(url: str, max_length: Optional[int] = 50000) -> str:\n \"\"\"Helper function to fetch and convert webpage content to markdown\"\"\"\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n h2t = html2text.HTML2Text()\n h2t.body_width = 0\n h2t.ignore_images = False\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n markdown = h2t.handle(str(soup))\n\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except Exception as e:\n return f\"Error fetching content: {str(e)}\"\n\n params = {\n \"key\": api_key,\n \"cx\": cse_id,\n \"q\": query,\n \"num\": num_results,\n \"hl\": language,\n \"safe\": \"active\" if safe_search else \"off\",\n }\n\n if country:\n params[\"gl\"] = country\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(\"https://www.googleapis.com/customsearch/v1\", params=params, timeout=10)\n response.raise_for_status()\n data = response.json()\n\n results = []\n if \"items\" in data:\n for item in data[\"items\"]:\n result = {\"title\": item.get(\"title\", \"\"), \"link\": item.get(\"link\", \"\")}\n if include_snippets:\n result[\"snippet\"] = item.get(\"snippet\", \"\")\n\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n results.append(result)\n\n return results\n\n except httpx.RequestError as e:\n raise ValueError(f\"Failed to perform search: {str(e)}\") from e\n except KeyError as e:\n raise ValueError(f\"Invalid API response format: {str(e)}\") from e\n except Exception as e:\n raise ValueError(f\"Error during search: {str(e)}\") from e\n", + "source_code": "async def google_search(\n query: str,\n num_results: int = 3,\n include_snippets: bool = True,\n include_content: bool = True,\n content_max_length: Optional[int] = 10000,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: bool = True,\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Google search using the Custom Search API and optionally fetch webpage content.\n\n Args:\n query: Search query string\n num_results: Number of results to return (max 10)\n include_snippets: Include result snippets in output\n include_content: Include full webpage content in markdown format\n content_max_length: Maximum length of webpage content (if included)\n language: Language code for search results (e.g., en, es, fr)\n country: Optional country code for search results (e.g., us, uk)\n safe_search: Enable safe search filtering\n\n Returns:\n List[Dict[str, str]]: List of search results, each containing:\n - title: Result title\n - link: Result URL\n - snippet: Result description (if include_snippets=True)\n - content: Webpage content in markdown (if include_content=True)\n \"\"\"\n api_key = os.getenv(\"GOOGLE_API_KEY\")\n cse_id = os.getenv(\"GOOGLE_CSE_ID\")\n\n if not api_key or not cse_id:\n raise ValueError(\"Missing required environment variables. Please set GOOGLE_API_KEY and GOOGLE_CSE_ID.\")\n\n num_results = min(max(1, num_results), 10)\n\n async def fetch_page_content(url: str, max_length: Optional[int] = 50000) -> str:\n \"\"\"Helper function to fetch and convert webpage content to markdown\"\"\"\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n h2t = html2text.HTML2Text()\n h2t.body_width = 0\n h2t.ignore_images = False\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n markdown = h2t.handle(str(soup))\n\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except Exception as e:\n return f\"Error fetching content: {str(e)}\"\n\n params = {\n \"key\": api_key,\n \"cx\": cse_id,\n \"q\": query,\n \"num\": num_results,\n \"hl\": language,\n \"safe\": \"active\" if safe_search else \"off\",\n }\n\n if country:\n params[\"gl\"] = country\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(\"https://www.googleapis.com/customsearch/v1\", params=params, timeout=10)\n response.raise_for_status()\n data = response.json()\n\n results = []\n if \"items\" in data:\n for item in data[\"items\"]:\n result = {\"title\": item.get(\"title\", \"\"), \"link\": item.get(\"link\", \"\")}\n if include_snippets:\n result[\"snippet\"] = item.get(\"snippet\", \"\")\n\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n results.append(result)\n\n return results\n\n except httpx.RequestError as e:\n raise ValueError(f\"Failed to perform search: {str(e)}\") from e\n except KeyError as e:\n raise ValueError(f\"Invalid API response format: {str(e)}\") from e\n except Exception as e:\n raise ValueError(f\"Error during search: {str(e)}\") from e\n", "name": "google_search", "description": "\n Perform Google searches using the Custom Search API with optional webpage content fetching.\n Requires GOOGLE_API_KEY and GOOGLE_CSE_ID environment variables to be set.\n ", "global_imports": [ @@ -400,7 +403,7 @@ "config": {} }, "description": "A verification specialist who ensures research quality and completeness", - "system_message": "You are a research verification specialist.\n Your role is to:\n 1. Verify that search queries are effective and suggest improvements if needed\n 2. Explore drill downs where needed e.g, if the answer is likely in a link in the returned search results, suggest clicking on the link\n 3. Suggest additional angles or perspectives to explore. Be judicious in suggesting new paths to avoid scope creep or wasting resources, if the task appears to be addressed and we can provide a report, do this and respond with \"TERMINATE\".\n 4. Track progress toward answering the original question\n 5. When the research is complete, provide a detailed summary in markdown format\n \n For incomplete research, end your message with \"CONTINUE RESEARCH\". \n For complete research, end your message with APPROVED.\n \n Your responses should be structured as:\n - Progress Assessment\n - Gaps/Issues (if any)\n - Suggestions (if needed)\n - Next Steps or Final Summary", + "system_message": "You are a research verification specialist.\n Your role is to:\n 1. Verify that search queries are effective and suggest improvements if needed\n 2. Explore drill downs where needed e.g, if the answer is likely in a link in the returned search results, suggest clicking on the link\n 3. Suggest additional angles or perspectives to explore. Be judicious in suggesting new paths to avoid scope creep or wasting resources, if the task appears to be addressed and we can provide a report, do this and respond with \"TERMINATE\".\n 4. Track progress toward answering the original question\n 5. When the research is complete, provide a detailed summary in markdown format. For incomplete research, end your message with \"CONTINUE RESEARCH\". For complete research, end your message with APPROVED.\n Your responses should be structured as:\n - Progress Assessment\n - Gaps/Issues (if any)\n - Suggestions (if needed)\n - Next Steps or Final Summary", "model_client_stream": false, "reflect_on_tool_use": false, "tool_call_summary_format": "{result}" @@ -485,14 +488,16 @@ "description": "Terminate the conversation after a maximum number of messages have been exchanged.", "label": "MaxMessageTermination", "config": { - "max_messages": 30 + "max_messages": 30, + "include_agent_event": false } } ] } }, - "selector_prompt": "You are coordinating a research team by selecting the team member to speak/act next. The following team member roles are available:\n {roles}.\n The research_assistant performs searches and analyzes information.\n The verifier evaluates progress and ensures completeness.\n The summary_agent provides a detailed markdown summary of the research as a report to the user.\n\n Given the current context, select the most appropriate next speaker.\n The research_assistant should search and analyze.\n The verifier should evaluate progress and guide the research (select this role is there is a need to verify/evaluate progress). You should ONLY select the summary_agent role if the research is complete and it is time to generate a report.\n\n Base your selection on:\n 1. Current stage of research\n 2. Last speaker's findings or suggestions\n 3. Need for verification vs need for new information\n \n Read the following conversation. Then select the next role from {participants} to play. Only return the role.\n\n {history}\n\n Read the above conversation. Then select the next role from {participants} to play. ONLY RETURN THE ROLE.", - "allow_repeated_speaker": true + "selector_prompt": "You are coordinating a research team by selecting the team member to speak/act next. The following team member roles are available:\n {roles}.\n The research_assistant performs searches and analyzes information.\n The verifier evaluates progress and ensures completeness.\n The summary_agent provides a detailed markdown summary of the research as a report to the user.\n\n Given the current context, select the most appropriate next speaker.\n The research_assistant should search and analyze.\n The verifier should evaluate progress and guide the research (select this role is there is a need to verify/evaluate progress). You should ONLY select the summary_agent role if the research is complete and it is time to generate a report.\n\n Base your selection on:\n 1. Current stage of research\n 2. Last speaker's findings or suggestions\n 3. Need for verification vs need for new information\n Read the following conversation. Then select the next role from {participants} to play. Only return the role.\n\n {history}\n\n Read the above conversation. Then select the next role from {participants} to play. ONLY RETURN THE ROLE.", + "allow_repeated_speaker": true, + "max_selector_attempts": 3 } } ], @@ -638,8 +643,8 @@ "component_type": "model", "version": 1, "component_version": 1, - "description": "Chat completion client for OpenAI hosted models.", - "label": "OpenAIChatCompletionClient", + "description": "OpenAI GPT-4o-mini", + "label": "OpenAI GPT-4o Mini", "config": { "model": "gpt-4o-mini" } @@ -649,8 +654,8 @@ "component_type": "model", "version": 1, "component_version": 1, - "description": "Example on how to use the OpenAIChatCopletionClient with local models (Ollama, vllm etc).", - "label": "Mistral-7B vllm", + "description": "Local Mistral-7B model client for instruction-based generation (Ollama, LMStudio).", + "label": "Mistral-7B Local", "config": { "model": "TheBloke/Mistral-7B-Instruct-v0.2-GGUF", "model_info": { @@ -661,6 +666,21 @@ }, "base_url": "http://localhost:1234/v1" } + }, + { + "provider": "autogen_ext.models.openai.AzureOpenAIChatCompletionClient", + "component_type": "model", + "version": 1, + "component_version": 1, + "description": "GPT-4o Mini Azure OpenAI model client.", + "label": "AzureOpenAI GPT-4o-mini", + "config": { + "model": "gpt-4o-mini", + "api_key": "sk-...", + "azure_endpoint": "https://{your-custom-endpoint}.openai.azure.com/", + "azure_deployment": "{your-azure-deployment}", + "api_version": "2024-06-01" + } } ], "tools": [ @@ -796,7 +816,7 @@ "description": "A tool that performs Bing searches using the Bing Web Search API. Requires the requests library, BING_SEARCH_KEY env variable to function.", "label": "Bing Search Tool", "config": { - "source_code": "async def bing_search(\n query: str,\n num_results: int = 5,\n include_snippets: bool = True,\n include_content: bool = True,\n content_max_length: Optional[int] = 15000,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: str = \"moderate\",\n response_filter: str = \"webpages\",\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Bing search using the Bing Web Search API.\n\n Args:\n query: Search query string\n num_results: Number of results to return (max 50)\n include_snippets: Include result snippets in output\n include_content: Include full webpage content in markdown format\n content_max_length: Maximum length of webpage content (if included)\n language: Language code for search results (e.g., 'en', 'es', 'fr')\n country: Optional market code for search results (e.g., 'us', 'uk')\n safe_search: SafeSearch setting ('off', 'moderate', or 'strict')\n response_filter: Type of results ('webpages', 'news', 'images', or 'videos')\n\n Returns:\n List[Dict[str, str]]: List of search results\n\n Raises:\n ValueError: If API credentials are invalid or request fails\n \"\"\"\n # Get and validate API key\n api_key = os.getenv(\"BING_SEARCH_KEY\", \"\").strip()\n\n if not api_key:\n raise ValueError(\n \"BING_SEARCH_KEY environment variable is not set. \" \"Please obtain an API key from Azure Portal.\"\n )\n\n # Validate safe_search parameter\n valid_safe_search = [\"off\", \"moderate\", \"strict\"]\n if safe_search.lower() not in valid_safe_search:\n raise ValueError(f\"Invalid safe_search value. Must be one of: {', '.join(valid_safe_search)}\")\n\n # Validate response_filter parameter\n valid_filters = [\"webpages\", \"news\", \"images\", \"videos\"]\n if response_filter.lower() not in valid_filters:\n raise ValueError(f\"Invalid response_filter value. Must be one of: {', '.join(valid_filters)}\")\n\n async def fetch_page_content(url: str, max_length: Optional[int] = 50000) -> str:\n \"\"\"Helper function to fetch and convert webpage content to markdown\"\"\"\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n h2t = html2text.HTML2Text()\n h2t.body_width = 0\n h2t.ignore_images = False\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n markdown = h2t.handle(str(soup))\n\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except Exception as e:\n return f\"Error fetching content: {str(e)}\"\n\n # Build request headers and parameters\n headers = {\"Ocp-Apim-Subscription-Key\": api_key, \"Accept\": \"application/json\"}\n\n params = {\n \"q\": query,\n \"count\": min(max(1, num_results), 50),\n \"mkt\": f\"{language}-{country.upper()}\" if country else language,\n \"safeSearch\": safe_search.capitalize(),\n \"responseFilter\": response_filter,\n \"textFormat\": \"raw\",\n }\n\n # Make the request\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(\n \"https://api.bing.microsoft.com/v7.0/search\", headers=headers, params=params, timeout=10\n )\n\n # Handle common error cases\n if response.status_code == 401:\n raise ValueError(\"Authentication failed. Please verify your Bing Search API key.\")\n elif response.status_code == 403:\n raise ValueError(\n \"Access forbidden. This could mean:\\n\"\n \"1. The API key is invalid\\n\"\n \"2. The API key has expired\\n\"\n \"3. You've exceeded your API quota\"\n )\n elif response.status_code == 429:\n raise ValueError(\"API quota exceeded. Please try again later.\")\n\n response.raise_for_status()\n data = response.json()\n\n # Process results based on response_filter\n results = []\n if response_filter == \"webpages\" and \"webPages\" in data:\n items = data[\"webPages\"][\"value\"]\n elif response_filter == \"news\" and \"news\" in data:\n items = data[\"news\"][\"value\"]\n elif response_filter == \"images\" and \"images\" in data:\n items = data[\"images\"][\"value\"]\n elif response_filter == \"videos\" and \"videos\" in data:\n items = data[\"videos\"][\"value\"]\n else:\n if not any(key in data for key in [\"webPages\", \"news\", \"images\", \"videos\"]):\n return [] # No results found\n raise ValueError(f\"No {response_filter} results found in API response\")\n\n # Extract relevant information based on result type\n for item in items:\n result = {\"title\": item.get(\"name\", \"\")}\n\n if response_filter == \"webpages\":\n result[\"link\"] = item.get(\"url\", \"\")\n if include_snippets:\n result[\"snippet\"] = item.get(\"snippet\", \"\")\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n elif response_filter == \"news\":\n result[\"link\"] = item.get(\"url\", \"\")\n if include_snippets:\n result[\"snippet\"] = item.get(\"description\", \"\")\n result[\"date\"] = item.get(\"datePublished\", \"\")\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n elif response_filter == \"images\":\n result[\"link\"] = item.get(\"contentUrl\", \"\")\n result[\"thumbnail\"] = item.get(\"thumbnailUrl\", \"\")\n if include_snippets:\n result[\"snippet\"] = item.get(\"description\", \"\")\n\n elif response_filter == \"videos\":\n result[\"link\"] = item.get(\"contentUrl\", \"\")\n result[\"thumbnail\"] = item.get(\"thumbnailUrl\", \"\")\n if include_snippets:\n result[\"snippet\"] = item.get(\"description\", \"\")\n result[\"duration\"] = item.get(\"duration\", \"\")\n\n results.append(result)\n\n return results[:num_results]\n\n except httpx.RequestException as e:\n error_msg = str(e)\n if \"InvalidApiKey\" in error_msg:\n raise ValueError(\"Invalid API key. Please check your BING_SEARCH_KEY environment variable.\") from e\n elif \"KeyExpired\" in error_msg:\n raise ValueError(\"API key has expired. Please generate a new key.\") from e\n else:\n raise ValueError(f\"Search request failed: {error_msg}\") from e\n except json.JSONDecodeError:\n raise ValueError(\"Failed to parse API response. \" \"Please verify your API credentials and try again.\") from None\n except Exception as e:\n raise ValueError(f\"Unexpected error during search: {str(e)}\") from e\n", + "source_code": "async def bing_search(\n query: str,\n num_results: int = 3,\n include_snippets: bool = True,\n include_content: bool = True,\n content_max_length: Optional[int] = 10000,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: str = \"moderate\",\n response_filter: str = \"webpages\",\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Bing search using the Bing Web Search API.\n\n Args:\n query: Search query string\n num_results: Number of results to return (max 50)\n include_snippets: Include result snippets in output\n include_content: Include full webpage content in markdown format\n content_max_length: Maximum length of webpage content (if included)\n language: Language code for search results (e.g., 'en', 'es', 'fr')\n country: Optional market code for search results (e.g., 'us', 'uk')\n safe_search: SafeSearch setting ('off', 'moderate', or 'strict')\n response_filter: Type of results ('webpages', 'news', 'images', or 'videos')\n\n Returns:\n List[Dict[str, str]]: List of search results\n\n Raises:\n ValueError: If API credentials are invalid or request fails\n \"\"\"\n # Get and validate API key\n api_key = os.getenv(\"BING_SEARCH_KEY\", \"\").strip()\n\n if not api_key:\n raise ValueError(\n \"BING_SEARCH_KEY environment variable is not set. \" \"Please obtain an API key from Azure Portal.\"\n )\n\n # Validate safe_search parameter\n valid_safe_search = [\"off\", \"moderate\", \"strict\"]\n if safe_search.lower() not in valid_safe_search:\n raise ValueError(f\"Invalid safe_search value. Must be one of: {', '.join(valid_safe_search)}\")\n\n # Validate response_filter parameter\n valid_filters = [\"webpages\", \"news\", \"images\", \"videos\"]\n if response_filter.lower() not in valid_filters:\n raise ValueError(f\"Invalid response_filter value. Must be one of: {', '.join(valid_filters)}\")\n\n async def fetch_page_content(url: str, max_length: Optional[int] = 50000) -> str:\n \"\"\"Helper function to fetch and convert webpage content to markdown\"\"\"\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n h2t = html2text.HTML2Text()\n h2t.body_width = 0\n h2t.ignore_images = False\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n markdown = h2t.handle(str(soup))\n\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except Exception as e:\n return f\"Error fetching content: {str(e)}\"\n\n # Build request headers and parameters\n headers = {\"Ocp-Apim-Subscription-Key\": api_key, \"Accept\": \"application/json\"}\n\n params = {\n \"q\": query,\n \"count\": min(max(1, num_results), 50),\n \"mkt\": f\"{language}-{country.upper()}\" if country else language,\n \"safeSearch\": safe_search.capitalize(),\n \"responseFilter\": response_filter,\n \"textFormat\": \"raw\",\n }\n\n # Make the request\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(\n \"https://api.bing.microsoft.com/v7.0/search\", headers=headers, params=params, timeout=10\n )\n\n # Handle common error cases\n if response.status_code == 401:\n raise ValueError(\"Authentication failed. Please verify your Bing Search API key.\")\n elif response.status_code == 403:\n raise ValueError(\n \"Access forbidden. This could mean:\\n\"\n \"1. The API key is invalid\\n\"\n \"2. The API key has expired\\n\"\n \"3. You've exceeded your API quota\"\n )\n elif response.status_code == 429:\n raise ValueError(\"API quota exceeded. Please try again later.\")\n\n response.raise_for_status()\n data = response.json()\n\n # Process results based on response_filter\n results = []\n if response_filter == \"webpages\" and \"webPages\" in data:\n items = data[\"webPages\"][\"value\"]\n elif response_filter == \"news\" and \"news\" in data:\n items = data[\"news\"][\"value\"]\n elif response_filter == \"images\" and \"images\" in data:\n items = data[\"images\"][\"value\"]\n elif response_filter == \"videos\" and \"videos\" in data:\n items = data[\"videos\"][\"value\"]\n else:\n if not any(key in data for key in [\"webPages\", \"news\", \"images\", \"videos\"]):\n return [] # No results found\n raise ValueError(f\"No {response_filter} results found in API response\")\n\n # Extract relevant information based on result type\n for item in items:\n result = {\"title\": item.get(\"name\", \"\")}\n\n if response_filter == \"webpages\":\n result[\"link\"] = item.get(\"url\", \"\")\n if include_snippets:\n result[\"snippet\"] = item.get(\"snippet\", \"\")\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n elif response_filter == \"news\":\n result[\"link\"] = item.get(\"url\", \"\")\n if include_snippets:\n result[\"snippet\"] = item.get(\"description\", \"\")\n result[\"date\"] = item.get(\"datePublished\", \"\")\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n elif response_filter == \"images\":\n result[\"link\"] = item.get(\"contentUrl\", \"\")\n result[\"thumbnail\"] = item.get(\"thumbnailUrl\", \"\")\n if include_snippets:\n result[\"snippet\"] = item.get(\"description\", \"\")\n\n elif response_filter == \"videos\":\n result[\"link\"] = item.get(\"contentUrl\", \"\")\n result[\"thumbnail\"] = item.get(\"thumbnailUrl\", \"\")\n if include_snippets:\n result[\"snippet\"] = item.get(\"description\", \"\")\n result[\"duration\"] = item.get(\"duration\", \"\")\n\n results.append(result)\n\n return results[:num_results]\n\n except httpx.RequestException as e:\n error_msg = str(e)\n if \"InvalidApiKey\" in error_msg:\n raise ValueError(\"Invalid API key. Please check your BING_SEARCH_KEY environment variable.\") from e\n elif \"KeyExpired\" in error_msg:\n raise ValueError(\"API key has expired. Please generate a new key.\") from e\n else:\n raise ValueError(f\"Search request failed: {error_msg}\") from e\n except json.JSONDecodeError:\n raise ValueError(\"Failed to parse API response. \" \"Please verify your API credentials and try again.\") from None\n except Exception as e:\n raise ValueError(f\"Unexpected error during search: {str(e)}\") from e\n", "name": "bing_search", "description": "\n Perform Bing searches using the Bing Web Search API. Requires BING_SEARCH_KEY environment variable.\n Supports web, news, image, and video searches.\n See function documentation for detailed setup instructions.\n ", "global_imports": [ @@ -828,7 +848,7 @@ "description": "A tool that performs Google searches using the Google Custom Search API. Requires the requests library, [GOOGLE_API_KEY, GOOGLE_CSE_ID] to be set, env variable to function.", "label": "Google Search Tool", "config": { - "source_code": "async def google_search(\n query: str,\n num_results: int = 3,\n include_snippets: bool = True,\n include_content: bool = True,\n content_max_length: Optional[int] = 15000,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: bool = True,\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Google search using the Custom Search API and optionally fetch webpage content.\n\n Args:\n query: Search query string\n num_results: Number of results to return (max 10)\n include_snippets: Include result snippets in output\n include_content: Include full webpage content in markdown format\n content_max_length: Maximum length of webpage content (if included)\n language: Language code for search results (e.g., en, es, fr)\n country: Optional country code for search results (e.g., us, uk)\n safe_search: Enable safe search filtering\n\n Returns:\n List[Dict[str, str]]: List of search results, each containing:\n - title: Result title\n - link: Result URL\n - snippet: Result description (if include_snippets=True)\n - content: Webpage content in markdown (if include_content=True)\n \"\"\"\n api_key = os.getenv(\"GOOGLE_API_KEY\")\n cse_id = os.getenv(\"GOOGLE_CSE_ID\")\n\n if not api_key or not cse_id:\n raise ValueError(\"Missing required environment variables. Please set GOOGLE_API_KEY and GOOGLE_CSE_ID.\")\n\n num_results = min(max(1, num_results), 10)\n\n async def fetch_page_content(url: str, max_length: Optional[int] = 50000) -> str:\n \"\"\"Helper function to fetch and convert webpage content to markdown\"\"\"\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n h2t = html2text.HTML2Text()\n h2t.body_width = 0\n h2t.ignore_images = False\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n markdown = h2t.handle(str(soup))\n\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except Exception as e:\n return f\"Error fetching content: {str(e)}\"\n\n params = {\n \"key\": api_key,\n \"cx\": cse_id,\n \"q\": query,\n \"num\": num_results,\n \"hl\": language,\n \"safe\": \"active\" if safe_search else \"off\",\n }\n\n if country:\n params[\"gl\"] = country\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(\"https://www.googleapis.com/customsearch/v1\", params=params, timeout=10)\n response.raise_for_status()\n data = response.json()\n\n results = []\n if \"items\" in data:\n for item in data[\"items\"]:\n result = {\"title\": item.get(\"title\", \"\"), \"link\": item.get(\"link\", \"\")}\n if include_snippets:\n result[\"snippet\"] = item.get(\"snippet\", \"\")\n\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n results.append(result)\n\n return results\n\n except httpx.RequestError as e:\n raise ValueError(f\"Failed to perform search: {str(e)}\") from e\n except KeyError as e:\n raise ValueError(f\"Invalid API response format: {str(e)}\") from e\n except Exception as e:\n raise ValueError(f\"Error during search: {str(e)}\") from e\n", + "source_code": "async def google_search(\n query: str,\n num_results: int = 3,\n include_snippets: bool = True,\n include_content: bool = True,\n content_max_length: Optional[int] = 10000,\n language: str = \"en\",\n country: Optional[str] = None,\n safe_search: bool = True,\n) -> List[Dict[str, str]]:\n \"\"\"\n Perform a Google search using the Custom Search API and optionally fetch webpage content.\n\n Args:\n query: Search query string\n num_results: Number of results to return (max 10)\n include_snippets: Include result snippets in output\n include_content: Include full webpage content in markdown format\n content_max_length: Maximum length of webpage content (if included)\n language: Language code for search results (e.g., en, es, fr)\n country: Optional country code for search results (e.g., us, uk)\n safe_search: Enable safe search filtering\n\n Returns:\n List[Dict[str, str]]: List of search results, each containing:\n - title: Result title\n - link: Result URL\n - snippet: Result description (if include_snippets=True)\n - content: Webpage content in markdown (if include_content=True)\n \"\"\"\n api_key = os.getenv(\"GOOGLE_API_KEY\")\n cse_id = os.getenv(\"GOOGLE_CSE_ID\")\n\n if not api_key or not cse_id:\n raise ValueError(\"Missing required environment variables. Please set GOOGLE_API_KEY and GOOGLE_CSE_ID.\")\n\n num_results = min(max(1, num_results), 10)\n\n async def fetch_page_content(url: str, max_length: Optional[int] = 50000) -> str:\n \"\"\"Helper function to fetch and convert webpage content to markdown\"\"\"\n headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(url, headers=headers, timeout=10)\n response.raise_for_status()\n\n soup = BeautifulSoup(response.text, \"html.parser\")\n\n # Remove script and style elements\n for script in soup([\"script\", \"style\"]):\n script.decompose()\n\n # Convert relative URLs to absolute\n for tag in soup.find_all([\"a\", \"img\"]):\n if tag.get(\"href\"):\n tag[\"href\"] = urljoin(url, tag[\"href\"])\n if tag.get(\"src\"):\n tag[\"src\"] = urljoin(url, tag[\"src\"])\n\n h2t = html2text.HTML2Text()\n h2t.body_width = 0\n h2t.ignore_images = False\n h2t.ignore_emphasis = False\n h2t.ignore_links = False\n h2t.ignore_tables = False\n\n markdown = h2t.handle(str(soup))\n\n if max_length and len(markdown) > max_length:\n markdown = markdown[:max_length] + \"\\n...(truncated)\"\n\n return markdown.strip()\n\n except Exception as e:\n return f\"Error fetching content: {str(e)}\"\n\n params = {\n \"key\": api_key,\n \"cx\": cse_id,\n \"q\": query,\n \"num\": num_results,\n \"hl\": language,\n \"safe\": \"active\" if safe_search else \"off\",\n }\n\n if country:\n params[\"gl\"] = country\n\n try:\n async with httpx.AsyncClient() as client:\n response = await client.get(\"https://www.googleapis.com/customsearch/v1\", params=params, timeout=10)\n response.raise_for_status()\n data = response.json()\n\n results = []\n if \"items\" in data:\n for item in data[\"items\"]:\n result = {\"title\": item.get(\"title\", \"\"), \"link\": item.get(\"link\", \"\")}\n if include_snippets:\n result[\"snippet\"] = item.get(\"snippet\", \"\")\n\n if include_content:\n result[\"content\"] = await fetch_page_content(result[\"link\"], max_length=content_max_length)\n\n results.append(result)\n\n return results\n\n except httpx.RequestError as e:\n raise ValueError(f\"Failed to perform search: {str(e)}\") from e\n except KeyError as e:\n raise ValueError(f\"Invalid API response format: {str(e)}\") from e\n except Exception as e:\n raise ValueError(f\"Error during search: {str(e)}\") from e\n", "name": "google_search", "description": "\n Perform Google searches using the Custom Search API with optional webpage content fetching.\n Requires GOOGLE_API_KEY and GOOGLE_CSE_ID environment variables to be set.\n ", "global_imports": [ @@ -872,7 +892,8 @@ "description": "Terminate the conversation after a maximum number of messages have been exchanged.", "label": "MaxMessageTermination", "config": { - "max_messages": 10 + "max_messages": 10, + "include_agent_event": false } }, { @@ -903,7 +924,8 @@ "description": "Terminate the conversation after a maximum number of messages have been exchanged.", "label": "MaxMessageTermination", "config": { - "max_messages": 10 + "max_messages": 10, + "include_agent_event": false } } ] diff --git a/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx b/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx index 1f251958801e..6e35f26a6d15 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/gallery/store.tsx @@ -151,7 +151,7 @@ export const useGalleryStore = create()( }, }), { - name: "gallery-storage-v4", + name: "gallery-storage-v6", } ) ); diff --git a/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx b/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx index 65b277b6bfe1..40b419b052fa 100644 --- a/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx +++ b/python/packages/autogen-studio/frontend/src/components/views/team/builder/library.tsx @@ -95,7 +95,7 @@ export const ComponentLibrary: React.FC = () => { title: "Models", type: "model" as ComponentTypes, items: defaultGallery.items.components.models.map((model) => ({ - label: `${model.component_type} - ${model.config.model}`, + label: `${model.label || model.config.model}`, config: model, })), icon: , From b8c5e499b37cd496585adf9eb1cfb59e1f927204 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Sun, 9 Feb 2025 21:40:02 -0800 Subject: [PATCH 24/36] doc: improve m1 docs, remove duplicates (#5460) Resolves #5358 --------- Co-authored-by: Ryan Sweet --- .../agentchat-user-guide/magentic-one.md | 35 +++++- .../src/autogen_ext/teams/magentic_one.py | 4 +- .../packages/autogen-magentic-one/README.md | 100 ++---------------- 3 files changed, 39 insertions(+), 100 deletions(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/magentic-one.md b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/magentic-one.md index 5f9b29adf04f..194edee38c1a 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/magentic-one.md +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/magentic-one.md @@ -9,16 +9,12 @@ myst: [Magentic-One](https://aka.ms/magentic-one-blog) is a generalist multi-agent system for solving open-ended web and file-based tasks across a variety of domains. It represents a significant step forward for multi-agent systems, achieving competitive performance on a number of agentic benchmarks (see the [technical report](https://arxiv.org/abs/2411.04468) for full details). - When originally released in [November 2024](https://aka.ms/magentic-one-blog) Magentic-One was [implemented directly on the `autogen-core` library](https://github.com/microsoft/autogen/tree/v0.4.4/python/packages/autogen-magentic-one). We have now ported Magentic-One to use `autogen-agentchat`, providing a more modular and easier to use interface. - To this end, the Magentic-One orchestrator {py:class}`~autogen_agentchat.teams.MagenticOneGroupChat` is now simply an AgentChat team, supporting all standard AgentChat agents and features. Likewise, Magentic-One's {py:class}`~autogen_ext.agents.web_surfer.MultimodalWebSurfer`, {py:class}`~autogen_ext.agents.file_surfer.FileSurfer`, and {py:class}`~autogen_ext.agents.magentic_one.MagenticOneCoderAgent` agents are now broadly available as AgentChat agents, to be used in any AgentChat workflows. - Lastly, there is a helper class, {py:class}`~autogen_ext.teams.magentic_one.MagenticOne`, which bundles all of this together as it was in the paper with minimal configuration. - Find additional information about Magentic-one in our [blog post](https://aka.ms/magentic-one-blog) and [technical report](https://arxiv.org/abs/2411.04468). ![](../../images/autogen-magentic-one-example.png) @@ -40,6 +36,7 @@ Be aware that agents may occasionally attempt risky actions, such as recruiting ## Getting started Install the required packages: + ```bash pip install autogen-agentchat autogen-ext[magentic-one,openai] @@ -49,7 +46,9 @@ playwright install --with-deps chromium If you haven't done so already, go through the AgentChat tutorial to learn about the concepts of AgentChat. -Then, you can try swapping out a {py:class}`autogen_agentchat.teams.SelectorGroupChat` with {py:class}`~autogen_agentchat.teams.MagenticOneGroupChat`. For example: +Then, you can try swapping out a {py:class}`autogen_agentchat.teams.SelectorGroupChat` with {py:class}`~autogen_agentchat.teams.MagenticOneGroupChat`. + +For example: ```python import asyncio @@ -73,6 +72,8 @@ async def main() -> None: asyncio.run(main()) ``` +To use a different model, see [Models](./tutorial/models.ipynb) for more information. + Or, use the Magentic-One agents in a team: ```{caution} @@ -101,6 +102,27 @@ async def main() -> None: asyncio.run(main()) ``` +Or, use the {py:class}`~autogen_ext.teams.magentic_one.MagenticOne` helper class +with all the agents bundled together: + +```python +import asyncio +from autogen_ext.models.openai import OpenAIChatCompletionClient +from autogen_ext.teams.magentic_one import MagenticOne +from autogen_agentchat.ui import Console + + +async def example_usage(): + client = OpenAIChatCompletionClient(model="gpt-4o") + m1 = MagenticOne(client=client) + task = "Write a Python script to fetch data from an API." + result = await Console(m1.run_stream(task=task)) + print(result) + + +if __name__ == "__main__": + asyncio.run(example_usage()) +``` ## Architecture @@ -109,6 +131,7 @@ asyncio.run(main()) Magentic-One work is based on a multi-agent architecture where a lead Orchestrator agent is responsible for high-level planning, directing other agents and tracking task progress. The Orchestrator begins by creating a plan to tackle the task, gathering needed facts and educated guesses in a Task Ledger that is maintained. At each step of its plan, the Orchestrator creates a Progress Ledger where it self-reflects on task progress and checks whether the task is completed. If the task is not yet completed, it assigns one of Magentic-One other agents a subtask to complete. After the assigned agent completes its subtask, the Orchestrator updates the Progress Ledger and continues in this way until the task is complete. If the Orchestrator finds that progress is not being made for enough steps, it can update the Task Ledger and create a new plan. This is illustrated in the figure above; the Orchestrator work is thus divided into an outer loop where it updates the Task Ledger and an inner loop to update the Progress Ledger. Overall, Magentic-One consists of the following agents: + - Orchestrator: the lead agent responsible for task decomposition and planning, directing other agents in executing subtasks, tracking overall progress, and taking corrective actions as needed - WebSurfer: This is an LLM-based agent that is proficient in commanding and managing the state of a Chromium-based web browser. With each incoming request, the WebSurfer performs an action on the browser then reports on the new state of the web page The action space of the WebSurfer includes navigation (e.g. visiting a URL, performing a web search); web page actions (e.g., clicking and typing); and reading actions (e.g., summarizing or answering questions). The WebSurfer relies on the accessibility tree of the browser and on set-of-marks prompting to perform its actions. - FileSurfer: This is an LLM-based agent that commands a markdown-based file preview application to read local files of most types. The FileSurfer can also perform common navigation tasks such as listing the contents of directories and navigating a folder structure. @@ -122,6 +145,7 @@ While the default multimodal LLM we use for all agents is GPT-4o, Magentic-One i ## Citation ``` + @misc{fourney2024magenticonegeneralistmultiagentsolving, title={Magentic-One: A Generalist Multi-Agent System for Solving Complex Tasks}, author={Adam Fourney and Gagan Bansal and Hussein Mozannar and Cheng Tan and Eduardo Salinas and Erkang and Zhu and Friederike Niedtner and Grace Proebsting and Griffin Bassman and Jack Gerrits and Jacob Alber and Peter Chang and Ricky Loynd and Robert West and Victor Dibia and Ahmed Awadallah and Ece Kamar and Rafah Hosn and Saleema Amershi}, @@ -131,4 +155,5 @@ While the default multimodal LLM we use for all agents is GPT-4o, Magentic-One i primaryClass={cs.AI}, url={https://arxiv.org/abs/2411.04468}, } + ``` diff --git a/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py b/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py index 55f068d169d2..ec2d5192ab74 100644 --- a/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py +++ b/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py @@ -60,9 +60,9 @@ class MagenticOne(MagenticOneGroupChat): - WebSurfer: An LLM-based agent proficient in commanding and managing the state of a Chromium-based web browser. It performs actions on the browser and reports on the new state of the web page. - FileSurfer: An LLM-based agent that commands a markdown-based file preview application to read local files of most types. It can also perform common navigation tasks such as listing the contents of directories and navigating a folder structure. - Coder: An LLM-based agent specialized in writing code, analyzing information collected from other agents, or creating new artifacts. - - ComputerTerminal: Provides the team with access to a console shell where the Coder’s programs can be executed, and where new programming libraries can be installed. + - ComputerTerminal: Provides the team with access to a console shell where the Coder's programs can be executed, and where new programming libraries can be installed. - Together, Magentic-One’s agents provide the Orchestrator with the tools and capabilities needed to solve a broad variety of open-ended problems, as well as the ability to autonomously adapt to, and act in, dynamic and ever-changing web and file-system environments. + Together, Magentic-One's agents provide the Orchestrator with the tools and capabilities needed to solve a broad variety of open-ended problems, as well as the ability to autonomously adapt to, and act in, dynamic and ever-changing web and file-system environments. Examples: diff --git a/python/packages/autogen-magentic-one/README.md b/python/packages/autogen-magentic-one/README.md index 3b87e009d4e8..a1c14f6dd408 100644 --- a/python/packages/autogen-magentic-one/README.md +++ b/python/packages/autogen-magentic-one/README.md @@ -1,103 +1,17 @@ # Magentic-One -[Magentic-One](https://aka.ms/magentic-one-blog) is a generalist multi-agent system for solving open-ended web and file-based tasks across a variety of domains. It represents a significant step forward for multi-agent systems, achieving competitive performance on a number of agentic benchmarks (see the [technical report](https://arxiv.org/abs/2411.04468) for full details). - -When originally released in [November 2024](https://aka.ms/magentic-one-blog) Magentic-One was [implemented directly on the `autogen-core` library](https://github.com/microsoft/autogen/tree/v0.4.4/python/packages/autogen-magentic-one). We have now ported Magentic-One to use `autogen-agentchat`, providing a more modular and easier to use interface. To this end, the older implementation is deprecated, but can be accessed at [https://github.com/microsoft/autogen/tree/v0.4.4/python/packages/autogen-magentic-one](https://github.com/microsoft/autogen/tree/v0.4.4/python/packages/autogen-magentic-one). - -Moving forward, the Magentic-One orchestrator [MagenticOneGroupChat](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.teams.html#autogen_agentchat.teams.MagenticOneGroupChat) is now simply an AgentChat team, supporting all standard AgentChat agents and features. Likewise, Magentic-One's [MultimodalWebSurfer](https://microsoft.github.io/autogen/stable/reference/python/autogen_ext.agents.web_surfer.html#autogen_ext.agents.web_surfer.MultimodalWebSurfer), [FileSurfer](https://microsoft.github.io/autogen/stable/reference/python/autogen_ext.agents.file_surfer.html#autogen_ext.agents.file_surfer.FileSurfer), and [MagenticOneCoderAgent](https://microsoft.github.io/autogen/stable/reference/python/autogen_ext.agents.magentic_one.html#autogen_ext.agents.magentic_one.MagenticOneCoderAgent) agents are now broadly available as AgentChat agents, to be used in any AgentChat workflows. - -Lastly, there is a helper class, [MagenticOne](https://microsoft.github.io/autogen/stable/reference/python/autogen_ext.teams.magentic_one.html#autogen_ext.teams.magentic_one.MagenticOne), which bundles all of this together as it was in the paper with minimal configuration. - - -> [!CAUTION] -> Using Magentic-One involves interacting with a digital world designed for humans, which carries inherent risks. To minimize these risks, consider the following precautions: -> -> 1. **Use Containers**: Run all tasks in docker containers to isolate the agents and prevent direct system attacks. -> 2. **Virtual Environment**: Use a virtual environment to run the agents and prevent them from accessing sensitive data. -> 3. **Monitor Logs**: Closely monitor logs during and after execution to detect and mitigate risky behavior. -> 4. **Human Oversight**: Run the examples with a human in the loop to supervise the agents and prevent unintended consequences. -> 5. **Limit Access**: Restrict the agents' access to the internet and other resources to prevent unauthorized actions. -> 6. **Safeguard Data**: Ensure that the agents do not have access to sensitive data or resources that could be compromised. Do not share sensitive information with the agents. -> Be aware that agents may occasionally attempt risky actions, such as recruiting humans for help or accepting cookie agreements without human involvement. Always ensure agents are monitored and operate within a controlled environment to prevent unintended consequences. Moreover, be cautious that Magentic-One may be susceptible to prompt injection attacks from webpages. - -## Getting started - -Install the required packages: -```bash -pip install autogen-agentchat autogen-ext[magentic-one,openai] - -# If using the MultimodalWebSurfer, you also need to install playwright dependencies: -playwright install --with-deps chromium -``` - -If you haven't done so already, go through the AgentChat tutorial to learn about the concepts of AgentChat. - -Then, you can try swapping out a SelectorGroupChat with MagenticOneGroupChat. For example: - -```python -import asyncio -from autogen_ext.models.openai import OpenAIChatCompletionClient -from autogen_agentchat.agents import AssistantAgent -from autogen_agentchat.teams import MagenticOneGroupChat -from autogen_agentchat.ui import Console - - -async def main() -> None: - model_client = OpenAIChatCompletionClient(model="gpt-4o") +> Magentic-One is now available as part of the `autogen-agentchat` library. +> Please see the [user guide](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/magentic-one.html) for information. - assistant = AssistantAgent( - "Assistant", - model_client=model_client, - ) - team = MagenticOneGroupChat([assistant], model_client=model_client) - await Console(team.run_stream(task="Provide a different proof for Fermat's Last Theorem")) +> Looking for the original implementation of Magentic-One? It is available [here](https://github.com/microsoft/autogen/tree/v0.4.4/python/packages/autogen-magentic-one). +[Magentic-One](https://aka.ms/magentic-one-blog) is a generalist multi-agent system for solving open-ended web and file-based tasks across a variety of domains. It represents a significant step forward for multi-agent systems, achieving competitive performance on a number of agentic benchmarks (see the [technical report](https://arxiv.org/abs/2411.04468) for full details). -asyncio.run(main()) -``` - -Or, use the Magentic-One agents in a team: - -```{caution} -The example code may download files from the internet, execute code, and interact with web pages. Ensure you are in a safe environment before running the example code. -``` - -```python -import asyncio -from autogen_ext.models.openai import OpenAIChatCompletionClient -from autogen_agentchat.teams import MagenticOneGroupChat -from autogen_agentchat.ui import Console -from autogen_ext.agents.web_surfer import MultimodalWebSurfer - - -async def main() -> None: - model_client = OpenAIChatCompletionClient(model="gpt-4o") - - surfer = MultimodalWebSurfer( - "WebSurfer", - model_client=model_client, - ) - team = MagenticOneGroupChat([surfer], model_client=model_client) - await Console(team.run_stream(task="What is the UV index in Melbourne today?")) - - -asyncio.run(main()) -``` - -## Architecture - -Magentic-One work is based on a multi-agent architecture where a lead Orchestrator agent is responsible for high-level planning, directing other agents and tracking task progress. The Orchestrator begins by creating a plan to tackle the task, gathering needed facts and educated guesses in a Task Ledger that is maintained. At each step of its plan, the Orchestrator creates a Progress Ledger where it self-reflects on task progress and checks whether the task is completed. If the task is not yet completed, it assigns one of Magentic-One other agents a subtask to complete. After the assigned agent completes its subtask, the Orchestrator updates the Progress Ledger and continues in this way until the task is complete. If the Orchestrator finds that progress is not being made for enough steps, it can update the Task Ledger and create a new plan. This is illustrated in the figure above; the Orchestrator work is thus divided into an outer loop where it updates the Task Ledger and an inner loop to update the Progress Ledger. - -Overall, Magentic-One consists of the following agents: -- Orchestrator: the lead agent responsible for task decomposition and planning, directing other agents in executing subtasks, tracking overall progress, and taking corrective actions as needed -- WebSurfer: This is an LLM-based agent that is proficient in commanding and managing the state of a Chromium-based web browser. With each incoming request, the WebSurfer performs an action on the browser then reports on the new state of the web page The action space of the WebSurfer includes navigation (e.g. visiting a URL, performing a web search); web page actions (e.g., clicking and typing); and reading actions (e.g., summarizing or answering questions). The WebSurfer relies on the accessibility tree of the browser and on set-of-marks prompting to perform its actions. -- FileSurfer: This is an LLM-based agent that commands a markdown-based file preview application to read local files of most types. The FileSurfer can also perform common navigation tasks such as listing the contents of directories and navigating a folder structure. -- Coder: This is an LLM-based agent specialized through its system prompt for writing code, analyzing information collected from the other agents, or creating new artifacts. -- ComputerTerminal: Finally, ComputerTerminal provides the team with access to a console shell where the Coder’s programs can be executed, and where new programming libraries can be installed. +When originally released in [November 2024](https://aka.ms/magentic-one-blog) Magentic-One was [implemented directly on the `autogen-core` library](https://github.com/microsoft/autogen/tree/v0.4.4/python/packages/autogen-magentic-one). We have now ported Magentic-One to use `autogen-agentchat`, providing a more modular and easier to use interface. To this end, the older implementation is deprecated, but can be accessed at [https://github.com/microsoft/autogen/tree/v0.4.4/python/packages/autogen-magentic-one](https://github.com/microsoft/autogen/tree/v0.4.4/python/packages/autogen-magentic-one). -Together, Magentic-One’s agents provide the Orchestrator with the tools and capabilities that it needs to solve a broad variety of open-ended problems, as well as the ability to autonomously adapt to, and act in, dynamic and ever-changing web and file-system environments. +Moving forward, the Magentic-One orchestrator [MagenticOneGroupChat](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.teams.html#autogen_agentchat.teams.MagenticOneGroupChat) is now simply an AgentChat team, supporting all standard AgentChat agents and features. Likewise, Magentic-One's [MultimodalWebSurfer](https://microsoft.github.io/autogen/stable/reference/python/autogen_ext.agents.web_surfer.html#autogen_ext.agents.web_surfer.MultimodalWebSurfer), [FileSurfer](https://microsoft.github.io/autogen/stable/reference/python/autogen_ext.agents.file_surfer.html#autogen_ext.agents.file_surfer.FileSurfer), and [MagenticOneCoderAgent](https://microsoft.github.io/autogen/stable/reference/python/autogen_ext.teams.magentic_one.html) agents are now broadly available as AgentChat agents, to be used in any AgentChat workflows. -While the default multimodal LLM we use for all agents is GPT-4o, Magentic-One is model agnostic and can incorporate heterogonous models to support different capabilities or meet different cost requirements when getting tasks done. For example, it can use different LLMs and SLMs and their specialized versions to power different agents. We recommend a strong reasoning model for the Orchestrator agent such as GPT-4o. In a different configuration of Magentic-One, we also experiment with using OpenAI o1-preview for the outer loop of the Orchestrator and for the Coder, while other agents continue to use GPT-4o. +Lastly, there is a helper class, [MagenticOne](https://microsoft.github.io/autogen/stable/reference/python/autogen_ext.teams.magentic_one.html#autogen_ext.teams.magentic_one.MagenticOne), which bundles all of this together as it was in the paper with minimal configuration ## Citation From 7a772a2fcd38afa126829c2d4f3494383f47b23e Mon Sep 17 00:00:00 2001 From: wistuba Date: Mon, 10 Feb 2025 06:57:50 +0100 Subject: [PATCH 25/36] feat: add indictor for tool failure to FunctionExecutionResult (#5428) Some LLMs recieve an explicit signal about tool use failures. Closes #5273 Co-authored-by: Eric Zhu --- .../agents/_assistant_agent.py | 4 +-- .../tests/test_assistant_agent.py | 10 +++---- .../tests/test_group_chat.py | 4 +-- .../agentchat-user-guide/migration-guide.md | 1 + .../design-patterns/handoffs.ipynb | 8 +++-- .../src/autogen_core/models/_types.py | 1 + .../autogen_core/tool_agent/_caller_loop.py | 4 ++- .../autogen_core/tool_agent/_tool_agent.py | 2 +- .../autogen-core/tests/test_tool_agent.py | 2 +- .../agents/openai/_openai_assistant_agent.py | 30 +++++++++++-------- .../tests/models/test_openai_model_client.py | 8 ++--- 11 files changed, 43 insertions(+), 31 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index b3dda7175ae6..6bba99241652 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -557,9 +557,9 @@ async def _execute_tool_call( arguments = json.loads(tool_call.arguments) result = await tool.run_json(arguments, cancellation_token) result_as_str = tool.return_value_as_string(result) - return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id) + return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id, is_error=False) except Exception as e: - return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id) + return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id, is_error=True) async def on_reset(self, cancellation_token: CancellationToken) -> None: """Reset the assistant agent to its initialization state.""" diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 040f3e23e6ca..2d82158ef4cd 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -394,9 +394,9 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None: assert result.messages[1].models_usage.prompt_tokens == 10 assert isinstance(result.messages[2], ToolCallExecutionEvent) expected_content = [ - FunctionExecutionResult(call_id="1", content="pass"), - FunctionExecutionResult(call_id="2", content="pass"), - FunctionExecutionResult(call_id="3", content="task3"), + FunctionExecutionResult(call_id="1", content="pass", is_error=False), + FunctionExecutionResult(call_id="2", content="pass", is_error=False), + FunctionExecutionResult(call_id="3", content="task3", is_error=False), ] for expected in expected_content: assert expected in result.messages[2].content @@ -877,8 +877,8 @@ async def test_model_client_stream_with_tool_calls() -> None: FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'), ] assert message.messages[2].content == [ - FunctionExecutionResult(call_id="1", content="pass"), - FunctionExecutionResult(call_id="3", content="task"), + FunctionExecutionResult(call_id="1", content="pass", is_error=False), + FunctionExecutionResult(call_id="3", content="task", is_error=False), ] elif isinstance(message, ModelClientStreamingChunkEvent): chunks.append(message.content) diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 098d68349ca1..d12b24f04a1a 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -1134,8 +1134,8 @@ async def test_swarm_with_parallel_tool_calls(monkeypatch: pytest.MonkeyPatch) - ), FunctionExecutionResultMessage( content=[ - FunctionExecutionResult(content="tool1", call_id="1"), - FunctionExecutionResult(content="tool2", call_id="2"), + FunctionExecutionResult(content="tool1", call_id="1", is_error=False), + FunctionExecutionResult(content="tool2", call_id="2", is_error=False), ] ), ] diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/migration-guide.md b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/migration-guide.md index c8846859519a..fdb3ca2152d2 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/migration-guide.md +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/migration-guide.md @@ -763,6 +763,7 @@ def convert_to_v04_message(message: Dict[str, Any]) -> AgentEvent | ChatMessage: FunctionExecutionResult( call_id=tool_response["tool_call_id"], content=tool_response["content"], + is_error=False, ) ) return ToolCallExecutionEvent(source="tools", content=tool_results) diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/design-patterns/handoffs.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/design-patterns/handoffs.ipynb index 31885a10e977..92d3aa51959f 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/design-patterns/handoffs.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/design-patterns/handoffs.ipynb @@ -183,7 +183,9 @@ " # Execute the tool directly.\n", " result = await self._tools[call.name].run_json(arguments, ctx.cancellation_token)\n", " result_as_str = self._tools[call.name].return_value_as_string(result)\n", - " tool_call_results.append(FunctionExecutionResult(call_id=call.id, content=result_as_str))\n", + " tool_call_results.append(\n", + " FunctionExecutionResult(call_id=call.id, content=result_as_str, is_error=False)\n", + " )\n", " elif call.name in self._delegate_tools:\n", " # Execute the tool to get the delegate agent's topic type.\n", " result = await self._delegate_tools[call.name].run_json(arguments, ctx.cancellation_token)\n", @@ -194,7 +196,9 @@ " FunctionExecutionResultMessage(\n", " content=[\n", " FunctionExecutionResult(\n", - " call_id=call.id, content=f\"Transfered to {topic_type}. Adopt persona immediately.\"\n", + " call_id=call.id,\n", + " content=f\"Transferred to {topic_type}. Adopt persona immediately.\",\n", + " is_error=False,\n", " )\n", " ]\n", " ),\n", diff --git a/python/packages/autogen-core/src/autogen_core/models/_types.py b/python/packages/autogen-core/src/autogen_core/models/_types.py index 239af52bf141..22b1cff3bd40 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_types.py +++ b/python/packages/autogen-core/src/autogen_core/models/_types.py @@ -51,6 +51,7 @@ class FunctionExecutionResult(BaseModel): content: str call_id: str + is_error: bool | None = None class FunctionExecutionResultMessage(BaseModel): diff --git a/python/packages/autogen-core/src/autogen_core/tool_agent/_caller_loop.py b/python/packages/autogen-core/src/autogen_core/tool_agent/_caller_loop.py index 2353ca184f86..dbff8e05e552 100644 --- a/python/packages/autogen-core/src/autogen_core/tool_agent/_caller_loop.py +++ b/python/packages/autogen-core/src/autogen_core/tool_agent/_caller_loop.py @@ -62,7 +62,9 @@ async def tool_agent_caller_loop( if isinstance(result, FunctionExecutionResult): function_results.append(result) elif isinstance(result, ToolException): - function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id)) + function_results.append( + FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id, is_error=True) + ) elif isinstance(result, BaseException): raise result # Unexpected exception. generated_messages.append(FunctionExecutionResultMessage(content=function_results)) diff --git a/python/packages/autogen-core/src/autogen_core/tool_agent/_tool_agent.py b/python/packages/autogen-core/src/autogen_core/tool_agent/_tool_agent.py index 08d8f4b25376..2b1979b86a89 100644 --- a/python/packages/autogen-core/src/autogen_core/tool_agent/_tool_agent.py +++ b/python/packages/autogen-core/src/autogen_core/tool_agent/_tool_agent.py @@ -88,4 +88,4 @@ async def handle_function_call(self, message: FunctionCall, ctx: MessageContext) ) from e except Exception as e: raise ToolExecutionException(call_id=message.id, content=f"Error: {e}") from e - return FunctionExecutionResult(content=result_as_str, call_id=message.id) + return FunctionExecutionResult(content=result_as_str, call_id=message.id, is_error=False) diff --git a/python/packages/autogen-core/tests/test_tool_agent.py b/python/packages/autogen-core/tests/test_tool_agent.py index 85fcd3892c97..a0c1f9263f2c 100644 --- a/python/packages/autogen-core/tests/test_tool_agent.py +++ b/python/packages/autogen-core/tests/test_tool_agent.py @@ -61,7 +61,7 @@ async def test_tool_agent() -> None: result = await runtime.send_message( FunctionCall(id="1", arguments=json.dumps({"input": "pass"}), name="pass"), agent ) - assert result == FunctionExecutionResult(call_id="1", content="pass") + assert result == FunctionExecutionResult(call_id="1", content="pass", is_error=False) # Test raise function with pytest.raises(ToolExecutionException): diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/openai/_openai_assistant_agent.py b/python/packages/autogen-ext/src/autogen_ext/agents/openai/_openai_assistant_agent.py index 6da6755efc00..a74fe8bacdae 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/openai/_openai_assistant_agent.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/openai/_openai_assistant_agent.py @@ -381,17 +381,14 @@ def _thread_id(self) -> str: async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str: """Execute a tool call and return the result.""" - try: - if not self._original_tools: - raise ValueError("No tools are available.") - tool = next((t for t in self._original_tools if t.name == tool_call.name), None) - if tool is None: - raise ValueError(f"The tool '{tool_call.name}' is not available.") - arguments = json.loads(tool_call.arguments) - result = await tool.run_json(arguments, cancellation_token) - return tool.return_value_as_string(result) - except Exception as e: - return f"Error: {e}" + if not self._original_tools: + raise ValueError("No tools are available.") + tool = next((t for t in self._original_tools if t.name == tool_call.name), None) + if tool is None: + raise ValueError(f"The tool '{tool_call.name}' is not available.") + arguments = json.loads(tool_call.arguments) + result = await tool.run_json(arguments, cancellation_token) + return tool.return_value_as_string(result) async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: """Handle incoming messages and return a response.""" @@ -463,8 +460,15 @@ async def on_messages_stream( # Execute tool calls and get results tool_outputs: List[FunctionExecutionResult] = [] for tool_call in tool_calls: - result = await self._execute_tool_call(tool_call, cancellation_token) - tool_outputs.append(FunctionExecutionResult(content=result, call_id=tool_call.id)) + try: + result = await self._execute_tool_call(tool_call, cancellation_token) + is_error = False + except Exception as e: + result = f"Error: {e}" + is_error = True + tool_outputs.append( + FunctionExecutionResult(content=result, call_id=tool_call.id, is_error=is_error) + ) # Add tool result message to inner messages tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs) diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index edb7598961fb..74c4b2c08f3f 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -349,7 +349,7 @@ async def test_openai_chat_completion_client_count_tokens(monkeypatch: pytest.Mo ], source="user", ), - FunctionExecutionResultMessage(content=[FunctionExecutionResult(content="Hello", call_id="1")]), + FunctionExecutionResultMessage(content=[FunctionExecutionResult(content="Hello", call_id="1", is_error=False)]), ] def tool1(test: str, test2: str) -> str: @@ -902,7 +902,7 @@ async def _test_model_client_with_function_calling(model_client: OpenAIChatCompl messages.append(AssistantMessage(content=create_result.content, source="assistant")) messages.append( FunctionExecutionResultMessage( - content=[FunctionExecutionResult(content="passed", call_id=create_result.content[0].id)] + content=[FunctionExecutionResult(content="passed", call_id=create_result.content[0].id, is_error=False)] ) ) create_result = await model_client.create(messages=messages) @@ -932,8 +932,8 @@ async def _test_model_client_with_function_calling(model_client: OpenAIChatCompl messages.append( FunctionExecutionResultMessage( content=[ - FunctionExecutionResult(content="passed", call_id=create_result.content[0].id), - FunctionExecutionResult(content="failed", call_id=create_result.content[1].id), + FunctionExecutionResult(content="passed", call_id=create_result.content[0].id, is_error=False), + FunctionExecutionResult(content="failed", call_id=create_result.content[1].id, is_error=True), ] ) ) From 50d7587a4649504af3bb79ab928b2a3882a1a394 Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Mon, 10 Feb 2025 16:34:54 +1000 Subject: [PATCH 26/36] fix: Update SK kernel from tool to use method. (#5469) ## Why are these changes needed? The current implementation tries to recreate the metadata but it does it in an incomplete way. This PR uses SK built-in kernel from function decorator to infer the callable from the `run_json` and makes better use of the pydantic schemas for the input and output to infer the schema of the kernel function. ## Related issue number Closes #5458 ## Checks - [ ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --------- Co-authored-by: Leonardo Pinheiro Co-authored-by: Eric Zhu --- .../_sk_chat_completion_adapter.py | 65 +++++-- .../_kernel_function_from_tool.py | 162 ++++++------------ 2 files changed, 94 insertions(+), 133 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py index 26e2092dc430..0067ae30b5be 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py @@ -67,7 +67,7 @@ class SKChatCompletionAdapter(ChatCompletionClient): Examples: - Anthropic models: + Anthropic models with function calling: .. code-block:: bash @@ -79,13 +79,19 @@ class SKChatCompletionAdapter(ChatCompletionClient): import os from autogen_agentchat.agents import AssistantAgent - from autogen_core.models import UserMessage + from autogen_agentchat.ui import Console + from autogen_core.models import ModelFamily, UserMessage from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter from semantic_kernel import Kernel from semantic_kernel.connectors.ai.anthropic import AnthropicChatCompletion, AnthropicChatPromptExecutionSettings from semantic_kernel.memory.null_memory import NullMemory + async def get_weather(city: str) -> str: + \"\"\"Get the weather for a city.\"\"\" + return f"The weather in {city} is 75 degrees." + + async def main() -> None: sk_client = AnthropicChatCompletion( ai_model_id="claude-3-5-sonnet-20241022", @@ -96,24 +102,34 @@ async def main() -> None: temperature=0.2, ) - model_client = SKChatCompletionAdapter(sk_client, kernel=Kernel(memory=NullMemory()), prompt_settings=settings) + model_client = SKChatCompletionAdapter( + sk_client, + kernel=Kernel(memory=NullMemory()), + prompt_settings=settings, + model_info={ + "function_calling": True, + "json_output": True, + "vision": True, + "family": ModelFamily.CLAUDE_3_5_SONNET, + }, + ) # Call the model directly. - model_result = await model_client.create( - messages=[UserMessage(content="What is the capital of France?", source="User")] - ) - print(model_result) + response = await model_client.create([UserMessage(content="What is the capital of France?", source="test")]) + print(response) # Create an assistant agent with the model client. - assistant = AssistantAgent("assistant", model_client=model_client) + assistant = AssistantAgent( + "assistant", model_client=model_client, system_message="You are a helpful assistant.", tools=[get_weather] + ) # Call the assistant with a task. - result = await assistant.run(task="What is the capital of France?") - print(result) + await Console(assistant.run_stream(task="What is the weather in Paris and London?")) asyncio.run(main()) - Google Gemini models: + + Google Gemini models with function calling: .. code-block:: bash @@ -125,7 +141,8 @@ async def main() -> None: import os from autogen_agentchat.agents import AssistantAgent - from autogen_core.models import UserMessage + from autogen_agentchat.ui import Console + from autogen_core.models import UserMessage, ModelFamily from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter from semantic_kernel import Kernel from semantic_kernel.connectors.ai.google.google_ai import ( @@ -135,24 +152,31 @@ async def main() -> None: from semantic_kernel.memory.null_memory import NullMemory + def get_weather(city: str) -> str: + \"\"\"Get the weather for a city.\"\"\" + return f"The weather in {city} is 75 degrees." + + async def main() -> None: sk_client = GoogleAIChatCompletion( - gemini_model_id="gemini-1.5-flash", + gemini_model_id="gemini-2.0-flash", api_key=os.environ["GEMINI_API_KEY"], ) settings = GoogleAIChatPromptExecutionSettings( temperature=0.2, ) + kernel = Kernel(memory=NullMemory()) + model_client = SKChatCompletionAdapter( sk_client, - kernel=Kernel(memory=NullMemory()), + kernel=kernel, prompt_settings=settings, model_info={ - "family": "gemini-1.5-flash", + "family": ModelFamily.GEMINI_2_0_FLASH, "function_calling": True, "json_output": True, - "vision": False, + "vision": True, }, ) @@ -163,14 +187,17 @@ async def main() -> None: print(model_result) # Create an assistant agent with the model client. - assistant = AssistantAgent("assistant", model_client=model_client) + assistant = AssistantAgent( + "assistant", model_client=model_client, tools=[get_weather], system_message="You are a helpful assistant." + ) # Call the assistant with a task. - result = await assistant.run(task="What is the capital of France?") - print(result) + stream = assistant.run_stream(task="What is the weather in Paris and London?") + await Console(stream) asyncio.run(main()) + Ollama models: .. code-block:: bash diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py b/python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py index 1c2fec0e373d..8125aef9c5ce 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py @@ -1,133 +1,67 @@ -from typing import TypeVar +from typing import Any, TypeVar from autogen_core import CancellationToken from autogen_core.tools import BaseTool from pydantic import BaseModel -from semantic_kernel.exceptions import FunctionExecutionException -from semantic_kernel.filters.functions.function_invocation_context import FunctionInvocationContext -from semantic_kernel.functions.function_result import FunctionResult -from semantic_kernel.functions.kernel_function import KernelFunction -from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata +from semantic_kernel.functions import KernelFunctionFromMethod, kernel_function InputT = TypeVar("InputT", bound=BaseModel) OutputT = TypeVar("OutputT", bound=BaseModel) -class KernelFunctionFromTool(KernelFunction): - """ - KernelFunctionFromTool is an adapter that allows using Autogen tools as Semantic Kernel functions. - This makes it possible to integrate Autogen tools into Semantic Kernel when using the Semantic Kernel's - chat completion adapter or agent. - - By leveraging this adapter, you can: - - Convert any Autogen BaseTool into a Semantic Kernel KernelFunction - - Register the converted tool with a Semantic Kernel plugin - - Execute the tool through Semantic Kernel's function invocation mechanism - - Access tool metadata (name, description, parameters) through Semantic Kernel's metadata system - - Args: - tool (BaseTool[InputT, OutputT]): - The Autogen tool to wrap. Must be a subclass of BaseTool with Pydantic models for input/output. - plugin_name (str | None): - Optional name of the plugin this function belongs to. Defaults to None. - - Example usage: - .. code-block:: python - - from pydantic import BaseModel - from autogen_core.tools import BaseTool - from autogen_core import CancellationToken - from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool - from semantic_kernel.functions.kernel_plugin import KernelPlugin - from semantic_kernel.kernel import Kernel - - - # 1) Define input/output models - class CalculatorArgs(BaseModel): - a: float - b: float - - - class CalculatorResult(BaseModel): - result: float - - - # 2) Create an Autogen tool - class CalculatorTool(BaseTool[CalculatorArgs, CalculatorResult]): - def __init__(self) -> None: - super().__init__( - args_type=CalculatorArgs, - return_type=CalculatorResult, - name="calculator", - description="Add two numbers together", - ) - - async def run(self, args: CalculatorArgs, cancellation_token: CancellationToken) -> CalculatorResult: - return CalculatorResult(result=args.a + args.b) - - - # 3) Convert to Semantic Kernel function - calc_tool = CalculatorTool() - kernel_function = KernelFunctionFromTool(calc_tool, plugin_name="math") - - # 4) Add to Semantic Kernel plugin/kernel - plugin = KernelPlugin(name="math") - plugin.functions[calc_tool.name] = kernel_function - kernel = Kernel() - kernel.add_plugin(plugin) - """ - +class KernelFunctionFromTool(KernelFunctionFromMethod): def __init__(self, tool: BaseTool[InputT, OutputT], plugin_name: str | None = None): - # Build up KernelFunctionMetadata. You can also parse the tool's schema for parameters. - parameters = [ - KernelParameterMetadata( - name="args", - description="JSON arguments for the tool", - default_value=None, - type="dict", - type_object=dict, - is_required=True, + # Get the pydantic model types from the tool + args_type = tool.args_type() + return_type = tool.return_type() + + # 1) Define an async function that calls the tool + @kernel_function(name=tool.name, description=tool.description) + async def tool_method(**kwargs: dict[str, Any]) -> Any: + return await tool.run_json(kwargs, cancellation_token=CancellationToken()) + + # Parse schema for parameters + parameters_meta: list[KernelParameterMetadata] = [] + properties = tool.schema.get("parameters", {}).get("properties", {}) + + # Get the field types from the pydantic model + field_types = args_type.model_fields + + for prop_name, prop_info in properties.items(): + assert prop_name in field_types, f"Property {prop_name} not found in Tool {tool.name}" + assert isinstance(prop_info, dict), f"Property {prop_name} is not a dict in Tool {tool.name}" + + # Get the actual type from the pydantic model field + field_type = field_types[prop_name] + parameters_meta.append( + KernelParameterMetadata( + name=prop_name, + description=field_type.description or "", + default_value=field_type.get_default(), + type=prop_info.get("type", "string"), # type: ignore + type_object=field_type.annotation, + is_required=field_type.is_required(), + ) ) - ] - return_param = KernelParameterMetadata( + + # Create return parameter metadata + return_parameter = KernelParameterMetadata( name="return", - description="Result from the tool", + description=f"Result from '{tool.name}' tool", default_value=None, - type="str", - type_object=str, - is_required=False, + type="object" if issubclass(return_type, BaseModel) else "string", + type_object=return_type, + is_required=True, ) - metadata = KernelFunctionMetadata( - name=tool.name, - description=tool.description, - parameters=parameters, - return_parameter=return_param, - is_prompt=False, - is_asynchronous=True, + # Initialize the parent class + super().__init__( + method=tool_method, plugin_name=plugin_name, + parameters=parameters_meta, + return_parameter=return_parameter, + additional_metadata=None, ) - super().__init__(metadata=metadata) - self._tool = tool - async def _invoke_internal(self, context: FunctionInvocationContext) -> None: - # Extract the "args" parameter from the context - if "args" not in context.arguments: - raise FunctionExecutionException("Missing 'args' in FunctionInvocationContext.arguments") - tool_args = context.arguments - - # Call your tool's run_json - result = await self._tool.run_json(tool_args, cancellation_token=CancellationToken()) - - # Wrap in a FunctionResult - context.result = FunctionResult( - function=self.metadata, - value=result, - metadata={"used_arguments": tool_args}, - ) - - async def _invoke_internal_stream(self, context: FunctionInvocationContext) -> None: - # If you don't have a streaming mechanism in your tool, you can simply reuse _invoke_internal - # or raise NotImplementedError. For example: - await self._invoke_internal(context) + self._tool = tool From 378b5ac09a7f1d5735c137884ef7db9b663bc914 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 10 Feb 2025 11:22:23 -0800 Subject: [PATCH 27/36] Update version to 0.4.6 (#5477) --- python/packages/autogen-agentchat/pyproject.toml | 4 ++-- python/packages/autogen-core/pyproject.toml | 4 ++-- python/packages/autogen-ext/pyproject.toml | 12 ++++++------ python/uv.lock | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/packages/autogen-agentchat/pyproject.toml b/python/packages/autogen-agentchat/pyproject.toml index 95529c108676..102cd1a58d2c 100644 --- a/python/packages/autogen-agentchat/pyproject.toml +++ b/python/packages/autogen-agentchat/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "autogen-agentchat" -version = "0.4.5" +version = "0.4.6" license = {file = "LICENSE-CODE"} description = "AutoGen agents and teams library" readme = "README.md" @@ -15,7 +15,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "autogen-core==0.4.5", + "autogen-core==0.4.6", ] [tool.ruff] diff --git a/python/packages/autogen-core/pyproject.toml b/python/packages/autogen-core/pyproject.toml index 595696c5a394..21645d766e33 100644 --- a/python/packages/autogen-core/pyproject.toml +++ b/python/packages/autogen-core/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "autogen-core" -version = "0.4.5" +version = "0.4.6" license = {file = "LICENSE-CODE"} description = "Foundational interfaces and agent runtime implementation for AutoGen" readme = "README.md" @@ -69,7 +69,7 @@ dev = [ "pygments", "sphinxext-rediraffe", - "autogen_ext==0.4.5", + "autogen_ext==0.4.6", # Documentation tooling "diskcache", diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index f3e49b6ce992..8474de52a684 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "autogen-ext" -version = "0.4.5" +version = "0.4.6" license = {file = "LICENSE-CODE"} description = "AutoGen extensions library" readme = "README.md" @@ -15,7 +15,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "autogen-core==0.4.5", + "autogen-core==0.4.6", ] [project.optional-dependencies] @@ -28,24 +28,24 @@ azure = [ docker = ["docker~=7.0", "asyncio_atexit>=1.0.1"] openai = ["openai>=1.52.2", "tiktoken>=0.8.0", "aiofiles"] file-surfer = [ - "autogen-agentchat==0.4.5", + "autogen-agentchat==0.4.6", "markitdown>=0.0.1a2", ] graphrag = ["graphrag>=1.0.1"] web-surfer = [ - "autogen-agentchat==0.4.5", + "autogen-agentchat==0.4.6", "playwright>=1.48.0", "pillow>=11.0.0", "markitdown>=0.0.1a2", ] magentic-one = [ - "autogen-agentchat==0.4.5", + "autogen-agentchat==0.4.6", "markitdown>=0.0.1a2", "playwright>=1.48.0", "pillow>=11.0.0", ] video-surfer = [ - "autogen-agentchat==0.4.5", + "autogen-agentchat==0.4.6", "opencv-python>=4.5", "ffmpeg-python", "openai-whisper", diff --git a/python/uv.lock b/python/uv.lock index b4b1e170df60..78214455c815 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -440,7 +440,7 @@ wheels = [ [[package]] name = "autogen-agentchat" -version = "0.4.5" +version = "0.4.6" source = { editable = "packages/autogen-agentchat" } dependencies = [ { name = "autogen-core" }, @@ -451,7 +451,7 @@ requires-dist = [{ name = "autogen-core", editable = "packages/autogen-core" }] [[package]] name = "autogen-core" -version = "0.4.5" +version = "0.4.6" source = { editable = "packages/autogen-core" } dependencies = [ { name = "jsonref" }, @@ -570,7 +570,7 @@ dev = [ [[package]] name = "autogen-ext" -version = "0.4.5" +version = "0.4.6" source = { editable = "packages/autogen-ext" } dependencies = [ { name = "autogen-core" }, From 9e15e9529c9a73bd91807c598427408c7b862fba Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 10 Feb 2025 11:29:25 -0800 Subject: [PATCH 28/36] doc: improve agent tutorial to include multi-modal input. (#5471) Have seen discussion on Discord regarding confusion about multi-modal support in v0.4. This change adds a small note on how to use multi-modal messages with agents. --- .../tutorial/agents.ipynb | 87 +++++++++++++++++-- 1 file changed, 80 insertions(+), 7 deletions(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/agents.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/agents.ipynb index d1b90aa787a9..928100c05696 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/agents.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/agents.ipynb @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -75,15 +75,15 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[ToolCallRequestEvent(source='assistant', models_usage=RequestUsage(prompt_tokens=61, completion_tokens=15), content=[FunctionCall(id='call_hqVC7UJUPhKaiJwgVKkg66ak', arguments='{\"query\":\"AutoGen\"}', name='web_search')]), ToolCallExecutionEvent(source='assistant', models_usage=None, content=[FunctionExecutionResult(content='AutoGen is a programming framework for building multi-agent applications.', call_id='call_hqVC7UJUPhKaiJwgVKkg66ak')])]\n", - "source='assistant' models_usage=RequestUsage(prompt_tokens=92, completion_tokens=14) content='AutoGen is a programming framework designed for building multi-agent applications.'\n" + "[ToolCallRequestEvent(source='assistant', models_usage=RequestUsage(prompt_tokens=598, completion_tokens=16), content=[FunctionCall(id='call_9UWYM1CgE3ZbnJcSJavNDB79', arguments='{\"query\":\"AutoGen\"}', name='web_search')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant', models_usage=None, content=[FunctionExecutionResult(content='AutoGen is a programming framework for building multi-agent applications.', call_id='call_9UWYM1CgE3ZbnJcSJavNDB79', is_error=False)], type='ToolCallExecutionEvent')]\n", + "source='assistant' models_usage=None content='AutoGen is a programming framework for building multi-agent applications.' type='ToolCallSummaryMessage'\n" ] } ], @@ -128,6 +128,77 @@ "It follows the same interface as [Teams](./teams.ipynb) and returns a {py:class}`~autogen_agentchat.base.TaskResult` object." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi-Modal Input\n", + "\n", + "The {py:class}`~autogen_agentchat.agents.AssistantAgent` can handle multi-modal input\n", + "by providing the input as a {py:class}`~autogen_agentchat.messages.MultiModalMessage`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from io import BytesIO\n", + "\n", + "import PIL\n", + "import requests\n", + "from autogen_agentchat.messages import MultiModalMessage\n", + "from autogen_core import Image\n", + "\n", + "# Create a multi-modal message with random image and text.\n", + "pil_image = PIL.Image.open(BytesIO(requests.get(\"https://picsum.photos/300/200\").content))\n", + "img = Image(pil_image)\n", + "multi_modal_message = MultiModalMessage(content=[\"Can you describe the content of this image?\", img], source=\"user\")\n", + "img" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The image depicts a vintage car, likely from the 1930s or 1940s, with a sleek, classic design. The car seems to be customized or well-maintained, as indicated by its shiny exterior and lowered stance. It has a prominent grille and round headlights. There's a license plate on the front with the text \"FARMER BOY.\" The setting appears to be a street with old-style buildings in the background, suggesting a historical or retro theme.\n" + ] + } + ], + "source": [ + "# Use asyncio.run(...) when running in a script.\n", + "response = await agent.on_messages([multi_modal_message], CancellationToken())\n", + "print(response.chat_message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also use {py:class}`~autogen_agentchat.messages.MultiModalMessage` as a `task`\n", + "input to the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run` method." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -179,7 +250,8 @@ " agent.on_messages_stream(\n", " [TextMessage(content=\"Find information on AutoGen\", source=\"user\")],\n", " cancellation_token=CancellationToken(),\n", - " )\n", + " ),\n", + " output_stats=True, # Enable stats printing.\n", " )\n", "\n", "\n", @@ -241,7 +313,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -287,7 +359,8 @@ "await Console(\n", " agent.on_messages_stream(\n", " [TextMessage(content=\"What's the average age of the passengers?\", source=\"user\")], CancellationToken()\n", - " )\n", + " ),\n", + " output_stats=True,\n", ")" ] }, From 8a9f4521365dba3edd89536bb2ac06f567a647fa Mon Sep 17 00:00:00 2001 From: Eitan Yarmush Date: Mon, 10 Feb 2025 15:27:27 -0500 Subject: [PATCH 29/36] Adding declarative HTTP tools to autogen ext (#5181) ## Why are these changes needed? These changes are needed because currently there's no generic way to add `tools` to autogen studio workflows using the existing DSL and schema other than inline python. This API will be quite verbose, and lacks a discovery mechanism, but it unlocks a lot of programmatic use-cases. ## Related issue number https://github.com/microsoft/autogen/issues/5170 Co-authored-by: Victor Dibia Co-authored-by: Eric Zhu --- .../autogen-core/docs/src/reference/index.md | 7 +- .../python/autogen_ext.tools.http.rst | 8 + python/packages/autogen-ext/pyproject.toml | 5 + .../src/autogen_ext/tools/http/__init__.py | 3 + .../src/autogen_ext/tools/http/_http_tool.py | 233 ++++++++++++++++++ python/packages/autogen-ext/tests/__init__.py | 0 .../autogen-ext/tests/test_worker_runtime.py | 5 +- .../autogen-ext/tests/tools/__init__.py | 0 .../tests/tools/graphrag/__init__.py | 0 .../tests/tools/{ => graphrag}/conftest.py | 0 .../{ => graphrag}/test_graphrag_tools.py | 0 .../autogen-ext/tests/tools/http/__init__.py | 0 .../autogen-ext/tests/tools/http/conftest.py | 101 ++++++++ .../tests/tools/http/test_http_tool.py | 202 +++++++++++++++ python/uv.lock | 6 + 15 files changed, 565 insertions(+), 5 deletions(-) create mode 100644 python/packages/autogen-core/docs/src/reference/python/autogen_ext.tools.http.rst create mode 100644 python/packages/autogen-ext/src/autogen_ext/tools/http/__init__.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/tools/http/_http_tool.py create mode 100644 python/packages/autogen-ext/tests/__init__.py create mode 100644 python/packages/autogen-ext/tests/tools/__init__.py create mode 100644 python/packages/autogen-ext/tests/tools/graphrag/__init__.py rename python/packages/autogen-ext/tests/tools/{ => graphrag}/conftest.py (100%) rename python/packages/autogen-ext/tests/tools/{ => graphrag}/test_graphrag_tools.py (100%) create mode 100644 python/packages/autogen-ext/tests/tools/http/__init__.py create mode 100644 python/packages/autogen-ext/tests/tools/http/conftest.py create mode 100644 python/packages/autogen-ext/tests/tools/http/test_http_tool.py diff --git a/python/packages/autogen-core/docs/src/reference/index.md b/python/packages/autogen-core/docs/src/reference/index.md index 762e9f07ba35..f9dc5c1a19b1 100644 --- a/python/packages/autogen-core/docs/src/reference/index.md +++ b/python/packages/autogen-core/docs/src/reference/index.md @@ -46,17 +46,17 @@ python/autogen_ext.agents.web_surfer python/autogen_ext.agents.file_surfer python/autogen_ext.agents.video_surfer python/autogen_ext.agents.video_surfer.tools -python/autogen_ext.auth.azure python/autogen_ext.teams.magentic_one python/autogen_ext.models.cache python/autogen_ext.models.openai python/autogen_ext.models.replay python/autogen_ext.models.azure python/autogen_ext.models.semantic_kernel +python/autogen_ext.tools.code_execution +python/autogen_ext.tools.graphrag +python/autogen_ext.tools.http python/autogen_ext.tools.langchain python/autogen_ext.tools.mcp -python/autogen_ext.tools.graphrag -python/autogen_ext.tools.code_execution python/autogen_ext.tools.semantic_kernel python/autogen_ext.code_executors.local python/autogen_ext.code_executors.docker @@ -65,4 +65,5 @@ python/autogen_ext.code_executors.azure python/autogen_ext.cache_store.diskcache python/autogen_ext.cache_store.redis python/autogen_ext.runtimes.grpc +python/autogen_ext.auth.azure ``` diff --git a/python/packages/autogen-core/docs/src/reference/python/autogen_ext.tools.http.rst b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.tools.http.rst new file mode 100644 index 000000000000..7450fb9bfe08 --- /dev/null +++ b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.tools.http.rst @@ -0,0 +1,8 @@ +autogen\_ext.tools.http +======================= + + +.. automodule:: autogen_ext.tools.http + :members: + :undoc-members: + :show-inheritance: diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index 8474de52a684..bf86c3e4004e 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -106,6 +106,11 @@ semantic-kernel-dapr = [ "semantic-kernel[dapr]>=1.17.1", ] +http-tool = [ + "httpx>=0.27.0", + "json-schema-to-pydantic>=0.2.0" +] + semantic-kernel-all = [ "semantic-kernel[google,hugging_face,mistralai,ollama,onnx,anthropic,usearch,pandas,aws,dapr]>=1.17.1", ] diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/http/__init__.py b/python/packages/autogen-ext/src/autogen_ext/tools/http/__init__.py new file mode 100644 index 000000000000..6c276b625e3f --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/tools/http/__init__.py @@ -0,0 +1,3 @@ +from ._http_tool import HttpTool + +__all__ = ["HttpTool"] diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/http/_http_tool.py b/python/packages/autogen-ext/src/autogen_ext/tools/http/_http_tool.py new file mode 100644 index 000000000000..b66327ffa2c9 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/tools/http/_http_tool.py @@ -0,0 +1,233 @@ +import re +from typing import Any, Literal, Optional, Type + +import httpx +from autogen_core import CancellationToken, Component +from autogen_core.tools import BaseTool +from json_schema_to_pydantic import create_model +from pydantic import BaseModel, Field +from typing_extensions import Self + + +class HttpToolConfig(BaseModel): + name: str + """ + The name of the tool. + """ + description: Optional[str] + """ + A description of the tool. + """ + scheme: Literal["http", "https"] = "http" + """ + The scheme to use for the request. + """ + host: str + """ + The URL to send the request to. + """ + port: int + """ + The port to send the request to. + """ + path: str = Field(default="/") + """ + The path to send the request to. defaults to "/" + The path can accept parameters, e.g. "/{param1}/{param2}". + These parameters will be templated from the inputs args, any additional parameters will be added as query parameters or the body of the request. + """ + method: Optional[Literal["GET", "POST", "PUT", "DELETE", "PATCH"]] = "POST" + """ + The HTTP method to use, will default to POST if not provided. + """ + headers: Optional[dict[str, Any]] + """ + A dictionary of headers to send with the request. + """ + json_schema: dict[str, Any] + """ + A JSON Schema object defining the expected parameters for the tool. + Path parameters MUST also be included in the json_schema. They must also MUST be set to string + """ + return_type: Optional[Literal["text", "json"]] = "text" + """ + The type of response to return from the tool. + """ + + +class HttpTool(BaseTool[BaseModel, Any], Component[HttpToolConfig]): + """A wrapper for using an HTTP server as a tool. + + Args: + name (str): The name of the tool. + description (str, optional): A description of the tool. + scheme (str): The scheme to use for the request. Must be either "http" or "https". + host (str): The host to send the request to. + port (int): The port to send the request to. + path (str, optional): The path to send the request to. Defaults to "/". + Can include path parameters like "/{param1}/{param2}" which will be templated from input args. + method (str, optional): The HTTP method to use, will default to POST if not provided. + Must be one of "GET", "POST", "PUT", "DELETE", "PATCH". + headers (dict[str, Any], optional): A dictionary of headers to send with the request. + json_schema (dict[str, Any]): A JSON Schema object defining the expected parameters for the tool. + Path parameters must also be included in the schema and must be strings. + return_type (Literal["text", "json"], optional): The type of response to return from the tool. + Defaults to "text". + + .. note:: + This tool requires the :code:`http-tool` extra for the :code:`autogen-ext` package. + + To install: + + .. code-block:: bash + + pip install -U "autogen-agentchat" "autogen-ext[http-tool]" + + Example: + Simple use case:: + + import asyncio + + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.messages import TextMessage + from autogen_core import CancellationToken + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_ext.tools.http import HttpTool + + # Define a JSON schema for a base64 decode tool + base64_schema = { + "type": "object", + "properties": { + "value": {"type": "string", "description": "The base64 value to decode"}, + }, + "required": ["value"], + } + + # Create an HTTP tool for the httpbin API + base64_tool = HttpTool( + name="base64_decode", + description="base64 decode a value", + scheme="https", + host="httpbin.org", + port=443, + path="/base64/{value}", + method="GET", + json_schema=base64_schema, + ) + + + async def main(): + # Create an assistant with the base64 tool + model = OpenAIChatCompletionClient(model="gpt-4") + assistant = AssistantAgent("base64_assistant", model_client=model, tools=[base64_tool]) + + # The assistant can now use the base64 tool to decode the string + response = await assistant.on_messages( + [TextMessage(content="Can you base64 decode the value 'YWJjZGU=', please?", source="user")], + CancellationToken(), + ) + print(response.chat_message.content) + + + asyncio.run(main()) + """ + + component_type = "tool" + component_provider_override = "autogen_ext.tools.http.HttpTool" + component_config_schema = HttpToolConfig + + def __init__( + self, + name: str, + host: str, + port: int, + json_schema: dict[str, Any], + headers: Optional[dict[str, Any]] = None, + description: str = "HTTP tool", + path: str = "/", + scheme: Literal["http", "https"] = "http", + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "POST", + return_type: Literal["text", "json"] = "text", + ) -> None: + self.server_params = HttpToolConfig( + name=name, + description=description, + host=host, + port=port, + path=path, + scheme=scheme, + method=method, + headers=headers, + json_schema=json_schema, + return_type=return_type, + ) + + # Use regex to find all path parameters, we will need those later to template the path + path_params = {match.group(1) for match in re.finditer(r"{([^}]*)}", path)} + self._path_params = path_params + + # Create the input model from the modified schema + input_model = create_model(json_schema) + + # Use Any as return type since HTTP responses can vary + base_return_type: Type[Any] = object + + super().__init__(input_model, base_return_type, name, description) + + def _to_config(self) -> HttpToolConfig: + copied_config = self.server_params.model_copy() + return copied_config + + @classmethod + def _from_config(cls, config: HttpToolConfig) -> Self: + copied_config = config.model_copy().model_dump() + return cls(**copied_config) + + async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any: + """Execute the HTTP tool with the given arguments. + + Args: + args: The validated input arguments + cancellation_token: Token for cancelling the operation + + Returns: + The response body from the HTTP call in JSON format + + Raises: + Exception: If tool execution fails + """ + + model_dump = args.model_dump() + path_params = {k: v for k, v in model_dump.items() if k in self._path_params} + # Remove path params from the model dump + for k in self._path_params: + model_dump.pop(k) + + path = self.server_params.path.format(**path_params) + + url = httpx.URL( + scheme=self.server_params.scheme, + host=self.server_params.host, + port=self.server_params.port, + path=path, + ) + async with httpx.AsyncClient() as client: + match self.server_params.method: + case "GET": + response = await client.get(url, params=model_dump) + case "PUT": + response = await client.put(url, json=model_dump) + case "DELETE": + response = await client.delete(url, params=model_dump) + case "PATCH": + response = await client.patch(url, json=model_dump) + case _: # Default case POST + response = await client.post(url, json=model_dump) + + match self.server_params.return_type: + case "text": + return response.text + case "json": + return response.json() + case _: + raise ValueError(f"Invalid return type: {self.server_params.return_type}") diff --git a/python/packages/autogen-ext/tests/__init__.py b/python/packages/autogen-ext/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index 77101d1344d3..88c9d1e0bf02 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -29,7 +29,8 @@ MessageType, NoopAgent, ) -from protos.serialization_test_pb2 import ProtoMessage + +from .protos.serialization_test_pb2 import ProtoMessage @pytest.mark.grpc @@ -423,7 +424,7 @@ def __init__(self) -> None: self.received_messages: list[Any] = [] @event - async def on_new_message(self, message: ProtoMessage, ctx: MessageContext) -> None: + async def on_new_message(self, message: ProtoMessage, ctx: MessageContext) -> None: # type: ignore self.num_calls += 1 self.received_messages.append(message) diff --git a/python/packages/autogen-ext/tests/tools/__init__.py b/python/packages/autogen-ext/tests/tools/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/packages/autogen-ext/tests/tools/graphrag/__init__.py b/python/packages/autogen-ext/tests/tools/graphrag/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/packages/autogen-ext/tests/tools/conftest.py b/python/packages/autogen-ext/tests/tools/graphrag/conftest.py similarity index 100% rename from python/packages/autogen-ext/tests/tools/conftest.py rename to python/packages/autogen-ext/tests/tools/graphrag/conftest.py diff --git a/python/packages/autogen-ext/tests/tools/test_graphrag_tools.py b/python/packages/autogen-ext/tests/tools/graphrag/test_graphrag_tools.py similarity index 100% rename from python/packages/autogen-ext/tests/tools/test_graphrag_tools.py rename to python/packages/autogen-ext/tests/tools/graphrag/test_graphrag_tools.py diff --git a/python/packages/autogen-ext/tests/tools/http/__init__.py b/python/packages/autogen-ext/tests/tools/http/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/packages/autogen-ext/tests/tools/http/conftest.py b/python/packages/autogen-ext/tests/tools/http/conftest.py new file mode 100644 index 000000000000..74ea64a91465 --- /dev/null +++ b/python/packages/autogen-ext/tests/tools/http/conftest.py @@ -0,0 +1,101 @@ +import asyncio +from typing import Any, AsyncGenerator, Dict + +import pytest +import pytest_asyncio +import uvicorn +from autogen_core import ComponentModel +from fastapi import FastAPI +from pydantic import BaseModel, Field + + +class TestArgs(BaseModel): + query: str = Field(description="The test query") + value: int = Field(description="A test value") + + +class TestResponse(BaseModel): + result: str = Field(description="The test result") + + +# Create a test FastAPI app +app = FastAPI() + + +@app.post("/test") +async def test_endpoint(body: TestArgs) -> TestResponse: + return TestResponse(result=f"Received: {body.query} with value {body.value}") + + +@app.post("/test/{query}/{value}") +async def test_path_params_endpoint(query: str, value: int) -> TestResponse: + return TestResponse(result=f"Received: {query} with value {value}") + + +@app.put("/test/{query}/{value}") +async def test_path_params_and_body_endpoint(query: str, value: int, body: Dict[str, Any]) -> TestResponse: + return TestResponse(result=f"Received: {query} with value {value} and extra {body.get('extra')}") # type: ignore + + +@app.get("/test") +async def test_get_endpoint(query: str, value: int) -> TestResponse: + return TestResponse(result=f"Received: {query} with value {value}") + + +@app.put("/test") +async def test_put_endpoint(body: TestArgs) -> TestResponse: + return TestResponse(result=f"Received: {body.query} with value {body.value}") + + +@app.delete("/test") +async def test_delete_endpoint(query: str, value: int) -> TestResponse: + return TestResponse(result=f"Received: {query} with value {value}") + + +@app.patch("/test") +async def test_patch_endpoint(body: TestArgs) -> TestResponse: + return TestResponse(result=f"Received: {body.query} with value {body.value}") + + +@pytest.fixture +def test_config() -> ComponentModel: + return ComponentModel( + provider="autogen_ext.tools.http.HttpTool", + config={ + "name": "TestHttpTool", + "description": "A test HTTP tool", + "scheme": "http", + "path": "/test", + "host": "localhost", + "port": 8000, + "method": "POST", + "headers": {"Content-Type": "application/json"}, + "json_schema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "The test query"}, + "value": {"type": "integer", "description": "A test value"}, + }, + "required": ["query", "value"], + }, + }, + ) + + +@pytest_asyncio.fixture(scope="function") # type: ignore +async def test_server() -> AsyncGenerator[None, None]: + # Start the test server + config = uvicorn.Config(app, host="127.0.0.1", port=8000, log_level="error") + server = uvicorn.Server(config) + + # Create a task for the server + server_task = asyncio.create_task(server.serve()) + + # Wait a bit for server to start + await asyncio.sleep(0.5) # Increased sleep time to ensure server is ready + + yield + + # Cleanup + server.should_exit = True + await server_task diff --git a/python/packages/autogen-ext/tests/tools/http/test_http_tool.py b/python/packages/autogen-ext/tests/tools/http/test_http_tool.py new file mode 100644 index 000000000000..3756c48474d0 --- /dev/null +++ b/python/packages/autogen-ext/tests/tools/http/test_http_tool.py @@ -0,0 +1,202 @@ +import json + +import httpx +import pytest +from autogen_core import CancellationToken, Component, ComponentModel +from autogen_ext.tools.http import HttpTool +from pydantic import ValidationError + + +def test_tool_schema_generation(test_config: ComponentModel) -> None: + tool = HttpTool.load_component(test_config) + schema = tool.schema + + assert schema["name"] == "TestHttpTool" + assert "description" in schema + assert schema["description"] == "A test HTTP tool" + assert "parameters" in schema + assert schema["parameters"]["type"] == "object" + assert "properties" in schema["parameters"] + assert schema["parameters"]["properties"]["query"]["description"] == "The test query" + assert schema["parameters"]["properties"]["query"]["type"] == "string" + assert schema["parameters"]["properties"]["value"]["description"] == "A test value" + assert schema["parameters"]["properties"]["value"]["type"] == "integer" + assert "required" in schema["parameters"] + assert set(schema["parameters"]["required"]) == {"query", "value"} + + +def test_tool_properties(test_config: ComponentModel) -> None: + tool = HttpTool.load_component(test_config) + + assert tool.name == "TestHttpTool" + assert tool.description == "A test HTTP tool" + assert tool.server_params.host == "localhost" + assert tool.server_params.port == 8000 + assert tool.server_params.path == "/test" + assert tool.server_params.scheme == "http" + assert tool.server_params.method == "POST" + + +def test_component_base_class(test_config: ComponentModel) -> None: + tool = HttpTool.load_component(test_config) + assert tool.dump_component() is not None + assert HttpTool.load_component(tool.dump_component(), HttpTool) is not None + assert isinstance(tool, Component) + + +@pytest.mark.asyncio +async def test_post_request(test_config: ComponentModel, test_server: None) -> None: + tool = HttpTool.load_component(test_config) + result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken()) + + assert isinstance(result, str) + assert json.loads(result)["result"] == "Received: test query with value 42" + + +@pytest.mark.asyncio +async def test_post_request_json_return(test_config: ComponentModel, test_server: None) -> None: + # Modify config to use json return type + config = test_config.model_copy() + config.config["return_type"] = "json" + tool = HttpTool.load_component(config) + result = await tool.run_json({"query": "test query", "value": 45}, CancellationToken()) + + assert isinstance(result, dict) + assert result["result"] == "Received: test query with value 45" + + +@pytest.mark.asyncio +async def test_get_request(test_config: ComponentModel, test_server: None) -> None: + # Modify config for GET request + config = test_config.model_copy() + config.config["method"] = "GET" + tool = HttpTool.load_component(config) + + result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken()) + + assert isinstance(result, str) + assert json.loads(result)["result"] == "Received: test query with value 42" + + +@pytest.mark.asyncio +async def test_put_request(test_config: ComponentModel, test_server: None) -> None: + # Modify config for PUT request + config = test_config.model_copy() + config.config["method"] = "PUT" + tool = HttpTool.load_component(config) + + result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken()) + + assert isinstance(result, str) + assert json.loads(result)["result"] == "Received: test query with value 42" + + +@pytest.mark.asyncio +async def test_path_params(test_config: ComponentModel, test_server: None) -> None: + # Modify config to use path parameters + config = test_config.model_copy() + config.config["path"] = "/test/{query}/{value}" + tool = HttpTool.load_component(config) + + result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken()) + + assert isinstance(result, str) + assert json.loads(result)["result"] == "Received: test query with value 42" + + +@pytest.mark.asyncio +async def test_path_params_and_body(test_config: ComponentModel, test_server: None) -> None: + # Modify config to use path parameters and include body parameters + config = test_config.model_copy() + config.config["method"] = "PUT" + config.config["path"] = "/test/{query}/{value}" + config.config["json_schema"] = { + "type": "object", + "properties": { + "query": {"type": "string", "description": "The test query"}, + "value": {"type": "integer", "description": "A test value"}, + "extra": {"type": "string", "description": "Extra body parameter"}, + }, + "required": ["query", "value", "extra"], + } + tool = HttpTool.load_component(config) + + result = await tool.run_json({"query": "test query", "value": 42, "extra": "extra data"}, CancellationToken()) + + assert isinstance(result, str) + assert json.loads(result)["result"] == "Received: test query with value 42 and extra extra data" + + +@pytest.mark.asyncio +async def test_delete_request(test_config: ComponentModel, test_server: None) -> None: + # Modify config for DELETE request + config = test_config.model_copy() + config.config["method"] = "DELETE" + tool = HttpTool.load_component(config) + + result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken()) + + assert isinstance(result, str) + assert json.loads(result)["result"] == "Received: test query with value 42" + + +@pytest.mark.asyncio +async def test_patch_request(test_config: ComponentModel, test_server: None) -> None: + # Modify config for PATCH request + config = test_config.model_copy() + config.config["method"] = "PATCH" + tool = HttpTool.load_component(config) + + result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken()) + + assert isinstance(result, str) + assert json.loads(result)["result"] == "Received: test query with value 42" + + +@pytest.mark.asyncio +async def test_invalid_schema(test_config: ComponentModel, test_server: None) -> None: + # Create an invalid schema missing required properties + config: ComponentModel = test_config.model_copy() + config.config["host"] = True # Incorrect type + + with pytest.raises(ValidationError): + # Should fail when trying to create model from invalid schema + HttpTool.load_component(config) + + +@pytest.mark.asyncio +async def test_invalid_request(test_config: ComponentModel, test_server: None) -> None: + # Use an invalid URL + config = test_config.model_copy() + config.config["host"] = "fake" + tool = HttpTool.load_component(config) + + with pytest.raises(httpx.ConnectError): + await tool.run_json({"query": "test query", "value": 42}, CancellationToken()) + + +def test_config_serialization(test_config: ComponentModel) -> None: + tool = HttpTool.load_component(test_config) + config = tool.dump_component() + + assert config.config["name"] == test_config.config["name"] + assert config.config["description"] == test_config.config["description"] + assert config.config["host"] == test_config.config["host"] + assert config.config["port"] == test_config.config["port"] + assert config.config["path"] == test_config.config["path"] + assert config.config["scheme"] == test_config.config["scheme"] + assert config.config["method"] == test_config.config["method"] + assert config.config["headers"] == test_config.config["headers"] + + +def test_config_deserialization(test_config: ComponentModel) -> None: + tool = HttpTool.load_component(test_config) + + assert tool.name == test_config.config["name"] + assert tool.description == test_config.config["description"] + assert tool.server_params.host == test_config.config["host"] + assert tool.server_params.port == test_config.config["port"] + assert tool.server_params.path == test_config.config["path"] + assert tool.server_params.scheme == test_config.config["scheme"] + assert tool.server_params.method == test_config.config["method"] + assert tool.server_params.headers == test_config.config["headers"] diff --git a/python/uv.lock b/python/uv.lock index 78214455c815..17c0b76c4d4a 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -599,6 +599,10 @@ graphrag = [ grpc = [ { name = "grpcio" }, ] +http-tool = [ + { name = "httpx" }, + { name = "json-schema-to-pydantic" }, +] jupyter-executor = [ { name = "ipykernel" }, { name = "nbclient" }, @@ -698,7 +702,9 @@ requires-dist = [ { name = "ffmpeg-python", marker = "extra == 'video-surfer'" }, { name = "graphrag", marker = "extra == 'graphrag'", specifier = ">=1.0.1" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = "~=1.70.0" }, + { name = "httpx", marker = "extra == 'http-tool'", specifier = ">=0.27.0" }, { name = "ipykernel", marker = "extra == 'jupyter-executor'", specifier = ">=6.29.5" }, + { name = "json-schema-to-pydantic", marker = "extra == 'http-tool'", specifier = ">=0.2.0" }, { name = "json-schema-to-pydantic", marker = "extra == 'mcp'", specifier = ">=0.2.2" }, { name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" }, { name = "markitdown", marker = "extra == 'file-surfer'", specifier = ">=0.0.1a2" }, From 6db946bc89bc472cf85ceb98c8b471243d042219 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 10 Feb 2025 12:43:53 -0800 Subject: [PATCH 30/36] doc: enhance extensions user guide with component examples (#5480) --- .../src/user-guide/extensions-user-guide/index.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/extensions-user-guide/index.md b/python/packages/autogen-core/docs/src/user-guide/extensions-user-guide/index.md index 291138af07b0..a3eb585626ba 100644 --- a/python/packages/autogen-core/docs/src/user-guide/extensions-user-guide/index.md +++ b/python/packages/autogen-core/docs/src/user-guide/extensions-user-guide/index.md @@ -24,8 +24,19 @@ create-your-own azure-container-code-executor ``` -AutoGen is designed to be extensible. The `autogen-ext` package contains many different component implementations maintained by the AutoGen project. However, we strongly encourage others to build their own components and publish them as part of the ecosytem. +AutoGen is designed to be extensible. The `autogen-ext` package contains the built-in component implementations maintained by the AutoGen project. +Examples of components include: + +- `autogen_ext.agents.*` for agent implementations like {py:class}`~autogen_ext.agents.web_surfer.MultimodalWebSurfer` +- `autogen_ext.models.*` for model clients like {py:class}`~autogen_ext.models.openai.OpenAIChatCompletionClient` and {py:class}`~autogen_ext.models.semantic_kernel.SKChatCompletionAdapter` for connecting to hosted and local models. +- `autogen_ext.tools.*` for tools like GraphRAG {py:class}`~autogen_ext.tools.graphrag.LocalSearchTool` and {py:func}`~autogen_ext.tools.mcp.mcp_server_tools`. +- `autogen_ext.executors.*` for executors like {py:class}`~autogen_ext.code_executors.docker.DockerCommandLineCodeExecutor` and {py:class}`~autogen_ext.code_executors.azure.ACADynamicSessionsCodeExecutor` +- `autogen_ext.runtimes.*` for agent runtimes like {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime` + +See [API Reference](../../reference/index.md) for the full list of components and their APIs. + +We strongly encourage developers to build their own components and publish them as part of the ecosytem. ::::{grid} 2 2 2 2 :gutter: 3 From 9ceb5c05a793d7456c025bb8572260759cb5a8d1 Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Mon, 10 Feb 2025 13:13:55 -0800 Subject: [PATCH 31/36] [.NET] save chatHistory in another list to avoid duplicate messages (#5478) ## Why are these changes needed? ## Related issue number Fix #4731 ## Checks - [ ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --- dotnet/src/AutoGen.Core/Extension/AgentExtension.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs b/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs index 3bc5787a8594..b51cb0b90190 100644 --- a/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs @@ -144,13 +144,13 @@ public static async Task> InitiateChatAsync( chatHistory.Add(msg); } - + var intermediateChatHistory = new List(); await foreach (var msg in agent.SendAsync(receiver, chatHistory, maxRound, ct)) { - chatHistory.Add(msg); + intermediateChatHistory.Add(msg); } - return chatHistory; + return chatHistory.Concat(intermediateChatHistory); } [Obsolete("use GroupChatExtension.SendAsync")] From 2612796681a317a8aa012368d547edf461858a89 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Mon, 10 Feb 2025 18:05:19 -0500 Subject: [PATCH 32/36] Implement control channel in python host servicer (#5427) Co-authored-by: Eric Zhu --- .../grpc/_worker_runtime_host_servicer.py | 45 ++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py index daa4ad65101d..621ed9511eb4 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py @@ -7,6 +7,7 @@ from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Sequence, Set, Tuple, TypeVar from autogen_core import TopicId +from autogen_core._agent_id import AgentId from autogen_core._runtime_impl_helpers import SubscriptionManager from ._constants import GRPC_IMPORT_ERROR_STR @@ -100,6 +101,9 @@ def __init__(self) -> None: self._data_connections: Dict[ ClientConnectionId, ChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message] ] = {} + self._control_connections: Dict[ + ClientConnectionId, ChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage] + ] = {} self._agent_type_to_client_id_lock = asyncio.Lock() self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {} self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {} @@ -140,7 +144,23 @@ async def OpenControlChannel( # type: ignore request_iterator: AsyncIterator[agent_worker_pb2.ControlMessage], context: grpc.aio.ServicerContext[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage], ) -> AsyncIterator[agent_worker_pb2.ControlMessage]: - raise NotImplementedError("Method not implemented.") + client_id = await get_client_id_or_abort(context) + + async def handle_callback(message: agent_worker_pb2.ControlMessage) -> None: + await self._receive_control_message(client_id, message) + + connection = CallbackChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage]( + request_iterator, client_id, handle_callback=handle_callback + ) + self._control_connections[client_id] = connection + logger.info(f"Client {client_id} connected.") + + try: + async for message in connection: + yield message + finally: + # Clean up the client connection. + del self._control_connections[client_id] async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None: async with self._agent_type_to_client_id_lock: @@ -182,6 +202,29 @@ async def _receive_message(self, client_id: ClientConnectionId, message: agent_w case None: logger.warning("Received empty message") + async def _receive_control_message( + self, client_id: ClientConnectionId, message: agent_worker_pb2.ControlMessage + ) -> None: + logger.info(f"Received message from client {client_id}: {message}") + destination = message.destination + if destination.startswith("agentid="): + agent_id = AgentId.from_str(destination[len("agentid=") :]) + target_client_id = self._agent_type_to_client_id.get(agent_id.type) + if target_client_id is None: + logger.error(f"Agent client id not found for agent type {agent_id.type}.") + return + elif destination.startswith("clientid="): + target_client_id = destination[len("clientid=") :] + else: + logger.error(f"Invalid destination {destination}") + return + + target_send_queue = self._control_connections.get(target_client_id) + if target_send_queue is None: + logger.error(f"Client {target_client_id} not found, failed to deliver message.") + return + await target_send_queue.send(message) + async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None: # Deliver the message to a client given the target agent type. async with self._agent_type_to_client_id_lock: From cd085e6b894152216f8df985073e095664755818 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Mon, 10 Feb 2025 16:29:43 -0800 Subject: [PATCH 33/36] Improve custom agentchat agent docs with model clients (gemini example) and serialization (#5468) This PR improves documentation on custom agents - Shows example on how to create a custom agent that directly uses a model client. In this case an example of a GeminiAssistantAgent that directly uses the Gemini SDK model client. - Shows that that CustomAgent can be easily added to any agentchat team - Shows how the same CustomAgent can be made declarative by inheriting the Component interface and implementing the required methods. Closes #5450 --- python/packages/autogen-core/docs/src/conf.py | 1 + .../agentchat-user-guide/custom-agents.ipynb | 736 ++++++++++++++++++ .../user-guide/agentchat-user-guide/index.md | 4 +- .../agentchat-user-guide/migration-guide.md | 54 +- .../tutorial/custom-agents.ipynb | 313 -------- .../tutorial/messages.ipynb | 6 +- python/packages/autogen-ext/pyproject.toml | 4 + .../_kernel_function_from_tool.py | 2 +- python/uv.lock | 18 + 9 files changed, 796 insertions(+), 342 deletions(-) create mode 100644 python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/custom-agents.ipynb delete mode 100644 python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/custom-agents.ipynb diff --git a/python/packages/autogen-core/docs/src/conf.py b/python/packages/autogen-core/docs/src/conf.py index 5c4fcbb0ccfb..bfb8123da32b 100644 --- a/python/packages/autogen-core/docs/src/conf.py +++ b/python/packages/autogen-core/docs/src/conf.py @@ -178,6 +178,7 @@ "user-guide/core-user-guide/framework/command-line-code-executors.ipynb": "user-guide/core-user-guide/components/command-line-code-executors.ipynb", "user-guide/core-user-guide/framework/model-clients.ipynb": "user-guide/core-user-guide/components/model-clients.ipynb", "user-guide/core-user-guide/framework/tools.ipynb": "user-guide/core-user-guide/components/tools.ipynb", + "user-guide/agentchat-user-guide/tutorial/custom-agents.ipynb": "user-guide/agentchat-user-guide/custom-agents.ipynb", } diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/custom-agents.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/custom-agents.ipynb new file mode 100644 index 000000000000..d738e72b60f8 --- /dev/null +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/custom-agents.ipynb @@ -0,0 +1,736 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Custom Agents\n", + "\n", + "You may have agents with behaviors that do not fall into a preset. \n", + "In such cases, you can build custom agents.\n", + "\n", + "All agents in AgentChat inherit from {py:class}`~autogen_agentchat.agents.BaseChatAgent` \n", + "class and implement the following abstract methods and attributes:\n", + "\n", + "- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: The abstract method that defines the behavior of the agent in response to messages. This method is called when the agent is asked to provide a response in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run`. It returns a {py:class}`~autogen_agentchat.base.Response` object.\n", + "- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_reset`: The abstract method that resets the agent to its initial state. This method is called when the agent is asked to reset itself.\n", + "- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.produced_message_types`: The list of possible {py:class}`~autogen_agentchat.messages.ChatMessage` message types the agent can produce in its response.\n", + "\n", + "Optionally, you can implement the the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` method to stream messages as they are generated by the agent. If this method is not implemented, the agent\n", + "uses the default implementation of {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`\n", + "that calls the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` method and\n", + "yields all messages in the response." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## CountDownAgent\n", + "\n", + "In this example, we create a simple agent that counts down from a given number to zero,\n", + "and produces a stream of messages with the current count." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3...\n", + "2...\n", + "1...\n", + "Done!\n" + ] + } + ], + "source": [ + "from typing import AsyncGenerator, List, Sequence\n", + "\n", + "from autogen_agentchat.agents import BaseChatAgent\n", + "from autogen_agentchat.base import Response\n", + "from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage\n", + "from autogen_core import CancellationToken\n", + "\n", + "\n", + "class CountDownAgent(BaseChatAgent):\n", + " def __init__(self, name: str, count: int = 3):\n", + " super().__init__(name, \"A simple agent that counts down.\")\n", + " self._count = count\n", + "\n", + " @property\n", + " def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n", + " return (TextMessage,)\n", + "\n", + " async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n", + " # Calls the on_messages_stream.\n", + " response: Response | None = None\n", + " async for message in self.on_messages_stream(messages, cancellation_token):\n", + " if isinstance(message, Response):\n", + " response = message\n", + " assert response is not None\n", + " return response\n", + "\n", + " async def on_messages_stream(\n", + " self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n", + " ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n", + " inner_messages: List[AgentEvent | ChatMessage] = []\n", + " for i in range(self._count, 0, -1):\n", + " msg = TextMessage(content=f\"{i}...\", source=self.name)\n", + " inner_messages.append(msg)\n", + " yield msg\n", + " # The response is returned at the end of the stream.\n", + " # It contains the final message and all the inner messages.\n", + " yield Response(chat_message=TextMessage(content=\"Done!\", source=self.name), inner_messages=inner_messages)\n", + "\n", + " async def on_reset(self, cancellation_token: CancellationToken) -> None:\n", + " pass\n", + "\n", + "\n", + "async def run_countdown_agent() -> None:\n", + " # Create a countdown agent.\n", + " countdown_agent = CountDownAgent(\"countdown\")\n", + "\n", + " # Run the agent with a given task and stream the response.\n", + " async for message in countdown_agent.on_messages_stream([], CancellationToken()):\n", + " if isinstance(message, Response):\n", + " print(message.chat_message.content)\n", + " else:\n", + " print(message.content)\n", + "\n", + "\n", + "# Use asyncio.run(run_countdown_agent()) when running in a script.\n", + "await run_countdown_agent()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ArithmeticAgent\n", + "\n", + "In this example, we create an agent class that can perform simple arithmetic operations\n", + "on a given integer. Then, we will use different instances of this agent class\n", + "in a {py:class}`~autogen_agentchat.teams.SelectorGroupChat`\n", + "to transform a given integer into another integer by applying a sequence of arithmetic operations.\n", + "\n", + "The `ArithmeticAgent` class takes an `operator_func` that takes an integer and returns an integer,\n", + "after applying an arithmetic operation to the integer.\n", + "In its `on_messages` method, it applies the `operator_func` to the integer in the input message,\n", + "and returns a response with the result." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable, Sequence\n", + "\n", + "from autogen_agentchat.agents import BaseChatAgent\n", + "from autogen_agentchat.base import Response\n", + "from autogen_agentchat.conditions import MaxMessageTermination\n", + "from autogen_agentchat.messages import ChatMessage\n", + "from autogen_agentchat.teams import SelectorGroupChat\n", + "from autogen_agentchat.ui import Console\n", + "from autogen_core import CancellationToken\n", + "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", + "\n", + "\n", + "class ArithmeticAgent(BaseChatAgent):\n", + " def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:\n", + " super().__init__(name, description=description)\n", + " self._operator_func = operator_func\n", + " self._message_history: List[ChatMessage] = []\n", + "\n", + " @property\n", + " def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n", + " return (TextMessage,)\n", + "\n", + " async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n", + " # Update the message history.\n", + " # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.\n", + " self._message_history.extend(messages)\n", + " # Parse the number in the last message.\n", + " assert isinstance(self._message_history[-1], TextMessage)\n", + " number = int(self._message_history[-1].content)\n", + " # Apply the operator function to the number.\n", + " result = self._operator_func(number)\n", + " # Create a new message with the result.\n", + " response_message = TextMessage(content=str(result), source=self.name)\n", + " # Update the message history.\n", + " self._message_history.append(response_message)\n", + " # Return the response.\n", + " return Response(chat_message=response_message)\n", + "\n", + " async def on_reset(self, cancellation_token: CancellationToken) -> None:\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{note}\n", + "The `on_messages` method may be called with an empty list of messages, in which\n", + "case it means the agent was called previously and is now being called again,\n", + "without any new messages from the caller. So it is important to keep a history\n", + "of the previous messages received by the agent, and use that history to generate\n", + "the response.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with 5 instances of `ArithmeticAgent`:\n", + "\n", + "- one that adds 1 to the input integer,\n", + "- one that subtracts 1 from the input integer,\n", + "- one that multiplies the input integer by 2,\n", + "- one that divides the input integer by 2 and rounds down to the nearest integer, and\n", + "- one that returns the input integer unchanged.\n", + "\n", + "We then create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with these agents,\n", + "and set the appropriate selector settings:\n", + "\n", + "- allow the same agent to be selected consecutively to allow for repeated operations, and\n", + "- customize the selector prompt to tailor the model's response to the specific task." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "Apply the operations to turn the given number into 25.\n", + "---------- user ----------\n", + "10\n", + "---------- multiply_agent ----------\n", + "20\n", + "---------- add_agent ----------\n", + "21\n", + "---------- multiply_agent ----------\n", + "42\n", + "---------- divide_agent ----------\n", + "21\n", + "---------- add_agent ----------\n", + "22\n", + "---------- add_agent ----------\n", + "23\n", + "---------- add_agent ----------\n", + "24\n", + "---------- add_agent ----------\n", + "25\n", + "---------- Summary ----------\n", + "Number of messages: 10\n", + "Finish reason: Maximum number of messages 10 reached, current message count: 10\n", + "Total prompt tokens: 0\n", + "Total completion tokens: 0\n", + "Duration: 2.40 seconds\n" + ] + } + ], + "source": [ + "async def run_number_agents() -> None:\n", + " # Create agents for number operations.\n", + " add_agent = ArithmeticAgent(\"add_agent\", \"Adds 1 to the number.\", lambda x: x + 1)\n", + " multiply_agent = ArithmeticAgent(\"multiply_agent\", \"Multiplies the number by 2.\", lambda x: x * 2)\n", + " subtract_agent = ArithmeticAgent(\"subtract_agent\", \"Subtracts 1 from the number.\", lambda x: x - 1)\n", + " divide_agent = ArithmeticAgent(\"divide_agent\", \"Divides the number by 2 and rounds down.\", lambda x: x // 2)\n", + " identity_agent = ArithmeticAgent(\"identity_agent\", \"Returns the number as is.\", lambda x: x)\n", + "\n", + " # The termination condition is to stop after 10 messages.\n", + " termination_condition = MaxMessageTermination(10)\n", + "\n", + " # Create a selector group chat.\n", + " selector_group_chat = SelectorGroupChat(\n", + " [add_agent, multiply_agent, subtract_agent, divide_agent, identity_agent],\n", + " model_client=OpenAIChatCompletionClient(model=\"gpt-4o\"),\n", + " termination_condition=termination_condition,\n", + " allow_repeated_speaker=True, # Allow the same agent to speak multiple times, necessary for this task.\n", + " selector_prompt=(\n", + " \"Available roles:\\n{roles}\\nTheir job descriptions:\\n{participants}\\n\"\n", + " \"Current conversation history:\\n{history}\\n\"\n", + " \"Please select the most appropriate role for the next message, and only return the role name.\"\n", + " ),\n", + " )\n", + "\n", + " # Run the selector group chat with a given task and stream the response.\n", + " task: List[ChatMessage] = [\n", + " TextMessage(content=\"Apply the operations to turn the given number into 25.\", source=\"user\"),\n", + " TextMessage(content=\"10\", source=\"user\"),\n", + " ]\n", + " stream = selector_group_chat.run_stream(task=task)\n", + " await Console(stream)\n", + "\n", + "\n", + "# Use asyncio.run(run_number_agents()) when running in a script.\n", + "await run_number_agents()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "From the output, we can see that the agents have successfully transformed the input integer\n", + "from 10 to 25 by choosing appropriate agents that apply the arithmetic operations in sequence." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using Custom Model Clients in Custom Agents\n", + "\n", + "One of the key features of the {py:class}`~autogen_agentchat.agents.AssistantAgent` preset in AgentChat is that it takes a `model_client` argument and can use it in responding to messages. However, in some cases, you may want your agent to use a custom model client that is not currently supported (see [supported model clients](https://microsoft.github.io/autogen/dev/user-guide/core-user-guide/components/model-clients.html)) or custom model behaviours. \n", + "\n", + "You can accomplish this with a custom agent that implements *your custom model client*.\n", + "\n", + "In the example below, we will walk through an example of a custom agent that uses the [Google Gemini SDK](https://github.com/googleapis/python-genai) directly to respond to messages.\n", + "\n", + "> **Note:** You will need to install the [Google Gemini SDK](https://github.com/googleapis/python-genai) to run this example. You can install it using the following command: \n", + "\n", + "```bash\n", + "pip install google-genai\n", + "``` " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install google-genai\n", + "import os\n", + "from typing import AsyncGenerator, Sequence\n", + "\n", + "from autogen_agentchat.agents import BaseChatAgent\n", + "from autogen_agentchat.base import Response\n", + "from autogen_agentchat.messages import AgentEvent, ChatMessage\n", + "from autogen_core import CancellationToken\n", + "from autogen_core.model_context import UnboundedChatCompletionContext\n", + "from autogen_core.models import AssistantMessage, RequestUsage, UserMessage\n", + "from google import genai\n", + "from google.genai import types\n", + "\n", + "\n", + "class GeminiAssistantAgent(BaseChatAgent):\n", + " def __init__(\n", + " self,\n", + " name: str,\n", + " description: str = \"An agent that provides assistance with ability to use tools.\",\n", + " model: str = \"gemini-1.5-flash-002\",\n", + " api_key: str = os.environ[\"GEMINI_API_KEY\"],\n", + " system_message: str\n", + " | None = \"You are a helpful assistant that can respond to messages. Reply with TERMINATE when the task has been completed.\",\n", + " ):\n", + " super().__init__(name=name, description=description)\n", + " self._model_context = UnboundedChatCompletionContext()\n", + " self._model_client = genai.Client(api_key=api_key)\n", + " self._system_message = system_message\n", + " self._model = model\n", + "\n", + " @property\n", + " def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n", + " return (TextMessage,)\n", + "\n", + " async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n", + " final_response = None\n", + " async for message in self.on_messages_stream(messages, cancellation_token):\n", + " if isinstance(message, Response):\n", + " final_response = message\n", + "\n", + " if final_response is None:\n", + " raise AssertionError(\"The stream should have returned the final result.\")\n", + "\n", + " return final_response\n", + "\n", + " async def on_messages_stream(\n", + " self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n", + " ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n", + " # Add messages to the model context\n", + " for msg in messages:\n", + " await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))\n", + "\n", + " # Get conversation history\n", + " history = [\n", + " (msg.source if hasattr(msg, \"source\") else \"system\")\n", + " + \": \"\n", + " + (msg.content if isinstance(msg.content, str) else \"\")\n", + " + \"\\n\"\n", + " for msg in await self._model_context.get_messages()\n", + " ]\n", + " # Generate response using Gemini\n", + " response = self._model_client.models.generate_content(\n", + " model=self._model,\n", + " contents=f\"History: {history}\\nGiven the history, please provide a response\",\n", + " config=types.GenerateContentConfig(\n", + " system_instruction=self._system_message,\n", + " temperature=0.3,\n", + " ),\n", + " )\n", + "\n", + " # Create usage metadata\n", + " usage = RequestUsage(\n", + " prompt_tokens=response.usage_metadata.prompt_token_count,\n", + " completion_tokens=response.usage_metadata.candidates_token_count,\n", + " )\n", + "\n", + " # Add response to model context\n", + " await self._model_context.add_message(AssistantMessage(content=response.text, source=self.name))\n", + "\n", + " # Yield the final response\n", + " yield Response(\n", + " chat_message=TextMessage(content=response.text, source=self.name, models_usage=usage),\n", + " inner_messages=[],\n", + " )\n", + "\n", + " async def on_reset(self, cancellation_token: CancellationToken) -> None:\n", + " \"\"\"Reset the assistant by clearing the model context.\"\"\"\n", + " await self._model_context.clear()" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "What is the capital of New York?\n", + "---------- gemini_assistant ----------\n", + "Albany\n", + "TERMINATE\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the capital of New York?', type='TextMessage'), TextMessage(source='gemini_assistant', models_usage=RequestUsage(prompt_tokens=46, completion_tokens=5), content='Albany\\nTERMINATE\\n', type='TextMessage')], stop_reason=None)" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gemini_assistant = GeminiAssistantAgent(\"gemini_assistant\")\n", + "await Console(gemini_assistant.run_stream(task=\"What is the capital of New York?\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the example above, we have chosen to provide `model`, `api_key` and `system_message` as arguments - you can choose to provide any other arguments that are required by the model client you are using or fits with your application design. \n", + "\n", + "Now, let us explore how to use this custom agent as part of a team in AgentChat." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "Write a Haiku poem with 4 lines about the fall season.\n", + "---------- primary ----------\n", + "Crimson leaves cascade, \n", + "Whispering winds sing of change, \n", + "Chill wraps the fading, \n", + "Nature's quilt, rich and warm.\n", + "---------- gemini_critic ----------\n", + "The poem is good, but it has four lines instead of three. A haiku must have three lines with a 5-7-5 syllable structure. The content is evocative of autumn, but the form is incorrect. Please revise to adhere to the haiku's syllable structure.\n", + "\n", + "---------- primary ----------\n", + "Thank you for your feedback! Here’s a revised haiku that follows the 5-7-5 syllable structure:\n", + "\n", + "Crimson leaves drift down, \n", + "Chill winds whisper through the gold, \n", + "Autumn’s breath is near.\n", + "---------- gemini_critic ----------\n", + "The revised haiku is much improved. It correctly follows the 5-7-5 syllable structure and maintains the evocative imagery of autumn. APPROVE\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write a Haiku poem with 4 lines about the fall season.', type='TextMessage'), TextMessage(source='primary', models_usage=RequestUsage(prompt_tokens=33, completion_tokens=31), content=\"Crimson leaves cascade, \\nWhispering winds sing of change, \\nChill wraps the fading, \\nNature's quilt, rich and warm.\", type='TextMessage'), TextMessage(source='gemini_critic', models_usage=RequestUsage(prompt_tokens=86, completion_tokens=60), content=\"The poem is good, but it has four lines instead of three. A haiku must have three lines with a 5-7-5 syllable structure. The content is evocative of autumn, but the form is incorrect. Please revise to adhere to the haiku's syllable structure.\\n\", type='TextMessage'), TextMessage(source='primary', models_usage=RequestUsage(prompt_tokens=141, completion_tokens=49), content='Thank you for your feedback! Here’s a revised haiku that follows the 5-7-5 syllable structure:\\n\\nCrimson leaves drift down, \\nChill winds whisper through the gold, \\nAutumn’s breath is near.', type='TextMessage'), TextMessage(source='gemini_critic', models_usage=RequestUsage(prompt_tokens=211, completion_tokens=32), content='The revised haiku is much improved. It correctly follows the 5-7-5 syllable structure and maintains the evocative imagery of autumn. APPROVE\\n', type='TextMessage')], stop_reason=\"Text 'APPROVE' mentioned\")" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from autogen_agentchat.agents import AssistantAgent\n", + "from autogen_agentchat.conditions import TextMentionTermination\n", + "from autogen_agentchat.teams import RoundRobinGroupChat\n", + "from autogen_agentchat.ui import Console\n", + "\n", + "# Create the primary agent.\n", + "primary_agent = AssistantAgent(\n", + " \"primary\",\n", + " model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n", + " system_message=\"You are a helpful AI assistant.\",\n", + ")\n", + "\n", + "# Create a critic agent based on our new GeminiAssistantAgent.\n", + "gemini_critic_agent = GeminiAssistantAgent(\n", + " \"gemini_critic\",\n", + " system_message=\"Provide constructive feedback. Respond with 'APPROVE' to when your feedbacks are addressed.\",\n", + ")\n", + "\n", + "\n", + "# Define a termination condition that stops the task if the critic approves or after 10 messages.\n", + "termination = TextMentionTermination(\"APPROVE\") | MaxMessageTermination(10)\n", + "\n", + "# Create a team with the primary and critic agents.\n", + "team = RoundRobinGroupChat([primary_agent, gemini_critic_agent], termination_condition=termination)\n", + "\n", + "await Console(team.run_stream(task=\"Write a Haiku poem with 4 lines about the fall season.\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In section above, we show several very important concepts:\n", + "- We have developed a custom agent that uses the Google Gemini SDK to respond to messages. \n", + "- We show that this custom agent can be used as part of the broader AgentChat ecosystem - in this case as a participant in a {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat` as long as it inherits from {py:class}`~autogen_agentchat.agents.BaseChatAgent`.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Making the Custom Agent Declarative \n", + "\n", + "Autogen provides a [Component](https://microsoft.github.io/autogen/dev/user-guide/core-user-guide/framework/component-config.html) interface for making the configuration of components serializable to a declarative format. This is useful for saving and loading configurations, and for sharing configurations with others. \n", + "\n", + "We accomplish this by inheriting from the `Component` class and implementing the `_from_config` and `_to_config` methods.\n", + "The declarative class can be serialized to a JSON format using the `dump_component` method, and deserialized from a JSON format using the `load_component` method." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from typing import AsyncGenerator, Sequence\n", + "\n", + "from autogen_agentchat.agents import BaseChatAgent\n", + "from autogen_agentchat.base import Response\n", + "from autogen_agentchat.messages import AgentEvent, ChatMessage\n", + "from autogen_core import CancellationToken, Component\n", + "from pydantic import BaseModel\n", + "from typing_extensions import Self\n", + "\n", + "\n", + "class GeminiAssistantAgentConfig(BaseModel):\n", + " name: str\n", + " description: str = \"An agent that provides assistance with ability to use tools.\"\n", + " model: str = \"gemini-1.5-flash-002\"\n", + " system_message: str | None = None\n", + "\n", + "\n", + "class GeminiAssistantAgent(BaseChatAgent, Component[GeminiAssistantAgentConfig]): # type: ignore[no-redef]\n", + " component_config_schema = GeminiAssistantAgentConfig\n", + " # component_provider_override = \"mypackage.agents.GeminiAssistantAgent\"\n", + "\n", + " def __init__(\n", + " self,\n", + " name: str,\n", + " description: str = \"An agent that provides assistance with ability to use tools.\",\n", + " model: str = \"gemini-1.5-flash-002\",\n", + " api_key: str = os.environ[\"GEMINI_API_KEY\"],\n", + " system_message: str\n", + " | None = \"You are a helpful assistant that can respond to messages. Reply with TERMINATE when the task has been completed.\",\n", + " ):\n", + " super().__init__(name=name, description=description)\n", + " self._model_context = UnboundedChatCompletionContext()\n", + " self._model_client = genai.Client(api_key=api_key)\n", + " self._system_message = system_message\n", + " self._model = model\n", + "\n", + " @property\n", + " def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n", + " return (TextMessage,)\n", + "\n", + " async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n", + " final_response = None\n", + " async for message in self.on_messages_stream(messages, cancellation_token):\n", + " if isinstance(message, Response):\n", + " final_response = message\n", + "\n", + " if final_response is None:\n", + " raise AssertionError(\"The stream should have returned the final result.\")\n", + "\n", + " return final_response\n", + "\n", + " async def on_messages_stream(\n", + " self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n", + " ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n", + " # Add messages to the model context\n", + " for msg in messages:\n", + " await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))\n", + "\n", + " # Get conversation history\n", + " history = [\n", + " (msg.source if hasattr(msg, \"source\") else \"system\")\n", + " + \": \"\n", + " + (msg.content if isinstance(msg.content, str) else \"\")\n", + " + \"\\n\"\n", + " for msg in await self._model_context.get_messages()\n", + " ]\n", + "\n", + " # Generate response using Gemini\n", + " response = self._model_client.models.generate_content(\n", + " model=self._model,\n", + " contents=f\"History: {history}\\nGiven the history, please provide a response\",\n", + " config=types.GenerateContentConfig(\n", + " system_instruction=self._system_message,\n", + " temperature=0.3,\n", + " ),\n", + " )\n", + "\n", + " # Create usage metadata\n", + " usage = RequestUsage(\n", + " prompt_tokens=response.usage_metadata.prompt_token_count,\n", + " completion_tokens=response.usage_metadata.candidates_token_count,\n", + " )\n", + "\n", + " # Add response to model context\n", + " await self._model_context.add_message(AssistantMessage(content=response.text, source=self.name))\n", + "\n", + " # Yield the final response\n", + " yield Response(\n", + " chat_message=TextMessage(content=response.text, source=self.name, models_usage=usage),\n", + " inner_messages=[],\n", + " )\n", + "\n", + " async def on_reset(self, cancellation_token: CancellationToken) -> None:\n", + " \"\"\"Reset the assistant by clearing the model context.\"\"\"\n", + " await self._model_context.clear()\n", + "\n", + " @classmethod\n", + " def _from_config(cls, config: GeminiAssistantAgentConfig) -> Self:\n", + " return cls(\n", + " name=config.name, description=config.description, model=config.model, system_message=config.system_message\n", + " )\n", + "\n", + " def _to_config(self) -> GeminiAssistantAgentConfig:\n", + " return GeminiAssistantAgentConfig(\n", + " name=self.name,\n", + " description=self.description,\n", + " model=self._model,\n", + " system_message=self._system_message,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have the required methods implemented, we can now load and dump the custom agent to and from a JSON format, and then load the agent from the JSON format.\n", + " \n", + " > Note: You should set the `component_provider_override` class variable to the full path of the module containing the custom agent class e.g., (`mypackage.agents.GeminiAssistantAgent`). This is used by `load_component` method to determine how to instantiate the class. \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"provider\": \"__main__.GeminiAssistantAgent\",\n", + " \"component_type\": \"agent\",\n", + " \"version\": 1,\n", + " \"component_version\": 1,\n", + " \"description\": null,\n", + " \"label\": \"GeminiAssistantAgent\",\n", + " \"config\": {\n", + " \"name\": \"gemini_assistant\",\n", + " \"description\": \"An agent that provides assistance with ability to use tools.\",\n", + " \"model\": \"gemini-1.5-flash-002\",\n", + " \"system_message\": \"You are a helpful assistant that can respond to messages. Reply with TERMINATE when the task has been completed.\"\n", + " }\n", + "}\n", + "<__main__.GeminiAssistantAgent object at 0x11a5c5a90>\n" + ] + } + ], + "source": [ + "gemini_assistant = GeminiAssistantAgent(\"gemini_assistant\")\n", + "config = gemini_assistant.dump_component()\n", + "print(config.model_dump_json(indent=2))\n", + "loaded_agent = GeminiAssistantAgent.load_component(config)\n", + "print(loaded_agent)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Next Steps \n", + "\n", + "So far, we have seen how to create custom agents, add custom model clients to agents, and make custom agents declarative. There are a few ways in which this basic sample can be extended:\n", + "\n", + "- Extend the Gemini model client to handle function calling similar to the {py:class}`~autogen_agentchat.agents.AssistantAgent` class. https://ai.google.dev/gemini-api/docs/function-calling \n", + "- Implement a package wit a custom agent and experiment with using it's declarative format in a tool like [AutoGen Studio](https://microsoft.github.io/autogen/dev/user-guide/autogenstudio-user-guide/index.html)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/index.md b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/index.md index e83288338dc7..b7d124f9e67e 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/index.md +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/index.md @@ -101,7 +101,7 @@ tutorial/agents tutorial/teams tutorial/human-in-the-loop tutorial/termination -tutorial/custom-agents + tutorial/state ``` @@ -111,11 +111,13 @@ tutorial/state :hidden: :caption: Advanced +custom-agents selector-group-chat swarm magentic-one memory serialize-components + ``` ```{toctree} diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/migration-guide.md b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/migration-guide.md index fdb3ca2152d2..de9991609258 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/migration-guide.md +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/migration-guide.md @@ -33,29 +33,35 @@ We provide a detailed guide on how to migrate your existing codebase from `v0.2` See each feature below for detailed information on how to migrate. -- [Model Client](#model-client) -- [Model Client for OpenAI-Compatible APIs](#model-client-for-openai-compatible-apis) -- [Model Client Cache](#model-client-cache) -- [Assistant Agent](#assistant-agent) -- [Multi-Modal Agent](#multi-modal-agent) -- [User Proxy](#user-proxy) -- [Conversable Agent and Register Reply](#conversable-agent-and-register-reply) -- [Save and Load Agent State](#save-and-load-agent-state) -- [Two-Agent Chat](#two-agent-chat) -- [Tool Use](#tool-use) -- [Chat Result](#chat-result) -- [Conversion between v0.2 and v0.4 Messages](#conversion-between-v02-and-v04-messages) -- [Group Chat](#group-chat) -- [Group Chat with Resume](#group-chat-with-resume) -- [Save and Load Group Chat State](#save-and-load-group-chat-state) -- [Group Chat with Tool Use](#group-chat-with-tool-use) -- [Group Chat with Custom Selector (Stateflow)](#group-chat-with-custom-selector-stateflow) -- [Nested Chat](#nested-chat) -- [Sequential Chat](#sequential-chat) -- [GPTAssistantAgent](#gptassistantagent) -- [Long-Context Handling](#long-context-handling) -- [Observability and Control](#observability-and-control) -- [Code Executors](#code-executors) +- [Migration Guide for v0.2 to v0.4](#migration-guide-for-v02-to-v04) + - [What is `v0.4`?](#what-is-v04) + - [New to AutoGen?](#new-to-autogen) + - [What's in this guide?](#whats-in-this-guide) + - [Model Client](#model-client) + - [Use component config](#use-component-config) + - [Use model client class directly](#use-model-client-class-directly) + - [Model Client for OpenAI-Compatible APIs](#model-client-for-openai-compatible-apis) + - [Model Client Cache](#model-client-cache) + - [Assistant Agent](#assistant-agent) + - [Multi-Modal Agent](#multi-modal-agent) + - [User Proxy](#user-proxy) + - [Conversable Agent and Register Reply](#conversable-agent-and-register-reply) + - [Save and Load Agent State](#save-and-load-agent-state) + - [Two-Agent Chat](#two-agent-chat) + - [Tool Use](#tool-use) + - [Chat Result](#chat-result) + - [Conversion between v0.2 and v0.4 Messages](#conversion-between-v02-and-v04-messages) + - [Group Chat](#group-chat) + - [Group Chat with Resume](#group-chat-with-resume) + - [Save and Load Group Chat State](#save-and-load-group-chat-state) + - [Group Chat with Tool Use](#group-chat-with-tool-use) + - [Group Chat with Custom Selector (Stateflow)](#group-chat-with-custom-selector-stateflow) + - [Nested Chat](#nested-chat) + - [Sequential Chat](#sequential-chat) + - [GPTAssistantAgent](#gptassistantagent) + - [Long Context Handling](#long-context-handling) + - [Observability and Control](#observability-and-control) + - [Code Executors](#code-executors) The following features currently in `v0.2` will be provided in the future releases of `v0.4.*` versions: @@ -417,7 +423,7 @@ class CustomAgent(BaseChatAgent): ``` You can then use the custom agent in the same way as the {py:class}`~autogen_agentchat.agents.AssistantAgent`. -See [Custom Agent Tutorial](./tutorial/custom-agents.ipynb) +See [Custom Agent Tutorial](custom-agents.ipynb) for more details. ## Save and Load Agent State diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/custom-agents.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/custom-agents.ipynb deleted file mode 100644 index 5b8c4e7f24fa..000000000000 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/custom-agents.ipynb +++ /dev/null @@ -1,313 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Custom Agents\n", - "\n", - "You may have agents with behaviors that do not fall into a preset. \n", - "In such cases, you can build custom agents.\n", - "\n", - "All agents in AgentChat inherit from {py:class}`~autogen_agentchat.agents.BaseChatAgent` \n", - "class and implement the following abstract methods and attributes:\n", - "\n", - "- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: The abstract method that defines the behavior of the agent in response to messages. This method is called when the agent is asked to provide a response in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run`. It returns a {py:class}`~autogen_agentchat.base.Response` object.\n", - "- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_reset`: The abstract method that resets the agent to its initial state. This method is called when the agent is asked to reset itself.\n", - "- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.produced_message_types`: The list of possible {py:class}`~autogen_agentchat.messages.ChatMessage` message types the agent can produce in its response.\n", - "\n", - "Optionally, you can implement the the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` method to stream messages as they are generated by the agent. If this method is not implemented, the agent\n", - "uses the default implementation of {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`\n", - "that calls the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` method and\n", - "yields all messages in the response." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## CountDownAgent\n", - "\n", - "In this example, we create a simple agent that counts down from a given number to zero,\n", - "and produces a stream of messages with the current count." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "3...\n", - "2...\n", - "1...\n", - "Done!\n" - ] - } - ], - "source": [ - "from typing import AsyncGenerator, List, Sequence\n", - "\n", - "from autogen_agentchat.agents import BaseChatAgent\n", - "from autogen_agentchat.base import Response\n", - "from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage\n", - "from autogen_core import CancellationToken\n", - "\n", - "\n", - "class CountDownAgent(BaseChatAgent):\n", - " def __init__(self, name: str, count: int = 3):\n", - " super().__init__(name, \"A simple agent that counts down.\")\n", - " self._count = count\n", - "\n", - " @property\n", - " def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n", - " return (TextMessage,)\n", - "\n", - " async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n", - " # Calls the on_messages_stream.\n", - " response: Response | None = None\n", - " async for message in self.on_messages_stream(messages, cancellation_token):\n", - " if isinstance(message, Response):\n", - " response = message\n", - " assert response is not None\n", - " return response\n", - "\n", - " async def on_messages_stream(\n", - " self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n", - " ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n", - " inner_messages: List[AgentEvent | ChatMessage] = []\n", - " for i in range(self._count, 0, -1):\n", - " msg = TextMessage(content=f\"{i}...\", source=self.name)\n", - " inner_messages.append(msg)\n", - " yield msg\n", - " # The response is returned at the end of the stream.\n", - " # It contains the final message and all the inner messages.\n", - " yield Response(chat_message=TextMessage(content=\"Done!\", source=self.name), inner_messages=inner_messages)\n", - "\n", - " async def on_reset(self, cancellation_token: CancellationToken) -> None:\n", - " pass\n", - "\n", - "\n", - "async def run_countdown_agent() -> None:\n", - " # Create a countdown agent.\n", - " countdown_agent = CountDownAgent(\"countdown\")\n", - "\n", - " # Run the agent with a given task and stream the response.\n", - " async for message in countdown_agent.on_messages_stream([], CancellationToken()):\n", - " if isinstance(message, Response):\n", - " print(message.chat_message.content)\n", - " else:\n", - " print(message.content)\n", - "\n", - "\n", - "# Use asyncio.run(run_countdown_agent()) when running in a script.\n", - "await run_countdown_agent()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## ArithmeticAgent\n", - "\n", - "In this example, we create an agent class that can perform simple arithmetic operations\n", - "on a given integer. Then, we will use different instances of this agent class\n", - "in a {py:class}`~autogen_agentchat.teams.SelectorGroupChat`\n", - "to transform a given integer into another integer by applying a sequence of arithmetic operations.\n", - "\n", - "The `ArithmeticAgent` class takes an `operator_func` that takes an integer and returns an integer,\n", - "after applying an arithmetic operation to the integer.\n", - "In its `on_messages` method, it applies the `operator_func` to the integer in the input message,\n", - "and returns a response with the result." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Callable, Sequence\n", - "\n", - "from autogen_agentchat.agents import BaseChatAgent\n", - "from autogen_agentchat.base import Response\n", - "from autogen_agentchat.conditions import MaxMessageTermination\n", - "from autogen_agentchat.messages import ChatMessage\n", - "from autogen_agentchat.teams import SelectorGroupChat\n", - "from autogen_agentchat.ui import Console\n", - "from autogen_core import CancellationToken\n", - "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", - "\n", - "\n", - "class ArithmeticAgent(BaseChatAgent):\n", - " def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:\n", - " super().__init__(name, description=description)\n", - " self._operator_func = operator_func\n", - " self._message_history: List[ChatMessage] = []\n", - "\n", - " @property\n", - " def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n", - " return (TextMessage,)\n", - "\n", - " async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n", - " # Update the message history.\n", - " # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.\n", - " self._message_history.extend(messages)\n", - " # Parse the number in the last message.\n", - " assert isinstance(self._message_history[-1], TextMessage)\n", - " number = int(self._message_history[-1].content)\n", - " # Apply the operator function to the number.\n", - " result = self._operator_func(number)\n", - " # Create a new message with the result.\n", - " response_message = TextMessage(content=str(result), source=self.name)\n", - " # Update the message history.\n", - " self._message_history.append(response_message)\n", - " # Return the response.\n", - " return Response(chat_message=response_message)\n", - "\n", - " async def on_reset(self, cancellation_token: CancellationToken) -> None:\n", - " pass" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```{note}\n", - "The `on_messages` method may be called with an empty list of messages, in which\n", - "case it means the agent was called previously and is now being called again,\n", - "without any new messages from the caller. So it is important to keep a history\n", - "of the previous messages received by the agent, and use that history to generate\n", - "the response.\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with 5 instances of `ArithmeticAgent`:\n", - "\n", - "- one that adds 1 to the input integer,\n", - "- one that subtracts 1 from the input integer,\n", - "- one that multiplies the input integer by 2,\n", - "- one that divides the input integer by 2 and rounds down to the nearest integer, and\n", - "- one that returns the input integer unchanged.\n", - "\n", - "We then create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with these agents,\n", - "and set the appropriate selector settings:\n", - "\n", - "- allow the same agent to be selected consecutively to allow for repeated operations, and\n", - "- customize the selector prompt to tailor the model's response to the specific task." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "---------- user ----------\n", - "Apply the operations to turn the given number into 25.\n", - "---------- user ----------\n", - "10\n", - "---------- multiply_agent ----------\n", - "20\n", - "---------- add_agent ----------\n", - "21\n", - "---------- multiply_agent ----------\n", - "42\n", - "---------- divide_agent ----------\n", - "21\n", - "---------- add_agent ----------\n", - "22\n", - "---------- add_agent ----------\n", - "23\n", - "---------- add_agent ----------\n", - "24\n", - "---------- add_agent ----------\n", - "25\n", - "---------- Summary ----------\n", - "Number of messages: 10\n", - "Finish reason: Maximum number of messages 10 reached, current message count: 10\n", - "Total prompt tokens: 0\n", - "Total completion tokens: 0\n", - "Duration: 2.40 seconds\n" - ] - } - ], - "source": [ - "async def run_number_agents() -> None:\n", - " # Create agents for number operations.\n", - " add_agent = ArithmeticAgent(\"add_agent\", \"Adds 1 to the number.\", lambda x: x + 1)\n", - " multiply_agent = ArithmeticAgent(\"multiply_agent\", \"Multiplies the number by 2.\", lambda x: x * 2)\n", - " subtract_agent = ArithmeticAgent(\"subtract_agent\", \"Subtracts 1 from the number.\", lambda x: x - 1)\n", - " divide_agent = ArithmeticAgent(\"divide_agent\", \"Divides the number by 2 and rounds down.\", lambda x: x // 2)\n", - " identity_agent = ArithmeticAgent(\"identity_agent\", \"Returns the number as is.\", lambda x: x)\n", - "\n", - " # The termination condition is to stop after 10 messages.\n", - " termination_condition = MaxMessageTermination(10)\n", - "\n", - " # Create a selector group chat.\n", - " selector_group_chat = SelectorGroupChat(\n", - " [add_agent, multiply_agent, subtract_agent, divide_agent, identity_agent],\n", - " model_client=OpenAIChatCompletionClient(model=\"gpt-4o\"),\n", - " termination_condition=termination_condition,\n", - " allow_repeated_speaker=True, # Allow the same agent to speak multiple times, necessary for this task.\n", - " selector_prompt=(\n", - " \"Available roles:\\n{roles}\\nTheir job descriptions:\\n{participants}\\n\"\n", - " \"Current conversation history:\\n{history}\\n\"\n", - " \"Please select the most appropriate role for the next message, and only return the role name.\"\n", - " ),\n", - " )\n", - "\n", - " # Run the selector group chat with a given task and stream the response.\n", - " task: List[ChatMessage] = [\n", - " TextMessage(content=\"Apply the operations to turn the given number into 25.\", source=\"user\"),\n", - " TextMessage(content=\"10\", source=\"user\"),\n", - " ]\n", - " stream = selector_group_chat.run_stream(task=task)\n", - " await Console(stream)\n", - "\n", - "\n", - "# Use asyncio.run(run_number_agents()) when running in a script.\n", - "await run_number_agents()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "From the output, we can see that the agents have successfully transformed the input integer\n", - "from 10 to 25 by choosing appropriate agents that apply the arithmetic operations in sequence." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/messages.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/messages.ipynb index 25dc78641980..a3f5a25519dd 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/messages.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/messages.ipynb @@ -97,7 +97,7 @@ "\n", "Examples of these include {py:class}`~autogen_agentchat.messages.ToolCallRequestEvent`, which indicates that a request was made to call a tool, and {py:class}`~autogen_agentchat.messages.ToolCallExecutionEvent`, which contains the results of tool calls.\n", "\n", - "Typically, events are created by the agent itself and are contained in the {py:attr}`~autogen_agentchat.base.Response.inner_messages` field of the {py:class}`~autogen_agentchat.base.Response` returned from {py:class}`~autogen_agentchat.base.ChatAgent.on_messages`. If you are building a custom agent and have events that you want to communicate to other entities (e.g., a UI), you can include these in the {py:attr}`~autogen_agentchat.base.Response.inner_messages` field of the {py:class}`~autogen_agentchat.base.Response`. We will show examples of this in [Custom Agents](./custom-agents.ipynb).\n", + "Typically, events are created by the agent itself and are contained in the {py:attr}`~autogen_agentchat.base.Response.inner_messages` field of the {py:class}`~autogen_agentchat.base.Response` returned from {py:class}`~autogen_agentchat.base.ChatAgent.on_messages`. If you are building a custom agent and have events that you want to communicate to other entities (e.g., a UI), you can include these in the {py:attr}`~autogen_agentchat.base.Response.inner_messages` field of the {py:class}`~autogen_agentchat.base.Response`. We will show examples of this in [Custom Agents](../custom-agents.ipynb).\n", "\n", "\n", "You can read about the full set of messages supported in AgentChat in the {py:mod}`~autogen_agentchat.messages` module. \n", @@ -107,7 +107,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "agnext", "language": "python", "name": "python3" }, @@ -121,7 +121,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.6" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index bf86c3e4004e..c5e2b62bb627 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -70,6 +70,10 @@ semantic-kernel-core = [ "semantic-kernel>=1.17.1", ] +gemini = [ + "google-genai>=1.0.0", +] + semantic-kernel-google = [ "semantic-kernel[google]>=1.17.1", ] diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py b/python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py index 8125aef9c5ce..0f3f93e2e7fa 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py @@ -3,8 +3,8 @@ from autogen_core import CancellationToken from autogen_core.tools import BaseTool from pydantic import BaseModel -from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata from semantic_kernel.functions import KernelFunctionFromMethod, kernel_function +from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata InputT = TypeVar("InputT", bound=BaseModel) OutputT = TypeVar("OutputT", bound=BaseModel) diff --git a/python/uv.lock b/python/uv.lock index 17c0b76c4d4a..ad56714d25a7 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -593,6 +593,9 @@ file-surfer = [ { name = "autogen-agentchat" }, { name = "markitdown" }, ] +gemini = [ + { name = "google-genai" }, +] graphrag = [ { name = "graphrag" }, ] @@ -700,6 +703,7 @@ requires-dist = [ { name = "diskcache", marker = "extra == 'diskcache'", specifier = ">=5.6.3" }, { name = "docker", marker = "extra == 'docker'", specifier = "~=7.0" }, { name = "ffmpeg-python", marker = "extra == 'video-surfer'" }, + { name = "google-genai", marker = "extra == 'gemini'", specifier = ">=1.0.0" }, { name = "graphrag", marker = "extra == 'graphrag'", specifier = ">=1.0.1" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = "~=1.70.0" }, { name = "httpx", marker = "extra == 'http-tool'", specifier = ">=0.27.0" }, @@ -2221,6 +2225,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/14/fb/54deefe679b7d1c1cc81d83396fcf28ad1a66d213bddeb275a8d28665918/google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18e311c64008f1f1379158158bb3f0c8d72635b9eb4f9545f8cf990c5668e59d", size = 27866 }, ] +[[package]] +name = "google-genai" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "websockets" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/6b/30/25443d2dec5fe4efd9f05440e13735ae68b19de2dea043b3b83a91a4e14b/google_genai-1.1.0-py3-none-any.whl", hash = "sha256:c48ac44612ad6aadc0bf96b12fa4314756baa16382c890fff793bcb53e9a9cc8", size = 130299 }, +] + [[package]] name = "google-generativeai" version = "0.8.4" From f9d4a844dfe5a6f85828b25ff453a3f7aaaa53dd Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 10 Feb 2025 16:49:00 -0800 Subject: [PATCH 34/36] Update website version v0.4.6 (#5481) --- .github/workflows/docs.yml | 3 ++- docs/switcher.json | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 1129930a64f4..ee8bffb7b5a5 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -33,7 +33,7 @@ jobs: [ # For main use the workflow target { ref: "${{github.ref}}", dest-dir: dev, uv-version: "0.5.13", sphinx-release-override: "dev" }, - { ref: "python-v0.4.5", dest-dir: stable, uv-version: "0.5.13", sphinx-release-override: "stable" }, + { ref: "python-v0.4.6", dest-dir: stable, uv-version: "0.5.13", sphinx-release-override: "stable" }, { ref: "v0.4.0.dev0", dest-dir: "0.4.0.dev0", uv-version: "0.5.11", sphinx-release-override: "" }, { ref: "v0.4.0.dev1", dest-dir: "0.4.0.dev1", uv-version: "0.5.11", sphinx-release-override: "" }, { ref: "v0.4.0.dev2", dest-dir: "0.4.0.dev2", uv-version: "0.5.11", sphinx-release-override: "" }, @@ -54,6 +54,7 @@ jobs: { ref: "v0.4.3", dest-dir: "0.4.3", uv-version: "0.5.13", sphinx-release-override: "" }, { ref: "v0.4.4", dest-dir: "0.4.4", uv-version: "0.5.13", sphinx-release-override: "" }, { ref: "python-v0.4.5", dest-dir: "0.4.5", uv-version: "0.5.13", sphinx-release-override: "" }, + { ref: "python-v0.4.6", dest-dir: "0.4.6", uv-version: "0.5.13", sphinx-release-override: "" }, ] steps: - name: Checkout diff --git a/docs/switcher.json b/docs/switcher.json index a31f2ca58fec..341ee23ffa11 100644 --- a/docs/switcher.json +++ b/docs/switcher.json @@ -1,6 +1,6 @@ [ { - "name": "0.4.5 (stable)", + "name": "0.4.6 (stable)", "version": "stable", "url": "/autogen/stable/", "preferred": true From a9db38461fed69b381cca454f7bbb25e8f503779 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 10 Feb 2025 23:55:13 -0800 Subject: [PATCH 35/36] doc: Update API doc for MCP tool to include installation instructions (#5482) --- .../autogen-ext/src/autogen_ext/tools/mcp/_factory.py | 8 ++++++++ .../autogen-ext/src/autogen_ext/tools/mcp/_sse.py | 8 ++++++++ .../autogen-ext/src/autogen_ext/tools/mcp/_stdio.py | 9 +++++++++ 3 files changed, 25 insertions(+) diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py index fcb29e3c4a99..3eb8634b3698 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py @@ -12,6 +12,14 @@ async def mcp_server_tools( This factory function connects to an MCP server and returns adapters for all available tools. The adapters can be directly assigned to an AutoGen agent's tools list. + .. note:: + + To use this function, you need to install `mcp` extra for the `autogen-ext` package. + + .. code-block:: bash + + pip install -U "autogen-ext[mcp]" + Args: server_params (McpServerParams): Connection parameters for the MCP server. Can be either StdioServerParams for command-line tools or diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py index 7b875649a4f4..252af7ce50da 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py @@ -25,6 +25,14 @@ class SseMcpToolAdapter( with AutoGen agents. Common use cases include integrating with remote MCP services, cloud-based tools, and web APIs that implement the Model Context Protocol (MCP). + .. note:: + + To use this class, you need to install `mcp` extra for the `autogen-ext` package. + + .. code-block:: bash + + pip install -U "autogen-ext[mcp]" + Args: server_params (SseServerParameters): Parameters for the MCP server connection, including URL, headers, and timeouts diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py index 629b1a46ec1a..4f827785e903 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py @@ -24,6 +24,15 @@ class StdioMcpToolAdapter( with AutoGen agents. Common use cases include wrapping command-line tools and local services that implement the Model Context Protocol (MCP). + .. note:: + + To use this class, you need to install `mcp` extra for the `autogen-ext` package. + + .. code-block:: bash + + pip install -U "autogen-ext[mcp]" + + Args: server_params (StdioServerParams): Parameters for the MCP server connection, including command to run and its arguments From 540c4fb3457c1b54e93a74f3f0a29a54c2b19c07 Mon Sep 17 00:00:00 2001 From: Andrej Kyselica Date: Tue, 11 Feb 2025 13:17:34 -0500 Subject: [PATCH 36/36] feat: DockerCommandLineCodeExecutor support for additional volume mounts, exposed host ports (#5383) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the following additional configuration options to DockerCommandLineCodeExectutor: - **extra_volumes** (Optional[Dict[str, Dict[str, str]]], optional): A dictionary of extra volumes (beyond the work_dir) to mount to the container. Defaults to None. - **extra_hosts** (Optional[Dict[str, str]], optional): A dictionary of host mappings to add to the container. (See Docker docs on extra_hosts) Defaults to None. - **init_command** (Optional[str], optional): A shell command to run before each shell operation execution. Defaults to None.  ## Why are these changes needed? See linked issue below. In summary: Enable the agents to: - work with a richer set of sys admin tools on top of code execution - add support for a 'project' directory the agents can interact on that's accessible by bash tools and custom scripts ## Related issue number Closes #5363 ## Checks - [x] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --- .../docker/_docker_code_executor.py | 26 +++++++++-- .../test_docker_commandline_code_executor.py | 45 +++++++++++++++++++ 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/code_executors/docker/_docker_code_executor.py b/python/packages/autogen-ext/src/autogen_ext/code_executors/docker/_docker_code_executor.py index 05924e186643..1ac058a9680f 100644 --- a/python/packages/autogen-ext/src/autogen_ext/code_executors/docker/_docker_code_executor.py +++ b/python/packages/autogen-ext/src/autogen_ext/code_executors/docker/_docker_code_executor.py @@ -12,7 +12,7 @@ from hashlib import sha256 from pathlib import Path from types import TracebackType -from typing import Any, Callable, ClassVar, List, Optional, ParamSpec, Type, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, ParamSpec, Type, Union from autogen_core import CancellationToken from autogen_core.code_executor import ( @@ -88,6 +88,13 @@ class DockerCommandLineCodeExecutor(CodeExecutor): the Python process exits with atext. Defaults to True. functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list. functions_module (str, optional): The name of the module that will be created to store the functions. Defaults to "functions". + extra_volumes (Optional[Dict[str, Dict[str, str]]], optional): A dictionary of extra volumes (beyond the work_dir) to mount to the container; + key is host source path and value 'bind' is the container path. See Defaults to None. + Example: extra_volumes = {'/home/user1/': {'bind': '/mnt/vol2', 'mode': 'rw'}, '/var/www': {'bind': '/mnt/vol1', 'mode': 'ro'}} + extra_hosts (Optional[Dict[str, str]], optional): A dictionary of host mappings to add to the container. (See Docker docs on extra_hosts) Defaults to None. + Example: extra_hosts = {"kubernetes.docker.internal": "host-gateway"} + init_command (Optional[str], optional): A shell command to run before each shell operation execution. Defaults to None. + Example: init_command="kubectl config use-context docker-hub" """ SUPPORTED_LANGUAGES: ClassVar[List[str]] = [ @@ -126,6 +133,9 @@ def __init__( ] ] = [], functions_module: str = "functions", + extra_volumes: Optional[Dict[str, Dict[str, str]]] = None, + extra_hosts: Optional[Dict[str, str]] = None, + init_command: Optional[str] = None, ): if timeout < 1: raise ValueError("Timeout must be greater than or equal to 1.") @@ -157,6 +167,10 @@ def __init__( self._functions_module = functions_module self._functions = functions + self._extra_volumes = extra_volumes if extra_volumes is not None else {} + self._extra_hosts = extra_hosts if extra_hosts is not None else {} + self._init_command = init_command + # Setup could take some time so we intentionally wait for the first code block to do it. if len(functions) > 0: self._setup_functions_complete = False @@ -354,16 +368,22 @@ async def start(self) -> None: # Let the docker exception escape if this fails. await asyncio.to_thread(client.images.pull, self._image) + # Prepare the command (if needed) + shell_command = "/bin/sh" + command = ["-c", f"{(self._init_command)};exec {shell_command}"] if self._init_command else None + self._container = await asyncio.to_thread( client.containers.create, self._image, name=self.container_name, - entrypoint="/bin/sh", + entrypoint=shell_command, + command=command, tty=True, detach=True, auto_remove=self._auto_remove, - volumes={str(self._bind_dir.resolve()): {"bind": "/workspace", "mode": "rw"}}, + volumes={str(self._bind_dir.resolve()): {"bind": "/workspace", "mode": "rw"}, **self._extra_volumes}, working_dir="/workspace", + extra_hosts=self._extra_hosts, ) await asyncio.to_thread(self._container.start) diff --git a/python/packages/autogen-ext/tests/code_executors/test_docker_commandline_code_executor.py b/python/packages/autogen-ext/tests/code_executors/test_docker_commandline_code_executor.py index ad74236008fa..6c65835d183d 100644 --- a/python/packages/autogen-ext/tests/code_executors/test_docker_commandline_code_executor.py +++ b/python/packages/autogen-ext/tests/code_executors/test_docker_commandline_code_executor.py @@ -164,3 +164,48 @@ async def test_docker_commandline_code_executor_start_stop_context_manager() -> with tempfile.TemporaryDirectory() as temp_dir: async with DockerCommandLineCodeExecutor(work_dir=temp_dir) as _exec: pass + + +@pytest.mark.asyncio +async def test_docker_commandline_code_executor_extra_args() -> None: + if not docker_tests_enabled(): + pytest.skip("Docker tests are disabled") + + with tempfile.TemporaryDirectory() as temp_dir: + # Create a file in temp_dir to mount + host_file_path = Path(temp_dir) / "host_file.txt" + host_file_path.write_text("This is a test file.") + + container_file_path = "/container/host_file.txt" + + extra_volumes = {str(host_file_path): {"bind": container_file_path, "mode": "rw"}} + init_command = "echo 'Initialization command executed' > /workspace/init_command.txt" + extra_hosts = {"example.com": "127.0.0.1"} + + async with DockerCommandLineCodeExecutor( + work_dir=temp_dir, + extra_volumes=extra_volumes, + init_command=init_command, + extra_hosts=extra_hosts, + ) as executor: + cancellation_token = CancellationToken() + + # Verify init_command was executed + init_command_file_path = Path(temp_dir) / "init_command.txt" + assert init_command_file_path.exists() + + # Verify extra_hosts + ns_lookup_code_blocks = [ + CodeBlock(code="import socket; print(socket.gethostbyname('example.com'))", language="python") + ] + ns_lookup_result = await executor.execute_code_blocks(ns_lookup_code_blocks, cancellation_token) + assert ns_lookup_result.exit_code == 0 + assert "127.0.0.1" in ns_lookup_result.output + + # Verify the file is accessible in the volume mounted in extra_volumes + code_blocks = [ + CodeBlock(code=f"with open('{container_file_path}') as f: print(f.read())", language="python") + ] + code_result = await executor.execute_code_blocks(code_blocks, cancellation_token) + assert code_result.exit_code == 0 + assert "This is a test file." in code_result.output