Multi-Stage Spatial-Temporal Convolutional Neural Network (MS-GCN)

Related tags

Deep Learning MS-GCN
Overview

Multi-Stage Spatial-Temporal Convolutional Neural Network (MS-GCN)

This code implements the skeleton-based action segmentation MS-GCN model from Automated freezing of gait assessment with marker-based motion capture and multi-stage spatial-temporal graph convolutional neural networks and Skeleton-based action segmentation with multi-stage spatial-temporal graph convolutional neural networks, arXiv 2022 (in-review).

It was originally developed for freezing of gait (FOG) assessment on a proprietary dataset. Recently, we have also achieved high skeleton-based action segmentation performance on public datasets, e.g. HuGaDB, LARa, PKU-MMD v2, TUG.

Requirements

Tested on Ubuntu 16.04 and Pytorch 1.10.1. Models were trained on a Nvidia Tesla K80.

The c3d data preparation script requires Biomechanical-Toolkit. For installation instructions, please refer to the following issue.

Content

  • data_prep/ -- Data preparation scripts.
  • main.py -- Main script. I suggest working with this interactively with an IDE. Please provide the dataset and train/predict arguments, e.g. --dataset=fog_example --action=train.
  • batch_gen.py -- Batch loader.
  • label_eval.py -- Compute metrics and save prediction results.
  • model.py -- train/predict script.
  • models/ -- Location for saving the trained models.
  • models/ms_gcn.py -- The MS-GCN model.
  • models/net_utils/ -- Scripts to partition the graph for the various datasets. For more information about the partitioning, please refer to the section Graph representations. For more information about spatial-temporal graphs, please refer to ST-GCN.
  • data/ -- Location for the processed datasets. For more information, please refer to the 'FOG' example.
  • data/signals. -- Scripts for computing the feature representations. Used for datasets that provided spatial features per joint, e.g. FOG, TUG, and PKU-MMD v2. For more information, please refer to the section Graph representations.
  • results/ -- Location for saving the results.

Data

After processing the dataset (scripts are dataset specific), each processed dataset should be placed in the data folder. We provide an example for a motion capture dataset that is in c3d format. For this particular example, we extract 9 joints in 3D:

  • data_prep/read_frame.py -- Import the joints and action labels from the c3d and save both in a separate csv.
  • data_prep/gen_data/ -- Import the csv, construct the input, and save to npy for training. For more information about the input and label shape, please refer to the section Problem statement.

Please refer to the example in data/example/ for more information on how to structure the files for training/prediction.

Pre-trained models

Pre-trained models are provided for HuGaDB, PKU-MMD, and LARa. To reproduce the results from the paper:

  • The dataset should be downloaded from their respective repository.
  • See the "Data" section for more information on how to prepare the datasets.
  • Place the pre-trained models in models/, e.g. models/hugadb.
  • Ensure that the correct graph representation is chosen in ms_gcn.
  • Comment out features = get_features(features) in model (only for lara and hugadb).
  • Specify the correct sampling rate, e.g. downsampling factor of 4 for lara.
  • Run main to generate the per-sample predictions with proper arguments, e.g. --dataset=hugadb --action=predict.
  • Run label_eval with proper arguments, e.g. --dataset=hugadb.

Acknowledgements

The MS-GCN model and code are heavily based on ST-GCN and MS-TCN. We thank the authors for publicly releasing their code.

License

MIT

Comments
  • Hyperparameter Setting

    Hyperparameter Setting

    Hi,when I use the dataset hugadb I set bz=8 and sample_rate=4,but it does not reach the results of the paper, and it is much worse than the results of the paper.May I ask how you set these parameters?Looking forward to your reply. Thank you.

    opened by Forgetmmmm 12
  • Dimension mismatch

    Dimension mismatch

    Hello,thanks for sharing the code.When I use the lara dataset,I encountered the following problem: Training subject: 1 Traceback (most recent call last): File "/media/hz/A2/cmz/code/MS-GCN/main.py", line 68, in trainer.train(model_dir, batch_gen, num_epochs=num_epochs, batch_size=bz, learning_rate=lr, device=device) File "/media/hz/A2/cmz/code/MS-GCN/model.py", line 36, in train batch_input, batch_target, mask, weight = batch_gen.next_batch(batch_size) File "/media/hz/A2/cmz/code/MS-GCN/batch_gen.py", line 69, in next_batch batch_input_tensor[i, :, :np.shape(batch_input[i])[1], :, :] = torch.from_numpy(batch_input[i]) RuntimeError: The expanded size of the tensor (9) must match the existing size (19) at non-singleton dimension 2. Target sizes: [6, 24000, 9, 1]. Tensor sizes: [12, 24000, 19, 1] Do you know how to fix it? Looking forward to your reply.

    opened by Forgetmmmm 4
  • Lara Dataset

    Lara Dataset

    Hi benjamin, Me here again :) Another question haha! I found the joint number of LARa is 39 referred to [LARa: Creating a Dataset for Human Activity Recognition in Logistics Using Semantic Attributes]. However, the norm_data_csv in OMoCap.zip (downloaded from your LARa link) owns 132/6 = 22 joints. Your report is 17 joints. It confused me a lot. Did I made a mistake somewhere?

    opened by hitcbw 2
  • About read_frame.py files

    About read_frame.py files

    Hello,I would like to ask if the input of read_frame.py file is the video, and the output is the joint data of the people in the video. What does the first dimension 6 of the downloaded feature file represent?

    opened by Forgetmmmm 1
  • action segmentation in PKU_MMD

    action segmentation in PKU_MMD

    Hi Benjamin, Thanks for your exciting efforts. I have modified the code to adjust to the dataset PKU_MMD with your help, thank you very much. However, i noticed that there exists action overlaps in PKU_MMD, e.g., several frames own double labels. How does the MS-GCN work in this situation? As far as I know, the MS_TCN is a single-label classifier.

    opened by hitcbw 1
  • function getfeatures() error

    function getfeatures() error

    Hi, Many thanks for sharing the code. I met trouble when execute function getfeatures() -> get_relative_coordinates. My understanding of this function is to calculate the displacements compared to 4 root nodes (joint index: 4 8 12 16 defined as NTU_RGBD).

    The error appears when execute code: coords_diff = (sample.transpose((2, 0, 1, 3)) - ref_loc).transpose((1, 2, 0, 3)) ERROR: The numpy can't broadcast with shapes (25,3,1838,1) v.s. (3,1838,4,1)

    Then the code shows: # Shape: 4*C, t, V, M rel_coords = np.vstack(rel_coords) so I modified the code as:

        for ref in references:
            ref_loc = sample[:, :, ref, :]
            coords_diff = (sample.transpose((2, 0, 1, 3)) - ref_loc).transpose((1, 2, 0, 3))
            rel_coords.append(coords_diff)
    

    The code runs correctly then. However,

    # Shape: 4*C, t, V, M 
    rel_coords = np.vstack(rel_coords)
    # Shape: C, T, V, M
    final_sample[:, start:end, :, :] = rel_coords
    

    can't runs correctly because [4*C, t, V, M] can't adjust to [C, t, V, M] Could you help to solve this problem? thanks.

    opened by hitcbw 1
  • How to get the datasets?

    How to get the datasets?

    Hi ,

    Many thanks for sharing the code and I really like this interesting work. However, I can't find a link to download the corresponding dataset. How do I access the dataset?

    Best regards

    opened by lyhisme 1
Owner
Benjamin Filtjens
PhD Student working towards at-home freezing of gait detection https://orcid.org/0000-0003-2609-6883
Benjamin Filtjens
Code for paper Decoupled Dynamic Spatial-Temporal Graph Neural Network for Traffic Forecasting

Decoupled Spatial-Temporal Graph Neural Networks Code for our paper: Decoupled Dynamic Spatial-Temporal Graph Neural Network for Traffic Forecasting.

S22 43 Jan 4, 2023
A PyTorch implementation of "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" (KDD 2019).

ClusterGCN ⠀⠀ A PyTorch implementation of "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" (KDD 2019). A

Benedek Rozemberczki 697 Dec 27, 2022
Code for Two-stage Identifier: "Locate and Label: A Two-stage Identifier for Nested Named Entity Recognition"

Code for Two-stage Identifier: "Locate and Label: A Two-stage Identifier for Nested Named Entity Recognition", accepted at ACL 2021. For details of the model and experiments, please see our paper.

tricktreat 87 Dec 16, 2022
Code for our NeurIPS 2021 paper Mining the Benefits of Two-stage and One-stage HOI Detection

CDN Code for our NeurIPS 2021 paper "Mining the Benefits of Two-stage and One-stage HOI Detection". Contributed by Aixi Zhang*, Yue Liao*, Si Liu, Mia

null 71 Dec 14, 2022
Code for Mining the Benefits of Two-stage and One-stage HOI Detection

Status: Archive (code is provided as-is, no updates expected) PPO-EWMA [Paper] This is code for training agents using PPO-EWMA and PPG-EWMA, introduce

OpenAI 33 Dec 15, 2022
Virtual Dance Reality Stage: a feature that offers you to share a stage with another user virtually

Portrait Segmentation using Tensorflow This script removes the background from an input image. You can read more about segmentation here Setup The scr

null 291 Dec 24, 2022
data/code repository of "C2F-FWN: Coarse-to-Fine Flow Warping Network for Spatial-Temporal Consistent Motion Transfer"

C2F-FWN data/code repository of "C2F-FWN: Coarse-to-Fine Flow Warping Network for Spatial-Temporal Consistent Motion Transfer" (https://arxiv.org/abs/

EKILI 46 Dec 14, 2022
Graph Self-Attention Network for Learning Spatial-Temporal Interaction Representation in Autonomous Driving

GSAN Introduction Code for paper GSAN: Graph Self-Attention Network for Learning Spatial-Temporal Interaction Representation in Autonomous Driving, wh

YE Luyao 6 Oct 27, 2022
TCNN Temporal convolutional neural network for real-time speech enhancement in the time domain

TCNN Pandey A, Wang D L. TCNN: Temporal convolutional neural network for real-time speech enhancement in the time domain[C]//ICASSP 2019-2019 IEEE Int

凌逆战 16 Dec 30, 2022
Temporal Dynamic Convolutional Neural Network for Text-Independent Speaker Verification and Phonemetic Analysis

TDY-CNN for Text-Independent Speaker Verification Official implementation of Temporal Dynamic Convolutional Neural Network for Text-Independent Speake

Seong-Hu Kim 16 Oct 17, 2022
This is a model made out of Neural Network specifically a Convolutional Neural Network model

This is a model made out of Neural Network specifically a Convolutional Neural Network model. This was done with a pre-built dataset from the tensorflow and keras packages. There are other alternative libraries that can be used for this purpose, one of which is the PyTorch library.

null 9 Oct 18, 2022
CVPR2021: Temporal Context Aggregation Network for Temporal Action Proposal Refinement

Temporal Context Aggregation Network - Pytorch This repo holds the pytorch-version codes of paper: "Temporal Context Aggregation Network for Temporal

Zhiwu Qing 63 Sep 27, 2022
MXNet implementation for: Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution

Octave Convolution MXNet implementation for: Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution Imag

Meta Research 549 Dec 28, 2022
Implementation of the paper NAST: Non-Autoregressive Spatial-Temporal Transformer for Time Series Forecasting.

Non-AR Spatial-Temporal Transformer Introduction Implementation of the paper NAST: Non-Autoregressive Spatial-Temporal Transformer for Time Series For

Chen Kai 66 Nov 28, 2022
Unofficial implementation of "TTNet: Real-time temporal and spatial video analysis of table tennis" (CVPR 2020)

TTNet-Pytorch The implementation for the paper "TTNet: Real-time temporal and spatial video analysis of table tennis" An introduction of the project c

Nguyen Mau Dung 438 Dec 29, 2022
The project is an official implementation of our paper "3D Human Pose Estimation with Spatial and Temporal Transformers".

3D Human Pose Estimation with Spatial and Temporal Transformers This repo is the official implementation for 3D Human Pose Estimation with Spatial and

Ce Zheng 363 Dec 28, 2022
This is the official Pytorch implementation of the paper "Diverse Motion Stylization for Multiple Style Domains via Spatial-Temporal Graph-Based Generative Model"

Diverse Motion Stylization (Official) This is the official Pytorch implementation of this paper. Diverse Motion Stylization for Multiple Style Domains

Soomin Park 28 Dec 16, 2022
Group Activity Recognition with Clustered Spatial Temporal Transformer

GroupFormer Group Activity Recognition with Clustered Spatial-TemporalTransformer Backbone Style Action Acc Activity Acc Config Download Inv3+flow+pos

null 28 Dec 12, 2022
PyMove is a Python library to simplify queries and visualization of trajectories and other spatial-temporal data

Use PyMove and go much further Information Package Status License Python Version Platforms Build Status PyPi version PyPi Downloads Conda version Cond

Insight Data Science Lab 64 Nov 15, 2022