Skip to content

Commit

Permalink
[Model] Implemented SubgraphX Explainer for Homogeneous graph (dmlc#5315
Browse files Browse the repository at this point in the history
)

* subgraphx commit

* nits

* newline eof added

* lint fix

* test script updated to use default values

* lint fix

* graphs that are used for test cases are updated to a small graph

* lint formatted

* test paramter adj to complete the test under 20s

* lint fixes

---------

Co-authored-by: kxm180046 <[email protected]>
  • Loading branch information
2 people authored and paoxiaode committed Mar 24, 2023
1 parent 1eb5deb commit 70dd2a0
Show file tree
Hide file tree
Showing 4 changed files with 409 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/api/python/nn-pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ Utility Modules
~dgl.nn.pytorch.sparse_emb.NodeEmbedding
~dgl.nn.pytorch.explain.GNNExplainer
~dgl.nn.pytorch.explain.HeteroGNNExplainer
~dgl.nn.pytorch.explain.SubgraphX
~dgl.nn.pytorch.utils.LabelPropagation
~dgl.nn.pytorch.graph_transformer.DegreeEncoder
~dgl.nn.pytorch.utils.LaplacianPosEnc
Expand Down
3 changes: 2 additions & 1 deletion python/dgl/nn/pytorch/explain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Torch modules for explanation models."""
# pylint: disable= no-member, arguments-differ, invalid-name

from .gnnexplainer import GNNExplainer, HeteroGNNExplainer
from .gnnexplainer import *
from .subgraphx import *
366 changes: 366 additions & 0 deletions python/dgl/nn/pytorch/explain/subgraphx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
"""Torch Module for SubgraphX"""
import math

import networkx as nx
import numpy as np
import torch
import torch.nn as nn

from ....base import NID
from ....convert import to_networkx
from ....subgraph import node_subgraph
from ....transforms.functional import remove_nodes

__all__ = ["SubgraphX"]


class MCTSNode:
r"""Monte Carlo Tree Search Node
Parameters
----------
nodes : Tensor
The node IDs of the graph that are associated with this tree node
"""

def __init__(self, nodes):
self.nodes = nodes
self.num_visit = 0
self.total_reward = 0.0
self.immediate_reward = 0.0
self.children = []

def __repr__(self):
r"""Get the string representation of the node.
Returns
-------
str
The string representation of the node
"""
return str(self.nodes)


class SubgraphX(nn.Module):
r"""SubgraphX from `On Explainability of Graph Neural Networks via Subgraph
Explorations <https://arxiv.org/abs/2102.05152>`
It identifies the most important subgraph from the original graph that
plays a critical role in GNN-based graph classification.
It employs Monte Carlo tree search (MCTS) in efficiently exploring
different subgraphs for explanation and uses Shapley values as the measure
of subgraph importance.
Parameters
----------
model : nn.Module
The GNN model to explain that tackles multiclass graph classification
* Its forward function must have the form
:attr:`forward(self, graph, nfeat)`.
* The output of its forward function is the logits.
num_hops : int
Number of message passing layers in the model
coef : float, optional
This hyperparameter controls the trade-off between exploration and
exploitation. A higher value encourages the algorithm to explore
relatively unvisited nodes. Default: 10.0
high2low : bool, optional
If True, it will use the "High2low" strategy for pruning actions,
expanding children nodes from high degree to low degree when extending
the children nodes in the search tree. Otherwise, it will use the
"Low2high" strategy. Default: True
num_child : int, optional
This is the number of children nodes to expand when extending the
children nodes in the search tree. Default: 12
num_rollouts : int, optional
This is the number of rollouts for MCTS. Default: 20
node_min : int, optional
This is the threshold to define a leaf node based on the number of
nodes in a subgraph. Default: 3
shapley_steps : int, optional
This is the number of steps for Monte Carlo sampling in estimating
Shapley values. Default: 100
log : bool, optional
If True, it will log the progress. Default: False
"""

def __init__(
self,
model,
num_hops,
coef=10.0,
high2low=True,
num_child=12,
num_rollouts=20,
node_min=3,
shapley_steps=100,
log=False,
):
super().__init__()
self.num_hops = num_hops
self.coef = coef
self.high2low = high2low
self.num_child = num_child
self.num_rollouts = num_rollouts
self.node_min = node_min
self.shapley_steps = shapley_steps
self.log = log

self.model = model

def shapley(self, subgraph_nodes):
r"""Compute Shapley value with Monte Carlo approximation.
Parameters
----------
subgraph_nodes : tensor
The tensor node ids of the subgraph that are associated with this
tree node
Returns
-------
float
Shapley value
"""
num_nodes = self.graph.num_nodes()
subgraph_nodes = subgraph_nodes.tolist()

# Obtain neighboring nodes of the subgraph g_i, P'.
local_region = subgraph_nodes
for _ in range(self.num_hops - 1):
in_neighbors, _ = self.graph.in_edges(local_region)
_, out_neighbors = self.graph.out_edges(local_region)
neighbors = torch.cat([in_neighbors, out_neighbors]).tolist()
local_region = list(set(local_region + neighbors))

split_point = num_nodes
coalition_space = list(set(local_region) - set(subgraph_nodes)) + [
split_point
]

marginal_contributions = []
device = self.feat.device
for _ in range(self.shapley_steps):
permuted_space = np.random.permutation(coalition_space)
split_idx = int(np.where(permuted_space == split_point)[0])

selected_nodes = permuted_space[:split_idx]

# Mask for coalition set S_i
exclude_mask = torch.ones(num_nodes)
exclude_mask[local_region] = 0.0
exclude_mask[selected_nodes] = 1.0

# Mask for set S_i and g_i
include_mask = exclude_mask.clone()
include_mask[subgraph_nodes] = 1.0

exclude_feat = self.feat * exclude_mask.unsqueeze(1).to(device)
include_feat = self.feat * include_mask.unsqueeze(1).to(device)

with torch.no_grad():
exclude_probs = self.model(
self.graph, exclude_feat, **self.kwargs
).softmax(dim=-1)
exclude_value = exclude_probs[:, self.target_class]
include_probs = self.model(
self.graph, include_feat, **self.kwargs
).softmax(dim=-1)
include_value = include_probs[:, self.target_class]
marginal_contributions.append(include_value - exclude_value)

return torch.cat(marginal_contributions).mean().item()

def get_mcts_children(self, mcts_node):
r"""Get the children of the MCTS node for the search.
Parameters
----------
mcts_node : MCTSNode
Node in MCTS
Returns
-------
list
Children nodes after pruning
"""
if len(mcts_node.children) > 0:
return mcts_node.children

subg = node_subgraph(self.graph, mcts_node.nodes)
node_degrees = subg.out_degrees() + subg.in_degrees()
k = min(subg.num_nodes(), self.num_child)
chosen_nodes = torch.topk(
node_degrees, k, largest=self.high2low
).indices

mcts_children_maps = dict()

for node in chosen_nodes:
new_subg = remove_nodes(subg, node.to(subg.idtype), store_ids=True)
# Get the largest weakly connected component in the subgraph.
nx_graph = to_networkx(new_subg.cpu())
largest_cc_nids = list(
max(nx.weakly_connected_components(nx_graph), key=len)
)
# Map to the original node IDs.
largest_cc_nids = new_subg.ndata[NID][largest_cc_nids].long()
largest_cc_nids = subg.ndata[NID][largest_cc_nids].sort().values
if str(largest_cc_nids) not in self.mcts_node_maps:
child_mcts_node = MCTSNode(largest_cc_nids)
self.mcts_node_maps[str(child_mcts_node)] = child_mcts_node
else:
child_mcts_node = self.mcts_node_maps[str(largest_cc_nids)]

if str(child_mcts_node) not in mcts_children_maps:
mcts_children_maps[str(child_mcts_node)] = child_mcts_node

mcts_node.children = list(mcts_children_maps.values())
for child_mcts_node in mcts_node.children:
if child_mcts_node.immediate_reward == 0:
child_mcts_node.immediate_reward = self.shapley(
child_mcts_node.nodes
)

return mcts_node.children

def mcts_rollout(self, mcts_node):
r"""Perform a MCTS rollout.
Parameters
----------
mcts_node : MCTSNode
Starting node for MCTS
Returns
-------
float
Reward for visiting the node this time
"""
if len(mcts_node.nodes) <= self.node_min:
return mcts_node.immediate_reward

children_nodes = self.get_mcts_children(mcts_node)
children_visit_sum = sum([child.num_visit for child in children_nodes])
children_visit_sum_sqrt = math.sqrt(children_visit_sum)
chosen_child = max(
children_nodes,
key=lambda c: c.total_reward / max(c.num_visit, 1)
+ self.coef
* c.immediate_reward
* children_visit_sum_sqrt
/ (1 + c.num_visit),
)
reward = self.mcts_rollout(chosen_child)
chosen_child.num_visit += 1
chosen_child.total_reward += reward

return reward

def explain_graph(self, graph, feat, target_class, **kwargs):
r"""Find the most important subgraph from the original graph for the
model to classify the graph into the target class.
Parameters
----------
graph : DGLGraph
A homogeneous graph
feat : Tensor
The input node feature of shape :math:`(N, D)`, :math:`N` is the
number of nodes, and :math:`D` is the feature size
target_class : int
The target class to explain
kwargs : dict
Additional arguments passed to the GNN model
Returns
-------
Tensor
Nodes that represent the most important subgraph
Examples
--------
>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> from dgl.data import GINDataset
>>> from dgl.dataloading import GraphDataLoader
>>> from dgl.nn import GraphConv, AvgPooling, SubgraphX
>>> # Define the model
>>> class Model(nn.Module):
... def __init__(self, in_dim, n_classes, hidden_dim=128):
... super().__init__()
... self.conv1 = GraphConv(in_dim, hidden_dim)
... self.conv2 = GraphConv(hidden_dim, n_classes)
... self.pool = AvgPooling()
...
... def forward(self, g, h):
... h = F.relu(self.conv1(g, h))
... h = self.conv2(g, h)
... return self.pool(g, h)
>>> # Load dataset
>>> data = GINDataset('MUTAG', self_loop=True)
>>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Train the model
>>> feat_size = data[0][0].ndata['attr'].shape[1]
>>> model = Model(feat_size, data.gclasses)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for bg, labels in dataloader:
... logits = model(bg, bg.ndata['attr'])
... loss = criterion(logits, labels)
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Initialize the explainer
>>> explainer = SubgraphX(model, num_hops=2)
>>> # Explain the prediction for graph 0
>>> graph, l = data[0]
>>> graph_feat = graph.ndata.pop("attr")
>>> g_nodes_explain = explainer.explain_graph(graph, graph_feat,
... target_class=l)
"""
self.model.eval()
assert (
graph.num_nodes() > self.node_min
), f"The number of nodes in the\
graph {graph.num_nodes()} should be bigger than {self.node_min}."

self.graph = graph
self.feat = feat
self.target_class = target_class
self.kwargs = kwargs

# book all nodes in MCTS
self.mcts_node_maps = dict()

root = MCTSNode(graph.nodes())
self.mcts_node_maps[str(root)] = root

for i in range(self.num_rollouts):
if self.log:
print(
f"Rollout {i}/{self.num_rollouts}, \
{len(self.mcts_node_maps)} subgraphs have been explored."
)
self.mcts_rollout(root)

best_leaf = None
best_immediate_reward = float("-inf")
for mcts_node in self.mcts_node_maps.values():
if len(mcts_node.nodes) > self.node_min:
continue

if mcts_node.immediate_reward > best_immediate_reward:
best_leaf = mcts_node
best_immediate_reward = best_leaf.immediate_reward

return best_leaf.nodes
Loading

0 comments on commit 70dd2a0

Please sign in to comment.