Contextual Attention Network: Transformer Meets U-Net

Overview

Contextual Attention Network: Transformer Meets U-Net

Contexual attention network for medical image segmentation with state of the art results on skin lesion segmentation, multiple myeloma cell segmentation. This method incorpotrates the transformer module into a U-Net structure so as to concomitantly capture long-range dependency along with resplendent local informations. If this code helps with your research please consider citing the following paper:

R. Azad, Moein Heidari, Yuli Wu and Dorit Merhof , "Contextual Attention Network: Transformer Meets U-Net", download link.

@article{reza2022contextual,
  title={Contextual Attention Network: Transformer Meets U-Net},
  author={Reza, Azad and Moein, Heidari and Yuli, Wu and Dorit, Merhof},
  journal={arXiv preprint arXiv:2203.01932},
  year={2022}
}

Please consider starring us, if you found it useful. Thanks

Updates

This code has been implemented in python language using Pytorch library and tested in ubuntu OS, though should be compatible with related environment. following Environement and Library needed to run the code:

  • Python 3
  • Pytorch

Run Demo

For training deep model and evaluating on each data set follow the bellow steps:
1- Download the ISIC 2018 train dataset from this link and extract both training dataset and ground truth folders inside the dataset_isic18.
2- Run Prepare_ISIC2018.py for data preperation and dividing data to train,validation and test sets.
3- Run train_skin.py for training the model using trainng and validation sets. The model will be train for 100 epochs and it will save the best weights for the valiation set.
4- For performance calculation and producing segmentation result, run evaluate_skin.py. It will represent performance measures and will saves related results in results folder.

Notice: For training and evaluating on ISIC 2017 and ph2 follow the bellow steps :

ISIC 2017- Download the ISIC 2017 train dataset from this link and extract both training dataset and ground truth folders inside the dataset_isic18\7.
then Run Prepare_ISIC2017.py for data preperation and dividing data to train,validation and test sets.
ph2- Download the ph2 dataset from this link and extract it then Run Prepare_ph2.py for data preperation and dividing data to train,validation and test sets.
Follow step 3 and 4 for model traing and performance estimation. For ph2 dataset you need to first train the model with ISIC 2017 data set and then fine-tune the trained model using ph2 dataset.

Quick Overview

Diagram of the proposed method

Perceptual visualization of the proposed Contextual Attention module.

Diagram of the proposed method

Results

For evaluating the performance of the proposed method, Two challenging task in medical image segmentaion has been considered. In bellow, results of the proposed approach illustrated.

Task 1: SKin Lesion Segmentation

Performance Comparision on SKin Lesion Segmentation

In order to compare the proposed method with state of the art appraoches on SKin Lesion Segmentation, we considered Drive dataset.

Methods (On ISIC 2017) Dice-Score Sensivity Specificaty Accuracy
Ronneberger and et. all U-net 0.8159 0.8172 0.9680 0.9164
Oktay et. all Attention U-net 0.8082 0.7998 0.9776 0.9145
Lei et. all DAGAN 0.8425 0.8363 0.9716 0.9304
Chen et. all TransU-net 0.8123 0.8263 0.9577 0.9207
Asadi et. all MCGU-Net 0.8927 0.8502 0.9855 0.9570
Valanarasu et. all MedT 0.8037 0.8064 0.9546 0.9090
Wu et. all FAT-Net 0.8500 0.8392 0.9725 0.9326
Azad et. all Proposed TMUnet 0.9164 0.9128 0.9789 0.9660

For more results on ISIC 2018 and PH2 dataset, please refer to the paper

SKin Lesion Segmentation segmentation result on test data

SKin Lesion Segmentation  result (a) Input images. (b) Ground truth. (c) U-net. (d) Gated Axial-Attention. (e) Proposed method without a contextual attention module and (f) Proposed method.

Multiple Myeloma Cell Segmentation

Performance Evalution on the Multiple Myeloma Cell Segmentation task

Methods mIOU
Frequency recalibration U-Net 0.9392
XLAB Insights 0.9360
DSC-IITISM 0.9356
Multi-scale attention deeplabv3+ 0.9065
U-Net 0.7665
Baseline 0.9172
Proposed 0.9395

Multiple Myeloma Cell Segmentation results

Multiple Myeloma Cell Segmentation result

Model weights

You can download the learned weights for each dataset in the following table.

Dataset Learned weights
ISIC 2018 TMUnet
ISIC 2017 TMUnet
Ph2 TMUnet

Query

All implementations are done by Reza Azad and Moein Heidari. For any query please contact us for more information.

rezazad68@gmail.com
moeinheidari7829@gmail.com
Issues
  • About the evaluate result

    About the evaluate result

    Hi! You have done a splendid work and thanks for your code. I have a question about the evaluate result. I used the learned weights you give to evaluate your model. And I didn't change any code. But the I found there is always 0 about the confusion matrix. As shown below: TMUNet It would be very kind of you to answer my question. Thanks.

    opened by GuoQingqing 1
  • About the boudary loss

    About the boudary loss

    Hi, you have done a splendid work and thanks for your code. I have a question about the boudary loss. In your code, the boudary loss did not use the boundary, loss_boundary = criteria(msk_pred, msk). And I run the code , I found the extracted boudary by the function "def Bextraction (img)" is not the edge. As shown below: 1

    It would be very kind of you to answer my question. Thanks.

    opened by GuoQingqing 1
  • I have a question

    I have a question

    Hello author, it is an honor to read your article. While reading your paper, I have a question. You mentioned in the data processing section that the image is processed to 256 256, but in the FAT-net you compared, the image resolution is 224 224, but your comparison result does not seem to have any change, why is this?

    opened by hxp2396 1
  • CUDNN_STATUS_NOT_INITIALIZED error

    CUDNN_STATUS_NOT_INITIALIZED error

    .
    .
    .
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [54,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [55,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [56,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [57,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [58,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [59,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [60,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [61,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [62,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    /pytorch/aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [487,0,0], thread: [63,0,0] Assertion `input_val >= zero && input_val <= one` failed.
    Traceback (most recent call last):
      File "train_skin.py", line 79, in <module>
        tloss.backward()
      File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 307, in backward
        torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
      File "/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 156, in backward
        allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
    RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED
    

    I faced this error when I ran the ISIC2018 demo you provided. I think the input of BCELoss is not valid, leading to the error. Note that my notebook's torch and torchvision are updated.

    opened by alibalapour 0
  • FileNotFoundError when running ISIC2018 demo

    FileNotFoundError when running ISIC2018 demo

    I want to run the demo on the ISIC2018 dataset. I do what you said in the Readme, but in the third stage, when I run train_skin.py, this error has occurred:

    Traceback (most recent call last):
      File "train_skin.py", line 36, in <module>
        train_dataset = isic_loader(path_Data = data_path, train = True)
      File "/kaggle/working/TMUnet/loader.py", line 44, in __init__
        self.data   = np.load(path_Data+'data_train.npy')
      File "/opt/conda/lib/python3.7/site-packages/numpy/lib/npyio.py", line 417, in load
        fid = stack.enter_context(open(os_fspath(file), "rb"))
    FileNotFoundError: [Errno 2] No such file or directory: './processed_data/isic18/data_train.npy'
    

    I think this is because in the Prepare_ISIC2018.py you didn't give the path correctly. Because data_train.npy, mask_train.npy, and other generated npy files are in the main directory.

    opened by alibalapour 0
Owner
Reza Azad
Deep Learning and Computer Vision Researcher
Reza Azad
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 267 Jun 24, 2022
Official implementation of NeurIPS 2021 paper "Contextual Similarity Aggregation with Self-attention for Visual Re-ranking"

CSA: Contextual Similarity Aggregation with Self-attention for Visual Re-ranking PyTorch training code for CSA (Contextual Similarity Aggregation). We

Hui Wu 13 May 19, 2022
Self-Learned Video Rain Streak Removal: When Cyclic Consistency Meets Temporal Correspondence

In this paper, we address the problem of rain streaks removal in video by developing a self-learned rain streak removal method, which does not require any clean groundtruth images in the training process.

Yang Wenhan 40 May 27, 2022
Code release for SLIP Self-supervision meets Language-Image Pre-training

SLIP: Self-supervision meets Language-Image Pre-training What you can find in this repo: Pre-trained models (with ViT-Small, Base, Large) and code to

Meta Research 570 Jun 28, 2022
ConvMAE: Masked Convolution Meets Masked Autoencoders

ConvMAE ConvMAE: Masked Convolution Meets Masked Autoencoders Peng Gao1, Teli Ma1, Hongsheng Li2, Jifeng Dai3, Yu Qiao1, 1 Shanghai AI Laboratory, 2 M

Alpha VL Team of Shanghai AI Lab 259 Jun 17, 2022
U^2-Net - Portrait matting This repository explores possibilities of using the original u^2-net model for portrait matting.

U^2-Net - Portrait matting This repository explores possibilities of using the original u^2-net model for portrait matting.

Dennis Bappert 99 Jun 23, 2022
The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

MIC-DKFZ 1.1k Jun 20, 2022
Neural networks applied in recognizing guitar chords using python, AutoML.NET with C# and .NET Core

Chord Recognition Demo application The demo application is written in C# with .NETCore. As of July 9, 2020, the only version available is for windows

Andres Mauricio Rondon PatiƱo 22 May 22, 2022
U-2-Net: U Square Net - Modified for paired image training of style transfer

U2-Net: U Square Net Modified for paired image training of style transfer This is an unofficial repo making use of the code which was made available b

Doron Adler 40 May 23, 2022
RGBD-Net - This repository contains a pytorch lightning implementation for the 3DV 2021 RGBD-Net paper.

[3DV 2021] We propose a new cascaded architecture for novel view synthesis, called RGBD-Net, which consists of two core components: a hierarchical depth regression network and a depth-aware generator network.

Phong Nguyen Ha 4 May 26, 2022
Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"

Deformable Attention Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DET

Phil Wang 87 Jun 18, 2022
Alex Pashevich 51 Jun 2, 2022
PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Haoyu Chen 62 May 13, 2022
Code for "Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search"

Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search This is an implementation for our paper Contextual Non-Loca

Tencent YouTu Research 45 Jun 22, 2022
Repo for CVPR2021 paper "QPIC: Query-Based Pairwise Human-Object Interaction Detection with Image-Wide Contextual Information"

QPIC: Query-Based Pairwise Human-Object Interaction Detection with Image-Wide Contextual Information by Masato Tamura, Hiroki Ohashi, and Tomoaki Yosh

null 89 Jun 29, 2022
banditml is a lightweight contextual bandit & reinforcement learning library designed to be used in production Python services.

banditml is a lightweight contextual bandit & reinforcement learning library designed to be used in production Python services. This library is developed by Bandit ML and ex-authors of Facebook's applied reinforcement learning platform, Reagent.

Bandit ML 44 May 23, 2022
Source code and data from the RecSys 2020 article "Carousel Personalization in Music Streaming Apps with Contextual Bandits" by W. Bendada, G. Salha and T. Bontempelli

Carousel Personalization in Music Streaming Apps with Contextual Bandits - RecSys 2020 This repository provides Python code and data to reproduce expe

Deezer 44 May 26, 2022
UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus

UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus General info This is

null 62 Jun 19, 2022
Generate Contextual Directory Wordlist For Target Org

PathPermutor Generate Contextual Directory Wordlist For Target Org This script generates contextual wordlist for any target org based on the set of UR

null 8 Jun 23, 2021