PyTorch implementation of NIPS 2017 paper Dynamic Routing Between Capsules

Overview

Dynamic Routing Between Capsules - PyTorch implementation

PyTorch implementation of NIPS 2017 paper Dynamic Routing Between Capsules from Sara Sabour, Nicholas Frosst and Geoffrey E. Hinton.

The hyperparameters and data augmentation strategy strictly follow the paper.

Requirements

Only PyTorch with torchvision is required (tested on pytorch 0.2.0 and 0.3.0). Jupyter and matplotlib is required to run the notebook with visualizations.

Usage

Train the model by running

python net.py

Optional arguments and default values:

  --batch-size N          input batch size for training (default: 128)
  --test-batch-size N     input batch size for testing (default: 1000)
  --epochs N              number of epochs to train (default: 250)
  --lr LR                 learning rate (default: 0.001)
  --no-cuda               disables CUDA training
  --seed S                random seed (default: 1)
  --log-interval N        how many batches to wait before logging training
                          status (default: 10)
  --routing_iterations    number of iterations for routing algorithm (default: 3)
  --with_reconstruction   should reconstruction layers be used

MNIST dataset will be downloaded automatically.

Results

The network trained with reconstruction and 3 routing iterations on MNIST dataset achieves 99.65% accuracy on test set. The test loss is still slightly decreasing, so the accuracy could probably be improved with more training and more careful learning rate schedule.

Visualizations

We can create visualizations of digit reconstructions from DigitCaps (e.g. Figure 3 in the paper)

Reconstructions

We can also visualize what each dimension of digit capsule represents (Section 5.1, Figure 4 in the paper).

Below, each row shows the reconstruction when one of the 16 dimensions in the DigitCaps representation is tweaked by intervals of 0.05 in the range [−0.25, 0.25].

Perturbations

We can see what individual dimensions represent for digit 7, e.g. dim6 - stroke thickness, dim11 - digit width, dim 15 - vertical shift.

Visualization examples are provided in a jupyter notebook

Comments
  • Error using with_reconstruction flag

    Error using with_reconstruction flag

    When using reconstruction flag, I get the following error:

    net.py:28: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. c = F.softmax(self.b) net.py:38: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. c = F.softmax(b_batch.view(-1, output_caps)).view(-1, input_caps, output_caps, 1) Traceback (most recent call last): File "net.py", line 274, in train(epoch) File "net.py", line 230, in train output, probs = model(data, target) File "/home/khan/anaconda3/envs/kt/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in call result = self.forward(*input, **kwargs) File "net.py", line 137, in forward reconstruction = self.reconstruction_net(x, target) File "/home/khan/anaconda3/envs/kt/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in call result = self.forward(*input, **kwargs) File "net.py", line 119, in forward mask.scatter_(1, target.data.view(-1, 1), 1.) RuntimeError: scatter_() received an invalid combination of arguments - got (int, torch.cuda.LongTensor, float), but expected one of:

    • (int dim, Variable index, Variable src) didn't match because some of the arguments have invalid types: (int, torch.cuda.LongTensor, float)
    • (int dim, Variable index, float value) didn't match because some of the arguments have invalid types: (int, torch.cuda.LongTensor, float)
    opened by w4zir 4
  • RuntimeError: cuda runtime error (2) : out of memory

    RuntimeError: cuda runtime error (2) : out of memory

    I have NVIDIA GTX 1060 with 6GB RAM and cpu RAM of 16GB. Getting the following error after 1st epoch:

    THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1512387374934/work/torch/lib/THC/generic/THCStorage.cu line=58 error=2 : out of memory Traceback (most recent call last): File "net.py", line 275, in test_loss = test() File "net.py", line 260, in test output, probs = model(data) File "/home/khan/anaconda3/envs/kt/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in call result = self.forward(*input, **kwargs) File "net.py", line 101, in forward x = self.digitCaps(x) File "/home/khan/anaconda3/envs/kt/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in call result = self.forward(input, **kwargs) File "net.py", line 62, in forward u_predict = caps_output.matmul(self.weights) File "/home/khan/anaconda3/envs/kt/lib/python3.6/site-packages/torch/autograd/variable.py", line 386, in matmul return torch.matmul(self, other) File "/home/khan/anaconda3/envs/kt/lib/python3.6/site-packages/torch/functional.py", line 219, in matmul tensor2_expanded = tensor2.expand((expand_batch_portion + tensor2_exp_size[-2:]))
    File "/home/khan/anaconda3/envs/kt/lib/python3.6/site-packages/torch/autograd/variable.py", line 280, in contiguous self.data = self.data.contiguous() RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1512387374934/work/torch/lib/THC/generic/THCStorage.cu:58

    opened by w4zir 2
  • logits bij not initialized to 0

    logits bij not initialized to 0

    Hi, For every routing call, logits b_ij needs to set to 0 as per the paper. In the code, it is initialized only once at the beginning. Or. Am I missing something here?

    opened by vijaykumar01 1
  • Can't load model

    Can't load model

    After the train, I encountered the following problems in the reconstruction_visualization process.

    KeyError: 'unexpected key "conv1.weight" in state_dict'

    opened by ryujaehun 1
  • There is a mistake when I run reconstruction_visualization.ipynb

    There is a mistake when I run reconstruction_visualization.ipynb

    When I run reconstruction_visualization.ipynb on jupyter-notebook, I find a mistake in In [4]. The error is in reconstructionnet = ReconstructionNet(16, 10), when I run it, there will be a mistake as follows :FileNotFoundError: [Errno 2] No such file or directory: '229_model_dict_3routing_reconstructionTrue.pth' How should I correct it?

    opened by Runkun-Lu 0
  • Bug fix and change of defaults

    Bug fix and change of defaults

    1. fix a mismatch between default and actual value of lr param.
    2. some changes to default values for low memory.
    3. compatibility change from .data[0] to item() for newer PyTorch versions.
    opened by comsaint 0
  • why use target/label to reconstruct image when test?

    why use target/label to reconstruct image when test?

    In inference time , why use the label to reconstruct the image while not use the prediction of the model and the digit capsules as the input of reconstruction module? If the predict is wrong, the reconstruction module will reconstruct the wrong class of image?

    opened by emobuAgent 0
  • Deprecated

    Deprecated

    /home/raul/Desktop/CapsNet-pytorch/net.py:28: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. c = F.softmax(self.b) /home/raul/Desktop/CapsNet-pytorch/net.py:38: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. c = F.softmax(b_batch.view(-1, output_caps)).view(-1, input_caps, output_caps, 1) Traceback (most recent call last): File "/home/raul/Desktop/CapsNet-pytorch/net.py", line 275, in train(epoch) File "/home/raul/Desktop/CapsNet-pytorch/net.py", line 244, in train 100. * batch_idx / len(train_loader), loss.data[0])) IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

    opened by raul1968 2
  • squash

    squash

    I have a question: according to the paper, the squash function only be used after the sum of prediction u-hat? and in this code, there is a squash after the primary capsule. I got really confused. class PrimaryCapsLayer(nn.Module): def __init__(self, input_channels, output_caps, output_dim, kernel_size, stride): super(PrimaryCapsLayer, self).__init__() self.conv = nn.Conv2d(input_channels, output_caps * output_dim, kernel_size=kernel_size, stride=stride) # input_channels = 256,output_caps = 32, output_dim = 8, kernel_size = 9, stride = 2 self.input_channels = input_channels self.output_caps = output_caps self.output_dim = output_dim def forward(self, input): out = self.conv(input) N, C, H, W = out.size() out = out.view(N, self.output_caps, self.output_dim, H, W) # will output N x OUT_CAPS x OUT_DIM out = out.permute(0, 1, 3, 4, 2).contiguous() out = out.view(out.size(0), -1, out.size(4)) out = squash(out) #####QUESTION?? return out

    opened by yimzhai3 0
Owner
Adam Bielski
PhD Candidate @ University of Bern, Computer Vision Group AI / ML / DL / CV
Adam Bielski
UAV-Networks-Routing is a Python simulator for experimenting routing algorithms and mac protocols on unmanned aerial vehicle networks.

UAV-Networks Simulator - Autonomous Networking - A.A. 20/21 UAV-Networks-Routing is a Python simulator for experimenting routing algorithms and mac pr

null 0 Nov 13, 2021
Fader Networks: Manipulating Images by Sliding Attributes - NIPS 2017

FaderNetworks PyTorch implementation of Fader Networks (NIPS 2017). Fader Networks can generate different realistic versions of images by modifying at

Facebook Research 753 Dec 23, 2022
Re-implementation of the vector capsule with dynamic routing

VectorCapsule Re-implementation of the vector capsule with dynamic routing We implement the vector capsule and dynamic routing via graph neural networ

ZhenchaoTang 10 Feb 10, 2022
Official implementation of the paper "Topographic VAEs learn Equivariant Capsules"

Topographic Variational Autoencoder Paper: https://arxiv.org/abs/2109.01394 Getting Started Install requirements with Anaconda: conda env create -f en

T. Andy Keller 69 Dec 12, 2022
PyTorch implementation of the Value Iteration Networks (VIN) (NIPS '16 best paper)

Value Iteration Networks in PyTorch Tamar, A., Wu, Y., Thomas, G., Levine, S., and Abbeel, P. Value Iteration Networks. Neural Information Processing

LEI TAI 75 Nov 24, 2022
Pytorch implementation of Value Iteration Networks (NIPS 2016 best paper)

VIN: Value Iteration Networks A quick thank you A few others have released amazing related work which helped inspire and improve my own implementation

Kent Sommer 297 Dec 26, 2022
PyTorch implementation of the NIPS-17 paper "Poincaré Embeddings for Learning Hierarchical Representations"

Poincaré Embeddings for Learning Hierarchical Representations PyTorch implementation of Poincaré Embeddings for Learning Hierarchical Representations

Facebook Research 1.6k Dec 25, 2022
Pytorch implementation for our ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visual Question Answering".

TRAnsformer Routing Networks (TRAR) This is an official implementation for ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visu

Ren Tianhe 49 Nov 10, 2022
This is the official pytorch implementation for our ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visual Question Answering" on VQA Task

?? ERASOR (RA-L'21 with ICRA Option) Official page of "ERASOR: Egocentric Ratio of Pseudo Occupancy-based Dynamic Object Removal for Static 3D Point C

Hyungtae Lim 225 Dec 29, 2022
Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning The predictive learning of spatiotemporal sequences aims to generate future

THUML: Machine Learning Group @ THSS 243 Dec 26, 2022
PyTorch implementation of spectral graph ConvNets, NIPS’16

Graph ConvNets in PyTorch October 15, 2017 Xavier Bresson http://www.ntu.edu.sg/home/xbresson https://github.com/xbresson https://twitter.com/xbresson

Xavier Bresson 287 Jan 4, 2023
Implementation of EMNLP 2017 Paper "Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog" using PyTorch and ParlAI

Language Emergence in Multi Agent Dialog Code for the Paper Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog Satwik Kottur, José M.

Karan Desai 105 Nov 25, 2022
A PyTorch implementation of the paper "Semantic Image Synthesis via Adversarial Learning" in ICCV 2017

Semantic Image Synthesis via Adversarial Learning This is a PyTorch implementation of the paper Semantic Image Synthesis via Adversarial Learning. Req

Seonghyeon Nam 146 Nov 25, 2022
Implementation of EMNLP 2017 Paper "Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog" using PyTorch and ParlAI

Language Emergence in Multi Agent Dialog Code for the Paper Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog Satwik Kottur, José M.

Karan Desai 105 Nov 25, 2022
The code release of paper 'Domain Generalization for Medical Imaging Classification with Linear-Dependency Regularization' NIPS 2020.

Domain Generalization for Medical Imaging Classification with Linear Dependency Regularization The code release of paper 'Domain Generalization for Me

Yufei Wang 56 Dec 28, 2022
Reimplementation of the paper "Attention, Learn to Solve Routing Problems!" in jax/flax.

JAX + Attention Learn To Solve Routing Problems Reinplementation of the paper Attention, Learn to Solve Routing Problems! using Jax and Flax. Fully su

Gabriela Surita 7 Dec 1, 2022
The PyTorch improved version of TPAMI 2017 paper: Face Alignment in Full Pose Range: A 3D Total Solution.

Face Alignment in Full Pose Range: A 3D Total Solution By Jianzhu Guo. [Updates] 2020.8.30: The pre-trained model and code of ECCV-20 are made public

Jianzhu Guo 3.4k Jan 2, 2023
PyTorch version of the paper 'Enhanced Deep Residual Networks for Single Image Super-Resolution' (CVPRW 2017)

About PyTorch 1.2.0 Now the master branch supports PyTorch 1.2.0 by default. Due to the serious version problem (especially torch.utils.data.dataloade

Sanghyun Son 2.1k Jan 1, 2023
Implementation supporting the ICCV 2017 paper "GANs for Biological Image Synthesis"

GANs for Biological Image Synthesis This codes implements the ICCV-2017 paper "GANs for Biological Image Synthesis". The paper and its supplementary m

Anton Osokin 95 Nov 25, 2022