PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation

Overview

PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation

The paper: https://arxiv.org/abs/1704.03296

What makes the deep learning network think the image label is 'pug, pug-dog' and 'tabby, tabby cat':

Dog Cat

A perturbation of the dog that caused the dog category score to vanish:

Perturbed

What makes the deep learning network think the image label is 'flute, transverse flute':

Flute


Usage: python explain.py <path_to_image>

This is a PyTorch impelentation of

"Interpretable Explanations of Black Boxes by Meaningful Perturbation. Ruth Fong, Andrea Vedaldi" with some deviations.

This uses VGG19 from torchvision. It will be downloaded when used for the first time.

This learns a mask of pixels that explain the result of a black box. The mask is learned by posing an optimization problem and solving directly for the mask values.

This is different than other visualization techniques like Grad-CAM that use heuristics like high positive gradient values as an indication of relevance to the network score.

In our case the black box is the VGG19 model, but this can use any differentiable model.


How it works

Equation

Taken from the paper https://arxiv.org/abs/1704.03296

The goal is to solve for a mask that explains why did the network output a score for a certain category.

We create a low resolution (28x28) mask, and use it to perturb the input image to a deep learning network.

The perturbation combines a blurred version of the image, the regular image, and the up-sampled mask.

Wherever the mask contains low values, the input image will become more blurry.

We want to optimize for the next properties:

  1. When using the mask to blend the input image and it's blurred versions, the score of the target category should drop significantly. The evidence of the category should be removed!
  2. The mask should be sparse. Ideally the mask should be the minimal possible mask to drop the category score. This translates to a L1(1 - mask) term in the cost function.
  3. The mask should be smooth. This translates to a total variation regularization in the cost function.
  4. The mask shouldn't over-fit the network. Since the network activations might contain a lot of noise, it can be easy for the mask to just learn random values that cause the score to drop without being visually coherent. In addition to the other terms, this translates to solving for a lower resolution 28x28 mask.

Deviations from the paper

The paper uses a gaussian kernel with a sigma that is modulated by the value of the mask. This is computational costly to compute since the mask values are updated during the iterations, meaning we need a different kernel for every mask pixel for every iteration.

Initially I tried approximating this by first filtering the image with a filter bank of varying gaussian kernels. Then during optimization, the input image pixel would use the quantized mask value to select an appropriate filter bank output pixel (high mask value -> lower channel).

This was done using the PyTorch variable gather/select_index functions. But it turns out that the gather and select_index functions in PyTorch are not differentiable by the indexes.

Instead, we just compute a perturbed image once, and then blend the image and the perturbed image using:

input_image = (1 - mask) * image + mask * perturbed_image

And it works well in practice.

The perturbed image here is the average of the gaussian and median blurred image, but this can really be changed to many other combinations (try it out and find something better!).

Also now gaussian noise with a sigma of 0.2 is added to the preprocssed image at each iteration, inspired by google's SmoothGradient.

Comments
  • Adding license?

    Adding license?

    Thank you for this excellent contributions! However, it is not clear what are your views about using and modifying this code. Would you mind adding a license?

    opened by hungaroring 2
  • Interpretation of the Heatmap

    Interpretation of the Heatmap

    Hey Jacob,

    Thanks for this great and easy to use implementation of meaningful perturbation algorithm.

    I just have a small question. Could you please help me understand the interpretation of the heatmap. As per my understanding, the heatmap is the same thing as the mask (values between 0 and 1) representing the minimum region in the image that you would suppress to reduce the prediction score maximally.

    My understanding that the mask is same as the heatmap is based on the following lines from your code:

    mask = 1 - mask
    heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    

    Could you please confirm whether my understanding is correct or is it something else?

    Thanks, Naman

    opened by bnaman50 1
  • fix for issue #1

    fix for issue #1

    I fixed issue #1 . It is because cv2.randn() doesn't return array, but fill array argument with random value. https://docs.opencv.org/2.4/modules/core/doc/operations_on_arrays.html#randn

    I hope this would be some of help.

    opened by bonzeboy 1
  • Missing reference for operation in the code

    Missing reference for operation in the code

    Hi, thank you for your work! Playing with the code i noticed that adding noise has a lot of impact on the result and i was wondering if you could add the reference for this technique in the description. "Also now gaussian noise with a sigma of 0.2 is added to the preprocssed image at each iteration, inspired by google's SmoothGradient." i found only this in your description, but the name is too generic to find what you were referring to. Could you point me to the paper/article? Thanks!

    opened by godimarcovr 1
  • Run Error

    Run Error

    Dear @jacobgil ,

    Thanks for your great work! However, when I run python explain.py examples/flute.jpg, something goes wrong:

    $ python explain.py examples/flute.jpg 
    Category with highest probability 558
    Optimizing.. 
    /usr/local/lib/python2.7/dist-packages/torch/nn/modules/upsampling.py:183: UserWarning: nn.UpsamplingBilinear2d is deprecated. Use nn.Upsample instead.
      warnings.warn("nn.UpsamplingBilinear2d is deprecated. Use nn.Upsample instead.")
    Traceback (most recent call last):
      File "explain.py", line 132, in <module>
        noise = noise + cv2.randn(noise, 0, 0.2)
    TypeError: unsupported operand type(s) for +: 'float' and 'NoneType'
    

    Maybe it is caused by the some python packages with different versions as you used, I guess. Do you know how to fix it?

    Thanks! Kun

    opened by wk910930 1
  • blur_img

    blur_img

        blurred_img1 = cv2.GaussianBlur(img, (11, 11), 5)
    blurred_img2 = np.float32(cv2.medianBlur(original_img, 11))/255
    blurred_img_numpy = (blurred_img1 + blurred_img2) / 2
    

    I wonder the reason why choose this method?

    opened by Yung-zi 0
  • Some questions about the implementation

    Some questions about the implementation

    Thank you for this great and straight forward implementation of this technique! I was going through the code to understand a bit more of how it works, and I'm curious about some parts:

    There seem to be two blurred images:

    blurred_img1 = cv2.GaussianBlur(img, (11, 11), 5)
    blurred_img2 = np.float32(cv2.medianBlur(original_img, 11))/255
    blurred_img_numpy = (blurred_img1 + blurred_img2) / 2
    

    Futher ahead only blurred_img2 is used for training the mask:

    img = preprocess_image(img)
    blurred_img = preprocess_image(blurred_img2)
    mask = numpy_to_torch(mask_init)
    

    But for heatmap generation we use the average of the two blurs: save(upsampled_mask, original_img, blurred_img_numpy)

    I was a bit confused by this. Is there a reason why it is done in this way?

    opened by palatos 0
  • How is \tau in Eq(4) reflected in the implementation?

    How is \tau in Eq(4) reflected in the implementation?

    Very nice work! However, I was wondering how is E_{\tau} in the equation reflected in the implementation? It seems in the code only one (unshifted) image is used?

    opened by zechenghe 0
Owner
Jacob Gildenblat
Machine learning / Computer Vision.
Jacob Gildenblat
Maximum Spatial Perturbation for Image-to-Image Translation (Official Implementation)

MSPC for I2I This repository is by Yanwu Xu and contains the PyTorch source code to reproduce the experiments in our CVPR2022 paper Maximum Spatial Pe

null 51 Dec 14, 2022
Contextualized Perturbation for Textual Adversarial Attack, NAACL 2021

Contextualized Perturbation for Textual Adversarial Attack Introduction This is a PyTorch implementation of Contextualized Perturbation for Textual Ad

cookielee77 30 Jan 1, 2023
Pytorch based library to rank predicted bounding boxes using text/image user's prompts.

pytorch_clip_bbox: Implementation of the CLIP guided bbox ranking for Object Detection. Pytorch based library to rank predicted bounding boxes using t

Sergei Belousov 50 Nov 27, 2022
Pytorch ImageNet1k Loader with Bounding Boxes.

ImageNet 1K Bounding Boxes For some experiments, you might wanna pass only the background of imagenet images vs passing only the foreground. Here, I'v

Amin Ghiasi 11 Oct 15, 2022
📦 PyTorch based visualization package for generating layer-wise explanations for CNNs.

Explainable CNNs ?? Flexible visualization package for generating layer-wise explanations for CNNs. It is a common notion that a Deep Learning model i

Ashutosh Hathidara 183 Dec 15, 2022
This repository is based on Ultralytics/yolov5, with adjustments to enable polygon prediction boxes.

Polygon-Yolov5 This repository is based on Ultralytics/yolov5, with adjustments to enable polygon prediction boxes. Section I. Description The codes a

xinzelee 226 Jan 5, 2023
This repository is based on Ultralytics/yolov5, with adjustments to enable rotate prediction boxes.

Rotate-Yolov5 This repository is based on Ultralytics/yolov5, with adjustments to enable rotate prediction boxes. Section I. Description The codes are

xinzelee 90 Dec 13, 2022
labelpix is a graphical image labeling interface for drawing bounding boxes

Welcome to labelpix ?? labelpix is a graphical image labeling interface for drawing bounding boxes. ?? Homepage Install pip install -r requirements.tx

schissmantics 26 May 24, 2022
A simple python module to generate anchor (aka default/prior) boxes for object detection tasks.

PyBx WIP A simple python module to generate anchor (aka default/prior) boxes for object detection tasks. Calculated anchor boxes are returned as ndarr

thatgeeman 4 Dec 15, 2022
[ICLR 2022] DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR

DAB-DETR This is the official pytorch implementation of our ICLR 2022 paper DAB-DETR. Authors: Shilong Liu, Feng Li, Hao Zhang, Xiao Yang, Xianbiao Qi

null 336 Dec 25, 2022
[CVPR 2021] Pytorch implementation of Hijack-GAN: Unintended-Use of Pretrained, Black-Box GANs

Hijack-GAN: Unintended-Use of Pretrained, Black-Box GANs In this work, we propose a framework HijackGAN, which enables non-linear latent space travers

Hui-Po Wang 46 Sep 5, 2022
Annotated, understandable, and visually interpretable PyTorch implementations of: VAE, BIRVAE, NSGAN, MMGAN, WGAN, WGANGP, LSGAN, DRAGAN, BEGAN, RaGAN, InfoGAN, fGAN, FisherGAN

Overview PyTorch 0.4.1 | Python 3.6.5 Annotated implementations with comparative introductions for minimax, non-saturating, wasserstein, wasserstein g

Shayne O'Brien 471 Dec 16, 2022
audioLIME: Listenable Explanations Using Source Separation

audioLIME This repository contains the Python package audioLIME, a tool for creating listenable explanations for machine learning models in music info

Institute of Computational Perception 27 Dec 1, 2022
Code for paper: Group-CAM: Group Score-Weighted Visual Explanations for Deep Convolutional Networks

Group-CAM By Zhang, Qinglong and Rao, Lu and Yang, Yubin [State Key Laboratory for Novel Software Technology at Nanjing University] This repo is the o

zhql 98 Nov 16, 2022
Collection of NLP model explanations and accompanying analysis tools

Thermostat is a large collection of NLP model explanations and accompanying analysis tools. Combines explainability methods from the captum library wi

null 126 Nov 22, 2022
The official implementation of CSG-Stump: A Learning Friendly CSG-Like Representation for Interpretable Shape Parsing

CSGStumpNet The official implementation of CSG-Stump: A Learning Friendly CSG-Like Representation for Interpretable Shape Parsing Paper | Project page

Daxuan 39 Dec 26, 2022
A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

Yunxia Zhao 3 Dec 29, 2022
Code for CVPR2021 paper 'Where and What? Examining Interpretable Disentangled Representations'.

PS-SC GAN This repository contains the main code for training a PS-SC GAN (a GAN implemented with the Perceptual Simplicity and Spatial Constriction c

Xinqi/Steven Zhu 40 Dec 16, 2022
Data and Code for ACL 2021 Paper "Inter-GPS: Interpretable Geometry Problem Solving with Formal Language and Symbolic Reasoning"

Introduction Code and data for ACL 2021 Paper "Inter-GPS: Interpretable Geometry Problem Solving with Formal Language and Symbolic Reasoning". We cons

Pan Lu 81 Dec 27, 2022