BICLab / Spike-Driven-Transformer

Offical implementation of "Spike-driven Transformer" (NeurIPS2023)
https://openreview.net/forum?id=9FmolyOHi5
Apache License 2.0
212 stars 15 forks source link

The config for cifar10 in Paper #1

Closed FengLiXue closed 1 year ago

FengLiXue commented 1 year ago

Hi, First of all thanks for this great work. I have tried all the configs under the conf/cifar10 folder, but none of them can meet the accuracy in the paper. Could you provide detailed configurations that can meet the accuracy in the paper? Thank you very much!

jkhu29 commented 1 year ago

https://github.com/BICLab/Spike-Driven-Transformer/blob/main/conf/cifar10/2_512_300E_t4.yml is the conf we used.

If the model fails to achieve an accuracy of 95% on the cifar10 dataset after you train with this config, I would recommend you try to extend the number of layers to 4 (but you don't actually have to do that, maybe it's a random number seed problem, or maybe it's an issue in the hardware and software environment, such as an old CUDA version)

FengLiXue commented 1 year ago

Thanks for reply. I tried conf/cifar10/2_256_300E_t4.yml, 2_512_300E_t4.yml and 2_512_300E_t4_TET.yml. I also extend the number of layers to 4, but all the accuracy is about 85%, there is a big gap between 95%, it doesn't seem like a random seed problem. The GPU I used is 3090, and the CUDA version is 11.6, pytorch version is 1.13.0. I don't know where I didn't do it right. If you could provide help, I would be very grateful!

jkhu29 commented 1 year ago

Maybe you can provide us with your tensorboard file and we can analyze your loss or accuracy together.

jkhu29 commented 1 year ago

2023-08-16_15-29

You can see that the model accuracy has reached 86.35 at epoch 27, which is the result of the two RTX3090 I just used to run this git.

jkhu29 commented 1 year ago

The launch command is

CUDA_VISIBLE_DEVICES=0,1 /usr/bin/python3 -m torch.distributed.launch --nproc_per_node=2 --master_port 29501 train.py -c conf/cifar10/2_512_300E_t4.yml --model sdt --spike-mode lif
FengLiXue commented 1 year ago

@jkhu29 Firstly, I am very grateful for the work you have done for my problem. I have put the training log here. It would be great if you could help me take a look. image As the logs, when the accuracy reaches 84%, the model accuracy will improve very slowly, this makes me very confused. Thank you very much again!

jkhu29 commented 1 year ago

The batch-size you used is 64, CIFAR10 has 50000 training files and 10000 testing files, so the log should be [0/780], not the 3298. I guess there is something wrong with your data, please follow the Pytorch website documentation to configure your data.

FengLiXue commented 1 year ago

Thank you very much for pointing out my mistake. I did indeed make a mistake with the dataset, and I am very sorry for taking up so much of your time. Once again, I would like to express my gratitude for your work

jkhu29 commented 1 year ago

Great, I hope it works for you. If your problem has been solved, please close this issue, you can reopen it if you have any further problems.