| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # pylint: disable=E1101,C0103 |
| |
| """ |
| CAmouflage-REsistant GNN (CARE-GNN) |
| |
| References |
| ---------- |
| Paper: https://arxiv.org/abs/2008.08692 |
| Author's code: https://github.com/YingtongDou/CARE-GNN |
| DGL code: https://github.com/dmlc/dgl/tree/master/examples/pytorch/caregnn |
| """ |
| |
| import dgl.function as fn |
| import numpy as np |
| import torch as th |
| from torch import nn |
| |
| |
| class CAREConv(nn.Module): |
| """One layer of CARE-GNN.""" |
| |
| def __init__( |
| self, |
| in_dim, |
| out_dim, |
| num_classes, |
| edges, |
| activation=None, |
| step_size=0.02, |
| ): |
| super(CAREConv, self).__init__() |
| |
| self.activation = activation |
| self.step_size = step_size |
| self.in_dim = in_dim |
| self.out_dim = out_dim |
| self.num_classes = num_classes |
| self.edges = edges |
| self.dist = {} |
| |
| self.linear = nn.Linear(self.in_dim, self.out_dim) |
| self.MLP = nn.Linear(self.in_dim, self.num_classes) |
| |
| self.p = {} |
| self.last_avg_dist = {} |
| self.f = {} |
| self.cvg = {} |
| for etype in edges: |
| self.p[etype] = 0.5 |
| self.last_avg_dist[etype] = 0 |
| self.f[etype] = [] |
| self.cvg[etype] = False |
| |
| def _calc_distance(self, edges): |
| # formula 2 |
| d = th.norm( |
| th.tanh(self.MLP(edges.src["h"])) - th.tanh(self.MLP(edges.dst["h"])), |
| 1, |
| 1, |
| ) |
| return {"d": d} |
| |
| def _top_p_sampling(self, g, p): |
| # this implementation is low efficient |
| # optimization requires dgl.sampling.select_top_p requested in issue #3100 |
| dist = g.edata["d"] |
| neigh_list = [] |
| for node in g.nodes(): |
| edges = g.in_edges(node, form="eid") |
| num_neigh = th.ceil(g.in_degrees(node) * p).int().item() |
| neigh_dist = dist[edges] |
| if neigh_dist.shape[0] > num_neigh: |
| neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[ |
| :num_neigh |
| ] |
| else: |
| neigh_index = np.arange(num_neigh) |
| neigh_list.append(edges[neigh_index]) |
| return th.cat(neigh_list) |
| |
| def forward(self, g, feat): |
| with g.local_scope(): |
| g.ndata["h"] = feat |
| |
| hr = {} |
| for _, etype in enumerate(g.canonical_etypes): |
| g.apply_edges(self._calc_distance, etype=etype) |
| self.dist[etype] = g.edges[etype].data["d"] |
| sampled_edges = self._top_p_sampling(g[etype], self.p[etype]) |
| |
| # formula 8 |
| g.send_and_recv( |
| sampled_edges, |
| fn.copy_u("h", "m"), |
| fn.mean("m", f"h_{etype[1]}"), |
| etype=etype, |
| ) |
| hr[etype] = g.ndata[f"h_{etype[1]}"] |
| if self.activation is not None: |
| hr[etype] = self.activation(hr[etype]) |
| |
| # formula 9 using mean as inter-relation aggregator |
| p_tensor = th.Tensor(list(self.p.values())).view(-1, 1, 1).to(g.device) |
| h_homo = th.sum(th.stack(list(hr.values())) * p_tensor, dim=0) |
| h_homo += feat |
| if self.activation is not None: |
| h_homo = self.activation(h_homo) |
| |
| return self.linear(h_homo) |
| |
| |
| class CAREGNN(nn.Module): |
| def __init__( |
| self, |
| in_dim, |
| num_classes, |
| hid_dim=64, |
| edges=None, |
| num_layers=2, |
| activation=None, |
| step_size=0.02, |
| ): |
| super(CAREGNN, self).__init__() |
| self.in_dim = in_dim |
| self.hid_dim = hid_dim |
| self.num_classes = num_classes |
| self.edges = edges |
| self.activation = activation |
| self.step_size = step_size |
| self.num_layers = num_layers |
| |
| self.layers = nn.ModuleList() |
| |
| if self.num_layers == 1: |
| # Single layer |
| self.layers.append( |
| CAREConv( |
| self.in_dim, |
| self.num_classes, |
| self.num_classes, |
| self.edges, |
| activation=self.activation, |
| step_size=self.step_size, |
| ) |
| ) |
| |
| else: |
| # Input layer |
| self.layers.append( |
| CAREConv( |
| self.in_dim, |
| self.hid_dim, |
| self.num_classes, |
| self.edges, |
| activation=self.activation, |
| step_size=self.step_size, |
| ) |
| ) |
| |
| # Hidden layers with n - 2 layers |
| for _ in range(self.num_layers - 2): |
| self.layers.append( |
| CAREConv( |
| self.hid_dim, |
| self.hid_dim, |
| self.num_classes, |
| self.edges, |
| activation=self.activation, |
| step_size=self.step_size, |
| ) |
| ) |
| |
| # Output layer |
| self.layers.append( |
| CAREConv( |
| self.hid_dim, |
| self.num_classes, |
| self.num_classes, |
| self.edges, |
| activation=self.activation, |
| step_size=self.step_size, |
| ) |
| ) |
| |
| def forward(self, graph, feat): |
| # For full graph training, directly use the graph |
| # formula 4 |
| sim = th.tanh(self.layers[0].MLP(feat)) |
| |
| # Forward of n layers of CARE-GNN |
| for layer in self.layers: |
| feat = layer(graph, feat) |
| |
| return feat, sim |
| |
| def RLModule(self, graph, epoch, idx): |
| for layer in self.layers: |
| for etype in self.edges: |
| if not layer.cvg[etype]: |
| # formula 5 |
| eid = graph.in_edges(idx, form="eid", etype=etype) |
| avg_dist = th.mean(layer.dist[etype][eid]) |
| |
| # formula 6 |
| if layer.last_avg_dist[etype] < avg_dist: |
| if layer.p[etype] - self.step_size > 0: |
| layer.p[etype] -= self.step_size |
| layer.f[etype].append(-1) |
| else: |
| if layer.p[etype] + self.step_size <= 1: |
| layer.p[etype] += self.step_size |
| layer.f[etype].append(+1) |
| layer.last_avg_dist[etype] = avg_dist |
| |
| # formula 7 |
| if epoch >= 9 and abs(sum(layer.f[etype][-10:])) <= 2: |
| layer.cvg[etype] = True |