Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.24k stars 3.38k forks source link

mixed precision with Deepspeed #15168

Open wangleiofficial opened 2 years ago

wangleiofficial commented 2 years ago

Bug description

When using mixed precision with Deepspeed, the model resulted in the error: RuntimeError: expected scalar type Float but found Half.

How to reproduce the bug

class SimpleModel(LightningModule):
    """SimpleModel

    Args:
        args: model init hyperparameters
    """
    def __init__(self, args):
        super().__init__()
        self.args= args
        self.save_hyperparameters(args)
        self.pretrain_model = Bert()
        self.classifier = SimpleMLP()

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        tokens, labels = x
        with torch.no_grad():
            embeddings = self.pretrain_model(batch_tokens)

        preds, loss = self.classifier(embeddings, label)
        return preds

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward

        # Logging to TensorBoard by default
        # self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
        tokens, labels = x
        with torch.no_grad():
            embeddings = self.pretrain_model(batch_tokens)

        preds, loss = self.classifier(embeddings, label)
        self.log("training_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        tokens, labels = x
        with torch.no_grad():
            embeddings = self.pretrain_model(batch_tokens)

        preds, loss = self.classifier(embeddings, label)
        self.log("val_loss", loss, on_epoch=True, on_step=True, sync_dist=True)

    def test_step(self, batch, batch_idx):
        tokens, labels = x
        with torch.no_grad():
            embeddings = self.pretrain_model(batch_tokens)

        preds, loss = self.classifier(embeddings, label)
        self.log("test_loss", loss, on_epoch=True, on_step=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.classifier.parameters(), lr=self.args['lr'], weight_decay=0.01, eps=1e-6)
        return optimizer
model = SimpleModel(args=args)
trainer = pl.Trainer(devices=4,strategy="deepspeed_stage_3", precision=16, max_epochs=20, accelerator='gpu')
trainer.fit(model, datamodule=dataset)

Error messages and logs


# Error messages and logs here please
Traceback (most recent call last):
  File "/home/wanglei/data/alphafold_db/ProtBert/benchmark/src/esm_contactmap_pl.py", line 228, in <module>
    main(params)
  File "/home/wanglei/data/alphafold_db/ProtBert/benchmark/src/esm_contactmap_pl.py", line 196, in main
    trainer.fit(model, datamodule=dataset)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
    self._call_and_handle_interrupt(
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1236, in _run
    results = self._run_stage()
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1323, in _run_stage
    return self._run_train()
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1345, in _run_train
    self._run_sanity_check()
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1413, in _run_sanity_check
    val_loop.run()
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 155, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 128, in advance
    output = self._evaluation_step(**kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 226, in _evaluation_step
    output = self.trainer._call_strategy_hook("validation_step", *kwargs.values())
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1765, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/strategies/deepspeed.py", line 906, in validation_step
    return self.model(*args, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 11, in wrapped_fn
    return func(*args, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1599, in forward
    loss = self.module(*inputs, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1128, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/strategies/deepspeed.py", line 80, in forward
    return super().forward(*inputs, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 93, in forward
    return self.module.validation_step(*inputs, **kwargs)
  File "/home/wanglei/data/alphafold_db/ProtBert/benchmark/src/esm_contactmap_pl.py", line 129, in validation_step
    protein_dict = self.esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1128, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/esm/model.py", line 140, in forward
    x = self.emb_layer_norm_before(x)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1128, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/modules/normalization.py", line 189, in forward
    return F.layer_norm(
  File "/home/wanglei/anaconda3/envs/dl/lib/python3.8/site-packages/torch/nn/functional.py", line 2486, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: expected scalar type Float but found Half

Environment

Current Environment ``` * CUDA: - GPU: - GeForce RTX 3090 - GeForce RTX 3090 - GeForce RTX 3090 - GeForce RTX 3090 - GeForce RTX 3090 - GeForce RTX 3090 - GeForce RTX 3090 - GeForce RTX 3090 - available: True - version: 11.3 * Lightning: - pytorch-lightning: 1.6.5 - torch: 1.11.0 - torchaudio: 0.11.0 - torchinfo: 1.7.0 - torchmetrics: 0.10.0 - torchvision: 0.12.0 * Packages: - absl-py: 1.0.0 - aiohttp: 3.8.1 - aiosignal: 1.2.0 - asttokens: 2.0.5 - async-timeout: 4.0.2 - attrs: 21.4.0 - backcall: 0.2.0 - biopython: 1.79 - brotlipy: 0.7.0 - cached-property: 1.5.2 - cachetools: 5.0.0 - certifi: 2022.6.15 - cffi: 1.14.4 - charset-normalizer: 2.1.0 - click: 8.1.3 - cryptography: 37.0.2 - cycler: 0.11.0 - decorator: 5.1.1 - deepspeed: 0.6.6 - deprecated: 1.2.13 - distlib: 0.3.4 - docker-pycreds: 0.4.0 - einops: 0.4.0 - executing: 0.8.3 - fair-esm: 0.4.2 - fairscale: 0.4.6 - filelock: 3.7.0 - fonttools: 4.29.1 - frozenlist: 1.3.0 - fsspec: 2022.2.0 - future: 0.18.2 - gitdb: 4.0.9 - gitpython: 3.1.27 - google-auth: 2.6.0 - google-auth-oauthlib: 0.4.6 - grpcio: 1.44.0 - h5py: 3.6.0 - hjson: 3.0.2 - huggingface-hub: 0.6.0 - idna: 3.3 - importlib-metadata: 4.11.2 - infinibatch: 0.1.0 - ipython: 8.1.0 - jedi: 0.18.1 - joblib: 1.1.0 - kiwisolver: 1.3.2 - lmdb: 1.3.0 - lxml: 4.8.0 - markdown: 3.3.6 - matplotlib: 3.5.1 - matplotlib-inline: 0.1.3 - mkl-fft: 1.3.1 - mkl-random: 1.2.2 - mkl-service: 2.4.0 - multidict: 6.0.2 - ninja: 1.10.2.3 - numpy: 1.22.3 - oauthlib: 3.2.0 - packaging: 21.3 - pandas: 1.4.2 - parso: 0.8.3 - pathtools: 0.1.2 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.1.1 - pip: 22.1.2 - platformdirs: 2.5.2 - plip: 2.2.2 - promise: 2.3 - prompt-toolkit: 3.0.28 - protobuf: 3.19.4 - psutil: 5.9.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py-cpuinfo: 8.0.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycparser: 2.21 - pydantic: 1.9.1 - pydeprecate: 0.3.1 - pygments: 2.11.2 - pyopenssl: 22.0.0 - pyparsing: 3.0.7 - pysocks: 1.7.1 - python-dateutil: 2.8.2 - pytorch-lightning: 1.6.5 - pytz: 2022.1 - pyyaml: 6.0 - redis: 4.3.1 - regex: 2022.4.24 - requests: 2.28.1 - requests-oauthlib: 1.3.1 - rsa: 4.8 - scikit-learn: 1.1.1 - scipy: 1.8.0 - sentencepiece: 0.1.97 - sentry-sdk: 1.5.12 - setproctitle: 1.2.3 - setuptools: 62.6.0 - shortuuid: 1.0.9 - six: 1.16.0 - smmap: 5.0.0 - stack-data: 0.2.0 - tensorboard: 2.8.0 - tensorboard-data-server: 0.6.1 - tensorboard-plugin-wit: 1.8.1 - threadpoolctl: 3.1.0 - tokenizers: 0.12.1 - torch: 1.11.0 - torchaudio: 0.11.0 - torchinfo: 1.7.0 - torchmetrics: 0.10.0 - torchvision: 0.12.0 - tqdm: 4.63.0 - traitlets: 5.1.1 - transformers: 4.21.2 - triton: 1.0.0 - typing-extensions: 4.3.0 - urllib3: 1.26.9 - virtualenv: 20.14.1 - wandb: 0.12.16 - wcwidth: 0.2.5 - werkzeug: 2.0.3 - wheel: 0.37.1 - wrapt: 1.14.1 - yarl: 1.7.2 - zipp: 3.7.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.8.0 - version: #1 SMP Thu Nov 8 23:39:32 UTC 2018 ```

More info

No response

cc @awaelchli

Line290 commented 1 year ago

Hi @wangleiofficial, I met the problem same with you. Do you fix it?

wangleiofficial commented 1 year ago

@Line290 Not yet,i guess the part parameters(Pretrained model) are not handled correctly.

FarzanT commented 1 year ago

I've got the same problem, any fixes yet?

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

ewrfcas commented 1 year ago

+1, deepspeed_stage2 meets the same error.

YTEP-ZHI commented 6 months ago

Similar problem with deepspeed_stage_1.

File "python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2019, in backward self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) File "python3.9/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward scaled_loss.backward(retain_graph=retain_graph) File "python3.9/site-packages/torch/_tensor.py", line 487, in backward torch.autograd.backward( File "python3.9/site-packages/torch/autograd/init.py", line 200, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Found dtype Float but expected Half