Code for our paper "Multi-scale Guided Attention for Medical Image Segmentation"

Overview

Medical Image Segmentation with Guided Attention

This repository contains the code of our paper:
"'Multi-scale self-guided attention for medical image segmentation'", which has been recently accepted at the Journal of Biomedical And Health Informatics (JBHI).

Abstract

Even though convolutional neural networks (CNNs) are driving progress in medical image segmentation, standard models still have some drawbacks. First, the use of multi-scale approaches, i.e., encoder-decoder architectures, leads to a redundant use of information, where similar low-level features are extracted multiple times at multiple scales. Second, long-range feature dependencies are not efficiently modeled, resulting in nonoptimal discriminative feature representations associated with each semantic class. In this paper we attempt to overcome these limitations with the proposed architecture, by capturing richer contextual dependencies based on the use of guided self-attention mechanisms. This approach is able to integrate local features with their corresponding global dependencies, as well as highlight interdependent channel maps in an adaptive manner. Further, the additional loss between different modules guides the attention mechanisms to neglect irrelevant information and focus on more discriminant regions of the image by emphasizing relevant feature associations. We evaluate the proposed model in the context of abdominal organ segmentation on magnetic resonance imaging (MRI). A series of ablation experiments support the importance of these attention modules in the proposed architecture. In addition, compared to other state-of-the-art segmentation networks our model yields better segmentation performance, increasing the accuracy of the predictions while reducing the standard deviation. This demonstrates the efficiency of our approach to generate precise and reliable automatic segmentations of medical images.

Design of the Proposed Model

model

Results

Result

Requirements

  • The code has been written in Python (3.6) and requires pyTorch (version 1.1.0)
  • Install the dependencies using pip install -r requirements.txt

Preparing your data

You have to split your data into three folders: train/val/test. Each folder will contain two sub-folders: Img and GT, which contain the png files for the images and their corresponding ground truths. The naming of these images is important, as the code to save the results temporarily to compute the 3D DSC, for example, is sensitive to their names.

Specifically, the convention we follow for the names is as follows:

  • Subj_Xslice_Y.png where X indicates the subject number (or ID) and Y is the slice number within the whole volume. (Do not use 0 padding for numbers, i.e., the first slice should be 1 and not 01)
  • The corresponding mask must be named in the same way as the image.

An example of a sample image is added in dataset

Running the code

Note: Set the data path appropriately in src/main.py before running the code.

To run the code you simply need to use the following script:

bash train.sh

If you use this code for your research, please consider citing our paper:

@article{sinha2020multi,
  title={Multi-scale self-guided attention for medical image segmentation.},
  author={Sinha, A and Dolz, J},
  journal={IEEE Journal of Biomedical and Health Informatics},
  year={2020}
}
Comments
  • Results not matching with paper?

    Results not matching with paper?

    Hi, I am trying to reproduce the results but could not get the results closer to those reported in paper. Did you utilize validation set for reporting the results or test set (in table 4)? Also, could you provide with the training, test and validation splits?

    opened by samra-irshad 7
  • Segmentation target and output size mismatch during loss calculation

    Segmentation target and output size mismatch during loss calculation

    Hey, thanks for sharing your research. I was trying to train the model using your code but could not due to following issue:

    ValueError: Expected input batch_size (1) to match target batch_size (256).

    I am using batch size 1 and the sizes for inputs that is fed into (loss0 = CE_loss(outputs0, Segmentation_class)) are:

    Segmentation_class torch.Size([256, 256])
    outputs0 torch.Size([1, 5, 256, 256])
    

    The target segmentation should be properly converted

    opened by samra-irshad 4
  • Issue in common/utils.py

    Issue in common/utils.py

    On line 33, the following file location is requested path_GT = './DataSet_Challenge/GT_Nifti/Val_1' What does this exactly correspond to? You had instructed to put the slices in Img/ and masks in GT/ for each of test, train and val. So kindly clarify what does this correspond to. It could be a greater help if you can host the entire data that you are using, in the exact format on a drive link. @sinAshish @josedolz

    opened by DefUs3r 3
  • Evaluation metrics (dice coefficient, volume similarity and mean surface distance) on 2D or 3D segmentations?

    Evaluation metrics (dice coefficient, volume similarity and mean surface distance) on 2D or 3D segmentations?

    Hi, I am trying to reproduce the results given in paper. Initially I assumed the dice scores and volume similarity has been calculated for 2D segmentation results but I saw you mentioned 'Since inter- slice distances and x-y spacing for each individual scan are not provided, we report these results on voxels.' in paper, does that mean the dice scores, volume similarity and MSD have been calculated on 3D segmentations (i.e., first reconstructing 3D segmentations and then evaluating)? Also, I could not find formulation for volume similarity and MSD, it would be great if you can append these evaluation metrics in your source code.

    opened by samra-irshad 3
  • Questions about picture size and output channels

    Questions about picture size and output channels

    Thank you for sharing your code!

    I have been reading your paper and code these days, I found that 7 subjects' picture size of CHAOS dataset (T1 DUAL, In Phase) is 288288, which are different with other 13 subjects.So I'm confused whether to crop or resize them to 256256.

    Besides, I found that the output channels in your my_stacked_danet.py python file was set as 5: ` self.predict4 = nn.Conv2d(64, 5, kernel_size=1) self.predict3 = nn.Conv2d(64, 5, kernel_size=1) self.predict2 = nn.Conv2d(64, 5, kernel_size=1) self.predict1 = nn.Conv2d(64, 5, kernel_size=1)

        self.predict4_2 = nn.Conv2d(64, 5, kernel_size=1)
        self.predict3_2 = nn.Conv2d(64, 5, kernel_size=1)
        self.predict2_2 = nn.Conv2d(64, 5, kernel_size=1)
        self.predict1_2 = nn.Conv2d(64, 5, kernel_size=1)`
    

    But there are 4 classes ,if I'm right,in the dataset,so here comes another question: Why is it?

    Waiting for your guidance in your spare time,thanks! TAT

    opened by JasonmorrowGuo 3
  • Model complexity

    Model complexity

    Hello and congratulation for your work,

    I study your paper and decided to reproduce the models in tensorflow, however I get an elevated number of parameters for a 128*128 original input : about 106 millions with ResNet101 in my implementation. Moreover, when I check the number of parameters of the ResNext101 you use in Pytorch, I get about 88 millions for the ResNext alone. Since this doesn't correspond to the Table VIII of your appendix at all, I wonder what the numbers in this table exactly corresponds to. Anyway thanks for your contribution.

    Joris Fournel

    opened by Jorisfournel 2
  • AttributeError: cannot assign module before Module.__init__() call

    AttributeError: cannot assign module before Module.__init__() call

    Hi, @sinAshish I 'm trying to run your code on my dataset. But I got this error : (I think the problem may be about init() in runTraining(args) .) (Do you have any idea to solve it ? Thanks~)

    docker@warriors:[/Desktop/CodeFolder/attention/multi_scale_guided_attention]$ bash train.sh

    ~~~~~~~~ Starting the training... ~~~~~~

    Dataset: ./DataSet/ ~~~~~~~~~~~ Creating the DAF Stacked model ~~~~~~~~~~ Traceback (most recent call last): File "src/main.py", line 349, in runTraining(args) File "src/main.py", line 89, in runTraining net = DAF_stack() File "/Desktop/CodeFolder/attention/multi_scale_guided_attention/src/models/my_stacked_danet.py", line 61, in init self.pam_attention_1_1= PAM_CAM_Layer(64, True) File "/Desktop/CodeFolder/attention/multi_scale_guided_attention/src/models/attention.py", line 175, in init nn.PReLU() File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 565, in setattr "cannot assign module before Module.init() call") AttributeError: cannot assign module before Module.init() call

    opened by sugerpopo 2
  • About the code before softmax in CAM_Module

    About the code before softmax in CAM_Module

    Hi, thanks for sharing this awesome project :)

    Here's a question while reading your source code. In CAM_Module, there is one line code before softmax function. It doesn't exist in PAM_Module. According to my understanding, it means you use the maximum value (which calculated by query dot key) each channel vector to minus every value respectively. But...it will be the larger number equals the more irrelevant channel, right? Sorry... I cannot understand this, could you kindly explain it for me? Thanks a lot!

    CAM_Module

    opened by kaneyxx 2
  • Parameters for training on other models (UNet, DANet, PAN, DAF)

    Parameters for training on other models (UNet, DANet, PAN, DAF)

    Hi, I can see in your paper that you have compared the performance of your model with several other models (UNet, DANet, PAN, DAF), however, I am not sure if you have used the same training parameters (batch-size, learning rate, total number of epochs etc) for training the datasets (CHAOS) on other models? Are the training parameters same for all models?

    opened by samra-irshad 1
  • For the dataset split

    For the dataset split

    Hi, Thank you for your contribution, which inspires me a lot. Can you provide the Chaos dataset split list(i.e. train/val/test/ in three-fold) for reproducing? Thank you!

    opened by JiYuanFeng 1
  • Looks like the code of your attention module is exactly the same as the attention module of DANet?

    Looks like the code of your attention module is exactly the same as the attention module of DANet?

    Looks like the code of your attention module is exactly the same as the attention module of DANet? It seems to be different from the structure of the paper

    opened by ustczhouyu 1
  • About the pre-trained weights

    About the pre-trained weights

    We are sorry that we have re-trained your code on BRATS and CHAOS datasets in our experiments, but we did not get the result reported in the paper. Hence, can you release the re-trained weights, which will help us more accurately report your results?

    opened by Joyies 2
  • How to generate png data?

    How to generate png data?

    Hi! I am interested in the repo. But the released code have not mention how to generate png data. Could you share the preprocessing dataset code? Thanks!

    opened by ee-hash 1
Owner
Ashish Sinha
Working on understanding visionđź––
Ashish Sinha
Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

CoProtector Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

Zhensu Sun 1 Oct 26, 2021
Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166

Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation In this repo, we provide the code for our paper : "Few-Shot Segmentation Wit

Malik Boudiaf 138 Dec 12, 2022
Code for our ICASSP 2021 paper: SA-Net: Shuffle Attention for Deep Convolutional Neural Networks

SA-Net: Shuffle Attention for Deep Convolutional Neural Networks (paper) By Qing-Long Zhang and Yu-Bin Yang [State Key Laboratory for Novel Software T

Qing-Long Zhang 199 Jan 8, 2023
Code for our CVPR 2021 paper "MetaCam+DSCE"

Joint Noise-Tolerant Learning and Meta Camera Shift Adaptation for Unsupervised Person Re-Identification (CVPR'21) Introduction Code for our CVPR 2021

FlyingRoastDuck 59 Oct 31, 2022
Code for our CVPR2021 paper coordinate attention

Coordinate Attention for Efficient Mobile Network Design (preprint) This repository is a PyTorch implementation of our coordinate attention (will appe

Qibin (Andrew) Hou 726 Jan 5, 2023
[CVPR2021] The source code for our paper 《Removing the Background by Adding the Background: Towards Background Robust Self-supervised Video Representation Learning》.

TBE The source code for our paper "Removing the Background by Adding the Background: Towards Background Robust Self-supervised Video Representation Le

Jinpeng Wang 150 Dec 28, 2022
Code for our paper at ECCV 2020: Post-Training Piecewise Linear Quantization for Deep Neural Networks

PWLQ Updates 2020/07/16 - We are working on getting permission from our institution to release our source code. We will release it once we are granted

null 54 Dec 15, 2022
Code for our paper Domain Adaptive Semantic Segmentation with Self-Supervised Depth Estimation

CorDA Code for our paper Domain Adaptive Semantic Segmentation with Self-Supervised Depth Estimation Prerequisite Please create and activate the follo

Qin Wang 60 Nov 30, 2022
the code for our CVPR 2021 paper Bilateral Grid Learning for Stereo Matching Network [BGNet]

BGNet This repository contains the code for our CVPR 2021 paper Bilateral Grid Learning for Stereo Matching Network [BGNet] Environment Python 3.6.* C

3DCV developer 87 Nov 29, 2022
PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Haoyu Chen 71 Dec 30, 2022
Code for reproducing our analysis in the paper titled: Image Cropping on Twitter: Fairness Metrics, their Limitations, and the Importance of Representation, Design, and Agency

Image Crop Analysis This is a repo for the code used for reproducing our Image Crop Analysis paper as shared on our blog post. If you plan to use this

Twitter Research 239 Jan 2, 2023
Code for our ACL 2021 paper - ConSERT: A Contrastive Framework for Self-Supervised Sentence Representation Transfer

ConSERT Code for our ACL 2021 paper - ConSERT: A Contrastive Framework for Self-Supervised Sentence Representation Transfer Requirements torch==1.6.0

Yan Yuanmeng 478 Dec 25, 2022
code for our paper "Source Data-absent Unsupervised Domain Adaptation through Hypothesis Transfer and Labeling Transfer"

SHOT++ Code for our TPAMI submission "Source Data-absent Unsupervised Domain Adaptation through Hypothesis Transfer and Labeling Transfer" that is ext

null 75 Dec 16, 2022
Code for our paper "SimCLS: A Simple Framework for Contrastive Learning of Abstractive Summarization", ACL 2021

SimCLS Code for our paper: "SimCLS: A Simple Framework for Contrastive Learning of Abstractive Summarization", ACL 2021 1. How to Install Requirements

Yixin Liu 150 Dec 12, 2022
Code for our paper "Sematic Representation for Dialogue Modeling" in ACL2021

AMR-Dialogue An implementation for paper "Semantic Representation for Dialogue Modeling". You may find our paper here. Requirements python 3.6 pytorch

xfbai 45 Dec 26, 2022
Code for our ACL 2021 paper "One2Set: Generating Diverse Keyphrases as a Set"

One2Set This repository contains the code for our ACL 2021 paper “One2Set: Generating Diverse Keyphrases as a Set”. Our implementation is built on the

Jiacheng Ye 63 Jan 5, 2023
Code for our TKDE paper "Understanding WeChat User Preferences and “Wow” Diffusion"

wechat-wow-analysis Understanding WeChat User Preferences and “Wow” Diffusion. Fanjin Zhang, Jie Tang, Xueyi Liu, Zhenyu Hou, Yuxiao Dong, Jing Zhang,

null 18 Sep 16, 2022
This is the source code for our ICLR2021 paper: Adaptive Universal Generalized PageRank Graph Neural Network.

GPRGNN This is the source code for our ICLR2021 paper: Adaptive Universal Generalized PageRank Graph Neural Network. Hidden state feature extraction i

Jianhao 92 Jan 3, 2023
PyTorch code for our paper "Image Super-Resolution with Non-Local Sparse Attention" (CVPR2021).

Image Super-Resolution with Non-Local Sparse Attention This repository is for NLSN introduced in the following paper "Image Super-Resolution with Non-

null 143 Dec 28, 2022