choasma / HSIC-bottleneck

The HSIC Bottleneck: Deep Learning without Back-Propagation
https://arxiv.org/abs/1908.01580
MIT License
81 stars 17 forks source link

Accuracy of hsicsolve is zero? #7

Closed Yocodeyo closed 4 years ago

Yocodeyo commented 4 years ago

Hi, may I know why I got accuracy of zero for all 5 epochs when I run 'run_hsicbt -cfg config/hsicsolve-beta.yaml -tt hsictrain -dc mnist -lr 0.001 -s 6 -ld 200 -dt 5'? It will print out accuracy for unformat training which should not be zero right? Thank you!

choasma commented 4 years ago

hi @Yocodeyo, thanks for running this, actually you should easily get the high accuracy at first epoch with MNIST assuming the config file is the same. For instance, this is the what I just trained:

$ > run_hsicbt -cfg config/hsicsolve-beta.yaml -tt hsictrain -dc mnist -lr 0.001 -s 6 -ld 200 -dt 5
Loaded  [config/hsicsolve-beta.yaml]
# # # # # # # # # # # # # # # # # # # #
#     HSIC-Bottleneck training
# # # # # # # # # # # # # # # # # # # #
Train Epoch: 1 [ 60160/60160 (100%)] H_hx:5.7293 H_hy:7.9651: 100%|█████████████████| 235/235.0 [00:36<00:00,  6.38it/s]
Epoch - [0001]: Training Acc: 96.69
Epoch - [0001]: Testing  Acc: 96.26

I guess you can try to run the following snippet code to see what's going on for your failing case. It produces the argmax of the activation according each class, and save it as a figure you can check visually.

from hsicbt.utils import plot
your_activation = "./assets/activation/hsic-solve-hsictrain-mnist.npy"
plot.plot_activation_distribution(your_activation, 'just for test')
# [7, 6, 5, 4, 3, 2, 8, 1, 0, 9] <-- unformatted order printed, if you have repeated digits here then it's incorrect
plot.save_figure('/tmp/out.png')

Let me know what's your experience after this debugging, Thanks. (also I might be wrong)

Yocodeyo commented 4 years ago

Thank you so much for your prompt response!

I git clone this repository without making any changes to the code and tried running again and this is the output:

$ run_hsicbt -cfg config/hsicsolve-beta.yaml -tt hsictrain -dc mnist -lr 0.001 -s 6 -ld 200 -dt 5
Loaded  [config/hsicsolve-beta.yaml]
# # # # # # # # # # # # # # # # # # # #
#     HSIC-Bottleneck training
# # # # # # # # # # # # # # # # # # # #
Train Epoch: 1 [ 60160/60160 (100%)] H_hx:5.7292 H_hy:7.9651: 100%|███████████████████| 235/235 
[00:44<00:00,  5.32it/s]
Epoch - [0001]: Training Acc: 0.00
Epoch - [0001]: Testing  Acc: 0.00

For your information, I think training process is going well as the accuracy given by format training is quite high.

And thank you for your suggested snippet code! I got [7, 6, 5, 4, 3, 2, 8, 1, 0, 9] from that. This output look correct so it's still confusing why the accuracy is not reflected in a proper way.

choasma commented 4 years ago

unfortunately I can't reproduce your result, but here is the list I could recommend you for debugging purpose: 1) see unformatted order and result image by plot module (which you've did) 2) check the misc.get_accurarcy_hsic function at the line in misc module, which is invoked by line in engine.py. The idea of that function is getting all activation and find the argmax one by one according to each class. (sorry it is quite ugly) 3) if the train_acc is correct in engine.py, then it probably is the issue of print_highlight function I made for color printing purpose. Maybe switch it to python builtin print.

again.... appreciate any suggestion and help. if you find something suspicious please let me know thanks

Yocodeyo commented 4 years ago

I found the cause of this problem! It seems to be related to how division is handled in python 2.7. Because in misc.get_accurarcy_hsic function, num_correct.shape[0] is an int and out.shape[0] is an int. If we follow accuracy = float(num_correct.shape[0]/out.shape[0]), it will be float(zero)=0.0. By changing to accuracy = float(num_correct.shape[0])/float(out.shape[0]), the problem is solved. Thanks a lot for your help! And this HSIC-Bottleneck is really an interesting piece of work:)

choasma commented 4 years ago

ah!! That's an very embarrassing mistake. I have modified that line with your changing on your behalf. Here is the fixed commit (link) Sorry the code is not optimized with lots things hard-coded. Let me know if you have any obstacles in the future. Really appreciate.