TransMorph: Transformer for Medical Image Registration

Overview

TransMorph: Transformer for Medical Image Registration

arXiv

keywords: Vision Transformer, Swin Transformer, convolutional neural networks, image registration

This is a PyTorch implementation of my paper:

Chen, Junyu, et al. "TransMorph: Transformer for Medical Image Registration. " arXiv, 2021.

TransMorph

TransMorph DIR Variants:

There are four TransMorph variants: TransMorph, TransMorph-diff, TransMorph-bspl, and TransMorph-Bayes.
Training and inference scripts are in TransMorph/, and the models are contained in TransMorph/model/.

  1. TransMorph: A hybrid Transformer-ConvNet network for image registration.
  2. TransMorph-diff: A probabilistic TransMorph that ensures a diffeomorphism.
  3. TransMorph-bspl: A B-spline TransMorph that ensures a diffeomorphism.
  4. TransMorph-Bayes: A Bayesian uncerntainty TransMorph that produces registration uncertainty estimate.

TransMorph Affine Model:

The scripts for TransMorph affine model are in TransMorph_affine/ folder.

train_xxx.py and infer_xxx.py are the training and inference scripts for TransMorph models.

Baseline Models:

We compared TransMorph with eight baseline registration methods + four Transformer architectures.
Baseline registration methods:

  1. SyN (ATNsPy)
  2. NiftyReg
  3. LDDMM
  4. deedsBCV
  5. VoxelMorph-1 & -2
  6. CycleMorph
  7. MIDIR

Baseline Transformer architectures:

  1. PVT
  2. nnFormer
  3. CoTr
  4. ViT-V-Net

Training and inference scripts for the baseline models will be available in the near future!

Dataset:

Due to restrictions, we cannot distribute our brain MRI data. However, several brain MRI datasets are publicly available online: IXI, ADNI, OASIS, ABIDE, etc. Note that those datasets may not contain labels (segmentation). To generate labels, you can use FreeSurfer, which is an open-source software for normalizing brain MRI images. Here are some useful commands in FreeSurfer: Brain MRI preprocessing and subcortical segmentation using FreeSurfer.

Citation:

If you find this code is useful in your research, please consider to cite:

@misc{chen2021transmorph,
title={TransMorph: Transformer for Medical Image Registration}, 
author={Junyu Chen and Yufan He and Eric C. Frey and Ye Li and Yong Du},
year={2021},
eprint={2111.10480},
archivePrefix={arXiv},
primaryClass={eess.IV}
}

TransMorph Architecture:

Example Results:

Qualitative comparisons:

Uncertainty Estimate by TransMorph-Bayes:

Quantitative Results:

Inter-patient Brain MRI:

XCAT-to-CT:

Reference:

Swin Transformer
easyreg
MIDIR
VoxelMorph

About Me

Comments
  • Questions about the test results of iXI dataset

    Questions about the test results of iXI dataset

    We are having a bit of a problem running the code you provided with IXI dataset.

    We loaded the 0.744 pre-trained model of the IXI dataset you provided, but the result was only around 0.6.

    Just wondering whether there is any preprocessing required for the IXI dataset。

    result

    Thanks.

    opened by zhujunkun 11
  • Deformation field not smooth

    Deformation field not smooth

    Hi Chen,

    First off, thank you for making the code and datasets for OASIS available! I have been trying to apply the pre-trained model on the OASIS data you provided. For the output, I convert the deformation field in the .npz file into a .nii file in order to visualize the result in ITK-SNAP. However, the result does not seem very smooth which is frustrating.

    Attached is the screenshot of the deformation field: Screenshot from 2022-04-01 16-07-56

    I am wondering what is causing this issue. Thank you in advance for your help and I look forward to hear back from you!

    opened by kvttt 6
  •  The training loss of the affine registration network keeps going up

    The training loss of the affine registration network keeps going up

    Thank you so much for sharing your outstanding work, which has helped me a lot. But when I use my own data to train the affine registration network, the training loss keeps going up. I have checked my training input and there is no problem. I set batch size to 1 and the learning rate to 0.00001. Do you know what the problem might be?

    opened by XiaotianJia 5
  • Training problems on the IXI dataset

    Training problems on the IXI dataset

    Hi,

    I encountered some difficulties in training the model, I tried to use the preprocessed IXI dataset for model training, including "VoxelMorph", "CycleMorph", "TransMorph", but all encountered the same problem: "ValueError: num_samples should be a positive integer value, But got num_samples=0 "This problem appears in the train.py file" train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, Num_workers =4, pin_memory=True) ", I found that I could fix this by setting shuffle=Flase, but then a new problem appeared: "UnboundLocalError: Local variable 'def_out' referenced Before Assignment ", this problem appears in the train.py file "pred_fig = comput_fig(def_out)". I can't deal with this new problem. I would like to ask you for help. Maybe I cannot modify the shuffle parameter in the source program. What is the cause and solution of the first problem? I've been using both my laptop and workstation, so I rule out the possibility of hardware causing the problem.

    Best

    Yongxin

    image

    image

    image

    image

    opened by Yongxin1120 4
  • Deformation field close to zero.

    Deformation field close to zero.

    Hi Junyu,

    Sorry to bother you again. Last time you suggested that I look at the diffeomorphic variants. Lately, I have been specifically looking at VoxelMorph-diff among the baseline models you included. However, I have observed that after the training converges, the resulted deformation fields are very close to zero, i.e. the intensity at each voxel falls under the range of around 0.001 to 0.02. When I visualize the deformation field in ITK-SNAP (see attached image), the grid lines seem straight which is drastically different from the deformation fields produced by non-diffeomorphic variants. I am wondering if this is normal.

    Screen Shot 2022-06-07 at 6 51 33 AM

    At first, I suspected that I am saving the stationary velocity field (SVF) instead of the deformation field but that is not the case. I also suspected that there is something wrong with the "scaling and squaring" integrating step. However, the deformation field is still close to zero even after I replace the for-loop with the VecInt function here: https://github.com/voxelmorph/voxelmorph/blob/a746f77098962da1be9e6a03dacc3ef6c90d5244/voxelmorph/torch/layers.py#L51-L68.

    After all these attempts to understand what causes such behavior of the model, I assume that the diffeomorphic variants such as VoxelMorph-diff tend to produce close to zero deformation fields. I would really appreciate it if you could provide some insight on this or confirm my observation.

    Again, I enjoyed reading your paper and your code!

    opened by kvttt 4
  • Problem in the val step. Could not get enough values to unpack in pkload

    Problem in the val step. Could not get enough values to unpack in pkload

    Hi, junyuchen.

    I down load the dataset from Google Drive (1.44G). And I run the train.py.

    The training phase was successful. But it fails in the validation phase: When running into the getitem of class JHUBrainInferDataset , it fails here: x, y, x_seg, y_seg = pkload(path).

    I can only get 2 values from pkload(). Not 4. The above line would throws an exception.

    opened by clinton81 4
  • Evaluation in datasets oriented according to MNI152

    Evaluation in datasets oriented according to MNI152

    Dear author,

    I have tried to evaluate the models trained on IXI datasets on a dataset that was initially oriented according to MNI152 space. Since the models have been trained with images oriented to OASIS space, should I notice a downgrade of the performance due to this or the method should provide competitive results no matter this initial alignment?

    Best regards.

    opened by moniquisma 3
  • Visualization of SyN's deformation field

    Visualization of SyN's deformation field

    Hi, @junyuchen245

    I'm sorry to trouble you, I have some questions on how visual the deformation flow generated by SyN algotirhm.

    My thought is to save the deformation flow file and visual this flow by using ITK-SNAP, but the result seems have something wrong.

    This is my main code for saving deformation flow: f_img_set = sitk.ReadImage("./fixed.nii.gz") f_img = ants.image_read("./fixed.nii.gz") m_img = ants.image_read("./moving.nii.gz") mytx = ants.registration(fixed=f_img, moving=m_img, type_of_transform='SyN') flow = np.array(nib_load(mytx['fwdtransforms'][0]), dtype='float32', order='C') flow_save = torch.from_numpy(flow).to(device).float() save_image(flow_save.permute(3, 0, 1, 2, 4)[np.newaxis, ...], f_img_set, fileName + "_flow.nii.gz") where Function save_image() is defined as follows: def save_image(img, ref_img, name):     img = sitk.GetImageFromArray(img[0, 0, ...].cpu().detach().numpy())     img.SetOrigin(ref_img.GetOrigin())     img.SetDirection(ref_img.GetDirection())     img.SetSpacing(ref_img.GetSpacing())     sitk.WriteImage(img, os.path.join('./result/', name))

    After this operation, I found that the saved deformation flow(flow.nii.gz) is a little bit strange, just like: Snipaste_2022-04-04_19-03-27 So, nextly, I try to save this deformation flow, and try to warp the moving image by using this saved flow and STN function(I am confident that this function will work correctly for deformable flow generated by deep learning-based methods), but I found that the generated warped image is different from warped image generated by ANTs, just like: Snipaste_2022-04-04_19-01-20 Left is warped image generated by saved deformation flow, right is warped image generated by ANTs function ants.apply_transforms

    I have no idea what the problem is although I have tried many methods, so I come to you for help, thank you!

    Yours, Tzayuan

    opened by tzayuan 3
  • nnf.pad error

    nnf.pad error

    Hi Junyu,

    Thanks for sharing your interesting work! It does help a lot for my project. However, I found some tiny errors related to nnf.pad function in your code. So I made the below modifications:

    x = nnf.pad(x, (0, 0, pad_f, pad_h, pad_t, pad_b, pad_l, pad_r)) # Line 219
    # x = nnf.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h))
            
    if pad_input: #Line 285
        x = nnf.pad(x, (0, 0, 0, T % 2, 0, W % 2, 0, H % 2))
        # x = nnf.pad(x, (0, 0, 0, W % 2, 0, H % 2, 0, T % 2))
    
    _, _, H, W, T = x.size() # Line 439
    if T % self.patch_size[2] != 0:
        x = nnf.pad(x, (0, self.patch_size[2] - T % self.patch_size[2]))
    if W % self.patch_size[1] != 0:
        x = nnf.pad(x, (0, 0, 0, self.patch_size[1] - W % self.patch_size[1]))
    if H % self.patch_size[0] != 0:
        x = nnf.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
    # if W % self.patch_size[1] != 0:
    #     x = nnf.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
    # if H % self.patch_size[0] != 0:
    #     x = nnf.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
    # if T % self.patch_size[0] != 0:
    #     x = nnf.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - T % self.patch_size[0]))
    

    Best, Woaka

    opened by wangkaiwan 3
  • Keep getting Runtime error of out of memory while reproducing the results on IXI datasets.

    Keep getting Runtime error of out of memory while reproducing the results on IXI datasets.

    RuntimeError: CUDA out of memory. Tried to allocate 420.00 MiB (GPU 1; 11.91 GiB total capacity; 10.87 GiB already allocated; 316.25 MiB free; 11.07 GiB reserved in total by PyTorch)

    I keep getting the above error. I tried freeing the cache and tried to print out the memory usage summary, but don't understand what does each type mean,

    |===========================================================================|
    |                  PyTorch CUDA memory summary, device ID 1                 |
    |---------------------------------------------------------------------------|
    |            CUDA OOMs: 1            |        cudaMalloc retries: 1         |
    |===========================================================================|
    |        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
    |---------------------------------------------------------------------------|
    | Allocated memory      |   10656 MB |   10932 MB |   19094 MB |    8437 MB |
    |       from large pool |   10627 MB |   10903 MB |   19040 MB |    8412 MB |
    |       from small pool |      29 MB |      31 MB |      53 MB |      24 MB |
    |---------------------------------------------------------------------------|
    | Active memory         |   10656 MB |   10932 MB |   19094 MB |    8437 MB |
    |       from large pool |   10627 MB |   10903 MB |   19040 MB |    8412 MB |
    |       from small pool |      29 MB |      31 MB |      53 MB |      24 MB |
    |---------------------------------------------------------------------------|
    | GPU reserved memory   |   11328 MB |   11328 MB |   11328 MB |       0 B  |
    |       from large pool |   11294 MB |   11294 MB |   11294 MB |       0 B  |
    |       from small pool |      34 MB |      34 MB |      34 MB |       0 B  |
    |---------------------------------------------------------------------------|
    | Non-releasable memory |  322610 KB |     940 MB |    4760 MB |    4445 MB |
    |       from large pool |  317682 KB |     936 MB |    4707 MB |    4396 MB |
    |       from small pool |    4928 KB |       4 MB |      53 MB |      48 MB |
    |---------------------------------------------------------------------------|
    | Allocations           |     449    |     456    |     845    |     396    |
    |       from large pool |     192    |     198    |     405    |     213    |
    |       from small pool |     257    |     258    |     440    |     183    |
    |---------------------------------------------------------------------------|
    | Active allocs         |     449    |     456    |     845    |     396    |
    |       from large pool |     192    |     198    |     405    |     213    |
    |       from small pool |     257    |     258    |     440    |     183    |
    |---------------------------------------------------------------------------|
    | GPU reserved segments |     107    |     107    |     107    |       0    |
    |       from large pool |      90    |      90    |      90    |       0    |
    |       from small pool |      17    |      17    |      17    |       0    |
    |---------------------------------------------------------------------------|
    | Non-releasable allocs |      49    |      50    |     345    |     296    |
    |       from large pool |      37    |      37    |     203    |     166    |
    |       from small pool |      12    |      14    |     142    |     130    |
    |===========================================================================|
    

    I was specifically running train_TransMorph.py. One suggestion was to reduce the batch size, but it's already set to 1. It might be possible to delete and collect the memory of unused variables and a few other things suggested in the PyTorch forum, but I am not yet confident to changing the training loop.

    There's also one issue - the inability of allocating fragmented blocks, fixed in https://github.com/pytorch/pytorch/pull/44742. I am not quite sure in which PyTorch version this is fixed, following up more on that.

    However, meanwhile, any thoughts on how to resolve this or any other thoughts on a workaround? Also, is it possible to know how much peak memory would be needed while training?

    Thanks

    opened by bhosalems 3
  • Error

    Error

    Hi,

    Great paper and thanks for sharing your source code ! It works when using the default input size but I am getting an error when I am trying to use different input sizes. Note that I change the size also in the config file. Do you have any idea what I am doing wrong ?

    Example:

    #x = torch.rand((1,2, 160, 192, 224)) x = torch.rand((1,2, 192, 192, 208))

    out, flow = model.forward(x)

    print(out.size()) print(flow.size())

    Error: x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], L // window_size[2], window_size[2], C) RuntimeError: shape '[1, 10, 5, 8, 6, 7, 7, 6]' is invalid for input of size 808704

    question 
    opened by ppegiosk 3
  • About the hyperparameter for CT image registration.

    About the hyperparameter for CT image registration.

    Hi, I would like to ask what the hyperparameter of regularization used when using VoxelMorph as well as TransMorph for CT image registration is. What is the similarity loss function? Looking forward to your answer.

    opened by MingR-Ma 0
  • RaFD dataset preprocessing

    RaFD dataset preprocessing

    Hi,

    Thanks for your wonderful work. I'd like to use the TransMorph as a baseline. I noticed that you used the PKL file in your code (RaFD -> TransMorph2D -> train_TransMorph.py). My question is: how do you preprocess the raw RaFD dataset into the pkl file format required by your code?

    Hope to your reply. You can also send me the preprocessing code to my email ([email protected]). Thank you.

    Jun

    opened by Jun-electrophysiology 2
  • Problems of running traditional registration methods such as NiftyReg and SyNc on the IXI dataset

    Problems of running traditional registration methods such as NiftyReg and SyNc on the IXI dataset

    Dear author Hello, I am interested in the effect of traditional registration methods such as NiftyReg and SyNc provided by your team on the IXI dataset, but I have encountered some difficulties when trying to reproduce. Could you please provide a running example to let me know how to run the traditional registration method codes such as NiftyReg and SyNc provided by you? (These are not in the help document)

    opened by Yongxin1120 0
  • Memory problems during model training and testing

    Memory problems during model training and testing

    Hi Junyu

    Thanks for your help. Recently we encountered some memory problems during model training and testing. As shown in the picture, on our workstation, we are using an NVIDIA RTX A2000 graphics card with 12GB of video memory, but it still prompts that there is not enough memory. I would like to ask what configuration of video card is required to run this model, and whether there is a way to reduce the amount of video memory needed to run this model.

    Best regards

    Yongxin

    cece5c87b0e6a8c987423515c414d14
    opened by Yongxin1120 0
  • something wrong in code

    something wrong in code

    first of all, Thank you for sharing your code.

    I'm trying to image registering in public medical dataset

    https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=39879146

    I want to compare how different it is from commercial software when deep learning and classical dir are applied to real rt data, not phantom.

    I created a dataset with this instructions

    https://github.com/junyuchen245/Preprocessed_IXI_Dataset

    I converted the image resolution of (512,512,200~180) ct image to (256,256,160) cause gpu memory.(A6000*2)

    and modified dataset.py, utils.py, train.py, infer.py, and config.py inside each DIR algorithm folder. modified file.pptx

    Therefore, each model is producing results, but the traditional method has better results.

    I'm confused, so I post a question.

    Is there anything else that needs to be modified in the code?

    opened by nightandweather 0
Owner
Junyu Chen
Ph.D. candidate in the Department of Electrical and Computer Engineering & the Department of Radiology and Radiological Science @ Johns Hopkins University
Junyu Chen
Registration Loss Learning for Deep Probabilistic Point Set Registration

RLLReg This repository contains a Pytorch implementation of the point set registration method RLLReg. Details about the method can be found in the 3DV

Felix Järemo Lawin 35 Nov 2, 2022
The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

MIC-DKFZ 1.2k Jan 4, 2023
Build a medical knowledge graph based on Unified Language Medical System (UMLS)

UMLS-Graph Build a medical knowledge graph based on Unified Language Medical System (UMLS) Requisite Install MySQL Server 5.6 and import UMLS data int

Donghua Chen 6 Dec 25, 2022
CoTr: Efficiently Bridging CNN and Transformer for 3D Medical Image Segmentation

CoTr: Efficient 3D Medical Image Segmentation by bridging CNN and Transformer This is the official pytorch implementation of the CoTr: Paper: CoTr: Ef

null 218 Dec 25, 2022
The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"

Swin-Unet The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"(https://arxiv.org/abs/2105.05537). A validatio

null 869 Jan 7, 2023
MISSFormer: An Effective Medical Image Segmentation Transformer

MISSFormer Code for paper "MISSFormer: An Effective Medical Image Segmentation Transformer". Please read our preprint at the following link: paper_add

Fong 22 Dec 24, 2022
GeoTransformer - Geometric Transformer for Fast and Robust Point Cloud Registration

Geometric Transformer for Fast and Robust Point Cloud Registration PyTorch imple

Zheng Qin 220 Jan 5, 2023
Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

ImageProcessingTransformer Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

null 61 Jan 1, 2023
Breaking the Dilemma of Medical Image-to-image Translation

Breaking the Dilemma of Medical Image-to-image Translation Supervised Pix2Pix and unsupervised Cycle-consistency are two modes that dominate the field

Kid Liet 86 Dec 21, 2022
Copy Paste positive polyp using poisson image blending for medical image segmentation

Copy Paste positive polyp using poisson image blending for medical image segmentation According poisson image blending I've completely used it for bio

Phạm Vũ Hùng 2 Oct 19, 2021
VSR-Transformer - This paper proposes a new Transformer for video super-resolution (called VSR-Transformer).

VSR-Transformer By Jiezhang Cao, Yawei Li, Kai Zhang, Luc Van Gool This paper proposes a new Transformer for video super-resolution (called VSR-Transf

Jiezhang Cao 225 Nov 13, 2022
Framework for joint representation learning, evaluation through multimodal registration and comparison with image translation based approaches

CoMIR: Contrastive Multimodal Image Representation for Registration Framework ?? Registration of images in different modalities with Deep Learning ??

Methods for Image Data Analysis - MIDA 55 Dec 9, 2022
Deep learning image registration library for PyTorch

TorchIR: Pytorch Image Registration TorchIR is a image registration library for deep learning image registration (DLIR). I have integrated several ide

Bob de Vos 40 Dec 16, 2022
Automated image registration. Registrationimation was too much of a mouthful.

alignimation Automated image registration. Registrationimation was too much of a mouthful. This repo contains the code used for my blog post Alignimat

Ethan Rosenthal 9 Oct 13, 2022
A multi-scale unsupervised learning for deformable image registration

A multi-scale unsupervised learning for deformable image registration Shuwei Shao, Zhongcai Pei, Weihai Chen, Wentao Zhu, Xingming Wu and Baochang Zha

ShuweiShao 2 Apr 13, 2022
A variational Bayesian method for similarity learning in non-rigid image registration (CVPR 2022)

A variational Bayesian method for similarity learning in non-rigid image registration We provide the source code and the trained models used in the re

daniel grzech 14 Nov 21, 2022
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation

TransUNet This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation Usage

null 1.4k Jan 4, 2023
[CVPR'21] FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space

FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space by Quande Liu, Cheng Chen, Ji

Quande Liu 178 Jan 6, 2023