The code for Expectation-Maximization Attention Networks for Semantic Segmentation (ICCV'2019 Oral)

Related tags

Deep Learning EMANet
Overview

EMANet

News

  • The bug in loading the pretrained model is now fixed. I have updated the .pth. To use it, download it again.
  • EMANet-101 gets 80.99 on the PASCAL VOC dataset (Thanks for Sensetimes' server). So, with a classic backbone(ResNet) instead of some newest ones(WideResNet, HRNet), EMANet still achieves the top performance.
  • EMANet-101 (OHEM) gets 81.14 in mIoU on Cityscapes val using single-scale inference, and 81.9 on test server with multi-scale inference.

Background

This repository is for Expectation-Maximization Attention Networks for Semantic Segmentation (to appear in ICCV 2019, Oral presentation),

by Xia Li, Zhisheng Zhong, Jianlong Wu, Yibo Yang, Zhouchen Lin and Hong Liu from Peking University.

The source code is now available!

citation

If you find EMANet useful in your research, please consider citing:

@inproceedings{li19,
    author={Xia Li and Zhisheng Zhong and Jianlong Wu and Yibo Yang and Zhouchen Lin and Hong Liu},
    title={Expectation-Maximization Attention Networks for Semantic Segmentation},
    booktitle={International Conference on Computer Vision},   
    year={2019},   
}

table of contents

Introduction

Self-attention mechanism has been widely used for various tasks. It is designed to compute the representation of each position by a weighted sum of the features at all positions. Thus, it can capture long-range relations for computer vision tasks. However, it is computationally consuming. Since the attention maps are computed w.r.t all other positions. In this paper, we formulate the attention mechanism into an expectation-maximization manner and iteratively estimate a much more compact set of bases upon which the attention maps are computed. By a weighted summation upon these bases, the resulting representation is low-rank and deprecates noisy information from the input. The proposed Expectation-Maximization Attention (EMA) module is robust to the variance of input and is also friendly in memory and computation. Moreover, we set up the bases maintenance and normalization methods to stabilize its training procedure. We conduct extensive experiments on popular semantic segmentation benchmarks including PASCAL VOC, PASCAL Context, and COCO Stuff, on which we set new records. EMA Unit

Design

As so many peers have starred at this repo, I feel the great pressure, and try to release the code with high quality. That's why I didn't release it until today (Aug, 22, 2018). It's known that the design of the code structure is not an easy thing. Different designs are suitable for different usage. Here, I aim at making research on Semantic Segmentation, especially on PASCAL VOC, more easier. So, I delete necessary encapsulation as much as possible, and leave over less than 10 python files. To be honest, the global variables in settings are not a good design for large project. But for research, it offers great flexibility. So, hope you can understand that

For research, I recommand seperatting each experiment with a folder. Each folder contains the whole project, and should be named as the experiment settings, such as 'EMANet101.moving_avg.l2norm.3stages'. Through this, you can keep tracks of all the experiments, and find their differences just by the 'diff' command.

Usage

  1. Install the libraries listed in the 'requirements.txt'
  2. Downloads images and labels of PASCAL VOC and SBD, decompress them together.
  3. Downloads the pretrained ResNet50 and ResNet101, unzip them, and put into the 'models' folder.
  4. Change the 'DATA_ROOT' in settings.py to where you place the dataset.
  5. Run sh clean.sh to clear the models and logs from the last experiment.
  6. Run python train.py for training and sh tensorboard.sh for visualization on your browser.
  7. Or you can download the pretraind model, put into the 'models' folder, and skip step 6.
  8. Run python eval.py for validation

Ablation Studies

The following results are referred from the paper. For this repo, it's not strange to get even higer performance. If so, I'd like you share it in the issue. By now, this repo only provides the SS inference. I may release the code for MS and Flip latter.

Tab 1. Detailed comparisons with Deeplabs. All results are achieved with the backbone ResNet-101 and output stride 8. The FLOPs and memory are computed with the input size 513×513. SS: Single scale input during test. MS: Multi-scale input. Flip: Adding left-right flipped input. EMANet (256) and EMANet (512) represent EMANet withthe number of input channels for EMA as 256 and 512, respectively.

Method SS MS+Flip FLOPs Memory Params
ResNet-101 - - 190.6G 2.603G 42.6M
DeeplabV3 78.51 79.77 +63.4G +66.0M +15.5M
DeeplabV3+ 79.35 80.57 +84.1G +99.3M +16.3M
PSANet 78.51 79.77 +56.3G +59.4M +18.5M
EMANet(256) 79.73 80.94 +21.1G +12.3M +4.87M
EMANet(512) 80.05 81.32 +43.1G +22.1M +10.0M

To be note, the majority overheads of EMANets come from the 3x3 convs before and after the EMA Module. As for the EMA Module itself, its computation is only 1/3 of a 3x3 conv's, and its parameter number is even smaller than a 1x1 conv.

Comparisons with SOTAs

Note that, for validation on the 'val' set, you just have to train 30k on the 'trainaug' set. But for test on the evaluation server, you should first pretrain on COCO, and then 30k on 'trainaug', and another 30k on the 'trainval' set.

Tab 2. Comparisons on the PASCAL VOC test dataset.

Method Backbone mIoU(%)
GCN ResNet-152 83.6
RefineNet ResNet-152 84.2
Wide ResNet WideResNet-38 84.9
PSPNet ResNet-101 85.4
DeeplabV3 ResNet-101 85.7
PSANet ResNet-101 85.7
EncNet ResNet-101 85.9
DFN ResNet-101 86.2
Exfuse ResNet-101 86.2
IDW-CNN ResNet-101 86.3
SDN DenseNet-161 86.6
DIS ResNet-101 86.8
EMANet101 ResNet-101 87.7
DeeplabV3+ Xception-65 87.8
Exfuse ResNeXt-131 87.9
MSCI ResNet-152 88.0
EMANet152 ResNet-152 88.2

Code Borrowed From

RESCAN

Pytorch-Encoding

Synchronized-BN

Comments
  • RuntimeError:  self.net.module.load_state_dict(obj['net'])

    RuntimeError: self.net.module.load_state_dict(obj['net'])

    using the pretrained model throw out this error:

    RuntimeError: Error(s) in loading state_dict for EMANet: Missing key(s) in state_dict: "extractor.4.0.conv1.weight", "extractor.4.0.bn1.weight", "extractor.4.0.bn1.bias", "extractor.4.0.bn1.running_mean", "extractor.4.0.bn1.running_var", "extractor.4.0.conv2.weight", ...... Unexpected key(s) in state_dict: "layer1.0.0.conv1.weight", "layer1.0.0.bn1.weight", "layer1.0.0.b

    someone can help?

    opened by taotaoyuhust 8
  • could run in one gpu

    could run in one gpu

    (base) pf@pf-System-Product-Name:~/EMANet$ python train.py 2019-12-06 21:37:49,527 - INFO - set log dir as ./logdir 2019-12-06 21:37:49,528 - INFO - set model dir as ./models Traceback (most recent call last): File "train.py", line 181, in main() File "train.py", line 146, in main sess = Session(dt_split='trainaug') File "train.py", line 93, in init self.net = DataParallel(self.net, device_ids=settings.DEVICES) File "/home/pf/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 131, in init _check_balance(self.device_ids) File "/home/pf/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 18, in _check_balance dev_props = [torch.cuda.get_device_properties(i) for i in device_ids] File "/home/pf/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 18, in dev_props = [torch.cuda.get_device_properties(i) for i in device_ids] File "/home/pf/anaconda3/lib/python3.6/site-packages/torch/cuda/init.py", line 301, in get_device_properties raise AssertionError("Invalid device id") AssertionError: Invalid device id

    opened by biexiangduo 4
  • The difference between the code and the original paper.

    The difference between the code and the original paper.

    Hi, thank you for releasing the code for EMANet. I find a difference between the code and the paper. The difference lies in the formulation of Equation 13 (in the paper). In the paper, the M step (bases reconstruct) is formulated as follows: image However, in the code, the M step is formulated as: mu = torch.bmm(x, z_)
    Actually, mu = torch.bmm(x, z_) is the weighted summation of X. However, Equation 13 (in the paper) is not the weighted summation of X. Anything wrong in the paper?

    opened by zhouyuan888888 3
  • Fail to reproduce your result: EMANet(512)80.05%?

    Fail to reproduce your result: EMANet(512)80.05%?

    Hi @XiaLiPKU ! This work is wonderful and thanks so much for releasing the code.

    May I ask a question? I used your pretrained model to evaluate on val set and got 80.50% mIoU using single-scale test, but when I trained this model from scratch, I can only get 79.44% finally, which is supposed to be 80.05%.

    I just followed your default settings(using pretrained ResNet weights, batch size 16,4 gpus, 30k iterations and so on...).

    Are there any other techniques special you adopted to get this final model?

    Looking forward to your reply!

    opened by Euphoria16 3
  • Can't train the model?

    Can't train the model?

    (base) davis@davis-MS-7B17:~/Network/EMANet-master$ python train.py 2019-08-31 13:50:14,703 - INFO - set log dir as ./logdir 2019-08-31 13:50:14,703 - INFO - set model dir as ./models 2019-08-31 13:50:17,131 - ERROR - No checkpoint ./models/latest.pth!

    The Training step is stopped, so I have to Keyboard Interrupt it... Does anybody know how to solve it?

    opened by fdujay 3
  • Output stride

    Output stride

    Hi, first of all thanks for your paper. You mention that for some nets the stride is 16 while for other 8. However, there is nothing on how do you recover it back to the original size. Do you use bi-linear upsampling? If yes, don't have a problem with borders and fine structures for using such a steep upsampling method?

    opened by arc144 3
  • How long does it take to train EMANet with a Resnet-101 backbone?

    How long does it take to train EMANet with a Resnet-101 backbone?

    Hello,

    Thank you for publishing the code to your excellent work. I was wondering how long it takes to train the EMANet with a Resnet-101 backbone - both for when the number of input channels is 256 and 512? How many GPUs did you use to achieve this training time?

    Thank you in advance :)

    opened by ghost 2
  • Is this a bug or trick?  image = (image - settings.MEAN) / settings.MEAN

    Is this a bug or trick? image = (image - settings.MEAN) / settings.MEAN

    https://github.com/XiaLiPKU/EMANet/blob/9a492d8aaad297e15eac044b3bb9583e63ffa3a3/dataset.py#L19

    I support this line should be

    image = (image - settings.MEAN) / settings.STD.
    

    Or is this line a trick?

    opened by gasvn 2
  • ResNet18 pretrained model

    ResNet18 pretrained model

    Hi,

    Thanks for providing the pre-trained ResNet50 and ResNet101 models. Do you have the pre-trained ResNet18 model that replaces the first 7x7 Conv to three 3x3 Conv? I have surfed it for a long time but unfortunately, I didn't find it. If you have saved this model, could you please share it with me? Many thanks in advance.

    opened by Mayy1994 1
  • selection of K

    selection of K

    In my opinion, besides T, the selection of K is also important (like in GMM or k-means). I didn't see any ablation study on the effect of different K's, did you do some experiments?

    Intuitively, I have the impression that mu represents different features for different classes, so the first K I would try is the number of classes (e.g. 19 for Cityscapes). Can you explain how you decide to use K=64?

    As the visualization of responsibility shows, different z's tend to represent different classes, so won't it happen that having K>number of class makes some z's be actually close to each other, making them eventually redundant?

    Thanks.

    opened by kwea123 1
  • I had some trouble,could you help me?

    I had some trouble,could you help me?

    Thanks for your reply!!! According to your ground truth,I made the ground truth of my dataset .But during the training, there was a problem,which I've compiled below. Emmmm, Can you help me? Maybe my dataset is too messy, and their boundaries are not obvious.What advice would you offer to me?

    RuntimeError: CUDA error: an illegal memory access was encountered terminate called after throwing an instance of 'c10::Error' what(): CUDA error: an illegal memory access was encountered (insert_events at /pytorch/c10/cuda/CUDACachingAllocator.cpp:564) frame #0: std::function<std::string ()>::operator()() const + 0x11 (0x7f5345247441 in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libc10.so) frame #1: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x2a (0x7f5345246d7a in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libc10.so) frame #2: + 0x13652 (0x7f534261a652 in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libc10_cuda.so) frame #3: c10::TensorImpl::release_resources() + 0x50 (0x7f5345237ce0 in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libc10.so) frame #4: + 0x30facb (0x7f52f071aacb in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libtorch.so.1) frame #5: + 0x376d60 (0x7f52f0781d60 in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libtorch.so.1) frame #6: + 0x3128ea (0x7f52f071d8ea in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libtorch.so.1) frame #7: torch::autograd::deleteFunction(torch::autograd::Function*) + 0xa2 (0x7f52f071d9a2 in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libtorch.so.1) frame #8: std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release() + 0xa2 (0x7f5330b81bb2 in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libtorch_python.so) frame #9: + 0x14216b (0x7f5330ba516b in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libtorch_python.so) frame #10: + 0x1421d9 (0x7f5330ba51d9 in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libtorch_python.so) frame #11: torch::autograd::Variable::Impl::release_resources() + 0x1b (0x7f52f0d5708b in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libtorch.so.1) frame #12: + 0x1420bb (0x7f5330ba50bb in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libtorch_python.so) frame #13: + 0x3c30f4 (0x7f5330e260f4 in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libtorch_python.so) frame #14: + 0x3c3141 (0x7f5330e26141 in /home/r/.conda/envs/pytorch/lib/python3.6/site-packages/torch/lib/libtorch_python.so) frame #15: + 0x19aa5e (0x55791a64ba5e in /home/r/.conda/envs/pytorch/bin/python3) frame #16: + 0xf1b77 (0x55791a5a2b77 in /home/r/.conda/envs/pytorch/bin/python3) frame #17: + 0xf1a07 (0x55791a5a2a07 in /home/r/.conda/envs/pytorch/bin/python3) frame #18: + 0xf1a1d (0x55791a5a2a1d in /home/r/.conda/envs/pytorch/bin/python3) frame #19: + 0xf1a1d (0x55791a5a2a1d in /home/r/.conda/envs/pytorch/bin/python3) frame #20: PyDict_SetItem + 0x3da (0x55791a5e963a in /home/r/.conda/envs/pytorch/bin/python3) frame #21: PyDict_SetItemString + 0x4f (0x55791a5f065f in /home/r/.conda/envs/pytorch/bin/python3) frame #22: PyImport_Cleanup + 0x99 (0x55791a655d89 in /home/r/.conda/envs/pytorch/bin/python3) frame #23: Py_FinalizeEx + 0x61 (0x55791a6c0231 in /home/r/.conda/envs/pytorch/bin/python3) frame #24: Py_Main + 0x35e (0x55791a6ca57e in /home/r/.conda/envs/pytorch/bin/python3) frame #25: main + 0xee (0x55791a59488e in /home/r/.conda/envs/pytorch/bin/python3) frame #26: __libc_start_main + 0xf0 (0x7f5348fdd830 in /lib/x86_64-linux-gnu/libc.so.6) frame #27: + 0x1c3160 (0x55791a674160 in /home/r/.conda/envs/pytorch/bin/python3)

    opened by RIKOYUKI 1
  • about voc12

    about voc12

    hi, I submitted the results of the val set and the test set to the official website for testing, but the two results differ by four points. How can I reduce this gap.

    opened by xLuge 0
  • About the BN!!!

    About the BN!!!

    I add other block to replace EMAU, but get some warning. I guess it's bn_lib you used not suitable for my block.

    `2020-07-30 20:05:26,727 - INFO - step: 1 loss: 2.429 lr: 0.009

    WARNING batched routines are designed for small sizes. It might be better to use the Native/Hybrid classical routines if you want good performance.

    ========================================================================================= WARNING batched routines are designed for small sizes. It might be better to use the Native/Hybrid classical routines if you want good performance.

    ========================================================================================= WARNING batched routines are designed for small sizes. It might be better to use the Native/Hybrid classical routines if you want good performance.

    ========================================================================================= WARNING batched routines are designed for small sizes. It might be better to use the Native/Hybrid classical routines if you want good performance.

    2020-07-30 20:05:29,586 - INFO - step: 2 loss: 2.398 lr: 0.009`

    opened by zyxu1996 1
  • The pre-trained Resnet-50 and Resnet-101 can't be downloaded

    The pre-trained Resnet-50 and Resnet-101 can't be downloaded

    When I click the link, there comes a problem: 'This XML file does not appear to have any style information associated with it. The document tree is shown below.' How to solve this?

    opened by caozhe1011 2
Code for the paper: Learning Adversarially Robust Representations via Worst-Case Mutual Information Maximization (https://arxiv.org/abs/2002.11798)

Representation Robustness Evaluations Our implementation is based on code from MadryLab's robustness package and Devon Hjelm's Deep InfoMax. For all t

Sicheng 19 Dec 7, 2022
This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Information Maximization for Multimodal Sentiment Analysis, accepted at EMNLP 2021.

MultiModal-InfoMax This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Informa

Deep Cognition and Language Research (DeCLaRe) Lab 89 Dec 26, 2022
Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation)

Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation) Download Synthia dataset The model uses

null 32 Sep 21, 2022
The implement of papar "Enhanced Graph Learning for Collaborative Filtering via Mutual Information Maximization"

SIGIR2021-EGLN The implement of paper "Enhanced Graph Learning for Collaborative Filtering via Mutual Information Maximization" Neural graph based Col

null 15 Dec 27, 2022
RE3: State Entropy Maximization with Random Encoders for Efficient Exploration

State Entropy Maximization with Random Encoders for Efficient Exploration (RE3) (ICML 2021) Code for State Entropy Maximization with Random Encoders f

Younggyo Seo 47 Nov 29, 2022
Self-Supervised Learning with Kernel Dependence Maximization

Self-Supervised Learning with Kernel Dependence Maximization This is the code for SSL-HSIC, a self-supervised learning loss proposed in the paper Self

DeepMind 29 Dec 29, 2022
Joint learning of images and text via maximization of mutual information

mutual_info_img_txt Joint learning of images and text via maximization of mutual information. This repository incorporates the algorithms presented in

Ruizhi Liao 10 Dec 22, 2022
Implementation of ICCV2021(Oral) paper - VMNet: Voxel-Mesh Network for Geodesic-aware 3D Semantic Segmentation

VMNet: Voxel-Mesh Network for Geodesic-Aware 3D Semantic Segmentation Created by Zeyu HU Introduction This work is based on our paper VMNet: Voxel-Mes

HU Zeyu 82 Dec 27, 2022
Pytorch Implementation for NeurIPS (oral) paper: Pixel Level Cycle Association: A New Perspective for Domain Adaptive Semantic Segmentation

Pixel-Level Cycle Association This is the Pytorch implementation of our NeurIPS 2020 Oral paper Pixel-Level Cycle Association: A New Perspective for D

null 87 Oct 19, 2022
Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORAL)

Scribble-Supervised LiDAR Semantic Segmentation Dataset and code release for the paper Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORA

null 102 Dec 25, 2022
CVPR2022 (Oral) - Rethinking Semantic Segmentation: A Prototype View

Rethinking Semantic Segmentation: A Prototype View Rethinking Semantic Segmentation: A Prototype View, Tianfei Zhou, Wenguan Wang, Ender Konukoglu and

Tianfei Zhou 239 Dec 26, 2022
code for paper"A High-precision Semantic Segmentation Method Combining Adversarial Learning and Attention Mechanism"

PyTorch implementation of UAGAN(U-net Attention Generative Adversarial Networks) This repository contains the source code for the paper "A High-precis

Tong 8 Apr 25, 2022
TorchDistiller - a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

This project is a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

yifan liu 147 Dec 3, 2022
FANet - Real-time Semantic Segmentation with Fast Attention

FANet Real-time Semantic Segmentation with Fast Attention Ping Hu, Federico Perazzi, Fabian Caba Heilbron, Oliver Wang, Zhe Lin, Kate Saenko , Stan Sc

Ping Hu 42 Nov 30, 2022
Learning Pixel-level Semantic Affinity with Image-level Supervision for Weakly Supervised Semantic Segmentation, CVPR 2018

Learning Pixel-level Semantic Affinity with Image-level Supervision This code is deprecated. Please see https://github.com/jiwoon-ahn/irn instead. Int

Jiwoon Ahn 337 Dec 15, 2022
Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral)

DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral) This repo is the official imp

如今我已剑指天涯 46 Dec 21, 2022
[CVPR 2022 Oral] MixFormer: End-to-End Tracking with Iterative Mixed Attention

MixFormer The official implementation of the CVPR 2022 paper MixFormer: End-to-End Tracking with Iterative Mixed Attention [Models and Raw results] (G

Multimedia Computing Group, Nanjing University 235 Jan 3, 2023