Codebase for "ProtoAttend: Attention-Based Prototypical Learning."

Overview

Codebase for "ProtoAttend: Attention-Based Prototypical Learning."

Authors: Sercan O. Arik and Tomas Pfister

Paper: Sercan O. Arik and Tomas Pfister, "ProtoAttend: Attention-Based Prototypical Learning" Link: https://arxiv.org/abs/1902.06292

We propose a novel inherently interpretable machine learning method that bases decisions on few relevant examples that we call prototypes. Our method, ProtoAttend, can be integrated into a wide range of neural network architectures including pre-trained models. It utilizes an attention mechanism that relates the encoded representations to samples in order to determine prototypes. The resulting model outperforms state of the art in three high impact problems without sacrificing accuracy of the original model: (1) it enables high-quality interpretability that outputs samples most relevant to the decision-making (i.e. a sample-based interpretability method); (2) it achieves state of the art confidence estimation by quantifying the mismatch across prototype labels; and (3) it obtains state of the art in distribution mismatch detection. All this can be achieved with minimal additional test time and a practically viable training time computational cost.

This codebase exemplifies the ProtoAttend training and evaluation pipeline for Fashion-MNIST dataset, using ResNet as the image encoder model.

To run the training pipeline, simply use python3 main_protoattend.py. The results and visualizations will be ported to Tensorboard.

To modify the experiment to other datasets and models:

  • Implement data batching and preprocessing functions (modify input_data.py and data iterators like iter_train etc.).
  • Integrate the encoder model function suitable for the data type (modify cnn_encoder in model.py).
  • Reoptimize the learning hyperparameters for the new dataset.
You might also like...
Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World
Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World

Legged Robots that Keep on Learning Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World, whic

An Image Captioning codebase

An Image Captioning codebase This is a codebase for image captioning research. It supports: Self critical training from Self-critical Sequence Trainin

Ranger - a synergistic optimizer using RAdam (Rectified Adam), Gradient Centralization and LookAhead in one codebase
Ranger - a synergistic optimizer using RAdam (Rectified Adam), Gradient Centralization and LookAhead in one codebase

Ranger-Deep-Learning-Optimizer Ranger - a synergistic optimizer combining RAdam (Rectified Adam) and LookAhead, and now GC (gradient centralization) i

Codebase for the self-supervised goal reaching benchmark introduced in the LEXA paper
Codebase for the self-supervised goal reaching benchmark introduced in the LEXA paper

LEXA Benchmark Codebase for the self-supervised goal reaching benchmark introduced in the LEXA paper (Discovering and Achieving Goals via World Models

This codebase is the official implementation of Test-Time Classifier Adjustment Module for Model-Agnostic Domain Generalization (NeurIPS2021, Spotlight)

Test-Time Classifier Adjustment Module for Model-Agnostic Domain Generalization This codebase is the official implementation of Test-Time Classifier A

Official codebase for "B-Pref: Benchmarking Preference-BasedReinforcement Learning" contains scripts to reproduce experiments.

B-Pref Official codebase for B-Pref: Benchmarking Preference-BasedReinforcement Learning contains scripts to reproduce experiments. Install conda env

Using this codebase as a tool for my own research. Making some modifications to the original repo for my own purposes.

For SwapNet Create a list.txt file containing all the images to process. This can be done with the GNU find command: find path/to/input/folder -name '

Codebase for Amodal Segmentation through Out-of-Task andOut-of-Distribution Generalization with a Bayesian Model

Codebase for Amodal Segmentation through Out-of-Task andOut-of-Distribution Generalization with a Bayesian Model

PySlowFast: video understanding codebase from FAIR for reproducing state-of-the-art video models.
PySlowFast: video understanding codebase from FAIR for reproducing state-of-the-art video models.

PySlowFast PySlowFast is an open source video understanding codebase from FAIR that provides state-of-the-art video classification models with efficie

Owner
47
47
A general 3D Object Detection codebase in PyTorch.

Det3D is the first 3D Object Detection toolbox which provides off the box implementations of many 3D object detection algorithms such as PointPillars, SECOND, PIXOR, etc, as well as state-of-the-art methods on major benchmarks like KITTI(ViP) and nuScenes(CBGS).

Benjin Zhu 1.4k Jan 5, 2023
Official codebase for Pretrained Transformers as Universal Computation Engines.

universal-computation Overview Official codebase for Pretrained Transformers as Universal Computation Engines. Contains demo notebook and scripts to r

Kevin Lu 210 Dec 28, 2022
AOT-GAN for High-Resolution Image Inpainting (codebase for image inpainting)

AOT-GAN for High-Resolution Image Inpainting Arxiv Paper | AOT-GAN: Aggregated Contextual Transformations for High-Resolution Image Inpainting Yanhong

Multimedia Research 214 Jan 3, 2023
This is the codebase for Diffusion Models Beat GANS on Image Synthesis.

This is the codebase for Diffusion Models Beat GANS on Image Synthesis.

OpenAI 3k Dec 26, 2022
Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling.

Decision Transformer Lili Chen*, Kevin Lu*, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas†, and Igor M

Kevin Lu 1.4k Jan 7, 2023
Codebase for the Summary Loop paper at ACL2020

Summary Loop This repository contains the code for ACL2020 paper: The Summary Loop: Learning to Write Abstractive Summaries Without Examples. Training

Canny Lab @ The University of California, Berkeley 44 Nov 4, 2022
This is the codebase for the ICLR 2021 paper Trajectory Prediction using Equivariant Continuous Convolution

Trajectory Prediction using Equivariant Continuous Convolution (ECCO) This is the codebase for the ICLR 2021 paper Trajectory Prediction using Equivar

Spatiotemporal Machine Learning 45 Jul 22, 2022
A weakly-supervised scene graph generation codebase. The implementation of our CVPR2021 paper ``Linguistic Structures as Weak Supervision for Visual Scene Graph Generation''

README.md shall be finished soon. WSSGG 0 Overview 1 Installation 1.1 Faster-RCNN 1.2 Language Parser 1.3 GloVe Embeddings 2 Settings 2.1 VG-GT-Graph

Keren Ye 35 Nov 20, 2022
X-modaler is a versatile and high-performance codebase for cross-modal analytics.

X-modaler X-modaler is a versatile and high-performance codebase for cross-modal analytics. This codebase unifies comprehensive high-quality modules i

null 910 Dec 28, 2022
Codebase for Diffusion Models Beat GANS on Image Synthesis.

Codebase for Diffusion Models Beat GANS on Image Synthesis.

Katherine Crowson 128 Dec 2, 2022