From 00db07de9b82fe8e24b41d4a2b7dbf7653087dd9 Mon Sep 17 00:00:00 2001 From: linuxdaemon Date: Tue, 9 Apr 2024 09:47:21 +0000 Subject: [PATCH] Clean up typing for executor pool --- cloudbot/bot.py | 6 +++++- cloudbot/util/executor_pool.py | 29 +++++++++++++++++++---------- tests/util/mock_bot.py | 7 ++++++- 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/cloudbot/bot.py b/cloudbot/bot.py index abcd27da..efb8b409 100644 --- a/cloudbot/bot.py +++ b/cloudbot/bot.py @@ -160,7 +160,11 @@ def __init__( self.db_engine = create_engine(db_path) database.configure(self.db_engine) self.db_executor_pool = ExecutorPool( - 50, max_workers=1, thread_name_prefix="cloudbot-db", loop=self.loop + 50, + max_workers=1, + thread_name_prefix="cloudbot-db", + loop=self.loop, + executor_type=ThreadPoolExecutor, ) logger.debug("Database system initialised.") diff --git a/cloudbot/util/executor_pool.py b/cloudbot/util/executor_pool.py index b053ddee..3bd21782 100644 --- a/cloudbot/util/executor_pool.py +++ b/cloudbot/util/executor_pool.py @@ -1,8 +1,9 @@ -from asyncio import AbstractEventLoop import logging import os import random -from concurrent.futures import ThreadPoolExecutor +from asyncio import AbstractEventLoop +from concurrent.futures import Executor +from typing import Generic, List, Optional, Type, TypeVar from cloudbot.util.async_util import create_future @@ -26,9 +27,17 @@ def __del__(self): self.release() -class ExecutorPool: +T = TypeVar("T", bound=Executor) + + +class ExecutorPool(Generic[T]): def __init__( - self, max_executors=None, executor_type=ThreadPoolExecutor, *, loop:AbstractEventLoop, **kwargs + self, + max_executors: Optional[int] = None, + *, + executor_type: Type[T], + loop: AbstractEventLoop, + **kwargs, ) -> None: if max_executors is None: max_executors = (os.cpu_count() or 1) * 5 @@ -40,14 +49,14 @@ def __init__( self._exec_class = executor_type self._exec_args = kwargs - self._executors = [] - self._free_executors = [] + self._executors: List[T] = [] + self._free_executors: List[T] = [] self._executor_waiter = create_future(loop) - def get(self): + def get(self) -> ExecutorWrapper: return ExecutorWrapper(self, self._get()) - def _get(self): + def _get(self) -> T: if not self._free_executors: if len(self._executors) < self._max: return self._add_executor() @@ -56,10 +65,10 @@ def _get(self): return self._free_executors.pop() - def release_executor(self, executor): + def release_executor(self, executor: T) -> None: self._free_executors.append(executor) - def _add_executor(self): + def _add_executor(self) -> T: exc = self._exec_class(**self._exec_args) self._executors.append(exc) diff --git a/tests/util/mock_bot.py b/tests/util/mock_bot.py index a8041294..1b341a4b 100644 --- a/tests/util/mock_bot.py +++ b/tests/util/mock_bot.py @@ -1,4 +1,5 @@ import logging +from concurrent.futures import ThreadPoolExecutor from typing import Awaitable, Dict, Optional from watchdog.observers import Observer @@ -23,7 +24,11 @@ def __init__( ): if loop: self.db_executor_pool = ExecutorPool( - 50, max_workers=1, thread_name_prefix="cloudbot-db", loop=loop + 50, + max_workers=1, + thread_name_prefix="cloudbot-db", + loop=loop, + executor_type=ThreadPoolExecutor, ) else: self.db_executor_pool = None