-
Notifications
You must be signed in to change notification settings - Fork 12
/
agd.py
51 lines (37 loc) · 1.38 KB
/
agd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import math
import torch
from torch.optim.optimizer import Optimizer
from torch.nn.init import orthogonal_
def singular_value(p):
sv = math.sqrt(p.shape[0] / p.shape[1])
if p.dim() == 4:
sv /= math.sqrt(p.shape[2] * p.shape[3])
return sv
class AGD(Optimizer):
def __init__(self, net, gain=1.0):
self.net = net
self.depth = len(list(net.parameters()))
self.gain = gain
for p in self.net.parameters():
if p.dim() == 1: raise Exception("Biases are not supported.")
super().__init__(net.parameters(), defaults=dict())
@torch.no_grad()
def init_weights(self):
for p in self.net.parameters():
if p.dim() == 2: orthogonal_(p)
if p.dim() == 4:
for kx in range(p.shape[2]):
for ky in range(p.shape[3]):
orthogonal_(p[:,:,kx,ky])
p *= singular_value(p)
@torch.no_grad()
def step(self):
G = 0
for p in self.net.parameters():
G += singular_value(p) * p.grad.norm(dim=(0,1)).sum()
G /= self.depth
log = math.log(0.5 * (1 + math.sqrt(1 + 4*G)))
for p in self.net.parameters():
factor = singular_value(p) / p.grad.norm(dim=(0,1), keepdim=True)
p -= self.gain * log / self.depth * torch.nan_to_num(factor) * p.grad
return log