blob: 5527fba5f63c9c029a4b452a08b7e34b0774a3ac [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
import torch
from dgl.data import CoraGraphDataset, GINDataset
from hugegraph_ml.data.hugegraph2dgl import HugeGraph2DGL
from hugegraph_ml.utils.dgl2hugegraph_utils import load_acm_raw
class TestHugegraph2dDGL(unittest.TestCase):
def setUp(self):
self.cora_data = CoraGraphDataset()[0]
self.mutag_dataset = GINDataset(name="MUTAG", self_loop=True)
self.acm_data = load_acm_raw()
self.ntype_map = {
"author": "ACM_author_v",
"field": "ACM_field_v",
"paper": "ACM_paper_v",
}
self.etype_map = {
"ap": "ACM_ap_e",
"fp": "ACM_fp_e",
"pa": "ACM_pa_e",
"pf": "ACM_pf_e",
}
def test_convert_graph(self):
hg2d = HugeGraph2DGL()
graph = hg2d.convert_graph(vertex_label="CORA_vertex", edge_label="CORA_edge")
self.assertEqual(graph.number_of_nodes(), self.cora_data.number_of_nodes(), "Number of nodes does not match.")
self.assertEqual(graph.number_of_edges(), self.cora_data.number_of_edges(), "Number of edges does not match.")
self.assertEqual(
graph.ndata["feat"].shape, self.cora_data.ndata["feat"].shape, "Node feature dimensions do not match."
)
self.assertEqual(
graph.ndata["label"].unique().shape[0],
self.cora_data.ndata["label"].unique().shape[0],
"Number of classes does not match.",
)
self.assertTrue(
torch.equal(graph.ndata["train_mask"], self.cora_data.ndata["train_mask"]), "Train mask does not match."
)
self.assertTrue(
torch.equal(graph.ndata["val_mask"], self.cora_data.ndata["val_mask"]), "Validation mask does not match."
)
self.assertTrue(
torch.equal(graph.ndata["test_mask"], self.cora_data.ndata["test_mask"]), "Test mask does not match."
)
self.assertEqual(
graph.number_of_nodes(), self.cora_data.number_of_nodes(), "Number of nodes in graph_info does not match."
)
self.assertEqual(
graph.number_of_edges(), self.cora_data.number_of_edges(), "Number of edges in graph_info does not match."
)
def test_convert_graph_dataset(self):
hg2d = HugeGraph2DGL()
dataset_dgl = hg2d.convert_graph_dataset(
graph_vertex_label="MUTAG_graph_vertex",
vertex_label="MUTAG_vertex",
edge_label="MUTAG_edge",
)
self.assertEqual(
len(dataset_dgl), len(self.mutag_dataset), "Number of graphs does not match."
)
self.assertEqual(
dataset_dgl.info["n_graphs"], len(self.mutag_dataset), "Number of graphs does not match."
)
self.assertEqual(
dataset_dgl.info["n_classes"], self.mutag_dataset.gclasses, "Number of graph classes does not match."
)
self.assertEqual(
dataset_dgl.info["n_feat_dim"], self.mutag_dataset.dim_nfeats, "Node feature dimensions do not match."
)
def test_convert_hetero_graph(self):
hg2d = HugeGraph2DGL()
hetero_graph = hg2d.convert_hetero_graph(
vertex_labels=["ACM_paper_v", "ACM_author_v", "ACM_field_v"],
edge_labels=["ACM_ap_e", "ACM_fp_e", "ACM_pa_e", "ACM_pf_e"]
)
for ntype in self.acm_data.ntypes:
self.assertIn(
self.ntype_map[ntype],
hetero_graph.ntypes,
f"Node type {ntype} is missing in converted graph."
)
acm_node_count = self.acm_data.num_nodes(ntype)
hetero_node_count = hetero_graph.num_nodes(self.ntype_map[ntype])
self.assertEqual(
acm_node_count,
hetero_node_count,
f"Node count for type {ntype} does not match: {acm_node_count} != {hetero_node_count}"
)
for c_etypes in self.acm_data.canonical_etypes:
mapped_c_etypes = (self.ntype_map[c_etypes[0]], self.etype_map[c_etypes[1]], self.ntype_map[c_etypes[2]])
self.assertIn(
mapped_c_etypes,
hetero_graph.canonical_etypes,
f"Edge type {mapped_c_etypes} is missing in converted graph."
)
acm_edge_count = self.acm_data.num_edges(etype=c_etypes)
hetero_edge_count = hetero_graph.num_edges(etype=mapped_c_etypes)
self.assertEqual(
acm_edge_count,
hetero_edge_count,
f"Edge count for type {mapped_c_etypes} does not match: {acm_edge_count} != {hetero_edge_count}"
)