The calculation of loss function is wicked and allows model to fit very quickly and produce very good translation. But unfortunately this is not real and the model is not able to spell out the German sentence itself. It could only completes an i-length German sentence if you give it the first (i-1) token, which means it could not generate the whole sentence from a start-of-sentence tag .
It seems reasonable enough to use this loss to train the network, but unreasonable to assess its translation ability, though I have yet to train this network to its full functionality.
###################
## original leave-last-token out decoder,
## Not sure what's the exact error in this calculation,
## but maybe because the model see the mask token directly?
##
#Output German, One Token At A Time
all_outs = torch.tensor([],requires_grad=True).to(device)
for i in range(item["german"].shape[1]-1):
out = model(item["german"][:,:i+1])
all_outs = torch.cat((all_outs,out),dim=1)
# ###################
# My variation of leave-last-token-out decoder, Used at training
# output_vocab_size = german_vocab_len
g = item["german"].shape
x = torch.zeros( [g[0],g[1],],dtype=torch.long ).to(device)
all_outs = torch.tensor([],requires_grad=True).to(device)
for i in range(item["german"].shape[1]-1):
xx = torch.zeros( [g[0],g[1], ],dtype=torch.long ).to(device)
out = model(x)
xx[:,i:i+1] = item["german"][:,i:i+1]
x = x+xx
all_outs = torch.cat((all_outs,out),dim=1)
# ###################
# My variation of beam search decoder
model.encode(item["english"][:,1:-1])
g = item["german"].shape
x = torch.zeros( [g[0],g[1],],dtype=torch.long ).to(device)
all_outs = torch.tensor([],requires_grad=True).to(device)
for i in range(item["german"].shape[1]-1):
out = model(x)
x[:,i:i+1] = out.argmax(axis=-1)
all_outs = torch.cat((all_outs,out),dim=1)
I found this glitch when fiddling with the attention layer at its core, and found zeroing the attention value created no harm to the performance of a last-token-only model in
sub_layers.py
attention_weights = F.softmax(attention_weights,dim=2)
attention_weights = attention_weights *0. ## Try this!