Chiaraplizz / ST-TR

Spatial Temporal Transformer Network for Skeleton-Based Activity Recognition
MIT License
294 stars 57 forks source link

Spatial Temporal Transformer Network

Introduction

This repository contains the implementation of the model presented in the following paper:

Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition, Chiara Plizzari, Marco Cannici, Matteo Matteucci, ArXiv

Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition, Chiara Plizzari, Marco Cannici, Matteo Matteucci, Pattern Recognition. ICPR International Workshops and Challenges, 2021, Proceedings

Skeleton-based action recognition via spatial and temporal transformer networks, Chiara Plizzari, Marco Cannici, Matteo Matteucci, Computer Vision and Image Understanding, Volumes 208-209, 2021, 103219, ISSN 1077-3142, CVIU

Alt Text

Visualizations of Spatial Transformer logits

The heatmaps are 25 x 25 matrices, where each row and each column represents a body joint. An element in position (i, j) represents the correlation between joint i and joint j, resulting from self-attention.

Alt TextAlt Text

Prerequisites

Run mode

 python3 main.py 
**Training**: Set in /config/st_gcn/nturgbd/train.yaml: - Training: True **Testing**: Set in /config/st_gcn/nturgbd/train.yaml: - Training: False ### Data generation We performed our experiments on three datasets: **NTU-RGB+D 60**, **NTU-RGB+D 120** and **Kinetics**. #### NTU-RGB+D The data can downloaded from [their website](http://rose1.ntu.edu.sg/datasets/actionrecognition.asp). You need to download **3D Skeletons** only (5.8G (NTU-60) + 4.5G (NTU-120)). Once downloaded, use the following to generate joint data for NTU-60:
 python3 ntu_gendata.py 
If you want to generate data and preprocess them, use directly:
 python3 preprocess.py 
In order to generate bones, you need to run:
 python3 ntu_gen_bones.py 
The joint information and bone information can be merged through:
 python3 ntu_merge_joint_bones.py 
For NTU-120, the samples are divided between training and testing in a different way. Thus, you need to run:
 python3 ntu120_gendata.py 
If you want to generate data and process them directly, use:
 python3 preprocess_120.py 
#### Kinetics [Kinetics](https://deepmind.com/research/open-source/open-source-datasets/kinetics/) is a dataset for video action recognition, consisting of raw video data only. The corresponding skeletons are extracted using Openpose, and are available for download at [GoogleDrive](https://drive.google.com/open?id=1SPQ6FmFsjGg3f59uCWfdUWI-5HJM_YhZ) (7.5G). From raw skeletons, generate the dataset by running:
 python3 kinetics_gendata.py 
### Spatial Transformer Stream Spatial Transformer implementation corresponds to ST-TR/code/st_gcn/net/spatial_transformer.py. Set in /config/st_gcn/nturgbd/train.yaml: - attention: True - tcn_attention: False - only_attention: True - all_layers: False to run the spatial transformer stream (S-TR-stream). ### Temporal Transformer Stream Temporal Transformer implementation corresponds to ST-TR/code/st_gcn/net/temporal_transformer.py. Set in /config/st_gcn/nturgbd/train.yaml : - attention: False - tcn_attention: True - only_attention: True - all_layers: False to run the temporal transformer stream (T-TR-stream). ### To merge S-TR and T-TR (ST-TR) The score resulting from the S-TR stream and T-TR stream are combined to produce the final ST-TR score by:
  python3 ensemble.py 
### Adaptive Configuration (AGCN) In order to run T-TR-agcn and ST-TR-agcn configurations, please set agcn: True. ### Different ST-TR configurations Set in /config/st_gcn/nturgbd/train.yaml: - only_attention: False, to use ST-TR as an augmentation procedure to ST-GCN (refer to Sec. V(E) "Effect of Augmenting Convolution with Self-Attention") - all_layers: True, to apply ST-TR on all layers, otherwise it will be applied from the 4th layer on (refer to Sec. V(D) "Effect of Applying Self-Attention to Feature Extraction") - Set both attention: True and tcn_attention: True to combine both SSA and TSA on a unique stream (refer to Sec. V(F) "Effect of combining SSA and TSA on one stream") - more_channels: True, to assign to each head more channels than dk/Nh. - n: used if more_channels is set to True, in order to assign to each head dk*num/Nh channels To set the block dimensions of the windowed version of Temporal Transformer: - dim_block1, dim_block2, dim_block3, respectively to set block dimension where the output channels are equal to 64, 128 and 256. ### Second order information Set in /config/st_gcn/nturgbd/train.yaml: - channels: 6 , because on channels dimension we have both the coordinates of joint (3), and coordinates of bones(3) - double_channel: True , since in this configuration we also doubled the channels in each layer. ### Pre-trained Models Please notice I have attached pre-trained models of the configurations presented in the paper in the checkpoint_ST-TR folder. Please note that the \*_bones_\*.pth configurations correspond to the models trained with joint+bones information, while the others are trained with joints only. ### Citation Please cite one of the following papers if you use this code for your researches:
@article{plizzari2021skeleton,
  title={Skeleton-based action recognition via spatial and temporal transformer networks},
  author={Plizzari, Chiara and Cannici, Marco and Matteucci, Matteo},
  journal={Computer Vision and Image Understanding},
  volume={208},
  pages={103219},
  year={2021},
  publisher={Elsevier}
}
@inproceedings{plizzari2021spatial,
  title={Spatial temporal transformer network for skeleton-based action recognition},
  author={Plizzari, Chiara and Cannici, Marco and Matteucci, Matteo},
  booktitle={Pattern Recognition. ICPR International Workshops and Challenges: Virtual Event, January 10--15, 2021, Proceedings, Part III},
  pages={694--701},
  year={2021},
  organization={Springer}
}
## Contact :pushpin: If you have any question, do not hesitate to contact me at chiara.plizzari@mail.polimi.it. I will be glad to clarify your doubts! Note: we include LICENSE, LICENSE_1 and LICENSE_2 in this repository since part of the code has been derived respectively from https://github.com/yysijie/st-gcn, https://github.com/leaderj1001/Attention-Augmented-Conv2d and https://github.com/kenziyuliu/Unofficial-DGNN-PyTorch/blob/master/README.md