BAAI-DCAI / M3D

M3D: Advancing 3D Medical Image Analysis with Multi-Modal Large Language Models
MIT License
171 stars 9 forks source link

Problem about seg module when finetuning with lora #20

Open Skylight-Lark opened 1 week ago

Skylight-Lark commented 1 week ago

Hi @baifanxxx : I'm encountering an issue where the forward pass of the SegVol class hangs when the image is passed to image_encoder, resulting in NCCL communication timeouts in finetuning with lora. Below is the relevant part of the code:

class SegVol(nn.Module):
    def __init__(self, 
                 image_encoder, 
                 mask_decoder,
                 prompt_encoder,
                 roi_size,
                 patch_size):
        super().__init__()
        self.image_encoder = image_encoder
        self.mask_decoder = mask_decoder
        self.prompt_encoder = prompt_encoder
        self.feat_shape = np.array(roi_size) / np.array(patch_size)

    def forward(self, image, text_emb=None, text=None, boxes=None, points=None):
        bs = image.shape[0]
        img_shape = (image.shape[2], image.shape[3], image.shape[4])
        print("image shape is ", image.shape)
        image_embedding, _ = self.image_encoder(image)

        image_embedding = image_embedding.transpose(1, 2).view(
            bs, -1, 
            int(self.feat_shape[0]), 
            int(self.feat_shape[1]), 
            int(self.feat_shape[2])
        )
        print("image_embedding is ", image_embedding)

        logits = self.forward_decoder(image_embedding, img_shape, 
                                      text_emb=text_emb, 
                                      text=text, 
                                      boxes=boxes, 
                                      points=points)

        return logits

Problem:

Questions:

  1. What might be causing the process to hang when passing the image tensor to the image_encoder?
  2. How can I resolve the NCCL communication timeout issue in this context?
  3. Is there a recommended strategy to debug this kind of issue (e.g., reducing tensor size, tuning NCCL parameters, or checking GPU utilization)?
  4. Do you use zero3 or zero2 when finetuning?

Environment:

Any insights or suggestions would be appreciated!


You can modify the environment details as needed before submitting!

baifanxxx commented 3 days ago

Hi,

I'm sorry I'm late. I've been a little busy lately. I've also encountered NCCL communication problems, especially when training larger models, such as Llama-8B. I guess that both text and segmentation tasks exist in the batch, resulting in different sequence operation times in the same batch. That's just the reason for my suspicions. One possible solution is to decouple text tasks from split tasks as much as possible. The segmentation tasks need additional segmentation modules, resulting in unbalanced computing time. Therefore, for larger models, the text and segmentation tasks can be trained separately.

Another approach is multi-step fine-tuning rather than one step, for example, by fine-tuning the model with text dataset instructions and then using segmentation datasets alone for instruction fine-tuning, avoiding mixing two different data types.

If you have successfully solved this problem using a better method, please share it with us and thank you.

Skylight-Lark commented 2 days ago

Hi,

Yeah, I think it is the reason that segmentation modules are not always calculated as you said, so it might have no gradients of seg modules in some machines when using ddp. If you use sharded parameters technology like zero methods in deepspeed which shard all parameters in the models, there will be some nccl communication problems. So my solution is use FSDP integrated with accelerate, and only wrap LlamaDecoderLayer in LLM backbone with FSDP sharding strategy. Below is my accelerate configuration:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: "no"
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false