francescotonini / human-gaze-target-detection-transformer

An implementation of the paper "End-to-End Human-Gaze-Target Detection with Transformers"
16 stars 2 forks source link

End-to-End Human-Gaze-Target Detection with Transformers

An unofficial PyTorch implementation of the paper "End-to-End Human-Gaze-Target Detection with Transformers".

Prerequisites

Environment and dependencies

We provide a pip requirements file to install all the dependencies. We recommend using a conda environment to install the dependencies.

# Clone project and submodules
git clone --recursive https://github.com/francescotonini/human-gaze-target-detection-transformer.git
cd human-gaze-target-detection-transformer

# Create conda environment
conda create -n human-gaze-target-detection-transformer python=3.9
conda activate human-gaze-target-detection-transformer

# Install requirements
pip install -r requirements.txt

(optional) Setup wandb

cp .env.example .env

# Add token to .env

Dataset preprocessing

The code expects that the datasets are placed under the data/ folder. You can change this by modifying the data_dir parameter in the configuration files.

cat <<EOT >> configs/local/default.yaml
# @package _global_

paths:
  data_dir: "{PATH TO DATASETS}"
EOT

The implementation requires faces annotations ("auxiliary faces", i.e. the ones not annotated by GazeFollow or VideoAttentionTarget). Therefore, you need run the following script to extract face annotations.

# GazeFollow
python scripts/gazefollow_get_aux_faces.py --dataset_path /path/to/gazefollow --subset train
python scripts/gazefollow_get_aux_faces.py --dataset_path /path/to/gazefollow --subset test

# VideoAttentionTarget
cp data/videoattentiontarget_extended/*.csv /path/to/videoattentiontarget

python scripts/videoattentiontarget_get_aux_faces.py --dataset_path /path/to/videoattentiontarget --subset train
python scripts/videoattentiontarget_get_aux_faces.py --dataset_path /path/to/videoattentiontarget --subset test

Training

We provide configuration to train on GazeFollow and VideoAttentionTarget (see configs/experiment/).

# GazeFollow
python src/train.py experiment=hgttr_gazefollow

# VideoAttentionTarget
python src/train.py experiment=hgttr_videoattentiontarget +model.net_pretraining={URL/PATH TO GAZEFOLLOW WEIGHTS}

Evaluation

The configuration files are also useful when evaluating the model.

# GazeFollow
python src/eval.py experiment=hgttr_gazefollow ckpt_path={PATH TO CHECKPOINT}
# or
python src/eval.py experiment=hgttr_gazefollow +model.net_pretraining={URL/PATH TO WEIGHTS}

# VideoAttentionTarget
python src/eval.py experiment=hgttr_videoattentiontarget ckpt_path={PATH TO CHECKPOINT}
# or
python src/eval.py experiment=hgttr_videoattentiontarget +model.net_pretraining={URL/PATH TO WEIGHTS}

Checkpoints

We provide model weights for GazeFollow at this URL and VideoAttentionTarget at this URL.

Acknowledgements

This code is based on PyTorch Lightning, Hydra, and the official DETR implementation.

Cite us

If you use this code implementation or our trained weights in your research, please cite us:

@inproceedings{tonini2023objectaware,
  title={Object-aware Gaze Target Detection},
  author={Tonini, Francesco and Dall'Asen, Nicola and Beyan, Cigdem and Ricci, Elisa},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={21860--21869},
  year={2023}
}

and the original paper:

@inproceedings{tu2022end,
  title={End-to-end human-gaze-target detection with transformers},
  author={Tu, Danyang and Min, Xiongkuo and Duan, Huiyu and Guo, Guodong and Zhai, Guangtao and Shen, Wei},
  booktitle={2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  month={June},
  year={2022},
}