Skip to content

Commit dd8c545

Browse files
authored
Merge shap project into databricks demo and fixes (#161)
* add shap to databricks demo * update artifact config syntax * one version of the readme * another version * update readme * mention shap too * rm shap project and rename databricks project * fix older syntax * use uv * disable cache
1 parent 3268ae9 commit dd8c545

File tree

65 files changed

+270
-920
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+270
-920
lines changed

airflow-cloud-composer-etl-feature-train/steps/training/model_trainer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from materializers import BigQueryDataset, CSVDataset
2222
from typing_extensions import Annotated
2323
from zenml import ArtifactConfig, step
24+
from zenml.enums import ArtifactType
2425
from zenml.logger import get_logger
2526

2627
logger = get_logger(__name__)
@@ -31,7 +32,7 @@ def train_xgboost_model(
3132
dataset: Union[BigQueryDataset, CSVDataset],
3233
) -> Tuple[
3334
Annotated[
34-
xgb.Booster, ArtifactConfig(name="xgb_model", is_model_artifact=True)
35+
xgb.Booster, ArtifactConfig(name="xgb_model", artifact_type=ArtifactType.MODEL)
3536
],
3637
Annotated[Dict[str, float], "metrics"],
3738
]:

customer-satisfaction/steps/train_model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from model.model_dev import ModelTrainer
77
from sklearn.base import RegressorMixin
88
from zenml import ArtifactConfig, step
9+
from zenml.enums import ArtifactType
910
from zenml.client import Client
1011

1112
experiment_tracker = Client().active_stack.experiment_tracker
@@ -21,7 +22,7 @@ def train_model(
2122
do_fine_tuning: bool = True,
2223
) -> Annotated[
2324
RegressorMixin,
24-
ArtifactConfig(name="sklearn_regressor", is_model_artifact=True),
25+
ArtifactConfig(name="sklearn_regressor", artifact_type=ArtifactType.MODEL),
2526
]:
2627
"""
2728
Args:

databricks-demo/README.md

-156
This file was deleted.
File renamed without changes.
File renamed without changes.
+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Databricks + ZenML: End-to-End Explainable ML Project
2+
3+
Welcome to this end-to-end demo project that showcases how to train, deploy, and run batch inference on a machine learning model using ZenML in a Databricks environment. This setup demonstrates how ZenML can simplify the end-to-end process of building reproducible, production-grade ML pipelines with minimal fuss.
4+
5+
## Overview
6+
7+
This project uses an example classification dataset (Breast Cancer) and provides three major pipelines:
8+
9+
1. Training Pipeline
10+
2. Deployment Pipeline
11+
3. Batch Inference Pipeline (with SHAP-based model explainability)
12+
13+
The pipelines are orchestrated via ZenML. Additionally, this setup uses:
14+
- Databricks as the orchestrator
15+
- MLflow for experiment tracking and model registry
16+
- Evidently for data drift detection
17+
- SHAP for model explainability during inference
18+
- Slack notifications (configurable through ZenML's alerter stack components)
19+
20+
## Why ZenML?
21+
22+
ZenML is a lightweight MLOps framework for reproducible pipelines. With ZenML, you get:
23+
24+
- A consistent, standardized way to develop, version, and share pipelines.
25+
- Easy integration with various cloud providers, experiment trackers, model registries, and more.
26+
- Reproducibility and better collaboration: your pipelines and associated artifacts are automatically tracked and versioned.
27+
- Simple command-line interface for spinning pipelines up and down with different stack components (like local or Databricks orchestrators).
28+
- Built-in best practices for production ML, including quality gates for data drift and model performance thresholds.
29+
30+
## Project Structure
31+
32+
Here's an outline of the repository:
33+
34+
```
35+
.
36+
├── configs # Pipeline configuration files
37+
│ ├── deployer_config.yaml # Deployment pipeline config
38+
│ ├── inference_config.yaml # Batch inference pipeline config
39+
│ └── train_config.yaml # Training pipeline config
40+
├── pipelines # ZenML pipeline definitions
41+
│ ├── batch_inference.py # Orchestrates batch inference
42+
│ ├── deployment.py # Deploys a model service
43+
│ └── training.py # Trains and promotes model
44+
├── steps # ZenML steps logic
45+
│ ├── alerts # Alert/notification logic
46+
│ ├── data_quality # Data drift and quality checks
47+
│ ├── deployment # Deployment step
48+
│ ├── etl # ETL steps (data loading, preprocessing, splitting)
49+
│ ├── explainability # SHAP-based model explanations
50+
│ ├── hp_tuning # Hyperparameter tuning pipeline steps
51+
│ ├── inference # Batch inference step
52+
│ ├── promotion # Model promotion logic
53+
│ └── training # Model training and evaluation steps
54+
├── utils # Helper modules
55+
├── Makefile # Quick integration setup commands
56+
├── requirements.txt # Python dependencies
57+
├── run.py # CLI to run pipelines
58+
└── README.md # This file
59+
```
60+
61+
## Getting Started
62+
63+
1. (Optional) Create and activate a Python virtual environment:
64+
```bash
65+
python3 -m venv .venv
66+
source .venv/bin/activate
67+
```
68+
2. Install dependencies:
69+
```bash
70+
make setup
71+
```
72+
This installs the required ZenML integrations (MLflow, Slack, Evidently, Kubeflow, Kubernetes, AWS, etc.) and any library dependencies.
73+
74+
3. (Optional) Set up a local Stack (if you want to try this outside Databricks):
75+
```bash
76+
make install-stack-local
77+
```
78+
79+
4. If you have Databricks properly configured in your ZenML stack (with the Databricks token secret set up, cluster name, etc.), you can orchestrate the pipelines on Databricks by default.
80+
81+
## Running the Project
82+
83+
All pipeline runs happen via the CLI in run.py. Here are the main options:
84+
85+
• View available options:
86+
```bash
87+
python run.py --help
88+
```
89+
90+
• Run everything (train, deploy, inference) with default settings:
91+
```bash
92+
python run.py --training --deployment --inference
93+
```
94+
This will:
95+
1. Train a model and evaluate its performance
96+
2. Deploy the model if it meets quality criteria
97+
3. Run batch inference with SHAP explanations and data drift checks
98+
99+
• Run just the training pipeline (to build or update a model):
100+
```bash
101+
python run.py --training
102+
```
103+
104+
• Run just the deployment pipeline (to deploy the latest staged model):
105+
```bash
106+
python run.py --deployment
107+
```
108+
109+
• Run just the batch inference pipeline (to generate predictions and explanations while checking for data drift):
110+
```bash
111+
python run.py --inference
112+
```
113+
114+
### Additional Command-Line Flags
115+
116+
• Disable caching:
117+
```bash
118+
python run.py --no-cache --training
119+
```
120+
121+
• Skip dropping NA values or skipping normalization:
122+
```bash
123+
python run.py --no-drop-na --no-normalize --training
124+
```
125+
126+
• Drop specific columns:
127+
```bash
128+
python run.py --training --drop-columns columnA,columnB
129+
```
130+
131+
• Set minimal accuracy thresholds for training and test sets:
132+
```bash
133+
python run.py --min-train-accuracy 0.9 --min-test-accuracy 0.8 --fail-on-accuracy-quality-gates --training
134+
```
135+
136+
When you run any of these commands, ZenML will orchestrate each pipeline on the active stack (Databricks if configured) and log the results in your model registry (MLflow). If you have Slack or other alerter components configured, you'll see success/failure notifications.
137+
138+
## Observing Your Pipelines
139+
140+
ZenML offers a local dashboard that you can launch with:
141+
```bash
142+
zenml up
143+
```
144+
Check the terminal logs for the local web address (usually http://127.0.0.1:8237). You'll see pipeline runs, steps, and artifacts.
145+
146+
If you deployed on Databricks, you can also see the runs orchestrated in the Databricks jobs UI. The project is flexible enough to run the same pipelines locally or in the cloud without changing the code.
147+
148+
## Contributing & License
149+
150+
Contributions and suggestions are welcome. This project is licensed under the Apache License 2.0.
151+
152+
For questions, feedback, or support, please reach out to the ZenML community or open an issue in this repository.
153+
154+
---

databricks-demo/configs/train_config.yaml renamed to databricks-production-qa-demo/configs/train_config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ settings:
2424
- mlflow
2525
- sklearn
2626
- databricks
27+
python_package_installer: "uv"
2728
orchestrator.databricks:
2829
cluster_name: adas_
2930
node_type_id: Standard_D8ads_v5

databricks-demo/pipelines/batch_inference.py renamed to databricks-production-qa-demo/pipelines/batch_inference.py

+6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from zenml.integrations.evidently.steps import evidently_report_step
3030
from zenml.logger import get_logger
3131

32+
from steps.explainability import explain_model
33+
3234
logger = get_logger(__name__)
3335

3436

@@ -53,6 +55,10 @@ def production_line_qa_batch_inference():
5355
preprocess_pipeline=model.get_artifact("preprocess_pipeline"),
5456
target=target,
5557
)
58+
59+
########## Model Explainability stage ##########
60+
explain_model(df_inference)
61+
5662
########## DataQuality stage ##########
5763
report, _ = evidently_report_step(
5864
reference_dataset=model.get_artifact("dataset_trn"),
File renamed without changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .shap_explainer import explain_model

0 commit comments

Comments
 (0)