Demonstration of transfer of knowledge and generalization with distillation

Overview

Distilling-the-Knowledge-in-a-Neural-Network

This is an implementation of a part of the paper "Distilling the Knowledge in a Neural Network" (https://arxiv.org/abs/1503.02531).

Teacher network has two hidden layers with 1200 units in each layer. It is trained on MNIST with data augmentation and achieves 108 test errors.

Student network has one hidden layer with 400 units. No regularization techniques are used to train student network except weight regularization. Without distillation, it achieves 181 test errors. With distillation, the test errors reduces to 134. This demonstrates the knowledge transfer happening from teacher to student, helping the student to generalize better.

Training and testing teacher and student network

For training teacher network, run all cells of distill_basic_teacher.ipynb. For training student network, run all cells of distill_basic_student.ipynb. Modify second cell of both notebooks according to the availability of GPU.

You might also like...
This is the official pytorch implementation of Student Helping Teacher: Teacher Evolution via Self-Knowledge Distillation(TESKD)
This is the official pytorch implementation of Student Helping Teacher: Teacher Evolution via Self-Knowledge Distillation(TESKD)

Student Helping Teacher: Teacher Evolution via Self-Knowledge Distillation (TESKD) By Zheng Li[1,4], Xiang Li[2], Lingfeng Yang[2,4], Jian Yang[2], Zh

The official implementation of CVPR 2021 Paper: Improving Weakly Supervised Visual Grounding by Contrastive Knowledge Distillation.

Improving Weakly Supervised Visual Grounding by Contrastive Knowledge Distillation This repository is the official implementation of CVPR 2021 paper:

Light-weight network, depth estimation, knowledge distillation, real-time depth estimation, auxiliary data.
Light-weight network, depth estimation, knowledge distillation, real-time depth estimation, auxiliary data.

light-weight-depth-estimation Boosting Light-Weight Depth Estimation Via Knowledge Distillation, https://arxiv.org/abs/2105.06143 Junjie Hu, Chenyou F

[NeurIPS-2021] Mosaicking to Distill: Knowledge Distillation from Out-of-Domain Data
[NeurIPS-2021] Mosaicking to Distill: Knowledge Distillation from Out-of-Domain Data

MosaicKD Code for NeurIPS-21 paper "Mosaicking to Distill: Knowledge Distillation from Out-of-Domain Data" 1. Motivation Natural images share common l

Instance-conditional Knowledge Distillation for Object Detection

Instance-conditional Knowledge Distillation for Object Detection This is a MegEngine implementation of the paper "Instance-conditional Knowledge Disti

Knowledge Distillation Toolbox for Semantic Segmentation
Knowledge Distillation Toolbox for Semantic Segmentation

SegDistill: Toolbox for Knowledge Distillation on Semantic Segmentation Networks This repo contains the supported code and configuration files for Seg

Pytorch implementation for Patient Knowledge Distillation for BERT Model Compression

Patient Knowledge Distillation for BERT Model Compression Knowledge distillation for BERT model Installation Run command below to install the environm

PyTorch implementation of paper A Fast Knowledge Distillation Framework for Visual Recognition.
PyTorch implementation of paper A Fast Knowledge Distillation Framework for Visual Recognition.

FKD: A Fast Knowledge Distillation Framework for Visual Recognition Official PyTorch implementation of paper A Fast Knowledge Distillation Framework f

Official implementation of the paper
Official implementation of the paper "Lightweight Deep CNN for Natural Image Matting via Similarity Preserving Knowledge Distillation"

Lightweight-Deep-CNN-for-Natural-Image-Matting-via-Similarity-Preserving-Knowledge-Distillation Introduction Accepted at IEEE Signal Processing Letter

Comments
  • transforms.Normalize() is not  working in teacher training file

    transforms.Normalize() is not working in teacher training file

    In distill_basic_teacher.ipynb, for both training_dataset and test_dataset, the transform.Normalize has been set wrongly. The mentioned transform converts the input image to the shape (2,28,28).

    opened by aryanasadianuoit 1
  • List index out of range

    List index out of range

    Hello,

    I followed the instruction and trained the teacher model successfully. However, when I tried to do the same with the student model it showed me this error:

    nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
    ------------------
    # plt.rcParams['figure.figsize'] = [10, 5]
    weight_decay_scatter = ([math.log10(h['weight_decay']) if h['weight_decay'] > 0 else -6 for h in hparams_list])
    dropout_scatter = [int(h['dropout_input'] == 0.2) for h in hparams_list]
    colors = []
    print('hparams_list has a length of ', len(hparams_list))
    for i in range(len(hparams_list)):
        cur_hparam_tuple = utils.hparamDictToTuple(hparams_list[i])
        colors.append(results_no_distill[cur_hparam_tuple]['val_acc'][-1])
    
    marker_size = 100
    fig, ax = plt.subplots()
    plt.scatter(weight_decay_scatter, dropout_scatter, marker_size, c=colors, edgecolors='black')
    plt.colorbar()
    print('weight_decay_scatter has a length of ', len(weight_decay_scatter))
    print('colors has a length of ', len(colors))
    print('dropout_scatter has a length of ', len(dropout_scatter))
    for i in range(len(weight_decay_scatter)):
        ax.annotate(str('%0.4f' % (colors[i], )), (weight_decay_scatter[i], dropout_scatter[i]))
    plt.show()
    
    ---------------------------------------------------------------------------
    IndexError                                Traceback (most recent call last)
    <ipython-input-1-945da500c1fb> in <module>
          6 for i in range(len(hparams_list)):
          7     cur_hparam_tuple = utils.hparamDictToTuple(hparams_list[i])
    ----> 8     colors.append(results_no_distill[cur_hparam_tuple]['val_acc'][-1])
          9 
         10 marker_size = 100
    
    IndexError: list index out of range
    IndexError: list index out of range
    

    I am unsure what caused this. I tried to debug by first printing out the length of each list/tuple that could cause this but the result would not print the length out.

    opened by alexzhang0825 0
  • Training the Teacher model issues

    Training the Teacher model issues

    Hi, I am facing issues while training the basic_teacher model.

    Attaching the logs here-

    Training with hparamsdropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05 /usr/local/lib/python3.6/dist-packages/torch/optim/lr_scheduler.py:123: UserWarning: Detected call oflr_scheduler.step()beforeoptimizer.step(). In PyTorch 1.1.0 and later, you should call them in the opposite order:optimizer.step()beforelr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

    RuntimeError Traceback (most recent call last) in () 30 train_val_loader, None, 31 print_every=print_every, ---> 32 fast_device=fast_device) 33 save_path = checkpoints_path + utils.hparamToString(hparam) + '_final.tar' 34 torch.save({'results' : results[hparam_tuple],

    4 frames /usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self) 393 # (https://bugs.python.org/issue2651), so we work around it. 394 msg = KeyErrorMessage(msg) --> 395 raise self.exc_type(msg)

    RuntimeError: Caught RuntimeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop data = fetcher.fetch(index) File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/utils/fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py", line 97, in getitem img = self.transform(img) File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 61, in call img = t(img) File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 212, in call return F.normalize(tensor, self.mean, self.std, self.inplace) File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py", line 298, in normalize tensor.sub(mean).div(std) RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [2, 28, 28]`

    opened by swapanj162 0
  • Reproduce the teacher accuracy

    Reproduce the teacher accuracy

    Hi @shriramsb. Unfortunately, I can't reach the accuracy in the teacher model based on your code snippet. would you please take a look at it? Maybe the architecture is not appropriate.

    opened by aryanasadianuoit 3
Owner
null
Official implementation for (Show, Attend and Distill: Knowledge Distillation via Attention-based Feature Matching, AAAI-2021)

Show, Attend and Distill: Knowledge Distillation via Attention-based Feature Matching Official pytorch implementation of "Show, Attend and Distill: Kn

Clova AI Research 80 Dec 16, 2022
TorchDistiller - a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

This project is a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

yifan liu 147 Dec 3, 2022
PocketNet: Extreme Lightweight Face Recognition Network using Neural Architecture Search and Multi-Step Knowledge Distillation

PocketNet This is the official repository of the paper: PocketNet: Extreme Lightweight Face Recognition Network using Neural Architecture Search and M

Fadi Boutros 40 Dec 22, 2022
Focal and Global Knowledge Distillation for Detectors

FGD Paper: Focal and Global Knowledge Distillation for Detectors Install MMDetection and MS COCO2017 Our codes are based on MMDetection. Please follow

Mesopotamia 261 Dec 23, 2022
Paper Title: Heterogeneous Knowledge Distillation for Simultaneous Infrared-Visible Image Fusion and Super-Resolution

HKDnet Paper Title: "Heterogeneous Knowledge Distillation for Simultaneous Infrared-Visible Image Fusion and Super-Resolution" Email: 18186470991@163.

wasteland 11 Nov 12, 2022
Official implementation for (Refine Myself by Teaching Myself : Feature Refinement via Self-Knowledge Distillation, CVPR-2021)

FRSKD Official implementation for Refine Myself by Teaching Myself : Feature Refinement via Self-Knowledge Distillation (CVPR-2021) Requirements Pytho

null 75 Dec 28, 2022
Block-wisely Supervised Neural Architecture Search with Knowledge Distillation (CVPR 2020)

DNA This repository provides the code of our paper: Blockwisely Supervised Neural Architecture Search with Knowledge Distillation. Illustration of DNA

Changlin Li 215 Dec 19, 2022
AMTML-KD: Adaptive Multi-teacher Multi-level Knowledge Distillation

AMTML-KD: Adaptive Multi-teacher Multi-level Knowledge Distillation

Frank Liu 26 Oct 13, 2022
Code implementation of Data Efficient Stagewise Knowledge Distillation paper.

Data Efficient Stagewise Knowledge Distillation Table of Contents Data Efficient Stagewise Knowledge Distillation Table of Contents Requirements Image

IvLabs 112 Dec 2, 2022