The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"

Overview

Hierarchical Token Semantic Audio Transformer

Introduction

The Code Repository for "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection", in ICASSP 2022.

In this paper, we devise a model, HTS-AT, by combining a swin transformer with a token-semantic module and adapt it in to audio classification and sound event detection tasks. HTS-AT is an efficient and light-weight audio transformer with a hierarchical structure and has only 30 million parameters. It achieves new state-of-the-art (SOTA) results on AudioSet and ESC-50, and equals the SOTA on Speech Command V2. It also achieves better performance in event localization than the previous CNN-based models.

HTS-AT Architecture

Classification Results on AudioSet, ESC-50, and Speech Command V2 (mAP)

HTS-AT ClS Result

Localization/Detection Results on DESED dataset (F1-Score)

HTS-AT Localization Result

Getting Started

Install Requirments

pip install -r requirements.txt

Download and Processing Datasets

  • config.py
change the varible "dataset_path" to your audioset address
change the variable "desed_folder" to your DESED address
change the classes_num to 527
./create_index.sh # 
// remember to change the pathes in the script
// more information about this script is in https://github.com/qiuqiangkong/audioset_tagging_cnn

python main.py save_idc 
// count the number of samples in each class and save the npy files
Open the jupyter notebook at esc-50/prep_esc50.ipynb and process it
Open the jupyter notebook at scv2/prep_scv2.ipynb and process it
python conver_desed.py 
// will produce the npy data files

Set the Configuration File: config.py

The script config.py contains all configurations you need to assign to run your code. Please read the introduction comments in the file and change your settings. For the most important part: If you want to train/test your model on AudioSet, you need to set:

dataset_path = "your processed audioset folder"
dataset_type = "audioset"
balanced_data = True
loss_type = "clip_bce"
sample_rate = 32000
hop_size = 320 
classes_num = 527

If you want to train/test your model on ESC-50, you need to set:

dataset_path = "your processed ESC-50 folder"
dataset_type = "esc-50"
loss_type = "clip_ce"
sample_rate = 32000
hop_size = 320 
classes_num = 50

If you want to train/test your model on Speech Command V2, you need to set:

dataset_path = "your processed SCV2 folder"
dataset_type = "scv2"
loss_type = "clip_bce"
sample_rate = 16000
hop_size = 160
classes_num = 35

If you want to test your model on DESED, you need to set:

resume_checkpoint = "Your checkpoint on AudioSet"
heatmap_dir = "localization results output folder"
test_file = "output heatmap name"
fl_local = True
fl_dataset = "Your DESED npy file"

Train and Evaluation

Notice: Our model is run on DDP mode and requires at least two GPU cards. If you want to use a single GPU for training and evaluation, you need to mannually change sed_model.py and main.py

All scripts is run by main.py:

Train: CUDA_VISIBLE_DEVICES=1,2,3,4 python main.py train

Test: CUDA_VISIBLE_DEVICES=1,2,3,4 python main.py test

Ensemble Test: CUDA_VISIBLE_DEVICES=1,2,3,4 python main.py esm_test 
// See config.py for settings of ensemble testing

Weight Average: python main.py weight_average
// See config.py for settings of weight averaging

Localization on DESED

CUDA_VISIBLE_DEVICES=1,2,3,4 python main.py test
// make sure that fl_local=True in config.py
python fl_evaluate.py
// organize and gather the localization results
fl_evaluate_f1.ipynb
// Follow the notebook to produce the results

Model Checkpoints:

We provide the model checkpoints on three datasets (and additionally DESED dataset) in this link. Feel free to download and test it.

Citing

@inproceedings{htsat-ke2022,
  author = {Ke Chen and Xingjian Du and Bilei Zhu and Zejun Ma and Taylor Berg-Kirkpatrick and Shlomo Dubnov},
  title = {HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection},
  booktitle = {{ICASSP} 2022}
}

Our work is based on Swin Transformer, which is a famous image classification transformer model.

Issues
  • The mAP following Audioset Recipe is very low

    The mAP following Audioset Recipe is very low

    Hi, I downloaded model checkpoint files that you provided through Google drive and followed README Audios Evaluate code. Test: CUDA_VISIBLE_DEVICES=1,2,3,4 python main.py test I expected to get similar performance to your paper But I got very low mAP. A number of eval dataset I used are 18,887. I would like to know the your data set size if possible. Attached Single model evaluation(HTSAT_AudioSet_Saved_1.ckpt) results pic. Thanks.

    opened by kimsojeong1225 11
  • Hello, I have the following problem when running the code, what is the reason?

    Hello, I have the following problem when running the code, what is the reason?

    RuntimeError: upsample_bicubic2d_backward_out_cuda does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation, or you can use the 'warn_only=True' option, if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation

    opened by ykingliu 6
  • get bad result for esc50

    get bad result for esc50

    i have one GPU, so i changed some code in model.py and sed_model.py and main.py

    and set config.py like that:

    dataset_type = "esc-50" loss_type = "clip_ce" sample_rate = 32000 classes_num = 50

    then i just get ACC : 0.55

    i changed

    deterministic=False dist.init_process_group( backend="nccl", init_method="tcp://localhost:23456", rank=0, world_size=1 ) for init_process_group error

    code can be runing. but not get results same as your paper.

    opened by wac81 6
  • masking and difference in mix-up strategy

    masking and difference in mix-up strategy

    Hi Ke,

    Thanks for the great work and open sourcing the code!

    I'd like to build from your excellent codebase, and I have a few questions regarding the details:

    1. I couldn't find any information about padding mask. Is it not used in the model?
    2. the mix-up procedure seems to be a bit different from AST. 2.1 In AST, they mix up the raw waveform before applying transforms, while in HST-AT, you get fbanks first, and then mix-up fbanks. 2.2 In AST, the mix-up waveform is randomly sampled from the entire dataset, while you sample within the current batch. 2.3 In AST, the they also mix-up labels using lambda * label1 + (1-lambda)*label2, while HST-AT does not mix labels. Not sure if the three differences will make a big difference in performance, but I'm curious about your thoughts.

    Thanks, Puyuan

    opened by jasonppy 4
  • Different length audio input for infer mode

    Different length audio input for infer mode

    Hi, thanks for the interesting work!

    I have a question about the infer mode in htsat.py. When training, the length of audio input will always be 10 seconds. When inference, the model needs to handle variable-length audio input which could be longer or shorter than 10 seconds, but for the infer mode in the htsat.py,

     if infer_mode:
                # in infer mode. we need to handle different length audio input
                frame_num = x.shape[2]
                target_T = int(self.spec_size * self.freq_ratio)
                repeat_ratio = math.floor(target_T / frame_num)
                x = x.repeat(repeats=(1,1,repeat_ratio,1))
                x = self.reshape_wav2img(x) 
    

    What if the length of input frame_num > target_T (should be 256x4=1024 here)? If so, repeat_ratio will be 0. So how does the model process the audio longer than 10 seconds for inference here?

    opened by CaptainPrice12 3
  • Learning rate

    Learning rate

    Hi I think lr in code is different for paper. In the paper show learning rate is 0.05,0.1,0.2 in the first three epochs, but code lr is 2e-5,5e-5,1e-4 (config.py lr_rate=[0.02,0.05,0.1]). I revised lr_rate=[50,100,200] to match the paper but model training shows bad results. I want to know what method is right to get the same results indicated in the paper.

    opened by kimsojeong1225 3
  • error in training

    error in training

    Hi,RetroCirce

       I get the error when running the esc_50 script for training as follows. How to solved it ?
    

    RuntimeError: upsample_bicubic2d_backward_out_cuda does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation, or you can use the 'warn_only=True' option, if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation. Epoch 0: 0%| | 0/65 [00:04<?, ?it/s]

    opened by joewale 2
  • training and evaluation on a single GPU

    training and evaluation on a single GPU

    HI, RetroCirce, I want to use a single GPU for training and evaluation. How to mannually change the code of sed_model.py and main.py?

     Looking forward to your reply. Thanks.
    
    opened by joewale 2
  • Stopped during Audioset training

    Stopped during Audioset training

    Hi, Thanks for your good study. I tried to trianing Audioset using your code. After epoch 2 10%, Training is not working.(Always happened this epoch) Not shutdown just not working. Can I know why it was happened? I attached log monitor.

    opened by kimsojeong1225 2
  • 预训练的.ckpt文件

    预训练的.ckpt文件

    你好,我看见在htsat_esc_training.ipynb文件中的数据准备部分,下载完ESC-50的文件后有要求下载一个名为htsat_audioset_pretrain.ckpt的文件,但是我试了并无法下载,1OK8a5XuMVLyeVKF117L8pfxeZYdfSDZv这是对应的ID,我想知道怎么在不适用gdown的情况下下载这个文件

    opened by dong-0412 1
  • bug

    bug

    Global seed set to 970131

    each batch size: 8

    INFO:root:total dataset size: 400

    GPU available: True, used: True

    TPU available: False, using: 0 TPU cores

    IPU available: False, using: 0 IPUs

    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

    Testing: 98%|█████████▊| 49/50 [00:03<00:00, 22.58it/s]Traceback (most recent call last):

    File "E:/HTS/HTS/main.py", line 429, in

    main()
    

    File "E:/HTS/HTS/main.py", line 417, in main

    test()
    

    File "E:/HTS/HTS/main.py", line 247, in test

    trainer.test(model, datamodule=audioset_data)
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 911, in test

    return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 685, in _call_and_handle_interrupt

    return trainer_fn(*args, **kwargs)
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 954, in _test_impl

    results = self._run(model, ckpt_path=self.tested_ckpt_path)
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1199, in _run

    self._dispatch()
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1275, in _dispatch

    self.training_type_plugin.start_evaluating(self)
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py", line 206, in start_evaluating

    self._results = trainer.run_stage()
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1286, in run_stage

    return self._run_evaluate()
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1334, in _run_evaluate

    eval_loop_results = self._evaluation_loop.run()
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\pytorch_lightning\loops\base.py", line 151, in run

    output = self.on_run_end()
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\pytorch_lightning\loops\dataloader\evaluation_loop.py", line 131, in on_run_end

    self._evaluation_epoch_end(outputs)
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\pytorch_lightning\loops\dataloader\evaluation_loop.py", line 231, in _evaluation_epoch_end

    model.test_epoch_end(outputs)
    

    File "E:\HTS\HTS\sed_model.py", line 201, in test_epoch_end

    gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())] 
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\torch\distributed\distributed_c10d.py", line 748, in get_world_size

    return _get_group_size(group)
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\torch\distributed\distributed_c10d.py", line 274, in _get_group_size

    default_pg = _get_default_group()
    

    File "C:\Users\Administrator\anaconda3\envs\torchtf\lib\site-packages\torch\distributed\distributed_c10d.py", line 358, in _get_default_group

    raise RuntimeError("Default process group has not been initialized, "
    

    RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

    Process finished with exit code 1

    I run the test process on esc-50 dateset with a signal GPU, something wrong with it, the result like above , I Commented out the program about gather_pred、gather_target and some about dist.xxxxx, It works .

    I want to know that can I run the code though signal GPU with changes above

    opened by dong-0412 1
  • Learning rate for small datasets

    Learning rate for small datasets

    Hi, thank you for your great work.

    Here, you mentioned that the warm-up part should be adjusted according to the dataset. Could you give some advice for small datasets? For example, I have approximately 15K samples, how can I set lr_scheduler_epoch and lr_rate?

    I compared audioset config and esc config, but mentioned parameters are same.

    opened by EmreOzkose 0
Owner
Knut(Ke) Chen
ORZ: { godfather: sweetdum, ufo: zgg, dragon sister: lzl, morning king: corner café }
Knut(Ke) Chen
This repo contains the official code and pre-trained models for the Dynamic Vision Transformer (DVT).

Dynamic-Vision-Transformer (Pytorch) This repo contains the official code and pre-trained models for the Dynamic Vision Transformer (DVT). Not All Ima

null 193 Jul 20, 2022
This is the official source code for SLATE. We provide the code for the model, the training code, and a dataset loader for the 3D Shapes dataset. This code is implemented in Pytorch.

SLATE This is the official source code for SLATE. We provide the code for the model, the training code and a dataset loader for the 3D Shapes dataset.

Gautam Singh 54 Aug 3, 2022
Official Repo for Ground-aware Monocular 3D Object Detection for Autonomous Driving

Visual 3D Detection Package: This repo aims to provide flexible and reproducible visual 3D detection on KITTI dataset. We expect scripts starting from

Yuxuan Liu 270 Aug 11, 2022
Official repo for QHack—the quantum machine learning hackathon

Note: This repository has been frozen while we consider the submissions for the QHack Open Hackathon. We hope you enjoyed the event! Welcome to QHack,

Xanadu 112 Jul 29, 2022
The official repo of the CVPR2021 oral paper: Representative Batch Normalization with Feature Calibration

Representative Batch Normalization (RBN) with Feature Calibration The official implementation of the CVPR2021 oral paper: Representative Batch Normali

Open source projects of ShangHua-Gao 72 Jul 12, 2022
Official repo for the work titled "SharinGAN: Combining Synthetic and Real Data for Unsupervised GeometryEstimation"

SharinGAN Official repo for the work titled "SharinGAN: Combining Synthetic and Real Data for Unsupervised GeometryEstimation" The official project we

Koutilya PNVR 23 Feb 26, 2022
This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures

Introduction This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures. @inproceedings{Wa

Jiaqi Wang 38 Jul 17, 2022
Official repo for AutoInt: Automatic Integration for Fast Neural Volume Rendering in CVPR 2021

AutoInt: Automatic Integration for Fast Neural Volume Rendering CVPR 2021 Project Page | Video | Paper PyTorch implementation of automatic integration

Stanford Computational Imaging Lab 139 Aug 7, 2022
This repo contains the official implementations of EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis

EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis This repo contains the official implementations of EigenDamage: Structured Prunin

Chaoqi Wang 107 Apr 20, 2022
Official repo for SemanticGAN https://nv-tlabs.github.io/semanticGAN/

SemanticGAN This is the official code for: Semantic Segmentation with Generative Models: Semi-Supervised Learning and Strong Out-of-Domain Generalizat

null 128 Jul 31, 2022
The official repo of the CVPR 2021 paper Group Collaborative Learning for Co-Salient Object Detection .

GCoNet The official repo of the CVPR 2021 paper Group Collaborative Learning for Co-Salient Object Detection . Trained model Download final_gconet.pth

Qi Fan 43 Jul 5, 2022
The official repo for CVPR2021——ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search.

ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search [paper] Introduction This is the official implementation of ViPNAS: Efficient V

Lumin 40 Aug 6, 2022
Official git repo for the CHIRP project

CHIRP Project This is the official git repository for the CHIRP project. Pull requests are accepted here, but for the moment, the main repository is s

Dan Smith 38 Aug 11, 2022
This repo is official PyTorch implementation of MobileHumanPose: Toward real-time 3D human pose estimation in mobile devices(CVPRW 2021).

Github Code of "MobileHumanPose: Toward real-time 3D human pose estimation in mobile devices" Introduction This repo is official PyTorch implementatio

Choi Sang Bum 171 Aug 7, 2022
Official Repo for ICCV2021 Paper: Learning to Regress Bodies from Images using Differentiable Semantic Rendering

[ICCV2021] Learning to Regress Bodies from Images using Differentiable Semantic Rendering Getting Started DSR has been implemented and tested on Ubunt

Sai Kumar Dwivedi 81 Jul 29, 2022
Official repo for BMVC2021 paper ASFormer: Transformer for Action Segmentation

ASFormer: Transformer for Action Segmentation This repo provides training & inference code for BMVC 2021 paper: ASFormer: Transformer for Action Segme

null 35 Aug 2, 2022
Official Repo of my work for SREC Nandyal Machine Learning Bootcamp

About the Bootcamp A 3-day Machine Learning Bootcamp organised by Department of Electronics and Communication Engineering, Santhiram Engineering Colle

MS 1 Nov 29, 2021
The official repo for OC-SORT: Observation-Centric SORT on video Multi-Object Tracking. OC-SORT is simple, online and robust to occlusion/non-linear motion.

OC-SORT Observation-Centric SORT (OC-SORT) is a pure motion-model-based multi-object tracker. It aims to improve tracking robustness in crowded scenes

Jinkun Cao 225 Aug 12, 2022