I follow monotonic attention here: https://arxiv.org/pdf/1704.00784.pdf.
In tensorflow, it work well. (source code here: https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py. )
But in pytorch, it cannot work. Here is my source code. Could you take a look, please?
def safe_cumprod(x, exclusive=False, max_value=1):
"""
exclusive=True: cumprod(x) = [1, x1, x1*x2, x1*x2*x3, ...]
exclusive=False: cumprod(x) = [x1, x1*x2, x1*x2*x3, ...]
Args:
x (torch.Tensor): shape of [batch, input_dim]
exclusive ():
max_value (): clip max value
Returns:
"""
tiny = float(np.finfo(np.float32).tiny)
clip_x = torch.clamp(x, tiny, max_value)
cumprod_x = torch.exp(torch.cumsum(torch.log(clip_x), dim=1))
if exclusive is True:
return F.pad(cumprod_x, (1, 0, 0, 0), value=1)[:, :-1]
else:
return cumprod_x
class BahdanauAttention(nn.Module):
def __init__(self, dim):
super(BahdanauAttention, self).__init__()
self.query_layer = nn.Linear(dim, dim, bias=False)
self.tanh = nn.Tanh()
self.v = Parameter(torch.Tensor(1, dim))
self.reset_parameters()
def reset_parameters(self):
fan_in, fan_out = self.v.size()
scale = 1 / max(1., (fan_in + fan_out) / 2.)
limit = math.sqrt(3.0 * scale)
self.v.data.uniform_(-limit, limit)
def _alignment_probability(self, score, previous_alignment=None):
return F.softmax(score, dim=1)
def forward(self, query, processed_memory):
"""
Args:
query: (batch, 1, dim) or (batch, dim)
processed_memory: (batch, max_time, dim)
"""
if query.dim() == 2:
# insert time-axis for broadcasting
query = query.unsqueeze(1)
# (batch, 1, dim)
processed_query = self.query_layer(query)
# (batch, max_time, 1)
alignment = F.linear(self.tanh(processed_query + processed_memory), self.v)
# (batch, max_time)
return alignment.squeeze(-1)
class BahdanauMonoAttention(BahdanauAttention):
"""BahdanauMonoAttention
"""
def __init__(self, dim):
super(BahdanauMonoAttention, self).__init__(dim)
self.score_bias = Parameter(torch.Tensor(1))
self.reset_parameters()
def reset_parameters(self):
self.score_bias.data.zero_()
def forward(self, query, processed_memory):
return super(BahdanauMonoAttention, self).forward(query, processed_memory) + self.score_bias
def _alignment_probability(self, score, previous_alignment=None):
"""
_mono_score, https://arxiv.org/pdf/1704.00784.pdf
Args:
score (): shape of [batch, encoder_length]
previous_alignment (): shape of [batch, encoder_length]
Returns:
"""
#score += Variable(torch.FloatTensor(np.random.randn(*score.shape) * 2).cuda())
p_choose_i = F.sigmoid(score)
cumprod_1mp_choose_i = safe_cumprod(1 - p_choose_i, exclusive=True, max_value=1)
attention = p_choose_i * cumprod_1mp_choose_i * torch.cumsum(
previous_alignment / torch.clamp(cumprod_1mp_choose_i, 1e-10, 1.), dim=1)
return attention
def get_mask_from_lengths(memory, memory_lengths):
"""Get mask tensor from list of length
Args:
memory: (batch, max_time, dim)
memory_lengths: array like
"""
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
for idx, l in enumerate(memory_lengths):
mask[idx][:l] = 1
return ~mask
class AttentionWrapper(nn.Module):
def __init__(self, rnn_cell, attention_mechanism,
score_mask_value=-float("inf")):
super(AttentionWrapper, self).__init__()
self.rnn_cell = rnn_cell
self.attention_mechanism = attention_mechanism
self.score_mask_value = score_mask_value
def forward(self, query, attention, cell_state, memory, previous_alignment=None,
processed_memory=None, mask=None, memory_lengths=None):
if processed_memory is None:
processed_memory = memory
if memory_lengths is not None and mask is None:
mask = get_mask_from_lengths(memory, memory_lengths)
# Concat input query and previous attention context
cell_input = torch.cat((query, attention), -1)
# Feed it to RNN
cell_output = self.rnn_cell(cell_input, cell_state)
# Alignment
# (batch, max_time)
alignment = self.attention_mechanism(cell_output, processed_memory)
if mask is not None:
mask = mask.view(query.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value)
# Normalize attention weight
# alignment = F.softmax(alignment, dim=-1)
alignment = self.attention_mechanism._alignment_probability(alignment, previous_alignment)
# Attention context vector
# (batch, 1, dim)
attention = torch.bmm(alignment.unsqueeze(1), memory)
# (batch, dim)
attention = attention.squeeze(1)
return cell_output, attention, alignment
class Decoder(nn.Module):
def __init__(self, in_dim, r, use_mono=True):
super(Decoder, self).__init__()
self.in_dim = in_dim
self.r = r
self.prenet = Prenet(in_dim, sizes=[256, 128])
# (prenet_out + attention context) -> output
if use_mono is True:
attention_mechanism = BahdanauMonoAttention(256)
else:
attention_mechanism = BahdanauAttention(256)
self.attention_rnn = AttentionWrapper(
nn.GRUCell(256 + 128, 256),
attention_mechanism
)
self.memory_layer = nn.Linear(256, 256, bias=False)
self.project_to_decoder_in = nn.Linear(512, 256)
self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)])
self.proj_to_mel = nn.Linear(256, in_dim * r)
self.max_decoder_steps = 200
def forward(self, encoder_outputs, inputs=None, memory_lengths=None):
"""
Decoder forward step.
If decoder inputs are not given (e.g., at testing time), as noted in
Tacotron paper, greedy decoding is adapted.
Args:
encoder_outputs: Encoder outputs. (B, T_encoder, dim)
inputs: Decoder inputs. i.e., mel-spectrogram. If None (at eval-time),
decoder outputs are used as decoder inputs.
memory_lengths: Encoder output (memory) lengths. If not None, used for
attention masking.
"""
B = encoder_outputs.size(0)
T_encoder = encoder_outputs.size(1)
processed_memory = self.memory_layer(encoder_outputs)
if memory_lengths is not None:
mask = get_mask_from_lengths(processed_memory, memory_lengths)
else:
mask = None
# Run greedy decoding if inputs is None
greedy = inputs is None
if inputs is not None:
# Grouping multiple frames if necessary
if inputs.size(-1) == self.in_dim:
inputs = inputs.view(B, inputs.size(1) // self.r, -1)
assert inputs.size(-1) == self.in_dim * self.r
T_decoder = inputs.size(1)
# go frames
initial_input = Variable(
encoder_outputs.data.new(B, self.in_dim).zero_())
# Init decoder states
attention_rnn_hidden = Variable(
encoder_outputs.data.new(B, 256).zero_())
decoder_rnn_hiddens = [Variable(
encoder_outputs.data.new(B, 256).zero_())
for _ in range(len(self.decoder_rnns))]
current_attention = Variable(
encoder_outputs.data.new(B, 256).zero_())
# Time first (T_decoder, B, in_dim)
if inputs is not None:
inputs = inputs.transpose(0, 1)
outputs = []
alignments = []
t = 0
current_input = initial_input
previous_alignment = Variable(
encoder_outputs.data.new(B, T_encoder).zero_())
previous_alignment[:, 0] = 1.0
while True:
if t > 0:
current_input = outputs[-1] if greedy else inputs[t - 1]
current_input = current_input[:, -self.in_dim:]
# Prenet
current_input = self.prenet(current_input)
# Attention RNN
attention_rnn_hidden, current_attention, alignment = self.attention_rnn(
current_input, current_attention, attention_rnn_hidden,
encoder_outputs, previous_alignment=previous_alignment,
processed_memory=processed_memory, mask=mask)
previous_alignment = alignment
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_attention), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
decoder_input, decoder_rnn_hiddens[idx])
# Residual connectinon
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
output = decoder_input
output = self.proj_to_mel(output)
outputs += [output]
alignments += [alignment]
t += 1
if greedy:
if t > 1 and is_end_of_frames(output):
break
elif t > self.max_decoder_steps:
print("Warning! doesn't seems to be converged")
break
else:
if t >= T_decoder:
break
assert greedy or len(outputs) == T_decoder
# Back to batch first
alignments = torch.stack(alignments).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
return outputs, alignments
@r9y9
wontfix discussion