blob: 76ca096564dcb6f56878696bffb4a80307f9ea4a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.right (c) 2024 by jinsong, All Rights Reserved.
# pylint: disable=E1101
"""
auto-regressive moving average (ARMA)
References
----------
Paper: https://arxiv.org/abs/1901.01343
Author's code:
DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/arma
"""
import math
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
def glorot(tensor):
if tensor is not None:
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
tensor.data.uniform_(-stdv, stdv)
def zeros(tensor):
if tensor is not None:
tensor.data.fill_(0)
class ARMAConv(nn.Module):
def __init__(
self,
in_dim,
out_dim,
num_stacks,
num_layers,
activation=None,
dropout=0.0,
bias=True,
):
super(ARMAConv, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.K = num_stacks
self.T = num_layers
self.activation = activation
self.dropout = nn.Dropout(p=dropout)
# init weight
self.w_0 = nn.ModuleDict(
{str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)}
)
# deeper weight
self.w = nn.ModuleDict(
{str(k): nn.Linear(out_dim, out_dim, bias=False) for k in range(self.K)}
)
# v
self.v = nn.ModuleDict(
{str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)}
)
# bias
if bias:
self.bias = nn.Parameter(torch.Tensor(self.K, self.T, 1, self.out_dim))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
for k in range(self.K):
glorot(self.w_0[str(k)].weight)
glorot(self.w[str(k)].weight)
glorot(self.v[str(k)].weight)
zeros(self.bias)
def forward(self, g, feats):
with g.local_scope():
init_feats = feats
# assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()
degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)
output = []
for k in range(self.K):
feats = init_feats
for t in range(self.T):
feats = feats * norm
g.ndata["h"] = feats
g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
feats = g.ndata.pop("h")
feats = feats * norm
if t == 0:
feats = self.w_0[str(k)](feats)
else:
feats = self.w[str(k)](feats)
feats += self.dropout(self.v[str(k)](init_feats))
feats += self.v[str(k)](self.dropout(init_feats))
if self.bias is not None:
feats += self.bias[k][t]
if self.activation is not None:
feats = self.activation(feats)
output.append(feats)
return torch.stack(output).mean(dim=0)
class ARMA4NC(nn.Module):
def __init__(
self,
in_dim,
hid_dim,
out_dim,
num_stacks,
num_layers,
activation=None,
dropout=0.0,
):
super(ARMA4NC, self).__init__()
self.conv1 = ARMAConv(
in_dim=in_dim,
out_dim=hid_dim,
num_stacks=num_stacks,
num_layers=num_layers,
activation=activation,
dropout=dropout,
)
self.conv2 = ARMAConv(
in_dim=hid_dim,
out_dim=out_dim,
num_stacks=num_stacks,
num_layers=num_layers,
activation=activation,
dropout=dropout,
)
self.dropout = nn.Dropout(p=dropout)
self.criterion = nn.CrossEntropyLoss()
def forward(self, g, feats):
feats = F.relu(self.conv1(g, feats))
feats = self.dropout(feats)
feats = self.conv2(g, feats)
return feats
def loss(self, logits, labels):
return self.criterion(logits, labels)
def inference(self, graph, feats):
return self.forward(graph, feats)