Tensorflow implementation and notebooks for Implicit Maximum Likelihood Estimation

Overview

tf-imle

Tensorflow 2 and PyTorch implementation and Jupyter notebooks for Implicit Maximum Likelihood Estimation (I-MLE) proposed in the NeurIPS 2021 paper Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions.

I-MLE is also available as a PyTorch library: https://github.com/uclnlp/torch-imle

Introduction

Implicit MLE (I-MLE) makes it possible to include discrete combinatorial optimization algorithms, such as Dijkstra's algorithm or integer linear programming (ILP) solvers, as well as complex discrete probability distributions in standard deep learning architectures. The figure below illustrates the setting I-MLE was developed for. is a standard neural network, mapping some input to the input parameters of a discrete combinatorial optimization algorithm or a discrete probability distribution, depicted as the black box. In the forward pass, the discrete component is executed and its discrete output fed into a downstream neural network . Now, with I-MLE it is possible to estimate gradients of with respect to a loss function, which are used during backpropagation to update the parameters of the upstream neural network.

Illustration of the problem addressed by I-MLE

The core idea of I-MLE is that it defines an implicit maximum likelihood objective whose gradients are used to update upstream parameters of the model. Every instance of I-MLE requires two ingredients:

  1. A method to approximately sample from a complex and possibly intractable distribution. For this we use Perturb-and-MAP (aka the Gumbel-max trick) and propose a novel family of noise perturbations tailored to the problem at hand.
  2. A method to compute a surrogate empirical distribution: Vanilla MLE reduces the KL divergence between the current distribution and the empirical distribution. Since in our setting, we do not have access to such an empirical distribution, we have to design surrogate empirical distributions which we term target distributions. Here we propose two families of target distributions which are widely applicable and work well in practice.

Requirements:

TensorFlow 2 implementation:

  • tensorflow==2.3.0 or tensorflow-gpu==2.3.0
  • numpy==1.18.5
  • matplotlib==3.1.1
  • scikit-learn==0.24.1
  • tensorflow-probability==0.7.0

PyTorch implementation:

Example: I-MLE as a Layer

The following is an instance of I-MLE implemented as a layer. This is a class where the optimization problem is computing the k-subset configuration, the target distribution is based on perturbation-based implicit differentiation, and the perturb-and-MAP noise perturbations are drawn from the sum-of-gamma distribution.

class IMLESubsetkLayer(tf.keras.layers.Layer):
    
    def __init__(self, k, _tau=10.0, _lambda=10.0):
        super(IMLESubsetkLayer, self).__init__()
        # average number of 1s in a solution to the optimization problem
        self.k = k
        # the temperature at which we want to sample
        self._tau = _tau
        # the perturbation strength (here we use a target distribution based on perturbation-based implicit differentiation
        self._lambda = _lambda  
        # the samples we store for the backward pass
        self.samples = None 
        
    @tf.function
    def sample_sum_of_gamma(self, shape):
        
        s = tf.map_fn(fn=lambda t: tf.random.gamma(shape, 1.0/self.k, self.k/t), 
                  elems=tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]))   
        # now add the samples
        s = tf.reduce_sum(s, 0)
        # the log(m) term
        s = s - tf.math.log(10.0)
        # divide by k --> each s[c] has k samples whose sum is distributed as Gumbel(0, 1)
        s = self._tau * (s / self.k)

        return s
    
    @tf.function
    def sample_discrete_forward(self, logits): 
        self.samples = self.sample_sum_of_gamma(tf.shape(logits))
        gamma_perturbed_logits = logits + self.samples
        # gamma_perturbed_logits is the input to the combinatorial opt algorithm
        # the next two lines can be replaced by a custom black-box algorithm call
        threshold = tf.expand_dims(tf.nn.top_k(gamma_perturbed_logits, self.k, sorted=True)[0][:,-1], -1)
        y = tf.cast(tf.greater_equal(gamma_perturbed_logits, threshold), tf.float32)
        
        return y
    
    @tf.function
    def sample_discrete_backward(self, logits):     
        gamma_perturbed_logits = logits + self.samples
        # gamma_perturbed_logits is the input to the combinatorial opt algorithm
        # the next two lines can be replaced by a custom black-box algorithm call
        threshold = tf.expand_dims(tf.nn.top_k(gamma_perturbed_logits, self.k, sorted=True)[0][:,-1], -1)
        y = tf.cast(tf.greater_equal(gamma_perturbed_logits, threshold), tf.float32)
        return y
    
    @tf.custom_gradient
    def subset_k(self, logits, k):

        # sample discretely with perturb and map
        z_train = self.sample_discrete_forward(logits)
        # compute the top-k discrete values
        threshold = tf.expand_dims(tf.nn.top_k(logits, self.k, sorted=True)[0][:,-1], -1)
        z_test = tf.cast(tf.greater_equal(logits, threshold), tf.float32)
        # at training time we sample, at test time we take the argmax
        z_output = K.in_train_phase(z_train, z_test)
        
        def custom_grad(dy):

            # we perturb (implicit diff) and then resuse sample for perturb and MAP
            map_dy = self.sample_discrete_backward(logits - (self._lambda*dy))
            # we now compute the gradients as the difference (I-MLE gradients)
            grad = tf.math.subtract(z_train, map_dy)
            # return the gradient            
            return grad, k

        return z_output, custom_grad

Reference

@inproceedings{niepert21imle,
  author    = {Mathias Niepert and
               Pasquale Minervini and
               Luca Franceschi},
  title     = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family
               Distributions},
  booktitle = {NeurIPS},
  series    = {Proceedings of Machine Learning Research},
  publisher = {{PMLR}},
  year      = {2021}
}
You might also like...
Softlearning is a reinforcement learning framework for training maximum entropy policies in continuous domains. Includes the official implementation of the Soft Actor-Critic algorithm.

Softlearning Softlearning is a deep reinforcement learning toolbox for training maximum entropy policies in continuous domains. The implementation is

Maximum Spatial Perturbation for Image-to-Image Translation (Official Implementation)
Maximum Spatial Perturbation for Image-to-Image Translation (Official Implementation)

MSPC for I2I This repository is by Yanwu Xu and contains the PyTorch source code to reproduce the experiments in our CVPR2022 paper Maximum Spatial Pe

A repository that shares tuning results of trained models generated by TensorFlow / Keras. Post-training quantization (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization), Quantization-aware training. TensorFlow Lite. OpenVINO. CoreML. TensorFlow.js. TF-TRT. MediaPipe. ONNX. [.tflite,.h5,.pb,saved_model,tfjs,tftrt,mlmodel,.xml/.bin, .onnx] TorchPQ is a python library for Approximate Nearest Neighbor Search (ANNS) and Maximum Inner Product Search (MIPS) on GPU using Product Quantization (PQ) algorithm.
PyTorch code accompanying our paper on Maximum Entropy Generators for Energy-Based Models

Maximum Entropy Generators for Energy-Based Models All experiments have tensorboard visualizations for samples / density / train curves etc. To run th

Deep Reinforcement Learning by using an on-policy adaptation of Maximum a Posteriori Policy Optimization (MPO)

V-MPO Simple code to demonstrate Deep Reinforcement Learning by using an on-policy adaptation of Maximum a Posteriori Policy Optimization (MPO) in Pyt

Predicting path with preference based on user demonstration using Maximum Entropy Deep Inverse Reinforcement Learning in a continuous environment
Predicting path with preference based on user demonstration using Maximum Entropy Deep Inverse Reinforcement Learning in a continuous environment

Preference-Planning-Deep-IRL Introduction Check my portfolio post Dependencies Gym stable-baselines3 PyTorch Usage Take Demonstration python3 record.

Realtime YOLO Monster Detection With Non Maximum Supression
Realtime YOLO Monster Detection With Non Maximum Supression

Realtime-YOLO-Monster-Detection-With-Non-Maximum-Supression Table of Contents In

A commany has recently introduced a new type of bidding, the average bidding, as an alternative to the bid given to the current maximum bidding
A commany has recently introduced a new type of bidding, the average bidding, as an alternative to the bid given to the current maximum bidding

Business Problem A commany has recently introduced a new type of bidding, the average bidding, as an alternative to the bid given to the current maxim

Comments
  • error in

    error in "traditional" Gumbel Softmax trick

    Thanks for the Discete VAE demo. Seems there is small error in "def sample_gumbel(self, shape, eps=1e-20): " function. U is not assigned. It seems to be U= tf.random.uniform(shape, minval=0, maxval=1) ?

    opened by balcilar 4
  • `k` in IMLESubsetkLayer

    `k` in IMLESubsetkLayer

    In your example code in the README there are two k values: one passed in the constructor and used in IMLESubsetkLayer.subset_k via self.k and one passed directly to IMLESubsetkLayer.subset_k. The latter is seemingly not used for anything except to return from custom_grad. Is this an error which should be replaced by self.k everywhere or am I missing something?

    opened by jatentaki 1
Owner
NEC Laboratories Europe
Research software developed at NEC Laboratories Europe
NEC Laboratories Europe
Code for "Human Pose Regression with Residual Log-likelihood Estimation", ICCV 2021 Oral

Human Pose Regression with Residual Log-likelihood Estimation [Paper] [arXiv] [Project Page] Human Pose Regression with Residual Log-likelihood Estima

JeffLi 347 Dec 24, 2022
Re-implementation of the Noise Contrastive Estimation algorithm for pyTorch, following "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." (Gutmann and Hyvarinen, AISTATS 2010)

Noise Contrastive Estimation for pyTorch Overview This repository contains a re-implementation of the Noise Contrastive Estimation algorithm, implemen

Denis Emelin 42 Nov 24, 2022
Tensorflow Implementation of SMU: SMOOTH ACTIVATION FUNCTION FOR DEEP NETWORKS USING SMOOTHING MAXIMUM TECHNIQUE

SMU A Tensorflow Implementation of SMU: SMOOTH ACTIVATION FUNCTION FOR DEEP NETWORKS USING SMOOTHING MAXIMUM TECHNIQUE arXiv https://arxiv.org/abs/211

Fuhang 5 Jan 18, 2022
Official implementation for Likelihood Regret: An Out-of-Distribution Detection Score For Variational Auto-encoder at NeurIPS 2020

Likelihood-Regret Official implementation of Likelihood Regret: An Out-of-Distribution Detection Score For Variational Auto-encoder at NeurIPS 2020. T

Xavier 33 Oct 12, 2022
Pcos-prediction - Predicts the likelihood of Polycystic Ovary Syndrome based on patient attributes and symptoms

PCOS Prediction ?? Predicts the likelihood of Polycystic Ovary Syndrome based on

Samantha Van Seters 1 Jan 10, 2022
Learning Energy-Based Models by Diffusion Recovery Likelihood

Learning Energy-Based Models by Diffusion Recovery Likelihood Ruiqi Gao, Yang Song, Ben Poole, Ying Nian Wu, Diederik P. Kingma Paper: https://arxiv.o

Ruiqi Gao 41 Nov 22, 2022
Github for the conference paper GLOD-Gaussian Likelihood OOD detector

FOOD - Fast OOD Detector Pytorch implamentation of the confernce peper FOOD arxiv link. Abstract Deep neural networks (DNNs) perform well at classifyi

null 17 Jun 19, 2022
Minimisation of a negative log likelihood fit to extract the lifetime of the D^0 meson (MNLL2ELDM)

Minimisation of a negative log likelihood fit to extract the lifetime of the D^0 meson (MNLL2ELDM) Introduction The average lifetime of the $D^{0}$ me

Son Gyo Jung 1 Dec 17, 2021
Unofficial Tensorflow 2 implementation of the paper Implicit Neural Representations with Periodic Activation Functions

Siren: Implicit Neural Representations with Periodic Activation Functions The unofficial Tensorflow 2 implementation of the paper Implicit Neural Repr

Seyma Yucer 2 Jun 27, 2022
Implementation of fast algorithms for Maximum Spanning Tree (MST) parsing that includes fast ArcMax+Reweighting+Tarjan algorithm for single-root dependency parsing.

Fast MST Algorithm Implementation of fast algorithms for (Maximum Spanning Tree) MST parsing that includes fast ArcMax+Reweighting+Tarjan algorithm fo

Miloš Stanojević 11 Oct 14, 2022