[BUG]: Assertion error when forward passing `nn.Embedding` without gradient. #5191

Open namespace-Pt opened 6 months ago

namespace-Pt commented 6 months ago

🐛 Describe the bug

I have a freezed embedding table, i.e. all parameters in the table do not require gradients. When using this embedding table in the forward pass in training, there is an AssertionError raised. (Note that there are other parameters in the model that require gradients.)

Here is a minimal script to reproduce:

>>> test.py

import torch
import torch.nn as nn
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.nn.optimizer import HybridAdam


plugin = GeminiPlugin(precision="bf16", initial_scale=2**16)
booster = Booster(plugin=plugin)

class Model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.embedding = nn.Embedding(100, 1024)

        self.linear = nn.Linear(1024,1024)

    def forward(self, x):
        embed = self.embedding(x)
        transform = self.linear(embed)
        loss = (transform ** 2).sum()
        return loss

model = Model()
optimizer = HybridAdam(model.parameters(), lr=5e-5, betas=(0.9, 0.999), weight_decay=0)
model, optimizer = booster.boost(model, optimizer)[:2]

inputs = torch.tensor([1,2,3], device="cuda")
loss = model(inputs)
booster.backward(loss, optimizer)

Run with torchrun --nproc_per_node 4 test.py


Current environment * CUDA: - GPU: - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - available: True - version: 11.8 * Lightning: - torch: 2.1.0 - torch-scatter: 2.1.2 - torchvision: 0.16.0 * Packages: - accelerate: 0.23.0 - aiohttp: 3.8.6 - aiosignal: 1.3.1 - anyio: 4.1.0 - asttokens: 2.4.0 - async-timeout: 4.0.3 - attrs: 23.1.0 - backcall: 0.2.0 - bcrypt: 4.1.1 - beautifulsoup4: 4.12.2 - blis: 0.7.11 - cachetools: 5.3.1 - catalogue: 2.0.10 - certifi: 2023.7.22 - cffi: 1.16.0 - cfgv: 3.4.0 - charset-normalizer: 3.3.0 - click: 8.1.7 - cloudpathlib: 0.16.0 - coloredlogs: 15.0.1 - colossalai: 0.3.4 - comm: 0.1.4 - confection: 0.1.4 - contexttimer: 0.3.3 - contourpy: 1.1.1 - cryptography: 41.0.7 - cycler: 0.12.1 - cymem: 2.0.8 - cython: 3.0.6 - datasets: 2.14.5 - debugpy: 1.8.0 - decorator: 5.1.1 - deepspeed: 0.11.1 - deprecated: 1.2.14 - dill: 0.3.7 - distlib: 0.3.8 - distro: 1.8.0 - einops: 0.7.0 - exceptiongroup: 1.1.3 - executing: 2.0.0 - fabric: 3.2.2 - faiss: 1.7.2 - filelock: 3.13.1 - flagembedding: 1.1.5 - flash-attn: 2.3.4 - flatbuffers: 23.5.26 - fonttools: 4.43.1 - frozenlist: 1.4.0 - fsspec: 2023.6.0 - fuzzywuzzy: 0.18.0 - gmpy2: 2.1.2 - google: 3.0.0 - h11: 0.14.0 - hjson: 3.1.0 - httpcore: 1.0.2 - httpx: 0.25.2 - huggingface-hub: 0.17.3 - humanfriendly: 10.0 - identify: 2.5.33 - idna: 3.4 - instructorembedding: 1.0.1 - invoke: 2.2.0 - ipykernel: 6.25.2 - ipython: 8.16.1 - ipywidgets: 8.1.1 - jedi: 0.19.1 - jieba: 0.42.1 - jinja2: 3.1.2 - joblib: 1.3.2 - jsonschema: 4.20.0 - jsonschema-specifications: 2023.11.2 - jupyter-client: 8.4.0 - jupyter-core: 5.4.0 - jupyterlab-widgets: 3.0.9 - keybert: 0.8.3 - kiwisolver: 1.4.5 - langcodes: 3.3.0 - levenshtein: 0.23.0 - lightgbm: 4.1.0 - marisa-trie: 1.1.0 - markdown-it-py: 3.0.0 - markupsafe: 2.1.1 - matplotlib: 3.8.0 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.4 - multiprocess: 0.70.15 - murmurhash: 1.0.10 - nest-asyncio: 1.5.8 - networkx: 3.1 - ninja: - nltk: 3.8.1 - nmslib: 2.1.1 - nodeenv: 1.8.0 - numpy: 1.26.1 - nvidia-ml-py: 12.535.108 - nvitop: 1.3.1 - onnxruntime: 1.16.3 - openai: 1.3.9 - packaging: 23.2 - pandas: 2.1.1 - paramiko: 3.3.1 - parso: 0.8.3 - peft: 0.6.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 10.1.0 - pip: 23.3 - platformdirs: 3.11.0 - pre-commit: 3.6.0 - preshed: 3.0.9 - prompt-toolkit: 3.0.39 - protobuf: 4.25.0 - psutil: 5.9.6 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py-cpuinfo: 9.0.0 - pyarrow: 13.0.0 - pybind11: 2.6.1 - pycparser: 2.21 - pydantic: 1.10.13 - pygments: 2.16.1 - pyjnius: 1.6.1 - pynacl: 1.5.0 - pyparsing: 3.1.1 - python-dateutil: 2.8.2 - python-levenshtein: 0.23.0 - pytrec-eval: 0.5 - pytz: 2023.3.post1 - pyyaml: 6.0 - pyzmq: 25.1.1 - rapidfuzz: 3.5.1 - ray: 2.8.1 - referencing: 0.32.0 - regex: 2023.10.3 - requests: 2.31.0 - rich: 13.6.0 - rouge: 1.0.1 - rpds-py: 0.13.2 - safetensors: 0.4.0 - scikit-learn: 1.3.1 - scipy: 1.11.3 - seaborn: 0.13.0 - sentence-transformers: 2.2.2 - sentencepiece: 0.1.99 - setuptools: 68.0.0 - six: 1.16.0 - smart-open: 6.4.0 - sniffio: 1.3.0 - soupsieve: 2.5 - spacy: 3.7.2 - spacy-legacy: 3.0.12 - spacy-loggers: 1.0.5 - srsly: 2.4.8 - stack-data: 0.6.3 - sympy: 1.11.1 - termcolor: 2.3.0 - thinc: 8.2.1 - threadpoolctl: 3.2.0 - tiktoken: 0.5.2 - tokenizers: 0.14.1 - torch: 2.1.0 - torch-scatter: 2.1.2 - torchvision: 0.16.0 - tornado: 6.3.3 - tqdm: 4.66.1 - traitlets: 5.11.2 - transformers: 4.34.1 - triton: 2.1.0 - typer: 0.9.0 - typing-extensions: 4.7.1 - tzdata: 2023.3 - urllib3: 2.0.7 - virtualenv: 20.25.0 - wasabi: 1.1.2 - wcwidth: 0.2.8 - weasel: 0.3.4 - wheel: 0.41.2 - widgetsnbextension: 4.0.9 - wrapt: 1.16.0 - xformers: 0.0.22.post7 - xxhash: 3.4.1 - yarl: 1.9.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.13 - release: 5.4.0-147-generic - version: #164-Ubuntu SMP Tue Mar 21 14:23:17 UTC 2023
flybird11111 commented 6 months ago

Hi, this seems to be incompatible with the Gemini strategy. Indeed, encountering this issue with a frozen embedding table.

namespace-Pt commented 6 months ago

Okay thank you.

luckyyangrun commented 6 months ago

Are there any plans to fix it?