Skip to content

Commit

Permalink
Enable inference serving capabilities on sagemaker endpoint using tor…
Browse files Browse the repository at this point in the history
…nado
  • Loading branch information
gwang111 committed Dec 27, 2024
1 parent 466239f commit bb132ea
Show file tree
Hide file tree
Showing 13 changed files with 298 additions and 1 deletion.
3 changes: 2 additions & 1 deletion build_artifacts/v2/v2.2/v2.2.0/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ RUN mkdir -p $SAGEMAKER_LOGGING_DIR && \
&& ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} python \
&& rm -rf ${HOME_DIR}/oss_compliance*

ENV PATH="/opt/conda/bin:/opt/conda/condabin:$PATH"
# Adding inference-server to path, so that docker can run serve as executable.
ENV PATH="/etc/inference-server:/opt/conda/bin:/opt/conda/condabin:$PATH"
WORKDIR "/home/${NB_USER}"
ENV SHELL=/bin/bash
ENV OPENSSL_MODULES=/opt/conda/lib64/ossl-modules/
Expand Down
1 change: 1 addition & 0 deletions build_artifacts/v2/v2.2/v2.2.0/cpu.env.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ conda-forge::amazon-sagemaker-jupyter-ai-q-developer[version='>=1.0.12,<2.0.0']
conda-forge::amazon-q-developer-jupyterlab-ext[version='>=3.4.0,<4.0.0']
conda-forge::langchain[version='>=0.2.16,<1.0.0']
conda-forge::fastapi[version='>=0.115.2,<1.0.0']
conda-forge::tornado[version='>=6.3.3,<7.0.0']
conda-forge::uvicorn[version='>=0.32.0,<1.0.0']
conda-forge::pytorch[version='>=2.4.1,<3.0.0']
conda-forge::tensorflow[version='>=2.17.0,<3.0.0']
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import absolute_import

import utils.logger
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/bash
python /etc/inference-server/server.py
101 changes: 101 additions & 0 deletions build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import absolute_import

import asyncio
import importlib
import logging
import subprocess
import sys
from pathlib import Path
from utils.environment import Environment
from utils.exception import (
InferenceCodeLoadException,
RequirementsInstallException,
ServerStartException
)
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER

logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)


class TornadoServer:
def serve(self):
if asyncio.iscoroutinefunction(self._handler):
logger.info("Starting inference server in asynchronous mode...")
import tornado_server.async_server as inference_server
else:
logger.info("Starting inference server in synchronous mode...")
import tornado_server.sync_server as inference_server

try:
asyncio.run(inference_server.serve(self._handler, self._environment))
except Exception as e:
raise ServerStartException(e)


class InferenceServer(TornadoServer):
def __init__(self):
self._environment = Environment()
logger.setLevel(self._environment.logging_level)
logger.debug(f"Environment: {str(self._environment)}")

self._path_to_inference_code = (
Path(self._environment.base_directory).joinpath(self._environment.code_directory)
if self._environment.code_directory else
Path(self._environment.base_directory)
)
logger.debug(f"Path to inference code: `{str(self._path_to_inference_code)}`")

self._handler = None

def initialize(self):
self._install_runtime_requirements()
self._handler = self._load_inference_handler()

def _install_runtime_requirements(self):
logger.info("Installing runtime requirements...")

requirements_txt = self._path_to_inference_code.joinpath(self._environment.requirements)
if requirements_txt.is_file():
try:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-r", str(requirements_txt)]
)
except Exception as e:
raise RequirementsInstallException(e)
else:
logger.debug(f"No requirements file was found at `{str(requirements_txt)}`")

def _load_inference_handler(self) -> callable:
logger.info("Loading inference handler...")

inference_module_name, handle_name = self._environment.code.split(".")
if inference_module_name and handle_name:
inference_module_file = f"{inference_module_name}.py"
module_spec = importlib.util.spec_from_file_location(
inference_module_file,
str(self._path_to_inference_code.joinpath(inference_module_file))
)
if module_spec:
sys.path.insert(0, str(self._path_to_inference_code.resolve()))
inference_module = module_spec.loader.load_module(inference_module_file)
if hasattr(inference_module, handle_name):
handler = getattr(inference_module, handle_name)
else:
raise InferenceCodeLoadException(
f"Handler `{handle_name}` could not be found in module `{inference_module_file}`"
)
logger.debug(f"Loaded handler `{handle_name}` from module `{inference_module_name}`")
return handler
else:
raise InferenceCodeLoadException(
f"Inference code could not be found at `{str(self._path_to_inference_code.joinpath(inference_module_file))}`"
)
raise InferenceCodeLoadException(
f"Inference code expected in the format of `<module>.<handler>` but was provided as {code}"
)


if __name__ == "__main__":
inference_server = InferenceServer()
inference_server.initialize()
inference_server.serve()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import absolute_import
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import absolute_import

import asyncio
import logging
import tornado.web
from utils.environment import Environment
from utils.exception import AsyncInvocationsException
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER

logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)


class InvocationsHandler(tornado.web.RequestHandler):
def initialize(self, handler: callable, environment: Environment):
self._handler = handler
self._environment = environment

async def post(self):
try:
response = await self._handler(self.request)
self.write(response)
except Exception as e:
raise AsyncInvocationsException(e)


class PingHandler(tornado.web.RequestHandler):
def get(self):
self.write("")


async def serve(handler: callable, environment: Environment):
app = tornado.web.Application([
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)),
(r"/ping", PingHandler),
])
app.listen(environment.port)
logger.debug(f"Asynchronous inference server listening on port: `{environment.port}`")
await asyncio.Event().wait()
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import absolute_import

import asyncio
import logging
import tornado.web
from utils.environment import Environment
from utils.exception import SyncInvocationsException
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER

logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)


class InvocationsHandler(tornado.web.RequestHandler):
def initialize(self, handler: callable, environment: Environment):
self._handler = handler
self._environment = environment

def post(self):
try:
self.write(self._handler(self.request))
except Exception as e:
raise SyncInvocationsException(e)


class PingHandler(tornado.web.RequestHandler):
def get(self):
self.write("")


async def serve(handler: callable, environment: Environment):
app = tornado.web.Application([
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)),
(r"/ping", PingHandler),
])
app.listen(environment.port)
logger.debug(f"Synchronous inference server listening on port: `{environment.port}`")
await asyncio.Event().wait()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import absolute_import
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import absolute_import

import json
import os
from enum import Enum

class SageMakerInference(str, Enum):
BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY"
REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS"
CODE_DIRECTORY = "SAGEMAKER_INFERENCE_CODE_DIRECTORY"
CODE = "SAGEMAKER_INFERENCE_CODE"
LOGGING_LEVEL = "SAGEMAKER_INFERENCE_LOGGING_LEVEL"
PORT = "SAGEMAKER_INFERENCE_PORT"


class Environment:
def __init__(self):
self._environment_variables = {
SageMakerInference.BASE_DIRECTORY: "/opt/ml/model",
SageMakerInference.REQUIREMENTS: "requirements.txt",
SageMakerInference.CODE_DIRECTORY: os.getenv(SageMakerInference.CODE_DIRECTORY, None),
SageMakerInference.CODE: os.getenv(SageMakerInference.CODE, "inference.handler"),
SageMakerInference.LOGGING_LEVEL: os.getenv(SageMakerInference.LOGGING_LEVEL, 10),
SageMakerInference.PORT: os.getenv(SageMakerInference.PORT, 8080)
}

def __str__(self):
return json.dumps(self._environment_variables)

@property
def base_directory(self):
return self._environment_variables.get(SageMakerInference.BASE_DIRECTORY)

@property
def requirements(self):
return self._environment_variables.get(SageMakerInference.REQUIREMENTS)

@property
def code_directory(self):
return self._environment_variables.get(SageMakerInference.CODE_DIRECTORY)

@property
def code(self):
return self._environment_variables.get(SageMakerInference.CODE)

@property
def logging_level(self):
return self._environment_variables.get(SageMakerInference.LOGGING_LEVEL)

@property
def port(self):
return self._environment_variables.get(SageMakerInference.PORT)
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import absolute_import

class RequirementsInstallException(Exception):
pass

class InferenceCodeLoadException(Exception):
pass

class ServerStartException(Exception):
pass

class SyncInvocationsException(Exception):
pass

class AsyncInvocationsException(Exception):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import absolute_import

import logging.config

SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER = "sagemaker_distribution.inference_server"
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": True,
"formatters": {
"standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"},
},
"handlers": {
"default": {
"level": "DEBUG",
"formatter": "standard",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
},
},
"loggers": {
SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER: {
"level": "DEBUG",
"handlers": ["default"],
"propagate": True,
},
"tornado.application": {
"level": "DEBUG",
"handlers": ["default"],
"propagate": True,
},
"tornado.general": {
"level": "DEBUG",
"handlers": ["default"],
"propagate": True,
},
"tornado.access": {
"level": "DEBUG",
"handlers": ["default"],
"propagate": True,
},
},
}
logging.config.dictConfig(LOGGING_CONFIG)
1 change: 1 addition & 0 deletions build_artifacts/v2/v2.2/v2.2.0/gpu.env.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ conda-forge::amazon_sagemaker_sql_editor[version='>=0.1.12,<1.0.0']
conda-forge::amazon-sagemaker-sql-magic[version='>=0.1.3,<1.0.0']
conda-forge::langchain[version='>=0.2.16,<1.0.0']
conda-forge::fastapi[version='>=0.115.2,<1.0.0']
conda-forge::tornado[version='>=6.3.3,<7.0.0']
conda-forge::uvicorn[version='>=0.32.0,<1.0.0']
conda-forge::pytorch[version='>=2.3.1,<3.0.0',build='*cuda12*']
conda-forge::tensorflow[version='>=2.17.0,<3.0.0']
Expand Down

0 comments on commit bb132ea

Please sign in to comment.