hpcaitech / EnergonAI

Large-scale model inference.
Apache License 2.0
630 stars 90 forks source link

question about load model state_dict in multi-gpus #213

Closed irasin closed 1 year ago

irasin commented 1 year ago

The code is as below.

def load_checkpoint(file,
                    model: torch.nn.Module,
                    strict: bool = True,
                    preprocess_fn: Optional[Callable[[dict], dict]] = None,
                    **kwargs):
    """Loads training states from a checkpoint file.

    Args:
        file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike
            object containing a file name.
        model (:class:`torch.nn.Module`): Model to load saved weights and buffers.
        optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate.
        lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional):
            lr_scheduler to recuperate, defaults to None.
        strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict`
            of the checkpoint match the names of parameters and buffers in model, defaults to True.

    Returns:
        int: The saved epoch number.

    Raises:
        RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
    """
    start = time()
    if gpc.get_local_rank(ParallelMode.MODEL) == 0:
        model_state = load_state_dict(file)
        if preprocess_fn:
            model_state = preprocess_fn(model_state)
    else:
        model_state = dict()
    dist.barrier()
    print(f'Load file time: {time()-start:.3f} s')
    # pipeline
    if is_using_pp():
        model_state = partition_pipeline_parallel_state_dict(model, model_state, **kwargs)
    if "prefix" in kwargs.keys():
        if kwargs['prefix'] != '':
            model_state = remove_prefix(model_state, kwargs["prefix"])

    model.load_state_dict(model_state, strict=strict)
    broadcast_model(model)

When we using the tp=4 parallel, I wonder why here just load_state_dict only 'get_local_rank(ParallelMode.MODEL) == 0'? If so, the rest process will load empty model_state, right?

kurisusnowdeng commented 1 year ago

When using tensor parallel models, parameters are sharded while the weights in the state dict are not. Thus, we load the complete weights at only rank 0 and scatter the corresponding shards to each tensor parallel rank.

irasin commented 1 year ago

Got it~ Thanks a lot