Regularized Frank-Wolfe for Dense CRFs: Generalizing Mean Field and Beyond

Overview

CRF - Conditional Random Fields

A library for dense conditional random fields (CRFs).

This is the official accompanying code for the paper Regularized Frank-Wolfe for Dense CRFs: Generalizing Mean Field and Beyond published at NeurIPS 2021 by Đ.Khuê Lê-Huu and Karteek Alahari. Please cite this paper if you use any part of this code, using the following BibTeX entry:

@inproceedings{lehuu2021regularizedFW,
  title={Regularized Frank-Wolfe for Dense CRFs: Generalizing Mean Field and Beyond},
  author={L\^e-Huu, \DJ.Khu\^e and Alahari, Karteek},
  booktitle={Neural Information Processing Systems (NeurIPS)},
  year={2021}
}

Currently the code is messy and undocumented, and we apology for that. We will make an effort to fix this soon. To facilitate the maintenance, the code and pre-trained models for the semantic segmentation task will be available in a separate repository.

Installation

git clone https://github.com/netw0rkf10w/CRF.git
cd CRF
python setup.py install

Usage

After having installed the package, you can create a CRF layer as follows:

import CRF

params = CRF.FrankWolfeParams(scheme='fixed', # constant stepsize
            stepsize=1.0,
            regularizer='l2',
            lambda_=1.0, # regularization weight
            lambda_learnable=False,
            x0_weight=0.5, # useful for training, set to 0 if inference only
            x0_weight_learnable=False)

crf = CRF.DenseGaussianCRF(classes=21,
                alpha=160,
                beta=0.05,
                gamma=3.0,
                spatial_weight=1.0,
                bilateral_weight=1.0,
                compatibility=1.0,
                init='potts',
                solver='fw',
                iterations=5,
                params=params)

Detailed documentation on the available options will be added later.

Below is an example of how to use this layer in combination with a CNN. We can define for example the following simple CNN-CRF module:

import torch

class CNNCRF(torch.nn.Module):
    """
    Simple CNN-CRF model
    """
    def __init__(self, cnn, crf):
        super().__init__()
        self.cnn = cnn
        self.crf = crf

    def forward(self, x):
        """
        x is a batch of input images
        """
        logits = self.cnn(x)
        logits = self.crf(x, logits)
        return logits

# Create a CNN-CRF model from given `cnn` and `crf`
# This is a PyTorch module that can be used in a usual way
model = CNNCRF(cnn, crf)

Acknowledgements

The CUDA implementation of the permutohedral lattice is due to https://github.com/MiguelMonteiro/permutohedral_lattice. An initial version of our permutohedral layer was based on https://github.com/Fettpet/pytorch-crfasrnn.

Comments
  • Forward pass issue

    Forward pass issue

    Thanks for sharing the great work.

    I've some question regarding the self.crf(x, logits) usage from your tutorial code.

    Does x is a normalized image tensors? Or should it be in a range of 0-255?

    Also, is the pre-trained CRF available by any chance? Thanks a lot.

    opened by WeiChihChern 9
  • The losses remain the same

    The losses remain the same

    import torch

    class CNNCRF(torch.nn.Module): """ Simple CNN-CRF model """ def init(self, cnn, crf): super().init() self.cnn = cnn self.crf = crf

    def forward(self, x):
        """
        x is a batch of input images
        """
        logits = self.cnn(x)
        logits = self.crf(x, logits)
        return logits
    

    Create a CNN-CRF model from given cnn and crf

    This is a PyTorch module that can be used in a usual way

    model = CNNCRF(cnn, crf)

    First I train unET and save the model, then I load the trained UNET model and train UNET and CRFS. I found the loss stuck at 0.693147. Do you have any suggestions?

    opened by 18972441546 5
  • Installation Issue

    Installation Issue

    Hi, thanks for the great work.

    I tried to install the pack via the python3 setup.py install, however as I import CRF I got no module name 'CRF'.

    Do you have any insight regarding this? Thanks.

    Here's what CRF-0.0.01-py3.6-linux-x86_64.egg directory has in my machine: image

    opened by WeiChihChern 5
  • Support for distributed training

    Support for distributed training

    Hello, thank you for this amazing work!

    I want to use this CRF package for distributed training. I wonder if it is OK to use it to train a segmentation model across many GPUs or processes. For example, does it contain any operations like batch normalization that needs to perform computations across batches?

    opened by Ending2015a 2
  • how to deal with the size of x and logits

    how to deal with the size of x and logits

    import torch

    class CNNCRF(torch.nn.Module): """ Simple CNN-CRF model """ def init(self, cnn, crf): super().init() self.cnn = cnn self.crf = crf

    def forward(self, x):
        """
        x is a batch of input images
        """
        logits = self.cnn(x)
        logits = self.crf(x, logits)
        return logits
    

    Create a CNN-CRF model from given cnn and crf

    This is a PyTorch module that can be used in a usual way

    model = CNNCRF(cnn, crf)

    According to your usage, if my image input size is (1,3,512,512), the label shape is (1,2,512,512) where 2 is the category containing the background. The output logits shape after CNN processing is also (1,2,512,512). At this point, (1,3,512,512) and (1,2,512,512) will be input self.crf. I wonder if I can.

    opened by 18972441546 1
  • Use of FW CRF for post-processing benchmarking

    Use of FW CRF for post-processing benchmarking

    Hi, First of all, thanks for the impressive work.

    We are experimenting with a novel post-processing method (aiming at CVPR) and would like to benchmark ours with Frank-Wolfe dense CRFs (FWCRF). To that end, we are not sure if we can use FWCRF directly as is (following chapter 5.2) for inference on Pascal VOC (as defined below, with adjusted alpha, beta, gamma according to section E.1) or whether we need to train it first on the training data of the respective dataset.

    fw_params = frankwolfecrf.FrankWolfeParams(scheme='fixed',  # constant stepsize
                                                 stepsize=1.0,
                                                 regularizer='l2',
                                                 lambda_=1.0,  # regularization weight
                                                 lambda_learnable=False,
                                                 x0_weight=0,  
                                                 x0_weight_learnable=False)
    
      fw_crf = frankwolfecrf.DenseGaussianCRF(classes=21,
                                              alpha=80,
                                              beta=13,
                                              gamma=3,
                                              spatial_weight=1.0,
                                              bilateral_weight=1.0,
                                              compatibility=1.0,
                                              init='potts',
                                              solver='fw',
                                              iterations=5,
                                              params=fw_params)
      prediction_fwcrf = fw_crf(images, base_model_logits)
    

    Thanks much in advance for getting back.

    Best wishes, Lukas

    opened by lukaszbinden 3
Owner
Đ.Khuê Lê-Huu
Đ.Khuê Lê-Huu
[cvpr22] Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation

PS-MT [cvpr22] Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation by Yuyuan Liu, Yu Tian, Yuanhong Chen, Fengbei Liu, Vasile

Yuyuan Liu 132 Jan 3, 2023
Official code for "Mean Shift for Self-Supervised Learning"

MSF Official code for "Mean Shift for Self-Supervised Learning" Requirements Python >= 3.7.6 PyTorch >= 1.4 torchvision >= 0.5.0 faiss-gpu >= 1.6.1 In

UMBC Vision 44 Nov 21, 2022
A pytorch implementation of MBNET: MOS PREDICTION FOR SYNTHESIZED SPEECH WITH MEAN-BIAS NETWORK

Pytorch-MBNet A pytorch implementation of MBNET: MOS PREDICTION FOR SYNTHESIZED SPEECH WITH MEAN-BIAS NETWORK Training To train a new model, please ru

null 46 Dec 28, 2022
Unet network with mean teacher for altrasound image segmentation

Unet network with mean teacher for altrasound image segmentation

null 5 Nov 21, 2022
Home repository for the Regularized Greedy Forest (RGF) library. It includes original implementation from the paper and multithreaded one written in C++, along with various language-specific wrappers.

Regularized Greedy Forest Regularized Greedy Forest (RGF) is a tree ensemble machine learning method described in this paper. RGF can deliver better r

RGF-team 364 Dec 28, 2022
Code for the paper "Adversarially Regularized Autoencoders (ICML 2018)" by Zhao, Kim, Zhang, Rush and LeCun

ARAE Code for the paper "Adversarially Regularized Autoencoders (ICML 2018)" by Zhao, Kim, Zhang, Rush and LeCun https://arxiv.org/abs/1706.04223 Disc

Junbo (Jake) Zhao 399 Jan 2, 2023
Two-Stage Peer-Regularized Feature Recombination for Arbitrary Image Style Transfer

Two-Stage Peer-Regularized Feature Recombination for Arbitrary Image Style Transfer Paper on arXiv Public PyTorch implementation of two-stage peer-reg

NNAISENSE 38 Oct 14, 2022
(IEEE TIP 2021) Regularized Densely-connected Pyramid Network for Salient Instance Segmentation

RDPNet IEEE TIP 2021: Regularized Densely-connected Pyramid Network for Salient Instance Segmentation PyTorch training and testing code are available.

Yu-Huan Wu 41 Oct 21, 2022
Disagreement-Regularized Imitation Learning

Due to a normalization bug the expert trajectories have lower performance than the rl_baseline_zoo reported experts. Please see the following link in

Kianté Brantley 25 Apr 28, 2022
R-Drop: Regularized Dropout for Neural Networks

R-Drop: Regularized Dropout for Neural Networks R-drop is a simple yet very effective regularization method built upon dropout, by minimizing the bidi

null 756 Dec 27, 2022
This repository is the official implementation of Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regularized Fine-Tuning (NeurIPS21).

Core-tuning This repository is the official implementation of ``Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regular

vanint 18 Dec 17, 2022
Graph Regularized Residual Subspace Clustering Network for hyperspectral image clustering

Graph Regularized Residual Subspace Clustering Network for hyperspectral image clustering

Yaoming Cai 5 Jul 18, 2022
Flexible-CLmser: Regularized Feedback Connections for Biomedical Image Segmentation

Flexible-CLmser: Regularized Feedback Connections for Biomedical Image Segmentation The skip connections in U-Net pass features from the levels of enc

Boheng Cao 1 Dec 29, 2021
Code for the paper: On Pathologies in KL-Regularized Reinforcement Learning from Expert Demonstrations

Non-Parametric Prior Actor-Critic (N-PPAC) This repository contains the code for On Pathologies in KL-Regularized Reinforcement Learning from Expert D

Cong Lu 5 May 13, 2022
Official Pytorch implementation of "Beyond Static Features for Temporally Consistent 3D Human Pose and Shape from a Video", CVPR 2021

TCMR: Beyond Static Features for Temporally Consistent 3D Human Pose and Shape from a Video Qualtitative result Paper teaser video Introduction This r

Hongsuk Choi 215 Jan 6, 2023
Code Repo for the ACL21 paper "Common Sense Beyond English: Evaluating and Improving Multilingual LMs for Commonsense Reasoning"

Common Sense Beyond English: Evaluating and Improving Multilingual LMs for Commonsense Reasoning This is the Github repository of our paper, "Common S

INK Lab @ USC 19 Nov 30, 2022
Scenic: A Jax Library for Computer Vision and Beyond

Scenic Scenic is a codebase with a focus on research around attention-based models for computer vision. Scenic has been successfully used to develop c

Google Research 1.6k Dec 27, 2022
BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond

BasicVSR BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond Ported from https://github.com/xinntao/BasicSR Dependencie

Holy Wu 8 Jun 7, 2022
Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context Code in both PyTorch and TensorFlow

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context This repository contains the code in both PyTorch and TensorFlow for our paper

Zhilin Yang 3.3k Jan 6, 2023