blob: f80230e15788967d01868dfa3d6e2e17940d5c00 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
GRACE (Graph Contrastive Learning)
References
----------
Paper: https://arxiv.org/abs/2006.04131
Author's code: https://github.com/CRIPAC-DIG/GRACE
DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/grace
"""
import dgl
import numpy as np
import torch
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
from torch import nn
class GRACE(nn.Module):
"""
GRACE model for graph representation learning via contrastive learning.
Parameters
----------
n_in_feats : int
Number of input features per node.
n_hidden : int
Dimension of the hidden layers.
n_out_feats : int
Dimension of the output features.
n_layers : int
Number of GNN layers.
act_fn : nn.Module
Activation function used in each layer.
temp : float
Temperature parameter for contrastive loss, controls the sharpness of
the similarity distribution.
edges_removing_rate_1 : float
Proportion of edges to remove when generating the first view of the graph.
edges_removing_rate_2 : float
Proportion of edges to remove when generating the second view of the graph.
feats_masking_rate_1 : float
Proportion of node features to mask when generating the first view of the graph.
feats_masking_rate_2 : float
Proportion of node features to mask when generating the second view of the graph.
"""
def __init__(
self,
n_in_feats,
n_hidden=128,
n_out_feats=128,
n_layers=2,
act_fn=nn.ReLU(),
temp=0.4,
edges_removing_rate_1=0.2,
edges_removing_rate_2=0.4,
feats_masking_rate_1=0.3,
feats_masking_rate_2=0.4,
):
super(GRACE, self).__init__()
self.encoder = GCN(n_in_feats, n_hidden, act_fn, n_layers) # Initialize the GCN encoder
# Initialize the MLP projector to map the encoded features to the contrastive space
self.proj = MLP(n_hidden, n_out_feats)
self.temp = temp # Set the temperature for the contrastive loss
self.edges_removing_rate_1 = edges_removing_rate_1 # Edge removal rate for the first view
self.edges_removing_rate_2 = edges_removing_rate_2 # Edge removal rate for the second view
self.feats_masking_rate_1 = feats_masking_rate_1 # Feature masking rate for the first view
self.feats_masking_rate_2 = feats_masking_rate_2 # Feature masking rate for the second view
@staticmethod
def sim(z1, z2):
"""
Compute the cosine similarity between two sets of node embeddings.
Parameters
----------
z1 : torch.Tensor
Node embeddings from the first view.
z2 : torch.Tensor
Node embeddings from the second view.
Returns
-------
torch.Tensor
Cosine similarity matrix.
"""
z1 = F.normalize(z1) # Normalize the embeddings for the first view
z2 = F.normalize(z2) # Normalize the embeddings for the second view
return torch.mm(z1, z2.t()) # Compute pairwise cosine similarity
def sim_loss(self, z1, z2):
"""
Compute the contrastive loss based on cosine similarity.
Parameters
----------
z1 : torch.Tensor
Node embeddings from the first view.
z2 : torch.Tensor
Node embeddings from the second view.
Returns
-------
torch.Tensor
Contrastive loss for the input embeddings.
"""
refl_sim = torch.exp(self.sim(z1, z1) / self.temp) # Self-similarity within the first view
between_sim = torch.exp(self.sim(z1, z2) / self.temp) # Cross-similarity between the two views
x1 = refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag() # Summation of similarities
loss = -torch.log(between_sim.diag() / x1) # Compute the contrastive loss
return loss
def loss(self, z1, z2):
"""
Compute the symmetric contrastive loss for both views.
Parameters
----------
z1 : torch.Tensor
Node embeddings from the first view.
z2 : torch.Tensor
Node embeddings from the second view.
Returns
-------
torch.Tensor
Average symmetric contrastive loss.
"""
l1 = self.sim_loss(z1=z1, z2=z2) # Loss for the first view
l2 = self.sim_loss(z1=z2, z2=z1) # Loss for the second view (symmetry)
return (l1 + l2).mean() * 0.5 # Average the loss for symmetry
def get_embedding(self, graph, feats):
"""
Get the node embeddings from the encoder without computing gradients.
Parameters
----------
graph : dgl.DGLGraph
The input graph.
feats : torch.Tensor
Node features.
Returns
-------
torch.Tensor
Node embeddings.
"""
h = self.encoder(graph, feats) # Encode the node features with GCN
return h.detach() # Detach from computation graph for evaluation
def forward(self, graph, feats):
"""
Perform the forward pass and compute the contrastive loss.
Parameters
----------
graph : dgl.DGLGraph
The input graph.
feats : torch.Tensor
Node features.
Returns
-------
torch.Tensor
Contrastive loss between two views of the graph.
"""
# Generate the first view
graph1, feats1 = _generating_views(graph, feats, self.edges_removing_rate_1, self.feats_masking_rate_1)
# Generate the second view
graph2, feats2 = _generating_views(graph, feats, self.edges_removing_rate_2, self.feats_masking_rate_2)
z1 = self.proj(self.encoder(graph1, feats1)) # Project the encoded features for the first view
z2 = self.proj(self.encoder(graph2, feats2)) # Project the encoded features for the second view
loss = self.loss(z1, z2) # Compute the contrastive loss
return loss
class GCN(nn.Module):
"""
Graph Convolutional Network (GCN) for node feature transformation.
Parameters
----------
n_in_feats : int
Number of input features per node.
n_out_feats : int
Number of output features per node.
act_fn : nn.Module
Activation function.
n_layers : int
Number of GCN layers.
"""
def __init__(self, n_in_feats, n_out_feats, act_fn, n_layers=2):
super(GCN, self).__init__()
assert n_layers >= 2, "Number of layers should be at least 2."
self.n_layers = n_layers # Set the number of layers
self.n_hidden = n_out_feats * 2 # Set the hidden dimension as twice the output dimension
self.input_layer = GraphConv(n_in_feats, self.n_hidden, activation=act_fn) # Define the input layer
self.hidden_layers = nn.ModuleList(
[GraphConv(self.n_hidden, self.n_hidden, activation=act_fn) for _ in range(n_layers - 2)]
) # Define the hidden layers
self.output_layer = GraphConv(self.n_hidden, n_out_feats, activation=act_fn) # Define the output layer
def forward(self, graph, feat):
"""
Forward pass through the GCN.
Parameters
----------
graph : dgl.DGLGraph
The input graph.
feat : torch.Tensor
Node features.
Returns
-------
torch.Tensor
Transformed node features after passing through the GCN layers.
"""
feat = self.input_layer(graph, feat) # Apply graph convolution at the input layer
for hidden_layer in self.hidden_layers:
feat = hidden_layer(graph, feat) # Apply graph convolution at each hidden layer
return self.output_layer(graph, feat) # Apply graph convolution at the output layer
class MLP(nn.Module):
"""
A simple Multi-Layer Perceptron (MLP) for projecting node embeddings to a new space.
Parameters
----------
n_in_feats : int
Number of input features.
n_out_feats : int
Number of output features.
"""
def __init__(self, n_in_feats, n_out_feats):
super(MLP, self).__init__()
self.fc1 = nn.Linear(n_in_feats, n_out_feats) # Define the first fully connected layer
self.fc2 = nn.Linear(n_out_feats, n_out_feats) # Define the second fully connected layer
def forward(self, x):
"""
Forward pass through the MLP.
Parameters
----------
x : torch.Tensor
Input node embeddings.
Returns
-------
torch.Tensor
Projected node embeddings.
"""
z = F.elu(self.fc1(x)) # Apply ELU activation after the first layer
return self.fc2(z) # Return the output of the second layer
def _generating_views(graph, feats, edges_removing_rate, feats_masking_rate):
"""
Generate two different views of the graph by removing edges and masking node features.
Parameters
----------
graph : dgl.DGLGraph
The input graph.
feats : torch.Tensor
Node features.
edges_removing_rate : float
Proportion of edges to remove.
feats_masking_rate : float
Proportion of node features to mask.
Returns
-------
new_graph : dgl.DGLGraph
The modified graph with some edges removed.
masked_feats : torch.Tensor
Node features with some values masked.
"""
# Removing edges (RE)
removing_edges_idx = _get_removing_edges_idx(graph, edges_removing_rate) # Get the indices of edges to remove
src = graph.edges()[0] # Source nodes of the edges
dst = graph.edges()[1] # Destination nodes of the edges
new_src = src[removing_edges_idx] # New source nodes after edge removal
new_dst = dst[removing_edges_idx] # New destination nodes after edge removal
new_graph = dgl.graph(
(new_src, new_dst), num_nodes=graph.num_nodes(), device=graph.device
) # Create a new graph with the remaining edges
new_graph = dgl.add_self_loop(new_graph) # Add self-loops to the new graph
# Masking node features (MF)
masked_feats = _masking_node_feats(feats, feats_masking_rate) # Mask node features
return new_graph, masked_feats # Return the modified graph and masked features
def _masking_node_feats(feats, masking_rate):
"""
Mask node features by setting a certain proportion to zero.
Parameters
----------
feats : torch.Tensor
Node features.
masking_rate : float
Proportion of features to mask.
Returns
-------
torch.Tensor
Node features with some values masked.
"""
mask = torch.rand(feats.size(1), dtype=torch.float32, device=feats.device) < masking_rate # Generate a random mask
feats = feats.clone() # Clone the features to avoid in-place modification
feats[:, mask] = 0 # Set masked features to zero
return feats # Return the masked features
def _get_removing_edges_idx(graph, edges_removing_rate):
"""
Generate the indices of edges to be removed from the graph.
Parameters
----------
graph : dgl.DGLGraph
The input graph.
edges_removing_rate : float
Proportion of edges to remove.
Returns
-------
torch.Tensor
Indices of the edges to be removed.
"""
n_edges = graph.num_edges() # Total number of edges
mask_rates = torch.FloatTensor(np.ones(n_edges) * edges_removing_rate) # Generate mask rates for each edge
masks = torch.bernoulli(1 - mask_rates) # Generate a mask indicating which edges to keep
mask_idx = masks.nonzero().squeeze(1) # Get the indices of edges to keep
return mask_idx # Return the indices of edges to be removed