Skip to content

Commit

Permalink
Monkey patched yellowbrick.regressor.CooksDistance.draw()
Browse files Browse the repository at this point in the history
  • Loading branch information
SiddhantSadangi committed Apr 3, 2024
1 parent abfc6c2 commit 36af5ef
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.8", "3.10", "3.12"]
exclude:
- os: windows-latest
python-version: "3.10"

steps:
- uses: actions/checkout@v4

Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
## neptune-sklearn 2.1.3

### Fixes
- Monkey patches [`yellowbrick.regression.CooksDistance.draw()`](https://github.com/DistrictDataLabs/yellowbrick/blob/f7a8e950bd31452ea2f5d402a1c5d519cd163fd5/yellowbrick/regressor/influence.py#L184) to remove unsupported `use_line_collection` matplotlib arg ([#28](https://github.com/neptune-ai/neptune-sklearn/pull/28))

### Changes
- Constrained matplotlib to `<3.3` on Python `>=3.12` ([#28](https://github.com/neptune-ai/neptune-sklearn/pull/28))
- Replaced `print()` with `warnings.warn()` to better capture `stderr` ([#28](https://github.com/neptune-ai/neptune-sklearn/pull/28))

## neptune-sklearn 2.1.2
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ scikit-learn = ">=0.24.1"
yellowbrick = ">=1.3"
scikit-plot = ">=0.3.7"
scipy = "<1.12" # Fixes #24 (https://github.com/neptune-ai/neptune-sklearn/issues/24)
matplotlib = { version = "<3.3", python = "<3.12"} # Fixes https://github.com/DistrictDataLabs/yellowbrick/issues/1312

# dev
pre-commit = { version = "*", optional = true }
Expand Down
36 changes: 36 additions & 0 deletions src/neptune_sklearn/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,42 @@ def create_prediction_error_chart(regressor, X_train, X_test, y_train, y_test):
return chart


def monkey_draw(self):
"""
Monkey patches `yellowbrick.regressor.CooksDistance.draw()`
to remove unsupported matplotlib argument `use_line_collection`.
Draws a stem plot where each stem is the Cook's Distance of the instance at the
index specified by the x axis. Optionaly draws a threshold line.
"""
# Draw a stem plot with the influence for each instance
_, _, baseline = self.ax.stem(
self.distance_,
linefmt=self.linefmt,
markerfmt=self.markerfmt,
# use_line_collection=True
)

# No padding on either side of the instance index
self.ax.set_xlim(0, len(self.distance_))

# Draw the threshold for most influential points
if self.draw_threshold:
label = r"{:0.2f}% > $I_t$ ($I_t=\frac {{4}} {{n}}$)".format(self.outlier_percentage_)
self.ax.axhline(
self.influence_threshold_,
ls="--",
label=label,
c=baseline.get_color(),
lw=baseline.get_linewidth(),
)

return self.ax


CooksDistance.draw = monkey_draw


def create_cooks_distance_chart(regressor, X_train, y_train):
"""Creates cooks distance chart.
Expand Down

0 comments on commit 36af5ef

Please sign in to comment.