forked from ayyyq/T-LSTM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SMatirxLayer.py
54 lines (45 loc) · 1.91 KB
/
SMatirxLayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data
# graph functional
class SentenceMatrixLayer(nn.Module):
def __init__(self, in_size, out_size=1, p_Asem=0.6):
super(SentenceMatrixLayer, self).__init__()
self.in_size = in_size
self.out_size = out_size
self.p_Asem = p_Asem
self.linear = nn.Linear(in_size * 2, out_size)
def forward(self, x, adj, mask):
# x: [batch, seq_len, embed_dim]
# adj: [batch, seq_len, seq_len], dense
# mask: [batch, seq_len]
# adj is dense batch*node*node*(2*emb)
# 2*emb for cat xi,xj
# new_adj = adj.unsqueeze(-1)
# new_adj = new_adj.expand(new_adj.shape[0], new_adj.shape[1], new_adj.shape[2], x.shape[-1] * 2)
seq_len = x.shape[1]
xi = x.unsqueeze(2).expand(-1, -1, seq_len, -1) # [batch, seq_len, 1, embed_dim]
xj = x.unsqueeze(1).expand(-1, seq_len, -1, -1) # [batch, 1, seq_len, embed_dim]
xij = torch.sigmoid(self.linear(torch.cat((xi, xj), dim=-1))).squeeze(-1) # [batch, seq_len, seq_len]
A_esm = self.p_Asem * xij + (1 - self.p_Asem) * adj
assert mask.shape[1] == seq_len, "seq_len inconsistent"
mask_i = mask.unsqueeze(1).expand(-1, seq_len, -1) # [batch, 1, seq_len]
mask_j = mask.unsqueeze(2).expand(-1, -1, seq_len) # [batch, seq_len, 1]
A_mask = mask_i * mask_j
return A_esm.masked_fill(A_mask == 0, 1e-9) # [batch, seq_len, seq_len]
##test
# edge_index = torch.tensor([[0, 1, 1, 2],
# [1, 0, 2, 1]], dtype=torch.long)
# x = torch.rand((3, 100))
# tri = torch.rand((1, 72))
# data = Data(x=x, edge_index=edge_index)
# device = torch.device('cuda')
# data = data.to(device)
# tri = tri.to(device)
# model = FRGN(100, 1)
# model.cuda()
# test = model(data)
# print(test)