Hi,
I tried this method to linear attention
http://proceedings.mlr.press/v119/katharopoulos20a/katharopoulos20a.pdf
as following code:
#PermuterFormer - P
q = q.gather(-1, self.permutation[:, :, :q.shape[2]].expand_as(q))
k = k.gather(-1, self.permutation[:, :, :k.shape[2]].expand_as(k))
# Apply the feature map to the queries and keys
Q = torch.nn.functional.elu(q) + 1
K = torch.nn.functional.elu(k) + 1
#PermuterFormer - r
Q *= (self.ratio.unsqueeze(-1) ** torch.arange(Q.shape[2], device=Q.device).unsqueeze(0)).unsqueeze(-1)
K *= ((1 / self.ratio).unsqueeze(-1) ** torch.arange(K.shape[2], device=K.device).unsqueeze(0)).unsqueeze(-1)
if mask is not None:
K.masked_fill_(mask.unsqueeze(1).unsqueeze(-1), 0.0)
# Compute the KV matrix
KV = torch.einsum("nhsd,nhsm->nhmd", K, v)
# Compute the normalizer
Z = 1/(torch.einsum("nhld,nhd->nlh", Q, K.sum(dim=2))+self.eps)
# Finally compute and return the new values
V = torch.einsum("nhld,nhmd,nlh->nlhm", Q, KV, Z)
But always got "nan" issue after 1~5 steps.
From my perspective, this may caused by this step:
Q *= (self.ratio.unsqueeze(-1) ** torch.arange(Q.shape[2], device=Q.device).unsqueeze(0)).unsqueeze(-1)
K *= ((1 / self.ratio).unsqueeze(-1) ** torch.arange(K.shape[2], device=K.device).unsqueeze(0)).unsqueeze(-1)
which multiply a very small number to Q and a very big number to K when the index is large.
Do I use the correct integration way? Or any suggestion for this?
Thanks.