-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathapp.py
63 lines (51 loc) · 1.86 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/env python3
import json
import logging
import os
from aws_cdk import core
from infra.model_registry import ModelRegistry
from infra.deployment_config import DeploymentConfig
from infra.sagemaker_stack import SageMakerStack
# Configure the logger
logger = logging.getLogger(__name__)
logging.basicConfig(level="INFO")
# Load these from environment variables, that are passed into CodeBuild job from pipeline stack
project_name = os.environ["SAGEMAKER_PROJECT_NAME"]
project_id = os.environ["SAGEMAKER_PROJECT_ID"]
stage_name = os.environ["STAGE_NAME"]
# Create App and stacks
app = core.App()
# Define variables for passing down to stacks
endpoint_name = f"sagemaker-{project_name}-{stage_name}"
if len(endpoint_name) > 63:
raise Exception(
f"SageMaker endpoint: {endpoint_name} must be less than 64 characters"
)
logger.info(f"Create endpoint: {endpoint_name}")
# Define the deployment tags
tags = [
core.CfnTag(key="sagemaker:deployment-stage", value=stage_name),
core.CfnTag(key="sagemaker:project-id", value=project_id),
core.CfnTag(key="sagemaker:project-name", value=project_name),
]
# Get the stage specific deployment config for sagemaker
with open(f"{stage_name}-config.json", "r") as f:
j = json.load(f)
deployment_config = DeploymentConfig(**j)
# Append tags for ab-testing
tags += [
core.CfnTag(key="ab-testing:enabled", value="true"),
core.CfnTag(key="ab-testing:strategy", value=deployment_config.strategy),
core.CfnTag(key="ab-testing:epsilon", value=str(deployment_config.epsilon)),
core.CfnTag(key="ab-testing:warmup", value=str(deployment_config.warmup)),
]
sagemaker = SageMakerStack(
app,
"ab-testing-sagemaker",
deployment_config=deployment_config,
project_name=project_name,
project_id=project_id,
endpoint_name=endpoint_name,
tags=tags,
)
app.synth()