This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation

Overview

TransUNet

This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation

Usage

1. Download Google pre-trained ViT models

wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz &&
mkdir ../model/vit_checkpoint/imagenet21k &&
mv {MODEL_NAME}.npz ../model/vit_checkpoint/imagenet21k/{MODEL_NAME}.npz

2. Prepare data

Please go to "./datasets/README.md" for details, or please send an Email to jienengchen01 AT gmail.com to request the preprocessed data. If you would like to use the preprocessed data, please use it for research purposes and do not redistribute it.

3. Environment

Please prepare an environment with python=3.7, and then use the command "pip install -r requirements.txt" for the dependencies.

4. Train/Test

  • Run the train script on synapse dataset. The batch size can be reduced to 12 or 6 to save memory (please also decrease the base_lr linearly), and both can reach similar performance.
CUDA_VISIBLE_DEVICES=0 python train.py --dataset Synapse --vit_name R50-ViT-B_16
  • Run the test script on synapse dataset. It supports testing for both 2D images and 3D volumes.
python test.py --dataset Synapse --vit_name R50-ViT-B_16

Reference

Citations

@article{chen2021transunet,
  title={TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation},
  author={Chen, Jieneng and Lu, Yongyi and Yu, Qihang and Luo, Xiangde and Adeli, Ehsan and Wang, Yan and Lu, Le and Yuille, Alan L., and Zhou, Yuyin},
  journal={arXiv preprint arXiv:2102.04306},
  year={2021}
}
Comments
  • AttributeError:

    AttributeError: "'skip_channels'"

    Hi, When I run the test.py , I get the following error:

    ` File "/home/wf/anaconda3/envs/trunet/lib/python3.7/site-packages/ml_collections/config_dict/config_dict.py", line 883, in getitem field = self._fields[key] KeyError: 'skip_channels'

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/home/wf/anaconda3/envs/trunet/lib/python3.7/site-packages/ml_collections/config_dict/config_dict.py", line 807, in getattr return self[attribute] File "/home/wf/anaconda3/envs/trunet/lib/python3.7/site-packages/ml_collections/config_dict/config_dict.py", line 889, in getitem raise KeyError(self._generate_did_you_mean_message(key, str(e))) KeyError: "'skip_channels'"

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/home/wf/anaconda3/envs/project_TransUNet_2/project_TransUNet/TransUNet/test.py", line 118, in net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda() File "/home/wf/anaconda3/envs/project_TransUNet_2/project_TransUNet/TransUNet/networks/vit_seg_modeling.py", line 378, in init self.decoder = DecoderCup(config) File "/home/wf/anaconda3/envs/project_TransUNet_2/project_TransUNet/TransUNet/networks/vit_seg_modeling.py", line 343, in init skip_channels = self.config.skip_channels File "/home/wf/anaconda3/envs/trunet/lib/python3.7/site-packages/ml_collections/config_dict/config_dict.py", line 809, in getattr raise AttributeError(e) AttributeError: "'skip_channels'"`

    Then I change skip_channels = self.config.skip_channels (in vit_seg_modeling.py 343 line) to skip_channels = [512,256,64,16] , I get the following new error:

    RuntimeError: Error(s) in loading state_dict for VisionTransformer: Unexpected key(s) in state_dict: "transformer.embeddings.hybrid_model.root.conv.weight", size mismatch for transformer.embeddings.patch_embeddings.weight: copying a param with shape torch.Size([768, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([768, 3, 16, 16])

    For me, the cause of this is unclear. Does anyone have an idea?

    opened by wangfan120 17
  • You must reference the code you copied

    You must reference the code you copied

    First, most of your code is taken from(https://github.com/jeonsworld/ViT-pytorch) which is owned by @jeonsworld. Second, the entire idea of your paper is taken from (https://arxiv.org/abs/2012.15840). In your paper, you have not mentioned that they are the first to propose this architecture and your work is derived from them.

    This is very unprofessional. I hope the famous people in your paper already knows about your conduct.

    opened by ching-sui1995 6
  • image preprocess code

    image preprocess code

    Thanks for your great contribution! Would u like to share your image preprocess code, that is get 3D image normalize and ertract 2D from 3D and then save to .npz file Thanks !

    opened by AlphaJia 3
  • Request for the preprocessed data

    Request for the preprocessed data

    @Beckschen , Hello, I am trying to extend an existing continual semantic segmentation method to be suitable for 3D medical datasets. The Synapse dataset (Medical Image Segmentation on Synapse multi-organ CT) seems to be a good choice for us to construct further experiments. However, we can not achieve satisfactory baseline performance due to the lack of 3D data preprocessing. It will be very nice if you would like to share the preprocessed data. My email: [email protected]

    opened by ciuzaak 2
  • Hi, a question about the models' output

    Hi, a question about the models' output

    Thanks for your work. I have some problems about using it for another task. I modified the dataloader, my input are {img (1,3,512,512), label(1,1,512,512) }, but on training phase, I got the output=model(image_patch) (1,2,512,512), when calc loss, caused error RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 1, 512, 512].

    args: vit_name = R50-ViT-B_16, n_skip=3, vit_patches_size=16

    can you give me some advice? thank you!

    opened by JoyChen1998 2
  • ModuleNotFoundError: No module named 'datasets.dataset_synapse'

    ModuleNotFoundError: No module named 'datasets.dataset_synapse'

    Hello, I use the preprocessed data (requested as mentioned in the Readme) and get the following error:

    TransUNet$ python test.py --dataset Synapse --vit_name R50-ViT-B_16 Traceback (most recent call last): File "test.py", line 12, in from datasets.dataset_synapse import Synapse_dataset ModuleNotFoundError: No module named 'datasets.dataset_synapse'

    For me, the cause of this is unclear. Does anyone have an idea?

    opened by andife 2
  • Wrong Position Embedding Size

    Wrong Position Embedding Size

    In the original implementation, the position embedding has a dimension of [ n_patches+1, hidden_size] to accommodate for additional class token:

    https://github.com/jeonsworld/ViT-pytorch/blob/a786151f6ceed00e97ab526772916faec5efb8ed/models/modeling.py#L150

    In your implementation, you removed the class token and your position embedding has a dimention [ n_patches, hidden_size]:

    https://github.com/Beckschen/TransUNet/blob/main/networks/vit_seg_modeling.py#L149

    When I tried to load the pre-trained model based on your changes (the new size) , I get a mismatch error:

    model.load_from(np.load(args.pretrained_dir)), in load_from self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) RuntimeError: The size of tensor a (196) must match the size of tensor b (170) at non-singleton dimension 1

    Would you please explain how the pretrained checkpoint can be loaded based on these changes ?

    opened by Siyuan89 2
  • Running the test.py module

    Running the test.py module

    Hi! I am very interested in your paper and solution for the medical image segmentation.

    Regarding my first question, I ran the test.py and it outputs that there is a problem in line 122 in the code and that there is no such file or directory '../model/TU_Synapse224/TU_pretrain_R50-ViT-B_16_skip3_bs24_224/epoch_29.pth' I hope that you can help me with this problem. I did run it on single image though, if that can be some sort of an issue.

    Second question is, why is there a need to clip an image between [-125, 275] ? Image that I downloaded is grayscale of the uint8 data type.

    opened by SM-93 2
  • How to add an extra channel to the Restnet50 model architecture?

    How to add an extra channel to the Restnet50 model architecture?

    TransUNet/networks/vit_seg_modeling_resnet_skip.py)

    When I tried to add the 4th channel but it throws me this error:

    RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[2, 4, 256, 256] to have 3 channels, but got 4 channels instead

    opened by prakharsdev 1
  • Problem with test performance

    Problem with test performance

    Hello, thanks a lot for your work. I have a question about the testing result on Synapse dataset. I have trained the network with R50-ViT-B_16, and base_lr=0.01, batch_size=24,n_skip=3, img_size=224, max_epochs=300, vit_patches_size=16. The difference with method in paper is: num_classes=14, because there 13 classes and background in Synapse dataset. After 300 epochs, I use <python test.py --dataset Synapse --vit_name R50-ViT-B_16 --max_epochs 300> to test. However, I got <mean_dice : 0.131854 mean_hd95 : 83.012894>, this is far from the result in paper. I don't know the reason. Can somebody help me?

    opened by dianexuli 1
  • AttributeError: Can't pickle local object 'trainer_synapse.<locals>.worker_init_fn'

    AttributeError: Can't pickle local object 'trainer_synapse..worker_init_fn'

    I got below error. Any idea please.

    <main.Args object at 0x000002995A0011C0> <main.Args object at 0x000002995A0011C0> The length of train set is: 2211 93 iterations per epoch. 13950 max iterations 93 iterations per epoch. 13950 max iterations 0%| | 0/150 [00:00<?, ?it/s]

    AttributeError Traceback (most recent call last) C:\Users\60962~1.LPI\AppData\Local\Temp/ipykernel_19332/2385906778.py in 90 91 trainer = {'Synapse': trainer_synapse,} ---> 92 trainer[dataset_name](args, net, snapshot_path)

    C:\Users\60962~1.LPI\AppData\Local\Temp/ipykernel_19332/1972247546.py in trainer_synapse(args, model, snapshot_path) 50 iterator = tqdm(range(max_epoch), ncols=70) 51 for epoch_num in iterator: ---> 52 for i_batch, sampled_batch in enumerate(trainloader): 53 image_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 54 image_batch, label_batch = image_batch.cuda(), label_batch.cuda()

    D:\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in iter(self) 366 return self._iterator 367 else: --> 368 return self._get_iterator() 369 370 @property

    D:\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in _get_iterator(self) 312 else: 313 self.check_worker_number_rationality() --> 314 return _MultiProcessingDataLoaderIter(self) 315 316 @property

    D:\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in init(self, loader) 925 # before it starts, and del tries to join but will get: 926 # AssertionError: can only join a started process. --> 927 w.start() 928 self._index_queues.append(index_queue) 929 self._workers.append(w)

    D:\anaconda3\lib\multiprocessing\process.py in start(self) 119 'daemonic processes are not allowed to have children' 120 _cleanup() --> 121 self._popen = self._Popen(self) 122 self._sentinel = self._popen.sentinel 123 # Avoid a refcycle if the target function holds an indirect

    D:\anaconda3\lib\multiprocessing\context.py in _Popen(process_obj) 222 @staticmethod 223 def _Popen(process_obj): --> 224 return _default_context.get_context().Process._Popen(process_obj) 225 226 class DefaultContext(BaseContext):

    D:\anaconda3\lib\multiprocessing\context.py in _Popen(process_obj) 325 def _Popen(process_obj): 326 from .popen_spawn_win32 import Popen --> 327 return Popen(process_obj) 328 329 class SpawnContext(BaseContext):

    D:\anaconda3\lib\multiprocessing\popen_spawn_win32.py in init(self, process_obj) 91 try: 92 reduction.dump(prep_data, to_child) ---> 93 reduction.dump(process_obj, to_child) 94 finally: 95 set_spawning_popen(None)

    D:\anaconda3\lib\multiprocessing\reduction.py in dump(obj, file, protocol) 58 def dump(obj, file, protocol=None): 59 '''Replacement for pickle.dump() using ForkingPickler.''' ---> 60 ForkingPickler(file, protocol).dump(obj) 61 62 #

    AttributeError: Can't pickle local object 'trainer_synapse..worker_init_fn'

    opened by alqurri77 1
  • Focal-Unet --> low complexity, easy to use unet implementation for small dataset domains

    Focal-Unet --> low complexity, easy to use unet implementation for small dataset domains

    Dear researchers, please also consider checking our unet implementation of the proposed focal modulation block for small datasets problems, specially in medical domains. Our model shows the state of the art result camparing to base swin-unet and transunet models in terms of dice score on synapse dataset. https://github.com/givkashi/Focal-Unet Thanks.

    opened by mohammadrezanaderi4 0
Owner
CS Ph.D student @ Johns Hopkins University, CCVL
null
TransGAN: Two Transformers Can Make One Strong GAN

[Preprint] "TransGAN: Two Transformers Can Make One Strong GAN", Yifan Jiang, Shiyu Chang, Zhangyang Wang

VITA 1.5k Jan 7, 2023
This repository holds the code for the paper "Deep Conditional Gaussian Mixture Model forConstrained Clustering".

Deep Conditional Gaussian Mixture Model for Constrained Clustering. This repository holds the code for the paper Deep Conditional Gaussian Mixture Mod

null 17 Oct 30, 2022
This repository holds code and data for our PETS'22 article 'From "Onion Not Found" to Guard Discovery'.

From "Onion Not Found" to Guard Discovery (PETS'22) This repository holds the code and data for our PETS'22 paper titled 'From "Onion Not Found" to Gu

Lennart Oldenburg 3 May 4, 2022
Medical Image Segmentation using Squeeze-and-Expansion Transformers

Medical Image Segmentation using Squeeze-and-Expansion Transformers Introduction This repository contains the code of the IJCAI'2021 paper 'Medical Im

askerlee 172 Dec 20, 2022
Multi-atlas segmentation (MAS) is a promising framework for medical image segmentation

Multi-atlas segmentation (MAS) is a promising framework for medical image segmentation. Generally, MAS methods register multiple atlases, i.e., medical images with corresponding labels, to a target image;

NanYoMy 13 Oct 9, 2022
A Strong Baseline for Image Semantic Segmentation

A Strong Baseline for Image Semantic Segmentation Introduction This project is an open source semantic segmentation toolbox based on PyTorch. It is ba

Clark He 49 Sep 20, 2022
This framework implements the data poisoning method found in the paper Adversarial Examples Make Strong Poisons

Adversarial poison generation and evaluation. This framework implements the data poisoning method found in the paper Adversarial Examples Make Strong

null 31 Nov 1, 2022
This repo is developed for Strong Baseline For Vehicle Re-Identification in Track 2 Ai-City-2021 Challenges

A STRONG BASELINE FOR VEHICLE RE-IDENTIFICATION This paper is accepted to the IEEE Conference on Computer Vision and Pattern Recognition Workshop(CVPR

Cybercore Co. Ltd 78 Dec 29, 2022
Copy Paste positive polyp using poisson image blending for medical image segmentation

Copy Paste positive polyp using poisson image blending for medical image segmentation According poisson image blending I've completely used it for bio

Phạm Vũ Hùng 2 Oct 19, 2021
The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

MIC-DKFZ 1.2k Jan 4, 2023
Build a medical knowledge graph based on Unified Language Medical System (UMLS)

UMLS-Graph Build a medical knowledge graph based on Unified Language Medical System (UMLS) Requisite Install MySQL Server 5.6 and import UMLS data int

Donghua Chen 6 Dec 25, 2022
Semi Supervised Learning for Medical Image Segmentation, a collection of literature reviews and code implementations.

Semi-supervised-learning-for-medical-image-segmentation. Recently, semi-supervised image segmentation has become a hot topic in medical image computin

Healthcare Intelligence Laboratory 1.3k Jan 3, 2023
Code base for "On-the-Fly Test-time Adaptation for Medical Image Segmentation"

On-the-Fly Adaptation Official Pytorch Code base for On-the-Fly Test-time Adaptation for Medical Image Segmentation Paper Introduction One major probl

Jeya Maria Jose 17 Nov 10, 2022
GAN encoders in PyTorch that could match PGGAN, StyleGAN v1/v2, and BigGAN. Code also integrates the implementation of these GANs.

MTV-TSA: Adaptable GAN Encoders for Image Reconstruction via Multi-type Latent Vectors with Two-scale Attentions. This is the official code release fo

owl 37 Dec 24, 2022
Final project code: Implementing MAE with downscaled encoders and datasets, for ESE546 FA21 at University of Pennsylvania

546 Final Project: Masked Autoencoder Haoran Tang, Qirui Wu 1. Training To train the network, please run mae_pretraining.py. Please modify folder path

Haoran Tang 0 Apr 22, 2022
[ICCV 2021] A Simple Baseline for Semi-supervised Semantic Segmentation with Strong Data Augmentation

[ICCV 2021] A Simple Baseline for Semi-supervised Semantic Segmentation with Strong Data Augmentation

CodingMan 45 Dec 12, 2022
CoTr: Efficiently Bridging CNN and Transformer for 3D Medical Image Segmentation

CoTr: Efficient 3D Medical Image Segmentation by bridging CNN and Transformer This is the official pytorch implementation of the CoTr: Paper: CoTr: Ef

null 218 Dec 25, 2022
[CVPR'21] FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space

FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space by Quande Liu, Cheng Chen, Ji

Quande Liu 178 Jan 6, 2023
A collection of loss functions for medical image segmentation

A collection of loss functions for medical image segmentation

Jun 3.1k Jan 3, 2023