diff --git a/scrapy_playwright/_utils.py b/scrapy_playwright/_utils.py index 26598289..c739a0e0 100644 --- a/scrapy_playwright/_utils.py +++ b/scrapy_playwright/_utils.py @@ -1,4 +1,5 @@ import asyncio +import inspect import logging import platform import threading @@ -8,6 +9,7 @@ from playwright.async_api import Error, Page, Request, Response from scrapy.http.headers import Headers from scrapy.settings import Settings +from scrapy.utils.asyncgen import collect_asyncgen from scrapy.utils.python import to_unicode from twisted.internet.defer import Deferred from w3lib.encoding import html_body_declared_encoding, http_content_type_encoding @@ -117,6 +119,7 @@ class _ThreadedLoopAdapter: @classmethod async def _handle_coro(cls, coro, future) -> None: try: + coro = collect_asyncgen(coro) if inspect.isasyncgen(coro) else coro future.set_result(await coro) except Exception as exc: future.set_exception(exc) @@ -129,10 +132,14 @@ async def _process_queue(cls) -> None: cls._coro_queue.task_done() @classmethod - def _deferred_from_coro(cls, coro) -> Deferred: + def _ensure_future(cls, coro: Awaitable) -> asyncio.Future: future: asyncio.Future = asyncio.Future() asyncio.run_coroutine_threadsafe(cls._coro_queue.put((coro, future)), cls._loop) - return scrapy.utils.defer.deferred_from_coro(future) + return future + + @classmethod + def _deferred_from_coro(cls, coro: Awaitable) -> Deferred: + return scrapy.utils.defer.deferred_from_coro(cls._ensure_future(coro)) @classmethod def start(cls, caller_id: int) -> None: diff --git a/scrapy_playwright/handler.py b/scrapy_playwright/handler.py index b475d615..d45c5c23 100644 --- a/scrapy_playwright/handler.py +++ b/scrapy_playwright/handler.py @@ -142,7 +142,10 @@ def __init__(self, crawler: Crawler) -> None: self.config = Config.from_settings(crawler.settings) if self.config.use_threaded_loop: + logger.warning("Starting threaded loop") _ThreadedLoopAdapter.start(id(self)) + else: + logger.warning("NOT starting threaded loop") self.browser_launch_lock = asyncio.Lock() self.context_launch_lock = asyncio.Lock() diff --git a/scrapy_playwright/utils.py b/scrapy_playwright/utils.py new file mode 100644 index 00000000..a40a7e19 --- /dev/null +++ b/scrapy_playwright/utils.py @@ -0,0 +1,54 @@ +import functools +import inspect +from typing import Callable + +from ._utils import _ThreadedLoopAdapter + + +async def _run_async_gen(asyncgen): + async for item in asyncgen: + yield item + + +def use_threaded_loop(callback) -> Callable: + """Wrap a coroutine callback so that Playwright coroutines are executed in + the threaded event loop. + + On windows, Playwright runs in an event loop of its own in a separate thread. + If Playwright coroutines are awaited directly, they are assigned to the main + thread's event loop, resulting in: "ValueError: The future belongs to a + different loop than the one specified as the loop argument" + + Usage: + ``` + from playwright.async_api import Page + from scrapy_playwright.utils import use_threaded_loop + + @use_threaded_loop + async def parse(self, response): + page: Page = response.meta["playwright_page"] + await page.screenshot(path="example.png", full_page=True) + ``` + """ + + if not inspect.iscoroutinefunction(callback) and not inspect.isasyncgenfunction(callback): + raise RuntimeError( + f"Cannot decorate callback '{callback.__name__}' with 'use_threaded_loop':" + " callback must be a coroutine function or an async generator" + ) + + @functools.wraps(callback) + async def async_func_wrapper(*args, **kwargs): + future = _ThreadedLoopAdapter._ensure_future(callback(*args, **kwargs)) + return await future + + @functools.wraps(callback) + async def async_gen_wrapper(*args, **kwargs): + asyncgen = _run_async_gen(callback(*args, **kwargs)) + future = _ThreadedLoopAdapter._ensure_future(asyncgen) + for item in await future: + yield item + + if inspect.isasyncgenfunction(callback): + return async_gen_wrapper + return async_func_wrapper diff --git a/tests/tests_asyncio/test_threaded_loop.py b/tests/tests_asyncio/test_threaded_loop.py new file mode 100644 index 00000000..a947a705 --- /dev/null +++ b/tests/tests_asyncio/test_threaded_loop.py @@ -0,0 +1,93 @@ +import platform +from unittest import TestCase + +import pytest +import scrapy +from playwright.async_api import Page +from scrapy import signals +from scrapy.crawler import CrawlerProcess +from scrapy.utils.test import get_crawler +from scrapy_playwright.utils import use_threaded_loop + +from tests.mockserver import StaticMockServer + + +class ThreadedLoopSpider(scrapy.Spider): + name = "threaded_loop" + start_url: str + + def start_requests(self): + yield scrapy.Request( + url=self.start_url, + meta={"playwright": True, "playwright_include_page": True}, + ) + + @use_threaded_loop + async def parse(self, response, **kwargs): # pylint: disable=invalid-overridden-method + """async generator""" + page: Page = response.meta["playwright_page"] + title = await page.title() + await page.close() + yield {"url": response.url, "title": title} + yield scrapy.Request( + url=response.url + "?foo=bar", + meta={"playwright": True, "playwright_include_page": True}, + callback=self.parse_2, + ) + + @use_threaded_loop + async def parse_2(self, response): + page: Page = response.meta["playwright_page"] + title = await page.title() + await page.close() + return {"url": response.url, "title": title} + + +@pytest.mark.skipif( + platform.system() != "Windows", + reason="Test threaded loop implementation only on Windows", +) +class ThreadedLoopSpiderTestCase(TestCase): + def test_threaded_loop_spider(self): + items: list = [] + + def collect_items(item): + items.append(item) + + with StaticMockServer() as server: + index_url = server.urljoin("/index.html") + crawler = get_crawler( + spidercls=ThreadedLoopSpider, + settings_dict={ + "TWISTED_REACTOR": "twisted.internet.asyncioreactor.AsyncioSelectorReactor", + "DOWNLOAD_HANDLERS": { + "http": "scrapy_playwright.handler.ScrapyPlaywrightDownloadHandler", + }, + "_PLAYWRIGHT_THREADED_LOOP": True, + }, + ) + crawler.signals.connect(collect_items, signals.item_scraped) + process = CrawlerProcess() + process.crawl(crawler, start_url=index_url) + process.start() + + self.assertCountEqual( + items, + [ + {"url": index_url, "title": "Awesome site"}, + {"url": index_url + "?foo=bar", "title": "Awesome site"}, + ], + ) + + def test_use_threaded_loop_non_coroutine_function(self): + with pytest.raises(RuntimeError) as exc_info: + + @use_threaded_loop + def not_a_coroutine(): + pass + + self.assertEqual( + str(exc_info.value), + "Cannot decorate callback 'not_a_coroutine' with 'use_threaded_loop':" + " callback must be a coroutine function or an async generator", + )