Skip to content

carlg/forest-memory #943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c87b9d9
Add a memmap before Parallel
carl-offerfit Jan 7, 2025
672699f
Filename fix.
carl-offerfit Jan 7, 2025
4e9820a
Script update
carl-offerfit Jan 8, 2025
d0d575d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2025
5b3db78
Remove test script from the PR
carl-offerfit Jan 8, 2025
b235744
Merge branch 'carl/causal-memory' of github.com:carl-offerfit/EconML …
carl-offerfit Jan 8, 2025
7c06e0b
Merge branch 'main' into carl/causal-memory
carl-offerfit Jan 8, 2025
060d418
Merge branch 'main' into carl/causal-memory
carl-offerfit Feb 21, 2025
362b57e
Make memmap an option, and add the reference in the doc string
carl-offerfit Feb 21, 2025
ebb4348
Start a notebook to demonstrate causal forest memory usage, by copyin…
carl-offerfit Feb 21, 2025
5d694f7
Set use_memmap in constructor
carl-offerfit Feb 26, 2025
d5f14ac
Merge branch 'main' into carl/causal-memory
carl-offerfit Apr 8, 2025
8df2c15
Shell script to load a data file and fit a CausalForestDML with diffe…
carl-offerfit Apr 10, 2025
5ea336f
Try memory profiler, memory usage
carl-offerfit Apr 10, 2025
1ef0a53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
3cb90d1
Add a printout to make sure it is working
carl-offerfit Apr 10, 2025
228c17d
Import cleanup
carl-offerfit Apr 10, 2025
6dbdd52
Add catboost and save the output of the memory test script
carl-offerfit Apr 28, 2025
27dc80c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2025
e9a5405
Use default estimators
carl-offerfit Apr 28, 2025
059c4a0
Merge branches 'carl/causal-memory' and 'carl/causal-memory' of githu…
carl-offerfit Apr 28, 2025
6910ed3
Fix name conflict between input file and result file
carl-offerfit Apr 28, 2025
bb69fee
Make sure we remove the memory for the causal forest estimator
carl-offerfit Apr 28, 2025
fee0757
Limit digits
carl-offerfit Apr 28, 2025
e247739
Switch to separate runs for catboost and causalforest
carl-offerfit Apr 28, 2025
c39ef5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2025
2f2491c
More memory tests
carl-offerfit May 6, 2025
39a9d3e
Run all the memory tests
carl-offerfit May 7, 2025
c64d911
Better memory test that adds up the numpy arrays in the estimator
carl-offerfit May 9, 2025
5ff5807
Fix print statement
carl-offerfit May 9, 2025
e2baf3e
Merge branch 'carl/causal-memory' of github.com:carl-offerfit/EconML …
carl-offerfit May 9, 2025
999d030
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,10 @@ class CausalForestDML(_BaseDML):
at depth `depth`, is re-weighted by 1 / (1 + `depth`)**2.0. See the method ``feature_importances``
for a method that allows one to change these defaults.

use_memmap: Whether to use a numpy memmap to pass data to parallel training. Helps
reduce memory overhead for large data sets. For details on memmap see:
https://numpy.org/doc/stable/reference/generated/numpy.memmap.html

References
----------
.. [cfdml1] Athey, Susan, Julie Tibshirani, and Stefan Wager. "Generalized random forests."
Expand Down Expand Up @@ -619,7 +623,8 @@ def __init__(self, *,
verbose=0,
allow_missing=False,
use_ray=False,
ray_remote_func_options=None):
ray_remote_func_options=None,
use_memmap=False):

# TODO: consider whether we need more care around stateful featurizers,
# since we clone it and fit separate copies
Expand Down Expand Up @@ -647,6 +652,7 @@ def __init__(self, *,
self.fit_intercept = fit_intercept
self.subforest_size = subforest_size
self.n_jobs = n_jobs
self.use_memmap = use_memmap
self.verbose = verbose
super().__init__(discrete_outcome=discrete_outcome,
discrete_treatment=discrete_treatment,
Expand Down Expand Up @@ -698,7 +704,8 @@ def _gen_model_final(self):
n_jobs=self.n_jobs,
random_state=self.random_state,
verbose=self.verbose,
warm_start=False))
warm_start=False,
use_memmap=self.use_memmap))

def _gen_rlearner_model_final(self):
return _CausalForestFinalWrapper(self._gen_model_final(), self._gen_featurizer(),
Expand Down
25 changes: 23 additions & 2 deletions econml/grf/_base_grf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#
# Copyright (c) 2007-2020 The scikit-learn developers.
# All rights reserved.

import gc
import numbers
from warnings import warn
from abc import ABCMeta, abstractmethod
Expand All @@ -27,6 +27,7 @@
from sklearn.utils import check_X_y
import scipy.stats
from scipy.special import erfc
import tempfile

__all__ = ["BaseGRF"]

Expand All @@ -51,6 +52,11 @@ class BaseGRF(BaseEnsemble, metaclass=ABCMeta):

Warning: This class should not be used directly. Use derived classes
instead.


use_memmap: Whether to use a numpy memmap to pass data to parallel training. Helps
reduce memory overhead for large data sets. For details on memmap see:
https://numpy.org/doc/stable/reference/generated/numpy.memmap.html
"""

def __init__(self,
Expand All @@ -73,7 +79,8 @@ def __init__(self,
n_jobs=-1,
random_state=None,
verbose=0,
warm_start=False):
warm_start=False,
use_memmap=False):
super().__init__(
base_estimator=GRFTree(),
n_estimators=n_estimators,
Expand Down Expand Up @@ -103,6 +110,7 @@ def __init__(self,
self.verbose = verbose
self.warm_start = warm_start
self.max_samples = max_samples
self.use_memmap = use_memmap

@abstractmethod
def _get_alpha_and_pointJ(self, X, T, y, **kwargs):
Expand Down Expand Up @@ -384,12 +392,25 @@ def fit(self, X, T, y, *, sample_weight=None, **kwargs):
s_inds = [subsample_random_state.choice(n_samples, n_samples_subsample, replace=False)
for _ in range(n_more_estimators)]

if self.use_memmap:
# Make a memmap for better performance on large number of treatment variables
with tempfile.NamedTemporaryFile(delete=False, suffix=".npy") as temp_file:
filename = temp_file.name
print(f"BaseGRF.fit Making memmap with temp file {filename}")
np.save(filename, yaug) # Save array to disk
# Remove references to (potentially) large data before Parallel
del yaug, pointJ
gc.collect()
# Create the memmap version
yaug = np.load(filename, mmap_mode='r')

# Parallel loop: we prefer the threading backend as the Cython code
# for fitting the trees is internally releasing the Python GIL
# making threading more efficient than multiprocessing in
# that case. However, for joblib 0.12+ we respect any
# parallel_backend contexts set at a higher level,
# since correctness does not rely on using threads.

trees = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, backend='threading')(
delayed(t.fit)(X[s], yaug[s], self.n_y_, self.n_outputs_, self.n_relevant_outputs_,
sample_weight=sample_weight[s] if sample_weight is not None else None,
Expand Down
5 changes: 3 additions & 2 deletions econml/grf/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ def __init__(self,
n_jobs=-1,
random_state=None,
verbose=0,
warm_start=False):
warm_start=False,
use_memmap=False):
super().__init__(n_estimators=n_estimators, criterion=criterion, max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf, min_weight_fraction_leaf=min_weight_fraction_leaf,
Expand All @@ -368,7 +369,7 @@ def __init__(self,
max_samples=max_samples, min_balancedness_tol=min_balancedness_tol,
honest=honest, inference=inference, fit_intercept=fit_intercept,
subforest_size=subforest_size, n_jobs=n_jobs, random_state=random_state, verbose=verbose,
warm_start=warm_start)
warm_start=warm_start, use_memmap=use_memmap)

def fit(self, X, T, y, *, sample_weight=None):
"""
Expand Down
37 changes: 37 additions & 0 deletions notebooks/Causal Forest Memory Demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "initial_id",
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
""
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading