Contextual Attention Network: Transformer Meets U-Net

Overview

Contextual Attention Network: Transformer Meets U-Net

Contexual attention network for medical image segmentation with state of the art results on skin lesion segmentation, multiple myeloma cell segmentation. This method incorpotrates the transformer module into a U-Net structure so as to concomitantly capture long-range dependency along with resplendent local informations. If this code helps with your research please consider citing the following paper:

R. Azad, Moein Heidari, Yuli Wu and Dorit Merhof , "Contextual Attention Network: Transformer Meets U-Net", download link.

@article{reza2022contextual,
  title={Contextual Attention Network: Transformer Meets U-Net},
  author={Reza, Azad and Moein, Heidari and Yuli, Wu and Dorit, Merhof},
  journal={arXiv preprint arXiv:2203.01932},
  year={2022}
}

Please consider starring us, if you found it useful. Thanks

Updates

This code has been implemented in python language using Pytorch library and tested in ubuntu OS, though should be compatible with related environment. following Environement and Library needed to run the code:

  • Python 3
  • Pytorch

Run Demo

For training deep model and evaluating on each data set follow the bellow steps:
1- Download the ISIC 2018 train dataset from this link and extract both training dataset and ground truth folders inside the dataset_isic18.
2- Run Prepare_ISIC2018.py for data preperation and dividing data to train,validation and test sets.
3- Run train_skin.py for training the model using trainng and validation sets. The model will be train for 100 epochs and it will save the best weights for the valiation set.
4- For performance calculation and producing segmentation result, run evaluate_skin.py. It will represent performance measures and will saves related results in results folder.

Notice: For training and evaluating on ISIC 2017 and ph2 follow the bellow steps :

ISIC 2017- Download the ISIC 2017 train dataset from this link and extract both training dataset and ground truth folders inside the dataset_isic18\7.
then Run Prepare_ISIC2017.py for data preperation and dividing data to train,validation and test sets.
ph2- Download the ph2 dataset from this link and extract it then Run Prepare_ph2.py for data preperation and dividing data to train,validation and test sets.
Follow step 3 and 4 for model traing and performance estimation. For ph2 dataset you need to first train the model with ISIC 2017 data set and then fine-tune the trained model using ph2 dataset.

Quick Overview

Diagram of the proposed method

Perceptual visualization of the proposed Contextual Attention module.

Diagram of the proposed method

Results

For evaluating the performance of the proposed method, Two challenging task in medical image segmentaion has been considered. In bellow, results of the proposed approach illustrated.

Task 1: SKin Lesion Segmentation

Performance Comparision on SKin Lesion Segmentation

In order to compare the proposed method with state of the art appraoches on SKin Lesion Segmentation, we considered Drive dataset.

Methods (On ISIC 2017) Dice-Score Sensivity Specificaty Accuracy
Ronneberger and et. all U-net 0.8159 0.8172 0.9680 0.9164
Oktay et. all Attention U-net 0.8082 0.7998 0.9776 0.9145
Lei et. all DAGAN 0.8425 0.8363 0.9716 0.9304
Chen et. all TransU-net 0.8123 0.8263 0.9577 0.9207
Asadi et. all MCGU-Net 0.8927 0.8502 0.9855 0.9570
Valanarasu et. all MedT 0.8037 0.8064 0.9546 0.9090
Wu et. all FAT-Net 0.8500 0.8392 0.9725 0.9326
Azad et. all Proposed TMUnet 0.9164 0.9128 0.9789 0.9660

For more results on ISIC 2018 and PH2 dataset, please refer to the paper

SKin Lesion Segmentation segmentation result on test data

SKin Lesion Segmentation  result (a) Input images. (b) Ground truth. (c) U-net. (d) Gated Axial-Attention. (e) Proposed method without a contextual attention module and (f) Proposed method.

Multiple Myeloma Cell Segmentation

Performance Evalution on the Multiple Myeloma Cell Segmentation task

Methods mIOU
Frequency recalibration U-Net 0.9392
XLAB Insights 0.9360
DSC-IITISM 0.9356
Multi-scale attention deeplabv3+ 0.9065
U-Net 0.7665
Baseline 0.9172
Proposed 0.9395

Multiple Myeloma Cell Segmentation results

Multiple Myeloma Cell Segmentation result

Model weights

You can download the learned weights for each dataset in the following table.

Dataset Learned weights
ISIC 2018 TMUnet
ISIC 2017 TMUnet
Ph2 TMUnet

Query

All implementations are done by Reza Azad and Moein Heidari. For any query please contact us for more information.

rezazad68@gmail.com
moeinheidari7829@gmail.com
Comments
  • About the evaluate result

    About the evaluate result

    Hi! You have done a splendid work and thanks for your code. I have a question about the evaluate result. I used the learned weights you give to evaluate your model. And I didn't change any code. But the I found there is always 0 about the confusion matrix. As shown below: TMUNet It would be very kind of you to answer my question. Thanks.

    opened by GuoQingqing 1
  • About the boudary loss

    About the boudary loss

    Hi, you have done a splendid work and thanks for your code. I have a question about the boudary loss. In your code, the boudary loss did not use the boundary, loss_boundary = criteria(msk_pred, msk). And I run the code , I found the extracted boudary by the function "def Bextraction (img)" is not the edge. As shown below: 1

    It would be very kind of you to answer my question. Thanks.

    opened by GuoQingqing 1
  • FileNotFoundError when running ISIC2018 demo

    FileNotFoundError when running ISIC2018 demo

    I want to run the demo on the ISIC2018 dataset. I do what you said in the Readme, but in the third stage, when I run train_skin.py, this error has occurred:

    Traceback (most recent call last):
      File "train_skin.py", line 36, in <module>
        train_dataset = isic_loader(path_Data = data_path, train = True)
      File "/kaggle/working/TMUnet/loader.py", line 44, in __init__
        self.data   = np.load(path_Data+'data_train.npy')
      File "/opt/conda/lib/python3.7/site-packages/numpy/lib/npyio.py", line 417, in load
        fid = stack.enter_context(open(os_fspath(file), "rb"))
    FileNotFoundError: [Errno 2] No such file or directory: './processed_data/isic18/data_train.npy'
    

    I think this is because in the Prepare_ISIC2018.py you didn't give the path correctly. Because data_train.npy, mask_train.npy, and other generated npy files are in the main directory.

    opened by alibalapour 1
  • I have a question

    I have a question

    Hello author, it is an honor to read your article. While reading your paper, I have a question. You mentioned in the data processing section that the image is processed to 256 256, but in the FAT-net you compared, the image resolution is 224 224, but your comparison result does not seem to have any change, why is this?

    opened by hxp2396 1
  • CUDNN_STATUS_NOT_INITIALIZED error

    CUDNN_STATUS_NOT_INITIALIZED error

    .
    .
    .
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [54,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [55,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [56,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [57,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [58,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [59,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [60,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [61,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [62,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [63,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    Traceback (most recent call last):
      File "train_skin.py", line 79, in <module>
        tloss.backward()
      File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 307, in backward
        torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
      File "/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 156, in backward
        allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
    RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED
    

    I faced this error when I ran the ISIC2018 demo you provided. I think the input of BCELoss is not valid, leading to the error. Note that my notebook's torch and torchvision are updated.

    opened by alibalapour 0
  • AttributeError: module 'scipy.misc' has no attribute 'imread'

    AttributeError: module 'scipy.misc' has no attribute 'imread'

    Hi,Thank you for your excellent work. When I configured the virtual environmentrun following the requirement.txt and ran the Prepare_ISIC2017.py,it reported an error about AttributeError: module 'scipy.misc' has no attribute 'imread'. The vision of Scipy I used is 1.4.1. Could you tell me how to solve this problem?

    opened by Huaqitao 0
Owner
Reza Azad
Deep Learning and Computer Vision Researcher
Reza Azad
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Nov 15, 2022
Official implementation of NeurIPS 2021 paper "Contextual Similarity Aggregation with Self-attention for Visual Re-ranking"

CSA: Contextual Similarity Aggregation with Self-attention for Visual Re-ranking PyTorch training code for CSA (Contextual Similarity Aggregation). We

Hui Wu 19 Oct 21, 2022
U^2-Net - Portrait matting This repository explores possibilities of using the original u^2-net model for portrait matting.

U^2-Net - Portrait matting This repository explores possibilities of using the original u^2-net model for portrait matting.

Dennis Bappert 104 Nov 25, 2022
The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

MIC-DKFZ 1.2k Dec 4, 2022
Neural networks applied in recognizing guitar chords using python, AutoML.NET with C# and .NET Core

Chord Recognition Demo application The demo application is written in C# with .NETCore. As of July 9, 2020, the only version available is for windows

Andres Mauricio Rondon Patiño 24 Oct 22, 2022
U-2-Net: U Square Net - Modified for paired image training of style transfer

U2-Net: U Square Net Modified for paired image training of style transfer This is an unofficial repo making use of the code which was made available b

Doron Adler 43 Oct 3, 2022
RGBD-Net - This repository contains a pytorch lightning implementation for the 3DV 2021 RGBD-Net paper.

[3DV 2021] We propose a new cascaded architecture for novel view synthesis, called RGBD-Net, which consists of two core components: a hierarchical depth regression network and a depth-aware generator network.

Phong Nguyen Ha 4 May 26, 2022
Self-Learned Video Rain Streak Removal: When Cyclic Consistency Meets Temporal Correspondence

In this paper, we address the problem of rain streaks removal in video by developing a self-learned rain streak removal method, which does not require any clean groundtruth images in the training process.

Yang Wenhan 44 Dec 6, 2022
Code release for SLIP Self-supervision meets Language-Image Pre-training

SLIP: Self-supervision meets Language-Image Pre-training What you can find in this repo: Pre-trained models (with ViT-Small, Base, Large) and code to

Meta Research 611 Nov 27, 2022
ConvMAE: Masked Convolution Meets Masked Autoencoders

ConvMAE ConvMAE: Masked Convolution Meets Masked Autoencoders Peng Gao1, Teli Ma1, Hongsheng Li2, Jifeng Dai3, Yu Qiao1, 1 Shanghai AI Laboratory, 2 M

Alpha VL Team of Shanghai AI Lab 332 Nov 29, 2022
Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"

Deformable Attention Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DET

Phil Wang 125 Dec 2, 2022
Alex Pashevich 61 Nov 17, 2022
PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Haoyu Chen 70 Dec 2, 2022
Code for our ICASSP 2021 paper: SA-Net: Shuffle Attention for Deep Convolutional Neural Networks

SA-Net: Shuffle Attention for Deep Convolutional Neural Networks (paper) By Qing-Long Zhang and Yu-Bin Yang [State Key Laboratory for Novel Software T

Qing-Long Zhang 194 Nov 18, 2022
The open source code of SA-UNet: Spatial Attention U-Net for Retinal Vessel Segmentation.

SA-UNet: Spatial Attention U-Net for Retinal Vessel Segmentation(ICPR 2020) Overview This code is for the paper: Spatial Attention U-Net for Retinal V

Changlu Guo 142 Nov 21, 2022
Code for paper "ASAP-Net: Attention and Structure Aware Point Cloud Sequence Segmentation"

ASAP-Net This project implements ASAP-Net of paper ASAP-Net: Attention and Structure Aware Point Cloud Sequence Segmentation (BMVC2020). Overview We i

Hanwen Cao 26 Aug 25, 2022
Implementation of Uniformer, a simple attention and 3d convolutional net that achieved SOTA in a number of video classification tasks

Uniformer - Pytorch Implementation of Uniformer, a simple attention and 3d convolutional net that achieved SOTA in a number of video classification ta

Phil Wang 90 Nov 24, 2022
RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

RETRO - Pytorch (wip) Implementation of RETRO, Deepmind's Retrieval based Attent

Phil Wang 477 Dec 2, 2022
Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch

?? Flamingo - Pytorch Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the p

Phil Wang 579 Dec 5, 2022