Code for the Paper: Conditional Variational Capsule Network for Open Set Recognition

Overview

Conditional Variational Capsule Network for Open Set Recognition

arXiv arXiv

This repository hosts the official code related to "Conditional Variational Capsule Network for Open Set Recognition", Y. Guo, G. Camporese, W. Yang, A. Sperduti, L. Ballan, arXiv:2104.09159, 2021. [Download]

alt text

If you use the code/models hosted in this repository, please cite the following paper and give a star to the repo:

@misc{guo2021conditional,
      title={Conditional Variational Capsule Network for Open Set Recognition}, 
      author={Yunrui Guo and Guglielmo Camporese and Wenjing Yang and Alessandro Sperduti and Lamberto Ballan},
      year={2021},
      eprint={2104.09159},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Updates

  • [2021/04/09] - The code is online,
  • [2021/07/22] - The paper has been accepted to ICCV-2021!

Install

Once you have cloned the repo, all the commands below should be runned inside the main project folder cvaecaposr:

# Clone the repo
$ git clone https://github.com/guglielmocamporese/cvaecaposr.git

# Go to the project directory
$ cd cvaecaposr

To run the code you need to have conda installed (version >= 4.9.2).

Furthermore, all the requirements for running the code are specified in the environment.yaml file and can be installed with:

# Install the conda env
$ conda env create --file environment.yaml

# Activate the conda env
$ conda activate cvaecaposr

Dataset Splits

You can find the dataset splits for all the datasets we have used (i.e. for MNIST, SVHN, CIFAR10, CIFAR+10, CIFAR+50 and TinyImageNet) in the splits.py file.

When you run the code the datasets will be automatically downloaded in the ./data folder and the split number selected is determined by the --split_num argument specified when you run the main.py file (more on how to run the code in the Experiment section below).

Model Checkpoints

You can download the model checkpoints using the download_checkpoints.sh script in the scripts folder by running:

# Extend script permissions
$ chmod +x ./scripts/download_checkpoints.sh

# Download model checkpoints
$ ./scripts/download_checkpoints.sh

After the download you will find the model checkpoints in the ./checkpoints folder:

  • ./checkpoints/mnist.ckpt
  • ./checkpoints/svhn.ckpt
  • ./checkpoints/cifar10.ckpt
  • ./checkpoints/cifar+10.ckpt
  • ./checkpoints/cifar+50.ckpt
  • ./checkpoints/tiny_imagenet.ckpt

The size of each checkpoint file is between ~370 MB and ~670 MB.

Experiments

For all the experiments we have used a GeForce RTX 2080 Ti (11GB of memory) GPU.

For the training you will need ~7300 MiB of GPU memory whereas for test ~5000 MiB of GPU memory.

Train

The CVAECapOSR model can be trained using the main.py program. Here we reported an example of a training script for the mnist experiment:

# Train
$ python main.py \
      --data_base_path "./data" \
      --dataset "mnist" \
      --val_ratio 0.2 \
      --seed 1234 \
      --batch_size 32 \
      --split_num 0 \
      --z_dim 128 \
      --lr 5e-5 \
      --t_mu_shift 10.0 \
      --t_var_scale 0.1 \
      --alpha 1.0 \
      --beta 0.01 \
      --margin 10.0 \
      --in_dim_caps 16 \
      --out_dim_caps 32 \
      --checkpoint "" \
      --epochs 100 \
      --mode "train"

For simplicity we provide all the training scripts for the different datasets in the scripts folder. Specifically, you will find:

  • train_mnist.sh
  • train_svhn.sh
  • train_cifar10.sh
  • train_cifar+10.sh
  • train_cifar+50.sh
  • train_tinyimagenet.sh

that you can run as follows:

# Extend script permissions
$ chmod +x ./scripts/train_{dataset}.sh # where you have to set a dataset name

# Run training
$ ./scripts/train_{dataset}.sh # where you have to set a dataset name

All the temporary files of the training stage (model checkpoints, tensorboard metrics, ...) are created at ./tmp/{dataset}/version_{version_number}/ where the dataset is specified in the train_{dataset}.sh script and version_number is an integer number that is tracked and computed automatically in order to not override training logs (each training will create unique files in different folders, with different versions).

Test

The CVAECapOSR model can be tested using the main.py program. Here we reported an example of a test script for the mnist experiment:

# Test
$ python main.py \
      --data_base_path "./data" \
      --dataset "mnist" \
      --val_ratio 0.2 \
      --seed 1234 \
      --batch_size 32 \
      --split_num 0 \
      --z_dim 128 \
      --lr 5e-5 \
      --t_mu_shift 10.0 \
      --t_var_scale 0.1 \
      --alpha 1.0 \
      --beta 0.01 \
      --margin 10.0 \
      --in_dim_caps 16 \
      --out_dim_caps 32 \
      --checkpoint "checkpoints/mnist.ckpt" \
      --mode "test"

For simplicity we provide all the test scripts for the different datasets in the scripts folder. Specifically, you will find:

  • test_mnist.sh
  • test_svhn.sh
  • test_cifar10.sh
  • test_cifar+10.sh
  • test_cifar+50.sh
  • test_tinyimagenet.sh

that you can run as follows:

# Extend script permissions
$ chmod +x ./scripts/test_{dataset}.sh # where you have to set a dataset name

# Run training
$ ./scripts/test_{dataset}.sh # where you have to set a dataset name

Model Reconstruction

Here we reported the reconstruction of some test samples of the model after training.

MNIST
alt text
SVHN
alt text
CIFAR10
alt text
TinyImageNet
alt text
You might also like...
Code for the paper
Code for the paper "MASTER: Multi-Aspect Non-local Network for Scene Text Recognition" (Pattern Recognition 2021)

MASTER-PyTorch PyTorch reimplementation of "MASTER: Multi-Aspect Non-local Network for Scene Text Recognition" (Pattern Recognition 2021). This projec

Unofficial implement with paper SpeakerGAN: Speaker identification with conditional generative adversarial network
Unofficial implement with paper SpeakerGAN: Speaker identification with conditional generative adversarial network

Introduction This repository is about paper SpeakerGAN , and is unofficially implemented by Mingming Huang ([email protected]), Tiezheng Wang (wtz920729

Official PyTorch implementation of
Official PyTorch implementation of "Adversarial Reciprocal Points Learning for Open Set Recognition"

Adversarial Reciprocal Points Learning for Open Set Recognition Official PyTorch implementation of "Adversarial Reciprocal Points Learning for Open Se

GB-CosFace: Rethinking Softmax-based Face Recognition from the Perspective of Open Set Classification

GB-CosFace: Rethinking Softmax-based Face Recognition from the Perspective of Open Set Classification This is the official pytorch implementation of t

DVG-Face: Dual Variational Generation for Heterogeneous Face Recognition, TPAMI 2021

DVG-Face: Dual Variational Generation for HFR This repo is a PyTorch implementation of DVG-Face: Dual Variational Generation for Heterogeneous Face Re

Official code of the paper
Official code of the paper "Expanding Low-Density Latent Regions for Open-Set Object Detection" (CVPR 2022)

OpenDet Expanding Low-Density Latent Regions for Open-Set Object Detection (CVPR2022) Jiaming Han, Yuqiang Ren, Jian Ding, Xingjia Pan, Ke Yan, Gui-So

Code for paper "Document-Level Argument Extraction by Conditional Generation". NAACL 21'

Argument Extraction by Generation Code for paper "Document-Level Argument Extraction by Conditional Generation". NAACL 21' Dependencies pytorch=1.6 tr

This repository holds the code for the paper "Deep Conditional Gaussian Mixture Model forConstrained Clustering".

Deep Conditional Gaussian Mixture Model for Constrained Clustering. This repository holds the code for the paper Deep Conditional Gaussian Mixture Mod

Official PyTorch code for WACV 2022 paper "CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows"

CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows WACV 2022 preprint:https://arxiv.org/abs/2107.1

Comments
  • About open-set face recognition

    About open-set face recognition

    Hi Guglielmo Camporese, congratulations to you for your work! I want to ask Can I can apply this model in open-set face recognition problem, for example in LFW dataset, by replace resnet34 by VGG-Face? Can you give me some advice about this ^^ many thanks

    opened by thanhuitha 1
  • Why did the experiment select the known class as the positive sample?

    Why did the experiment select the known class as the positive sample?

    I noticed that in the experimental code, you use known class as a positive sample to calculate related indicators. This does not seem to be consistent with the openset/ODD field. Most of the openset open source codes choose unknown class as a positive sample, such as the OSRCI method you compared.

    opened by leyiweb 1
  • Issues about the performance

    Issues about the performance

    Hi, thanks for your nice job.

    I ran the training script provided in this repo, and did not change any code. However, there is a significant performance gap between the code and your paper (for example, 0.588 v.s. 0.715 AUROC on tinyimagenet).

    Should I tune some learning parameters for increasing accuracy? I have tried to adjust the lr and epochs but it does not work. I am looking forward to your insightful suggestions for this.

    (cvaecaposr) xx@xxx:~/cvaecaposr$ sh ./scripts/train_tinyimagenet.sh
    {
        "data_base_path": "./data",
        "val_ratio": 0.2,
        "seed": 1234,
        "known_classes": [
            2,
            3,
            13,
            30,
            44,
            45,
            64,
            66,
            76,
            101,
            111,
            121,
            128,
            130,
            136,
            158,
            167,
            170,
            187,
            193
        ],
        "unknown_classes": [
            0,
            1,
            4,
            5,
            6,
            7,
            8,
            9,
            10,
            11,
            12,
            14,
            15,
            16,
            17,
            18,
            19,
            20,
            21,
            22,
            23,
            24,
            25,
            26,
            27,
            28,
            29,
            31,
            32,
            33,
            34,
            35,
            36,
            37,
            38,
            39,
            40,
            41,
            42,
            43,
            46,
            47,
            48,
            49,
            50,
            51,
            52,
            53,
            54,
            55,
            56,
            57,
            58,
            59,
            60,
            61,
            62,
            63,
            65,
            67,
            68,
            69,
            70,
            71,
            72,
            73,
            74,
            75,
            77,
            78,
            79,
            80,
            81,
            82,
            83,
            84,
            85,
            86,
            87,
            88,
            89,
            90,
            91,
            92,
            93,
            94,
            95,
            96,
            97,
            98,
            99,
            100,
            102,
            103,
            104,
            105,
            106,
            107,
            108,
            109,
            110,
            112,
            113,
            114,
            115,
            116,
            117,
            118,
            119,
            120,
            122,
            123,
            124,
            125,
            126,
            127,
            129,
            131,
            132,
            133,
            134,
            135,
            137,
            138,
            139,
            140,
            141,
            142,
            143,
            144,
            145,
            146,
            147,
            148,
            149,
            150,
            151,
            152,
            153,
            154,
            155,
            156,
            157,
            159,
            160,
            161,
            162,
            163,
            164,
            165,
            166,
            168,
            169,
            171,
            172,
            173,
            174,
            175,
            176,
            177,
            178,
            179,
            180,
            181,
            182,
            183,
            184,
            185,
            186,
            188,
            189,
            190,
            191,
            192,
            194,
            195,
            196,
            197,
            198,
            199
        ],
        "split_num": 0,
        "batch_size": 32,
        "num_workers": 0,
        "dataset": "tiny_imagenet",
        "z_dim": 128,
        "lr": 5e-05,
        "t_mu_shift": 10.0,
        "t_var_scale": 0.01,
        "alpha": 1.0,
        "beta": 0.01,
        "margin": 10.0,
        "in_dim_caps": 16,
        "out_dim_caps": 32,
        "checkpoint": "",
        "mode": "train",
        "epochs": 100
    }
    GPU available: True, used: True
    TPU available: False, using: 0 TPU cores
    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
    
      | Name    | Type      | Params
    --------------------------------------
    0 | enc     | ResNet34  | 21.3 M
    1 | vae_cap | VaeCap    | 23.5 M
    2 | fc      | Linear    | 10.5 M
    3 | dec     | Decoder   | 760 K
    4 | t_mean  | Embedding | 51.2 K
    5 | t_var   | Embedding | 51.2 K
    --------------------------------------
    56.1 M    Trainable params
    0         Non-trainable params
    56.1 M    Total params
    224.552   Total estimated model params size (MB)
    /xx/anaconda3/envs/cvaecaposr/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 80 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
      warnings.warn(*args, **kwargs)
    /xx/anaconda3/envs/cvaecaposr/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 80 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
      warnings.warn(*args, **kwargs)
    Epoch 20: 100%|██████████▉| 312/313 [00:38<00:00,  8.21it/s, loss=4.99e+03, v_num=0, train_acc=0.938, validation_acc=0.456Epoch    21: reducing learning rate of group 0 to 2.5000e-05.█████████████████████████████▊ | 62/63 [00:03<00:00, 16.56it/s]
    Epoch 28: 100%|██████████▉| 312/313 [00:37<00:00,  8.35it/s, loss=3.31e+03, v_num=0, train_acc=0.906, validation_acc=0.460Epoch    29: reducing learning rate of group 0 to 1.2500e-05.█████████████████████████████▊ | 62/63 [00:03<00:00, 16.48it/s]
    Epoch 42: 100%|██████████▉| 312/313 [00:38<00:00,  8.20it/s, loss=2.32e+03, v_num=0, train_acc=0.906, validation_acc=0.459Epoch    43: reducing learning rate of group 0 to 6.2500e-06.█████████████████████████████▊ | 62/63 [00:03<00:00, 16.35it/s]
    Epoch 48: 100%|██████████▉| 312/313 [00:37<00:00,  8.35it/s, loss=1.88e+03, v_num=0, train_acc=0.969, validation_acc=0.474Epoch    49: reducing learning rate of group 0 to 3.1250e-06.█████████████████████████████▊ | 62/63 [00:03<00:00, 16.42it/s]
    Epoch 54: 100%|██████████▉| 312/313 [00:37<00:00,  8.25it/s, loss=2.41e+03, v_num=0, train_acc=0.938, validation_acc=0.465Epoch    55: reducing learning rate of group 0 to 1.5625e-06.█████████████████████████████▊ | 62/63 [00:03<00:00, 16.49it/s]
    Epoch 60: 100%|██████████▉| 312/313 [00:37<00:00,  8.35it/s, loss=1.72e+03, v_num=0, train_acc=1.000, validation_acc=0.468Epoch    61: reducing learning rate of group 0 to 7.8125e-07.█████████████████████████████▊ | 62/63 [00:03<00:00, 16.53it/s]
    Epoch 66: 100%|██████████▉| 312/313 [00:37<00:00,  8.35it/s, loss=1.88e+03, v_num=0, train_acc=1.000, validation_acc=0.471Epoch    67: reducing learning rate of group 0 to 3.9063e-07.█████████████████████████████▊ | 62/63 [00:03<00:00, 16.44it/s]
    Epoch 72: 100%|██████████▉| 312/313 [00:38<00:00,  8.21it/s, loss=1.62e+03, v_num=0, train_acc=1.000, validation_acc=0.466Epoch    73: reducing learning rate of group 0 to 1.9531e-07.█████████████████████████████▊ | 62/63 [00:03<00:00, 16.36it/s]
    Epoch 78: 100%|██████████▉| 312/313 [00:37<00:00,  8.35it/s, loss=1.15e+03, v_num=0, train_acc=1.000, validation_acc=0.472Epoch    79: reducing learning rate of group 0 to 9.7656e-08.█████████████████████████████▊ | 62/63 [00:03<00:00, 16.41it/s]
    Epoch 84: 100%|██████████▉| 312/313 [00:40<00:00,  7.73it/s, loss=1.48e+03, v_num=0, train_acc=0.969, validation_acc=0.470Epoch    85: reducing learning rate of group 0 to 4.8828e-08.█████████████████████████████▊ | 62/63 [00:04<00:00, 15.34it/s]
    Epoch 90: 100%|██████████▉| 312/313 [00:38<00:00,  8.04it/s, loss=1.68e+03, v_num=0, train_acc=0.938, validation_acc=0.472Epoch    91: reducing learning rate of group 0 to 2.4414e-08.█████████████████████████████▊ | 62/63 [00:03<00:00, 16.67it/s]
    Epoch 96: 100%|██████████▉| 312/313 [00:38<00:00,  8.09it/s, loss=1.82e+03, v_num=0, train_acc=0.938, validation_acc=0.471Epoch    97: reducing learning rate of group 0 to 1.2207e-08.█████████████████████████████▊ | 62/63 [00:03<00:00, 16.40it/s]
    Epoch 99: 100%|███████████| 313/313 [00:38<00:00,  8.10it/s, loss=1.76e+03, v_num=0, train_acc=1.000, validation_acc=0.468Saving latest checkpoint...
    Epoch 99: 100%|███████████| 313/313 [00:40<00:00,  7.73it/s, loss=1.76e+03, v_num=0, train_acc=1.000, validation_acc=0.468]
    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
    /xx/anaconda3/envs/cvaecaposr/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 80 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
      warnings.warn(*args, **kwargs)
    Testing: 100%|██████████████████████████████████████████████████████████████████████████▊| 312/313 [00:23<00:00, 13.07it/s]/xx/anaconda3/envs/cvaecaposr/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: Metric `AUROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.
      warnings.warn(*args, **kwargs)
    Testing: 100%|███████████████████████████████████████████████████████████████████████████| 313/313 [00:23<00:00, 13.10it/s]
    --------------------------------------------------------------------------------
    DATALOADER:0 TEST RESULTS
    {'test_auroc': 0.5880855321884155}
    
    
    opened by zl9501 3
Owner
Guglielmo Camporese
PhD Student in Brain, Mind and Computer Science and Applied Scientist Intern at Amazon. Machine Learning for Videos, Images and Audio Speech contexts.
Guglielmo Camporese
A PyTorch implementation of "Capsule Graph Neural Network" (ICLR 2019).

CapsGNN ⠀⠀ A PyTorch implementation of Capsule Graph Neural Network (ICLR 2019). Abstract The high-quality node embeddings learned from the Graph Neur

Benedek Rozemberczki 1.2k Jan 2, 2023
VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech

VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech Jaehyeon Kim, Jungil Kong, and Juhee Son In our rece

Jaehyeon Kim 1.7k Jan 8, 2023
Automatically download the cwru data set, and then divide it into training data set and test data set

Automatically download the cwru data set, and then divide it into training data set and test data set.自动下载cwru数据集,然后分训练数据集和测试数据集

null 6 Jun 27, 2022
Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR 2022)

Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR2022)[paper] Authors: Chenhang He, Ruihuang Li, Shuai Li, L

Billy HE 141 Dec 30, 2022
OpenGAN: Open-Set Recognition via Open Data Generation

OpenGAN: Open-Set Recognition via Open Data Generation ICCV 2021 (oral) Real-world machine learning systems need to analyze novel testing data that di

Shu Kong 90 Jan 6, 2023
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
VR-Caps: A Virtual Environment for Active Capsule Endoscopy

VR-Caps: A Virtual Environment for Capsule Endoscopy Overview We introduce a virtual active capsule endoscopy environment developed in Unity that prov

DeepMIA Lab 90 Dec 27, 2022
Re-implementation of the vector capsule with dynamic routing

VectorCapsule Re-implementation of the vector capsule with dynamic routing We implement the vector capsule and dynamic routing via graph neural networ

ZhenchaoTang 10 Feb 10, 2022
Capsule endoscopy detection DACON challenge

capsule_endoscopy_detection (DACON Challenge) Overview Yolov5, Yolor, mmdetection기반의 모델을 사용 (총 11개 모델 앙상블) 모든 모델은 학습 시 Pretrained Weight을 yolov5, yolo

MAILAB 11 Nov 25, 2022
[CVPR 2021] Released code for Counterfactual Zero-Shot and Open-Set Visual Recognition

Counterfactual Zero-Shot and Open-Set Visual Recognition This project provides implementations for our CVPR 2021 paper Counterfactual Zero-S

null 144 Dec 24, 2022