Code for 2021 NeurIPS --- Towards Multi-Grained Explainability for Graph Neural Networks

Related tags

Deep Learning ReFine
Overview

ReFine: Multi-Grained Explainability for GNNs

We are trying hard to update the code, but it may take a while to complete due to our tight schedule recently. Thank you for your waiting!

Installation

Requirements

  • CPU or NVIDIA GPU, Linux, Python 3.7
  • PyTorch, various Python packages

Main Packages

  1. Pytorch Geometric. Official Download.
# We use TORCH version 1.6.0
CUDA=cu101
TORCH=1.6.0 
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-geometric
  1. Visual Genome. Google Drive Download. This is used for preprocessing the VG-5 dataset and visualizing the generated explanations. Manually download it to the same directory as data/. (Yes, this package can be installed using pip or API, but we find it slow to use).

Datasets

  1. The processed raw data for BA-3motif is available in the data/ folder.
  2. Datasets MNIST, Mutagenicity will be automatically downloaded when training models.
  3. We select and label 4444 graphs from https://visualgenome.org/ to construct the VG-5 dataset. The graphs are labeled with five classes: stadium, street, farm, surfing, forest. Each graph contains regions of the objects as the nodes, while edges indicate the relationships between object nodes.

Download the dataset from Google Drive. Arrange the dir as

data ---BA3
 |------VG
        |---raw

Please remember to cite Visual Genome (bibtex) if you use our VG-5 dataset.

Training GNNs

cd gnns/
python ba3motif_gnn.py --epoch 100 --num_unit 2 --batch_size 128

The trained GNNs will be saved in param/gnns.

Explaining the Predictions

code is coming soon

Evaluation & Visualization

code is coming soon

Citation

Please cite our paper if you find the repository useful.

@inproceedings{2021refine,
  title={Towards Multi-Grained Explainability for Graph Neural Networks },
  author={Wang, Xiang and Wu, Ying-Xin and Zhang, An and He, Xiangnan and Chua, Tat-Seng},
  booktitle={Proceedings of the 35th Conference on Neural Information Processing Systems},
  year={2021} 
}
Comments
  • ReFine training

    ReFine training

    Hello, thanks for making the explainer available! I was trying to do the training step, and get print like so (running python refine_train.py --dataset ba3 --hid 50 --epoch 25 --ratio 0.4 --lr 1e-4):

    2021-12-22 11:22:08,770 - refine_train.py[line:95] - INFO: number of graphs(train): 2185
    2021-12-22 11:22:08,775 - refine_train.py[line:95] - INFO: number of graphs(val):  398
    2021-12-22 11:22:08,777 - refine_train.py[line:95] - INFO: number of graphs(test):  397
    2021-12-22 11:23:00,702 - refine_train.py[line:158] - INFO: Epoch: 1, LR: 0.00010, Ratio: 0.40, Train Loss: nan, Val Loss: nan
    2021-12-22 11:23:00,709 - refine_train.py[line:163] - INFO: ACC:[[0.31 0.31 0.31 0.32 0.35 0.37 0.52 0.77 0.96 1.  ]] ACC-AUC: 0.521 Mean P: nan
    2021-12-22 11:23:49,479 - refine_train.py[line:158] - INFO: Epoch: 2, LR: 0.00010, Ratio: 0.40, Train Loss: nan, Val Loss: nan
    2021-12-22 11:24:37,940 - refine_train.py[line:158] - INFO: Epoch: 3, LR: 0.00010, Ratio: 0.40, Train Loss: nan, Val Loss: nan
    2021-12-22 11:24:37,943 - refine_train.py[line:163] - INFO: ACC:[[0.31 0.31 0.31 0.31 0.34 0.37 0.56 0.76 0.97 1.  ]] ACC-AUC: 0.523 Mean P: nan
    2021-12-22 11:25:26,978 - refine_train.py[line:158] - INFO: Epoch: 4, LR: 0.00002, Ratio: 0.40, Train Loss: nan, Val Loss: nan
    2021-12-22 11:26:15,519 - refine_train.py[line:158] - INFO: Epoch: 5, LR: 0.00002, Ratio: 0.40, Train Loss: nan, Val Loss: nan
    2021-12-22 11:26:15,522 - refine_train.py[line:163] - INFO: ACC:[[0.31 0.31 0.31 0.31 0.36 0.37 0.52 0.73 0.97 1.  ]] ACC-AUC: 0.519 Mean P: nan
    2021-12-22 11:27:03,904 - refine_train.py[line:158] - INFO: Epoch: 6, LR: 0.00002, Ratio: 0.40, Train Loss: nan, Val Loss: nan
    2021-12-22 11:27:52,747 - refine_train.py[line:158] - INFO: Epoch: 7, LR: 0.00002, Ratio: 0.40, Train Loss: nan, Val Loss: nan
    2021-12-22 11:27:52,750 - refine_train.py[line:163] - INFO: ACC:[[0.31 0.31 0.31 0.31 0.39 0.38 0.52 0.77 0.97 0.99]] ACC-AUC: 0.526 Mean P: nan
    2021-12-22 11:28:41,475 - refine_train.py[line:158] - INFO: Epoch: 8, LR: 0.00001, Ratio: 0.40, Train Loss: nan, Val Loss: nan
    2021-12-22 11:29:29,920 - refine_train.py[line:158] - INFO: Epoch: 9, LR: 0.00001, Ratio: 0.40, Train Loss: nan, Val Loss: nan
    2021-12-22 11:29:29,923 - refine_train.py[line:163] - INFO: ACC:[[0.31 0.31 0.31 0.31 0.35 0.37 0.51 0.74 0.97 1.  ]] ACC-AUC: 0.517 Mean P: nan
    

    Is it normal that the train loss and validation loss are nan? Also the ACC-AUC does not seem to meaningfully improve, so I'm wondering where is the issue...

    opened by Cupcee 9
  • BA3-Motif cannot be created correctly.

    BA3-Motif cannot be created correctly.

    When I try to train another GNN in BA3-Motif dataset, the 'process' function in 'BA3Motif' class will throw an exception.

    Concretely, in 'ba3motif_dataset.py' line 67:

    torch.save(self.collate(data_list[800:]),self.processed_paths[0])

    Keyerror:15

    'collate' function always causes the Keyerror above.

    opened by Esperanto-mega 6
  • Question about the EdgeMask

    Question about the EdgeMask

    Hi Ying-Xin,

    Thanks for your excellent work and concise codes!

    I am troubled with the mask function that used in both PG-Explainer and ReFine. Specifically, the mask function consists of the ARMAConv operations and MLPs in common.py, but I haven't found the descriptions of ARMAConv in paper. So I am wondering that why ARMAConv operations are necessary here?

    https://github.com/Wuyxin/ReFine/blob/d8ec5482e8ce2eb37308f1fae1f2dcabe1f33716/explainers/common.py#L36

    Looking forward to your reply, much thanks! :)

    opened by zealscott 4
  • Advising the hyperparameters for training GIN for Mutagenicity

    Advising the hyperparameters for training GIN for Mutagenicity

    Hi,

    Could you provide me with the hyperparameters for training GIN for Mutagenicity? I can not reproduce a 100% test accuracy model.

    Thank you, Zhaoning

    opened by ZhaoningYu1996 4
  • How to deal with HeteroData objects in ReFine?

    How to deal with HeteroData objects in ReFine?

    I encouter problems when I use HAN model to do some graph classification tasks. I read mutag datasets codes which deal with datasets with torch_geometric.data.Data, but HeteroData is used in HAN model. I want to know if ReFine can deal with HeteroData objects.

    opened by 1054518207 3
  • When test dataset no label, how refine will predict?

    When test dataset no label, how refine will predict?

    I read refine model codes, that's an excellent project. And, I want to know when test dataset no label, how refine will predict? (In station of mutag) I see codes in https://github.com/Wuyxin/ReFine/blob/main/explainers/refine.py#L118 in get_mask() method when gets mask using labels.

    opened by Aksox 2
  • ACC-AUC on BA3-Motif

    ACC-AUC on BA3-Motif

    Hi, I want to ask something regarding the performance on BA3-motif.

    You reported that Refiner achieve an ACC-AUC of 0.630. However, in your log file, the highest ACC-AUC of Refiner is only 0.612. Is anything wrong with the log you uploaded?

    opened by smiles724 2
  • Performance of other baseline methods

    Performance of other baseline methods

    Hi, I notice that you provide many scripts of other explanation methods. Did you test the performance of those approaches? If you don't mind, can you give the full results of them? Thanks.

    opened by smiles724 2
  • How to construct VG-5 Dataset.

    How to construct VG-5 Dataset.

    I have downloaded visual_genome from https://github.com/ranjaykrishna/visual_genome_python_driver

    and image_data.json & synsets.json from https://visualgenome.org/api/v0/api_home.html.

    But where can I get the image_id.json such as 35.json, hope to get your reply.

    opened by Esperanto-mega 2
  • Performance of backbone GNN

    Performance of backbone GNN

    Hi, thanks for sharing the code, which is well-written. I re-implemented the GNN on all four datasets. I found that GNNs on MINST/VG/BA3 perform closely to what is reported in the paper. However, the accuracy of the testing set is only 80% for MUTAG, but you claim that it can achieve 100% accuracy.

    I know this will not influence the effectiveness of your method. But since I want to follow your work, it is better to figure out whether the backbone GNN is strong enough on MUTAG. What do you think of this gap?

    opened by smiles724 1
  • Advising ACC-AUC metrics in the paper

    Advising ACC-AUC metrics in the paper

    Hi,

    Does the ACC-AUC metric in your paper mean ROC-AUC of the prediction of generated subgraphs? How do you calculate the ACC-AUC?

    Thank you, Zhaoning

    opened by ZhaoningYu1996 1
  • Failed to reproduce the results on Mutag dataset

    Failed to reproduce the results on Mutag dataset

    Hello,

    I retrained the downstream model following the documentation (e.g. python mutag_gnn.py) and also the Refine model using the recommended: python refine_train.py --dataset mutag --hid 100 --epoch 100 --ratio 0.4 --lr 1e-3 --batch_size 64

    However, at the end of the training I get an AUC of 0.82 which is quite far from the reported 0.955. Am I missing something?

    2022-07-01 18:01:23,919 - refine_train.py[line:163] - INFO: ACC:[[0.86 0.81 0.78 0.78 0.77 0.79 0.77 0.8 0.86 1. ]] ACC-AUC: 0.822 Mean P: 0.099 2022-07-01 18:01:44,651 - refine_train.py[line:158] - INFO: Epoch: 98, LR: 0.00001, Ratio: 0.40, Train Loss: -163.027, ValLoss: -1.704

    opened by ciortanmadalina 1
  • Strange behavior with randomness

    Strange behavior with randomness

    On first run, say if I clone the repo and train with command python3 refine_train.py --dataset ba3 --hid 50 --epoch 1 --ratio 0.4 --lr 1e-4, I always get the same ACC-AUC of 0.518. On second and all subsequent runs, this same command gives me ACC-AUC 0.490.

    This happens with any number of epochs but is easiest to verify with just one. It seems like something is not quite working with the random seed on first run (although this first run is still seeded as it produces the same result of 0.518 every time), and then once it is trained once, the seed starts working. I even cloned the repo several times again and this same pattern always happened.

    I tried to fix this myself but couldn't. Now, this isn't really critical to fix or anything, but I think it's good to mention as it caused quite a bit of confusion for me when testing the code.

    Torch version is 1.8.0 because I couldn't get 1.6.0 to work with torch-scatter. Otherwise the setup is the same as in README

    opened by Cupcee 4
  • ReFine training on other GNNs

    ReFine training on other GNNs

    Hi,

    I am trying to use ReFine on other GNNs like GCN. But I cannot achieve a good result. Could you tell me which part of the settings should I change to achieve the best performance?

    opened by ZhaoningYu1996 0
Owner
Shirley (Ying-Xin) Wu
Senior Undergraduate @ LDS, School of Data Science. [email protected]
Shirley (Ying-Xin) Wu
PyTorch code for the paper "Complementarity is the King: Multi-modal and Multi-grained Hierarchical Semantic Enhancement Network for Cross-modal Retrieval".

Complementarity is the King: Multi-modal and Multi-grained Hierarchical Semantic Enhancement Network for Cross-modal Retrieval (M2HSE) PyTorch code fo

Xinlei-Pei 6 Dec 23, 2022
Towards Fine-Grained Reasoning for Fake News Detection

FinerFact This is the PyTorch implementation for the FinerFact model in the AAAI 2022 paper Towards Fine-Grained Reasoning for Fake News Detection (Ar

Ahren_Jin 15 Dec 15, 2022
Code to reproduce experiments in the paper "Explainability Requires Interactivity".

Explainability Requires Interactivity This repository contains the code to train all custom models used in the paper Explainability Requires Interacti

Digital Health & Machine Learning 5 Apr 7, 2022
FIRA: Fine-Grained Graph-Based Code Change Representation for Automated Commit Message Generation

FIRA is a learning-based commit message generation approach, which first represents code changes via fine-grained graphs and then learns to generate commit messages automatically.

Van 21 Dec 30, 2022
The source code of the paper "Understanding Graph Neural Networks from Graph Signal Denoising Perspectives"

GSDN-F and GSDN-EF This repository provides a reference implementation of GSDN-F and GSDN-EF as described in the paper "Understanding Graph Neural Net

Guoji Fu 18 Nov 14, 2022
Defending graph neural networks against adversarial attacks (NeurIPS 2020)

GNNGuard: Defending Graph Neural Networks against Adversarial Attacks Authors: Xiang Zhang ([email protected]), Marinka Zitnik (marinka@hms.

Zitnik Lab @ Harvard 44 Dec 7, 2022
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 4, 2023
Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.

PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers 1 Using Colab Please notic

Hila Chefer 489 Jan 7, 2023
Code to reproduce the experiments from our NeurIPS 2021 paper " The Limitations of Large Width in Neural Networks: A Deep Gaussian Process Perspective"

Code To run: python runner.py new --save <SAVE_NAME> --data <PATH_TO_DATA_DIR> --dataset <DATASET> --model <model_name> [options] --n 1000 - train - t

Geoff Pleiss 5 Dec 12, 2022
This repository contains the source code for the paper "DONeRF: Towards Real-Time Rendering of Compact Neural Radiance Fields using Depth Oracle Networks",

DONeRF: Towards Real-Time Rendering of Compact Neural Radiance Fields using Depth Oracle Networks Project Page | Video | Presentation | Paper | Data L

Facebook Research 281 Dec 22, 2022
A tensorflow=1.13 implementation of Deconvolutional Networks on Graph Data (NeurIPS 2021)

GDN A tensorflow=1.13 implementation of Deconvolutional Networks on Graph Data (NeurIPS 2021) Abstract In this paper, we consider an inverse problem i

null 4 Sep 13, 2022
Code for the ICML 2021 paper "Bridging Multi-Task Learning and Meta-Learning: Towards Efficient Training and Effective Adaptation", Haoxiang Wang, Han Zhao, Bo Li.

Bridging Multi-Task Learning and Meta-Learning Code for the ICML 2021 paper "Bridging Multi-Task Learning and Meta-Learning: Towards Efficient Trainin

AI Secure 57 Dec 15, 2022
This is the code of NeurIPS'21 paper "Towards Enabling Meta-Learning from Target Models".

ST This is the code of NeurIPS 2021 paper "Towards Enabling Meta-Learning from Target Models". If you use any content of this repo for your work, plea

Su Lu 7 Dec 6, 2022
Official pytorch code for SSC-GAN: Semi-Supervised Single-Stage Controllable GANs for Conditional Fine-Grained Image Generation(ICCV 2021)

SSC-GAN_repo Pytorch implementation for 'Semi-Supervised Single-Stage Controllable GANs for Conditional Fine-Grained Image Generation'.PDF SSC-GAN:Sem

tyty 4 Aug 28, 2022
PyTorch implementation for Stochastic Fine-grained Labeling of Multi-state Sign Glosses for Continuous Sign Language Recognition.

Stochastic CSLR This is the PyTorch implementation for the ECCV 2020 paper: Stochastic Fine-grained Labeling of Multi-state Sign Glosses for Continuou

Zhe Niu 28 Dec 19, 2022
X-VLM: Multi-Grained Vision Language Pre-Training

X-VLM: learning multi-grained vision language alignments Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts. Yan Zeng, Xi

Yan Zeng 286 Dec 23, 2022
A static analysis library for computing graph representations of Python programs suitable for use with graph neural networks.

python_graphs This package is for computing graph representations of Python programs for machine learning applications. It includes the following modu

Google Research 258 Dec 29, 2022
Some tentative models that incorporate label propagation to graph neural networks for graph representation learning in nodes, links or graphs.

Some tentative models that incorporate label propagation to graph neural networks for graph representation learning in nodes, links or graphs.

zshicode 1 Nov 18, 2021