Graph Attention Networks

Overview

GAT

Graph Attention Networks (Veličković et al., ICLR 2018): https://arxiv.org/abs/1710.10903

GAT layer t-SNE + Attention coefficients on Cora

Overview

Here we provide the implementation of a Graph Attention Network (GAT) layer in TensorFlow, along with a minimal execution example (on the Cora dataset). The repository is organised as follows:

  • data/ contains the necessary dataset files for Cora;
  • models/ contains the implementation of the GAT network (gat.py);
  • pre_trained/ contains a pre-trained Cora model (achieving 84.4% accuracy on the test set);
  • utils/ contains:
    • an implementation of an attention head, along with an experimental sparse version (layers.py);
    • preprocessing subroutines (process.py);
    • preprocessing utilities for the PPI benchmark (process_ppi.py).

Finally, execute_cora.py puts all of the above together and may be used to execute a full training run on Cora.

Sparse version

An experimental sparse version is also available, working only when the batch size is equal to 1. The sparse model may be found at models/sp_gat.py.

You may execute a full training run of the sparse model on Cora through execute_cora_sparse.py.

Dependencies

The script has been tested running under Python 3.5.2, with the following packages installed (along with their dependencies):

  • numpy==1.14.1
  • scipy==1.0.0
  • networkx==2.1
  • tensorflow-gpu==1.6.0

In addition, CUDA 9.0 and cuDNN 7 have been used.

Reference

If you make advantage of the GAT model in your research, please cite the following in your manuscript:

@article{
  velickovic2018graph,
  title="{Graph Attention Networks}",
  author={Veli{\v{c}}kovi{\'{c}}, Petar and Cucurull, Guillem and Casanova, Arantxa and Romero, Adriana and Li{\`{o}}, Pietro and Bengio, Yoshua},
  journal={International Conference on Learning Representations},
  year={2018},
  url={https://openreview.net/forum?id=rJXMpikCZ},
  note={accepted as poster},
}

For getting started with GATs, as well as graph representation learning in general, we highly recommend the pytorch-GAT repository by Aleksa Gordić. It ships with an inductive (PPI) example as well.

GAT is a popular method for graph representation learning, with optimised implementations within virtually all standard GRL libraries:

We recommend using either one of those (depending on your favoured framework), as their implementations have been more readily battle-tested.

Early on post-release, two unofficial ports of the GAT model to various frameworks quickly surfaced. To honour the effort of their developers as early adopters of the GAT layer, we leave pointers to them here.

License

MIT

Comments
  • Extract features from Graph attention network

    Extract features from Graph attention network

    I am trying to extract only features from graph attention network, I was using Gcn as feature extractor and I want to replace it with GAT

    gc1 = GraphConvolution( input_dim = 300,  output_dim = 1024, 'first_layer')( features_matrix, adj_matrix )
    gc2 = GraphConvolution(input_dim = 1024, output_dim = 10 ,'second_layer') (gc1, adj_matrix)
    

    Where GraphConvolution layer is defined as :

    class GraphConvolution():
        """Basic graph convolution layer for undirected graph without edge labels."""
        def __init__(self, input_dim, output_dim, name, dropout=0., act=tf.nn.relu):
            self.name = name
            self.vars = {}
    
            with tf.variable_scope(self.name + '_vars'):
                self.vars['weights'] = weight_variable_glorot(input_dim, output_dim, name='weights')
            self.dropout = dropout
            
            self.act = act
    
        def __call__(self, inputs, adj):
            
            with tf.name_scope(self.name):        
                x = inputs
                x = tf.nn.dropout(x, 1 - self.dropout)
                x = tf.matmul(x, self.vars['weights'])
                x = tf.matmul(adj, x)
                outputs = self.act(x)
            return outputs
    

    Now to replace gcn layer with GAT, I tried this :

    from gat import GAT
    
    # Because Gat is accepting 3d input [ batch, node, features ]
    features_matrix     = tf.expand_dims(features_matrix, axis = 0)
    adj_matrix          = tf.expand_dims(adj_matrix, axis = 0)
    
    gat_logits = GAT.inference( inputs = features_matrix, 
                                     nb_classes  = 10, 
                                     nb_nodes    = 22, 
                                     training    = True,
                                     attn_drop   = 0.0, 
                                     ffd_drop    = 0.0,
                                     bias_mat    = adj_matrix,
                                     hid_units   = [8], 
                                     n_heads     = [8, 1],
                                     residual    = False, 
                                     activation  = tf.nn.elu)
    

    Now I want to get just the logits from GAT as features and it should learn the features too, so I set training = True

    But the accuracy from GCN features I was getting around 90% but in GAT features I am not able to get accuracy more than 80 %, instead, it should increase the accuracy compared to GCN.

    Is there anything I am missing in the network or my hyperparameters are not correct to compare to the hyperparameters i was using in GCN.

    @PetarV- @gcucurull Can you suggest me how I can extract feature from GAT and if I am doing correct way then why I am not getting good accuracy.

    Thank you

    opened by monk1337 13
  • What is difference between transductive and inductive in GNN?

    What is difference between transductive and inductive in GNN?

    It seems in GNN(graph neural network), in transductive situation, we input the whole graph and we mask the label of valid data and predict the label for the valid data.

    But is seems in inductive situation, we also input the whole graph(but sample to batch) and mask the label of the valid data and predict the label for the valid data.

    Thank you very much. @PetarV-

    opened by guotong1988 5
  • some question about GAT

    some question about GAT

    Hello, recently when I read paper of GAT again, I found a question and was confused. I hope to obtain your help. The coefficient of $\alpha_{i,j}$ is decided by the features of node i and node j under the supervised learning of those training set nodes with corresponding labels. But if the situation in the traing set is: The two vertices of an edge belong to the training set and the test set respectively. Theoretically, node without label will not be able to use gradient descent for learning. In this way, how does GAT it works? Thanks a lot !

    opened by junkangwu 4
  • Does the attention used in codes the same with the one in paper?

    Does the attention used in codes the same with the one in paper?

    I find in function attn_head() (in utils/layers.py) '''

    simplest self-attention possible

    f_1 = tf.layers.conv1d(seq_fts, 1, 1) f_2 = tf.layers.conv1d(seq_fts, 1, 1) logits = f_1 + tf.transpose(f_2, [0, 2, 1]) coefs = tf.nn.softmax(tf.nn.leaky_relu(logits) + bias_mat) ''' In my understanding,the codes equals to $$f_1 W_1 + f_2 W_2$$ but in the paper, the chose attention mechanism use concatenation, and $$W_1 = W_2 = W$$ Did I get something wrong?

    opened by hapoyige 4
  • Confusion about some codes

    Confusion about some codes

    Hi, Thank you for sharing the codes! I find the codes below (in def sp_attn_head.py) is a bit hard to understand, since they are not directly corresponding to the original description in text:

    f_1 = tf.layers.conv1d(seq_fts, 1, 1)
    f_2 = tf.layers.conv1d(seq_fts, 1, 1)
    
    f_1 = tf.reshape(f_1, (nb_nodes, 1))
    f_2 = tf.reshape(f_2, (nb_nodes, 1))
    
    f_1 = adj_mat*f_1
    f_2 = adj_mat * tf.transpose(f_2, [1,0])
    
    logits = tf.sparse_add(f_1, f_2)
    
    lrelu = tf.SparseTensor(indices=logits.indices, 
    values=tf.nn.leaky_relu(logits.values), 
    dense_shape=logits.dense_shape)
    
    coefs = tf.sparse_softmax(lrelu)
    

    seq_fts = tf.layers.conv1d(seq, out_sz, 1, use_bias=False) can be regarded as multiplying a weight matrix W (converting F features to F’ features). However, I do not understand the following steps:

    f_1 = tf.layers.conv1d(seq_fts, 1, 1)
    f_2 = tf.layers.conv1d(seq_fts, 1, 1)
    f_1 = adj_mat*f_1
    f_2 = adj_mat * tf.transpose(f_2, [1,0])
    logits = tf.sparse_add(f_1, f_2)
    

    I am aware that the codes are different since it is conducted on a “matrix” level instead of node level. But I cannot see how the attention mechanism is achieved by these steps. Can you help explain a bit?

    Many thanks, Weixin.

    opened by DexterZeng 4
  • Pytorch implementation

    Pytorch implementation

    Hi,

    I was really interested to use your model but unfortunately it is in Tensorflow (or Keras for the unofficial implementation). I propose here the Pytorch version (https://github.com/Diego999/pyGAT) where I obtained between 83.6-84.6 for the accuracy on the transductive learning in the cora task. If you would like to add it in your readme (didn't want to do a pull request for this).

    Best,

    opened by Diego999 4
  • multi-graph node classsification

    multi-graph node classsification

    @gcucurull @PetarV- Hello, I have 20 objects, that is, there are 20 graphs. Now I divide the data set according to graphs, but GAT seems to be able to handle only one graph, and batch-size can only be set to 1. How can I do to process multi-graph node classfication?

    opened by tanjia123456 3
  • how to understand “dropping all structural information”

    how to understand “dropping all structural information”

    Hi,

    I'm trying to study GAT, and awesome works!

    I'm wondering how to understand "without ... depending on knowing the graph structure upfront", and "dropping all structural information".

    I also see "we only compute eij for nodes i belongs Ni, where Ni is some neighborhood of node i in the graph" in the paper, i also see function adj_to_bias() in the code which requires to know the adjacent matrix (the edges).

    So my understanding is that we do need to know the graph structure (edge information) at the beginning, thanks.

    opened by guoyejun 3
  • Question about pubmed dataset results

    Question about pubmed dataset results

    Hello, I am working on GAT. It's an excellent idea to introduce attention to graph.

    But I have some questions about the results on Pubmed dataset which you said "79.0 +- 0.3%".

    In the paper, it said :

    we found that Pubmed’s training set size (60 examples) required slight changes to the GAT architecture: we have applied K = 8 output attention heads (instead of one), and strengthened the L2 regularization to λ = 0.001. Otherwise, the architecture matches the one used for Cora and Citeseer.
    ...
    
    Both models are initialized using Glorot initialization (Glorot & Bengio, 2010) and trained to minimize cross-entropy on the training nodes using the Adam SGD optimizer (Kingma & Ba, 2014) with an initial learning rate of 0.01 for Pubmed, and 0.005 for all other datasets.
    

    Since you said "we use 500 additional nodes for validation purposes (the same ones as used by Kipf & Welling (2017))", I copy the data from gcn.

    And I changed the hyperparameter to:

    dataset = 'pubmed'
    lr = 0.01  # learning rate
    l2_coef = 0.001  # weight decay
    hid_units = [8] # numbers of hidden units per each attention head in each layer
    n_heads = [8, 8] # additional entry for the output layer
    

    and run 100 times, here is the code difference, I get the results " 0.777 +- 0.8%", which is not able to pass t-test.

    (79.0-77.7)/(0.8/\sqrt(100)) = 16.25 > 1.984 (sig 0.05)
    

    Is there anything wrong about the experiments parameters ? Can you help me to reproduce the pubmed dataset results?

    Thank you

    opened by hotchilipowder 3
  • Question about your sparse implementation

    Question about your sparse implementation

    Hi GAT folks,

    Awesome works! I tried out your sparse implementation. What I did are: (1) Replace the attn_head with sp_attn_head function. (2) Use sparse adj matrix. I removed the adj.todense() and process.adj_to_bias function calls and replace it with sparse version using scipy.

    I found I cannot achieve the expected accuracy after using the sparse implementation for Cora. Here is the log:

    Training: loss = 1.94517, acc = 0.15000 | Val: loss = 1.94614, acc = 0.20600 | Time: nan (graph/s)
    Training: loss = 1.93820, acc = 0.20714 | Val: loss = 1.94221, acc = 0.25000 | Time: nan (graph/s)
    Training: loss = 1.92800, acc = 0.28571 | Val: loss = 1.93924, acc = 0.26600 | Time: nan (graph/s)
    Training: loss = 1.91733, acc = 0.23571 | Val: loss = 1.93592, acc = 0.24400 | Time: nan (graph/s)
    Training: loss = 1.90697, acc = 0.30714 | Val: loss = 1.93212, acc = 0.21600 | Time: nan (graph/s)
    Training: loss = 1.89323, acc = 0.30714 | Val: loss = 1.92840, acc = 0.17800 | Time: 17.03976 (graph/s)
    Training: loss = 1.88492, acc = 0.26429 | Val: loss = 1.92444, acc = 0.18400 | Time: 17.02856 (graph/s)
    Training: loss = 1.87695, acc = 0.35000 | Val: loss = 1.91930, acc = 0.18000 | Time: 17.09180 (graph/s)
    Training: loss = 1.86555, acc = 0.28571 | Val: loss = 1.91408, acc = 0.18200 | Time: 17.05338 (graph/s)
    Training: loss = 1.84816, acc = 0.29286 | Val: loss = 1.90931, acc = 0.17800 | Time: 17.03552 (graph/s)
    Training: loss = 1.85963, acc = 0.25000 | Val: loss = 1.90346, acc = 0.18800 | Time: 16.99722 (graph/s)
    Training: loss = 1.86400, acc = 0.21429 | Val: loss = 1.89660, acc = 0.20400 | Time: 17.03414 (graph/s)
    Training: loss = 1.82015, acc = 0.32143 | Val: loss = 1.89007, acc = 0.22400 | Time: 16.99219 (graph/s)
    Training: loss = 1.80568, acc = 0.39286 | Val: loss = 1.88342, acc = 0.25800 | Time: 17.00136 (graph/s)
    Training: loss = 1.80814, acc = 0.35714 | Val: loss = 1.87724, acc = 0.27600 | Time: 16.99366 (graph/s)
    Training: loss = 1.80206, acc = 0.38571 | Val: loss = 1.87083, acc = 0.32000 | Time: 17.00094 (graph/s)
    Training: loss = 1.77896, acc = 0.41429 | Val: loss = 1.86466, acc = 0.34600 | Time: 17.05043 (graph/s)
    Training: loss = 1.76743, acc = 0.40714 | Val: loss = 1.85916, acc = 0.38600 | Time: 17.03179 (graph/s)
    Training: loss = 1.76884, acc = 0.38571 | Val: loss = 1.85294, acc = 0.45000 | Time: 17.02546 (graph/s)
    Training: loss = 1.76213, acc = 0.50000 | Val: loss = 1.84764, acc = 0.48800 | Time: 16.93918 (graph/s)
    Training: loss = 1.76706, acc = 0.45000 | Val: loss = 1.84279, acc = 0.52800 | Time: 16.97079 (graph/s)
    Training: loss = 1.75194, acc = 0.47143 | Val: loss = 1.83775, acc = 0.54800 | Time: 16.96184 (graph/s)
    Training: loss = 1.69834, acc = 0.55000 | Val: loss = 1.83297, acc = 0.56800 | Time: 16.95823 (graph/s)
    Training: loss = 1.71937, acc = 0.52857 | Val: loss = 1.82791, acc = 0.58200 | Time: 16.96505 (graph/s)
    Training: loss = 1.71782, acc = 0.50714 | Val: loss = 1.82320, acc = 0.58600 | Time: 16.96362 (graph/s)
    Training: loss = 1.69105, acc = 0.56429 | Val: loss = 1.81782, acc = 0.60000 | Time: 16.95437 (graph/s)
    Training: loss = 1.67340, acc = 0.60000 | Val: loss = 1.81338, acc = 0.59200 | Time: 16.78093 (graph/s)
    Training: loss = 1.70836, acc = 0.55000 | Val: loss = 1.80828, acc = 0.58800 | Time: 16.76095 (graph/s)
    Training: loss = 1.71126, acc = 0.56429 | Val: loss = 1.80237, acc = 0.58800 | Time: 16.78002 (graph/s)
    Training: loss = 1.66770, acc = 0.62857 | Val: loss = 1.79581, acc = 0.60400 | Time: 16.75709 (graph/s)
    Training: loss = 1.63414, acc = 0.62143 | Val: loss = 1.78960, acc = 0.60200 | Time: 16.76527 (graph/s)
    Training: loss = 1.64903, acc = 0.59286 | Val: loss = 1.78325, acc = 0.61000 | Time: 16.71849 (graph/s)
    Training: loss = 1.62723, acc = 0.55714 | Val: loss = 1.77702, acc = 0.61000 | Time: 16.72862 (graph/s)
    Training: loss = 1.65727, acc = 0.55714 | Val: loss = 1.77083, acc = 0.63000 | Time: 16.74734 (graph/s)
    Training: loss = 1.60982, acc = 0.61429 | Val: loss = 1.76514, acc = 0.60600 | Time: 16.73516 (graph/s)
    Training: loss = 1.56368, acc = 0.59286 | Val: loss = 1.75920, acc = 0.60800 | Time: 16.74822 (graph/s)
    Training: loss = 1.59711, acc = 0.59286 | Val: loss = 1.75289, acc = 0.58600 | Time: 16.73207 (graph/s)
    Training: loss = 1.56515, acc = 0.60714 | Val: loss = 1.74649, acc = 0.58600 | Time: 16.74241 (graph/s)
    Training: loss = 1.60043, acc = 0.55000 | Val: loss = 1.74041, acc = 0.58200 | Time: 16.73666 (graph/s)
    Training: loss = 1.57450, acc = 0.62143 | Val: loss = 1.73533, acc = 0.57800 | Time: 16.73286 (graph/s)
    Training: loss = 1.57213, acc = 0.57857 | Val: loss = 1.73015, acc = 0.56600 | Time: 16.73614 (graph/s)
    Training: loss = 1.55529, acc = 0.56429 | Val: loss = 1.72659, acc = 0.56600 | Time: 16.73717 (graph/s)
    Training: loss = 1.55898, acc = 0.55714 | Val: loss = 1.72352, acc = 0.56200 | Time: 16.73034 (graph/s)
    Training: loss = 1.55415, acc = 0.55000 | Val: loss = 1.72040, acc = 0.55200 | Time: 16.72506 (graph/s)
    Training: loss = 1.55050, acc = 0.52143 | Val: loss = 1.71850, acc = 0.53800 | Time: 16.73672 (graph/s)
    Training: loss = 1.47474, acc = 0.63571 | Val: loss = 1.71621, acc = 0.51600 | Time: 16.74091 (graph/s)
    Training: loss = 1.56495, acc = 0.50714 | Val: loss = 1.71540, acc = 0.49600 | Time: 16.73061 (graph/s)
    Training: loss = 1.51994, acc = 0.55714 | Val: loss = 1.71458, acc = 0.47800 | Time: 16.74436 (graph/s)
    Training: loss = 1.54271, acc = 0.52143 | Val: loss = 1.71304, acc = 0.46600 | Time: 16.75906 (graph/s)
    Training: loss = 1.58519, acc = 0.45000 | Val: loss = 1.71244, acc = 0.45000 | Time: 16.76833 (graph/s)
    Training: loss = 1.57245, acc = 0.50714 | Val: loss = 1.71141, acc = 0.44200 | Time: 16.76075 (graph/s)
    Training: loss = 1.62070, acc = 0.47857 | Val: loss = 1.70944, acc = 0.44400 | Time: 16.76779 (graph/s)
    Training: loss = 1.63155, acc = 0.50714 | Val: loss = 1.70797, acc = 0.47200 | Time: 16.76185 (graph/s)
    Training: loss = 1.56914, acc = 0.47143 | Val: loss = 1.70734, acc = 0.49400 | Time: 16.76988 (graph/s)
    Training: loss = 1.53856, acc = 0.52143 | Val: loss = 1.70701, acc = 0.48600 | Time: 16.73263 (graph/s)
    Training: loss = 1.46632, acc = 0.58571 | Val: loss = 1.70669, acc = 0.49000 | Time: 16.73773 (graph/s)
    Training: loss = 1.45926, acc = 0.60714 | Val: loss = 1.70659, acc = 0.50800 | Time: 16.75432 (graph/s)
    Training: loss = 1.48121, acc = 0.57857 | Val: loss = 1.70474, acc = 0.50800 | Time: 16.76423 (graph/s)
    Training: loss = 1.50514, acc = 0.55714 | Val: loss = 1.70184, acc = 0.50000 | Time: 16.77102 (graph/s)
    Training: loss = 1.50490, acc = 0.50000 | Val: loss = 1.69951, acc = 0.49400 | Time: 16.78197 (graph/s)
    Training: loss = 1.51039, acc = 0.53571 | Val: loss = 1.69778, acc = 0.49400 | Time: 16.78962 (graph/s)
    Training: loss = 1.45282, acc = 0.57143 | Val: loss = 1.69572, acc = 0.49400 | Time: 16.79195 (graph/s)
    Training: loss = 1.43123, acc = 0.55000 | Val: loss = 1.69297, acc = 0.49800 | Time: 16.79870 (graph/s)
    Training: loss = 1.51627, acc = 0.49286 | Val: loss = 1.68947, acc = 0.50000 | Time: 16.77023 (graph/s)
    Training: loss = 1.46445, acc = 0.53571 | Val: loss = 1.68655, acc = 0.50200 | Time: 16.77020 (graph/s)
    Training: loss = 1.49241, acc = 0.49286 | Val: loss = 1.68367, acc = 0.51000 | Time: 16.78139 (graph/s)
    Training: loss = 1.52911, acc = 0.47857 | Val: loss = 1.68273, acc = 0.52000 | Time: 16.79249 (graph/s)
    Training: loss = 1.48992, acc = 0.57143 | Val: loss = 1.68100, acc = 0.53000 | Time: 16.78402 (graph/s)
    Training: loss = 1.43546, acc = 0.57143 | Val: loss = 1.67985, acc = 0.53600 | Time: 16.72668 (graph/s)
    Training: loss = 1.48215, acc = 0.52857 | Val: loss = 1.67853, acc = 0.54000 | Time: 16.72524 (graph/s)
    Training: loss = 1.47648, acc = 0.55000 | Val: loss = 1.67829, acc = 0.53800 | Time: 16.72739 (graph/s)
    Training: loss = 1.44751, acc = 0.58571 | Val: loss = 1.67800, acc = 0.54000 | Time: 16.73214 (graph/s)
    Training: loss = 1.40865, acc = 0.56429 | Val: loss = 1.67713, acc = 0.54000 | Time: 16.73742 (graph/s)
    Training: loss = 1.47875, acc = 0.50714 | Val: loss = 1.67518, acc = 0.54000 | Time: 16.73667 (graph/s)
    Training: loss = 1.40626, acc = 0.52143 | Val: loss = 1.67420, acc = 0.53200 | Time: 16.73576 (graph/s)
    Training: loss = 1.46455, acc = 0.49286 | Val: loss = 1.67260, acc = 0.54000 | Time: 16.69977 (graph/s)
    Training: loss = 1.42937, acc = 0.55000 | Val: loss = 1.66953, acc = 0.54200 | Time: 16.70049 (graph/s)
    Training: loss = 1.44192, acc = 0.55000 | Val: loss = 1.66651, acc = 0.53800 | Time: 16.70967 (graph/s)
    Training: loss = 1.44210, acc = 0.55714 | Val: loss = 1.66280, acc = 0.53400 | Time: 16.72562 (graph/s)
    Training: loss = 1.36144, acc = 0.61429 | Val: loss = 1.65898, acc = 0.53000 | Time: 16.72490 (graph/s)
    Training: loss = 1.51469, acc = 0.53571 | Val: loss = 1.65483, acc = 0.52400 | Time: 16.71665 (graph/s)
    Training: loss = 1.41710, acc = 0.55000 | Val: loss = 1.65153, acc = 0.52400 | Time: 16.71838 (graph/s)
    Training: loss = 1.42846, acc = 0.56429 | Val: loss = 1.64860, acc = 0.51600 | Time: 16.71081 (graph/s)
    Training: loss = 1.48258, acc = 0.47143 | Val: loss = 1.64704, acc = 0.50600 | Time: 16.71550 (graph/s)
    Training: loss = 1.39769, acc = 0.59286 | Val: loss = 1.64604, acc = 0.50200 | Time: 16.72110 (graph/s)
    Training: loss = 1.41342, acc = 0.58571 | Val: loss = 1.64720, acc = 0.49800 | Time: 16.72464 (graph/s)
    Training: loss = 1.37405, acc = 0.55000 | Val: loss = 1.64852, acc = 0.48800 | Time: 16.73652 (graph/s)
    Training: loss = 1.36246, acc = 0.56429 | Val: loss = 1.64919, acc = 0.48400 | Time: 16.73602 (graph/s)
    Training: loss = 1.35721, acc = 0.57857 | Val: loss = 1.65088, acc = 0.48200 | Time: 16.73687 (graph/s)
    Training: loss = 1.46561, acc = 0.52857 | Val: loss = 1.65400, acc = 0.48000 | Time: 16.70509 (graph/s)
    Training: loss = 1.41449, acc = 0.53571 | Val: loss = 1.65670, acc = 0.48200 | Time: 16.70738 (graph/s)
    Training: loss = 1.46798, acc = 0.48571 | Val: loss = 1.65805, acc = 0.47800 | Time: 16.71920 (graph/s)
    Training: loss = 1.39453, acc = 0.57143 | Val: loss = 1.66067, acc = 0.47400 | Time: 16.71041 (graph/s)
    Training: loss = 1.40467, acc = 0.53571 | Val: loss = 1.66297, acc = 0.46800 | Time: 16.70864 (graph/s)
    Training: loss = 1.41027, acc = 0.49286 | Val: loss = 1.66376, acc = 0.47000 | Time: 16.70883 (graph/s)
    Training: loss = 1.46268, acc = 0.50000 | Val: loss = 1.66714, acc = 0.46800 | Time: 16.70941 (graph/s)
    Training: loss = 1.41257, acc = 0.55714 | Val: loss = 1.66891, acc = 0.47200 | Time: 16.71804 (graph/s)
    Training: loss = 1.40972, acc = 0.50714 | Val: loss = 1.66983, acc = 0.47400 | Time: 16.72160 (graph/s)
    Training: loss = 1.42352, acc = 0.53571 | Val: loss = 1.67032, acc = 0.47200 | Time: 16.71816 (graph/s)
    Training: loss = 1.38608, acc = 0.53571 | Val: loss = 1.67717, acc = 0.46000 | Time: 16.72157 (graph/s)
    Training: loss = 1.43882, acc = 0.52857 | Val: loss = 1.68334, acc = 0.45200 | Time: 16.72999 (graph/s)
    Training: loss = 1.49979, acc = 0.49286 | Val: loss = 1.69646, acc = 0.43600 | Time: 16.73467 (graph/s)
    Training: loss = 1.51738, acc = 0.50000 | Val: loss = 1.71791, acc = 0.42800 | Time: 16.74103 (graph/s)
    Training: loss = 1.41980, acc = 0.54286 | Val: loss = 1.73596, acc = 0.41200 | Time: 16.74587 (graph/s)
    Training: loss = 1.49115, acc = 0.45000 | Val: loss = 1.75199, acc = 0.38800 | Time: 16.74358 (graph/s)
    Training: loss = 1.50618, acc = 0.49286 | Val: loss = 1.77033, acc = 0.38400 | Time: 16.72982 (graph/s)
    Training: loss = 1.54672, acc = 0.52143 | Val: loss = 1.80930, acc = 0.37800 | Time: 16.72799 (graph/s)
    Training: loss = 1.67387, acc = 0.47143 | Val: loss = 1.86923, acc = 0.35000 | Time: 16.73573 (graph/s)
    Training: loss = 1.69535, acc = 0.45714 | Val: loss = 1.94518, acc = 0.32400 | Time: 16.73458 (graph/s)
    Training: loss = 1.81820, acc = 0.41429 | Val: loss = 2.03584, acc = 0.31200 | Time: 16.73830 (graph/s)
    Training: loss = 1.93426, acc = 0.29286 | Val: loss = 2.14381, acc = 0.29800 | Time: 16.73484 (graph/s)
    Training: loss = 1.97562, acc = 0.27143 | Val: loss = 2.27155, acc = 0.28400 | Time: 16.73221 (graph/s)
    Training: loss = 2.17724, acc = 0.30714 | Val: loss = 2.42012, acc = 0.26800 | Time: 16.71115 (graph/s)
    Training: loss = 2.60620, acc = 0.28571 | Val: loss = 2.56814, acc = 0.24800 | Time: 16.71131 (graph/s)
    Training: loss = 3.08253, acc = 0.28571 | Val: loss = 2.75077, acc = 0.21800 | Time: 16.71770 (graph/s)
    Training: loss = 3.59296, acc = 0.23571 | Val: loss = 2.96115, acc = 0.21000 | Time: 16.71862 (graph/s)
    Training: loss = 4.29261, acc = 0.27857 | Val: loss = 3.19888, acc = 0.20200 | Time: 16.71416 (graph/s)
    Training: loss = 3.61426, acc = 0.19286 | Val: loss = 3.46225, acc = 0.19200 | Time: 16.71198 (graph/s)
    Training: loss = 4.62925, acc = 0.17857 | Val: loss = 3.75955, acc = 0.18400 | Time: 16.71600 (graph/s)
    Training: loss = 5.59345, acc = 0.20714 | Val: loss = 4.08395, acc = 0.17800 | Time: 16.71734 (graph/s)
    Training: loss = 6.05617, acc = 0.17143 | Val: loss = 4.44487, acc = 0.17400 | Time: 16.71906 (graph/s)
    Training: loss = 5.78760, acc = 0.18571 | Val: loss = 4.83142, acc = 0.16800 | Time: 16.71927 (graph/s)
    Training: loss = 6.78114, acc = 0.20000 | Val: loss = 5.25684, acc = 0.16400 | Time: 16.72466 (graph/s)
    Training: loss = 8.34020, acc = 0.18571 | Val: loss = 5.70460, acc = 0.16400 | Time: 16.73279 (graph/s)
    Training: loss = 8.18466, acc = 0.15714 | Val: loss = 6.18899, acc = 0.16400 | Time: 16.73642 (graph/s)
    Training: loss = 10.17127, acc = 0.20000 | Val: loss = 6.70358, acc = 0.16200 | Time: 16.72992 (graph/s)
    Training: loss = 8.01908, acc = 0.17143 | Val: loss = 7.24232, acc = 0.16400 | Time: 16.73301 (graph/s)
    Training: loss = 10.54279, acc = 0.15714 | Val: loss = 7.80857, acc = 0.12200 | Time: 16.72627 (graph/s)
    Training: loss = 11.20833, acc = 0.14286 | Val: loss = 8.40816, acc = 0.12200 | Time: 16.71242 (graph/s)
    Training: loss = 13.13751, acc = 0.14286 | Val: loss = 9.03840, acc = 0.12200 | Time: 16.71696 (graph/s)
    Training: loss = 15.50717, acc = 0.12857 | Val: loss = 9.70438, acc = 0.12600 | Time: 16.72127 (graph/s)
    Training: loss = 14.83573, acc = 0.14286 | Val: loss = 10.38955, acc = 0.12400 | Time: 16.72117 (graph/s)
    Training: loss = 16.32450, acc = 0.14286 | Val: loss = 11.10900, acc = 0.12800 | Time: 16.71955 (graph/s)
    Training: loss = 20.88581, acc = 0.13571 | Val: loss = 11.85103, acc = 0.13000 | Time: 16.72041 (graph/s)
    Training: loss = 18.14150, acc = 0.15714 | Val: loss = 12.62478, acc = 0.12800 | Time: 16.71870 (graph/s)
    Training: loss = 20.19287, acc = 0.12857 | Val: loss = 13.42502, acc = 0.12600 | Time: 16.71075 (graph/s)
    Training: loss = 20.17040, acc = 0.13571 | Val: loss = 14.25172, acc = 0.12800 | Time: 16.71072 (graph/s)
    Training: loss = 21.68529, acc = 0.13571 | Val: loss = 15.11031, acc = 0.13000 | Time: 16.71549 (graph/s)
    Training: loss = 24.35412, acc = 0.12857 | Val: loss = 15.99096, acc = 0.13000 | Time: 16.70976 (graph/s)
    Training: loss = 26.90386, acc = 0.09286 | Val: loss = 16.90097, acc = 0.12800 | Time: 16.71122 (graph/s)
    Training: loss = 32.60686, acc = 0.12857 | Val: loss = 17.83447, acc = 0.12400 | Time: 16.71450 (graph/s)
    Training: loss = 23.71290, acc = 0.10714 | Val: loss = 18.79685, acc = 0.12200 | Time: 16.71559 (graph/s)
    Training: loss = 32.05572, acc = 0.12143 | Val: loss = 19.79207, acc = 0.12400 | Time: 16.71759 (graph/s)
    Training: loss = 37.51984, acc = 0.17857 | Val: loss = 20.82151, acc = 0.12600 | Time: 16.71863 (graph/s)
    Training: loss = 43.44203, acc = 0.10714 | Val: loss = 21.87587, acc = 0.12400 | Time: 16.71384 (graph/s)
    Training: loss = 41.57386, acc = 0.12857 | Val: loss = 22.96205, acc = 0.12400 | Time: 16.71532 (graph/s)
    Training: loss = 32.44603, acc = 0.12857 | Val: loss = 24.04948, acc = 0.12800 | Time: 16.72206 (graph/s)
    Training: loss = 32.47784, acc = 0.15714 | Val: loss = 25.16518, acc = 0.12400 | Time: 16.72132 (graph/s)
    Training: loss = 57.02023, acc = 0.14286 | Val: loss = 26.32536, acc = 0.08800 | Time: 16.72422 (graph/s)
    Training: loss = 47.02181, acc = 0.13571 | Val: loss = 27.52316, acc = 0.09000 | Time: 16.72099 (graph/s)
    Training: loss = 60.71332, acc = 0.12857 | Val: loss = 28.76072, acc = 0.09200 | Time: 16.71417 (graph/s)
    Training: loss = 52.61975, acc = 0.13571 | Val: loss = 30.03493, acc = 0.09000 | Time: 16.70092 (graph/s)
    Training: loss = 55.17526, acc = 0.14286 | Val: loss = 31.34324, acc = 0.09200 | Time: 16.70220 (graph/s)
    Training: loss = 72.76334, acc = 0.12143 | Val: loss = 32.69006, acc = 0.09200 | Time: 16.70329 (graph/s)
    Training: loss = 42.62173, acc = 0.17143 | Val: loss = 34.06517, acc = 0.09200 | Time: 16.70729 (graph/s)
    Training: loss = 68.27650, acc = 0.15714 | Val: loss = 35.47098, acc = 0.09200 | Time: 16.70868 (graph/s)
    Training: loss = 53.12449, acc = 0.14286 | Val: loss = 36.86567, acc = 0.09200 | Time: 16.70579 (graph/s)
    Training: loss = 75.30608, acc = 0.14286 | Val: loss = 38.30105, acc = 0.09200 | Time: 16.70677 (graph/s)
    Training: loss = 66.39566, acc = 0.13571 | Val: loss = 39.77269, acc = 0.09200 | Time: 16.70885 (graph/s)
    Training: loss = 90.00805, acc = 0.15000 | Val: loss = 41.29017, acc = 0.09000 | Time: 16.70979 (graph/s)
    Training: loss = 74.48537, acc = 0.15714 | Val: loss = 42.83427, acc = 0.09000 | Time: 16.70017 (graph/s)
    Training: loss = 83.59474, acc = 0.14286 | Val: loss = 44.40656, acc = 0.09000 | Time: 16.70253 (graph/s)
    Training: loss = 100.54999, acc = 0.12143 | Val: loss = 46.01003, acc = 0.09000 | Time: 16.70330 (graph/s)
    Training: loss = 78.89310, acc = 0.15000 | Val: loss = 47.63050, acc = 0.09000 | Time: 16.70234 (graph/s)
    Training: loss = 84.41219, acc = 0.16429 | Val: loss = 49.28119, acc = 0.09000 | Time: 16.70904 (graph/s)
    Training: loss = 88.26729, acc = 0.17857 | Val: loss = 50.97437, acc = 0.08800 | Time: 16.70679 (graph/s)
    Training: loss = 65.11741, acc = 0.14286 | Val: loss = 52.67675, acc = 0.08800 | Time: 16.70200 (graph/s)
    Training: loss = 85.14694, acc = 0.18571 | Val: loss = 54.41297, acc = 0.08800 | Time: 16.69457 (graph/s)
    Training: loss = 97.87077, acc = 0.15714 | Val: loss = 56.18945, acc = 0.08800 | Time: 16.69361 (graph/s)
    Training: loss = 98.05456, acc = 0.13571 | Val: loss = 57.99624, acc = 0.08800 | Time: 16.68741 (graph/s)
    Training: loss = 107.37295, acc = 0.13571 | Val: loss = 59.82062, acc = 0.08800 | Time: 16.68968 (graph/s)
    Training: loss = 114.99556, acc = 0.14286 | Val: loss = 61.67920, acc = 0.08800 | Time: 16.68897 (graph/s)
    Training: loss = 135.17931, acc = 0.12857 | Val: loss = 63.58669, acc = 0.08800 | Time: 16.68689 (graph/s)
    Training: loss = 109.73069, acc = 0.15000 | Val: loss = 65.50701, acc = 0.08800 | Time: 16.68364 (graph/s)
    Training: loss = 142.71088, acc = 0.11429 | Val: loss = 67.47314, acc = 0.08800 | Time: 16.67883 (graph/s)
    Training: loss = 122.98055, acc = 0.11429 | Val: loss = 69.45295, acc = 0.08800 | Time: 16.67969 (graph/s)
    Training: loss = 122.06322, acc = 0.12857 | Val: loss = 71.47054, acc = 0.08800 | Time: 16.67910 (graph/s)
    Training: loss = 130.83694, acc = 0.12857 | Val: loss = 73.51564, acc = 0.08800 | Time: 16.67657 (graph/s)
    Training: loss = 147.36774, acc = 0.12143 | Val: loss = 75.58735, acc = 0.08800 | Time: 16.67618 (graph/s)
    Training: loss = 128.03278, acc = 0.15714 | Val: loss = 77.68494, acc = 0.08800 | Time: 16.67729 (graph/s)
    Training: loss = 155.61414, acc = 0.13571 | Val: loss = 79.81355, acc = 0.08600 | Time: 16.67960 (graph/s)
    Training: loss = 149.44771, acc = 0.15714 | Val: loss = 81.98518, acc = 0.08600 | Time: 16.67523 (graph/s)
    Training: loss = 143.26579, acc = 0.12143 | Val: loss = 84.19203, acc = 0.08600 | Time: 16.67501 (graph/s)
    Training: loss = 147.99545, acc = 0.14286 | Val: loss = 86.41908, acc = 0.08800 | Time: 16.67924 (graph/s)
    Training: loss = 153.45032, acc = 0.13571 | Val: loss = 88.67501, acc = 0.08800 | Time: 16.68069 (graph/s)
    

    You could see the accuracy dropped suddenly after achieving ~60% for validation set. Have you guys met similar problems? Did I miss anything?

    Thank you, Minjie

    opened by jermainewang 3
  • Weighted adjacency graphs

    Weighted adjacency graphs

    Hi,

    I am currently working on a project where we have weighted directed graphs. I know that the current model support directed graphs, i.e. the neighborhood is defined as the incoming nodes.

    However, is it possible to use weighted adjacency graphs? My intuition is saying no, as the edge weights are learned through the attention mechanisms. However, in the associated blog there is a footnote that states that GAT can be trivially extended to include these cases ( https://petar-v.com/GAT/#fn:1 ).

    So if it is possible, how would one do this?

    • Joakim Haurum
    opened by JoakimHaurum 2
  • How to get node embeddings?

    How to get node embeddings?

    Hi, I would like to ask how to use GAT to generate node embeddings for downstream tasks, like link prediction. And If nodes have no features, how to initialize it?

    I'd appreciate your help!

    opened by linz2000 0
  • Why using undirected graph for transductive learning?

    Why using undirected graph for transductive learning?

    Thanks for your great work! Excuse me for a question about the transductive experiments.

    According to the paper, GAT could be applied to directed graph. I wonder that why do you use undirected Cora, Citeseer and Pubmed for experiments rather than the original directed ones? As we all know, the citation network is directed.

    opened by qncsn2016 0
  • transform to other scope dataset

    transform to other scope dataset

    I'm gonna transform to other scope dataset this code

    other dataset is a file (one file) and shape is below column1, column2, column3, ...column9, label 1500, 45, 1, .... 0 933, 22, 0, .... 0 1234, 30, 1, .... 1 1112, 23, 0, .... 0 ...

    Let me know how to transform to above dataset with this code

    opened by kimsijin33 0
  • How to batch large datasets to apply GAT?

    How to batch large datasets to apply GAT?

    Hi, Thanks for great repo.

    I want to apply GAT on a very large dataset and train it in GPU. So I should provide my data in small batches. However, the algorithm uses the full data. So I was wondering how I could run GAT with small batches?

    opened by taherhekmatfar 0
  • Need help in creating evaulation metrics

    Need help in creating evaulation metrics

    hi @PetarV- , I'm new to machine learning, so I humbly request you to help me out in creating evaulation metrics such as precision, recall, f1 score for this model in your code. I could train my model on my own dataset. Could you please help me as in how to construct evaluation metrics?

    opened by Joey0538 0
Owner
Petar Veličković
Senior Research Scientist
Petar Veličković
Official PyTorch implementation of "AASIST: Audio Anti-Spoofing using Integrated Spectro-Temporal Graph Attention Networks"

AASIST This repository provides the overall framework for training and evaluating audio anti-spoofing systems proposed in 'AASIST: Audio Anti-Spoofing

Clova AI Research 56 Jan 2, 2023
A static analysis library for computing graph representations of Python programs suitable for use with graph neural networks.

python_graphs This package is for computing graph representations of Python programs for machine learning applications. It includes the following modu

Google Research 258 Dec 29, 2022
The source code of the paper "Understanding Graph Neural Networks from Graph Signal Denoising Perspectives"

GSDN-F and GSDN-EF This repository provides a reference implementation of GSDN-F and GSDN-EF as described in the paper "Understanding Graph Neural Net

Guoji Fu 18 Nov 14, 2022
This is the repository for the AAAI 21 paper [Contrastive and Generative Graph Convolutional Networks for Graph-based Semi-Supervised Learning].

CG3 This is the repository for the AAAI 21 paper [Contrastive and Generative Graph Convolutional Networks for Graph-based Semi-Supervised Learning]. R

null 12 Oct 28, 2022
Some tentative models that incorporate label propagation to graph neural networks for graph representation learning in nodes, links or graphs.

Some tentative models that incorporate label propagation to graph neural networks for graph representation learning in nodes, links or graphs.

zshicode 1 Nov 18, 2021
On Size-Oriented Long-Tailed Graph Classification of Graph Neural Networks

On Size-Oriented Long-Tailed Graph Classification of Graph Neural Networks We provide the code (in PyTorch) and datasets for our paper "On Size-Orient

Zemin Liu 4 Jun 18, 2022
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
Implementation of the 😇 Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones

HaloNet - Pytorch Implementation of the Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones. This re

Phil Wang 189 Nov 22, 2022
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Phil Wang 109 Dec 28, 2022
PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Haoyu Chen 71 Dec 30, 2022
Attention-driven Robot Manipulation (ARM) which includes Q-attention

Attention-driven Robotic Manipulation (ARM) This codebase is home to: Q-attention: Enabling Efficient Learning for Vision-based Robotic Manipulation I

Stephen James 84 Dec 29, 2022
Locally Enhanced Self-Attention: Rethinking Self-Attention as Local and Context Terms

LESA Introduction This repository contains the official implementation of Locally Enhanced Self-Attention: Rethinking Self-Attention as Local and Cont

Chenglin Yang 20 Dec 31, 2021
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

mandos 43 Dec 7, 2022
Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Phil Wang 180 Jan 5, 2023
Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention

cosFormer Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention Update log 2022/2/28 Add core code License This

null 120 Dec 15, 2022
Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"

Deformable Attention Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DET

Phil Wang 128 Dec 24, 2022
Graph neural network message passing reframed as a Transformer with local attention

Adjacent Attention Network An implementation of a simple transformer that is equivalent to graph neural network where the message passing is done with

Phil Wang 49 Dec 28, 2022
Implementation of E(n)-Transformer, which extends the ideas of Welling's E(n)-Equivariant Graph Neural Network to attention

E(n)-Equivariant Transformer (wip) Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant G

Phil Wang 132 Jan 2, 2023