Great job!
However, I think there are probably two tiny issues in you code.
The first one is in bridge.py
(line 24 & line 53). I think there are some differences in the following two lines of code
x = x.reshape(b, c, h*w).transpose(1,2).unsqueeze(1)
x = x.contiguous().view(b, h * w, c).unsqueeze(1)
May be the first line is correct?
The second one is in config.py
.Accroding to the original paper, in page 13,
Figure 7. Visualization of cross attention on the two-way bridge: MobileāFormer and MobileāFormer. Mobile-Former-294M is used,which includes 6 tokens (each corresponds to a column) and 11 Mobile-Former blocks (block 2ā12) across 4 stages. Each block has two attention heads that are visualized in two rows. Attention in MobileāFormer (left half) is normalized over pixels, showing the focused region per token. Attention in MobileāFormer (right half) is normalized over tokens showing the contribution per token at each pixel.
But in config.py
, there are some stages with only one head.
I'm not sure whether the above is correct. Looking forward to your reply!