Open robogast opened 2 years ago
Thanks for reporting the issue. We will look into it.
I have the same problem. The reason might be that pytorch-lightning calls optimizer.step(closure=closure, **kwargs) in the precision_plugin and .grad is None before calling .backward.
Currently, IPEX cannot work with PyTorch-lighting. We are working in progress on the enablement.
Observations:
ipex.optimize(..., level='01', dtype=torch.bfloat16)
no erroripex optimize(..., level='01', dtype=torch.float32)
+AMP bf16
no erroripex.optimize(..., level='01', dtype=torch.bfloat16)
errorContext: I train an auto-encoder with PyTorch lightning, and I am trying to move (part of) my codebase to IPEX + torchCCL, with the intention of running on Sapphire Rapids (in the meantime I'm running on IceLakes) I ran into this issue, and have created the minimal reproducible example below. I've posted my full env for completeness, but I think the important deps are:
Bug report model
```python import os import torch from torch.utils.data import DataLoader, Dataset from pytorch_lightning import LightningModule, Trainer ###### User imports ###### import oneccl_bindings_for_pytorch import intel_extension_for_pytorch as ipex from pytorch_lightning.strategies import DDPStrategy class RandomDataset(Dataset): def __init__(self, size, num_samples=10000): self.len = num_samples self.data = torch.randn(num_samples, size) def __getitem__(self, index): return self.data[index] def __len__(self): return self.len class BoringModel(LightningModule): def __init__(self): super().__init__() self.layer = torch.nn.Linear(32, 2) def forward(self, x): return self.layer(x) def training_step(self, batch, batch_idx): loss = self(batch).sum() self.log("train_loss", loss) return {"loss": loss} def validation_step(self, batch, batch_idx): loss = self(batch).sum() self.log("valid_loss", loss) def test_step(self, batch, batch_idx): loss = self(batch).sum() self.log("test_loss", loss) def configure_optimizers(self): _, optim = ipex.optimize( model=self, optimizer=torch.optim.AdamW(self.layer.parameters(), lr=0.1), # XXX: Changing to SGD or Adam works as expected level='O1', inplace=True, dtype=torch.bfloat16 # XXX: Commenting out this line also works! ) return optim def run(): train_data = DataLoader(RandomDataset(32, 64), batch_size=2) val_data = DataLoader(RandomDataset(32, 64), batch_size=2) test_data = DataLoader(RandomDataset(32, 64), batch_size=2) model = BoringModel() trainer = Trainer( default_root_dir=os.getcwd(), limit_train_batches=5, limit_val_batches=1, limit_test_batches=1, num_sanity_val_steps=0, max_epochs=1, enable_model_summary=True, # ADDED, strategy=DDPStrategy(process_group_backend='ccl', find_unused_parameters=False), precision='bf16', # XXX: changing precision to 32 + removing dtype from optimize also works ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) trainer.test(model, dataloaders=test_data) if __name__ == '__main__': run() ```Execution log + Error
```python $ python bug_report_model.py /gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: ae is an invalid version and will not be supported in a future release warnings.warn( /gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:712: UserWarning: You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU. Using `precision='bf16'` instead. rank_zero_warn( Using bfloat16 Automatic Mixed Precision (AMP) GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs `Trainer(limit_val_batches=1)` was configured so 1 batch will be used. `Trainer(limit_test_batches=1)` was configured so 1 batch will be used. Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1 2022-08-10 14:09:38,344 - torch.distributed.distributed_c10d - INFO - Added key: store_based_barrier_key:1 to store for rank: 0 2022-08-10 14:09:38,344 - torch.distributed.distributed_c10d - INFO - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes. ---------------------------------------------------------------------------------------------------- distributed_backend=ccl All distributed processes registered. Starting with 1 processes ---------------------------------------------------------------------------------------------------- 2022:08:10-14:09:38:(3843473) |CCL_WARN| did not find MPI-launcher specific variables, switch to ATL/OFI, to force enable ATL/MPI set CCL_ATL_TRANSPORT=mpi 2022:08:10-14:09:39:(3843473) |CCL_WARN| /gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu/lib/python3.9/site-packages/intel_extension_for_pytorch/frontend.py:280: UserWarning: IPEX does not support fused/fused split update forEnvironment
```python $ pdm list --graph Inside an active virtualenv /gpfs/home5/robertsc/2D-VQ-AE-2/.venv/2D-VQ-AE-2-LDjtrq15-py39-cpu, reuse it. albumentations 1.2.1 [ required: >=1.1.0 ] ├── numpy 1.23.1 [ required: >=1.11.1 ] ├── opencv-python-headless 4.6.0.66 [ required: >=4.1.1 ] │ └── numpy 1.23.1 [ required: >=1.19.3 ] ├── pyyaml 6.0 [ required: Any ] ├── qudida 0.0.4 [ required: >=0.0.4 ] │ ├── numpy 1.23.1 [ required: >=0.18.0 ] │ ├── opencv-python-headless 4.6.0.66 [ required: >=4.0.1 ] │ │ └── numpy 1.23.1 [ required: >=1.19.3 ] │ ├── scikit-learn 1.1.1 [ required: >=0.19.1 ] │ │ ├── joblib 1.1.0 [ required: >=1.0.0 ] │ │ ├── numpy 1.23.1 [ required: >=1.17.3 ] │ │ ├── scipy 1.9.0 [ required: >=1.3.2 ] │ │ │ └── numpy 1.23.1 [ required: <1.25.0,>=1.18.5 ] │ │ └── threadpoolctl 3.1.0 [ required: >=2.0.0 ] │ └── typing-extensions 4.3.0 [ required: Any ] ├── scikit-image 0.19.3 [ required: >=0.16.1 ] │ ├── imageio 2.21.0 [ required: >=2.4.1 ] │ │ ├── numpy 1.23.1 [ required: Any ] │ │ └── pillow 9.2.0 [ required: >=8.3.2 ] │ ├── networkx 2.8.5 [ required: >=2.2 ] │ ├── numpy 1.23.1 [ required: >=1.17.0 ] │ ├── packaging 21.3 [ required: >=20.0 ] │ │ └── pyparsing 3.0.9 [ required: !=3.0.5,>=2.0.2 ] │ ├── pillow 9.2.0 [ required: !=7.1.0,!=7.1.1,!=8.3.0,>=6.1.0 ] │ ├── pywavelets 1.3.0 [ required: >=1.1.1 ] │ │ └── numpy 1.23.1 [ required: >=1.17.3 ] │ ├── scipy 1.9.0 [ required: >=1.4.1 ] │ │ └── numpy 1.23.1 [ required: <1.25.0,>=1.18.5 ] │ └── tifffile 2022.8.3 [ required: >=2019.7.26 ] │ └── numpy 1.23.1 [ required: >=1.19.2 ] └── scipy 1.9.0 [ required: Any ] └── numpy 1.23.1 [ required: <1.25.0,>=1.18.5 ] h5py 3.7.0 [ required: >=3.6.0 ] └── numpy 1.23.1 [ required: >=1.14.5 ] hydra-submitit-launcher 1.2.0 [ required: Any ] ├── hydra-core 1.2.0 [ required: >=1.1.0.dev7 ] │ ├── antlr4-python3-runtime 4.9.3 [ required: ==4.9.* ] │ ├── omegaconf 2.2.2 [ required: ~=2.2 ] │ │ ├── antlr4-python3-runtime 4.9.3 [ required: ==4.9.* ] │ │ └── pyyaml 6.0 [ required: >=5.1.0 ] │ └── packaging 21.3 [ required: Any ] │ └── pyparsing 3.0.9 [ required: !=3.0.5,>=2.0.2 ] └── submitit 1.4.5 [ required: >=1.3.3 ] ├── cloudpickle 2.1.0 [ required: >=1.2.1 ] └── typing-extensions 4.3.0 [ required: >=3.7.4.2 ] intel-extension-for-pytorch 1.12.100+cpu [ required: ==1.12.100+cpu ] └── psutil 5.9.1 [ required: Any ] matplotlib 3.5.2 [ required: >=3.5.2 ] ├── cycler 0.11.0 [ required: >=0.10 ] ├── fonttools 4.34.4 [ required: >=4.22.0 ] ├── kiwisolver 1.4.4 [ required: >=1.0.1 ] ├── numpy 1.23.1 [ required: >=1.17 ] ├── packaging 21.3 [ required: >=20.0 ] │ └── pyparsing 3.0.9 [ required: !=3.0.5,>=2.0.2 ] ├── pillow 9.2.0 [ required: >=6.2.0 ] ├── pyparsing 3.0.9 [ required: >=2.2.1 ] └── python-dateutil 2.8.2 [ required: >=2.7 ] └── six 1.16.0 [ required: >=1.5 ] oneccl-bind-pt 1.12.0+cpu [ required: ==1.12.0+cpu ] pytorch-lightning 1.7.0 [ required: >=1.6.1 ] ├── fsspec[http] 2022.7.1 [ required: !=2021.06.0,>=2021.05.0 ] │ ├── aiohttp 3.8.1 [ required: Any ] │ │ ├── aiosignal 1.2.0 [ required: >=1.1.2 ] │ │ │ └── frozenlist 1.3.1 [ required: >=1.1.0 ] │ │ ├── async-timeout 4.0.2 [ required: <5.0,>=4.0.0a3 ] │ │ ├── attrs 22.1.0 [ required: >=17.3.0 ] │ │ ├── charset-normalizer 2.1.0 [ required: <3.0,>=2.0 ] │ │ ├── frozenlist 1.3.1 [ required: >=1.1.1 ] │ │ ├── multidict 6.0.2 [ required: <7.0,>=4.5 ] │ │ └── yarl 1.8.1 [ required: <2.0,>=1.0 ] │ │ ├── idna 3.3 [ required: >=2.0 ] │ │ └── multidict 6.0.2 [ required: >=4.0 ] │ └── requests 2.28.1 [ required: Any ] │ ├── certifi 2022.6.15 [ required: >=2017.4.17 ] │ ├── charset-normalizer 2.1.0 [ required: <3,>=2 ] │ ├── idna 3.3 [ required: <4,>=2.5 ] │ └── urllib3 1.26.11 [ required: <1.27,>=1.21.1 ] ├── numpy 1.23.1 [ required: >=1.17.2 ] ├── packaging 21.3 [ required: >=17.0 ] │ └── pyparsing 3.0.9 [ required: !=3.0.5,>=2.0.2 ] ├── pydeprecate 0.3.2 [ required: >=0.3.1 ] ├── pyyaml 6.0 [ required: >=5.4 ] ├── tensorboard 2.9.1 [ required: >=2.9.1 ] │ ├── absl-py 1.2.0 [ required: >=0.4 ] │ ├── google-auth 2.9.1 [ required: <3,>=1.6.3 ] │ │ ├── cachetools 5.2.0 [ required: <6.0,>=2.0.0 ] │ │ ├── pyasn1-modules 0.2.8 [ required: >=0.2.1 ] │ │ │ └── pyasn1 0.4.8 [ required: <0.5.0,>=0.4.6 ] │ │ ├── rsa 4.9 [ required: <5,>=3.1.4 ] │ │ │ └── pyasn1 0.4.8 [ required: >=0.1.3 ] │ │ └── six 1.16.0 [ required: >=1.9.0 ] │ ├── google-auth-oauthlib 0.4.6 [ required: <0.5,>=0.4.1 ] │ │ ├── google-auth 2.9.1 [ required: >=1.0.0 ] │ │ │ ├── cachetools 5.2.0 [ required: <6.0,>=2.0.0 ] │ │ │ ├── pyasn1-modules 0.2.8 [ required: >=0.2.1 ] │ │ │ │ └── pyasn1 0.4.8 [ required: <0.5.0,>=0.4.6 ] │ │ │ ├── rsa 4.9 [ required: <5,>=3.1.4 ] │ │ │ │ └── pyasn1 0.4.8 [ required: >=0.1.3 ] │ │ │ └── six 1.16.0 [ required: >=1.9.0 ] │ │ └── requests-oauthlib 1.3.1 [ required: >=0.7.0 ] │ │ ├── oauthlib 3.2.0 [ required: >=3.0.0 ] │ │ └── requests 2.28.1 [ required: >=2.0.0 ] │ │ ├── certifi 2022.6.15 [ required: >=2017.4.17 ] │ │ ├── charset-normalizer 2.1.0 [ required: <3,>=2 ] │ │ ├── idna 3.3 [ required: <4,>=2.5 ] │ │ └── urllib3 1.26.11 [ required: <1.27,>=1.21.1 ] │ ├── grpcio 1.47.0 [ required: >=1.24.3 ] │ │ └── six 1.16.0 [ required: >=1.5.2 ] │ ├── markdown 3.4.1 [ required: >=2.6.8 ] │ │ └── importlib-metadata 4.12.0 [ required: >=4.4 ] │ │ └── zipp 3.8.1 [ required: >=0.5 ] │ ├── numpy 1.23.1 [ required: >=1.12.0 ] │ ├── protobuf 3.19.4 [ required: <3.20,>=3.9.2 ] │ ├── requests 2.28.1 [ required: <3,>=2.21.0 ] │ │ ├── certifi 2022.6.15 [ required: >=2017.4.17 ] │ │ ├── charset-normalizer 2.1.0 [ required: <3,>=2 ] │ │ ├── idna 3.3 [ required: <4,>=2.5 ] │ │ └── urllib3 1.26.11 [ required: <1.27,>=1.21.1 ] │ ├── setuptools 63.4.1 [ required: >=41.0.0 ] │ ├── tensorboard-data-server 0.6.1 [ required: <0.7.0,>=0.6.0 ] │ ├── tensorboard-plugin-wit 1.8.1 [ required: >=1.6.0 ] │ ├── werkzeug 2.2.1 [ required: >=1.0.1 ] │ │ └── markupsafe 2.1.1 [ required: >=2.1.1 ] │ └── wheel 0.37.1 [ required: >=0.26 ] ├── torch 1.12.0+cpu [ required: >=1.9.* ] │ └── typing-extensions 4.3.0 [ required: Any ] ├── torchmetrics 0.9.3 [ required: >=0.7.0 ] │ ├── numpy 1.23.1 [ required: >=1.17.2 ] │ ├── packaging 21.3 [ required: Any ] │ │ └── pyparsing 3.0.9 [ required: !=3.0.5,>=2.0.2 ] │ └── torch 1.12.0+cpu [ required: >=1.3.1 ] │ └── typing-extensions 4.3.0 [ required: Any ] ├── tqdm 4.64.0 [ required: >=4.57.0 ] └── typing-extensions 4.3.0 [ required: >=4.0.0 ] rich 12.5.1 [ required: >=12.4.1 ] ├── commonmark 0.9.1 [ required: <0.10.0,>=0.9.0 ] └── pygments 2.12.0 [ required: <3.0.0,>=2.6.0 ] torchvision 0.13.0+cpu [ required: ==0.13.0+cpu ] ├── numpy 1.23.1 [ required: Any ] ├── pillow 9.2.0 [ required: !=8.3.*,>=5.3.0 ] ├── requests 2.28.1 [ required: Any ] │ ├── certifi 2022.6.15 [ required: >=2017.4.17 ] │ ├── charset-normalizer 2.1.0 [ required: <3,>=2 ] │ ├── idna 3.3 [ required: <4,>=2.5 ] │ └── urllib3 1.26.11 [ required: <1.27,>=1.21.1 ] ├── torch 1.12.0+cpu [ required: ==1.12.0 ] │ └── typing-extensions 4.3.0 [ required: Any ] └── typing-extensions 4.3.0 [ required: Any ] vq-ae-2-2d 0.0.3 [ Not required ] ```