Temporal Segment Networks (TSN) in PyTorch

Overview

TSN-Pytorch

We have released MMAction, a full-fledged action understanding toolbox based on PyTorch. It includes implementation for TSN as well as other STOA frameworks for various tasks. The lessons we learned in this repo are incorporated into MMAction to make it bettter. We highly recommend you switch to it. This repo will remain here for historical references.

Note: always use git clone --recursive https://github.com/yjxiong/tsn-pytorch to clone this project. Otherwise you will not be able to use the inception series CNN archs.

This is a reimplementation of temporal segment networks (TSN) in PyTorch. All settings are kept identical to the original caffe implementation.

For optical flow extraction and video list generation, you still need to use the original TSN codebase.

Training

To train a new model, use the main.py script.

The command to reproduce the original TSN experiments of RGB modality on UCF101 can be

python main.py ucf101 RGB <ucf101_rgb_train_list> <ucf101_rgb_val_list> \
   --arch BNInception --num_segments 3 \
   --gd 20 --lr 0.001 --lr_steps 30 60 --epochs 80 \
   -b 128 -j 8 --dropout 0.8 \
   --snapshot_pref ucf101_bninception_ 

For flow models:

python main.py ucf101 Flow <ucf101_flow_train_list> <ucf101_flow_val_list> \
   --arch BNInception --num_segments 3 \
   --gd 20 --lr 0.001 --lr_steps 190 300 --epochs 340 \
   -b 128 -j 8 --dropout 0.7 \
   --snapshot_pref ucf101_bninception_ --flow_pref flow_  

For RGB-diff models:

python main.py ucf101 RGBDiff <ucf101_rgb_train_list> <ucf101_rgb_val_list> \
   --arch BNInception --num_segments 7 \
   --gd 40 --lr 0.001 --lr_steps 80 160 --epochs 180 \
   -b 128 -j 8 --dropout 0.8 \
   --snapshot_pref ucf101_bninception_ 

Testing

After training, there will checkpoints saved by pytorch, for example ucf101_bninception_rgb_checkpoint.pth.

Use the following command to test its performance in the standard TSN testing protocol:

python test_models.py ucf101 RGB <ucf101_rgb_val_list> ucf101_bninception_rgb_checkpoint.pth \
   --arch BNInception --save_scores <score_file_name>

Or for flow models:

python test_models.py ucf101 Flow <ucf101_rgb_val_list> ucf101_bninception_flow_checkpoint.pth \
   --arch BNInception --save_scores <score_file_name> --flow_pref flow_
Comments
  • dataloader runtime errors

    dataloader runtime errors

    python main.py ucf101 RGB ucf101_trainlist01new.txt ucf101_testlist01new.txt --gpus 1 --arch BNInception --num_segments 3 --gd 20 --lr 0.001 --lr_steps 30 60 --epochs 80 -b 128 -j 8 --dropout 0.8

    Initializing TSN with base model: BNInception. TSN Configurations: input_modality: RGB num_segments: 3 new_length: 1 consensus_module: avg dropout_ratio: 0.8

    /home/ytan/miniconda3/lib/python3.6/site-packages/torch/nn/modules/module.py:360: UserWarning: src is not broadcastable to dst, but they have the same number of elements. Falling back to deprecated pointwise behavior. own_state[name].copy_(param) group: first_conv_weight has 1 params, lr_mult: 1, decay_mult: 1 group: first_conv_bias has 1 params, lr_mult: 2, decay_mult: 0 group: normal_weight has 69 params, lr_mult: 1, decay_mult: 1 group: normal_bias has 69 params, lr_mult: 2, decay_mult: 0 group: BN scale/shift has 2 params, lr_mult: 1, decay_mult: 0 Freezing BatchNorm2D except the first one. Traceback (most recent call last): File "main.py", line 301, in main() File "main.py", line 124, in main train(train_loader, model, criterion, optimizer, epoch) File "main.py", line 157, in train for i, (input, target) in enumerate(train_loader): File "/home/ytan/miniconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 201, in next return self._process_next_batch(batch) File "/home/ytan/miniconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 221, in _process_next_batch raise batch.exc_type(batch.exc_msg) RuntimeError: Traceback (most recent call last): File "/home/ytan/miniconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 40, in _worker_loop samples = collate_fn([dataset[i] for i in batch_indices]) File "/home/ytan/miniconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 109, in default_collate return [default_collate(samples) for samples in transposed] File "/home/ytan/miniconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 109, in return [default_collate(samples) for samples in transposed] File "/home/ytan/miniconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 89, in default_collate storage = batch[0].storage()._new_shared(numel) File "/home/ytan/miniconda3/lib/python3.6/site-packages/torch/storage.py", line 113, in _new_shared return cls._new_using_fd(size) RuntimeError: unable to write to file </torch_476_615100490> at /opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/TH/THAllocator.c:271

    Any suggestions?

    opened by ntuyt 20
  • [solved] RGBDiff - no progress

    [solved] RGBDiff - no progress

    Hi,

    I am trying to reproduce your results, nevertheless I found, that current official RGBDiff implementations does no learn: image Python 3.6, Pytorch 0.3.1, cuda90.

    The interesting part is, that the training loop looks correct, so....some model architecture bug?

    opened by Scitator 18
  • size mismatch

    size mismatch

    .local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for BNInception: size mismatch for conv1_7x7_s2_bn.running_var: copying a param with shape torch.Size([1, 64]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for conv1_7x7_s2_bn.bias: copying a param with shape torch.Size([1, 64]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for conv1_7x7_s2_bn.weight: copying a param with shape torch.Size([1, 64]) from checkpoint, the shape in current model is torch.Size([64]). ...

    opened by zhangtao22 12
  • Normalization in RGBDiff model

    Normalization in RGBDiff model

    in main.py, it doesn't do normalization for RGBDiff:

    if args.modality != 'RGBDiff':
          normalize = GroupNormalize(input_mean, input_std)
    else:
          normalize = IdentityTransform()
    

    Is there any reason to do that? And I found in test_model.py, you still have normalization for RGBDiff, so I got incorrect testing results when I first tried it. That problem was solved by changing to IdentityTransform. I wonder which one you used for your final results? IdentityTransform or GroupNormalize?

    Thank you.

    opened by cmhungsteve 10
  • Error: return int(self._data[2]) IndexError: list index out of range

    Error: return int(self._data[2]) IndexError: list index out of range

    I am getting this error. I uploaded the ucf_101_rgb_train_list.txt, I am suspecting that there is something wrong with it. ucf101_rgb_train_list.txt

    Traceback (most recent call last): File "C:\Users\mab73\Desktop\tsn-pytorch-master\main.py", line 301, in main() File "C:\Users\mab73\Desktop\tsn-pytorch-master\main.py", line 124, in main train(train_loader, model, criterion, optimizer, epoch) File "C:\Users\mab73\Desktop\tsn-pytorch-master\main.py", line 157, in train for i, (input, target) in enumerate(train_loader): File "C:\Users\mab73\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 336, in next return self._process_next_batch(batch) File "C:\Users\mab73\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 357, in _process_next_batch raise batch.exc_type(batch.exc_msg) IndexError: Traceback (most recent call last): File "C:\Users\mab73\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 106, in _worker_loop samples = collate_fn([dataset[i] for i in batch_indices]) File "C:\Users\mab73\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 106, in samples = collate_fn([dataset[i] for i in batch_indices]) File "C:\Users\mab73\Desktop\tsn-pytorch-master\dataset.py", line 101, in getitem return self.get(record, segment_indices) File "C:\Users\mab73\Desktop\tsn-pytorch-master\dataset.py", line 115, in get return process_data, record.label File "C:\Users\mab73\Desktop\tsn-pytorch-master\dataset.py", line 23, in label return int(self._data[2]) IndexError: list index out of range

    opened by Mohamad73 8
  • Problems with the test_models.py

    Problems with the test_models.py

    Hi,

    I have trained the RGB models for all 3 splits but I am facing some issues with the test_models.py program.

    • Line 48, while calling the model, two arguments are passed( rnn=args.rnn, rnn_mem_size=args.rnn_mem_size ) which are not valid.
    • If I remove these arguments and run, I am getting an list index out of range error on line 123.

    Here is the error stack trace

    model epoch 80 best prec@1: 83.4522855911
    Traceback (most recent call last):
      File "test_models.py", line 123, in <module>
        rst = eval_video((i, data, label))
      File "test_models.py", line 111, in eval_video
        rst = net(input_var).data.cpu().numpy().copy()
      File "/export/home/utsav/.local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 224, in __call__
        result = self.forward(*input, **kwargs)
      File "/export/home/utsav/.local/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 56, in forward
        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
      File "/export/home/utsav/.local/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 67, in scatter
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
      File "/export/home/utsav/.local/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py", line 30, in scatter_kwargs
        inputs = scatter(inputs, target_gpus, dim)
      File "/export/home/utsav/.local/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py", line 25, in scatter
        return scatter_map(inputs)
      File "/export/home/utsav/.local/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py", line 18, in scatter_map
        return tuple(zip(*map(scatter_map, obj)))
      File "/export/home/utsav/.local/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py", line 15, in scatter_map
        return Scatter(target_gpus, dim=dim)(obj)
      File "/export/home/utsav/.local/lib/python2.7/site-packages/torch/nn/parallel/_functions.py", line 59, in forward
        streams = [_get_stream(device) for device in self.target_gpus]
      File "/export/home/utsav/.local/lib/python2.7/site-packages/torch/nn/parallel/_functions.py", line 85, in _get_stream
        if _streams[device] is None:
    IndexError: list index out of range
    
    opened by utsavgarg 8
  • Running out of memory

    Running out of memory

    I was trying to run training for UCF-101 RGB split-1 but the model seems to be running out of memory. I am using a GPU with 16 GB VRAM. What is the memory requirement ?

    opened by utsavgarg 8
  • Network is unreachable

    Network is unreachable

    When I run the training script, I encounter the following error:

    Downloading: "https://yjxiong.blob.core.windows.net/models/bn_inception-9f5701afb96c8044.pth" to /mnt/lustre/ganweihao/.torch/models/bn_inception-9f5701afb96c8044.p th

    Initializing TSN with base model: BNInception.
    TSN Configurations:
    input_modality: RGB
    num_segments: 3
    new_length: 1
    consensus_module: avg
    dropout_ratio: 0.8

    Traceback (most recent call last):
    File "main.py", line 301, in
    main()
    File "main.py", line 35, in main
    consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn)
    File "/mnt/lustre/ganweihao/codes/tsn-pytorch/models.py", line 39, in init
    self._prepare_base_model(base_model)
    File "/mnt/lustre/ganweihao/codes/tsn-pytorch/models.py", line 96, in _prepare_base_model
    self.base_model = getattr(tf_model_zoo, base_model)()
    File "/mnt/lustre/ganweihao/codes/tsn-pytorch/tf_model_zoo/bninception/pytorch_load.py", line 35, in init
    self.load_state_dict(torch.utils.model_zoo.load_url(weight_url))
    File "/mnt/lustre/ganweihao/anaconda3/envs/python27/lib/python2.7/site-packages/torch/utils/model_zoo.py", line 56, in load_url _download_url_to_file(url, cached_file, hash_prefix)
    File "/mnt/lustre/ganweihao/anaconda3/envs/python27/lib/python2.7/site-packages/torch/utils/model_zoo.py", line 61, in _download_url_to_file u = urlopen(url)
    File "/mnt/lustre/ganweihao/anaconda3/envs/python27/lib/python2.7/urllib2.py", line 154, in urlopen return opener.open(url, data, timeout)
    File "/mnt/lustre/ganweihao/anaconda3/envs/python27/lib/python2.7/urllib2.py", line 429, in open response = self._open(req, data)
    File "/mnt/lustre/ganweihao/anaconda3/envs/python27/lib/python2.7/urllib2.py", line 447, in _open
    '_open', req)
    File "/mnt/lustre/ganweihao/anaconda3/envs/python27/lib/python2.7/urllib2.py", line 407, in _call_chain result = func(*args)
    File "/mnt/lustre/ganweihao/anaconda3/envs/python27/lib/python2.7/urllib2.py", line 1241, in https_open context=self._context)
    File "/mnt/lustre/ganweihao/anaconda3/envs/python27/lib/python2.7/urllib2.py", line 1198, in do_open raise URLError(err)
    urllib2.URLError: <urlopen error [Errno 101] Network is unreachable>

    Any idea to solve this? Many thanks.

    opened by gwh0112 6
  • List index out of range in test_models.py

    List index out of range in test_models.py

    Hello, I did the training part but when I try to do the test using test_models.py, I get the error below. Any help is appreciated. `python test_models.py ucf101 RGB ucf101_rgb_val_list ucf101_bninception__rgb_model_best.pth --arch BNInception --save_scores score_file_name

    Initializing TSN with base model: BNInception. TSN Configurations: input_modality: RGB num_segments: 1 new_length: 1 consensus_module: avg dropout_ratio: 0.7

    /home/mohamad/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py:514: UserWarning: src is not broadcastable to dst, but they have the same number of elements. Falling back to deprecated pointwise behavior. own_state[name].copy_(param) model epoch 40 best prec@1: 78.0861750075 Freezing BatchNorm2D except the first one. Traceback (most recent call last): File "test_models.py", line 128, in rst = eval_video((i, data, label)) File "test_models.py", line 116, in eval_video rst = net(input_var).data.cpu().numpy().copy() File "/home/mohamad/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 357, in call result = self.forward(*input, **kwargs) File "/home/mohamad/anaconda2/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 69, in forward inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) File "/home/mohamad/anaconda2/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 80, in scatter return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) File "/home/mohamad/anaconda2/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py", line 38, in scatter_kwargs inputs = scatter(inputs, target_gpus, dim) if inputs else [] File "/home/mohamad/anaconda2/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py", line 31, in scatter return scatter_map(inputs) File "/home/mohamad/anaconda2/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py", line 18, in scatter_map return list(zip(*map(scatter_map, obj))) File "/home/mohamad/anaconda2/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py", line 15, in scatter_map return Scatter.apply(target_gpus, None, dim, obj) File "/home/mohamad/anaconda2/lib/python2.7/site-packages/torch/nn/parallel/_functions.py", line 73, in forward streams = [_get_stream(device) for device in ctx.target_gpus] File "/home/mohamad/anaconda2/lib/python2.7/site-packages/torch/nn/parallel/_functions.py", line 100, in _get_stream if _streams[device] is None: IndexError: list index out of range`

    opened by Mohamad73 5
  • about flow accuracy

    about flow accuracy

    The accuracy of ucf101 (split 1) with flow is only 81%, but I used the official code to extract the flow. And the acc is 87% on the paper.Can you tell me what is wrong? @yjxiong

    opened by ll490187880 5
  • Error in loading weights for BNInception module

    Error in loading weights for BNInception module

    Hi Xiong, Thanks for your great codes. Trouble happen when I running the test_model.py following your instruction:

    ucf101 RGB data/ucf101_splits kinetics_rgb.pth --arch BNInception --save_scores kinetics_flow_ucf101_scores.txt
    

    wherein kinetics_rgb.pth comes from here by @shuangshuangguo

    The errors look like:

    RuntimeError: Error(s) in loading state_dict for BNInception:
    ...
    While copying the parameter named "inception_4b_1x1_bn.running_mean", whose dimensions in the model are torch.Size([192]) and whose dimensions in the checkpoint are torch.Size([1, 192]).
    ...
    

    Namely each module has it's weight in checkpoint with one more dimension, compared to what it supposed to be.

    I am a beginner with PyTorch. Hope you can answer my question. Thanks you!

    opened by pengzhenghao 5
  • f**k-est implementation

    f**k-est implementation

    DO U really Think this implementation is good ? this is the fk implementation ever seen, the Logic and Organization of project is just like a shit, fk the repo, rubbish, garbage. ..!!!

    opened by Ontheway361 3
  • Why subtract 'new_length' to calculate 'average_duration' ?

    Why subtract 'new_length' to calculate 'average_duration' ?

    Hi. Why do you subtract 'new_length' from 'num_frames' and then divide by 'num_segments' to calculate 'average_duration' ? Can we not directly divide 'num_frames' by 'num_segments' to get the 'average_duration' ? Thanks!

    https://github.com/yjxiong/tsn-pytorch/blob/2f0468e049760530b4995bf72c8a512623929c39/dataset.py#L66

    opened by Gateway2745 0
  • Training the TSN model on custom dataset - couldn't implement as mentioned in paper

    Training the TSN model on custom dataset - couldn't implement as mentioned in paper

    I am trying to train action detection model using TSN network for a different custom dataset. In the original TSN paper it is mentioned that along with rgb data, rgb differences and optical flow is passed into the model. Moreover the entire video is divided into snippets and then model prediction is done based on the results of it. I couldn't find that type of implementation here or in mmaction repository. Please help me in knowing how can I implement the action detection task on a custom dataset as mentioned in the paper. Thanks.

    opened by anuraganand1838 1
  • video live test?

    video live test?

    Hello Many thanks for sharing your code, which looks fantastic. Is it possible for you to explain that the procedure of live video test such as webcam camera? Thank you

    opened by Mkarami3 0
Owner
Young and simple. MMLAB@CUHK -> Amazon Rekognition. We are hiring summer interns for 2021.
null
A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction"

ssnt-loss ℹ️ This is a WIP project. the implementation is still being tested. A pure PyTorch implementation of the loss described in "Online Segment t

張致強 1 Feb 9, 2022
CVPR2021: Temporal Context Aggregation Network for Temporal Action Proposal Refinement

Temporal Context Aggregation Network - Pytorch This repo holds the pytorch-version codes of paper: "Temporal Context Aggregation Network for Temporal

Zhiwu Qing 63 Sep 27, 2022
Implementation of temporal pooling methods studied in [ICIP'20] A Comparative Evaluation Of Temporal Pooling Methods For Blind Video Quality Assessment

Implementation of temporal pooling methods studied in [ICIP'20] A Comparative Evaluation Of Temporal Pooling Methods For Blind Video Quality Assessment

Zhengzhong Tu 5 Sep 16, 2022
Cascaded Deep Video Deblurring Using Temporal Sharpness Prior and Non-local Spatial-Temporal Similarity

This repository is the official PyTorch implementation of Cascaded Deep Video Deblurring Using Temporal Sharpness Prior and Non-local Spatial-Temporal Similarity

hippopmonkey 4 Dec 11, 2022
Pytorch implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Detection"

M-LSD: Towards Light-weight and Real-time Line Segment Detection Pytorch implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Det

null 123 Jan 4, 2023
Pytorch implementation of paper "Learning Co-segmentation by Segment Swapping for Retrieval and Discovery"

SegSwap Pytorch implementation of paper "Learning Co-segmentation by Segment Swapping for Retrieval and Discovery" [PDF] [Project page] If our project

xshen 41 Dec 10, 2022
Official PyTorch implementation of "AASIST: Audio Anti-Spoofing using Integrated Spectro-Temporal Graph Attention Networks"

AASIST This repository provides the overall framework for training and evaluating audio anti-spoofing systems proposed in 'AASIST: Audio Anti-Spoofing

Clova AI Research 56 Jan 2, 2023
Spatial Temporal Graph Convolutional Networks (ST-GCN) for Skeleton-Based Action Recognition in PyTorch

Reminder ST-GCN has transferred to MMSkeleton, and keep on developing as an flexible open source toolbox for skeleton-based human understanding. You a

sijie yan 1.1k Dec 25, 2022
Code for "Learning to Segment Rigid Motions from Two Frames".

rigidmask Code for "Learning to Segment Rigid Motions from Two Frames". ** This is a partial release with inference and evaluation code.

Gengshan Yang 157 Nov 21, 2022
[CVPR2021] DoDNet: Learning to segment multi-organ and tumors from multiple partially labeled datasets

DoDNet This repo holds the pytorch implementation of DoDNet: DoDNet: Learning to segment multi-organ and tumors from multiple partially labeled datase

null 116 Dec 12, 2022
code for `Look Closer to Segment Better: Boundary Patch Refinement for Instance Segmentation`

Look Closer to Segment Better: Boundary Patch Refinement for Instance Segmentation (CVPR 2021) Introduction PBR is a conceptually simple yet effective

H.Chen 143 Jan 5, 2023
Official Tensorflow implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Detection"

M-LSD: Towards Light-weight and Real-time Line Segment Detection Official Tensorflow implementation of "M-LSD: Towards Light-weight and Real-time Line

NAVER/LINE Vision 357 Jan 4, 2023
LETR: Line Segment Detection Using Transformers without Edges

LETR: Line Segment Detection Using Transformers without Edges Introduction This repository contains the official code and pretrained models for Line S

mlpc-ucsd 157 Jan 6, 2023
COD-Rank-Localize-and-Segment (CVPR2021)

COD-Rank-Localize-and-Segment (CVPR2021) Simultaneously Localize, Segment and Rank the Camouflaged Objects Full camouflage fixation training dataset i

JingZhang 52 Dec 20, 2022
【ACMMM 2021】DSANet: Dynamic Segment Aggregation Network for Video-Level Representation Learning

DSANet: Dynamic Segment Aggregation Network for Video-Level Representation Learning (ACMMM 2021) Overview We release the code of the DSANet (Dynamic S

Wenhao Wu 46 Dec 27, 2022
This project aims to segment 4 common retinal lesions from Fundus Images.

This project aims to segment 4 common retinal lesions from Fundus Images.

Husam Nujaim 1 Oct 10, 2021
Identify the emotion of multiple speakers in an Audio Segment

MevonAI - Speech Emotion Recognition Identify the emotion of multiple speakers in a Audio Segment Report Bug · Request Feature Try the Demo Here Table

Suyash More 110 Dec 3, 2022
some classic model used to segment the medical images like CT、X-ray and so on

github_project This is a project for medical image segmentation. This project includes common medical image segmentation models such as U-net, FCN, De

null 2 Mar 30, 2022
An e-commerce company wants to segment its customers and determine marketing strategies according to these segments.

customer_segmentation_with_rfm Business Problem : An e-commerce company wants to

Buse Yıldırım 3 Jan 6, 2022