Graph Attention Networks



Graph Attention Networks (Veličković et al., ICLR 2018):

GAT layer t-SNE + Attention coefficients on Cora


Here we provide the implementation of a Graph Attention Network (GAT) layer in TensorFlow, along with a minimal execution example (on the Cora dataset). The repository is organised as follows:

  • data/ contains the necessary dataset files for Cora;
  • models/ contains the implementation of the GAT network (;
  • pre_trained/ contains a pre-trained Cora model (achieving 84.4% accuracy on the test set);
  • utils/ contains:
    • an implementation of an attention head, along with an experimental sparse version (;
    • preprocessing subroutines (;
    • preprocessing utilities for the PPI benchmark (

Finally, puts all of the above together and may be used to execute a full training run on Cora.

Sparse version

An experimental sparse version is also available, working only when the batch size is equal to 1. The sparse model may be found at models/

You may execute a full training run of the sparse model on Cora through


The script has been tested running under Python 3.5.2, with the following packages installed (along with their dependencies):

  • numpy==1.14.1
  • scipy==1.0.0
  • networkx==2.1
  • tensorflow-gpu==1.6.0

In addition, CUDA 9.0 and cuDNN 7 have been used.


If you make advantage of the GAT model in your research, please cite the following in your manuscript:

  title="{Graph Attention Networks}",
  author={Veli{\v{c}}kovi{\'{c}}, Petar and Cucurull, Guillem and Casanova, Arantxa and Romero, Adriana and Li{\`{o}}, Pietro and Bengio, Yoshua},
  journal={International Conference on Learning Representations},
  note={accepted as poster},

For getting started with GATs, as well as graph representation learning in general, we highly recommend the pytorch-GAT repository by Aleksa Gordić. It ships with an inductive (PPI) example as well.

GAT is a popular method for graph representation learning, with optimised implementations within virtually all standard GRL libraries:

We recommend using either one of those (depending on your favoured framework), as their implementations have been more readily battle-tested.

Early on post-release, two unofficial ports of the GAT model to various frameworks quickly surfaced. To honour the effort of their developers as early adopters of the GAT layer, we leave pointers to them here.



  • Extract features from Graph attention network

    Extract features from Graph attention network

    I am trying to extract only features from graph attention network, I was using Gcn as feature extractor and I want to replace it with GAT

    gc1 = GraphConvolution( input_dim = 300,  output_dim = 1024, 'first_layer')( features_matrix, adj_matrix )
    gc2 = GraphConvolution(input_dim = 1024, output_dim = 10 ,'second_layer') (gc1, adj_matrix)

    Where GraphConvolution layer is defined as :

    class GraphConvolution():
        """Basic graph convolution layer for undirected graph without edge labels."""
        def __init__(self, input_dim, output_dim, name, dropout=0., act=tf.nn.relu):
   = name
            self.vars = {}
            with tf.variable_scope( + '_vars'):
                self.vars['weights'] = weight_variable_glorot(input_dim, output_dim, name='weights')
            self.dropout = dropout
            self.act = act
        def __call__(self, inputs, adj):
            with tf.name_scope(        
                x = inputs
                x = tf.nn.dropout(x, 1 - self.dropout)
                x = tf.matmul(x, self.vars['weights'])
                x = tf.matmul(adj, x)
                outputs = self.act(x)
            return outputs

    Now to replace gcn layer with GAT, I tried this :

    from gat import GAT
    # Because Gat is accepting 3d input [ batch, node, features ]
    features_matrix     = tf.expand_dims(features_matrix, axis = 0)
    adj_matrix          = tf.expand_dims(adj_matrix, axis = 0)
    gat_logits = GAT.inference( inputs = features_matrix, 
                                     nb_classes  = 10, 
                                     nb_nodes    = 22, 
                                     training    = True,
                                     attn_drop   = 0.0, 
                                     ffd_drop    = 0.0,
                                     bias_mat    = adj_matrix,
                                     hid_units   = [8], 
                                     n_heads     = [8, 1],
                                     residual    = False, 
                                     activation  = tf.nn.elu)

    Now I want to get just the logits from GAT as features and it should learn the features too, so I set training = True

    But the accuracy from GCN features I was getting around 90% but in GAT features I am not able to get accuracy more than 80 %, instead, it should increase the accuracy compared to GCN.

    Is there anything I am missing in the network or my hyperparameters are not correct to compare to the hyperparameters i was using in GCN.

    @PetarV- @gcucurull Can you suggest me how I can extract feature from GAT and if I am doing correct way then why I am not getting good accuracy.

    Thank you

    opened by monk1337 13
  • What is difference between transductive and inductive in GNN?

    What is difference between transductive and inductive in GNN?

    It seems in GNN(graph neural network), in transductive situation, we input the whole graph and we mask the label of valid data and predict the label for the valid data.

    But is seems in inductive situation, we also input the whole graph(but sample to batch) and mask the label of the valid data and predict the label for the valid data.

    Thank you very much. @PetarV-

    opened by guotong1988 5
  • some question about GAT

    some question about GAT

    Hello, recently when I read paper of GAT again, I found a question and was confused. I hope to obtain your help. The coefficient of $\alpha_{i,j}$ is decided by the features of node i and node j under the supervised learning of those training set nodes with corresponding labels. But if the situation in the traing set is: The two vertices of an edge belong to the training set and the test set respectively. Theoretically, node without label will not be able to use gradient descent for learning. In this way, how does GAT it works? Thanks a lot !

    opened by junkangwu 4
  • Does the attention used in codes the same with the one in paper?

    Does the attention used in codes the same with the one in paper?

    I find in function attn_head() (in utils/ '''

    simplest self-attention possible

    f_1 = tf.layers.conv1d(seq_fts, 1, 1) f_2 = tf.layers.conv1d(seq_fts, 1, 1) logits = f_1 + tf.transpose(f_2, [0, 2, 1]) coefs = tf.nn.softmax(tf.nn.leaky_relu(logits) + bias_mat) ''' In my understanding,the codes equals to $$f_1 W_1 + f_2 W_2$$ but in the paper, the chose attention mechanism use concatenation, and $$W_1 = W_2 = W$$ Did I get something wrong?

    opened by hapoyige 4
  • Confusion about some codes

    Confusion about some codes

    Hi, Thank you for sharing the codes! I find the codes below (in def is a bit hard to understand, since they are not directly corresponding to the original description in text:

    f_1 = tf.layers.conv1d(seq_fts, 1, 1)
    f_2 = tf.layers.conv1d(seq_fts, 1, 1)
    f_1 = tf.reshape(f_1, (nb_nodes, 1))
    f_2 = tf.reshape(f_2, (nb_nodes, 1))
    f_1 = adj_mat*f_1
    f_2 = adj_mat * tf.transpose(f_2, [1,0])
    logits = tf.sparse_add(f_1, f_2)
    lrelu = tf.SparseTensor(indices=logits.indices, 
    coefs = tf.sparse_softmax(lrelu)

    seq_fts = tf.layers.conv1d(seq, out_sz, 1, use_bias=False) can be regarded as multiplying a weight matrix W (converting F features to F’ features). However, I do not understand the following steps:

    f_1 = tf.layers.conv1d(seq_fts, 1, 1)
    f_2 = tf.layers.conv1d(seq_fts, 1, 1)
    f_1 = adj_mat*f_1
    f_2 = adj_mat * tf.transpose(f_2, [1,0])
    logits = tf.sparse_add(f_1, f_2)

    I am aware that the codes are different since it is conducted on a “matrix” level instead of node level. But I cannot see how the attention mechanism is achieved by these steps. Can you help explain a bit?

    Many thanks, Weixin.

    opened by DexterZeng 4
  • Pytorch implementation

    Pytorch implementation


    I was really interested to use your model but unfortunately it is in Tensorflow (or Keras for the unofficial implementation). I propose here the Pytorch version ( where I obtained between 83.6-84.6 for the accuracy on the transductive learning in the cora task. If you would like to add it in your readme (didn't want to do a pull request for this).


    opened by Diego999 4
  • multi-graph node classsification

    multi-graph node classsification

    @gcucurull @PetarV- Hello, I have 20 objects, that is, there are 20 graphs. Now I divide the data set according to graphs, but GAT seems to be able to handle only one graph, and batch-size can only be set to 1. How can I do to process multi-graph node classfication?

    opened by tanjia123456 3
  • how to understand “dropping all structural information”

    how to understand “dropping all structural information”


    I'm trying to study GAT, and awesome works!

    I'm wondering how to understand "without ... depending on knowing the graph structure upfront", and "dropping all structural information".

    I also see "we only compute eij for nodes i belongs Ni, where Ni is some neighborhood of node i in the graph" in the paper, i also see function adj_to_bias() in the code which requires to know the adjacent matrix (the edges).

    So my understanding is that we do need to know the graph structure (edge information) at the beginning, thanks.

    opened by guoyejun 3
  • Question about pubmed dataset results

    Question about pubmed dataset results

    Hello, I am working on GAT. It's an excellent idea to introduce attention to graph.

    But I have some questions about the results on Pubmed dataset which you said "79.0 +- 0.3%".

    In the paper, it said :

    we found that Pubmed’s training set size (60 examples) required slight changes to the GAT architecture: we have applied K = 8 output attention heads (instead of one), and strengthened the L2 regularization to λ = 0.001. Otherwise, the architecture matches the one used for Cora and Citeseer.
    Both models are initialized using Glorot initialization (Glorot & Bengio, 2010) and trained to minimize cross-entropy on the training nodes using the Adam SGD optimizer (Kingma & Ba, 2014) with an initial learning rate of 0.01 for Pubmed, and 0.005 for all other datasets.

    Since you said "we use 500 additional nodes for validation purposes (the same ones as used by Kipf & Welling (2017))", I copy the data from gcn.

    And I changed the hyperparameter to:

    dataset = 'pubmed'
    lr = 0.01  # learning rate
    l2_coef = 0.001  # weight decay
    hid_units = [8] # numbers of hidden units per each attention head in each layer
    n_heads = [8, 8] # additional entry for the output layer

    and run 100 times, here is the code difference, I get the results " 0.777 +- 0.8%", which is not able to pass t-test.

    (79.0-77.7)/(0.8/\sqrt(100)) = 16.25 > 1.984 (sig 0.05)

    Is there anything wrong about the experiments parameters ? Can you help me to reproduce the pubmed dataset results?

    Thank you

    opened by hotchilipowder 3
  • Question about your sparse implementation

    Question about your sparse implementation

    Hi GAT folks,

    Awesome works! I tried out your sparse implementation. What I did are: (1) Replace the attn_head with sp_attn_head function. (2) Use sparse adj matrix. I removed the adj.todense() and process.adj_to_bias function calls and replace it with sparse version using scipy.

    I found I cannot achieve the expected accuracy after using the sparse implementation for Cora. Here is the log:

    Training: loss = 1.94517, acc = 0.15000 | Val: loss = 1.94614, acc = 0.20600 | Time: nan (graph/s)
    Training: loss = 1.93820, acc = 0.20714 | Val: loss = 1.94221, acc = 0.25000 | Time: nan (graph/s)
    Training: loss = 1.92800, acc = 0.28571 | Val: loss = 1.93924, acc = 0.26600 | Time: nan (graph/s)
    Training: loss = 1.91733, acc = 0.23571 | Val: loss = 1.93592, acc = 0.24400 | Time: nan (graph/s)
    Training: loss = 1.90697, acc = 0.30714 | Val: loss = 1.93212, acc = 0.21600 | Time: nan (graph/s)
    Training: loss = 1.89323, acc = 0.30714 | Val: loss = 1.92840, acc = 0.17800 | Time: 17.03976 (graph/s)
    Training: loss = 1.88492, acc = 0.26429 | Val: loss = 1.92444, acc = 0.18400 | Time: 17.02856 (graph/s)
    Training: loss = 1.87695, acc = 0.35000 | Val: loss = 1.91930, acc = 0.18000 | Time: 17.09180 (graph/s)
    Training: loss = 1.86555, acc = 0.28571 | Val: loss = 1.91408, acc = 0.18200 | Time: 17.05338 (graph/s)
    Training: loss = 1.84816, acc = 0.29286 | Val: loss = 1.90931, acc = 0.17800 | Time: 17.03552 (graph/s)
    Training: loss = 1.85963, acc = 0.25000 | Val: loss = 1.90346, acc = 0.18800 | Time: 16.99722 (graph/s)
    Training: loss = 1.86400, acc = 0.21429 | Val: loss = 1.89660, acc = 0.20400 | Time: 17.03414 (graph/s)
    Training: loss = 1.82015, acc = 0.32143 | Val: loss = 1.89007, acc = 0.22400 | Time: 16.99219 (graph/s)
    Training: loss = 1.80568, acc = 0.39286 | Val: loss = 1.88342, acc = 0.25800 | Time: 17.00136 (graph/s)
    Training: loss = 1.80814, acc = 0.35714 | Val: loss = 1.87724, acc = 0.27600 | Time: 16.99366 (graph/s)
    Training: loss = 1.80206, acc = 0.38571 | Val: loss = 1.87083, acc = 0.32000 | Time: 17.00094 (graph/s)
    Training: loss = 1.77896, acc = 0.41429 | Val: loss = 1.86466, acc = 0.34600 | Time: 17.05043 (graph/s)
    Training: loss = 1.76743, acc = 0.40714 | Val: loss = 1.85916, acc = 0.38600 | Time: 17.03179 (graph/s)
    Training: loss = 1.76884, acc = 0.38571 | Val: loss = 1.85294, acc = 0.45000 | Time: 17.02546 (graph/s)
    Training: loss = 1.76213, acc = 0.50000 | Val: loss = 1.84764, acc = 0.48800 | Time: 16.93918 (graph/s)
    Training: loss = 1.76706, acc = 0.45000 | Val: loss = 1.84279, acc = 0.52800 | Time: 16.97079 (graph/s)
    Training: loss = 1.75194, acc = 0.47143 | Val: loss = 1.83775, acc = 0.54800 | Time: 16.96184 (graph/s)
    Training: loss = 1.69834, acc = 0.55000 | Val: loss = 1.83297, acc = 0.56800 | Time: 16.95823 (graph/s)
    Training: loss = 1.71937, acc = 0.52857 | Val: loss = 1.82791, acc = 0.58200 | Time: 16.96505 (graph/s)
    Training: loss = 1.71782, acc = 0.50714 | Val: loss = 1.82320, acc = 0.58600 | Time: 16.96362 (graph/s)
    Training: loss = 1.69105, acc = 0.56429 | Val: loss = 1.81782, acc = 0.60000 | Time: 16.95437 (graph/s)
    Training: loss = 1.67340, acc = 0.60000 | Val: loss = 1.81338, acc = 0.59200 | Time: 16.78093 (graph/s)
    Training: loss = 1.70836, acc = 0.55000 | Val: loss = 1.80828, acc = 0.58800 | Time: 16.76095 (graph/s)
    Training: loss = 1.71126, acc = 0.56429 | Val: loss = 1.80237, acc = 0.58800 | Time: 16.78002 (graph/s)
    Training: loss = 1.66770, acc = 0.62857 | Val: loss = 1.79581, acc = 0.60400 | Time: 16.75709 (graph/s)
    Training: loss = 1.63414, acc = 0.62143 | Val: loss = 1.78960, acc = 0.60200 | Time: 16.76527 (graph/s)
    Training: loss = 1.64903, acc = 0.59286 | Val: loss = 1.78325, acc = 0.61000 | Time: 16.71849 (graph/s)
    Training: loss = 1.62723, acc = 0.55714 | Val: loss = 1.77702, acc = 0.61000 | Time: 16.72862 (graph/s)
    Training: loss = 1.65727, acc = 0.55714 | Val: loss = 1.77083, acc = 0.63000 | Time: 16.74734 (graph/s)
    Training: loss = 1.60982, acc = 0.61429 | Val: loss = 1.76514, acc = 0.60600 | Time: 16.73516 (graph/s)
    Training: loss = 1.56368, acc = 0.59286 | Val: loss = 1.75920, acc = 0.60800 | Time: 16.74822 (graph/s)
    Training: loss = 1.59711, acc = 0.59286 | Val: loss = 1.75289, acc = 0.58600 | Time: 16.73207 (graph/s)
    Training: loss = 1.56515, acc = 0.60714 | Val: loss = 1.74649, acc = 0.58600 | Time: 16.74241 (graph/s)
    Training: loss = 1.60043, acc = 0.55000 | Val: loss = 1.74041, acc = 0.58200 | Time: 16.73666 (graph/s)
    Training: loss = 1.57450, acc = 0.62143 | Val: loss = 1.73533, acc = 0.57800 | Time: 16.73286 (graph/s)
    Training: loss = 1.57213, acc = 0.57857 | Val: loss = 1.73015, acc = 0.56600 | Time: 16.73614 (graph/s)
    Training: loss = 1.55529, acc = 0.56429 | Val: loss = 1.72659, acc = 0.56600 | Time: 16.73717 (graph/s)
    Training: loss = 1.55898, acc = 0.55714 | Val: loss = 1.72352, acc = 0.56200 | Time: 16.73034 (graph/s)
    Training: loss = 1.55415, acc = 0.55000 | Val: loss = 1.72040, acc = 0.55200 | Time: 16.72506 (graph/s)
    Training: loss = 1.55050, acc = 0.52143 | Val: loss = 1.71850, acc = 0.53800 | Time: 16.73672 (graph/s)
    Training: loss = 1.47474, acc = 0.63571 | Val: loss = 1.71621, acc = 0.51600 | Time: 16.74091 (graph/s)
    Training: loss = 1.56495, acc = 0.50714 | Val: loss = 1.71540, acc = 0.49600 | Time: 16.73061 (graph/s)
    Training: loss = 1.51994, acc = 0.55714 | Val: loss = 1.71458, acc = 0.47800 | Time: 16.74436 (graph/s)
    Training: loss = 1.54271, acc = 0.52143 | Val: loss = 1.71304, acc = 0.46600 | Time: 16.75906 (graph/s)
    Training: loss = 1.58519, acc = 0.45000 | Val: loss = 1.71244, acc = 0.45000 | Time: 16.76833 (graph/s)
    Training: loss = 1.57245, acc = 0.50714 | Val: loss = 1.71141, acc = 0.44200 | Time: 16.76075 (graph/s)
    Training: loss = 1.62070, acc = 0.47857 | Val: loss = 1.70944, acc = 0.44400 | Time: 16.76779 (graph/s)
    Training: loss = 1.63155, acc = 0.50714 | Val: loss = 1.70797, acc = 0.47200 | Time: 16.76185 (graph/s)
    Training: loss = 1.56914, acc = 0.47143 | Val: loss = 1.70734, acc = 0.49400 | Time: 16.76988 (graph/s)
    Training: loss = 1.53856, acc = 0.52143 | Val: loss = 1.70701, acc = 0.48600 | Time: 16.73263 (graph/s)
    Training: loss = 1.46632, acc = 0.58571 | Val: loss = 1.70669, acc = 0.49000 | Time: 16.73773 (graph/s)
    Training: loss = 1.45926, acc = 0.60714 | Val: loss = 1.70659, acc = 0.50800 | Time: 16.75432 (graph/s)
    Training: loss = 1.48121, acc = 0.57857 | Val: loss = 1.70474, acc = 0.50800 | Time: 16.76423 (graph/s)
    Training: loss = 1.50514, acc = 0.55714 | Val: loss = 1.70184, acc = 0.50000 | Time: 16.77102 (graph/s)
    Training: loss = 1.50490, acc = 0.50000 | Val: loss = 1.69951, acc = 0.49400 | Time: 16.78197 (graph/s)
    Training: loss = 1.51039, acc = 0.53571 | Val: loss = 1.69778, acc = 0.49400 | Time: 16.78962 (graph/s)
    Training: loss = 1.45282, acc = 0.57143 | Val: loss = 1.69572, acc = 0.49400 | Time: 16.79195 (graph/s)
    Training: loss = 1.43123, acc = 0.55000 | Val: loss = 1.69297, acc = 0.49800 | Time: 16.79870 (graph/s)
    Training: loss = 1.51627, acc = 0.49286 | Val: loss = 1.68947, acc = 0.50000 | Time: 16.77023 (graph/s)
    Training: loss = 1.46445, acc = 0.53571 | Val: loss = 1.68655, acc = 0.50200 | Time: 16.77020 (graph/s)
    Training: loss = 1.49241, acc = 0.49286 | Val: loss = 1.68367, acc = 0.51000 | Time: 16.78139 (graph/s)
    Training: loss = 1.52911, acc = 0.47857 | Val: loss = 1.68273, acc = 0.52000 | Time: 16.79249 (graph/s)
    Training: loss = 1.48992, acc = 0.57143 | Val: loss = 1.68100, acc = 0.53000 | Time: 16.78402 (graph/s)
    Training: loss = 1.43546, acc = 0.57143 | Val: loss = 1.67985, acc = 0.53600 | Time: 16.72668 (graph/s)
    Training: loss = 1.48215, acc = 0.52857 | Val: loss = 1.67853, acc = 0.54000 | Time: 16.72524 (graph/s)
    Training: loss = 1.47648, acc = 0.55000 | Val: loss = 1.67829, acc = 0.53800 | Time: 16.72739 (graph/s)
    Training: loss = 1.44751, acc = 0.58571 | Val: loss = 1.67800, acc = 0.54000 | Time: 16.73214 (graph/s)
    Training: loss = 1.40865, acc = 0.56429 | Val: loss = 1.67713, acc = 0.54000 | Time: 16.73742 (graph/s)
    Training: loss = 1.47875, acc = 0.50714 | Val: loss = 1.67518, acc = 0.54000 | Time: 16.73667 (graph/s)
    Training: loss = 1.40626, acc = 0.52143 | Val: loss = 1.67420, acc = 0.53200 | Time: 16.73576 (graph/s)
    Training: loss = 1.46455, acc = 0.49286 | Val: loss = 1.67260, acc = 0.54000 | Time: 16.69977 (graph/s)
    Training: loss = 1.42937, acc = 0.55000 | Val: loss = 1.66953, acc = 0.54200 | Time: 16.70049 (graph/s)
    Training: loss = 1.44192, acc = 0.55000 | Val: loss = 1.66651, acc = 0.53800 | Time: 16.70967 (graph/s)
    Training: loss = 1.44210, acc = 0.55714 | Val: loss = 1.66280, acc = 0.53400 | Time: 16.72562 (graph/s)
    Training: loss = 1.36144, acc = 0.61429 | Val: loss = 1.65898, acc = 0.53000 | Time: 16.72490 (graph/s)
    Training: loss = 1.51469, acc = 0.53571 | Val: loss = 1.65483, acc = 0.52400 | Time: 16.71665 (graph/s)
    Training: loss = 1.41710, acc = 0.55000 | Val: loss = 1.65153, acc = 0.52400 | Time: 16.71838 (graph/s)
    Training: loss = 1.42846, acc = 0.56429 | Val: loss = 1.64860, acc = 0.51600 | Time: 16.71081 (graph/s)
    Training: loss = 1.48258, acc = 0.47143 | Val: loss = 1.64704, acc = 0.50600 | Time: 16.71550 (graph/s)
    Training: loss = 1.39769, acc = 0.59286 | Val: loss = 1.64604, acc = 0.50200 | Time: 16.72110 (graph/s)
    Training: loss = 1.41342, acc = 0.58571 | Val: loss = 1.64720, acc = 0.49800 | Time: 16.72464 (graph/s)
    Training: loss = 1.37405, acc = 0.55000 | Val: loss = 1.64852, acc = 0.48800 | Time: 16.73652 (graph/s)
    Training: loss = 1.36246, acc = 0.56429 | Val: loss = 1.64919, acc = 0.48400 | Time: 16.73602 (graph/s)
    Training: loss = 1.35721, acc = 0.57857 | Val: loss = 1.65088, acc = 0.48200 | Time: 16.73687 (graph/s)
    Training: loss = 1.46561, acc = 0.52857 | Val: loss = 1.65400, acc = 0.48000 | Time: 16.70509 (graph/s)
    Training: loss = 1.41449, acc = 0.53571 | Val: loss = 1.65670, acc = 0.48200 | Time: 16.70738 (graph/s)
    Training: loss = 1.46798, acc = 0.48571 | Val: loss = 1.65805, acc = 0.47800 | Time: 16.71920 (graph/s)
    Training: loss = 1.39453, acc = 0.57143 | Val: loss = 1.66067, acc = 0.47400 | Time: 16.71041 (graph/s)
    Training: loss = 1.40467, acc = 0.53571 | Val: loss = 1.66297, acc = 0.46800 | Time: 16.70864 (graph/s)
    Training: loss = 1.41027, acc = 0.49286 | Val: loss = 1.66376, acc = 0.47000 | Time: 16.70883 (graph/s)
    Training: loss = 1.46268, acc = 0.50000 | Val: loss = 1.66714, acc = 0.46800 | Time: 16.70941 (graph/s)
    Training: loss = 1.41257, acc = 0.55714 | Val: loss = 1.66891, acc = 0.47200 | Time: 16.71804 (graph/s)
    Training: loss = 1.40972, acc = 0.50714 | Val: loss = 1.66983, acc = 0.47400 | Time: 16.72160 (graph/s)
    Training: loss = 1.42352, acc = 0.53571 | Val: loss = 1.67032, acc = 0.47200 | Time: 16.71816 (graph/s)
    Training: loss = 1.38608, acc = 0.53571 | Val: loss = 1.67717, acc = 0.46000 | Time: 16.72157 (graph/s)
    Training: loss = 1.43882, acc = 0.52857 | Val: loss = 1.68334, acc = 0.45200 | Time: 16.72999 (graph/s)
    Training: loss = 1.49979, acc = 0.49286 | Val: loss = 1.69646, acc = 0.43600 | Time: 16.73467 (graph/s)
    Training: loss = 1.51738, acc = 0.50000 | Val: loss = 1.71791, acc = 0.42800 | Time: 16.74103 (graph/s)
    Training: loss = 1.41980, acc = 0.54286 | Val: loss = 1.73596, acc = 0.41200 | Time: 16.74587 (graph/s)
    Training: loss = 1.49115, acc = 0.45000 | Val: loss = 1.75199, acc = 0.38800 | Time: 16.74358 (graph/s)
    Training: loss = 1.50618, acc = 0.49286 | Val: loss = 1.77033, acc = 0.38400 | Time: 16.72982 (graph/s)
    Training: loss = 1.54672, acc = 0.52143 | Val: loss = 1.80930, acc = 0.37800 | Time: 16.72799 (graph/s)
    Training: loss = 1.67387, acc = 0.47143 | Val: loss = 1.86923, acc = 0.35000 | Time: 16.73573 (graph/s)
    Training: loss = 1.69535, acc = 0.45714 | Val: loss = 1.94518, acc = 0.32400 | Time: 16.73458 (graph/s)
    Training: loss = 1.81820, acc = 0.41429 | Val: loss = 2.03584, acc = 0.31200 | Time: 16.73830 (graph/s)
    Training: loss = 1.93426, acc = 0.29286 | Val: loss = 2.14381, acc = 0.29800 | Time: 16.73484 (graph/s)
    Training: loss = 1.97562, acc = 0.27143 | Val: loss = 2.27155, acc = 0.28400 | Time: 16.73221 (graph/s)
    Training: loss = 2.17724, acc = 0.30714 | Val: loss = 2.42012, acc = 0.26800 | Time: 16.71115 (graph/s)
    Training: loss = 2.60620, acc = 0.28571 | Val: loss = 2.56814, acc = 0.24800 | Time: 16.71131 (graph/s)
    Training: loss = 3.08253, acc = 0.28571 | Val: loss = 2.75077, acc = 0.21800 | Time: 16.71770 (graph/s)
    Training: loss = 3.59296, acc = 0.23571 | Val: loss = 2.96115, acc = 0.21000 | Time: 16.71862 (graph/s)
    Training: loss = 4.29261, acc = 0.27857 | Val: loss = 3.19888, acc = 0.20200 | Time: 16.71416 (graph/s)
    Training: loss = 3.61426, acc = 0.19286 | Val: loss = 3.46225, acc = 0.19200 | Time: 16.71198 (graph/s)
    Training: loss = 4.62925, acc = 0.17857 | Val: loss = 3.75955, acc = 0.18400 | Time: 16.71600 (graph/s)
    Training: loss = 5.59345, acc = 0.20714 | Val: loss = 4.08395, acc = 0.17800 | Time: 16.71734 (graph/s)
    Training: loss = 6.05617, acc = 0.17143 | Val: loss = 4.44487, acc = 0.17400 | Time: 16.71906 (graph/s)
    Training: loss = 5.78760, acc = 0.18571 | Val: loss = 4.83142, acc = 0.16800 | Time: 16.71927 (graph/s)
    Training: loss = 6.78114, acc = 0.20000 | Val: loss = 5.25684, acc = 0.16400 | Time: 16.72466 (graph/s)
    Training: loss = 8.34020, acc = 0.18571 | Val: loss = 5.70460, acc = 0.16400 | Time: 16.73279 (graph/s)
    Training: loss = 8.18466, acc = 0.15714 | Val: loss = 6.18899, acc = 0.16400 | Time: 16.73642 (graph/s)
    Training: loss = 10.17127, acc = 0.20000 | Val: loss = 6.70358, acc = 0.16200 | Time: 16.72992 (graph/s)
    Training: loss = 8.01908, acc = 0.17143 | Val: loss = 7.24232, acc = 0.16400 | Time: 16.73301 (graph/s)
    Training: loss = 10.54279, acc = 0.15714 | Val: loss = 7.80857, acc = 0.12200 | Time: 16.72627 (graph/s)
    Training: loss = 11.20833, acc = 0.14286 | Val: loss = 8.40816, acc = 0.12200 | Time: 16.71242 (graph/s)
    Training: loss = 13.13751, acc = 0.14286 | Val: loss = 9.03840, acc = 0.12200 | Time: 16.71696 (graph/s)
    Training: loss = 15.50717, acc = 0.12857 | Val: loss = 9.70438, acc = 0.12600 | Time: 16.72127 (graph/s)
    Training: loss = 14.83573, acc = 0.14286 | Val: loss = 10.38955, acc = 0.12400 | Time: 16.72117 (graph/s)
    Training: loss = 16.32450, acc = 0.14286 | Val: loss = 11.10900, acc = 0.12800 | Time: 16.71955 (graph/s)
    Training: loss = 20.88581, acc = 0.13571 | Val: loss = 11.85103, acc = 0.13000 | Time: 16.72041 (graph/s)
    Training: loss = 18.14150, acc = 0.15714 | Val: loss = 12.62478, acc = 0.12800 | Time: 16.71870 (graph/s)
    Training: loss = 20.19287, acc = 0.12857 | Val: loss = 13.42502, acc = 0.12600 | Time: 16.71075 (graph/s)
    Training: loss = 20.17040, acc = 0.13571 | Val: loss = 14.25172, acc = 0.12800 | Time: 16.71072 (graph/s)
    Training: loss = 21.68529, acc = 0.13571 | Val: loss = 15.11031, acc = 0.13000 | Time: 16.71549 (graph/s)
    Training: loss = 24.35412, acc = 0.12857 | Val: loss = 15.99096, acc = 0.13000 | Time: 16.70976 (graph/s)
    Training: loss = 26.90386, acc = 0.09286 | Val: loss = 16.90097, acc = 0.12800 | Time: 16.71122 (graph/s)
    Training: loss = 32.60686, acc = 0.12857 | Val: loss = 17.83447, acc = 0.12400 | Time: 16.71450 (graph/s)
    Training: loss = 23.71290, acc = 0.10714 | Val: loss = 18.79685, acc = 0.12200 | Time: 16.71559 (graph/s)
    Training: loss = 32.05572, acc = 0.12143 | Val: loss = 19.79207, acc = 0.12400 | Time: 16.71759 (graph/s)
    Training: loss = 37.51984, acc = 0.17857 | Val: loss = 20.82151, acc = 0.12600 | Time: 16.71863 (graph/s)
    Training: loss = 43.44203, acc = 0.10714 | Val: loss = 21.87587, acc = 0.12400 | Time: 16.71384 (graph/s)
    Training: loss = 41.57386, acc = 0.12857 | Val: loss = 22.96205, acc = 0.12400 | Time: 16.71532 (graph/s)
    Training: loss = 32.44603, acc = 0.12857 | Val: loss = 24.04948, acc = 0.12800 | Time: 16.72206 (graph/s)
    Training: loss = 32.47784, acc = 0.15714 | Val: loss = 25.16518, acc = 0.12400 | Time: 16.72132 (graph/s)
    Training: loss = 57.02023, acc = 0.14286 | Val: loss = 26.32536, acc = 0.08800 | Time: 16.72422 (graph/s)
    Training: loss = 47.02181, acc = 0.13571 | Val: loss = 27.52316, acc = 0.09000 | Time: 16.72099 (graph/s)
    Training: loss = 60.71332, acc = 0.12857 | Val: loss = 28.76072, acc = 0.09200 | Time: 16.71417 (graph/s)
    Training: loss = 52.61975, acc = 0.13571 | Val: loss = 30.03493, acc = 0.09000 | Time: 16.70092 (graph/s)
    Training: loss = 55.17526, acc = 0.14286 | Val: loss = 31.34324, acc = 0.09200 | Time: 16.70220 (graph/s)
    Training: loss = 72.76334, acc = 0.12143 | Val: loss = 32.69006, acc = 0.09200 | Time: 16.70329 (graph/s)
    Training: loss = 42.62173, acc = 0.17143 | Val: loss = 34.06517, acc = 0.09200 | Time: 16.70729 (graph/s)
    Training: loss = 68.27650, acc = 0.15714 | Val: loss = 35.47098, acc = 0.09200 | Time: 16.70868 (graph/s)
    Training: loss = 53.12449, acc = 0.14286 | Val: loss = 36.86567, acc = 0.09200 | Time: 16.70579 (graph/s)
    Training: loss = 75.30608, acc = 0.14286 | Val: loss = 38.30105, acc = 0.09200 | Time: 16.70677 (graph/s)
    Training: loss = 66.39566, acc = 0.13571 | Val: loss = 39.77269, acc = 0.09200 | Time: 16.70885 (graph/s)
    Training: loss = 90.00805, acc = 0.15000 | Val: loss = 41.29017, acc = 0.09000 | Time: 16.70979 (graph/s)
    Training: loss = 74.48537, acc = 0.15714 | Val: loss = 42.83427, acc = 0.09000 | Time: 16.70017 (graph/s)
    Training: loss = 83.59474, acc = 0.14286 | Val: loss = 44.40656, acc = 0.09000 | Time: 16.70253 (graph/s)
    Training: loss = 100.54999, acc = 0.12143 | Val: loss = 46.01003, acc = 0.09000 | Time: 16.70330 (graph/s)
    Training: loss = 78.89310, acc = 0.15000 | Val: loss = 47.63050, acc = 0.09000 | Time: 16.70234 (graph/s)
    Training: loss = 84.41219, acc = 0.16429 | Val: loss = 49.28119, acc = 0.09000 | Time: 16.70904 (graph/s)
    Training: loss = 88.26729, acc = 0.17857 | Val: loss = 50.97437, acc = 0.08800 | Time: 16.70679 (graph/s)
    Training: loss = 65.11741, acc = 0.14286 | Val: loss = 52.67675, acc = 0.08800 | Time: 16.70200 (graph/s)
    Training: loss = 85.14694, acc = 0.18571 | Val: loss = 54.41297, acc = 0.08800 | Time: 16.69457 (graph/s)
    Training: loss = 97.87077, acc = 0.15714 | Val: loss = 56.18945, acc = 0.08800 | Time: 16.69361 (graph/s)
    Training: loss = 98.05456, acc = 0.13571 | Val: loss = 57.99624, acc = 0.08800 | Time: 16.68741 (graph/s)
    Training: loss = 107.37295, acc = 0.13571 | Val: loss = 59.82062, acc = 0.08800 | Time: 16.68968 (graph/s)
    Training: loss = 114.99556, acc = 0.14286 | Val: loss = 61.67920, acc = 0.08800 | Time: 16.68897 (graph/s)
    Training: loss = 135.17931, acc = 0.12857 | Val: loss = 63.58669, acc = 0.08800 | Time: 16.68689 (graph/s)
    Training: loss = 109.73069, acc = 0.15000 | Val: loss = 65.50701, acc = 0.08800 | Time: 16.68364 (graph/s)
    Training: loss = 142.71088, acc = 0.11429 | Val: loss = 67.47314, acc = 0.08800 | Time: 16.67883 (graph/s)
    Training: loss = 122.98055, acc = 0.11429 | Val: loss = 69.45295, acc = 0.08800 | Time: 16.67969 (graph/s)
    Training: loss = 122.06322, acc = 0.12857 | Val: loss = 71.47054, acc = 0.08800 | Time: 16.67910 (graph/s)
    Training: loss = 130.83694, acc = 0.12857 | Val: loss = 73.51564, acc = 0.08800 | Time: 16.67657 (graph/s)
    Training: loss = 147.36774, acc = 0.12143 | Val: loss = 75.58735, acc = 0.08800 | Time: 16.67618 (graph/s)
    Training: loss = 128.03278, acc = 0.15714 | Val: loss = 77.68494, acc = 0.08800 | Time: 16.67729 (graph/s)
    Training: loss = 155.61414, acc = 0.13571 | Val: loss = 79.81355, acc = 0.08600 | Time: 16.67960 (graph/s)
    Training: loss = 149.44771, acc = 0.15714 | Val: loss = 81.98518, acc = 0.08600 | Time: 16.67523 (graph/s)
    Training: loss = 143.26579, acc = 0.12143 | Val: loss = 84.19203, acc = 0.08600 | Time: 16.67501 (graph/s)
    Training: loss = 147.99545, acc = 0.14286 | Val: loss = 86.41908, acc = 0.08800 | Time: 16.67924 (graph/s)
    Training: loss = 153.45032, acc = 0.13571 | Val: loss = 88.67501, acc = 0.08800 | Time: 16.68069 (graph/s)

    You could see the accuracy dropped suddenly after achieving ~60% for validation set. Have you guys met similar problems? Did I miss anything?

    Thank you, Minjie

    opened by jermainewang 3
  • Weighted adjacency graphs

    Weighted adjacency graphs


    I am currently working on a project where we have weighted directed graphs. I know that the current model support directed graphs, i.e. the neighborhood is defined as the incoming nodes.

    However, is it possible to use weighted adjacency graphs? My intuition is saying no, as the edge weights are learned through the attention mechanisms. However, in the associated blog there is a footnote that states that GAT can be trivially extended to include these cases ( ).

    So if it is possible, how would one do this?

    • Joakim Haurum
    opened by JoakimHaurum 2
  • How to get node embeddings?

    How to get node embeddings?

    Hi, I would like to ask how to use GAT to generate node embeddings for downstream tasks, like link prediction. And If nodes have no features, how to initialize it?

    I'd appreciate your help!

    opened by linz2000 0
  • Why using undirected graph for transductive learning?

    Why using undirected graph for transductive learning?

    Thanks for your great work! Excuse me for a question about the transductive experiments.

    According to the paper, GAT could be applied to directed graph. I wonder that why do you use undirected Cora, Citeseer and Pubmed for experiments rather than the original directed ones? As we all know, the citation network is directed.

    opened by qncsn2016 0
  • transform to other scope dataset

    transform to other scope dataset

    I'm gonna transform to other scope dataset this code

    other dataset is a file (one file) and shape is below column1, column2, column3, ...column9, label 1500, 45, 1, .... 0 933, 22, 0, .... 0 1234, 30, 1, .... 1 1112, 23, 0, .... 0 ...

    Let me know how to transform to above dataset with this code

    opened by kimsijin33 0
  • How to batch large datasets to apply GAT?

    How to batch large datasets to apply GAT?

    Hi, Thanks for great repo.

    I want to apply GAT on a very large dataset and train it in GPU. So I should provide my data in small batches. However, the algorithm uses the full data. So I was wondering how I could run GAT with small batches?

    opened by taherhekmatfar 0
  • Need help in creating evaulation metrics

    Need help in creating evaulation metrics

    hi @PetarV- , I'm new to machine learning, so I humbly request you to help me out in creating evaulation metrics such as precision, recall, f1 score for this model in your code. I could train my model on my own dataset. Could you please help me as in how to construct evaluation metrics?

    opened by Joey0538 0
Petar Veličković
Senior Research Scientist
Petar Veličković
