HyperSeg: Patch-wise Hypernetwork for Real-time Semantic Segmentation Official PyTorch Implementation

Overview

HyperSeg - Official PyTorch Implementation

Teaser Example segmentations on the PASCAL VOC dataset.

This repository contains the source code for the real-time semantic segmentation method described in the paper:

HyperSeg: Patch-wise Hypernetwork for Real-time Semantic Segmentation
Conference on Computer Vision and Pattern Recognition (CVPR), 2021
Yuval Nirkin, Lior Wolf, Tal Hassner
Paper

Abstract: We present a novel, real-time, semantic segmentation network in which the encoder both encodes and generates the parameters (weights) of the decoder. Furthermore, to allow maximal adaptivity, the weights at each decoder block vary spatially. For this purpose, we design a new type of hypernetwork, composed of a nested U-Net for drawing higher level context features, a multi-headed weight generating module which generates the weights of each block in the decoder immediately before they are consumed, for efficient memory utilization, and a primary network that is composed of novel dynamic patch-wise convolutions. Despite the usage of less-conventional blocks, our architecture obtains real-time performance. In terms of the runtime vs. accuracy trade-off, we surpass state of the art (SotA) results on popular semantic segmentation benchmarks: PASCAL VOC 2012 (val. set) and real-time semantic segmentation on Cityscapes, and CamVid.

Installation

Install the following packages:

conda install pytorch torchvision cudatoolkit=11.1 -c pytorch -c conda-forge
pip install opencv-python ffmpeg-python

Add the parent directory of the repository to PYTHONPATH.

Models

Template Dataset Resolution mIoU (%) FPS Link
HyperSeg-L PASCAL VOC 512x512 80.6 (val) - download
HyperSeg-M CityScapes 1024x512 76.2 (val) 36.9 download
HyperSeg-S CityScapes 1536x768 78.2 (val) 16.1 download
HyperSeg-S CamVid 768x576 78.4 (test) 38.0 download
HyperSeg-L CamVid 1024x768 79.1 (test) 16.6 -

The models FPS was measured on an NVIDIA GeForce GTX 1080TI GPU.

Either download the models under /weights or adjust the model variable in the test configuration files.

Datasets

Dataset # Images Classes Resolution Link
PASCAL VOC 10,582 21 up to 500x500 auto downloaded
CityScapes 5,000 19 960x720 download
CamVid 701 12 2048x1024 download

Either download the datasets under /data or adjust the data_dir variable in the configuration files.

Training

To train the HyperSeg-M model on Cityscapes, set the exp_dir and data_dir paths in cityscapes_efficientnet_b1_hyperseg-m.py and run:

python configs/train/cityscapes_efficientnet_b1_hyperseg-m.py

Testing

Testing a model after training

For example testing the HyperSeg-M model on Cityscapes validation set:

python test.py 'checkpoints/cityscapes/cityscapes_efficientnet_b1_hyperseg-m' \
-td "hyper_seg.datasets.cityscapes.CityscapesDataset('data/cityscapes',split='val',mode='fine')" \
-it "seg_transforms.LargerEdgeResize([512,1024])"

Testing a pretrained model

For example testing the PASCAL VOC HyperSeg-L model using the available test configuration:

python configs/test/vocsbd_efficientnet_b3_hyperseg-l.py
Comments
  • TypeError: an integer is required (got type tuple)

    TypeError: an integer is required (got type tuple)

    Hi, Thanks for your sharing. There is a error when I run the training code which is the sample.

    TRAINING: Epoch: 1 / 360; LR: 1.0e-03; losses: [total: 3.1942 (3.1942); ] bench: [iou: 0.0106 (0.0106); ] : 0%| | 1/1000 [00:07<1:59:20, 7.17s/batches]

    Traceback (most recent call last): File "/home/tt/zyj_ws/hyperseg/configs/train/cityscapes_efficientnet_b1_hyperseg-m.py", line 43, in main(exp_dir, train_dataset=train_dataset, val_dataset=val_dataset, train_img_transforms=train_img_transforms, File "/home/tt/zyj_ws/hyperseg/train.py", line 248, in main epoch_loss, epoch_iou = proces_epoch(train_loader, train=True) File "/home/tt/zyj_ws/hyperseg/train.py", line 104, in proces_epoch for i, (input, target) in enumerate(pbar): File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/tqdm/std.py", line 1185, in iter for obj in iterable: File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 517, in next data = self._next_data() File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1179, in _next_data return self._process_data(data) File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1225, in _process_data data.reraise() File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/_utils.py", line 429, in reraise raise self.exc_type(msg) TypeError: Caught TypeError in DataLoader worker process 1. Original Traceback (most recent call last): File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop data = fetcher.fetch(index) File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/tt/zyj_ws/hyperseg/datasets/cityscapes.py", line 220, in getitem image, target = self.transforms(image, target) File "/home/tt/zyj_ws/hyperseg/datasets/seg_transforms.py", line 78, in call input = list(t(*input)) File "/home/tt/zyj_ws/hyperseg/datasets/seg_transforms.py", line 334, in call lbl = F.pad(lbl, (int(self.size[1] - lbl.size[0]), 0), self.lbl_fill, self.padding_mode).copy() File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torchvision/transforms/functional.py", line 426, in pad return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py", line 153, in pad image = ImageOps.expand(img, border=padding, **opts) File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/PIL/ImageOps.py", line 403, in expand draw.rectangle((0, 0, width - 1, height - 1), outline=color, width=border) File "/home/tt/anaconda3/envs/zyjenv/lib/python3.9/site-packages/PIL/ImageDraw.py", line 259, in rectangle self.draw.draw_rectangle(xy, ink, 0, width) TypeError: an integer is required (got type tuple)

    According to the information from the Internet, this may be a problem in the transformation. Can you give me some suggestion on how to do it?

    opened by consideragain 12
  • ValueError: Unknown resampling filter (InterpolationMode.BICUBIC).

    ValueError: Unknown resampling filter (InterpolationMode.BICUBIC).

    I have some problems when I train the model on my own data and I don't know if it is because that I make some mistakes when I load my data. Here are the details: Traceback (most recent call last): File "E:\zjy\hyperseg-main\hyperseg\train.py", line 254, in main epoch_loss, epoch_iou = proces_epoch(val_loader, train=False) File "E:\zjy\hyperseg-main\hyperseg\train.py", line 104, in proces_epoch for i, (input, target) in enumerate(pbar): File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\tqdm\std.py", line 1130, in iter for obj in iterable: File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch\utils\data\dataloader.py", line 521, in next data = self._next_data() File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch\utils\data\dataloader.py", line 1203, in _next_data return self._process_data(data) File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch\utils\data\dataloader.py", line 1229, in _process_data data.reraise() File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch_utils.py", line 438, in reraise raise exception ValueError: Caught ValueError in DataLoader worker process 0. Original Traceback (most recent call last): File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch\utils\data_utils\worker.py", line 287, in _worker_loop data = fetcher.fetch(index) File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch\utils\data_utils\fetch.py", line 49, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\torch\utils\data_utils\fetch.py", line 49, in data = [self.dataset[idx] for idx in possibly_batched_index] File "E:\zjy\hyperseg-main\hyperseg\datasets\Massachusetts.py", line 104, in getitem img, target = self.transforms(img, target) File "E:\zjy\hyperseg-main\hyperseg\datasets\seg_transforms.py", line 78, in call input = list(t(*input)) File "E:\zjy\hyperseg-main\hyperseg\datasets\seg_transforms.py", line 199, in call img = larger_edge_resize(img, self.size, self.interpolation) File "E:\zjy\hyperseg-main\hyperseg\datasets\seg_transforms.py", line 174, in larger_edge_resize return img.resize(size[::-1], interpolation) File "E:\anaconda3\envs\zhangjiaying\lib\site-packages\PIL\Image.py", line 1861, in resize raise ValueError( ValueError: Unknown resampling filter (InterpolationMode.BICUBIC). Use Image.NEAREST (0), Image.LANCZOS (1), Image.BILINEAR (2), Image.BICUBIC (3), Image.BOX (4) or Image.HAMMING (5)

    opened by JoyeeZhang 9
  • ModuleNotFoundError:No module named ‘hyperseg’

    ModuleNotFoundError:No module named ‘hyperseg’

    I encountered this error when I was training the model, and when I deleted all the "hyperseg" in the code, this error was not reported. Is this the right thing to do?

    opened by Rebufleming 5
  • No module named 'hyper_seg'

    No module named 'hyper_seg'

    Because the training failed, I tried to download cityscapes_efficientnet_b1_hyperseg-m.pth to hyperseg/weights, and renamed it to model_best.pth, I try python test.py 'weights'
    -td "hyper_seg.datasets.cityscapes.CityscapesDataset('data/cityscapes',split='val',mode='fine')"\ -it "seg_transforms.LargerEdgeResize([512,1024])"
    --gpus 0 image Looking forward to your reply!

    opened by leo-hao 5
  • Dataset not found or incomplete.

    Dataset not found or incomplete.

    (liuhaomag) cv428@428:~/data/LH/hyperseg$ python configs/train/cityscapes_efficientnet_b1_hyperseg-m.py
    Traceback (most recent call last):
      File "configs/train/cityscapes_efficientnet_b1_hyperseg-m.py", line 46, in <module>
        scheduler=scheduler, pretrained=pretrained, model=model, criterion=criterion, batch_scheduler=batch_scheduler)
      File "/home/cv428/data/LH/hyperseg/train.py", line 187, in main
        train_dataset = obj_factory(train_dataset, transforms=train_transforms)
      File "/home/cv428/data/LH/hyperseg/utils/obj_factory.py", line 57, in obj_factory
        return obj_exp(*args, **kwargs)
      File "/home/cv428/data/LH/hyperseg/datasets/cityscapes.py", line 156, in __init__
        raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
    RuntimeError: Dataset not found or incomplete. Please make sure all required folders for the specified "split" and "mode" are inside the "root" directory
     
    
    opened by leo-hao 5
  • Saving for C++ inference

    Saving for C++ inference

    First, great work !

    I was trying to train in Python and save it for C++ inference. The classic approach doesn't work:

    annotation_script_module = torch.jit.script(model)
    annotation_script_module.save("my_path")
    

    Do you have any suggestion on how to do it?

    opened by Roios 4
  • MIoU on Pascal VOC is about 0.7 with 50 epoches

    MIoU on Pascal VOC is about 0.7 with 50 epoches

    Thank you for your sharing. I have tried to retrain the model on Pascal VOC dataset, but the miou is only about 0.7 after 50 epoches and it hardly increased, may I know is it correct? or should I continue the training for 160 epoches?

    opened by dansonc 3
  • About Test bug

    About Test bug

    While I use the test script of vocdataset, it will occur a bug----[Errno 2] No such file or directory: 'data/vocsbd/Vocdevkit/V0C2012/3PEGImages/2007_000033.jp9. How to deal with it?

    opened by 870572761 2
  • Question about freezing updates for model

    Question about freezing updates for model

    Hello, thank you so much for releasing your code! I am currently training the model on a custom dataset, and I was wondering if it is possible to freeze the gradient updates for a select number of classes? I am currently zeroing out the gradients from the decoder level 4 BatchNorm3 layer because that is the only layer that has the size of the number of classes within my custom dataset, but when I attempt this, it appears that the weights of the classes that I try to freeze are still changing.

    opened by ma53ma 2
  • CVE-2007-4559 Patch

    CVE-2007-4559 Patch

    Patching CVE-2007-4559

    Hi, we are security researchers from the Advanced Research Center at Trellix. We have began a campaign to patch a widespread bug named CVE-2007-4559. CVE-2007-4559 is a 15 year old bug in the Python tarfile package. By using extract() or extractall() on a tarfile object without sanitizing input, a maliciously crafted .tar file could perform a directory path traversal attack. We found at least one unsantized extractall() in your codebase and are providing a patch for you via pull request. The patch essentially checks to see if all tarfile members will be extracted safely and throws an exception otherwise. We encourage you to use this patch or your own solution to secure against CVE-2007-4559. Further technical information about the vulnerability can be found in this blog.

    If you have further questions you may contact us through this projects lead researcher Kasimir Schulz.

    opened by TrellixVulnTeam 1
  • The difference between the Resize(512, 1024) and LargerEdgeResize(512, 1024)

    The difference between the Resize(512, 1024) and LargerEdgeResize(512, 1024)

    Hello! @YuvalNirkin,

    I would like to know what is the difference between the torchvision.transforms.Resize function and your designed hyperseg.datasets.seg_transforms.LargerEdgeResize function?

    The results of global accuray and mIoU are slightly difference when I use these two transfromation separably in test phrase.

    Wish your reply.

    opened by J-JunChen 1
  • Another bug about the script of train camvid

    Another bug about the script of train camvid

    While I solve the bug of camvid dataset, I find another bug of the script of train camvid. When running the script of train camvid, it will occur "AttributeError: 'RandomRotation' object has no attribute 'resample' ". Tips: the script of test camvid can run. I don't know how to deal with it.

    opened by 870572761 1
  • Assertion `t >= 0 && t < n_classes` failed.

    Assertion `t >= 0 && t < n_classes` failed.

    Sorry to bother you, I have encontered a probelm: Assertion t >= 0 && t < n_classes failed when I try to use your model to my datasets, I have changed the num_class in my datasets, but I still face the problem, can u help me? here is my change: self.classes = list(range(2)) self.image_classes = calc_classes_per_image(self.masks, 2, cache_file)

    opened by xings-sdnu 1
  • signal_index is not getting updated

    signal_index is not getting updated

    The variable signal_index is not getting updated inside the init_signal2weights function in hyperseg_v1_0.py. Hence, it is assigning singal_index = 0 for every meta block in the decoder.

    Because of this inside the apply_signal2weights function in each meta block, we will not be able to use all the channels of the input signal channels. Essentially we are using [0:signal_channels] of the whole signal where signal_channels of each block is always less than input signal channels.

            # Inside apply_signal2weights function
            w = self.signal2weights(s[:, self.signal_index:self.signal_index + self.signal_channels])[:, :self.hyper_params]
    

    So, is the signal_index supposed to be zero for every meta block?

    opened by Dhamodhar-DDR 1
Owner
Yuval Nirkin
Deep Learning Researcher
Yuval Nirkin
FANet - Real-time Semantic Segmentation with Fast Attention

FANet Real-time Semantic Segmentation with Fast Attention Ping Hu, Federico Perazzi, Fabian Caba Heilbron, Oliver Wang, Zhe Lin, Kate Saenko , Stan Sc

Ping Hu 42 Nov 30, 2022
DFFNet: An IoT-perceptive Dual Feature Fusion Network for General Real-time Semantic Segmentation

DFFNet Paper DFFNet: An IoT-perceptive Dual Feature Fusion Network for General Real-time Semantic Segmentation. Xiangyan Tang, Wenxuan Tu, Keqiu Li, J

null 4 Sep 23, 2022
Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation)

Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation) Download Synthia dataset The model uses

null 32 Sep 21, 2022
Official code of Retinal Vessel Segmentation with Pixel-wise Adaptive Filters and Consistency Training

Official code of Retinal Vessel Segmentation with Pixel-wise Adaptive Filters and Consistency Training (ISBI 2022)

anonymous 7 Feb 10, 2022
Pixel-wise segmentation on VOC2012 dataset using pytorch.

PiWiSe Pixel-wise segmentation on the VOC2012 dataset using pytorch. FCN SegNet PSPNet UNet RefineNet For a more complete implementation of segmentati

Bodo Kaiser 378 Dec 30, 2022
code for `Look Closer to Segment Better: Boundary Patch Refinement for Instance Segmentation`

Look Closer to Segment Better: Boundary Patch Refinement for Instance Segmentation (CVPR 2021) Introduction PBR is a conceptually simple yet effective

H.Chen 143 Jan 5, 2023
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
PyTorch implementation of adversarial patch

adversarial-patch PyTorch implementation of adversarial patch This is an implementation of the Adversarial Patch paper. Not official and likely to hav

Jamie Hayes 172 Nov 29, 2022
This repo is official PyTorch implementation of MobileHumanPose: Toward real-time 3D human pose estimation in mobile devices(CVPRW 2021).

Github Code of "MobileHumanPose: Toward real-time 3D human pose estimation in mobile devices" Introduction This repo is official PyTorch implementatio

Choi Sang Bum 203 Jan 5, 2023
Official PyTorch implementation of Segmenter: Transformer for Semantic Segmentation

Segmenter: Transformer for Semantic Segmentation Segmenter: Transformer for Semantic Segmentation by Robin Strudel*, Ricardo Garcia*, Ivan Laptev and

null 594 Jan 6, 2023
An official implementation of "Exploiting a Joint Embedding Space for Generalized Zero-Shot Semantic Segmentation" (ICCV 2021) in PyTorch.

Exploiting a Joint Embedding Space for Generalized Zero-Shot Semantic Segmentation This is an official implementation of the paper "Exploiting a Joint

CV Lab @ Yonsei University 35 Oct 26, 2022
This is an official implementation of "Polarized Self-Attention: Towards High-quality Pixel-wise Regression"

Polarized Self-Attention: Towards High-quality Pixel-wise Regression This is an official implementation of: Huajun Liu, Fuqiang Liu, Xinyi Fan and Don

DeLightCMU 212 Jan 8, 2023
Official implementation of "Robust channel-wise illumination estimation"

This repository provides the official implementation of "Robust channel-wise illumination estimation." accepted in BMVC (2021).

Firas Laakom 4 Nov 8, 2022
TorchDistiller - a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

This project is a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

yifan liu 147 Dec 3, 2022
Retinal Vessel Segmentation with Pixel-wise Adaptive Filters (ISBI 2022)

Retinal Vessel Segmentation with Pixel-wise Adaptive Filters (ISBI 2022) Introdu

anonymous 14 Oct 27, 2022
The source code and data of the paper "Instance-wise Graph-based Framework for Multivariate Time Series Forecasting".

IGMTF The source code and data of the paper "Instance-wise Graph-based Framework for Multivariate Time Series Forecasting". Requirements The framework

Wentao Xu 24 Dec 5, 2022
Semantic Segmentation for Real Point Cloud Scenes via Bilateral Augmentation and Adaptive Fusion (CVPR 2021)

Semantic Segmentation for Real Point Cloud Scenes via Bilateral Augmentation and Adaptive Fusion (CVPR 2021) This repository is for BAAF-Net introduce

null 90 Dec 29, 2022
Official PyTorch code for WACV 2022 paper "CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows"

CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows WACV 2022 preprint:https://arxiv.org/abs/2107.1

Denis 156 Dec 28, 2022