Pytorch Code for "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

Overview

Medical-Transformer

Pytorch Code for the paper "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

About this repo:

This repo hosts the code for the following networks:

  1. Gated Axial Attention U-Net
  2. MedT

Introduction

Majority of existing Transformer-based network architectures proposed for vision applications require large-scale datasets to train properly. However, compared to the datasets for vision applications, for medical imaging the number of data samples is relatively low, making it difficult to efficiently train transformers for medical appli- cations. To this end, we propose a Gated Axial-Attention model which extends the existing architectures by introducing an additional control mechanism in the self-attention module. Furthermore, to train the model effectively on medical images, we propose a Local-Global training strat- egy (LoGo) which further improves the performance. Specifically, we op- erate on the whole image and patches to learn global and local features, respectively. The proposed Medical Transformer (MedT) uses LoGo training strategy on Gated Axial Attention U-Net.

Using the code:

  • Clone this repository:
git clone https://github.com/jeya-maria-jose/Medical-Transformer
cd Medical-Transformer

The code is stable using Python 3.6.10, Pytorch 1.4.0

To install all the dependencies using conda:

conda env create -f environment.yml
conda activate medt

To install all the dependencies using pip:

pip install -r requirements.txt

Links for downloading the public Datasets:

  1. GLAS Dataset - Link (Original) | Link (Resized)
  2. MoNuSeG Dataset - Link (Original)
  3. Brain Anatomy US dataset from the paper will be made public soon !

Using the Code for your dataset

Dataset Preparation

Prepare the dataset in the following format for easy use of the code. The train and test folders should contain two subfolders each: img and label. Make sure the images their corresponding segmentation masks are placed under these folders and have the same name for easy correspondance. Please change the data loaders to your need if you prefer not preparing the dataset in this format.

Train Folder-----
      img----
          0001.png
          0002.png
          .......
      label---
          0001.png
          0002.png
          .......
Validation Folder-----
      img----
          0001.png
          0002.png
          .......
      label---
          0001.png
          0002.png
          .......
Test Folder-----
      img----
          0001.png
          0002.png
          .......
      label---
          0001.png
          0002.png
          .......
  • The ground truth images should have pixels corresponding to the labels. Example: In case of binary segmentation, the pixels in the GT should be 0 or 255.

Training Command:

python train.py --train_dataset "enter train directory" --val_dataset "enter validation directory" --direc 'path for results to be saved' --batch_size 4 --epoch 400 --save_freq 10 --modelname "gatedaxialunet" --learning_rate 0.001 --imgsize 128 --gray "no"
Change modelname to MedT or logo to train them

Testing Command:

python test.py --loaddirec "./saved_model_path/model_name.pth" --val_dataset "test dataset directory" --direc 'path for results to be saved' --batch_size 1 --modelname "gatedaxialunet" --imgsize 128 --gray "no"

The results including predicted segmentations maps will be placed in the results folder along with the model weights. Run the performance metrics code in MATLAB for calculating F1 Score and mIoU.

Notes:

Note that these experiments were conducted in Nvidia Quadro 8000 with 48 GB memory.

Acknowledgement:

The dataloader code is inspired from pytorch-UNet . The axial attention code is developed from axial-deeplab.

Citation:

To add

Open an issue or mail me directly in case of any queries or suggestions.

Comments
  • test error:lower accuracy

    test error:lower accuracy

    Thank you for your excellent contribution, and I have reproduced the experiment of the thesis. An error occurred during the test. The test result is quite different from the test result of the thesis. Will you update test.py recently,please.

    opened by hgmlu 11
  • when i run the command that python train.py --train_dataset

    when i run the command that python train.py --train_dataset "E:\Medical-Transformer-main\datasets\train" --val_dataset "E:\Medical-Transformer-main\datasets\val" --direc 'E:\Medical-Transformer-main\path' --batch_size 4 --epoch 400 --save_freq 10 --modelname "gatedaxialunet" --learning_rate 0.001 --imgsize 128 --gray "no" ,An error is displayed as follows:

    Total_params: 1326850 Traceback (most recent call last): File "train.py", line 130, in for batch_idx, (X_batch, y_batch, *rest) in enumerate(dataloader): File "D:\anaconda1\envs\Medical-transformer\lib\site-packages\torch\utils\data\dataloader.py", line 521, in next data = self._next_data() File "D:\anaconda1\envs\Medical-transformer\lib\site-packages\torch\utils\data\dataloader.py", line 561, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "D:\anaconda1\envs\Medical-transformer\lib\site-packages\torch\utils\data_utils\fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "D:\anaconda1\envs\Medical-transformer\lib\site-packages\torch\utils\data_utils\fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "E:\Medical-Transformer-main\utils.py", line 156, in getitem mask[mask<=127] = 0 TypeError: '<=' not supported between instances of 'NoneType' and 'int'

    How does the author solve this error report? Is it a problem of data set processing? How to solve it?

    opened by 123xu223 6
  • Larger training image input size results

    Larger training image input size results "cuda out of memory" error

    Hi, sorry to bother you, really appreciate your work. I recently trained the axialunet with input size 473, and it gives me a cuda out of memory error, I also noticed that you seems to comment out the decoder1, decoder2, decoder3 code in medt_net, could you please tell me why? Is it because the memory consumption problem?Thanks in advance.

    good first issue 
    opened by weiaicunzai 4
  • Have problem in the MoNuSegTrainingData?

    Have problem in the MoNuSegTrainingData?

    I could not see the label in the Annotations folder. In this folder has only .xml file. Please help me have labels in order train. Thank you. image

    image Thank you.

    opened by tphankr 3
  • How to use test.py?

    How to use test.py?

    dear @jeya-maria-jose, Can you help me, I had problem about test.py below: (medt) D:\segmentation\Medical-Transformer-main>python test.py --loaddirec "resultfinal_model.pth" --val_dataset "data_6449_s256/test" --direc '290/gatedaxialunet.pth' --batch_size 1 --modelname "gatedaxialunet" --imgsize 256 --gray "no" None Traceback (most recent call last): File "test.py", line 82, in train_dataset = ImageToImage2D(args.train_dataset, tf_train) File "D:\segmentation\Medical-Transformer-main\utils.py", line 131, in init self.input_path = os.path.join(dataset_path, 'img') File "C:\Users\DELL\anaconda3\envs\medt\lib\ntpath.py", line 76, in join path = os.fspath(path) TypeError: expected str, bytes or os.PathLike object, not NoneType

    opened by NguyenDangBinh 3
  • The question of the gated axial attention formula

    The question of the gated axial attention formula

    Excuse me, for the gated axial-attention formula, I can understant the setting of the gating parameters in the three positions of Gqr, Gkr and Gr, but why do you add the gating parameters on v, which has no position information. Thank you!

    opened by wang-yu-xuan 3
  • Image Size

    Image Size

    Thank you for your excellent code. Could you tell me what is the image size when you used to train your network. If I want to use my own image, for example the image with size 512 x 512, what parameters should I modify? Since when I try to re-implement your project, I encountered some dimensions issues, like

    RuntimeError: einsum() operands do not broadcast with remapped shapes [original->remapped]: [1024, 8, 1, 512]->[1024, 8, 512, 1, 1] [1, 128, 128]->[1, 1, 128, 128, 1]

    Could you give me some hints?

    opened by jm-R152 3
  • Image size when training with Glas dataset

    Image size when training with Glas dataset

    Thanks for your great work, and I need some help on training Glas datset. Could you please tell me the image size you used for training Glas dataset?Cause the images in Glas dataset tend to have different size, 775 x 522, 589 x 453 and 567 x 430, which need to be converted to the same shape before sending to the network, and 128x128 is just to small for training Glas dataset,. Thanks.

    good first issue 
    opened by weiaicunzai 3
  • Gated Mechanism

    Gated Mechanism

    In the Gated parameters, why are they all requires_grad=False? # Priority on encoding

        ## Initial values 
    
        self.f_qr = nn.Parameter(torch.tensor(0.1),  requires_grad=False)
        self.f_kr = nn.Parameter(torch.tensor(0.1),  requires_grad=False)
        self.f_sve = nn.Parameter(torch.tensor(0.1),  requires_grad=False)
        self.f_sv = nn.Parameter(torch.tensor(1.0),  requires_grad=False)
    
    opened by QiaoSiBo 2
  • About utils.py in lib.models Code problem

    About utils.py in lib.models Code problem

    When I read your code, I found that, in axialnet.py, AxialAttention_wopos(nn.Module) and AxialAttention_dynamic(nn.Module) called the method of qkv_transform() in lib.models.utils.py. But in lib.models.utils.py, qkv_transform(nn.Conv1d) Only comments are included, Is it the problem of the code or does it have no content?

    opened by Jx-Tan 2
  • Hello, I reported an error after training. The loss value has always been 0. How can I solve it?

    Hello, I reported an error after training. The loss value has always been 0. How can I solve it?

    libpng warning: iCCP: known incorrect sRGB profile epoch [0/400], loss:0.2294 libpng warning: iCCP: known incorrect sRGB profile libpng warning: iCCP: known incorrect sRGB profile epoch [1/400], loss:0.0000 libpng warning: iCCP: known incorrect sRGB profile epoch [2/400], loss:0.0000 libpng warning: iCCP: known incorrect sRGB profile epoch [3/400], loss:0.0000 libpng warning: iCCP: known incorrect sRGB profile epoch [4/400], loss:0.0000 libpng warning: iCCP: known incorrect sRGB profile epoch [5/400], loss:0.0000 libpng warning: iCCP: known incorrect sRGB profile epoch [6/400], loss:0.0000 libpng warning: iCCP: known incorrect sRGB profile epoch [7/400], loss:0.0000 libpng warning: iCCP: known incorrect sRGB profile epoch [8/400], loss:0.0000 libpng warning: iCCP: known incorrect sRGB profile epoch [9/400], loss:0.0000 libpng warning: iCCP: known incorrect sRGB profile epoch [10/400], loss:0.0000 libpng warning: iCCP: known incorrect sRGB profile libpng warning: iCCP: known incorrect sRGB profile epoch [11/400], loss:0.0000 libpng warning: iCCP: known incorrect sRGB profile epoch [12/400], loss:0.0000

    opened by 123xu223 2
  • 请教,einsum()函数报错

    请教,einsum()函数报错

    RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [277, 8, 1, 234]->[277, 8, 234, 1, 1] [1, 64, 64]->[1, 1, 64, 64, 1]

    opened by orderer0001 2
  • Question about the TransWeather

    Question about the TransWeather

    Hey @jeya-maria-jose,

    It's really interesting of the Transwheathe in CVPR22,

    However, I found there is one file named utils.py missing, and how to organize the train/test dataset is also not clear.

    May I ask if you can take care of these issues?

    Bests,

    opened by Amazingren 1
  • input data

    input data

    I used GLAS dataset,but my training results are black,I think it's may be caused by the dataset. So I want to know how should I prepare the dataset or how to process the dataset. If it's ok,could you please show your dataset? Thanks!

    opened by wangmengyao123 1
  • RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0

    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0

    when I run this command: python train.py --train_dataset "Traindata" --val_dataset "Traindata" --direc 'result' --batch_size 4 --epoch 400 --save_freq 10 --modelname "gatedaxialunet" --learning_rate 0.001 --imgsize 128 --gray "no", then an error occurred. Can you help me fix this issue?

    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 522 and 453 in dimension 2 at /opt/conda/conda-bld/pytorch_1579022034529/work/aten/src/TH/generic/THTensor.cpp:612

    opened by Deerzh 1
  • pred images are black

    pred images are black

    i found that the output tensor are <0,so that after tmp[tmp>=0.5] = 1 ,tmp[tmp<0.5] = 0,the saved mask are black..can you give some advice? thanks

    opened by scwuchung 0
Owner
Jeya Maria Jose
PhD Student at Johns Hopkins University.
Jeya Maria Jose
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 360 Dec 10, 2022
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 359 Jan 5, 2023
A code generator from ONNX to PyTorch code

onnx-pytorch Generating pytorch code from ONNX. Currently support onnx==1.9.0 and torch==1.8.1. Installation From PyPI pip install onnx-pytorch From

Wenhao Hu 94 Jan 6, 2023
An essential implementation of BYOL in PyTorch + PyTorch Lightning

Essential BYOL A simple and complete implementation of Bootstrap your own latent: A new approach to self-supervised Learning in PyTorch + PyTorch Ligh

Enrico Fini 48 Sep 27, 2022
RealFormer-Pytorch Implementation of RealFormer using pytorch

RealFormer-Pytorch Implementation of RealFormer using pytorch. Includes comparison with classical Transformer on image classification task (ViT) wrt C

Simo Ryu 90 Dec 8, 2022
Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

Luca Moschella 520 Dec 30, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

NVIDIA Corporation 6.9k Jan 3, 2023
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Mayur 119 Nov 24, 2022
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pytorch Lightning 1.4k Jan 1, 2023
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Ritchie Ng 9.2k Jan 2, 2023
A PyTorch implementation of the paper Mixup: Beyond Empirical Risk Minimization in PyTorch

Mixup: Beyond Empirical Risk Minimization in PyTorch This is an unofficial PyTorch implementation of mixup: Beyond Empirical Risk Minimization. The co

Harry Yang 121 Dec 17, 2022
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
PyTorch implementation of Advantage async actor-critic Algorithms (A3C) in PyTorch

Advantage async actor-critic Algorithms (A3C) in PyTorch @inproceedings{mnih2016asynchronous, title={Asynchronous methods for deep reinforcement lea

LEI TAI 111 Dec 8, 2022
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

Vince 0 Jul 13, 2021
PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

Subin An 8 Nov 21, 2022
PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

PyTorch-LIT PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices. With

Amin Rezaei 157 Dec 11, 2022
A general framework for deep learning experiments under PyTorch based on pytorch-lightning

torchx Torchx is a general framework for deep learning experiments under PyTorch based on pytorch-lightning. TODO list gan-like training wrapper text

Yingtian Liu 6 Mar 17, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

Introduction This is a Python package available on PyPI for NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pyto

Artit 'Art' Wangperawong 5 Sep 29, 2021
Pytorch-diffusion - A basic PyTorch implementation of 'Denoising Diffusion Probabilistic Models'

PyTorch implementation of 'Denoising Diffusion Probabilistic Models' This reposi

Arthur Juliani 76 Jan 7, 2023