google-research / mixmatch

Apache License 2.0
1.13k stars 163 forks source link

How to use multi-gpus for training? #17

Closed vozhuo closed 4 years ago

vozhuo commented 4 years ago

I have 4 gpus and I noticed that libml/utils.py have some function (para_list, para_mean, etc) for parallel training, but it's seems not to be used. I'm not familiar with multi gpus training, how can I use these functions, should I modify mixmatch.py?

david-berthelot commented 4 years ago

Indeed it is not used, I wrote this function in case I needed to use multiple GPUs but since it was fast enough with a single one, I ended up not using them. That being said the change should be easy, something like:

logits = utils.para_cat(lambda x: classifier(x, training=True), images)
vozhuo commented 4 years ago

Thank you!