Public Implementation of ChIRo from "Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations"

Related tags

Deep Learning ChIRo

Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations


This directory contains the model architectures and experimental setups used for ChIRo, SchNet, DimeNet++, and SphereNet on the four tasks considered in the preprint:

Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations

These four tasks are:

  1. Contrastive learning to cluster conformers of different stereoisomers in a learned latent space
  2. Classification of chiral centers as R/S
  3. Classification of the sign (+/-; l/d) of rotated circularly polarized light
  4. Ranking enantiomers by their docking scores in an enantiosensitive protein pocket.

The exact data splits used for tasks (1), (2), and (4) can be downloaded from:

See the appendix of "Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations" for details on how the datasets for task (3) were extracted and filtered from the commercial Reaxys database.

This directory is organized as follows:

  • Subdirectory model/ contains the implementation of ChIRo.

    • model/ contains the network architecture of ChIRo

    • model/ contains the featurization of the input conformers (RDKit mol objects) for ChIRo.

    • model/ contains the Pytorch / Pytorch Geometric data samplers used for sampling conformers in each training batch.

    • model/ and model/ contain supporting training/inference loops for each experiment with ChIRo.

    • model/ contains the loss functions used in the experiments with ChIRo.

    • Subdirectory model/gnn_3D/ contains the implementations of SchNet, DimeNet++, and SphereNet used for each experiment.

      • model/gnn_3D/ contains the publicly available code for SchNet, with adaptations for readout.
      • model/gnn_3D/ contains the publicly available code for DimeNet++, with adaptations for readout.
      • model/gnn_3D/ contains the publicly available code for SphereNet, with adaptations for readout.
      • model/gnn_3D/ and model/gnn_3D/ contain the training/inference loops for each experiment with SchNet, DimeNet++, or SphereNet.
      • model/gnn_3D/ contains the loss functions used in the experiments with SchNet, DimeNet++, or SphereNet.
  • Subdirectory params_files/ contains the hyperparameters used to define exact network initializations for ChIRo, SchNet, DimeNet++, and SphereNet for each experiment. The parameter .json files are specified with a random seed = 1, and the first fold of cross validation for the l/d classifcation task. For the experiments specified in the paper, we use random seeds = 1,2,3 when repeating experiments across three training/test trials.

  • Subdirectory training_scripts/ contains the python scripts to run each of the four experiments, for each of the four 3D models ChIRo, SchNet, DimeNet++, and SphereNet. Before running each experiment, move the corresponding training script to the parent directory.

  • Subdirectory hyperopt/ contains hyperparameter optimization scripts for ChIRo using Raytune.

  • Subdirectory experiment_analysis/ contains jupyter notebooks for analyzing results of each experiment.

  • Subdirectory paper_results/ contains the parameter files, model parameter dictionaries, and loss curves for each experiment reported in the paper.

To run each experiment, first create a conda environment with the following dependencies:

  • python = 3.8.6
  • pytorch = 1.7.0
  • torchaudio = 0.7.0
  • torchvision = 0.8.1
  • torch-geometric = 1.6.3
  • torch-cluster = 1.5.8
  • torch-scatter = 2.0.5
  • torch-sparce = 0.6.8
  • torch-spline-conv = 1.2.1
  • numpy = 1.19.2
  • pandas = 1.1.3
  • rdkit = 2020.09.4
  • scikit-learn = 0.23.2
  • matplotlib = 3.3.3
  • scipy = 1.5.2
  • sympy = 1.8
  • tqdm = 4.58.0

Then, download the datasets (with exact training/validation/test splits) from and place them in a new directory final_data_splits/

You may then run each experiment by calling:

python training_{experiment}_{model}.py params_files/params_{experiment}_{model}.json {path_to_results_directory}/

For instance, you can run the docking experiment for ChIRo with a random seed of 1 (editable in the params .json file) by calling:

python params_files/params_binary_ranking_ChIRo.json results_binary_ranking_ChIRo/

After training, this will create a results directory containing model checkpoints, best model parameter dictionaries, and results on the test set (if applicable).

  • question about `get_local_structure_map`

    question about `get_local_structure_map`

    hello Keir,

    Nice work on ChIRo! we've been exploring its application for the task of protein-ligand affinity prediction at my startup.

    I have gotten the model to train and it seems to be working, but I have a question about the way get_local_structure_map is defined. I noticed that it's meant to work on batches of psi_indices. Is there any specific reason why this was done? Would it be possible to define a similar function that just generates the input features LS_map and alpha_indices for one molecule at a time? This is more from an engineering perspective.


    opened by linminhtoo 2
  • spherenet xyz_to_dat repeat_interleave issue

    spherenet xyz_to_dat repeat_interleave issue

    Hi, first of all many thanks for open-sourcing this!

    I was checking out your implementation of the SphereNet architecture, and I noticed that I cannot successfully perform a forward pass of the network. Namely, the problem appears to lie in the xyz_to_dat function defined within that submodule.

    I start with 3D coordinates in cartesian space and compute edge indices according torch_geometric.nn.pool.radius_graph:

    In [13]: g.coords
    tensor([[ 4.5877,  1.2124,  0.9045],
            [ 3.4939,  2.0393,  0.7065],
            [ 2.5577,  1.7294, -0.2622],
            [ 2.6985,  0.5842, -1.0372],
            [ 1.6548,  0.2541, -2.0763],
            [ 0.3290,  0.2897, -1.5940],
    In [14]: g.coords.shape
    Out[14]: torch.Size([41, 3])
    In [16]: edge_index = radius_graph(g.coords, r=5.0, batch=torch.zeros(g.atomids.size(0), dtype=int))
    In [17]: edge_index
    tensor([[20,  1, 19,  ...,  2,  4,  3],
            [ 0,  0,  0,  ..., 40, 40, 40]])
    In [18]: edge_index.shape
    Out[18]: torch.Size([2, 937])

    I then try and call the xyz_to_dat routine present in the forward pass, and I encounter this error:

    In [20]:     out = xyz_to_dat(pos=g.coords, edge_index=edge_index, num_nodes=g.atomids.size(0), use_torsion=True)
    RuntimeError                              Traceback (most recent call last)
    <ipython-input-20-e5d12c7af566> in <module>
    ----> 1 out = xyz_to_dat(pos=g.coords, edge_index=edge_index, num_nodes=g.atomids.size(0), use_torsion=True) in xyz_to_dat(pos, edge_index, num_nodes, use_torsion)
        114     repeat = num_triplets - 1
        115     num_triplets_t = num_triplets.repeat_interleave(repeat)
    --> 116     idx_i_t = idx_i.repeat_interleave(num_triplets_t)
        117     idx_j_t = idx_j.repeat_interleave(num_triplets_t)
        118     idx_k_t = idx_k.repeat_interleave(num_triplets_t)
    RuntimeError: repeats must have the same size as input along dim

    I'm wondering whether the problem is on my side and I'm misunderstanding how these functions shoud be called. Any help is appreciated!

    opened by josejimenezluna 2
  • faster positive/negative sample map generation

    faster positive/negative sample map generation

    Thanks for open-sourcing this cool project.

    I was playing around and found Sample_Map_To_Positives and Sample_Map_To_Negatives are pretty slow, so I did some optimization. Now it feels significantly faster because there is no need to loop over subsets of the frame.

    Does my change look right to you?

    opened by wwang2 1
