blob: 2741d7bd868ddf05d26ef4ba034559b4fec64ba8 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=E0401,C0301
import torch
from dgl import DGLGraph
from sklearn.metrics import recall_score, roc_auc_score
from torch import nn
from torch.nn.functional import softmax
class DetectorCaregnn:
def __init__(self, graph: DGLGraph, model: nn.Module):
self.graph = graph
self._model = model
self._device = ""
def train(
self,
lr: float = 1e-3,
weight_decay: float = 0,
n_epochs: int = 200,
gpu: int = -1,
):
self._device = f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu"
self._model.to(self._device)
self.graph = self.graph.to(self._device)
labels = self.graph.ndata["label"].to(self._device)
feat = self.graph.ndata["feature"].to(self._device)
train_mask = self.graph.ndata["train_mask"]
val_mask = self.graph.ndata["val_mask"]
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze(1).to(self._device)
val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze(1).to(self._device)
rl_idx = torch.nonzero(train_mask.to(self._device) & labels.bool(), as_tuple=False).squeeze(1)
_, cnt = torch.unique(labels, return_counts=True)
loss_fn = torch.nn.CrossEntropyLoss(weight=1 / cnt)
optimizer = torch.optim.Adam(self._model.parameters(), lr=lr, weight_decay=weight_decay)
for _epoch in range(n_epochs):
self._model.train()
logits_gnn, logits_sim = self._model(self.graph, feat)
tr_loss = loss_fn(logits_gnn[train_idx], labels[train_idx]) + 2 * loss_fn(
logits_sim[train_idx], labels[train_idx]
)
recall_score(
labels[train_idx].cpu(),
logits_gnn.data[train_idx].argmax(dim=1).cpu(),
)
roc_auc_score(
labels[train_idx].cpu(),
softmax(logits_gnn, dim=1).data[train_idx][:, 1].cpu(),
)
loss_fn(logits_gnn[val_idx], labels[val_idx]) + 2 * loss_fn(logits_sim[val_idx], labels[val_idx])
recall_score(labels[val_idx].cpu(), logits_gnn.data[val_idx].argmax(dim=1).cpu())
roc_auc_score(
labels[val_idx].cpu(),
softmax(logits_gnn, dim=1).data[val_idx][:, 1].cpu(),
)
optimizer.zero_grad()
tr_loss.backward()
optimizer.step()
self._model.RLModule(self.graph, _epoch, rl_idx)
def evaluate(self):
labels = self.graph.ndata["label"].to(self._device)
feat = self.graph.ndata["feature"].to(self._device)
test_mask = self.graph.ndata["test_mask"]
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze(1).to(self._device)
_, cnt = torch.unique(labels, return_counts=True)
loss_fn = torch.nn.CrossEntropyLoss(weight=1 / cnt)
self._model.eval()
logits_gnn, logits_sim = self._model.forward(self.graph, feat)
test_loss = loss_fn(logits_gnn[test_idx], labels[test_idx]) + 2 * loss_fn(
logits_sim[test_idx], labels[test_idx]
)
test_recall = recall_score(labels[test_idx].cpu(), logits_gnn[test_idx].argmax(dim=1).cpu())
test_auc = roc_auc_score(
labels[test_idx].cpu(),
softmax(logits_gnn, dim=1).data[test_idx][:, 1].cpu(),
)
return {"recall": test_recall, "accuracy": test_auc, "loss": test_loss.item()}