Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Overview

Transformer in Transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch.

Install

$ pip install transformer-in-transformer

Usage

import torch
from transformer_in_transformer import TNT

tnt = TNT(
    image_size = 256,       # size of image
    patch_dim = 512,        # dimension of patch token
    pixel_dim = 24,         # dimension of pixel token
    patch_size = 16,        # patch size
    pixel_size = 4,         # pixel size
    depth = 6,              # depth
    num_classes = 1000,     # output number of classes
    attn_dropout = 0.1,     # attention dropout
    ff_dropout = 0.1        # feedforward dropout
)

img = torch.randn(2, 3, 256, 256)
logits = tnt(img) # (2, 1000)

Citations

@misc{han2021transformer,
    title   = {Transformer in Transformer}, 
    author  = {Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
    year    = {2021},
    eprint  = {2103.00112},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • Only works if pixel_size**2 == patch_size?

    Only works if pixel_size**2 == patch_size?

    Hi, is this only supposed to work if

    pixel_size**2 == patch_size 
    

    ?. When setting the patch_size to any number that doesn't fulfill the equation this error occurs:

    --> 146         pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d')
        147 
        148         for pixel_attn, pixel_ff, pixel_to_patch_residual, patch_attn, patch_ff in self.layers:
    
    RuntimeError: The size of tensor a (4) must match the size of tensor b (64) at non-singleton dimension 1
    

    The error came when running

    tnt = TNT(
        image_size = 128,       # size of image
        patch_dim = 256,        # dimension of patch token
        pixel_dim = 24,         # dimension of pixel token
        patch_size = 16,        # patch size
        pixel_size = 2,         # pixel size
        depth = 6,              # depth
        heads = 1,
        num_classes = 2,     # output number of classes
        attn_dropout = 0.1,     # attention dropout
        ff_dropout = 0.1        # feedforward dropout,
    )
    img = torch.randn(2, 3, 128, 128)
    logits = tnt(img)
    

    Since I am completely new to einops its quite hard for me to debug :D Thanks

    opened by PhilippMarquardt 1
  • Not sure what is wrong!

    Not sure what is wrong!


    RuntimeError Traceback (most recent call last) in 14 15 img = torch.randn(1, 3, 256, 256) ---> 16 logits = tnt(img) # (2, 1000)

    ~/opt/anaconda3/envs/ml/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1109 or _global_forward_hooks or _global_forward_pre_hooks): -> 1110 return forward_call(*input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], []

    ~/opt/anaconda3/envs/ml/lib/python3.8/site-packages/transformer_in_transformer/tnt.py in forward(self, x) 159 patches = repeat(self.patch_tokens[:(n + 1)], 'n d -> b n d', b = b) 160 --> 161 patches += rearrange(self.patch_pos_emb[:(n + 1)], 'n d -> () n d') 162 pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d') 163

    RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

    opened by RisabBiswas 0
  • patch_tokens vs patch_pos_emb

    patch_tokens vs patch_pos_emb

    Hi!

    I'm trying to understand your TNT implementation and one thing that got me a bit confused is why there are 2 parameters patch_tokens and patch_pos_emb which seems to have the same purpose - to encode patch position. Isn't one of them redundant?

    self.patch_tokens = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
    self.patch_pos_emb = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
    ...
    patches = repeat(self.patch_tokens[:(n + 1)], 'n d -> b n d', b = b)
    patches += rearrange(self.patch_pos_emb[:(n + 1)], 'n d -> () n d')
    
    opened by stas-sl 0
  • Inconsistent model  params with MindSpore src code

    Inconsistent model params with MindSpore src code

    There's no function or readme description of TNT-S/TNT-B model in this codebase. Something like :

    def tnt_b(num_class):
        return TNT(img_size=384,
                   patch_size=16,
                   num_channels=3,
                   embedding_dim=640,
                   num_heads=10,
                   num_layers=12,
                   hidden_dim=640*4,
                   stride=4,
                   num_class=num_class)
    

    And heads number of inner block should be 4.... https://github.com/lucidrains/transformer-in-transformer/blob/main/transformer_in_transformer/tnt.py#L135

    Wondering if anyone reproduce the paper reported results with this codebase??

    opened by WongChen 0
  • Why the loss become NaN?

    Why the loss become NaN?

    It is a great project. I am very interested in Transformer in Transformer model. I had use your model to train on Vehicle-1M dataset. Vehicle-1M is a fine graied visual classification dataset. When I use this model the loss become NaN after some batch iteration. I had decrease the learning rate of AdamOptimizer and clipping the graident torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2) . But the loss still will become NaN sometimes. It seems that gradients are not big but they are in the same direction for many iterations. How to solve it?

    opened by yt7589 3
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
U-2-Net: U Square Net - Modified for paired image training of style transfer

U2-Net: U Square Net Modified for paired image training of style transfer This is an unofficial repo making use of the code which was made available b

Doron Adler 43 Oct 3, 2022
PAIRED in PyTorch 🔥

PAIRED This codebase provides a PyTorch implementation of Protagonist Antagonist Induced Regret Environment Design (PAIRED), which was first introduce

UCL DARK Lab 46 Dec 12, 2022
FPGA: Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification

FPGA & FreeNet Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification by Zhuo Zheng, Yanfei Zhong, Ailong M

Zhuo Zheng 92 Jan 3, 2023
Code for paper PairRE: Knowledge Graph Embeddings via Paired Relation Vectors.

PairRE Code for paper PairRE: Knowledge Graph Embeddings via Paired Relation Vectors. This implementation of PairRE for Open Graph Benchmak datasets (

Alipay 65 Dec 19, 2022
LLVIP: A Visible-infrared Paired Dataset for Low-light Vision

LLVIP: A Visible-infrared Paired Dataset for Low-light Vision Project | Arxiv | Abstract It is very challenging for various visual tasks such as image

CVSM Group -  email: czhu@bupt.edu.cn 377 Jan 7, 2023
Pytorch Implementation for NeurIPS (oral) paper: Pixel Level Cycle Association: A New Perspective for Domain Adaptive Semantic Segmentation

Pixel-Level Cycle Association This is the Pytorch implementation of our NeurIPS 2020 Oral paper Pixel-Level Cycle Association: A New Perspective for D

null 87 Oct 19, 2022
Simple-Image-Classification - Simple Image Classification Code (PyTorch)

Simple-Image-Classification Simple Image Classification Code (PyTorch) Yechan Kim This repository contains: Python3 / Pytorch code for multi-class ima

Yechan Kim 8 Oct 29, 2022
The implementation of ICASSP 2020 paper "Pixel-level self-paced learning for super-resolution"

Pixel-level Self-Paced Learning for Super-Resolution This is an official implementaion of the paper Pixel-level Self-Paced Learning for Super-Resoluti

Elon Lin 41 Dec 15, 2022
Implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT : Cross-Attention Multi-Scale Vision Transformer for Image Classification This is an unofficial PyTorch implementation of CrossViT: Cross-Att

Rishikesh (ऋषिकेश) 103 Nov 25, 2022
Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv If

International Business Machines 168 Dec 29, 2022
This is an official implementation of "Polarized Self-Attention: Towards High-quality Pixel-wise Regression"

Polarized Self-Attention: Towards High-quality Pixel-wise Regression This is an official implementation of: Huajun Liu, Fuqiang Liu, Xinyi Fan and Don

DeLightCMU 212 Jan 8, 2023
HyperSeg: Patch-wise Hypernetwork for Real-time Semantic Segmentation Official PyTorch Implementation

: We present a novel, real-time, semantic segmentation network in which the encoder both encodes and generates the parameters (weights) of the decoder. Furthermore, to allow maximal adaptivity, the weights at each decoder block vary spatially. For this purpose, we design a new type of hypernetwork, composed of a nested U-Net for drawing higher level context features

Yuval Nirkin 182 Dec 14, 2022
PyTorch implementation of adversarial patch

adversarial-patch PyTorch implementation of adversarial patch This is an implementation of the Adversarial Patch paper. Not official and likely to hav

Jamie Hayes 172 Nov 29, 2022
Per-Pixel Classification is Not All You Need for Semantic Segmentation

MaskFormer: Per-Pixel Classification is Not All You Need for Semantic Segmentation Bowen Cheng, Alexander G. Schwing, Alexander Kirillov [arXiv] [Proj

Facebook Research 1k Jan 8, 2023
Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning, CVPR 2021

Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning By Zhenda Xie*, Yutong Lin*, Zheng Zhang, Yue Ca

Zhenda Xie 293 Dec 20, 2022
Patch2Pix: Epipolar-Guided Pixel-Level Correspondences [CVPR2021]

Patch2Pix for Accurate Image Correspondence Estimation This repository contains the Pytorch implementation of our paper accepted at CVPR2021: Patch2Pi

Qunjie Zhou 199 Nov 29, 2022
Pixel-level Crack Detection From Images Of Levee Systems : A Comparative Study

PIXEL-LEVEL CRACK DETECTION FROM IMAGES OF LEVEE SYSTEMS : A COMPARATIVE STUDY G

Manisha Panta 2 Jul 23, 2022
DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021)

DPT This repo is the official implementation of DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021). We provide code and model

CASIA-IVA-Lab 111 Dec 21, 2022
Image Classification - A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

null 0 Jan 23, 2022