forked from Yancey1989/gotorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist.go
60 lines (53 loc) · 1.52 KB
/
mnist.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
package main
import (
torch "github.com/wangkuiyi/gotorch"
F "github.com/wangkuiyi/gotorch/nn/functional"
"github.com/wangkuiyi/gotorch/nn/initializer"
"github.com/wangkuiyi/gotorch/vision/datasets"
"github.com/wangkuiyi/gotorch/vision/models"
"github.com/wangkuiyi/gotorch/vision/transforms"
"log"
"time"
)
func main() {
var device torch.Device
if torch.IsCUDAAvailable() {
log.Println("CUDA is valid")
device = torch.NewDevice("cuda")
} else {
log.Println("No CUDA found; CPU only")
device = torch.NewDevice("cpu")
}
initializer.ManualSeed(1)
mnist := datasets.MNIST("",
[]transforms.Transform{transforms.Normalize(0.1307, 0.3081)})
net := models.MLP()
net.ZeroGrad()
net.To(device)
opt := torch.SGD(0.01, 0.5, 0, 0, false)
opt.AddParameters(net.Parameters())
epochs := 2
startTime := time.Now()
var lastLoss float32
iters := 0
for epoch := 0; epoch < epochs; epoch++ {
trainLoader := datasets.NewMNISTLoader(mnist, 64)
for trainLoader.Scan() {
batch := trainLoader.Batch()
data, target := batch.Data.To(device, batch.Data.Dtype()), batch.Target.To(device, batch.Target.Dtype())
opt.ZeroGrad()
pred := net.Forward(data)
loss := F.NllLoss(pred, target, torch.Tensor{}, -100, "mean")
loss.Backward()
opt.Step()
lastLoss = loss.Item()
iters++
}
log.Printf("Epoch: %d, Loss: %.4f", epoch, lastLoss)
trainLoader.Close()
}
throughput := float64(60000*epochs) / time.Since(startTime).Seconds()
log.Printf("Throughput: %f samples/sec", throughput)
mnist.Close()
torch.FinishGC()
}