Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.31k stars 3.38k forks source link

Manual optimization doesn't work with multiple TPUs with `pytorch-lightning: 1.7.1` #14249

Closed dinhanhx closed 2 years ago

dinhanhx commented 2 years ago

🐛 Bug

As title, it only works with core. When multiple cores, it yields Assertion Errors under this line https://github.com/Lightning-AI/lightning/blob/acd4805f1a284e513272d150de6f98f27a0489b3/src/pytorch_lightning/loops/optimization/manual_loop.py#L110

(torch-12) ]0;dinhanhx@t1v-n-8b0bf8c6-w-0: ~/storage/projects/boringdinhanhx@t1v-n-8b0bf8c6-w-0:~/storage/projects/boring$ conda activate torch-12python3 boring.py
2022-08-17 15:38:55.853913: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:38:55.853977: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
2022-08-17 15:39:25.752772: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:25.752857: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:50.188700: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:50.188770: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:51.019169: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:51.019232: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:51.679213: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:51.679274: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:52.989587: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:52.989650: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:53.426351: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:53.426413: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:54.191559: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:54.191623: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
2022-08-17 15:39:54.916327: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-08-17 15:39:54.916389: E tensorflow/core/framework/op_kernel.cc:1676] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey
/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1894: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
[?25l
Epoch 0    ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/2 0:00:00 • -:--:-- 0.00it/s  
Epoch 0    ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/2 0:00:00 • -:--:-- 0.00it/s  pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
pt-xla-profiler: TransferFromServerTime too frequent: 5 counts during 1 steps
Exception in device=TPU:7: 
Exception in device=TPU:4: 
Exception in device=TPU:2: 
Exception in device=TPU:3: 
Exception in device=TPU:1: 
Exception in device=TPU:5: 
Exception in device=TPU:6: 
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
    fn(rank, *_args)
Traceback (most recent call last):
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
    fn(rank, *_args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
    results = self._run_stage()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
    results = self._run_stage()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
    return self._run_train()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
    return self._run_train()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
    self.fit_loop.run()
Traceback (most recent call last):
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
    self.fit_loop.run()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
    fn(rank, *_args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
Traceback (most recent call last):
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
Traceback (most recent call last):
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
    results = self._run_stage()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
    fn(rank, *_args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
    return self._run_train()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
    outputs = self.manual_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
    self.fit_loop.run()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
    outputs = self.manual_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
    results = self._run_stage()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
    fn(rank, *_args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 157, in wrapped
    fn(rank, *_args)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
    return self._run_train()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
AssertionError
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 107, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
    self.fit_loop.run()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
AssertionError
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
    results = self._run_stage()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1168, in _run
    results = self._run_stage()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
    return self._run_train()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
    outputs = self.manual_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1254, in _run_stage
    return self._run_train()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
    self.fit_loop.run()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _run_train
    self.fit_loop.run()
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
Traceback (most recent call last):
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 270, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
    outputs = self.manual_loop.run(kwargs)
AssertionError
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
    outputs = self.manual_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 89, in advance
    outputs = self.manual_loop.run(kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
AssertionError
  File "/home/dinhanhx/storage/miniconda3/envs/torch-12/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 110, in advance
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import LinearLR

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import RichProgressBar

import torch_xla.core.xla_model as xm
from torch.utils.data.distributed import DistributedSampler

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.automatic_optimization = False
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        opt.zero_grad()
        loss = self(batch).sum()
        self.log("train_loss", loss, sync_dist=True, sync_dist_group=True, rank_zero_only=True)
        self.manual_backward(loss)
        opt.step()
        sch = self.lr_schedulers()
        sch.step()
        # return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        opt = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        scheduler = LinearLR(opt, start_factor=0.5, total_iters=4)
        return [opt], [scheduler]

def run():
    ds = RandomDataset(32, 64)
    sampler = DistributedSampler(
            ds, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True
        )

    train_data = DataLoader(ds, batch_size=2)
    val_data = DataLoader(ds, batch_size=2)
    test_data = DataLoader(ds, batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        logger=CSVLogger("csvlogs"), 
        accelerator='tpu', devices=8, 
        callbacks=[RichProgressBar()], 
        strategy="tpu_spawn_debug"
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)

if __name__ == "__main__":
    run()
export TPU_LOG_DIR="disabled"
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
python boring.py

Expected behavior

It runs.

Environment

* CUDA:
        - GPU:               None
        - available:         False
        - version:           10.2
* Lightning:
        - pytorch-lightning: 1.7.1
        - torch:             1.11.0
        - torch-xla:         1.11
        - torchinfo:         1.7.0
        - torchmetrics:      0.9.3
        - torchvision:       0.12.0
* Packages:
        - absl-py:           1.2.0
        - aiohttp:           3.8.1
        - aiosignal:         1.2.0
        - argon2-cffi:       21.3.0
        - argon2-cffi-bindings: 21.2.0
        - astroid:           2.11.7
        - asttokens:         2.0.5
        - async-timeout:     4.0.2
        - attrs:             22.1.0
        - backcall:          0.2.0
        - beautifulsoup4:    4.11.1
        - bleach:            5.0.1
        - cachetools:        4.2.4
        - certifi:           2022.6.15
        - cffi:              1.15.1
        - charset-normalizer: 2.1.0
        - cloud-tpu-client:  0.10
        - cloud-tpu-profiler: 2.4.0
        - commonmark:        0.9.1
        - debugpy:           1.6.2
        - decorator:         5.1.1
        - defusedxml:        0.7.1
        - dill:              0.3.5.1
        - einops:            0.4.1
        - entrypoints:       0.4
        - executing:         0.9.1
        - fastjsonschema:    2.16.1
        - filelock:          3.7.1
        - flake8:            5.0.4
        - frozenlist:        1.3.1
        - fsspec:            2022.7.1
        - google-api-core:   1.32.0
        - google-api-python-client: 1.8.0
        - google-auth:       1.35.0
        - google-auth-httplib2: 0.1.0
        - google-auth-oauthlib: 0.4.6
        - googleapis-common-protos: 1.56.4
        - grpcio:            1.47.0
        - httplib2:          0.20.4
        - huggingface-hub:   0.8.1
        - idna:              3.3
        - importlib-metadata: 4.12.0
        - importlib-resources: 5.9.0
        - install:           1.3.5
        - ipykernel:         6.15.1
        - ipython:           8.4.0
        - ipython-genutils:  0.2.0
        - ipywidgets:        7.7.1
        - isort:             5.10.1
        - jedi:              0.18.1
        - jinja2:            3.1.2
        - jsonschema:        4.9.1
        - jupyter-client:    7.3.4
        - jupyter-core:      4.11.1
        - jupyterlab-pygments: 0.2.2
        - jupyterlab-widgets: 1.1.1
        - lazy-object-proxy: 1.7.1
        - libtpu-nightly:    0.1.dev20220303
        - markdown:          3.4.1
        - markupsafe:        2.1.1
        - matplotlib-inline: 0.1.3
        - mccabe:            0.7.0
        - mistune:           0.8.4
        - multidict:         6.0.2
        - nbclient:          0.6.6
        - nbconvert:         6.5.0
        - nbformat:          5.4.0
        - nest-asyncio:      1.5.5
        - notebook:          6.4.12
        - numpy:             1.23.1
        - oauth2client:      4.1.3
        - oauthlib:          3.2.0
        - packaging:         21.3
        - pandocfilters:     1.5.0
        - parso:             0.8.3
        - pexpect:           4.8.0
        - pickleshare:       0.7.5
        - pillow:            9.2.0
        - pip:               22.1.2
        - pkgutil-resolve-name: 1.3.10
        - platformdirs:      2.5.2
        - prometheus-client: 0.14.1
        - prompt-toolkit:    3.0.30
        - protobuf:          3.19.4
        - psutil:            5.9.1
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.2
        - pyasn1:            0.4.8
        - pyasn1-modules:    0.2.8
        - pycodestyle:       2.9.1
        - pycparser:         2.21
        - pydeprecate:       0.3.2
        - pyflakes:          2.5.0
        - pygments:          2.12.0
        - pylint:            2.14.5
        - pyparsing:         3.0.9
        - pyrsistent:        0.18.1
        - python-dateutil:   2.8.2
        - pytorch-lightning: 1.7.1
        - pytz:              2022.1
        - pyyaml:            6.0
        - pyzmq:             23.2.0
        - regex:             2022.7.25
        - requests:          2.28.1
        - requests-oauthlib: 1.3.1
        - rich:              12.5.1
        - rsa:               4.9
        - send2trash:        1.8.0
        - sentencepiece:     0.1.97
        - setuptools:        61.2.0
        - six:               1.16.0
        - soupsieve:         2.3.2.post1
        - stack-data:        0.3.0
        - tablign:           0.3.4
        - tensorboard:       2.10.0
        - tensorboard-data-server: 0.6.1
        - tensorboard-plugin-wit: 1.8.1
        - terminado:         0.15.0
        - tinycss2:          1.1.1
        - tokenizers:        0.12.1
        - tomli:             2.0.1
        - tomlkit:           0.11.1
        - torch:             1.11.0
        - torch-xla:         1.11
        - torchinfo:         1.7.0
        - torchmetrics:      0.9.3
        - torchvision:       0.12.0
        - tornado:           6.2
        - tqdm:              4.64.0
        - traitlets:         5.3.0
        - transformers:      4.21.1
        - typing-extensions: 4.3.0
        - uritemplate:       3.0.1
        - urllib3:           1.26.11
        - wcwidth:           0.2.5
        - webencodings:      0.5.1
        - werkzeug:          2.2.2
        - wheel:             0.37.1
        - widgetsnbextension: 3.6.1
        - yarl:              1.8.1
        - zipp:              3.8.1
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.13
        - version:           #46-Ubuntu SMP Mon Apr 19 19:17:04 UTC 2021

Additional context

I runs on Google TPU VM 3.8

cc @kaushikb11 @rohitgr7

carmocca commented 2 years ago

Hi! Can you fix the link for the AssertionError? It does not point to an assertion

dinhanhx commented 2 years ago

@carmocca Sorry I didn't make myself clear at that point. When the stack trace is printed, I see AssertionsError below the line I linked. I will send you the complete stack trace when I get back to my machine.

dinhanhx commented 2 years ago

@carmocca I have updated the issue with traces.

carmocca commented 2 years ago

Thanks! I guess this is caused by one of the recent mypy PRs we've merged. But the real assertion error does not get surfaced.

cc @kaushikb11 @awaelchli

dinhanhx commented 2 years ago

@carmocca I downgraded lightning from 1.7.1 to 1.6.5. Then I rerun the boring models. It yields no errors.

awaelchli commented 2 years ago

We have a test for manual optimization silently failing with the same error here: #14034 I have not yet found a way to make the test failures surface in the CI.

awaelchli commented 2 years ago

Correction: I no longer see these failures from a few weeks ago. All tests are passing. @dinhanhx maybe it is worth trying master?

dinhanhx commented 2 years ago

@awaelchli I just tried master 2022.08.22. Yes it works.

carmocca commented 2 years ago

@kaushikb11 or @dinhanhx It would be great if you can git-bisect the commits between 1.7.0 and current master to find which commit fixed it so we include it in a bug-fix release.

dinhanhx commented 2 years ago

@kaushikb11 @carmocca sorry I don't know how to use git bisect :(

carmocca commented 2 years ago

Closing as there will be no more bug-fix releases before 1.8