-
Notifications
You must be signed in to change notification settings - Fork 73
/
xgblinear_io.go
59 lines (51 loc) · 1.64 KB
/
xgblinear_io.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
package leaves
import (
"bufio"
"fmt"
"os"
"github.com/dmitryikh/leaves/internal/xgbin"
"github.com/dmitryikh/leaves/transformation"
)
// XGBLinearFromReader reads XGBoost's 'gblinear' model from `reader`
func XGBLinearFromReader(reader *bufio.Reader, loadTransformation bool) (*Ensemble, error) {
e := &xgLinear{}
// reading header info
header, err := xgbin.ReadModelHeader(reader)
if err != nil {
return nil, err
}
if header.NameGbm != "gblinear" {
return nil, fmt.Errorf("only gblinear is supported (got %s). Use XGEnsembleFrom.. for gbtree", header.NameGbm)
}
if header.Param.NumFeatures == 0 {
return nil, fmt.Errorf("zero number of features")
}
e.BaseScore = float64(header.Param.BaseScore)
gbLinearModel, err := xgbin.ReadGBLinearModel(reader)
if err != nil {
return nil, err
}
e.nRawOutputGroups = int(gbLinearModel.Param.NumOutputGroup)
e.NumFeature = int(gbLinearModel.Param.NumFeature)
e.Weights = gbLinearModel.Weights
var transform transformation.Transform
transform = &transformation.TransformRaw{e.nRawOutputGroups}
if loadTransformation {
if header.NameObj == "binary:logistic" {
transform = &transformation.TransformLogistic{}
} else {
return nil, fmt.Errorf("unknown transformation function '%s'", header.NameObj)
}
}
return &Ensemble{e, transform}, nil
}
// XGBLinearFromFile reads XGBoost's 'gblinear' model from binary file
func XGBLinearFromFile(filename string, loadTransformation bool) (*Ensemble, error) {
reader, err := os.Open(filename)
if err != nil {
return nil, err
}
defer reader.Close()
bufReader := bufio.NewReader(reader)
return XGBLinearFromReader(bufReader, loadTransformation)
}