Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parallel execution for permutation_importance #244

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions eli5/permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import numpy as np # type: ignore
from sklearn.utils import check_random_state # type: ignore

from multiprocess import Pool # type: ignore

def iter_shuffled(X, columns_to_shuffle=None, pre_shuffle=False,
random_state=None):
Expand Down Expand Up @@ -58,7 +58,8 @@ def get_score_importances(
y,
n_iter=5, # type: int
columns_to_shuffle=None,
random_state=None
random_state=None,
n_jobs=1
):
# type: (...) -> Tuple[float, List[np.ndarray]]
"""
Expand All @@ -84,12 +85,15 @@ def get_score_importances(
"""
rng = check_random_state(random_state)
base_score = score_func(X, y)
seed0 = rng.randint(2**32)
pool = Pool(n_jobs)
result = pool.map(
lambda seed: _get_scores_shufled(score_func, X, y,
columns_to_shuffle=columns_to_shuffle,
random_state=np.random.RandomState(seed)),
range(seed0, seed0+n_iter))
scores_decreases = []
for i in range(n_iter):
scores_shuffled = _get_scores_shufled(
score_func, X, y, columns_to_shuffle=columns_to_shuffle,
random_state=rng
)
for scores_shuffled in result:
scores_decreases.append(-scores_shuffled + base_score)
return base_score, scores_decreases

Expand Down
7 changes: 5 additions & 2 deletions eli5/sklearn/permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class PermutationImportance(BaseEstimator, MetaEstimatorMixin):
Whether to fit the estimator on the whole data if cross-validation
is used (default is True).

n_jobs : int, number of parallel jobs for shuffle iterations

Attributes
----------
feature_importances_ : array
Expand All @@ -142,7 +144,7 @@ class PermutationImportance(BaseEstimator, MetaEstimatorMixin):
random state
"""
def __init__(self, estimator, scoring=None, n_iter=5, random_state=None,
cv='prefit', refit=True):
cv='prefit', refit=True, n_jobs=1):
# type: (...) -> None
if isinstance(cv, str) and cv != "prefit":
raise ValueError("Invalid cv value: {!r}".format(cv))
Expand All @@ -152,6 +154,7 @@ def __init__(self, estimator, scoring=None, n_iter=5, random_state=None,
self.n_iter = n_iter
self.random_state = random_state
self.cv = cv
self.n_jobs = n_jobs
self.rng_ = check_random_state(random_state)

def _wrap_scorer(self, base_scorer, pd_columns):
Expand Down Expand Up @@ -228,7 +231,7 @@ def _non_cv_scores_importances(self, X, y):

def _get_score_importances(self, score_func, X, y):
return get_score_importances(score_func, X, y, n_iter=self.n_iter,
random_state=self.rng_)
random_state=self.rng_, n_jobs=self.n_jobs)

@property
def caveats_(self):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ attrs > 16.0.0
jinja2
pip >= 8.1
setuptools >= 20.7
multiprocess
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_long_description():
'typing',
'graphviz',
'tabulate>=0.7.7',
'multiprocess',
],
extras_require={
":python_version<'3.5.6'": [
Expand Down
15 changes: 8 additions & 7 deletions tests/test_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ def is_shuffled(X, X_sh, col):
def test_get_feature_importances(boston_train):
X, y, feat_names = boston_train
svr = SVR(C=20).fit(X, y)
score, importances = get_score_importances(svr.score, X, y)
assert score > 0.7
importances = dict(zip(feat_names, np.mean(importances, axis=0)))
print(score)
print(importances)
assert importances['AGE'] > importances['NOX']
assert importances['B'] > importances['CHAS']
for n_jobs in [1, 2]:
score, importances = get_score_importances(svr.score, X, y, n_jobs=n_jobs)
assert score > 0.7
importances = dict(zip(feat_names, np.mean(importances, axis=0)))
print(score)
print(importances)
assert importances['AGE'] > importances['NOX']
assert importances['B'] > importances['CHAS']