Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New model images #511

Merged
merged 35 commits into from
Mar 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
39282be
adding new model images for training and inference
brifordwylie Mar 1, 2025
8a9a566
adding the code/docker for the workbench model image generation (WIP)
brifordwylie Mar 1, 2025
d255315
adding the code/docker for the workbench model image generation (WIP)
brifordwylie Mar 1, 2025
49f0032
adding the code/docker for the workbench model image generation (WIP)
brifordwylie Mar 1, 2025
8874c43
making a mock_estimator class for testing the training image
brifordwylie Mar 1, 2025
115d86a
just some cleanup
brifordwylie Mar 1, 2025
72dfacf
cleanup and simplification
brifordwylie Mar 1, 2025
99a6630
cleanup and simplification
brifordwylie Mar 1, 2025
0a06066
removing some old test code
brifordwylie Mar 1, 2025
84d1db8
refactoring inference entry_point and test harness
brifordwylie Mar 1, 2025
84378f5
fixing StringIO imports
brifordwylie Mar 1, 2025
5c2d99a
improved json handling
brifordwylie Mar 1, 2025
24c34b2
change test data a bit
brifordwylie Mar 1, 2025
ad4a063
changing repo naming
brifordwylie Mar 1, 2025
091fc3d
changing InferenceImage to ModelImages
brifordwylie Mar 1, 2025
4b837cf
using new model images
brifordwylie Mar 1, 2025
2a80c66
unlocking scikit-learn version
brifordwylie Mar 1, 2025
1f6299c
switching over to 'serve' script
brifordwylie Mar 1, 2025
1cb2186
cleaning up the requirements.txt files for models since our new train…
brifordwylie Mar 1, 2025
54e5eb3
making the serve script executable
brifordwylie Mar 1, 2025
c2f8a37
refactoring the training and inference containers
brifordwylie Mar 2, 2025
5f230e3
simplifying the inference entry point
brifordwylie Mar 2, 2025
cf008ec
adding code and metadata to model dir (for pick up by inference conta…
brifordwylie Mar 2, 2025
fefca78
changing script args so they don't fail if ENV vars aren't set
brifordwylie Mar 2, 2025
951511b
changing script args so they don't fail if ENV vars aren't set
brifordwylie Mar 2, 2025
e56fe74
changing logic for copying code files/directories
brifordwylie Mar 2, 2025
f424222
PYTHONPATH doesn't work with importlib, so use sys.path
brifordwylie Mar 2, 2025
56a59ae
fixing the file/dir copy from code to model dir
brifordwylie Mar 2, 2025
8b7af33
flake8/linter cleanup
brifordwylie Mar 2, 2025
8868c88
adding install requirements.txt for inference entry point
brifordwylie Mar 2, 2025
f9b0db8
putting in a better pip install (with cache) and better ping response
brifordwylie Mar 2, 2025
9246be4
flake8/linter cleanup
brifordwylie Mar 2, 2025
9da5875
new version of rdkit
brifordwylie Mar 2, 2025
4911c05
unlocking scikit-learn version
brifordwylie Mar 2, 2025
7b44ec8
fix test
brifordwylie Mar 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions applications/compound_explorer/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ sagemaker >= 2.143
cryptography>=42.0.5
ipython>=8.17.2
xgboost>=2.0.3
scikit-learn >=1.4.2, <= 1.5.2
scikit-learn >=1.5.2
joblib>=1.3.2
requests>=2.32.0
plotly >= 5.18.0
Expand All @@ -18,7 +18,7 @@ dash-bootstrap-templates >= 1.3.0
dash_ag_grid
tabulate >= 0.9.0
shap>=0.43.0
rdkit>=2024.3.2
rdkit>=2024.9.5
mordredcommunity>=2.0.6
networkx>=3.2
matplotlib>=3.9.2
Empty file added model_docker_images/Readme.md
Empty file.
27 changes: 27 additions & 0 deletions model_docker_images/inference/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
FROM python:3.12-slim

# Install Vim
RUN apt-get update && apt-get install -y vim

# Copy requirements file
COPY requirements.txt /tmp/

# Install dependencies
RUN pip install --no-cache-dir -r /tmp/requirements.txt

# Add the serve script
COPY serve /usr/local/bin/
RUN chmod +x /usr/local/bin/serve

# Copy the main.py/entrypoint script
COPY main.py /opt/program/
WORKDIR /opt/program

# Make port 8080 available for the web server
EXPOSE 8080

# Define environment variable
ENV PYTHONUNBUFFERED=TRUE

# SageMaker will look for this
CMD ["serve"]
150 changes: 150 additions & 0 deletions model_docker_images/inference/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from fastapi import FastAPI, Request, Response
from contextlib import asynccontextmanager
import os
import sys
import json
import importlib.util
import logging
import subprocess
import site

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables
model = None
inference_module = None


def get_inference_script(model_dir: str) -> str:
"""Retrieve the inference script name

Args:
model_dir (str): The directory containing the model artifacts

Returns:
str: The name of the inference script
"""

# Get the path to the inference-metadata.json file
inference_meta_path = os.path.join(model_dir, "inference-metadata.json")
with open(inference_meta_path, "r") as f:
config = json.load(f)
return config["inference_script"]


def install_requirements(requirements_path):
"""Install Python dependencies from requirements file.
Uses a persistent cache to speed up container cold starts.
Note: Inference containers don't have root access, so we
use the --user flag and add the user package path manually.
"""
if os.path.exists(requirements_path):
logger.info(f"Installing dependencies from {requirements_path}...")

# Define a persistent cache location
pip_cache_dir = "/opt/ml/model/.cache/pip"
os.environ["PIP_CACHE_DIR"] = pip_cache_dir

try:
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"--cache-dir",
pip_cache_dir, # Enable caching
"--disable-pip-version-check",
"--no-warn-script-location",
"--user",
"-r",
requirements_path,
]
)
# Ensure Python can find user-installed packages
sys.path.append(site.getusersitepackages())
logger.info("Requirements installed successfully.")
except subprocess.CalledProcessError as e:
logger.error(f"Error installing requirements: {e}")
sys.exit(1)
else:
logger.info(f"No requirements file found at {requirements_path}")


@asynccontextmanager
async def lifespan(app: FastAPI):
"""Handle model loading on startup and cleanup on shutdown."""
global model, inference_module

# Note: SageMaker will put model.tar.gz in /opt/ml/model
# which includes the model artifacts and inference code
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
inference_script = get_inference_script(model_dir)

# List directory contents for debugging
logger.info(f"Contents of {model_dir}: {os.listdir(model_dir)}")

try:
# Load the inference script from source_dir
inference_script_path = os.path.join(model_dir, inference_script)
if not os.path.exists(inference_script_path):
raise FileNotFoundError(f"Inference script not found: {inference_script_path}")

# Install requirements if present
install_requirements(os.path.join(model_dir, "requirements.txt"))

# Ensure the model directory is in the Python path
sys.path.insert(0, model_dir)

# Import the inference module
logger.info(f"Importing inference module from {inference_script_path}")
spec = importlib.util.spec_from_file_location("inference_module", inference_script_path)
inference_module = importlib.util.module_from_spec(spec)
sys.modules["inference_module"] = inference_module
spec.loader.exec_module(inference_module)

# Check if model_fn is defined
if not hasattr(inference_module, "model_fn"):
raise ImportError(f"Inference module {inference_script_path} does not define model_fn")

# Load the model using model_fn
logger.info("Calling model_fn to load the model")
model = inference_module.model_fn(model_dir)
logger.info(f"Model loaded successfully: {type(model)}")

except Exception as e:
logger.error(f"Error initializing model: {e}", exc_info=True)
raise

yield

logger.info("Shutting down model server")


app = FastAPI(lifespan=lifespan)


@app.get("/ping")
def ping():
"""Health check endpoint for SageMaker."""
# Check if the inference module is loaded
return Response(status_code=200 if inference_module else 500)


@app.post("/invocations")
async def invoke(request: Request):
"""Inference endpoint for SageMaker."""
content_type = request.headers.get("Content-Type", "")
accept_type = request.headers.get("Accept", "")

try:
body = await request.body()
data = inference_module.input_fn(body, content_type)
result = inference_module.predict_fn(data, model)
output_data, output_content_type = inference_module.output_fn(result, accept_type)
return Response(content=output_data, media_type=output_content_type)
except Exception as e:
logger.error(f"Error during inference: {e}", exc_info=True)
return Response(content=json.dumps({"error": str(e)}), status_code=500, media_type="application/json")
7 changes: 7 additions & 0 deletions model_docker_images/inference/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
fastapi==0.115.10
uvicorn==0.34.0
scikit-learn==1.6.1
xgboost-cpu==2.1.4
pandas==2.2.3
awswrangler==3.11.0
joblib==1.4.2
6 changes: 6 additions & 0 deletions model_docker_images/inference/serve
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

# SageMaker expect a 'serve' script to be found in the container which starts the model server.

# Start the FastAPI server using Uvicorn
exec uvicorn main:app --host 0.0.0.0 --port 8080
152 changes: 152 additions & 0 deletions model_docker_images/scripts/build_deploy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#!/bin/bash
set -e

# Get the directory of this script
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)"
# Get the parent directory (project root)
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"

# AWS Account ID
AWS_ACCOUNT_ID="507740646243"

# Define repository names - used for both local and ECR images
TRAINING_REPO="aws-ml-images/py312-sklearn-xgb-training"
INFERENCE_REPO="aws-ml-images/py312-sklearn-xgb-inference"

# Local directories
TRAINING_DIR="$PROJECT_ROOT/training"
INFERENCE_DIR="$PROJECT_ROOT/inference"

# Image version
IMAGE_VERSION=${1:-"0.1"}

# Expect AWS_PROFILE to be set in the environment when deploying
if [ "$2" == "--deploy" ]; then
: "${AWS_PROFILE:?AWS_PROFILE environment variable is not set.}"
fi

# Define the regions to deploy to.
REGION_LIST=("us-east-1" "us-west-2")

# Colors for output
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color

# Parse arguments
DEPLOY=false
LATEST=false
for arg in "$@"; do
case $arg in
--deploy)
DEPLOY=true
;;
--latest)
LATEST=true
;;
*)
;;
esac
done

# Function to build a Docker image
build_image() {
local dir=$1
local repo_name=$2
local tag=$3
local full_name="${repo_name}:${tag}"

echo -e "${YELLOW}Building image: ${full_name}${NC}"

# Check if Dockerfile exists
if [ ! -f "$dir/Dockerfile" ]; then
echo "❌ Error: Dockerfile not found in $dir"
return 1
fi

# Build the image for AMD64 architecture
echo "Building local Docker image ${full_name} for linux/amd64..."
docker build --platform linux/amd64 -t $full_name $dir

echo -e "${GREEN}✅ Successfully built: ${full_name}${NC}"
return 0
}

# Function to deploy an image to ECR
deploy_image() {
local repo_name=$1
local tag=$2
local use_latest=$3
local full_name="${repo_name}:${tag}"

for REGION in "${REGION_LIST[@]}"; do
echo "Processing region: ${REGION}"
# Construct the ECR repository URL
ECR_REPO="${AWS_ACCOUNT_ID}.dkr.ecr.${REGION}.amazonaws.com/${repo_name}"
AWS_ECR_IMAGE="${ECR_REPO}:${tag}"

echo "Logging in to AWS ECR in ${REGION}..."
aws ecr get-login-password --region ${REGION} --profile ${AWS_PROFILE} | \
docker login --username AWS --password-stdin "${AWS_ACCOUNT_ID}.dkr.ecr.${REGION}.amazonaws.com"

echo "Tagging image for AWS ECR as ${AWS_ECR_IMAGE}..."
docker tag ${full_name} ${AWS_ECR_IMAGE}

echo "Pushing Docker image to AWS ECR: ${AWS_ECR_IMAGE}..."
docker push ${AWS_ECR_IMAGE}

if [ "$use_latest" = true ]; then
AWS_ECR_LATEST="${ECR_REPO}:latest"
echo "Tagging AWS ECR image as latest: ${AWS_ECR_LATEST}..."
docker tag ${full_name} ${AWS_ECR_LATEST}
echo "Pushing Docker image to AWS ECR: ${AWS_ECR_LATEST}..."
docker push ${AWS_ECR_LATEST}
fi
done
}

# Build training image
echo "======================================"
echo "🏗️ Building training container"
echo "======================================"
build_image "$TRAINING_DIR" "$TRAINING_REPO" "$IMAGE_VERSION"

# Build inference image
echo "======================================"
echo "🏗️ Building inference container"
echo "======================================"
build_image "$INFERENCE_DIR" "$INFERENCE_REPO" "$IMAGE_VERSION"

echo "======================================"
echo -e "${GREEN}✅ All builds completed successfully!${NC}"
echo "======================================"

if [ "$DEPLOY" = true ]; then
echo "======================================"
echo "🚀 Deploying containers to ECR"
echo "======================================"

# Deploy training image
echo "Deploying training image..."
deploy_image "$TRAINING_REPO" "$IMAGE_VERSION" "$LATEST"

# Deploy inference image
echo "Deploying inference image..."
deploy_image "$INFERENCE_REPO" "$IMAGE_VERSION" "$LATEST"

echo "======================================"
echo -e "${GREEN}✅ Deployment complete!${NC}"
echo "======================================"
else
echo "Local build complete. Use --deploy to push the images to AWS ECR in regions: ${REGION_LIST[*]}."

# Print information about the built images
echo "======================================"
echo "📋 Image information:"
echo "Training image: ${TRAINING_REPO}:${IMAGE_VERSION}"
echo "Inference image: ${INFERENCE_REPO}:${IMAGE_VERSION}"
echo "======================================"

# Inform about testing option
echo "To test these containers, run: $PROJECT_ROOT/tests/run_tests.sh ${IMAGE_VERSION}"
fi
Loading