ssnt-loss
ℹ️ This is a WIP project. the implementation is still being tested.
A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction" https://arxiv.org/abs/1609.08194.
Usage
There are two versions, a normal version and a memory efficient version. They should give the same output, please inform me if they don't.
def ssnt_loss_mem(
log_probs: Tensor,
targets: Tensor,
log_p_choose: Tensor,
source_lengths: Tensor,
target_lengths: Tensor,
neg_inf: float = -1e4,
reduction="mean",
):
"""The memory efficient implementation concatenates along the targets
dimension to reduce wasted computation on padding positions.
Assuming the summation of all targets in the batch is T_flat, then
the original B x T x ... tensor is reduced to T_flat x ...
The input tensors can be obtained by using target mask:
Example:
>>> target_mask = targets.ne(pad) # (B, T)
>>> targets = targets[target_mask] # (T_flat,)
>>> log_probs = log_probs[target_mask] # (T_flat, S, V)
Args:
log_probs (Tensor): Word prediction log-probs, should be output of log_softmax.
tensor with shape (T_flat, S, V)
where T_flat is the summation of all target lengths,
S is the maximum number of input frames and V is
the vocabulary of labels.
targets (Tensor): Tensor with shape (T_flat,) representing the
reference target labels for all samples in the minibatch.
log_p_choose (Tensor): emission log-probs, should be output of F.logsigmoid.
tensor with shape (T_flat, S)
where T_flat is the summation of all target lengths,
S is the maximum number of input frames.
source_lengths (Tensor): Tensor with shape (N,) representing the
number of frames for each sample in the minibatch.
target_lengths (Tensor): Tensor with shape (N,) representing the
length of the transcription for each sample in the minibatch.
neg_inf (float, optional): The constant representing -inf used for masking.
Default: -1e4
reduction (string, optional): Specifies reduction. suppoerts mean / sum.
Default: None.
"""
Minimal example
import torch
import torch.nn as nn
import torch.nn.functional as F
from ssnt_loss import ssnt_loss_mem, lengths_to_padding_mask
B, S, H, T, V = 2, 100, 256, 10, 2000
# model
transcriber = nn.LSTM(input_size=H, hidden_size=H, num_layers=1).cuda()
predictor = nn.LSTM(input_size=H, hidden_size=H, num_layers=1).cuda()
joiner_trans = nn.Linear(H, V, bias=False).cuda()
joiner_alpha = nn.Sequential(
nn.Linear(H, 1, bias=True),
nn.Tanh()
).cuda()
# inputs
src_embed = torch.rand(B, S, H).cuda().requires_grad_()
tgt_embed = torch.rand(B, T, H).cuda().requires_grad_()
targets = torch.randint(0, V, (B, T)).cuda()
adjust = lambda x, goal: x * goal // x.max()
source_lengths = adjust(torch.randint(1, S+1, (B,)).cuda(), S)
target_lengths = adjust(torch.randint(1, T+1, (B,)).cuda(), T)
# forward
src_feats, (h1, c1) = transcriber(src_embed.transpose(1, 0))
tgt_feats, (h2, c2) = predictor(tgt_embed.transpose(1, 0))
# memory efficient joint
mask = ~lengths_to_padding_mask(target_lengths)
lattice = F.relu(
src_feats.transpose(0, 1).unsqueeze(1) + tgt_feats.transpose(0, 1).unsqueeze(2)
)[mask]
log_alpha = F.logsigmoid(joiner_alpha(lattice)).squeeze(-1)
lattice = joiner_trans(lattice).log_softmax(-1)
# normal ssnt loss
loss = ssnt_loss_mem(
lattice,
targets[mask],
log_alpha,
source_lengths=source_lengths,
target_lengths=target_lengths,
reduction="sum"
) / (B*T)
loss.backward()
print(loss.item())
Note
This implementation is based on the simplifying derivation proposed for monotonic attention, where they use parallelized cumsum
and cumprod
to compute the alignment. Based on the similarity of SSNT and monotonic attention, we can infer that the forward variable alpha(i,j) can be computed similarly.
Feel free to contact me if there are bugs in the code.