Official code for our CVPR '22 paper "Dataset Distillation by Matching Training Trajectories"

Overview

Dataset Distillation by Matching Training Trajectories

Project Page | Paper


Teaser image

This repo contains code for training expert trajectories and distilling synthetic data from our Dataset Distillation by Matching Training Trajectories paper (CVPR 2022). Please see our project page for more results.

Dataset Distillation by Matching Training Trajectories
George Cazenavette, Tongzhou Wang, Antonio Torralba, Alexei A. Efros, Jun-Yan Zhu
CMU, MIT, UC Berkeley
CVPR 2022

The task of "Dataset Distillation" is to learn a small number of synthetic images such that a model trained on this set alone will have similar test performance as a model trained on the full real dataset.

Our method distills the synthetic dataset by directly optimizing the fake images to induce similar network training dynamics as the full, real dataset. We train "student" networks for many iterations on the synthetic data, measure the error in parameter space between the "student" and "expert" networks trained on real data, and back-propagate through all the student network updates to optimize the synthetic pixels.

Wearable ImageNet: Synthesizing Tileable Textures

Teaser image

Instead of treating our synthetic data as individual images, we can instead encourage every random crop (with circular padding) on a larger canvas of pixels to induce a good training trajectory. This results in class-based textures that are continuous around their edges.

Given these tileable textures, we can apply them to areas that require such properties, such as clothing patterns.

Visualizations made using FAB3D

Getting Started

First, download our repo:

git clone https://github.com/GeorgeCazenavette/mtt-distillation.git
cd mtt-distillation

For an express instillation, we include .yaml files.

If you have an RTX 30XX GPU (or newer), run

conda env create -f requirements_11_3.yaml

If you have an RTX 20XX GPU (or older), run

conda env create -f requirements_10_2.yaml

You can then activate your conda environment with

conda activate distillation
Quadro Users Take Note:

torch.nn.DataParallel seems to not work on Quadro A5000 GPUs, and this may extend to other Quadro cards.

If you experience indefinite hanging during training, try running the process with only 1 GPU by prepending CUDA_VISIBLE_DEVICES=0 to the command.

Generating Expert Trajectories

Before doing any distillation, you'll need to generate some expert trajectories using buffer.py

The following command will train 100 ConvNet models on CIFAR-100 with ZCA whitening for 50 epochs each:

python buffer.py --dataset=CIFAR100 --model=ConvNet --train_epochs=50 --num_experts=100 --zca --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

We used 50 epochs with the default learning rate for all of our experts. Worse (but still interesting) results can be obtained faster through training fewer experts by changing --num_experts. Note that experts need only be trained once and can be re-used for multiple distillation experiments.

Distillation by Matching Training Trajectories

The following command will then use the buffers we just generated to distill CIFAR-100 down to just 1 image per class:

python distill.py --dataset=CIFAR100 --ipc=1 --syn_steps=20 --expert_epochs=3 --max_start_epoch=20 --zca --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

ImageNet

Our method can also distill subsets of ImageNet into low-support synthetic sets.

When generating expert trajectories with buffer.py or distilling the dataset with distill.py, you must designate a named subset of ImageNet with the --subset flag.

For example,

python distill.py --dataset=ImageNet --subset=imagefruit --model=ConvNetD5 --ipc=1 --res=128 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

will distill the imagefruit subset (at 128x128 resolution) into the following 10 images

To register your own ImageNet subset, you can add it to the Config class at the top of utils.py.

Simply create a list with the desired class ID's and add it to the dictionary.

This gist contains a list of all 1k ImageNet classes and their corresponding numbers.

Texture Distillation

You can also use the same set of expert trajectories (except those using ZCA) to distill classes into toroidal textures by simply adding the --texture flag.

For example,

python distill.py --texture --dataset=ImageNet --subset=imagesquawk --model=ConvNetD5 --ipc=1 --res=256 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

will distill the imagesquawk subset (at 256x256 resolution) into the following 10 textures

Acknowledgments

We would like to thank Alexander Li, Assaf Shocher, Gokul Swamy, Kangle Deng, Ruihan Gao, Nupur Kumari, Muyang Li, Gaurav Parmar, Chonghyuk Song, Sheng-Yu Wang, and Bingliang Zhang as well as Simon Lucey's Vision Group at the University of Adelaide for their valuable feedback. This work is supported, in part, by the NSF Graduate Research Fellowship under Grant No. DGE1745016 and grants from J.P. Morgan Chase, IBM, and SAP. Our code is adapted from https://github.com/VICO-UoE/DatasetCondensation

Related Work

  1. Tongzhou Wang et al. "Dataset Distillation", in arXiv preprint 2018
  2. Bo Zhao et al. "Dataset Condensation with Gradient Matching", in ICLR 2020
  3. Bo Zhao and Hakan Bilen. "Dataset Condensation with Differentiable Siamese Augmentation", in ICML 2021
  4. Timothy Nguyen et al. "Dataset Meta-Learning from Kernel Ridge-Regression", in ICLR 2021
  5. Timothy Nguyen et al. "Dataset Distillation with Infinitely Wide Convolutional Networks", in NeurIPS 2021
  6. Bo Zhao and Hakan Bilen. "Dataset Condensation with Distribution Matching", in arXiv preprint 2021
  7. Kai Wang et al. "CAFE: Learning to Condense Dataset by Aligning Features", in CVPR 2022

Reference

If you find our code useful for your research, please cite our paper.

@inproceedings{
cazenavette2022distillation,
title={Dataset Distillation by Matching Training Trajectories},
author={George Cazenavette and Tongzhou Wang and Antonio Torralba and Alexei A. Efros and Jun-Yan Zhu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
Comments
  • how to use images?

    how to use images?

    hello, i wanna know how to use distilled images. I used distilled images to train a new network, but the accuracy was terrible(10% on cifar10). So, can these images be used to train a new network? if not, what's the meanning of these images. if these images can train a new network, can you share me the network architecture.

    opened by Fduxiaozhige 7
  • values for max_start_epoch

    values for max_start_epoch

    Hi there, I can see that max_start_epoch is set to 20. However, during the generation of the expert trajectories, train_epochs is 50. It means that during distillation, we don't use most of the saved checkpoints (>20+3). My questions:

    1. Is there any reason to choose max_start_epoch as 20 not 50?
    2. Can we make train_epochs to a lower value so to reduce training time?
    opened by ankanbhunia 7
  • about hyperparameter: learning rate about updating condenses samples

    about hyperparameter: learning rate about updating condenses samples

    Hello, George! First of all, I must say that this is very nice work.

    I have some doubts about the used hyperparameter lr_img for updating condenses samples. It is not mentioned how to choose lr_img in the paper. Besides, I conduct the experiment about 10 images about each class for CIFAR-10 in terms of Table 6 and only obtain 58.50% accuracy. Should I modify other hyperparameters?

    opened by Alan-Qin 5
  • Negative LR

    Negative LR

    Hi! Thank you for your great work.

    When I was distilling with my own dataset, there was very large loss (iter = 0490) and negative learning rate.

    Could you help me figure out what is happening here? What hyperparameters should be adjusted in such case? Can we implement anything in code to prevent negative LR?

    Thank you!

    Evaluate 5 random ConvNetD4, mean = 0.2429 std = 0.0080
    -------------------------
    [2022-08-14 00:29:04] iter = 0400, loss = 1.2390[2022-08-14 00:29:12] iter = 0410, loss = 1.3564
    [2022-08-14 00:29:19] iter = 0420, loss = 1.5845
    [2022-08-14 00:29:27] iter = 0430, loss = 0.9945
    [2022-08-14 00:29:35] iter = 0440, loss = 1.4876
    [2022-08-14 00:29:43] iter = 0450, loss = 1.0734
    [2022-08-14 00:29:51] iter = 0460, loss = 1.9312
    [2022-08-14 00:29:58] iter = 0470, loss = 1.0497
    [2022-08-14 00:30:06] iter = 0480, loss = 16.3134
    [2022-08-14 00:30:14] iter = 0490, loss = 23.7197
    -------------------------
    Evaluation
    model_train = ConvNetD4, model_eval = ConvNetD4, iteration = 500
    DSA augmentation strategy:  color_crop_cutout_flip_scale_rotateDSA augmentation parameters: 
     {'aug_mode': 'S', 'prob_flip': 0.5, 'ratio_scale': 1.2, 'ratio_rotate': 15.0, 'ratio_crop_pad': 0.125, 'ratio_cutout': 0.5, 'ratio_noise': 0.05, 'brightness': 1.0, 'saturation': 2.0, 'contrast': 0.5, 'batchmode': False, 'latestseed': -1}Traceback (most recent call last):
      File "/media/ntu/volume1/home/s121md302_06/workspace/code/mtt-distillation/distill.py", line 496, in <module>
        main(args)
      File "/media/ntu/volume1/home/s121md302_06/workspace/code/mtt-distillation/distill.py", line 227, in main
        _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args, texture=args.texture)
      File "/media/ntu/volume1/home/s121md302_06/workspace/code/mtt-distillation/utils.py", line 400, in evaluate_synset
        optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
      File "/media/ntu/volume1/home/s121md302_06/anaconda3/envs/distillation/lib/python3.9/site-packages/torch/optim/sgd.py", line 91, in __init__
        raise ValueError("Invalid learning rate: {}".format(lr))
    ValueError: Invalid learning rate: -0.00048201243043877184
    
    opened by c-liangyu 3
  • A question about backbone networks

    A question about backbone networks

    Hi I've taken great interest in your work and am trying to experiment on various environments.

    1

    From the table 1. in the paper you show that the ConvNet used as a baseline only shows a maximum of 56.2% accuracy even when trained with a full CIFAR100 training set which is considerably lower compared to the SOTA classification models with higher than 90% accuracy.

    As the performance of the baseline model or expert trajectories trained on the full dataset serves as a upper bound for the performance of the student network trained on the synthetic dataset I was wondering if you ever experimented on more complex networks like WideResNet50 from the point of training expert networks . If you haven't do you have any naive guesses to what the outcome would be?

    Thanks a bunch.

    opened by imesu2378 3
  • Where did you get the acc 36.1% from the paper Dataset Distillation with Infinitely Wide Convolutional Networks

    Where did you get the acc 36.1% from the paper Dataset Distillation with Infinitely Wide Convolutional Networks

    Thanks for your great idea and detailed work, and I hope you are enjoying your day so far.

    I have a question regarding your paper "Dataset Distillation by Matching Training Trajectories". In the third sentence count from the bottom of the Introduction, you stated you break SOTA "Dataset Distillation with Infinitely Wide Convolutional Networks" on his accuracy of 36.1%/46.5%, However, the accuracy stated in the paper is actually 64.7%/80.6%.

    Is that a small mistake? If it's not, could you help me to address where on the paper you find the accuracy? Thank you and best regards!

    opened by NiaLiu 3
  • Expert trajactory performance

    Expert trajactory performance

    Thanks for your work! I've got a question. When training the expert trajactory with CiFAR10 accroding to buffer.py, I only got test accuracy around 0.79 and 0.77 w/o --zca after 50 epochs. However, Table 1 in your paper reports that full dataset can reach 0.84 accuracy on CiFAR10. Is there any mistake I've made here?

    opened by 1215481871 3
  • GPU requirement

    GPU requirement

    Thanks for your great work in distilled datasets. I was wondering to know about your hardware setup for CIFAR100, tinyImagenet, and Imagenet(subset). How long did you need for your results (generating the experts and distillation step).

    opened by rave78 2
  • The clip value

    The clip value

    Hi thanks for your great work! I am curious about the clip_val. Why do you choose 2.5? why clipping needed? Could you please explain a little bit? Thanks! And when training with distilled data, we don't need clipping, right?

    for clip_val in [2.5]:
        std = torch.std(images_train)
        mean = torch.mean(images_train)
        upsampled = torch.clip(images_train, min=mean-clip_val*std, max=mean+clip_val*std)```
    
    opened by tao-bai 2
  • A question for the paper

    A question for the paper

    I am very interested in your work, but I have a question: can you directly train a randomly initialized network with the synthetic dataset? if 10-500 images can train a robust network, that's incredible. Or you have to use raw dataset to help distill images meanwhile train the network. Can you tell me the answer directly?

    opened by alittleCVer 2
  • Reproduce cross-architecture performance

    Reproduce cross-architecture performance

    Hi George, Thanks for your inspiring and great work.

    I would like to reproduce the cross-architecture accuracy. But I'm having difficulty to have a accuracy which is comparable to the accuracy listed in the paper. I think I might be missing some details. Could you please type out the command you used to produce the cross-architecture performance of Cifar 10 with 10 img/cls?

    Here is the command I used: First step: python buffer.py --dataset=CIFAR10 --model=ConvNet --train_epochs=50 --num_experts=100 --zca --buffer_path=buffer --data_path=data Second step: python3 distill.py --dataset=CIFAR10 --ipc=10 --syn_steps=30 --expert_epochs=2 --max_start_epoch=15 --zca --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path=buffer --data_path=data --eval_mode='M' --eval_it=1 --Iteration=300

    Is there some thing I'm missing here? In addition, did you change the parser augment epoch_eval_train when you produce the SOTA cross-architecture results?

    Thank you! Looking forward to your reply!

    opened by NiaLiu 1
  • question about learning rate

    question about learning rate

    HI, sorry to bother you I want to know why the lr_img learning rate is set to 1000,How did you determine 1000? Because usually the learning rate is used 0.1, 0.01, 0.001。

    and if i just change lr_img to 100 or smaller value, then the loss became nan.

    Can you tell me how it works and give some advice about setting Hyperparameter? (i know The Hyperparameters given now can reproduce the effect in the paper. but i want to use other network to distill datasets, i think The hyperparameters must be modified when using different networks.)

    opened by harrylee999 1
  • How did you get x̄ ± s in table 1

    How did you get x̄ ± s in table 1

    Hi George,

    Thanks for your great work, and sorry to bother you again.

    I have another question regarding the accuracy value shown in table 1. I assume there are two possible ways to get those numbers. 1, train synthetic data for a certain number of steps (e.g 9000 steps), then test the accuracy on the test dataset. 2, test on test dataset at every 100 steps of training on a synthetic dataset, then take a maximum accuracy.

    The second way is not valid since the test dataset should only be used one time in the end. So did you use the first method to get the accuracy? If so, how many steps did you take?

    Thank you, and hope you have a great day! Dai

    opened by NiaLiu 2
  • Question about imagenette

    Question about imagenette

    Hello, I am very interested in MTT. When I use your code about imagenette, I find it is different from the original imagenette, am I right?

    See imagenette in https://github.com/fastai/imagenette

    opened by yaolu-zjut 1
  • have trouble at distillating with VGG networks

    have trouble at distillating with VGG networks

    Hi,

    I encountered bloating Synthetic-LR and zero grant loss issue using VGG models.

    • including VGG11, VGG13, VGG16

    Similar issues are

    But, my experiments are ok with ConvNet and ResNet18 using similar scripts given below.


    Here is the snippet of the scripts

    SCRIPT_NAME=VGG13
    MODEL=VGG13
    DATASET=CIFAR10
    IMAGE_PER_CLASS=1
    
    python buffer.py --dataset=$DATASET --model=$MODEL --train_epochs=50 --num_experts=100 --buffer_path=$BUFFER_PATH --data_path=$DATA_PATH >> ./results/buffer_$SCRIPT_NAME.txt
    
    python distill.py --dataset=$DATASET --model=$MODEL --ipc=$IMAGE_PER_CLASS --syn_steps=20 --expert_epochs=3 --max_start_epoch=20 --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path=$BUFFER_PATH --data_path=$DATA_PATH >> ./results/distill_$SCRIPT_NAME.txt
    
    • VGG11 is using CIFAR100 and --ipc=10
    • all models are not using --zca whitening

    image image

    Either the Synthetic-LR goes to extremely positive or negative at very begining.

    Thank you.

    opened by ArmandXiao 4
Owner
George Cazenavette
Carnegie Mellon University
George Cazenavette
This is an official implementation of our CVPR 2021 paper "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression" (https://arxiv.org/abs/2104.02300)

Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression Introduction In this paper, we are interested in the bottom-up paradigm of estima

HRNet 367 Dec 27, 2022
The official implementation of our CVPR 2021 paper - Hybrid Rotation Averaging: A Fast and Robust Rotation Averaging Approach

Graph Optimizer This repo contains the official implementation of our CVPR 2021 paper - Hybrid Rotation Averaging: A Fast and Robust Rotation Averagin

Chenyu 109 Dec 23, 2022
This repo contains the official code of our work SAM-SLR which won the CVPR 2021 Challenge on Large Scale Signer Independent Isolated Sign Language Recognition.

Skeleton Aware Multi-modal Sign Language Recognition By Songyao Jiang, Bin Sun, Lichen Wang, Yue Bai, Kunpeng Li and Yun Fu. Smile Lab @ Northeastern

Isen (Songyao Jiang) 128 Dec 8, 2022
Code for our CVPR 2021 paper "MetaCam+DSCE"

Joint Noise-Tolerant Learning and Meta Camera Shift Adaptation for Unsupervised Person Re-Identification (CVPR'21) Introduction Code for our CVPR 2021

FlyingRoastDuck 59 Oct 31, 2022
the code for our CVPR 2021 paper Bilateral Grid Learning for Stereo Matching Network [BGNet]

BGNet This repository contains the code for our CVPR 2021 paper Bilateral Grid Learning for Stereo Matching Network [BGNet] Environment Python 3.6.* C

3DCV developer 87 Nov 29, 2022
Code for our CVPR 2021 Paper "Rethinking Style Transfer: From Pixels to Parameterized Brushstrokes".

Rethinking Style Transfer: From Pixels to Parameterized Brushstrokes (CVPR 2021) Project page | Paper | Colab | Colab for Drawing App Rethinking Style

CompVis Heidelberg 153 Jan 4, 2023
Code for our CVPR 2022 Paper "GEN-VLKT: Simplify Association and Enhance Interaction Understanding for HOI Detection"

GEN-VLKT Code for our CVPR 2022 paper "GEN-VLKT: Simplify Association and Enhance Interaction Understanding for HOI Detection". Contributed by Yue Lia

Yue Liao 47 Dec 4, 2022
PyTorch implementation of our Adam-NSCL algorithm from our CVPR2021 (oral) paper "Training Networks in Null Space for Continual Learning"

Adam-NSCL This is a PyTorch implementation of Adam-NSCL algorithm for continual learning from our CVPR2021 (oral) paper: Title: Training Networks in N

Shipeng Wang 34 Dec 21, 2022
Repository of our paper 'Refer-it-in-RGBD' in CVPR 2021

Refer-it-in-RGBD This is the repository of our paper 'Refer-it-in-RGBD: A Bottom-up Approach for 3D Visual Grounding in RGBD Images' in CVPR 2021 Pape

Haolin Liu 34 Nov 7, 2022
This project is the PyTorch implementation of our CVPR 2022 paper:

Requirements and Dependency Install PyTorch with CUDA (for GPU). (Experiments are validated on python 3.8.11 and pytorch 1.7.0) (For visualization if

Lei Huang 23 Nov 29, 2022
This is the official code of our paper "Diversity-based Trajectory and Goal Selection with Hindsight Experience Relay" (PRICAI 2021)

Diversity-based Trajectory and Goal Selection with Hindsight Experience Replay This is the official implementation of our paper "Diversity-based Traje

Tianhong Dai 6 Jul 18, 2022
Official code for our ICCV paper: "From Continuity to Editability: Inverting GANs with Consecutive Images"

GANInversion_with_ConsecutiveImgs Official code for our ICCV paper: "From Continuity to Editability: Inverting GANs with Consecutive Images" https://a

QingyangXu 38 Dec 7, 2022
Official code of the paper "ReDet: A Rotation-equivariant Detector for Aerial Object Detection" (CVPR 2021)

ReDet: A Rotation-equivariant Detector for Aerial Object Detection ReDet: A Rotation-equivariant Detector for Aerial Object Detection (CVPR2021), Jiam

csuhan 334 Dec 23, 2022
Official code for the paper: Deep Graph Matching under Quadratic Constraint (CVPR 2021)

QC-DGM This is the official PyTorch implementation and models for our CVPR 2021 paper: Deep Graph Matching under Quadratic Constraint. It also contain

Quankai Gao 55 Nov 14, 2022
Official code for the CVPR 2021 paper "How Well Do Self-Supervised Models Transfer?"

How Well Do Self-Supervised Models Transfer? This repository hosts the code for the experiments in the CVPR 2021 paper How Well Do Self-Supervised Mod

Linus Ericsson 157 Dec 16, 2022
Official source code to CVPR'20 paper, "When2com: Multi-Agent Perception via Communication Graph Grouping"

When2com: Multi-Agent Perception via Communication Graph Grouping This is the PyTorch implementation of our paper: When2com: Multi-Agent Perception vi

null 34 Nov 9, 2022
Official PyTorch code for CVPR 2020 paper "Deep Active Learning for Biased Datasets via Fisher Kernel Self-Supervision"

Deep Active Learning for Biased Datasets via Fisher Kernel Self-Supervision https://arxiv.org/abs/2003.00393 Abstract Active learning (AL) aims to min

Denis 29 Nov 21, 2022
CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.

selfcontact This repo is part of our project: On Self-Contact and Human Pose. [Project Page] [Paper] [MPI Project Page] It includes the main function

Lea Müller 68 Dec 6, 2022
CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.

SMPLify-XMC This repo is part of our project: On Self-Contact and Human Pose. [Project Page] [Paper] [MPI Project Page] License Software Copyright Lic

Lea Müller 83 Dec 14, 2022