Learning to Prompt for Continual Learning

Overview

Learning to Prompt for Continual Learning (L2P) Official Jax Implementation

L2P is a novel continual learning technique which learns to dynamically prompt a pre-trained model to learn tasks sequentially under different task transitions. Different from mainstream rehearsal-based or architecture-based methods, L2P requires neither a rehearsal buffer nor test-time task identity. L2P can be generalized to various continual learning settings including the most challenging and realistic task-agnostic setting. L2P consistently outperforms prior state-of-the-art methods. Surprisingly, L2P achieves competitive results against rehearsal-based methods even without a rehearsal buffer.

Code is written by Zifeng Wang. Acknowledgement to https://github.com/google-research/nested-transformer.

This is not an officially supported Google product.

Enviroment setup

pip install -r requirements.txt

Getting pretrained ViT model

ViT-B/16 model used in this paper can be downloaded at here.

Instructions on running L2P

We provide the configuration file to train and evaluate L2P on multiple benchmarks in configs.

To run our method on the Split CIFAR-100 dataset (class-incremental setting):

python -m main.py --my_config configs/cifar100_l2p.py --workdir=./cifar100_l2p --my_config.init_checkpoint=<ViT-saved-path/ViT-B_16.npz>

To run our method on the more complex Gaussian Scheduled CIFAR-100 dataset (task-agnostic setting):

python -m main.py --my_config configs/cifar100_gaussian_l2p.py --workdir=./cifar100_gaussian_l2p --my_config.init_checkpoint=<ViT-saved-path/ViT-B_16.npz>

Note: we run our experiments using 8 V100 GPUs or 4 TPUs, and we specify a per device batch size of 16 in the config files. This indicates that we use a total batch size of 128.

Visualize results

We use tensorboard to visualize the result. For example, if the working directory specified to run L2P is workdir=./cifar100_l2p, the command to check result is as follows:

tensorboard --logdir ./cifar100_l2p

Here are the important metrics to keep track of, and their corresponding meanings:

Metric Description
accuracy_n Accuracy of the n-th task
forgetting Average forgetting up until the current task
avg_acc Average evaluation accuracy up until the current task

Cite

@inproceedings{wang2021learning,
  title={Learning to Prompt for Continual Learning},
  author={Zifeng Wang and Zizhao Zhang and Chen-Yu Lee and Han Zhang and Ruoxi Sun and Xiaoqi Ren and Guolong Su and Vincent Perot and Jennifer Dy and Tomas Pfister},
  booktitle={arXiv preprint arXiv:2112.08654},
  year={2021}
}
Comments
  • ImportError: cannot import name 'resnet_v1' from 'models' (unknown location)

    ImportError: cannot import name 'resnet_v1' from 'models' (unknown location)

    In the File "/media/iiau/LiGong4T2/zwb/code/l2p/train_continual.py", line 42, I can see:

    from models import resnet_v1

    but in "./models" folder, i can't find resnet_v1, so I get error like this:

    ImportError: cannot import name 'resnet_v1' from 'models' (unknown location)

    How to solve this problem? Thanks!

    opened by hubblezhang 3
  • Confusion about the ImageNet-R dataset

    Confusion about the ImageNet-R dataset

    Hi, Thanks for your nice work!

    I am currently trying to follow your Split ImageNet-R benchmark but encounter some problems. Firstly, I am not able to download the ImageNet-R dataset. I've got errors like this. 2022-10-25 16:43:39.764732: E tensorflow/core/platform/cloud/curl_http_request.cc:614] The transmission of request 0x668a800 (URI: https://www.googleapis.com/storage/v1/b/tfds-data/o/dataset_info%2Fimagenet_r%2F0.2.0?fields=size%2Cgeneration%2Cupdated) has been stuck at 0 of 0 bytes for 61 seconds and will be aborted. CURL timing information: lookup time: 0.004201 (No error), connect time: 0 (No error), pre-transfer time: 0 (No error), start-transfer time: 0 (No error) It seems like a networking issue, so I access the uri(https://www.googleapis.com/storage/v1/b/tfds-data/o/dataset_info%2Fimagenet_r%2F0.2.0?fields=size%2Cgeneration%2Cupdated) through my web browser, but I got such a info:

    { "error": { "code": 404, "message": "No such object: tfds-data/dataset_info/imagenet_r/0.2.0", "errors": [ { "message": "No such object: tfds-data/dataset_info/imagenet_r/0.2.0", "domain": "global", "reason": "notFound" } ] } }

    It seems that the dataset is not accessible at this time.

    Secound, I tried to understand the codes you priveded in libml/input_pipeline.py and try to understand 2 things: (1) How to split the training and testing set from the original dataset. (2) What is the training order of the classes in continual learning tasks.

    For (1), I found that the codes using TFDS split, in Line414-424 in input_pipeline.py. I am not familiar with TFDS, Is this a stratified sampling (based on classes) or just a split based on the whole data?

    For (2), I found that the codes permutate class order here, if the config.continual.rand_seed is set. However in configs/imr_dualprompt.py, config.continual.rand_seed is set to $-1$. Does this mean that the dualprompt conducts experiments in the order of natural numbers? (i.e,. 0, 1, 2, 3, ...)

    Hope you counld kindly help me :)

    Looking forward to your early reply and thanks again!

    opened by TOM-tym 2
  • ReadMe issues

    ReadMe issues

    I am trying to reproduce the results of the tables included in your nice paper. And I don't think that the code you have provided allows me to reproduce the results of Table 3. Table 3 uses CORe50, and I was curious to know when the code that produces the results of Table 3 will be available.

    Sorry to take your time with this, and hope you are doing well.

    opened by adocaj 2
  • Reproduce issue

    Reproduce issue

    Dear author.

    Thank you for your great work.

    I'm having a little problem with reproducing L2P.

    First, please modify the environment setup of README.md.

    The link for adjust the jax version in README does not support CUDA version. I think the link

    https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

    seems to have changed to the link above.

    The other thing is that even if batch_size is set to 1, I can't run the code completely in my 4 A5000 GPUs because of out of memory issue. Can you also provide small models such as ViT-Tiny or ViT-Small?

    I look forward to hearing from you.

    Thank you.

    opened by JH-LEE-KR 2
  • Possible information leakage from pretrained model

    Possible information leakage from pretrained model

    Dear author,

    Thank you for your excellent work!

    I am a little curious about the pretrained model, it is trained on the entire ImageNet-21k dataset, and is fixed during training. But will this lead to information leakage?

    Take the class incremental setting as an example, I think all 100 classes of CIFAR100 can be found in ImageNet-21k so it is possible that the model has already learned all the features necessary for CIFAR100. But in practice, the model is expected to learn new features. We can not assume the classes in new tasks have already been observed by the backbone, right?

    Have you tried to remove CIFAR100 classes in ImageNet and pretrained a model or evaluate the model on some datasets disjoint with ImageNet?

    Thank you very much!

    opened by BinahHu 2
  • About optionally diversifying prompt-selection

    About optionally diversifying prompt-selection

    Thanks for the great idea and the result!

    As the title says, I'd like to know how to use optionally diversifying prompt-selection, I don't see where to use the arguments for this method, nor do I see an implementation of it in . /models/prompt.py

    I would like to ask about how to normalize the frequency of each prompt into a penalty factor, I don't see a specific description in the paper.

    opened by Dicer-Zz 2
  • How to reproduce the results with replay?

    How to reproduce the results with replay?

    Hi, thanks for the great work!

    With the script you provided, I successfully reproduced the L2P result (of Split CIFAR-100 dataset) without replay.

    I am now trying to reproduce the results with replays, but even using replay buffer storing 50 samples/class, I got much lower results (acc 80.10/forgetting 9.13) compared to the reported ones (acc 86.31/forgetting 5.83).

    It doesn't make sense that using replay leads to higher forgetting. I guess I might miss something.

    Since there is no examples how to use replays in given configuration file (i.e., cifar100_l2p.py), I added some lines to handle replays as below.

      # add for replay
      config.continual.replay = ml_collections.ConfigDict()
      config.continual.replay.num_samples_per_class = 50
      config.continual.replay_no_mask = True
      config.continual.replay_reverse_mask = False
      config.continual.replay.include_new_task = True
      config.continual.review_trick = False
      config.continual.num_review_steps = -1
      config.continual.num_review_epochs = 5
    

    I guess review_trick is for fine-tuning the model with balanced dataset.

    Strangely, when I set review_trick=True, I got much lower result (especially, very low learning accuracy).

    And when I set review_trick=False, then the model is kept updating with replays, but still it shows much low accuracy (acc 80.10/forgetting 9.13)

    Do you have any advice on what I am missing or where to modify in your code?

    Or can you share the correct configuration files to reproduce the result of Table 1 using replay buffer?

    Thank you.

    opened by whieya 2
  • Question about the paper's comparison.

    Question about the paper's comparison.

    I've read the paper and have some small questions about the paper's comparison to other models like EWC, ER and DER++ because I've not gained or maybe missed the information about the details of conducting these methods in the pretrained ViT model. Which part of the ViT is trained or finetuned in Upper-bound and those methods and which parts are using pretrained weights? I guess only the classifier is trained but need some confirmation.

    opened by wzlk655 2
  • constraint_coefficient of surrogate loss for pulling selected keys closer

    constraint_coefficient of surrogate loss for pulling selected keys closer

    The constraint_coefficient of surrogate loss is 0.5 in the paper.

    But, in the codes, that is set to -0.1.

    What is correct? Which is better?

    Thank you.

    opened by kimwongyuda 1
  • RESOURCE_EXHAUSTED: Out of memory while trying to allocate # bytes.

    RESOURCE_EXHAUSTED: Out of memory while trying to allocate # bytes.

    Hi, I am trying to run the model on a CIFAR100 dataset. I am getting the following error. I have 4 Tesla V100 GPUs.

    2022-08-05 10:00:26.833817: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 9.00MiB (rounded to 9437184)requested by op 
    2022-08-05 10:00:26.835182: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] *********************************************************************************x**************x***
    2022-08-05 10:00:26.835281: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 9437184 bytes.
    BufferAssignment OOM Debugging.
    

    The complete running logs can be found here. Please help me with solving the issue.

    ===============

    For your information, I was getting a RuntimeError: Visible devices cannot be modified after being initialized error. Hence, I added the following code snippet in main.py from https://www.tensorflow.org/guide/gpu, and it solved the issue.

    """Main file for running the example."""
    
    import os
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
    
    imports ...
    
    FLAGS = flags.FLAGS
    ...
    
    def main(argv):
      del argv
    
      # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
      # it unavailable to JAX.
      # tf.config.experimental.set_visible_devices([], "GPU")
    
      gpus = tf.config.list_physical_devices('GPU')
      if gpus:
        # Restrict TensorFlow to only use the first GPU
        try:
          tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
          logical_gpus = tf.config.list_logical_devices('GPU')
          print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
        except RuntimeError as e:
          # Visible devices must be set before GPUs have been initialized
          print(e)
    
      # if gpus:
      #   # Create 2 virtual GPUs with 1GB memory each
      #   try:
      #     tf.config.set_logical_device_configuration(
      #         gpus[0],
      #         [tf.config.LogicalDeviceConfiguration(memory_limit=1024),
      #         tf.config.LogicalDeviceConfiguration(memory_limit=1024)])
      #     logical_gpus = tf.config.list_logical_devices('GPU')
      #     print(len(gpus), "Physical GPU,", len(logical_gpus), "Logical GPUs")
      #   except RuntimeError as e:
      #     # Virtual devices must be set before GPUs have been initialized
      #     print(e)
    
    
      if FLAGS.exp_id:
         ...
    
    opened by vgthengane 0
  • Questions about the pre-trained ViT

    Questions about the pre-trained ViT

    Dear authors,

    Thanks for your great job in building CIL learners with pre-trained models. I have a simple question regarding the pre-trained ViT. I noticed there are several versions of pre-trained ViT on the market. Seeing the fact that the current repo suggests downloading the pre-trained model from https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz, does it mean the model is pre-trained with IN21K?

    Thx in advance.

    opened by zhoudw-zdw 0
  • Question regarding on FT-seq-Frozen

    Question regarding on FT-seq-Frozen

    In 'Learning to prompt for continual learning' paper, I understand 'FT-seq-Frozen' in Table 1 as a naive prompt tuning at the input token feature.

    To implement the FT-seq-Frozen setting in CIFAR100, I set prompt pool_size as 1. The result shows Acc@1 81.49 with Forgetting 6.3667.

    Any point that I missed? How did you set the hyperparamters for FT-seq-Frozen? Specifically, did you set the argument 'train_mask = False' for FT-seq-Frozen?

    opened by jcy132 0
  • Question regarding the average and last accuracy.

    Question regarding the average and last accuracy.

    Dear authors,

    Thank you for open-sourcing your work. I am slightly confused about the metrics you guys used for the evaluation. Here is my understanding from your readme.

    1. accuracy_n: accuracy evaluation on only the n-th task after training for the n-th task.
    2. forgetting: Average forgetting up until the current task.
    3. avg_acc: Average evaluation accuracy after training for the n-th task.
    for (3), after training for the n-th task, we got,
    acc_per_task = [a1, a2,...an];
    avg_acc = average(acc_per_task);
    

    Is it right?

    opened by vgthengane 0
  • Inference

    Inference

    Hi, thanks for the interesting work.

    I have one question regarding the choice of prompt during the testing.

    It seems that both DualPrompt and L2P use batch mode during testing, and each batch will choose the same prompt using majority voting. However, assuming that every test sample per batch is from the same task is questionable. Do you have any thoughts about this? Looking forward to hearing from you!

    Bests, Yuansheng

    opened by YUZ128pitt 6
  • Questions about the reproducibility of the code and the results of the paper

    Questions about the reproducibility of the code and the results of the paper

    I sincerely question the reproducibility of the code and the results of the paper, in this repo issue:

    1. I have seen people using the same V100 GPU as the author in time, but unable to run through the code and experiencing OOM errors. No reply was forthcoming #1 #20
    2. My own reproduction of the code does not achieve the results shown in the left panel of Figure 3 in the paper, and I do not understand why catastrophic forgetting does not occur when such statistical results occur. And even without using the Optionally diversifying prompt-selection method, I can't get this statistic, same as in the #18 #24 This issue comes with detailed statistics logs. By looking at the Histogram records, we can see that only four prompts were selected and that all tasks share these prompts. I think this is inevitably going to cause catastrophic forgetting.
    3. The use of pre-trained ViT may have caused an information leak #11.
    4. The given requirement.txt does not directly install the required runtime environment, and even if it does, it will only run on the CPU #1.

    And, for myself:

    1. This code is really hard to run on my RTX 3090 GPU, and even after a lot of effort and without any error reporting, the program is stuck at training step 5 of the first task.
    2. I have not seen anyone in the issues who has successfully reproduced the results.

    I very sincerely hope that the author will answer the above questions.

    opened by Dicer-Zz 9
CL-Gym: Full-Featured PyTorch Library for Continual Learning

CL-Gym: Full-Featured PyTorch Library for Continual Learning CL-Gym is a small yet very flexible library for continual learning research and developme

Iman Mirzadeh 36 Dec 25, 2022
PyTorch implementation of our Adam-NSCL algorithm from our CVPR2021 (oral) paper "Training Networks in Null Space for Continual Learning"

Adam-NSCL This is a PyTorch implementation of Adam-NSCL algorithm for continual learning from our CVPR2021 (oral) paper: Title: Training Networks in N

Shipeng Wang 34 Dec 21, 2022
Avalanche RL: an End-to-End Library for Continual Reinforcement Learning

Avalanche RL: an End-to-End Library for Continual Reinforcement Learning Avalanche Website | Getting Started | Examples | Tutorial | API Doc | Paper |

ContinualAI 43 Dec 24, 2022
Cl datasets - PyTorch image dataloaders and utility functions to load datasets for supervised continual learning

Continual learning datasets Introduction This repository contains PyTorch image

berjaoui 5 Aug 28, 2022
Pytorch Implementation of Continual Learning With Filter Atom Swapping (ICLR'22 Spolight) Paper

Continual Learning With Filter Atom Swapping Pytorch Implementation of Continual Learning With Filter Atom Swapping (ICLR'22 Spolight) Paper If find t

null 11 Aug 29, 2022
Official Pytorch implementation of Online Continual Learning on Class Incremental Blurry Task Configuration with Anytime Inference (ICLR 2022)

The Official Implementation of CLIB (Continual Learning for i-Blurry) Online Continual Learning on Class Incremental Blurry Task Configuration with An

NAVER AI 34 Oct 26, 2022
EMNLP 2021 Adapting Language Models for Zero-shot Learning by Meta-tuning on Dataset and Prompt Collections

Adapting Language Models for Zero-shot Learning by Meta-tuning on Dataset and Prompt Collections Ruiqi Zhong, Kristy Lee*, Zheng Zhang*, Dan Klein EMN

Ruiqi Zhong 42 Nov 3, 2022
The Few-Shot Bot: Prompt-Based Learning for Dialogue Systems

Few-Shot Bot: Prompt-Based Learning for Dialogue Systems This repository includes the dataset, experiments results, and code for the paper: Few-Shot B

Andrea Madotto 103 Dec 28, 2022
Learning to Prompt for Vision-Language Models.

CoOp Paper: Learning to Prompt for Vision-Language Models Authors: Kaiyang Zhou, Jingkang Yang, Chen Change Loy, Ziwei Liu CoOp (Context Optimization)

Kaiyang 679 Jan 4, 2023
PyTorch implementation of: Michieli U. and Zanuttigh P., "Continual Semantic Segmentation via Repulsion-Attraction of Sparse and Disentangled Latent Representations", CVPR 2021.

Continual Semantic Segmentation via Repulsion-Attraction of Sparse and Disentangled Latent Representations This is the official PyTorch implementation

Multimedia Technology and Telecommunication Lab 42 Nov 9, 2022
CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper)

CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper) (Accepted for oral presentation at ACM

Minha Kim 1 Nov 12, 2021
ICSS - Interactive Continual Semantic Segmentation

Presentation This repository contains the code of our paper: Weakly-supervised c

Alteia 9 Jul 23, 2022
[CVPR 2022] CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation

CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation Prerequisite Please create and activate the following conda envrionment. To r

Qin Wang 87 Jan 8, 2023
Official repository for the paper "Self-Supervised Models are Continual Learners" (CVPR 2022)

Self-Supervised Models are Continual Learners This is the official repository for the paper: Self-Supervised Models are Continual Learners Enrico Fini

Enrico Fini 73 Dec 18, 2022
[CVPR2022] Representation Compensation Networks for Continual Semantic Segmentation

RCIL [CVPR2022] Representation Compensation Networks for Continual Semantic Segmentation Chang-Bin Zhang1, Jia-Wen Xiao1, Xialei Liu1, Ying-Cong Chen2

Chang-Bin Zhang 71 Dec 28, 2022
The Power of Scale for Parameter-Efficient Prompt Tuning

The Power of Scale for Parameter-Efficient Prompt Tuning Implementation of soft embeddings from https://arxiv.org/abs/2104.08691v1 using Pytorch and H

Kip Parker 208 Dec 30, 2022
This repository accompanies our paper “Do Prompt-Based Models Really Understand the Meaning of Their Prompts?”

This repository accompanies our paper “Do Prompt-Based Models Really Understand the Meaning of Their Prompts?” Usage To replicate our results in Secti

Albert Webson 64 Dec 11, 2022
The code for our paper "NSP-BERT: A Prompt-based Zero-Shot Learner Through an Original Pre-training Task —— Next Sentence Prediction"

The code for our paper "NSP-BERT: A Prompt-based Zero-Shot Learner Through an Original Pre-training Task —— Next Sentence Prediction"

Sun Yi 201 Nov 21, 2022
Feed forward VQGAN-CLIP model, where the goal is to eliminate the need for optimizing the latent space of VQGAN for each input prompt

Feed forward VQGAN-CLIP model, where the goal is to eliminate the need for optimizing the latent space of VQGAN for each input prompt. This is done by

Mehdi Cherti 135 Dec 30, 2022