TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?
A PyTorch implementation of TokenLearner: What Can 8 Learned Tokens Do for Images and Videos? [1-2]. Unlike another Unofficial PyTorch implementation [3], our version is heavily borrowed from the official implementation [4] and TensorFlow implementation[5], and try to keep consistent with them.
Usage
You can access the TokenLearner
and TokenLearnerModuleV11
class from the tokenlearner
file. You can use this layer with a Vision Transformer, MLPMixer, or Video Vision Transformer as done in the paper.
import torch
from tokenlearner import TokenLearner
tklr = TokenLearner(in_channels=128, num_tokens=8, use_sum_pooling=False)
x = torch.ones(256, 32, 32, 128) # [bs, h, w, c]
y1 = tklr(x)
print(y1.shape) # [256, 8, 128]
You can also use TokenLearnerModuleV11
, which aligns with the official implementation.
import torch
from tokenlearner import TokenLearnerModuleV11
tklr_v11 = TokenLearnerModuleV11(in_channels=128, num_tokens=8, num_groups=4, dropout_rate=0.)
tklr_v11.eval() # control droput
x = torch.ones(256, 32, 32, 128) # [bs, h, w, c]
y2 = tklr_v11(x)
print(y2.shape) # [256, 8, 128]
References
[1] TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?; Ryoo et al.; arXiv 2021; https://arxiv.org/abs/2106.11297
[2] TokenLearner: Adaptive Space-Time Tokenization for Videos; Ryoo et al., NeurIPS 2021; https://openreview.net/forum?id=z-l1kpDXs88