blob: f743857bfcb407d8d16b9a4de04fc2d511be63f1 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=R1719,C0103,R0205,R1721.R1705,R0205,W0612
"""
SEAL
References
----------
Paper: https://arxiv.org/abs/1802.09691
Author's code: https://github.com/muhanzhang/SEAL
DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/seal
"""
import argparse
import os
import os.path as osp
from copy import deepcopy
import logging
import time
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import dgl
from dgl import add_self_loop, NID
from dgl.dataloading.negative_sampler import Uniform
from dgl.nn.pytorch import GraphConv, SAGEConv, SortPooling, SumPooling
import numpy as np
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path
from tqdm import tqdm
class GCN(nn.Module):
"""
GCN Model
Attributes:
num_layers(int): num of gcn layers
hidden_units(int): num of hidden units
gcn_type(str): type of gcn layer, 'gcn' for GraphConv and 'sage' for SAGEConv
pooling_type(str): type of graph pooling to get subgraph representation
'sum' for sum pooling and 'center' for center pooling.
node_attributes(Tensor, optional): node attribute
edge_weights(Tensor, optional): edge weight
node_embedding(Tensor, optional): pre-trained node embedding
use_embedding(bool, optional): whether to use node embedding. Note that if 'use_embedding' is set True
and 'node_embedding' is None, will automatically randomly initialize node embedding.
num_nodes(int, optional): num of nodes
dropout(float, optional): dropout rate
max_z(int, optional): default max vocab size of node labeling, default 1000.
"""
def __init__(
self,
num_layers,
hidden_units,
gcn_type="gcn",
pooling_type="sum",
node_attributes=None,
edge_weights=None,
node_embedding=None,
use_embedding=False,
num_nodes=None,
dropout=0.5,
max_z=1000,
):
super(GCN, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
self.pooling_type = pooling_type
self.use_attribute = False if node_attributes is None else True
self.use_embedding = use_embedding
self.use_edge_weight = False if edge_weights is None else True
self.z_embedding = nn.Embedding(max_z, hidden_units)
if node_attributes is not None:
self.node_attributes_lookup = nn.Embedding.from_pretrained(node_attributes)
self.node_attributes_lookup.weight.requires_grad = False
if edge_weights is not None:
self.edge_weights_lookup = nn.Embedding.from_pretrained(edge_weights)
self.edge_weights_lookup.weight.requires_grad = False
if node_embedding is not None:
self.node_embedding = nn.Embedding.from_pretrained(node_embedding)
self.node_embedding.weight.requires_grad = False
elif use_embedding:
self.node_embedding = nn.Embedding(num_nodes, hidden_units)
initial_dim = hidden_units
if self.use_attribute:
initial_dim += self.node_attributes_lookup.embedding_dim
if self.use_embedding:
initial_dim += self.node_embedding.embedding_dim
self.layers = nn.ModuleList()
if gcn_type == "gcn":
self.layers.append(
GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)
)
for _ in range(num_layers - 1):
self.layers.append(
GraphConv(hidden_units, hidden_units, allow_zero_in_degree=True)
)
elif gcn_type == "sage":
self.layers.append(
SAGEConv(initial_dim, hidden_units, aggregator_type="gcn")
)
for _ in range(num_layers - 1):
self.layers.append(
SAGEConv(hidden_units, hidden_units, aggregator_type="gcn")
)
else:
raise ValueError("Gcn type error.")
self.linear_1 = nn.Linear(hidden_units, hidden_units)
self.linear_2 = nn.Linear(hidden_units, 1)
if pooling_type != "sum":
raise ValueError("Pooling type error.")
self.pooling = SumPooling()
def reset_parameters(self):
for layer in self.layers:
layer.reset_parameters()
def forward(self, g, z, node_id=None, edge_id=None):
"""
Args:
g(DGLGraph): the graph
z(Tensor): node labeling tensor, shape [N, 1]
node_id(Tensor, optional): node id tensor, shape [N, 1]
edge_id(Tensor, optional): edge id tensor, shape [E, 1]
Returns:
x(Tensor): output tensor
"""
z_emb = self.z_embedding(z)
if self.use_attribute:
x = self.node_attributes_lookup(node_id)
x = torch.cat([z_emb, x], 1)
else:
x = z_emb
if self.use_edge_weight:
edge_weight = self.edge_weights_lookup(edge_id)
else:
edge_weight = None
if self.use_embedding:
n_emb = self.node_embedding(node_id)
x = torch.cat([x, n_emb], 1)
for layer in self.layers[:-1]:
x = layer(g, x, edge_weight=edge_weight)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.layers[-1](g, x, edge_weight=edge_weight)
x = self.pooling(g, x)
x = F.relu(self.linear_1(x))
F.dropout(x, p=self.dropout, training=self.training)
x = self.linear_2(x)
return x
class DGCNN(nn.Module):
"""
An end-to-end deep learning architecture for graph classification.
paper link: https://muhanzhang.github.io/papers/AAAI_2018_DGCNN.pdf
Attributes:
num_layers(int): num of gcn layers
hidden_units(int): num of hidden units
k(int, optional): The number of nodes to hold for each graph in SortPooling.
gcn_type(str): type of gcn layer, 'gcn' for GraphConv and 'sage' for SAGEConv
node_attributes(Tensor, optional): node attribute
edge_weights(Tensor, optional): edge weight
node_embedding(Tensor, optional): pre-trained node embedding
use_embedding(bool, optional): whether to use node embedding. Note that if 'use_embedding' is set True
and 'node_embedding' is None, will automatically randomly initialize node embedding.
num_nodes(int, optional): num of nodes
dropout(float, optional): dropout rate
max_z(int, optional): default max vocab size of node labeling, default 1000.
"""
def __init__(
self,
num_layers,
hidden_units,
k=10,
gcn_type="gcn",
node_attributes=None,
edge_weights=None,
node_embedding=None,
use_embedding=False,
num_nodes=None,
dropout=0.5,
max_z=1000,
):
super(DGCNN, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
self.use_attribute = False if node_attributes is None else True
self.use_embedding = use_embedding
self.use_edge_weight = False if edge_weights is None else True
self.z_embedding = nn.Embedding(max_z, hidden_units)
if node_attributes is not None:
self.node_attributes_lookup = nn.Embedding.from_pretrained(node_attributes)
self.node_attributes_lookup.weight.requires_grad = False
if edge_weights is not None:
self.edge_weights_lookup = nn.Embedding.from_pretrained(edge_weights)
self.edge_weights_lookup.weight.requires_grad = False
if node_embedding is not None:
self.node_embedding = nn.Embedding.from_pretrained(node_embedding)
self.node_embedding.weight.requires_grad = False
elif use_embedding:
self.node_embedding = nn.Embedding(num_nodes, hidden_units)
initial_dim = hidden_units
if self.use_attribute:
initial_dim += self.node_attributes_lookup.embedding_dim
if self.use_embedding:
initial_dim += self.node_embedding.embedding_dim
self.layers = nn.ModuleList()
if gcn_type == "gcn":
self.layers.append(
GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)
)
for _ in range(num_layers - 1):
self.layers.append(
GraphConv(hidden_units, hidden_units, allow_zero_in_degree=True)
)
self.layers.append(GraphConv(hidden_units, 1, allow_zero_in_degree=True))
elif gcn_type == "sage":
self.layers.append(
SAGEConv(initial_dim, hidden_units, aggregator_type="gcn")
)
for _ in range(num_layers - 1):
self.layers.append(
SAGEConv(hidden_units, hidden_units, aggregator_type="gcn")
)
self.layers.append(SAGEConv(hidden_units, 1, aggregator_type="gcn"))
else:
raise ValueError("Gcn type error.")
self.pooling = SortPooling(k=k)
conv1d_channels = [16, 32]
total_latent_dim = hidden_units * num_layers + 1
conv1d_kws = [total_latent_dim, 5]
self.conv_1 = nn.Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0])
self.maxpool1d = nn.MaxPool1d(2, 2)
self.conv_2 = nn.Conv1d(
conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1
)
dense_dim = int((k - 2) / 2 + 1)
dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
self.linear_1 = nn.Linear(dense_dim, 128)
self.linear_2 = nn.Linear(128, 1)
def forward(self, g, z, node_id=None, edge_id=None):
"""
Args:
g(DGLGraph): the graph
z(Tensor): node labeling tensor, shape [N, 1]
node_id(Tensor, optional): node id tensor, shape [N, 1]
edge_id(Tensor, optional): edge id tensor, shape [E, 1]
Returns:
x(Tensor): output tensor
"""
z_emb = self.z_embedding(z)
if self.use_attribute:
x = self.node_attributes_lookup(node_id)
x = torch.cat([z_emb, x], 1)
else:
x = z_emb
if self.use_edge_weight:
edge_weight = self.edge_weights_lookup(edge_id)
else:
edge_weight = None
if self.use_embedding:
n_emb = self.node_embedding(node_id)
x = torch.cat([x, n_emb], 1)
xs = [x]
for layer in self.layers:
out = torch.tanh(layer(g, xs[-1], edge_weight=edge_weight))
xs += [out]
x = torch.cat(xs[1:], dim=-1)
# SortPooling
x = self.pooling(g, x)
x = x.unsqueeze(1)
x = F.relu(self.conv_1(x))
x = self.maxpool1d(x)
x = F.relu(self.conv_2(x))
x = x.view(x.size(0), -1)
x = F.relu(self.linear_1(x))
F.dropout(x, p=self.dropout, training=self.training)
x = self.linear_2(x)
return x
def parse_arguments():
"""
Parse arguments
"""
parser = argparse.ArgumentParser(description="SEAL")
parser.add_argument("--dataset", type=str, default="ogbl-ddi")
parser.add_argument("--gpu_id", type=int, default=0)
parser.add_argument("--hop", type=int, default=1)
parser.add_argument("--model", type=str, default="dgcnn")
parser.add_argument("--gcn_type", type=str, default="gcn")
parser.add_argument("--num_layers", type=int, default=3)
parser.add_argument("--hidden_units", type=int, default=32)
parser.add_argument("--sort_k", type=int, default=30)
parser.add_argument("--pooling", type=str, default="sum")
parser.add_argument("--dropout", type=str, default=0.5)
parser.add_argument("--hits_k", type=int, default=50)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--neg_samples", type=int, default=1)
parser.add_argument("--subsample_ratio", type=float, default=0.1)
parser.add_argument("--epochs", type=int, default=60)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_steps", type=int, default=5)
parser.add_argument("--num_workers", type=int, default=32)
parser.add_argument("--random_seed", type=int, default=2021)
parser.add_argument("--save_dir", type=str, default="./processed")
args = parser.parse_args()
return args
def load_ogb_dataset(dataset):
"""
Load OGB dataset
Args:
dataset(str): name of dataset (ogbl-collab, ogbl-ddi, ogbl-citation)
Returns:
graph(DGLGraph): graph
split_edge(dict): split edge
"""
dataset = DglLinkPropPredDataset(name=dataset)
split_edge = dataset.get_edge_split()
graph = dataset[0]
return graph, split_edge
def drnl_node_labeling(subgraph, src, dst):
"""
Double Radius Node labeling
d = r(i,u)+r(i,v)
label = 1+ min(r(i,u),r(i,v))+ (d//2)*(d//2+d%2-1)
Isolated nodes in subgraph will be set as zero.
Extreme large graph may cause memory error.
Args:
subgraph(DGLGraph): The graph
src(int): node id of one of src node in new subgraph
dst(int): node id of one of dst node in new subgraph
Returns:
z(Tensor): node labeling tensor
"""
adj = subgraph.adj_external().to_dense().numpy()
src, dst = (dst, src) if src > dst else (src, dst)
idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
adj_wo_src = adj[idx, :][:, idx]
idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
adj_wo_dst = adj[idx, :][:, idx]
dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)
dist2src = np.insert(dist2src, dst, 0, axis=0)
dist2src = torch.from_numpy(dist2src)
dist2dst = shortest_path(
adj_wo_src, directed=False, unweighted=True, indices=dst - 1
)
dist2dst = np.insert(dist2dst, src, 0, axis=0)
dist2dst = torch.from_numpy(dist2dst)
dist = dist2src + dist2dst
dist_over_2, dist_mod_2 = dist // 2, dist % 2
z = 1 + torch.min(dist2src, dist2dst)
z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
z[src] = 1.0
z[dst] = 1.0
z[torch.isnan(z)] = 0.0
return z.to(torch.long)
def evaluate_hits(name, pos_pred, neg_pred, K):
"""
Compute hits
Args:
name(str): name of dataset
pos_pred(Tensor): predict value of positive edges
neg_pred(Tensor): predict value of negative edges
K(int): num of hits
Returns:
hits(float): score of hits
"""
evaluator = Evaluator(name)
evaluator.K = K
hits = evaluator.eval(
{
"y_pred_pos": pos_pred,
"y_pred_neg": neg_pred,
}
)[f"hits@{K}"]
return hits
class GraphDataSet(Dataset):
"""
GraphDataset for torch DataLoader
"""
def __init__(self, graph_list, tensor):
self.graph_list = graph_list
self.tensor = tensor
def __len__(self):
return len(self.graph_list)
def __getitem__(self, index):
return (self.graph_list[index], self.tensor[index])
class PosNegEdgesGenerator(object):
"""
Generate positive and negative samples
Attributes:
g(dgl.DGLGraph): graph
split_edge(dict): split edge
neg_samples(int): num of negative samples per positive sample
subsample_ratio(float): ratio of subsample
shuffle(bool): if shuffle generated graph list
"""
def __init__(self, g, split_edge, neg_samples=1, subsample_ratio=0.1, shuffle=True):
self.neg_sampler = Uniform(neg_samples)
self.subsample_ratio = subsample_ratio
self.split_edge = split_edge
self.g = g
self.shuffle = shuffle
def __call__(self, split_type):
if split_type == "train":
subsample_ratio = self.subsample_ratio
else:
subsample_ratio = 1
pos_edges = self.split_edge[split_type]["edge"]
if split_type == "train":
# Adding self loop in train avoids sampling the source node itself.
g = add_self_loop(self.g)
eids = g.edge_ids(pos_edges[:, 0], pos_edges[:, 1])
neg_edges = torch.stack(self.neg_sampler(g, eids), dim=1)
else:
neg_edges = self.split_edge[split_type]["edge_neg"]
pos_edges = self.subsample(pos_edges, subsample_ratio).long()
neg_edges = self.subsample(neg_edges, subsample_ratio).long()
edges = torch.cat([pos_edges, neg_edges])
labels = torch.cat(
[
torch.ones(pos_edges.size(0), 1),
torch.zeros(neg_edges.size(0), 1),
]
)
if self.shuffle:
perm = torch.randperm(edges.size(0))
edges = edges[perm]
labels = labels[perm]
return edges, labels
def subsample(self, edges, subsample_ratio):
"""
Subsample generated edges.
Args:
edges(Tensor): edges to subsample
subsample_ratio(float): ratio of subsample
Returns:
edges(Tensor): edges
"""
num_edges = edges.size(0)
perm = torch.randperm(num_edges)
perm = perm[: int(subsample_ratio * num_edges)]
edges = edges[perm]
return edges
class EdgeDataSet(Dataset):
"""
Assistant Dataset for speeding up the SEALSampler
"""
def __init__(self, edges, labels, transform):
self.edges = edges
self.transform = transform
self.labels = labels
def __len__(self):
return len(self.edges)
def __getitem__(self, index):
subgraph = self.transform(self.edges[index])
return (subgraph, self.labels[index])
class SEALSampler(object):
"""
Sampler for SEAL in paper(no-block version)
The strategy is to sample all the k-hop neighbors around the two target nodes.
Attributes:
graph(DGLGraph): The graph
hop(int): num of hop
num_workers(int): num of workers
"""
def __init__(self, graph, hop=1, num_workers=32, print_fn=print):
self.graph = graph
self.hop = hop
self.print_fn = print_fn
self.num_workers = num_workers
def sample_subgraph(self, target_nodes):
"""
Args:
target_nodes(Tensor): Tensor of two target nodes
Returns:
subgraph(DGLGraph): subgraph
"""
sample_nodes = [target_nodes]
frontiers = target_nodes
for _ in range(self.hop):
frontiers = self.graph.out_edges(frontiers)[1]
frontiers = torch.unique(frontiers)
sample_nodes.append(frontiers)
sample_nodes = torch.cat(sample_nodes)
sample_nodes = torch.unique(sample_nodes)
subgraph = dgl.node_subgraph(self.graph, sample_nodes)
# Each node should have unique node id in the new subgraph
u_id = int(
torch.nonzero(subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False)
)
v_id = int(
torch.nonzero(subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False)
)
# remove link between target nodes in positive subgraphs.
if subgraph.has_edges_between(u_id, v_id):
link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2]
subgraph.remove_edges(link_id)
if subgraph.has_edges_between(v_id, u_id):
link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2]
subgraph.remove_edges(link_id)
z = drnl_node_labeling(subgraph, u_id, v_id)
subgraph.ndata["z"] = z
return subgraph
def _collate(self, batch):
batch_graphs, batch_labels = map(list, zip(*batch))
batch_graphs = dgl.batch(batch_graphs)
batch_labels = torch.stack(batch_labels)
return batch_graphs, batch_labels
def __call__(self, edges, labels):
subgraph_list = []
labels_list = []
edge_dataset = EdgeDataSet(edges, labels, transform=self.sample_subgraph)
self.print_fn(f"Using {self.num_workers} workers in sampling job.")
sampler = DataLoader(
edge_dataset,
batch_size=32,
num_workers=self.num_workers,
shuffle=False,
collate_fn=self._collate,
)
for subgraph, label in tqdm(sampler, ncols=100):
label_copy = deepcopy(label)
subgraph = dgl.unbatch(subgraph)
del label
subgraph_list += subgraph
labels_list.append(label_copy)
return subgraph_list, torch.cat(labels_list)
class SEALData(object):
"""
1. Generate positive and negative samples
2. Subgraph sampling
Attributes:
g(dgl.DGLGraph): graph
split_edge(dict): split edge
hop(int): num of hop
neg_samples(int): num of negative samples per positive sample
subsample_ratio(float): ratio of subsample
use_coalesce(bool): True for coalesce graph. Graph with multi-edge need to coalesce
"""
def __init__(
self,
g,
split_edge,
hop=1,
neg_samples=1,
subsample_ratio=1,
prefix=None,
save_dir=None,
num_workers=32,
shuffle=True,
use_coalesce=True,
print_fn=print,
):
self.g = g
self.hop = hop
self.subsample_ratio = subsample_ratio
self.prefix = prefix
self.save_dir = save_dir
self.print_fn = print_fn
self.generator = PosNegEdgesGenerator(
g=self.g,
split_edge=split_edge,
neg_samples=neg_samples,
subsample_ratio=subsample_ratio,
shuffle=shuffle,
)
if use_coalesce:
for k, v in g.edata.items():
g.edata[k] = v.float() # dgl.to_simple() requires data is float
self.g = dgl.to_simple(
g, copy_ndata=True, copy_edata=True, aggregator="sum"
)
self.ndata = {k: v for k, v in self.g.ndata.items()}
self.edata = {k: v for k, v in self.g.edata.items()}
self.g.ndata.clear()
self.g.edata.clear()
self.print_fn("Save ndata and edata in class.")
self.print_fn("Clear ndata and edata in graph.")
self.sampler = SEALSampler(
graph=self.g, hop=hop, num_workers=num_workers, print_fn=print_fn
)
def __call__(self, split_type):
if split_type == "train":
subsample_ratio = self.subsample_ratio
else:
subsample_ratio = 1
path = osp.join(
self.save_dir or "",
f"{self.prefix}_{split_type}_{self.hop}-hop_{subsample_ratio}-subsample.bin",
)
if osp.exists(path):
self.print_fn(f"Load existing processed {split_type} files")
graph_list, data = dgl.load_graphs(path)
dataset = GraphDataSet(graph_list, data["labels"])
else:
self.print_fn(f"Processed {split_type} files not exist.")
edges, labels = self.generator(split_type)
self.print_fn(f"Generate {edges.size(0)} edges totally.")
graph_list, labels = self.sampler(edges, labels)
dataset = GraphDataSet(graph_list, labels)
dgl.save_graphs(path, graph_list, {"labels": labels})
self.print_fn(f"Save preprocessed subgraph to {path}")
return dataset
def _transform_log_level(str_level):
if str_level == "info":
return logging.INFO
elif str_level == "warning":
return logging.WARNING
elif str_level == "critical":
return logging.CRITICAL
elif str_level == "debug":
return logging.DEBUG
elif str_level == "error":
return logging.ERROR
else:
raise KeyError("Log level error")
class LightLogging(object):
def __init__(self, log_path=None, log_name="lightlog", log_level="debug"):
log_level = _transform_log_level(log_level)
if log_path:
if not log_path.endswith("/"):
log_path += "/"
if not os.path.exists(log_path):
os.mkdir(log_path)
if log_name.endswith("-") or log_name.endswith("_"):
log_name = (
log_path
+ log_name
+ time.strftime("%Y-%m-%d-%H:%M", time.localtime(time.time()))
+ ".log"
)
else:
log_name = (
log_path
+ log_name
+ "_"
+ time.strftime("%Y-%m-%d-%H-%M", time.localtime(time.time()))
+ ".log"
)
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d-%H:%M",
handlers=[
logging.FileHandler(log_name, mode="w"),
logging.StreamHandler(),
],
)
logging.info("Start Logging")
logging.info("Log file path: %s", log_name)
else:
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d-%H:%M",
handlers=[logging.StreamHandler()],
)
logging.info("Start Logging")
def debug(self, msg):
logging.debug(msg)
def info(self, msg):
logging.info(msg)
def critical(self, msg):
logging.critical(msg)
def warning(self, msg):
logging.warning(msg)
def error(self, msg):
logging.error(msg)
def data_prepare(graph, split_edge):
seal_data = SEALData(
g=graph,
split_edge=split_edge,
hop=1,
neg_samples=1,
subsample_ratio=0.1,
use_coalesce=True,
prefix="ogbl-collab",
save_dir="./processed",
num_workers=32,
print_fn=print,
)
node_attribute = seal_data.ndata["feat"]
edge_weight = seal_data.edata["weight"].float()
return node_attribute, edge_weight