Hi,
Thanks for your excellent work.
I found the implementation of padding a batch in your code is confusing. The code for batch is:
def pad_batch(h_node, batch, max_input_len, get_mask=False):
num_batch = batch[-1] + 1
num_nodes = []
masks = []
for i in range(num_batch):
mask = batch.eq(i)
masks.append(mask)
num_node = mask.sum()
num_nodes.append(num_node)
# logger.info(max(num_nodes))
max_num_nodes = min(max(num_nodes), max_input_len)
padded_h_node = h_node.data.new(max_num_nodes, num_batch, h_node.size(-1)).fill_(0)
src_padding_mask = h_node.data.new(num_batch, max_num_nodes).fill_(0).bool()
for i, mask in enumerate(masks):
num_node = num_nodes[i]
if num_node > max_num_nodes:
num_node = max_num_nodes
padded_h_node[-num_node:, i] = h_node[mask][-num_node:]
src_padding_mask[i, : max_num_nodes - num_node] = True # [b, s]
if get_mask:
return padded_h_node, src_padding_mask, num_nodes, masks, max_num_nodes
return padded_h_node, src_padding_mask
I think the line "src_padding_mask[i, : max_num_nodes - num_node] = True" for masking might should be:
src_padding_mask = h_node.data.new(num_batch, max_num_nodes).fill_(1).bool()
src_padding_mask[i, : max_num_nodes - num_node] = False
Because in the pooling part, the original code can cause the denominator of the this line as 0:
h_graph = transformer_out.sum(0) / src_padding_mask.sum(-1, keepdim=True)