blob: 288a434b88a5b594c96fcba3f66130e1d3fcdbfb [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C0103,R1705.R1734
"""
Bootstrapped Graph Latents (BGRL)
References
----------
Paper: https://arxiv.org/abs/2102.06514
Author's code: https://github.com/nerdslab/bgrl
DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/bgrl
"""
import copy
import torch
from torch import nn
from torch.nn import BatchNorm1d
from torch.nn.functional import cosine_similarity
import dgl
from dgl.nn.pytorch.conv import GraphConv
from dgl.transforms import Compose, DropEdge, FeatMask
import numpy as np
class MLP_Predictor(nn.Module):
r"""MLP used for predictor. The MLP has one hidden layer.
Args:
input_size (int): Size of input features.
output_size (int): Size of output features.
hidden_size (int, optional): Size of hidden layer. (default: :obj:`4096`).
"""
def __init__(self, input_size, output_size, hidden_size=512):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_size, hidden_size, bias=True),
nn.PReLU(1),
nn.Linear(hidden_size, output_size, bias=True),
)
self.reset_parameters()
def forward(self, x):
return self.net(x)
def reset_parameters(self):
# kaiming_uniform
for m in self.modules():
if isinstance(m, nn.Linear):
m.reset_parameters()
class GCN(nn.Module):
def __init__(self, layer_sizes, batch_norm_mm=0.99):
super(GCN, self).__init__()
self.layers = nn.ModuleList()
for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:]):
self.layers.append(GraphConv(in_dim, out_dim))
self.layers.append(BatchNorm1d(out_dim, momentum=batch_norm_mm))
self.layers.append(nn.PReLU())
def forward(self, g, feats):
x = feats
for layer in self.layers:
if isinstance(layer, GraphConv):
x = layer(g, x)
else:
x = layer(x)
return x
def reset_parameters(self):
for layer in self.layers:
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
class BGRL(nn.Module):
r"""BGRL architecture for Graph representation learning.
Args:
encoder (torch.nn.Module): Encoder network to be duplicated and used in both online and target networks.
predictor (torch.nn.Module): Predictor network used to predict the target projection from the online projection.
.. note::
`encoder` must have a `reset_parameters` method, as the weights of the target network will be initialized
differently from the online network.
"""
def __init__(self, encoder, predictor):
super(BGRL, self).__init__()
# online network
self.online_encoder = encoder
self.predictor = predictor
# target network
self.target_encoder = copy.deepcopy(encoder)
# reinitialize weights
self.target_encoder.reset_parameters()
# stop gradient
for param in self.target_encoder.parameters():
param.requires_grad = False
def trainable_parameters(self):
r"""Returns the parameters that will be updated via an optimizer."""
return list(self.online_encoder.parameters()) + list(
self.predictor.parameters()
)
@torch.no_grad()
def update_target_network(self, mm):
r"""Performs a momentum update of the target network's weights.
Args:
mm (float): Momentum used in moving average update.
"""
for param_q, param_k in zip(
self.online_encoder.parameters(), self.target_encoder.parameters()
):
param_k.data.mul_(mm).add_(param_q.data, alpha=1.0 - mm)
def forward(self, graph, feat):
transform_1 = get_graph_drop_transform(
drop_edge_p=0.3, feat_mask_p=0.3
)
transform_2 = get_graph_drop_transform(
drop_edge_p=0.2, feat_mask_p=0.4
)
online_x = transform_1(graph)
target_x = transform_2(graph)
online_x, target_x = dgl.add_self_loop(online_x), dgl.add_self_loop(target_x)
online_feats, target_feats = online_x.ndata["feat"], target_x.ndata["feat"]
# forward online network
online_y1 = self.online_encoder(online_x, online_feats)
# prediction
online_q1 = self.predictor(online_y1)
# forward target network
with torch.no_grad():
target_y1 = self.target_encoder(target_x, target_feats).detach()
# forward online network 2
online_y2 = self.online_encoder(target_x, target_feats)
# prediction
online_q2 = self.predictor(online_y2)
# forward target network
with torch.no_grad():
target_y2 = self.target_encoder(online_x, online_feats).detach()
loss = (
2
- cosine_similarity(online_q1, target_y1.detach(), dim=-1).mean() # pylint: disable=E1102
- cosine_similarity(online_q2, target_y2.detach(), dim=-1).mean() # pylint: disable=E1102
)
return loss
def get_embedding(self, graph, feats):
"""
Get the node embeddings from the encoder without computing gradients.
Parameters
----------
graph : dgl.DGLGraph
The input graph.
feats : torch.Tensor
Node features.
Returns
-------
torch.Tensor
Node embeddings.
"""
h = self.target_encoder(graph, feats) # Encode the node features with GCN
return h.detach() # Detach from computation graph for evaluation
def compute_representations(net, dataset, device):
r"""Pre-computes the representations for the entire data.
Returns:
[torch.Tensor, torch.Tensor]: Representations and labels.
"""
net.eval()
reps = []
labels = []
if len(dataset) == 1:
g = dataset[0]
g = dgl.add_self_loop(g)
g = g.to(device)
with torch.no_grad():
reps.append(net(g))
labels.append(g.ndata["label"])
else:
for g in dataset:
# forward
g = g.to(device)
with torch.no_grad():
reps.append(net(g))
labels.append(g.ndata["label"])
reps = torch.cat(reps, dim=0)
labels = torch.cat(labels, dim=0)
return [reps, labels]
class CosineDecayScheduler:
def __init__(self, max_val, warmup_steps, total_steps):
self.max_val = max_val
self.warmup_steps = warmup_steps
self.total_steps = total_steps
def get(self, step):
if step < self.warmup_steps:
return self.max_val * step / self.warmup_steps
elif self.warmup_steps <= step <= self.total_steps:
return (
self.max_val
* (
1
+ np.cos(
(step - self.warmup_steps)
* np.pi
/ (self.total_steps - self.warmup_steps)
)
)
/ 2
)
else:
raise ValueError(
f"Step ({step}) > total number of steps ({self.total_steps})."
)
def get_graph_drop_transform(drop_edge_p, feat_mask_p):
transforms = list()
# make copy of graph
transforms.append(copy.deepcopy)
# drop edges
if drop_edge_p > 0.0:
transforms.append(DropEdge(drop_edge_p))
# drop features
if feat_mask_p > 0.0:
transforms.append(FeatMask(feat_mask_p, node_feat_names=["feat"]))
return Compose(transforms)