Skip to content
This repository was archived by the owner on Jun 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #26 from sparks-baird/notebooks
Browse files Browse the repository at this point in the history
refactor .py examples into basic-usage notebook
  • Loading branch information
sgbaird authored Jun 5, 2022
2 parents 7f20fb7 + e1eef8b commit a390f2f
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 51 deletions.
230 changes: 230 additions & 0 deletions notebooks/1.0-basic-usage.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Materials Project time splits for materials generative benchmarking"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we will install the `mp_time_split` package and run through the following examples:\n",
"1. accessing the cross-validation folds and final train/test split\n",
"2. \"fitting\" a DummyGenerator model and comparing to validation data\n",
"3. evaluating cross-validated model accuracy\n",
"4. hyperparameter optimization of generator statistic(s)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Installation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We install the optional dependency `pyxtal` to generate random crystal structures within the `DummyGenerator` model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install mp_time_split[pyxtal]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Access the data and the data splits"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use the `MPTimeSplit` class as the main interface with the benchmark dataset in\n",
"each of the examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mp_time_split.core import MPTimeSplit"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We use the default `\"TimeSeriesSplit\"` cross-validation splitting scheme. \n",
"\n",
"We specify `\"energy_above_hull\"` as the target which is surfaced in the `train_outputs`,\n",
"`val_outputs`, and `test_outputs` `Series`-s. The target variable is excluded from the\n",
"corresponding `_inputs` variables, i.e. `train_inputs`, `val_inputs`, and `test_inputs`\n",
"to prevent data leakage during conditional generation, regression/classification, and\n",
"hyperparameter optimization."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mpt = MPTimeSplit(mode=\"TimeSeriesSplit\", target=\"energy_above_hull\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We load the full snapshot dataset (~30 MB compressed). To load and work with a much smaller dummy\n",
"dataset (~10 kB), set `dummy=True`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mpt.load(dummy=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similar to Matbench, we loop through each of the folds of the train and validation\n",
"splits and can also access the final train/test split. We use the default \""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for fold in mpt.folds:\n",
" train_inputs, val_inputs, train_outputs, val_outputs = mpt.get_train_and_val_data(\n",
" fold\n",
" )\n",
"final_inputs, test_inputs, final_outputs, test_outputs = mpt.get_test_data()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mp_time_split.core import MPTimeSplit\n",
"from mp_time_split.utils.gen import DummyGenerator\n",
"\n",
"mpt = MPTimeSplit(target=\"energy_above_hull\")\n",
"mpt.load(dummy=True)\n",
"\n",
"for fold in mpt.folds:\n",
" train_inputs, val_inputs, train_outputs, val_outputs = mpt.get_train_and_val_data(\n",
" fold\n",
" )\n",
" dg = DummyGenerator()\n",
" dg.fit(train_inputs)\n",
" generated_structures = dg.gen(n=100)\n",
" # compare generated_structures and val_inputs\n",
" # some_code_here\n",
"\n",
"1 + 1\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.dummy import DummyRegressor\n",
"from sklearn.metrics import mean_absolute_error\n",
"\n",
"from mp_time_split.core import MPTimeSplit\n",
"\n",
"model = DummyRegressor(strategy=\"mean\")\n",
"\n",
"mpt = MPTimeSplit(target=\"energy_above_hull\")\n",
"mpt.load(dummy=False)\n",
"\n",
"maes = []\n",
"for fold in mpt.folds:\n",
" train_inputs, val_inputs, train_outputs, val_outputs = mpt.get_train_and_val_data(\n",
" fold\n",
" )\n",
" model.fit(train_inputs, train_outputs)\n",
" predictions = model.predict(val_inputs)\n",
" mae = mean_absolute_error(val_outputs, predictions)\n",
" maes.append(mae)\n",
"\n",
"np.mean(maes)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mp_time_split.core import MPTimeSplit\n",
"from mp_time_split.utils.gen import DummyGenerator\n",
"\n",
"mpt = MPTimeSplit(target=\"energy_above_hull\")\n",
"mpt.load(dummy=True)\n",
"\n",
"def compare(inputs, gen_inputs):\n",
" inputs, gen_inputs\n",
" return np.random.rand()\n",
"\n",
"def fit_and_evaluate(parameterization):\n",
" metrics = []\n",
" for fold in mpt.folds:\n",
" train_inputs, val_inputs, _, _ = mpt.get_train_and_val_data(\n",
" fold\n",
" )\n",
" dg = DummyGenerator(**parameterization)\n",
" dg.fit(train_inputs)\n",
" generated_structures = dg.gen(n=100)\n",
" # compare generated_structures and val_inputs\n",
" metric = compare(val_inputs, generated_structures)\n",
" metrics.append(metric)\n",
" avg_metric = np.mean(metrics)\n",
" return avg_metric\n",
" \n",
"parameterization = {}\n",
"\n",
"fit_and_evaluate(parameterization)\n",
"\n",
"1 + 1\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
11 changes: 0 additions & 11 deletions notebooks/1.0_basic.py

This file was deleted.

18 changes: 0 additions & 18 deletions notebooks/1.1_generation.py

This file was deleted.

22 changes: 0 additions & 22 deletions notebooks/1.2_accuracy.py

This file was deleted.

0 comments on commit a390f2f

Please sign in to comment.