Adds the functionality to perform distributed data parallel training with spark. The logic is as follows:
The user provides a spark dataframe and sets a configuration with how many nodes and how many GPUs each node has.
We'll then have one task for each GPU in the cluster and thus partition the dataframe accordingly.
We save the partitioned dataframe and get the names of the generated parquet files.
Each task will compute its global rank, load its corresponding file and use that to train.
There were a couple of challenges:
The final model is serialized and sent back to the driver, so we should make sure that it doesn't contain any exotic things (to avoid pickling errors), thus we remove the _trainer attribute (and thus the trainer property) from the model.
The save method of the models used the trainer's save_checkpoint method. Since we won't have the trainer anymore this implements very simple methods to save and load models, which use only the init params and weights (which will also make the files smaller). The premise here is that we don't actually need all the stuff that the checkpoint has in order to load the model for inference. This tries to maintain backward compatibility by using the same names as pytorch lightning does (hyper_parameters and state_dict).
Also makes the following change, which isn't strictly necessary and could be made in a separate PR:
Ensures that the original aliases are preserved when saving and loading models. Right now when loading a saved model it'll use the default alias, so if an AutoNHITS was trained with the alias 'my_model' after loading it and making predictions with it the column will be named 'NHITS'.
Adds the functionality to perform distributed data parallel training with spark. The logic is as follows:
There were a couple of challenges:
_trainer
attribute (and thus thetrainer
property) from the model.save
method of the models used the trainer'ssave_checkpoint
method. Since we won't have the trainer anymore this implements very simple methods to save and load models, which use only the init params and weights (which will also make the files smaller). The premise here is that we don't actually need all the stuff that the checkpoint has in order to load the model for inference. This tries to maintain backward compatibility by using the same names as pytorch lightning does (hyper_parameters
andstate_dict
).Also makes the following change, which isn't strictly necessary and could be made in a separate PR: