-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Daniel Balsam
authored and
Daniel Balsam
committed
Apr 8, 2024
1 parent
3b67c1d
commit f186219
Showing
10 changed files
with
23,115 additions
and
569 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import sys\n", | ||
"module_path = os.path.abspath(os.path.join('../'))\n", | ||
"if module_path not in sys.path:\n", | ||
" sys.path.append(module_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"from dotenv import load_dotenv\n", | ||
"\n", | ||
"load_dotenv()\n", | ||
"\n", | ||
"from surv_ai.lib.log import logger, AgentLogLevel\n", | ||
"import logging\n", | ||
"\n", | ||
"logger.set_log_level(AgentLogLevel.OUTPUT)\n", | ||
"logging.basicConfig(level=logging.INFO)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"INFO:surv_ai:\u001b[1m\u001b[37m...Using tool: SEARCH(\"debt ceiling crisis responsibility Republicans\")...\u001b[0m\n", | ||
"ERROR:surv_ai:Could not retrieve all pages.\n", | ||
"Traceback (most recent call last):\n", | ||
" File \"/Users/danielbalsam/surv_ai/surv_ai/surv_ai/lib/tools/query/google_custom_search.py\", line 66, in _search\n", | ||
" new_records = data[\"items\"]\n", | ||
" ~~~~^^^^^^^^^\n", | ||
"KeyError: 'items'\n", | ||
"INFO:surv_ai:\u001b[1m\u001b[37m...Using tool: SEARCH(\"responsibility for impending debt ceiling crisis\")...\u001b[0m\n", | ||
"ERROR:surv_ai:Could not retrieve all pages.\n", | ||
"Traceback (most recent call last):\n", | ||
" File \"/Users/danielbalsam/surv_ai/surv_ai/surv_ai/lib/tools/query/google_custom_search.py\", line 66, in _search\n", | ||
" new_records = data[\"items\"]\n", | ||
" ~~~~^^^^^^^^^\n", | ||
"KeyError: 'items'\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from surv_ai import (\n", | ||
" GPTClient,\n", | ||
" Model,\n", | ||
" ToolBelt,\n", | ||
" GoogleCustomSearchTool,\n", | ||
" Knowledge,\n", | ||
" Survey,\n", | ||
" SurveyParameter\n", | ||
")\n", | ||
"\n", | ||
"client = GPTClient(os.environ[\"OPEN_AI_API_KEY\"])\n", | ||
"\n", | ||
"def build_parameter(news_source: str):\n", | ||
" tool_belt = ToolBelt(\n", | ||
" tools=[\n", | ||
" GoogleCustomSearchTool(\n", | ||
" google_api_key=os.environ[\"GOOGLE_API_KEY\"],\n", | ||
" google_search_engine_id=os.environ[\"GOOGLE_SEARCH_ENGINE_ID\"],\n", | ||
" n_pages=30,\n", | ||
" start_date=\"2023-05-01\",\n", | ||
" end_date=\"2023-06-01\",\n", | ||
" ),\n", | ||
" ],\n", | ||
" )\n", | ||
" base_knowledge = [\n", | ||
" Knowledge(\n", | ||
" text=f\"It is currently 2023-06-01. The included articles were published between 2023-05-01 and 2023-06-01\",\n", | ||
" source=\"Additional context\",\n", | ||
" ),\n", | ||
" ]\n", | ||
" return SurveyParameter(\n", | ||
" independent_variable=news_source,\n", | ||
" kwargs={\n", | ||
" \"client\": client,\n", | ||
" \"n_agents\": 100,\n", | ||
" \"max_knowledge_per_agent\":5,\n", | ||
" \"max_concurrency\": 10,\n", | ||
" \"tool_belt\": tool_belt,\n", | ||
" \"base_knowledge\": base_knowledge,\n", | ||
" },\n", | ||
" )\n", | ||
"\n", | ||
"news_sources = [\n", | ||
" \"nytimes.com\",\n", | ||
" \"cnn.com\",\n", | ||
" \"wsj.com\",\n", | ||
" \"foxnews.com\",\n", | ||
"]\n", | ||
"\n", | ||
"model = Model(\n", | ||
" Survey,\n", | ||
" parameters=[build_parameter(news_source) for news_source in news_sources],\n", | ||
")\n", | ||
"results = await model.build(\n", | ||
" \"Republicans are responsible for the impending debt ceiling crisis.\"\n", | ||
")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from matplotlib import pyplot\n", | ||
"\n", | ||
"variables = model.get_plot_variables(results)\n", | ||
"pyplot.scatter(*variables)\n", | ||
"pyplot.ylabel(\"Agreement (n_agents=100)\")\n", | ||
"pyplot.title(\"Republicans were responsible for the impending debt ceiling crisis.\")\n", | ||
"pyplot.ylim(0, 0.5)\n", | ||
"pyplot.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from surv_ai import (\n", | ||
" GPTClient,\n", | ||
" AnthropicClient,\n", | ||
" Model,\n", | ||
" ToolBelt,\n", | ||
" GoogleCustomSearchTool,\n", | ||
" Knowledge,\n", | ||
" Survey,\n", | ||
" SurveyParameter,\n", | ||
" LargeLanguageModelClientInterface\n", | ||
")\n", | ||
"\n", | ||
"clients = [AnthropicClient(os.environ[\"ANTHROPIC_API_KEY\"]), GPTClient(os.environ[\"OPEN_AI_API_KEY\"])]\n", | ||
"\n", | ||
"def build_parameter(client: LargeLanguageModelClientInterface):\n", | ||
" tool_belt = ToolBelt(\n", | ||
" tools=[\n", | ||
" GoogleCustomSearchTool(\n", | ||
" google_api_key=os.environ[\"GOOGLE_API_KEY\"],\n", | ||
" google_search_engine_id=os.environ[\"GOOGLE_SEARCH_ENGINE_ID\"],\n", | ||
" n_pages=20,\n", | ||
" start_date=\"2023-01-01\",\n", | ||
" end_date=\"2024-05-01\",\n", | ||
" )\n", | ||
" ],\n", | ||
" )\n", | ||
" base_knowledge = [\n", | ||
" Knowledge(\n", | ||
" text=f\"It is currently 2023-06-01. The included articles were published between 2023-01-01 and 2023-06-01\",\n", | ||
" source=\"Additional context\",\n", | ||
" ),\n", | ||
" ]\n", | ||
" return SurveyParameter(\n", | ||
" independent_variable=client.__class__.__name__,\n", | ||
" kwargs={\n", | ||
" \"client\": client,\n", | ||
" \"n_agents\": 100,\n", | ||
" \"max_knowledge_per_agent\":5,\n", | ||
" \"max_concurrency\": 3,\n", | ||
" \"tool_belt\": tool_belt,\n", | ||
" \"base_knowledge\": base_knowledge,\n", | ||
" },\n", | ||
" )\n", | ||
"\n", | ||
"model = Model(\n", | ||
" Survey,\n", | ||
" parameters=[build_parameter(client) for client in clients],\n", | ||
")\n", | ||
"results = await model.build(\n", | ||
" \"OpenAI has been irresponsible in their handling of AI technology.\"\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from matplotlib import pyplot\n", | ||
"\n", | ||
"variables = model.get_plot_variables(results)\n", | ||
"pyplot.scatter(*variables)\n", | ||
"pyplot.ylabel(\"Agreement (n_agents=100)\")\n", | ||
"pyplot.title(\"OpenAI has been irresponsible in their handling of AI technology.\")\n", | ||
"pyplot.ylim(0, 0.5)\n", | ||
"pyplot.show()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "superintelligence-kEKUOLhR-py3.11", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.6" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.