An Extendible (General) Continual Learning Framework based on Pytorch - official codebase of Dark Experience for General Continual Learning

Overview

Mammoth - An Extendible (General) Continual Learning Framework for Pytorch

NEWS

STAY TUNED: We are working on an update of this repository to include the codebase of our extended paper Class-Incremental Continual Learning into the eXtended DER-verse.


Official repository of Dark Experience for General Continual Learning: a Strong, Simple Baseline

Sequential MNIST Sequential CIFAR-10 Sequential TinyImagenet Permuted MNIST Rotated MNIST MNIST-360

Setup

  • Use ./utils/main.py to run experiments.
  • Use argument --load_best_args to use the best hyperparameters from the paper.
  • New models can be added to the models/ folder.
  • New datasets can be added to the datasets/ folder.

Models

  • Gradient Episodic Memory (GEM)
  • A-GEM
  • A-GEM with Reservoir (A-GEM-R)
  • Experience Replay (ER)
  • Meta-Experience Replay (MER)
  • Function Distance Regularization (FDR)
  • Greedy gradient-based Sample Selection (GSS)
  • Hindsight Anchor Learning (HAL)
  • Incremental Classifier and Representation Learning (iCaRL)
  • online Elastic Weight Consolidation (oEWC)
  • Synaptic Intelligence
  • Learning without Forgetting
  • Progressive Neural Networks
  • Dark Experience Replay (DER)
  • Dark Experience Replay++ (DER++)

Datasets

Class-Il / Task-IL settings

  • Sequential MNIST
  • Sequential CIFAR-10
  • Sequential Tiny ImageNet

Domain-IL settings

  • Permuted MNIST
  • Rotated MNIST

General Continual Learning setting

  • MNIST-360

Citing this work

@inproceedings{buzzega2020dark,
 author = {Buzzega, Pietro and Boschini, Matteo and Porrello, Angelo and Abati, Davide and Calderara, Simone},
 booktitle = {Advances in Neural Information Processing Systems},
 editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
 pages = {15920--15930},
 publisher = {Curran Associates, Inc.},
 title = {Dark Experience for General Continual Learning: a Strong, Simple Baseline},
 volume = {33},
 year = {2020}
}
Comments
  • Proper Benchmarking for DER and DER++

    Proper Benchmarking for DER and DER++

    Hi!

    I'm looking to add DER / DER++ results as a baseline in my paper. I'm interested in the setting where only a single pass is done through the dataset (i.e. num_epochs == 1). I was wondering if you had any advice / suggestions regarding hyperparameter selections in this setting ? Specifically alpha and beta and the learning rate. Thanks in advance :)

    and great job on the library! I'm usually not a fan; I usually find them too convoluted. This one is super minimal and easy to play with đź‘Ś

    opened by cl-for-life 6
  • Doubt related to transform in the buffer.py

    Doubt related to transform in the buffer.py

    Hello,

    I have a question related to the der.py and buffer.py, specifically related to the applied transforms for the data augmentation. Following are the transformation used in der for split cifar 10

    Compose(
        ToPILImage()
        Compose(
        RandomCrop(size=(32, 32), padding=4)
        RandomHorizontalFlip(p=0.5)
        ToTensor()
        Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2615))
    ))
    

    While we store the samples in the buffer, we always save the non-augmented inputs and the corresponding logits, as shown in the following snippet.

    self.buffer.add_data(examples=not_aug_inputs, logits=outputs.data)
    
        def add_data(self, examples, labels=None, logits=None, task_labels=None):
            if not hasattr(self, 'examples'):
                self.init_tensors(examples, labels, logits, task_labels)
    
            for i in range(examples.shape[0]):
                index = reservoir(self.num_seen_examples, self.buffer_size)
                self.num_seen_examples += 1
                if index >= 0:
                    self.examples[index] = examples[i].to(self.device)
                    if labels is not None:
                        self.labels[index] = labels[i].to(self.device)
                    if logits is not None:
                        self.logits[index] = logits[i].to(self.device)
                    if task_labels is not None:
                        self.task_labels[index] = task_labels[i].to(self.device)
    
    

    Now, when we call get_all_elements or get_data

        def get_data(self, size: int, transform: transforms=None) -> Tuple:
            """
            Random samples a batch of size items.
            :param size: the number of requested items
            :param transform: the transformation to be applied (data augmentation)
            :return:
            """
            if size > min(self.num_seen_examples, self.examples.shape[0]):
                size = min(self.num_seen_examples, self.examples.shape[0])
    
            choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]),
                                      size=size, replace=False)
            if transform is None: transform = lambda x: x
            ret_tuple = (torch.stack([transform(ee.cpu())
                                for ee in self.examples[choice]]).to(self.device),)
            for attr_str in self.attributes[1:]:
                if hasattr(self, attr_str):
                    attr = getattr(self, attr_str)
                    ret_tuple += (attr[choice],)
    
            return ret_tuple
    

    It applies the set of transformations, on the non_augmented examples.

    Now my question is the following- when we request for the elements from the buffer, would the transformation on the top of non_augmented images still be the same, which generated the corresponding logits.? Since the transforms are stochastic, (crop/flip), it seems to give different example in contrast with the original transformed input, which generated the logits.

    It will be great if you can answer my query, thanks!!

    opened by bhattg 4
  • Gradient computation in case of multi-head (task incremental setting)

    Gradient computation in case of multi-head (task incremental setting)

    Hi,

    Thanks for the great baseline. I was looking at the code implementation for A-GEM models/agem.py and in particular the training pipeline for the multi-headed case. In this regard, I have a few doubts about the code.

    During the forward pass in the case of a multi-headed setting (Task-IL), when we say have classes [0, 1, 2, 3, 4] in task 0, why is there no-masking? From what I've understood from the code flow, the following calls happen:

    (I am only mentioning the steps for the A-GEM, therefore will be omitting the steps not required in A-GEM)

    for epoch in range(args.n_epochs):
        for i, data in enumerate(train_loader):
            if hasattr(dataset.train_loader.dataset, 'logits'):
                # ignoring this
            else:
                inputs, labels, not_aug_inputs = data
                inputs, labels = inputs.to(model.device), labels.to(
                    model.device)
                not_aug_inputs = not_aug_inputs.to(model.device)
                loss = model.observe(inputs, labels, not_aug_inputs)
    
    

    When model.observe(inputs, labels, not_aug_inputs is called, it executes the following steps:

    self.zero_grad()
    p = self.net.forward(inputs)
    loss = self.loss(p, labels) #Here why there is no masking on the output p, for the classes corresponding to the other tasks? 
    loss.backward()
    self.opt.step()
    return loss.item()
    

    As I've mentioned in the comment on line 3, why there is no masking on the output p, for the classes corresponding to the other tasks? If there is no masking then the weights corresponding to the classes in the other tasks will also change, and thus defeating the purpose for multi-head.

    Similarly, for computing the reference gradient (in the observe function, inside agem.py), why there is no masking corresponding to the tasks in the memory buffer?

    if not self.buffer.is_empty():
        store_grad(self.parameters, self.grad_xy, self.grad_dims)
        buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size, transform=self.transform)
        self.net.zero_grad()
        buf_outputs = self.net.forward(buf_inputs)
        penalty = self.loss(buf_outputs, buf_labels)
        penalty.backward()
    

    All of these things are present in the original A-GEM source code (TF implementation) Line 392 to Line 426

    Please let me know your thoughts on my concerns. Thanks for the PyTorch implementations.

    opened by bhattg 4
  • Mask shape error in xder code implementation

    Mask shape error in xder code implementation

    Hello. Thank you for your recent work and for sharing xder code.

    While I tried to reproduce your work on my one, I encountered a troublesome as follows.


    File ".../models/xder.py", line 143, in update_logits gt_values = old[torch.arange(len(gt)), gt] IndexError: The shape of the mask [32] at index 0 does not match the shape of the indexed tensor [32, 10] at index 1

    Have you ever met such problem? I'd like to ask you how to fix it. Thank you.

    opened by oyt9306 3
  • ICarl seq-tinyimg accuracy

    ICarl seq-tinyimg accuracy

    Hi! I have a question about the results reported in the paper "Dark Experience for General Continual Learning: a Strong, Simple Baseline" When I run: python ./utils/main.py --model=icarl --dataset=seq-tinyimg --load_best_args --buffer_size=200 the accuracy is around 14% but in the paper the accuracy is around 7%. Have I done something wrong? Thank you in advance.

    opened by gab709 2
  • Resolving python package errors

    Resolving python package errors

    Hi, thanks for this awesome library.

    When I tried running

    python utils/main.py --model=icarl --dataset=seq-cifar10 --buffer_size=500

    I got the following error:

    [email protected]:~/workspace/mammoth$ python utils/main.py --model=icarl --dataset=seq-cifar10 --buffer_size=500
    Traceback (most recent call last):
      File "utils/main.py", line 7, in <module>
        from datasets import NAMES as DATASET_NAMES
    ImportError: cannot import name 'NAMES'
    

    I could resolve this by adding:

    import os
    import sys
    conf_path = os.getcwd()
    sys.path.append(conf_path)
    sys.path.append(conf_path + '/datasets')
    sys.path.append(conf_path + '/backbone')
    sys.path.append(conf_path + '/models')
    

    to mammoth/utils/main.py.

    I have updated my fork with it.

    I just thought of brining it up here, so that no one else might fret around this issue.

    Thanks!

    opened by JosephKJ 2
  • iCARL method not working.

    iCARL method not working.

    Hi, I was using the library and it seems like the iCARL method is not working. I figured out the reason and it seems like while selecting the samples to store for the buffer, the normalization transformation is to be applied onto the whole batch of images. The transform.Normalize only work for a single image. An easy fix is to do replace this https://github.com/aimagelab/mammoth/blob/0b6f1434b06ed9e53a10c120d1efef8015184b1c/models/icarl.py#L70 with feats = self.net(torch.cat([norm_trans(aa).unsqueeze(0) for aa in not_norm_x]), returnt='features')

    This might not be the most efficient way but it works. I hope this helps!

    opened by prateeky2806 1
  • How do I provide the continual learning setting?

    How do I provide the continual learning setting?

    Hi, thank you for the amazing paper and repository. I am trying to run some experiments on CIFAR100 in task incremental setting with 10 different 10-way classification tasks. Can you please tell me how to specify the continual learning setting for different methods? For example, I see der is compatible with incremental and task-incremental settings so how do I specficy which setting I want to run experiments for? If you can list a few example command it would be very helpful as I don't se an argparse argument for this.

    Thanks, Prateek

    opened by prateeky2806 1
  • Gradients of fc1.weight are always zero during the training

    Gradients of fc1.weight are always zero during the training

    How to reproduce? Add the following line of code in line 107 of training.py to view the gradient values stored in the parameters of the fc1 layer after each task. Only bias is being updated.

    for name, param in model.named_parameters():
        if 'net.fc1' in name:
            print(name, param.grad) 
    

    Test script on seq-mnist dataset using der model python3 utils/main.py --model der --dataset seq-mnist --lr 0.03 --batch_size=16 --n_epochs=1 --buffer_size=500 --minibatch_size=16 --alpha=0.1 Also tested on derpp, agem, lwf models and on rot-mnist dataset.

    opened by shafeef901 1
  • Question regarding EWC_ON

    Question regarding EWC_ON

    Dear authors, I reciprocated your results on ewc_on using your code. I could get 19.49% on CIFAR-10(same as yours). However, I see that catastrophic forgetting is actually catastrophic i.e performance on previous tasks hit zero once the model is past those tasks. Please see graph attached for more info. acc_mean: acc_mean

    acc_task01: acc_task01

    acc_task02: acc_task02

    Please let me know whether this is due to the implementation or is it a proper behavior of ewc_on.

    opened by bhat-prashant 1
  • On the Joint method

    On the Joint method

    Hi, thanks for the code base, I want to use it in my work. However, I met a problem when I run the joint method. If I am not misunderstanding, It looks like the "observe" function is not defined correctly. Could you please help with it? Thanks.

    opened by MiaoyunZhao 1
Owner
AImageLab
AImageLab
Dark Finix: All in one hacking framework with almost 100 tools

Dark Finix - Hacking Framework. Dark Finix is a all in one hacking framework wit

Md. Nur habib 2 Feb 18, 2022
A general and strong 3D object detection codebase that supports more methods, datasets and tools (debugging, recording and analysis).

ALLINONE-Det ALLINONE-Det is a general and strong 3D object detection codebase built on OpenPCDet, which supports more methods, datasets and tools (de

Michael.CV 5 Nov 3, 2022
A general framework for deep learning experiments under PyTorch based on pytorch-lightning

torchx Torchx is a general framework for deep learning experiments under PyTorch based on pytorch-lightning. TODO list gan-like training wrapper text

Yingtian Liu 6 Mar 17, 2022
Official Pytorch implementation of Online Continual Learning on Class Incremental Blurry Task Configuration with Anytime Inference (ICLR 2022)

The Official Implementation of CLIB (Continual Learning for i-Blurry) Online Continual Learning on Class Incremental Blurry Task Configuration with An

NAVER AI 34 Oct 26, 2022
Official Repository for the ICCV 2021 paper "PixelSynth: Generating a 3D-Consistent Experience from a Single Image"

PixelSynth: Generating a 3D-Consistent Experience from a Single Image (ICCV 2021) Chris Rockwell, David F. Fouhey, and Justin Johnson [Project Website

Chris Rockwell 94 Oct 29, 2022
Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling.

Decision Transformer Lili Chen*, Kevin Lu*, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas†, and Igor M

Kevin Lu 1.4k Nov 22, 2022
Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World

Legged Robots that Keep on Learning Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World, whic

Laura Smith 67 Oct 18, 2022
Official codebase for "B-Pref: Benchmarking Preference-BasedReinforcement Learning" contains scripts to reproduce experiments.

B-Pref Official codebase for B-Pref: Benchmarking Preference-BasedReinforcement Learning contains scripts to reproduce experiments. Install conda env

null 44 Nov 22, 2022
A neuroanatomy-based augmented reality experience powered by computer vision. Features 3D visuals of the Atlas Brain Map slices.

Brain Augmented Reality (AR) A neuroanatomy-based augmented reality experience powered by computer vision that features 3D visuals of the Atlas Brain

Yasmeen Brain 10 Oct 6, 2022
PyTorch implementation for ACL 2021 paper "Maria: A Visual Experience Powered Conversational Agent".

Maria: A Visual Experience Powered Conversational Agent This repository is the Pytorch implementation of our paper "Maria: A Visual Experience Powered

Jokie 22 Nov 12, 2022
A general python framework for single object tracking in LiDAR point clouds, based on PyTorch Lightning.

Open3DSOT A general python framework for single object tracking in LiDAR point clouds, based on PyTorch Lightning. The official code release of BAT an

Kangel Zenn 161 Nov 22, 2022
Official code of CVPR 2021's PLOP: Learning without Forgetting for Continual Semantic Segmentation

PLOP: Learning without Forgetting for Continual Semantic Segmentation This repository contains all of our code. It is a modified version of Cermelli e

Arthur Douillard 112 Nov 14, 2022
Official codebase for Pretrained Transformers as Universal Computation Engines.

universal-computation Overview Official codebase for Pretrained Transformers as Universal Computation Engines. Contains demo notebook and scripts to r

Kevin Lu 207 Nov 14, 2022
This codebase is the official implementation of Test-Time Classifier Adjustment Module for Model-Agnostic Domain Generalization (NeurIPS2021, Spotlight)

Test-Time Classifier Adjustment Module for Model-Agnostic Domain Generalization This codebase is the official implementation of Test-Time Classifier A

null 42 Nov 24, 2022
Official codebase for running the small, filtered-data GLIDE model from GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models.

GLIDE This is the official codebase for running the small, filtered-data GLIDE model from GLIDE: Towards Photorealistic Image Generation and Editing w

OpenAI 2.8k Nov 21, 2022
Official codebase for ICLR oral paper Unsupervised Vision-Language Grammar Induction with Shared Structure Modeling

CLIORA This is the official codebase for ICLR oral paper: Unsupervised Vision-Language Grammar Induction with Shared Structure Modeling. We introduce

Bo Wan                                             30 Nov 18, 2022
[CVPR22] Official codebase of Semantic Segmentation by Early Region Proxy.

RegionProxy Figure 2. Performance vs. GFLOPs on ADE20K val split. Semantic Segmentation by Early Region Proxy Yifan Zhang, Bo Pang, Cewu Lu CVPR 2022

Yifan 52 Nov 16, 2022
Official codebase used to develop Vision Transformer, MLP-Mixer, LiT and more.

Big Vision This codebase is designed for training large-scale vision models on Cloud TPU VMs. It is based on Jax/Flax libraries, and uses tf.data and

Google Research 613 Nov 19, 2022
A DeepStack custom model for detecting common objects in dark/night images and videos.

DeepStack_ExDark This repository provides a custom DeepStack model that has been trained and can be used for creating a new object detection API for d

MOSES OLAFENWA 97 Nov 17, 2022