MAGMA - a GPT-style multimodal model that can understand any combination of images and language

MAGMA -- Multimodal Augmentation of Generative Models through Adapter-based Finetuning


Constantin (CoEich), Mayukh (Mayukhdeb), Sid (sdtblck)


Constantin Eichenberg, Sidney Black, Samuel Weinbach, Aleph Alpha

Letitia Parcalabescu, Anette Frank, Heidelberg University


Large-scale pretraining is fast becoming the norm in Vision-Language (VL) modeling. However, prevailing VL approaches are limited by the requirement for labeled data and the use of complex multi-step pretraining objectives. We present MAGMA - a simple method for augmenting generative language models with additional modalities using adapter-based finetuning. Building on Frozen, we train a series of VL models that autoregressively generate text from arbitrary combinations of visual and textual input. The pretraining is entirely end-to-end using a single language modeling objective, simplifying optimization compared to previous approaches. Importantly, the language model weights remain unchanged during training, allowing for transfer of encyclopedic knowledge and in-context learning abilities from language pretraining. MAGMA outperforms Frozen on open-ended generative tasks, achieving state of the art results on the OKVQA benchmark and competitive results on a range of other popular VL benchmarks, while pretraining on 0.2% of the number of samples used to train SimVLM.

Paper on arXiv:

Examples (via Aleph Alpha playground)

Photos Text & Technical
A man covering a woman's eyes to hide a present A hand drawn treasure map
A fallen tree is blocking a road A software architecture

Model design

MAGMA model design

About the repository

In this repository we share the main parts of the codebase for training and inference of our MAGMA VL model. The main use of the repo is for downloading our pretrained weights and interacting with the model. We include a script for data parallel training with Deepspeed for finetuning our models or training a MAGMA model from scratch.


Make sure PyTorch (Ver >= 1.9.0) and Torchvision are installed. See

You can pip install from the git repository with:

pip install git+

Make sure that you also download the config:

mkdir configs; wget -O configs/MAGMA_v1.yml

Or if you've cloned the repo, you can install all further requirements by:

pip install -r requirements.txt


We also publish the model checkpoint that has been used for the publication. It is hosted on our infrastructure and downloads automatically. It can be downloaded manually here:

This checkpoint can also be played around with on a space managed by Heath Mitchell, AK, and Stella Biderman. (This is a 3rd party space, not managed by Aleph Alpha.)

Loading a model for inference

Downloads the checkpoint file into checkpoint_path if it's not already present.

from magma import Magma
from magma.image_input import ImageInput

model = Magma.from_checkpoint(
    config_path = "configs/MAGMA_v1.yml",
    checkpoint_path = "./",
    device = 'cuda:0'

inputs =[
    ## supports urls and path/to/image
    'Describe the painting:'

## returns a tensor of shape: (1, 149, 4096)
embeddings = model.preprocess_inputs(inputs)  

## returns a list of length embeddings.shape[0] (batch size)
output = model.generate(
    embeddings = embeddings,
    max_steps = 6,
    temperature = 0.7,
    top_k = 0,

print(output[0]) ##  A cabin on a lake

Converting datasets to our format

To convert an image-caption dataset to our dataset class magma.datasets.ImgCptDataset, we suggest:

from magma.datasets.convert_datasets import convert_dataset

def my_dataset_iterator():
    Implement an iterator for your dataset that for every datapoint yields a tuple
    image_path, {"captions": [...], "metadata": {...}, }, where image_path is the path to the image as a Path object, captions is a list of caption strings and metadata is an optional field.

if __name__ == "__main__":
    convert_dataset(data_dir="/target/directory", ds_iterator=my_dataset_iterator())

How to train MAGMA

Run the training with:

deepspeed --config path_to_my_config

To continue training from a deepspeed checkpoint, provide the checkpoint directory in the "load" config parameter.

WARNING: By default, instantiating magma via the init method instead of from_checkpoint loads the pretrained CLIP weights but not the pretrained gpt-j weights. For training MAGMA from scratch, download the gpt-j weights from this repo: and include them in the state dict after initializing the MAGMA model.

  • AssertionError: Parameter with name: lm.transformer.wte.weight occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behaviour

    AssertionError: Parameter with name: lm.transformer.wte.weight occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behaviour


    I'd like to rerun the code using gpt-neo125M or gpt2-med instead of gpt-nep2.7B ad I'm getting this error?

    AssertionError: Parameter with name: lm.transformer.wte.weight occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behaviour.

    Any idea why this issue exist for other language model?

    opened by monajati 6
  • No module named 'magma.transformers'

    No module named 'magma.transformers'


    just download the "magma-master" file and followed the instructions (I think), but trying to run I get errors. It seems there are some parts missing?

    First I get: (magma) c:\Python\magma-master>python Traceback (most recent call last): File "c:\Python\magma-master\", line 4, in <module> from magma.language_model import get_language_model ImportError: cannot import name 'get_language_model' from 'magma.language_model' (c:\Python\magma-master\magma\>

    looking at the code it seems like get_language_model is not used anyhow, so commented line 4 out. But after that there is a similar miss:

    (magma) c:\Python\magma-master>python Traceback (most recent call last): File "c:\Python\magma-master\", line 25, in <module> from magma.transformers import GPTJForCausalLM ModuleNotFoundError: No module named 'magma.transformers'

    And here GPTJForCausalLM is used right in the next line. Looking at there is just nothing like GPTJForCausalLM in there at all. Seems like something is missing here completly?

    Best Tuxius

    opened by Tuxius 6
  • Automatic model download doesn't work

    Automatic model download doesn't work ends up containing:

    <!DOCTYPE html><html><head><title>Google Drive - Virus scan warning</title><meta http-equiv="content-type" content="text/html; charset=utf-8"/><style nonce="t256RPQHLynZvvCq0ggl7w">/* Copyright 2022 Google Inc. All Rights Reserved. */
    .goog-inline-block{position:relative;display:-moz-inline-box;display:inline-block}* html .goog-inline-block,*:first-child+html .goog-inline-block{display:inline}.goog-link-button{position:relative;color:#15c;text-decoration:underline;cursor:pointer}.goog-link-button-disabled{color:#ccc;text-decoration:none;cursor:default}body{color:#222;font:normal 13px/1.4 arial,sans-serif;margin:0}.grecaptcha-badge{visibility:hidden}.uc-main{padding-top:50px;text-align:center}#uc-dl-icon{display:inline-block;margin-top:16px;padding-right:1em;vertical-align:top}#uc-text{display:inline-block;max-width:68ex;text-align:left}.uc-error-caption,.uc-warning-caption{color:#222;font-size:16px}#uc-download-link{text-decoration:none}.uc-name-size a{color:#15c;text-decoration:none}.uc-name-size a:visited{color:#61c;text-decoration:none}.uc-name-size a:active{color:#d14836;text-decoration:none}.uc-footer{color:#777;font-size:11px;padding-bottom:5ex;padding-top:5ex;text-align:center}.uc-footer a{color:#15c}.uc-footer a:visited{color:#61c}.uc-footer a:active{color:#d14836}.uc-footer-divider{color:#ccc;width:100%}</style><link rel="icon" href="null"/></head><body><div class="uc-main"><div id="uc-dl-icon" class="image-container"><div class="drive-sprite-aux-download-file"></div></div><div id="uc-text"><p class="uc-warning-caption">Google Drive can't scan this file for viruses.</p><p class="uc-warning-subcaption"><span class="uc-name-size"><a href="/open?id=1EiAY3IcKWmGADaLDzdG25ykQghUwza6L"></a> (12G)</span> is too large for Google to scan for viruses. Would you still like to download this file?</p><form id="downloadForm" action=";export=download&amp;confirm=t" method="post"><input type="submit" id="uc-download-link" class="goog-inline-block jfk-button jfk-button-action" value="Download anyway"/></form></div></div><div class="uc-footer"><hr class="uc-footer-divider"></div></body></html>


    Traceback (most recent call last):
      File "/home/ubuntu/magma/", line 4, in <module>
        model = Magma.from_checkpoint(
      File "/home/ubuntu/magma/magma/", line 292, in from_checkpoint
        sd = torch.load(checkpoint_path, map_location=torch.device("cpu"))
      File "/usr/local/share/miniconda/lib/python3.9/site-packages/torch/", line 593, in load
        return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
      File "/usr/local/share/miniconda/lib/python3.9/site-packages/torch/", line 762, in _legacy_load
        magic_number = pickle_module.load(f, **pickle_load_args)
    _pickle.UnpicklingError: invalid load key, '<'.

    Possibly related to

    opened by Heath123 5
  • Mismatching LM shape between 50400 (pre-trainined pt) and 50258 (gpt-2)

    Mismatching LM shape between 50400 (pre-trainined pt) and 50258 (gpt-2)

    Thanks for this wonderful work 😍
    I load, but it shows that the shape of LM is different:

    size mismatch for lm.lm_head.weight: copying a param with shape torch.Size([50400, 4096]) from checkpoint, the shape in current model is torch.Size([50258, 4096])

    I guess it is because of the resize_token_embeddings here.
    I also tried to truncate the additional dimension,

    sd["lm.lm_head.weight"] = sd["lm.lm_head.weight"][:50258, :]
    sd["lm.lm_head.bias"] = sd["lm.lm_head.bias"][:50258]

    but the result of seems weird 😂

    bondankeNM Drama fixtures Sergey
    Fantasticheddar AUTHOR hob sealedunction

    Super thanks for the help!
    opened by tsujuifu 3
  • Torch Size Mismatch

    Torch Size Mismatch

    Hey guys!

    I had a quick issue while loading Magma from the checkpoint, and I was wondering if anyone encountered or knows how to solve the problem.

    RuntimeError: Error(s) in loading state_dict for Magma: size mismatch for lm.lm_head.weight: copying a param with shape torch.Size([50400, 4096]) from checkpoint, the shape in current model is torch.Size([50258, 4096]).

    It seems like the size of the checkpoint model differs from the size of the model it is expecting from the rest of the code.

    Thank you so much--this model looks super cool and I'm excited to use it!

    opened by harshagundala 1
  • Subsequent inference calls produce less good results

    Subsequent inference calls produce less good results

    Following the code in or to perform inference by calling model.preprocess_inputs(…) followed by model.generate(…) produces good results the first time the pair is called, but poor results for subsequent pairs of calls.

    The reason is that model = Magma.from_checkpoint(…) loads the model with inconsistent training/eval settings. is True but is False. The first call to model.preprocess_inputs(…) works correctly as the image encoder has training False and so its Batch Normalisation steps work correctly. The call to model.generate(…) records the training state on entry and restores it on exit, which because is True puts the whole model into training state. Subsequent calls to model.preprocess_inputs(…) then don't perform Batch Normalisation steps correctly.

    The play space at has this problem too.

    The fix is to add model.eval() after model = Magma.from_checkpoint(…), setting the whole model to a consistent eval state.

    opened by steve-barlow 1
  • size mismatch for lm.lm_head.weight: copying a param with shape torch.Size([50400, 4096]) from checkpoint, the shape in current model is torch.Size([50258, 4096]).

    size mismatch for lm.lm_head.weight: copying a param with shape torch.Size([50400, 4096]) from checkpoint, the shape in current model is torch.Size([50258, 4096]).

    size mismatch for lm.lm_head.weight: copying a param with shape torch.Size([50400, 4096]) from checkpoint, the shape in current model is torch.Size([50258, 4096]). this question is very strange. I didn't change any code, and I found that the model and the config have some mismatch. Does anyone meet the same question?

    opened by UCCME 1
  • top_p argument is used like 1-top_p

    top_p argument is used like 1-top_p

    For example, top_p=0.999 gives you nearly deterministic sampling, not nearly on-distribution sampling.

    I was confused why I was getting much less diverse samples with top_p=0.95 than I got with top_p turned off.

    I found the cause in these lines:

    threshold is set to top_p here:

    Suppose eg threshold is 0.95. Then 1-threshold is 0.05.

    So we remove all tokens where the cumulative probs are > 0.05, which is most of the tokens -- we are really doing top-p sampling with top_p=0.05 (in the usual convention), not the intended top_p=0.95.

    opened by nostalgebraist 1
  • (#9) Improved inference interface

    (#9) Improved inference interface

    Contains the following changes:

    1. the model, tokenizer, and transforms are now contained under a unified wrapper: Magma() which can be used as shown below:
    from multimodal_fewshot import Magma
    magma = Magma(
        checkpoint_path = '', ## downloads automatically if not present in this path
        config_path = 'configs/MAGMA_v1.yml',
    1. Image inputs are now handled by ImageInput(), which supports both urls and local image paths.
    inputs =[
    	'Describe the painting:'
    1. Magma() supports both low level and high level inference
    ## forward pass
    embeddings = magma.preprocess_inputs(inputs = inputs) ## returns a torch tensor of shape (1, sequence_length, hidden_dim)
    outputs = magma(embeddings) ## output logits shape: torch.Size([1, 150, 50400])
    ## high level inference
    completion = magma.generate(inputs = inputs, num_tokens = 4, topk = 1) 
    ## completion: "A cabin on a lake"
    opened by Mayukhdeb 1
  • Improved inference interface

    Improved inference interface

    Implement an interface like the one Mayukh suggested:

    from magma import Magma 
    from magma.image import Image, ImageFromURL  ## to easily load/use images
    model, tokenizer = Magma(checkpoint = '', config = 'config.yml', device = 'cuda:0')
    inputs = [
        'Where is this ? A: Egypt',
        'Where is this ? A:'
    embeddings = tokenizer.tokenize(inputs).to(model.device)
    output = model.forward(embeddings, output_attentions = True)
    logits = output.logits ## tensor of shape [1, len_seq, len_vocab]
    attentions = output.attentions ## list of tensors
    ## this already exists
    generated_text = model.generate(embeddings, n_steps = 10, *args)```
    opened by CoEich 1
  • Remove dataset builders and old classes in multimodal_fewshot.datasets

    Remove dataset builders and old classes in multimodal_fewshot.datasets

    • Remove scripts that download the various datasets
    • Keep the ImgCptDataset base class and classification wrappers in, remove "old" classes
    • Keep the convert_dataset function in (maybe slightly refactor)
    opened by CoEich 1
  • how did you calculate the bleu score

    how did you calculate the bleu score

    Hi, thanks for the awesome project. I noticed that the reported BLEU@4 and CIDEr scores in Table 1 are ~10 and ~50 on the MS COCO dataset(zero-shot, after fine-tuning the scores increase to 31 and 90+), respectively, which fall far behind traditional baselines like AoA and CLIP-ViL(they usually achieve ~40 BLEU-4 and 120+ CIDEr). I am wondering whether the difference is due to the evaluation setup, did you use the evaluation in coco-caption or calculate the scores yourself?

    opened by TobiasLee 0
  • fix inference_step

    fix inference_step

    inference_step passes inference=True to model_engine. However, the __forward__ of the Magma model does not accept this parameter, which will cause an error during training. I fix it by simply copying the inference code from

    opened by Fireblossom 3
Aleph Alpha GmbH
Aleph Alpha GmbH
