1414logger = logging .getLogger (__name__ )
1515
1616
17+ DEFAULT_MAX_CONCURRENCY = 16
18+
1719def runner_exception_hook (args : threading .ExceptHookArgs ):
1820 print (args )
1921 raise args .exc_type
@@ -22,6 +24,16 @@ def runner_exception_hook(args: threading.ExceptHookArgs):
2224# set a custom exception hook
2325# threading.excepthook = runner_exception_hook
2426
27+ def as_completed (loop , coros , max_concurrency ):
28+ if max_concurrency == - 1 :
29+ return asyncio .as_completed (coros , loop = loop )
30+
31+ semaphore = asyncio .Semaphore (max_concurrency , loop = loop )
32+ async def sem_coro (coro ):
33+ async with semaphore :
34+ return await coro
35+
36+ return asyncio .as_completed ([sem_coro (c ) for c in coros ], loop = loop )
2537
2638class Runner (threading .Thread ):
2739 def __init__ (
@@ -30,26 +42,29 @@ def __init__(
3042 desc : str ,
3143 keep_progress_bar : bool = True ,
3244 raise_exceptions : bool = True ,
45+ max_concurrency : int = DEFAULT_MAX_CONCURRENCY ,
3346 ):
3447 super ().__init__ ()
3548 self .jobs = jobs
3649 self .desc = desc
3750 self .keep_progress_bar = keep_progress_bar
3851 self .raise_exceptions = raise_exceptions
39- self .futures = []
52+ self .max_concurrency = max_concurrency
4053
4154 # create task
4255 self .loop = asyncio .new_event_loop ()
43- for job in self .jobs :
44- coroutine , name = job
45- self .futures .append (self .loop .create_task (coroutine , name = name ))
56+ self .futures = as_completed (
57+ self .loop ,
58+ [coro for coro , _ in self .jobs ],
59+ self .max_concurrency )
4660
4761 async def _aresults (self ) -> t .List [t .Any ]:
4862 results = []
63+
4964 for future in tqdm (
50- asyncio . as_completed ( self .futures ) ,
65+ self .futures ,
5166 desc = self .desc ,
52- total = len (self .futures ),
67+ total = len (self .jobs ),
5368 # whether you want to keep the progress bar after completion
5469 leave = self .keep_progress_bar ,
5570 ):
@@ -85,6 +100,7 @@ class Executor:
85100 keep_progress_bar : bool = True
86101 jobs : t .List [t .Any ] = field (default_factory = list , repr = False )
87102 raise_exceptions : bool = False
103+ max_concurrency : int = DEFAULT_MAX_CONCURRENCY
88104
89105 def wrap_callable_with_index (self , callable : t .Callable , counter ):
90106 async def wrapped_callable_async (* args , ** kwargs ):
@@ -104,6 +120,7 @@ def results(self) -> t.List[t.Any]:
104120 desc = self .desc ,
105121 keep_progress_bar = self .keep_progress_bar ,
106122 raise_exceptions = self .raise_exceptions ,
123+ max_concurrency = self .max_concurrency ,
107124 )
108125 executor_job .start ()
109126 try :
0 commit comments