diff --git a/main.py b/main.py index 05d5537..e059e06 100644 --- a/main.py +++ b/main.py @@ -23,13 +23,19 @@ def main() -> int: parser = argparse.ArgumentParser(description="Run the browser agent with a query.") - parser.add_argument( + query_group = parser.add_mutually_exclusive_group(required=True) + query_group.add_argument( "--query", type=str, - required=True, help="The query for the browser agent to execute.", ) - + query_group.add_argument( + "--query_file", + type=argparse.FileType("r", encoding="utf-8"), + help=( + "Path to a file containing the query for the browser agent to execute." + ), + ) parser.add_argument( "--env", type=str, @@ -56,6 +62,18 @@ def main() -> int: ) args = parser.parse_args() + if args.query_file: + with args.query_file as f: + filename = f.name + query = f.read() + if not query.strip(): + raise ValueError( + f"Query file '{filename}' is empty or contains only whitespace." + ) + else: + # This branch is taken if --query is used. + query = args.query + if args.env == "playwright": env = PlaywrightComputer( screen_size=PLAYWRIGHT_SCREEN_SIZE, @@ -73,7 +91,7 @@ def main() -> int: with env as browser_computer: agent = BrowserAgent( browser_computer=browser_computer, - query=args.query, + query=query, model_name=args.model, ) agent.agent_loop() diff --git a/test_main.py b/test_main.py index 4bee9ff..bf38890 100644 --- a/test_main.py +++ b/test_main.py @@ -13,7 +13,7 @@ # limitations under the License. import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, mock_open import main class TestMain(unittest.TestCase): @@ -27,6 +27,7 @@ def test_main_playwright(self, mock_browser_agent, mock_playwright_computer, moc mock_args.initial_url = 'test_url' mock_args.highlight_mouse = True mock_args.query = 'test_query' + mock_args.query_file = None mock_args.model = 'test_model' mock_args.api_server = None mock_args.api_server_key = None @@ -39,7 +40,11 @@ def test_main_playwright(self, mock_browser_agent, mock_playwright_computer, moc initial_url='test_url', highlight_mouse=True ) - mock_browser_agent.assert_called_once() + mock_browser_agent.assert_called_once_with( + browser_computer=mock_playwright_computer.return_value.__enter__.return_value, + query='test_query', + model_name='test_model', + ) mock_browser_agent.return_value.agent_loop.assert_called_once() @patch('main.argparse.ArgumentParser') @@ -49,6 +54,7 @@ def test_main_browserbase(self, mock_browser_agent, mock_browserbase_computer, m mock_args = MagicMock() mock_args.env = 'browserbase' mock_args.query = 'test_query' + mock_args.query_file = None mock_args.model = 'test_model' mock_args.api_server = None mock_args.api_server_key = None @@ -62,8 +68,61 @@ def test_main_browserbase(self, mock_browser_agent, mock_browserbase_computer, m screen_size=main.PLAYWRIGHT_SCREEN_SIZE, initial_url='test_url' ) - mock_browser_agent.assert_called_once() + mock_browser_agent.assert_called_once_with( + browser_computer=mock_browserbase_computer.return_value.__enter__.return_value, + query='test_query', + model_name='test_model', + ) mock_browser_agent.return_value.agent_loop.assert_called_once() + @patch('main.argparse.ArgumentParser') + @patch('main.PlaywrightComputer') + @patch('main.BrowserAgent') + def test_main_with_query_file(self, mock_browser_agent, mock_playwright_computer, mock_arg_parser): + mock_args = MagicMock() + mock_args.env = 'playwright' + mock_args.initial_url = 'test_url' + mock_args.highlight_mouse = False + mock_args.query = None + # mock_open is used to simulate opening a file + m = mock_open(read_data='query from file') + # The file object needs a 'name' attribute for the error message + m.return_value.name = 'test_query.txt' + mock_args.query_file = m() + mock_args.model = 'test_model' + mock_args.api_server = None + mock_args.api_server_key = None + mock_arg_parser.return_value.parse_args.return_value = mock_args + + main.main() + + mock_playwright_computer.assert_called_once_with( + screen_size=main.PLAYWRIGHT_SCREEN_SIZE, + initial_url='test_url', + highlight_mouse=False + ) + mock_browser_agent.assert_called_once_with( + browser_computer=mock_playwright_computer.return_value.__enter__.return_value, + query='query from file', + model_name='test_model', + ) + mock_browser_agent.return_value.agent_loop.assert_called_once() + + @patch('main.argparse.ArgumentParser') + def test_main_with_empty_query_file(self, mock_arg_parser): + mock_args = MagicMock() + mock_args.query = None + # Simulate an empty file + m = mock_open(read_data=' ') + m.return_value.name = 'empty.txt' + mock_args.query_file = m() + mock_arg_parser.return_value.parse_args.return_value = mock_args + + with self.assertRaisesRegex( + ValueError, "Query file 'empty.txt' is empty or contains only whitespace." + ): + main.main() + + if __name__ == '__main__': unittest.main()