Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update negative sampling to work directly on GPU #9608

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
728de8a
Update negative sampling to work directly on GPU
danielecastellana22 Aug 19, 2024
7ba4732
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
c5c3889
Add compatibility to Python 3.8
danielecastellana22 Aug 22, 2024
c6db6d7
Correct structured_edge_sampling feasibility test
danielecastellana22 Aug 22, 2024
4d764fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2024
f399554
Merge branch 'master' into negative_sampling_on_GPU
danielecastellana22 Aug 22, 2024
7db30a6
Merge branch 'master' into negative_sampling_on_GPU
danielecastellana22 Sep 5, 2024
10d440c
Add all type annotations
danielecastellana22 Sep 5, 2024
0926a61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
f919b2d
Add check for structured sampling feasibility. If the method fails to…
danielecastellana22 Sep 5, 2024
5e99af8
Change the deafult method of negative_sampling to "auto" in RandomLin…
danielecastellana22 Sep 5, 2024
afe67e2
test_add_random_edge was based on fixing the seed. Now it checks that…
danielecastellana22 Sep 5, 2024
2f3c2e7
The generation of the graph test_signed_gcn was based on randint. We …
danielecastellana22 Sep 5, 2024
d837eb8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
6d466cc
Adjust the number of trials for sparse negative sampling.
danielecastellana22 Sep 5, 2024
a2f5a71
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
27d6079
Adjust PEP8
danielecastellana22 Sep 5, 2024
92ff538
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
038f81a
lint issue fix.
wsad1 Sep 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions test/nn/models/test_signed_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
def test_signed_gcn():
model = SignedGCN(8, 16, num_layers=2, lamb=5)
assert str(model) == 'SignedGCN(8, 16, num_layers=2)'

pos_index = torch.randint(high=10, size=(2, 40), dtype=torch.long)
neg_index = torch.randint(high=10, size=(2, 40), dtype=torch.long)
N, E = 20, 40
all_index = torch.randperm(N * N, dtype=torch.long)
all_index = torch.stack([all_index // N, all_index % N], dim=0)
pos_index = all_index[:, :E]
neg_index = all_index[:, E:2 * E]

train_pos_index, test_pos_index = model.split_edges(pos_index)
train_neg_index, test_neg_index = model.split_edges(neg_index)
Expand All @@ -24,14 +26,14 @@ def test_signed_gcn():
x = model.create_spectral_features(
train_pos_index,
train_neg_index,
num_nodes=10,
num_nodes=N,
)
assert x.size() == (10, 8)
assert x.size() == (N, 8)
else:
x = torch.randn(10, 8)
x = torch.randn(N, 8)

z = model(x, train_pos_index, train_neg_index)
assert z.size() == (10, 16)
assert z.size() == (N, 16)

loss = model.loss(z, train_pos_index, train_neg_index)
assert loss.item() >= 0
Expand Down
19 changes: 6 additions & 13 deletions test/utils/test_augmentation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import torch

from torch_geometric import seed_everything
from torch_geometric.utils import (
add_random_edge,
is_undirected,
Expand Down Expand Up @@ -78,26 +77,20 @@ def test_add_random_edge():
assert out[0].tolist() == edge_index.tolist()
assert out[1].tolist() == [[], []]

seed_everything(5)
def _edge_idx_to_set(e: torch.Tensor) -> set:
return {tuple(v) for v in e.tolist()}

out = add_random_edge(edge_index, p=0.5)
assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 3, 1, 2],
[1, 0, 2, 1, 3, 2, 0, 3, 0]]
assert out[1].tolist() == [[3, 1, 2], [0, 3, 0]]
assert _edge_idx_to_set(out[0]).isdisjoint(_edge_idx_to_set(out[1]))

seed_everything(6)
out = add_random_edge(edge_index, p=0.5, force_undirected=True)
assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 1, 3],
[1, 0, 2, 1, 3, 2, 3, 1]]
assert out[1].tolist() == [[1, 3], [3, 1]]
assert _edge_idx_to_set(out[0]).isdisjoint(_edge_idx_to_set(out[1]))
assert is_undirected(out[0])
assert is_undirected(out[1])

# Test for bipartite graph:
seed_everything(7)
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], [2, 3, 1, 4, 2, 1]])
with pytest.raises(RuntimeError, match="not supported for bipartite"):
add_random_edge(edge_index, force_undirected=True, num_nodes=(6, 5))
out = add_random_edge(edge_index, p=0.5, num_nodes=(6, 5))
assert out[0].tolist() == [[0, 1, 2, 3, 4, 5, 2, 0, 2],
[2, 3, 1, 4, 2, 1, 0, 4, 2]]
assert out[1].tolist() == [[2, 0, 2], [0, 4, 2]]
assert _edge_idx_to_set(out[0]).isdisjoint(_edge_idx_to_set(out[1]))
135 changes: 98 additions & 37 deletions test/utils/test_negative_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from torch_geometric.utils import (
batched_negative_sampling,
contains_self_loops,
erdos_renyi_graph,
is_undirected,
negative_sampling,
structured_negative_sampling,
structured_negative_sampling_feasible,
to_undirected,
)
from torch_geometric.utils._negative_sampling import (
edge_index_to_vector,
vector_to_edge_index,
edge_index_to_vector_id,
vector_id_to_edge_index,
)


Expand All @@ -31,35 +32,20 @@ def is_negative(edge_index, neg_edge_index, size, bipartite):

def test_edge_index_to_vector_and_vice_versa():
# Create a fully-connected graph:
N = 10
row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)
col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)
N1, N2 = 13, 17
row = torch.arange(N1).view(-1, 1).repeat(1, N2).view(-1)
col = torch.arange(N2).view(1, -1).repeat(N1, 1).view(-1)
edge_index = torch.stack([row, col], dim=0)

idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)
assert population == N * N
assert idx.tolist() == list(range(population))
edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)
assert is_undirected(edge_index2)
idx = edge_index_to_vector_id(edge_index, (N1, N2))
assert idx.tolist() == list(range(N1 * N2))
edge_index2 = torch.stack(vector_id_to_edge_index(idx, (N1, N2)), dim=0)
assert edge_index.tolist() == edge_index2.tolist()

idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)
assert population == N * N - N
assert idx.tolist() == list(range(population))
mask = edge_index[0] != edge_index[1] # Remove self-loops.
edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)
assert is_undirected(edge_index2)
assert edge_index[:, mask].tolist() == edge_index2.tolist()

idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,
force_undirected=True)
assert population == (N * (N + 1)) / 2 - N
assert idx.tolist() == list(range(population))
mask = edge_index[0] != edge_index[1] # Remove self-loops.
edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,
force_undirected=True)
assert is_undirected(edge_index2)
assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()
vector_id = torch.arange(N1 * N2)
edge_index3 = torch.stack(vector_id_to_edge_index(vector_id, (N1, N2)),
dim=0)
assert edge_index.tolist() == edge_index3.tolist()


def test_negative_sampling():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decorate with @withCUDA to test on cpu and gpu.
See this for an example.

Expand All @@ -69,10 +55,6 @@ def test_negative_sampling():
assert neg_edge_index.size(1) == edge_index.size(1)
assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False)

neg_edge_index = negative_sampling(edge_index, method='dense')
assert neg_edge_index.size(1) == edge_index.size(1)
assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False)
Comment on lines -72 to -74
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why drop this test?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the method is now inferred automatically based on the graph size. Since the graph used for the test is small, the method is always dense. This reflects the idea that a sparse method (which is based on a random guessing of the negative edges) is reasonable only when the graph is spare ($E \ll N^2$).

To test both the sparse and the dense method, I added a new test called test_negative_sampling_with_different_edge_density. Actually, I think that the whole function test_negative_sampling can be removed but I left it there to be sure that the old test still works.


neg_edge_index = negative_sampling(edge_index, num_neg_samples=2)
assert neg_edge_index.size(1) == 2
assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False)
Expand All @@ -97,6 +79,30 @@ def test_bipartite_negative_sampling():
assert is_negative(edge_index, neg_edge_index, (3, 4), bipartite=True)


def test_negative_sampling_with_different_edge_density():
for num_nodes in [10, 100, 1000]:
for p in [0.1, 0.3, 0.5, 0.8]:
for is_directed in [False, True]:
edge_index = erdos_renyi_graph(num_nodes, p, is_directed)
neg_edge_index = negative_sampling(
edge_index, num_nodes, force_undirected=not is_directed)
assert is_negative(edge_index, neg_edge_index,
(num_nodes, num_nodes), bipartite=False)


def test_bipartite_negative_sampling_with_different_edge_density():
for num_nodes in [10, 100, 1000]:
for p in [0.1, 0.3, 0.5, 0.8]:
size = (num_nodes, int(num_nodes * 1.2))
n_edges = int(p * size[0] * size[1])
row, col = torch.randint(size[0], (n_edges, )), torch.randint(
size[1], (n_edges, ))
edge_index = torch.stack([row, col], dim=0)
neg_edge_index = negative_sampling(edge_index, size)
assert is_negative(edge_index, neg_edge_index, size,
bipartite=True)


def test_batched_negative_sampling():
edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]])
edge_index = torch.cat([edge_index, edge_index + 4], dim=1)
Expand Down Expand Up @@ -153,16 +159,71 @@ def test_structured_negative_sampling():
assert (adj & neg_adj).sum() == 0

# Test with no self-loops:
edge_index = torch.LongTensor([[0, 0, 1, 1, 2], [1, 2, 0, 2, 1]])
i, j, k = structured_negative_sampling(edge_index, num_nodes=4,
contains_neg_self_loops=False)
neg_edge_index = torch.vstack([i, k])
assert not contains_self_loops(neg_edge_index)


def test_structured_negative_sampling_sparse():
num_nodes = 1000
edge_index = erdos_renyi_graph(num_nodes, 0.1)

i, j, k = structured_negative_sampling(edge_index, num_nodes=num_nodes,
contains_neg_self_loops=True)
assert i.size(0) == edge_index.size(1)
assert j.size(0) == edge_index.size(1)
assert k.size(0) == edge_index.size(1)

assert torch.all(torch.ne(k, -1))
adj = torch.zeros(num_nodes, num_nodes, dtype=torch.bool)
adj[i, j] = 1

neg_adj = torch.zeros(num_nodes, num_nodes, dtype=torch.bool)
neg_adj[i, k] = 1
assert (adj & neg_adj).sum() == 0

i, j, k = structured_negative_sampling(edge_index, num_nodes=num_nodes,
contains_neg_self_loops=False)
assert i.size(0) == edge_index.size(1)
assert j.size(0) == edge_index.size(1)
assert k.size(0) == edge_index.size(1)

assert torch.all(torch.ne(k, -1))
adj = torch.zeros(num_nodes, num_nodes, dtype=torch.bool)
adj[i, j] = 1

neg_adj = torch.zeros(num_nodes, num_nodes, dtype=torch.bool)
neg_adj[i, k] = 1
assert (adj & neg_adj).sum() == 0

neg_edge_index = torch.vstack([i, k])
assert not contains_self_loops(neg_edge_index)


def test_structured_negative_sampling_feasible():
edge_index = torch.LongTensor([[0, 0, 1, 1, 2, 2, 2],
[1, 2, 0, 2, 0, 1, 1]])
assert not structured_negative_sampling_feasible(edge_index, 3, False)
assert structured_negative_sampling_feasible(edge_index, 3, True)
assert structured_negative_sampling_feasible(edge_index, 4, False)
def create_ring_graph(num_nodes):
forward_edges = torch.stack([
torch.arange(0, num_nodes, dtype=torch.long),
(torch.arange(0, num_nodes, dtype=torch.long) + 1) % num_nodes
], dim=0)
backward_edges = torch.stack([
torch.arange(0, num_nodes, dtype=torch.long),
(torch.arange(0, num_nodes, dtype=torch.long) - 1) % num_nodes
], dim=0)
return torch.concat([forward_edges, backward_edges], dim=1)

# ring 3 is always unfeasible
ring_3 = create_ring_graph(3)
assert not structured_negative_sampling_feasible(ring_3, 3, False)
assert not structured_negative_sampling_feasible(ring_3, 3, True)

# ring 4 is feasible only if we consider self loops
ring_4 = create_ring_graph(4)
assert not structured_negative_sampling_feasible(ring_4, 4, False)
assert structured_negative_sampling_feasible(ring_4, 4, True)

# ring 5 is always feasible
ring_5 = create_ring_graph(5)
assert structured_negative_sampling_feasible(ring_5, 5, False)
assert structured_negative_sampling_feasible(ring_5, 5, True)
2 changes: 1 addition & 1 deletion torch_geometric/transforms/random_link_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def forward(
size = size[0]
neg_edge_index = negative_sampling(edge_index, size,
num_neg_samples=num_neg,
method='sparse')
method='auto')

# Adjust ratio if not enough negative edges exist
if neg_edge_index.size(1) < num_neg:
Expand Down
Loading
Loading