Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Another shot at fixing regression #71

Merged
merged 6 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ on:
pull_request:
workflow_dispatch:

env:
JULIA_NUM_THREADS: 2

jobs:
Test:
name: Test
Expand All @@ -29,7 +32,7 @@ jobs:
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v1.3.0
- uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/Docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ on:
pull_request:
workflow_dispatch:

env:
JULIA_NUM_THREADS: 2

jobs:
BuildDocs:
permissions:
Expand All @@ -18,7 +21,7 @@ jobs:
- uses: julia-actions/setup-julia@v1
with:
version: '1'
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v1.3.0
with:
cache-name: 'docs'
- uses: julia-actions/julia-buildpkg@v1
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ and 2 classes: [1, 2].
Note: showing only the probability for class 2 since class 1 has probability 1 - p.
```

This is a basic example, in most cases you want to tune the `max_depth`, `max_rules`, and `lambda` hyperparameters.
See `?StableRulesClassifier`, `?StableRulesRegressor`, or the [API documentation](https://sirus.jl.huijzer.xyz/dev/api/) for more information about the models and their hyperparameters.
A full guide through binary classification can be found in the [Simple Binary Classification](https://sirus.jl.huijzer.xyz/dev/binary-classification/) example.

Expand Down
138 changes: 138 additions & 0 deletions debug-regression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
### A Pluto.jl notebook ###
# v0.19.32

using Markdown
using InteractiveUtils

# ╔═╡ 44f5f78b-d317-4662-ab6d-307f0701a11b
using Revise

# ╔═╡ f67cddf9-7304-47e2-aa2f-48ec93275554
root_dir = @__DIR__

# ╔═╡ 106c3646-8d4a-11ee-0601-27fe3abdc084
# ╠═╡ show_logs = false
x = let
using Pkg: Pkg
Pkg.add("TestEnv")
Pkg.activate(root_dir)
using TestEnv
TestEnv.activate()
Pkg.add("CairoMakie")
using CairoMakie
end;

# ╔═╡ ff2b378b-75a6-4e0a-b13b-6bc69442180b
let
x
using CategoricalArrays:
CategoricalValue,
CategoricalVector,
categorical,
unwrap
using CSV: CSV
using DataDeps: DataDeps, DataDep, @datadep_str
using Documenter: DocMeta, doctest
using MLDatasets:
BostonHousing,
Iris,
Titanic
using DataFrames:
DataFrames,
DataFrame,
Not,
dropmissing!,
rename!,
rename,
select
using DecisionTree: DecisionTree
using MLJBase:
CV,
MLJBase,
PerformanceEvaluation,
evaluate,
mode,
fit!,
machine,
make_blobs,
make_moons,
make_regression,
predict
using MLJDecisionTreeInterface: DecisionTreeClassifier, DecisionTreeRegressor
using MLJLinearModels: LogisticClassifier, LinearRegressor, MultinomialClassifier
using MLJTestInterface: MLJTestInterface
using MLJXGBoostInterface: XGBoostClassifier, XGBoostRegressor
using Random: shuffle, seed!
using StableRNGs: StableRNG
using StatisticalMeasures:
accuracy,
auc,
rsq
using SIRUS
using Statistics: mean, var
using Tables: Tables
using Test
end

# ╔═╡ 5e0aa5fc-c78d-4f2c-94dd-c73fbc726a8d
include(joinpath(root_dir, "test/preliminaries.jl"))

# ╔═╡ 14957f30-40c5-4445-8efc-68ea3c8e3e0f
X, y = boston();

# ╔═╡ 24366d12-3df7-4934-b57e-f489e1b76e9a
X

# ╔═╡ 693b7971-878f-40be-9989-1a37f1d8511f
y

# ╔═╡ ee87d4f9-08c6-4d25-ad68-8e69e675d3c7
fr = let
hyper = (; rng=_rng(), max_depth=2, n_trees=100)
measure = accuracy
_evaluate!(results, "iris", StableForestClassifier, hyper; measure)
end

# ╔═╡ ebc2ba6f-a7da-4290-877a-59a40045d021
fr.fitted_params_per_fold[1].fitresult

# ╔═╡ f306ae6f-fa3d-484b-af78-47f012c9ac07
forest = fr.fitted_params_per_fold[2].fitresult

# ╔═╡ 932a9f80-6ec4-41e1-b550-4fca88ecb68a
rr = let
# Increasing max_rules decreases score, which makes no sense.
hyper = (; rng=_rng(), max_depth=2, max_rules=30, lambda=100, q=20)
measure = rsq
_evaluate!(results, "boston", StableRulesRegressor, hyper; measure)
end

# ╔═╡ c9b5c558-30c2-4b53-a10d-4396878d24b0
only(rr.per_fold)

# ╔═╡ 7dc325e0-07b5-4f65-9d4e-ba98c5e38dc5
rr.fitted_params_per_fold[1].fitresult

# ╔═╡ 4af16b84-683f-4a16-abb5-855a98a14263
rr.fitted_params_per_fold[1].fitresult.weights

# ╔═╡ d2d7d82b-18bb-4b7c-ac77-4b8ceaa1a8e9
forest

# ╔═╡ Cell order:
# ╠═44f5f78b-d317-4662-ab6d-307f0701a11b
# ╠═f67cddf9-7304-47e2-aa2f-48ec93275554
# ╠═106c3646-8d4a-11ee-0601-27fe3abdc084
# ╠═5e0aa5fc-c78d-4f2c-94dd-c73fbc726a8d
# ╠═ff2b378b-75a6-4e0a-b13b-6bc69442180b
# ╠═14957f30-40c5-4445-8efc-68ea3c8e3e0f
# ╠═24366d12-3df7-4934-b57e-f489e1b76e9a
# ╠═693b7971-878f-40be-9989-1a37f1d8511f
# ╠═ee87d4f9-08c6-4d25-ad68-8e69e675d3c7
# ╠═ebc2ba6f-a7da-4290-877a-59a40045d021
# ╠═f306ae6f-fa3d-484b-af78-47f012c9ac07
# ╠═932a9f80-6ec4-41e1-b550-4fca88ecb68a
# ╠═c9b5c558-30c2-4b53-a10d-4396878d24b0
# ╠═7dc325e0-07b5-4f65-9d4e-ba98c5e38dc5
# ╠═4af16b84-683f-4a16-abb5-855a98a14263
# ╠═d2d7d82b-18bb-4b7c-ac77-4b8ceaa1a8e9
10 changes: 9 additions & 1 deletion docs/src/basic-example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,18 @@ y = data.survival;
md"""
Next, we can load the model that we want to use.
Since Haberman's outcome column (`survival`) contains 0's and 1's, we use the `StableRulesClassifier`.

Here, the hyperparameters were manually tuned for this dataset.
On this dataset:

- The number of split points `q` was lowered to 4. This means that the random forest can choose from only 4 locations to split the data. With such a low number, the accuracy typically goes down, but the rules become more interpretable.
- The max tree depth `max_depth` was set to 2. This is the highest that SIRUS can go, as is discussed in the original paper. The reason is that rules for greater tree depths do almost certainly not end up in the final rule set, so it's computationally cheaper to not determine them in the first place.
- The `lambda` parameter was set to 1. This parameter is used with a ridge regression to determine the weights in the final rule set. SIRUS is very sensitive to the choice of this hyperparameter. Ensure that you try the full range from 10^-4 to 10^4 (e.g., 0.001, 0.01, ..., 100).
- The max number of rules `max_rules` was set to 8. Typically, the model becomes more accurate with more rules, but less interpretable. When the model becomes less accurate with more rules, then ensure that you have set the right `lambda`.
"""

# ╔═╡ ccce5f3e-e396-4765-bf5f-6f79e905aca8
model = StableRulesClassifier(; rng=StableRNG(1), q=4, max_depth=2, max_rules=8);
model = StableRulesClassifier(; rng=StableRNG(1), q=4, max_depth=2, lambda=1, max_rules=8);

# ╔═╡ 97c9ea2a-2897-472b-b15e-215f40049cf5
md"""
Expand Down
13 changes: 8 additions & 5 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,25 @@ Pretty visual representation of the algorithm via an image that was generated vi
This package is a pure Julia implementation of the **S**table and **I**nterpretable **RU**le **S**ets (SIRUS) algorithm.
The algorithm was originally created by Clément Bénard, Gérard Biau, Sébastien Da Veiga, and Erwan Scornet (Bénard et al., [2021](http://proceedings.mlr.press/v130/benard21a.html)).
`SIRUS.jl` has implemented both classification and regression.
Performance is generally the best on binary classification tasks.
Performance is generally the best on classification tasks; especially when tuning the `max_depth`, `max_rules`, and `lambda` hyperparameters.

For R users, the original version of the SIRUS algorithm is available via [CRAN](https://cran.r-project.org/web/packages/sirus/index.html).
The source code is available under the GPL-3 license at <https://gitlab.com/drti/sirus>.
Compared to the R version, this Julia implementation is more easy to inspect than the original R and C++ implementation.
Compared to the R version, this Julia implementation implements multi-class classification and is more easy to inspect than the original R and C++ implementation.
The original algorithm is implemented in about 10k lines of C++ and 2k lines of R code, whereas the Julia implementation is about 2k lines of pure Julia code.
Furthermore, this implementation is integrated with the `MLJ.jl` machine learning ecosystem.
With this, multiple benchmarks are executed and checked with every test run.
The results are listed in the GitHub Actions summary.

The algorithm is based on random forests.
Random forests perform generally very well; especially on datsets with a relatively high number of features compared to the number of datapoints (Biau & Scornet, [2016](https://doi.org/10.1007/s11749-016-0481-7)).
However, random forests are hard to interpret because of the large number of, sometimes large, trees.
However, random forests are hard to interpret because of the typically thousands of trees in the random forest.
Interpretability methods such as SHAP alleviate this problem slightly, but still do not fully explain predictions.
Put differently, it is not possible to reproduce predictions on the feature importances that SHAP reports.
SIRUS solved this by converting the large number of trees to interpretable rules.
These rules fully explain the predictions while remaining easy to interpret.
Also, interpretability methods convert the complex model to a simplified representation.
This causes the simplified representation to be different from the complex model and may therefore hide biases and issues related to safety and reliability.
SIRUS solves this by simplifying the complex model and then using the simplified model for predictions.
This ensures that the same model is used for interpretation and prediction.

## Where to Start?

Expand Down
2 changes: 1 addition & 1 deletion src/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ function _predict(forest::StableForest, row::AbstractVector)
if forest.algo isa Classification
return _median(predictions)
else
m = median(predictions)
m = mean(predictions)
@assert m isa Number
return m
end
Expand Down
2 changes: 2 additions & 0 deletions src/mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ const RULES_HYPERPARAMETERS_SECTION = """
- `lambda::Float64=$LAMBDA_DEFAULT`:
The weights of the final rules are determined via a regularized regression over each rule as a binary feature.
This hyperparameter specifies the strength of the ridge (L2) regularizer.
SIRUS is very sensitive to the choice of this hyperparameter.
Ensure that you try the full range from 10^-4 to 10^4 (e.g., 0.001, 0.01, ..., 100).
"""

const OPERATIONS_SECTION = """
Expand Down
Loading
Loading