However, it's taking about ~2.2s/iteration which works out to at least ~80 hours of training time, (assuming at least 700 steps per epoch for a train batch size of 64 for CIFAR10) rather than 36 as stated in the paper (https://arxiv.org/pdf/1912.03263.pdf, pg 4). Running on a p3.2xlarge instance on AWS. Could you please help explain the discrepancy?
Hi,
I'm trying to to run the JEM training algorithm in train_wrn_ebm.py, using
python train_wrn_ebm.py --lr .0001 --dataset cifar10 --optimizer adam --p_x_weight 1.0 --p_y_given_x_weight 1.0 --p_x_y_weight 0.0 --sigma .03 --width 10 --depth 28 --save_dir /YOUR/SAVE/DIR --plot_uncond --warmup_iters 1000.
However, it's taking about ~2.2s/iteration which works out to at least ~80 hours of training time, (assuming at least 700 steps per epoch for a train batch size of 64 for CIFAR10) rather than 36 as stated in the paper (https://arxiv.org/pdf/1912.03263.pdf, pg 4). Running on a p3.2xlarge instance on AWS. Could you please help explain the discrepancy?
Thanks!