A federated reinforcement learning pipeline for optimal ventilator management using the CLIF (Common Longitudinal ICU Format) data standard.
This pipeline implements a Double Deep Q-Network (DDQN) with Conservative Q-Learning (CQL) for learning optimal ventilator management policies from multi-site ICU data.
- Python 3.9+
- CLIF-formatted ICU data tables
- uv package manager (https://github.com/astral-sh/uv)
# Install uv if not already installed
curl -LsSf https://astral.sh/uv/install.sh | sh
# Clone the repository
git clone <repository-url>
cd CLIF-RL
# Install dependencies using uv
uv pip install -e .- Copy the configuration template:
cp config/config_template.json config/config.json- Edit
config/config.jsonwith your site's information:
{
"site_id": "yoursite",
"site_name": "Your Site Name",
"tables_path": "/path/to/your/CLIF/tables",
"file_type": "parquet",
"timezone": "US/Eastern",
"data_path": "model_data/train.parquet",
"standardization_path": "shared/state_standardization.pkl",
"train_ratio": 0.8,
"random_seed": 42
}The easiest way to run the entire Phase 1 pipeline is using the provided script:
# Make sure the script is executable
chmod +x run_phase1.sh
# Run with default settings
./run_phase1.shEach site performs the following steps:
- Data Preparation
# Create cohort
uv run python code/00_cohort.py --config config/config.json
# Create analysis dataset
uv run python code/01_analysis_data.py --config config/config.json
# Filter IMV episodes
uv run python code/02_filter_imv_episodes.py --config config/config.json
# Calculate SOFA scores
uv run python code/03_sofa_calculator.py --config config/config.json
# Prepare training data
uv run python code/04_prepare_training_data.py --config config/config.json- Generate Table One
uv run python code/05_create_tableone.py --config config/config.json- Train Local Model
uv run python code/06_site_training.py \
--config config/config.json \
--round 0 \
--epochs 50 \
--use_cql \
--cql_alpha 0.02- Local Validation
# Simply run the evaluation - it will automatically find your model
uv run python code/09_visualize_training.py --config config/config.json- MIMIC Baseline Evaluation
Each site can evaluate how well the MIMIC-trained model generalizes to their local data.
# Evaluate MIMIC model on your site's data
python code/13_evaluate_mimic_baseline.py --config config/config.json- Share Results
Upload the entire
output/upload_to_box/phase1/folder to the coordinating center:
# Package Phase 1 results for sharing
tar -czf phase1_results_{site_name}.tar.gz output/upload_to_box/phase1/After collecting all site weights for a round:
python code/07_federated_aggregation.py \
--weights_dir ./collected_weights \
--round 0 \
--sites mimic ucmc nu rush\
--output ./shared/global_round_0.ptAfter receiving the global model from the coordinating center:
# Make the script executable
chmod +x run_phase2.sh
# Run complete Phase 2 pipeline
./run_phase2.sh --site ucmc --round 1-
Download Global Model Place
global_round_0.ptin theshared/directory -
Train with Global Model
python code/06_site_training.py \
--config config/config.json \
--round 1 \
--epochs 25 \
--global_model shared/global_round_0.pt \
--use_cql- Evaluate Phase 2 Model
python code/09_visualize_training.py \
--config config/config.json \
--phase phase2 \
--round 1- Compare YOUR Phase 1 vs Phase 2 (Single-Site Comparison)
# Use the dedicated single-site comparison script
python code/compare_phase1_vs_phase2.py \
--site ucmc \
--config config/config.json- Package and Share Results
tar -czf phase2_round1_{site_name}.tar.gz \
all_site_data/{site_name}/upload_to_box/phase2/round_1/After collecting Phase 2 results from all sites:
- Multi-Site Comparison (Script 10)
python code/10_compare_models.py \
--config config/config.json \
--round 1 \
--sites mimic ucmc rush nu \
--data_dir all_site_data \
--compare_phases \
--output_dir output/phase2_comparison- Generate Aggregate Forest Plots (Script 11)
python code/11_aggregate_results.py \
--results_dir all_site_data \
--sites mimic ucmc rush nu \
--phases phase1 phase2 \
--output_dir output/final_results- Run Complete Evaluation (Script 12)
python code/12_automated_evaluation.py \
--phase phase2 \
--round 1 \
--results_dir all_site_data \
--output_dir output/phase2_reportAll outputs are in the upload_to_box folder and can be shared with the coordinating center. This allows for:
- Cross-site comparison of practice variations
- Identification of site-specific patterns
- Baseline metrics for federated learning improvements
code/
├── 00_cohort.py # Cohort identification
├── 01_analysis_data.py # Data analysis
├── 02_filter_imv_episodes.py # IMV filtering
├── 03_sofa_calculator.py # SOFA calculation
├── 04_prepare_training_data.py # Training data preparation
├── 05_create_tableone.py # Baseline characteristics
├── 06_site_training.py # Site training
├── 07_federated_aggregation.py # FedAvg aggregation
├── 09_visualize_training.py # Model evaluation & action distribution plots
├── 10_compare_models.py # Phase 1 vs Phase 2 statistical comparison
├── 11_aggregate_results.py # **Forest plots & multi-site visualization**
├── 12_automated_evaluation.py # Complete evaluation pipeline (runs all above)
├── 13_evaluate_mimic_baseline.py # Cross-site model evaluation
├── models/
│ └── dueling_dqn.py # Neural network
├── training/
│ ├── ddqn_base.py # Base DDQN
│ └── cql_dqn_trainer.py # CQL extension
└── utils.py # Utilities