training script for space time memory network

Overview

Trainig Script for Space Time Memory Network

This codebase implemented training code for Space Time Memory Network with some cyclic features.

sample results

Requirement

python package

  • torch
  • python-opencv
  • pillow
  • yaml
  • imgaug
  • yacs
  • progress
  • nvidia-dali (optional)

GPU support

  • GPU Memory >= 12GB
  • CUDA >= 10.0

Data

See the doc DATASET.md for more details on data organization of our prepared dataset.

Release

We provide pre-trained model with different backbone in our codebase, results are validated on DAVIS17-val with gradient correction.

model backbone data backend J F J & F link FPS
STM-Cycle Resnet18 DALI 65.3 70.8 68.1 Google Drive 14.8
STM-Cycle Resnet50 PIL 70.5 76.3 73.4 Google Drive 9.3

Runing

Appending the root folder to the search path of python interpreter

export PYTHONPATH=${PYTHONPATH}:./

To train the STM network, run following command.

python3 train.py --cfg config.yaml OPTION_KEY OPTION_VAL

To test the STM network, run following command

python3 test.py --cfg config.yaml initial ${PATH_TO_MODEL} OPTION_KEY OPTION_VAL

The test results will be saved as indexed png file at ${ROOT}/${output_dir}/${valset}.

To run a segmentation demo, run following command

python3 demo/demo.py --cfg demo/demo.yaml OPTION_KEY OPTION_VAL

The segmentation results will be saved at ${output_dir}.

Acknowledgement

This codebase borrows the code and structure from official STM repository

Reference

The codebase is built based on following works

@InProceedings{Oh_2019_ICCV,
author = {Oh, Seoung Wug and Lee, Joon-Young and Xu, Ning and Kim, Seon Joo},
title = {Video Object Segmentation Using Space-Time Memory Networks},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2019}
}

@InProceedings{Li_2020_NeurIPS,
author = {Li, Yuxi and Xu, Ning and Peng Jinlong and John See and Lin Weiyao},
title = {Delving into the Cyclic Mechanism in Semi-supervised Video Object Segmentation},
booktitle = {Neural Information Processing System (NeurIPS)},
year = {2020}
}
Comments
  • When using my own dataset to train the model, the loss cannot be reduced.

    When using my own dataset to train the model, the loss cannot be reduced.

    Can you give me some advice?


    My class reading the dataset (semi-supervised vos):

    
    class MyDataset(BaseData):
    
        def __init__(self, train=True, sampled_frames=3,
                     transform=None, max_skip=5, increment=5, samples_per_video=12):
            print(" ==>> Using MyDataset <<== ")
    
            data_dir = os.path.join(ROOT, 'mydata_dir')
    
            if train:
                dbfile = os.path.join(data_dir, 'ImageSets', 'train_valid.txt')
            else:
                dbfile = os.path.join(data_dir, 'ImageSets', 'test.txt')
    
            self.imgdir = os.path.join(data_dir, 'JPEGImages')
            self.annodir = os.path.join(data_dir, 'Annotations')
    
            self.root = data_dir
            self.max_obj = 0
    
            # extract annotation information
            self.videos = []
            with open(dbfile, 'r') as f:
                video_name = f.readline()
                while video_name:
                    video_name = video_name.strip()
                    self.videos.append(video_name)
                    objn = np.array(Image.open(os.path.join(self.annodir, video_name, '00000.png')).convert('P')).max()
                    self.max_obj = max(objn, self.max_obj)
                    video_name = f.readline()
            print(" ==>> Length of Trainset: {}".format(len(self.videos)))
    
            self.samples_per_video = samples_per_video
            self.sampled_frames = sampled_frames
            self.length = samples_per_video * len(self.videos)
            self.max_skip = max_skip
            self.increment = increment
    
            self.transform = transform
            self.train = train
    
        def increase_max_skip(self):
            self.max_skip = min(self.max_skip + self.increment, MAX_TRAINING_SKIP)
    
        def set_max_skip(self, max_skip):
            self.max_skip = max_skip
    
        def __getitem__(self, idx):
    
            video_name = self.videos[(idx // self.samples_per_video)]
    
            imgfolder = os.path.join(self.imgdir, video_name)
            annofolder = os.path.join(self.annodir, video_name)
    
            frames = [name[:5] for name in os.listdir(imgfolder)]
            frames.sort()
            nframes = len(frames)
    
            num_obj = 0
            while num_obj == 0:
    
                if self.train:
                    last_sample = -1
                    sample_frame = []
    
                    nsamples = min(self.sampled_frames, nframes)
                    for i in range(nsamples):
                        if i == 0:
                            last_sample = random.sample(range(0, nframes - nsamples + 1), 1)[0]
                        else:
                            last_sample = random.sample(
                                range(last_sample + 1, min(last_sample + self.max_skip + 1, nframes - nsamples + i + 1)),
                                1)[0]
                        sample_frame.append(frames[last_sample])
                    mask = [np.array(Image.open(os.path.join(annofolder, name + '.png'))) for name in sample_frame]
                else:
                    sample_frame = frames
                    mask = []
                    for i, name in enumerate(sample_frame):
                        if i == 0:
                            mask.append(np.array(Image.open(os.path.join(annofolder, name + '.png'))))
                        else:
                            mask.append(np.ones_like(mask[0]) * 255)
                frame = [np.array(Image.open(os.path.join(imgfolder, name + '.jpg'))) for name in sample_frame]
    
                # clear dirty data
                for msk in mask:
                    msk[msk == 255] = 0
    
                num_obj = mask[0].max()
    
            # if self.train:
            #     num_obj = min(num_obj, MAX_TRAINING_OBJ)
    
            info = dict(
                name=video_name,
                palette=Image.open(os.path.join(annofolder, frames[0] + '.png')).getpalette(),
                size=frame[0].shape[:2],
                frame_index_list=sample_frame,
            )
    
            mask = [convert_mask(msk, self.max_obj) for msk in mask]
    
            if self.transform is None:
                raise RuntimeError('Lack of proper transformation')
    
            frame, mask = self.transform(frame, mask, False)
    
            if self.train:
                num_obj = 0
                for i in range(1, MAX_TRAINING_OBJ + 1):
                    if torch.sum(mask[0, i]) > 0:
                        num_obj += 1
                    else:
                        break
    
            return frame, mask, num_obj, info
    
        def __len__(self):
    
            return self.length
    
    DATA_CONTAINER['MyDataset'] = MyDataset
    

    and the config:

    from easydict import EasyDict
    
    OPTION = EasyDict()
    
    # ------------------------------------------ data configuration ---------------------------------------------
    OPTION.trainset = 'MyDataset'
    OPTION.valset = 'MyDataset'
    OPTION.datafreq = [5, 1]  # unused
    OPTION.input_size = (384, 384)  # input image size
    OPTION.sampled_frames = 4  # min sampled time length while trianing
    # OPTION.max_skip = [5, 3]  # max skip time length while trianing
    OPTION.max_skip = 3  # max skip time length while trianing
    OPTION.samples_per_video = 2  # sample numbers per video
    
    # ----------------------------------------- model configuration ---------------------------------------------
    OPTION.keydim = 128
    OPTION.valdim = 512
    OPTION.save_freq = 5
    OPTION.epochs_per_increment = 5
    
    # ---------------------------------------- training configuration -------------------------------------------
    OPTION.epochs = 120
    OPTION.train_batch = 4
    OPTION.learning_rate = 0.00001
    OPTION.gamma = 0.1
    OPTION.momentum = (0.9, 0.999)
    OPTION.solver = 'adam'  # 'sgd' or 'adam'
    OPTION.weight_decay = 5e-4
    OPTION.iter_size = 1
    OPTION.milestone = []  # epochs to degrades the learning rate
    OPTION.loss = 'both'  # 'ce' or 'iou' or 'both'
    OPTION.mode = 'recurrent'  # 'mask'(记忆网络中存储的是真值) or 'recurrent'(原始论文的循环训练的方式) or 'threshold'
    OPTION.iou_threshold = 0.65  # used only for 'threshold' training
    
    # ---------------------------------------- testing configuration --------------------------------------------
    OPTION.epoch_per_test = 1
    
    # ------------------------------------------- other configuration -------------------------------------------
    OPTION.checkpoint = 'mydataset'
    OPTION.initial = '/home/lart/Coding/STM/STM_weights.pth'  # path to initialize the backbone
    # OPTION.initial = ''  # path to initialize the backbone
    OPTION.resume = ''  # path to restart from the checkpoint
    OPTION.gpu_id = '0'  # defualt gpu-id (if not specified in cmd)
    OPTION.workers = 4
    OPTION.save_indexed_format = True  # set True to save indexed format png file, otherwise segmentation with original image
    OPTION.output_dir = 'output'
    

    Here, I use the pretrained parameter file STM_weights.pth by STM's author to initialize the model.

    My train.py:

    import argparse
    import os
    import os.path as osp
    import random
    import time
    from collections import OrderedDict
    
    import torch.nn as nn
    import torch.optim as optim
    import torch.utils.data as data
    
    from libs.dataset.data import DATA_CONTAINER, multibatch_collate_fn
    from libs.dataset.transform import TrainTransform, TestTransform
    # from libs.models.models import STM
    from libs.models.models_cp import STM
    from libs.utils.logger import Logger, AverageMeter
    from libs.utils.loss import *
    from libs.utils.utility import save_checkpoint, adjust_learning_rate
    from options import OPTION as opt
    
    MAX_FLT = 1e6
    
    
    def parse_args():
        parser = argparse.ArgumentParser('Training Mask Segmentation')
        parser.add_argument('--gpu', default='', type=str, help='set gpu id to train the network, split with comma')
        return parser.parse_args()
    
    
    def main():
        start_epoch = 0
        random.seed(0)
    
        args = parse_args()
        # Use GPU
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu if args.gpu != '' else str(opt.gpu_id)
        use_gpu = torch.cuda.is_available() and (args.gpu != '' or int(opt.gpu_id)) >= 0
        gpu_ids = [int(val) for val in args.gpu.split(',')]
    
        if not os.path.isdir(opt.checkpoint):
            os.makedirs(opt.checkpoint)
    
        # Data
        print('==> Preparing dataset')
    
        input_dim = opt.input_size
    
        train_transformer = TrainTransform(size=input_dim)
        test_transformer = TestTransform(size=input_dim)
    
        try:
            if isinstance(opt.trainset, list):
                print("[DATASET] List {}".format(opt.trainset))
                datalist = []
                for dataset, freq, max_skip in zip(opt.trainset, opt.datafreq, opt.max_skip):
                    ds = DATA_CONTAINER[dataset](
                        train=True,
                        sampled_frames=opt.sampled_frames,
                        transform=train_transformer,
                        max_skip=max_skip,
                        samples_per_video=opt.samples_per_video
                    )
                    datalist += [ds] * freq
    
                trainset = data.ConcatDataset(datalist)
    
            else:
                print("[DATASET] {}".format(opt.trainset))
                max_skip = opt.max_skip[0] if isinstance(opt.max_skip, list) else opt.max_skip
                trainset = DATA_CONTAINER[opt.trainset](
                    train=True,
                    sampled_frames=opt.sampled_frames,
                    transform=train_transformer,
                    max_skip=max_skip,
                    samples_per_video=opt.samples_per_video
                )
        except KeyError as ke:
            print('[ERROR] invalide dataset name is encountered. The current acceptable datasets are:')
            print(list(DATA_CONTAINER.keys()))
            exit()
    
        # testset = DATA_CONTAINER[opt.valset](
        #     train=False,
        #     transform=test_transformer,
        #     samples_per_video=1
        # )
    
        trainloader = data.DataLoader(trainset, batch_size=opt.train_batch, shuffle=True, num_workers=opt.workers,
                                      collate_fn=multibatch_collate_fn, drop_last=True)
    
        # testloader = data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=opt.workers,
        #                              collate_fn=multibatch_collate_fn)
        # Model
        print("==> creating model")
    
        net = STM(opt.keydim, opt.valdim, 'train',
                  mode=opt.mode, iou_threshold=opt.iou_threshold)
        print('    Total params: %.2fM' % (sum(p.numel() for p in net.parameters()) / 1000000.0))
        net.eval()
        if use_gpu:
            net = net.cuda()
    
        assert opt.train_batch % len(gpu_ids) == 0
        net = nn.DataParallel(net, device_ids=gpu_ids, dim=0)
    
        # set training parameters
        for p in net.parameters():
            p.requires_grad = True
    
        criterion = None
        celoss = cross_entropy_loss
    
        if opt.loss == 'ce':
            criterion = celoss
        elif opt.loss == 'iou':
            criterion = mask_iou_loss
        elif opt.loss == 'both':
            criterion = lambda pred, target, obj: celoss(pred, target, obj) + mask_iou_loss(pred, target, obj)
        else:
            raise TypeError('unknown training loss %s' % opt.loss)
    
        optimizer = None
    
        if opt.solver == 'sgd':
    
            optimizer = optim.SGD(net.parameters(), lr=opt.learning_rate,
                                  momentum=opt.momentum[0], weight_decay=opt.weight_decay)
        elif opt.solver == 'adam':
    
            optimizer = optim.Adam(net.parameters(), lr=opt.learning_rate,
                                   betas=opt.momentum, weight_decay=opt.weight_decay)
        else:
            raise TypeError('unkown solver type %s' % opt.solver)
    
        # Resume
        title = 'STM'
        minloss = float('inf')
    
        opt.checkpoint = osp.join(osp.join(opt.checkpoint, opt.valset))
        if not osp.exists(opt.checkpoint):
            os.mkdir(opt.checkpoint)
    
        if opt.resume:
            # Load checkpoint.
            print('==> Resuming from checkpoint {}'.format(opt.resume))
            assert os.path.isfile(opt.resume), 'Error: no checkpoint directory found!'
            # opt.checkpoint = os.path.dirname(opt.resume)
            checkpoint = torch.load(opt.resume)
            minloss = checkpoint['minloss']
            start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            skips = checkpoint['max_skip']
    
            try:
                if isinstance(skips, list):
                    for idx, skip in enumerate(skips):
                        trainloader.dataset.datasets[idx].set_max_skip(skip)
                else:
                    trainloader.dataset.set_max_skip(skips)
            except:
                print('[Warning] Initializing max skip fail')
    
            logger = Logger(os.path.join(opt.checkpoint, opt.mode + '_log.txt'), resume=True)
        else:
            if opt.initial:
                print('==> Initialize model with weight file {}'.format(opt.initial))
                weight = torch.load(opt.initial)
                if isinstance(weight, OrderedDict):
                    net.module.load_param(weight)
                else:
                    net.module.load_param(weight['state_dict'])
    
            logger = Logger(os.path.join(opt.checkpoint, opt.mode + '_log.txt'), resume=False)
            start_epoch = 0
    
        logger.set_items(['Epoch', 'LR', 'Train Loss'])
    
        # Train and val
        for epoch in range(start_epoch):
            adjust_learning_rate(optimizer, epoch, opt)
    
        for epoch in range(start_epoch, opt.epochs):
    
            print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, opt.epochs, opt.learning_rate))
            adjust_learning_rate(optimizer, epoch, opt)
    
            net.module.phase = 'train'
            train_loss = train(trainloader,
                               model=net,
                               criterion=criterion,
                               optimizer=optimizer,
                               epoch=epoch,
                               use_cuda=use_gpu,
                               iter_size=opt.iter_size,
                               mode=opt.mode,
                               threshold=opt.iou_threshold)
    
            # no test
            # if (epoch + 1) % opt.epoch_per_test == 0:
            #     net.module.phase = 'test'
            #     test_loss = test(testloader,
            #                      model=net.module,
            #                      criterion=criterion,
            #                      epoch=epoch,
            #                      use_cuda=use_gpu)
    
            # append logger file
            logger.log(epoch + 1, opt.learning_rate, train_loss)
    
            # adjust max skip
            if (epoch + 1) % opt.epochs_per_increment == 0:
                if isinstance(trainloader.dataset, data.ConcatDataset):
                    for dataset in trainloader.dataset.datasets:
                        dataset.increase_max_skip()
                else:
                    trainloader.dataset.increase_max_skip()
    
            # save model
            is_best = train_loss <= minloss
            minloss = min(minloss, train_loss)
            skips = [ds.max_skip for ds in trainloader.dataset.datasets] \
                if isinstance(trainloader.dataset, data.ConcatDataset) \
                else trainloader.dataset.max_skip
    
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': net.state_dict(),
                'loss': train_loss,
                'minloss': minloss,
                'optimizer': optimizer.state_dict(),
                'max_skip': skips,
            }, epoch + 1, is_best, checkpoint=opt.checkpoint, filename=opt.mode)
    
        logger.close()
    
        print('minimum loss:')
        print(minloss)
    
    
    def train(trainloader, model, criterion, optimizer, epoch, use_cuda, iter_size, mode, threshold):
        # switch to train mode
    
        data_time = AverageMeter()
        loss = AverageMeter()
    
        end = time.time()
    
        # bar = Bar('Processing', max=len(trainloader))
        optimizer.zero_grad()
    
        for batch_idx, data in enumerate(trainloader):
            frames, masks, objs, infos = data
            # measure data loading time
            data_time.update(time.time() - end)
    
            if use_cuda:
                frames = frames.cuda()
                masks = masks.cuda()
                objs = objs.cuda()
    
            objs[objs == 0] = 1
    
            N, T, C, H, W = frames.size()
            max_obj = masks.shape[2] - 1
    
            total_loss = 0.0
            out = model(frame=frames, mask=masks, num_objects=objs)
            for idx in range(N):
                for t in range(1, T):
                    gt = masks[idx, t:t + 1]
                    pred = out[idx, t - 1: t]
                    No = objs[idx].item()
    
                    total_loss = total_loss + criterion(pred, gt, No)
    
            total_loss = total_loss / (N * (T - 1))
    
            # record loss
            if total_loss.item() > 0.0:
                loss.update(total_loss.item(), 1)
    
            # compute gradient and do SGD step (divided by accumulated steps)
            total_loss /= iter_size
            total_loss.backward()
    
            if (batch_idx + 1) % iter_size == 0:
                optimizer.step()
                model.zero_grad()
    
            # measure elapsed time
            end = time.time()
    
            # plot progress
            log = '({batch}/{size}/{epoch}) Name: {name} Idx: {idx} |Data: {data:.3f}s |Loss: {loss:.5f}'.format(
                batch=batch_idx + 1,
                size=len(trainloader),
                epoch=epoch,
                name=[infos[i]['name'] for i in range(opt.train_batch)],
                idx=[infos[i]['frame_index_list'] for i in range(opt.train_batch)],
                data=data_time.val,
                loss=loss.avg
            )
            print(log)
    
        return loss.avg
    
    
    if __name__ == '__main__':
        main()
    
    opened by lartpang 0
  • How to implement Multi-GPU for training

    How to implement Multi-GPU for training

    I've tried to implement data-parallel for training in multi-GPUs but it doesn't work. The model only runs in my first GPU. So I replace the model with ResNet with no extra tensor operation and it works normally. Maybe tensor operation renders data-parallel. Could you tell me how to make the model parallel in DataParallel or Distributed DataParallel way?

    opened by nku-zhichengzhang 3
  • An error occurred during evaluation on the Davis dataset #36

    An error occurred during evaluation on the Davis dataset #36

    Hi all,

    This is my command I run and I got an unexpected error. I change the path to DAVIS correctly.

    python evaluation_method.py --task semi-supervised --results_path my_results/semi-supervised Evaluating sequences for the semi-supervised task... 0%| | 0/30 [00:00<?, ?it/s]bike-packing frame 00001 not found! The frames have to be indexed PNG files placed inside the corespondent sequence folder. The indexes have to match with the initial frame. IOError: No such file or directory 0%| | 0/30 [00:00<?, ?it/s] Could anyone help me to figure out the problem?

    opened by longmalongma 0
  • Could the training batchsize turned up

    Could the training batchsize turned up

    When I turn up the value of train batch, it will appear this masks = torch.stack([sample[1] for sample in batch]) RuntimeError: stack expects each tensor to be equal size, but got [3, 9, 240, 427] at entry 0 and [3, 13, 240, 427] at entry 1 this error points to the data.py (line 63) only keep batchsize=1 it will run, I'm hope for you reply, Thanks a lot.

    opened by Sunyankun 1
  • about pre-train code

    about pre-train code

    First of all, thank you for your training code, but the accuracy obtained in this way is about 69.4. Do you have a pre-train code? If it is convenient for you, please release it.

    opened by tian961214 2
  • 请问训练可以设定batchsize大于1吗?

    请问训练可以设定batchsize大于1吗?

    我个人也复现了STM,想请教您几个问题, 一、STM在测试时个人认为有个很大的缺陷,它处理尺度较大的视频时显层需求太大了,如果resize为小尺度又感觉精度明显损失,针对这个问题您如何解决? 二、STM在测试时,如果视频很长,随着推测进行memory的features会越来越多,这又会造成很大的开销,这个问题您又如果解决呢?

    opened by pengqianli 1
Owner
Yuxi Li
Yuxi Li
Space Time Recurrent Memory Network - Pytorch

Space Time Recurrent Memory Network - Pytorch (wip) Implementation of Space Time Recurrent Memory Network, recurrent network competitive with attentio

Phil Wang 50 Nov 7, 2021
Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.

InfoPro-Pytorch The Information Propagation algorithm for training deep networks with local supervision. (ICLR 2021) Revisiting Locally Supervised Lea

null 78 Dec 27, 2022
Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation

STCN Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [a

Rex Cheng 456 Dec 12, 2022
STMTrack: Template-free Visual Tracking with Space-time Memory Networks

STMTrack This is the official implementation of the paper: STMTrack: Template-free Visual Tracking with Space-time Memory Networks. Setup Prepare Anac

Zhihong Fu 62 Dec 21, 2022
PyTorch Code of "Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics"

Memory In Memory Networks It is based on the paper Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spati

Yang Li 12 May 30, 2022
ActNN: Reducing Training Memory Footprint via 2-Bit Activation Compressed Training

ActNN : Activation Compressed Training This is the official project repository for ActNN: Reducing Training Memory Footprint via 2-Bit Activation Comp

UC Berkeley RISE 178 Jan 5, 2023
Segcache: a memory-efficient and scalable in-memory key-value cache for small objects

Segcache: a memory-efficient and scalable in-memory key-value cache for small objects This repo contains the code of Segcache described in the followi

TheSys Group @ CMU CS 78 Jan 7, 2023
Episodic-memory - Ego4D Episodic Memory Benchmark

Ego4D Episodic Memory Benchmark EGO4D is the world's largest egocentric (first p

null 3 Feb 18, 2022
Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Phil Wang 180 Jan 5, 2023
Space robot - (Course Project) Using the space robot to capture the target satellite that is disabled and spinning, then stabilize and fix it up

Space robot - (Course Project) Using the space robot to capture the target satellite that is disabled and spinning, then stabilize and fix it up

Mingrui Yu 3 Jan 7, 2022
The official pytorch implemention of the CVPR paper "Temporal Modulation Network for Controllable Space-Time Video Super-Resolution".

This is the official PyTorch implementation of TMNet in the CVPR 2021 paper "Temporal Modulation Network for Controllable Space-Time VideoSuper-Resolu

Gang Xu 95 Oct 24, 2022
Drone-based Joint Density Map Estimation, Localization and Tracking with Space-Time Multi-Scale Attention Network

DroneCrowd Paper Detection, Tracking, and Counting Meets Drones in Crowds: A Benchmark. Introduction This paper proposes a space-time multi-scale atte

VisDrone 98 Nov 16, 2022
In-Place Activated BatchNorm for Memory-Optimized Training of DNNs

In-Place Activated BatchNorm In-Place Activated BatchNorm for Memory-Optimized Training of DNNs In-Place Activated BatchNorm (InPlace-ABN) is a novel

null 1.3k Dec 29, 2022
This is the official PyTorch implementation for "Mesa: A Memory-saving Training Framework for Transformers".

Mesa: A Memory-saving Training Framework for Transformers This is the official PyTorch implementation for Mesa: A Memory-saving Training Framework for

Zhuang AI Group 105 Dec 6, 2022
PyTorchMemTracer - Depict GPU memory footprint during DNN training of PyTorch

A Memory Tracer For PyTorch OOM is a nightmare for PyTorch users. However, most

Jiarui Fang 9 Nov 14, 2022
PyTorch implementation of our Adam-NSCL algorithm from our CVPR2021 (oral) paper "Training Networks in Null Space for Continual Learning"

Adam-NSCL This is a PyTorch implementation of Adam-NSCL algorithm for continual learning from our CVPR2021 (oral) paper: Title: Training Networks in N

Shipeng Wang 34 Dec 21, 2022
RTS3D: Real-time Stereo 3D Detection from 4D Feature-Consistency Embedding Space for Autonomous Driving

RTS3D: Real-time Stereo 3D Detection from 4D Feature-Consistency Embedding Space for Autonomous Driving (AAAI2021). RTS3D is efficiency and accuracy s

null 71 Nov 29, 2022
PyTorch implementation of paper "Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes", CVPR 2021

Neural Scene Flow Fields PyTorch implementation of paper "Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes", CVPR 20

Zhengqi Li 585 Jan 4, 2023
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Phil Wang 109 Dec 28, 2022