intel / intel-extension-for-pytorch

A Python package for extending the official PyTorch that can easily obtain performance on Intel platform
Apache License 2.0
1.46k stars 224 forks source link

[Bug] BFloat16 + master_param_non_fused_step results in `None` grad #239

Open robogast opened 1 year ago

robogast commented 1 year ago

Observations:

Context: I train an auto-encoder with PyTorch lightning, and I am trying to move (part of) my codebase to IPEX + torchCCL, with the intention of running on Sapphire Rapids (in the meantime I'm running on IceLakes) I ran into this issue, and have created the minimal reproducible example below. I've posted my full env for completeness, but I think the important deps are:

Bug report model ```python import os import torch from torch.utils.data import DataLoader, Dataset from pytorch_lightning import LightningModule, Trainer ###### User imports ###### import oneccl_bindings_for_pytorch import intel_extension_for_pytorch as ipex from pytorch_lightning.strategies import DDPStrategy class RandomDataset(Dataset): def __init__(self, size, num_samples=10000): self.len = num_samples self.data = torch.randn(num_samples, size) def __getitem__(self, index): return self.data[index] def __len__(self): return self.len class BoringModel(LightningModule): def __init__(self): super().__init__() self.layer = torch.nn.Linear(32, 2) def forward(self, x): return self.layer(x) def training_step(self, batch, batch_idx): loss = self(batch).sum() self.log("train_loss", loss) return {"loss": loss} def validation_step(self, batch, batch_idx): loss = self(batch).sum() self.log("valid_loss", loss) def test_step(self, batch, batch_idx): loss = self(batch).sum() self.log("test_loss", loss) def configure_optimizers(self): _, optim = ipex.optimize( model=self, optimizer=torch.optim.AdamW(self.layer.parameters(), lr=0.1), # XXX: Changing to SGD or Adam works as expected level='O1', inplace=True, dtype=torch.bfloat16 # XXX: Commenting out this line also works! ) return optim def run(): train_data = DataLoader(RandomDataset(32, 64), batch_size=2) val_data = DataLoader(RandomDataset(32, 64), batch_size=2) test_data = DataLoader(RandomDataset(32, 64), batch_size=2) model = BoringModel() trainer = Trainer( default_root_dir=os.getcwd(), limit_train_batches=5, limit_val_batches=1, limit_test_batches=1, num_sanity_val_steps=0, max_epochs=1, enable_model_summary=True, # ADDED, strategy=DDPStrategy(process_group_backend='ccl', find_unused_parameters=False), precision='bf16', # XXX: changing precision to 32 + removing dtype from optimize also works ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) trainer.test(model, dataloaders=test_data) if __name__ == '__main__': run() ```
Execution log + Error ```python $ python bug_report_model.py /gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: ae is an invalid version and will not be supported in a future release warnings.warn( /gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:712: UserWarning: You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU. Using `precision='bf16'` instead. rank_zero_warn( Using bfloat16 Automatic Mixed Precision (AMP) GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs `Trainer(limit_val_batches=1)` was configured so 1 batch will be used. `Trainer(limit_test_batches=1)` was configured so 1 batch will be used. Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1 2022-08-10 14:09:38,344 - torch.distributed.distributed_c10d - INFO - Added key: store_based_barrier_key:1 to store for rank: 0 2022-08-10 14:09:38,344 - torch.distributed.distributed_c10d - INFO - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes. ---------------------------------------------------------------------------------------------------- distributed_backend=ccl All distributed processes registered. Starting with 1 processes ---------------------------------------------------------------------------------------------------- 2022:08:10-14:09:38:(3843473) |CCL_WARN| did not find MPI-launcher specific variables, switch to ATL/OFI, to force enable ATL/MPI set CCL_ATL_TRANSPORT=mpi 2022:08:10-14:09:39:(3843473) |CCL_WARN| /gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/intel_extension_for_pytorch/frontend.py:280: UserWarning: IPEX does not support fused/fused split update forwill use non-fused master weight update for bf16 training warnings.warn( | Name | Type | Params -------------------------------------- 0 | layer | _IPEXLinear | 66 -------------------------------------- 66 Trainable params 0 Non-trainable params 66 Total params 0.000 Total estimated model params size (MB) /gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:219: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 72 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn( /gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1891: PossibleUserWarning: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch. rank_zero_warn( /gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:219: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 72 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn( Epoch 0: 0%| | 0/6 [00:00 run() File "/gpfs/home5/robertsc/2D-VQ-AE-2/REMOVE/bug_report_model.py", line 107, in run trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 700, in fit self._call_and_handle_interrupt( File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 652, in _call_and_handle_interrupt return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch return function(*args, **kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 741, in _fit_impl results = self._run(model, ckpt_path=self.ckpt_path) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run results = self._run_stage() File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage return self._run_train() File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in _run_train self.fit_loop.run() File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run self.advance(*args, **kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 269, in advance self._outputs = self.epoch_loop.run(self._data_fetcher) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run self.advance(*args, **kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 203, in advance batch_output = self.batch_loop.run(kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run self.advance(*args, **kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 87, in advance outputs = self.optimizer_loop.run(optimizers, kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run self.advance(*args, **kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 201, in advance result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position]) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 248, in _run_optimization self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 358, in _optimizer_step self.trainer._call_lightning_module_hook( File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1549, in _call_lightning_module_hook output = fn(*args, **kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/core/module.py", line 1666, in optimizer_step optimizer.step(closure=optimizer_closure) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 168, in step step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/strategies/ddp.py", line 286, in optimizer_step optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 216, in optimizer_step return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/native_amp.py", line 80, in optimizer_step return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 153, in optimizer_step return optimizer.step(closure=closure, **kwargs) File "/gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/intel_extension_for_pytorch/optim/_optimizer_utils.py", line 57, in master_param_non_fused_step k.grad = value['bf16_param'].grad.detach().float() AttributeError: 'NoneType' object has no attribute 'detach' Epoch 0: 0%| | 0/6 [00:00
Environment ```python $ pdm list --graph Inside an active virtualenv /gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu, reuse it. albumentations 1.2.1 [ required: >=1.1.0 ] ├── numpy 1.23.1 [ required: >=1.11.1 ] ├── opencv-python-headless 4.6.0.66 [ required: >=4.1.1 ] │ └── numpy 1.23.1 [ required: >=1.19.3 ] ├── pyyaml 6.0 [ required: Any ] ├── qudida 0.0.4 [ required: >=0.0.4 ] │ ├── numpy 1.23.1 [ required: >=0.18.0 ] │ ├── opencv-python-headless 4.6.0.66 [ required: >=4.0.1 ] │ │ └── numpy 1.23.1 [ required: >=1.19.3 ] │ ├── scikit-learn 1.1.1 [ required: >=0.19.1 ] │ │ ├── joblib 1.1.0 [ required: >=1.0.0 ] │ │ ├── numpy 1.23.1 [ required: >=1.17.3 ] │ │ ├── scipy 1.9.0 [ required: >=1.3.2 ] │ │ │ └── numpy 1.23.1 [ required: <1.25.0,>=1.18.5 ] │ │ └── threadpoolctl 3.1.0 [ required: >=2.0.0 ] │ └── typing-extensions 4.3.0 [ required: Any ] ├── scikit-image 0.19.3 [ required: >=0.16.1 ] │ ├── imageio 2.21.0 [ required: >=2.4.1 ] │ │ ├── numpy 1.23.1 [ required: Any ] │ │ └── pillow 9.2.0 [ required: >=8.3.2 ] │ ├── networkx 2.8.5 [ required: >=2.2 ] │ ├── numpy 1.23.1 [ required: >=1.17.0 ] │ ├── packaging 21.3 [ required: >=20.0 ] │ │ └── pyparsing 3.0.9 [ required: !=3.0.5,>=2.0.2 ] │ ├── pillow 9.2.0 [ required: !=7.1.0,!=7.1.1,!=8.3.0,>=6.1.0 ] │ ├── pywavelets 1.3.0 [ required: >=1.1.1 ] │ │ └── numpy 1.23.1 [ required: >=1.17.3 ] │ ├── scipy 1.9.0 [ required: >=1.4.1 ] │ │ └── numpy 1.23.1 [ required: <1.25.0,>=1.18.5 ] │ └── tifffile 2022.8.3 [ required: >=2019.7.26 ] │ └── numpy 1.23.1 [ required: >=1.19.2 ] └── scipy 1.9.0 [ required: Any ] └── numpy 1.23.1 [ required: <1.25.0,>=1.18.5 ] h5py 3.7.0 [ required: >=3.6.0 ] └── numpy 1.23.1 [ required: >=1.14.5 ] hydra-submitit-launcher 1.2.0 [ required: Any ] ├── hydra-core 1.2.0 [ required: >=1.1.0.dev7 ] │ ├── antlr4-python3-runtime 4.9.3 [ required: ==4.9.* ] │ ├── omegaconf 2.2.2 [ required: ~=2.2 ] │ │ ├── antlr4-python3-runtime 4.9.3 [ required: ==4.9.* ] │ │ └── pyyaml 6.0 [ required: >=5.1.0 ] │ └── packaging 21.3 [ required: Any ] │ └── pyparsing 3.0.9 [ required: !=3.0.5,>=2.0.2 ] └── submitit 1.4.5 [ required: >=1.3.3 ] ├── cloudpickle 2.1.0 [ required: >=1.2.1 ] └── typing-extensions 4.3.0 [ required: >=3.7.4.2 ] intel-extension-for-pytorch 1.12.100+cpu [ required: ==1.12.100+cpu ] └── psutil 5.9.1 [ required: Any ] matplotlib 3.5.2 [ required: >=3.5.2 ] ├── cycler 0.11.0 [ required: >=0.10 ] ├── fonttools 4.34.4 [ required: >=4.22.0 ] ├── kiwisolver 1.4.4 [ required: >=1.0.1 ] ├── numpy 1.23.1 [ required: >=1.17 ] ├── packaging 21.3 [ required: >=20.0 ] │ └── pyparsing 3.0.9 [ required: !=3.0.5,>=2.0.2 ] ├── pillow 9.2.0 [ required: >=6.2.0 ] ├── pyparsing 3.0.9 [ required: >=2.2.1 ] └── python-dateutil 2.8.2 [ required: >=2.7 ] └── six 1.16.0 [ required: >=1.5 ] oneccl-bind-pt 1.12.0+cpu [ required: ==1.12.0+cpu ] pytorch-lightning 1.7.0 [ required: >=1.6.1 ] ├── fsspec[http] 2022.7.1 [ required: !=2021.06.0,>=2021.05.0 ] │ ├── aiohttp 3.8.1 [ required: Any ] │ │ ├── aiosignal 1.2.0 [ required: >=1.1.2 ] │ │ │ └── frozenlist 1.3.1 [ required: >=1.1.0 ] │ │ ├── async-timeout 4.0.2 [ required: <5.0,>=4.0.0a3 ] │ │ ├── attrs 22.1.0 [ required: >=17.3.0 ] │ │ ├── charset-normalizer 2.1.0 [ required: <3.0,>=2.0 ] │ │ ├── frozenlist 1.3.1 [ required: >=1.1.1 ] │ │ ├── multidict 6.0.2 [ required: <7.0,>=4.5 ] │ │ └── yarl 1.8.1 [ required: <2.0,>=1.0 ] │ │ ├── idna 3.3 [ required: >=2.0 ] │ │ └── multidict 6.0.2 [ required: >=4.0 ] │ └── requests 2.28.1 [ required: Any ] │ ├── certifi 2022.6.15 [ required: >=2017.4.17 ] │ ├── charset-normalizer 2.1.0 [ required: <3,>=2 ] │ ├── idna 3.3 [ required: <4,>=2.5 ] │ └── urllib3 1.26.11 [ required: <1.27,>=1.21.1 ] ├── numpy 1.23.1 [ required: >=1.17.2 ] ├── packaging 21.3 [ required: >=17.0 ] │ └── pyparsing 3.0.9 [ required: !=3.0.5,>=2.0.2 ] ├── pydeprecate 0.3.2 [ required: >=0.3.1 ] ├── pyyaml 6.0 [ required: >=5.4 ] ├── tensorboard 2.9.1 [ required: >=2.9.1 ] │ ├── absl-py 1.2.0 [ required: >=0.4 ] │ ├── google-auth 2.9.1 [ required: <3,>=1.6.3 ] │ │ ├── cachetools 5.2.0 [ required: <6.0,>=2.0.0 ] │ │ ├── pyasn1-modules 0.2.8 [ required: >=0.2.1 ] │ │ │ └── pyasn1 0.4.8 [ required: <0.5.0,>=0.4.6 ] │ │ ├── rsa 4.9 [ required: <5,>=3.1.4 ] │ │ │ └── pyasn1 0.4.8 [ required: >=0.1.3 ] │ │ └── six 1.16.0 [ required: >=1.9.0 ] │ ├── google-auth-oauthlib 0.4.6 [ required: <0.5,>=0.4.1 ] │ │ ├── google-auth 2.9.1 [ required: >=1.0.0 ] │ │ │ ├── cachetools 5.2.0 [ required: <6.0,>=2.0.0 ] │ │ │ ├── pyasn1-modules 0.2.8 [ required: >=0.2.1 ] │ │ │ │ └── pyasn1 0.4.8 [ required: <0.5.0,>=0.4.6 ] │ │ │ ├── rsa 4.9 [ required: <5,>=3.1.4 ] │ │ │ │ └── pyasn1 0.4.8 [ required: >=0.1.3 ] │ │ │ └── six 1.16.0 [ required: >=1.9.0 ] │ │ └── requests-oauthlib 1.3.1 [ required: >=0.7.0 ] │ │ ├── oauthlib 3.2.0 [ required: >=3.0.0 ] │ │ └── requests 2.28.1 [ required: >=2.0.0 ] │ │ ├── certifi 2022.6.15 [ required: >=2017.4.17 ] │ │ ├── charset-normalizer 2.1.0 [ required: <3,>=2 ] │ │ ├── idna 3.3 [ required: <4,>=2.5 ] │ │ └── urllib3 1.26.11 [ required: <1.27,>=1.21.1 ] │ ├── grpcio 1.47.0 [ required: >=1.24.3 ] │ │ └── six 1.16.0 [ required: >=1.5.2 ] │ ├── markdown 3.4.1 [ required: >=2.6.8 ] │ │ └── importlib-metadata 4.12.0 [ required: >=4.4 ] │ │ └── zipp 3.8.1 [ required: >=0.5 ] │ ├── numpy 1.23.1 [ required: >=1.12.0 ] │ ├── protobuf 3.19.4 [ required: <3.20,>=3.9.2 ] │ ├── requests 2.28.1 [ required: <3,>=2.21.0 ] │ │ ├── certifi 2022.6.15 [ required: >=2017.4.17 ] │ │ ├── charset-normalizer 2.1.0 [ required: <3,>=2 ] │ │ ├── idna 3.3 [ required: <4,>=2.5 ] │ │ └── urllib3 1.26.11 [ required: <1.27,>=1.21.1 ] │ ├── setuptools 63.4.1 [ required: >=41.0.0 ] │ ├── tensorboard-data-server 0.6.1 [ required: <0.7.0,>=0.6.0 ] │ ├── tensorboard-plugin-wit 1.8.1 [ required: >=1.6.0 ] │ ├── werkzeug 2.2.1 [ required: >=1.0.1 ] │ │ └── markupsafe 2.1.1 [ required: >=2.1.1 ] │ └── wheel 0.37.1 [ required: >=0.26 ] ├── torch 1.12.0+cpu [ required: >=1.9.* ] │ └── typing-extensions 4.3.0 [ required: Any ] ├── torchmetrics 0.9.3 [ required: >=0.7.0 ] │ ├── numpy 1.23.1 [ required: >=1.17.2 ] │ ├── packaging 21.3 [ required: Any ] │ │ └── pyparsing 3.0.9 [ required: !=3.0.5,>=2.0.2 ] │ └── torch 1.12.0+cpu [ required: >=1.3.1 ] │ └── typing-extensions 4.3.0 [ required: Any ] ├── tqdm 4.64.0 [ required: >=4.57.0 ] └── typing-extensions 4.3.0 [ required: >=4.0.0 ] rich 12.5.1 [ required: >=12.4.1 ] ├── commonmark 0.9.1 [ required: <0.10.0,>=0.9.0 ] └── pygments 2.12.0 [ required: <3.0.0,>=2.6.0 ] torchvision 0.13.0+cpu [ required: ==0.13.0+cpu ] ├── numpy 1.23.1 [ required: Any ] ├── pillow 9.2.0 [ required: !=8.3.*,>=5.3.0 ] ├── requests 2.28.1 [ required: Any ] │ ├── certifi 2022.6.15 [ required: >=2017.4.17 ] │ ├── charset-normalizer 2.1.0 [ required: <3,>=2 ] │ ├── idna 3.3 [ required: <4,>=2.5 ] │ └── urllib3 1.26.11 [ required: <1.27,>=1.21.1 ] ├── torch 1.12.0+cpu [ required: ==1.12.0 ] │ └── typing-extensions 4.3.0 [ required: Any ] └── typing-extensions 4.3.0 [ required: Any ] vq-ae-2-2d 0.0.3 [ Not required ] ```
jingxu10 commented 1 year ago

Thanks for reporting the issue. We will look into it.

y199387 commented 1 year ago

I have the same problem. The reason might be that pytorch-lightning calls optimizer.step(closure=closure, **kwargs) in the precision_plugin and .grad is None before calling .backward.

jingxu10 commented 1 year ago

Currently, IPEX cannot work with PyTorch-lighting. We are working in progress on the enablement.