blob: edc316bc2a28f680299ae597cfa5c8042129b4e4 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Literal
import dgl
import torch
from torch import nn
from tqdm import trange
from hugegraph_ml.data.hugegraph_dataset import HugeGraphDataset
from hugegraph_ml.utils.early_stopping import EarlyStopping
class GraphClassify:
def __init__(self, dataset: HugeGraphDataset, model: nn.Module):
self.dataset = dataset
self._train_dataloader = None
self._valid_dataloader = None
self._test_dataloader = None
self._model = model
self._device = ""
self._early_stopping = None
def _evaluate(self, dataloader):
self._model.eval()
correct = 0
total = 0
total_loss = 0
with torch.no_grad():
for batched_graph, labels in dataloader:
batched_graph = batched_graph.to(self._device)
feats = torch.FloatTensor(batched_graph.ndata["feat"]).to(self._device)
labels = torch.LongTensor(labels.long()).to(self._device)
logits = self._model(batched_graph, feats)
preds = torch.argmax(logits, dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
loss = self._model.loss(logits, labels)
total_loss += loss.item()
accuracy = correct / total
loss = total_loss / total
return {"accuracy": accuracy, "loss": loss}
def train(
self,
batch_size: int = 20,
lr: float = 1e-3,
weight_decay: float = 0,
n_epochs: int = 200,
patience: int = float("inf"),
early_stopping_monitor: Literal["loss", "accuracy"] = "loss",
clip: float = 2.0,
gpu: int = -1,
):
self._device = f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu"
self._early_stopping = EarlyStopping(patience=patience, monitor=early_stopping_monitor)
self._model.to(self._device)
# default 7-2-1 train-valid-test
train_size = int(len(self.dataset) * 0.7)
test_size = int(len(self.dataset) * 0.1)
valid_size = len(self.dataset) - train_size - test_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
self.dataset, (train_size, valid_size, test_size)
)
self._train_dataloader = dgl.dataloading.GraphDataLoader(train_dataset, batch_size=batch_size, shuffle=True)
self._valid_dataloader = dgl.dataloading.GraphDataLoader(val_dataset, batch_size=batch_size, shuffle=False)
self._test_dataloader = dgl.dataloading.GraphDataLoader(test_dataset, batch_size=batch_size, shuffle=False)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, self._model.parameters()), lr=lr, weight_decay=weight_decay
)
epochs = trange(n_epochs)
for epoch in epochs:
self._model.train()
correct = 0
total = 0
total_loss = 0
for batched_graph, labels in self._train_dataloader:
batched_graph = batched_graph.to(self._device)
feats = torch.FloatTensor(batched_graph.ndata["feat"]).to(self._device)
labels = torch.LongTensor(labels.long()).to(self._device)
self._model.zero_grad()
total += labels.size(0)
logits = self._model(batched_graph, feats)
preds = torch.argmax(logits, dim=1)
correct += (preds == labels).sum().item()
loss = self._model.loss(logits, labels)
total_loss += loss.item()
loss.backward()
nn.utils.clip_grad_norm_(self._model.parameters(), clip)
optimizer.step()
train_acc = correct / total
loss = total_loss / total
# validation
valid_metrics = self._evaluate(self._valid_dataloader)
epochs.set_description(
f"epoch {epoch} | train loss {loss:.4f} | val loss {valid_metrics['loss']:.4f} | "
f"train acc {train_acc:.4f} | val acc {valid_metrics['accuracy']:.4f}"
)
# early stopping
self._early_stopping(valid_metrics[self._early_stopping.monitor], self._model)
if self._early_stopping.early_stop:
break
self._early_stopping.load_best_model(self._model)
def evaluate(self):
return self._evaluate(self._test_dataloader)