GemNet model in PyTorch, as proposed in "GemNet: Universal Directional Graph Neural Networks for Molecules" (NeurIPS 2021)

Overview

GemNet: Universal Directional Graph Neural Networks for Molecules

Reference implementation in PyTorch of the geometric message passing neural network (GemNet). You can find its original TensorFlow 2 implementation in another repository. GemNet is a model for predicting the overall energy and the forces acting on the atoms of a molecule. It was proposed in the paper:

GemNet: Universal Directional Graph Neural Networks for Molecules
by Johannes Klicpera, Florian Becker, Stephan Günnemann
Published at NeurIPS 2021.

Run the code

Adjust config.yaml (or config_seml.yaml) to your needs. This repository contains notebooks for training the model (train.ipynb) and for generating predictions on a molecule loaded from ASE (predict.ipynb). It also contains a script for training the model on a cluster with Sacred and SEML (train_seml.py).

Compute scaling factors

You can either use the precomputed scaling_factors (in scaling_factors.json) or compute them yourself by running fit_scaling.py. Scaling factors are used to ensure a consistent scale of activations at initialization. They are the same for all GemNet variants.

Contact

Please contact [email protected] if you have any questions.

Cite

Please cite our paper if you use the model or this code in your own work:

@inproceedings{klicpera_gemnet_2021,
  title = {GemNet: Universal Directional Graph Neural Networks for Molecules},
  author = {Klicpera, Johannes and Becker, Florian and G{\"u}nnemann, Stephan},
  booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
  year = {2021}
}
Comments
  • Package installation throws errors

    Package installation throws errors

    There were several issues prohibiting the installation of gemnet as package:

    • The setup.py file referred to a gemnet_pytorch package even though the folder name is gemnet
    • python, cudatoolkit, pytorch and torch_geometric are invalid pip packages.

    Most of these should be fixed by 1922774fbf8e21179de5c64d7490fe8dc28fa720

    opened by n-gao 3
  • Some questions regarding reducing the size of both input and model?

    Some questions regarding reducing the size of both input and model?

    Hi, thanks for sharing the code of GemNet, wonderful work for the prediction of energy.

    However, I intend to adopt your model to macromolecules such as proteins instead of small molecules. However, as you know, proteins have far more atoms, which unavoidably leads to much more GPU memories. In order to prevent the explosion of GPU memories, I had to control the size of model inputs. Thus, I am here to ask for some advice on how to reduce the size of your model input.

    Specifically, a straightforward way is to decrease the cutoff distance. So that the edges become less. But I believe that is not a good practice. Can you give me some other solutions (for other hyperparameters)?

    image

    opened by smiles724 3
  • Question about pretrained weights

    Question about pretrained weights

    Hi there,

    Thank you for providing the codebase. I am using GemNet-T pertained weights, and I get better performance in the downstream task. But I can't find how to get the GemNet-T pretrained weights.

    Can I know what database and target is used in pretraining?

    opened by okin1234 2
  • the dihedral angle difinition

    the dihedral angle difinition

    GemNet is a great work for molecules representation learning.I have a question about dihedral angle "cabd" definition,the definition in your paper seems not consider the direction of dihedral angle , what i means is that the dihedral angles in the flowing picture is the same in your definition, but we konw the dihedral angles is different .The difference may have a influence on message passing and your model. image

    opened by wang1215789 1
  • Reproducing Questions

    Reproducing Questions

    Hi there,

    Thank you for providing the codebase. I am trying to reproduce GemNet on COLL and MD17, and have two questions below:

    1. For COLL, I'm wondering what is the total training time? On my end, it seems to be quite large (300-400 hours), so just want to double-check with you.
    2. For MD17, can you help provide the config.yml and scaling_factors.json files on each task?

    Any help is appreciated.

    opened by chao1224 1
  • Question about the gradient of positions

    Question about the gradient of positions

    Hi, great work, and thanks for sharing the code. I have a small question regarding the gradient of inputs["R"]. As far as I know, inputs["R"] represents the positions of atoms. Why do you make its requires_grad True?

            if not self.direct_forces:
                inputs["R"].requires_grad = True
    
    opened by smiles724 1
  • A potential bug in data_provider.py

    A potential bug in data_provider.py

    Thank you for your awesome work! When using the proposed model in another downstream task, I found that there is a potential bug in https://github.com/TUM-DAML/gemnet_pytorch/blob/master/gemnet/training/data_provider.py#L31-L49, it could causes evaluating the model on the training data rather than the testing subset.

    For example, our indices are [n_train: n_train+n_val] when using split="val". As the shuffle=False, the idx sampler is SequentialSampler(Subset(data_container, indices)). Notice that this sampler produces indices in range [0, n_val].

    However,

            super().__init__(
                data_container, ## here is the full set.
                sampler=batch_sampler,
                collate_fn=lambda x: collate(x, data_container),
                pin_memory=True,  # load on CPU push to GPU
                **kwargs
            )
    

    as shown in this code snippet, the "dataset" passed to the DataLoader is the full dataset. Then, the iterator of data loader would take sample according to the indices provided by the sampler. As illustrated above, the sampler produces indices in range [0, n_val]. Therefore, it actually takes data from a subset of training part.

    I noticed that such a problem is avoid in train.ipynb by using two data containers. However, this part would be ambiguous for users who want to generalize this model to other datasets.

    I tried to fix it as:

    class CustomDataLoader(DataLoader):
        def __init__(
            self, data_container, batch_size, indices, shuffle, seed=None, **kwargs
        ):
    
            if shuffle:
                generator = torch.Generator()
                if seed is not None:
                    generator.manual_seed(seed)
                idx_sampler = SubsetRandomSampler(indices, generator)
            else:
                idx_sampler = SequentialSampler(Subset(data_container, indices))
    
            batch_sampler = BatchSampler(
                idx_sampler, batch_size=batch_size, drop_last=False
            )
            # Note: a bug here if we do not use subset.
            # Sequential sampler on subset returns index like (0, 1, 2, 3...)
            # However, the returned index is on the full data. 
            # If we do not take Subset here, it uses data from training subset. 
            dataset = data_container if shuffle else Subset(data_container, indices)
    
            super().__init__(
                dataset ,
                sampler=batch_sampler,
                collate_fn=data_container.collate_fn,
                pin_memory=True,  # load on CPU push to GPU
                **kwargs
            )
    
    opened by gyfastas 1
  • QM9 dataset

    QM9 dataset

    Hi, I notice you implemented experiments in extensive datasets. However, most of your baseline models show their performance in the popular QM9 dataset. Can you please provide this sort of information so that we can have a more clear understanding of how well your model is? Thanks.

    opened by smiles724 1
  • Error at line 287 in  basis_layers.py

    Error at line 287 in basis_layers.py

    Traceback (most recent call last): File "", line 1, in File "/Users/ngoccuongnguyen/GitHub/gemnet_pytorch/gemnet/model/gemnet.py", line 14, in from .layers.basis_layers import BesselBasisLayer, SphericalBasisLayer, TensorBasisLayer File "/Users/ngoccuongnguyen/GitHub/gemnet_pytorch/gemnet/model/layers/basis_layers.py", line 287 Kmax = if sph.shape[0]==0 else torch.max(torch.max(Kidx + 1), torch.tensor(0))
    ^ SyntaxError: invalid syntax

    opened by exapde 1
Owner
Data Analytics and Machine Learning Group
Data Analytics and Machine Learning Group
Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two

512x512 flowers after 12 hours of training, 1 gpu 256x256 flowers after 12 hours of training, 1 gpu Pizza 'Lightweight' GAN Implementation of 'lightwe

Phil Wang 1.5k Jan 2, 2023
This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in Eurographics 2021

Deep-Detail-Enhancement-for-Any-Garment Introduction This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in

null 40 Dec 13, 2022
This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021.

Off-Belief Learning Introduction This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021. Environment Setup

Facebook Research 32 Jan 5, 2023
The implemetation of Dynamic Nerual Garments proposed in Siggraph Asia 2021

DynamicNeuralGarments Introduction This repository contains the implemetation of Dynamic Nerual Garments proposed in Siggraph Asia 2021. ./GarmentMoti

null 42 Dec 27, 2022
Torch-ngp - A pytorch implementation of the hash encoder proposed in instant-ngp

HashGrid Encoder (WIP) A pytorch implementation of the HashGrid Encoder from ins

hawkey 1k Jan 1, 2023
PyTorch reimplementation of the Smooth ReLU activation function proposed in the paper "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations" [arXiv 2022].

Smooth ReLU in PyTorch Unofficial PyTorch reimplementation of the Smooth ReLU (SmeLU) activation function proposed in the paper Real World Large Scale

Christoph Reich 10 Jan 2, 2023
A PyTorch implementation of Mugs proposed by our paper "Mugs: A Multi-Granular Self-Supervised Learning Framework".

Mugs: A Multi-Granular Self-Supervised Learning Framework This is a PyTorch implementation of Mugs proposed by our paper "Mugs: A Multi-Granular Self-

Sea AI Lab 62 Nov 8, 2022
The source code for the Cutoff data augmentation approach proposed in this paper: "A Simple but Tough-to-Beat Data Augmentation Approach for Natural Language Understanding and Generation".

Cutoff: A Simple Data Augmentation Approach for Natural Language This repository contains source code necessary to reproduce the results presented in

Dinghan Shen 49 Dec 22, 2022
Code and data of the Fine-Grained R2R Dataset proposed in paper Sub-Instruction Aware Vision-and-Language Navigation

Fine-Grained R2R Code and data of the Fine-Grained R2R Dataset proposed in the EMNLP2020 paper Sub-Instruction Aware Vision-and-Language Navigation. C

YicongHong 34 Nov 15, 2022
An implementation for the loss function proposed in Decoupled Contrastive Loss paper.

Decoupled-Contrastive-Learning This repository is an implementation for the loss function proposed in Decoupled Contrastive Loss paper. Requirements P

Ramin Nakhli 71 Dec 4, 2022
Code for CMaskTrack R-CNN (proposed in Occluded Video Instance Segmentation)

CMaskTrack R-CNN for OVIS This repo serves as the official code release of the CMaskTrack R-CNN model on the Occluded Video Instance Segmentation data

Q . J . Y 61 Nov 25, 2022
Implementation of the method proposed in the paper "Neural Descriptor Fields: SE(3)-Equivariant Object Representations for Manipulation"

Neural Descriptor Fields (NDF) PyTorch implementation for training continuous 3D neural fields to represent dense correspondence across objects, and u

null 167 Jan 6, 2023
Implement object segmentation on images using HOG algorithm proposed in CVPR 2005

HOG Algorithm Implementation Description HOG (Histograms of Oriented Gradients) Algorithm is an algorithm aiming to realize object segmentation (edge

Leo Hsieh 2 Mar 12, 2022
Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Phil Wang 180 Jan 5, 2023
Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

FLASH - Pytorch Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time Install $ pip install FLASH-pytorch

Phil Wang 209 Dec 28, 2022
In this project we investigate the performance of the SetCon model on realistic video footage. Therefore, we implemented the model in PyTorch and tested the model on two example videos.

Contrastive Learning of Object Representations Supervisor: Prof. Dr. Gemma Roig Institutions: Goethe University CVAI - Computational Vision & Artifici

Dirk Neuhäuser 6 Dec 8, 2022
Pytorch implementation of RED-SDS (NeurIPS 2021).

Recurrent Explicit Duration Switching Dynamical Systems (RED-SDS) This repository contains a reference implementation of RED-SDS, a non-linear state s

Abdul Fatir 10 Dec 2, 2022
Official Pytorch implementation of "Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021)

Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021) Official Pytorch implementation of Unbiased Classification

Youngkyu 17 Jan 1, 2023