davidpicard / HoMM

High order Moment Models
22 stars 7 forks source link

High order Moment Models

We propose an alternative to classical attention that scales linearly with the number of tokens and is based on high order moments.

homm scheme

The HoMM scheme is as follows: Having a query token $x_q$ and a set of context tokens $x_c$, we first use a projection $ho$ to map each token $x_c$ to a high dimension space, where the high-order moments are computed recursively (by chunking and performing element-wise product, and then averaging over the tokens). $x_q$ is projected into the same high dimensional space with a projection $s$. The element-wise product of the two corresponds to $x_q$ selecting the information it needs in the high-order moments of $x_c$. The results is then projected back to the same space as $x_q$ and added to the original tokens via a residual connection.

/!\ Help welcome: DM me on twitter (https://twitter.com/david_picard), or submit an issue, or email me!

Changelog

Fix me

Easy targets if you want to contribute

Currently testing on

Launching a classification training run

This repo supports hydra for handling configs. Look at src/configs to edit them. Here is an example of a training run:

python src/train.py data.dataset_builder.data_dir=path_to_imagenet seed=3407 model.network.dim=128  data.size=224 model.network.kernel_size=32 model.network.nb_layers=12 model.network.order=2 model.network.order_expand=4 model.network.ffw_expand=4  model.network.dropout=0.0 model.optimizer.optim.weight_decay=0.01 model.optimizer.optim.lr=1e-3 data.full_batch_size=1024 trainer.max_steps=300000 model.lr_scheduler.warmup_steps=10000 computer.num_workers=8 computer.precision=bf16-mixed data/additional_train_transforms=randaugment data.additional_train_transforms.randaugment_p=0.1 data.additional_train_transforms.randaugment_magnitude=6 model.train_batch_preprocess.apply_transform_prob=1.0 checkpoint_dir="./checkpoints/"

Launching MAE training run

python src/train.py --config-name train_mae data.dataset_builder.data_dir=path_to_dataset seed=3407 model.network.dim=128  data.size=256 model.network.kernel_size=16 model.network.nb_layers=8 model.network.order=4 model.network.order_expand=8 model.network.ffw_expand=4  model.network.dropout=0.0 model.optimizer.optim.weight_decay=0.01 model.optimizer.optim.lr=1e-3 data.full_batch_size=256 trainer.max_steps=300000 model.lr_scheduler.warmup_steps=10000 computer.num_workers=8 computer.precision=bf16-mixed data/additional_train_transforms=randaugment data.additional_train_transforms.randaugment_p=0.1 data.additional_train_transforms.randaugment_magnitude=6 model.train_batch_preprocess.apply_transform_prob=1.0 checkpoint_dir="./checkpoints/"

GAT-HoMM: a Graph Neural Network with HoMM Attention

TODO:

Ablation

On imagenet, with the following parameters:

dim o oe acc Flops # params
320 1 8 43.6 2.6G 26M
320 2 4 47.6 2.6G 26M
320 4 2 46.1 2.6G 26M
256 2 8 47.9 2.9G 29M
256 4 4 46.1 2.9G 29M

Clearly, having the second order makes a big difference. Having the fourth order not so much. It's better to have a higher dimension and lower expansion than the contrary.