-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update negative sampling to work directly on GPU #9608
base: master
Are you sure you want to change the base?
Changes from all commits
728de8a
7ba4732
c5c3889
c6db6d7
4d764fa
f399554
7db30a6
10d440c
0926a61
f919b2d
5e99af8
afe67e2
2f3c2e7
d837eb8
6d466cc
a2f5a71
27d6079
92ff538
038f81a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,15 +3,16 @@ | |
from torch_geometric.utils import ( | ||
batched_negative_sampling, | ||
contains_self_loops, | ||
erdos_renyi_graph, | ||
is_undirected, | ||
negative_sampling, | ||
structured_negative_sampling, | ||
structured_negative_sampling_feasible, | ||
to_undirected, | ||
) | ||
from torch_geometric.utils._negative_sampling import ( | ||
edge_index_to_vector, | ||
vector_to_edge_index, | ||
edge_index_to_vector_id, | ||
vector_id_to_edge_index, | ||
) | ||
|
||
|
||
|
@@ -31,35 +32,20 @@ def is_negative(edge_index, neg_edge_index, size, bipartite): | |
|
||
def test_edge_index_to_vector_and_vice_versa(): | ||
# Create a fully-connected graph: | ||
N = 10 | ||
row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1) | ||
col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1) | ||
N1, N2 = 13, 17 | ||
row = torch.arange(N1).view(-1, 1).repeat(1, N2).view(-1) | ||
col = torch.arange(N2).view(1, -1).repeat(N1, 1).view(-1) | ||
edge_index = torch.stack([row, col], dim=0) | ||
|
||
idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True) | ||
assert population == N * N | ||
assert idx.tolist() == list(range(population)) | ||
edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True) | ||
assert is_undirected(edge_index2) | ||
idx = edge_index_to_vector_id(edge_index, (N1, N2)) | ||
assert idx.tolist() == list(range(N1 * N2)) | ||
edge_index2 = torch.stack(vector_id_to_edge_index(idx, (N1, N2)), dim=0) | ||
assert edge_index.tolist() == edge_index2.tolist() | ||
|
||
idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False) | ||
assert population == N * N - N | ||
assert idx.tolist() == list(range(population)) | ||
mask = edge_index[0] != edge_index[1] # Remove self-loops. | ||
edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False) | ||
assert is_undirected(edge_index2) | ||
assert edge_index[:, mask].tolist() == edge_index2.tolist() | ||
|
||
idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False, | ||
force_undirected=True) | ||
assert population == (N * (N + 1)) / 2 - N | ||
assert idx.tolist() == list(range(population)) | ||
mask = edge_index[0] != edge_index[1] # Remove self-loops. | ||
edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False, | ||
force_undirected=True) | ||
assert is_undirected(edge_index2) | ||
assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist() | ||
vector_id = torch.arange(N1 * N2) | ||
edge_index3 = torch.stack(vector_id_to_edge_index(vector_id, (N1, N2)), | ||
dim=0) | ||
assert edge_index.tolist() == edge_index3.tolist() | ||
|
||
|
||
def test_negative_sampling(): | ||
|
@@ -69,10 +55,6 @@ def test_negative_sampling(): | |
assert neg_edge_index.size(1) == edge_index.size(1) | ||
assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) | ||
|
||
neg_edge_index = negative_sampling(edge_index, method='dense') | ||
assert neg_edge_index.size(1) == edge_index.size(1) | ||
assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) | ||
Comment on lines
-72
to
-74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why drop this test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because the method is now inferred automatically based on the graph size. Since the graph used for the test is small, the method is always dense. This reflects the idea that a sparse method (which is based on a random guessing of the negative edges) is reasonable only when the graph is spare ( To test both the sparse and the dense method, I added a new test called |
||
|
||
neg_edge_index = negative_sampling(edge_index, num_neg_samples=2) | ||
assert neg_edge_index.size(1) == 2 | ||
assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) | ||
|
@@ -97,6 +79,30 @@ def test_bipartite_negative_sampling(): | |
assert is_negative(edge_index, neg_edge_index, (3, 4), bipartite=True) | ||
|
||
|
||
def test_negative_sampling_with_different_edge_density(): | ||
for num_nodes in [10, 100, 1000]: | ||
for p in [0.1, 0.3, 0.5, 0.8]: | ||
for is_directed in [False, True]: | ||
edge_index = erdos_renyi_graph(num_nodes, p, is_directed) | ||
neg_edge_index = negative_sampling( | ||
edge_index, num_nodes, force_undirected=not is_directed) | ||
assert is_negative(edge_index, neg_edge_index, | ||
(num_nodes, num_nodes), bipartite=False) | ||
|
||
|
||
def test_bipartite_negative_sampling_with_different_edge_density(): | ||
for num_nodes in [10, 100, 1000]: | ||
for p in [0.1, 0.3, 0.5, 0.8]: | ||
size = (num_nodes, int(num_nodes * 1.2)) | ||
n_edges = int(p * size[0] * size[1]) | ||
row, col = torch.randint(size[0], (n_edges, )), torch.randint( | ||
size[1], (n_edges, )) | ||
edge_index = torch.stack([row, col], dim=0) | ||
neg_edge_index = negative_sampling(edge_index, size) | ||
assert is_negative(edge_index, neg_edge_index, size, | ||
bipartite=True) | ||
|
||
|
||
def test_batched_negative_sampling(): | ||
edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) | ||
edge_index = torch.cat([edge_index, edge_index + 4], dim=1) | ||
|
@@ -153,16 +159,71 @@ def test_structured_negative_sampling(): | |
assert (adj & neg_adj).sum() == 0 | ||
|
||
# Test with no self-loops: | ||
edge_index = torch.LongTensor([[0, 0, 1, 1, 2], [1, 2, 0, 2, 1]]) | ||
i, j, k = structured_negative_sampling(edge_index, num_nodes=4, | ||
contains_neg_self_loops=False) | ||
neg_edge_index = torch.vstack([i, k]) | ||
assert not contains_self_loops(neg_edge_index) | ||
|
||
|
||
def test_structured_negative_sampling_sparse(): | ||
num_nodes = 1000 | ||
edge_index = erdos_renyi_graph(num_nodes, 0.1) | ||
|
||
i, j, k = structured_negative_sampling(edge_index, num_nodes=num_nodes, | ||
contains_neg_self_loops=True) | ||
assert i.size(0) == edge_index.size(1) | ||
assert j.size(0) == edge_index.size(1) | ||
assert k.size(0) == edge_index.size(1) | ||
|
||
assert torch.all(torch.ne(k, -1)) | ||
adj = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) | ||
adj[i, j] = 1 | ||
|
||
neg_adj = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) | ||
neg_adj[i, k] = 1 | ||
assert (adj & neg_adj).sum() == 0 | ||
|
||
i, j, k = structured_negative_sampling(edge_index, num_nodes=num_nodes, | ||
contains_neg_self_loops=False) | ||
assert i.size(0) == edge_index.size(1) | ||
assert j.size(0) == edge_index.size(1) | ||
assert k.size(0) == edge_index.size(1) | ||
|
||
assert torch.all(torch.ne(k, -1)) | ||
adj = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) | ||
adj[i, j] = 1 | ||
|
||
neg_adj = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) | ||
neg_adj[i, k] = 1 | ||
assert (adj & neg_adj).sum() == 0 | ||
|
||
neg_edge_index = torch.vstack([i, k]) | ||
assert not contains_self_loops(neg_edge_index) | ||
|
||
|
||
def test_structured_negative_sampling_feasible(): | ||
edge_index = torch.LongTensor([[0, 0, 1, 1, 2, 2, 2], | ||
[1, 2, 0, 2, 0, 1, 1]]) | ||
assert not structured_negative_sampling_feasible(edge_index, 3, False) | ||
assert structured_negative_sampling_feasible(edge_index, 3, True) | ||
assert structured_negative_sampling_feasible(edge_index, 4, False) | ||
def create_ring_graph(num_nodes): | ||
forward_edges = torch.stack([ | ||
torch.arange(0, num_nodes, dtype=torch.long), | ||
(torch.arange(0, num_nodes, dtype=torch.long) + 1) % num_nodes | ||
], dim=0) | ||
backward_edges = torch.stack([ | ||
torch.arange(0, num_nodes, dtype=torch.long), | ||
(torch.arange(0, num_nodes, dtype=torch.long) - 1) % num_nodes | ||
], dim=0) | ||
return torch.concat([forward_edges, backward_edges], dim=1) | ||
|
||
# ring 3 is always unfeasible | ||
ring_3 = create_ring_graph(3) | ||
assert not structured_negative_sampling_feasible(ring_3, 3, False) | ||
assert not structured_negative_sampling_feasible(ring_3, 3, True) | ||
|
||
# ring 4 is feasible only if we consider self loops | ||
ring_4 = create_ring_graph(4) | ||
assert not structured_negative_sampling_feasible(ring_4, 4, False) | ||
assert structured_negative_sampling_feasible(ring_4, 4, True) | ||
|
||
# ring 5 is always feasible | ||
ring_5 = create_ring_graph(5) | ||
assert structured_negative_sampling_feasible(ring_5, 5, False) | ||
assert structured_negative_sampling_feasible(ring_5, 5, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decorate with
@withCUDA
to test on cpu and gpu.See this for an example.