I have a freezed embedding table, i.e. all parameters in the table do not require gradients. When using this embedding table in the forward pass in training, there is an AssertionError raised. (Note that there are other parameters in the model that require gradients.)
Here is a minimal script to reproduce:
>>> test.py
import torch
import torch.nn as nn
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.nn.optimizer import HybridAdam
colossalai.launch_from_torch({})
plugin = GeminiPlugin(precision="bf16", initial_scale=2**16)
booster = Booster(plugin=plugin)
class Model(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.embedding = nn.Embedding(100, 1024)
self.embedding.requires_grad_(False)
self.linear = nn.Linear(1024,1024)
def forward(self, x):
embed = self.embedding(x)
transform = self.linear(embed)
loss = (transform ** 2).sum()
return loss
model = Model()
optimizer = HybridAdam(model.parameters(), lr=5e-5, betas=(0.9, 0.999), weight_decay=0)
model, optimizer = booster.boost(model, optimizer)[:2]
inputs = torch.tensor([1,2,3], device="cuda")
loss = model(inputs)
booster.backward(loss, optimizer)
🐛 Describe the bug
I have a freezed embedding table, i.e. all parameters in the table do not require gradients. When using this embedding table in the forward pass in training, there is an
AssertionError
raised. (Note that there are other parameters in the model that require gradients.)Here is a minimal script to reproduce:
Run with
torchrun --nproc_per_node 4 test.py
Environment
Current environment
* CUDA: - GPU: - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - available: True - version: 11.8 * Lightning: - torch: 2.1.0 - torch-scatter: 2.1.2 - torchvision: 0.16.0 * Packages: - accelerate: 0.23.0 - aiohttp: 3.8.6 - aiosignal: 1.3.1 - anyio: 4.1.0 - asttokens: 2.4.0 - async-timeout: 4.0.3 - attrs: 23.1.0 - backcall: 0.2.0 - bcrypt: 4.1.1 - beautifulsoup4: 4.12.2 - blis: 0.7.11 - cachetools: 5.3.1 - catalogue: 2.0.10 - certifi: 2023.7.22 - cffi: 1.16.0 - cfgv: 3.4.0 - charset-normalizer: 3.3.0 - click: 8.1.7 - cloudpathlib: 0.16.0 - coloredlogs: 15.0.1 - colossalai: 0.3.4 - comm: 0.1.4 - confection: 0.1.4 - contexttimer: 0.3.3 - contourpy: 1.1.1 - cryptography: 41.0.7 - cycler: 0.12.1 - cymem: 2.0.8 - cython: 3.0.6 - datasets: 2.14.5 - debugpy: 1.8.0 - decorator: 5.1.1 - deepspeed: 0.11.1 - deprecated: 1.2.14 - dill: 0.3.7 - distlib: 0.3.8 - distro: 1.8.0 - einops: 0.7.0 - exceptiongroup: 1.1.3 - executing: 2.0.0 - fabric: 3.2.2 - faiss: 1.7.2 - filelock: 3.13.1 - flagembedding: 1.1.5 - flash-attn: 2.3.4 - flatbuffers: 23.5.26 - fonttools: 4.43.1 - frozenlist: 1.4.0 - fsspec: 2023.6.0 - fuzzywuzzy: 0.18.0 - gmpy2: 2.1.2 - google: 3.0.0 - h11: 0.14.0 - hjson: 3.1.0 - httpcore: 1.0.2 - httpx: 0.25.2 - huggingface-hub: 0.17.3 - humanfriendly: 10.0 - identify: 2.5.33 - idna: 3.4 - instructorembedding: 1.0.1 - invoke: 2.2.0 - ipykernel: 6.25.2 - ipython: 8.16.1 - ipywidgets: 8.1.1 - jedi: 0.19.1 - jieba: 0.42.1 - jinja2: 3.1.2 - joblib: 1.3.2 - jsonschema: 4.20.0 - jsonschema-specifications: 2023.11.2 - jupyter-client: 8.4.0 - jupyter-core: 5.4.0 - jupyterlab-widgets: 3.0.9 - keybert: 0.8.3 - kiwisolver: 1.4.5 - langcodes: 3.3.0 - levenshtein: 0.23.0 - lightgbm: 4.1.0 - marisa-trie: 1.1.0 - markdown-it-py: 3.0.0 - markupsafe: 2.1.1 - matplotlib: 3.8.0 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.4 - multiprocess: 0.70.15 - murmurhash: 1.0.10 - nest-asyncio: 1.5.8 - networkx: 3.1 - ninja: 1.11.1.1 - nltk: 3.8.1 - nmslib: 2.1.1 - nodeenv: 1.8.0 - numpy: 1.26.1 - nvidia-ml-py: 12.535.108 - nvitop: 1.3.1 - onnxruntime: 1.16.3 - openai: 1.3.9 - packaging: 23.2 - pandas: 2.1.1 - paramiko: 3.3.1 - parso: 0.8.3 - peft: 0.6.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 10.1.0 - pip: 23.3 - platformdirs: 3.11.0 - pre-commit: 3.6.0 - preshed: 3.0.9 - prompt-toolkit: 3.0.39 - protobuf: 4.25.0 - psutil: 5.9.6 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py-cpuinfo: 9.0.0 - pyarrow: 13.0.0 - pybind11: 2.6.1 - pycparser: 2.21 - pydantic: 1.10.13 - pygments: 2.16.1 - pyjnius: 1.6.1 - pynacl: 1.5.0 - pyparsing: 3.1.1 - python-dateutil: 2.8.2 - python-levenshtein: 0.23.0 - pytrec-eval: 0.5 - pytz: 2023.3.post1 - pyyaml: 6.0 - pyzmq: 25.1.1 - rapidfuzz: 3.5.1 - ray: 2.8.1 - referencing: 0.32.0 - regex: 2023.10.3 - requests: 2.31.0 - rich: 13.6.0 - rouge: 1.0.1 - rpds-py: 0.13.2 - safetensors: 0.4.0 - scikit-learn: 1.3.1 - scipy: 1.11.3 - seaborn: 0.13.0 - sentence-transformers: 2.2.2 - sentencepiece: 0.1.99 - setuptools: 68.0.0 - six: 1.16.0 - smart-open: 6.4.0 - sniffio: 1.3.0 - soupsieve: 2.5 - spacy: 3.7.2 - spacy-legacy: 3.0.12 - spacy-loggers: 1.0.5 - srsly: 2.4.8 - stack-data: 0.6.3 - sympy: 1.11.1 - termcolor: 2.3.0 - thinc: 8.2.1 - threadpoolctl: 3.2.0 - tiktoken: 0.5.2 - tokenizers: 0.14.1 - torch: 2.1.0 - torch-scatter: 2.1.2 - torchvision: 0.16.0 - tornado: 6.3.3 - tqdm: 4.66.1 - traitlets: 5.11.2 - transformers: 4.34.1 - triton: 2.1.0 - typer: 0.9.0 - typing-extensions: 4.7.1 - tzdata: 2023.3 - urllib3: 2.0.7 - virtualenv: 20.25.0 - wasabi: 1.1.2 - wcwidth: 0.2.8 - weasel: 0.3.4 - wheel: 0.41.2 - widgetsnbextension: 4.0.9 - wrapt: 1.16.0 - xformers: 0.0.22.post7 - xxhash: 3.4.1 - yarl: 1.9.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.13 - release: 5.4.0-147-generic - version: #164-Ubuntu SMP Tue Mar 21 14:23:17 UTC 2023