I am encountering an Out-of-Memory (OOM) error when using the load_state_dict function to load a model on the first node of my multiple Mali GPU devices. Each Mali GPU device has approximately 3GB of memory.
Initially, I thought the issue might be due to having too few nodes, so I tried using 14 nodes. However, I continued to experience the OOM error on the first node, even with 14 nodes.
I have also modified the fix_bf16 function to adapt to the error on the Mali GPU devices, but I'm not sure if this has had any impact on the issue.
def fix_bf16(weights: Dict[Any, Tensor]):
import torch
if getenv("SUPPORT_BF16", 1):
# Convert bfloat16 to float16 when SUPPORT_BF16 is set
return {k: v.to(torch.float16) if v.dtype == torch.bfloat16 else v for k, v in weights.items()}
else:
# Convert bfloat16 to float16 using llvm_bf16_cast
return {k: v.llvm_bf16_cast(torch.float16) if v.dtype == torch.bfloat16 else v for k, v in weights.items()}
I am encountering an Out-of-Memory (OOM) error when using the load_state_dict function to load a model on the first node of my multiple Mali GPU devices. Each Mali GPU device has approximately 3GB of memory.
Initially, I thought the issue might be due to having too few nodes, so I tried using 14 nodes. However, I continued to experience the OOM error on the first node, even with 14 nodes.
I have also modified the fix_bf16 function to adapt to the error on the Mali GPU devices, but I'm not sure if this has had any impact on the issue.