This repository was archived by the owner on Jun 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #26 from sparks-baird/notebooks
refactor .py examples into basic-usage notebook
- Loading branch information
Showing
4 changed files
with
230 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Materials Project time splits for materials generative benchmarking" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"In this notebook, we will install the `mp_time_split` package and run through the following examples:\n", | ||
"1. accessing the cross-validation folds and final train/test split\n", | ||
"2. \"fitting\" a DummyGenerator model and comparing to validation data\n", | ||
"3. evaluating cross-validated model accuracy\n", | ||
"4. hyperparameter optimization of generator statistic(s)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Installation" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We install the optional dependency `pyxtal` to generate random crystal structures within the `DummyGenerator` model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install mp_time_split[pyxtal]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Access the data and the data splits" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We will use the `MPTimeSplit` class as the main interface with the benchmark dataset in\n", | ||
"each of the examples." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from mp_time_split.core import MPTimeSplit" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We use the default `\"TimeSeriesSplit\"` cross-validation splitting scheme. \n", | ||
"\n", | ||
"We specify `\"energy_above_hull\"` as the target which is surfaced in the `train_outputs`,\n", | ||
"`val_outputs`, and `test_outputs` `Series`-s. The target variable is excluded from the\n", | ||
"corresponding `_inputs` variables, i.e. `train_inputs`, `val_inputs`, and `test_inputs`\n", | ||
"to prevent data leakage during conditional generation, regression/classification, and\n", | ||
"hyperparameter optimization." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"mpt = MPTimeSplit(mode=\"TimeSeriesSplit\", target=\"energy_above_hull\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We load the full snapshot dataset (~30 MB compressed). To load and work with a much smaller dummy\n", | ||
"dataset (~10 kB), set `dummy=True`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"mpt.load(dummy=False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Similar to Matbench, we loop through each of the folds of the train and validation\n", | ||
"splits and can also access the final train/test split. We use the default \"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for fold in mpt.folds:\n", | ||
" train_inputs, val_inputs, train_outputs, val_outputs = mpt.get_train_and_val_data(\n", | ||
" fold\n", | ||
" )\n", | ||
"final_inputs, test_inputs, final_outputs, test_outputs = mpt.get_test_data()\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from mp_time_split.core import MPTimeSplit\n", | ||
"from mp_time_split.utils.gen import DummyGenerator\n", | ||
"\n", | ||
"mpt = MPTimeSplit(target=\"energy_above_hull\")\n", | ||
"mpt.load(dummy=True)\n", | ||
"\n", | ||
"for fold in mpt.folds:\n", | ||
" train_inputs, val_inputs, train_outputs, val_outputs = mpt.get_train_and_val_data(\n", | ||
" fold\n", | ||
" )\n", | ||
" dg = DummyGenerator()\n", | ||
" dg.fit(train_inputs)\n", | ||
" generated_structures = dg.gen(n=100)\n", | ||
" # compare generated_structures and val_inputs\n", | ||
" # some_code_here\n", | ||
"\n", | ||
"1 + 1\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"from sklearn.dummy import DummyRegressor\n", | ||
"from sklearn.metrics import mean_absolute_error\n", | ||
"\n", | ||
"from mp_time_split.core import MPTimeSplit\n", | ||
"\n", | ||
"model = DummyRegressor(strategy=\"mean\")\n", | ||
"\n", | ||
"mpt = MPTimeSplit(target=\"energy_above_hull\")\n", | ||
"mpt.load(dummy=False)\n", | ||
"\n", | ||
"maes = []\n", | ||
"for fold in mpt.folds:\n", | ||
" train_inputs, val_inputs, train_outputs, val_outputs = mpt.get_train_and_val_data(\n", | ||
" fold\n", | ||
" )\n", | ||
" model.fit(train_inputs, train_outputs)\n", | ||
" predictions = model.predict(val_inputs)\n", | ||
" mae = mean_absolute_error(val_outputs, predictions)\n", | ||
" maes.append(mae)\n", | ||
"\n", | ||
"np.mean(maes)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from mp_time_split.core import MPTimeSplit\n", | ||
"from mp_time_split.utils.gen import DummyGenerator\n", | ||
"\n", | ||
"mpt = MPTimeSplit(target=\"energy_above_hull\")\n", | ||
"mpt.load(dummy=True)\n", | ||
"\n", | ||
"def compare(inputs, gen_inputs):\n", | ||
" inputs, gen_inputs\n", | ||
" return np.random.rand()\n", | ||
"\n", | ||
"def fit_and_evaluate(parameterization):\n", | ||
" metrics = []\n", | ||
" for fold in mpt.folds:\n", | ||
" train_inputs, val_inputs, _, _ = mpt.get_train_and_val_data(\n", | ||
" fold\n", | ||
" )\n", | ||
" dg = DummyGenerator(**parameterization)\n", | ||
" dg.fit(train_inputs)\n", | ||
" generated_structures = dg.gen(n=100)\n", | ||
" # compare generated_structures and val_inputs\n", | ||
" metric = compare(val_inputs, generated_structures)\n", | ||
" metrics.append(metric)\n", | ||
" avg_metric = np.mean(metrics)\n", | ||
" return avg_metric\n", | ||
" \n", | ||
"parameterization = {}\n", | ||
"\n", | ||
"fit_and_evaluate(parameterization)\n", | ||
"\n", | ||
"1 + 1\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.