We propose an alternative to classical attention that scales linearly with the number of tokens and is based on high order moments.
The HoMM scheme is as follows: Having a query token $x_q$ and a set of context tokens $x_c$, we first use a projection $ho$ to map each token $x_c$ to a high dimension space, where the high-order moments are computed recursively (by chunking and performing element-wise product, and then averaging over the tokens). $x_q$ is projected into the same high dimensional space with a projection $s$. The element-wise product of the two corresponds to $x_q$ selecting the information it needs in the high-order moments of $x_c$. The results is then projected back to the same space as $x_q$ and added to the original tokens via a residual connection.
/!\ Help welcome: DM me on twitter (https://twitter.com/david_picard), or submit an issue, or email me!
Easy targets if you want to contribute
This repo supports hydra for handling configs. Look at src/configs to edit them. Here is an example of a training run:
python src/train.py data.dataset_builder.data_dir=path_to_imagenet seed=3407 model.network.dim=128 data.size=224 model.network.kernel_size=32 model.network.nb_layers=12 model.network.order=2 model.network.order_expand=4 model.network.ffw_expand=4 model.network.dropout=0.0 model.optimizer.optim.weight_decay=0.01 model.optimizer.optim.lr=1e-3 data.full_batch_size=1024 trainer.max_steps=300000 model.lr_scheduler.warmup_steps=10000 computer.num_workers=8 computer.precision=bf16-mixed data/additional_train_transforms=randaugment data.additional_train_transforms.randaugment_p=0.1 data.additional_train_transforms.randaugment_magnitude=6 model.train_batch_preprocess.apply_transform_prob=1.0 checkpoint_dir="./checkpoints/"
python src/train.py --config-name train_mae data.dataset_builder.data_dir=path_to_dataset seed=3407 model.network.dim=128 data.size=256 model.network.kernel_size=16 model.network.nb_layers=8 model.network.order=4 model.network.order_expand=8 model.network.ffw_expand=4 model.network.dropout=0.0 model.optimizer.optim.weight_decay=0.01 model.optimizer.optim.lr=1e-3 data.full_batch_size=256 trainer.max_steps=300000 model.lr_scheduler.warmup_steps=10000 computer.num_workers=8 computer.precision=bf16-mixed data/additional_train_transforms=randaugment data.additional_train_transforms.randaugment_p=0.1 data.additional_train_transforms.randaugment_magnitude=6 model.train_batch_preprocess.apply_transform_prob=1.0 checkpoint_dir="./checkpoints/"
python src/train_gnn.py
python src/optimize_hps_gnn.py
src/gnn_homm_nb.ipynb
src/configs/train_gnn.yml
src/configs/hp_opt_gnn.yml
On imagenet, with the following parameters:
dim | o | oe | acc | Flops | # params |
---|---|---|---|---|---|
320 | 1 | 8 | 43.6 | 2.6G | 26M |
320 | 2 | 4 | 47.6 | 2.6G | 26M |
320 | 4 | 2 | 46.1 | 2.6G | 26M |
256 | 2 | 8 | 47.9 | 2.9G | 29M |
256 | 4 | 4 | 46.1 | 2.9G | 29M |
Clearly, having the second order makes a big difference. Having the fourth order not so much. It's better to have a higher dimension and lower expansion than the contrary.