OpenGVLab / Ask-Anything

[CVPR2024 Highlight][VideoChatGPT] ChatGPT with video understanding! And many more supported LMs such as miniGPT4, StableLM, and MOSS.
MIT License
2.85k stars 230 forks source link

Using Gradient Accumulation #174

Open bexxnaz opened 1 month ago

bexxnaz commented 1 month ago

Hello Thanks for your great work. I need to use gradient accumulation on batches due to RAM constraints. The training loop involves iterating over two modalities. I am concerned about the implications of using gradient accumulation in this scenario. Is it possible and recommended to use gradient accumulation with multiple modalities in an iterator?

 with torch.cuda.amp.autocast(enabled=config.fp16):
            loss_dict = model(image, text)
            loss = sum(loss_dict.values()) / config.accumulate_grad_batches  
        accumulated_batches += 1
        if accumulated_batches % config.accumulate_grad_batches == 0:
            if config.optimizer.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)

            optimizer.zero_grad()  # Reset gradients only after optimizer step
            scheduler.step()  # Step the scheduler as per your original strategy
            accumulated_batches = 0  
Andy1621 commented 1 month ago

Good question! Can you try to use a small batch? From my previous experience, the results are similar.

Besides, if you want to use gradient accumulation, I think the multiple modalities in an iterator will be a good baseline, but some papers argue that single modalities in an iterator will be better. If you want to realize it, a simple strategy is to split the input data manually.

bexxnaz commented 1 month ago

Thank you for your response. I have another question regarding the extra_num_query_tokens. Specifically, I'm interested in understanding if you've tested the scenario where this parameter is set to 0. How does this compression of visual tokens affect performance?

Andy1621 commented 1 month ago

There is an ablation in our paper. And 0 extra query lead to poorer performance on MVBench
