# RealFormer-Pytorch

Implementation of RealFormer using pytorch. Includes comparison with classical Transformer on image classification task (ViT) wrt CIFAR-10 dataset.

Original Paper of the model : https://arxiv.org/abs/2012.11747

## So how are RealFormers at vision tasks?

Run the train.py with

```
model = ViR(
image_pix = 32,
patch_pix = 4,
class_cnt = 10,
layer_cnt = 4
)
```

to Test how RealFormer works on CIFAR-10 dataset compared to just classical ViT, which is

```
model = ViT(
image_pix = 32,
patch_pix = 4,
class_cnt = 10,
layer_cnt = 4
)
```

... which is of course, much, much smaller version of ViT compared to the origianl ones ().

## Results

**Model : layers = 4, hidden_dim = 128, feedforward_dim = 512, head_cnt = 4**

Trained 10 epochs

After 10'th epoch, Realformer achieves 65.45% while Transformer achieves 64.59% RealFormer seems to consistently have about 1% greater accuracy, which seems reasonable (as the papaer suggested simillar result)

**Model : layers = 8, hidden_dim = 128, feedforward_dim = 512, head_cnt = 4**

Having 4 more layers obviously improves in general, and still, RealFormer consistently wins in terms of accuracy (68.3% vs 66.3%). Notice that **larger the model, bigger the difference** seems to follow here too. (I wonder how much of difference it would make on ViT-Large)

When it comes to computation time, there was almost zero difference. (I guess adding residual attention score is O(L^2) operation, compared to matrix multiplication in softmax which is O(L^2 * D))

## Conclusion

Use RealFormer. It benifits with almost zero additional resource!

## To make a custom RealFormer for other tasks

Its not a pip package, but you can use the ResEncoderBlock module in the models.py to make a Encoder Only Transformer like the following :

```
import ResEncoderBlock from models
def RealFormer(nn.Module):
...
def __init__(self, ...):
...
self.mains = nn.Sequential(*[ResEncoderBlock(emb_s = 32, head_cnt = 8, dp1 = 0.1, dp2 = 0.1) for _ in range(layer_cnt)])
...
def forward(self, x):
...
prev = None
for resencoder in self.mains:
x, prev = resencoder(x, prev = prev)
...
return x
```

If you're not really clear what is going on or what to do, request me to make this a pip package.