exo-explore / exo

Run your own AI cluster at home with everyday devices 📱💻 🖥️⌚
GNU General Public License v3.0
6.87k stars 355 forks source link

OOM error when loading model using load_state_dict on multiple Mali GPU devices #128

Closed artistlu closed 1 month ago

artistlu commented 1 month ago

I am encountering an Out-of-Memory (OOM) error when using the load_state_dict function to load a model on the first node of my multiple Mali GPU devices. Each Mali GPU device has approximately 3GB of memory.

Initially, I thought the issue might be due to having too few nodes, so I tried using 14 nodes. However, I continued to experience the OOM error on the first node, even with 14 nodes.

I have also modified the fix_bf16 function to adapt to the error on the Mali GPU devices, but I'm not sure if this has had any impact on the issue.

def fix_bf16(weights: Dict[Any, Tensor]):
  import torch
  if getenv("SUPPORT_BF16", 1):
    # Convert bfloat16 to float16 when SUPPORT_BF16 is set
    return {k: v.to(torch.float16) if v.dtype == torch.bfloat16 else v for k, v in weights.items()}
  else:
    # Convert bfloat16 to float16 using llvm_bf16_cast
    return {k: v.llvm_bf16_cast(torch.float16) if v.dtype == torch.bfloat16 else v for k, v in weights.items()}