apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

Distributed training -- where is the final model? Why are validation scores different? #1410

Closed svohara closed 8 years ago

svohara commented 8 years ago

First of all, I would like to express that I appreciate the significant development effort required to produce mxnet, and I am very excited to use this tool and hopefully to contribute features as time goes on. However, there are areas that could use more polish. My reason for using mxnet is for distributed training (cluster/AWS). Since few of the most popular packages do a good job in this regard, it is the single most important differentiating factor, in my opinion, for using mxnet.

As such, I seem to be having a disconnect between what I expect the distributed training process should produce (and how validation results should be reported per epoch) and what I observe actually happens. I'm using ssh based distributed training on AWS. Training cifar10 on 3 nodes, for example, generates checkpoints for each of the three nodes. Training seems to be successful, yet there is no obvious "final model" that is produced and the three node-specific models that are saved by the process at the last epoch yield different validation scores.

Issue 1 The user expects that there will be an obvious final model produced as a result of the training process, and there is not with how the examples/tutorials are structured.

Issue 2 The user expects that after parameter synchronization, the models on each of the distributed nodes are, well, synchronized (meaning identical parameters). They appear not to be. If they were, then after each synchronization event, the validation scores should be the same.

Expectations These two expectations may be the result of a misunderstanding on how distributed training works and what a parameter server actually synchronizes. I outline my understanding below, and if this is way off, then I would ask for some user-friendly high-level documentation with a training flow chart of some sort that makes it clear what each artifact represents and what to do at the end of the process to generate the single "final" model that I desire to use at run-time for predictions on new data.

My basic understanding is that, when using "dist-sync", at regular points in time (e.g., each epoch, but maybe more frequently? each mini-batch?), the parameter server collects the parameters from each of the nodes, performs some sort of aggregation to generate a consolidated/updated parameter set, and then distributes this update to each of the nodes. At this point in time, after synchronization but before continuing to train, the models at each of the nodes should be identical. After the synchronization, each node continues the training process independently until the next synchronization point. The process continues until the training completes, and then a final synchronization should yield a single consolidated model with parameters aggregated from the nodes.

Observations My observations on where what happens clashes with my expectations:

  1. When done, I am left with K different models (param files) and symbols (json), where K is the number of nodes. I would expect that a distributed training job should end up producing one final consolidated model, just as if I had only used single-node training, but of course, faster. The other artifacts produced by the separate nodes are useful, but when all is done, it should be obvious which is my final model that I should use for prediction.
  2. When I load the K different models that result, and test their performance on the validation data, I get different results! I expect that after syncing using the parameter server, that the final three models would be the same. If not, at what point can/should I checkpoint in order to get the model after the parameter synchronization has completed?
  3. When one observes the training output of the K nodes, at each epoch when the validation accuracy is reported, each of the K models has a different value. Is this because the validation data has also been partitioned and thus the K models are evaluating on different data? (If so, I think this should at least be optional -- partition training but not the presumably smaller validation data).
  4. If the answer to 3 is NOT that the validation data was partitioned, and in fact all K models are evaluating on the same validation set, then this becomes a similar question to 2 above. Why don't we, during the dist-sync process, end up with K identical models after the synchronization? This is what I intuitively expect, but it would seem the validation results are produced prior to synchronization. Why is this good/desired? Doesn't the user want to know, at the end of each epoch, how the model (in general, after updates from the K nodes) is performing?

Bottom Line In summary, the mechanism by which we train the model (local, distributed, sych/asynch, multi-gpu, cpus, etc.) is irrelevant at the highest-level of reporting performance at each epoch and at the end of the process. The name of the final model params and symbol files produced should be independent of the training mechanism.

Anyone writing a wrapper or automation tool around using mxnet to train a deep net would want the outputs to be the same format regardless of mechanism. There should not be extra undocumented work that the user must perform in order to take the results of distributed training and produce the desired final model.

mli commented 8 years ago

hi stephen

many thanks for the feedbacks. correct me if i'm wrong, the main question you asked is why there are n different models even using synchronized training, e.g. dist_sync.

the answer is that all parameters expect for the batch normalization layer should be identical. the problem for the batch normalization is that it maintains a guess of the mean and variance of the gradient, which is correlated to the local data that node seem. we didn't do a synchronization between nodes to average their values during training.

as a results, we have several different models at the end even for dist_sync. we can either pick one of them, or do a model average. we will be very appreciated if you have a better idea.

svohara commented 8 years ago

Thank you for your very speedy response. However, I suggest we try to make the situation better. I assume that my high level understanding of how the dist-sync process works is essentially correct, given that you made no corrections, so I'm continuing with this in mind.

  1. Your response suggests that the user deploy an ensemble if they trained with a distributed net (perhaps only in the case if batch norm was used in the model design??). Thus if they trained over 28 nodes, e.g., you are recommending they deploy an ensemble of 28 models, so they can average the predictions of each? And thus, the more nodes they use in training, the worse the deployment to a run-time application becomes? As I stated in my original post, the mechanisms by which we train the model would ideally not effect the deployment/usage/interpretation of that model. I suppose we could always advise the user to simply pick the final model with the highest validation score, but then why not automate this to produce the final model?
  2. If you know that the batch norm layer will create extra effort/issues on behalf of the users, I suggest that this should be pointed out clearly in the documentation and tutorials. I suggest the distributed training docs/tutorials should not end at: "yay! It ran to completion!", but rather: "now that training is completed, here's how we deploy the model to perform prediction on new images".
  3. Why not allow the batch norm params to be synchronized, at least as an option? The normalization would be better, it would seem, if it were generalized across the entire training set by communicating each node's normalization stats to the PS and then aggregating and updating.
  4. Even if the above is not advisable, how about at the very end of the process, we aggregate the batch norm values for use in the final model for run-time prediction? After all, the idea is to use statistics over the entire training corpus as an estimate to help us center new/unseen data. The more data the normalization stats are derived from, the better it should generalize. (I guess this simply restates the previous point, but makes the concession that maybe we aggregate the BN stats only at the very end, if for some reason, that's much preferred for accuracy or speed reasons.)
  5. Finally, I interpret your answer that if I had no batch-norm in my deep net design, then I should observe identical validation scores at each epoch in the distributed training, and that my K final models would in fact be identical, so picking any one as the final model would suffice. (I will test this and see if this assertion holds).
tqchen commented 8 years ago

On the batchnorm question, the statistics on batchnorm is only needed at prediction time, and it is not part of model parameters used during training, the related parameters are synchronized.

So the batchnorm only need to be aggregated once after the training is done

mli commented 8 years ago

my answers to your comments:

  1. a postprocessing to average the models into a single one will work. or even simply pick one model should be good enough. for large datasets such as imagenet, the model variance is not so big
  2. can you help clarify the document?
  3. refer to tianqi's comments
  4. that should work, i think we already have a such script @antinucleon
  5. yeah. i tested that months ago. if you cannot repeat it then there should be a bug
svohara commented 8 years ago

Follow up to my original observation 3 -- training output shows different validation accuracy

Below is part of the output for distributed training of cifar10 data using a simple LeNet network, thus no batch norm. You can see that each of the 3 nodes reports different validation accuracy at each epoch. This output still contradicts my expectations, but it may be explained if the validation scores are computed prior to synchronization (which I feel is a mistake, but perhaps there are implementation reasons why this is best).

[16:14:22] src/io/./iter_normalize.h:103: Load mean image from s3://sandbox-svo/datasets/cifar/data/mean.bin
[16:14:22] src/io/./iter_normalize.h:103: Load mean image from s3://sandbox-svo/datasets/cifar/data/mean.bin
2016-02-04 16:14:23,146 Node[2] Start training with [gpu(0)]
2016-02-04 16:14:23,247 Node[0] Start training with [gpu(0)]
[16:14:23] src/io/./iter_normalize.h:103: Load mean image from s3://sandbox-svo/datasets/cifar/data/mean.bin
2016-02-04 16:14:23,374 Node[1] Start training with [gpu(0)]
2016-02-04 16:14:27,286 Node[1] Epoch[0] Batch [50] Speed: 5842.33 samples/sec  Train-accuracy=0.302500
2016-02-04 16:14:27,364 Node[0] Epoch[0] Batch [50] Speed: 5774.41 samples/sec  Train-accuracy=0.290781
2016-02-04 16:14:27,384 Node[2] Epoch[0] Batch [50] Speed: 5633.65 samples/sec  Train-accuracy=0.303594
2016-02-04 16:14:28,417 Node[0] Epoch[0] Batch [100]    Speed: 6075.84 samples/sec  Train-accuracy=0.335859
2016-02-04 16:14:28,438 Node[2] Epoch[0] Batch [100]    Speed: 6076.09 samples/sec  Train-accuracy=0.330078
2016-02-04 16:14:28,340 Node[1] Epoch[0] Batch [100]    Speed: 6073.98 samples/sec  Train-accuracy=0.339922
2016-02-04 16:14:29,816 Node[1] Epoch[0] Batch [150]    Speed: 4336.60 samples/sec  Train-accuracy=0.366406
2016-02-04 16:14:29,896 Node[0] Epoch[0] Batch [150]    Speed: 4328.20 samples/sec  Train-accuracy=0.363698
2016-02-04 16:14:29,917 Node[2] Epoch[0] Batch [150]    Speed: 4327.61 samples/sec  Train-accuracy=0.359010
2016-02-04 16:14:29,977 Node[1] Epoch[0] Train-accuracy=0.368339
2016-02-04 16:14:29,978 Node[1] Epoch[0] Time cost=4.133
2016-02-04 16:14:30,078 Node[2] Epoch[0] Train-accuracy=0.362179
2016-02-04 16:14:30,078 Node[2] Epoch[0] Time cost=4.142
2016-02-04 16:14:30,058 Node[0] Epoch[0] Train-accuracy=0.365685
2016-02-04 16:14:30,058 Node[0] Epoch[0] Time cost=4.143
2016-02-04 16:14:30,764 Node[0] Epoch[0] Validation-accuracy=0.438657
2016-02-04 16:14:30,729 Node[1] Epoch[0] Validation-accuracy=0.436921
2016-02-04 16:14:30,915 Node[2] Epoch[0] Validation-accuracy=0.422776
2016-02-04 16:14:31,047 Node[0] Saved checkpoint to "s3://sandbox-svo/datasets/cifar/models/cifar10_lenet-0-0001.params"
2016-02-04 16:14:31,068 Node[1] Saved checkpoint to "s3://sandbox-svo/datasets/cifar/models/cifar10_lenet-1-0001.params"
2016-02-04 16:14:31,226 Node[2] Saved checkpoint to "s3://sandbox-svo/datasets/cifar/models/cifar10_lenet-2-0001.params"
2016-02-04 16:14:32,651 Node[0] Epoch[1] Batch [50] Speed: 5146.04 samples/sec  Train-accuracy=0.435937
2016-02-04 16:14:32,671 Node[2] Epoch[1] Batch [50] Speed: 6033.58 samples/sec  Train-accuracy=0.426875
2016-02-04 16:14:32,574 Node[1] Epoch[1] Batch [50] Speed: 5467.13 samples/sec  Train-accuracy=0.442031
2016-02-04 16:14:33,723 Node[2] Epoch[1] Batch [100]    Speed: 6084.47 samples/sec  Train-accuracy=0.434141
2016-02-04 16:14:33,626 Node[1] Epoch[1] Batch [100]    Speed: 6084.57 samples/sec  Train-accuracy=0.444609
2016-02-04 16:14:33,703 Node[0] Epoch[1] Batch [100]    Speed: 6083.52 samples/sec  Train-accuracy=0.443594
2016-02-04 16:14:35,067 Node[1] Epoch[1] Batch [150]    Speed: 4441.59 samples/sec  Train-accuracy=0.452448
2016-02-04 16:14:35,146 Node[0] Epoch[1] Batch [150]    Speed: 4435.86 samples/sec  Train-accuracy=0.447865
2016-02-04 16:14:35,168 Node[2] Epoch[1] Batch [150]    Speed: 4432.12 samples/sec  Train-accuracy=0.445156
2016-02-04 16:14:35,228 Node[1] Epoch[1] Train-accuracy=0.454026
2016-02-04 16:14:35,228 Node[1] Epoch[1] Time cost=4.159
2016-02-04 16:14:35,328 Node[2] Epoch[1] Train-accuracy=0.447917
2016-02-04 16:14:35,328 Node[2] Epoch[1] Time cost=4.102
2016-02-04 16:14:35,308 Node[0] Epoch[1] Train-accuracy=0.448768
2016-02-04 16:14:35,308 Node[0] Epoch[1] Time cost=4.261
2016-02-04 16:14:35,617 Node[1] Epoch[1] Validation-accuracy=0.499399
2016-02-04 16:14:35,718 Node[0] Epoch[1] Validation-accuracy=0.490685
2016-02-04 16:14:35,799 Node[2] Epoch[1] Validation-accuracy=0.477464
2016-02-04 16:14:35,855 Node[1] Saved checkpoint to "s3://sandbox-svo/datasets/cifar/models/cifar10_lenet-1-0002.params"
2016-02-04 16:14:35,966 Node[0] Saved checkpoint to "s3://sandbox-svo/datasets/cifar/models/cifar10_lenet-0-0002.params"
2016-02-04 16:14:36,387 Node[2] Saved checkpoint to "s3://sandbox-svo/datasets/cifar/models/cifar10_lenet-2-0002.params"

Follow up to my original observation 2 -- k final models are different

After the training using LeNet completed, I tested each of the three saved models on the validation data and confirmed that they do produce the exact same output, so the final models are identical. Thus the usage of batch norm is indeed the culprit for my earlier observation, and with no batch norm, all is well.

Follow up to discussion on synchronizing batch norm

When we load the params file, there are "arg params" and "aux params". I believe the batch norm parameters are found in the "aux params", correct? Is it a true statement that any aux_params will NOT be synchronized by the parameter server, and thus will cause differences between the distributed models during the training process?

@tqchen, why not synchronize the batch norm variables during training? The paper that introduced the batch norm procedure didn't discuss what to do in the case of distributed training, but it would seem that if we shared the parameters during learning across the nodes that we'd have a better estimate at each learning stage from which to center the data. Since we are using data parallelism, and each of k nodes only sees 1/k of the data, especially when k is large, wouldn't we need to share the data statistics across the nodes in order to not have poorly estimated values at some of the nodes?

tqchen commented 8 years ago
mli commented 8 years ago

the reason you see different validation acc during training is because each node only runs validation on a part of the validation data. assume there are n nodes, we partition validation data into n parts, node i only report the results on part i.