blob: 3a870e270fefdbcb7e94571fc344c454209b10f8 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C0103,R1732,C0200,R1705
"""
Position-aware Graph Neural Networks (P-GNN)
References
----------
Paper: http://proceedings.mlr.press/v97/you19b/you19b.pdf
Author's code: https://github.com/JiaxuanYou/P-GNN
DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/P-GNN
"""
import multiprocessing as mp
import random
from multiprocessing import get_context
import torch
from torch import nn
import torch.nn.functional as F
import networkx as nx
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score
import dgl.function as fn
class PGNN_layer(nn.Module):
def __init__(self, input_dim, output_dim):
super(PGNN_layer, self).__init__()
self.input_dim = input_dim
self.linear_hidden_u = nn.Linear(input_dim, output_dim)
self.linear_hidden_v = nn.Linear(input_dim, output_dim)
self.linear_out_position = nn.Linear(output_dim, 1)
self.act = nn.ReLU()
def forward(self, graph, feature, anchor_eid, dists_max):
with graph.local_scope():
u_feat = self.linear_hidden_u(feature)
v_feat = self.linear_hidden_v(feature)
graph.srcdata.update({"u_feat": u_feat})
graph.dstdata.update({"v_feat": v_feat})
graph.apply_edges(fn.u_mul_e("u_feat", "sp_dist", "u_message")) # pylint: disable=E1101
graph.apply_edges(fn.v_add_e("v_feat", "u_message", "message")) # pylint: disable=E1101
messages = torch.index_select(
graph.edata["message"],
0,
torch.LongTensor(anchor_eid).to(feature.device),
)
messages = messages.reshape(
dists_max.shape[0], dists_max.shape[1], messages.shape[-1]
)
messages = self.act(messages) # n*m*d
out_position = self.linear_out_position(messages).squeeze(-1) # n*m_out
out_structure = torch.mean(messages, dim=1) # n*d
return out_position, out_structure
class PGNN(nn.Module):
def __init__(self, input_dim, feature_dim=32, dropout=0.5):
super(PGNN, self).__init__()
self.dropout = nn.Dropout(dropout)
self.linear_pre = nn.Linear(input_dim, feature_dim)
self.conv_first = PGNN_layer(feature_dim, feature_dim)
self.conv_out = PGNN_layer(feature_dim, feature_dim)
def forward(self, data):
x = data["graph"].ndata["feat"]
graph = data["graph"]
x = self.linear_pre(x)
x_position, x = self.conv_first(graph, x, data["anchor_eid"], data["dists_max"])
x = self.dropout(x)
x_position, x = self.conv_out(graph, x, data["anchor_eid"], data["dists_max"])
x_position = F.normalize(x_position, p=2, dim=-1)
return x_position
def get_communities(remove_feature, graph):
community_size = 20
# Randomly rewire 1% edges
node_list = list(graph.nodes)
for u, v in graph.edges():
if random.random() < 0.01:
x = random.choice(node_list)
if graph.has_edge(u, x):
continue
graph.remove_edge(u, v)
graph.add_edge(u, x)
# remove self-loops
graph.remove_edges_from(nx.selfloop_edges(graph))
edge_index = np.array(list(graph.edges))
# Add (i, j) for an edge (j, i)
edge_index = np.concatenate((edge_index, edge_index[:, ::-1]), axis=0)
edge_index = torch.from_numpy(edge_index).long().permute(1, 0)
n = graph.number_of_nodes()
label = np.zeros((n, n), dtype=int)
for u in node_list:
# the node IDs are simply consecutive integers from 0
for v in range(u):
if u // community_size == v // community_size:
label[u, v] = 1
if remove_feature:
feature = torch.ones((n, 1))
else:
rand_order = np.random.permutation(n)
feature = np.identity(n)[:, rand_order]
data = {
"edge_index": edge_index,
"feature": feature,
"positive_edges": np.stack(np.nonzero(label)),
"num_nodes": feature.shape[0],
}
return data
def to_single_directed(edges):
edges_new = np.zeros((2, edges.shape[1] // 2), dtype=int)
j = 0
for i in range(edges.shape[1]):
if edges[0, i] < edges[1, i]:
edges_new[:, j] = edges[:, i]
j += 1
return edges_new
# each node at least remain in the new graph
def split_edges(p, edges, data, non_train_ratio=0.2):
e = edges.shape[1]
edges = edges[:, np.random.permutation(e)]
split1 = int((1 - non_train_ratio) * e)
split2 = int((1 - non_train_ratio / 2) * e)
data.update(
{
f"{p}_edges_train": edges[:, :split1], # 80%
f"{p}_edges_val": edges[:, split1:split2], # 10%
f"{p}_edges_test": edges[:, split2:], # 10%
}
)
def to_bidirected(edges):
return np.concatenate((edges, edges[::-1, :]), axis=-1)
def get_negative_edges(positive_edges, num_nodes, num_negative_edges):
positive_edge_set = []
positive_edges = to_bidirected(positive_edges)
for i in range(positive_edges.shape[1]):
positive_edge_set.append(tuple(positive_edges[:, i]))
positive_edge_set = set(positive_edge_set)
negative_edges = np.zeros((2, num_negative_edges), dtype=positive_edges.dtype)
for i in range(num_negative_edges):
while True:
mask_temp = tuple(np.random.choice(num_nodes, size=(2,), replace=False))
if mask_temp not in positive_edge_set:
negative_edges[:, i] = mask_temp
break
return negative_edges
def get_pos_neg_edges(data, infer_link_positive=True):
if infer_link_positive:
data["positive_edges"] = to_single_directed(data["edge_index"].numpy())
split_edges("positive", data["positive_edges"], data)
# resample edge mask link negative
negative_edges = get_negative_edges(
data["positive_edges"],
data["num_nodes"],
num_negative_edges=data["positive_edges"].shape[1],
)
split_edges("negative", negative_edges, data)
return data
def shortest_path(graph, node_range, cutoff):
dists_dict = {}
for node in tqdm(node_range, leave=False):
dists_dict[node] = nx.single_source_shortest_path_length(graph, node, cutoff)
return dists_dict
def merge_dicts(dicts):
result = {}
for dictionary in dicts:
result.update(dictionary)
return result
def all_pairs_shortest_path(graph, cutoff=None, num_workers=4):
nodes = list(graph.nodes)
random.shuffle(nodes)
pool = mp.Pool(processes=num_workers)
interval_size = len(nodes) / num_workers
results = [
pool.apply_async(
shortest_path,
args=(
graph,
nodes[int(interval_size * i) : int(interval_size * (i + 1))],
cutoff,
),
)
for i in range(num_workers)
]
output = [p.get() for p in results]
dists_dict = merge_dicts(output)
pool.close()
pool.join()
return dists_dict
def precompute_dist_data(edge_index, num_nodes, approximate=0):
"""
Here dist is 1/real_dist, higher actually means closer, 0 means disconnected
:return:
"""
graph = nx.Graph()
edge_list = edge_index.transpose(1, 0).tolist()
graph.add_edges_from(edge_list)
n = num_nodes
dists_array = np.zeros((n, n))
dists_dict = all_pairs_shortest_path(
graph, cutoff=approximate if approximate > 0 else None
)
node_list = graph.nodes()
for node_i in node_list:
shortest_dist = dists_dict[node_i]
for node_j in node_list:
dist = shortest_dist.get(node_j, -1)
if dist != -1:
dists_array[node_i, node_j] = 1 / (dist + 1)
return dists_array
def get_dataset(graph):
# Generate graph data
data_info = get_communities(False, graph)
# Get positive and negative edges
data = get_pos_neg_edges(data_info, infer_link_positive=True)
# Pre-compute shortest path length
dists_removed = precompute_dist_data(
data["positive_edges_train"],
data["num_nodes"],
approximate=-1,
)
data["dists"] = torch.from_numpy(dists_removed).float()
data["edge_index"] = torch.from_numpy(
to_bidirected(data["positive_edges_train"])
).long()
return data
def get_anchors(n):
"""Get a list of NumPy arrays, each of them is an anchor node set"""
m = int(np.log2(n))
anchor_set_id = []
for i in range(m):
anchor_size = int(n / np.exp2(i + 1))
for _ in range(m):
anchor_set_id.append(np.random.choice(n, size=anchor_size, replace=False))
return anchor_set_id
def get_dist_max(anchor_set_id, dist):
# N x K, N is number of nodes, K is the number of anchor sets
dist_max = torch.zeros((dist.shape[0], len(anchor_set_id)))
dist_argmax = torch.zeros((dist.shape[0], len(anchor_set_id))).long()
for i in range(len(anchor_set_id)):
temp_id = torch.as_tensor(anchor_set_id[i], dtype=torch.long)
# Get reciprocal of shortest distance to each node in the i-th anchor set
dist_temp = torch.index_select(dist, 1, temp_id)
# For each node in the graph, find its closest anchor node in the set
# and the reciprocal of shortest distance
dist_max_temp, dist_argmax_temp = torch.max(dist_temp, dim=-1)
dist_max[:, i] = dist_max_temp
dist_argmax[:, i] = torch.index_select(temp_id, 0, dist_argmax_temp)
return dist_max, dist_argmax
def get_a_graph(dists_max, dists_argmax):
src = []
dst = []
real_src = []
real_dst = []
edge_weight = []
dists_max = dists_max.numpy()
for i in range(dists_max.shape[0]):
# Get unique closest anchor nodes for node i across all anchor sets
tmp_dists_argmax, tmp_dists_argmax_idx = np.unique(dists_argmax[i, :], True)
src.extend([i] * tmp_dists_argmax.shape[0])
real_src.extend([i] * dists_argmax[i, :].shape[0])
real_dst.extend(list(dists_argmax[i, :].numpy()))
dst.extend(list(tmp_dists_argmax))
edge_weight.extend(dists_max[i, tmp_dists_argmax_idx].tolist())
eid_dict = {(u, v): i for i, (u, v) in enumerate(list(zip(dst, src)))}
anchor_eid = [eid_dict.get((u, v)) for u, v in zip(real_dst, real_src)]
g = (dst, src)
return g, anchor_eid, edge_weight
def get_graphs(data, anchor_sets):
graphs = []
anchor_eids = []
dists_max_list = []
edge_weights = []
for anchor_set in tqdm(anchor_sets, leave=False):
dists_max, dists_argmax = get_dist_max(anchor_set, data["dists"])
g, anchor_eid, edge_weight = get_a_graph(dists_max, dists_argmax)
graphs.append(g)
anchor_eids.append(anchor_eid)
dists_max_list.append(dists_max)
edge_weights.append(edge_weight)
return graphs, anchor_eids, dists_max_list, edge_weights
def merge_result(outputs):
graphs = []
anchor_eids = []
dists_max_list = []
edge_weights = []
for g, anchor_eid, dists_max, edge_weight in outputs:
graphs.extend(g)
anchor_eids.extend(anchor_eid)
dists_max_list.extend(dists_max)
edge_weights.extend(edge_weight)
return graphs, anchor_eids, dists_max_list, edge_weights
def preselect_anchor(data, num_workers=4):
pool = get_context("spawn").Pool(processes=num_workers)
# Pre-compute anchor sets, a collection of anchor sets per epoch
anchor_set_ids = [get_anchors(data["num_nodes"]) for _ in range(200)]
interval_size = len(anchor_set_ids) / num_workers
results = [
pool.apply_async(
get_graphs,
args=(
data,
anchor_set_ids[int(interval_size * i) : int(interval_size * (i + 1))],
),
)
for i in range(num_workers)
]
output = [p.get() for p in results]
graphs, anchor_eids, dists_max_list, edge_weights = merge_result(output)
pool.close()
pool.join()
return graphs, anchor_eids, dists_max_list, edge_weights
def get_loss(p, data, out, loss_func, device, get_auc=True):
edge_mask = np.concatenate(
(
data[f"positive_edges_{p}"],
data[f"negative_edges_{p}"],
),
axis=-1,
)
nodes_first = torch.index_select(
out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device)
)
nodes_second = torch.index_select(
out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device)
)
pred = torch.sum(nodes_first * nodes_second, dim=-1)
label_positive = torch.ones(
[
data[f"positive_edges_{p}"].shape[1],
],
dtype=pred.dtype,
)
label_negative = torch.zeros(
[
data[f"negative_edges_{p}"].shape[1],
],
dtype=pred.dtype,
)
label = torch.cat((label_positive, label_negative)).to(device)
loss = loss_func(pred, label)
if get_auc:
auc = roc_auc_score(
label.flatten().cpu().numpy(),
torch.sigmoid(pred).flatten().data.cpu().numpy(),
)
return loss, auc
else:
return loss
def train_model(data, model, loss_func, optimizer, device, g_data):
model.train()
out = model(g_data)
loss = get_loss("train", data, out, loss_func, device, get_auc=False)
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
return g_data
def eval_model(data, g_data, model, loss_func, device):
model.eval()
out = model(g_data)
# train loss and auc
tmp_loss, auc_train = get_loss("train", data, out, loss_func, device)
loss_train = tmp_loss.cpu().data.numpy()
# val loss and auc
_, auc_val = get_loss("val", data, out, loss_func, device)
# test loss and auc
_, auc_test = get_loss("test", data, out, loss_func, device)
return loss_train, auc_train, auc_val, auc_test