-
Notifications
You must be signed in to change notification settings - Fork 2
/
guard.py
228 lines (186 loc) · 7.45 KB
/
guard.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import dgl
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from graphattack import Surrogater
from graphattack.models import SGC, GCN
from graphattack.utils import remove_edges
from typing import Union
class UniversalDefense(torch.nn.Module):
def __init__(self, device: str = "cpu"):
super().__init__()
self.device = torch.device(device)
self._anchors = None
def forward(self, g: dgl.DGLGraph, target_nodes: Union[int, Tensor],
k: int = 50, symmetric: bool = True) -> dgl.DGLGraph:
"""return the defended graph with defensive perturbation on the clean graph
Parameters
----------
g : dgl.DGLGraph
the graph where the defensive perturbation performed on
target_nodes : Union[int, Tensor]
the target nodes where the defensive perturbation performed on
k : int
the number of anchor nodes in the defensive perturbation, by default 50
symmetric : bool
Determine whether the resulting graph is forcibly symmetric,
by default True
Returns
-------
dgl.DGLGraph
the defended graph with defensive perturbation performed on the target nodes
"""
edges = self.removed_edges(target_nodes, k)
return remove_edges(g, edges, symmetric=symmetric)
def removed_edges(self, target_nodes: Union[int, Tensor], k: int = 50) -> Tensor:
"""return edges to remove with the defensive perturbation performed on
on the target nodes
Parameters
----------
target_nodes : Union[int, Tensor]
the target nodes where the defensive perturbation performed on
k : int
the number of anchor nodes in the defensive perturbation, by default 50
Returns
-------
Tensor, shape [2, k]
the edges to remove with the defensive perturbation performed on
on the target nodes
"""
row = torch.as_tensor(target_nodes, device=self.device).view(-1)
col = self.anchors(k)
row, col = row.repeat_interleave(k), col.repeat(row.size(0))
return torch.stack([row, col], dim=0)
def anchors(self, k: int = 50) -> Tensor:
"""return the top-k anchor nodes
Parameters
----------
k : int, optional
the number of anchor nodes in the defensive perturbation, by default 50
Returns
-------
Tensor
the top-k anchor nodes
"""
assert k > 0
return self._anchors[:k]
def patch(self, k=50) -> Tensor:
"""return the universal patch of the defensive perturbation
Parameters
----------
k : int, optional
the number of anchor nodes in the defensive perturbation, by default 50
Returns
-------
Tensor
the 0-1 (boolean) universal patch where 1 denots the edges to be removed.
"""
_patch = torch.zeros(self.num_nodes, dtype=torch.bool, device=self.device)
_patch[self.anchors(k=k)] = True
return _patch
class GUARD(UniversalDefense, Surrogater):
"""Graph Universal Adversarial Defense (GUARD)
Example
-------
>>> g = ... # DGLGraph
>>> splits = ... # node splits
>>> surrogate = GCN(num_feats, num_classes, bias=False, acts=None)
>>> surrogate_trainer = Trainer(surrogate, device=device)
>>> surrogate_trainer.fit(g, y_train, splits.train_nodes)
>>> surrogate_trainer.evaluate(g, y_test, splits.test_nodes)
>>> guard = GUARD(g.ndata['feat'], g.in_degrees())
>>> target_node = 1
>>> guard(target_node, g, k=50)
"""
def __init__(self, feat: Tensor, degree: Tensor, alpha: float = 2,
batch_size: int = 512, device: str = "cpu"):
super().__init__(device=device)
self.feat = feat.to(self.device)
self.degree = degree.to(self.feat)
self.alpha = alpha
self.batch_size = batch_size
def setup_surrogate(self, surrogate: torch.nn.Module,
victim_labels: Tensor) -> "GUARD":
Surrogater.setup_surrogate(self, surrogate=surrogate,
freeze=True, required=(SGC, GCN))
W = None
for para in self.surrogate.parameters():
if para.ndim == 1:
continue
if W is None:
W = para.detach()
else:
W = W @ para.detach()
W = self.feat @ W
d = self.degree.clamp(min=1)
loader = DataLoader(victim_labels, pin_memory=False,
batch_size=self.batch_size, shuffle=False)
w_max = W.max(1).values
I = 0.
for y in loader:
I += W[:, y].sum(1)
I = (w_max - I / victim_labels.size(0)) / d.pow(self.alpha) # node importance
self._anchors = torch.argsort(I, descending=True)
return self
class DegreeGUARD(UniversalDefense):
"""Graph Universal Defense based on node degrees
Example
-------
>>> g = ...
>>> guard = DegreeGUARD(g.in_degrees())
>>> target_node = 1
>>> guard(target_node, g, k=50)
"""
def __init__(self, degree: Tensor, descending=False, device: str = "cpu"):
super().__init__(device=device)
self._anchors = torch.argsort(degree.to(self.device), descending=descending)
class RandomGUARD(UniversalDefense):
"""Graph Universal Defense based on random choice
Example
-------
>>> g = ...
>>> guard = RandomGUARD(g.num_nodes())
>>> target_node = 1
>>> guard(target_node, g, k=50)
"""
def __init__(self, num_nodes: int, device: str = "cpu"):
super().__init__(device=device)
self.num_nodes = num_nodes
self._anchors = torch.randperm(self.num_nodes, device=self.device)
if __name__ == '__main__':
import torch
from graphattack.data import GraphAttackDataset
from graphattack.training import Trainer
from graphattack.training.callbacks import ModelCheckpoint
from graphattack.models import GCN, SGC
from graphattack.utils import split_nodes
data = GraphAttackDataset('cora', verbose=True, standardize=True)
g = data[0]
y = g.ndata['label']
splits = split_nodes(y, random_state=15)
num_feats = g.ndata['feat'].size(1)
num_classes = data.num_classes
y_train = y[splits.train_nodes]
y_val = y[splits.val_nodes]
y_test = y[splits.test_nodes]
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
g = g.to(device)
defense = 'GUARD'
if defense == "GUARD":
surrogate = GCN(num_feats, num_classes, bias=False, acts=None)
surrogate_trainer = Trainer(surrogate, device=device)
cb = ModelCheckpoint('guard.pth', monitor='val_accuracy')
surrogate_trainer.fit(g, y_train, splits.train_nodes, val_y=y_val,
val_index=splits.val_nodes, callbacks=[cb], verbose=0)
guard = GUARD(g.ndata['feat'], g.in_degrees(), device=device)
guard.setup_surrogate(surrogate, y_train)
elif defense == "RandomGUARD":
guard = RandomGUARD(g.num_nodes(), device=device)
elif defense == "DegreeGUARD":
guard = DegreeGUARD(g.in_degrees(), device=device)
else:
raise ValueError(f"Unknown defense {defense}")
# get a defensed graph
defense_g = guard(g, target_nodes=1, k=50)
# get anchors nodes (potential attacker nodes)
anchors = guard.anchors(k=50)