-
Notifications
You must be signed in to change notification settings - Fork 3
/
products.py
62 lines (51 loc) · 2.21 KB
/
products.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
def vjp(y, x, v):
"""Computes a vector-jacobian product v^T J, aka Lop (Left Operation).
This is what reverse-mode automatic differentiation directly obtains.
Arguments:
y (torch.tensor): output of differentiated function
x (torch.tensor): differentiated input
v (torch.tensor): vector to be multiplied with Jacobian from the left
"""
return torch.autograd.grad(y, x, v, retain_graph=True)
def jvp(y, x, v):
"""Computes a jacobian-vector product J v, aka Rop (Right Operation)
This is what forward-mode automatic differentiation directly obtains.
It can also be obtained via reverse-mode differentiation using the
trick below.
Arguments:
y (torch.tensor): output of differentiated function
x (torch.tensor): differentiated input
v (torch.tensor): vector to be multiplied with Jacobian from the right
from: https://gist.github.com/apaszke/c7257ac04cb8debb82221764f6d117ad
"""
w = torch.ones_like(y, requires_grad=True)
return torch.autograd.grad(torch.autograd.grad(y, x, w, create_graph=True), w, v)
def jvp_diff(y, x, v):
"""Computes a jacobian-vector product J v, aka Rop (Right Operation)
This is what forward-mode automatic differentiation directly obtains.
The result of the operation can be differentiated.
Arguments:
y (torch.tensor): output of differentiated function
x (torch.tensor): differentiated input
v (torch.tensor): vector to be multiplied with Jacobian from the right
from: https://gist.github.com/apaszke/c7257ac04cb8debb82221764f6d117ad
"""
w = torch.ones_like(y, requires_grad=True)
return torch.autograd.grad(torch.autograd.grad(y, x, w, create_graph=True), w, v, create_graph=True)
def unflatten_like(vector, tensor_lst):
"""
Takes a flat torch.tensor and unflattens it to a list of torch.tensors
shaped like tensor_lst
Arguments:
vector (torch.tensor): flat one dimensional tensor
likeTensorList (list or iterable): list of tensors with same number of ele-
ments as vector
"""
outList = []
i = 0
for tensor in tensor_lst:
n = tensor.numel()
outList.append(vector[i: i + n].view(tensor.shape))
i += n
return outList