Layer-wise Relevance Propagation (LRP) in PyTorch
Basic unsupervised implementation of Layer-wise Relevance Propagation (Bach et al., Montavon et al.) in PyTorch for VGG networks from PyTorch's Model Zoo. This tutorial served as a starting point. In this implementation, I tried to make sure that the code is easy to understand and easy to extend to other network architectures.
I also added a novel relevance propagation filter to this implementation resulting in much crisper heatmaps (see my blog for more information). If you want to use it, please don't forget to cite this implementation.
This implementation is already reasonably fast. It is therefore also suitable for projects that want to use LRP in real time. Using a RTX 2080 Ti graphics card I reach 53 FPS with the VGG-16 network.
If I find the time, I will provide a more model agnostic implementation. I also welcome pull requests improving this implementation.
You can find more information about this implementation on my blog.
To run
Running LRP for a VGG-like network is fairly straightforward
import torch
import torchvision
from src.lrp import LRPModel
x = torch.rand(size=(1, 3, 224, 224))
model = torchvision.models.vgg16(pretrained=True)
lrp_model = LRPModel(model)
r = lrp_model.forward(x)
Examples
Examples show the z-plus-rule without and with additional relevance filter.
TODOs
- Add support for other network architectures (model agnostic)
- Add control for layer parameters via config file.
- Add control for relevance filter via config file.
License
MIT
Citation
@misc{blogpost,
title={Layer-wise Relevance Propagation for PyTorch},
author={Fabi, Kai},
howpublished={\url{https://github.com/KaiFabi/PyTorchRelevancePropagation}},
year={2021}
}