From 9cee77bda269a7f5bc5229de29b8892a2eb102e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Fri, 27 Dec 2024 20:36:16 +0000 Subject: [PATCH] Enable inference serving capabilities on sagemaker endpoint using tornado --- build_artifacts/v2/v2.2/v2.2.0/Dockerfile | 3 +- build_artifacts/v2/v2.2/v2.2.0/cpu.env.in | 1 + .../dirs/etc/inference-server/__init__.py | 3 + .../v2.2.0/dirs/etc/inference-server/serve | 2 + .../dirs/etc/inference-server/server.py | 101 ++++++++++++++++++ .../tornado_server/__init__.py | 1 + .../tornado_server/async_server.py | 38 +++++++ .../tornado_server/sync_server.py | 37 +++++++ .../etc/inference-server/utils/__init__.py | 1 + .../etc/inference-server/utils/environment.py | 52 +++++++++ .../etc/inference-server/utils/exception.py | 16 +++ .../dirs/etc/inference-server/utils/logger.py | 43 ++++++++ 12 files changed, 297 insertions(+), 1 deletion(-) create mode 100644 build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/__init__.py create mode 100755 build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/serve create mode 100644 build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/server.py create mode 100644 build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/__init__.py create mode 100644 build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/async_server.py create mode 100644 build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/sync_server.py create mode 100644 build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/__init__.py create mode 100644 build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/environment.py create mode 100644 build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/exception.py create mode 100644 build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/logger.py diff --git a/build_artifacts/v2/v2.2/v2.2.0/Dockerfile b/build_artifacts/v2/v2.2/v2.2.0/Dockerfile index bc6b2fb2..dbb7dcf0 100644 --- a/build_artifacts/v2/v2.2/v2.2.0/Dockerfile +++ b/build_artifacts/v2/v2.2/v2.2.0/Dockerfile @@ -180,7 +180,8 @@ RUN mkdir -p $SAGEMAKER_LOGGING_DIR && \ && ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} python \ && rm -rf ${HOME_DIR}/oss_compliance* -ENV PATH="/opt/conda/bin:/opt/conda/condabin:$PATH" +# Adding inference-server to path, so that docker can run serve as executable. +ENV PATH="/etc/inference-server:/opt/conda/bin:/opt/conda/condabin:$PATH" WORKDIR "/home/${NB_USER}" ENV SHELL=/bin/bash ENV OPENSSL_MODULES=/opt/conda/lib64/ossl-modules/ diff --git a/build_artifacts/v2/v2.2/v2.2.0/cpu.env.in b/build_artifacts/v2/v2.2/v2.2.0/cpu.env.in index bb179aab..ed3675f7 100644 --- a/build_artifacts/v2/v2.2/v2.2.0/cpu.env.in +++ b/build_artifacts/v2/v2.2/v2.2.0/cpu.env.in @@ -11,6 +11,7 @@ conda-forge::amazon-sagemaker-jupyter-ai-q-developer[version='>=1.0.12,<2.0.0'] conda-forge::amazon-q-developer-jupyterlab-ext[version='>=3.4.0,<4.0.0'] conda-forge::langchain[version='>=0.2.16,<1.0.0'] conda-forge::fastapi[version='>=0.115.2,<1.0.0'] +conda-forge::tornado[version='>=6.3.3,<7.0.0'] conda-forge::uvicorn[version='>=0.32.0,<1.0.0'] conda-forge::pytorch[version='>=2.4.1,<3.0.0'] conda-forge::tensorflow[version='>=2.17.0,<3.0.0'] diff --git a/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/__init__.py b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/__init__.py new file mode 100644 index 00000000..0427e383 --- /dev/null +++ b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import + +import utils.logger diff --git a/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/serve b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/serve new file mode 100755 index 00000000..a6ea877b --- /dev/null +++ b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/serve @@ -0,0 +1,2 @@ +#!/bin/bash +python /etc/inference-server/server.py diff --git a/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/server.py b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/server.py new file mode 100644 index 00000000..83f968b6 --- /dev/null +++ b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/server.py @@ -0,0 +1,101 @@ +from __future__ import absolute_import + +import asyncio +import importlib +import logging +import subprocess +import sys +from pathlib import Path +from utils.environment import Environment +from utils.exception import ( + InferenceCodeLoadException, + RequirementsInstallException, + ServerStartException +) +from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER + +logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) + + +class TornadoServer: + def serve(self): + if asyncio.iscoroutinefunction(self._handler): + logger.info("Starting inference server in asynchronous mode...") + import tornado_server.async_server as inference_server + else: + logger.info("Starting inference server in synchronous mode...") + import tornado_server.sync_server as inference_server + + try: + asyncio.run(inference_server.serve(self._handler, self._environment)) + except Exception as e: + raise ServerStartException(e) + + +class InferenceServer(TornadoServer): + def __init__(self): + self._environment = Environment() + logger.setLevel(self._environment.logging_level) + logger.debug(f"Environment: {str(self._environment)}") + + self._path_to_inference_code = ( + Path(self._environment.base_directory).joinpath(self._environment.code_directory) + if self._environment.code_directory else + Path(self._environment.base_directory) + ) + logger.debug(f"Path to inference code: `{str(self._path_to_inference_code)}`") + + self._handler = None + + def initialize(self): + self._install_runtime_requirements() + self._handler = self._load_inference_handler() + + def _install_runtime_requirements(self): + logger.info("Installing runtime requirements...") + + requirements_txt = self._path_to_inference_code.joinpath(self._environment.requirements) + if requirements_txt.is_file(): + try: + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "-r", str(requirements_txt)] + ) + except Exception as e: + raise RequirementsInstallException(e) + else: + logger.debug(f"No requirements file was found at `{str(requirements_txt)}`") + + def _load_inference_handler(self) -> callable: + logger.info("Loading inference handler...") + + inference_module_name, handle_name = self._environment.code.split(".") + if inference_module_name and handle_name: + inference_module_file = f"{inference_module_name}.py" + module_spec = importlib.util.spec_from_file_location( + inference_module_file, + str(self._path_to_inference_code.joinpath(inference_module_file)) + ) + if module_spec: + sys.path.insert(0, str(self._path_to_inference_code.resolve())) + inference_module = module_spec.loader.load_module(inference_module_file) + if hasattr(inference_module, handle_name): + handler = getattr(inference_module, handle_name) + else: + raise InferenceCodeLoadException( + f"Handler `{handle_name}` could not be found in module `{inference_module_file}`" + ) + logger.debug(f"Loaded handler `{handle_name}` from module `{inference_module_name}`") + return handler + else: + raise InferenceCodeLoadException( + f"Inference code could not be found at `{str(self._path_to_inference_code.joinpath(inference_module_file))}`" + ) + raise InferenceCodeLoadException( + f"Inference code expected in the format of `.` but was provided as {code}" + ) + + +if __name__ == "__main__": + inference_server = InferenceServer() + inference_server.initialize() + inference_server.serve() diff --git a/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/__init__.py b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/__init__.py new file mode 100644 index 00000000..c3961685 --- /dev/null +++ b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/__init__.py @@ -0,0 +1 @@ +from __future__ import absolute_import diff --git a/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/async_server.py b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/async_server.py new file mode 100644 index 00000000..41f12a07 --- /dev/null +++ b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/async_server.py @@ -0,0 +1,38 @@ +from __future__ import absolute_import + +import asyncio +import logging +import tornado.web +from utils.environment import Environment +from utils.exception import AsyncInvocationsException +from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER + +logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) + + +class InvocationsHandler(tornado.web.RequestHandler): + def initialize(self, handler: callable, environment: Environment): + self._handler = handler + self._environment = environment + + async def post(self): + try: + response = await self._handler(self.request) + self.write(response) + except Exception as e: + raise AsyncInvocationsException(e) + + +class PingHandler(tornado.web.RequestHandler): + def get(self): + self.write("") + + +async def serve(handler: callable, environment: Environment): + app = tornado.web.Application([ + (r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)), + (r"/ping", PingHandler), + ]) + app.listen(environment.port) + logger.debug(f"Asynchronous inference server listening on port: `{environment.port}`") + await asyncio.Event().wait() diff --git a/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/sync_server.py b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/sync_server.py new file mode 100644 index 00000000..d9dac7c7 --- /dev/null +++ b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/sync_server.py @@ -0,0 +1,37 @@ +from __future__ import absolute_import + +import asyncio +import logging +import tornado.web +from utils.environment import Environment +from utils.exception import SyncInvocationsException +from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER + +logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) + + +class InvocationsHandler(tornado.web.RequestHandler): + def initialize(self, handler: callable, environment: Environment): + self._handler = handler + self._environment = environment + + def post(self): + try: + self.write(self._handler(self.request)) + except Exception as e: + raise SyncInvocationsException(e) + + +class PingHandler(tornado.web.RequestHandler): + def get(self): + self.write("") + + +async def serve(handler: callable, environment: Environment): + app = tornado.web.Application([ + (r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)), + (r"/ping", PingHandler), + ]) + app.listen(environment.port) + logger.debug(f"Synchronous inference server listening on port: `{environment.port}`") + await asyncio.Event().wait() diff --git a/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/__init__.py b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/__init__.py new file mode 100644 index 00000000..c3961685 --- /dev/null +++ b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/__init__.py @@ -0,0 +1 @@ +from __future__ import absolute_import diff --git a/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/environment.py b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/environment.py new file mode 100644 index 00000000..dc6fb873 --- /dev/null +++ b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/environment.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import + +import json +import os +from enum import Enum + +class SageMakerInference(str, Enum): + BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY" + REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS" + CODE_DIRECTORY = "SAGEMAKER_INFERENCE_CODE_DIRECTORY" + CODE = "SAGEMAKER_INFERENCE_CODE" + LOGGING_LEVEL = "SAGEMAKER_INFERENCE_LOGGING_LEVEL" + PORT = "SAGEMAKER_INFERENCE_PORT" + + +class Environment: + def __init__(self): + self._environment_variables = { + SageMakerInference.BASE_DIRECTORY: "/opt/ml/model", + SageMakerInference.REQUIREMENTS: "requirements.txt", + SageMakerInference.CODE_DIRECTORY: os.getenv(SageMakerInference.CODE_DIRECTORY, None), + SageMakerInference.CODE: os.getenv(SageMakerInference.CODE, "inference.handler"), + SageMakerInference.LOGGING_LEVEL: os.getenv(SageMakerInference.LOGGING_LEVEL, 10), + SageMakerInference.PORT: os.getenv(SageMakerInference.PORT, 8080) + } + + def __str__(self): + return json.dumps(self._environment_variables) + + @property + def base_directory(self): + return self._environment_variables.get(SageMakerInference.BASE_DIRECTORY) + + @property + def requirements(self): + return self._environment_variables.get(SageMakerInference.REQUIREMENTS) + + @property + def code_directory(self): + return self._environment_variables.get(SageMakerInference.CODE_DIRECTORY) + + @property + def code(self): + return self._environment_variables.get(SageMakerInference.CODE) + + @property + def logging_level(self): + return self._environment_variables.get(SageMakerInference.LOGGING_LEVEL) + + @property + def port(self): + return self._environment_variables.get(SageMakerInference.PORT) diff --git a/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/exception.py b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/exception.py new file mode 100644 index 00000000..851b95d6 --- /dev/null +++ b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/exception.py @@ -0,0 +1,16 @@ +from __future__ import absolute_import + +class RequirementsInstallException(Exception): + pass + +class InferenceCodeLoadException(Exception): + pass + +class ServerStartException(Exception): + pass + +class SyncInvocationsException(Exception): + pass + +class AsyncInvocationsException(Exception): + pass diff --git a/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/logger.py b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/logger.py new file mode 100644 index 00000000..c8800868 --- /dev/null +++ b/build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/logger.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import + +import logging.config + +SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER = "sagemaker_distribution.inference_server" +LOGGING_CONFIG = { + "version": 1, + "disable_existing_loggers": True, + "formatters": { + "standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"}, + }, + "handlers": { + "default": { + "level": "DEBUG", + "formatter": "standard", + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER: { + "level": "DEBUG", + "handlers": ["default"], + "propagate": True, + }, + "tornado.application": { + "level": "DEBUG", + "handlers": ["default"], + "propagate": True, + }, + "tornado.general": { + "level": "DEBUG", + "handlers": ["default"], + "propagate": True, + }, + "tornado.access": { + "level": "DEBUG", + "handlers": ["default"], + "propagate": True, + }, + }, +} +logging.config.dictConfig(LOGGING_CONFIG)