Quantization library for PyTorch. Support low-precision and mixed-precision quantization, with hardware implementation through TVM.

Overview



HAWQ: Hessian AWare Quantization

HAWQ is an advanced quantization library written for PyTorch. HAWQ enables low-precision and mixed-precision uniform quantization, with direct hardware implementation through TVM.

For more details please see:

Installation

  • PyTorch version >= 1.4.0
  • Python version >= 3.6
  • For training new models, you'll also need NVIDIA GPUs and NCCL
  • To install HAWQ and develop locally:
git clone https://github.com/Zhen-Dong/HAWQ.git
cd HAWQ
pip install -r requirements.txt

Getting Started

Quantization-Aware Training

An example to run uniform 8-bit quantization for resnet50 on ImageNet.

export CUDA_VISIBLE_DEVICES=0
python quant_train.py -a resnet50 --epochs 1 --lr 0.0001 --batch-size 128 --data /path/to/imagenet/ --pretrained --save-path /path/to/checkpoints/ --act-range-momentum=0.99 --wd 1e-4 --data-percentage 0.0001 --fix-BN --checkpoint-iter -1 --quant-scheme uniform8

The commands for other quantization schemes and for other networks are shown in the model zoo.

Inference Acceleration

Experimental Results

Table I and Table II in HAWQ-V3: Dyadic Neural Network Quantization

ResNet18 on ImageNet

Model Quantization Model Size(MB) BOPS(G) Accuracy(%) Inference Speed (batch=8, ms) Download
ResNet18 Floating Points 44.6 1858 71.47 9.7 (1.0x) resnet18_baseline
ResNet18 W8A8 11.1 116 71.56 3.3 (3.0x) resnet18_uniform8
ResNet18 Mixed Precision 6.7 72 70.22 2.7 (3.6x) resnet18_bops0.5
ResNet18 W4A4 5.8 34 68.45 2.2 (4.4x) resnet18_uniform4

ResNet50 on ImageNet

Model Quantization Model Size(MB) BOPS(G) Accuracy(%) Inference Speed (batch=8, ms) Download
ResNet50 Floating Points 97.8 3951 77.72 26.2 (1.0x) resnet50_baseline
ResNet50 W8A8 24.5 247 77.58 8.5 (3.1x) resnet50_uniform8
ResNet50 Mixed Precision 18.7 154 75.39 6.9 (3.8x) resnet50_bops0.5
ResNet50 W4A4 13.1 67 74.24 5.8 (4.5x) resnet50_uniform4

More results for different quantization schemes and different models (also the corresponding commands and important notes) are available in the model zoo.
To download the quantized models through wget, please refer to a simple command in model zoo.
Checkpoints in model zoo are saved in floating point precision. To shrink the memory size, BitPack can be applied on weight_integer tensors, or directly on quantized_checkpoint.pth.tar file.

Related Works

License

HAWQ is released under the MIT license.

Comments
  • Similar running time with INT8 and INT4

    Similar running time with INT8 and INT4

    Hi, I run resnet50 with uniform8 and uniform4, but they have a similar running time.

    I run INT8 and INT4 as

    #!/bin/bash
    
    run_inference() {
            bit_config=$1
            num_layers=$2
    
            printf "%s\n" $bit_config
    
            python test_resnet_inference_time.py --bit-config $bit_config --num-layers $num_layers
    
            cp ./debug_output/resnet_generated.cu ./debug_output/resnet_manual.cu
    
            sed -i 's/h_w_fused_n_fused_i_fused_nn_fused_ii_fused_inner < 8;/h_w_fused_n_fused_i_fused_nn_fused_ii_fused_inner < 1;/g' ./debug_output/resnet_manual.cu
            sed -i 's/ax0_ax1_fused_ax2_fused_ax3_fused_inner < 8;/ax0_ax1_fused_ax2_fused_ax3_fused_inner < 1;/g' ./debug_output/resnet_manual.cu
    
            sleep 5
            python test_resnet_inference_time.py --bit-config $bit_config --num-layers $num_layers --manual-code
    }
    
    run_inference "bit_config_resnet50_uniform4"   50
    run_inference "bit_config_resnet50_uniform8"   50
    

    However, there is a similar running time with manual mode as

    Performed inference in 17.05ms (std = 0.15) for 8 samples
    Average per sample inference time: 2.13ms
    

    and

    Performed inference in 20.49ms (std = 0.27) for 8 samples
    Average per sample inference time: 2.56ms
    
    opened by haibao-yu 11
  • Issue on dict_keys

    Issue on dict_keys

    I found an issue trying to run your model on TVM.

    When I tried to run python hawq_utils_resnet50.py --model-dir ./data/resnet18_uniform4/ (assuming I want to run resnet18, uniform4 based trained model) This error appears

    Traceback (most recent call last):

    File "/home/kjk2020/tvm-newHAWQ/tvm_benchmark/hawq_utils_resnet50.py", line 483, in weight_integer = model['weight_integer']

    KeyError: 'weight_integer'

    and when I print out the keys of model, this appears dict_keys(['epoch', 'arch', 'state_dict', 'best_acc1', 'optimizer'])

    which do not include weight_integer ,,,

    Also, I just wanted to run the '6. Measure inference time (with uniform int4/int8 or custom mixed-precision bit configs in bit_config.py).' part but I get errors because some part of information ("all_impls" variable in "/--/tvm/python/tvm/relay/backend/compile_engine.py : 150") is empty, and as I follow those error traces(the former empty parts), that leads to external api functions. Here's the error code

    Traceback (most recent call last):

    File "/home/kjk2020/tvm-newHAWQ/tvm_benchmark/test_resnet_inference_time.py", line 235, in graph, lib, params = relay.build(func, target=TARGET_NAME, params=params)

    File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/relay/build_module.py", line 251, in build graph_json, mod, params = bld_mod.build(mod, target, target_host, params)

    File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/relay/build_module.py", line 120, in build self._build(mod, target, target_host)

    File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 219, in call raise get_last_ffi_error()

    KeyError: 'Traceback (most recent call last):\n [bt] (8) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)+0x8e) [0x7f63bba4269e]\n [bt] (7) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x91) [0x7f63bba47651]\n [bt] (6) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>)#6}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>)+0x27) [0x7f63bba44627]\n [bt] (5) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*)+0x43) [0x7f63bb8e9d73]\n [bt] (4) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&)+0x745) [0x7f63bb8ed215]\n [bt] (3) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call<tvm::RelayExpr, 3, tvm::RelayExpr ()(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr ( const&)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)+0x210) [0x7f63bb875bf0]\n [bt] (2) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::RelayExpr tvm::relay::LayoutRewritertvm::relay::alter_op_layout::AlterTransformMemorizer(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)+0xa45) [0x7f63bb8736f5]\n [bt] (1) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(tvm::relay::alter_op_layout::AlterTransformMemorizer::CallWithNewLayouts(tvm::relay::Call const&, std::vector<tvm::RelayExpr, std::allocatortvm::RelayExpr > const&)+0x773) [0x7f63bb8711f3]\n [bt] (0) /home/kjk2020/tvm-newHAWQ/tvm/build/libtvm.so(+0x13434fb) [0x7f63bbb174fb]\n File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun\n rv = local_pyfunc(*pyargs)\n File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/relay/op/nn/_nn.py", line 98, in alter_op_layout_conv2d\n return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)\n File "", line 2, in conv2d_alter_layout\n File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/target/generic_func.py", line 267, in dispatch_func\n return dispatch_dict[k](*args, **kwargs)\n File "/home/kjk2020/tvm-newHAWQ/tvm/topi/python/topi/cuda/conv2d_alter_op.py", line 39, in _alter_conv2d_layout\n relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)\n File "/home/kjk2020/tvm-newHAWQ/tvm/python/tvm/relay/backend/compile_engine.py", line 229, in select_implementation\n return best_plevel_impl, outputs[best_plevel_impl]\nKeyError: None'

    opened by wowow11111 4
  • ModelZoo file formats

    ModelZoo file formats

    When you unpack the downloaded model files from modleZoo files, some of them have different file formats. The instruction says as if they should all contain checkpoint.pth.tar file but, for example, resnet18_baseline.tar.gz, resnet18_uniform8, resnet50_baseline files contains just resnet.pth file.

    What do I do with these file formats? Is it okay to just change the format?

    opened by wowow11111 4
  • No module named 'bit_config'

    No module named 'bit_config'

    Hi,

    I try to run quant_train.py, but meet the error as below, how to solve it ?

    Traceback (most recent call last): File "quant_train.py", line 22, in from bit_config import * ModuleNotFoundError: No module named 'bit_config'

    Thx, Lei

    opened by leiwen83 3
  • Questions with HAWQ/tvm_benchmark

    Questions with HAWQ/tvm_benchmark

    Hi, I'm about to run the test with "tvm_benchmark/test_resnet_inference.py" on Tesla V100 and compare the result with Tesla T4 device. However, I encountered some errors on building tvm.relay.[relay.build(..)]. I know this is natural consequences as README informs that the procedure is for NVIDIA T4 GPU for inference speed-up But my question is:

    • Which part of the code makes GPU device dependancy? I guess it is due to int4 configuration on the code and using specific int4 branch of TVM. Am I right with this?
    • Even if this is the case, I still got questions with the errors I got because the error was about nvcc compile error on the temporary .cu file with type error. Does nvcc have this much GPU dependancy?

    image

    • Additionally, it worked fine with my T4 GPU server on the same environment except the device itself.

    I would appreciate your reply. Any reply would be helpful for me.

    opened by adwwsd 2
  • Issue about default HAWQ

    Issue about default HAWQ

    Hi, I've been working on running HAWQ based on my machine and now I finally could run the 'test_resnet_inference_time.py' file completely. Thus, I'm now working on a given zoo model and run it on gpu following your git explanation. (At last, want to run HAWQ on VTA, TVM based NPU)

    I re-downloaded from baseline and followed the steps you gave and am facing few questions. First of all, except for the 'resnet18_uniform8', your models downloadable from model zoo does not contain 'quantized_checkpoint.pth.tar' file but only 'checkpoint.pth.tar' file, which leads to error [No such file or directory error]. But 'hawq_utils_resnet50.py' is hard coded based on resnet50.

    So, What is the difference between checkpoint and quantized_checkpoint? Is it just okay to change from quantized_checkpoint to checkpoint in 'hawq_utils_resnet50.py' file?

    If I do, then the former error(the dict_key error) occurs. How do I change the parameters as "3. change PyTorch parameters to TVM format" for the ones that only contain checkpoint.pth.tar file?

    opened by wowow11111 2
  • Scale Parameter with Gradient

    Scale Parameter with Gradient

    Hi, I want to mix your HAWQ-v3 and QNN which implement custom gradient in scale parameters, like PACT, QIL, LSQ.

    I wonder if why didn't you tried to those scale paramter with gradient.

    Is there any problem with training? or something else?

    I would appreciate for you reply.

    opened by thuako 2
  • Cannot create Compute Engine Instance in Google Cloud

    Cannot create Compute Engine Instance in Google Cloud

    Is there any people who also cannot create Compute Engine Instance in Google Cloud with Tesla GPUs with "The zone 'projects/graceful-castle-301212/zones/us-central1-a' does not have enough resources available to fulfill the request. Try a different zone, or try again later."? I have blocked here several days.

    opened by haibao-yu 1
  • Add qonnx support

    Add qonnx support

    Adding support to export quantized models to the QONNX format. All layers are supported in this format. While not many model architectures were tested, an MLP and CNN were heavily used with varying settings, such as symmetric and asymmetric quantization, bit-width, etc. A README.md was added to showcase the simplicity of exporting models and executing the onnx graph.

    An additional directory labeled export has been added to utils. All other files and directories are left unchanged. After exporting models, we perform a post-export optimization step with python packages that need to be installed. These packages are onnx, qonnx, and onnxoptimizer.

    This PR does not include support for BinaryQuant nodes. This has been left for a future PR along with further onnx optimizations.

    opened by jicampos 0
  • Question About bit_config.py

    Question About bit_config.py

    Great work! I'm wondering why the re-quant op before stage1 is always configured to 16 bits as 'quant_act_int32': 16. From my perspective, configure it to the same as 'stage1.unit1.quant_act' seems to make no difference.

    opened by rainyBJ 0
  • I have an Issue in test_resnet_inference.py

    I have an Issue in test_resnet_inference.py

    I try to run test_resnet_inference.py, but i have an issue about TypeError:'IntImm'. How can i solve it?

    (qt) dmsl3@dmsl3:~/jh/HAWQ/tvm_benchmark$ python test_resnet_inference.py --model-dir ./fix_y/ File synset.txt exists, skip. Traceback (most recent call last):

    File "/home/dmsl3/jh/HAWQ/tvm_benchmark/test_resnet_inference.py", line 127, in graph, lib, params = relay.build(func, target=TARGET_NAME, params=params)

    File "/home/dmsl3/tvm/python/tvm/relay/build_module.py", line 251, in build graph_json, mod, params = bld_mod.build(mod, target, target_host, params)

    File "/home/dmsl3/tvm/python/tvm/relay/build_module.py", line 114, in build target = _update_target(target)

    File "/home/dmsl3/tvm/python/tvm/relay/build_module.py", line 47, in _update_target dev_type = tvm_expr.IntImm("int32", _nd.context(dev).device_type)

    File "/home/dmsl3/tvm/python/tvm/runtime/ndarray.py", line 240, in context return TVMContext(dev_type, dev_id)

    File "/home/dmsl3/tvm/python/tvm/_ffi/runtime_ctypes.py", line 175, in init self.device_type = device_type

    TypeError: 'IntImm' object cannot be interpreted as an integer

    opened by AlwaysHoon 0
  • Shift operation in TVM

    Shift operation in TVM

    Thank you for your perfect work!

    I am wondering how the shift operation in pytorch corresponds to TVM, since I don't find the relevant operation in the code in TVM.

    Thank you very much!

    opened by zkkli 0
  • Can't load provided checkpoints

    Can't load provided checkpoints

    I downloaded baseline and quantized .pth files for resnet18 and 50, but when i'm trying to load them i'm facing with error

    python quant_train.py -a resnet50 --epochs 1 --lr 0.0001 --batch-size 128 --data data/imagenet/ --pretrained --save-path ./checkpoints/ --act-range-momentum=0.99 --wd 1e-4 --data-percentage 0.0001 --fix-BN --checkpoint-iter -1 --quant-scheme uniform8 --resume ./HAWQ/loaded_models/resnet50_baseline/resnet50.pth

    Traceback (most recent call last): File "quant_train.py", line 766, in main() File "quant_train.py", line 205, in main main_worker(args.gpu, ngpus_per_node, args) File "quant_train.py", line 242, in main_worker checkpoint = torch.load(args.resume)['state_dict'] KeyError: 'state_dict'

    python quant_train.py -a resnet18 --epochs 1 --lr 0.0001 --batch-size 128 --data data/imagenet/ --pretrained --save-path ./checkpoints/ --act-range-momentum=0.99 --wd 1e-4 --data-percentage 0.0001 --fix-BN --checkpoint-iter -1 --quant-scheme uniform8 --resume "/workspace/LyginE/projects/paradigma/quantization/HAWQ/loaded_models/resnet18_uniform8/quantized_checkpoint.pth.tar" --resume-quant

    Traceback (most recent call last): File "quant_train.py", line 766, in main() File "quant_train.py", line 205, in main main_worker(args.gpu, ngpus_per_node, args) File "quant_train.py", line 307, in main_worker checkpoint = torch.load(args.resume)['state_dict'] KeyError: 'state_dict'

    opened by L-ED 0
  • About quantization scheme

    About quantization scheme

    Hello. Thanks for your work.I am a newcomer of quantization, and I feel confused about the quantization scheme. It seems that it has so many configurations, like uniform, bops, model size, latency, etc. Could you please explain the differences between these models, and how to train these model? https://github.com/Zhen-Dong/HAWQ/blob/main/model_zoo.md

    opened by NoLookDefense 0
Owner
Zhen Dong
PhD student at BAIR; B.S. at PKU EECS.
Zhen Dong
BitPack is a practical tool to efficiently save ultra-low precision/mixed-precision quantized models.

BitPack is a practical tool that can efficiently save quantized neural network models with mixed bitwidth.

Zhen Dong 36 Dec 2, 2022
This is the pytorch implementation for the paper: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation, which is accepted to ICCV2021.

GMPQ: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation This is the pytorch implementation for the paper: Generalizable Mix

null 18 Sep 2, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

NVIDIA Corporation 6.9k Jan 3, 2023
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

Introduction This is a Python package available on PyPI for NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pyto

Artit 'Art' Wangperawong 5 Sep 29, 2021
Nonuniform-to-Uniform Quantization: Towards Accurate Quantization via Generalized Straight-Through Estimation. In CVPR 2022.

Nonuniform-to-Uniform Quantization This repository contains the training code of N2UQ introduced in our CVPR 2022 paper: "Nonuniform-to-Uniform Quanti

Zechun Liu 60 Dec 28, 2022
EdMIPS: Rethinking Differentiable Search for Mixed-Precision Neural Networks

EdMIPS is an efficient algorithm to search the optimal mixed-precision neural network directly without proxy task on ImageNet given computation budgets. It can be applied to many popular network architectures, including ResNet, GoogLeNet, and Inception-V3.

Zhaowei Cai 47 Dec 30, 2022
DiffQ performs differentiable quantization using pseudo quantization noise. It can automatically tune the number of bits used per weight or group of weights, in order to achieve a given trade-off between model size and accuracy.

Differentiable Model Compression via Pseudo Quantization Noise DiffQ performs differentiable quantization using pseudo quantization noise. It can auto

Facebook Research 145 Dec 30, 2022
QTool: A Low-bit Quantization Toolbox for Deep Neural Networks in Computer Vision

This project provides abundant choices of quantization strategies (such as the quantization algorithms, training schedules and empirical tricks) for quantizing the deep neural networks into low-bit counterparts.

Monash Green AI Lab 51 Dec 10, 2022
Provide partial dates and retain the date precision through processing

Prefix date parser This is a helper class to parse dates with varied degrees of precision. For example, a data source might state a date as 2001, 2001

Friedrich Lindenberg 13 Dec 14, 2022
Code for HLA-Face: Joint High-Low Adaptation for Low Light Face Detection (CVPR21)

HLA-Face: Joint High-Low Adaptation for Low Light Face Detection The official PyTorch implementation for HLA-Face: Joint High-Low Adaptation for Low L

Wenjing Wang 77 Dec 8, 2022
Official code of "R2RNet: Low-light Image Enhancement via Real-low to Real-normal Network."

R2RNet Official code of "R2RNet: Low-light Image Enhancement via Real-low to Real-normal Network." Jiang Hai, Zhu Xuan, Ren Yang, Yutong Hao, Fengzhu

null 77 Dec 24, 2022
ViViT: Curvature access through the generalized Gauss-Newton's low-rank structure

ViViT is a collection of numerical tricks to efficiently access curvature from the generalized Gauss-Newton (GGN) matrix based on its low-rank structure. Provided functionality includes computing

Felix Dangel 12 Dec 8, 2022
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

SynSense 21 Dec 14, 2022
Multiple types of NN model optimization environments. It is possible to directly access the host PC GUI and the camera to verify the operation. Intel iHD GPU (iGPU) support. NVIDIA GPU (dGPU) support.

mtomo Multiple types of NN model optimization environments. It is possible to directly access the host PC GUI and the camera to verify the operation.

Katsuya Hyodo 24 Mar 2, 2022
Finite difference solution of 2D Poisson equation. Can handle Dirichlet, Neumann and mixed boundary conditions.

Poisson-solver-2D Finite difference solution of 2D Poisson equation Current version can handle Dirichlet, Neumann, and mixed (combination of Dirichlet

Mohammad Asif Zaman 34 Dec 23, 2022
A library for low-memory inferencing in PyTorch.

Pylomin Pylomin (PYtorch LOw-Memory INference) is a library for low-memory inferencing in PyTorch. Installation ... Usage For example, the following c

null 3 Oct 26, 2022
Large Scale Multi-Illuminant (LSMI) Dataset for Developing White Balance Algorithm under Mixed Illumination

Large Scale Multi-Illuminant (LSMI) Dataset for Developing White Balance Algorithm under Mixed Illumination (ICCV 2021) Dataset License This work is l

DongYoung Kim 33 Jan 4, 2023