diff --git a/src/mlj_catboostregressor.jl b/src/mlj_catboostregressor.jl index a6c083c..e79cfaf 100644 --- a/src/mlj_catboostregressor.jl +++ b/src/mlj_catboostregressor.jl @@ -44,7 +44,7 @@ MMI.@mlj_model mutable struct CatBoostRegressor <: MMI.Deterministic task_type::Union{String,Nothing} = nothing devices::Union{String,Nothing} = nothing bootstrap_type::Union{String,Nothing} = nothing - subsample::Union{Int,Nothing} = nothing + subsample::Union{Float64,Nothing} = nothing sampling_frequency::String = "PerTreeLevel"::(_ in ("PerTree", "PerTreeLevel")) sampling_unit::String = "Object"::(_ in ("Group", "Object")) gpu_cat_features_storage::String = "GpuRam"::(_ in ("CpuPinnedMemory", "GpuRam"))