Medical Image Segmentation using Squeeze-and-Expansion Transformers

Overview

Medical Image Segmentation using Squeeze-and-Expansion Transformers

Introduction

This repository contains the code of the IJCAI'2021 paper 'Medical Image Segmentation using Squeeze-and-Expansion Transformers'.

Installation

This repository is based on PyTorch 1.7.

To evaluate setr, you need to install mmcv according to https://github.com/fudan-zvg/SETR/.

Usage Example

python3.7 train2d.py --task refuge --split all --net segtran --bb resnet101 --translayers 3 --layercompress 1,1,2,2 --maxiter 10000

python3.7 test2d.py --task refuge --split all --ds valid2 --net segtran --bb resnet101 --translayers 3 --layercompress 1,1,2,2 --cpdir ../model/segtran-refuge-train,valid,test,drishiti,rim-05101448 --iters 7000

Acknowledgement

The "receptivefield" folder is from https://github.com/fornaxai/receptivefield/, with minor edits and bug fixes.

The "MNet_DeepCDR" folder is from https://github.com/HzFu/MNet_DeepCDR, with minor customizations.

The "efficientnet" folder is from https://github.com/lukemelas/EfficientNet-PyTorch, with minor customizations.

The "networks/setr" folder is a slimmed-down version of https://github.com/fudan-zvg/SETR/, with a few custom config files.

There are a few baseline models under networks/ which were originally implemented in various github repos. Here I won't acknowlege them individually.

Some code under "dataloaders/" (esp. 3D image preprocessing) was borrowed from https://github.com/yulequan/UA-MT.

Citation

If you find our code useful, please kindly consider to cite our paper as:

@InProceedings{segtran,
author="Li, Shaohua
and Sui, Xiuchao
and Luo, Xiangde
and Xu, Xinxing
and Liu Yong
and Goh, Rick Siow Mong",
title="Medical Image Segmentation using Squeeze-and-Expansion Transformers",
booktitle="The 30th International Joint Conference on Artificial Intelligence (IJCAI)",
year="2021",
}
Comments
  • Great project! But I encountered some problems about test

    Great project! But I encountered some problems about test

    When I try test2d.py, the error occured: python3 test2d.py --task fundus --split all --ds valid2 --net segtran --bb resnet101 --translayers 3 --layercompress 1,1,2,2 --cpdir ../model/segtran-fundus-train,valid,test,drishti,rim-05011826 --iters 9500 --outorigsize 'fundus' mean/std loaded from 'fundus-cropped-gray0.5-stats.json' 'all' 400 samples of size 576 chosen (total 800) in '../data/fundus/valid2' 'args' orig in-feat: 2048, in-feat: 2048, out-feat: 512, in-scheme: AN, out-scheme: AN, translayer_dims: [2048, 2048, 1024, 512] Namespace(ablate_multihead=False, attn_clip=500, backbone_type='resnet101', batch_size=8, bb_feat_upsize=True, binarize=False, calc_flop=False, checkpoint_dir='../model/segtran-fundus-train,valid,test,drishti,rim-05011826', debug=False, device='cuda', do_remove_frag=False, ds_class='SegCrop', ds_name='valid2', ds_split='all', eval_robustness=False, gpu='0', gray_alpha=0.5, has_FFN_in_squeeze=False, in_fpn_layers='34', in_fpn_scheme='AN', in_fpn_use_bn=False, iters='9500', job_name='fundus-valid2', mean=[0.578, 0.429, 0.318], mid_type='shared', mince_channel_props=None, mince_scales=None, net='segtran', num_attractors=256, num_classes=3, num_modalities=0, num_modes=4, num_translayers=3, num_workers=4, orig_input_size=(576, 576), out_fpn_layers='1234', out_fpn_scheme='AN', out_origsize=True, output_upscale=2.0, patch_size=(288, 288), polyformer_mode=None, pos_bias_radius=7, pos_code_type='lsinu', pos_code_weight=1.0, qk_have_bias=True, reload_mask=False, reshape_mask_type=None, robust_aug_degrees=[0.5, 1.5], robust_aug_types=None, robust_ref_cp_path=None, robust_sample_num=120, robustness_augs=None, sample_num=-1, save_ext='png', save_features_img_count=0, save_results=True, std=[0.184, 0.162, 0.144], task_name='fundus', test_interp=None, tie_qk_scheme='none', trans_output_type='private', translayer_compress_ratios=[1.0, 1.0, 2.0, 2.0], use_exclusive_masks=False, use_global_bias=False, use_mince_transformer=False, use_pretrained=True, use_squeezed_transformer=True, verbose_output=False, vis_layers=None, vis_mode=None) Segtran Fusion Encoder with 3 trans-layers Learnable Sinusoidal positional encoding Fusion0-in-squeeze: v_has_bias: False, has_FFN: False, has_input_skip: False Fusion0-in-squeeze in_feat_dim: 2048, feat_dim: 2048, qk_have_bias: True Fusion0-squeeze-out: v_has_bias: False, has_FFN: True, has_input_skip: False Fusion0-squeeze-out in_feat_dim: 2048, feat_dim: 2048, qk_have_bias: True Fusion1-in-squeeze: v_has_bias: False, has_FFN: False, has_input_skip: False Fusion1-in-squeeze in_feat_dim: 2048, feat_dim: 2048, qk_have_bias: True Fusion1-squeeze-out: v_has_bias: False, has_FFN: True, has_input_skip: False Fusion1-squeeze-out in_feat_dim: 2048, feat_dim: 1024, qk_have_bias: True Fusion2-in-squeeze: v_has_bias: False, has_FFN: False, has_input_skip: False Fusion2-in-squeeze in_feat_dim: 1024, feat_dim: 1024, qk_have_bias: True Fusion2-squeeze-out: v_has_bias: False, has_FFN: True, has_input_skip: False Fusion2-squeeze-out in_feat_dim: 1024, feat_dim: 512, qk_have_bias: True Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /root/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth **resnet101 created** Parameter Count: 172737073 **args[backbone_type]=resnet101, checkpoint args[backbone_type]=eff-b4, inconsistent!** I find in test2d:“parser.add_argument('--bb', dest='backbone_type', type=str, default='eff-b4', help='Segtran backbone'” then find in resnet.py,when test2d.py is running,it downloads from the web101. But it has this error, I don't know how to solve this,Could you tell me how to solve this problem? thank u!

    opened by Lemonweier 14
  • brats dice seems low?

    brats dice seems low?

    Dear Shaohua while checking the train process I see the dice is around 9% after percent of total epochs. the total training on 8 rtx 8000 nvidia takes around 25 hours estimated. what is the cause of such low value for dice? waiting to hear from you. thanks

    opened by bhralzz 11
  • Are checkpoints of this model available?

    Are checkpoints of this model available?

    Hi, thanks for your excellent work!

    I want to re-implement this model, will you provide your trained models whose results achieve SOTA from your paper? Thus we can only use your model doing inference.

    Especially model on REFUGE and BraTS, segtran is very extraordinary.

    Best,

    opened by MengzhangLI 9
  • how to test on other brats dataset such as 2020?

    how to test on other brats dataset such as 2020?

    Hi dear Askerlee, first of all, greetin for your valuable code, please explain how I can test this code over brats 2020 dataset, as I replace that dataset in the brats path also setting the parameters in the train3d.py, but the .h5 files were not exist in the cases pathes. waiting to hear from you, thanks

    opened by bhralzz 7
  • Baseline comparison for 2D dataset

    Baseline comparison for 2D dataset

    Hi Dr. Lee, do you have the code that you used to compare with various baselines in section 5.2 (list of other models) for 2D dataset? Is there a similar comparison for the 3D dataset?

    opened by wshi8 6
  • about ModuleNotFoundError

    about ModuleNotFoundError

    Hi,sir. I move the project to the new computer. I meet ModuleNotFoundError: No module named 'networks.segtran2d'.It would have been nice if you had given me a hint. Thank u!

    opened by Lemonweier 5
  • about calling test from another file.py

    about calling test from another file.py

    Hi Dear Shaohua

    is there any part of this code to set just trained model path and new sample pathe from out file.oy then getting result by these parameters without going into args parsing?

    opened by bhralzz 4
  • Brats iter-8000 checkpoint is not loading on test

    Brats iter-8000 checkpoint is not loading on test

    Hi Dr Lee,

    I got this error when I tried to run the test command using your recently updated Brats iter_8000 checkpoint, any advice?

    python3 test3d.py --task brats --split all --bs 5 --ds 2019valid --net segtran --attractors 1024 --translayers 1 --cpdir ./ --iters 8000

    Traceback (most recent call last): File "test3d.py", line 426, in allcls_avg_metric = test_calculate_metric(iter_nums) File "test3d.py", line 350, in test_calculate_metric load_model(net, args, checkpoint_path) File "test3d.py", line 311, in load_model if (k not in ignored_keys) and (args2.dict[k] != cp_args[k]): KeyError: 'qk_have_bias'

    The Error is on "load_model(net, args, checkpoint_path)" in the following:

    for iter_num in iter_nums:
        if args.checkpoint_dir:
            checkpoint_path = os.path.join(args.checkpoint_dir, 'iter_' + str(iter_num) + '.pth')
            load_model(net, args, checkpoint_path)
    
    opened by wshi8 4
  • test3d.py failure

    test3d.py failure

    Hi Dr Lee,

    I tried to run test3d.py for a set of 110 test images a few times, but it somehow always exit unexpectedly around 70-ish images without any error. So far, I have never successfully finished the 110 images.

    Any advice?

    Best, Wendy

    opened by wshi8 4
  • Model training

    Model training

    Hello, I am trying to reproduce the model training with 2019 Brats datasets (LGG and HGG). I am using BS = 1 due to memory limitation.

    I trained 10k+ iterations now, the loss functions look like this, is this trending in the right direction or expected?

    Screen Shot 2021-08-30 at 1 29 35 PM
    opened by wshi8 4
  • about nnunet

    about nnunet

    Dear sir, we want to use the code of nnU-Net on fundus (2d images). We have some problems, could you give us some guidance, thank u! We also notice you have mentioned' It is primarily designed for 3D tasks, but can also handle 2D images after converting them to pseudo-3D.' Here are our settings and traceback: The parameter are --nproc_per_node=1 --master_port=7152 /data/sementation/code/train2d.py --task fundus --ds train --split train --translayers 3 --layercompress 1,1,2,2 --net nnunet --bb resnet101 --maxiter 5000 --bs 32 --noqkbias

    Traceback (most recent call last): File "<input>", line 1, in <module> File "/root/.pycharm_helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile pydev_imports.execfile(filename, global_vars, local_vars) # execute the script File "/root/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile exec(compile(contents+"\n", file, 'exec'), glob, loc) File "/data/sementation/code/train2d.py", line 1434, in <module> outputs = net(image_batch) File "/data/hliu/anaconda3/install/envs/segtran-master1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/data/hliu/anaconda3/install/envs/segtran-master1/lib/python3.8/site-packages/nnunet/network_architecture/generic_UNet.py", line 400, in forward x = torch.cat((x, skips[-(u + 1)]), dim=1) RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 6 but got size 5 for tensor number 1 in the list.

    opened by Lemonweier 3
  • data problem on Polymorphic Transformers

    data problem on Polymorphic Transformers

    Hi~please help me figure out some questions.

    1. i found that " python3 train2d.py --task refuge --ds train,valid,test --split all --maxiter 10000 --net unet-scratch " should be " --task fundus", otherwise will report errors.
    2. the data for polyp downloaded from https://github.com/DengPingFan/PraNet (search for "testing data") is not complete. some image is missing, i guess it should include in their training data. But for training data, they have several datasets mixed together into two folders(image and mask)..so should i manually select image to our folders? Could you have a look?

    many thanks in advance

    opened by kathyliu579 14
  • Train Polyformer (source) Problem

    Train Polyformer (source) Problem

    Screenshot 2022-05-09 174149 I got the error AttributeError: 'EasyDict' object has no attribute 'use_mince_transformer' as shown. I tried to fix it but it didn't work. Please tell me how to fix it. Thank you.

    opened by nguyenlecong 18
  • problem about dataset

    problem about dataset

    Thanks for your great project! I downloaded REFUGE dataset but the name and number of image files are different from names in this repo! For example, training part has 360 files with names 0001.jpg, ...

    opened by rezashaemi 10
  • Great project! But I encountered some problems.

    Great project! But I encountered some problems.

    When I train polyp datasets, the error occured, logs as before:

    10 epochs, 1002 itertations each.
      0%|                                          | 0/10 [00:00<?, ?it/s]
    
    Image scales: 8x8. Voxels: [1, 1600, 1792]
    outputs shape: torch.Size([1, 2, 320, 320])
    mask_batch shape: torch.Size([1, 3, 320, 320])
      0%|                                          | 0/10 [00:00<?, ?it/s]
    Traceback (most recent call last):
      File "train2d.py", line 1100, in <module>
        mask_batch.permute([0, 2, 3, 1]))
      File "/root/anaconda3/envs/python377/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/root/anaconda3/envs/python377/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 617, in forward
        reduction=self.reduction)
      File "/root/anaconda3/envs/python377/lib/python3.7/site-packages/torch/nn/functional.py", line 2433, in binary_cross_entropy_with_logits
        raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
    ValueError: Target size (torch.Size([1, 320, 320, 3])) must be the same as input size (torch.Size([1, 320, 320, 2]))
    

    Could you tell me how to solve this problem? thank u!

    opened by Byronnar 12
Owner
askerlee
Machine learning researcher in Singapore.
askerlee
Pytorch Code for "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

Medical-Transformer Pytorch Code for the paper "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation" About this repo: This repo

Jeya Maria Jose 615 Dec 25, 2022
This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation

TransUNet This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation Usage

null 1.4k Jan 4, 2023
Copy Paste positive polyp using poisson image blending for medical image segmentation

Copy Paste positive polyp using poisson image blending for medical image segmentation According poisson image blending I've completely used it for bio

Phạm Vũ Hùng 2 Oct 19, 2021
Multi-atlas segmentation (MAS) is a promising framework for medical image segmentation

Multi-atlas segmentation (MAS) is a promising framework for medical image segmentation. Generally, MAS methods register multiple atlases, i.e., medical images with corresponding labels, to a target image;

NanYoMy 13 Oct 9, 2022
The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

MIC-DKFZ 1.2k Jan 4, 2023
Build a medical knowledge graph based on Unified Language Medical System (UMLS)

UMLS-Graph Build a medical knowledge graph based on Unified Language Medical System (UMLS) Requisite Install MySQL Server 5.6 and import UMLS data int

Donghua Chen 6 Dec 25, 2022
Code release for NeX: Real-time View Synthesis with Neural Basis Expansion

NeX: Real-time View Synthesis with Neural Basis Expansion Project Page | Video | Paper | COLAB | Shiny Dataset We present NeX, a new approach to novel

null 536 Dec 20, 2022
NBEATSx: Neural basis expansion analysis with exogenous variables

NBEATSx: Neural basis expansion analysis with exogenous variables We extend the NBEATS model to incorporate exogenous factors. The resulting method, c

Cristian Challu 100 Dec 31, 2022
Code release for NeX: Real-time View Synthesis with Neural Basis Expansion

NeX: Real-time View Synthesis with Neural Basis Expansion Project Page | Video | Paper | COLAB | Shiny Dataset We present NeX, a new approach to novel

null 538 Jan 9, 2023
A Moonraker plug-in for real-time compensation of frame thermal expansion

Frame Expansion Compensation A Moonraker plug-in for real-time compensation of frame thermal expansion. Installation Credit to protoloft, from whom I

null 58 Jan 2, 2023
An expansion for RDKit to read all types of files in one line

RDMolReader An expansion for RDKit to read all types of files in one line How to use? Add this single .py file to your project and import MolFromFile(

Ali Khodabandehlou 1 Dec 18, 2021
CoTr: Efficiently Bridging CNN and Transformer for 3D Medical Image Segmentation

CoTr: Efficient 3D Medical Image Segmentation by bridging CNN and Transformer This is the official pytorch implementation of the CoTr: Paper: CoTr: Ef

null 218 Dec 25, 2022
Semi Supervised Learning for Medical Image Segmentation, a collection of literature reviews and code implementations.

Semi-supervised-learning-for-medical-image-segmentation. Recently, semi-supervised image segmentation has become a hot topic in medical image computin

Healthcare Intelligence Laboratory 1.3k Jan 3, 2023
[CVPR'21] FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space

FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space by Quande Liu, Cheng Chen, Ji

Quande Liu 178 Jan 6, 2023
A collection of loss functions for medical image segmentation

A collection of loss functions for medical image segmentation

Jun 3.1k Jan 3, 2023
The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"

Swin-Unet The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"(https://arxiv.org/abs/2105.05537). A validatio

null 869 Jan 7, 2023
Segmentation for medical image.

EfficientSegmentation Introduction EfficientSegmentation is an open source, PyTorch-based segmentation framework for 3D medical image. Features A whol

null 68 Nov 28, 2022
A PyTorch implementation for V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation

A PyTorch implementation of V-Net Vnet is a PyTorch implementation of the paper V-Net: Fully Convolutional Neural Networks for Volumetric Medical Imag

Matthew Macy 606 Dec 21, 2022
A pytorch-based deep learning framework for multi-modal 2D/3D medical image segmentation

A 3D multi-modal medical image segmentation library in PyTorch We strongly believe in open and reproducible deep learning research. Our goal is to imp

Adaloglou Nikolas 1.2k Dec 27, 2022