axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.91k stars 870 forks source link

Jamba-1.5-Mini fine-tuning got ValueError: Must flatten tensors with uniform dtype but got torch.float32 and torch.bfloat16 #1889

Closed coranholmes closed 2 months ago

coranholmes commented 2 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

To start training.

Current behaviour

It shows the following error:

[rank2]: Traceback (most recent call last):
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank2]:     return _run_code(code, main_globals, None,
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/runpy.py", line 86, in _run_code
[rank2]:     exec(code, run_globals)
[rank2]:   File "/root/Codes/axolotl/src/axolotl/cli/train.py", line 72, in <module>
[rank2]:     fire.Fire(do_cli)
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
[rank2]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
[rank2]:     component, remaining_args = _CallAndUpdateTrace(
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
[rank2]:     component = fn(*varargs, **kwargs)
[rank2]:   File "/root/Codes/axolotl/src/axolotl/cli/train.py", line 39, in do_cli
[rank2]:     return do_train(parsed_cfg, parsed_cli_args)
[rank2]:   File "/root/Codes/axolotl/src/axolotl/cli/train.py", line 67, in do_train
[rank2]:     return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
[rank2]:   File "/root/Codes/axolotl/src/axolotl/train.py", line 188, in train
[rank2]:     trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train
[rank2]:     return inner_training_loop(
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 2032, in _inner_training_loop
[rank2]:     self.model = self.accelerator.prepare(self.model)
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/accelerate/accelerator.py", line 1292, in prepare
[rank2]:     result = tuple(
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/accelerate/accelerator.py", line 1293, in <genexpr>
[rank2]:     self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/accelerate/accelerator.py", line 1169, in _prepare_one
[rank2]:     return self.prepare_model(obj, device_placement=device_placement)
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/accelerate/accelerator.py", line 1459, in prepare_model
[rank2]:     model = FSDP(model, **kwargs)
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 485, in __init__
[rank2]:     _auto_wrap(
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank2]:     _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank2]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank2]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank2]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank2]:   [Previous line repeated 2 more times]
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
[rank2]:     return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
[rank2]:     return wrapper_cls(module, **kwargs)
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 511, in __init__
[rank2]:     _init_param_handle_from_module(
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 598, in _init_param_handle_from_module
[rank2]:     _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 610, in _init_param_handle_from_params
[rank2]:     handle = FlatParamHandle(
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank2]:     self._init_flat_param_and_metadata(
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank2]:     ) = self._validate_tensors_to_flatten(params)
[rank2]:   File "/root/miniconda3/envs/axolotl/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten
[rank2]:     raise ValueError(
[rank2]: ValueError: Must flatten tensors with uniform dtype but got torch.float32 and torch.bfloat16

Full log can be seen in the attachment complete log.txt

Steps to reproduce

Training script is:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --main_process_port 29505 -m axolotl.cli.train examples/jamba/my_jamba.yaml

conda list:

# packages in environment at /root/miniconda3/envs/axolotl:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
_openmp_mutex             5.1                       1_gnu
absl-py                   2.1.0                    pypi_0    pypi
accelerate                0.30.1                   pypi_0    pypi
addict                    2.4.0                    pypi_0    pypi
aiobotocore               2.13.0                   pypi_0    pypi
aiofiles                  23.2.1                   pypi_0    pypi
aiohttp                   3.9.5                    pypi_0    pypi
aioitertools              0.11.0                   pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
altair                    5.3.0                    pypi_0    pypi
annotated-types           0.7.0                    pypi_0    pypi
anyio                     4.4.0                    pypi_0    pypi
art                       6.2                      pypi_0    pypi
async-timeout             4.0.3                    pypi_0    pypi
attrs                     23.2.0                   pypi_0    pypi
axolotl                   0.4.1                     dev_0    <develop>
bitsandbytes              0.43.1                   pypi_0    pypi
botocore                  1.34.106                 pypi_0    pypi
bzip2                     1.0.8                h5eee18b_6
ca-certificates           2024.3.11            h06a4308_0
cachetools                5.3.3                    pypi_0    pypi
causal-conv1d             1.4.0                    pypi_0    pypi
certifi                   2024.2.2                 pypi_0    pypi
charset-normalizer        3.3.2                    pypi_0    pypi
click                     8.1.7                    pypi_0    pypi
cmake                     3.29.3                   pypi_0    pypi
colorama                  0.4.6                    pypi_0    pypi
coloredlogs               15.0.1                   pypi_0    pypi
contourpy                 1.2.1                    pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
datasets                  2.19.1                   pypi_0    pypi
decorator                 5.1.1                    pypi_0    pypi
deepspeed                 0.14.2                   pypi_0    pypi
deepspeed-kernels         0.0.1.dev1698255861          pypi_0    pypi
dill                      0.3.8                    pypi_0    pypi
dnspython                 2.6.1                    pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
docstring-parser          0.16                     pypi_0    pypi
einops                    0.8.0                    pypi_0    pypi
email-validator           2.1.1                    pypi_0    pypi
evaluate                  0.4.1                    pypi_0    pypi
exceptiongroup            1.2.1                    pypi_0    pypi
fastapi                   0.111.0                  pypi_0    pypi
fastapi-cli               0.0.4                    pypi_0    pypi
fastcore                  1.5.42                   pypi_0    pypi
ffmpy                     0.3.2                    pypi_0    pypi
filelock                  3.14.0                   pypi_0    pypi
fire                      0.6.0                    pypi_0    pypi
flash-attn                2.5.8                    pypi_0    pypi
fonttools                 4.53.0                   pypi_0    pypi
frozenlist                1.4.1                    pypi_0    pypi
fschat                    0.2.36                   pypi_0    pypi
fsspec                    2024.3.1                 pypi_0    pypi
gcsfs                     2024.3.1                 pypi_0    pypi
gitdb                     4.0.11                   pypi_0    pypi
gitpython                 3.1.43                   pypi_0    pypi
google-api-core           2.19.0                   pypi_0    pypi
google-auth               2.29.0                   pypi_0    pypi
google-auth-oauthlib      1.2.0                    pypi_0    pypi
google-cloud-core         2.4.1                    pypi_0    pypi
google-cloud-storage      2.16.0                   pypi_0    pypi
google-crc32c             1.5.0                    pypi_0    pypi
google-resumable-media    2.7.0                    pypi_0    pypi
googleapis-common-protos  1.63.0                   pypi_0    pypi
gradio                    3.50.2                   pypi_0    pypi
gradio-client             0.6.1                    pypi_0    pypi
grpcio                    1.64.0                   pypi_0    pypi
h11                       0.14.0                   pypi_0    pypi
hf-transfer               0.1.6                    pypi_0    pypi
hjson                     3.1.0                    pypi_0    pypi
httpcore                  1.0.5                    pypi_0    pypi
httptools                 0.6.1                    pypi_0    pypi
httpx                     0.27.0                   pypi_0    pypi
huggingface-hub           0.23.2                   pypi_0    pypi
humanfriendly             10.0                     pypi_0    pypi
idna                      3.7                      pypi_0    pypi
importlib-resources       6.4.0                    pypi_0    pypi
jinja2                    3.1.4                    pypi_0    pypi
jmespath                  1.0.1                    pypi_0    pypi
joblib                    1.4.2                    pypi_0    pypi
jsonschema                4.22.0                   pypi_0    pypi
jsonschema-specifications 2023.12.1                pypi_0    pypi
kiwisolver                1.4.5                    pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1
libffi                    3.4.4                h6a678d5_1
libgcc-ng                 11.2.0               h1234567_1
libgomp                   11.2.0               h1234567_1
libstdcxx-ng              11.2.0               h1234567_1
libuuid                   1.41.5               h5eee18b_0
llvmlite                  0.42.0                   pypi_0    pypi
mamba-ssm                 2.2.2                    pypi_0    pypi
markdown                  3.6                      pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markdown2                 2.4.13                   pypi_0    pypi
markupsafe                2.1.5                    pypi_0    pypi
matplotlib                3.9.0                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
multidict                 6.0.5                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
ncurses                   6.4                  h6a678d5_0
networkx                  3.3                      pypi_0    pypi
nh3                       0.2.17                   pypi_0    pypi
ninja                     1.11.1.1                 pypi_0    pypi
numba                     0.59.1                   pypi_0    pypi
numpy                     1.26.4                   pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.5.40                  pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
oauthlib                  3.2.2                    pypi_0    pypi
openssl                   3.0.13               h7f8727e_2
optimum                   1.16.2                   pypi_0    pypi
orjson                    3.10.3                   pypi_0    pypi
packaging                 23.2                     pypi_0    pypi
pandas                    2.2.2                    pypi_0    pypi
peft                      0.11.1                   pypi_0    pypi
pillow                    10.3.0                   pypi_0    pypi
pip                       24.0            py310h06a4308_0
platformdirs              4.2.2                    pypi_0    pypi
prompt-toolkit            3.0.45                   pypi_0    pypi
proto-plus                1.23.0                   pypi_0    pypi
protobuf                  4.25.3                   pypi_0    pypi
psutil                    5.9.8                    pypi_0    pypi
py-cpuinfo                9.0.0                    pypi_0    pypi
pyarrow                   16.1.0                   pypi_0    pypi
pyarrow-hotfix            0.6                      pypi_0    pypi
pyasn1                    0.6.0                    pypi_0    pypi
pyasn1-modules            0.4.0                    pypi_0    pypi
pydantic                  2.6.3                    pypi_0    pypi
pydantic-core             2.16.3                   pypi_0    pypi
pydub                     0.25.1                   pypi_0    pypi
pygments                  2.18.0                   pypi_0    pypi
pynvml                    11.5.0                   pypi_0    pypi
pyparsing                 3.1.2                    pypi_0    pypi
python                    3.10.14              h955ad1f_1
python-dateutil           2.9.0.post0              pypi_0    pypi
python-dotenv             1.0.1                    pypi_0    pypi
python-multipart          0.0.9                    pypi_0    pypi
pytz                      2024.1                   pypi_0    pypi
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.2                  h5eee18b_0
referencing               0.35.1                   pypi_0    pypi
regex                     2024.5.15                pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
requests-oauthlib         2.0.0                    pypi_0    pypi
responses                 0.18.0                   pypi_0    pypi
rich                      13.7.1                   pypi_0    pypi
rpds-py                   0.18.1                   pypi_0    pypi
rsa                       4.9                      pypi_0    pypi
s3fs                      2024.3.1                 pypi_0    pypi
safetensors               0.4.3                    pypi_0    pypi
scikit-learn              1.2.2                    pypi_0    pypi
scipy                     1.13.1                   pypi_0    pypi
semantic-version          2.10.0                   pypi_0    pypi
sentencepiece             0.2.0                    pypi_0    pypi
sentry-sdk                2.3.1                    pypi_0    pypi
setproctitle              1.3.3                    pypi_0    pypi
setuptools                69.5.1          py310h06a4308_0
shellingham               1.5.4                    pypi_0    pypi
shortuuid                 1.0.13                   pypi_0    pypi
shtab                     1.7.1                    pypi_0    pypi
six                       1.16.0                   pypi_0    pypi
smmap                     5.0.1                    pypi_0    pypi
sniffio                   1.3.1                    pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0
starlette                 0.37.2                   pypi_0    pypi
svgwrite                  1.4.3                    pypi_0    pypi
sympy                     1.12.1                   pypi_0    pypi
tensorboard               2.16.2                   pypi_0    pypi
tensorboard-data-server   0.7.2                    pypi_0    pypi
termcolor                 2.4.0                    pypi_0    pypi
threadpoolctl             3.5.0                    pypi_0    pypi
tiktoken                  0.7.0                    pypi_0    pypi
tk                        8.6.14               h39e8969_0
tokenizers                0.19.1                   pypi_0    pypi
toolz                     0.12.1                   pypi_0    pypi
torch                     2.3.0                    pypi_0    pypi
torchaudio                2.3.0                    pypi_0    pypi
torchvision               0.18.0                   pypi_0    pypi
tqdm                      4.66.4                   pypi_0    pypi
transformers              4.41.1                   pypi_0    pypi
triton                    2.3.0                    pypi_0    pypi
trl                       0.9.6                    pypi_0    pypi
typer                     0.12.3                   pypi_0    pypi
typing-extensions         4.12.0                   pypi_0    pypi
tyro                      0.8.4                    pypi_0    pypi
tzdata                    2024.1                   pypi_0    pypi
ujson                     5.10.0                   pypi_0    pypi
urllib3                   2.2.1                    pypi_0    pypi
uvicorn                   0.30.0                   pypi_0    pypi
uvloop                    0.19.0                   pypi_0    pypi
wandb                     0.17.0                   pypi_0    pypi
watchfiles                0.22.0                   pypi_0    pypi
wavedrom                  2.0.3.post3              pypi_0    pypi
wcwidth                   0.2.13                   pypi_0    pypi
websockets                11.0.3                   pypi_0    pypi
werkzeug                  3.0.3                    pypi_0    pypi
wheel                     0.43.0          py310h06a4308_0
wrapt                     1.16.0                   pypi_0    pypi
xformers                  0.0.26.post1             pypi_0    pypi
xxhash                    3.4.1                    pypi_0    pypi
xz                        5.4.6                h5eee18b_1
yarl                      1.9.4                    pypi_0    pypi
zlib                      1.2.13               h5eee18b_1
zstandard                 0.22.0                   pypi_0    pypi

Config yaml

base_model: /mnt/models/Jamba-1.5-Mini
tokenizer_type: AutoTokenizer

load_in_4bit: true
strict: false
use_tensorboard: true
datasets:
  - path: /mnt/data/SlimOrcaDedupCleaned/CleanedOrcaSlimDedup.json
    type: chat_template
    chat_template: jamba
    drop_system_message: true
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/Jamba_test
save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true

lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: [down_proj,gate_proj,in_proj,k_proj,o_proj,out_proj,q_proj,up_proj,v_proj,x_proj]
lora_target_linear: false

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

train_on_inputs: false
group_by_length: false
bf16: true
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true
logging_steps: 1
flash_attention: true

warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: false
  fsdp_use_orig_params: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sharding_strategy: FULL_SHARD

Possible solution

Not sure whether this is related to #1836

Which Operating Systems are you using?

Python Version

3.10.14

axolotl branch-commit

main/8a1da822

Acknowledgements

xgal commented 2 months ago

Hi :) please try to use this specific commit

pip install git+https://github.com/xgal/transformers@897f80665c37c531b7803f92655dbc9b3a593fe7

or transformers >= 4.44.2

coranholmes commented 2 months ago

Hi :) please try to use this specific commit

pip install git+https://github.com/xgal/transformers@897f80665c37c531b7803f92655dbc9b3a593fe7

or transformers >= 4.44.2

I can train it properly now. Thanks