Skip to content

Commit

Permalink
Fix MLJ Serialization, add test for single class classifiers, fix sin…
Browse files Browse the repository at this point in the history
…gle class classifiers predict (#35)

* Fix MLJ Serialization

* reformat

* add test for single class classifiers, fix single class classifiers predict
  • Loading branch information
tylerjthomas9 authored Jun 18, 2024
1 parent 445ba4b commit 13b1919
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 19 deletions.
9 changes: 5 additions & 4 deletions src/mlj_catboostclassifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,11 @@ MMI.reports_feature_importances(::Type{<:CatBoostClassifier}) = true
function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool)
if fitresult[1] === nothing
# Always predict the single class
n = nrow(X_pool)
n = pyconvert(Int, X_pool.shape[0])
classes = [fitresult.single_class]
probs = ones(n, 1)
return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first)
pool = MMI.categorical([fitresult.y_first])
return MMI.UnivariateFinite(classes, probs; pool=pool)
end

model, y_first = fitresult
Expand All @@ -116,8 +117,8 @@ end
function MMI.predict_mode(mlj_model::CatBoostClassifier, fitresult, X_pool)
if fitresult[1] === nothing
# Return probability 1 for the single class
n = nrow(X_pool)
return hcat(ones(n), zeros(n))
n = pyconvert(Int, X_pool.shape[0])
return fill(fitresult.y_first, n)
end

model, y_first = fitresult
Expand Down
64 changes: 51 additions & 13 deletions src/mlj_serialization.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,41 @@
# Taken from https://github.com/JuliaAI/MLJXGBoostInterface.jl
# It is likely also not the optimal method for serializing models, but it works

"""
_persistent(booster)
_persistent(model::CatBoostModels, fitresult)
Private method.
Return a persistent (ie, Julia-serializable) representation of the
CatBoost.jl model `booster`.
CatBoost.jl model `fitresult`.
Restore the model with [`booster`](@ref)
Restore the model with [`fitresult`](@ref)
"""
function _persistent(booster)
function _persistent(::CatBoostRegressor, fitresult)
ctb_file, io = mktemp()
close(io)

booster.save_model(ctb_file)
fitresult.save_model(ctb_file)
persistent_booster = read(ctb_file)
rm(ctb_file)
return persistent_booster
end
function _persistent(::CatBoostClassifier, fitresult)
model, y_first = fitresult
if model === nothing
# Case 1: Single unique class
return (nothing, fitresult.single_class, y_first)
else
# Case 2: Multiple unique classes
ctb_file, io = mktemp()
close(io)

model.save_model(ctb_file)
persistent_booster = read(ctb_file)
rm(ctb_file)
return (persistent_booster, y_first)
end
end

"""
_booster(persistent)
Expand All @@ -28,24 +45,45 @@ Private method.
Return the CatBoost.jl model which has `persistent` as its persistent
(Julia-serializable) representation. See [`persistent`](@ref) method.
"""
function _booster(persistent)
function _booster(::CatBoostRegressor, persistent)
ctb_file, io = mktemp()
write(io, persistent)
close(io)

booster = catboost.CatBoostRegressor().load_model(ctb_file)

rm(ctb_file)

return booster
end
function _booster(::CatBoostClassifier, persistent)
ctb_file, io = mktemp()
write(io, persistent)
close(io)

booster = catboost.CatBoost().load_model(ctb_file)
booster = catboost.CatBoostClassifier().load_model(ctb_file)

rm(ctb_file)

return booster
end

function MMI.save(::CatBoostModels, fr; kw...)
(booster, a_target_element) = fr
return (_persistent(booster), a_target_element)
function MMI.save(model::CatBoostModels, fitresult; kwargs...)
return _persistent(model, fitresult)
end

function MMI.restore(model::CatBoostRegressor, serializable_fitresult)
return _booster(model, serializable_fitresult)
end

function MMI.restore(::CatBoostModels, fr)
(persistent, a_target_element) = fr
return (_booster(persistent), a_target_element)
function MMI.restore(model::CatBoostClassifier, serializable_fitresult)
if serializable_fitresult[1] === nothing
# Case 1: Single unique class
return (model=nothing, single_class=serializable_fitresult[2],
y_first=serializable_fitresult[3])
else
# Case 2: Multiple unique classes
persistent_booster, y_first = serializable_fitresult
return (_booster(model, persistent_booster), y_first)
end
end
21 changes: 19 additions & 2 deletions test/mlj_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,24 @@
preds = MLJBase.predict(mach, X)
probs = MLJBase.predict_mode(mach, X)

serializable_fitresult = MLJBase.save(mach, mach.fitresult)
serializable_fitresult = MLJBase.save(mach, mach)
restored_fitresult = MLJBase.restore(mach, serializable_fitresult)
end

@testset "CatBoostClassifier - single class" begin
X = (; a=[1, 4, 5, 6], b=[4, 5, 6, 7])
y = [0, 0, 0, 0]

# MLJ Interface
model = CatBoostClassifier(; iterations=5)
mach = machine(model, X, y)
MLJBase.fit!(mach)
preds = MLJBase.predict(mach, X)
println(preds)
probs = MLJBase.predict_mode(mach, X)
println(probs)

serializable_fitresult = MLJBase.save(mach, mach)
restored_fitresult = MLJBase.restore(mach, serializable_fitresult)
end

Expand All @@ -36,7 +53,7 @@
MLJBase.fit!(mach)
preds = MLJBase.predict(mach, X)

serializable_fitresult = MLJBase.save(mach, mach.fitresult)
serializable_fitresult = MLJBase.save(mach, mach)
restored_fitresult = MLJBase.restore(mach, serializable_fitresult)
end

Expand Down

0 comments on commit 13b1919

Please sign in to comment.