Instance Segmentation by Jointly Optimizing Spatial Embeddings and Clustering Bandwidth

Overview

Instance segmentation by jointly optimizing spatial embeddings and clustering bandwidth

This codebase implements the loss function described in:

Instance Segmentation by Jointly Optimizing Spatial Embeddings and Clustering Bandwidth Davy Neven, Bert De Brabandere, Marc Proesmans, and Luc Van Gool Conference on Computer Vision and Pattern Recognition (CVPR), june 2019

Our network architecture is a multi-branched version of ERFNet and uses the Lovasz-hinge loss for maximizing the IoU of each instance.

License

This software is released under a creative commons license which allows for personal and research use only. For a commercial license please contact the authors. You can view a license summary here.

Getting started

This codebase showcases the proposed loss function on car instance segmentation using the Cityscapes dataset.

Prerequisites

Dependencies:

  • Pytorch 1.1
  • Python 3.6.8 (or higher)
  • Cityscapes + scripts (if you want to evaluate the model)

Training

Training consists out of 2 steps. We first train on 512x512 crops around each object, to avoid computation on background patches. Afterwards, we finetune on larger patches (1024x1024) to account for bigger objects and background features which are not present in the smaller crops.

To generate these crops do the following:

$ CITYSCAPES_DIR=/path/to/cityscapes/ python utils/generate_crops.py

Afterwards start training:

$ CITYSCAPES_DIR=/path/to/cityscapes/ python train.py

Different options can be modified in train_config.py, e.g. to visualize set display=True.

Testing

You can download a pretrained model here. Save this file in the src/pretrained_models/ or adapt the test_config.py file.

To test the model on the Cityscapes validation set run:

$ CITYSCAPES_DIR=/path/to/cityscapes/ python test.py

The pretrained model gets 56.4 AP on the car validation set.

Acknowledgement

This work was supported by Toyota, and was carried out at the TRACE Lab at KU Leuven (Toyota Research on Automated Cars in Europe - Leuven)

Comments
  • why don't use GT in instance seed loss calculation

    why don't use GT in instance seed loss calculation

    seed loss

                seed_loss += self.foreground_weight * torch.sum(
                    torch.pow(seed_map[in_mask] - dist[in_mask].detach(), 2))
    

    we usually use prediction and gt to calculate losses, but in your loss function, both seed_map and dist are prediction. So why don't use GT in instance seed loss calculation? It shouldn't be seed_loss += self.foreground_weight * torch.sum( torch.pow(gt[in_mask] - dist[in_mask].detach(), 2))
    or seed_loss += self.foreground_weight * torch.sum( torch.pow(seed_map[in_mask] - gt[in_mask].detach(), 2)) ?

    opened by xiaojiaobingliang14 5
  • Gaussian calculation

    Gaussian calculation

    Hi, I'm reading the my_loss.py file and get a question regarding the calculation of gaussian.

                   s = torch.exp(s*10)
                   # calculate gaussian
                   dist = torch.exp(-1*torch.sum(
                   torch.pow(spatial_emb - center, 2)*s, 0, keepdim=True))
    

    According to Equation 5 in the paper, shouldn't this be: dist = torch.exp(-1*torch.sum( torch.pow(spatial_emb - center, 2)/(2*s**2), 0, keepdim=True))? Am I missing something? Thank you.

    opened by shikunyu8 4
  • Regarding the Lovasz Hinge loss

    Regarding the Lovasz Hinge loss

    Hello,

    I have more of a question here rather than an issue. I was curious about this line of code. As I understand, dist is a variable that lies between 0 (pixel embedding is very far from the instance center) and 1 (pixel embedding lies atop the instance center). Also by that logic,

    0< dist*2 -1 <1
    

    Could you share any intuition on why this dist*2-1 is preferred instead of just using dist, as the first argument to the lovasz hinge class? Does it have better convergence properties for the loss in your opinion, for example? Thank you!

    opened by lmanan 3
  • Computing var_loss

    Computing var_loss

    Hello Davy @davyneven

    I was wondering if instead of saying here:

     var_loss = var_loss + torch.mean(torch.pow(sigma_in - s.detach(), 2))
    

    One should rather say:

     var_loss = var_loss + torch.mean(torch.pow(sigma_in - s[..., 0].detach(), 2))
    

    I suggest this because sigma_in is of shape 2 x N and s is of shape 2 x 1 x 1, and subtracting two tensors of different shapes could lead to strange consequences (maybe?).

    For example:

    import numpy as np
    sigma_in = np.array([[1.0 ,2.0, 3.0], [2.0, 4.0, 6.0]]) 
    sigma_in = torch.from_numpy(sigma_in) # shape is 2 x 3
    >>> sigma_in
    tensor([[1., 2., 3.],
            [2., 4., 6.]], dtype=torch.float64)
    s = sigma_in.mean(1).view(2, 1, 1)  
    >>> s
    tensor([[[2.]],
            [[4.]]], dtype=torch.float64) # shape is 2 x 1 x 1
    result = sigma_in - s.detach() # shape is 2 x 2 x 3
    >>> result tensor([[[-1.,  0.,  1.],
             [ 0.,  2.,  4.]],
            [[-3., -2., -1.],
             [-2.,  0.,  2.]]], dtype=torch.float64)
    result_edited = sigma_in - s[..., 0].detach() # shape is 2 x 3
    >>> result_edited
    tensor([[-1.,  0.,  1.],
            [-2.,  0.,  2.]], dtype=torch.float64)
    

    Just wanted to ask if the current way is intended. Wouldn't we want the correct margin bandwidth dimension to be subtracted instead of all (2 x 2) subtractions? Thank you!

    opened by lmanan 2
  • Post-processing (Clustering) Slow

    Post-processing (Clustering) Slow

    Thank you for publishing the code!

    I ran the test using your pretrained model on the car class but found out that the avg post-processing time is about 391 ms per image on a Titan Xp GPU which is much slower than the number reported in the paper. May I ask if I missed anything? Thanks!

    opened by rayhou0710 2
  • Did you use multi-scale test?

    Did you use multi-scale test?

    Hi,

    Could you clarify if you used single scale or multi scale test to get the 27.6 AP on Cityscapes dataset? I could not find details about it in the paper.

    Thanks!

    opened by bowenc0221 2
  • seed loss after downsampling

    seed loss after downsampling

    hi @davyneven Thanks for your work. I have trouble when getting the seed map after resizing the original image into its 1/4 original size by downsampling in model. I read through the paper. Shall I also change the loss function for seed map?

    Thanks!

    opened by jyang68sh 1
  • Regarding variable `xym`

    Regarding variable `xym`

    Hello,

    Thank you for an excellent implementation and publication. I am learning quite a lot looking at your code and how you packaged this project. :+1:

    One question, which I wanted to run by you is, regarding the line of code for creating the state dict xym. I suspect that the cat order should be reversed since later this variable is accessed as [channel, height, width]. So what I suggest is:

    # coordinate map
    xm = torch.linspace(0, 2, 2048).view(1, 1, -1).expand(1, 1024, 2048)
    ym = torch.linspace(0, 1, 1024).view(1, -1, 1).expand(1, 1024, 2048)
    xym = torch.cat((xm, ym), 0)
    

    should become

    # coordinate map
    xm = torch.linspace(0, 2, 2048).view(1, 1, -1).expand(1, 1024, 2048)
    ym = torch.linspace(0, 1, 1024).view(1, -1, 1).expand(1, 1024, 2048)
    xym = torch.cat((ym, xm), 0)
    

    and this would lead to equivalent changes in this line of code as well. I might be interpreting this completely wrongly, but just wanted to check with you. Thank you for your time!

    opened by lmanan 1
  • Training problem

    Training problem

    Hi I have a problem while running train.py . Do you know how to solve it?

    
    lr_scheduler.py:122: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
      "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
    learning rate: 0.0005
      0%|                                                                  | 0/187 [00:03<?, ?it/s]
    Traceback (most recent call last):
      File "train.py", line 187, in <module>
        train_loss = train(epoch)
      File "train.py", line 108, in train
        loss = criterion(output, instances, class_labels, **args['loss_w'])
      File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
        outputs = self.parallel_apply(replicas, inputs, kwargs)
      File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
      File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
        output.reraise()
      File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/_utils.py", line 394, in reraise
        raise self.exc_type(msg)
    RuntimeError: Caught RuntimeError in replica 0 on device 0.
    Original Traceback (most recent call last):
      File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
        output = module(*input, **kwargs)
      File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
        result = self.forward(*input, **kwargs)
      File "/data/4TB/Andi/semantic/SpatialEmbeddings/src/criterions/my_loss.py", line 101, in forward
        lovasz_hinge(dist*2-1, in_mask)
      File "/data/4TB/Andi/semantic/SpatialEmbeddings/src/criterions/lovasz_losses.py", line 90, in lovasz_hinge
        for log, lab in zip(logits, labels))
      File "/data/4TB/Andi/semantic/SpatialEmbeddings/src/criterions/lovasz_losses.py", line 231, in mean
        acc = next(l)
      File "/data/4TB/Andi/semantic/SpatialEmbeddings/src/criterions/lovasz_losses.py", line 90, in <genexpr>
        for log, lab in zip(logits, labels))
      File "/data/4TB/Andi/semantic/SpatialEmbeddings/src/criterions/lovasz_losses.py", line 112, in lovasz_hinge_flat
        grad = lovasz_grad(gt_sorted)
      File "/data/4TB/Andi/semantic/SpatialEmbeddings/src/criterions/lovasz_losses.py", line 26, in lovasz_grad
        union = gts.float() + (1 - gt_sorted).float().cumsum(0)
      File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/tensor.py", line 394, in __rsub__
        return _C._VariableFunctions.rsub(self, other)
    RuntimeError: Subtraction, the `-` operator, with a bool tensor is not supported. If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.
    
    
    opened by ftlong6666 1
  • How to choose the value of n_sigma

    How to choose the value of n_sigma

    It looks like a one channel sigma map works well. And from the code we can see that the sigma map can be multi channels accordding to the n_sigma. Does this make differences and how to choose a proper value?

    opened by Croooooow 1
  • How to visualize offset vectors?

    How to visualize offset vectors?

    As shown in figure 2, I want to visualize predicted offset vectors. But, there are no descriptions about how the offset vectors were visualized. I'm studying these papers and codes carefully, and I would really appreciate your help.

    opened by gymoon10 0
  • sigma different with paper

    sigma different with paper

    https://github.com/davyneven/SpatialEmbeddings/blob/213b5fa9900513818365d98e7db9f2067633d94d/src/criterions/my_loss.py#L93 why this sigma calcuted is different with paper formula.5?

    opened by linhaoqi027 1
  • Correct number of workers for val_dataset

    Correct number of workers for val_dataset

    Carried out a very minor code update here in train.py:

    Correctly read num_workers from the val_dataset dictionary (instead of the train_dataset dictionary, as it was earlier). The new line of code is here:

    val_dataset_it = torch.utils.data.DataLoader(val_dataset, batch_size=args['val_dataset']['batch_size'], shuffle=True, drop_last=True, num_workers=args['val_dataset']['workers'], pin_memory=True if args['cuda'] else False)

    opened by lmanan 0
  • Ablation Experiments

    Ablation Experiments

    I'm trying to reproduce the Ablation Experiments but result is not good.

    1. This experiments is done with single-class model, right? If so, should I use cropped dataset obtained by generate_crops.py? For example, when I train person class, first train with (512,512) cropped dataset(OBJ_ID=26) and then train with (1024,1024) cropped dataset(OBJ_ID=26).

    2. How to use cluster_with_gt() function? Is it used for Ablation Experiments?

    Any help would be great, thanks!!

    opened by TakeruIto 0
  • Best Loss Weight

    Best Loss Weight

    Hi, in the config file, the loss weight is 'w_inst':1,'w_var':1,'w_seed':10. When training in this way, I can't reproduce your results on cars due to my bad seed map.

    Did you modify the weight when training? ex, first set the weight to 1 1 10 to optimize seed map then set 1 10 1 to optimize sigma map.

    Thank you very much!

    opened by charlotte12l 3
Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation

Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation This paper has been accepted and early accessed

Yun Liu 39 Sep 20, 2022
HiFi++: a Unified Framework for Neural Vocoding, Bandwidth Extension and Speech Enhancement

HiFi++ : a Unified Framework for Neural Vocoding, Bandwidth Extension and Speech Enhancement This is the unofficial implementation of Vocoder part of

Rishikesh (ऋषिकेश) 118 Dec 29, 2022
Unofficial implementation of HiFi-GAN+ from the paper "Bandwidth Extension is All You Need" by Su, et al.

HiFi-GAN+ This project is an unoffical implementation of the HiFi-GAN+ model for audio bandwidth extension, from the paper Bandwidth Extension is All

Brent M. Spell 134 Dec 30, 2022
Graph Regularized Residual Subspace Clustering Network for hyperspectral image clustering

Graph Regularized Residual Subspace Clustering Network for hyperspectral image clustering

Yaoming Cai 5 Jul 18, 2022
Awesome Deep Graph Clustering is a collection of SOTA, novel deep graph clustering methods

ADGC: Awesome Deep Graph Clustering ADGC is a collection of state-of-the-art (SOTA), novel deep graph clustering methods (papers, codes and datasets).

yueliu1999 297 Dec 27, 2022
TorchDistiller - a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

This project is a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

yifan liu 147 Dec 3, 2022
[ArXiv 2021] Data-Efficient Instance Generation from Instance Discrimination

InsGen - Data-Efficient Instance Generation from Instance Discrimination Data-Efficient Instance Generation from Instance Discrimination Ceyuan Yang,

GenForce: May Generative Force Be with You 93 Dec 25, 2022
Segmentation and Identification of Vertebrae in CT Scans using CNN, k-means Clustering and k-NN

Segmentation and Identification of Vertebrae in CT Scans using CNN, k-means Clustering and k-NN If you use this code for your research, please cite ou

null 41 Dec 8, 2022
Image morphing without reference points by applying warp maps and optimizing over them.

Differentiable Morphing Image morphing without reference points by applying warp maps and optimizing over them. Differentiable Morphing is machine lea

Alex K 380 Dec 19, 2022
Optimizing DR with hard negatives and achieving SOTA first-stage retrieval performance on TREC DL Track (SIGIR 2021 Full Paper).

Optimizing Dense Retrieval Model Training with Hard Negatives Jingtao Zhan, Jiaxin Mao, Yiqun Liu, Jiafeng Guo, Min Zhang, Shaoping Ma This repo provi

Jingtao Zhan 99 Dec 27, 2022
⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.

Optimized Einsum Optimized Einsum: A tensor contraction order optimizer Optimized einsum can significantly reduce the overall execution time of einsum

Daniel Smith 653 Dec 30, 2022
PHOTONAI is a high level python API for designing and optimizing machine learning pipelines.

PHOTONAI is a high level python API for designing and optimizing machine learning pipelines. We've created a system in which you can easily select and

Medical Machine Learning Lab - University of Münster 57 Nov 12, 2022
PiCIE: Unsupervised Semantic Segmentation using Invariance and Equivariance in clustering (CVPR2021)

PiCIE: Unsupervised Semantic Segmentation using Invariance and Equivariance in Clustering Jang Hyun Cho1, Utkarsh Mall2, Kavita Bala2, Bharath Harihar

Jang Hyun Cho 164 Dec 30, 2022
Adversarial Color Enhancement: Generating Unrestricted Adversarial Images by Optimizing a Color Filter

ACE Please find the preliminary version published at BMVC 2020 in the folder BMVC_version, and its extended journal version in Journal_version. Datase

null 28 Dec 25, 2022
Code for Iso-Points: Optimizing Neural Implicit Surfaces with Hybrid Representations

Implementation for Iso-Points (CVPR 2021) Official code for paper Iso-Points: Optimizing Neural Implicit Surfaces with Hybrid Representations paper |

Yifan Wang 66 Nov 8, 2022
Minimal PyTorch implementation of Generative Latent Optimization from the paper "Optimizing the Latent Space of Generative Networks"

Minimal PyTorch implementation of Generative Latent Optimization This is a reimplementation of the paper Piotr Bojanowski, Armand Joulin, David Lopez-

Thomas Neumann 117 Nov 27, 2022
Feed forward VQGAN-CLIP model, where the goal is to eliminate the need for optimizing the latent space of VQGAN for each input prompt

Feed forward VQGAN-CLIP model, where the goal is to eliminate the need for optimizing the latent space of VQGAN for each input prompt. This is done by

Mehdi Cherti 135 Dec 30, 2022
The implementation of "Optimizing Shoulder to Shoulder: A Coordinated Sub-Band Fusion Model for Real-Time Full-Band Speech Enhancement"

SF-Net for fullband SE This is the repo of the manuscript "Optimizing Shoulder to Shoulder: A Coordinated Sub-Band Fusion Model for Real-Time Full-Ban

Guochen Yu 36 Dec 2, 2022
The open source code of SA-UNet: Spatial Attention U-Net for Retinal Vessel Segmentation.

SA-UNet: Spatial Attention U-Net for Retinal Vessel Segmentation(ICPR 2020) Overview This code is for the paper: Spatial Attention U-Net for Retinal V

Changlu Guo 151 Dec 28, 2022