The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"

Overview

Swin-Unet

The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"(https://arxiv.org/abs/2105.05537). A validation for U-shaped Swin Transformer.

1. Download pre-trained swin transformer model (Swin-T)

2. Prepare data

  • The datasets we used are provided by TransUnet's authors. Please go to "./datasets/README.md" for details, or please send an Email to jienengchen01 AT gmail.com to request the preprocessed data. If you would like to use the preprocessed data, please use it for research purposes and do not redistribute it (following the TransUnet's License).

3. Environment

  • Please prepare an environment with python=3.7, and then use the command "pip install -r requirements.txt" for the dependencies.

4. Train/Test

  • Run the train script on synapse dataset. The batch size we used is 24. If you do not have enough GPU memory, the bacth size can be reduced to 12 or 6 to save memory.

  • Train

sh train.sh or python train.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --root_path your DATA_DIR --max_epochs 150 --output_dir your OUT_DIR  --img_size 224 --base_lr 0.05 --batch_size 24
  • Test
sh test.sh or python test.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --is_saveni --volume_path your DATA_DIR --output_dir your OUT_DIR --max_epoch 150 --base_lr 0.05 --img_size 224 --batch_size 24

References

Citation

@misc{cao2021swinunet,
      title={Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation}, 
      author={Hu Cao and Yueyue Wang and Joy Chen and Dongsheng Jiang and Xiaopeng Zhang and Qi Tian and Manning Wang},
      year={2021},
      eprint={2105.05537},
      archivePrefix={arXiv},
      primaryClass={eess.IV}
}
Comments
  • How is Unet trained?

    How is Unet trained?

    Hi there! Thank you for your innovative work. I want to ask how your Unet and attention-unet are trained. I tried many times and I couldn’t achieve that effect.

    Looking forward to your reply. Best wish!!!

    opened by a313071162 4
  • Unable to adapt to different resolution

    Unable to adapt to different resolution

    Hi, thanks for your contribution. The model is unable to adapt to different resolutions except for 224 for both height and width. The error often occurs as

    RuntimeError: shape '[xxx, xxx, xxx, xxx, xxx]' is invalid for input of size xxx
    
    opened by KruskalLin 4
  • different data

    different data

    hello @HuCaoFighting , I want to use the swin_unet for segmentation task, and the data is end with png or jpg rather than .nii.gz. So I need to change the dataset_synapse.py, I tried many times but I failed. Coul you give me some advice? Thanks!

    opened by FrankWuuu 3
  • ACDC dataset preprocessing

    ACDC dataset preprocessing

    Hi! Thanks for your nice work! And I'd like to ask what is the preprocessing method of the ACDC dataset? I did not find relevant information in TransUNet and Swin-Unet.

    opened by zhouzhenghong-gt 3
  • Great job! But pretrained_ckpt load erroe;

    Great job! But pretrained_ckpt load erroe;

    When I train with my datasets, the error occured as below:

    => merge config from configs/swin_tiny_patch4_window7_224_lite.yaml
    SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.2;num_classes:9
    ---final upsample expand_first---
    pretrained_path:./pretrained_ckpt/swin_tiny_patch4_window7_224.pth
    Traceback (most recent call last):
      File "train.py", line 96, in <module>
        net.load_from(config)
      File "/mnt/e/projects/Sementic_Segmentation/Swin-Unet-PyTorch/networks/vision_transformer.py", line 58, in load_from
        pretrained_dict = torch.load(pretrained_path, map_location=device)
      File "/home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/serialization.py", line 527, in load
        with _open_zipfile_reader(f) as opened_zipfile:
      File "/home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/serialization.py", line 224, in __init__
        super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
    RuntimeError: version_ <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /pytorch/caffe2/serialize/inline_container.cc:132, please report a bug to PyTorch. Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 2. Your PyTorch installation may be too old. (init at /pytorch/caffe2/serialize/inline_container.cc:132)
    frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7fd2bbfab193 in /home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/lib/libc10.so)
    frame #1: caffe2::serialize::PyTorchStreamReader::init() + 0x1f5b (0x7fd2bf1339eb in /home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/lib/libtorch.so)
    frame #2: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::string const&) + 0x64 (0x7fd2bf134c04 in /home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/lib/libtorch.so)
    frame #3: <unknown function> + 0x6c6536 (0x7fd307246536 in /home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
    frame #4: <unknown function> + 0x295a74 (0x7fd306e15a74 in /home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
    <omitting python frames>
    frame #33: __libc_start_main + 0xf3 (0x7fd319cf10b3 in /lib/x86_64-linux-gnu/libc.so.6)
    

    How should I do? Thank you!

    opened by Byronnar 1
  • config.py        _C.MODEL.SWIN.QK_SCALE = None

    config.py _C.MODEL.SWIN.QK_SCALE = None

    Traceback (most recent call last): File "train.py", line 10, in from config import get_config File "C:\Users\Downloads\Swin-Unet-main\config.py", line 72, in _C.MODEL.SWIN.QK_SCALE = None File "C:\anaconda\envs\swin\lib\site-packages\yacs\config.py", line 158, in setattr type(value), name, _VALID_TYPES File "C:\anaconda\envs\swin\lib\site-packages\yacs\config.py", line 521, in _assert_with_logging assert cond, msg AssertionError: Invalid type <class 'NoneType'> for key QK_SCALE; valid types = {<class 'float'>, <class 'tuple'>, <class 'str'>, <class 'list'>, <class 'bool'>, <class 'int'>}

    opened by Lsysslsyss 0
  • Test gets stuck

    Test gets stuck

    I'm trying to reproduce the experiment. The training stage looks good. However, the test stage gets stuck here around half an hour. Is this normal? I also wonder how long it will take to finish the evaluation? Thanks! image

    opened by hsparrow 0
  • Focal-Unet --> low complexity, easy to use unet implementation for small dataset domains

    Focal-Unet --> low complexity, easy to use unet implementation for small dataset domains

    Dear researchers, please also consider checking our new unet implementation of the proposed focal modulation block for small datasets problems, specially in medical domains. https://github.com/givkashi/Focal-Unet

    https://www.researchgate.net/publication/366423296_Focal-UNet_UNet-like_Focal_Modulation_for_Medical_Image_Segmentation

    Thanks.

    opened by mohammadrezanaderi4 0
  • einops.rearrange 替换问题

    einops.rearrange 替换问题

    I want to translate this network with torchsharp
    in class PatchExpand 语言用C#

       x = x.view(B, H, W, C);
       x1 = x.reshape(B,H*2,W*2,C/4);
    
       x2= rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1 : 2, p2 : 2, c : C / 4)
    

    Can x1 replace x2? Or there are other options?

    opened by lindadamama 1
  • Training time

    Training time

    First I would like to thank you for your contribution. I was wondering, how many GPUs did you use and how long does it take to train your network and reach the results of the paper? Thanks in advance for the answer.

    opened by cugwu 0
  • test metric is very poor

    test metric is very poor

    During training, the parameters I use are: accumulation_steps=None, amp_opt_level='O1', base_lr=0.01, batch_size=6, cache_mode='part', cfg='./configs/swin_tiny_patch4_window7_224_lite.yaml', dataset='Synapse', deterministic=1, eval=False, img_size=224, is_pretrain=True, is_savenii=False, list_dir='./lists/lists_Synapse', max_epochs=150, max_iterations=30000, num_classes=14, opts=None, output_dir='./output_dir', resume=None, seed=1234, tag=None, test_save_dir='../predictions', throughput=False, use_checkpoint=False, volume_path='../data/Synapse/test_vol_h5', z_spacing=1, zip=False) epoch_149.pth The difference with method in paper is: num_classes=14, because there 13 classes and background in Synapse dataset. However, I got <mean_dice : 0.000471 mean_hd95 : 56.415742 I don't know the reason. Can somebody help me?

    opened by wangru1026 0
Owner
PHD Candidate
null
Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face Manipulation" published in CVPR 2020.

FFD Source Code Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face M

null 88 Nov 22, 2022
Official code of our work, Unified Pre-training for Program Understanding and Generation [NAACL 2021].

PLBART Code pre-release of our work, Unified Pre-training for Program Understanding and Generation accepted at NAACL 2021. Note. A detailed documentat

Wasi Ahmad 138 Dec 30, 2022
This repo contains the official code of our work SAM-SLR which won the CVPR 2021 Challenge on Large Scale Signer Independent Isolated Sign Language Recognition.

Skeleton Aware Multi-modal Sign Language Recognition By Songyao Jiang, Bin Sun, Lichen Wang, Yue Bai, Kunpeng Li and Yun Fu. Smile Lab @ Northeastern

Isen (Songyao Jiang) 128 Dec 8, 2022
Official repo for the work titled "SharinGAN: Combining Synthetic and Real Data for Unsupervised GeometryEstimation"

SharinGAN Official repo for the work titled "SharinGAN: Combining Synthetic and Real Data for Unsupervised GeometryEstimation" The official project we

Koutilya PNVR 23 Oct 19, 2022
Evaluating different engineering tricks that make RL work

Reinforcement Learning Tricks, Index This repository contains the code for the paper "Distilling Reinforcement Learning Tricks for Video Games". Short

Anssi 15 Dec 26, 2022
This is the repo for our work "Towards Persona-Based Empathetic Conversational Models" (EMNLP 2020)

Towards Persona-Based Empathetic Conversational Models (PEC) This is the repo for our work "Towards Persona-Based Empathetic Conversational Models" (E

Zhong Peixiang 35 Nov 17, 2022
PyTorch implementation of CloudWalk's recent work DenseBody

densebody_pytorch PyTorch implementation of CloudWalk's recent paper DenseBody. Note: For most recent updates, please check out the dev branch. Update

Lingbo Yang 401 Nov 19, 2022
This is the implementation of our work Deep Extreme Cut (DEXTR), for object segmentation from extreme points.

This is the implementation of our work Deep Extreme Cut (DEXTR), for object segmentation from extreme points.

Sergi Caelles 828 Jan 5, 2023
The personal repository of the work: *DanceNet3D: Music Based Dance Generation with Parametric Motion Transformer*.

DanceNet3D The personal repository of the work: DanceNet3D: Music Based Dance Generation with Parametric Motion Transformer. Dataset and Results Pleas

南嘉Nanga 36 Dec 21, 2022
ST++: Make Self-training Work Better for Semi-supervised Semantic Segmentation

ST++ This is the official PyTorch implementation of our paper: ST++: Make Self-training Work Better for Semi-supervised Semantic Segmentation. Lihe Ya

Lihe Yang 147 Jan 3, 2023
Official code of our work, AVATAR: A Parallel Corpus for Java-Python Program Translation.

AVATAR Official code of our work, AVATAR: A Parallel Corpus for Java-Python Program Translation. AVATAR stands for jAVA-pyThon progrAm tRanslation. AV

Wasi Ahmad 26 Dec 3, 2022
Hand tracking demo for DIY Smart Glasses with a remote computer doing the work

CameraStream This is a demonstration that streams the image from smartglasses to a pc, does the hand recognition on the remote pc and streams the proc

Teemu Laurila 20 Oct 13, 2022
Implementation for the IJCAI2021 work "Beyond the Spectrum: Detecting Deepfakes via Re-synthesis"

Beyond the Spectrum Implementation for the IJCAI2021 work "Beyond the Spectrum: Detecting Deepfakes via Re-synthesis" by Yang He, Ning Yu, Margret Keu

Yang He 27 Jan 7, 2023
Official Repo of my work for SREC Nandyal Machine Learning Bootcamp

About the Bootcamp A 3-day Machine Learning Bootcamp organised by Department of Electronics and Communication Engineering, Santhiram Engineering Colle

MS 1 Nov 29, 2021
PlaidML is a framework for making deep learning work everywhere.

A platform for making deep learning work everywhere. Documentation | Installation Instructions | Building PlaidML | Contributing | Troubleshooting | R

PlaidML 4.5k Jan 2, 2023
This repository contains the source code of our work on designing efficient CNNs for computer vision

Efficient networks for Computer Vision This repo contains source code of our work on designing efficient networks for different computer vision tasks:

Sachin Mehta 386 Nov 26, 2022
Does MAML Only Work via Feature Re-use? A Data Set Centric Perspective

Does-MAML-Only-Work-via-Feature-Re-use-A-Data-Set-Centric-Perspective Does MAML Only Work via Feature Re-use? A Data Set Centric Perspective Installin

null 2 Nov 7, 2022
How the Deep Q-learning method works and discuss the new ideas that makes the algorithm work

Deep Q-Learning Recommend papers The first step is to read and understand the method that you will implement. It was first introduced in a 2013 paper

null 1 Jan 25, 2022