Caiyun-AI / DCFormer

MIT License
185 stars 15 forks source link

DCFormer

This repository contains the code in both PyTorch and Jax for our paper.

Improving Transformers with Dynamically Composable Multi-Head Attention\ Da Xiao, Qingye Meng, Shengping Li, Xingyuan Yuan\ ICML 2024 (oral)

DCFormer

About

We propose Dynamically Composable Multi-Head Attention (DCMHA), a parameter and computation efficient attention architecture that tackles the shortcomings of Multi-Head Attention(MHA) and increases the expressive power of the model by dynamically composing attention heads. At the core of DCMHA is a Compose function that transforms the attention score and weight matrices in an input-dependent way. DCMHA can be used as a drop-in replacement of MHA in any transformer architecture to obtain the corresponding DCFormer.

In practice, we train DCFormer on TPU for efficiency and then infer on GPU for convenience, so we open-source Jax training code and PyTorch inference code in two separate folders.

Jax

PyTorch

Synthetic tasks

Synthetic tasks dataset in the paper is located at data/synthetic_dataset.jsonl. Each line in the file contains an in-context learning sample, where words in the bracket [] are compared to calculate accuracy. Eg.

< Elizabeth has jeans. Steven has strawberries. Steven has a fox. >. Steven does not have a kind of clothing
< Karen has a taxi. Karen has a pig. Paul has a cocktail. >. Karen does not have a kind of drink
< Ruth has a sweater. Steven has a taxi. Ruth has a handgun. >. Ruth does not have a kind of [ vehicle ]
. . .