machine-discovery / deer

Parallelizing non-linear sequential models over the sequence length
BSD 3-Clause "New" or "Revised" License
40 stars 1 forks source link

Help replicating the scifar experiment? #38

Open xaviergonzalez opened 1 month ago

xaviergonzalez commented 1 month ago

Would it be possible to clarify which hyperparameter settings (including choice of hardware) led to the reported 90.25% Multi-head GRU accuracy on sequential CIFAR reported in Table 2? (https://arxiv.org/pdf/2309.12252)

Looking at the configs currently in the github (config0.yaml and config_rnn2.yaml), they both seem to correspond to different sections of the discussion in Appendix B.4. https://github.com/machine-discovery/deer/tree/main/experiments/05_rnn_scifar/configs

In particular, config0.yaml seems to match up in terms of optimization hyperparmeters, while config_rnn2.yaml seems to match up in terms of model architecture hyperparamters (except for the the fact that the architecture used is RNNNet2, which is stated to "follow the architecture from https://github.com/thjashin/multires-conv/blob/main/classification.py"

This is confusing because MultiresNet (Shi et al., 2023) 93.15% is listed as a different entry in Table 2.

When I try to run the code, using config0.yaml, I get 79% test set accuracy instead of the reported 90% test accuracy, and the log records 182,154 parameters while Appendix B.4 of the paper reports 1,347,082 parameters.

As a final minor note, it appears that the text in Section 4.4 is not agreement with Table 2. The text says "Table 2 shows that multi-head GRU can achieve 89.35% test accuracy," whereas the table reports 90.25% accuracy.

It would be a great help for replicability if the expected test accuracy and the hyperparameters (including choice of hardware) needed to attain this accuracy could be clarified.

mfkasim1 commented 1 month ago

If I remember it correctly, we ran the experiment for quite long time, so by looking at config0.yaml vs config_rnn2.yaml, it's using the latter one with 500k steps. About my comment "following the architecture from ...", I think it's because I just copied some part of the architecture (like the arrangement of embedding + linear + time mixing in a block), but a quick glance between those 2 architectures, they are not the same. MultiResnet is using convolution while RNNNet2 is using nonlinear RNN.

xaviergonzalez commented 1 month ago

Thank you for clarifying. What hardware did you use?