DrNAS: Dirichlet Neural Architecture Search

Related tags

Deep Learning DrNAS
Overview

DrNAS

About

Code accompanying the paper
ICLR'2021: DrNAS: Dirichlet Neural Architecture Search paper
Xiangning Chen*, Ruochen Wang*, Minhao Cheng*, Xiaocheng Tang, Cho-Jui Hsieh

This code is based on the implementation of NAS-Bench-201 and PC-DARTS.

This paper proposes a novel differentiable architecture search method by formulating it into a distribution learning problem. We treat the continuously relaxed architecture mixing weight as random variables, modeled by Dirichlet distribution. With recently developed pathwise derivatives, the Dirichlet parameters can be easily optimized with gradient-based optimizer in an end-to-end manner. This formulation improves the generalization ability and induces stochasticity that naturally encourages exploration in the search space. Furthermore, to alleviate the large memory consumption of differentiable NAS, we propose a simple yet effective progressive learning scheme that enables searching directly on large-scale tasks, eliminating the gap between search and evaluation phases. Extensive experiments demonstrate the effectiveness of our method. Specifically, we obtain a test error of 2.46% for CIFAR-10, 23.7% for ImageNet under the mobile setting. On NAS-Bench-201, we also achieve state-of-the-art results on all three datasets and provide insights for the effective design of neural architecture search algorithms.

Results

On NAS-Bench-201

The table below shows the test accuracy on NAS-Bench-201 space. We achieve the state-of-the-art results on all three datasets. On CIFAR-100, DrNAS even achieves the global optimal with no variance!

Method CIFAR-10 (test) CIFAR-100 (test) ImageNet-16-120 (test)
ENAS 54.30 ± 0.00 10.62 ± 0.27 16.32 ± 0.00
DARTS 54.30 ± 0.00 38.97 ± 0.00 18.41 ± 0.00
SNAS 92.77 ± 0.83 69.34 ± 1.98 43.16 ± 2.64
PC-DARTS 93.41 ± 0.30 67.48 ± 0.89 41.31 ± 0.22
DrNAS (ours) 94.36 ± 0.00 73.51 ± 0.00 46.34 ± 0.00
optimal 94.37 73.51 47.31

For every search process, we sample 100 architectures from the current Dirichlet distribution and plot their accuracy range along with the current architecture selected by Dirichlet mean (solid line). The figure below shows that the accuracy range of the sampled architectures starts very wide but narrows gradually during the search phase. It indicates that DrNAS learns to encourage exploration at the early stages and then gradually reduces it towards the end as the algorithm becomes more and more confident of the current choice. Moreover, the performance of our architectures can consistently match the best performance of the sampled architectures, indicating the effectiveness of DrNAS.

On DARTS Space (CIFAR-10)

DrNAS achieves an average test error of 2.46%, ranking top amongst recent NAS results.

Method Test Error (%) Params (M) Search Cost (GPU days)
ENAS 2.89 4.6 0.5
DARTS 2.76 ± 0.09 3.3 1.0
SNAS 2.85 ± 0.02 2.8 1.5
PC-DARTS 2.57 ± 0.07 3.6 0.1
DrNAS (ours) 2.46 ± 0.03 4.1 0.6

On DARTS Space (ImageNet)

DrNAS can perform a direct search on ImageNet and achieves a top-1 test error below 24.0%!

Method Top-1 Error (%) Params (M) Search Cost (GPU days)
DARTS* 26.7 4.7 1.0
SNAS* 27.3 4.3 1.5
PC-DARTS 24.2 5.3 3.8
DSNAS 25.7 - -
DrNAS (ours) 23.7 5.7 4.6

* not a direct search

Usage

Architecture Search

Search on NAS-Bench-201 Space: (3 datasets to choose from)

  • Data preparation: Please first download the 201 benchmark file and prepare the api follow this repository.

  • cd 201-space && python train_search.py

  • With Progressively Pruning: cd 201-space && python train_search_progressive.py

Search on DARTS Space:

  • Data preparation: For a direct search on ImageNet, we follow PC-DARTS to sample 10% and 2.5% images for earch class as train and validation.

  • CIFAR-10: cd DARTS-space && python train_search.py

  • ImageNet: cd DARTS-space && python train_search_imagenet.py

Architecture Evaluation

  • CIFAR-10: cd DARTS-space && python train.py --cutout --auxiliary

  • ImageNet: cd DARTS-space && python train_imagenet.py --auxiliary

Reference

If you find this code useful in your research please cite

@inproceedings{chen2021drnas,
    title={Dr{\{}NAS{\}}: Dirichlet Neural Architecture Search},
    author={Xiangning Chen and Ruochen Wang and Minhao Cheng and Xiaocheng Tang and Cho-Jui Hsieh},
    booktitle={International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=9FWas6YbmB3}
}

Related Publications

Comments
  • Raw_id in configure_optimizer

    Raw_id in configure_optimizer

    Hi!

    Thank you for this code! While trying the code I run into an error while running configure_optimizer(optimizer_old, optimizer_new): stating:

    state_old = optimizer_old.state_dict()['state'][p.raw_id]
    KeyError: 140435163178672
    

    In general, I do not see a way how the raw_id can be mapped to a dictionary key?

    Thanks for any pointers/help!

    opened by martinferianc 6
  • setting init_channels=16 in train_search.py of DARTS search space gives an error

    setting init_channels=16 in train_search.py of DARTS search space gives an error

    The error appears at:

    File ".../DrNAS/DARTS-space/operations.py", line 189, in __init__
        assert C_out % 2 == 0
    

    So what value of init_channels is expected/allowed? Thanks!

    opened by optyang 5
  • Reproducing nb201 results

    Reproducing nb201 results

    Hi,

    Thank you for the great work!

    I was trying to reproduce the nb201 results, however, while I was able to reproduce the results for Cifar10 and Imagenet16-120, I couldn't do the same for Cifar100. I am running the train_search.py file, and just changing the dataset argument to run.

    The results I am getting for 100 epochs using the default hyperparameters provided in the script are:

    • NB201 test accuracy cifar10: 94.36%
    • NB201 test accuracy cifar100: 70.47%
    • NB201 test accuracy imagenet16-120: 46.34%

    Any help would be appreciated!

    opened by gurizab 4
  • Key Error

    Key Error

    After the 25th epoch, I will receive a KeyError:

    Traceback (most recent call last): File "train_search.py", line 238, in <module> main() File "train_search.py", line 151, in main optimizer = configure_optimizer(optimizer, torch.optim.SGD( File "../net2wider.py", line 84, in configure_optimizer state_old = optimizer_old.state_dict()['state'][p.raw_id] KeyError: 139737672243200

    Any idea how to solve this? Is that a version conflict issue?

    opened by edixiong 4
  • Version issue and replication

    Version issue and replication

    Hello!

    I am trying to replicate the results of DrNAS for NASbench201 space in Table 4 of your paper. I want to generate the test accuracy for CIFAR10 (94.36% in table 4 of the paper).

    1. Are you using progressive training (train_search_progressive.py) to get those results in Table 4? Additionally you have provided instructions for the evaluation phase in the DARTS space. But what script do you use for the eval phase in NASbench201?
    2. Could you share which version of pytorch, cudatoolkit, torchvision, tensorboard etc you used? It would be great if you have a screenshot for the environment you used to get the results.

    Thanks a lot for your help!

    opened by sumegha1024 2
  • Error in cinfigure_optimizer

    Error in cinfigure_optimizer

    Hey i'm trying to run /201space/train-search-progressive but there seems to be a problem in the configure_optimizer function and it encounters the following error can you help me figure this out?

    Traceback (most recent call last): File "/content/drive/MyDrive/DrNAS-master/201-space/train_search_progressive.py", line 336, in main() File "/content/drive/MyDrive/DrNAS-master/201-space/train_search_progressive.py", line 251, in main weight_decay=args.weight_decay)) File "/content/drive/MyDrive/DrNAS-master/net2wider.py", line 84, in configure_optimizer state_old = optimizer_old.state_dict()['state'][p.raw_id] KeyError: 140076590179920

    reading this link I believe the problem might be because of the pytorch version i'm using which is 1.10 I don't know what version is required for your code

    opened by jahdkaran 2
  • where is the regularization of the Dirichlet distribution parameter \beta as stated in the paper

    where is the regularization of the Dirichlet distribution parameter \beta as stated in the paper

    Hello, @xiangning-chen , thanks for your released code. I found in the code, you do not implement the regularization of the Dirichlet distribution parameter \beta as stated in the paper, I am wondering if I understand it wrongly or what? Also, the pathwise gradient estimator is not appearing in the code?

    PLease help me with this issue. Thanks,

    opened by d12306 1
  • Mobile Setting for ImageNet

    Mobile Setting for ImageNet

    Thanks for the great work!

    1. Is train_search_imagenet.py under the mobile setting?
    2. Could you explain the details about the mobile setting?
      1. I only found "input image size is 224×224" and "the number of multiply-add operations in the model is restricted to be less than 600M". Are these all?
      2. Is this setting applied during search or train-from-scratch, and how to apply them? I could not see any FLOPs constraint during search or train-from-scratch.

    Thank you!

    opened by chenwydj 1
  • Reproducing results

    Reproducing results

    Hi,

    Thank you for the great work and for open sourcing your code!

    I have tried reproducing the results for NB201 and the DARTS search space.

    On the NB201 search space using cd 201-space && python train_search.py, I can reproduce the results from your paper as follows:

    • cifar10 : 94.360000
    • cifar100 : 73.510000
    • Imagenet16-120 : 46.340000

    However on the DARTS search space when searching on CIFAR-10 using cd DARTS-space && python train_search.py I am not able to obtain the same genotype as the one mentioned in the repo. With the new genotype I get an error of 2.89±0.091.

    Secondly when performing evaluation on the DARTS search space using cd DARTS-space && python train.py --cutout --auxiliary and the DrNAS_cifar10 genotype from the repo, I could obtain an error of 2.67±0.090 which is higher than the 2.46 ± 0.03 reported in the paper.

    Any help in replicating the results on the DARTS search space would be greatly appreciated! Thanks!

    opened by rheasukthanker 0
  • Reproducing the results

    Reproducing the results

    Hi author,

    I am having difficulty reproducing the results on cifar-10. The paper claimed test error of 2.46+-0.03 with 600 epochs, but when I am evaluating with the provided 'DrNAS_cifar10' genotype, I only get accuracy 94.86 with 600 epochs and getting accuracy of 97.44 with 1200 epochs. It seems like the default parameter in the code matches the parameter claimed in the paper, or did I miss something here?

    I saw in PC-DARTS github page that there is much randomness in training cifar10 so the result is not stable, is it also the case here?

    Thank you for your response.

    opened by edixiong 1
Owner
Xiangning Chen
UCLA CS Ph.D. Student
Xiangning Chen
Finite difference solution of 2D Poisson equation. Can handle Dirichlet, Neumann and mixed boundary conditions.

Poisson-solver-2D Finite difference solution of 2D Poisson equation Current version can handle Dirichlet, Neumann, and mixed (combination of Dirichlet

Mohammad Asif Zaman 34 Dec 23, 2022
DeepHyper: Scalable Asynchronous Neural Architecture and Hyperparameter Search for Deep Neural Networks

What is DeepHyper? DeepHyper is a software package that uses learning, optimization, and parallel computing to automate the design and development of

DeepHyper Team 214 Jan 8, 2023
Model search is a framework that implements AutoML algorithms for model architecture search at scale

Model search (MS) is a framework that implements AutoML algorithms for model architecture search at scale. It aims to help researchers speed up their exploration process for finding the right model architecture for their classification problems (i.e., DNNs with different types of layers).

Google 3.2k Dec 31, 2022
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.

Deep Image Search - AI-Based Image Search Engine Deep Image Search is an AI-based image search engine that includes deep transfer learning features Ex

null 139 Jan 1, 2023
[ICLR 2021] "Neural Architecture Search on ImageNet in Four GPU Hours: A Theoretically Inspired Perspective" by Wuyang Chen, Xinyu Gong, Zhangyang Wang

Neural Architecture Search on ImageNet in Four GPU Hours: A Theoretically Inspired Perspective [PDF] Wuyang Chen, Xinyu Gong, Zhangyang Wang In ICLR 2

VITA 156 Nov 28, 2022
[ICLR 2021] HW-NAS-Bench: Hardware-Aware Neural Architecture Search Benchmark

HW-NAS-Bench: Hardware-Aware Neural Architecture Search Benchmark Accepted as a spotlight paper at ICLR 2021. Table of content File structure Prerequi

null 72 Jan 3, 2023
BossNAS: Exploring Hybrid CNN-transformers with Block-wisely Self-supervised Neural Architecture Search

BossNAS This repository contains PyTorch evaluation code, retraining code and pretrained models of our paper: BossNAS: Exploring Hybrid CNN-transforme

Changlin Li 127 Dec 26, 2022
Deep Multimodal Neural Architecture Search

MMNas: Deep Multimodal Neural Architecture Search This repository corresponds to the PyTorch implementation of the MMnas for visual question answering

Vision and Language Group@ MIL 23 Dec 21, 2022
Official implementation of Rethinking Graph Neural Architecture Search from Message-passing (CVPR2021)

Rethinking Graph Neural Architecture Search from Message-passing Intro The GNAS can automatically learn better architecture with the optimal depth of

Shaofei Cai 48 Sep 30, 2022
Block-wisely Supervised Neural Architecture Search with Knowledge Distillation (CVPR 2020)

DNA This repository provides the code of our paper: Blockwisely Supervised Neural Architecture Search with Knowledge Distillation. Illustration of DNA

Changlin Li 215 Dec 19, 2022
"NAS-Bench-301 and the Case for Surrogate Benchmarks for Neural Architecture Search".

NAS-Bench-301 This repository containts code for the paper: "NAS-Bench-301 and the Case for Surrogate Benchmarks for Neural Architecture Search". The

AutoML-Freiburg-Hannover 57 Nov 30, 2022
[CVPR21] LightTrack: Finding Lightweight Neural Network for Object Tracking via One-Shot Architecture Search

LightTrack: Finding Lightweight Neural Networks for Object Tracking via One-Shot Architecture Search The official implementation of the paper LightTra

Multimedia Research 290 Dec 24, 2022
Code release to accompany paper "Geometry-Aware Gradient Algorithms for Neural Architecture Search."

Geometry-Aware Gradient Algorithms for Neural Architecture Search This repository contains the code required to run the experiments for the DARTS sear

null 18 May 27, 2022
Official PyTorch implementation of "Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets" (ICLR 2021)

Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets This is the official PyTorch implementation for the paper Rapid Neural A

null 48 Dec 26, 2022
code for "AttentiveNAS Improving Neural Architecture Search via Attentive Sampling"

code for "AttentiveNAS Improving Neural Architecture Search via Attentive Sampling"

Facebook Research 94 Oct 26, 2022
Few-shot Neural Architecture Search

One-shot Neural Architecture Search uses a single supernet to approximate the performance each architecture. However, this performance estimation is super inaccurate because of co-adaption among operations in supernet.

Yiyang Zhao 38 Oct 18, 2022
PyTorch implementation of "Efficient Neural Architecture Search via Parameters Sharing"

Efficient Neural Architecture Search (ENAS) in PyTorch PyTorch implementation of Efficient Neural Architecture Search via Parameters Sharing. ENAS red

Taehoon Kim 2.6k Dec 31, 2022
An implementation for Neural Architecture Search with Random Labels (CVPR 2021 poster) on Pytorch.

Neural Architecture Search with Random Labels(RLNAS) Introduction This project provides an implementation for Neural Architecture Search with Random L

null 18 Nov 8, 2022