yahoo / spivak

Apache License 2.0
30 stars 2 forks source link

DDP training #8

Open lpc-eol opened 1 year ago

lpc-eol commented 1 year ago

Thank you for your excellent work. You used a single V100 GPU for training. Will the programme support distributed training? We are trying to use multiple 4090 GPUs on the same machine to repeat the experiments.

jvbsoares commented 1 year ago

Thank you for your interest in our work! Indeed, we have always trained with a single V100 GPU. We have not experimented with distributed training for this model. Our implementation is based on the Keras Model class, so it's possible that it would be able to support distributed training with some small changes. We haven't actually tried to implement it, so we're not really sure. During training, the Keras Model is created from some function calls within load_or_create_trainer() at https://github.com/yahoo/spivak/blob/master/spivak/application/model_creation.py#L200 The related model.compile() and model.fit() functions are inside the DefaultTrainer class: https://github.com/yahoo/spivak/blob/master/spivak/models/trainer.py#L49 Please let me know if you end up trying the distributed training. Thank you!