A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

Overview

CapsNet-Tensorflow

Contributions welcome License Gitter

A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

capsVSneuron

Notes:

  1. The current version supports MNIST and Fashion-MNIST datasets. The current test accuracy for MNIST is 99.64%, and Fashion-MNIST 90.60%, see details in the Results section
  2. See dist_version for multi-GPU support
  3. Here(ηŸ₯乎) is an article explaining my understanding of the paper. It may be helpful in understanding the code.

Important:

If you need to apply CapsNet model to your own datasets or build up a new model with the basic block of CapsNet, please follow my new project CapsLayer, which is an advanced library for capsule theory, aiming to integrate capsule-relevant technologies, provide relevant analysis tools, develop related application examples, and promote the development of capsule theory. For example, you can use capsule layer block in your code easily with the API capsLayer.layers.fully_connected and capsLayer.layers.conv2d

Requirements

  • Python
  • NumPy
  • Tensorflow>=1.3
  • tqdm (for displaying training progress info)
  • scipy (for saving images)

Usage

Step 1. Download this repository with git or click the download ZIP button.

$ git clone https://github.com/naturomics/CapsNet-Tensorflow.git
$ cd CapsNet-Tensorflow

Step 2. Download MNIST or Fashion-MNIST dataset. In this step, you have two choices:

  • a) Automatic downloading with download_data.py script
$ python download_data.py   (for mnist dataset)
$ python download_data.py --dataset fashion-mnist --save_to data/fashion-mnist (for fashion-mnist dataset)
  • b) Manual downloading with wget or other tools, move and extract dataset into data/mnist or data/fashion-mnist directory, for example:
$ mkdir -p data/mnist
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
$ gunzip data/mnist/*.gz

Step 3. Start the training(Using the MNIST dataset by default):

$ python main.py
$ # or training for fashion-mnist dataset
$ python main.py --dataset fashion-mnist
$ # If you need to monitor the training process, open tensorboard with this command
$ tensorboard --logdir=logdir
$ # or use `tail` command on linux system
$ tail -f results/val_acc.csv

Step 4. Calculate test accuracy

$ python main.py --is_training=False
$ # for fashion-mnist dataset
$ python main.py --dataset fashion-mnist --is_training=False

Note: The default parameters of batch size is 128, and epoch 50. You may need to modify the config.py file or use command line parameters to suit your case, e.g. set batch size to 64 and do once test summary every 200 steps: python main.py --test_sum_freq=200 --batch_size=48

Results

The pictures here are plotted by tensorboard and my tool plot_acc.R

  • training loss

total_loss margin_loss reconstruction_loss

Here are the models I trained and my talk and something else:

Baidu Netdisk(password:ahjs)

  • The best val error(using reconstruction)
Routing iteration 1 3 4
val error 0.36 0.36 0.41
Paper 0.29 0.25 -

test_acc

My simple comments for capsule

  1. A new version neural unit(vector in vector out, not scalar in scalar out)
  2. The routing algorithm is similar to attention mechanism
  3. Anyway, a great potential work, a lot to be built upon

My weChat:

my_wechat

Reference

Comments
  • Why b_IJ is shared between single batch examples.

    Why b_IJ is shared between single batch examples.

    Forgive me if I got this wrong but it seems like the b_IJ are shared between all examples within a single batch (see reduce_sum and the shape).

    I didn't see any mention of the batches in the paper, so I have assumed that there is a separate set of b_IJ weights for every batch. Why do you think that it's better to share those variables?

    Edit: I've corrected the statement:

    b_IJ are shared between all batches

    to:

    b_IJ are shared between all examples within a single batch

    which is was I originally meant.

    opened by pkubik 8
  • Deer man

    Deer man

    I think we should make a wechat group here for who interest this kind of subject. My wechat is bn31201 . Hope your adding, make some deep communicating.

    opened by nb312 8
  • Training on different input dimensions than MNIST

    Training on different input dimensions than MNIST

    Thanks for writing the code so shortly after the article was released. I'm trying to change the structure such that the capsule network can be trained for any image(x,y,z), but I am having trouble re-structuring the code. Can you help me identify which lines needs to be modified. I am guessing all lines with ... 28, 28, 1) -> ... 32, 32, 3) for CIFAR 10. But I am still not able to make it work.

    Thank you again πŸ‘

    opened by servetcoskun 5
  • Routing algorithm

    Routing algorithm

    To the owner and all other visitors:

    I do not mean to be offensive, but I decided to speak out my understanding of this routing algorithm as I have not seen any correct implementation so far yet.

    The correct implementation of the routing algorithm should be treated something like the dynamic RNN in TensorFlow. In other words, if you implement it in a static way, and if you do 3 iterations, the two caps layers are actually 6 such layers. The primary layer performs line 4 and output to the digits layer, and then the digits layer performs line 5, 6, and 7 with b_ij updated, and then loop back to the primary layer again. This will need to use tf.while_loop if you use a dynamic way.

    What confuses me or stops me from implementing myself is I am not sure how the weights and biases associated with the conv units are updated, as I assume other than the weights and biases associated with the capsules, each individual conv unit inside still carries its own parameters. Maybe I missed this by reading the paper.

    Feel free to correct me if you believe I am wrong. Thanks.

    opened by bshao001 5
  • Cannot evaluate the model when using python main --is_training=False

    Cannot evaluate the model when using python main --is_training=False

    Evaluating/Testing the trained model using python main.py --is_training=False gives the following error ValueError: Can't load save_path when it is None.

    opened by kumarlamichhane 4
  • [Question] Could the CapsNet unit apply to other more complex architecture ?

    [Question] Could the CapsNet unit apply to other more complex architecture ?

    Hi!

    I'm a student interested in Speech Synthesis with neural networks. I suppose this CapsNet might improve the quality of synthesized speech, so I try to apply this great program to the other program to generate artificial speech with neural network.

    I would like to ask whether this CapsNet could replace other popular neural networks like CNN.

    Thank you for answering.

    opened by rild 3
  • Relu activation in PrimaryCap?

    Relu activation in PrimaryCap?

    the tf.contrib.layers.conv2d applies a relu activation,the PrimaryCap convolution does not included a relu activation before grouping neurons into capsules and then squashed, or did I miss something from the paper

    https://github.com/naturomics/CapsNet-Tensorflow/blob/894c79cd8434d7de7784006c1646ff3107bb2a4f/capsLayer.py#L59

    opened by oargueta3 3
  • Is your squashing input dimensions correct?

    Is your squashing input dimensions correct?

    If squashing is done per capsule, then whey is the input dimensions to it 32, 1152, 8, 1, where 32 is the batch size? Shouldn't it be 32, 668, 32, 1?

    opened by isaacgerg 3
  • Why num_outputs is mandatory?

    Why num_outputs is mandatory?

    https://github.com/naturomics/CapsNet-Tensorflow/blob/1e0668037447e89aca9173a688ce1965bf6a43c1/capsNet.py#L45

    Why num_outputs is set in this while it will not be used?

    opened by thibo73800 3
  • Note to Huadong

    Note to Huadong

    Hi Huadong,

    I've been running succesful tests of CapsNets with Pytorch and would like to compare notes with you. Maybe we can take our discussion offline? My email is: firstname.lastname[@]gmail.com

    Let me know!

    Tarry

    opened by TarrySingh 2
  • why average b_ij a cross example?

    why average b_ij a cross example?

    https://github.com/naturomics/CapsNet-Tensorflow/blob/master/capsLayer.py#L151

                # then matmul in the last tow dim: [16, 1].T x [16, 1] => [1, 1], reduce mean in the
                # batch_size dim, resulting in [1, 1152, 10, 1, 1]
                v_J_tiled = tf.tile(v_J, [1, 1152, 1, 1, 1])
                u_produce_v = tf.matmul(u_hat, v_J_tiled, transpose_a=True)
                assert u_produce_v.get_shape() == [cfg.batch_size, 1152, 10, 1, 1]
                b_IJ += tf.reduce_sum(u_produce_v, axis=0, keep_dims=True)
    

    Why would you need to average b across batch dimension? I don't see why would that be good, since that would make the model batch-size dependent. If there is any mention on this in the paper or other source, can you point out where and send a link, appreciated.

    opened by JerrikEph 2
  • How to save model in SavedModel or frozen_graph format?

    How to save model in SavedModel or frozen_graph format?

    Is there a way to save this model in SavedModel or frozen_graph format? I am trying to do inference using Tensorrt, which only supports these two formats.

    opened by Azuresonance 0
  • about a problem in gpu_version

    about a problem in gpu_version

    thanks your codes. I finished the normal version, but when I used the dist_version, it occurs a issue. The errors shows "AssertionError :assert not np.isnan(loss_value)" in distributed_train.py. in line 132. how's it going? can you help me ?

    opened by Questdream 0
  • a quesion

    a quesion

    1. Y = valY[:num_val_batch * cfg.batch_size].reshape((-1, 1))

    what is the use of this line in main.py? 2.Can we use the code for RGB image dataset? 3.What is the use of channel parameter in CapsNet function?

    opened by NilakshanKunananthaseelan 1
Owner
Huadong Liao
Explore Nature from an Omics Perspective
Huadong Liao
UAV-Networks-Routing is a Python simulator for experimenting routing algorithms and mac protocols on unmanned aerial vehicle networks.

UAV-Networks Simulator - Autonomous Networking - A.A. 20/21 UAV-Networks-Routing is a Python simulator for experimenting routing algorithms and mac pr

null 0 Nov 13, 2021
An attempt at the implementation of Glom, Geoffrey Hinton's new idea that integrates neural fields, predictive coding, top-down-bottom-up, and attention (consensus between columns)

GLOM - Pytorch (wip) An attempt at the implementation of Glom, Geoffrey Hinton's new idea that integrates neural fields, predictive coding,

Phil Wang 173 Dec 14, 2022
An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" in Pytorch.

GLOM An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" for MNIST Dataset. To understand this

null 50 Oct 19, 2022
An attempt at the implementation of GLOM, Geoffrey Hinton's paper for emergent part-whole hierarchies from data

GLOM TensorFlow This Python package attempts to implement GLOM in TensorFlow, which allows advances made by several different groups transformers, neu

Rishit Dagli 32 Feb 21, 2022
Re-implementation of the vector capsule with dynamic routing

VectorCapsule Re-implementation of the vector capsule with dynamic routing We implement the vector capsule and dynamic routing via graph neural networ

ZhenchaoTang 10 Feb 10, 2022
Official implementation of the paper "Topographic VAEs learn Equivariant Capsules"

Topographic Variational Autoencoder Paper: https://arxiv.org/abs/2109.01394 Getting Started Install requirements with Anaconda: conda env create -f en

T. Andy Keller 69 Dec 12, 2022
Pytorch implementation for our ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visual Question Answering".

TRAnsformer Routing Networks (TRAR) This is an official implementation for ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visu

Ren Tianhe 49 Nov 10, 2022
This is the official pytorch implementation for our ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visual Question Answering" on VQA Task

?? ERASOR (RA-L'21 with ICRA Option) Official page of "ERASOR: Egocentric Ratio of Pseudo Occupancy-based Dynamic Object Removal for Static 3D Point C

Hyungtae Lim 225 Dec 29, 2022
Reimplementation of the paper "Attention, Learn to Solve Routing Problems!" in jax/flax.

JAX + Attention Learn To Solve Routing Problems Reinplementation of the paper Attention, Learn to Solve Routing Problems! using Jax and Flax. Fully su

Gabriela Surita 7 Dec 1, 2022
Dynamic View Synthesis from Dynamic Monocular Video

Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer This repository contains code to compute depth from a

Intelligent Systems Lab Org 2.3k Jan 1, 2023
Dynamic View Synthesis from Dynamic Monocular Video

Dynamic View Synthesis from Dynamic Monocular Video Project Website | Video | Paper Dynamic View Synthesis from Dynamic Monocular Video Chen Gao, Ayus

Chen Gao 139 Dec 28, 2022
Dynamic vae - Dynamic VAE algorithm is used for anomaly detection of battery data

Dynamic VAE frame Automatic feature extraction can be achieved by probability di

null 10 Oct 7, 2022
MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.

MMdnn MMdnn is a comprehensive and cross-framework tool to convert, visualize and diagnose deep learning (DL) models. The "MM" stands for model manage

Microsoft 5.7k Jan 9, 2023
Compare outputs between layers written in Tensorflow and layers written in Pytorch

Compare outputs of Wasserstein GANs between TensorFlow vs Pytorch This is our testing module for the implementation of improved WGAN in Pytorch Prereq

Hung Nguyen 72 Dec 20, 2022
Neural machine translation between the writings of Shakespeare and modern English using TensorFlow

Shakespeare translations using TensorFlow This is an example of using the new Google's TensorFlow library on monolingual translation going from modern

Motoki Wu 245 Dec 28, 2022
Deep learning with dynamic computation graphs in TensorFlow

TensorFlow Fold TensorFlow Fold is a library for creating TensorFlow models that consume structured data, where the structure of the computation graph

null 1.8k Dec 28, 2022
DIT is a DTLS MitM proxy implemented in Python 3. It can intercept, manipulate and suppress datagrams between two DTLS endpoints and supports psk-based and certificate-based authentication schemes (RSA + ECC).

DIT - DTLS Interception Tool DIT is a MitM proxy tool to intercept DTLS traffic. It can intercept, manipulate and/or suppress DTLS datagrams between t

null 52 Nov 30, 2022
Codes for TIM2021 paper "Anchor-Based Spatio-Temporal Attention 3-D Convolutional Networks for Dynamic 3-D Point Cloud Sequences"

Codes for TIM2021 paper "Anchor-Based Spatio-Temporal Attention 3-D Convolutional Networks for Dynamic 3-D Point Cloud Sequences"

Intelligent Robotics and Machine Vision Lab 4 Jul 19, 2022