COVID-VIT: Classification of Covid-19 from CT chest images based on vision transformer models

Overview

COVID-ViT

COVID-VIT: Classification of Covid-19 from CT chest images based on vision transformer models

This code is to response to te MIA-COV19 competition on classification of covid from non-covid chest volumetric CT datasets.

Pre-trained models for ViT and DenseNet can be download from https://drive.google.com/drive/folders/1nBI02F-8Y0hFeN10CMu9Svj4xRViN5xt?usp=sharing.

Both 2D and 3D versions of training and test code are provided. It appears classificaiton based on 2D slices performs better. The final score is subject based, i.e. for a dataset, if more than 25% or more slices are classfied as COVID, then this subject has COVID. Otherwise, the patient in concern will be classified as normal. This threshold (e.g 25%) can be determined from validation stage.

The ViT is heavily based on vit-pytorch at https://github.com/lucidrains/vit-pytorch and is in the form of both notebook and python.

The DenseNet-CT is built upon https://github.com/UCSD-AI4H/COVID-CT.

More details are at the paper ar Arxiv (https://arxiv.org/) with the following information: "Xiaohong Gao, Yu Qian, Alice Gao, COVID-VIT: Classification of Covid-19 from CT chest images based on vision transformer models"

Comments
  • xg_vit_model_covid_2d.pt

    xg_vit_model_covid_2d.pt

    https://drive.google.com/drive/folders/1nBI02F-8Y0hFeN10CMu9Svj4xRViN5xt 这个链接里面有xg_vit_model_covid_2d.pt吗?我只看到了xg_vit_model_covid_3d.pt和xg_vit_model_covid.pt 我下载了后者发现不对,请问哪里有xg_vit_model_covid_2d.pt吗?

    opened by antoniaaaaaaaaaaaa 1
  • Loading pretrained weights into the 2D and 3D models

    Loading pretrained weights into the 2D and 3D models

    Hello, I got this error when loading the weights with load_state_dict

    RuntimeError: Error(s) in loading state_dict for ViT3: size mismatch for pos_embedding: copying a param with shape torch.Size([1, 3137, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1569, 1024]).

    what's wrong ? also same problem with the 2D model

    opened by AsmaBaccouche 0
  • Unable to achieve 76.6% accuracy

    Unable to achieve 76.6% accuracy

    I managed to run your code and start the training on the pre-trained model however, I am getting the same results (about 50% accuracy) as shown in the jupyter notebook

     0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 1 - loss : 0.6947 - acc: 0.4957 - val_loss : 0.6930 - val_acc: 0.5084
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 2 - loss : 0.6947 - acc: 0.4881 - val_loss : 0.6930 - val_acc: 0.5084
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 3 - loss : 0.6948 - acc: 0.5003 - val_loss : 0.6942 - val_acc: 0.5088
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 4 - loss : 0.6941 - acc: 0.5060 - val_loss : 0.6931 - val_acc: 0.5088
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 5 - loss : 0.6951 - acc: 0.4868 - val_loss : 0.6934 - val_acc: 0.5093
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 6 - loss : 0.6944 - acc: 0.5146 - val_loss : 0.6936 - val_acc: 0.4912
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 7 - loss : 0.6947 - acc: 0.4924 - val_loss : 0.6935 - val_acc: 0.4907
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 8 - loss : 0.6949 - acc: 0.4954 - val_loss : 0.6930 - val_acc: 0.5093
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 9 - loss : 0.6945 - acc: 0.5010 - val_loss : 0.6966 - val_acc: 0.4921
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 10 - loss : 0.6949 - acc: 0.4874 - val_loss : 0.6934 - val_acc: 0.5093
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 11 - loss : 0.6941 - acc: 0.5056 - val_loss : 0.6971 - val_acc: 0.5084
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 12 - loss : 0.6946 - acc: 0.5023 - val_loss : 0.6949 - val_acc: 0.4907
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 13 - loss : 0.6945 - acc: 0.4954 - val_loss : 0.6933 - val_acc: 0.4916
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 14 - loss : 0.6942 - acc: 0.5030 - val_loss : 0.6958 - val_acc: 0.4907
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 15 - loss : 0.6935 - acc: 0.5126 - val_loss : 0.6965 - val_acc: 0.5079
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 16 - loss : 0.6957 - acc: 0.4967 - val_loss : 0.6935 - val_acc: 0.4907
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 17 - loss : 0.6941 - acc: 0.5023 - val_loss : 0.6932 - val_acc: 0.5088
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 18 - loss : 0.6948 - acc: 0.4973 - val_loss : 0.6930 - val_acc: 0.5084
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 19 - loss : 0.6936 - acc: 0.5053 - val_loss : 0.6957 - val_acc: 0.4912
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 20 - loss : 0.6945 - acc: 0.4904 - val_loss : 0.6934 - val_acc: 0.5079
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 21 - loss : 0.6943 - acc: 0.4940 - val_loss : 0.6931 - val_acc: 0.5088
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 22 - loss : 0.6950 - acc: 0.4957 - val_loss : 0.6941 - val_acc: 0.5093
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 23 - loss : 0.6942 - acc: 0.4930 - val_loss : 0.6937 - val_acc: 0.4912
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 24 - loss : 0.6942 - acc: 0.4950 - val_loss : 0.6930 - val_acc: 0.5079
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 25 - loss : 0.6947 - acc: 0.4957 - val_loss : 0.6930 - val_acc: 0.5079
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 26 - loss : 0.6939 - acc: 0.4904 - val_loss : 0.6972 - val_acc: 0.5084
    
      0%|          | 0/755 [00:00<?, ?it/s]
    Epoch : 27 - loss : 0.6941 - acc: 0.5070 - val_loss : 0.6930 - val_acc: 0.5088
    
      0%|          | 0/755 [00:00<?, ?it/s]
    

    Can you let me know what changes are required to be done to achieve 76.6% accuracy as mentioned in the paper?

    opened by nabeel3133 2
  • Invalid number of patches for 32x224x224

    Invalid number of patches for 32x224x224

    In file vit_pytorch/vit_3D.py on line number 90, the formula for the calculation of the number of patches is invalid for a data sample of size 32x224x224. When we call the rearrange function on line number 121, we get a tensor of size [4, 3136, 512], meaning that the number of patches should be 3136. The formula for calculating the number of patches on line number 90 gives us: num_patches = (image_size // patch_size) ** 2 * 2 = (224//8)**2 * 2 = 1568 You can change the formula to: num_patches = (image_size // patch_size) ** 2 * 4 = (224//8)**2 * 2 = 3136

    opened by nabeel3133 0
  • Questions about preprocessing images

    Questions about preprocessing images

    Hello, Thank you very much for sharing your code on Github. I am a college student who just started to learn medical image processing. Now I have some problems and I hope to get help from the author. Thank you very much. I added unpreprocessed images to the data set, and found that the training effect was not very good, and some images could not even run. When reading your paper, I noticed that it took 12 hours to preprocess the images. I hope you can provide all the preprocessed images. Or hope you can provide a way to preprocess the image. You can contact me through my email: [email protected] Thanks!

    opened by whiteBAI-97 2
Owner
null
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.2k Nov 28, 2022
So-ViT: Mind Visual Tokens for Vision Transformer

So-ViT: Mind Visual Tokens for Vision Transformer        Introduction This repository contains the source code under PyTorch framework and models trai

Jiangtao Xie 45 Nov 22, 2022
A PyTorch Implementation of ViT (Vision Transformer)

ViT - Vision Transformer This is an implementation of ViT - Vision Transformer by Google Research Team through the paper "An Image is Worth 16x16 Word

Quan Nguyen 7 May 11, 2022
Official implement of Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer

Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer This repository contains the PyTorch code for Evo-ViT. This work proposes a slow-fas

YifanXu 52 Nov 21, 2022
Implementing Vision Transformer (ViT) in PyTorch

Lightning-Hydra-Template A clean and scalable template to kickstart your deep learning project ?? ⚡ ?? Click on Use this template to initialize new re

null 2 Dec 24, 2021
TorchXRayVision: A library of chest X-ray datasets and models.

torchxrayvision A library for chest X-ray datasets and models. Including pre-trained models. ( ?? promo video about the project) Motivation: While the

Machine Learning and Medicine Lab 552 Nov 23, 2022
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

null 1 Dec 24, 2021
PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-supervised ViT.

MAE for Self-supervised ViT Introduction This is an unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-sup

null 36 Oct 30, 2022
As-ViT: Auto-scaling Vision Transformers without Training

As-ViT: Auto-scaling Vision Transformers without Training [PDF] Wuyang Chen, Wei Huang, Xianzhi Du, Xiaodan Song, Zhangyang Wang, Denny Zhou In ICLR 2

VITA 68 Sep 5, 2022
This project uses ViT to perform image classification tasks on DATA set CIFAR10.

Vision-Transformer-Multiprocess-DistributedDataParallel-Apex Introduction This project uses ViT to perform image classification tasks on DATA set CIFA

Kaicheng Yang 3 Jun 3, 2022
vit for few-shot classification

Few-Shot ViT Requirements PyTorch (>= 1.9) TorchVision timm (latest) einops tqdm numpy scikit-learn scipy argparse tensorboardx Pretrained Checkpoints

Martin Dong 23 Nov 20, 2022
Alex Pashevich 61 Nov 17, 2022
An image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testingAn image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testing

SVM Données Une base d’images contient 490 images pour l’apprentissage (400 voitures et 90 bateaux), et encore 21 images pour fait des tests. Prétrait

Achraf Rahouti 3 Nov 30, 2021
This is the official Pytorch implementation of "Lung Segmentation from Chest X-rays using Variational Data Imputation", Raghavendra Selvan et al. 2020

README This is the official Pytorch implementation of "Lung Segmentation from Chest X-rays using Variational Data Imputation", Raghavendra Selvan et a

Raghav 41 Nov 21, 2022
A simple rest api that classifies pneumonia infection weather it is Normal, Pneumonia Virus or Pneumonia Bacteria from a chest-x-ray image.

This is a simple rest api that classifies pneumonia infection weather it is Normal, Pneumonia Virus or Pneumonia Bacteria from a chest-x-ray image.

crispengari 3 Jan 8, 2022
Multi-Scale Vision Longformer: A New Vision Transformer for High-Resolution Image Encoding

Vision Longformer This project provides the source code for the vision longformer paper. Multi-Scale Vision Longformer: A New Vision Transformer for H

Microsoft 205 Nov 12, 2022
Implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT : Cross-Attention Multi-Scale Vision Transformer for Image Classification This is an unofficial PyTorch implementation of CrossViT: Cross-Att

Rishikesh (ऋषिकेश) 103 Nov 25, 2022
Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv If

International Business Machines 163 Nov 13, 2022
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.4k Nov 18, 2022