Public Implementation of ChIRo from "Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations"

Related tags

Deep Learning ChIRo
Overview

Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations

ScreenShot

This directory contains the model architectures and experimental setups used for ChIRo, SchNet, DimeNet++, and SphereNet on the four tasks considered in the preprint:

Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations

These four tasks are:

  1. Contrastive learning to cluster conformers of different stereoisomers in a learned latent space
  2. Classification of chiral centers as R/S
  3. Classification of the sign (+/-; l/d) of rotated circularly polarized light
  4. Ranking enantiomers by their docking scores in an enantiosensitive protein pocket.

The exact data splits used for tasks (1), (2), and (4) can be downloaded from:

https://figshare.com/s/e23be65a884ce7fc8543

See the appendix of "Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations" for details on how the datasets for task (3) were extracted and filtered from the commercial Reaxys database.


This directory is organized as follows:

  • Subdirectory model/ contains the implementation of ChIRo.

    • model/alpha_encoder.py contains the network architecture of ChIRo

    • model/embedding_functions.py contains the featurization of the input conformers (RDKit mol objects) for ChIRo.

    • model/datasets_samplers.py contains the Pytorch / Pytorch Geometric data samplers used for sampling conformers in each training batch.

    • model/train_functions.py and model/train_models.py contain supporting training/inference loops for each experiment with ChIRo.

    • model/optimization_functions.py contains the loss functions used in the experiments with ChIRo.

    • Subdirectory model/gnn_3D/ contains the implementations of SchNet, DimeNet++, and SphereNet used for each experiment.

      • model/gnn_3D/schnet.py contains the publicly available code for SchNet, with adaptations for readout.
      • model/gnn_3D/dimenet_pp.py contains the publicly available code for DimeNet++, with adaptations for readout.
      • model/gnn_3D/spherenet.py contains the publicly available code for SphereNet, with adaptations for readout.
      • model/gnn_3D/train_functions.py and model/gnn_3D/train_models.py contain the training/inference loops for each experiment with SchNet, DimeNet++, or SphereNet.
      • model/gnn_3D/optimization_functions.py contains the loss functions used in the experiments with SchNet, DimeNet++, or SphereNet.
  • Subdirectory params_files/ contains the hyperparameters used to define exact network initializations for ChIRo, SchNet, DimeNet++, and SphereNet for each experiment. The parameter .json files are specified with a random seed = 1, and the first fold of cross validation for the l/d classifcation task. For the experiments specified in the paper, we use random seeds = 1,2,3 when repeating experiments across three training/test trials.

  • Subdirectory training_scripts/ contains the python scripts to run each of the four experiments, for each of the four 3D models ChIRo, SchNet, DimeNet++, and SphereNet. Before running each experiment, move the corresponding training script to the parent directory.

  • Subdirectory hyperopt/ contains hyperparameter optimization scripts for ChIRo using Raytune.

  • Subdirectory experiment_analysis/ contains jupyter notebooks for analyzing results of each experiment.

  • Subdirectory paper_results/ contains the parameter files, model parameter dictionaries, and loss curves for each experiment reported in the paper.


To run each experiment, first create a conda environment with the following dependencies:

  • python = 3.8.6
  • pytorch = 1.7.0
  • torchaudio = 0.7.0
  • torchvision = 0.8.1
  • torch-geometric = 1.6.3
  • torch-cluster = 1.5.8
  • torch-scatter = 2.0.5
  • torch-sparce = 0.6.8
  • torch-spline-conv = 1.2.1
  • numpy = 1.19.2
  • pandas = 1.1.3
  • rdkit = 2020.09.4
  • scikit-learn = 0.23.2
  • matplotlib = 3.3.3
  • scipy = 1.5.2
  • sympy = 1.8
  • tqdm = 4.58.0

Then, download the datasets (with exact training/validation/test splits) from https://figshare.com/s/e23be65a884ce7fc8543 and place them in a new directory final_data_splits/

You may then run each experiment by calling:

python training_{experiment}_{model}.py params_files/params_{experiment}_{model}.json {path_to_results_directory}/

For instance, you can run the docking experiment for ChIRo with a random seed of 1 (editable in the params .json file) by calling:

python training_binary_ranking.py params_files/params_binary_ranking_ChIRo.json results_binary_ranking_ChIRo/

After training, this will create a results directory containing model checkpoints, best model parameter dictionaries, and results on the test set (if applicable).

You might also like...
Dcf-game-infrastructure-public - Contains all the components necessary to run a DC finals (attack-defense CTF) game from OOO

dcf-game-infrastructure All the components necessary to run a game of the OOO DC

Official public repository of paper
Official public repository of paper "Intention Adaptive Graph Neural Network for Category-Aware Session-Based Recommendation"

Intention Adaptive Graph Neural Network (IAGNN) This is the official repository of paper Intention Adaptive Graph Neural Network for Category-Aware Se

A forwarding MPI implementation that can use any other MPI implementation via an MPI ABI

MPItrampoline MPI wrapper library: MPI trampoline library: MPI integration tests: MPI is the de-facto standard for inter-node communication on HPC sys

ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... 모델의 개념이해를 돕기 위한 구현물로 현재 변수명을 상세히 적었고

Numenta Platform for Intelligent Computing is an implementation of Hierarchical Temporal Memory (HTM), a theory of intelligence based strictly on the neuroscience of the neocortex.

NuPIC Numenta Platform for Intelligent Computing The Numenta Platform for Intelligent Computing (NuPIC) is a machine intelligence platform that implem

PyTorch implementation of neural style transfer algorithm
PyTorch implementation of neural style transfer algorithm

neural-style-pt This is a PyTorch implementation of the paper A Neural Algorithm of Artistic Style by Leon A. Gatys, Alexander S. Ecker, and Matthias

PyTorch implementation of DeepDream algorithm
PyTorch implementation of DeepDream algorithm

neural-dream This is a PyTorch implementation of DeepDream. The code is based on neural-style-pt. Here we DeepDream a photograph of the Golden Gate Br

The project is an official implementation of our CVPR2019 paper
The project is an official implementation of our CVPR2019 paper "Deep High-Resolution Representation Learning for Human Pose Estimation"

Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019) News [2020/07/05] A very nice blog from Towards Data Science introd

Image-to-Image Translation with Conditional Adversarial Networks (Pix2pix) implementation in keras

pix2pix-keras Pix2pix implementation in keras. Original paper: Image-to-Image Translation with Conditional Adversarial Networks (pix2pix) Paper Author

Comments
  • question about `get_local_structure_map`

    question about `get_local_structure_map`

    hello Keir,

    Nice work on ChIRo! we've been exploring its application for the task of protein-ligand affinity prediction at my startup.

    I have gotten the model to train and it seems to be working, but I have a question about the way get_local_structure_map is defined. I noticed that it's meant to work on batches of psi_indices. Is there any specific reason why this was done? Would it be possible to define a similar function that just generates the input features LS_map and alpha_indices for one molecule at a time? This is more from an engineering perspective.

    https://github.com/keiradams/ChIRo/blob/2686d3a1801db8fb3dec10b61fb0cb0cec047c7c/model/train_functions.py#L22-L37

    Thanks!

    opened by linminhtoo 2
  • spherenet xyz_to_dat repeat_interleave issue

    spherenet xyz_to_dat repeat_interleave issue

    Hi, first of all many thanks for open-sourcing this!

    I was checking out your implementation of the SphereNet architecture, and I noticed that I cannot successfully perform a forward pass of the network. Namely, the problem appears to lie in the xyz_to_dat function defined within that submodule.

    I start with 3D coordinates in cartesian space and compute edge indices according torch_geometric.nn.pool.radius_graph:

    In [13]: g.coords
    Out[13]: 
    tensor([[ 4.5877,  1.2124,  0.9045],
            [ 3.4939,  2.0393,  0.7065],
            [ 2.5577,  1.7294, -0.2622],
            [ 2.6985,  0.5842, -1.0372],
            [ 1.6548,  0.2541, -2.0763],
            [ 0.3290,  0.2897, -1.5940],
     ...]])
    
    In [14]: g.coords.shape
    Out[14]: torch.Size([41, 3])
    
    
    In [16]: edge_index = radius_graph(g.coords, r=5.0, batch=torch.zeros(g.atomids.size(0), dtype=int))
    
    In [17]: edge_index
    Out[17]: 
    tensor([[20,  1, 19,  ...,  2,  4,  3],
            [ 0,  0,  0,  ..., 40, 40, 40]])
    
    In [18]: edge_index.shape
    Out[18]: torch.Size([2, 937])
    
    

    I then try and call the xyz_to_dat routine present in the forward pass, and I encounter this error:

    In [20]:     out = xyz_to_dat(pos=g.coords, edge_index=edge_index, num_nodes=g.atomids.size(0), use_torsion=True)
        ...: 
    ---------------------------------------------------------------------------
    RuntimeError                              Traceback (most recent call last)
    <ipython-input-20-e5d12c7af566> in <module>
    ----> 1 out = xyz_to_dat(pos=g.coords, edge_index=edge_index, num_nodes=g.atomids.size(0), use_torsion=True)
    
    ~net.py in xyz_to_dat(pos, edge_index, num_nodes, use_torsion)
        114     repeat = num_triplets - 1
        115     num_triplets_t = num_triplets.repeat_interleave(repeat)
    --> 116     idx_i_t = idx_i.repeat_interleave(num_triplets_t)
        117     idx_j_t = idx_j.repeat_interleave(num_triplets_t)
        118     idx_k_t = idx_k.repeat_interleave(num_triplets_t)
    
    RuntimeError: repeats must have the same size as input along dim
    
    

    I'm wondering whether the problem is on my side and I'm misunderstanding how these functions shoud be called. Any help is appreciated!

    opened by josejimenezluna 2
  • faster positive/negative sample map generation

    faster positive/negative sample map generation

    Thanks for open-sourcing this cool project.

    I was playing around and found Sample_Map_To_Positives and Sample_Map_To_Negatives are pretty slow, so I did some optimization. Now it feels significantly faster because there is no need to loop over subsets of the frame.

    Does my change look right to you?

    opened by wwang2 1
Owner
null
The first public PyTorch implementation of Attentive Recurrent Comparators

arc-pytorch PyTorch implementation of Attentive Recurrent Comparators by Shyam et al. A blog explaining Attentive Recurrent Comparators Visualizing At

Sanyam Agarwal 150 Oct 14, 2022
Rank 1st in the public leaderboard of ScanRefer (2021-03-18)

InstanceRefer InstanceRefer: Cooperative Holistic Understanding for Visual Grounding on Point Clouds through Instance Multi-level Contextual Referring

null 63 Dec 7, 2022
A public available dataset for road boundary detection in aerial images

Topo-boundary This is the official github repo of paper Topo-boundary: A Benchmark Dataset on Topological Road-boundary Detection Using Aerial Images

Zhenhua Xu 79 Jan 4, 2023
atmaCup #11 の Public 4th / Pricvate 5th Solution のリポジトリです。

#11 atmaCup 2021-07-09 ~ 2020-07-21 に行われた #11 [初心者歓迎! / 画像編] atmaCup のリポジトリです。結果は Public 4th / Private 5th でした。 フレームワークは PyTorch で、実装は pytorch-image-m

Tawara 12 Apr 7, 2022
UIUCTF 2021 Public Challenge Repository

UIUCTF-2021-Public UIUCTF 2021 Public Challenge Repository Notes: every challenge folder contains a challenge.yml file in the format for ctfcli, CTFd'

SIGPwny 15 Nov 3, 2022
GradAttack is a Python library for easy evaluation of privacy risks in public gradients in Federated Learning

GradAttack is a Python library for easy evaluation of privacy risks in public gradients in Federated Learning, as well as corresponding mitigation strategies.

null 129 Dec 30, 2022
Public repository of the 3DV 2021 paper "Generative Zero-Shot Learning for Semantic Segmentation of 3D Point Clouds"

Generative Zero-Shot Learning for Semantic Segmentation of 3D Point Clouds Björn Michele1), Alexandre Boulch1), Gilles Puy1), Maxime Bucher1) and Rena

valeo.ai 15 Dec 22, 2022
Public repository created to store my custom-made tools for Just Dance (UbiArt Engine)

Woody's Just Dance Tools Public repository created to store my custom-made tools for Just Dance (UbiArt Engine) Development and updates Almost all of

Wodson de Andrade 8 Dec 24, 2022
Genshin-assets - 👧 Public documentation & static assets for Genshin Impact data.

genshin-assets This repo provides easy access to the Genshin Impact assets, primarily for use on static sites. Sources Genshin Optimizer - An Artifact

Zerite Development 5 Nov 22, 2022
Public scripts, services, and configuration for running a smart home K3S network cluster

makerhouse_network Public scripts, services, and configuration for running MakerHouse's home network. This network supports: TODO features here For mo

Scott Martin 1 Jan 15, 2022