Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add sample_weight to permutation importances scikit-learn interface #265

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

ryanvarley
Copy link
Contributor

Sample weights are an important parameter for imbalanced problems and some forms of bias correction and are part of the scikit-learn API for all classifiers and regressors. They can also have a large impact on permutation importances and are not supported through fit_params due to the need to split them into test / train like X and y for the CV case.

Opening this as a PR now to start a discussion before I do anymore work. At the very least I would like to add to the documentation how you can use get_score_importances and pass sample weights through the score_func if you require them.

As a very quick (and perhaps slightly flawed) example, if we calculate permutation importances for our standard data and then imbalance it (the reverse of the standard case)

from sklearn.datasets import load_breast_cancer
from sklearn.utils import shuffle
from sklearn.svm import SVC
from eli5.sklearn import PermutationImportance
import matplotlib.pyplot as plt
import seaborn

seaborn.set()

data = load_breast_cancer()
X, y = shuffle(data.data, data.target, random_state=13)

perm = PermutationImportance(SVC(), cv=None).fit(X, y, sample_weight=None)

#But what if we had 10x more 0 labels than 1?
sample_weight = y.copy()
sample_weight[sample_weight==1] = 1
sample_weight[sample_weight==0] = 10

perm10 = PermutationImportance(SVC(), cv=None).fit(X, y, sample_weight=sample_weight)

seaborn.barplot(x=perm10.feature_importances_, y=data.feature_names, label='10x', color='b')
seaborn.barplot(x=perm.feature_importances_, y=data.feature_names, label='~balanced', color='g')
plt.legend()

fig1

I have seen dramatic (and less uniform) changes in feature_importances in real world imbalanced sets when doing this. I'll try and find a better example.

Issues

  • As currently implemented this would break PermutationImportance if an estimator does not support sample_weight on its fit method.
  • More of a scikit-learn issue but sample weights can get confusing fast if you combine them with a class_weight which some classifiers (e.g.RandomForestClassifier) support. In this case your sample weights are modified during the classifier fit to balance the classes but not during the permutation importance fit for the test set which can be misleading.
  • Not tested

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant