A PyTorch implementation of "Graph Classification Using Structural Attention" (KDD 2018).

Overview

GAM

PWC codebeat badge repo sizebenedekrozemberczki

A PyTorch implementation of Graph Classification Using Structural Attention (KDD 2018).

Abstract

Graph classification is a problem with practical applications in many different domains. To solve this problem, one usually calculates certain graph statistics (i.e., graph features) that help discriminate between graphs of different classes. When calculating such features, most existing approaches process the entire graph. In a graphlet-based approach, for instance, the entire graph is processed to get the total count of different graphlets or subgraphs. In many real-world applications, however, graphs can be noisy with discriminative patterns confined to certain regions in the graph only. In this work, we study the problem of attention-based graph classification . The use of attention allows us to focus on small but informative parts of the graph, avoiding noise in the rest of the graph. We present a novel RNN model, called the Graph Attention Model (GAM), that processes only a portion of the graph by adaptively selecting a sequence of “informative” nodes. Experimental results on multiple real-world datasets show that the proposed method is competitive against various well-known methods in graph classification even though our method is limited to only a portion of the graph.

This repository provides an implementation for GAM as described in the paper:

Graph Classification using Structural Attention. John Boaz Lee, Ryan Rossi, and Xiangnan Kong KDD, 2018. [Paper]

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx           2.4
tqdm               4.28.1
numpy              1.15.4
pandas             0.23.4
texttable          1.5.0
argparse           1.1.0
sklearn            0.20.0
torch              1.2.0.
torchvision        0.3.0

Datasets

The code takes graphs for training from an input folder where each graph is stored as a JSON. Graphs used for testing are also stored as JSON files. Every node id, node label and class has to be indexed from 0. Keys of dictionaries and nested dictionaries are stored strings in order to make JSON serialization possible.

For example these JSON files have the following key-value structure:

{"target": 1,
 "edges": [[0, 1], [0, 4], [1, 3], [1, 4], [2, 3], [2, 4], [3, 4]],
 "labels": {"0": 2, "1": 3, "2": 2, "3": 3, "4": 4},
 "inverse_labels": {"2": [0, 2], "3": [1, 3], "4": [4]}}

The **target key** has an integer value, which is the ID of the target class (e.g. Carcinogenicity). The **edges key** has an edge list value for the graph of interest. The **labels key** has a dictonary value for each node, these labels are stored as key-value pairs (e.g. node - atom pair). The **inverse_labels key** has a key for each node label and the values are lists containing the nodes that have a specific node label.

Options

Training a GAM model is handled by the src/main.py script which provides the following command line arguments.

Input and output options

  --train-graph-folder   STR    Training graphs folder.      Default is `input/train/`.
  --test-graph-folder    STR    Testing graphs folder.       Default is `input/test/`.
  --prediction-path      STR    Path to store labels.        Default is `output/erdos_predictions.csv`.
  --log-path             STR    Log json path.               Default is `logs/erdos_gam_logs.json`. 

Model options

  --repetitions          INT         Number of scoring runs.                  Default is 10. 
  --batch-size           INT         Number of graphs processed per batch.    Default is 32. 
  --time                 INT         Time budget.                             Default is 20. 
  --step-dimensions      INT         Neurons in step layer.                   Default is 32. 
  --combined-dimensions  INT         Neurons in shared layer.                 Default is 64. 
  --epochs               INT         Number of GAM training epochs.           Default is 10. 
  --learning-rate        FLOAT       Learning rate.                           Default is 0.001.
  --gamma                FLOAT       Discount rate.                           Default is 0.99. 
  --weight-decay         FLOAT       Weight decay.                            Default is 10^-5. 

Examples

The following commands learn a neural network, make predictions, create logs, and write the latter ones to disk.

Training a GAM model on the default dataset. Saving predictions and logs at default paths.

python src/main.py

Training a GAM model for a 100 epochs with a batch size of 512.

python src/main.py --epochs 100 --batch-size 512

Setting a high time budget for the agent.

python src/main.py --time 128

Training a model with some custom learning rate and epoch number.

python src/main.py --learning-rate 0.001 --epochs 200

License


Comments
  • ValueError: probabilities contain NaN

    ValueError: probabilities contain NaN

    image This problem occurred when I used my own data set? Why did he happen, how can I solve it, I have turned my data into the same json format as the input folder.

    opened by yy1335574510 8
  • ValueError: probabilities contain NaN

    ValueError: probabilities contain NaN

    Hi, thank you very much for your excellent work, I created my own dataset following the .json file format in your code, I had this problem after training part of my data, I observed the difference between our data, there is a non-connected graph in my data and there are isolated nodes, is this the reason why the code is wrong in the middle of training? Training started.

    Epoch: 0%| | 0/10 [00:00<?, ?it/s] 0%| | 0/25 [00:00<?, ?it/s] (Loss=4.1643): 0%| | 0/25 [00:01<?, ?it/s] (Loss=4.1643): 4%|▍ | 1/25 [00:01<00:29, 1.22s/it] (Loss=1.1765): 4%|▍ | 1/25 [00:02<00:29, 1.22s/it] (Loss=1.1765): 8%|▊ | 2/25 [00:02<00:28, 1.26s/it] (Loss=-1.8768): 8%|▊ | 2/25 [00:03<00:28, 1.26s/it] (Loss=-1.8768): 12%|█▏ | 3/25 [00:03<00:26, 1.20s/it] (Loss=-1.1337): 12%|█▏ | 3/25 [00:04<00:26, 1.20s/it] (Loss=-1.1337): 16%|█▌ | 4/25 [00:05<00:26, 1.27s/it] Epoch: 0%| | 0/10 [00:05<?, ?it/s] Traceback (most recent call last): File "D:/GAM-master/src/main.py", line 19, in main() File "D:/GAM-master/src/main.py", line 14, in main model.fit() File "D:\GAM-master\src\gam.py", line 272, in fit self.epoch_loss = self.epoch_loss + self.process_batch(batches[batch]) File "D:\GAM-master\src\gam.py", line 240, in process_batch batch_loss = self.process_graph(graph_path, batch_loss) File "D:\GAM-master\src\gam.py", line 221, in process_graph predictions, node, attention_score = self.model(data, graph, features, node) File "D:\anaconda\envs\spyder\lib\site-packages\torch\nn\modules\module.py", line 532, in call result = self.forward(*input, **kwargs) File "D:\GAM-master\src\gam.py", line 180, in forward self.state, node, attention_score = self.step_block(data, graph, features, node) File "D:\anaconda\envs\spyder\lib\site-packages\torch\nn\modules\module.py", line 532, in call result = self.forward(*input, **kwargs) File "D:\GAM-master\src\gam.py", line 97, in forward feature_row, node, attention_score = self.make_step(node, graph, features, File "D:\GAM-master\src\gam.py", line 78, in make_step label = self.sample_node_label(orig_neighbors, graph, features) File "D:\GAM-master\src\gam.py", line 66, in sample_node_label label = np.random.choice(np.arange(len(self.identifiers)), p=normalized_attention_spread) File "mtrand.pyx", line 928, in numpy.random.mtrand.RandomState.choice ValueError: probabilities contain NaN

    Process finished with exit code 1

    Here's an example from my .json file {"target": 1,"edges":[[0,2],[0,8],[0,12],[0,13],[0,14],[1,2],[1,8],[2,3],[2,8],[3,5],[6,14],[11,12],[11,13]],"labels":{"0":113,"1":106,"2":78,"3":102,"4":91,"5":63,"6":68,"7":76,"8":73,"9":119,"10":87,"11":102,"12":110,"13":117,"14":78,"15":124},"inverse_labels":{"113":[0],"106":[1],"78":[2,14],"102":[3,11],"91":[4],"63":[5],"68":[6],"76":[7],"73":[8],"119":[9],"87":[10],"110":[12],"117":[13],"124":[15]}}

    Sorry for disturbing you, thank you very much for your help !

    opened by XieTianLi 2
  • [IndexError] Mismatched node label indexing?

    [IndexError] Mismatched node label indexing?

    Hi, thanks for sharing the codes.

    I want to test GAM on my own dataset and encountered the IndexError. This is my simplied dataset: https://drive.google.com/file/d/1P2Z0i86ffF0dMq4OXcRPR2Icj4pvZ1WP/view?usp=sharing

    Is there any problem of the generated dataset? The node labels are actually the node degree and the sizes of graphs are always 50.

    I have tried to debug and found the most relevant codes: line 45-75 in src/gam.py . I am suspending that the indexing of labels in self.identifiers, labels, and inverse_labels are mismatched. In detail, the return label of function sample_node_label is not in the same domain of labels.keys(). They are referring to different mappings of the labels.

    Actually, I have changed the labels[str(label)] in line 73 and 74 to label . There is no bug afterwards. But I am not sure if the change is correct.

    Could you help me with this debug? Thank you!

    opened by silent567 2
  • Unhashable type dict

    Unhashable type dict

    I tried to run the code with default parameters and dataset, and on different systems I got the same error:

    Epoch:   0%|                                                                                    | 0/10 [00:00<?, ?it/s]Traceback (most recent call last):                                                                | 0/2 [00:00<?, ?it/s]
      File "src/main.py", line 17, in <module>
        main()
      File "src/main.py", line 12, in main
        model.fit()
      File "C:\Users\acecreamu\GAM\src\gam.py", line 251, in fit
        self.epoch_loss = self.epoch_loss + self.process_batch(batches[batch])
      File "C:\Users\acecreamu\GAM\src\gam.py", line 221, in process_batch
        batch_loss = self.process_graph(graph_path, batch_loss)
      File "C:\Users\acecreamu\GAM\src\gam.py", line 202, in process_graph
        predictions, node, attention_score = self.model(data, graph, features, node)
      File "C:\Users\acecreamu\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__
        result = self.forward(*input, **kwargs)
      File "C:\Users\acecreamu\GAM\src\gam.py", line 166, in forward
        self.state, node, attention_score = self.step_block(data, graph, features, node)
      File "C:\Users\acecreamu\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__
        result = self.forward(*input, **kwargs)
      File "C:\Users\acecreamu\GAM\src\gam.py", line 88, in forward
        feature_row, node, attention_score = self.make_step(node, graph, features, data["labels"], data["inverse_labels"])
      File "C:\Users\acecreamu\GAM\src\gam.py", line 69, in make_step
        original_neighbors = set(nx.neighbors(graph, node))
      File "C:\Users\acecreamu\Anaconda3\lib\site-packages\networkx\classes\function.py", line 64, in neighbors
        return G.neighbors(n)
      File "C:\Users\acecreamu\Anaconda3\lib\site-packages\networkx\classes\graph.py", line 1266, in neighbors
        return iter(self._adj[n])
    TypeError: unhashable type: 'dict'
    

    So the problem is in line original_neighbors = set(nx.neighbors(graph, node)) , but I have no clue how to solve it.

    opened by acecreamu 2
  • in gam.py,self.epoch_loss = 0,  self.nodes_processed = 0 what this?

    in gam.py,self.epoch_loss = 0, self.nodes_processed = 0 what this?

    Hi, I would like to ask you about the self.epoch_loss = 0, self.nodes_processed = 0 What does it mean?

    Seeing that you wrote the value 0, The divider cannot be assigned a value of 0 at runtime. How to set this value? thank u.

    opened by camellia2 1
  • Run the code error:The number of graph classes is: 0.

    Run the code error:The number of graph classes is: 0.

    Can you help me solve the error " Collecting unique node labels. 0it [00:00, ?it/s] The number of graph classes is: 0. Training started. 0it [00:00, ?it/s] ZeroDivisionError: division by zero Epoch: 0%|
    "when i run the code? Thank U very much.

    opened by YoungBx 1
  • Create own dataset

    Create own dataset

    Hello,

    I would like to test you code with my own dataset, however, I did not undestand the format of it, specially the label key . In your README yo say that

    labels key has a dictonary value for each node, these labels are stored as key-value pairs (e.g. node - atom pair)

    Can the labels be strings? If the label must be numbers, why is the labels needed as a dictionary? I mean, could not it be a list of the nodes? It is all confusing for me. Could you explain it please? Pehaps you could create an image of the graph you are representing in your example of the README. I think that a image would clarify how the json can be made.

    opened by HenriqueVarellaEhrenfried 1
  • "inverse_labels"

    Excuse me, in the program, which part of the paper does "inverse_labels" appear in? I downloaded the NCI_1 dataset without seeing it. I hope you can see the incomprehensible answer to me, thank you.

    opened by try-to-anything 1
  • ImportError: cannot import name 'parameter_parser'

    ImportError: cannot import name 'parameter_parser'

    I had this issue with import:

    File "src/main.py", line 1, in <module>
        from parser import parameter_parser
    ImportError: cannot import name 'parameter_parser'
    

    I solved it by renaming parser.py (e.g. my_parser.py) and calling from my_parser import parameter_parser No problems anymore, just that it may be helpful for others.

    opened by acecreamu 1
  • Graph embeddings

    Graph embeddings

    When I execute the script by default, it performs the learning and graph classification task, etc, but how could I get explicitly the graph embedding for my training/test examples?

    Thnaks

    opened by leoguti85 0
  • How to use GAM for edge prediction?

    How to use GAM for edge prediction?

    Hello,

    I am going though your paper implementation and it's really exciting. I am getting good result on the datasets which are already in repository.

    However I was thinking to extend this network and use it for link prediction/ edge classification/ relation prediction task.

    So suppose if my dataset looks like this:

    bond_one           bond_two              relations
    
    a_b                 b_c                 a_d_r , a_b , a_r .... etc ( relations can be many )
    c_v                 r_a                 c_r_a , c_b_a , r_r_r..etc
    

    Now if bond_one and bond_two are nodes in a graph and raltions are edges can we use GAM model for solving this kind of problem?

    Or do you have any idea how we can use GAM model for this kind of task.

    looking forward to hear from you.

    Thank you ! Keep wiring keep sharing :)

    opened by Abhinav43 0
Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
Code for: Gradient-based Hierarchical Clustering using Continuous Representations of Trees in Hyperbolic Space. Nicholas Monath, Manzil Zaheer, Daniel Silva, Andrew McCallum, Amr Ahmed. KDD 2019.

gHHC Code for: Gradient-based Hierarchical Clustering using Continuous Representations of Trees in Hyperbolic Space. Nicholas Monath, Manzil Zaheer, D

Nicholas Monath 35 Nov 16, 2022
This project provides an unsupervised framework for mining and tagging quality phrases on text corpora with pretrained language models (KDD'21).

UCPhrase: Unsupervised Context-aware Quality Phrase Tagging To appear on KDD'21...[pdf] This project provides an unsupervised framework for mining and

Xiaotao Gu 146 Dec 22, 2022
Code for KDD'20 "An Efficient Neighborhood-based Interaction Model for Recommendation on Heterogeneous Graph"

Heterogeneous INteract and aggreGatE (GraphHINGE) This is a pytorch implementation of GraphHINGE model. This is the experiment code in the following w

Jinjiarui 69 Nov 24, 2022
Code for the KDD 2021 paper 'Filtration Curves for Graph Representation'

Filtration Curves for Graph Representation This repository provides the code from the KDD'21 paper Filtration Curves for Graph Representation. Depende

Machine Learning and Computational Biology Lab 16 Oct 16, 2022
A PyTorch Implementation of "SINE: Scalable Incomplete Network Embedding" (ICDM 2018).

Scalable Incomplete Network Embedding ⠀⠀ A PyTorch implementation of Scalable Incomplete Network Embedding (ICDM 2018). Abstract Attributed network em

Benedek Rozemberczki 69 Sep 22, 2022
A PyTorch implementation of "Signed Graph Convolutional Network" (ICDM 2018).

SGCN ⠀ A PyTorch implementation of Signed Graph Convolutional Network (ICDM 2018). Abstract Due to the fact much of today's data can be represented as

Benedek Rozemberczki 251 Nov 30, 2022
A PyTorch Implementation of "Watch Your Step: Learning Node Embeddings via Graph Attention" (NeurIPS 2018).

Attention Walk ⠀⠀ A PyTorch Implementation of Watch Your Step: Learning Node Embeddings via Graph Attention (NIPS 2018). Abstract Graph embedding meth

Benedek Rozemberczki 303 Dec 9, 2022
PyTorch implementation of Wide Residual Networks with 1-bit weights by McDonnell (ICLR 2018)

1-bit Wide ResNet PyTorch implementation of training 1-bit Wide ResNets from this paper: Training wide residual networks for deployment using a single

Sergey Zagoruyko 122 Dec 7, 2022
Official Pytorch implementation of ICLR 2018 paper Deep Learning for Physical Processes: Integrating Prior Scientific Knowledge.

Deep Learning for Physical Processes: Integrating Prior Scientific Knowledge: Official Pytorch implementation of ICLR 2018 paper Deep Learning for Phy

emmanuel 47 Nov 6, 2022
StarGAN - Official PyTorch Implementation (CVPR 2018)

StarGAN - Official PyTorch Implementation ***** New: StarGAN v2 is available at https://github.com/clovaai/stargan-v2 ***** This repository provides t

Yunjey Choi 5.1k Jan 4, 2023
PyTorch code for our ECCV 2018 paper "Image Super-Resolution Using Very Deep Residual Channel Attention Networks"

PyTorch code for our ECCV 2018 paper "Image Super-Resolution Using Very Deep Residual Channel Attention Networks"

Yulun Zhang 1.2k Dec 26, 2022
Python implementation of Wu et al (2018)'s registration fusion

reg-fusion Projection of a central sulcus probability map using the RF-ANTs approach (right hemisphere shown). This is a Python implementation of Wu e

Dan Gale 26 Nov 12, 2021
Official Keras Implementation for UNet++ in IEEE Transactions on Medical Imaging and DLMIA 2018

UNet++: A Nested U-Net Architecture for Medical Image Segmentation UNet++ is a new general purpose image segmentation architecture for more accurate i

Zongwei Zhou 1.8k Jan 7, 2023
Project page of the paper 'Analyzing Perception-Distortion Tradeoff using Enhanced Perceptual Super-resolution Network' (ECCVW 2018)

EPSR (Enhanced Perceptual Super-resolution Network) paper This repo provides the test code, pretrained models, and results on benchmark datasets of ou

Subeesh Vasu 78 Nov 19, 2022
3D ResNets for Action Recognition (CVPR 2018)

3D ResNets for Action Recognition Update (2020/4/13) We published a paper on arXiv. Hirokatsu Kataoka, Tenga Wakamiya, Kensho Hara, and Yutaka Satoh,

Kensho Hara 3.5k Jan 6, 2023
Code for the paper "Adversarially Regularized Autoencoders (ICML 2018)" by Zhao, Kim, Zhang, Rush and LeCun

ARAE Code for the paper "Adversarially Regularized Autoencoders (ICML 2018)" by Zhao, Kim, Zhang, Rush and LeCun https://arxiv.org/abs/1706.04223 Disc

Junbo (Jake) Zhao 399 Jan 2, 2023
Code for paper "Which Training Methods for GANs do actually Converge? (ICML 2018)"

GAN stability This repository contains the experiments in the supplementary material for the paper Which Training Methods for GANs do actually Converg

Lars Mescheder 885 Jan 1, 2023
Training Confidence-Calibrated Classifier for Detecting Out-of-Distribution Samples / ICLR 2018

Training Confidence-Calibrated Classifier for Detecting Out-of-Distribution Samples This project is for the paper "Training Confidence-Calibrated Clas

null 168 Nov 29, 2022
PointNetVLAD: Deep Point Cloud Based Retrieval for Large-Scale Place Recognition, CVPR 2018

PointNetVLAD: Deep Point Cloud Based Retrieval for Large-Scale Place Recognition PointNetVLAD: Deep Point Cloud Based Retrieval for Large-Scale Place

Mikaela Uy 294 Dec 12, 2022