Code for A Volumetric Transformer for Accurate 3D Tumor Segmentation

Overview

VT-UNet

This repo contains the supported pytorch code and configuration files to reproduce 3D medical image segmentaion results of VT-UNet.

VT-UNet Architecture

Environment

Prepare an environment with python=3.8, and then run the command "pip install -r requirements.txt" for the dependencies.

Data Preparation

  • For experiments we used four datasets:

  • File structure

     BRATS2021
      |---Data
      |   |--- RSNA_ASNR_MICCAI_BraTS2021_TrainingData
      |   |   |--- BraTS2021_00000
      |   |   |   |--- BraTS2021_00000_flair...
      |   
      |              
      |   
      |
     VT-UNet
      |---train.py
      |---test.py
      |---pretrained_ckpt
      |---saved_model
      ...
    

Pre-Trained Weights

Pre-Trained Base Model For BraTS 2021

Train/Test

  • Train : Run the train script on BraTS 2021 Training Dataset with Base model Configurations.
python train.py --cfg configs/vt_unet_base.yaml --num_classes 3 --epochs 350
  • Test : Run the test script on BraTS 2021 Training Dataset.
python test.py --cfg configs/vt_unet_base.yaml --num_classes 3

Acknowledgements

This repository makes liberal use of code from open_brats2020, Swin Transformer, Video Swin Transformer and Swin-Unet

References

Citing VT-UNet

    @misc{peiris2021volumetric,
      title={A Volumetric Transformer for Accurate 3D Tumor Segmentation}, 
      author={Himashi Peiris and Munawar Hayat and Zhaolin Chen and Gary Egan and Mehrtash Harandi},
      year={2021},
      eprint={2111.13300},
      archivePrefix={arXiv},
      primaryClass={eess.IV}
    }
Comments
  • Data preprocess

    Data preprocess

    Hi @himashi92

    Thanks for this awesome work and repo. I am trying to replicate this on the Liver data in the decathlon. When I run the the train command for the base training CUDA_VISIBLE_DEVICES=0 nohup vtunet_train 3d_fullres vtunetTrainerV2_vtunet_tumor_base 3 0 &> base.out &

    I get the following error below KeyError: 'BRATS_001'(): `stage: 0 {'batch_size': 2, 'num_pool_per_axis': [5, 5, 5], 'patch_size': array([128, 128, 128]), 'median_patient_size_in_voxels': array([195, 207, 207]), 'current_spacing': array([2.473119 , 1.89831205, 1.89831205]), 'original_spacing': array([1. , 0.76757812, 0.76757812]), 'do_dummy_2D_data_aug': False, 'pool_op_kernel_sizes': [[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 'conv_kernel_sizes': [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]}

    stage: 1 {'batch_size': 2, 'num_pool_per_axis': [5, 5, 5], 'patch_size': array([128, 128, 128]), 'median_patient_size_in_voxels': array([482, 512, 512]), 'current_spacing': array([1. , 0.76757812, 0.76757812]), 'original_spacing': array([1. , 0.76757812, 0.76757812]), 'do_dummy_2D_data_aug': False, 'pool_op_kernel_sizes': [[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 'conv_kernel_sizes': [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]}

    I am using stage 1 from these plans I am using batch dice + CE loss

    I am using data from this folder: /home/VT-UNet/VTUNet/DATASET/vtunet_preprocessed/Task003_Liver/vtunetData_plans_v2.1 ############################################### loading dataset loading all case properties 2022-08-16 04:36:37.549490: Creating new 5-fold cross-validation split... 2022-08-16 04:36:37.550782: Desired fold for training: 0 2022-08-16 04:36:37.550879: This split has 387 training and 73 validation cases. Traceback (most recent call last): File "/home/VT-UNet/transegvenv/bin/vtunet_train", line 33, in sys.exit(load_entry_point('vtunet', 'console_scripts', 'vtunet_train')()) File "/home/VT-UNet/VTUNet/vtunet/run/run_training.py", line 134, in main trainer.initialize(not validation_only) File "/home/VT-UNet/VTUNet/vtunet/training/network_training/vtunetTrainerV2_vtunet_liver_base.py", line 90, in initialize self.dl_tr, self.dl_val = self.get_basic_generators() File "/home/VT-UNet/VTUNet/vtunet/training/network_training/vtunetTrainer.py", line 401, in get_basic_generators self.do_split() File "/home/VT-UNet/VTUNet/vtunet/training/network_training/vtunetTrainerV2_vtunet_liver_base.py", line 410, in do_split self.dataset_tr[i] = self.dataset[i] KeyError: 'BRATS_001'`

    Any idea how to fix this, also but separately I get an error when I try to do the vtunet_train_3d?

    opened by b3r-prog 7
  • dataset

    dataset

    Hello, author, through your paper, I found that you divided 1251 pieces of data into training, verification and testing for experiments, rather than taking part in the 21 challenge to obtain 251 verification sets for experiments. Is this a convincing division?

    opened by H-CODE6 7
  • Ask for data help

    Ask for data help

    Hello, I'd like to ask if you can provide the treatment of liver or pancreatic tumors. The code shows the treatment of brats2021 data. I need to deal with the segmentation of gastric cancer data. The data type is DCM. Thank you for sharing.

    opened by huying12 7
  • Training & Testing On MSD Data

    Training & Testing On MSD Data

    First of all, thanks @himashi92 for this amazing work. Since we do not have access to BraTS data and want to run your model on the MSD data set given here, it would be very helpful if you could help me with that. What will be the tree structure of the code before running & what changes in the code need to be done? It would be further more helpful if could add your code of running it on MSD data here in this repo.

    opened by abdur75648 7
  • Still overflowing GPU VRAM with reduced batch size

    Still overflowing GPU VRAM with reduced batch size

    Hi! I'm still having issues with overflowing my VRAM (RTX 3090 24GB) whenever attempting to train. Even after I've redued my batch size to 1. Any ideas on what I can do?

    opened by Chadkowski 5
  • RuntimeError: CUDA out of memory

    RuntimeError: CUDA out of memory

    Traceback (most recent call last):
      File "train.py", line 314, in <module>
        main(arguments)
      File "train.py", line 166, in main
        segs_S1 = model_1(inputs_S1)
      File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vision_transformer.py", line 49, in forward
        return self.swin_unet(x)
      File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 1118, in forward
        x, x_downsample, v_values_1, k_values_1, q_values_1, v_values_2, k_values_2, q_values_2 = self.forward_features(
      File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 960, in forward_features
        x, v1, k1, q1, v2, k2, q2 = layer(x, i)
      File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 726, in forward
        x, v1, k1, q1 = blk(x, attn_mask, None, None, None)
      File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 392, in forward
        x, x2, v, k, q = self.forward_part1(x, mask_matrix, prev_v, prev_k, prev_q, is_decoder)
      File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 340, in forward_part1
        attn_windows, cross_attn_windows, v, k, q = self.attn(x_windows, mask=attn_mask, prev_v=prev_v, prev_k=prev_k,
      File "/home/yusongli/.bin/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/yusongli/_project/shidaoai/task/01_seg/VT-Unet/vtunet/vt_unet.py", line 191, in forward
        attn = q @ k.transpose(-2, -1)
    RuntimeError: CUDA out of memory. Tried to allocate 9.02 GiB (GPU 0; 23.70 GiB total capacity; 11.37 GiB already allocated; 8.42 GiB free; 13.10 GiB reserved in total by PyTorch) If reserved memor
    y is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
    

    Notice: I modified the dataloader to fit my own dataset. I've serched lots of solution but failed to solve this. Very sad :-( Could you please help me? Maybe the model is too large to train? I don't know. Thanks!

    opened by ysl2 5
  • Error when run train.py

    Error when run train.py

    Hi, I am trying to regenerate the result, I got the following error. Any help with this? Thanks

    AssertionError Traceback (most recent call last) ~/Desktop/projects/VT-UNet-main/train.py in 300 arguments = parser.parse_args() 301 os.environ['CUDA_VISIBLE_DEVICES'] = arguments.devices --> 302 main(arguments)

    ~/Desktop/projects/VT-UNet-main/train.py in main(args) 109 optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) 110 --> 111 full_train_dataset, l_val_dataset, bench_dataset = get_datasets(args.seed, fold_number=args.fold) 112 train_loader = torch.utils.data.DataLoader(full_train_dataset, batch_size=args.batch_size, shuffle=True, 113 num_workers=args.workers, pin_memory=True, drop_last=True)

    ~/Desktop/projects/VT-UNet-main/dataset/brats.py in get_datasets(seed, on, fold_number, normalisation) 90 base_folder = pathlib.Path(get_brats_folder(on)).resolve() 91 print(base_folder) ---> 92 assert base_folder.exists() 93 patients_dir = sorted([x for x in base_folder.iterdir() if x.is_dir()]) 94

    AssertionError:

    opened by shathaa1983 3
  • About Results on BraTS

    About Results on BraTS

    The paper <A Volumetric Transformer for Accurate 3D Tumor Segmentation> evaluates the VT-UNet's performance on BraTS 2021 dataset, but the MICCAI paper <A Robust Volumetric Transformer for Accurate 3D Tumor Segmentation> evaluates the VT-UNet on MSD BraTS task. Why not evaluate and report the performance of VT-UNet on BraTS 2021 any more in the MICCAI paper?

    opened by auroua 2
  • how to predict on Brats Validation case?

    how to predict on Brats Validation case?

    First of all, thank you for your excellent work,i want to ask you a question,how to use this cood to predict a case.Brats Validation dataset only have four files,it don't have seg file, i want to use this code to predict on every validation case,looking forward to your reply.

    opened by wangnanv5 2
  • Data processing questions about Brats21

    Data processing questions about Brats21

    Thank you very much for your excellent work, I noticed that in your Brats.py code, the fusion of different tags into ET, TC, WT also occurs in many data processing examples, such as Monai. But I see that your paper is still classified and evaluated according to the label of the competition, right? Looking forward to your reply.

    opened by Breeze-Zero 2
  • Inference failed due to missing posporcessing.json file

    Inference failed due to missing posporcessing.json file

    Dear author, I find your paper very interesting and I'm new in the medical imaging field. I am trying to reproduce your model results for version 1 and version 2.

    In version 1, I couldn't find the HD95 code, therefore I copied it from here. It gives the wrong HD95 during evaluation, can you please provide this version 1.

    In version 2, when I run the inference code it gives the error shown in the figure. To resolve this issue I run consolidate_postprocessing_simple.py to compute the postprocessing.json file. But this says fold_0, fold_1, fold_2, etc are missing. I trained the model for fold=0 as described in the instructions. Could you please see this?

    image

    opened by mustansarfiaz 1
Owner
Himashi Amanda Peiris
Former Senior Software Engineer at Pearson. Currently a PhD Candidate in Monash University
Himashi Amanda Peiris
Cancer-and-Tumor-Detection-Using-Inception-model - In this repo i am gonna show you how i did cancer/tumor detection in lungs using deep neural networks, specifically here the Inception model by google.

Cancer-and-Tumor-Detection-Using-Inception-model In this repo i am gonna show you how i did cancer/tumor detection in lungs using deep neural networks

Deepak Nandwani 1 Jan 1, 2022
This repo provides the official code for TransBTS: Multimodal Brain Tumor Segmentation Using Transformer (https://arxiv.org/pdf/2103.04430.pdf).

TransBTS: Multimodal Brain Tumor Segmentation Using Transformer This repo is the official implementation for TransBTS: Multimodal Brain Tumor Segmenta

Raymond 247 Dec 28, 2022
Code for "Multi-Compound Transformer for Accurate Biomedical Image Segmentation"

News The code of MCTrans has been released. if you are interested in contributing to the standardization of the medical image analysis community, plea

null 97 Jan 5, 2023
Self-supervised Multi-modal Hybrid Fusion Network for Brain Tumor Segmentation

JBHI-Pytorch This repository contains a reference implementation of the algorithms described in our paper "Self-supervised Multi-modal Hybrid Fusion N

FeiyiFANG 5 Dec 13, 2021
A PyTorch implementation for V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation

A PyTorch implementation of V-Net Vnet is a PyTorch implementation of the paper V-Net: Fully Convolutional Neural Networks for Volumetric Medical Imag

Matthew Macy 606 Dec 21, 2022
Unofficial implementation of Point-Unet: A Context-Aware Point-Based Neural Network for Volumetric Segmentation

Point-Unet This is an unofficial implementation of the MICCAI 2021 paper Point-Unet: A Context-Aware Point-Based Neural Network for Volumetric Segment

Namt0d 9 Dec 7, 2022
Realtime segmentation with ENet, the fast and accurate segmentation net.

Enet This is a realtime segmentation net with almost 22 fps on GTX1080 ti, and the model size is very small with only 28M. This repo contains the infe

JinTian 14 Aug 30, 2022
This repository contains the code for the CVPR 2020 paper "Differentiable Volumetric Rendering: Learning Implicit 3D Representations without 3D Supervision"

Differentiable Volumetric Rendering Paper | Supplementary | Spotlight Video | Blog Entry | Presentation | Interactive Slides | Project Page This repos

null 697 Jan 6, 2023
TumorInsight is a Brain Tumor Detection and Classification model built using RESNET50 architecture.

A Brain Tumor Detection and Classification Model built using RESNET50 architecture. The model is also deployed as a web application using Flask framework.

Pranav Khurana 0 Aug 17, 2021
VGG16 model-based classification project about brain tumor detection.

Brain-Tumor-Classification-with-MRI VGG16 model-based classification project about brain tumor detection. First, you can check what people are doing o

Atakan Erdoğan 2 Mar 21, 2022
VSR-Transformer - This paper proposes a new Transformer for video super-resolution (called VSR-Transformer).

VSR-Transformer By Jiezhang Cao, Yawei Li, Kai Zhang, Luc Van Gool This paper proposes a new Transformer for video super-resolution (called VSR-Transf

Jiezhang Cao 225 Nov 13, 2022
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.4k Dec 30, 2022
Code for ACM MM2021 paper "Complementary Trilateral Decoder for Fast and Accurate Salient Object Detection"

CTDNet The PyTorch code for ACM MM2021 paper "Complementary Trilateral Decoder for Fast and Accurate Salient Object Detection" Requirements Python 3.6

CVTEAM 28 Oct 20, 2022
Combining Automatic Labelers and Expert Annotations for Accurate Radiology Report Labeling Using BERT

CheXbert: Combining Automatic Labelers and Expert Annotations for Accurate Radiology Report Labeling Using BERT CheXbert is an accurate, automated dee

Stanford Machine Learning Group 51 Dec 8, 2022
Python and C++ implementation of "MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation". Accepted at LXCV @ CVPR 2021.

MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation This is a PyTorch and LibTorch implementation of MarkerPose: a

Jhacson Meza 47 Nov 18, 2022
A lightweight deep network for fast and accurate optical flow estimation.

FastFlowNet: A Lightweight Network for Fast Optical Flow Estimation The official PyTorch implementation of FastFlowNet (ICRA 2021). Authors: Lingtong

Tone 161 Jan 3, 2023
VID-Fusion: Robust Visual-Inertial-Dynamics Odometry for Accurate External Force Estimation

VID-Fusion VID-Fusion: Robust Visual-Inertial-Dynamics Odometry for Accurate External Force Estimation Authors: Ziming Ding , Tiankai Yang, Kunyi Zhan

ZJU FAST Lab 86 Nov 18, 2022