Open alstonlo opened 1 year ago
Very likely caused by #17818. I'm seeing this with multi-gpu as well and it's likely not TPU related.
@rejuvyesh Why do you think that it is very likely caused by #17818? Can you git-bisect or provide me with a code example for multi-gpu? I appreciate the help.
EDIT: I ran the above code example with accelerator="cuda"
and couldn't see any issues.
@alstonlo Thanks for the report. I don't see anything wrong with the code example. My uneducated guess is that maybe it has to do with launching with the PJRT runtime and the feature in wandb for attaching to a run in a subprocess not working well together.
Since you have access to the TPU machine, could I ask you, what happens if you comment out these three lines of code in Lightning: https://github.com/Lightning-AI/lightning/blob/00496da92d9e7d17c81f51c9abfb54583ba2817f/src/lightning/pytorch/loggers/wandb.py#L354-L356
Will it work?
@awaelchli Haven't done a git bisect yet, but downgrading to 2.0.4
fixed the issue for us. Will attempt one once we have more time and my hunch was that's only major change to happen to that codepath.
Only semi-related to the current issue, but rerunning the same script with the nightly build (as of now) raises an error. This is due to the local tpu
variable in xla.py not being defined when _XLA_GREATER_EQUAL_2_1
is false.
@alstonlo My bad! Let me fix that quickly
Thanks!
@awaelchli I have installed lightning directly from #18085 and commented out the suggested lines. The training script runs but no WandB run is ever created and nothing is logged to WandB.
One way to reduce the surface of issues would be to do
import lightning as L
from lightning.pytorch.loggers.wandb import WandbLogger
def fn(fabric, logger):
...
logger = WandbLogger()
fabric = L.Fabric(accelerator="tpu")
fabric.launch(fn, logger)
While trying to find a solution for this issue, I think I may have stumbled upon another potential bug (which I suspect may be causing this issue, but I am not sure). For context, I noticed that if I added the following to the LightningModule
:
import torch_xla.core.xla_model as xm
from lightning.pytorch.utilities.rank_zero import rank_zero_only
class LinearRegression(pl.LightningModule):
def setup(self, stage):
print(f"{rank_zero_only.rank = }, {self.trainer.global_rank = }, {xm.get_ordinal() = }")
then there was a mismatch between rank_zero_only.rank
and self.trainer.global_rank
(and xm.get_ordinal()
agrees with the latter). I think this issue is caused by an interaction between rank_zero_only
and xm.rendezvous()
(which is called at various points of the Trainer
setup). The following is a minimal example:
# debug.py
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from lightning.pytorch.utilities.rank_zero import rank_zero_only
def f(index):
rank_zero_only.rank = xm.get_ordinal()
xm.rendezvous("barrier")
print(f"{rank_zero_only.rank = }, {xm.get_ordinal() = }")
if __name__ == "__main__":
xmp.spawn(f, args=tuple())
$ PJRT_DEVICE=TPU python3 -m debug
rank_zero_only.rank = 5, xm.get_ordinal() = 4
rank_zero_only.rank = 3, xm.get_ordinal() = 2
rank_zero_only.rank = 1, xm.get_ordinal() = 0
rank_zero_only.rank = 7, xm.get_ordinal() = 6
rank_zero_only.rank = 5, xm.get_ordinal() = 5
rank_zero_only.rank = 1, xm.get_ordinal() = 1
rank_zero_only.rank = 7, xm.get_ordinal() = 7
rank_zero_only.rank = 3, xm.get_ordinal() = 3
If I comment out the xm.rendezvous("barrier")
line, then I get
$ PJRT_DEVICE=TPU python3 -m debug
rank_zero_only.rank = 4, xm.get_ordinal() = 4
rank_zero_only.rank = 5, xm.get_ordinal() = 5
rank_zero_only.rank = 2, xm.get_ordinal() = 2
rank_zero_only.rank = 3, xm.get_ordinal() = 3
rank_zero_only.rank = 0, xm.get_ordinal() = 0
rank_zero_only.rank = 1, xm.get_ordinal() = 1
rank_zero_only.rank = 6, xm.get_ordinal() = 6
rank_zero_only.rank = 7, xm.get_ordinal() = 7
If I had instead assigned xm.get_ordinal()
to a local variable like so:
def f(index):
tmp = xm.get_ordinal()
xm.rendezvous("barrier")
print(f"{tmp = } {xm.get_ordinal() = }")
then tmp
and xm.get_ordinal()
match, so I think this is an issue with rank_zero_only.rank
.
While trying to find a solution for this issue, I think I may have stumbled upon another potential bug (which I suspect may be causing this issue, but I am not sure). For context, I noticed that if I added the following to the
LightningModule
:import torch_xla.core.xla_model as xm from lightning.pytorch.utilities.rank_zero import rank_zero_only class LinearRegression(pl.LightningModule): def setup(self, stage): print(f"{rank_zero_only.rank = }, {self.trainer.global_rank = }, {xm.get_ordinal() = }")
then there was a mismatch between
rank_zero_only.rank
andself.trainer.global_rank
(andxm.get_ordinal()
agrees with the latter). I think this issue is caused by an interaction betweenrank_zero_only
andxm.rendezvous()
(which is called at various points of theTrainer
setup). The following is a minimal example:# debug.py import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp from lightning.pytorch.utilities.rank_zero import rank_zero_only def f(index): rank_zero_only.rank = xm.get_ordinal() xm.rendezvous("barrier") print(f"{rank_zero_only.rank = }, {xm.get_ordinal() = }") if __name__ == "__main__": xmp.spawn(f, args=tuple())
$ PJRT_DEVICE=TPU python3 -m debug rank_zero_only.rank = 5, xm.get_ordinal() = 4 rank_zero_only.rank = 3, xm.get_ordinal() = 2 rank_zero_only.rank = 1, xm.get_ordinal() = 0 rank_zero_only.rank = 7, xm.get_ordinal() = 6 rank_zero_only.rank = 5, xm.get_ordinal() = 5 rank_zero_only.rank = 1, xm.get_ordinal() = 1 rank_zero_only.rank = 7, xm.get_ordinal() = 7 rank_zero_only.rank = 3, xm.get_ordinal() = 3
If I comment out the
xm.rendezvous("barrier")
line, then I get$ PJRT_DEVICE=TPU python3 -m debug rank_zero_only.rank = 4, xm.get_ordinal() = 4 rank_zero_only.rank = 5, xm.get_ordinal() = 5 rank_zero_only.rank = 2, xm.get_ordinal() = 2 rank_zero_only.rank = 3, xm.get_ordinal() = 3 rank_zero_only.rank = 0, xm.get_ordinal() = 0 rank_zero_only.rank = 1, xm.get_ordinal() = 1 rank_zero_only.rank = 6, xm.get_ordinal() = 6 rank_zero_only.rank = 7, xm.get_ordinal() = 7
If I had instead assigned
xm.get_ordinal()
to a local variable like so:def f(index): tmp = xm.get_ordinal() xm.rendezvous("barrier") print(f"{tmp = } {xm.get_ordinal() = }")
then
tmp
andxm.get_ordinal()
match, so I think this is an issue withrank_zero_only.rank
.
The xmp.spawn()
on v3 TPUs is multi-process and multi-thread. There are 4 processes for 4 chips, and 2 threads in each process for each core in a chip. So the rank_zero_only
object is shared between 2 threads, and that's why modifying one would cause 2 rank_zero_only.rank
to have same value. Without xm.rendezvous("barrier")
, the print value seems to be right, but it's only transient, and if you sleep
for 5 seconds and print again, they would be same as the wrong one.
This is actually the reason why trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
is needed on the Lightning code. The shared objects between the threads need to be decoupled.
@will-cromar
Isn't this a matter of a delayed init after forking? This fixes wandb from initializing 4 times (on a vx-8) and having mixed stream ids.
@@ -59,12 +59,14 @@
data = LinearDataModule()
model = LinearRegression()
+ logger=pl.loggers.WandbLogger(project="tpu_debug")
+ logger.experiment
trainer = pl.Trainer(
accelerator="tpu",
devices=8,
enable_checkpointing=False,
precision="bf16-mixed",
- logger=pl.loggers.WandbLogger(project="tpu_debug"),
+ logger=logger,
max_epochs=100,
enable_progress_bar=True,
)
however there will also be these 4 of these warnings from trying to create new session:
.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py:391: UserWarning: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
rank_zero_warn(
and it will hang (possibly related: https://docs.wandb.ai/guides/integrations/lightning#how-to-use-multiple-gpus-with-lightning-and-wb).
I'm not sure what the proper patch would be within lightning.
After debugging this for a bit, the issue is that you have to wandb.login
before the fit (before the forks?). eg:
@@ -59,12 +59,14 @@
data = LinearDataModule()
model = LinearRegression()
+ import wandb
+ wandb.login()
trainer = pl.Trainer(
accelerator="tpu",
devices=8,
as an aside, I had a (user) issue with consolidating all under one run:
Either set things up on Google's TPU VMs via:
python3 -m pip install --upgrade pip
python3 -m pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip wandb -U
or the _WANDB_GREATER_EQUAL_0_12_10
check needs to be less strict
The _WANDB_GREATER_EQUAL_0_12_10
check failure skips the pickling hack to unify the runs.
(Pdb) p _WANDB_GREATER_EQUAL_0_12_10
ContextualVersionConflict: (urllib3 1.25.8 (/usr/lib/python3/dist-packages), Requirement.parse('urllib3>=1.26.11; python_version >= "3.6"'), {'sentry-sdk'}). HINT: Try running `pip install -U 'wandb>=0.12.10'`
pip on a fresh Google's --version=tpu-vm-pt-2.0
TPU VM is 20.0.2
, so it doesn't have requirements backtracking.
The VM also installed pip via apt, so doing python3 -m pip install --upgrade pip
doesn't update pip
on the default PATH.
This is my user error in not updating+using the right pip/fixing all the env warnings, but maybe the RequirementCache
class might be a bit too strict, and it should just check if the version number is satisfied rather than if all the sub-requirements are also satisfied.
Hi @s22chan
After debugging this for a bit, the issue is that you have to wandb.login before the fit (before the forks?). eg:
I recommend that you do wandb login
in the command line instead (one time only). Then you will be automatically logged in whenever you call wandb in Python.
Regarding the other issue:
We have this trick in the logger to init the experiment when processes get launched (see comment in the code): https://github.com/Lightning-AI/lightning/blob/6511ac28759718a524dd00e627c186fb6baea763/src/lightning/pytorch/loggers/wandb.py#L349-L356
It would be very helpful if you could check whether this code path gets triggered or not in your case.
I didn't fully understand your comment about _WANDB_GREATER_EQUAL_0_12_10
. Are you saying you have wandb>=0.12.10 installed, yet the check failed and defaulted to False? If so, we could consider setting this version as the minimum required version, so we don't have to check it in the first place.
@awaelchli sorry if the messages were a bit scattered yesterday.
I recommend that you do wandb login in the command line instead (one time only).
I've already done that. The wandb.login()
before the fork/spawn is required to avoid a datarace between the two TPU threads launched on rank 0 for the logger init, which leads to the original reported crash.
@alstonlo is inferring that much of the rank_zero mechanisms in place for logging/profiling(/other?) doesn't work in a TPU scenario with the PJRT change because there are now two threads that have rank 0.
I didn't fully understand your comment about _WANDB_GREATER_EQUAL_0_12_10.
Wandb was wandb==0.15.7
, but because of a conflict in urllib3
(which is a sub-dependency of wandb
), the bool cast from RequirementCache
fails. This is super not obvious as a user.
any updates on this issue?
related: https://github.com/Lightning-AI/pytorch-lightning/issues/19035 (not wandb but logging and dataraces on the threads)
Bug description
On a TPU VM, using
WandbLogger
causes training to crash. I am using the nightly build which I know states "no guarantees", so apologies in advance if this is currently being worked on (I wasn't able to find any relevant issues or PRs). I am also unsure of why this error is occurring, and whether it is an issue with Lightning or WandB.What version are you seeing the problem on?
master
How to reproduce the bug
The above code was written to a file
train.py
and run withError messages and logs
Environment
v3-8
tpu-vm-pt-2.0
More info
If I train without a logger instead, then no error occurs and the script proceeds normally.
cc @carmocca @JackCaoG @steventk-g @Liyang90 @awaelchli @morganmcg1 @borisdayma @scottire @parambharat