A PyTorch implementation of "CoAtNet: Marrying Convolution and Attention for All Data Sizes".

Overview

CoAtNet

Overview

This is a PyTorch implementation of CoAtNet specified in "CoAtNet: Marrying Convolution and Attention for All Data Sizes", arXiv 2021.

img

👉 Check out MobileViT if you are interested in other Convolution + Transformer models.

Usage

import torch
from coatnet import coatnet_0

img = torch.randn(1, 3, 224, 224)
net = coatnet_0()
out = net(img)

Try out other block combinations mentioned in the paper:

from coatnet import CoAtNet

num_blocks = [2, 2, 3, 5, 2]            # L
channels = [64, 96, 192, 384, 768]      # D
block_types=['C', 'T', 'T', 'T']        # 'C' for MBConv, 'T' for Transformer

net = CoAtNet((224, 224), 3, num_blocks, channels, block_types=block_types)
out = net(img)

Citation

@article{dai2021coatnet,
  title={CoAtNet: Marrying Convolution and Attention for All Data Sizes},
  author={Dai, Zihang and Liu, Hanxiao and Le, Quoc V and Tan, Mingxing},
  journal={arXiv preprint arXiv:2106.04803},
  year={2021}
}

Credits

Code adapted from MobileNetV2 and ViT.

Comments
  • An error occurs when an image of 512 size is given as input.

    An error occurs when an image of 512 size is given as input.

    Hello.I really aprreciate for your project.

    However, The following error occurs when a 512-size image is input at Attention class.

    dots = dots + relative_bias RuntimeError: The size of tensor a (1024) must match the size of tensor b (196) at non-singleton dimension 3.

    https://github.com/chinhsuanwu/coatnet-pytorch/blob/d3ef1c3e4d6dfcc0b5f731e46774885686062452/coatnet.py#L155

    Why this error is occured? How do I edit your code when I want to resize the image?

    Thank you!

    opened by jeongHwarr 2
  • About the stochastic depth

    About the stochastic depth

    Hi, I can't found the code about stochastic depth in your implementation.

    And I add the stochastic depth code and train a CoAtNet-Tiny on ImageNet 1k, but got 79.27%@top1.

    Have you reproduce the results reported by the paper?

    opened by JiaquanYe 2
  • aboult attention model

    aboult attention model

    屏幕截图 2021-10-25 222552 Hi, I noticed that the value of self.relative_bias_table is always all 0, then the following: relative_bias = self.relative_bias_table.gather( 0, self.relative_index.repeat(1, self.heads)) is actually meaningless (it is all 0)? Thanks!

    opened by mlxu995 2
  • About the # params

    About the # params

    Hey, I tried with your implementation, and I found the calculated #param is a little bit different from the paper, and I am curious about the reason, could you please help me out?

    Take coatnet_0 for example, the calculated result is 17789624 ( 17789624 / 2^20 = 16.97), and the reported #param of the paper is 25M

    Thanks in advance

    opened by HeYDwane3 1
  • About the Wi-j

    About the Wi-j

    Thanks for your sharing. We want to confirm that the relative_coords is learnable parameters or constant in CoatNet?

            coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
            coords = torch.flatten(torch.stack(coords), 1)
    
            relative_coords = coords[:, :, None] - coords[:, None, :]
            relative_coords[0] += self.ih - 1
            relative_coords[1] += self.iw - 1
            relative_coords[0] *= 2 * self.iw - 1
            relative_coords = rearrange(relative_coords, 'c h w -> h w c')
            relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
            self.register_buffer("relative_index", relative_index)
    
    opened by RioLLee 0
  • Models seem not converging.

    Models seem not converging.

    Hi,

    I tried to train CoAtNet_0 with tiny image net from cs231n (200 classes). Seems the model does not converge.

    Could it be that the implementation is not 100% correct? For example, the positional embedding indexing part. I went through the code and I think other components should be correct.

    Except for the pos embedding indexing, I'm not good enough to comprehend it. Do you have a reference for the implementation of the positional embedding indexing part?

    opened by bsun0802 1
  • Hello, first off, really appreciate your work!

    Hello, first off, really appreciate your work!

    Hello, first off, really appreciate your work! Unfortunately, I'm getting overfitting using a custom dataset even in coatnet_0, is there a workaround? Please! Thank you veru much!

    opened by learningelectric 0
  • Any pretrained network for fine-tuning?

    Any pretrained network for fine-tuning?

    Hello,

    I'm training the model from scratch on my custom dataset and the convergence is very slow. So, can you share any pretrained networks if possible?

    opened by yashnsn 2
RealFormer-Pytorch Implementation of RealFormer using pytorch

RealFormer-Pytorch Implementation of RealFormer using pytorch. Includes comparison with classical Transformer on image classification task (ViT) wrt C

Simo Ryu 90 Dec 8, 2022
A PyTorch implementation of the paper Mixup: Beyond Empirical Risk Minimization in PyTorch

Mixup: Beyond Empirical Risk Minimization in PyTorch This is an unofficial PyTorch implementation of mixup: Beyond Empirical Risk Minimization. The co

Harry Yang 121 Dec 17, 2022
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
PyTorch implementation of Advantage async actor-critic Algorithms (A3C) in PyTorch

Advantage async actor-critic Algorithms (A3C) in PyTorch @inproceedings{mnih2016asynchronous, title={Asynchronous methods for deep reinforcement lea

LEI TAI 111 Dec 8, 2022
Pytorch-diffusion - A basic PyTorch implementation of 'Denoising Diffusion Probabilistic Models'

PyTorch implementation of 'Denoising Diffusion Probabilistic Models' This reposi

Arthur Juliani 76 Jan 7, 2023
Fang Zhonghao 13 Nov 19, 2022
RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

RETRO - Pytorch (wip) Implementation of RETRO, Deepmind's Retrieval based Attent

Phil Wang 556 Jan 4, 2023
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives

HashNeRF-pytorch Instant-NGP recently introduced a Multi-resolution Hash Encodin

Yash Sanjay Bhalgat 616 Jan 6, 2023
Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

Luca Moschella 520 Dec 30, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

NVIDIA Corporation 6.9k Jan 3, 2023
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Mayur 119 Nov 24, 2022
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pytorch Lightning 1.4k Jan 1, 2023
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 360 Dec 10, 2022
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Ritchie Ng 9.2k Jan 2, 2023
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 359 Jan 5, 2023
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

Vince 0 Jul 13, 2021
PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

Subin An 8 Nov 21, 2022
PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

PyTorch-LIT PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices. With

Amin Rezaei 157 Dec 11, 2022
A general framework for deep learning experiments under PyTorch based on pytorch-lightning

torchx Torchx is a general framework for deep learning experiments under PyTorch based on pytorch-lightning. TODO list gan-like training wrapper text

Yingtian Liu 6 Mar 17, 2022