Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.54k stars 3.3k forks source link

WandbLogger crashes when used with TPU VM #18051

Open alstonlo opened 1 year ago

alstonlo commented 1 year ago

Bug description

On a TPU VM, using WandbLogger causes training to crash. I am using the nightly build which I know states "no guarantees", so apologies in advance if this is currently being worked on (I wasn't able to find any relevant issues or PRs). I am also unsure of why this error is occurring, and whether it is an issue with Lightning or WandB.

What version are you seeing the problem on?

master

How to reproduce the bug

import lightning.pytorch as pl
import lightning.pytorch.loggers
import torch
import torch.backends.cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

class LinearDataModule(pl.LightningDataModule):

    def __init__(self):
        super().__init__()
        w = torch.randn([128])
        eps = 0.01 * torch.randn([2400, 1])
        self.X = torch.randn([2400, 128])
        self.Y = torch.sum(w * self.X, dim=-1, keepdim=True) + eps

    def loader(self):
        return DataLoader(
            TensorDataset(self.X, self.Y),
            batch_size=100,
            num_workers=4,
            shuffle=True,
        )

    def train_dataloader(self):
        return self.loader()

    def val_dataloader(self):
        return self.loader()

class LinearRegression(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.proj = nn.Linear(128, 1)

    def step(self, batch, split):
        X, y = batch
        loss = F.mse_loss(self.proj(X), y)
        self.log(f"{split}/loss", loss, sync_dist=(split != "train"), prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, "val")

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)

def train():
    pl.seed_everything(100, workers=True)

    data = LinearDataModule()
    model = LinearRegression()

    trainer = pl.Trainer(
        accelerator="tpu",
        devices=8,
        enable_checkpointing=False,
        precision="bf16-mixed",
        logger=pl.loggers.WandbLogger(project="tpu_debug"),
        max_epochs=100,
        enable_progress_bar=True,
    )

    trainer.fit(model=model, datamodule=data)

if __name__ == "__main__":
    train()

The above code was written to a file train.py and run with

PJRT_DEVICE=TPU python3 -m train

Error messages and logs

Global seed set to 100
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
wandb: Currently logged in as: .... Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.15.5
wandb: Run data is saved locally in ./wandb/run-20230710_212512-3vghnzj8
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run fine-music-19
wandb: ⭐️ View project at https://wandb.ai/...
wandb: 🚀 View run at https://wandb.ai/.../runs/3vghnzj8
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
    shandler(sreq)
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
    dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
    shandler(sreq)
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
    dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
    shandler(sreq)
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
    dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
    shandler(sreq)
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
    dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
    shandler(sreq)
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
    dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
    shandler(sreq)
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
    dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
    shandler(sreq)
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
    dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
    shandler(sreq)
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
    dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
wandb: ERROR Unable to attach to run 3vghnzj8
concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
    replica_results = list(
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
    return fn()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 341, in __call__
    self.fn(global_ordinal(), *self.args, **self.kwargs)
  File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 128, in _wrapping_function
    trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 210, in _deepcopy_tuple
    y = [deepcopy(a, memo) for a in x]
  File "/usr/lib/python3.8/copy.py", line 210, in <listcomp>
    y = [deepcopy(a, memo) for a in x]
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 205, in _deepcopy_list
    append(deepcopy(a, memo))
  File "/usr/lib/python3.8/copy.py", line 161, in deepcopy
    rv = reductor(4)
  File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 356, in __getstate__
    _ = self.experiment
  File "/.../.local/lib/python3.8/site-packages/lightning/fabric/loggers/logger.py", line 114, in experiment
    return fn(self)
  File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 398, in experiment
    self._experiment = wandb._attach(attach_id)
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 877, in _attach
    raise UsageError(f"Unable to attach to run {attach_id}")
wandb.errors.UsageError: Unable to attach to run 3vghnzj8
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/...", line 76, in <module>
    train()
  File "/...", line 72, in train
    trainer.fit(model=model, datamodule=data)
  File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 536, in fit
    call._call_and_handle_interrupt(
  File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 41, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 88, in launch
    process_context = xmp.spawn(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 386, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 365, in spawn
    _run_multiprocess(spawn_fn, start_method=start_method)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
    replica_results = list(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
wandb.errors.UsageError: Unable to attach to run 3vghnzj8
concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
    replica_results = list(
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
    return fn()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 341, in __call__
    self.fn(global_ordinal(), *self.args, **self.kwargs)
  File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 128, in _wrapping_function
    trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 210, in _deepcopy_tuple
    y = [deepcopy(a, memo) for a in x]
  File "/usr/lib/python3.8/copy.py", line 210, in <listcomp>
    y = [deepcopy(a, memo) for a in x]
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 205, in _deepcopy_list
    append(deepcopy(a, memo))
  File "/usr/lib/python3.8/copy.py", line 161, in deepcopy
    rv = reductor(4)
  File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 356, in __getstate__
    _ = self.experiment
  File "/.../.local/lib/python3.8/site-packages/lightning/fabric/loggers/logger.py", line 114, in experiment
    return fn(self)
  File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 398, in experiment
    self._experiment = wandb._attach(attach_id)
  File "/.../.local/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 877, in _attach
    raise UsageError(f"Unable to attach to run {attach_id}")
wandb.errors.UsageError: Unable to attach to run 3vghnzj8
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/...", line 76, in <module>
    train()
  File "/...", line 72, in train
    trainer.fit(model=model, datamodule=data)
  File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 536, in fit
    call._call_and_handle_interrupt(
  File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 41, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 88, in launch
    process_context = xmp.spawn(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 386, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 365, in spawn
    _run_multiprocess(spawn_fn, start_method=start_method)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
    replica_results = list(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
wandb.errors.UsageError: Unable to attach to run 3vghnzj8

Environment

More info

If I train without a logger instead, then no error occurs and the script proceeds normally.

cc @carmocca @JackCaoG @steventk-g @Liyang90 @awaelchli @morganmcg1 @borisdayma @scottire @parambharat

rejuvyesh commented 1 year ago

Very likely caused by #17818. I'm seeing this with multi-gpu as well and it's likely not TPU related.

awaelchli commented 1 year ago

@rejuvyesh Why do you think that it is very likely caused by #17818? Can you git-bisect or provide me with a code example for multi-gpu? I appreciate the help.

EDIT: I ran the above code example with accelerator="cuda" and couldn't see any issues.

awaelchli commented 1 year ago

@alstonlo Thanks for the report. I don't see anything wrong with the code example. My uneducated guess is that maybe it has to do with launching with the PJRT runtime and the feature in wandb for attaching to a run in a subprocess not working well together.

Since you have access to the TPU machine, could I ask you, what happens if you comment out these three lines of code in Lightning: https://github.com/Lightning-AI/lightning/blob/00496da92d9e7d17c81f51c9abfb54583ba2817f/src/lightning/pytorch/loggers/wandb.py#L354-L356

Will it work?

rejuvyesh commented 1 year ago

@awaelchli Haven't done a git bisect yet, but downgrading to 2.0.4 fixed the issue for us. Will attempt one once we have more time and my hunch was that's only major change to happen to that codepath.

alstonlo commented 1 year ago

Only semi-related to the current issue, but rerunning the same script with the nightly build (as of now) raises an error. This is due to the local tpu variable in xla.py not being defined when _XLA_GREATER_EQUAL_2_1 is false.

carmocca commented 1 year ago

@alstonlo My bad! Let me fix that quickly

carmocca commented 1 year ago

Opened https://github.com/Lightning-AI/lightning/pull/18085

alstonlo commented 1 year ago

Thanks!

@awaelchli I have installed lightning directly from #18085 and commented out the suggested lines. The training script runs but no WandB run is ever created and nothing is logged to WandB.

carmocca commented 1 year ago

One way to reduce the surface of issues would be to do

import lightning as L
from lightning.pytorch.loggers.wandb import WandbLogger

def fn(fabric, logger):
    ...

logger = WandbLogger()
fabric = L.Fabric(accelerator="tpu")
fabric.launch(fn, logger)
alstonlo commented 1 year ago

While trying to find a solution for this issue, I think I may have stumbled upon another potential bug (which I suspect may be causing this issue, but I am not sure). For context, I noticed that if I added the following to the LightningModule:

import torch_xla.core.xla_model as xm
from lightning.pytorch.utilities.rank_zero import rank_zero_only

class LinearRegression(pl.LightningModule):

    def setup(self, stage):
        print(f"{rank_zero_only.rank = }, {self.trainer.global_rank = }, {xm.get_ordinal() = }")

then there was a mismatch between rank_zero_only.rank and self.trainer.global_rank (and xm.get_ordinal() agrees with the latter). I think this issue is caused by an interaction between rank_zero_only and xm.rendezvous() (which is called at various points of the Trainer setup). The following is a minimal example:

# debug.py
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from lightning.pytorch.utilities.rank_zero import rank_zero_only

def f(index):
    rank_zero_only.rank = xm.get_ordinal()
    xm.rendezvous("barrier")
    print(f"{rank_zero_only.rank = }, {xm.get_ordinal() = }")

if __name__ == "__main__":
    xmp.spawn(f, args=tuple())
$ PJRT_DEVICE=TPU python3 -m debug
rank_zero_only.rank = 5, xm.get_ordinal() = 4
rank_zero_only.rank = 3, xm.get_ordinal() = 2
rank_zero_only.rank = 1, xm.get_ordinal() = 0
rank_zero_only.rank = 7, xm.get_ordinal() = 6
rank_zero_only.rank = 5, xm.get_ordinal() = 5
rank_zero_only.rank = 1, xm.get_ordinal() = 1
rank_zero_only.rank = 7, xm.get_ordinal() = 7
rank_zero_only.rank = 3, xm.get_ordinal() = 3

If I comment out the xm.rendezvous("barrier") line, then I get

$ PJRT_DEVICE=TPU python3 -m debug
rank_zero_only.rank = 4, xm.get_ordinal() = 4
rank_zero_only.rank = 5, xm.get_ordinal() = 5
rank_zero_only.rank = 2, xm.get_ordinal() = 2
rank_zero_only.rank = 3, xm.get_ordinal() = 3
rank_zero_only.rank = 0, xm.get_ordinal() = 0
rank_zero_only.rank = 1, xm.get_ordinal() = 1
rank_zero_only.rank = 6, xm.get_ordinal() = 6
rank_zero_only.rank = 7, xm.get_ordinal() = 7

If I had instead assigned xm.get_ordinal() to a local variable like so:

def f(index):
    tmp = xm.get_ordinal()
    xm.rendezvous("barrier")
    print(f"{tmp = } {xm.get_ordinal() = }")

then tmp and xm.get_ordinal() match, so I think this is an issue with rank_zero_only.rank.

Liyang90 commented 1 year ago

While trying to find a solution for this issue, I think I may have stumbled upon another potential bug (which I suspect may be causing this issue, but I am not sure). For context, I noticed that if I added the following to the LightningModule:

import torch_xla.core.xla_model as xm
from lightning.pytorch.utilities.rank_zero import rank_zero_only

class LinearRegression(pl.LightningModule):

    def setup(self, stage):
        print(f"{rank_zero_only.rank = }, {self.trainer.global_rank = }, {xm.get_ordinal() = }")

then there was a mismatch between rank_zero_only.rank and self.trainer.global_rank (and xm.get_ordinal() agrees with the latter). I think this issue is caused by an interaction between rank_zero_only and xm.rendezvous() (which is called at various points of the Trainer setup). The following is a minimal example:

# debug.py
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from lightning.pytorch.utilities.rank_zero import rank_zero_only

def f(index):
    rank_zero_only.rank = xm.get_ordinal()
    xm.rendezvous("barrier")
    print(f"{rank_zero_only.rank = }, {xm.get_ordinal() = }")

if __name__ == "__main__":
    xmp.spawn(f, args=tuple())
$ PJRT_DEVICE=TPU python3 -m debug
rank_zero_only.rank = 5, xm.get_ordinal() = 4
rank_zero_only.rank = 3, xm.get_ordinal() = 2
rank_zero_only.rank = 1, xm.get_ordinal() = 0
rank_zero_only.rank = 7, xm.get_ordinal() = 6
rank_zero_only.rank = 5, xm.get_ordinal() = 5
rank_zero_only.rank = 1, xm.get_ordinal() = 1
rank_zero_only.rank = 7, xm.get_ordinal() = 7
rank_zero_only.rank = 3, xm.get_ordinal() = 3

If I comment out the xm.rendezvous("barrier") line, then I get

$ PJRT_DEVICE=TPU python3 -m debug
rank_zero_only.rank = 4, xm.get_ordinal() = 4
rank_zero_only.rank = 5, xm.get_ordinal() = 5
rank_zero_only.rank = 2, xm.get_ordinal() = 2
rank_zero_only.rank = 3, xm.get_ordinal() = 3
rank_zero_only.rank = 0, xm.get_ordinal() = 0
rank_zero_only.rank = 1, xm.get_ordinal() = 1
rank_zero_only.rank = 6, xm.get_ordinal() = 6
rank_zero_only.rank = 7, xm.get_ordinal() = 7

If I had instead assigned xm.get_ordinal() to a local variable like so:

def f(index):
    tmp = xm.get_ordinal()
    xm.rendezvous("barrier")
    print(f"{tmp = } {xm.get_ordinal() = }")

then tmp and xm.get_ordinal() match, so I think this is an issue with rank_zero_only.rank.

The xmp.spawn() on v3 TPUs is multi-process and multi-thread. There are 4 processes for 4 chips, and 2 threads in each process for each core in a chip. So the rank_zero_only object is shared between 2 threads, and that's why modifying one would cause 2 rank_zero_only.rank to have same value. Without xm.rendezvous("barrier"), the print value seems to be right, but it's only transient, and if you sleep for 5 seconds and print again, they would be same as the wrong one.

This is actually the reason why trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs)) is needed on the Lightning code. The shared objects between the threads need to be decoupled.

@will-cromar

s22chan commented 11 months ago

Isn't this a matter of a delayed init after forking? This fixes wandb from initializing 4 times (on a vx-8) and having mixed stream ids.

@@ -59,12 +59,14 @@
     data = LinearDataModule()
     model = LinearRegression()

+    logger=pl.loggers.WandbLogger(project="tpu_debug")
+    logger.experiment
     trainer = pl.Trainer(
         accelerator="tpu",
         devices=8,
         enable_checkpointing=False,
         precision="bf16-mixed",
-        logger=pl.loggers.WandbLogger(project="tpu_debug"),
+        logger=logger,
         max_epochs=100,
         enable_progress_bar=True,
     )

however there will also be these 4 of these warnings from trying to create new session:

.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py:391: UserWarning: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
  rank_zero_warn(

and it will hang (possibly related: https://docs.wandb.ai/guides/integrations/lightning#how-to-use-multiple-gpus-with-lightning-and-wb).

I'm not sure what the proper patch would be within lightning.

s22chan commented 11 months ago

After debugging this for a bit, the issue is that you have to wandb.login before the fit (before the forks?). eg:

@@ -59,12 +59,14 @@
     data = LinearDataModule()
     model = LinearRegression()

+    import wandb
+    wandb.login()
     trainer = pl.Trainer(
         accelerator="tpu",
         devices=8,
s22chan commented 11 months ago

as an aside, I had a (user) issue with consolidating all under one run:

TLDR

Either set things up on Google's TPU VMs via:

python3 -m pip install --upgrade pip
python3 -m pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip wandb -U

or the _WANDB_GREATER_EQUAL_0_12_10 check needs to be less strict

More details

The _WANDB_GREATER_EQUAL_0_12_10 check failure skips the pickling hack to unify the runs.

(Pdb) p _WANDB_GREATER_EQUAL_0_12_10
ContextualVersionConflict: (urllib3 1.25.8 (/usr/lib/python3/dist-packages), Requirement.parse('urllib3>=1.26.11; python_version >= "3.6"'), {'sentry-sdk'}). HINT: Try running `pip install -U 'wandb>=0.12.10'`

pip on a fresh Google's --version=tpu-vm-pt-2.0 TPU VM is 20.0.2, so it doesn't have requirements backtracking. The VM also installed pip via apt, so doing python3 -m pip install --upgrade pip doesn't update pip on the default PATH.

This is my user error in not updating+using the right pip/fixing all the env warnings, but maybe the RequirementCache class might be a bit too strict, and it should just check if the version number is satisfied rather than if all the sub-requirements are also satisfied.

awaelchli commented 11 months ago

Hi @s22chan

After debugging this for a bit, the issue is that you have to wandb.login before the fit (before the forks?). eg:

I recommend that you do wandb login in the command line instead (one time only). Then you will be automatically logged in whenever you call wandb in Python.

Regarding the other issue:

We have this trick in the logger to init the experiment when processes get launched (see comment in the code): https://github.com/Lightning-AI/lightning/blob/6511ac28759718a524dd00e627c186fb6baea763/src/lightning/pytorch/loggers/wandb.py#L349-L356

It would be very helpful if you could check whether this code path gets triggered or not in your case. I didn't fully understand your comment about _WANDB_GREATER_EQUAL_0_12_10. Are you saying you have wandb>=0.12.10 installed, yet the check failed and defaulted to False? If so, we could consider setting this version as the minimum required version, so we don't have to check it in the first place.

s22chan commented 11 months ago

@awaelchli sorry if the messages were a bit scattered yesterday.

I recommend that you do wandb login in the command line instead (one time only).

I've already done that. The wandb.login() before the fork/spawn is required to avoid a datarace between the two TPU threads launched on rank 0 for the logger init, which leads to the original reported crash.

@alstonlo is inferring that much of the rank_zero mechanisms in place for logging/profiling(/other?) doesn't work in a TPU scenario with the PJRT change because there are now two threads that have rank 0.

I didn't fully understand your comment about _WANDB_GREATER_EQUAL_0_12_10.

Wandb was wandb==0.15.7, but because of a conflict in urllib3 (which is a sub-dependency of wandb), the bool cast from RequirementCache fails. This is super not obvious as a user.

carlesoctav commented 3 months ago

any updates on this issue?

s22chan commented 2 weeks ago

related: https://github.com/Lightning-AI/pytorch-lightning/issues/19035 (not wandb but logging and dataraces on the threads)