Code for KDD'20 "Generative Pre-Training of Graph Neural Networks"

Overview

GPT-GNN: Generative Pre-Training of Graph Neural Networks



GPT-GNN is a pre-training framework to initialize GNNs by generative pre-training. It can be applied to large-scale and heterogensous graphs.

You can see our KDD 2020 paper Generative Pre-Training of Graph Neural Networks for more details.

Overview

The key package is GPT_GNN, which contains the the high-level GPT-GNN pretraining framework, base GNN models, and base graph structure and data loader.

To illustrate how to apply the GPT_GNN framework for arbitrary graphs, we provide examples of pre-training on both hetergeneous (OAG) and homogeneous graphs (reddit). Both of them are of large-scale.

Within each example_* package, there is a pretrain_*.py file for pre-training a GNN on the given graph, and also multiple finetune_*.py files for training and validating on downstream tasks.

DataSet

For Open Academic Graph (OAG), we provide a heterogeneous graph containing highly-cited CS papers (8.1G) spanning from 1900-2020. You can download the preprocessed graph via this link. We split the data by their time: Pre-training ( t < 2014 ); Training ( 2014 <= t < 2017); Validation ( t = 2017 ); Testing ( 2018 <= t ). As we use the raw-text as attribute generation task for OAG, we provide a pre-trained word2vec model via this link.

If you want to directly process from raw data, you can download via this link. After downloading it, run preprocess_OAG.py to extract features and store them in our data structure.

For Reddit, we simply download the preprocessed graph using pyG.datasets API, and then turn it into our own data structure using preprocess_reddit.py. We randomly split the data into different sets.

Setup

This implementation is based on pytorch_geometric. To run the code, you need the following dependencies:

You can simply run pip install -r requirements.txt to install all the necessary packages.

Usage

We first introduce the arguments to control hyperparameters. There are mainly three types of arguments, for pre-training; for dataset; for model and optimization.

For pre-training, we provide arguments to control different modules for attribute and edge generation tasks:

  --attr_ratio                     FLOAT   The ratio (0~1) of attribute generation loss .       Default is 0.5.
  --attr_type                      STR     type of attribute decoder ['text' or 'vec']          Default is 'vec'
  --neg_samp_num                   INT     Whether to use layer-norm on the last layer.         Default is False.
  --queue_size                     INT     Max size of adaptive embedding queue.                Default is 256.

For datasets, we provide arguments to control mini-batch sampling:

  --data_dir                       STR     The address of preprocessed graph.
  --pretrain_model_dir             STR     The address for storing the pre-trained models.
  --sample_depth                   INT     How many layers within a mini-batch subgraph         Default is 6.
  --sample_width                   INT     How many nodes to be sampled per layer per type      Default is 128.

For both pre-training and fine-tuning, we provide arguments to control model and optimizer hyperparameters. We highlight some key arguments below:

  --conv_name                      STR     Name of GNN filter (model)                           Default is hgt.
  --scheduler                      STR     Name of learning rate scheduler                      Default is cycle (for pretrain) and cosine (for fine-tuning)
  --n_hid                          INT     Number of hidden dimension                           Default is 400.
  --n_layers                       INT     Number of GNN layers                                 Default is 3.
  --prev_norm                      BOOL    Whether to use layer-norm on previous layers.        Default is False.
  --last_norm                      BOOL    Whether to use layer-norm on the last layer.         Default is False.
  --max_lr                         FLOAT   Maximum learning rate.                               Default is 1e-3 (for pretrain) and 5e-4 (for fine-tuning).  

The following commands pretrain a 3-layer HGT over OAG-CS:

python pretrain_OAG.py --attr_type text --conv_name hgt --n_layers 3 --pretrain_model_dir /datadrive/models/gta_all_cs3

The following commands use the pre-trained model as initialization and finetune on the paper-field classification task using 10% of training and validation data:

python finetune_OAG_PF.py --use_pretrain --pretrain_model_dir /datadrive/models/gta_all_cs3 --n_layer 3 --data_percentage 0.1

Pre-trained Models

  1. The 3-layer HGT model pre-trained over OAG-CS under Time-Transfer Setting via this link
  2. The 3-layer HGT model pre-trained over Reddit via this link

Citation

Please consider citing the following paper when using our code for your application.

@inproceedings{gpt_gnn,
  title={GPT-GNN: Generative Pre-Training of Graph Neural Networks},
  author={Ziniu Hu and Yuxiao Dong and Kuansan Wang and Kai-Wei Chang and Yizhou Sun},
  booktitle={Proceedings of the 26th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
  year={2020}
}

This implementation is mainly based on pyHGT API.

Comments
  • error when run example_reddit/pretrain_reddit.py using cpu

    error when run example_reddit/pretrain_reddit.py using cpu

    When I run example_reddit/pretrain_reddit.py using cpu, encountered this error So any advice is welcome

    ... File "/usr/local/lib/python3.5/dist-packages/pandas/core/generic.py", line 5063, in getattr return object.getattribute(self, name) File "pandas/_libs/properties.pyx", line 65, in pandas._libs.properties.AxisProperty.get File "/usr/local/lib/python3.5/dist-packages/pandas/core/generic.py", line 5063, in getattr return object.getattribute(self, name) RecursionError: maximum recursion depth exceeded while calling a Python object

    opened by herbertguoqi 11
  • About the dataset in GPT_GNN

    About the dataset in GPT_GNN

    I notice that you actually have three categories (CS/Med/NN) in OAG, which is available in the preprocessed graphs. I am interested in the whole datasets about the three categories. Maybe, you can provide the raw data about the Med and NN like CS. Thanks for your help in advance.

    opened by SKD621 6
  • pretrain_OAG.py hgt

    pretrain_OAG.py hgt

    Thanks for your great work. I try to run the pretrain_OAG.py and get this error. Traceback (most recent call last): File "/home/tiaoban/chen_qian_yu/gptgnn/GPT-GNN/example_OAG/pretrain_OAG.py", line 228, in node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), edge_index.to(device), edge_type.to(device)) File "/home/tiaoban/anaconda3/envs/pyg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, **kwargs) File "/home/tiaoban/chen_qian_yu/gptgnn/GPT-GNN/example_OAG/GPT_GNN/model.py", line 191, in forward meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time) File "/home/tiaoban/anaconda3/envs/pyg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, **kwargs) File "/home/tiaoban/chen_qian_yu/gptgnn/GPT-GNN/example_OAG/GPT_GNN/conv.py", line 169, in forward return self.base_conv(meta_xs, node_type, edge_index, edge_type, edge_time) File "/home/tiaoban/anaconda3/envs/pyg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, **kwargs) File "/home/tiaoban/chen_qian_yu/gptgnn/GPT-GNN/example_OAG/GPT_GNN/conv.py", line 55, in forward edge_type=edge_type, edge_time = edge_time) File "/home/tiaoban/anaconda3/envs/pyg/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py", line 233, in propagate kwargs) File "/home/tiaoban/anaconda3/envs/pyg/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py", line 156, in collect self.set_size(size, dim, data) File "/home/tiaoban/anaconda3/envs/pyg/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py", line 118, in set_size size[dim] = src.size(self.node_dim) IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)

    There are any problems with the pretrain_OAG.py?Thanks!

    opened by qyc-98 4
  • I want to know if  I can fine-tuning directly on the GPT-GNN model?

    I want to know if I can fine-tuning directly on the GPT-GNN model?

    hi @acbull image

    that means I must take some categories for pre-training and then using another categories to fine-tuning? I wonder if i can fine-tuning in the GPT-GNN model directly or not?like fine-tuning in the pre-trained bert,just modify the data to adjust their input format and then fine-tuning for the downstream task.

    thanks.

    opened by yangxia605 3
  • Error when using han or hetgnn

    Error when using han or hetgnn

    Hi, when I finetune directly (without pre-train) on models based on han or hetgnn, I encounter with this error:

    Start Loading Graph Data... Finish Loading Graph Data! Data Preparation: 91.0s Traceback (most recent call last): File "finetune_OAG_PF.py", line 252, in res = classifier.forward(node_rep[x_ids]) TypeError: 'NoneType' object is not subscriptable

    Do you have any ideas for the reasons? Thank you!

    opened by bangann 3
  • How to create the node permutation?

    How to create the node permutation?

    Hi, @acbull. thank you so much for sharing the code. I have a question on the graph preprocessing. May I ask how to determinate the order (permutation) of nodes in a sampled subgraph? What's the criterion that we refer to? -:)

    opened by voladorlu 2
  • About ablation studies on base GNNs

    About ablation studies on base GNNs

    Hi,

    Thank you for your good paper!

    The caption of Tab.2 says that experiments included here are under the combined transfer setting, where the accuracy for GPT-GNN + HGT is 0.407:

    image

    However, according to Tab.1, GPT-GNN + HGT under this setting achieves acc of 0.393. Rather, 0.407 is the result of GPT-GNN+ HGT under the field transfer setting.

    Could you please clarify the setting used in Tab.2?

    Additionally, I am also wondering how you evaluated GAE, GraphSage (unsp.), and Graph Infomax in experiments in Tab.1. As far as I know, GNN encoders used in the above three papers are GCN, the GraphSage architecture, and GIN, which are designed to learn from homogeneous graphs. May I ask whether you also used these GNN encoders, i.e., GCN, GraphSage and GIN, for related experiments in Tab.1? If so, could you please elaborate on how to apply these encoders to OAG and Amazon that are heterogeneous graphs?

    Thank you!

    opened by Kqiii 2
  • About the format of the data set?

    About the format of the data set?

    hi @acbull my case is that i should use a sample of the training data sets to build a graph but not all of the training data sets. that is to say: Use the internal elements of the sample to construct a heterogeneous graph. I was wonder if i can use GPT-GNN to fine-tuning it and then Do a classification task in the downstream ?

    thanks!!!

    opened by yangxia605 2
  • how to use the GPT-GNN on a graph?

    how to use the GPT-GNN on a graph?

    hi @acbull thanks for your excellent work!!! I want to known how to fine-tuning the GPT-GNN on my established graph? image is there any examples or guide files?

    opened by yangxia605 2
  • Results of example data

    Results of example data

    Hi, could you please provide the result of GPT-GNN on the OAG-CS data? I'm wondering what is the expected result if I run your code directly on the example data you provided. It would be great if a result for baseline (no pre-train) could also be provided. Thank you!

    opened by bangann 2
  • f1_score

    f1_score

    Hi acbull, When I use finetune_reddit.py, i get an error name 'f1_score' is not defined int 203 line. and I didnt find f1_score reference in example_reddit actually. so could you tell me about the f1_score's detail? thanks!

    opened by Mobzhang 1
  • How to generate vfi_vector.csv

    How to generate vfi_vector.csv

    Hello, I really appreciate your work. And I want to use my own dataset to pretrain the model. But when I preprocess the data, I didn't have the file 'vfi_vector.csv'. So I want to ask how to generate 'vfi_vector.csv' in my own work?

    opened by byronBBL 0
  • Can't download pretrain-model about Raddit

    Can't download pretrain-model about Raddit

    When I open the link https://drive.google.com/file/d/1Ja4PJT2bkFH0qgoWXjGBjByIFPco4h-S/view?usp=sharing, shown that I need access. How can I successfully download the pretrain model? Thank you!

    opened by jiaruHithub 2
  • About down load the OAG dataset OAG_ALL

    About down load the OAG dataset OAG_ALL

    Sorry to interrupt.

    Where can I download to your full dataset or processsed full dataset? The link for the full dataset 'https://www.openacademic.ai/oag/' (which you mentioned in other issues) seems unreachable nowadays.

    Thanks for your help in advance.

    opened by CLIS-237 1
  • Raise 'IndexError: index out of range in self'    when I use HGTConv with my own dataset.

    Raise 'IndexError: index out of range in self' when I use HGTConv with my own dataset.

    In my work, I build the Hetergouenous graph firstly, then I apply the HGTConv as the tutorial shown.

    The custom KG is shown as: HeteroData( symptom={ x=[39, 128] }, component={ x=[19, 128] }, reason={ x=[17, 128] }, solution={ x=[18, 128] }, (symptom, take_place, component)={ edge_index=[2, 38] }, (symptom, cause_by, reason)={ edge_index=[2, 33] }, (symptom, how_to_fit, solution)={ edge_index=[2, 33] }, (component, component_parallel, component)={ edge_index=[2, 17] } )

    And my code as follow:


    class HGT(torch.nn.Module): def init(self, hidden_channels, out_channels, num_heads, num_layers): super().init()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data_IKG.node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)
    
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data_IKG.metadata(),
                           num_heads, group='sum',cached=False)
        self.convs.append(conv)
    
        self.lin = Linear(hidden_channels, out_channels)
    
    def forward(self, x_dict, edge_index_dict):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()
    
        for conv in self.convs:
            print(1)
            x_dict = conv(x_dict, edge_index_dict)
            print(2)
        return self.lin(x_dict['symptom'])
    

    model_IKG = HGT(hidden_channels=64, out_channels=5, num_heads=1, num_layers=1)

    with torch.no_grad(): # Initialize lazy modules. out = model_IKG(data_IKG.x_dict, data_IKG.edge_index_dict)

    Then it raises: Input In [324], in HGT.forward(self, x_dict, edge_index_dict) 29 for conv in self.convs: 30 print(1) ---> 31 x_dict = conv(x_dict, edge_index_dict) 32 print(2) 33 return self.lin(x_dict['symptom'])


    Then the error message is:


    File ~/.virtualenvs/xlq/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs) 1098 # If we don't have any hooks, we want to skip the rest of the logic in 1099 # this function, and just call forward. 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(*input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], []

    File ~/.virtualenvs/xlq/lib/python3.8/site-packages/torch_geometric/nn/conv/hgt_conv.py:159, in HGTConv.forward(self, x_dict, edge_index_dict) 156 v = (v_dict[src_type].transpose(0, 1) @ m_rel).transpose(1, 0) 158 # propagate_type: (k: Tensor, q: Tensor, v: Tensor, rel: Tensor) --> 159 out = self.propagate(edge_index, k=k, q=q_dict[dst_type], v=v, 160 rel=self.p_rel[edge_type], size=None) 161 out_dict[dst_type].append(out) 163 # Iterate over node-types:

    File ~/.virtualenvs/xlq/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py:309, in MessagePassing.propagate(self, edge_index, size, **kwargs) 306 for arg in decomp_args: 307 kwargs[arg] = decomp_kwargs[arg][i] --> 309 coll_dict = self.collect(self.user_args, edge_index, 310 size, kwargs) 312 msg_kwargs = self.inspector.distribute('message', coll_dict) 313 for hook in self._message_forward_pre_hooks.values():

    File ~/.virtualenvs/xlq/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py:202, in MessagePassing.collect(self, args, edge_index, size, kwargs) 200 if isinstance(data, Tensor): 201 self.set_size(size, dim, data) --> 202 data = self.lift(data, edge_index, dim) 204 out[arg] = data 206 if isinstance(edge_index, Tensor):

    File ~/.virtualenvs/xlq/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py:172, in MessagePassing.lift(self, src, edge_index, dim) 170 if isinstance(edge_index, Tensor): 171 index = edge_index[dim] --> 172 return src.index_select(self.node_dim, index) 173 elif isinstance(edge_index, SparseTensor): 174 if dim == 1:

    IndexError: index out of range in self


    Please help me to fix the problem, much appreciate.

    opened by rickyqiao 0
  • IndexError:

    IndexError: "index out of range in self" in training on custom dataset

    i, I was trying to use HGTConv on a custom graph with 5 different nodes, but I kept on running into an error IndexError: index out of range in self when node_type only has target node.

    Error messages: node_type = tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1..., 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) node_type.shape = (129,) edge_index = tensor([[ 0, 1, 14, 3, 4, 5, 6, 7,...82, 184, 188, 189, 189, 190, 190, 192]]) edge_index.shape = (2, 353)

    IndexError: index out of range in self When calling: self.propagate(edge_index, node_inp=node_inp, node_type=node_type,
    edge_type=edge_type, edge_time = edge_time) Call ended by exception meta_xs = gc(meta_xs, node_type_id, edge_index, edge_type, edge_time) IndexError: index out of range in self When calling: gc(meta_xs, node_type_id, edge_index, edge_type, edge_time) Call ended by exception

    I was looking at https://github.com/pyg-team/pytorch_geometric/issues/2073 where suggestion "remove cached=True argument from the GCNConv layer can solve the index error.

    and https://github.com/pyg-team/pytorch_geometric/issues/1631: set add_self_loops=False in GATConv(..., add_self_loops=False), but no such argument in HGTConv.

    opened by leon-cas 1
Owner
Ziniu Hu
CS PhD student at UCLA
Ziniu Hu
Inference code for "StylePeople: A Generative Model of Fullbody Human Avatars" paper. This code is for the part of the paper describing video-based avatars.

NeuralTextures This is repository with inference code for paper "StylePeople: A Generative Model of Fullbody Human Avatars" (CVPR21). This code is for

Visual Understanding Lab @ Samsung AI Center Moscow 18 Oct 6, 2022
A code generator from ONNX to PyTorch code

onnx-pytorch Generating pytorch code from ONNX. Currently support onnx==1.9.0 and torch==1.8.1. Installation From PyPI pip install onnx-pytorch From

Wenhao Hu 94 Jan 6, 2023
This is the code for our KILT leaderboard submission to the T-REx and zsRE tasks. It includes code for training a DPR model then continuing training with RAG.

KGI (Knowledge Graph Induction) for slot filling This is the code for our KILT leaderboard submission to the T-REx and zsRE tasks. It includes code fo

International Business Machines 72 Jan 6, 2023
Convert Python 3 code to CUDA code.

Py2CUDA Convert python code to CUDA. Usage To convert a python file say named py_file.py to CUDA, run python generate_cuda.py --file py_file.py --arch

Yuval Rosen 3 Jul 14, 2021
Empirical Study of Transformers for Source Code & A Simple Approach for Handling Out-of-Vocabulary Identifiers in Deep Learning for Source Code

Transformers for variable misuse, function naming and code completion tasks The official PyTorch implementation of: Empirical Study of Transformers fo

Bayesian Methods Research Group 56 Nov 15, 2022
Reference implementation of code generation projects from Facebook AI Research. General toolkit to apply machine learning to code, from dataset creation to model training and evaluation. Comes with pretrained models.

This repository is a toolkit to do machine learning for programming languages. It implements tokenization, dataset preprocessing, model training and m

Facebook Research 408 Jan 1, 2023
Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

CoProtector Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

Zhensu Sun 1 Oct 26, 2021
Low-code/No-code approach for deep learning inference on devices

EzEdgeAI A concept project that uses a low-code/no-code approach to implement deep learning inference on devices. It provides a componentized framewor

On-Device AI Co., Ltd. 7 Apr 5, 2022
Code for all the Advent of Code'21 challenges mostly written in python

Advent of Code 21 Code for all the Advent of Code'21 challenges mostly written in python. They are not necessarily the best or fastest solutions but j

null 4 May 26, 2022
Code to use Augmented Shapiro Wilks Stopping, as well as code for the paper "Statistically Signifigant Stopping of Neural Network Training"

This codebase is being actively maintained, please create and issue if you have issues using it Basics All data files are included under losses and ea

J K Terry 32 Nov 9, 2021
Opinionated code formatter, just like Python's black code formatter but for Beancount

beancount-black Opinionated code formatter, just like Python's black code formatter but for Beancount Try it out online here Features MIT licensed - b

Launch Platform 16 Oct 11, 2022
a delightful machine learning tool that allows you to train, test and use models without writing code

igel A delightful machine learning tool that allows you to train/fit, test and use models without writing code Note I'm also working on a GUI desktop

Nidhal Baccouri 3k Jan 5, 2023
Pytorch Lightning code guideline for conferences

Deep learning project seed Use this seed to start new deep learning / ML projects. Built in setup.py Built in requirements Examples with MNIST Badges

Pytorch Lightning 1k Jan 2, 2023
Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.

Auto-ViML Automatically Build Variant Interpretable ML models fast! Auto_ViML is pronounced "auto vimal" (autovimal logo created by Sanket Ghanmare) N

AutoViz and Auto_ViML 397 Dec 30, 2022
Code samples for my book "Neural Networks and Deep Learning"

Code samples for "Neural Networks and Deep Learning" This repository contains code samples for my book on "Neural Networks and Deep Learning". The cod

Michael Nielsen 13.9k Dec 26, 2022
Code for: https://berkeleyautomation.github.io/bags/

DeformableRavens Code for the paper Learning to Rearrange Deformable Cables, Fabrics, and Bags with Goal-Conditioned Transporter Networks. Here is the

Daniel Seita 121 Dec 30, 2022
Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166

Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation In this repo, we provide the code for our paper : "Few-Shot Segmentation Wit

Malik Boudiaf 138 Dec 12, 2022
Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Facebook Research 68 Dec 29, 2022
Code for "Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search"

Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search This is an implementation for our paper Contextual Non-Loca

Tencent YouTu Research 50 Dec 3, 2022