Cross-framework Python Package for Evaluation of Latent-based Generative Models
Latte
Latte (for LATent Tensor Evaluation) is a cross-framework Python package for evaluation of latent-based generative models. Latte supports calculation of disentanglement and controllability metrics in both PyTorch (via TorchMetrics) and TensorFlow.
Installation
For developers working on local clone, cd
to the repo and replace latte
with .
. For example, pip install .[tests]
pip install latte-metrics # core (numpy only)
pip install latte-metrics[pytorch] # with torchmetrics wrapper
pip install latte-metrics[keras] # with tensorflow wrapper
pip install latte-metrics[tests] # for testing
Running tests locally
pip install .[tests]
pytest tests/ --cov=latte
Example
Functional API
import latte
from latte.functional.disentanglement.mutual_info import mig
import numpy as np
latte.seed(42)
z = np.random.randn(16, 8)
a = np.random.randn(16, 2)
mutual_info_gap = mig(z, a, discrete=False, reg_dim=[4, 3])
Modular API
import latte
from latte.metrics.core.disentanglement import MutualInformationGap
import numpy as np
latte.seed(42)
mig = MutualInformationGap()
# ...
# initialize data and model
# ...
for data, attributes in range(batches):
recon, z = model(data)
mig.update_state(z, attributes)
mig_val = mig.compute()
TorchMetrics API
import latte
from latte.metrics.torch.disentanglement import MutualInformationGap
import torch
latte.seed(42)
mig = MutualInformationGap()
# ...
# initialize data and model
# ...
for data, attributes in range(batches):
recon, z = model(data)
mig.update(z, attributes)
mig_val = mig.compute()
Keras Metric API
import latte
from latte.metrics.keras.disentanglement import MutualInformationGap
from tensorflow import keras as tfk
latte.seed(42)
mig = MutualInformationGap()
# ...
# initialize data and model
# ...
for data, attributes in range(batches):
recon, z = model(data)
mig.update_state(z, attributes)
mig_val = mig.result()
Documentation
https://latte.readthedocs.io/en/latest
Supported metrics
Metric | Latte Functional | Latte Modular | TorchMetrics | Keras Metric |
---|---|---|---|---|
Disentanglement Metrics | ||||
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Interpolatability Metrics | ||||
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Bundled metric modules
Metric Bundle | Latte Functional | Latte Modular | TorchMetrics | Keras Metric | Included |
---|---|---|---|---|---|
Dependency-aware Disentanglement | |
|
|
|
MIG, DMIG, XMIG, DLIG |
LIAD-based Interpolatability | |
|
|
|
Smoothness, Monotonicity |
Cite
For individual metrics, please cite the paper according to the link in the
If you find our package useful please cite our repository and arXiv preprint as
@article{
watcharasupat2021latte,
author = {Watcharasupat, Karn N. and Lee, Junyoung and Lerch, Alexander},
title = {{Latte: Cross-framework Python Package for Evaluation of Latent-based Generative Models}},
eprint={2112.10638},
archivePrefix={arXiv},
primaryClass={cs.LG},
url = {https://github.com/karnwatcharasupat/latte}
doi = {10.5281/zenodo.5786402}
}