An unofficial PyTorch implementation of the paper "End-to-End Human-Gaze-Target Detection with Transformers".
We provide a pip requirements file to install all the dependencies. We recommend using a conda environment to install the dependencies.
# Clone project and submodules
git clone --recursive https://github.com/francescotonini/human-gaze-target-detection-transformer.git
cd human-gaze-target-detection-transformer
# Create conda environment
conda create -n human-gaze-target-detection-transformer python=3.9
conda activate human-gaze-target-detection-transformer
# Install requirements
pip install -r requirements.txt
(optional) Setup wandb
cp .env.example .env
# Add token to .env
The code expects that the datasets are placed under the data/ folder.
You can change this by modifying the data_dir
parameter in the configuration files.
cat <<EOT >> configs/local/default.yaml
# @package _global_
paths:
data_dir: "{PATH TO DATASETS}"
EOT
The implementation requires faces annotations ("auxiliary faces", i.e. the ones not annotated by GazeFollow or VideoAttentionTarget). Therefore, you need run the following script to extract face annotations.
# GazeFollow
python scripts/gazefollow_get_aux_faces.py --dataset_path /path/to/gazefollow --subset train
python scripts/gazefollow_get_aux_faces.py --dataset_path /path/to/gazefollow --subset test
# VideoAttentionTarget
cp data/videoattentiontarget_extended/*.csv /path/to/videoattentiontarget
python scripts/videoattentiontarget_get_aux_faces.py --dataset_path /path/to/videoattentiontarget --subset train
python scripts/videoattentiontarget_get_aux_faces.py --dataset_path /path/to/videoattentiontarget --subset test
We provide configuration to train on GazeFollow and VideoAttentionTarget (see configs/experiment/).
# GazeFollow
python src/train.py experiment=hgttr_gazefollow
# VideoAttentionTarget
python src/train.py experiment=hgttr_videoattentiontarget +model.net_pretraining={URL/PATH TO GAZEFOLLOW WEIGHTS}
The configuration files are also useful when evaluating the model.
# GazeFollow
python src/eval.py experiment=hgttr_gazefollow ckpt_path={PATH TO CHECKPOINT}
# or
python src/eval.py experiment=hgttr_gazefollow +model.net_pretraining={URL/PATH TO WEIGHTS}
# VideoAttentionTarget
python src/eval.py experiment=hgttr_videoattentiontarget ckpt_path={PATH TO CHECKPOINT}
# or
python src/eval.py experiment=hgttr_videoattentiontarget +model.net_pretraining={URL/PATH TO WEIGHTS}
We provide model weights for GazeFollow at this URL and VideoAttentionTarget at this URL.
This code is based on PyTorch Lightning, Hydra, and the official DETR implementation.
If you use this code implementation or our trained weights in your research, please cite us:
@inproceedings{tonini2023objectaware,
title={Object-aware Gaze Target Detection},
author={Tonini, Francesco and Dall'Asen, Nicola and Beyan, Cigdem and Ricci, Elisa},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={21860--21869},
year={2023}
}
and the original paper:
@inproceedings{tu2022end,
title={End-to-end human-gaze-target detection with transformers},
author={Tu, Danyang and Min, Xiongkuo and Duan, Huiyu and Guo, Guodong and Zhai, Guangtao and Shen, Wei},
booktitle={2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month={June},
year={2022},
}