Skip to content

Commit

Permalink
Yellowbrick matploltib incompatibility fix; expanded tests; replaced …
Browse files Browse the repository at this point in the history
…print with warning
  • Loading branch information
SiddhantSadangi committed Mar 26, 2024
1 parent 107a31c commit 2db62dc
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.9]
python-version: ["3.8", "3.10", "3.12"]
steps:
- uses: actions/checkout@v2

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ scikit-learn = ">=0.24.1"
yellowbrick = ">=1.3"
scikit-plot = ">=0.3.7"
scipy = "<1.12" # Fixes #24 (https://github.com/neptune-ai/neptune-sklearn/issues/24)
matplotlib = { version = "<3.3", python = "<3.12"} # Fixes https://github.com/DistrictDataLabs/yellowbrick/issues/1312

# dev
pre-commit = { version = "*", optional = true }
Expand Down
35 changes: 19 additions & 16 deletions src/neptune_sklearn/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
)
from neptune.new.utils import stringify_unsupported

from warnings import warn


def create_regressor_summary(regressor, X_train, X_test, y_train, y_test, nrows=1000, log_charts=True):
"""Creates scikit-learn regressor summary.
Expand Down Expand Up @@ -455,7 +457,7 @@ def get_test_preds_proba(classifier, X_test=None, y_pred_proba=None, nrows=1000)
try:
y_pred_proba = classifier.predict_proba(X_test)
except Exception as e:
print("This classifier does not provide predictions probabilities. Error: {}".format(e))
warn(f"This classifier does not provide predictions probabilities. Error: {e}")
return

df = pd.DataFrame(data=y_pred_proba, columns=classifier.classes_)
Expand Down Expand Up @@ -590,7 +592,7 @@ def create_learning_curve_chart(regressor, X_train, y_train):
chart = File.as_image(fig)
plt.close(fig)
except Exception as e:
print("Did not log learning curve chart. Error: {}".format(e))
warn(f"Did not log learning curve chart. Error: {e}")

return chart

Expand Down Expand Up @@ -633,7 +635,7 @@ def create_feature_importance_chart(regressor, X_train, y_train):
chart = File.as_image(fig)
plt.close(fig)
except Exception as e:
print("Did not log feature importance chart. Error: {}".format(e))
warn(f"Did not log feature importance chart. Error: {e}")

return chart

Expand Down Expand Up @@ -678,7 +680,7 @@ def create_residuals_chart(regressor, X_train, X_test, y_train, y_test):
chart = File.as_image(fig)
plt.close(fig)
except Exception as e:
print("Did not log residuals chart. Error: {}".format(e))
warn(f"Did not log residuals chart. Error: {e}")

return chart

Expand Down Expand Up @@ -723,7 +725,7 @@ def create_prediction_error_chart(regressor, X_train, X_test, y_train, y_test):
chart = File.as_image(fig)
plt.close(fig)
except Exception as e:
print("Did not log prediction error chart. Error: {}".format(e))
warn(f"Did not log prediction error chart. Error: {e}")

return chart

Expand Down Expand Up @@ -765,7 +767,7 @@ def create_cooks_distance_chart(regressor, X_train, y_train):
chart = File.as_image(fig)
plt.close(fig)
except Exception as e:
print("Did not log cooks distance chart. Error: {}".format(e))
warn(f"Did not log cooks distance chart. Error: {e}")

return chart

Expand Down Expand Up @@ -812,7 +814,7 @@ def create_classification_report_chart(classifier, X_train, X_test, y_train, y_t
chart = File.as_image(fig)
plt.close(fig)
except Exception as e:
print("Did not log Classification Report chart. Error: {}".format(e))
warn(f"Did not log Classification Report chart. Error: {e}")

return chart

Expand Down Expand Up @@ -859,7 +861,7 @@ def create_confusion_matrix_chart(classifier, X_train, X_test, y_train, y_test):
chart = File.as_image(fig)
plt.close(fig)
except Exception as e:
print("Did not log Confusion Matrix chart. Error: {}".format(e))
warn(f"Did not log Confusion Matrix chart. Error: {e}")

return chart

Expand Down Expand Up @@ -904,7 +906,7 @@ def create_roc_auc_chart(classifier, X_train, X_test, y_train, y_test):
chart = File.as_image(fig)
plt.close(fig)
except Exception as e:
print("Did not log ROC-AUC chart. Error {}".format(e))
warn(f"Did not log ROC-AUC chart. Error {e}")

return chart

Expand Down Expand Up @@ -943,9 +945,9 @@ def create_precision_recall_chart(classifier, X_test, y_test, y_pred_proba=None)
try:
y_pred_proba = classifier.predict_proba(X_test)
except Exception as e:
print(
"Did not log Precision-Recall chart: this classifier does not provide predictions probabilities."
"Error {}".format(e)
warn(
f"""Did not log Precision-Recall chart: this classifier does not provide predictions probabilities.
Error {e}"""
)
return chart

Expand All @@ -955,7 +957,7 @@ def create_precision_recall_chart(classifier, X_test, y_test, y_pred_proba=None)
chart = File.as_image(fig)
plt.close(fig)
except Exception as e:
print("Did not log Precision-Recall chart. Error {}".format(e))
warn(f"Did not log Precision-Recall chart. Error {e}")

return chart

Expand Down Expand Up @@ -1002,7 +1004,7 @@ def create_class_prediction_error_chart(classifier, X_train, X_test, y_train, y_
chart = File.as_image(fig)
plt.close(fig)
except Exception as e:
print("Did not log Class Prediction Error chart. Error {}".format(e))
warn(f"Did not log Class Prediction Error chart. Error {e}")

return chart

Expand Down Expand Up @@ -1088,7 +1090,7 @@ def create_kelbow_chart(model, X, **kwargs):
chart = File.as_image(fig)
plt.close(fig)
except Exception as e:
print("Did not log KMeans elbow chart. Error {}".format(e))
warn(f"Did not log KMeans elbow chart. Error {e}")

return chart

Expand Down Expand Up @@ -1140,6 +1142,7 @@ def create_silhouette_chart(model, X, **kwargs):
charts.append(File.as_image(fig))
plt.close(fig)
except Exception as e:
print("Did not log Silhouette Coefficients chart. Error {}".format(e))
warn(f"Did not log Silhouette Coefficients chart. Error {e}")

return FileSeries(charts)
return FileSeries(charts)

0 comments on commit 2db62dc

Please sign in to comment.