-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
73e28aa
commit 8f58003
Showing
31 changed files
with
6,156 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# deepmind (i think) package for optimization | ||
import jax | ||
import jax.numpy as jnp | ||
import jraph | ||
import networkx as nx | ||
import numpy as np | ||
from utils import rng_sequence_from_rng | ||
|
||
def get_grid_adjacency(n_x, n_y, atol=1e-1): | ||
return nx.grid_2d_graph(n_x, n_y) # Get directed grid graph | ||
|
||
|
||
def sample_padded_grid_batch_shortest_path(rng, batch_size,feature_position,weighted, nx_min, nx_max, ny_min=None, ny_max=None): | ||
rng_seq = rng_sequence_from_rng(rng) | ||
"""Sample a batch of grid graphs with variable sizes. | ||
Args: | ||
rng: jax.random.PRNGKey(integer_seed) object, random number generator | ||
batch_size: int, number of graphs to sample | ||
nx_min: minum size along x axis | ||
nx_max: maximum size along x axis | ||
ny_min: minum size along y axis (default: nx_min) | ||
ny_max: maximum size along y axis (default: ny_max) | ||
Returns: | ||
padded graph batch that can be fed into a jraph GNN. | ||
""" | ||
ny_min = ny_min or nx_min | ||
ny_max = ny_max or nx_max | ||
max_n = ny_max * nx_max * batch_size | ||
max_e = max_n * 4 | ||
# Sample grid dimensions | ||
x_rng = next(rng_seq) | ||
y_rng = next(rng_seq) # need to "split" to advance the random number generator -- otherwise rng will be the same for all things sampled from it | ||
n_xs = jax.random.randint(x_rng, shape=(batch_size,), minval=nx_min, maxval=nx_max) | ||
n_ys = jax.random.randint(y_rng, shape=(batch_size,), minval=ny_min, maxval=ny_max) | ||
# Construct graphs with sampled dimensions. | ||
graphs = [] | ||
target = [] | ||
for n_x, n_y in zip(n_xs, n_ys): | ||
nx_graph = get_grid_adjacency(n_x, n_y) | ||
senders, receivers, node_positions, edge_displacements, n_node, n_edge, global_context = grid_networkx_to_graphstuple( | ||
nx_graph) | ||
i_end = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_node) | ||
i_start = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_node) | ||
# make it a node feature of the input graph if a node is a start/end node | ||
input_node_features = jnp.zeros((n_node, 1)) | ||
input_node_features = input_node_features.at[i_start, 0].set(1) # set start node feature | ||
input_node_features = input_node_features.at[i_end, 0].set(1) # set end node feature | ||
if feature_position: | ||
input_node_features = jnp.concatenate((input_node_features, node_positions),axis=1) | ||
if weighted: | ||
weights = add_weighted_edge(edge_displacements,n_edge,1) | ||
edge_displacement = jnp.concatenate((edge_displacements,weights), axis=1) | ||
graph = jraph.GraphsTuple(nodes=input_node_features, senders=senders, receivers=receivers, | ||
edges=edge_displacement, n_node=jnp.array([n_node], dtype=int), | ||
n_edge=jnp.array([n_edge], dtype=int), globals=global_context) | ||
nx_weighted = convert_jraph_to_networkx_graph(graph, 0) | ||
for i, j in nx_weighted.edges(): | ||
nx_weighted[i][j]['weight'] = weights[i] | ||
nodes_on_shortest_path_indexes = nx.shortest_path(nx_weighted, int(i_start[0]), int(i_end[0]), weight='weight') | ||
else: | ||
graph = jraph.GraphsTuple(nodes=input_node_features, senders=senders, receivers=receivers, | ||
edges=edge_displacements, n_node=jnp.array([n_node], dtype=int), | ||
n_edge=jnp.array([n_edge], dtype=int), globals=global_context) | ||
nx_weighted = convert_jraph_to_networkx_graph(graph, 0) | ||
nodes_on_shortest_path_indexes = nx.shortest_path(nx_weighted, int(i_start[0]), int(i_end[0])) | ||
graphs.append(graph) | ||
|
||
nodes_on_shortest_labels = jnp.zeros((n_node, 1)) | ||
for i in nodes_on_shortest_path_indexes: | ||
nodes_on_shortest_labels = nodes_on_shortest_labels.at[i].set(1) | ||
target.append(nodes_on_shortest_labels)# set start node feature | ||
targets=jnp.concatenate(target) | ||
target_pad = jnp.zeros(((max_n - len(targets)), 1)) | ||
padded_target= jnp.concatenate(( targets, target_pad), axis=0) | ||
graph_batch = jraph.batch(graphs) | ||
padded_graph_batch = jraph.pad_with_graphs(graph_batch, n_node=max_n, n_edge=max_e, n_graph=len(graphs) + 1) | ||
|
||
return padded_graph_batch, jnp.asarray( padded_target) | ||
|
||
def grid_networkx_to_graphstuple(nx_graph): | ||
"""Get edges for a grid graph.""" | ||
nx_graph = nx.DiGraph(nx_graph) | ||
node_positions = jnp.array(nx_graph.nodes) | ||
node_to_inds = {n: i for i, n in enumerate(nx_graph.nodes)} | ||
senders_receivers = [(node_to_inds[s], node_to_inds[r]) for s, r in nx_graph.edges] | ||
edge_displacements = jnp.array([np.array(r) - np.array(s) for s, r in nx_graph.edges]) | ||
senders, receivers = zip(*senders_receivers) | ||
n_node = node_positions.shape[0] | ||
n_edge = edge_displacements.shape[0] | ||
return ( | ||
jnp.array(senders, dtype=int), | ||
jnp.array(receivers, dtype=int), | ||
jnp.array(node_positions, dtype=float), | ||
jnp.array(edge_displacements, dtype=float), | ||
n_node, | ||
n_edge, | ||
jnp.zeros((1, 0), dtype=float)) | ||
|
||
def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple, number_graph_batch) -> nx.Graph: | ||
nodes, edges, receivers, senders, _, _, _ = jraph_graph | ||
node_padd = 0 | ||
edges_padd = 0 | ||
for i in range(number_graph_batch): | ||
node_padd = node_padd + jraph_graph.n_node[i] | ||
edges_padd = edges_padd + jraph_graph.n_edge[i] | ||
nx_graph = nx.DiGraph() | ||
if nodes is None: | ||
for n in range(jraph_graph.n_node[number_graph_batch]): | ||
nx_graph.add_node(n) | ||
else: | ||
for n in range(jraph_graph.n_node[number_graph_batch]): | ||
nx_graph.add_node(n, node_feature=nodes[node_padd + n]) | ||
for e in range(jraph_graph.n_edge[number_graph_batch]): | ||
nx_graph.add_edge( | ||
int(senders[edges_padd + e]) - node_padd, int(receivers[edges_padd + e] - node_padd), | ||
edge_feature=edges[edges_padd + e]) | ||
return nx_graph | ||
"Need to figure out how it is working with the batch thing, how it is working" \ | ||
"Figure out as well other types of weighting on the edges " | ||
"Here I have a problem again because for one of them i have to have positive weights " \ | ||
"I think I will add it as a feature of the edges " | ||
def add_weighted_edge(edge_displacement,n_edge,sigma_on_edge_weight_noise): | ||
weights= jnp.zeros((n_edge, 1)) | ||
for k in range(n_edge): | ||
weight=np.max([sigma_on_edge_weight_noise * np.random.rand() + 1., 0.5]) | ||
weights = weights.at[k, 0].set(weight) | ||
#edge_displacement = edge_displacement.at[k,l].set(edge_displacement[k][l] + weight) # weights=sigma_on_edge_weight_noise * np.random.rand() Because nedd postiove and add as features and need ot be used by the neural networks :) | ||
return weights |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
experiment_name: 'smaller size generalisation graph with no position feature' | ||
train_on_shortest_path: True | ||
resample: True # @param | ||
wandb_on: True | ||
seed: 42 | ||
|
||
feature_position: False | ||
weighted: True | ||
|
||
batch_size: 4 | ||
nx_min: 4 | ||
nx_max: 7 | ||
|
||
batch_size_test: 4 | ||
nx_min_test: 4 | ||
nx_max_test: 7 | ||
|
||
num_hidden: 100 # @param | ||
num_layers: 2 # @param | ||
num_message_passing_steps: 3 # @param | ||
learning_rate: 0.001 # @param | ||
num_training_steps: 10 # @param | ||
# for loop | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from config_manager import config_field | ||
from config_manager import config_template | ||
|
||
class ConfigTemplate: | ||
_grid_template = config_template.Template( | ||
fields=[ | ||
config_field.Field( | ||
name='resample', | ||
types=[bool], | ||
), | ||
config_field.Field( | ||
name='train_on_shortest_path', | ||
types=[bool], | ||
), | ||
config_field.Field( | ||
name='wandb_on', | ||
types=[bool], | ||
), | ||
config_field.Field( | ||
name='weighted', | ||
types=[bool], | ||
), | ||
config_field.Field( | ||
name='seed', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='batch_size', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='nx_min', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='nx_max', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='batch_size_test', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='nx_min_test', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='nx_max_test', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='num_hidden', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='num_layers', | ||
types=[int], | ||
), | ||
# @param | ||
config_field.Field( | ||
name='num_message_passing_steps', | ||
types=[int], | ||
), | ||
# @param | ||
config_field.Field( | ||
name='learning_rate', | ||
types=[float], | ||
), | ||
config_field.Field( | ||
name='num_training_steps', | ||
types=[float, int], | ||
), | ||
], | ||
) | ||
base_config_template = config_template.Template( | ||
fields=[ | ||
config_field.Field( | ||
name='experiment_name', | ||
types=[str, type(None)], | ||
), | ||
config_field.Field( | ||
name='resample', | ||
types=[bool], | ||
), | ||
config_field.Field( | ||
name='train_on_shortest_path', | ||
types=[bool], | ||
), | ||
config_field.Field( | ||
name='wandb_on', | ||
types=[bool], | ||
), | ||
config_field.Field( | ||
name='batch_size', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='nx_min', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='nx_max', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='batch_size_test', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='nx_min_test', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='nx_max_test', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='num_hidden', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='num_layers', | ||
types=[int], | ||
), | ||
# @param | ||
config_field.Field( | ||
name='num_message_passing_steps', | ||
types=[int], | ||
), | ||
config_field.Field( | ||
name='seed', | ||
types=[int], | ||
), | ||
# @param | ||
config_field.Field( | ||
name='learning_rate', | ||
types=[float], | ||
), | ||
config_field.Field( | ||
name='feature_position', | ||
types=[bool], | ||
), | ||
config_field.Field( | ||
name='weighted', | ||
types=[bool], | ||
), | ||
config_field.Field( | ||
name='num_training_steps', | ||
types=[float, int], | ||
), | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from typing import Dict | ||
from typing import Union | ||
from config_template import ConfigTemplate | ||
from config_manager import base_configuration | ||
|
||
|
||
class GridConfig(base_configuration.BaseConfiguration): | ||
def __init__(self, config: Union[str, Dict],) -> None: | ||
super().__init__( | ||
configuration=config, | ||
template=ConfigTemplate.base_config_template, | ||
) | ||
self._validate_configuration() | ||
|
||
def _validate_configuration(self): | ||
"""Method to check for non-trivial associations | ||
in the configuration. | ||
""" | ||
pass | ||
# if self.teacher_configuration == constants.BOTH_ROTATION: | ||
# assert self.scale_forward_by_hidden == True, ( | ||
# "In both rotation regime, i.e. mean field limit, " | ||
# "need to scale forward by 1/K." | ||
# ) | ||
# else: | ||
# assert self.scale_forward_by_hidden == False, ( | ||
# "When not in both rotation regime, i.e. mean field limit, " | ||
# "no need to scale forward by 1/K." | ||
# ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import haiku as hk | ||
import jraph | ||
import jax.numpy as jnp | ||
from plotting_utils import plot_message_passing_layers | ||
|
||
#TODO(clementine): set up object oriented GNN classes (eventually) | ||
|
||
def get_forward_function(num_hidden, num_layers,num_message_passing_steps): | ||
"""Get function that performs a forward call on a simple GNN.""" | ||
def _forward(x): | ||
"""Forward pass of a simple GNN.""" | ||
node_output_size = 1 | ||
edge_output_size = 1 | ||
|
||
# Set up MLP parameters for node/edge updates | ||
node_mlp_sizes = [num_hidden] * num_layers | ||
edge_mlp_sizes = [num_hidden] * num_layers | ||
|
||
# Map features to desired feature size. | ||
x = jraph.GraphMapFeatures( | ||
embed_edge_fn=hk.Linear(output_size=num_hidden), | ||
embed_node_fn=hk.Linear(output_size=num_hidden))(x) | ||
|
||
# Apply rounds of message passing. | ||
message_passing=[] | ||
for n in range(num_message_passing_steps): | ||
x = message_passing_layer(x, edge_mlp_sizes, node_mlp_sizes) | ||
message_passing.append(x) | ||
|
||
# Map features to desired feature size. | ||
x = jraph.GraphMapFeatures( | ||
embed_edge_fn=hk.Linear(output_size=edge_output_size), | ||
embed_node_fn=hk.Linear(output_size=node_output_size))(x) | ||
|
||
return x , message_passing | ||
return _forward | ||
|
||
def message_passing_layer(x, edge_mlp_sizes, node_mlp_sizes): | ||
update_edge_fn = jraph.concatenated_args(hk.nets.MLP(output_sizes=edge_mlp_sizes)) | ||
update_node_fn = jraph.concatenated_args(hk.nets.MLP(output_sizes=node_mlp_sizes)) | ||
x = jraph.GraphNetwork(update_edge_fn=update_edge_fn, update_node_fn=update_node_fn)(x) | ||
return x | ||
|
Oops, something went wrong.