Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SiddhantSadangi committed Jul 5, 2024
1 parent cd380be commit db90f22
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

import matplotlib as mpl
import pytest
from numpy import array_equal
from sklearn.cluster import KMeans
from sklearn.dummy import (
DummyClassifier,
DummyRegressor,
)
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import GridSearchCV

import neptune_sklearn as npt_utils
Expand All @@ -35,20 +37,23 @@ def test_classifier_summary(iris):

def test_regressor_summary(diabetes):
with init_run() as run:
model = DummyRegressor()
model = LinearRegression()
model.fit(diabetes.x_train, diabetes.y_train)

original_coef = model.coef_

run["summary"] = npt_utils.create_regressor_summary(
model, diabetes.x_train, diabetes.x_test, diabetes.y_train, diabetes.y_test
)

assert array_equal(model.coef_, original_coef), "Original model coefficients modified."

run.wait()
validate_run(run, log_charts=True)


def test_kmeans_summary(iris):
with init_run() as run:

model = KMeans()
model.fit(iris.x)

Expand Down Expand Up @@ -77,7 +82,11 @@ def test_unsupported_object(diabetes):
)

run["regressor_summary"] = npt_utils.create_regressor_summary(
grid_cv, diabetes.x_train, diabetes.x_test, diabetes.y_train, diabetes.y_test
grid_cv,
diabetes.x_train,
diabetes.x_test,
diabetes.y_train,
diabetes.y_test,
)

run.wait()
Expand Down

0 comments on commit db90f22

Please sign in to comment.