Closed irasin closed 1 year ago
When using tensor parallel models, parameters are sharded while the weights in the state dict are not. Thus, we load the complete weights at only rank 0 and scatter the corresponding shards to each tensor parallel rank.
Got it~ Thanks a lot
The code is as below.
When we using the tp=4 parallel, I wonder why here just load_state_dict only 'get_local_rank(ParallelMode.MODEL) == 0'? If so, the rest process will load empty model_state, right?