clip-text-decoder
Generate text captions for images from their CLIP embeddings. Includes PyTorch model code and example training script.
Example Predictions
Example captions were computed with the pretrained model mentioned below.
"A man riding a wave on top of a surfboard."
A baseball player is swinging a bat at a ball.
"A dog running across a field with a frisbee."
Installation
Install for easier access to the following objects/classes:
clip_text_decoder.datasets.ClipCocoCaptionsDataset
clip_text_decoder.models.ClipDecoder
clip_text_decoder.models.ClipDecoderInferenceModel
clip_text_decoder.tokenizer.Tokenizer
The train.py
script will not be available in the installed package, since it's located in the root directory. To train new models, either clone this repository or recreate train.py
locally.
Using pip
:
pip install clip-text-decoder
From source:
git clone https://github.com/fkodom/clip-text-decoder.git
cd clip-text-decoder
pip install .
NOTE: You'll also need to install openai/CLIP
to encode images with CLIP. This is also required by ClipCocoCaptionsDataset
to build the captions dataset the first time (cached for subsequent calls).
pip install "clip @ git+https://github.com/openai/CLIP.git"
For technical reasons, the CLIP dependency can't be included in the PyPI package, since it's not an officially published package.
Training
Launch your own training session using the provided script (train.py
):
python train.py --max-epochs 5
Training CLI arguments, along with their default values:
--max-epochs 5 # (int)
--num-layers 6 # (int)
--dim-feedforward 256 # (int)
--precision 16 # (16 or 32)
--seed 0 # (int)
Inference
The training script will produce a model.zip
archive, containing the Tokenizer
and trained model parameters. To perform inference with it:
import clip
from PIL import Image
import torch
from clip_text_decoder.model import ClipDecoderInferenceModel
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ClipDecoderInferenceModel.load("path/to/model.zip").to(device)
clip_model, clip_preprocessor = clip.load("ViT-B/32", device=device, jit=False)
# Create a blank dummy image
dummy_image = Image.new("RGB", (224, 224))
preprocessed = clip_preprocessor(dummy_image).to(device)
# Add a batch dimension using '.unsqueeze(0)'
encoded = clip_model.encode_image(preprocessed.unsqueeze(0))
text = model(encoded)
print(text)
# Probably some nonsense, because we used a dummy image.
Pretrained Models
A pretrained CLIP decoder is hosted in my Google Drive, and can easily be downloaded by:
from clip_text_decoder.model import ClipDecoderInferenceModel
model = ClipDecoderInferenceModel.download_pretrained()
To cache the pretrained model locally, so that it's not re-downloaded each time:
model = ClipDecoderInferenceModel.download_pretrained("/path/to/model.zip")
Shortcomings
- Only works well with COCO-style images. If you go outside the distribution of COCO objects, you'll get nonsense text captions.
- Relatively short training time. Even within the COCO domain, you'll occasionally see incorrect captions. Quite a few captions will have bad grammar, repetitive descriptors, etc.