Beta Divergence Loss - PyTorch Implementation
This repository contains code for a PyTorch implementation of the beta divergence loss.
Dependencies
This package is written in Python, and requires Python (with recommended version >= 3.9) to run. In addition to a working Pytorch installation, this package relies on the following libraries and version numbers:
Installation
To install the latest stable release, use pip. Use the following command to install:
$ pip install pytorch-beta-divergence
Usage
The nn.py
module contains two beta-divergence implementations: one general beta-divergence between two 2-dimensional matrices or tensors, and a beta-divergence implementation specific to non-negative matrix factorization (NMF). Import both beta-divergence implementations as follows:
# Import PyTorch beta-divergence implementations
from torch_beta_div.nn import *
Beta-divergence between two matrices
To calculate the beta-divergence between matrix A
and a target or reference matrix B
, use the BetaDivLoss
loss function. The BetaDivLoss
loss function can be instantiated and used as follows:
# Instantiate beta-divergence loss object
beta_div_loss = BetaDivLoss(beta=0, reduction='mean')
# Calculate beta-divergence loss between matrix A and target matrix B
loss = beta_div_loss(input=A, target=B)
NMF beta-divergence between data matrix and reconstruction
To calculate the NMF-specific beta-divergence between data matrix X
and the matrix product of a scores matrix H
and a components matrix W
, use the NMFBetaDivLoss
loss function. The NMFBetaDivLoss
loss function can be instantiated and used as follows:
# Instantiate NMF beta-divergence loss object
nmf_beta_div_loss = NMFBetaDivLoss(beta=0, reduction='mean')
# Calculate beta-divergence loss between data matrix X (target or
# reference matrix) and matrix product of H and W
loss = nmf_beta_div_loss(X=X, H=H, W=W)
Choosing beta value
When instantiating beta divergence loss objects, the value of beta should be chosen depending on data type and application. Integer values of beta correspond to the following divergences and loss functions:
- beta = 0: Itakura-Saito divergence
- beta = 1: Kullback-Leibler divergence
- beta = 2: mean-squared error
Issue Tracking and Reports
Please use the GitHub issue tracker associated with this repository for issue tracking, filing bug reports, and asking general questions about the package or project.