yfzhang114 / SliME

✨✨Beyond LLaVA-HD: Diving into High-Resolution Large Multimodal Models
Apache License 2.0
128 stars 7 forks source link

【RuntimeError: "erfinv_cuda" not implemented for 'BFloat16'】 #8

Closed Luo-Z13 closed 1 month ago

Luo-Z13 commented 1 month ago

Hello, I encounter the error "RuntimeError: 'erfinv_cuda' not implemented for 'BFloat16'" when I try to fine-tune based on the SliME-Vicuna-7B weight. Could you please provide some suggestions? My script:

deepspeed --master_port=$((RANDOM + 10000)) --include localhost:0,1,2,3 llava/train/train_mem.py \
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
    --deepspeed ./scripts/zero3.json \
    --model_name_or_path $PROJECTOR_DIR \
    --version v1 \
    --data_path finetune \
    --image_folder $DATA_DIR \
    --vision_tower /clip-vit-large-patch14-336 \
    --pretrain_mm_mlp_adapter $PROJECTOR_DIR/mm_projector.bin \
    --pretrain_mm_re_sampler $PROJECTOR_DIR/sampler.bin \
    --mm_projector_type gated \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --task SFT \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir /SliME_out_checkpoints \
    --num_train_epochs 1 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 10000 \
    --save_total_limit 10 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb \
    --mm_patch_merge_type spatial \
    --image_aspect_ratio $padding \
    --mm_resampler_type cosine \
    --mm_resampler_topp 0.95 \
    --mm_resampler_dim 144 \
    --mm_resampler_temp 1.0

My environment:

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main    defaults
_openmp_mutex             5.1                       1_gnu    defaults
accelerate                0.21.0                   pypi_0    pypi
aiofiles                  23.2.1                   pypi_0    pypi
aiohttp                   3.9.5                    pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
altair                    5.3.0                    pypi_0    pypi
annotated-types           0.7.0                    pypi_0    pypi
anyio                     4.4.0                    pypi_0    pypi
async-timeout             4.0.3                    pypi_0    pypi
attrs                     23.2.0                   pypi_0    pypi
bitsandbytes              0.42.0                   pypi_0    pypi
bzip2                     1.0.8                h5eee18b_6    defaults
ca-certificates           2024.7.2             h06a4308_0    defaults
certifi                   2024.7.4                 pypi_0    pypi
charset-normalizer        3.3.2                    pypi_0    pypi
click                     8.1.7                    pypi_0    pypi
contourpy                 1.2.1                    pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
datasets                  2.20.0                   pypi_0    pypi
deepspeed                 0.12.6                   pypi_0    pypi
dill                      0.3.8                    pypi_0    pypi
dnspython                 2.6.1                    pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
einops                    0.6.1                    pypi_0    pypi
einops-exts               0.0.4                    pypi_0    pypi
email-validator           2.2.0                    pypi_0    pypi
exceptiongroup            1.2.2                    pypi_0    pypi
fastapi                   0.111.1                  pypi_0    pypi
fastapi-cli               0.0.4                    pypi_0    pypi
ffmpy                     0.3.2                    pypi_0    pypi
filelock                  3.15.4                   pypi_0    pypi
flash-attn                2.6.1                    pypi_0    pypi
fonttools                 4.53.1                   pypi_0    pypi
frozenlist                1.4.1                    pypi_0    pypi
fsspec                    2024.5.0                 pypi_0    pypi
gitdb                     4.0.11                   pypi_0    pypi
gitpython                 3.1.43                   pypi_0    pypi
gradio                    4.16.0                   pypi_0    pypi
gradio-client             0.8.1                    pypi_0    pypi
h11                       0.14.0                   pypi_0    pypi
hjson                     3.1.0                    pypi_0    pypi
httpcore                  0.17.3                   pypi_0    pypi
httptools                 0.6.1                    pypi_0    pypi
httpx                     0.24.0                   pypi_0    pypi
huggingface-hub           0.23.4                   pypi_0    pypi
idna                      3.7                      pypi_0    pypi
importlib-resources       6.4.0                    pypi_0    pypi
jinja2                    3.1.4                    pypi_0    pypi
joblib                    1.4.2                    pypi_0    pypi
jsonschema                4.23.0                   pypi_0    pypi
jsonschema-specifications 2023.12.1                pypi_0    pypi
kiwisolver                1.4.5                    pypi_0    pypi
latex2mathml              3.77.0                   pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1    defaults
libffi                    3.4.4                h6a678d5_1    defaults
libgcc-ng                 11.2.0               h1234567_1    defaults
libgomp                   11.2.0               h1234567_1    defaults
libstdcxx-ng              11.2.0               h1234567_1    defaults
libuuid                   1.41.5               h5eee18b_0    defaults
llava                     1.2.2.post1              pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markdown2                 2.5.0                    pypi_0    pypi
markupsafe                2.1.5                    pypi_0    pypi
matplotlib                3.9.1                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
multidict                 6.0.5                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
ncurses                   6.4                  h6a678d5_0    defaults
networkx                  3.3                      pypi_0    pypi
ninja                     1.11.1.1                 pypi_0    pypi
numpy                     1.26.4                   pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu12          2.18.1                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.5.82                  pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
openssl                   3.0.14               h5eee18b_0    defaults
orjson                    3.10.6                   pypi_0    pypi
packaging                 24.1                     pypi_0    pypi
pandas                    2.2.2                    pypi_0    pypi
peft                      0.11.1                   pypi_0    pypi
pillow                    10.4.0                   pypi_0    pypi
pip                       24.1.2                   pypi_0    pypi
platformdirs              4.2.2                    pypi_0    pypi
protobuf                  5.27.2                   pypi_0    pypi
psutil                    6.0.0                    pypi_0    pypi
py-cpuinfo                9.0.0                    pypi_0    pypi
pyarrow                   16.1.0                   pypi_0    pypi
pyarrow-hotfix            0.6                      pypi_0    pypi
pydantic                  2.8.2                    pypi_0    pypi
pydantic-core             2.20.1                   pypi_0    pypi
pydub                     0.25.1                   pypi_0    pypi
pygments                  2.18.0                   pypi_0    pypi
pynvml                    11.5.2                   pypi_0    pypi
pyparsing                 3.1.2                    pypi_0    pypi
python                    3.10.14              h955ad1f_1    defaults
python-dateutil           2.9.0.post0              pypi_0    pypi
python-dotenv             1.0.1                    pypi_0    pypi
python-multipart          0.0.9                    pypi_0    pypi
pytz                      2024.1                   pypi_0    pypi
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.2                  h5eee18b_0    defaults
referencing               0.35.1                   pypi_0    pypi
regex                     2024.5.15                pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
rich                      13.7.1                   pypi_0    pypi
rpds-py                   0.19.0                   pypi_0    pypi
ruff                      0.5.2                    pypi_0    pypi
safetensors               0.4.3                    pypi_0    pypi
scikit-learn              1.2.2                    pypi_0    pypi
scipy                     1.14.0                   pypi_0    pypi
semantic-version          2.10.0                   pypi_0    pypi
sentencepiece             0.1.99                   pypi_0    pypi
sentry-sdk                2.10.0                   pypi_0    pypi
setproctitle              1.3.3                    pypi_0    pypi
setuptools                69.5.1          py310h06a4308_0    defaults
shellingham               1.5.4                    pypi_0    pypi
shortuuid                 1.0.13                   pypi_0    pypi
six                       1.16.0                   pypi_0    pypi
smmap                     5.0.1                    pypi_0    pypi
sniffio                   1.3.1                    pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0    defaults
starlette                 0.37.2                   pypi_0    pypi
svgwrite                  1.4.3                    pypi_0    pypi
sympy                     1.13.0                   pypi_0    pypi
threadpoolctl             3.5.0                    pypi_0    pypi
timm                      0.6.13                   pypi_0    pypi
tk                        8.6.14               h39e8969_0    defaults
tokenizers                0.15.1                   pypi_0    pypi
tomlkit                   0.12.0                   pypi_0    pypi
toolz                     0.12.1                   pypi_0    pypi
torch                     2.1.2                    pypi_0    pypi
torchvision               0.16.2                   pypi_0    pypi
tqdm                      4.66.4                   pypi_0    pypi
transformers              4.37.2                   pypi_0    pypi
triton                    2.1.0                    pypi_0    pypi
typer                     0.12.3                   pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
tzdata                    2024.1                   pypi_0    pypi
urllib3                   2.2.2                    pypi_0    pypi
uvicorn                   0.30.1                   pypi_0    pypi
uvloop                    0.19.0                   pypi_0    pypi
wandb                     0.17.4                   pypi_0    pypi
watchfiles                0.22.0                   pypi_0    pypi
wavedrom                  2.0.3.post3              pypi_0    pypi
websockets                11.0.3                   pypi_0    pypi
wheel                     0.43.0          py310h06a4308_0    defaults
xxhash                    3.4.1                    pypi_0    pypi
xz                        5.4.6                h5eee18b_1    defaults
yarl                      1.9.4                    pypi_0    pypi
zlib                      1.2.13               h5eee18b_1    defaults

The detail of the error information:

Traceback (most recent call last):
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/train/train_mem.py", line 9, in <module>
Traceback (most recent call last):
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/train/train_mem.py", line 9, in <module>
Traceback (most recent call last):
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/train/train_mem.py", line 9, in <module>
Traceback (most recent call last):
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/train/train_mem.py", line 9, in <module>
    train(attn_implementation="flash_attention_2")
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/train/train.py", line 1010, in train
    train(attn_implementation="flash_attention_2")
      File "/project/luojunwei/VisionLanguage/Code/SliME/llava/train/train.py", line 1010, in train
train(attn_implementation="flash_attention_2")
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/train/train.py", line 1010, in train
    train(attn_implementation="flash_attention_2")
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/train/train.py", line 1010, in train
    model = LlavaLlamaForCausalLM.from_pretrained(
      File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3594, in from_pretrained
model = LlavaLlamaForCausalLM.from_pretrained(
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3594, in from_pretrained
    model = LlavaLlamaForCausalLM.from_pretrained(
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3594, in from_pretrained
    model = LlavaLlamaForCausalLM.from_pretrained(
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3594, in from_pretrained
            model = cls(config, *model_args, **model_kwargs)model = cls(config, *model_args, **model_kwargs)

  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
model = cls(config, *model_args, **model_kwargs)
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    model = cls(config, *model_args, **model_kwargs)
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
            f(module, *args, **kwargs)f(module, *args, **kwargs)f(module, *args, **kwargs)

  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/language_model/llava_llama.py", line 46, in __init__
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/language_model/llava_llama.py", line 46, in __init__
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/language_model/llava_llama.py", line 46, in __init__
    f(module, *args, **kwargs)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/language_model/llava_llama.py", line 46, in __init__
            self.model = LlavaLlamaModel(config)self.model = LlavaLlamaModel(config)self.model = LlavaLlamaModel(config)

  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    self.model = LlavaLlamaModel(config)
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    f(module, *args, **kwargs)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/language_model/llava_llama.py", line 38, in __init__
        f(module, *args, **kwargs)f(module, *args, **kwargs)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/language_model/llava_llama.py", line 38, in __init__

  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/language_model/llava_llama.py", line 38, in __init__
    super(LlavaLlamaModel, self).__init__(config)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/llava_arch.py", line 35, in __init__
    f(module, *args, **kwargs)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/language_model/llava_llama.py", line 38, in __init__
    super(LlavaLlamaModel, self).__init__(config)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/llava_arch.py", line 35, in __init__
    super(LlavaLlamaModel, self).__init__(config)
      File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/llava_arch.py", line 35, in __init__
self.mm_projector = build_vision_projector(config)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/multimodal_projector/builder.py", line 239, in build_vision_projector
    self.mm_projector = build_vision_projector(config)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/multimodal_projector/builder.py", line 239, in build_vision_projector
    super(LlavaLlamaModel, self).__init__(config)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/llava_arch.py", line 35, in __init__
    self.mm_projector = build_vision_projector(config)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/multimodal_projector/builder.py", line 239, in build_vision_projector
    return GatedBlock(config)
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    return GatedBlock(config)
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    return GatedBlock(config)
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    self.mm_projector = build_vision_projector(config)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/multimodal_projector/builder.py", line 239, in build_vision_projector
    return GatedBlock(config)
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    f(module, *args, **kwargs)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/multimodal_projector/builder.py", line 43, in __init__
    f(module, *args, **kwargs)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/multimodal_projector/builder.py", line 43, in __init__
    f(module, *args, **kwargs)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/multimodal_projector/builder.py", line 43, in __init__
    self.attn = Resampler(
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    f(module, *args, **kwargs)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/multimodal_projector/builder.py", line 43, in __init__
    self.attn = Resampler(
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    self.attn = Resampler(
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    self.attn = Resampler(
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 459, in wrapper
    f(module, *args, **kwargs)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/multimodal_resampler/sampler.py", line 120, in __init__
    f(module, *args, **kwargs)    
trunc_normal_(self.query, std=.02)  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/multimodal_resampler/sampler.py", line 120, in __init__

  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/torch/nn/init.py", line 183, in trunc_normal_
    f(module, *args, **kwargs)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/multimodal_resampler/sampler.py", line 120, in __init__
    trunc_normal_(self.query, std=.02)
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/torch/nn/init.py", line 183, in trunc_normal_
    f(module, *args, **kwargs)
  File "/project/luojunwei/VisionLanguage/Code/SliME/llava/model/multimodal_resampler/sampler.py", line 120, in __init__
    trunc_normal_(self.query, std=.02)
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/torch/nn/init.py", line 183, in trunc_normal_
    trunc_normal_(self.query, std=.02)
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/torch/nn/init.py", line 183, in trunc_normal_
        return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/torch/nn/init.py", line 46, in _no_grad_trunc_normal_
    return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)

  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/torch/nn/init.py", line 46, in _no_grad_trunc_normal_
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/torch/nn/init.py", line 46, in _no_grad_trunc_normal_
    return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
  File "/scratch/luojunwei/miniconda-3/envs/llavahd/lib/python3.10/site-packages/torch/nn/init.py", line 46, in _no_grad_trunc_normal_
    tensor.erfinv_()
RuntimeError: "erfinv_cuda" not implemented for 'BFloat16'
    tensor.erfinv_()
RuntimeError: "erfinv_cuda" not implemented for 'BFloat16'
    tensor.erfinv_()
RuntimeError: "erfinv_cuda" not implemented for 'BFloat16'
    tensor.erfinv_()
RuntimeError: "erfinv_cuda" not implemented for 'BFloat16'
yfzhang114 commented 1 month ago

Hi,

It's a common PyTorch issue. Please refer to https://github.com/pytorch/pytorch/issues/123553. Additionally, you can try this function instead of nn.trunc_normal_:


def _trunc_normal_(tensor, mean, std, a, b):
    # rewrite timm trunc normal
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * l - 1, 2 * u - 1)

    # Use inverse cdf transform for normal distribution to get truncated standard normal
    # tensor.erfinv_() # NOTE: deleted as "erfinv_cuda" not implemented for 'BFloat16'

    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor
Luo-Z13 commented 1 month ago

Hi,

It's a common PyTorch issue. Please refer to pytorch/pytorch#123553. Additionally, you can try this function instead of nn.trunc_normal_:

def _trunc_normal_(tensor, mean, std, a, b):
    # rewrite timm trunc normal
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * l - 1, 2 * u - 1)

    # Use inverse cdf transform for normal distribution to get truncated standard normal
    # tensor.erfinv_() # NOTE: deleted as "erfinv_cuda" not implemented for 'BFloat16'

    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor

Thank you for the reply, the problem has been solved!