Explainability for Vision Transformers (in PyTorch)

Overview

Explainability for Vision Transformers (in PyTorch)

This repository implements methods for explainability in Vision Transformers.

See also https://jacobgil.github.io/deeplearning/vision-transformer-explainability

Currently implemented:

  • Attention Rollout.

  • Gradient Attention Rollout for class specific explainability. This is our attempt to further build upon and improve Attention Rollout.

  • TBD Attention flow is work in progress.

Includes some tweaks and tricks to get it working:

  • Different Attention Head fusion methods,
  • Removing the lowest attentions.

Usage

  • From code
from vit_grad_rollout import VITAttentionGradRollout

model = torch.hub.load('facebookresearch/deit:main', 
'deit_tiny_patch16_224', pretrained=True)
grad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9, head_fusion='max')
mask = grad_rollout(input_tensor, category_index=243)
  • From the command line:
python vit_explain.py --image_path  --head_fusion  --discard_ratio  --category_index 

If category_index isn't specified, Attention Rollout will be used, otherwise Gradient Attention Rollout will be used.

Notice that by default, this uses the 'Tiny' model from Training data-efficient image transformers & distillation through attention hosted on torch hub.

Where did the Transformer pay attention to in this image?

Image Vanilla Attention Rollout With discard_ratio+max fusion

Gradient Attention Rollout for class specific explainability

The Attention that flows in the transformer passes along information belonging to different classes. Gradient roll out lets us see what locations the network paid attention too, but it tells us nothing about if it ended up using those locations for the final classification.

We can multiply the attention with the gradient of the target class output, and take the average among the attention heads (while masking out negative attentions) to keep only attention that contributes to the target category (or categories).

Where does the Transformer see a Dog (category 243), and a Cat (category 282)?

Where does the Transformer see a Musket dog (category 161) and a Parrot (category 87):

Tricks and Tweaks to get this working

Filtering the lowest attentions in every layer

--discard_ratio

Removes noise by keeping the strongest attentions.

Results for dIfferent values:

Different Attention Head Fusions

The Attention Rollout method suggests taking the average attention accross the attention heads,

but emperically it looks like taking the Minimum value, Or the Maximum value combined with --discard_ratio, works better.

--head_fusion

Image Mean Fusion Min Fusion

References

Requirements

pip install timm

Comments
  • Code for Google's ViT and complete example

    Code for Google's ViT and complete example

    Hi @jacobgil!

    Thank you for this amazing piece of work. I was wondering if you plan to open-source the code to try out your experiments on Google's ViT (An Image is Worth ...) as well. If it's already there inside the repo, could you point me to it?

    Update: I was able to use timm and make use of the ViT model it comes with:

    timm_vit_model = timm.create_model('vit_large_patch16_384', pretrained=True)
    timm_vit_model.eval()
    roller = VITAttentionGradRollout(timm_vit_model, discard_ratio=0.9)
    mask = roller(x.unsqueeze(0), label_idx)
    

    However, I am still a bit unsure as to how to actually visualize the mask. Could you help?

    opened by sayakpaul 2
  • No separate `head_fusion` strategies for Gradient Attention Rollout?

    No separate `head_fusion` strategies for Gradient Attention Rollout?

    Should there be a condition to allow users to pass the head_fusion method in grad_rollout()?

    Something like -

    ...
        for attention, grad in zip(attentions, gradients):                
                weights = grad
                if head_fusion == "mean":
                    attention_heads_fused = (attention*weights).mean(axis=1)
                elif head_fusion == "max":
                    attention_heads_fused = (attention*weights).max(axis=1)[0]
                elif head_fusion == "min":
                     attention_heads_fused = (attention*weights).min(axis=1)[0]
                else:
                    raise "Attention head fusion type Not supported"
                
                attention_heads_fused[attention_heads_fused < 0] = 0
    ...
    
    opened by sayakpaul 1
  • What do we do for an image shape is (256, 128)?

    What do we do for an image shape is (256, 128)?

    How do we change the code when the input image's shape is reshaped to, for example, (256, 128). It seems that the code is only satisfied a square image

    opened by Songyihu 0
  • Error in tensor size mismatch

    Error in tensor size mismatch

    Hi @jacobgil, I am using this project for my swin transformers but it is giving error showing

         35             # print("a : ",a)
         36             # print(a.size())
    ---> 37             a = a / a.sum(dim=-1)
         38 
         39             result = torch.matmul(a, result)
    
    RuntimeError: The size of tensor a (49) must match the size of tensor b (64) at non-singleton dimension 1 ```
    
    Please a give a solution for the same
    opened by SURAJ28092001 1
  • cv2 is bgr whereas matplotlib is rgb

    cv2 is bgr whereas matplotlib is rgb

    There's a tiny mistake in the function show_mask_on_image. The heatmap you get from cv2 is in bgr format, so you need to convert it to rgb before adding to the img:

    def show_mask_on_image(img, mask):
        ...
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        cam = heatmap + np.float32(img)
        ...
    
    opened by samrere 0
  • Shape mismatch for timm vit model

    Shape mismatch for timm vit model

    while applying the attention rollout on a finetuned timm vit model (base_patch32_224) I'm am getting the following error with input tensor of shape: torch.Size([1, 3, 224, 224]):

    RuntimeError Traceback (most recent call last) in () ----> 1 mask_1 = attention_rollout(test_image_1_tensor)

    8 frames in reshape_transform(tensor, height, width) 1 def reshape_transform(tensor, height=7, width=7): 2 result = tensor[:, 1 : , :].reshape(tensor.size(0), ----> 3 height, width, tensor.size(2)) 4 5 # Bring the channels to the first dimension,

    RuntimeError: shape '[1, 7, 7, 7]' is invalid for input of size 37583

    Kindly advice on how to properly apply on the model as I'm facing the same issue for FullGrad in [pytorch-grad-cam] on the same model.

    opened by Hammad-Mir 0
  • Evaluating Explanation

    Evaluating Explanation

    Hello @jacobgil ,

    Do u have any recommendations on evaluating explainable methods using tools like Quantus . https://arxiv.org/pdf/2202.06861.pdf

    Best Regards, @jaiswati

    opened by jaiswati 0
  • normalize (sum to 1) attention score seems not right

    normalize (sum to 1) attention score seems not right

    Hi Thanks for sharing nice work.

    I noticed that you've done normalizing attention score (row sum to 1) as mentioned in the original attention rollout paper.

    I = torch.eye(attention_heads_fused.size(-1))
    a = (attention_heads_fused + 1.0*I)/2
    a = a / a.sum(dim=-1)
    

    But it seems when dividing the summation of row attention score, keepdim=True should be apply to ensure that sum of row attention score after normalization should be 1.

    a = a / a.sum(dim=-1,keepdim=True)
    

    Maybe I'm wrong, please double check this issue. Thanks

    opened by jihwanp 0
Owner
Jacob Gildenblat
Machine learning / Computer Vision developer.
Jacob Gildenblat
Code to reproduce experiments in the paper "Explainability Requires Interactivity".

Explainability Requires Interactivity This repository contains the code to train all custom models used in the paper Explainability Requires Interacti

Digital Health & Machine Learning 5 Apr 7, 2022
PyTorch Implementation of CvT: Introducing Convolutions to Vision Transformers

CvT: Introducing Convolutions to Vision Transformers Pytorch implementation of CvT: Introducing Convolutions to Vision Transformers Usage: img = torch

Rishikesh (ऋषिकेश) 193 Jan 3, 2023
PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO

Self-Supervised Vision Transformers with DINO PyTorch implementation and pretrained models for DINO. For details, see Emerging Properties in Self-Supe

Facebook Research 4.2k Jan 3, 2023
This repository contains PyTorch code for Robust Vision Transformers.

This repository contains PyTorch code for Robust Vision Transformers.

null 117 Dec 7, 2022
Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers.

Less is More: Pay Less Attention in Vision Transformers Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers. By

null 73 Jan 1, 2023
PyTorch evaluation code for Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.

Out-of-distribution Generalization Investigation on Vision Transformers This repository contains PyTorch evaluation code for Delving Deep into the Gen

Chongzhi Zhang 72 Dec 13, 2022
A PyTorch implementation of ViTGAN based on paper ViTGAN: Training GANs with Vision Transformers.

ViTGAN: Training GANs with Vision Transformers A PyTorch implementation of ViTGAN based on paper ViTGAN: Training GANs with Vision Transformers. Refer

Hong-Jia Chen 127 Dec 23, 2022
Many Class Activation Map methods implemented in Pytorch for CNNs and Vision Transformers. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM

Class Activation Map methods implemented in Pytorch pip install grad-cam ⭐ Tested on many Common CNN Networks and Vision Transformers. ⭐ Includes smoo

Jacob Gildenblat 6.6k Jan 6, 2023
A PyTorch library for Vision Transformers

VFormer A PyTorch library for Vision Transformers Getting Started Read the contributing guidelines in CONTRIBUTING.rst to learn how to start contribut

Society for Artificial Intelligence and Deep Learning 142 Nov 28, 2022
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.6k Jan 9, 2023
Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."

Spacetimeformer Multivariate Forecasting This repository contains the code for the paper, "Long-Range Transformers for Dynamic Spatiotemporal Forecast

QData 440 Jan 2, 2023
Implementation of various Vision Transformers I found interesting

Implementation of various Vision Transformers I found interesting

Kim Seonghyeon 78 Dec 6, 2022
Twins: Revisiting the Design of Spatial Attention in Vision Transformers

Twins: Revisiting the Design of Spatial Attention in Vision Transformers Very recently, a variety of vision transformer architectures for dense predic

null 482 Dec 18, 2022
Exploring whether attention is necessary for vision transformers

Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet Paper/Report TL;DR We replace the attention layer in a v

Luke Melas-Kyriazi 461 Jan 7, 2023
Contains code for the paper "Vision Transformers are Robust Learners".

Vision Transformers are Robust Learners This repository contains the code for the paper Vision Transformers are Robust Learners by Sayak Paul* and Pin

Sayak Paul 103 Jan 5, 2023
This is an official implementation of CvT: Introducing Convolutions to Vision Transformers.

Introduction This is an official implementation of CvT: Introducing Convolutions to Vision Transformers. We present a new architecture, named Convolut

Microsoft 408 Dec 30, 2022
This is an official implementation of CvT: Introducing Convolutions to Vision Transformers.

Introduction This is an official implementation of CvT: Introducing Convolutions to Vision Transformers. We present a new architecture, named Convolut

Bin Xiao 175 Jan 8, 2023
Official repository for "Intriguing Properties of Vision Transformers" (2021)

Intriguing Properties of Vision Transformers Muzammal Naseer, Kanchana Ranasinghe, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, & Ming-Hsuan Yang P

Muzammal Naseer 155 Dec 27, 2022
DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification Created by Yongming Rao, Wenliang Zhao, Benlin Liu, Jiwen Lu, Jie Zhou, Ch

Yongming Rao 414 Jan 1, 2023