-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #511 from SuperCowPowers/new_model_images
New model images
- Loading branch information
Showing
31 changed files
with
1,625 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
FROM python:3.12-slim | ||
|
||
# Install Vim | ||
RUN apt-get update && apt-get install -y vim | ||
|
||
# Copy requirements file | ||
COPY requirements.txt /tmp/ | ||
|
||
# Install dependencies | ||
RUN pip install --no-cache-dir -r /tmp/requirements.txt | ||
|
||
# Add the serve script | ||
COPY serve /usr/local/bin/ | ||
RUN chmod +x /usr/local/bin/serve | ||
|
||
# Copy the main.py/entrypoint script | ||
COPY main.py /opt/program/ | ||
WORKDIR /opt/program | ||
|
||
# Make port 8080 available for the web server | ||
EXPOSE 8080 | ||
|
||
# Define environment variable | ||
ENV PYTHONUNBUFFERED=TRUE | ||
|
||
# SageMaker will look for this | ||
CMD ["serve"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
from fastapi import FastAPI, Request, Response | ||
from contextlib import asynccontextmanager | ||
import os | ||
import sys | ||
import json | ||
import importlib.util | ||
import logging | ||
import subprocess | ||
import site | ||
|
||
# Set up logging | ||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
# Global variables | ||
model = None | ||
inference_module = None | ||
|
||
|
||
def get_inference_script(model_dir: str) -> str: | ||
"""Retrieve the inference script name | ||
Args: | ||
model_dir (str): The directory containing the model artifacts | ||
Returns: | ||
str: The name of the inference script | ||
""" | ||
|
||
# Get the path to the inference-metadata.json file | ||
inference_meta_path = os.path.join(model_dir, "inference-metadata.json") | ||
with open(inference_meta_path, "r") as f: | ||
config = json.load(f) | ||
return config["inference_script"] | ||
|
||
|
||
def install_requirements(requirements_path): | ||
"""Install Python dependencies from requirements file. | ||
Uses a persistent cache to speed up container cold starts. | ||
Note: Inference containers don't have root access, so we | ||
use the --user flag and add the user package path manually. | ||
""" | ||
if os.path.exists(requirements_path): | ||
logger.info(f"Installing dependencies from {requirements_path}...") | ||
|
||
# Define a persistent cache location | ||
pip_cache_dir = "/opt/ml/model/.cache/pip" | ||
os.environ["PIP_CACHE_DIR"] = pip_cache_dir | ||
|
||
try: | ||
subprocess.check_call( | ||
[ | ||
sys.executable, | ||
"-m", | ||
"pip", | ||
"install", | ||
"--cache-dir", | ||
pip_cache_dir, # Enable caching | ||
"--disable-pip-version-check", | ||
"--no-warn-script-location", | ||
"--user", | ||
"-r", | ||
requirements_path, | ||
] | ||
) | ||
# Ensure Python can find user-installed packages | ||
sys.path.append(site.getusersitepackages()) | ||
logger.info("Requirements installed successfully.") | ||
except subprocess.CalledProcessError as e: | ||
logger.error(f"Error installing requirements: {e}") | ||
sys.exit(1) | ||
else: | ||
logger.info(f"No requirements file found at {requirements_path}") | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
"""Handle model loading on startup and cleanup on shutdown.""" | ||
global model, inference_module | ||
|
||
# Note: SageMaker will put model.tar.gz in /opt/ml/model | ||
# which includes the model artifacts and inference code | ||
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model") | ||
inference_script = get_inference_script(model_dir) | ||
|
||
# List directory contents for debugging | ||
logger.info(f"Contents of {model_dir}: {os.listdir(model_dir)}") | ||
|
||
try: | ||
# Load the inference script from source_dir | ||
inference_script_path = os.path.join(model_dir, inference_script) | ||
if not os.path.exists(inference_script_path): | ||
raise FileNotFoundError(f"Inference script not found: {inference_script_path}") | ||
|
||
# Install requirements if present | ||
install_requirements(os.path.join(model_dir, "requirements.txt")) | ||
|
||
# Ensure the model directory is in the Python path | ||
sys.path.insert(0, model_dir) | ||
|
||
# Import the inference module | ||
logger.info(f"Importing inference module from {inference_script_path}") | ||
spec = importlib.util.spec_from_file_location("inference_module", inference_script_path) | ||
inference_module = importlib.util.module_from_spec(spec) | ||
sys.modules["inference_module"] = inference_module | ||
spec.loader.exec_module(inference_module) | ||
|
||
# Check if model_fn is defined | ||
if not hasattr(inference_module, "model_fn"): | ||
raise ImportError(f"Inference module {inference_script_path} does not define model_fn") | ||
|
||
# Load the model using model_fn | ||
logger.info("Calling model_fn to load the model") | ||
model = inference_module.model_fn(model_dir) | ||
logger.info(f"Model loaded successfully: {type(model)}") | ||
|
||
except Exception as e: | ||
logger.error(f"Error initializing model: {e}", exc_info=True) | ||
raise | ||
|
||
yield | ||
|
||
logger.info("Shutting down model server") | ||
|
||
|
||
app = FastAPI(lifespan=lifespan) | ||
|
||
|
||
@app.get("/ping") | ||
def ping(): | ||
"""Health check endpoint for SageMaker.""" | ||
# Check if the inference module is loaded | ||
return Response(status_code=200 if inference_module else 500) | ||
|
||
|
||
@app.post("/invocations") | ||
async def invoke(request: Request): | ||
"""Inference endpoint for SageMaker.""" | ||
content_type = request.headers.get("Content-Type", "") | ||
accept_type = request.headers.get("Accept", "") | ||
|
||
try: | ||
body = await request.body() | ||
data = inference_module.input_fn(body, content_type) | ||
result = inference_module.predict_fn(data, model) | ||
output_data, output_content_type = inference_module.output_fn(result, accept_type) | ||
return Response(content=output_data, media_type=output_content_type) | ||
except Exception as e: | ||
logger.error(f"Error during inference: {e}", exc_info=True) | ||
return Response(content=json.dumps({"error": str(e)}), status_code=500, media_type="application/json") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
fastapi==0.115.10 | ||
uvicorn==0.34.0 | ||
scikit-learn==1.6.1 | ||
xgboost-cpu==2.1.4 | ||
pandas==2.2.3 | ||
awswrangler==3.11.0 | ||
joblib==1.4.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
|
||
# SageMaker expect a 'serve' script to be found in the container which starts the model server. | ||
|
||
# Start the FastAPI server using Uvicorn | ||
exec uvicorn main:app --host 0.0.0.0 --port 8080 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
#!/bin/bash | ||
set -e | ||
|
||
# Get the directory of this script | ||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)" | ||
# Get the parent directory (project root) | ||
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" | ||
|
||
# AWS Account ID | ||
AWS_ACCOUNT_ID="507740646243" | ||
|
||
# Define repository names - used for both local and ECR images | ||
TRAINING_REPO="aws-ml-images/py312-sklearn-xgb-training" | ||
INFERENCE_REPO="aws-ml-images/py312-sklearn-xgb-inference" | ||
|
||
# Local directories | ||
TRAINING_DIR="$PROJECT_ROOT/training" | ||
INFERENCE_DIR="$PROJECT_ROOT/inference" | ||
|
||
# Image version | ||
IMAGE_VERSION=${1:-"0.1"} | ||
|
||
# Expect AWS_PROFILE to be set in the environment when deploying | ||
if [ "$2" == "--deploy" ]; then | ||
: "${AWS_PROFILE:?AWS_PROFILE environment variable is not set.}" | ||
fi | ||
|
||
# Define the regions to deploy to. | ||
REGION_LIST=("us-east-1" "us-west-2") | ||
|
||
# Colors for output | ||
GREEN='\033[0;32m' | ||
YELLOW='\033[1;33m' | ||
NC='\033[0m' # No Color | ||
|
||
# Parse arguments | ||
DEPLOY=false | ||
LATEST=false | ||
for arg in "$@"; do | ||
case $arg in | ||
--deploy) | ||
DEPLOY=true | ||
;; | ||
--latest) | ||
LATEST=true | ||
;; | ||
*) | ||
;; | ||
esac | ||
done | ||
|
||
# Function to build a Docker image | ||
build_image() { | ||
local dir=$1 | ||
local repo_name=$2 | ||
local tag=$3 | ||
local full_name="${repo_name}:${tag}" | ||
|
||
echo -e "${YELLOW}Building image: ${full_name}${NC}" | ||
|
||
# Check if Dockerfile exists | ||
if [ ! -f "$dir/Dockerfile" ]; then | ||
echo "❌ Error: Dockerfile not found in $dir" | ||
return 1 | ||
fi | ||
|
||
# Build the image for AMD64 architecture | ||
echo "Building local Docker image ${full_name} for linux/amd64..." | ||
docker build --platform linux/amd64 -t $full_name $dir | ||
|
||
echo -e "${GREEN}✅ Successfully built: ${full_name}${NC}" | ||
return 0 | ||
} | ||
|
||
# Function to deploy an image to ECR | ||
deploy_image() { | ||
local repo_name=$1 | ||
local tag=$2 | ||
local use_latest=$3 | ||
local full_name="${repo_name}:${tag}" | ||
|
||
for REGION in "${REGION_LIST[@]}"; do | ||
echo "Processing region: ${REGION}" | ||
# Construct the ECR repository URL | ||
ECR_REPO="${AWS_ACCOUNT_ID}.dkr.ecr.${REGION}.amazonaws.com/${repo_name}" | ||
AWS_ECR_IMAGE="${ECR_REPO}:${tag}" | ||
|
||
echo "Logging in to AWS ECR in ${REGION}..." | ||
aws ecr get-login-password --region ${REGION} --profile ${AWS_PROFILE} | \ | ||
docker login --username AWS --password-stdin "${AWS_ACCOUNT_ID}.dkr.ecr.${REGION}.amazonaws.com" | ||
|
||
echo "Tagging image for AWS ECR as ${AWS_ECR_IMAGE}..." | ||
docker tag ${full_name} ${AWS_ECR_IMAGE} | ||
|
||
echo "Pushing Docker image to AWS ECR: ${AWS_ECR_IMAGE}..." | ||
docker push ${AWS_ECR_IMAGE} | ||
|
||
if [ "$use_latest" = true ]; then | ||
AWS_ECR_LATEST="${ECR_REPO}:latest" | ||
echo "Tagging AWS ECR image as latest: ${AWS_ECR_LATEST}..." | ||
docker tag ${full_name} ${AWS_ECR_LATEST} | ||
echo "Pushing Docker image to AWS ECR: ${AWS_ECR_LATEST}..." | ||
docker push ${AWS_ECR_LATEST} | ||
fi | ||
done | ||
} | ||
|
||
# Build training image | ||
echo "======================================" | ||
echo "🏗️ Building training container" | ||
echo "======================================" | ||
build_image "$TRAINING_DIR" "$TRAINING_REPO" "$IMAGE_VERSION" | ||
|
||
# Build inference image | ||
echo "======================================" | ||
echo "🏗️ Building inference container" | ||
echo "======================================" | ||
build_image "$INFERENCE_DIR" "$INFERENCE_REPO" "$IMAGE_VERSION" | ||
|
||
echo "======================================" | ||
echo -e "${GREEN}✅ All builds completed successfully!${NC}" | ||
echo "======================================" | ||
|
||
if [ "$DEPLOY" = true ]; then | ||
echo "======================================" | ||
echo "🚀 Deploying containers to ECR" | ||
echo "======================================" | ||
|
||
# Deploy training image | ||
echo "Deploying training image..." | ||
deploy_image "$TRAINING_REPO" "$IMAGE_VERSION" "$LATEST" | ||
|
||
# Deploy inference image | ||
echo "Deploying inference image..." | ||
deploy_image "$INFERENCE_REPO" "$IMAGE_VERSION" "$LATEST" | ||
|
||
echo "======================================" | ||
echo -e "${GREEN}✅ Deployment complete!${NC}" | ||
echo "======================================" | ||
else | ||
echo "Local build complete. Use --deploy to push the images to AWS ECR in regions: ${REGION_LIST[*]}." | ||
|
||
# Print information about the built images | ||
echo "======================================" | ||
echo "📋 Image information:" | ||
echo "Training image: ${TRAINING_REPO}:${IMAGE_VERSION}" | ||
echo "Inference image: ${INFERENCE_REPO}:${IMAGE_VERSION}" | ||
echo "======================================" | ||
|
||
# Inform about testing option | ||
echo "To test these containers, run: $PROJECT_ROOT/tests/run_tests.sh ${IMAGE_VERSION}" | ||
fi |
Oops, something went wrong.