forked from BUPT-GAMMA/OpenHGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DMGI.py
401 lines (316 loc) · 14.4 KB
/
DMGI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import dgl
from . import register_model, BaseModel
import torch.nn as nn
import numpy as np
import dgl.nn.pytorch as dglnn
import torch
import torch.nn.functional as F
@register_model('DMGI')
class DMGI(BaseModel):
r"""
Description
-----------
**Title:** Unsupervised Attributed Multiplex Network Embedding
**Authors:** Chanyoung Park, Donghyun Kim, Jiawei Han, Hwanjo Yu
DMGI was introduced in `[paper] <https://ojs.aaai.org//index.php/AAAI/article/view/5985>`_
and parameters are defined as follows:
Input
------
meta_paths : dict
Extract metapaths from graph
sc : int
Introducing a weight to self-connections
category : string
The category of the nodes to be classificated
in_size : int
Input feature size
hidden_dim : int
Hidden units size
dropout : float
Dropout rate on feature. Defaults: ``0.5``.
num_nodes : int
The number of all nodes of category in graph
num_classes : int
The numbers of category's types
isBias :bool
If True, adds a learnable bias to the output.Defaults: ``False``.
isAttn : bool
If True, adopt the attention mechanism to calculate loss . Defaults: ``False``.
isSemi : bool
If True, add isSemi's loss to calculate loss
Parameters
----------
H : torch.FloatTensor
The learnable weight tensor.
gcn : The encoder is a single-layer GCN:
.. math::
\begin{equation}
\mathbf{H}^{(r)}=g_{r}\left(\mathbf{X}, \mathbf{A}^{(r)} \mid \mathbf{W}^{(r)}\right)=\sigma\left(\hat{\mathbf{D}}_{r}^{-\frac{1}{2}} \hat{\mathbf{A}}^{(r)} \hat{\mathbf{D}}_{r}^{-\frac{1}{2}} \mathbf{X} \mathbf{W}^{(r)}\right)
\end{equation}
where :math:`\hat{\mathbf{A}}^{(r)}=\mathbf{A}^{(r)}+w \mathbf{I}_{n}` ,
:math:`\hat{D}_{i i}=\sum_{j} \hat{A}_{i j}`
"""
@classmethod
def build_model_from_args(cls, args, hg):
etypes = hg.canonical_etypes
mps = []
for etype in etypes:
if etype[0] == args.category:
for dst_e in etypes:
if etype[0] == dst_e[2] and etype[2] == dst_e[0]:
if etype[0] != etype[2]:
mps.append([etype, dst_e])
num_nodes = hg.num_nodes(args.category)
return cls(meta_paths=mps, sc=args.sc,
category=args.category, in_size=args.in_dim,
hidden_dim=args.hidden_dim, nheads=args.num_heads,dropout=args.dropout,
num_nodes=num_nodes, num_classes=args.num_classes,
isSemi=args.isSemi,isAttn=args.isAttn, isBias=args.isBias)
def __init__(self, meta_paths, sc, category, in_size, hidden_dim, nheads,
dropout, num_nodes, num_classes, isBias, isAttn, isSemi):
super(DMGI, self).__init__()
self.category = category
# self.layers = nn.ModuleList()
self.hidden_dim = hidden_dim
self.meta_paths = meta_paths
self.nheads = nheads
self.isAttn = isAttn
self.isSemi = isSemi
self.sc = sc
self.gcn = nn.ModuleList([dglnn.GraphConv(in_feats=in_size,
out_feats=hidden_dim,
activation=nn.ReLU(),
bias=isBias,
allow_zero_in_degree=True) for _ in range(len(meta_paths))])
self.disc = Discriminator(hidden_dim)
self.readout = AvgReadout()
self.readout_act_func = nn.Sigmoid()
self.dropout = dropout
self.num_nodes = num_nodes
# num_head = 1
self.H = nn.Parameter(torch.FloatTensor(1, num_nodes, hidden_dim))
self.logistic = LogReg(hidden_dim, num_classes)
if self.isAttn:
self.attn = nn.ModuleList(Attention(hid_units=hidden_dim,
num_mps=len(meta_paths),
num_ndoes=num_nodes) for _ in range(nheads))
# self.attn = Attention(hid_units=hid_unit, num_mps=len(meta_paths), num_ndoes=num_nodes)
self.init_weight()
print("category:{}, category's classes:{}, isBias:{},"
" isAttn:{}, isSemi:{}".format(category, num_classes,isBias,isAttn,isSemi))
def init_weight(self):
nn.init.xavier_normal_(self.H)
# samp_bias1, samp_bias2 default None
def forward(self, hg, samp_bias1=None, samp_bias2=None):
r"""
The forward part of DMGI
Parameters
----------
hg : object
the dgl heterogeneous graph
Returns
-------
dict
The predicted logit, and reg_loss and semi_loss
The formula to compute the relation-type specific cross entropy :math:`\mathcal{L}^{(r)}`
.. math::
\begin{equation}
\mathcal{L}^{(r)}=\sum_{v_{i} \in \mathcal{V}}^{n} \log \mathcal{D}\left(\mathbf{h}_{i}^{(r)}, \mathbf{s}^{(r)}\right)+\sum_{j=1}^{n} \log \left(1-\mathcal{D}\left(\tilde{\mathbf{h}}_{j}^{(r)}, \mathbf{s}^{(r)}\right)\right)
\end{equation}
where :math:`h_{i}^{(r)}` is calculate by :math:`\mathbf{h}_{i}=\sigma\left(\sum_{j \in N(i)} \frac{1}{c_{i j}} \mathbf{x}_{j} \mathbf{W}\right)` ,
:math:`s^{(r)}` is :math:`\mathbf{s}^{(r)}=\operatorname{Readout}\left(\mathbf{H}^{(r)}\right)=\sigma\left(\frac{1}{n} \sum_{i=1}^{n} \mathbf{h}_{i}^{(r)}\right)` .
:math:`\mathcal{D}` is a discriminator that scores patchsummary representation pairs
:math:`\tilde{\mathbf{h}}_{j}^{(r)}` corrupt the original attribute matrix by shuffling it.
"""
h_1_all = [];h_2_all = [];c_all = [];logits = []
result = {}
# process features
features = hg.srcdata['h']
feats = self.normal_feat(features, self.meta_paths)
# shuffled features
shuf_feats = self.shuf_feats(feats)
for idx, meta_path in enumerate(self.meta_paths):
new_g = dgl.metapath_reachable_graph(hg, meta_path)
for i in range(self.sc):
new_g = dgl.add_self_loop(new_g)
feats[idx] = F.dropout(feats[idx], self.dropout, training=self.training)
shuf_feats[idx] = F.dropout(shuf_feats[idx], self.dropout, training=self.training)
h_1 = self.gcn[idx](new_g, feats[idx])
c = self.readout(h_1)
c = self.readout_act_func(c)
h_2 = self.gcn[idx](new_g, shuf_feats[idx])
logit = self.disc(c, h_1, h_2, samp_bias1, samp_bias2)
h_1_all.append(h_1.unsqueeze(0))
h_2_all.append(h_2.unsqueeze(0))
c_all.append(c)
logits.append(logit)
result['logits'] = logits
# Attention or not
if self.isAttn:
r"""
.. math::
\begin{equation}
\mathbf{h}_{i}=\mathcal{Q}\left(\left\{\mathbf{h}^{(r)} \mid r \in \mathcal{R}\right\}\right)=\sum_{r \in \mathcal{R}} a_{i}^{(r)} \mathbf{h}^{(r)}
\end{equation}
where :math:`a_{i}^{(r)}` denotes the importance of relationr in generating the final embedding of node videfined as:
.. math::
\begin{equation}
a_{i}^{(r)}=\frac{\exp \left(\mathbf{q}^{(r)} \cdot \mathbf{h}_{i}^{(r)}\right)}{\sum_{r^{\prime} \in \mathcal{R}} \exp \left(\mathbf{q}^{\left(r^{\prime}\right)} \cdot \mathbf{h}_{i}^{r^{\prime}}\right)}
\end{equation}
"""
h_1_all_lst = [];h_2_all_lst = [];c_all_lst = []
for h_idx in range(self.nheads):
h_1_all_, h_2_all_, c_all_ = self.attn[h_idx](h_1_all, h_2_all, c_all)
h_1_all_lst.append(h_1_all_);h_2_all_lst.append(h_2_all_); c_all_lst.append(c_all_)
h_1_all = torch.mean(torch.cat(h_1_all_lst, 0), 0).unsqueeze(0)
h_2_all = torch.mean(torch.cat(h_2_all_lst, 0), 0).unsqueeze(0)
else:
h_1_all = torch.mean(torch.cat(h_1_all, 0), 0).unsqueeze(0)
h_2_all = torch.mean(torch.cat(h_2_all, 0), 0).unsqueeze(0)
# Lcs = [Z − AVG { H(r)|r∈ R }]^2 - [Z − AVG { ~H(r)|r∈ R }]^2
pos_reg_loss = ((self.H - h_1_all) ** 2).sum()
neg_reg_loss = ((self.H - h_2_all) ** 2).sum()
reg_loss = pos_reg_loss - neg_reg_loss
result['reg_loss'] = reg_loss
# semi-supervised module
if self.isSemi:
r"""
Extension to Semi-Supervised Learning
.. math::
\begin{equation}
\ell_{\text {sup }}=-\frac{1}{\left|\mathcal{Y}_{L}\right|} \sum_{l \in \mathcal{Y}_{L}} \sum_{i=1}^{c} Y_{l i} \ln \hat{Y}_{l i}
\end{equation}
Where :math:`mathcal{Y}_{L}` is the set of node indices with labels
"""
semi = self.logistic(self.H).squeeze(0)
result['semi'] = semi
# result: ['logits','reg_loss','semi']
return result
'''feature_normalize'''
def normal_feat(self, feats, meta_paths):
feat = []
feats = feats[self.category].data
for mp in meta_paths:
rowsum = feats.sum(1)
r_inv = torch.pow(rowsum, -1).flatten()
r_inv[torch.isinf(r_inv)] = 0.
r_mat_inv = torch.diag(r_inv)
feats = torch.spmm(r_mat_inv, feats)
feat.append(feats)
return feat
'''corrupt the original attribute matrix by shuffling it'''
def shuf_feats(self, feats):
shuf_feats = []
for feat in feats:
idx = np.random.permutation(feat.shape[0])
shuf = feat[idx]
shuf_feats.append(shuf)
return shuf_feats
'''In the experiments, some relation type is more beneficial for a
certain downstream task than others. Therefore, we can adopt the
attention mechanism'''
class Attention(nn.Module):
def __init__(self, hidden_dim, num_mps, num_ndoes):
super(Attention, self).__init__()
self.num_mps = num_mps
self.hidden_dim = hidden_dim
self.num_nodes = num_ndoes
self.A = nn.ModuleList([nn.Linear(hidden_dim, 1) for _ in range(num_mps)])
self.weight_init()
def weight_init(self):
for i in range(self.num_mps):
nn.init.xavier_normal_(self.A[i].weight)
self.A[i].bias.data.fill_(0.0)
def forward(self, feat_pos, feat_neg, summary):
feat_pos, feat_pos_attn = self.attn_feature(feat_pos)
feat_neg, feat_neg_attn = self.attn_feature(feat_neg)
summary, summary_attn = self.attn_summary(summary)
return feat_pos, feat_neg, summary
def attn_feature(self, features):
features_attn = []
for i in range(self.num_mps):
features_attn.append((self.A[i](features[i].squeeze())))
features_attn = F.softmax(torch.cat(features_attn, 1), -1)
features = torch.cat(features,1).squeeze(0)
features_attn_reshaped = features_attn.transpose(1, 0).contiguous().view(-1, 1)
features = features * features_attn_reshaped.expand_as(features)
features = features.view(self.num_mps, self.num_nodes, self.hid_units).sum(0).unsqueeze(0)
return features, features_attn
def attn_summary(self, features):
features_attn = []
for i in range(self.num_mps):
features_attn.append((self.A[i](features[i].squeeze())))
features_attn = F.softmax(torch.cat(features_attn), dim=-1).unsqueeze(1)
features = torch.stack(features, 0)
features_attn_expanded = features_attn.expand_as(features)
features = (features * features_attn_expanded).sum(0).unsqueeze(0)
return features, features_attn
'''
D is a discriminator that scores patchsummary representation pairs.
In this paper, we apply a simple bilinear scoring function as it
empirically performs the best in our experiments:'''
class Discriminator(nn.Module):
r"""
The discriminator
.. math::
\begin{equation}
\mathcal{D}\left(\mathbf{h}_{i}^{(r)}, \mathbf{s}^{(r)}\right)=\sigma\left(\mathbf{h}_{i}^{(r) T} \mathbf{M}^{(r)} \mathbf{s}^{(r)}\right)
\end{equation}
where :math:`M^{(r)}` is a trainable scoring matrix.
"""
def __init__(self, n_h):
super(Discriminator, self).__init__()
self.f_k_bilinear = nn.Bilinear(n_h, n_h, 1)
for m in self.modules():
self.weights_init(m)
def weights_init(self, m):
if isinstance(m, nn.Bilinear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None):
c_x = c.expand_as(h_pl)
sc_1 = torch.squeeze(self.f_k_bilinear(h_pl, c_x), 1) # sc_1 = 1 x nb_nodes
sc_2 = torch.squeeze(self.f_k_bilinear(h_mi, c_x), 1) # sc_2 = 1 x nb_nodes
if s_bias1 is not None:
sc_1 += s_bias1
if s_bias2 is not None:
sc_2 += s_bias2
logits = torch.cat((sc_1, sc_2), 0)
return logits
'''considering the efficiency of the method, we simply employ average pooling'''
class AvgReadout(nn.Module):
r"""
Considering the efficiency of the method, we simply employ average pooling, computing the average of the set of embedding matrices
.. math::
\begin{equation}
\mathbf{H}=\mathcal{Q}\left(\left\{\mathbf{H}^{(r)} \mid r \in \mathcal{R}\right\}\right)=\frac{1}{|\mathcal{R}|} \sum_{r \in \mathcal{R}} \mathbf{H}^{(r)}
\end{equation}
"""
def __init__(self):
super(AvgReadout, self).__init__()
def forward(self, seq):
return torch.mean(seq, 0)
'''logreg'''
class LogReg(nn.Module):
r"""
Parameters
----------
ft_in : int
Size of hid_units
nb_class : int
The number of category's types
"""
def __init__(self, ft_in, nb_classes):
super(LogReg, self).__init__()
self.fc = nn.Linear(ft_in, nb_classes)
for m in self.modules():
self.weights_init(m)
def weights_init(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
def forward(self, seq):
ret = self.fc(seq)
return ret