diff --git a/custom_components/pyscript/eval.py b/custom_components/pyscript/eval.py index 4ca5371..0dfcf06 100644 --- a/custom_components/pyscript/eval.py +++ b/custom_components/pyscript/eval.py @@ -313,7 +313,7 @@ def getattr(self): class EvalFunc: """Class for a callable pyscript function.""" - def __init__(self, func_def, code_list, code_str, global_ctx): + def __init__(self, func_def, code_list, code_str, global_ctx, async_func=False): """Initialize a function calling context.""" self.func_def = func_def self.name = func_def.name @@ -338,6 +338,7 @@ def __init__(self, func_def, code_list, code_str, global_ctx): self.trigger = [] self.trigger_service = set() self.has_closure = False + self.async_func = async_func def get_name(self): """Return the function name.""" @@ -930,14 +931,18 @@ async def ast_not_implemented(self, arg, *args): name = "ast_" + arg.__class__.__name__.lower() raise NotImplementedError(f"{self.name}: not implemented ast " + name) - async def aeval(self, arg, undefined_check=True): + async def aeval(self, arg, undefined_check=True, do_await=True): """Vector to specific function based on ast class type.""" name = "ast_" + arg.__class__.__name__.lower() try: if hasattr(arg, "lineno"): self.lineno = arg.lineno self.col_offset = arg.col_offset - val = await getattr(self, name, self.ast_not_implemented)(arg) + val = ( + await getattr(self, name, self.ast_not_implemented)(arg) + if do_await + else getattr(self, name, self.ast_not_implemented)(arg) + ) if undefined_check and isinstance(val, EvalName): raise NameError(f"name '{val.name}' is not defined") return val @@ -1102,7 +1107,7 @@ async def ast_classdef(self, arg): del sym_table["__init__"] sym_table_assign[arg.name].set(type(arg.name, tuple(bases), sym_table)) - async def ast_functiondef(self, arg): + async def ast_functiondef(self, arg, async_func=False): """Evaluate function definition.""" other_dec = [] dec_name = None @@ -1158,7 +1163,7 @@ async def executor_wrap(*args, **kwargs): self.sym_table[arg.name].set(func) return - func = EvalFunc(arg, self.code_list, self.code_str, self.global_ctx) + func = EvalFunc(arg, self.code_list, self.code_str, self.global_ctx, async_func) await func.eval_defaults(self) await func.resolve_nonlocals(self) name = func.get_name() @@ -1215,7 +1220,7 @@ async def ast_lambda(self, arg): async def ast_asyncfunctiondef(self, arg): """Evaluate async function definition.""" - return await self.ast_functiondef(arg) + return await self.ast_functiondef(arg, async_func=True) async def ast_try(self, arg): """Execute try...except statement.""" @@ -2020,7 +2025,10 @@ async def ast_formattedvalue(self, arg): async def ast_await(self, arg): """Evaluate await expr.""" - return await self.aeval(arg.value) + coro = await self.aeval(arg.value, do_await=False) + if coro and asyncio.iscoroutine(coro): + return await coro + return coro async def get_target_names(self, lhs): """Recursively find all the target names mentioned in the AST tree.""" diff --git a/tests/test_unit_eval.py b/tests/test_unit_eval.py index 5983541..ca6626b 100644 --- a/tests/test_unit_eval.py +++ b/tests/test_unit_eval.py @@ -1415,6 +1415,39 @@ async def func(): """, 42, ], + [ + """ +import asyncio +async def coro(): + await asyncio.sleep(0.1) + return "done" + +await coro() +""", + "done", + ], + [ + """ +import asyncio + +@pyscript_compile +async def nested(): + await asyncio.sleep(1e-8) + return 42 + +@pyscript_compile +async def run(): + task = asyncio.create_task(nested()) + + # "task" can now be used to cancel "nested()", or + # can simply be awaited to wait until it is complete: + await task + return "done" + +await run() +""", + "done", + ], ]