Skip extra dataset state load. Previously, if loading a checkpoint with train_dataloader, we would first load the dataset_state on load. After loading, if we again set train_dataloader, it would call load_state_dict with a state_dict that had a value of None. This fixes the check on the setter of train_dataloader to properly skip the extra load.
What does this PR do?
Skip extra dataset state load. Previously, if loading a checkpoint with
train_dataloader
, we would first load the dataset_state on load. After loading, if we again set train_dataloader, it would callload_state_dict
with a state_dict that had a value ofNone
. This fixes the check on the setter oftrain_dataloader
to properly skip the extra load.