dvlab-research / LLaMA-VID

Official Implementation for LLaMA-VID: An Image is Worth 2 Tokens in Large Language Models
Apache License 2.0
622 stars 39 forks source link

About ZERO3 #75

Closed xxtars closed 3 months ago

xxtars commented 3 months ago

Hello, thank you very much for open-sourcing this work.

I encountered some problems during training. Each of my GPUs has only 40GB of memory. In stage2, despite using zero2_offload and setting --per_device_train_batch_size to 1, I still encountered OOM. Therefore, I plan to try zero3 or zero3_offload.

However, I encountered new issues during loading. I used the zero3.json provided by llava. However, I encountered some problems when loading qformer. Firstly, I loaded "bert-base-uncased" through transformers:

mm_model = BertLMHeadModelQF.from_pretrained(
    "bert-base-uncased", config=encoder_config
)

I'm not very familiar with DeepSpeed, and I'm not sure if zero3 handles this part of the loading process. Later, when loading the pretrained qformer, an error occurs:

self.vlm_att_projector.load_state_dict(get_w(att_projector_weights, 'vlm_att_projector'))

Error: bert.encoder.layer.0.attention.self.query.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([0]). Do you know how to handle this? Any help would be greatly appreciated.

jimchenhub commented 2 months ago

I encountered the same problem. Have you solved it?

xxtars commented 2 months ago

@jimchenhub I ask this question in DeepSpeed. The suggestion seems to be initially effective (test with several iterations), but I haven't verified it across the entire stage. You might give it a try.

BlueBlueFF commented 2 months ago

Do you meet this error?RuntimeError: still have inflight params

xxtars commented 2 months ago

Hello, I did not encounter this error. I share my load code here, but I cannot guarantee that the final training result is correct. I hope it helps you.

        def maybe_zero3_load_state_dict(module: nn.Module, state_dict):
            def check_zero3_optimization(model):
                for name, param in model.named_parameters():
                    if hasattr(param, 'ds_id'):
                        return True
                return False
            if check_zero3_optimization(module):
                missing_keys: List[str] = []
                unexpected_keys: List[str] = []
                error_msgs: List[str] = []
                # copy state_dict so _load_from_state_dict can modify it
                metadata = getattr(state_dict, '_metadata', None)
                state_dict = OrderedDict(state_dict)
                if metadata is not None:
                    # mypy isn't aware that "_metadata" exists in state_dict
                    state_dict._metadata = metadata  # type: ignore[attr-defined]

                def load(module: nn.Module, local_state_dict, prefix=""):
                    # because zero3 puts placeholders in model params, this context
                    # manager gathers (unpartitions) the params of the current layer, then loads from
                    # the state dict and then re-partitions them again
                    local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
                    with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
                        if deepspeed.comm.get_rank() == 0:
                            module._load_from_state_dict(local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)

                    for name, child in module._modules.items():
                        if child is not None:
                            child_prefix = prefix + name + "."
                            child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
                            load(child, child_state_dict, child_prefix)

                load(module, state_dict)
                del load
                del check_zero3_optimization
            else:
                module.load_state_dict(state_dict)
                del check_zero3_optimization
zhuqiangLu commented 2 months ago

Hello, I did not encounter this error. I share my load code here, but I cannot guarantee that the final training result is correct. I hope it helps you.

        def maybe_zero3_load_state_dict(module: nn.Module, state_dict):
            def check_zero3_optimization(model):
                for name, param in model.named_parameters():
                    if hasattr(param, 'ds_id'):
                        return True
                return False
            if check_zero3_optimization(module):
                missing_keys: List[str] = []
                unexpected_keys: List[str] = []
                error_msgs: List[str] = []
                # copy state_dict so _load_from_state_dict can modify it
                metadata = getattr(state_dict, '_metadata', None)
                state_dict = OrderedDict(state_dict)
                if metadata is not None:
                    # mypy isn't aware that "_metadata" exists in state_dict
                    state_dict._metadata = metadata  # type: ignore[attr-defined]

                def load(module: nn.Module, local_state_dict, prefix=""):
                    # because zero3 puts placeholders in model params, this context
                    # manager gathers (unpartitions) the params of the current layer, then loads from
                    # the state dict and then re-partitions them again
                    local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
                    with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
                        if deepspeed.comm.get_rank() == 0:
                            module._load_from_state_dict(local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)

                    for name, child in module._modules.items():
                        if child is not None:
                            child_prefix = prefix + name + "."
                            child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
                            load(child, child_state_dict, child_prefix)

                load(module, state_dict)
                del load
                del check_zero3_optimization
            else:
                module.load_state_dict(state_dict)
                del check_zero3_optimization

Hi, could you please add more details in how you enable zero3 offload? I try to fine-tune this model with zero3 offload, but it get stuck at the beginning of the training.

xxtars commented 2 months ago

@zhuqiangLu Sorry, I encountered a stuck issue during training stage 2 while using zero3. However, I can train stage 1 using zero3. So, I'm not sure if it's a problem with parameter loading or deepspeed. If you have any updates, please let me know, Thanks.

zhuqiangLu commented 1 month ago

@zhuqiangLu Sorry, I encountered a stuck issue during training stage 2 while using zero3. However, I can train stage 1 using zero3. So, I'm not sure if it's a problem with parameter loading or deepspeed. If you have any updates, please let me know, Thanks.

It turns out the stuck is caused by zero3 does not support imbalanced load. If the batch size varies across GPUs, zero3 will get stuck. So I have to use zero2 offload for stage 2 and freezing Q-former. Now I am running into another OOM issue at step 990 (batch size 16 in total). Still investigating.