Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.18k stars 3.37k forks source link

Parameters and Gradient is not logged by WandB under FSDP strategy #17512

Open weicao1990 opened 1 year ago

weicao1990 commented 1 year ago

Bug description

I find that when using FSDP strategy, the model parameters and gradients are not logged by WandB. However, everything works well if I switch FSDP to native DDP strategy.

Since the gradients are hooked by wandb.run.watch, I am not sure this is a lighting issue or wandb issue.

What version are you seeing the problem on?

master

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version: 2.1.0dev (updated to bleeding-edge) #- WandB Version: 0.14.0 ```

More info

No response

lendle commented 1 year ago

I reported this to wandb a couple of months ago. Here's a workaround I've been using, which I think should also work when not using FSDP:

``

In your LightningModule, implement

def on_before_optimizer_step(self, optimizer):
    if self.trainer._logger_connector.should_update_logs:
        with FullyShardedDataParallel.summon_full_params(
            self.get_configured_model(),
            rank0_only=True,
            writeback=False,
            with_grads=True,
        ):
            if self.global_rank == 0:
                if isinstance(self.logger, WandbLogger):
                    for name, param in self.model.named_parameters():
                        torch_history: TorchHistory = self.logger.experiment._torch
                        torch_history.log_tensor_stats(param, f"parameters/{name}")
                        if param.requires_grad:
                            torch_history.log_tensor_stats(
                                param.grad, f"gradients/{name}"
                            )
stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

lendle commented 1 year ago

Bump

konstantinjdobler commented 1 year ago

Logging parameters w/ FSDP is working for me without any monkey patching but gradients are not being logged despite using .watch(..., log="all").

I'm calling fabric.logger.watch after having called fabric.setup_module(model)