Official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer"

Overview

[AAAI2022] UCTransNet

This repo is the official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer" which is accepted at AAAI2022.

framework

We propose a Channel Transformer module (CTrans) and use it to replace the skip connections in original U-Net, thus we name it "U-CTrans-Net".

Requirements

Install from the requirements.txt using:

pip install -r requirements.txt

Usage

1. Data Preparation

1.1. GlaS and MoNuSeg Datasets

The original data can be downloaded in following links:

Then prepare the datasets in the following format for easy use of the code:

├── datasets
    ├── GlaS
    │   ├── Test_Folder
    │   │   ├── img
    │   │   └── labelcol
    │   ├── Train_Folder
    │   │   ├── img
    │   │   └── labelcol
    │   └── Val_Folder
    │       ├── img
    │       └── labelcol
    └── MoNuSeg
        ├── Test_Folder
        │   ├── img
        │   └── labelcol
        ├── Train_Folder
        │   ├── img
        │   └── labelcol
        └── Val_Folder
            ├── img
            └── labelcol

1.2. Synapse Dataset

The Synapse dataset we used is provided by TransUNet's authors. Please go to https://github.com/Beckschen/TransUNet/blob/main/datasets/README.md for details.

2. Training

As mentioned in the paper, we introduce two strategies to optimize UCTransNet.

The first step is to change the settings in Config.py, all the configurations including learning rate, batch size and etc. are in it.

2.1 Jointly Training

We optimize the convolution parameters in U-Net and the CTrans parameters together with a single loss. Run:

python train_model.py

2.2 Pre-training

Our method just replaces the skip connections in U-Net, so the parameters in U-Net can be used as part of pretrained weights.

By first training a classical U-Net using /nets/UNet.py then using the pretrained weights to train the UCTransNet, CTrans module can get better initial features.

This strategy can improve the convergence speed and may improve the final segmentation performance in some cases.

3. Testing

3.1. Get Pre-trained Models

Here, we provide pre-trained weights on GlaS and MoNuSeg, if you do not want to train the models by yourself, you can download them in the following links:

3.2. Test the Model and Visualize the Segmentation Results

First, change the session name in Config.py as the training phase. Then run:

python test_model.py

You can get the Dice and IoU scores and the visualization results.

4. Reproducibility

In our code, we carefully set the random seed and set cudnn as 'deterministic' mode to eliminate the randomness. However, there still exsist some factors which may cause different training results, e.g., the cuda version, GPU types, the number of GPUs and etc. The GPU used in our experiments is NVIDIA A40 (48G) and the cuda version is 11.2.

Especially for multi-GPU cases, the upsampling operation has big problems with randomness. See https://pytorch.org/docs/stable/notes/randomness.html for more details.

When training, we suggest to train the model twice to verify wheather the randomness is eliminated. Because we use the early stopping strategy, the final performance may change significantly due to the randomness.

Reference

Citations

If this code is helpful for your study, please cite:

@misc{wang2021uctransnet,
      title={UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer}, 
      author={Haonan Wang and Peng Cao and Jiaqi Wang and Osmar R. Zaiane},
      year={2021},
      eprint={2109.04335},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Contact

Haonan Wang ([email protected])

Comments
  • How to perform multi-classification tasks?

    How to perform multi-classification tasks?

    How to perform multi-classification tasks? I tried to modify the n_labels in config.py and got the following error. How can I modify it? Is there a problem with the mask I entered? What kind of mask should I use? image

    opened by l1uw3n 10
  • Question about the ''CCT'' block

    Question about the ''CCT'' block

    Hello author,thanks for your code.In my concerning, you just project the patches into the shape of ''Batch x 196 x 64、Batchsize x 196 x 128、Batchsize x 196 x 256、Batchsize x 196 x 512‘’. '196' is the number of patches. 64、128、256 and 512 are the enbedding dimensions and they are just equalling with the channel dimension,but not the real channel dimmension.So,could you tell me the difference of CCT oepration and original self-attention?Look forward to your reply.

    opened by JackHeroFly 4
  • 你好

    你好

    根据你给的requirements我安装了环境,下载了UCTransNet-MoNuSeg.pth.tar的预训练模型,运行后它显示RuntimeError: don't know how to restore data location of torch.FloatStorage (tagged with CUDA)。我的CUDA是11.1版本

    opened by Tmoork1996 3
  • Can you provide a GIAS data set that you divided when training GIAS?

    Can you provide a GIAS data set that you divided when training GIAS?

    Hello author, can you provide a GIAS data set that you divided when training GIAS? My email is [email protected], if possible, I will be very grateful to you

    opened by jane442 2
  • Early stop

    Early stop

    I trained the Unet in your implemantation. And test result: dice_pred 0.7823487865522568 iou_pred 0.649269364395218

    UNet is even better than the uctransnet model I trained and slightly lower than the model you provide uctransnet trained from default config. dice_pred 0.7601493846752263 iou_pred 0.6263514063702118

    Both training encounter the early stop(around 120 epoch, and default is 2000) What's the problem?

    opened by h1063135843 2
  • RuntimeError: The size of tensor a (1600) must match the size of tensor b (196) at non-singleton dimension 1

    RuntimeError: The size of tensor a (1600) must match the size of tensor b (196) at non-singleton dimension 1

    I tried to train model but got error:

    UCTransNet
    transformer head num: 4
    transformer layers num: 4
    transformer expand ratio: 4
    Let's use 2 GPUs!
    log dir:
    
    ========= Epoch [1/2001] =========
    Test_session_09.13_13h31
    Training with batch size : 4
    Traceback (most recent call last):
      File ".\train.py", line 187, in <module>
        model = main_loop(model_type=config.model_name, tensorboard=True)
      File ".\train.py", line 135, in main_loop
        train_one_epoch(train_loader, model, criterion, optimizer, writer, epoch, None, model_type, logger)
      File "D:\Projects\Crack\UCTransNet\Train_one_epoch.py", line 70, in train_one_epoch
        preds = model(images)
      File "C:\Users\techno v\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "C:\Users\techno v\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\parallel\data_parallel.py", line 168, in forward
        outputs = self.parallel_apply(replicas, inputs, kwargs)
      File "C:\Users\techno v\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\parallel\data_parallel.py", line 178, in parallel_apply
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
      File "C:\Users\techno v\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\parallel\parallel_apply.py", line 86, in parallel_apply
        output.reraise()
      File "C:\Users\techno v\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\_utils.py", line 425, in reraise
        raise self.exc_type(msg)
    RuntimeError: Caught RuntimeError in replica 0 on device 0.
    Original Traceback (most recent call last):
      File "C:\Users\techno v\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\parallel\parallel_apply.py", line 61, in _worker
        output = module(*input, **kwargs)
      File "C:\Users\techno v\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "D:\Projects\Crack\UCTransNet\models\UCTransNet.py", line 123, in forward
        x1,x2,x3,x4,att_weights = self.mtc(x1,x2,x3,x4)
      File "C:\Users\techno v\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "D:\Projects\Crack\UCTransNet\models\CTrans.py", line 348, in forward
        emb1 = self.embeddings_1(en1)
      File "C:\Users\techno v\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "D:\Projects\Crack\UCTransNet\models\CTrans.py", line 43, in forward
        embeddings = x + self.position_embeddings
    RuntimeError: The size of tensor a (1600) must match the size of tensor b (196) at non-singleton dimension 1
    
    opened by SeyedAliRezMousavi 2
  • 关于实验结果

    关于实验结果

    我使用您提供的MoNuSeg的训练验证及测试集,在两台不同得到服务器(Nvidia 3060Ti 和 Tesla T4)上分别训练了两次,Dice能保持在76%-78%之间,IOU保持在62%-64%之间。 但是,,,,换GlaS数据集,结果Dice47%-48%,Iou始终维持在:36%-37%?这也太离谱了!

    opened by Max-Well-Wang 1
  • What is the number of channels for Synapse dataset ?

    What is the number of channels for Synapse dataset ?

    Hi all, thank you for sharing your code. I have question about the implementation for Synapse dataset. The number of channels is defined as n_channels=3 in config.py file. However, this doesn't make sense for the Synapse dataset as each CT image has varying number of slices, from 85 to 198.

    Did you define the number of channels as 1 for Synapse dataset, as in the config.py of UDTransNet ?

    Best, Melike

    opened by ilteralp 1
  • Possible typo in encoding and token size

    Possible typo in encoding and token size

    Hi all, thanks a lot for sharing your code. I think there might be a small typo in embedding size and token size. The embedding size $E_i$ is defined as follows in "Multi-scale Feature Embedding" subsection,

    However, I think it should be,

    It's defined correctly for $E_5$ in Figure 2 as below,

    The same applies to token size $T_i$ as well.

    opened by ilteralp 1
  • about the details of three times five-fold cv

    about the details of three times five-fold cv

    Hello, author. I want to know the details of when you did three half-fold cross-validation.

    1. Do you mix the train dataset and the test dataset, and then do a five-fold cross-validation? Or do five-fold cross-validation on the train dataset, and then test the best model on the test dataset each time?
    2. Is it convenient for you to provide the dataset you divided? Clas and MoNuSeg

    Thanks

    opened by Schneey 1
  • train model

    train model

    训练时遇到一下问题,请问这种怎么处理 Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling .backward() or autograd.grad() the first time.

    opened by Minami77 0
Owner
Haonan Wang
Haonan Wang
Official PyTorch implementation for paper Context Matters: Graph-based Self-supervised Representation Learning for Medical Images

Context Matters: Graph-based Self-supervised Representation Learning for Medical Images Official PyTorch implementation for paper Context Matters: Gra

null 49 Nov 23, 2022
The official implementation of NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021]. https://arxiv.org/pdf/2101.12378.pdf

NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021] Release Notes The offical PyTorch implementation of NeMo, p

Angtian Wang 76 Nov 23, 2022
StyleGAN2-ADA - Official PyTorch implementation

Abstract: Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmentation mechanism that significantly stabilizes training in limited data regimes.

NVIDIA Research Projects 3.2k Dec 30, 2022
Official implementation of the ICLR 2021 paper

You Only Need Adversarial Supervision for Semantic Image Synthesis Official PyTorch implementation of the ICLR 2021 paper "You Only Need Adversarial S

Bosch Research 272 Dec 28, 2022
Official PyTorch implementation of Joint Object Detection and Multi-Object Tracking with Graph Neural Networks

This is the official PyTorch implementation of our paper: "Joint Object Detection and Multi-Object Tracking with Graph Neural Networks". Our project website and video demos are here.

Richard Wang 443 Dec 6, 2022
Official implementation of the paper Image Generators with Conditionally-Independent Pixel Synthesis https://arxiv.org/abs/2011.13775

CIPS -- Official Pytorch Implementation of the paper Image Generators with Conditionally-Independent Pixel Synthesis Requirements pip install -r requi

Multimodal Lab @ Samsung AI Center Moscow 201 Dec 21, 2022
Official pytorch implementation of paper "Image-to-image Translation via Hierarchical Style Disentanglement".

HiSD: Image-to-image Translation via Hierarchical Style Disentanglement Official pytorch implementation of paper "Image-to-image Translation

null 364 Dec 14, 2022
Official pytorch implementation of paper "Inception Convolution with Efficient Dilation Search" (CVPR 2021 Oral).

IC-Conv This repository is an official implementation of the paper Inception Convolution with Efficient Dilation Search. Getting Started Download Imag

Jie Liu 111 Dec 31, 2022
Official PyTorch Implementation of Unsupervised Learning of Scene Flow Estimation Fusing with Local Rigidity

UnRigidFlow This is the official PyTorch implementation of UnRigidFlow (IJCAI2019). Here are two sample results (~10MB gif for each) of our unsupervis

Liang Liu 28 Nov 16, 2022
Official implementation of our paper "LLA: Loss-aware Label Assignment for Dense Pedestrian Detection" in Pytorch.

LLA: Loss-aware Label Assignment for Dense Pedestrian Detection This project provides an implementation for "LLA: Loss-aware Label Assignment for Dens

null 35 Dec 6, 2022
Official implementation of Self-supervised Graph Attention Networks (SuperGAT), ICLR 2021.

SuperGAT Official implementation of Self-supervised Graph Attention Networks (SuperGAT). This model is presented at How to Find Your Friendly Neighbor

Dongkwan Kim 127 Dec 28, 2022
An official implementation of "SFNet: Learning Object-aware Semantic Correspondence" (CVPR 2019, TPAMI 2020) in PyTorch.

PyTorch implementation of SFNet This is the implementation of the paper "SFNet: Learning Object-aware Semantic Correspondence". For more information,

CV Lab @ Yonsei University 87 Dec 30, 2022
This project is the official implementation of our accepted ICLR 2021 paper BiPointNet: Binary Neural Network for Point Clouds.

BiPointNet: Binary Neural Network for Point Clouds Created by Haotong Qin, Zhongang Cai, Mingyuan Zhang, Yifu Ding, Haiyu Zhao, Shuai Yi, Xianglong Li

Haotong Qin 59 Dec 17, 2022
Official code implementation for "Personalized Federated Learning using Hypernetworks"

Personalized Federated Learning using Hypernetworks This is an official implementation of Personalized Federated Learning using Hypernetworks paper. [

Aviv Shamsian 121 Dec 25, 2022
StyleGAN2 - Official TensorFlow Implementation

StyleGAN2 - Official TensorFlow Implementation

NVIDIA Research Projects 10.1k Dec 28, 2022
Old Photo Restoration (Official PyTorch Implementation)

Bringing Old Photo Back to Life (CVPR 2020 oral)

Microsoft 11.3k Dec 30, 2022
Official implementation of "GS-WGAN: A Gradient-Sanitized Approach for Learning Differentially Private Generators" (NeurIPS 2020)

GS-WGAN This repository contains the implementation for GS-WGAN: A Gradient-Sanitized Approach for Learning Differentially Private Generators (NeurIPS

null 46 Nov 9, 2022
Official PyTorch implementation of Spatial Dependency Networks.

Spatial Dependency Networks: Neural Layers for Improved Generative Image Modeling Đorđe Miladinović   Aleksandar Stanić   Stefan Bauer   Jürgen Schmid

Djordje Miladinovic 34 Jan 19, 2022
Official implementation of YOGO for Point-Cloud Processing

You Only Group Once: Efficient Point-Cloud Processing with Token Representation and Relation Inference Module By Chenfeng Xu, Bohan Zhai, Bichen Wu, T

Chenfeng Xu 67 Dec 20, 2022