blob: 91bb582b222dca38cfd3dfe60a606c19ee15a155 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=R0205,C0200,R1732
"""
General Attributed Multiplex HeTerogeneous Network Embedding (GATNE)
References
----------
Paper: https://arxiv.org/abs/1905.01669
Author's code: https://github.com/THUDM/GATNE
DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/GATNE-T
"""
import math
import time
import multiprocessing
from functools import partial, reduce
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import dgl
import dgl.function as fn
class NeighborSampler(object):
def __init__(self, g, num_fanouts):
self.g = g
self.num_fanouts = num_fanouts
def sample(self, pairs):
heads, tails, types = zip(*pairs)
seeds, head_invmap = torch.unique(torch.LongTensor(heads), return_inverse=True)
blocks = []
for fanout in reversed(self.num_fanouts):
sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)
sampled_block = dgl.to_block(sampled_graph, seeds)
seeds = sampled_block.srcdata[dgl.NID]
blocks.insert(0, sampled_block)
return (
blocks,
torch.LongTensor(head_invmap),
torch.LongTensor(tails),
torch.LongTensor(types),
)
class DGLGATNE(nn.Module):
def __init__(
self,
num_nodes,
embedding_size,
embedding_u_size,
edge_types,
edge_type_count,
dim_a,
):
super(DGLGATNE, self).__init__()
self.num_nodes = num_nodes
self.embedding_size = embedding_size
self.embedding_u_size = embedding_u_size
self.edge_types = edge_types
self.edge_type_count = edge_type_count
self.dim_a = dim_a
self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size))
self.node_type_embeddings = Parameter(
torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)
)
self.trans_weights = Parameter(
torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)
)
self.trans_weights_s1 = Parameter(
torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)
)
self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1))
self.reset_parameters()
def reset_parameters(self):
self.node_embeddings.data.uniform_(-1.0, 1.0)
self.node_type_embeddings.data.uniform_(-1.0, 1.0)
self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
# embs: [batch_size, embedding_size]
def forward(self, block):
input_nodes = block.srcdata[dgl.NID]
output_nodes = block.dstdata[dgl.NID]
batch_size = block.number_of_dst_nodes()
node_embed = self.node_embeddings
node_type_embed = []
with block.local_scope():
for i in range(self.edge_type_count):
edge_type = self.edge_types[i]
block.srcdata[edge_type] = self.node_type_embeddings[input_nodes, i]
block.dstdata[edge_type] = self.node_type_embeddings[output_nodes, i]
block.update_all(
fn.copy_u(edge_type, "m"),
fn.sum("m", edge_type), # pylint: disable=E1101
etype=edge_type,
)
node_type_embed.append(block.dstdata[edge_type])
node_type_embed = torch.stack(node_type_embed, 1)
tmp_node_type_embed = node_type_embed.unsqueeze(2).view(
-1, 1, self.embedding_u_size
)
trans_w = (
self.trans_weights.unsqueeze(0)
.repeat(batch_size, 1, 1, 1)
.view(-1, self.embedding_u_size, self.embedding_size)
)
trans_w_s1 = (
self.trans_weights_s1.unsqueeze(0)
.repeat(batch_size, 1, 1, 1)
.view(-1, self.embedding_u_size, self.dim_a)
)
trans_w_s2 = (
self.trans_weights_s2.unsqueeze(0)
.repeat(batch_size, 1, 1, 1)
.view(-1, self.dim_a, 1)
)
attention = (
F.softmax(
torch.matmul(
torch.tanh(torch.matmul(tmp_node_type_embed, trans_w_s1)),
trans_w_s2,
)
.squeeze(2)
.view(-1, self.edge_type_count),
dim=1,
)
.unsqueeze(1)
.repeat(1, self.edge_type_count, 1)
)
node_type_embed = torch.matmul(attention, node_type_embed).view(
-1, 1, self.embedding_u_size
)
node_embed = node_embed[output_nodes].unsqueeze(1).repeat(
1, self.edge_type_count, 1
) + torch.matmul(node_type_embed, trans_w).view(
-1, self.edge_type_count, self.embedding_size
)
last_node_embed = F.normalize(node_embed, dim=2)
return last_node_embed # [batch_size, edge_type_count, embedding_size]
class NSLoss(nn.Module):
def __init__(self, num_nodes, num_sampled, embedding_size):
super(NSLoss, self).__init__()
self.num_nodes = num_nodes
self.num_sampled = num_sampled
self.embedding_size = embedding_size
self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size))
# [ (log(i+2) - log(i+1)) / log(num_nodes + 1)]
self.sample_weights = F.normalize(
torch.Tensor(
[
(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1)
for k in range(num_nodes)
]
),
dim=0,
)
self.reset_parameters()
def reset_parameters(self):
self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
def forward(self, input, embs, label):
n = input.shape[0]
log_target = torch.log(
torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1))
)
negs = torch.multinomial(
self.sample_weights, self.num_sampled * n, replacement=True
).view(n, self.num_sampled)
noise = torch.neg(self.weights[negs])
sum_log_sampled = torch.sum(
torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1
).squeeze()
loss = log_target + sum_log_sampled
return -loss.sum() / n
def generate_pairs_parallel(walks, skip_window=None, layer_id=None):
pairs = []
for walk in walks:
walk = walk.tolist()
for i in range(len(walk)):
for j in range(1, skip_window + 1):
if i - j >= 0:
pairs.append((walk[i], walk[i - j], layer_id))
if i + j < len(walk):
pairs.append((walk[i], walk[i + j], layer_id))
return pairs
def generate_pairs(all_walks, window_size, num_workers):
# for each node, choose the first neighbor and second neighbor of it to form pairs
# Get all worker processes
start_time = time.time()
print(f"We are generating pairs with {num_workers} cores.")
# Start all worker processes
pool = multiprocessing.Pool(processes=num_workers)
pairs = []
skip_window = window_size // 2
for layer_id, walks in enumerate(all_walks):
block_num = len(walks) // num_workers
if block_num > 0:
walks_list = [
walks[i * block_num : min((i + 1) * block_num, len(walks))]
for i in range(num_workers)
]
else:
walks_list = [walks]
tmp_result = pool.map(
partial(
generate_pairs_parallel,
skip_window=skip_window,
layer_id=layer_id,
),
walks_list,
)
pairs += reduce(lambda x, y: x + y, tmp_result)
pool.close()
end_time = time.time()
print(f"Generate pairs end, use {end_time - start_time}s.")
return np.array([list(pair) for pair in set(pairs)])
def construct_typenodes_from_graph(graph):
nodes = []
for etype in graph.etypes:
edges = graph.edges(etype=etype)
node1, node2 = edges
node1_list = node1.cpu().numpy().tolist()
node2_list = node2.cpu().numpy().tolist()
tmp_nodes = list(set(node1_list + node2_list))
nodes.append(tmp_nodes)
return nodes