blob: be2b33d4ab13a30aab723a4bad205a355352ea99 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-branches
# pylint: disable=C0304
import warnings
import dgl
import networkx as nx
import torch
from pyhugegraph.api.gremlin import GremlinManager
from pyhugegraph.client import PyHugeClient
from hugegraph_ml.data.hugegraph_dataset import HugeGraphDataset
class HugeGraph2DGL:
def __init__(
self,
url: str = "http://127.0.0.1:8080",
graph: str = "hugegraph",
user: str = "",
pwd: str = "",
graphspace: str | None = None,
):
self._client: PyHugeClient = PyHugeClient(url=url, graph=graph, user=user, pwd=pwd, graphspace=graphspace)
self._graph_germlin: GremlinManager = self._client.gremlin()
def convert_graph(
self,
vertex_label: str,
edge_label: str,
feat_key: str = "feat",
label_key: str = "label",
mask_keys: list[str] | None = None,
):
if mask_keys is None:
mask_keys = ["train_mask", "val_mask", "test_mask"]
vertices = self._graph_germlin.exec(f"g.V().hasLabel('{vertex_label}')")["data"]
edges = self._graph_germlin.exec(f"g.E().hasLabel('{edge_label}')")["data"]
graph_dgl = self._convert_graph_from_v_e(vertices, edges, feat_key, label_key, mask_keys)
return graph_dgl
def convert_hetero_graph(
self,
vertex_labels: list[str],
edge_labels: list[str],
feat_key: str = "feat",
label_key: str = "label",
mask_keys: list[str] | None = None,
):
if mask_keys is None:
mask_keys = ["train_mask", "val_mask", "test_mask"]
vertex_label_id2idx = {}
vertex_label_data = {}
# for each vertex label
for vertex_label in vertex_labels:
vertices = self._graph_germlin.exec(f"g.V().hasLabel('{vertex_label}')")["data"]
if len(vertices) == 0:
warnings.warn(f"Graph has no vertices of vertex_label: {vertex_label}", Warning, stacklevel=2)
else:
vertex_ids = [v["id"] for v in vertices]
id2idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)}
vertex_label_id2idx[vertex_label] = id2idx
# extract vertex property(feat, label, mask)
vertex_label_data[vertex_label] = {}
if feat_key in vertices[0]["properties"]:
node_feats = torch.tensor([v["properties"][feat_key] for v in vertices], dtype=torch.float32)
vertex_label_data[vertex_label]["feat"] = node_feats
if label_key in vertices[0]["properties"]:
node_labels = torch.tensor([v["properties"][label_key] for v in vertices], dtype=torch.long)
vertex_label_data[vertex_label]["label"] = node_labels
if mask_keys:
for mk in mask_keys:
if mk in vertices[0]["properties"]:
mask = torch.tensor([v["properties"][mk] for v in vertices], dtype=torch.bool)
vertex_label_data[vertex_label][mk] = mask
# build hetero graph from edges
edge_data_dict = {}
for edge_label in edge_labels:
edges = self._graph_germlin.exec(f"g.E().hasLabel('{edge_label}')")["data"]
if len(edges) == 0:
warnings.warn(f"Graph has no edges of edge_label: {edge_label}", Warning, stacklevel=2)
else:
src_vertex_label = edges[0]["outVLabel"]
src_idx = [vertex_label_id2idx[src_vertex_label][e["outV"]] for e in edges]
dst_vertex_label = edges[0]["inVLabel"]
dst_idx = [vertex_label_id2idx[dst_vertex_label][e["inV"]] for e in edges]
edge_data_dict[(src_vertex_label, edge_label, dst_vertex_label)] = (src_idx, dst_idx)
# add vertex properties data
hetero_graph = dgl.heterograph(edge_data_dict)
for vertex_label in vertex_labels:
for prop in vertex_label_data[vertex_label]:
hetero_graph.nodes[vertex_label].data[prop] = vertex_label_data[vertex_label][prop]
return hetero_graph
def convert_graph_dataset(
self,
graph_vertex_label: str,
vertex_label: str,
edge_label: str,
feat_key: str = "feat",
label_key: str = "label",
):
# get graph vertices
graph_vertices = self._graph_germlin.exec(f"g.V().hasLabel('{graph_vertex_label}')")["data"]
graphs = []
max_n_nodes = 0
graph_labels = []
for graph_vertex in graph_vertices:
graph_id = graph_vertex["id"]
label = graph_vertex["properties"][label_key]
graph_labels.append(label)
# get this graph's vertices and edges
vertices = self._graph_germlin.exec(f"g.V().hasLabel('{vertex_label}').has('graph_id', {graph_id})")["data"]
edges = self._graph_germlin.exec(f"g.E().hasLabel('{edge_label}').has('graph_id', {graph_id})")["data"]
graph_dgl = self._convert_graph_from_v_e(vertices, edges, feat_key)
graphs.append(graph_dgl)
# record max num of node
max_n_nodes = max(max_n_nodes, graph_dgl.number_of_nodes())
# record dataset info
graphs_info = {
"n_graphs": len(graph_vertices),
"max_n_nodes": max_n_nodes,
"n_feat_dim": graphs[0].ndata["feat"].size()[1],
"n_classes": len(set(graph_labels)),
}
dataset_dgl = HugeGraphDataset(graphs=graphs, labels=graph_labels, info=graphs_info)
return dataset_dgl
def convert_graph_nx(
self,
vertex_label: str,
edge_label: str,
):
vertices = self._graph_germlin.exec(f"g.V().hasLabel('{vertex_label}')")["data"]
edges = self._graph_germlin.exec(f"g.E().hasLabel('{edge_label}')")["data"]
graph_nx = self._convert_graph_from_v_e_nx(vertices=vertices, edges=edges)
return graph_nx
def convert_graph_with_edge_feat(
self,
vertex_label: str,
edge_label: str,
node_feat_key: str = "feat",
edge_feat_key: str = "edge_feat",
label_key: str = "label",
mask_keys: list[str] | None = None,
):
if mask_keys is None:
mask_keys = ["train_mask", "val_mask", "test_mask"]
vertices = self._graph_germlin.exec(f"g.V().hasLabel('{vertex_label}')")["data"]
edges = self._graph_germlin.exec(f"g.E().hasLabel('{edge_label}')")["data"]
graph_dgl = self._convert_graph_from_v_e_with_edge_feat(
vertices, edges, edge_feat_key, node_feat_key, label_key, mask_keys
)
return graph_dgl
def convert_graph_ogb(self, vertex_label: str, edge_label: str, split_label: str):
vertices = self._graph_germlin.exec(f"g.V().hasLabel('{vertex_label}')")["data"]
edges = self._graph_germlin.exec(f"g.E().hasLabel('{edge_label}')")["data"]
graph_dgl, vertex_id_to_idx = self._convert_graph_from_ogb(vertices, edges, "feat", "year", "weight")
edges_split = self._graph_germlin.exec(f"g.E().hasLabel('{split_label}')")["data"]
split_edge = self._convert_split_edge_from_ogb(edges_split, vertex_id_to_idx)
return graph_dgl, split_edge
def convert_hetero_graph_bgnn(
self,
vertex_labels: list[str],
edge_labels: list[str],
feat_key: str = "feat",
label_key: str = "class",
cat_key: str = "cat_features",
mask_keys: list[str] | None = None,
):
if mask_keys is None:
mask_keys = ["train_mask", "val_mask", "test_mask"]
vertex_label_id2idx = {}
vertex_label_data = {}
# for each vertex label
for vertex_label in vertex_labels:
vertices = self._graph_germlin.exec(f"g.V().hasLabel('{vertex_label}')")["data"]
if len(vertices) == 0:
warnings.warn(f"Graph has no vertices of vertex_label: {vertex_label}", Warning, stacklevel=2)
else:
vertex_ids = [v["id"] for v in vertices]
id2idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)}
vertex_label_id2idx[vertex_label] = id2idx
# extract vertex property(feat, label, mask)
vertex_label_data[vertex_label] = {}
if feat_key in vertices[0]["properties"]:
node_feats = torch.tensor(
[v["properties"][feat_key] for v in vertices],
dtype=torch.int32,
)
vertex_label_data[vertex_label]["feat"] = node_feats
if label_key in vertices[0]["properties"]:
node_labels = torch.tensor(
[v["properties"][label_key] for v in vertices],
dtype=torch.float64,
)
vertex_label_data[vertex_label]["class"] = node_labels
if cat_key in vertices[0]["properties"]:
node_cat = torch.tensor(
[v["properties"][cat_key] for v in vertices],
dtype=torch.int32,
)
vertex_label_data[vertex_label]["cat_features"] = node_cat
if mask_keys:
for mk in mask_keys:
if mk in vertices[0]["properties"]:
mask = torch.tensor(
[v["properties"][mk] for v in vertices],
dtype=torch.bool,
)
vertex_label_data[vertex_label][mk] = mask
# build hetero graph from edges
edge_data_dict = {}
for edge_label in edge_labels:
edges = self._graph_germlin.exec(f"g.E().hasLabel('{edge_label}')")["data"]
if len(edges) == 0:
warnings.warn(f"Graph has no edges of edge_label: {edge_label}", Warning, stacklevel=2)
else:
src_vertex_label = edges[0]["outVLabel"]
src_idx = [vertex_label_id2idx[src_vertex_label][e["outV"]] for e in edges]
dst_vertex_label = edges[0]["inVLabel"]
dst_idx = [vertex_label_id2idx[dst_vertex_label][e["inV"]] for e in edges]
edge_data_dict[(src_vertex_label, edge_label, dst_vertex_label)] = (
src_idx,
dst_idx,
)
# add vertex properties data
hetero_graph = dgl.heterograph(edge_data_dict)
for vertex_label in vertex_labels:
for prop in vertex_label_data[vertex_label]:
hetero_graph.nodes[vertex_label].data[prop] = vertex_label_data[vertex_label][prop]
return hetero_graph
@staticmethod
def _convert_graph_from_v_e(vertices, edges, feat_key=None, label_key=None, mask_keys=None):
if len(vertices) == 0:
warnings.warn("This graph has no vertices", Warning, stacklevel=2)
return dgl.graph(())
vertex_ids = [v["id"] for v in vertices]
vertex_id_to_idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)}
src_idx = [vertex_id_to_idx[e["outV"]] for e in edges]
dst_idx = [vertex_id_to_idx[e["inV"]] for e in edges]
graph_dgl = dgl.graph((src_idx, dst_idx))
if feat_key and feat_key in vertices[0]["properties"]:
node_feats = [v["properties"][feat_key] for v in vertices]
graph_dgl.ndata["feat"] = torch.tensor(node_feats, dtype=torch.float32)
if label_key and label_key in vertices[0]["properties"]:
node_labels = [v["properties"][label_key] for v in vertices]
graph_dgl.ndata["label"] = torch.tensor(node_labels, dtype=torch.long)
if mask_keys:
for mk in mask_keys:
if mk in vertices[0]["properties"]:
node_masks = [v["properties"][mk] for v in vertices]
mask = torch.tensor(node_masks, dtype=torch.bool)
graph_dgl.ndata[mk] = mask
return graph_dgl
@staticmethod
def _convert_graph_from_v_e_nx(vertices, edges):
if len(vertices) == 0:
warnings.warn("This graph has no vertices", Warning, stacklevel=2)
return nx.Graph(())
vertex_ids = [v["id"] for v in vertices]
vertex_id_to_idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)}
new_vertex_ids = [vertex_id_to_idx[id] for id in vertex_ids]
edge_list = [(edge["outV"], edge["inV"]) for edge in edges]
new_edge_list = [(vertex_id_to_idx[src], vertex_id_to_idx[dst]) for src, dst in edge_list]
graph_nx = nx.Graph()
graph_nx.add_nodes_from(new_vertex_ids)
graph_nx.add_edges_from(new_edge_list)
return graph_nx
@staticmethod
def _convert_graph_from_v_e_with_edge_feat(
vertices,
edges,
edge_feat_key,
node_feat_key=None,
label_key=None,
mask_keys=None,
):
if len(vertices) == 0:
warnings.warn("This graph has no vertices", Warning, stacklevel=2)
return dgl.graph(())
vertex_ids = [v["id"] for v in vertices]
vertex_id_to_idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)}
src_idx = [vertex_id_to_idx[e["outV"]] for e in edges]
dst_idx = [vertex_id_to_idx[e["inV"]] for e in edges]
graph_dgl = dgl.graph((src_idx, dst_idx))
if node_feat_key and node_feat_key in vertices[0]["properties"]:
node_feats = [v["properties"][node_feat_key] for v in vertices]
graph_dgl.ndata["feat"] = torch.tensor(node_feats, dtype=torch.int64)
if edge_feat_key and edge_feat_key in edges[0]["properties"]:
edge_feats = [e["properties"][edge_feat_key] for e in edges]
graph_dgl.edata["feat"] = torch.tensor(edge_feats, dtype=torch.int64)
if label_key and label_key in vertices[0]["properties"]:
node_labels = [v["properties"][label_key] for v in vertices]
graph_dgl.ndata["label"] = torch.tensor(node_labels, dtype=torch.long)
if mask_keys:
for mk in mask_keys:
if mk in vertices[0]["properties"]:
node_masks = [v["properties"][mk] for v in vertices]
mask = torch.tensor(node_masks, dtype=torch.bool)
graph_dgl.ndata[mk] = mask
return graph_dgl
@staticmethod
def _convert_graph_from_ogb(vertices, edges, feat_key, year_key, weight_key):
if len(vertices) == 0:
warnings.warn("This graph has no vertices", Warning, stacklevel=2)
return dgl.graph(())
vertex_ids = [v["id"] for v in vertices]
vertex_id_to_idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)}
src_idx = [vertex_id_to_idx[e["outV"]] for e in edges]
dst_idx = [vertex_id_to_idx[e["inV"]] for e in edges]
graph_dgl = dgl.graph((src_idx, dst_idx))
if feat_key and feat_key in vertices[0]["properties"]:
node_feats = [v["properties"][feat_key] for v in vertices[0 : graph_dgl.number_of_nodes()]]
graph_dgl.ndata["feat"] = torch.tensor(node_feats, dtype=torch.float32)
if year_key and year_key in edges[0]["properties"]:
year = [e["properties"][year_key] for e in edges]
graph_dgl.edata["year"] = torch.tensor(year, dtype=torch.int64)
if weight_key and weight_key in edges[0]["properties"]:
weight = [e["properties"][weight_key] for e in edges]
graph_dgl.edata["weight"] = torch.tensor(weight, dtype=torch.int64)
return graph_dgl, vertex_id_to_idx
@staticmethod
def _convert_split_edge_from_ogb(edges, vertex_id_to_idx):
train_edge_list = []
train_year_list = []
train_weight_list = []
valid_edge_list = []
valid_year_list = []
valid_weight_list = []
valid_edge_neg_list = []
test_edge_list = []
test_year_list = []
test_weight_list = []
test_edge_neg_list = []
for edge in edges:
if edge["properties"]["train_edge_mask"] == 1:
train_edge_list.append([vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]])
if edge["properties"]["train_year_mask"] != -1:
train_year_list.append(edge["properties"]["train_year_mask"])
if edge["properties"]["train_weight_mask"] != -1:
train_weight_list.append(edge["properties"]["train_weight_mask"])
if edge["properties"]["valid_edge_mask"] == 1:
valid_edge_list.append([vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]])
if edge["properties"]["valid_year_mask"] != -1:
valid_year_list.append(edge["properties"]["valid_year_mask"])
if edge["properties"]["valid_weight_mask"] != -1:
valid_weight_list.append(edge["properties"]["valid_weight_mask"])
if edge["properties"]["valid_edge_neg_mask"] == 1:
valid_edge_neg_list.append([vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]])
if edge["properties"]["test_edge_mask"] == 1:
test_edge_list.append([vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]])
if edge["properties"]["test_year_mask"] != -1:
test_year_list.append(edge["properties"]["test_year_mask"])
if edge["properties"]["test_weight_mask"] != -1:
test_weight_list.append(edge["properties"]["test_weight_mask"])
if edge["properties"]["test_edge_neg_mask"] == 1:
test_edge_neg_list.append([vertex_id_to_idx[edge["outV"]], vertex_id_to_idx[edge["inV"]]])
split_edge = {
"train": {
"edge": torch.tensor(train_edge_list),
"weight": torch.tensor(train_weight_list),
"year": torch.tensor(train_year_list),
},
"valid": {
"edge": torch.tensor(valid_edge_list),
"weight": torch.tensor(valid_weight_list),
"year": torch.tensor(valid_year_list),
"edge_neg": torch.tensor(valid_edge_neg_list),
},
"test": {
"edge": torch.tensor(test_edge_list),
"weight": torch.tensor(test_weight_list),
"year": torch.tensor(test_year_list),
"edge_neg": torch.tensor(test_edge_neg_list),
},
}
return split_edge
if __name__ == "__main__":
hg2d = HugeGraph2DGL()
hg2d.convert_graph(vertex_label="CORA_vertex", edge_label="CORA_edge")
hg2d.convert_graph_dataset(
graph_vertex_label="MUTAG_graph_vertex",
vertex_label="MUTAG_vertex",
edge_label="MUTAG_edge",
)
hg2d.convert_hetero_graph(
vertex_labels=["ACM_paper_v", "ACM_author_v", "ACM_field_v"],
edge_labels=["ACM_ap_e", "ACM_fp_e", "ACM_pa_e", "ACM_pf_e"],
)
hg2d.convert_graph_nx(vertex_label="CAVEMAN_vertex", edge_label="CAVEMAN_edge")
hg2d.convert_graph_with_edge_feat(vertex_label="CORA_edge_feat_vertex", edge_label="CORA_edge_feat_edge")
hg2d.convert_graph_ogb(
vertex_label="ogbl-collab_vertex",
edge_label="ogbl-collab_edge",
split_label="ogbl-collab_split_edge",
)
hg2d.convert_hetero_graph_bgnn(vertex_labels=["AVAZU__N_v"], edge_labels=["AVAZU__E_e"])
hg2d.convert_hetero_graph(
vertex_labels=["AMAZONGATNE__N_v"],
edge_labels=[
"AMAZONGATNE_1_e",
"AMAZONGATNE_2_e",
],
)