From 869f7d4439423bb758b172c747456a81c9c6f023 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Fern=C3=A1ndez?= <7312236+fernandezcuesta@users.noreply.github.com> Date: Mon, 17 Feb 2025 01:10:28 +0100 Subject: [PATCH 1/2] fix: #109 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jesús Fernández <7312236+fernandezcuesta@users.noreply.github.com> --- crossplane/function/runtime.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/crossplane/function/runtime.py b/crossplane/function/runtime.py index f91f392..599a806 100644 --- a/crossplane/function/runtime.py +++ b/crossplane/function/runtime.py @@ -16,6 +16,7 @@ import asyncio import os +import signal import grpc from grpc_reflection.v1alpha import reflection @@ -25,6 +26,7 @@ import crossplane.function.proto.v1beta1.run_function_pb2 as fnv1beta1 import crossplane.function.proto.v1beta1.run_function_pb2_grpc as grpcv1beta1 +GRACE_PERIOD = 5 SERVICE_NAMES = ( reflection.SERVICE_NAME, fnv1.DESCRIPTOR.services_by_name["FunctionRunnerService"].full_name, @@ -64,6 +66,10 @@ def load_credentials(tls_certs_dir: str) -> grpc.ServerCredentials: ) +async def _stop(server, timeout): # noqa: ASYNC109 + await server.stop(grace=timeout) + + def serve( function: grpcv1.FunctionRunnerService, address: str, @@ -90,6 +96,10 @@ def serve( server = grpc.aio.server() + loop.add_signal_handler( + signal.SIGTERM, lambda: asyncio.create_task(_stop(server, timeout=GRACE_PERIOD)) + ) + grpcv1.add_FunctionRunnerServiceServicer_to_server(function, server) grpcv1beta1.add_FunctionRunnerServiceServicer_to_server( BetaFunctionRunner(wrapped=function), server @@ -116,7 +126,7 @@ async def start(): try: loop.run_until_complete(start()) finally: - loop.run_until_complete(server.stop(grace=5)) + loop.run_until_complete(server.stop(grace=GRACE_PERIOD)) loop.close() From 50b187cae8336130070e22f07a7da2fac14392a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Fern=C3=A1ndez?= <7312236+fernandezcuesta@users.noreply.github.com> Date: Tue, 18 Feb 2025 14:08:44 +0100 Subject: [PATCH 2/2] chore: use signal.signal and add test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jesús Fernández <7312236+fernandezcuesta@users.noreply.github.com> --- crossplane/function/resource.py | 4 ++-- crossplane/function/runtime.py | 12 +++++++----- tests/test_runtime.py | 22 ++++++++++++++++++++++ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/crossplane/function/resource.py b/crossplane/function/resource.py index 039af46..c9bfc69 100644 --- a/crossplane/function/resource.py +++ b/crossplane/function/resource.py @@ -45,8 +45,8 @@ def update(r: fnv1.Resource, source: dict | structpb.Struct | pydantic.BaseModel # apiVersion is set to its default value 's3.aws.upbound.io/v1beta2' # (and not explicitly provided during initialization), it will be # excluded from the serialized output. - data['apiVersion'] = source.apiVersion - data['kind'] = source.kind + data["apiVersion"] = source.apiVersion + data["kind"] = source.kind r.resource.update(data) case structpb.Struct(): # TODO(negz): Use struct_to_dict and update to match other semantics? diff --git a/crossplane/function/runtime.py b/crossplane/function/runtime.py index 599a806..3783bbc 100644 --- a/crossplane/function/runtime.py +++ b/crossplane/function/runtime.py @@ -66,8 +66,8 @@ def load_credentials(tls_certs_dir: str) -> grpc.ServerCredentials: ) -async def _stop(server, timeout): # noqa: ASYNC109 - await server.stop(grace=timeout) +async def _stop(server, grace=GRACE_PERIOD): + await server.stop(grace=grace) def serve( @@ -96,8 +96,9 @@ def serve( server = grpc.aio.server() - loop.add_signal_handler( - signal.SIGTERM, lambda: asyncio.create_task(_stop(server, timeout=GRACE_PERIOD)) + signal.signal( + signal.SIGTERM, + lambda _, __: asyncio.create_task(_stop(server)), ) grpcv1.add_FunctionRunnerServiceServicer_to_server(function, server) @@ -126,7 +127,8 @@ async def start(): try: loop.run_until_complete(start()) finally: - loop.run_until_complete(server.stop(grace=GRACE_PERIOD)) + if server._server.is_running(): + loop.run_until_complete(server.stop(grace=GRACE_PERIOD)) loop.close() diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 26229aa..5c31269 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import dataclasses +import os +import signal import unittest import grpc @@ -52,6 +55,25 @@ class TestCase: self.assertEqual(rsp, case.want, "-want, +got") + async def test_sigterm_handling(self) -> None: + async def mock_server(): + await server.start() + await asyncio.sleep(1) + self.assertTrue(server._server.is_running(), "Server should be running") + os.kill(os.getpid(), signal.SIGTERM) + await server.wait_for_termination() + self.assertFalse( + server._server.is_running(), + "Server should have been stopped on SIGTERM", + ) + + server = grpc.aio.server() + signal.signal( + signal.SIGTERM, + lambda _, __: asyncio.create_task(runtime._stop(server)), + ) + await mock_server() + class EchoRunner(grpcv1.FunctionRunnerService): def __init__(self):