Skip to content

Commit

Permalink
[*] use Ensemble interface for common test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitryikh committed Sep 19, 2018
1 parent 986d193 commit 7e8b4dc
Showing 1 changed file with 72 additions and 119 deletions.
191 changes: 72 additions & 119 deletions leaves_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,70 +57,51 @@ func InnerTestLGMSLTR(t *testing.T, nThreads int) {
}

func TestLGHiggs(t *testing.T) {
InnerTestLGHiggs(t, 1)
InnerTestLGHiggs(t, 2)
InnerTestLGHiggs(t, 3)
InnerTestLGHiggs(t, 4)
}

func InnerTestLGHiggs(t *testing.T, nThreads int) {
// loading test data
path := filepath.Join("testdata", "higgs_1000examples_test.libsvm")
reader, err := os.Open(path)
if err != nil {
t.Skipf("Skipping due to absence of %s", path)
}
bufReader := bufio.NewReader(reader)
csrMat, err := CSRMatFromLibsvm(bufReader, 0, true)
if err != nil {
t.Fatal(err)
}
nRows := csrMat.Rows()

filename := "lghiggs_1000examples_true_predictions.txt"
// loading model
path = filepath.Join("testdata", "lghiggs.model")
path := filepath.Join("testdata", "lghiggs.model")
model, err := LGEnsembleFromFile(path)
if err != nil {
t.Fatal(err)
}

// loading true predictions as DenseMat
path = filepath.Join("testdata", "lghiggs_1000examples_true_predictions.txt")
reader, err = os.Open(path)
if err != nil {
t.Skipf("Skipping due to absence of %s", path)
}
bufReader = bufio.NewReader(reader)
truePredictions, err := DenseMatFromCsv(bufReader, 0, false, ",", 0.0)
if err != nil {
t.Fatal(err)
}
const tolerance = 1e-12

predictions := make([]float64, nRows)
model.PredictCSR(csrMat.RowHeaders, csrMat.ColIndexes, csrMat.Values, predictions, 0, nThreads)

// compare results
if err := almostEqualFloat64Slices(truePredictions.Values, predictions, 1e-12); err != nil {
t.Fatalf("different predictions: %s", err.Error())
}
// Dense matrix
InnerTestHiggs(t, model, 1, true, filename, tolerance)
InnerTestHiggs(t, model, 2, true, filename, tolerance)
InnerTestHiggs(t, model, 3, true, filename, tolerance)
InnerTestHiggs(t, model, 4, true, filename, tolerance)

InnerTestHiggs(t, model, 1, false, filename, tolerance)
InnerTestHiggs(t, model, 2, false, filename, tolerance)
InnerTestHiggs(t, model, 3, false, filename, tolerance)
InnerTestHiggs(t, model, 4, false, filename, tolerance)
}

func TestXGHiggs(t *testing.T) {
t.Skip("have mismatch on 45 element")
filename := "xghiggs_1000examples_true_predictions.txt"
// loading model
path := filepath.Join("testdata", "xghiggs.model")
model, err := XGEnsembleFromFile(path)
if err != nil {
t.Skipf("Skipping due to absence of %s", path)
}
const tolerance = 1e-5

// Dense matrix
InnerTestXGHiggs(t, 1, true)
InnerTestXGHiggs(t, 2, true)
InnerTestXGHiggs(t, 3, true)
InnerTestXGHiggs(t, 4, true)

// CSR matrix
InnerTestXGHiggs(t, 1, false)
InnerTestXGHiggs(t, 2, false)
InnerTestXGHiggs(t, 3, false)
InnerTestXGHiggs(t, 4, false)
InnerTestHiggs(t, model, 1, true, filename, tolerance)
InnerTestHiggs(t, model, 2, true, filename, tolerance)
InnerTestHiggs(t, model, 3, true, filename, tolerance)
InnerTestHiggs(t, model, 4, true, filename, tolerance)

InnerTestHiggs(t, model, 1, false, filename, tolerance)
InnerTestHiggs(t, model, 2, false, filename, tolerance)
InnerTestHiggs(t, model, 3, false, filename, tolerance)
InnerTestHiggs(t, model, 4, false, filename, tolerance)
}

func InnerTestXGHiggs(t *testing.T, nThreads int, dense bool) {
func InnerTestHiggs(t *testing.T, model Ensemble, nThreads int, dense bool, truePredictionsFilename string, tolerance float64) {
// loading test data
path := filepath.Join("testdata", "higgs_1000examples_test.libsvm")
reader, err := os.Open(path)
Expand All @@ -145,15 +126,8 @@ func InnerTestXGHiggs(t *testing.T, nThreads int, dense bool) {
nRows = csrMat.Rows()
}

// loading model
path = filepath.Join("testdata", "xghiggs.model")
model, err := XGEnsembleFromFile(path)
if err != nil {
t.Fatal(err)
}

// loading true predictions as DenseMat
path = filepath.Join("testdata", "xghiggs_1000examples_true_predictions.txt")
path = filepath.Join("testdata", truePredictionsFilename)
reader, err = os.Open(path)
if err != nil {
t.Skipf("Skipping due to absence of %s", path)
Expand All @@ -171,7 +145,7 @@ func InnerTestXGHiggs(t *testing.T, nThreads int, dense bool) {
model.PredictCSR(csrMat.RowHeaders, csrMat.ColIndexes, csrMat.Values, predictions, 0, nThreads)
}
// compare results
if err := almostEqualFloat64Slices(truePredictions.Values, predictions, 1e-5); err != nil {
if err := almostEqualFloat64Slices(truePredictions.Values, predictions, tolerance); err != nil {
t.Fatalf("different predictions: %s", err.Error())
}
}
Expand Down Expand Up @@ -213,65 +187,35 @@ func InnerBenchmarkLGMSLTR(b *testing.B, nThreads int) {
}

func BenchmarkLGHiggs_dense_1thread(b *testing.B) {
InnerBenchmarkLGHiggs(b, 1, true)
model, err := LGEnsembleFromFile(filepath.Join("testdata", "lghiggs.model"))
if err != nil {
b.Fatal(err)
}
InnerBenchmarkHiggs(b, model, 1, true)
}

func BenchmarkLGHiggs_dense_4thread(b *testing.B) {
InnerBenchmarkLGHiggs(b, 4, true)
model, err := LGEnsembleFromFile(filepath.Join("testdata", "lghiggs.model"))
if err != nil {
b.Fatal(err)
}
InnerBenchmarkHiggs(b, model, 4, true)
}

func BenchmarkLGHiggs_csr_1thread(b *testing.B) {
InnerBenchmarkLGHiggs(b, 1, false)
}

func BenchmarkLGHiggs_csr_4thread(b *testing.B) {
InnerBenchmarkLGHiggs(b, 4, false)
}

func InnerBenchmarkLGHiggs(b *testing.B, nThreads int, dense bool) {
// loading test data
path := filepath.Join("testdata", "higgs_1000examples_test.libsvm")
reader, err := os.Open(path)
model, err := LGEnsembleFromFile(filepath.Join("testdata", "lghiggs.model"))
if err != nil {
b.Skipf("Skipping due to absence of %s", path)
}
bufReader := bufio.NewReader(reader)
var denseMat DenseMat
var csrMat CSRMat
var nRows uint32
if dense {
denseMat, err = DenseMatFromLibsvm(bufReader, 0, true)
if err != nil {
b.Fatal(err)
}
nRows = denseMat.Rows
} else {
csrMat, err = CSRMatFromLibsvm(bufReader, 0, true)
if err != nil {
b.Fatal(err)
}
nRows = csrMat.Rows()
b.Fatal(err)
}
InnerBenchmarkHiggs(b, model, 1, false)
}

// loading model
path = filepath.Join("testdata", "lghiggs.model")
model, err := LGEnsembleFromFile(path)
func BenchmarkLGHiggs_csr_4thread(b *testing.B) {
model, err := LGEnsembleFromFile(filepath.Join("testdata", "lghiggs.model"))
if err != nil {
b.Fatal(err)
}

// do benchmark
b.ResetTimer()
predictions := make([]float64, nRows)
if dense {
for i := 0; i < b.N; i++ {
model.PredictDense(denseMat.Values, denseMat.Rows, denseMat.Cols, predictions, 0, nThreads)
}
} else {
for i := 0; i < b.N; i++ {
model.PredictCSR(csrMat.RowHeaders, csrMat.ColIndexes, csrMat.Values, predictions, 0, nThreads)
}
}
InnerBenchmarkHiggs(b, model, 4, false)
}

func TestXGAgaricus_1thread(t *testing.T) {
Expand Down Expand Up @@ -327,22 +271,38 @@ func InnerTestXGAgaricus(t *testing.T, nThreads int) {
}

func BenchmarkXGHiggs_dense_1thread(b *testing.B) {
InnerBenchmarkXGHiggs(b, 1, true)
model, err := XGEnsembleFromFile(filepath.Join("testdata", "xghiggs.model"))
if err != nil {
b.Fatal(err)
}
InnerBenchmarkHiggs(b, model, 1, true)
}

func BenchmarkXGHiggs_dense_4thread(b *testing.B) {
InnerBenchmarkXGHiggs(b, 4, true)
model, err := XGEnsembleFromFile(filepath.Join("testdata", "xghiggs.model"))
if err != nil {
b.Fatal(err)
}
InnerBenchmarkHiggs(b, model, 4, true)
}

func BenchmarkXGHiggs_csr_1thread(b *testing.B) {
InnerBenchmarkXGHiggs(b, 1, false)
model, err := XGEnsembleFromFile(filepath.Join("testdata", "xghiggs.model"))
if err != nil {
b.Fatal(err)
}
InnerBenchmarkHiggs(b, model, 1, false)
}

func BenchmarkXGHiggs_csr_4thread(b *testing.B) {
InnerBenchmarkXGHiggs(b, 4, false)
model, err := XGEnsembleFromFile(filepath.Join("testdata", "xghiggs.model"))
if err != nil {
b.Fatal(err)
}
InnerBenchmarkHiggs(b, model, 4, false)
}

func InnerBenchmarkXGHiggs(b *testing.B, nThreads int, dense bool) {
func InnerBenchmarkHiggs(b *testing.B, model Ensemble, nThreads int, dense bool) {
// loading test data
path := filepath.Join("testdata", "higgs_1000examples_test.libsvm")
reader, err := os.Open(path)
Expand All @@ -367,13 +327,6 @@ func InnerBenchmarkXGHiggs(b *testing.B, nThreads int, dense bool) {
nRows = csrMat.Rows()
}

// loading model
path = filepath.Join("testdata", "xghiggs.model")
model, err := XGEnsembleFromFile(path)
if err != nil {
b.Fatal(err)
}

// do benchmark
b.ResetTimer()
predictions := make([]float64, nRows)
Expand Down

0 comments on commit 7e8b4dc

Please sign in to comment.