pytorch implementation of "Distilling a Neural Network Into a Soft Decision Tree"

Overview

Soft-Decision-Tree

Soft-Decision-Tree is the pytorch implementation of Distilling a Neural Network Into a Soft Decision Tree, paper recently published on Arxiv about adopting decision tree algorithm into neural network. "If we could take the knowledge acquired by the neural net and express the same knowledge in a model that relies on hierarchical decisions instead, explaining a particular decision would be much easier."

Requirements

Result

I achieved 92.95% of test dataset accuracy on MNISTafter 40 epoches, without exploring enough of hyper-parameters (The paper achieved 94.45%). Higher accuracy might be achievable with searching hyper-parameters, or training longer epoches (if you can, please let me know :) )

Usage

$ python main.py

Comments
  • where is the 0.5 test  ?

    where is the 0.5 test ?

    Thank you for sharing this code I was wondering why did you comment that forward section in soft decision tree and why you are not using the 0.5 test for left or right nodes ??

    just another question output dimension in args is the number of classes right?

    best regards

    opened by Oussamab21 5
  • exponential increase in the temporal scale

    exponential increase in the temporal scale

    Really great implementation! I have a question about the implementation. In the last paragraph of section of regularizers, the authors mention ' exponential increase in the temporal scale of the window used to compute the running average'. Is this feature implemented in this codebase? I didn't find it. Thanks:)

    opened by weiguowilliam 0
  • loss is nan

    loss is nan

    I try to use my dataset, my data is a table with many discrete data, such as 0,1,2. I found the loss is nan

    Train Epoch: 1 [0/49626 (0%)]	Loss: nan, Accuracy: 28/64 (43.0000%)
    Train Epoch: 1 [640/49626 (1%)]	Loss: nan, Accuracy: 30/64 (46.0000%)
    Train Epoch: 1 [1280/49626 (3%)]	Loss: nan, Accuracy: 36/64 (56.0000%)
    Train Epoch: 1 [1920/49626 (4%)]	Loss: nan, Accuracy: 35/64 (54.0000%)
    Train Epoch: 1 [2560/49626 (5%)]	Loss: nan, Accuracy: 40/64 (62.0000%)
    Train Epoch: 1 [3200/49626 (6%)]	Loss: nan, Accuracy: 29/64 (45.0000%)
    Train Epoch: 1 [3840/49626 (8%)]	Loss: nan, Accuracy: 33/64 (51.0000%)
    Train Epoch: 1 [4480/49626 (9%)]	Loss: nan, Accuracy: 28/64 (43.0000%)
    Train Epoch: 1 [5120/49626 (10%)]	Loss: nan, Accuracy: 30/64 (46.0000%)
    Train Epoch: 1 [5760/49626 (12%)]	Loss: nan, Accuracy: 28/64 (43.0000%)
    Train Epoch: 1 [6400/49626 (13%)]	Loss: nan, Accuracy: 38/64 (59.0000%)
    Train Epoch: 1 [7040/49626 (14%)]	Loss: nan, Accuracy: 36/64 (56.0000%)
    Train Epoch: 1 [7680/49626 (15%)]	Loss: nan, Accuracy: 27/64 (42.0000%)
    Train Epoch: 1 [8320/49626 (17%)]	Loss: nan, Accuracy: 29/64 (45.0000%)
    Train Epoch: 1 [8960/49626 (18%)]	Loss: nan, Accuracy: 31/64 (48.0000%)
    Train Epoch: 1 [9600/49626 (19%)]	Loss: nan, Accuracy: 32/64 (50.0000%)
    Train Epoch: 1 [10240/49626 (21%)]	Loss: nan, Accuracy: 34/64 (53.0000%)
    Train Epoch: 1 [10880/49626 (22%)]	Loss: nan, Accuracy: 34/64 (53.0000%)
    Train Epoch: 1 [11520/49626 (23%)]	Loss: nan, Accuracy: 29/64 (45.0000%)
    Train Epoch: 1 [12160/49626 (24%)]	Loss: nan, Accuracy: 34/64 (53.0000%)
    Train Epoch: 1 [12800/49626 (26%)]	Loss: nan, Accuracy: 39/64 (60.0000%)
    Train Epoch: 1 [13440/49626 (27%)]	Loss: nan, Accuracy: 34/64 (53.0000%)
    
    
    opened by w5688414 0
  • Will this work out of the box with [16] shaped input?

    Will this work out of the box with [16] shaped input?

    Hi, I don't want to classify images, just 16-element tensors per datum. Will this work out of the box with that? If not, where should I look to change?

    opened by ghost 4
  • hello,about bigger sizes input?

    hello,about bigger sizes input?

    First of all, thank you for such a good code. I want to ask, when I input a larger size, such as 224 * 224 * 3, I find that the training has no effect, is it necessary to change some parts of the code?

    opened by 1448643857 0
Owner
Kim Heecheol
University of Tokyo, Intelligent systems & Informatics Lab.
Kim Heecheol
PyTorch implementation of DeepDream algorithm

neural-dream This is a PyTorch implementation of DeepDream. The code is based on neural-style-pt. Here we DeepDream a photograph of the Golden Gate Br

null 121 Nov 5, 2022
tensorboard for pytorch (and chainer, mxnet, numpy, ...)

tensorboardX Write TensorBoard events with simple function call. The current release (v2.1) is tested on anaconda3, with PyTorch 1.5.1 / torchvision 0

Tzu-Wei Huang 7.5k Jan 7, 2023
Many Class Activation Map methods implemented in Pytorch for CNNs and Vision Transformers. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM

Class Activation Map methods implemented in Pytorch pip install grad-cam ⭐ Comprehensive collection of Pixel Attribution methods for Computer Vision.

Jacob Gildenblat 6.5k Jan 1, 2023
Visualization toolkit for neural networks in PyTorch! Demo -->

FlashTorch A Python visualization toolkit, built with PyTorch, for neural networks in PyTorch. Neural networks are often described as "black box". The

Misa Ogura 692 Dec 29, 2022
Lucid library adapted for PyTorch

Lucent PyTorch + Lucid = Lucent The wonderful Lucid library adapted for the wonderful PyTorch! Lucent is not affiliated with Lucid or OpenAI's Clarity

Lim Swee Kiat 520 Dec 26, 2022
Pytorch Feature Map Extractor

MapExtrackt Convolutional Neural Networks Are Beautiful We all take our eyes for granted, we glance at an object for an instant and our brains can ide

Lewis Morris 40 Dec 7, 2022
Convolutional neural network visualization techniques implemented in PyTorch.

This repository contains a number of convolutional neural network visualization techniques implemented in PyTorch.

null 1 Nov 6, 2021
Python implementation of R package breakDown

pyBreakDown Python implementation of breakDown package (https://github.com/pbiecek/breakDown). Docs: https://pybreakdown.readthedocs.io. Requirements

MI^2 DataLab 41 Mar 17, 2022
Implementation of linear CorEx and temporal CorEx.

Correlation Explanation Methods Official implementation of linear correlation explanation (linear CorEx) and temporal correlation explanation (T-CorEx

Hrayr Harutyunyan 34 Nov 15, 2022
ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... 모델의 개념이해를 돕기 위한 구현물로 현재 변수명을 상세히 적었고

BG Kim 3 Oct 6, 2022
An essential implementation of BYOL in PyTorch + PyTorch Lightning

Essential BYOL A simple and complete implementation of Bootstrap your own latent: A new approach to self-supervised Learning in PyTorch + PyTorch Ligh

Enrico Fini 48 Sep 27, 2022
RealFormer-Pytorch Implementation of RealFormer using pytorch

RealFormer-Pytorch Implementation of RealFormer using pytorch. Includes comparison with classical Transformer on image classification task (ViT) wrt C

Simo Ryu 90 Dec 8, 2022
A PyTorch implementation of the paper Mixup: Beyond Empirical Risk Minimization in PyTorch

Mixup: Beyond Empirical Risk Minimization in PyTorch This is an unofficial PyTorch implementation of mixup: Beyond Empirical Risk Minimization. The co

Harry Yang 121 Dec 17, 2022
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
PyTorch implementation of Advantage async actor-critic Algorithms (A3C) in PyTorch

Advantage async actor-critic Algorithms (A3C) in PyTorch @inproceedings{mnih2016asynchronous, title={Asynchronous methods for deep reinforcement lea

LEI TAI 111 Dec 8, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
Pytorch-diffusion - A basic PyTorch implementation of 'Denoising Diffusion Probabilistic Models'

PyTorch implementation of 'Denoising Diffusion Probabilistic Models' This reposi

Arthur Juliani 76 Jan 7, 2023
Fang Zhonghao 13 Nov 19, 2022
RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

RETRO - Pytorch (wip) Implementation of RETRO, Deepmind's Retrieval based Attent

Phil Wang 556 Jan 4, 2023
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives

HashNeRF-pytorch Instant-NGP recently introduced a Multi-resolution Hash Encodin

Yash Sanjay Bhalgat 616 Jan 6, 2023