tokenlearner-pytorch
Unofficial PyTorch implementation of TokenLearner
by Ryoo et al. from Google AI (abs
, pdf
)
Installation
You can install TokenLearner via pip
:
pip install tokenlearner-pytorch
Usage
You can access the TokenLearner
class from the tokenlearner_pytorch
package. You can use this layer with a Vision Transformer, MLPMixer, or Video Vision Transformer as done in the paper.
import torch
from tokenlearner_pytorch import TokenLearner
tklr = TokenLearner(S=8)
x = torch.rand(512, 32, 32, 3)
y = tklr(x) # [512, 8, 3]
You can also use TokenLearner
and TokenFuser
together with Multi-head Self-Attention as done in the paper:
import torch
import torch.nn as nn
from tokenlearner_pytorch import TokenLearner, TokenFuser
mhsa = nn.MultiheadAttention(3, 1)
tklr = TokenLearner(S=8)
tkfr = TokenFuser(H=32, W=32, C=3, S=8)
x = torch.rand(512, 32, 32, 3) # a batch of images
y = tklr(x)
y = y.view(8, 512, 3)
y, _ = mhsa(y, y, y) # ignore attn weights
y = y.view(512, 8, 3)
out = tkfr(y, x) # [512, 32, 23, 3]
TODO
- Add support for temporal dimension
T
- Implement
TokenFuser
withViT
- Implement
TokenFuser
withViViT
Contributions
If I've made any errors or you have any suggestions, feel free to raise an Issue or PR. All contributions welcome!!