Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
Apache License 2.0
28.03k stars 3.36k forks source link

wandb logger 'global_step' affects other logger #1485

Closed olineumann closed 4 years ago

olineumann commented 4 years ago

🐛 Bug

The wandb logger adds a 'global_step' to the metric dict which appears in all other loggers (e.g. Tensorboard). Only the wandb logger is adding 'global_step' to metric and I think it is not necessary. Another side effect of that is, that 'global_step' is also added to empty dicts which then are logged and resulting to strange graphs like this:


or this


I also wrote a simple logger class to print out metrics. I got this output:

Step  0
{'global_step': 0}
Step  10
{'global_step': 10}
Step  190
{'global_step': 190}
Step  200
{'global_step': 200}
Step  0
{'val/mse': 0.01713273860514164, 'train/mse': 0.04259789362549782, 'global_step': 0}
Step  207
{'global_step': 207}
Step  217
{'global_step': 217}
Step  397
{'global_step': 397}
Step  407
{'global_step': 407}
Step  1
{'val/mse': 0.013123581185936928, 'train/mse': 0.01449404377490282, 'global_step': 1}
Step  414
{'global_step': 414}
Step  424
{'global_step': 424}
Step  604
{'global_step': 604}
Step  614
{'global_step': 614}
Step  2
{'val/mse': 0.012394818477332592, 'train/mse': 0.012575697153806686, 'global_step': 2}


Step  5
{'val/mse': 0.012411396019160748, 'train/mse': 0.011899641714990139, 'global_step': 5}
Step  1242
{'global_step': 1242}
Step  1252
{'global_step': 1252}
Step  1432
{'global_step': 1432}
Step  1442
{'global_step': 1442}
Step  6
{'val/mse': 0.01244258601218462, 'train/mse': 0.011944737285375595, 'global_step': 6}
Step  1449
{'global_step': 1449}
Step  1459
{'global_step': 1459}
Step  1639
{'global_step': 1639}
Step  1649
{'global_step': 1649}
Step  7
{'val/mse': 0.01261985208839178, 'train/mse': 0.011924241669476032, 'global_step': 7}
Step  1656
{'global_step': 1656}
Step  1666
{'global_step': 1666}
Step  1846
{'global_step': 1846}
Step  1856
{'global_step': 1856}
Step  8
{'val/mse': 0.012863481417298317, 'train/mse': 0.011850016191601753, 'global_step': 8}
Step  1863
{'global_step': 1863}
Step  1873
Step  2053
{'global_step': 2053}
Step  2063
{'global_step': 2063}

Also notice: I set max_epochs to 10 so expected to be 10 measurements. The last one is missing. But this could be handled in an other issue.

To Reproduce

Steps to reproduce the behavior:

  1. Use training_epoch_end and validation_epoch_end to log metric like {'log': {'loss': loss}} (see code bellow)
  2. Run training with wandb logger and one more logger of your choice.
  3. See global_step graphs.

Code sample

Important LightningModule Methods:

def training_step(self, batch, batch_idx):
        # calculate actual model prediction given batch
        # and calculate loss
        x, y = batch
        y_hat = self(x)

        # print out current loss on training every n-th iteration
        loss = F.mse_loss(y_hat, y)
        return {
            "loss": loss

def training_epoch_end(self, outputs):
    loss_mean = torch.stack([x["loss"] for x in outputs]).mean().item()
    return {
        "log": {
            "train/mse": loss_mean,
            "step": self.current_epoch

def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)        
    return {'val_loss': F.mse_loss(y_hat, y)}

def validation_epoch_end(self, outputs):
    val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean().item()
    return {
        "val_loss": val_loss_mean,
        "log": {
            "val/mse": val_loss_mean,
            "step": self.current_epoch


clbk_terminal = TerminalCallback()
checkpoint = ModelCheckpoint(filepath="ckpts/" + name + "{_val_loss:.5f}_{epoch:03d}", prefix="BasicNN_", monitor="val_loss", verbose=False, save_top_k=3, save_weights_only=True)
earlystopping = EarlyStopping(monitor="val_loss", patience=25, verbose=True)
loggers = [
    WandbLogger(project="nwp-energy-load", name=name, log_model=True),
    TensorBoardLogger(save_dir="tb_logs", name=name, version=0),
    MyLogger() # only prints metric; can also be ignored

trainer = Trainer(gpus=-1, max_epochs=10, progress_bar_refresh_rate=0, logger=loggers, log_save_interval=1, row_log_interval=10,
                  callbacks=[], early_stop_callback=earlystopping, checkpoint_callback=checkpoint)

Expected behavior

Is 'global_step' needed in wandb logger? If so, it should not affect other loggers. Also if there is nothing to log (e.g. in training_step) the logger should log nothing.


Linux Arch Python 3.8.2 Pytorch 1.4.0 Pytorch_Lightning 0.7.3

github-actions[bot] commented 4 years ago

Hi! thanks for your contribution!, great first issue!

olineumann commented 4 years ago

I looked over the code and found the issue. Maybe the wandb python API didn't accepted step as parameter in the log function in a version before. So the step was added to the metric dict (and not copied so affected other loggers).

Also I think that empty metric dicts could be skipped in the base logger. You can see my fix in the following commit:

Borda commented 4 years ago

@olineumann nice, mind send a PR, seems that is could be one-click only :]

rohitgr7 commented 4 years ago

@Borda Now wandb is giving warnings wandb: WARNING Adding to old History rows isn't currently supported. Step 25 < 38 and not logging when I try to use the WandbLogger with k-fold cross-validation because there I am using the same instance of wandb_logger but using multiple times for different train_dl and valid_dl. Since the step gets repeated in each case, it's not logging anything after the 1st fold is complete even though the log keys are completely different. For now, I have to create a different logger separately for each fold, but is there any other way around to make it work with the single instance.

olineumann commented 4 years ago

I also noticed that 'epoch' appearing now after upgrading to the current version in the metric dict without me logging any 'epoch'. It comes from trainer.logging:

Because of this the step=global_step and when I log something in my pytorch_lightning module with step=epoch it will crash because of the same reason. The only way to solve it is to pass a log dict in train_step and validation_step only containing 'step': epoch.

What would be the best solution? I think:

Sounds that useful/logical? I think this should work but I'm also tired. I can create a PR if wanted in the next days (Fr/Sa).

Borda commented 4 years ago

@olineumann I would prefer the second as we do not want to affect the other loggers

olineumann commented 4 years ago

I just wanted to fix it and pulled the current master from GitHub. It seems to be fixed already.

thomassajot commented 2 years ago

Is this issue solved? I am experiencing similar chart issues in Wandb:
