JinhuiYE / SignCL

This is the official code repository for the paper 'Improving Gloss-free Sign Language Translation by Reducing Representation Density'.
13 stars 0 forks source link

SignCL: Improving Gloss-free Sign Language Translation by Reducing Representation Density (NeurIPS'24)

Overview

SignCL is a PyTorch module designed to enhance sign language translation models by encouraging the learning of more discriminative feature representations. It brings the visual representations of sign gestures with identical semantics closer together while pushing those with different semantics farther apart through contrastive learning. This module can be integrated into both the pretraining and finetuning stages of a sign language translation model. Experiments demonstrate that SignCL can significantly reduce representation density and improve performance across various translation frameworks.

Representation Density and Performance Drop

We consistently observed a negative relationship between representation density and performance. Specifically, an increase in the representation density (+26%) can result in a performance drop (-39%) in BLEU score.

Click to see figure!

Installation

To use SignCL, ensure you have the following dependencies installed:

Usage

Here's a step-by-step guide to integrating SignCL into your sign language translation model.

0. cl_criterion = SignCL()
1. frames_feature = model.encoder(src_input)
2. margin = min(20, max(10, int(num_frames // text_length * 2.3)))
3. cl_loss = cl_criterion(frames_feature, margin=margin)
4. total_loss = lambda_ * cl_loss + original_loss

A. Usage example in your framework:

Click to expand! ```python import torch import torch.nn as nn import torch.optim as optim from sign_cl import SignCL # Define the Contrastive Loss Criterion cl_criterion = SignCL(max_distance=32.0, pos_samples=2, neg_samples=4) # Assume you have a model, data loader, and other necessary components model = YourSignLanguageModel() optimizer = optim.Adam(model.parameters(), lr=0.001) # Example training loop for epoch in range(num_epochs): for batch in data_loader: src_input, text_input = batch['src'], batch['text'] # Forward pass frames_feature = model.encoder(src_input) num_frames = frames_feature.size(1) text_length = len(text_input) # Assuming text_input is the corresponding text margin = min(20, max(10, int(num_frames // text_length * 2.3))*2) cl_loss = cl_criterion(frames_feature, margin=margin) original_loss = ... # Compute your original loss here lambda_ = 0.01 # Weight for the contrastive loss, adjust as necessary total_loss = lambda_ * cl_loss + original_loss # Backward pass and optimization optimizer.zero_grad() total_loss.backward() optimizer.step() print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss.item()}") ```

B. Usage example for GFSLT-VLP:

This example code was modified from GFSLT-VLP GitHub. Please refer to their homepage to set up the environment and dataset.

To execute, use the following command:

bash examples/scripts.sh

This script will execute the training and evaluation process, demonstrating how to integrate the SignCL loss function into the GFSLT-VLP framework. We also included our self-reproduced results and log.txt on the CSL-Daily dataset (see link).

Click to expand results!
Table 1: Enhancing GFSLT-VLP by reducing representation density on CSL-Daily test set.

Citation

Note if you find this code work for your research, please cite the following paper:

@inproceedings{ye2024improving,
  title={Improving Gloss-free Sign Language Translation by Reducing Representation Density},
  author={Ye, Jinhui and Wang, Xing and Jiao, Wenxiang and Liang, Junwei and Xiong, Hui},
  journal={arXiv preprint arXiv:2405.14312},
  year={2024}
}