blob: a07992099f22eebb3a9992fdeb96b313bea994ee [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C0301
from typing import Literal
import dgl
import numpy as np
import torch
from dgl import DGLGraph
from torch import nn
from tqdm import trange
from hugegraph_ml.utils.early_stopping import EarlyStopping
class NodeClassifyWithSample:
def __init__(self, graph: DGLGraph, model: nn.Module):
self.graph = graph
self._model = model
self.gpu = -1
self._device = "cpu"
self._early_stopping = None
self._is_trained = False
self.num_partitions = 100
self.batch_size = 100
self.sampler = dgl.dataloading.ClusterGCNSampler(
graph,
self.num_partitions,
)
self.dataloader = dgl.dataloading.DataLoader(
self.graph,
torch.arange(self.num_partitions).to(self._device),
self.sampler,
device=self._device,
batch_size=self.batch_size,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=False,
)
self._check_graph()
def _check_graph(self):
required_node_attrs = ["feat", "label", "train_mask", "val_mask", "test_mask"]
for attr in required_node_attrs:
if attr not in self.graph.ndata:
raise ValueError(f"Graph is missing required node attribute '{attr}' in ndata.")
def train(
self,
lr: float = 1e-3,
weight_decay: float = 0,
n_epochs: int = 200,
patience: int = float("inf"),
early_stopping_monitor: Literal["loss", "accuracy"] = "loss",
):
# Set device for training
early_stopping = EarlyStopping(patience=patience, monitor=early_stopping_monitor)
self._model.to(self._device)
# Get node features, labels, masks and move to device
feats = self.graph.ndata["feat"].to(self._device)
labels = self.graph.ndata["label"].to(self._device)
train_mask = self.graph.ndata["train_mask"].to(self._device)
val_mask = self.graph.ndata["val_mask"].to(self._device)
optimizer = torch.optim.Adam(self._model.parameters(), lr=lr, weight_decay=weight_decay)
# Training model
loss_fn = nn.CrossEntropyLoss()
epochs = trange(n_epochs)
for epoch in epochs:
# train
self._model.train()
for it, sg in enumerate(self.dataloader):
sg_feats = feats[sg.ndata["_ID"]]
sg_labels = labels[sg.ndata["_ID"]]
sg_train_msak = train_mask[sg.ndata["_ID"]].bool()
logits = self._model(sg, sg_feats)
train_loss = loss_fn(logits[sg_train_msak], sg_labels[sg_train_msak])
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
# validation
valid_metrics = self.evaluate_sg(
sg=sg,
sg_feats=sg_feats,
labels=labels,
val_mask=val_mask,
)
# logs
epochs.set_description(
f"epoch {epoch} | it {it} | train loss {train_loss.item():.4f} "
f"| val loss {valid_metrics['loss']:.4f}"
)
# early stopping
early_stopping(valid_metrics[early_stopping.monitor], self._model)
torch.cuda.empty_cache()
if early_stopping.early_stop:
break
early_stopping.load_best_model(self._model)
def evaluate_sg(self, sg, sg_feats, labels, val_mask):
self._model.eval()
sg_val_msak = val_mask[sg.ndata["_ID"]].bool()
sg_val_labels = labels[sg_val_msak]
with torch.no_grad():
sg_val_logits = self._model.inference(sg, sg_feats)[sg_val_msak]
val_loss = self._model.loss(sg_val_logits, sg_val_labels)
_, predicted = torch.max(sg_val_logits, dim=1)
accuracy = (predicted == sg_val_labels).sum().item() / len(sg_val_labels)
return {"accuracy": accuracy, "loss": val_loss.item()}
def evaluate(self):
test_mask = self.graph.ndata["test_mask"]
feats = self.graph.ndata["feat"]
labels = self.graph.ndata["label"]
test_logits = []
test_labels = []
total_loss = 0
with torch.no_grad():
for _, sg in enumerate(self.dataloader):
sg_feats = feats[sg.ndata["_ID"]]
sg_labels = labels[sg.ndata["_ID"]]
sg_test_msak = test_mask[sg.ndata["_ID"]].bool()
sg_test_labels = sg_labels[sg_test_msak]
sg_test_logits = self._model.inference(sg, sg_feats)[sg_test_msak]
loss = self._model.loss(sg_test_logits, sg_test_labels)
total_loss += loss
test_logits.append(sg_test_logits)
test_labels.append(sg_test_labels)
test_logits = torch.tensor(np.vstack(test_logits))
_, predicted = torch.max(test_logits, dim=1)
accuracy = (predicted == test_labels[0]).sum().item() / len(test_labels[0])
return {"accuracy": accuracy, "total_loss": total_loss.item()}