huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.77k stars 942 forks source link

Unable to load mistralai/Mixtral-8x7B-Instruct-v0.1 using mps #2778

Closed chimezie closed 3 months ago

chimezie commented 4 months ago

System Info

- `Accelerate` version: 0.30.1
- Platform: macOS-14.2.1-arm64-arm-64bit
- `accelerate` bash location: /path/to/venv/mmlu-eval/bin/accelerate
- Python version: 3.11.6
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.3.0 (False)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 128.00 GB
- `Accelerate` default config:
        Not found

Information

Tasks

Reproduction

from transformers import AutoModelForCausalLM
import torch
AutoModelForCausalLM.from_pretrained('mistralai/Mixtral-8x7B-Instruct-v0.1',
                                     revision='main',
                                     torch_dtype=torch.float16,
                                     trust_remote_code=False,
                                     device_map= {'': 'mps'})

Which results in:

File /path/to/python3.11/site-packages/transformers/models/auto/auto_factory.py:563, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    561 elif type(config) in cls._model_mapping.keys():
    562     model_class = _get_model_class(config, cls._model_mapping)
--> 563     return model_class.from_pretrained(
    564         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    565     )
    566 raise ValueError(
    567     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    568     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    569 )

File /path/to/python3.11/site-packages/transformers/modeling_utils.py:3531, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   3522     if dtype_orig is not None:
   3523         torch.set_default_dtype(dtype_orig)
   3524     (
   3525         model,
   3526         missing_keys,
   3527         unexpected_keys,
   3528         mismatched_keys,
   3529         offload_index,
   3530         error_msgs,
-> 3531     ) = cls._load_pretrained_model(
   3532         model,
   3533         state_dict,
   3534         loaded_state_dict_keys,  # XXX: rename?
   3535         resolved_archive_file,
   3536         pretrained_model_name_or_path,
   3537         ignore_mismatched_sizes=ignore_mismatched_sizes,
   3538         sharded_metadata=sharded_metadata,
   3539         _fast_init=_fast_init,
   3540         low_cpu_mem_usage=low_cpu_mem_usage,
   3541         device_map=device_map,
   3542         offload_folder=offload_folder,
   3543         offload_state_dict=offload_state_dict,
   3544         dtype=torch_dtype,
   3545         hf_quantizer=hf_quantizer,
   3546         keep_in_fp32_modules=keep_in_fp32_modules,
   3547     )
   3549 # make sure token embedding weights are still tied if needed
   3550 model.tie_weights()
File /path/to/python3.11/site-packages/transformers/modeling_utils.py:3958, in PreTrainedModel._load_pretrained_model(cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_modules)
   3954                 set_module_tensor_to_device(
   3955                     model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
   3956                 )
   3957     else:
-> 3958         new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
   3959             model_to_load,
   3960             state_dict,
   3961             loaded_keys,
   3962             start_prefix,
   3963             expected_keys,
   3964             device_map=device_map,
   3965             offload_folder=offload_folder,
   3966             offload_index=offload_index,
   3967             state_dict_folder=state_dict_folder,
   3968             state_dict_index=state_dict_index,
   3969             dtype=dtype,
   3970             hf_quantizer=hf_quantizer,
   3971             is_safetensors=is_safetensors,
   3972             keep_in_fp32_modules=keep_in_fp32_modules,
   3973             unexpected_keys=unexpected_keys,
   3974         )
   3975         error_msgs += new_error_msgs
   3976 else:

File /path/to/python3.11/site-packages/transformers/modeling_utils.py:812, in _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix, expected_keys, device_map, offload_folder, offload_index, state_dict_folder, state_dict_index, dtype, hf_quantizer, is_safetensors, keep_in_fp32_modules, unexpected_keys)
    801     state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
    802 elif (
    803     not is_quantized
    804     or (not hf_quantizer.requires_parameters_quantization)
   (...)
    810 ):
    811     # For backward compatibility with older versions of `accelerate` and for non-quantized params
--> 812     set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
    813 else:
    814     hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)

File /path/to/python3.11/site-packages/accelerate/utils/modeling.py:400, in set_module_tensor_to_device(module, tensor_name, device, value, dtype, fp16_statistics, tied_params_map)
    398             module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
    399 elif isinstance(value, torch.Tensor):
--> 400     new_value = value.to(device)
    401 else:
    402     new_value = torch.tensor(value, device=device)

RuntimeError: MPS backend out of memory (MPS allocated: 163.01 GB, other allocations: 384.00 KB, max allowed: 163.20 GB). Tried to allocate 250.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

Expected behavior

Should return a transformer model without error.

chimezie commented 4 months ago

The system has > 100GB free at the time the code is run

muellerzr commented 4 months ago

cc @SunMarc

SunMarc commented 4 months ago

Hi @chimezie, does this happen only with Mixtral-8x7B or with all the models ? From the traceback, the memory was completely used: MPS backend out of memory (MPS allocated: 163.01 GB, other allocations: 384.00 KB, max allowed: 163.20 GB)

chimezie commented 4 months ago

This seems to happen only with Mixtral-8x7B. I was able to load Llama 3 8, Qwen1.5-14B, and internistai/base-7b-v0.2 for example, without any issue

SunMarc commented 4 months ago

Mixtral-8x7B is a very big model with around 100GB but you should be able to load the model since you have over 160GB. At which checkpoint does the loading fail ? Near the end ? You can track the memory consumption using the activity monitor on your mac.

github-actions[bot] commented 3 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.