-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable inference serving capabilities on sagemaker endpoint using tor…
…nado
- Loading branch information
Showing
12 changed files
with
297 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from __future__ import absolute_import | ||
|
||
import utils.logger |
2 changes: 2 additions & 0 deletions
2
build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/serve
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#!/bin/bash | ||
python /etc/inference-server/server.py |
101 changes: 101 additions & 0 deletions
101
build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/server.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from __future__ import absolute_import | ||
|
||
import asyncio | ||
import importlib | ||
import logging | ||
import subprocess | ||
import sys | ||
from pathlib import Path | ||
from utils.environment import Environment | ||
from utils.exception import ( | ||
InferenceCodeLoadException, | ||
RequirementsInstallException, | ||
ServerStartException | ||
) | ||
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER | ||
|
||
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) | ||
|
||
|
||
class TornadoServer: | ||
def serve(self): | ||
if asyncio.iscoroutinefunction(self._handler): | ||
logger.info("Starting inference server in asynchronous mode...") | ||
import tornado_server.async_server as inference_server | ||
else: | ||
logger.info("Starting inference server in synchronous mode...") | ||
import tornado_server.sync_server as inference_server | ||
|
||
try: | ||
asyncio.run(inference_server.serve(self._handler, self._environment)) | ||
except Exception as e: | ||
raise ServerStartException(e) | ||
|
||
|
||
class InferenceServer(TornadoServer): | ||
def __init__(self): | ||
self._environment = Environment() | ||
logger.setLevel(self._environment.logging_level) | ||
logger.debug(f"Environment: {str(self._environment)}") | ||
|
||
self._path_to_inference_code = ( | ||
Path(self._environment.base_directory).joinpath(self._environment.code_directory) | ||
if self._environment.code_directory else | ||
Path(self._environment.base_directory) | ||
) | ||
logger.debug(f"Path to inference code: `{str(self._path_to_inference_code)}`") | ||
|
||
self._handler = None | ||
|
||
def initialize(self): | ||
self._install_runtime_requirements() | ||
self._handler = self._load_inference_handler() | ||
|
||
def _install_runtime_requirements(self): | ||
logger.info("Installing runtime requirements...") | ||
|
||
requirements_txt = self._path_to_inference_code.joinpath(self._environment.requirements) | ||
if requirements_txt.is_file(): | ||
try: | ||
subprocess.check_call( | ||
[sys.executable, "-m", "pip", "install", "-r", str(requirements_txt)] | ||
) | ||
except Exception as e: | ||
raise RequirementsInstallException(e) | ||
else: | ||
logger.debug(f"No requirements file was found at `{str(requirements_txt)}`") | ||
|
||
def _load_inference_handler(self) -> callable: | ||
logger.info("Loading inference handler...") | ||
|
||
inference_module_name, handle_name = self._environment.code.split(".") | ||
if inference_module_name and handle_name: | ||
inference_module_file = f"{inference_module_name}.py" | ||
module_spec = importlib.util.spec_from_file_location( | ||
inference_module_file, | ||
str(self._path_to_inference_code.joinpath(inference_module_file)) | ||
) | ||
if module_spec: | ||
sys.path.insert(0, str(self._path_to_inference_code.resolve())) | ||
inference_module = module_spec.loader.load_module(inference_module_file) | ||
if hasattr(inference_module, handle_name): | ||
handler = getattr(inference_module, handle_name) | ||
else: | ||
raise InferenceCodeLoadException( | ||
f"Handler `{handle_name}` could not be found in module `{inference_module_file}`" | ||
) | ||
logger.debug(f"Loaded handler `{handle_name}` from module `{inference_module_name}`") | ||
return handler | ||
else: | ||
raise InferenceCodeLoadException( | ||
f"Inference code could not be found at `{str(self._path_to_inference_code.joinpath(inference_module_file))}`" | ||
) | ||
raise InferenceCodeLoadException( | ||
f"Inference code expected in the format of `<module>.<handler>` but was provided as {code}" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
inference_server = InferenceServer() | ||
inference_server.initialize() | ||
inference_server.serve() |
1 change: 1 addition & 0 deletions
1
build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from __future__ import absolute_import |
38 changes: 38 additions & 0 deletions
38
build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/async_server.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from __future__ import absolute_import | ||
|
||
import asyncio | ||
import logging | ||
import tornado.web | ||
from utils.environment import Environment | ||
from utils.exception import AsyncInvocationsException | ||
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER | ||
|
||
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) | ||
|
||
|
||
class InvocationsHandler(tornado.web.RequestHandler): | ||
def initialize(self, handler: callable, environment: Environment): | ||
self._handler = handler | ||
self._environment = environment | ||
|
||
async def post(self): | ||
try: | ||
response = await self._handler(self.request) | ||
self.write(response) | ||
except Exception as e: | ||
raise AsyncInvocationsException(e) | ||
|
||
|
||
class PingHandler(tornado.web.RequestHandler): | ||
def get(self): | ||
self.write("") | ||
|
||
|
||
async def serve(handler: callable, environment: Environment): | ||
app = tornado.web.Application([ | ||
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)), | ||
(r"/ping", PingHandler), | ||
]) | ||
app.listen(environment.port) | ||
logger.debug(f"Asynchronous inference server listening on port: `{environment.port}`") | ||
await asyncio.Event().wait() |
37 changes: 37 additions & 0 deletions
37
build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/tornado_server/sync_server.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from __future__ import absolute_import | ||
|
||
import asyncio | ||
import logging | ||
import tornado.web | ||
from utils.environment import Environment | ||
from utils.exception import SyncInvocationsException | ||
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER | ||
|
||
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) | ||
|
||
|
||
class InvocationsHandler(tornado.web.RequestHandler): | ||
def initialize(self, handler: callable, environment: Environment): | ||
self._handler = handler | ||
self._environment = environment | ||
|
||
def post(self): | ||
try: | ||
self.write(self._handler(self.request)) | ||
except Exception as e: | ||
raise SyncInvocationsException(e) | ||
|
||
|
||
class PingHandler(tornado.web.RequestHandler): | ||
def get(self): | ||
self.write("") | ||
|
||
|
||
async def serve(handler: callable, environment: Environment): | ||
app = tornado.web.Application([ | ||
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)), | ||
(r"/ping", PingHandler), | ||
]) | ||
app.listen(environment.port) | ||
logger.debug(f"Synchronous inference server listening on port: `{environment.port}`") | ||
await asyncio.Event().wait() |
1 change: 1 addition & 0 deletions
1
build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from __future__ import absolute_import |
52 changes: 52 additions & 0 deletions
52
build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/environment.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from __future__ import absolute_import | ||
|
||
import json | ||
import os | ||
from enum import Enum | ||
|
||
class SageMakerInference(str, Enum): | ||
BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY" | ||
REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS" | ||
CODE_DIRECTORY = "SAGEMAKER_INFERENCE_CODE_DIRECTORY" | ||
CODE = "SAGEMAKER_INFERENCE_CODE" | ||
LOGGING_LEVEL = "SAGEMAKER_INFERENCE_LOGGING_LEVEL" | ||
PORT = "SAGEMAKER_INFERENCE_PORT" | ||
|
||
|
||
class Environment: | ||
def __init__(self): | ||
self._environment_variables = { | ||
SageMakerInference.BASE_DIRECTORY: "/opt/ml/model", | ||
SageMakerInference.REQUIREMENTS: "requirements.txt", | ||
SageMakerInference.CODE_DIRECTORY: os.getenv(SageMakerInference.CODE_DIRECTORY, None), | ||
SageMakerInference.CODE: os.getenv(SageMakerInference.CODE, "inference.handler"), | ||
SageMakerInference.LOGGING_LEVEL: os.getenv(SageMakerInference.LOGGING_LEVEL, 10), | ||
SageMakerInference.PORT: os.getenv(SageMakerInference.PORT, 8080) | ||
} | ||
|
||
def __str__(self): | ||
return json.dumps(self._environment_variables) | ||
|
||
@property | ||
def base_directory(self): | ||
return self._environment_variables.get(SageMakerInference.BASE_DIRECTORY) | ||
|
||
@property | ||
def requirements(self): | ||
return self._environment_variables.get(SageMakerInference.REQUIREMENTS) | ||
|
||
@property | ||
def code_directory(self): | ||
return self._environment_variables.get(SageMakerInference.CODE_DIRECTORY) | ||
|
||
@property | ||
def code(self): | ||
return self._environment_variables.get(SageMakerInference.CODE) | ||
|
||
@property | ||
def logging_level(self): | ||
return self._environment_variables.get(SageMakerInference.LOGGING_LEVEL) | ||
|
||
@property | ||
def port(self): | ||
return self._environment_variables.get(SageMakerInference.PORT) |
16 changes: 16 additions & 0 deletions
16
build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/exception.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from __future__ import absolute_import | ||
|
||
class RequirementsInstallException(Exception): | ||
pass | ||
|
||
class InferenceCodeLoadException(Exception): | ||
pass | ||
|
||
class ServerStartException(Exception): | ||
pass | ||
|
||
class SyncInvocationsException(Exception): | ||
pass | ||
|
||
class AsyncInvocationsException(Exception): | ||
pass |
43 changes: 43 additions & 0 deletions
43
build_artifacts/v2/v2.2/v2.2.0/dirs/etc/inference-server/utils/logger.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from __future__ import absolute_import | ||
|
||
import logging.config | ||
|
||
SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER = "sagemaker_distribution.inference_server" | ||
LOGGING_CONFIG = { | ||
"version": 1, | ||
"disable_existing_loggers": True, | ||
"formatters": { | ||
"standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"}, | ||
}, | ||
"handlers": { | ||
"default": { | ||
"level": "DEBUG", | ||
"formatter": "standard", | ||
"class": "logging.StreamHandler", | ||
"stream": "ext://sys.stdout", | ||
}, | ||
}, | ||
"loggers": { | ||
SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER: { | ||
"level": "DEBUG", | ||
"handlers": ["default"], | ||
"propagate": True, | ||
}, | ||
"tornado.application": { | ||
"level": "DEBUG", | ||
"handlers": ["default"], | ||
"propagate": True, | ||
}, | ||
"tornado.general": { | ||
"level": "DEBUG", | ||
"handlers": ["default"], | ||
"propagate": True, | ||
}, | ||
"tornado.access": { | ||
"level": "DEBUG", | ||
"handlers": ["default"], | ||
"propagate": True, | ||
}, | ||
}, | ||
} | ||
logging.config.dictConfig(LOGGING_CONFIG) |