blob: a09d78ac4026199bf5a0d4883040858141f7935e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
DiffPool (Differentiable Pooling)
References
----------
Paper: https://arxiv.org/abs/1806.08804
Author's code: https://github.com/RexYing/diffpool
Ref DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/diffpool
"""
import dgl
import numpy as np
import torch
import torch.nn.functional as F
from dgl.nn.pytorch import SAGEConv
from torch import nn
class DiffPool(nn.Module):
def __init__(
self,
n_in_feats,
n_out_feats,
max_n_nodes,
n_hidden=64,
n_embedding=64,
n_layers=3,
dropout=0.0,
n_pooling=1,
aggregator_type="mean",
pool_ratio=0.1,
concat=False,
):
super(DiffPool, self).__init__()
self.link_pred = True
self.concat = concat
self.n_pooling = n_pooling
self.link_pred_loss = []
self.entropy_loss = []
# list of GNN modules before the first diffpool operation
self.gc_before_pool = nn.ModuleList()
self.diffpool_layers = nn.ModuleList()
# list of GNN modules, each list after one diffpool operation
self.gc_after_pool = nn.ModuleList()
self.assign_dim = int(max_n_nodes * pool_ratio)
self.num_aggs = 1
# constructing layers before diffpool
assert n_layers >= 3, "n_layers too few"
self.gc_before_pool.append(
SAGEConv(n_in_feats, n_hidden, aggregator_type, feat_drop=dropout, activation=F.relu)
)
for _ in range(n_layers - 2):
self.gc_before_pool.append(
SAGEConv(n_hidden, n_hidden, aggregator_type, feat_drop=dropout, activation=F.relu)
)
self.gc_before_pool.append(
SAGEConv(n_hidden, n_embedding, aggregator_type, feat_drop=dropout, activation=None)
)
assign_dims = [self.assign_dim]
if self.concat:
# diffpool layer receive pool_embedding_dim node feature tensor
# and return pool_embedding_dim node embedding
pool_embedding_dim = n_hidden * (n_layers - 1) + n_embedding
else:
pool_embedding_dim = n_embedding
self.first_diffpool_layer = _DiffPoolBatchedGraphLayer(
pool_embedding_dim,
self.assign_dim,
n_hidden,
dropout,
aggregator_type,
self.link_pred,
)
gc_after_per_pool = nn.ModuleList()
for _ in range(n_layers - 1):
gc_after_per_pool.append(_BatchedGraphSAGE(n_hidden, n_hidden))
gc_after_per_pool.append(_BatchedGraphSAGE(n_hidden, n_embedding))
self.gc_after_pool.append(gc_after_per_pool)
self.assign_dim = int(self.assign_dim * pool_ratio)
# each pooling module
for _ in range(n_pooling - 1):
self.diffpool_layers.append(
_BatchedDiffPool(pool_embedding_dim, self.assign_dim, n_hidden, self.link_pred)
)
gc_after_per_pool = nn.ModuleList()
for _ in range(n_layers - 1):
gc_after_per_pool.append(_BatchedGraphSAGE(n_hidden, n_hidden))
gc_after_per_pool.append(_BatchedGraphSAGE(n_hidden, n_embedding))
self.gc_after_pool.append(gc_after_per_pool)
assign_dims.append(self.assign_dim)
self.assign_dim = int(self.assign_dim * pool_ratio)
# predicting layer
if self.concat:
self.pred_input_dim = pool_embedding_dim * self.num_aggs * (n_pooling + 1)
else:
self.pred_input_dim = n_embedding * self.num_aggs
self.pred_layer = nn.Linear(self.pred_input_dim, n_out_feats)
# weight initialization
for m in self.modules():
if isinstance(m, nn.Linear):
m.weight.data = nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain("relu"))
if m.bias is not None:
m.bias.data = nn.init.constant_(m.bias.data, 0.0)
def forward(self, g, feat):
self.link_pred_loss = []
self.entropy_loss = []
h = feat
# node feature for assignment matrix computation is the same as the
# original node feature
out_all = []
# we use GCN blocks to get an embedding first
g_embedding = _gcn_forward(g, h, self.gc_before_pool, self.concat)
g.ndata["h"] = g_embedding
readout = dgl.sum_nodes(g, "h")
out_all.append(readout)
if self.num_aggs == 2:
readout = dgl.max_nodes(g, "h")
out_all.append(readout)
adj, h = self.first_diffpool_layer(g, g_embedding)
node_per_pool_graph = int(adj.size()[0] / len(g.batch_num_nodes()))
h, adj = _batch2tensor(adj, h, node_per_pool_graph)
h = _gcn_forward_tensorized(h, adj, self.gc_after_pool[0], self.concat)
readout = torch.sum(h, dim=1)
out_all.append(readout)
if self.num_aggs == 2:
readout, _ = torch.max(h, dim=1)
out_all.append(readout)
for i, diffpool_layer in enumerate(self.diffpool_layers):
h, adj = diffpool_layer(h, adj)
h = _gcn_forward_tensorized(h, adj, self.gc_after_pool[i + 1], self.concat)
readout = torch.sum(h, dim=1)
out_all.append(readout)
if self.num_aggs == 2:
readout, _ = torch.max(h, dim=1)
out_all.append(readout)
if self.concat or self.num_aggs > 1:
final_readout = torch.cat(out_all, dim=1)
else:
final_readout = readout
ypred = self.pred_layer(final_readout)
return ypred
def loss(self, pred, label):
# softmax + CE
criterion = nn.CrossEntropyLoss()
loss = criterion(pred, label)
for _, value in self.first_diffpool_layer.loss_log.items():
loss += value
for diffpool_layer in self.diffpool_layers:
for _, value in diffpool_layer.loss_log.items():
loss += value
return loss
class _BatchedGraphSAGE(nn.Module):
def __init__(self, n_feat_in, n_feat_out, mean=False, add_self=False):
super().__init__()
self.bn = None
self.add_self = add_self
self.mean = mean
self.w = nn.Linear(n_feat_in, n_feat_out, bias=True)
nn.init.xavier_uniform_(self.w.weight, gain=nn.init.calculate_gain("relu"))
def forward(self, x, adj):
num_node_per_graph = adj.size(1)
self.bn = nn.BatchNorm1d(num_node_per_graph).to(adj.device)
if self.add_self:
adj = adj + torch.eye(num_node_per_graph).to(adj.device)
if self.mean:
adj = adj / adj.sum(-1, keepdim=True)
h_k_n = torch.matmul(adj, x)
h_k = self.w(h_k_n)
h_k = F.normalize(h_k, dim=2, p=2)
h_k = F.relu(h_k)
if self.bn is not None:
h_k = self.bn(h_k)
return h_k
class _DiffPoolAssignment(nn.Module):
def __init__(self, n_feat, n_next):
super().__init__()
self.assign_mat = _BatchedGraphSAGE(n_feat, n_next)
def forward(self, x, adj):
s_l_init = self.assign_mat(x, adj)
s_l = F.softmax(s_l_init, dim=-1)
return s_l
class _BatchedDiffPool(nn.Module):
def __init__(self, n_feat, n_next, n_hid, link_pred=False, entropy=True):
super(_BatchedDiffPool, self).__init__()
self.link_pred = link_pred
self.link_pred_layer = _LinkPredLoss()
self.embed = _BatchedGraphSAGE(n_feat, n_hid)
self.assign = _DiffPoolAssignment(n_feat, n_next)
self.reg_loss = nn.ModuleList([])
self.loss_log = {}
if link_pred:
self.reg_loss.append(_LinkPredLoss())
if entropy:
self.reg_loss.append(_EntropyLoss())
def forward(self, x, adj):
z_l = self.embed(x, adj)
s_l = self.assign(x, adj)
x_next = torch.matmul(s_l.transpose(-1, -2), z_l)
a_next = (s_l.transpose(-1, -2)).matmul(adj).matmul(s_l)
for loss_layer in self.reg_loss:
loss_name = str(type(loss_layer).__name__)
self.loss_log[loss_name] = loss_layer(adj, a_next, s_l)
return x_next, a_next
class _DiffPoolBatchedGraphLayer(nn.Module):
def __init__(
self,
input_dim,
assign_dim,
output_feat_dim,
dropout,
aggregator_type,
link_pred,
):
super(_DiffPoolBatchedGraphLayer, self).__init__()
self.embedding_dim = input_dim
self.assign_dim = assign_dim
self.hidden_dim = output_feat_dim
self.link_pred = link_pred
self.feat_gc = SAGEConv(
input_dim,
output_feat_dim,
aggregator_type,
feat_drop=dropout,
activation=F.relu,
)
self.pool_gc = SAGEConv(
input_dim,
assign_dim,
aggregator_type,
feat_drop=dropout,
activation=F.relu,
)
self.reg_loss = nn.ModuleList([])
self.loss_log = {}
self.reg_loss.append(_EntropyLoss())
def forward(self, g, h):
feat = self.feat_gc(g, h) # size = (sum_N, F_out), sum_N is num of nodes in this batch
device = feat.device
assign_tensor = self.pool_gc(g, h) # size = (sum_N, N_a), N_a is num of nodes in pooled graph.
assign_tensor = F.softmax(assign_tensor, dim=1)
assign_tensor = torch.split(assign_tensor, g.batch_num_nodes().tolist())
assign_tensor = torch.block_diag(*assign_tensor) # size = (sum_N, batch_size * N_a)
h = torch.matmul(torch.t(assign_tensor), feat)
adj = g.adj_external(transpose=True, ctx=device)
adj_dense = adj.to_dense()
adj_new = torch.mm(torch.t(assign_tensor), torch.mm(adj_dense, assign_tensor))
if self.link_pred:
current_lp_loss = torch.norm(adj_dense - torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(
g.num_nodes(), 2
)
self.loss_log["LinkPredLoss"] = current_lp_loss
for loss_layer in self.reg_loss:
loss_name = str(type(loss_layer).__name__)
self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor)
return adj_new, h
class _EntropyLoss(nn.Module):
# Return Scalar
def forward(self, adj, a_next, s_l):
entropy = (torch.distributions.Categorical(probs=s_l).entropy()).sum(-1).mean(-1)
assert not torch.isnan(entropy)
return entropy
class _LinkPredLoss(nn.Module):
def forward(self, adj, a_next, s_l):
link_pred_loss = (adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2))
link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2))
return link_pred_loss.mean()
def _batch2tensor(batch_adj, batch_feat, node_per_pool_graph):
"""
transform a batched graph to batched adjacency tensor and node feature tensor
"""
batch_size = int(batch_adj.size()[0] / node_per_pool_graph)
adj_list = []
feat_list = []
for i in range(batch_size):
start = i * node_per_pool_graph
end = (i + 1) * node_per_pool_graph
adj_list.append(batch_adj[start:end, start:end])
feat_list.append(batch_feat[start:end, :])
adj_list = list(map(lambda x: torch.unsqueeze(x, 0), adj_list))
feat_list = list(map(lambda x: torch.unsqueeze(x, 0), feat_list))
adj = torch.cat(adj_list, dim=0)
feat = torch.cat(feat_list, dim=0)
return feat, adj
def _masked_softmax(matrix, mask, dim=-1, memory_efficient=True, mask_fill_value=-1e32):
"""
Code snippet contributed by AllenNLP (https://github.com/allenai/allennlp)
"""
if mask is None:
result = torch.nn.functional.softmax(matrix, dim=dim)
else:
mask = mask.float()
while mask.dim() < matrix.dim():
mask = mask.unsqueeze(1)
if not memory_efficient:
result = torch.nn.functional.softmax(matrix * mask, dim=dim)
result = result * mask
result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
else:
masked_matrix = matrix.masked_fill((1 - mask).byte(), mask_fill_value)
result = torch.nn.functional.softmax(masked_matrix, dim=dim)
return result
def _gcn_forward(g, h, gc_layers, cat=False):
block_readout = []
for gc_layer in gc_layers[:-1]:
h = gc_layer(g, h)
block_readout.append(h)
h = gc_layers[-1](g, h)
block_readout.append(h)
if cat:
block = torch.cat(block_readout, dim=1) # N x F, F = F1 + F2 + ...
else:
block = h
return block
def _gcn_forward_tensorized(h, adj, gc_layers, cat=False):
block_readout = []
for gc_layer in gc_layers:
h = gc_layer(h, adj)
block_readout.append(h)
if cat:
block = torch.cat(block_readout, dim=2) # N x F, F = F1 + F2 + ...
else:
block = h
return block