PGMax
PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.
- General factor graphs: PGMax supports easy specification of general factor graphs with potentially complicated topology, factor definitions, and discrete variables with a varying number of states.
- LBP in JAX: PGMax generates pure JAX functions implementing LBP for a given factor graph. The generated pure JAX functions run on modern accelerators (GPU/TPU), work with JAX transformations (e.g.
vmap
for processing batches of models/samples,grad
for differentiating through the LBP iterative process), and can be easily used as part of a larger end-to-end differentiable system.
See our blog post and companion paper for more details.
Installation | Getting started
Installation
Install from PyPI
pip install pgmax
Install latest version from GitHub
pip install git+https://github.com/vicariousinc/PGMax.git
Developer
git clone https://github.com/vicariousinc/PGMax.git
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/install-poetry.py | python3 -
cd PGMax
poetry shell
poetry install
pre-commit install
Install on GPU
By default the above commands install JAX for CPU. If you have access to a GPU, follow the official instructions here to install JAX for GPU.
Getting Started
Here are a few self-contained Colab notebooks to help you get started on using PGMax:
- Tutorial on basic PGMax usage
- Implementing max-product LBP for Recursive Cortical Networks
- End-to-end differentiable LBP for gradient-based PGM training
Citing PGMax
Please consider citing our companion paper if you use PGMax in your work:
@article{zhou2022pgmax,
author = {Zhou, Guangyao and Kumar, Nishanth and L{\'a}zaro-Gredilla, Miguel and Kushagra, Shrinu and George, Dileep},
title = {{PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX}},
journal = {arXiv preprint arXiv:2202.04110},
year={2022}
}
First two authors contributed equally.