Hi,
thanks for a great implementation!
I wanted to clarify one thing that mismatches with the code, proposed in the article itself. In your code, you pre-normalize inputs, so that they are passed through LayerNorm before FFT. In the code, presented in the article, they have:
class FNetEncoderBlock ( nn . Module ) :
30 f o u r i e r _ l a y e r : Fou rie rT ran sfo rmLa ye r
31 f f _ l a y e r : FeedForwardLayer
32
33 @nn. compact
34 def _ _ c a l l _ _ ( s e l f , x , d e t e r m i n i s t i c ) :
35 m i x i n g _ o ut p ut = s e l f . f o u r i e r _ l a y e r ( x )
36 x = nn . LayerNorm (1 e−12 , name=" mixing_laye r_no rm " ) ( x + &
m i x i n g _ o ut p ut )
37 fe ed _fo rw a rd _o utp ut = s e l f . f f _ l a y e r ( x , d e t e r m i n i s t i c )
38 r e t u r n nn . LayerNorm (
39 1e−12 , name=" output_la ye r_no rm " ) ( x + fee d_fo rwa rd _outp ut )
which in my view is done in the opposite order.
Am I mistaken or is it indeed a bug?