ethanbar11 / ssm_2d

More dimensions = More fun
21 stars 2 forks source link

A 2-Dimensional State Space Layer for Spatial Inductive Bias

This repository is a Pytorch implementation of the paper. Alt text

Installation

Verify you have the PyTorch installed (it is working with 1 or 2). Then run:

pip install --r requirments.txt

Experiments Recreation

Supported datasets

Vit Backbones

baselines:

python main.py --model vit --dataset CIFAR100
python main.py --model vit --dataset T-IMNET
python main.py --model swin --dataset CIFAR100 --embed_dim 96
python main.py --model swin --dataset T-IMNET --embed_dim 96
python main.py --model mega --dataset CIFAR100 --ema {choice}

Our Runs

python main.py --model vit --dataset CIFAR100 --no_pos_embedding --use_mix_ffn --ema ssm_2d --normalize --n_ssm=2 --ndim 16 --directions_amount 2 --seed 0
python main.py --model swin --dataset CIFAR100 --ema {choice} --use_mega_gating --embed_dim 96
python main.py --model mega --dataset CIFAR100 --ema ssm_2d --n_ssm 8 --ndim 16

ConvNext

ConvNext for small datasets

To create original results from: https://juliusruseckas.github.io/ml/convnext-cifar10.html Notice original results are with batch-size = 128

python main.py --model convnext-small --dataset CIFAR100 --lr 1e-3 --batch_size 128 --weight-decay 1e-1
python main.py --model convnext-small --dataset T-IMNET --lr 1e-3 --batch_size 128 --weight-decay 1e-1

SSM Real:

python main.py --model convnext-small --dataset CIFAR100 --lr 1e-3 --batch_size 128 --weight-decay 1e-1 --ema ssm_2d --ssm_kernel_size 9 --n_ssm 2 --directions_amount 4 --ndim 16
python main.py --model convnext-small --dataset T-IMNET --lr 1e-3 --batch_size 128 --weight-decay 1e-1 --ema ssm_2d --ssm_kernel_size 13 --n_ssm 2 --directions_amount 2 --ndim 16

SSM Complex:

python main.py --model convnext-small --dataset CIFAR100 --lr 1e-3 --batch_size 128 --weight-decay 1e-1 --ema ssm_2d --ssm_kernel_size 9 --n_ssm 2 --directions_amount 2 --ndim 16 --complex_ssm
python main.py --model convnext-small --dataset T-IMNET --lr 1e-3 --batch_size 128 --weight-decay 1e-1 --ema ssm_2d --ssm_kernel_size 7 --n_ssm 2 --directions_amount 2 --ndim 16 --complex_ssm