hpcaitech / ColossalAI

Making large AI models cheaper, faster and more accessible
https://www.colossalai.org
Apache License 2.0
38.78k stars 4.34k forks source link

[BUG]: Assertion error when forward passing `nn.Embedding` without gradient. #5191

Open namespace-Pt opened 10 months ago

namespace-Pt commented 10 months ago

🐛 Describe the bug

I have a freezed embedding table, i.e. all parameters in the table do not require gradients. When using this embedding table in the forward pass in training, there is an AssertionError raised. (Note that there are other parameters in the model that require gradients.)

Here is a minimal script to reproduce:

>>> test.py

import torch
import torch.nn as nn
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.nn.optimizer import HybridAdam

colossalai.launch_from_torch({})

plugin = GeminiPlugin(precision="bf16", initial_scale=2**16)
booster = Booster(plugin=plugin)

class Model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.embedding = nn.Embedding(100, 1024)
        self.embedding.requires_grad_(False)

        self.linear = nn.Linear(1024,1024)

    def forward(self, x):
        embed = self.embedding(x)
        transform = self.linear(embed)
        loss = (transform ** 2).sum()
        return loss

model = Model()
optimizer = HybridAdam(model.parameters(), lr=5e-5, betas=(0.9, 0.999), weight_decay=0)
model, optimizer = booster.boost(model, optimizer)[:2]

inputs = torch.tensor([1,2,3], device="cuda")
loss = model(inputs)
booster.backward(loss, optimizer)

Run with torchrun --nproc_per_node 4 test.py

Environment

Current environment * CUDA: - GPU: - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - available: True - version: 11.8 * Lightning: - torch: 2.1.0 - torch-scatter: 2.1.2 - torchvision: 0.16.0 * Packages: - accelerate: 0.23.0 - aiohttp: 3.8.6 - aiosignal: 1.3.1 - anyio: 4.1.0 - asttokens: 2.4.0 - async-timeout: 4.0.3 - attrs: 23.1.0 - backcall: 0.2.0 - bcrypt: 4.1.1 - beautifulsoup4: 4.12.2 - blis: 0.7.11 - cachetools: 5.3.1 - catalogue: 2.0.10 - certifi: 2023.7.22 - cffi: 1.16.0 - cfgv: 3.4.0 - charset-normalizer: 3.3.0 - click: 8.1.7 - cloudpathlib: 0.16.0 - coloredlogs: 15.0.1 - colossalai: 0.3.4 - comm: 0.1.4 - confection: 0.1.4 - contexttimer: 0.3.3 - contourpy: 1.1.1 - cryptography: 41.0.7 - cycler: 0.12.1 - cymem: 2.0.8 - cython: 3.0.6 - datasets: 2.14.5 - debugpy: 1.8.0 - decorator: 5.1.1 - deepspeed: 0.11.1 - deprecated: 1.2.14 - dill: 0.3.7 - distlib: 0.3.8 - distro: 1.8.0 - einops: 0.7.0 - exceptiongroup: 1.1.3 - executing: 2.0.0 - fabric: 3.2.2 - faiss: 1.7.2 - filelock: 3.13.1 - flagembedding: 1.1.5 - flash-attn: 2.3.4 - flatbuffers: 23.5.26 - fonttools: 4.43.1 - frozenlist: 1.4.0 - fsspec: 2023.6.0 - fuzzywuzzy: 0.18.0 - gmpy2: 2.1.2 - google: 3.0.0 - h11: 0.14.0 - hjson: 3.1.0 - httpcore: 1.0.2 - httpx: 0.25.2 - huggingface-hub: 0.17.3 - humanfriendly: 10.0 - identify: 2.5.33 - idna: 3.4 - instructorembedding: 1.0.1 - invoke: 2.2.0 - ipykernel: 6.25.2 - ipython: 8.16.1 - ipywidgets: 8.1.1 - jedi: 0.19.1 - jieba: 0.42.1 - jinja2: 3.1.2 - joblib: 1.3.2 - jsonschema: 4.20.0 - jsonschema-specifications: 2023.11.2 - jupyter-client: 8.4.0 - jupyter-core: 5.4.0 - jupyterlab-widgets: 3.0.9 - keybert: 0.8.3 - kiwisolver: 1.4.5 - langcodes: 3.3.0 - levenshtein: 0.23.0 - lightgbm: 4.1.0 - marisa-trie: 1.1.0 - markdown-it-py: 3.0.0 - markupsafe: 2.1.1 - matplotlib: 3.8.0 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.4 - multiprocess: 0.70.15 - murmurhash: 1.0.10 - nest-asyncio: 1.5.8 - networkx: 3.1 - ninja: 1.11.1.1 - nltk: 3.8.1 - nmslib: 2.1.1 - nodeenv: 1.8.0 - numpy: 1.26.1 - nvidia-ml-py: 12.535.108 - nvitop: 1.3.1 - onnxruntime: 1.16.3 - openai: 1.3.9 - packaging: 23.2 - pandas: 2.1.1 - paramiko: 3.3.1 - parso: 0.8.3 - peft: 0.6.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 10.1.0 - pip: 23.3 - platformdirs: 3.11.0 - pre-commit: 3.6.0 - preshed: 3.0.9 - prompt-toolkit: 3.0.39 - protobuf: 4.25.0 - psutil: 5.9.6 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py-cpuinfo: 9.0.0 - pyarrow: 13.0.0 - pybind11: 2.6.1 - pycparser: 2.21 - pydantic: 1.10.13 - pygments: 2.16.1 - pyjnius: 1.6.1 - pynacl: 1.5.0 - pyparsing: 3.1.1 - python-dateutil: 2.8.2 - python-levenshtein: 0.23.0 - pytrec-eval: 0.5 - pytz: 2023.3.post1 - pyyaml: 6.0 - pyzmq: 25.1.1 - rapidfuzz: 3.5.1 - ray: 2.8.1 - referencing: 0.32.0 - regex: 2023.10.3 - requests: 2.31.0 - rich: 13.6.0 - rouge: 1.0.1 - rpds-py: 0.13.2 - safetensors: 0.4.0 - scikit-learn: 1.3.1 - scipy: 1.11.3 - seaborn: 0.13.0 - sentence-transformers: 2.2.2 - sentencepiece: 0.1.99 - setuptools: 68.0.0 - six: 1.16.0 - smart-open: 6.4.0 - sniffio: 1.3.0 - soupsieve: 2.5 - spacy: 3.7.2 - spacy-legacy: 3.0.12 - spacy-loggers: 1.0.5 - srsly: 2.4.8 - stack-data: 0.6.3 - sympy: 1.11.1 - termcolor: 2.3.0 - thinc: 8.2.1 - threadpoolctl: 3.2.0 - tiktoken: 0.5.2 - tokenizers: 0.14.1 - torch: 2.1.0 - torch-scatter: 2.1.2 - torchvision: 0.16.0 - tornado: 6.3.3 - tqdm: 4.66.1 - traitlets: 5.11.2 - transformers: 4.34.1 - triton: 2.1.0 - typer: 0.9.0 - typing-extensions: 4.7.1 - tzdata: 2023.3 - urllib3: 2.0.7 - virtualenv: 20.25.0 - wasabi: 1.1.2 - wcwidth: 0.2.8 - weasel: 0.3.4 - wheel: 0.41.2 - widgetsnbextension: 4.0.9 - wrapt: 1.16.0 - xformers: 0.0.22.post7 - xxhash: 3.4.1 - yarl: 1.9.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.13 - release: 5.4.0-147-generic - version: #164-Ubuntu SMP Tue Mar 21 14:23:17 UTC 2023
flybird11111 commented 10 months ago

Hi, this seems to be incompatible with the Gemini strategy. Indeed, encountering this issue with a frozen embedding table.

namespace-Pt commented 10 months ago

Okay thank you.

luckyyangrun commented 10 months ago

Are there any plans to fix it?