Improving Deep Network Debuggability via Sparse Decision Layers
This repository contains the code for our paper:
Leveraging Sparse Linear Layers for Debuggable Deep Networks
Eric Wong*, Shibani Santurkar*, Aleksander Madry
Paper: http://arxiv.org/abs/2105.04857
Blog posts: Part1 and Part2
@article{wong2021leveraging,
title={Leveraging Sparse Linear Layers for Debuggable Deep Networks},
author={Wong, Eric and Santurkar, Shibani and M{\k{a}}dry, Aleksander},
journal={arXiv preprint arXiv:2105.04857},
year={2021}
}
Getting started
Our code relies on the MadryLab public robustness
library, as well as the glm_saga
library which will be automatically installed when you follow the instructions below. The glm_saga
library contains a standalone implementation of our sparse GLM solver.
-
Clone our repo:
git clone https://github.com/microsoft/DebuggableDeepNetworks.git
-
Setup the lucent submodule using:
git submodule update --init --recursive
-
We recommend using conda for dependencies:
conda env create -f environment.yml conda activate debuggable
Training sparse decision layers
Contents:
main.py
fits a sparse decision layer on top of the deep features of the specified pre-trained (language/vision) deep networkhelpers/
has some helper functions for loading datasets, models, and featureslanguage/
has some additional code for handling language models and datasets
To run the settings in our paper, you can use the following commands:
# Sentiment classification
python main.py --dataset sst --dataset-path --dataset-type language --model-path barissayil/bert-sentiment-analysis-sst --arch bert --out-path ./tmp/sst/ --cache
# Toxic comment classification (biased)
python main.py --dataset jigsaw-toxic --dataset-path --dataset-type language --model-path unitary/toxic-bert --arch bert --out-path ./tmp/jigsaw-toxic/ --cache --balance
# Toxic comment classification (unbiased)
python main.py --dataset jigsaw-alt-toxic --dataset-path --dataset-type language --model-path unitary/unbiased-toxic-roberta --arch roberta --out-path ./tmp/unbiased-jigsaw-toxic/ --cache --balance
# Places-10
python main.py --dataset places-10 --dataset-path --dataset-type vision --model-path --arch resnet50 --out-path ./tmp/places/ --cache
# ImageNet
python main.py --dataset imagenet --dataset-path --dataset-type vision --model-path --arch resnet50 --out-path ./tmp/imagenet/ --cache
Interpreting deep features
After fitting a sparse GLM with one of the above commands, we provide some notebooks for inspecting and visualizing the resulting features. See inspect_vision_models.ipynb
and inspect_language_models.ipynb
for the vision and language settings respectively.