The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"


Hierarchical Token Semantic Audio Transformer


The Code Repository for "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection", in ICASSP 2022.

In this paper, we devise a model, HTS-AT, by combining a swin transformer with a token-semantic module and adapt it in to audio classification and sound event detection tasks. HTS-AT is an efficient and light-weight audio transformer with a hierarchical structure and has only 30 million parameters. It achieves new state-of-the-art (SOTA) results on AudioSet and ESC-50, and equals the SOTA on Speech Command V2. It also achieves better performance in event localization than the previous CNN-based models.

HTS-AT Architecture

Classification Results on AudioSet, ESC-50, and Speech Command V2 (mAP)

HTS-AT ClS Result

Localization/Detection Results on DESED dataset (F1-Score)

HTS-AT Localization Result

Getting Started

Install Requirments

pip install -r requirements.txt

Download and Processing Datasets

change the varible "dataset_path" to your audioset address
change the variable "desed_folder" to your DESED address
change the classes_num to 527
./ # 
// remember to change the pathes in the script
// more information about this script is in

python save_idc 
// count the number of samples in each class and save the npy files
Open the jupyter notebook at esc-50/prep_esc50.ipynb and process it
Open the jupyter notebook at scv2/prep_scv2.ipynb and process it
// will produce the npy data files

Set the Configuration File:

The script contains all configurations you need to assign to run your code. Please read the introduction comments in the file and change your settings. For the most important part: If you want to train/test your model on AudioSet, you need to set:

dataset_path = "your processed audioset folder"
dataset_type = "audioset"
balanced_data = True
loss_type = "clip_bce"
sample_rate = 32000
hop_size = 320 
classes_num = 527

If you want to train/test your model on ESC-50, you need to set:

dataset_path = "your processed ESC-50 folder"
dataset_type = "esc-50"
loss_type = "clip_ce"
sample_rate = 32000
hop_size = 320 
classes_num = 50

If you want to train/test your model on Speech Command V2, you need to set:

dataset_path = "your processed SCV2 folder"
dataset_type = "scv2"
loss_type = "clip_bce"
sample_rate = 16000
hop_size = 160
classes_num = 35

If you want to test your model on DESED, you need to set:

resume_checkpoint = "Your checkpoint on AudioSet"
heatmap_dir = "localization results output folder"
test_file = "output heatmap name"
fl_local = True
fl_dataset = "Your DESED npy file"

Train and Evaluation

Notice: Our model is run on DDP mode and requires at least two GPU cards. If you want to use a single GPU for training and evaluation, you need to mannually change and

All scripts is run by

Train: CUDA_VISIBLE_DEVICES=1,2,3,4 python train

Test: CUDA_VISIBLE_DEVICES=1,2,3,4 python test

Ensemble Test: CUDA_VISIBLE_DEVICES=1,2,3,4 python esm_test 
// See for settings of ensemble testing

Weight Average: python weight_average
// See for settings of weight averaging

Localization on DESED

CUDA_VISIBLE_DEVICES=1,2,3,4 python test
// make sure that fl_local=True in
// organize and gather the localization results
// Follow the notebook to produce the results

Model Checkpoints:

We provide the model checkpoints on three datasets (and additionally DESED dataset) in this link. Feel free to download and test it.


  author = {Ke Chen and Xingjian Du and Bilei Zhu and Zejun Ma and Taylor Berg-Kirkpatrick and Shlomo Dubnov},
  title = {HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection},
  booktitle = {{ICASSP} 2022}

Our work is based on Swin Transformer, which is a famous image classification transformer model.

    The mAP following Audioset Recipe is very low

    Hi, I downloaded model checkpoint files that you provided through Google drive and followed README Audios Evaluate code. Test: CUDA_VISIBLE_DEVICES=1,2,3,4 python test I expected to get similar performance to your paper But I got very low mAP. A number of eval dataset I used are 18,887. I would like to know the your data set size if possible. Attached Single model evaluation(HTSAT_AudioSet_Saved_1.ckpt) results pic. Thanks.

    opened by kimsojeong1225 11
    Hello, I have the following problem when running the code, what is the reason?

    RuntimeError: upsample_bicubic2d_backward_out_cuda does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation, or you can use the 'warn_only=True' option, if that's acceptable for your application. You can also file an issue at to help us prioritize adding deterministic support for this operation

    opened by ykingliu 6
    get bad result for esc50

    i have one GPU, so i changed some code in and and

    and set like that:

    dataset_type = "esc-50" loss_type = "clip_ce" sample_rate = 32000 classes_num = 50

    then i just get ACC : 0.55

    i changed

    deterministic=False dist.init_process_group( backend="nccl", init_method="tcp://localhost:23456", rank=0, world_size=1 ) for init_process_group error

    code can be runing. but not get results same as your paper.

    opened by wac81 6
    masking and difference in mix-up strategy

    Hi Ke,

    Thanks for the great work and open sourcing the code!

    I'd like to build from your excellent codebase, and I have a few questions regarding the details:

    1. I couldn't find any information about padding mask. Is it not used in the model?
    2. the mix-up procedure seems to be a bit different from AST. 2.1 In AST, they mix up the raw waveform before applying transforms, while in HST-AT, you get fbanks first, and then mix-up fbanks. 2.2 In AST, the mix-up waveform is randomly sampled from the entire dataset, while you sample within the current batch. 2.3 In AST, the they also mix-up labels using lambda * label1 + (1-lambda)*label2, while HST-AT does not mix labels. Not sure if the three differences will make a big difference in performance, but I'm curious about your thoughts.

    Thanks, Puyuan

    opened by jasonppy 4
    Learning rate for small datasets

    Hi, thank you for your great work.

    Here, you mentioned that the warm-up part should be adjusted according to the dataset. Could you give some advice for small datasets? For example, I have approximately 15K samples, how can I set lr_scheduler_epoch and lr_rate?

    I compared audioset config and esc config, but mentioned parameters are same.

    opened by EmreOzkose 3
    Different length audio input for infer mode

    Hi, thanks for the interesting work!

    I have a question about the infer mode in When training, the length of audio input will always be 10 seconds. When inference, the model needs to handle variable-length audio input which could be longer or shorter than 10 seconds, but for the infer mode in the,

     if infer_mode:
                # in infer mode. we need to handle different length audio input
                frame_num = x.shape[2]
                target_T = int(self.spec_size * self.freq_ratio)
                repeat_ratio = math.floor(target_T / frame_num)
                x = x.repeat(repeats=(1,1,repeat_ratio,1))
                x = self.reshape_wav2img(x) 

    What if the length of input frame_num > target_T (should be 256x4=1024 here)? If so, repeat_ratio will be 0. So how does the model process the audio longer than 10 seconds for inference here?

    opened by CaptainPrice12 3
    Learning rate

    Hi I think lr in code is different for paper. In the paper show learning rate is 0.05,0.1,0.2 in the first three epochs, but code lr is 2e-5,5e-5,1e-4 ( lr_rate=[0.02,0.05,0.1]). I revised lr_rate=[50,100,200] to match the paper but model training shows bad results. I want to know what method is right to get the same results indicated in the paper.

    opened by kimsojeong1225 3
    Question about AudioSet and finetune learning rate.

    陈轲 你好! 很棒的工作!我主要想问两个问题: 第一个是关于AudioSet,根据我阅读你的源码,我猜测你是用 Kong qiuqiang分享在百度云盘 的数据集,如果是,你是否遇到有解压到问题。表现在解压时报错(不管是winrar还是7zip,但winrar仍可以解压),加载数据时报错(大概是出现损坏文件)。排除掉损坏文件后,但结果无法复现另一篇工作(指标很差),所以我怀疑和我没有正确解压有关。如果你是用此分享数据集,我想问问你的解压方式是什么? 第二个是关于利用ImageNet预训练的Swin Transformer进行finetue时,你的学习率是多少?还有学习策略和train from scratch 是一样的吗?论文中似乎只说到了train from scratch的学习率和学习策略。


    opened by MichaelLynn1996 2
    How can I test model?

    Hello , in "htsat_esc_training.ipnb", how can i find test_data? Or i should design the test_data myself? I'm not familiar with pytorch-lightning.

    opened by yyssxxx 2
    TypeError: cannot pickle 'module' object

    I am running htsat_esc_training.ipynb and getting this error on my PC.

    Python version: 3.9.12 Installed all requirements from requirements.txt. Ran the notebook in VSCode. Not changes to the code.

    GPU available: True, used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs trainer properties: gpus: 1 max_epochs: 100 auto_lr_find: True accelerator: <pytorch_lightning.accelerators.gpu.GPUAccelerator object at 0x0000016B08410790> num_sanity_val_steps: 0 resume_from_checkpoint: None gradient_clip_val: 1.0


    TypeError Traceback (most recent call last) Cell In [26], line 3 1 # Training the model 2 # You can set different fold index by setting 'esc_fold' to any number from 0-4 in ----> 3, audioset_data)

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\, in, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path) 735 rank_zero_deprecation( 736 " is deprecated in v1.4 and will be removed in v1.6." 737 " Use instead. HINT: added 's'" 738 ) 739 train_dataloaders = train_dataloader --> 740 self._call_and_handle_interrupt( 741 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 742 )

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs) 675 r""" 676 Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict) 677 as all errors should funnel through them (...) 682 **kwargs: keyword arguments to be passed to trainer_fn 683 """ 684 try: --> 685 return trainer_fn(*args, **kwargs) 686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 687 except KeyboardInterrupt as exception:

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 775 # TODO: ckpt_path only in v1.7 776 ckpt_path = ckpt_path or self.resume_from_checkpoint --> 777 self._run(model, ckpt_path=ckpt_path) 779 assert self.state.stopped 780 = False

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\, in Trainer._run(self, model, ckpt_path) 1196 self.checkpoint_connector.resume_end() 1198 # dispatch start_training or start_evaluating or start_predicting -> 1199 self._dispatch() 1201 # plugin will finalized fitting (e.g. ddp_spawn will load trained model) 1202 self._post_dispatch()

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\, in Trainer._dispatch(self) 1277 self.training_type_plugin.start_predicting(self) 1278 else: -> 1279 self.training_type_plugin.start_training(self)

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\plugins\training_type\, in TrainingTypePlugin.start_training(self, trainer) 200 def start_training(self, trainer: "pl.Trainer") -> None: 201 # double dispatch to initiate the training loop --> 202 self._results = trainer.run_stage()

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\, in Trainer.run_stage(self) 1287 if self.predicting: 1288 return self._run_predict() -> 1289 return self._run_train()

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\, in Trainer._run_train(self) 1317 self.fit_loop.trainer = self 1318 with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> 1319

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\loops\, in, *args, **kwargs) 143 try: 144 self.on_advance_start(*args, **kwargs) --> 145 self.advance(*args, **kwargs) 146 self.on_advance_end() 147 self.restarting = False

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\loops\, in FitLoop.advance(self) 231 data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader) 233 with self.trainer.profiler.profile("run_training_epoch"): --> 234 236 # the global step is manually decreased here due to backwards compatibility with existing loggers 237 # as they expect that the same step is used when logging epoch end metrics even when the batch loop has 238 # finished. this means the attribute does not exactly track the number of optimizer steps applied. 239 # TODO(@carmocca): deprecate and rename so users don't get confused 240 self.global_step -= 1

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\loops\, in, *args, **kwargs) 136 return self.on_skip() 138 self.reset() --> 140 self.on_run_start(*args, **kwargs) 142 while not self.done: 143 try:

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\loops\epoch\, in TrainingEpochLoop.on_run_start(self, data_fetcher, **kwargs) 138 self.trainer.fit_loop.epoch_progress.increment_started() 140 self._reload_dataloader_state_dict(data_fetcher) --> 141 self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx + 1)

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\loops\, in _update_dataloader_iter(data_fetcher, batch_idx) 118 """Attach the dataloader.""" 119 if not isinstance(data_fetcher, DataLoaderIterDataFetcher): 120 # restore iteration --> 121 dataloader_iter = enumerate(data_fetcher, batch_idx) 122 else: 123 dataloader_iter = iter(data_fetcher)

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\utilities\, in AbstractDataFetcher.iter(self) 196 self.reset() 197 self.dataloader_iter = iter(self.dataloader) --> 198 self._apply_patch() 199 self.prefetching(self.prefetch_batches) 200 return self

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\utilities\, in AbstractDataFetcher._apply_patch(self) 130 loader._lightning_fetcher = self 131 patch_dataloader_iterator(loader, iterator, self) --> 133 apply_to_collections(self.loaders, self.loader_iters, (Iterator, DataLoader), _apply_patch_fn)

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\utilities\, in AbstractDataFetcher.loader_iters(self) 178 raise MisconfigurationException("The dataloader_iter isn't available outside the iter context.") 180 if isinstance(self.dataloader, CombinedLoader): --> 181 loader_iters = self.dataloader_iter.loader_iters 182 else: 183 loader_iters = [self.dataloader_iter]

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\, in CombinedLoaderIterator.loader_iters(self) 535 """Get the _loader_iters and create one if it is None.""" 536 if self._loader_iters is None: --> 537 self._loader_iters = self.create_loader_iters(self.loaders) 539 return self._loader_iters

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\trainer\, in CombinedLoaderIterator.create_loader_iters(loaders) 568 """Create and return a collection of iterators from loaders. 569 570 Args: (...) 574 a collections of iterators 575 """ 576 # dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences --> 577 return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping))

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\pytorch_lightning\utilities\, in apply_to_collection(data, dtype, function, wrong_dtype, include_none, *args, **kwargs) 93 # Breaking condition 94 if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): ---> 95 return function(data, *args, **kwargs) 97 elem_type = type(data) 99 # Recursively apply to collection items

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\torch\utils\data\, in DataLoader.iter(self) 442 return self._iterator 443 else: --> 444 return self._get_iterator()

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\torch\utils\data\, in DataLoader._get_iterator(self) 388 else: 389 self.check_worker_number_rationality() --> 390 return _MultiProcessingDataLoaderIter(self)

    File c:\Users\jonat\source\repos\HTS-Audio-Transformer-main\HTSATvenv\lib\site-packages\torch\utils\data\, in _MultiProcessingDataLoaderIter.init(self, loader) 1070 w.daemon = True 1071 # NB: Process.start() actually take some time as it needs to 1072 # start a process and pass the arguments over via a pipe. 1073 # Therefore, we only add a worker to self._workers list after 1074 # it started, so that we do not call .join() if program dies 1075 # before it starts, and del tries to join but will get: 1076 # AssertionError: can only join a started process. -> 1077 w.start() 1078 self._index_queues.append(index_queue) 1079 self._workers.append(w)

    File ~\AppData\Local\Programs\Python\Python39\lib\multiprocessing\, in BaseProcess.start(self) 118 assert not _current_process._config.get('daemon'),
    119 'daemonic processes are not allowed to have children' 120 _cleanup() --> 121 self._popen = self._Popen(self) 122 self._sentinel = self._popen.sentinel 123 # Avoid a refcycle if the target function holds an indirect 124 # reference to the process object (see bpo-30775)

    File ~\AppData\Local\Programs\Python\Python39\lib\multiprocessing\, in Process._Popen(process_obj) 222 @staticmethod 223 def _Popen(process_obj): --> 224 return _default_context.get_context().Process._Popen(process_obj)

    File ~\AppData\Local\Programs\Python\Python39\lib\multiprocessing\, in SpawnProcess._Popen(process_obj) 324 @staticmethod 325 def _Popen(process_obj): 326 from .popen_spawn_win32 import Popen --> 327 return Popen(process_obj)

    File ~\AppData\Local\Programs\Python\Python39\lib\multiprocessing\, in Popen.init(self, process_obj) 91 try: 92 reduction.dump(prep_data, to_child) ---> 93 reduction.dump(process_obj, to_child) 94 finally: 95 set_spawning_popen(None)

    File ~\AppData\Local\Programs\Python\Python39\lib\multiprocessing\, in dump(obj, file, protocol) 58 def dump(obj, file, protocol=None): 59 '''Replacement for pickle.dump() using ForkingPickler.''' ---> 60 ForkingPickler(file, protocol).dump(obj)

    TypeError: cannot pickle 'module' object

    Which version of PyTorch would you recommend? pip3 install torch torchvision torchaudio --extra-index-url

    opened by JonathanFL 2
  • error in training

    error in training


       I get the error when running the esc_50 script for training as follows. How to solved it ?

    RuntimeError: upsample_bicubic2d_backward_out_cuda does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation, or you can use the 'warn_only=True' option, if that's acceptable for your application. You can also file an issue at to help us prioritize adding deterministic support for this operation. Epoch 0: 0%| | 0/65 [00:04<?, ?it/s]

    opened by joewale 2
    How can run this project with one GPU?!

    hi. I am trying to test this model on Google Speech Command but after testing progress bar complete, I get this error: Default process group has not been initialized, please make sure to call init_process_group

    I google this error and found that this error happens because of 'SyncBatchNorm' in 1 GPU and I should replace them by normal ones. in your code i found 'sync_batchnorm' parameter in pl.Trainer() and set its value to 'False' and run test command again but it wasn't work.

    could anyone please tell me how can i run this project on 1 GPU ?! thanks.

    opened by saeedmaroof 0
    How to finetune on strong label dataset?

    您好,非常棒的工作!但是我在强标注数据集上finetune进行训练的时候有一些疑惑,我想请问一下您在issue 25中提到 "need to extract different output of HST-AT (I believe it is the last second layer feature-map output)", 这个last second layer 指的是token semantic 模块的输出吗,以及您提到“ the interpolation and resolution of the output may be different from the input localization time resolution ----- in that you need to find a way to align them.”,您代码中将输出的时间轴进行插值处理后变成1024的长度,算是一种处理的方式吗?若您回答我的问题,不胜感激!

    opened by wengstA 0
