blob: c455cc1f9fc19fbfc1469fb04dec26d484c1b63f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C0103
"""
Deep Adaptive Graph Neural Network (DAGNN)
References
----------
Paper: https://arxiv.org/abs/2007.09296
Author's code: https://github.com/divelab/DeeperGNN
DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/dagnn
"""
import dgl.function as fn
import torch
from torch import nn
from torch.nn import functional as F, Parameter
class DAGNNConv(nn.Module):
def __init__(self, in_dim, k):
super(DAGNNConv, self).__init__()
self.s = Parameter(torch.FloatTensor(in_dim, 1))
self.k = k
self.reset_parameters()
def reset_parameters(self):
gain = nn.init.calculate_gain("sigmoid")
nn.init.xavier_uniform_(self.s, gain=gain)
def forward(self, graph, feats):
with graph.local_scope():
results = [feats]
degs = graph.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm = norm.to(feats.device).unsqueeze(1)
for _ in range(self.k):
feats = feats * norm
graph.ndata["h"] = feats
graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) # pylint: disable=E1101
feats = graph.ndata["h"]
feats = feats * norm
results.append(feats)
H = torch.stack(results, dim=1)
S = F.sigmoid(torch.matmul(H, self.s))
S = S.permute(0, 2, 1)
H = torch.matmul(S, H).squeeze()
return H
class MLPLayer(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, activation=None, dropout=0):
super(MLPLayer, self).__init__()
self.linear = nn.Linear(in_dim, out_dim, bias=bias)
self.activation = activation
self.dropout = nn.Dropout(dropout)
self.reset_parameters()
def reset_parameters(self):
gain = 1.0
if self.activation is F.relu:
gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(self.linear.weight, gain=gain)
if self.linear.bias is not None:
nn.init.zeros_(self.linear.bias)
def forward(self, feats):
feats = self.dropout(feats)
feats = self.linear(feats)
if self.activation:
feats = self.activation(feats)
return feats
class DAGNN(nn.Module):
def __init__(
self,
k,
in_dim,
hid_dim,
out_dim,
bias=True,
activation=F.relu,
dropout=0,
):
super(DAGNN, self).__init__()
self.mlp = nn.ModuleList()
self.mlp.append(
MLPLayer(
in_dim=in_dim,
out_dim=hid_dim,
bias=bias,
activation=activation,
dropout=dropout,
)
)
self.mlp.append(
MLPLayer(
in_dim=hid_dim,
out_dim=out_dim,
bias=bias,
activation=None,
dropout=dropout,
)
)
self.dagnn = DAGNNConv(in_dim=out_dim, k=k)
self.criterion = nn.CrossEntropyLoss()
def forward(self, graph, feats):
for layer in self.mlp:
feats = layer(feats)
feats = self.dagnn(graph, feats)
return feats
def loss(self, logits, labels):
return self.criterion(logits, labels)
def inference(self, graph, feats):
return self.forward(graph, feats)