unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
17.08k stars 1.17k forks source link

[BUG] RuntimeError: Invalid device string: 'bfloat16' with transformers v4.40.1 and save_strategy="epoch" #404

Closed OAHC2022 closed 5 months ago

OAHC2022 commented 5 months ago

While fine-tuning the unsloth/codellama-7b model using transformers v4.40.1 and setting save_strategy=epoch, I encountered the following error:

line 540, in LlamaModel_fast_forward
    inputs_embeds = inputs_embeds.to(self.config.torch_dtype)
RuntimeError: Invalid device string: 'bfloat16'

Upon examining the code, I identified the problematic line at this GitHub location:

# Embed positions
if inputs_embeds is None:
    inputs_embeds = self.embed_tokens(input_ids)

inputs_embeds = inputs_embeds.to(self.config.torch_dtype)

It appears that during the model's saving process in fine-tuning, the self.config.torch_dtype was incorrectly set as the string "bfloat16" instead of torch.bfloat16. Here's a simple fix I implemented:

# Embed positions
if inputs_embeds is None:
    inputs_embeds = self.embed_tokens(input_ids)

# My Modification
if self.config.torch_dtype == "bfloat16":
    self.config.torch_dtype = torch.bfloat16
inputs_embeds = inputs_embeds.to(self.config.torch_dtype)
danielhanchen commented 5 months ago

Oh thanks for that!! Will add your fix in! Thanks!

OAHC2022 commented 5 months ago

Thank you!

johnsonice commented 3 months ago

is this fixed ? still hitting the same bug

danielhanchen commented 3 months ago

@johnsonice Could you try updating Unsloth as in https://github.com/unslothai/unsloth/wiki

chawins commented 2 months ago

@danielhanchen I can confirm that I'm still experiencing this bug after updating.

pip uninstall unsloth -y
pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
SVHawk13 commented 1 month ago

@danielhanchen I'm having this same issue as of commit 976d11a10d54383aeb7a692c69e01151a20bfd72. I get this issue every second time that I run my finetuning script. It alternates between running without issue and throwing this error each time. I saw that this was solved previously by using if statements to check for the name of the dtype as a string and correcting it to the appropriate dtype object. Why not use a mapping of dtype strings to their dtype objects?

A naive approach could look like the following:

Original:

if self.config.torch_dtype == "float32":
    self.config.torch_dtype = torch.float32
elif self.config.torch_dtype == "bfloat16":
    self.config.torch_dtype = torch.bfloat16
elif self.config.torch_dtype == "float16":
    self.config.torch_dtype = torch.float16
inputs_embeds = inputs_embeds.to(self.config.torch_dtype)

To this:

DTYPE_MAP = {  # probably good to set as a class attribute
    "float32": torch.float32,
    torch.float32: torch.float32,
    "float16": torch.float16,
    torch.float16: torch.float16,
    "bfloat16": torch.bfloat16,
    torch.bfloat16: torch.bfloat16,
}

inputs_embeds = inputs_embeds.to(DTYPE_MAP[self.config.torch_dtype])

An enum might be a better solution:

from enum import Enum

import torch

class DtypeMap(Enum):
    float32: torch.dtype = torch.float32
    fp32: torch.dtype = float32
    float16: torch.dtype = torch.float16
    fp16: torch.dtype = float16
    bfloat16: torch.dtype = torch.bfloat16
    bf16: torch.dtype = bfloat16

    @classmethod
    def get_dtype(cls, _v) -> torch.dtype:
        if isinstance(_v, str):
            return cls[_v].value
        elif isinstance(_v, torch.dtype):
            return _v
        else:
            raise TypeError(f"{type(_v).__name__}")

DtypeMap.get_dtype(torch.bfloat16) == torch.bfloat16 # True
DtypeMap.get_dtype("bfloat16") == torch.bfloat16 # True
DtypeMap.get_dtype("bf16") == torch.bfloat16 # True; may be useful for `TrainingArguments`
# List all non-aliased dtype names:
[dt.name for dt in DtypeMap]  # ['float32', 'float16', 'bfloat16']
# List all torch.dtype dtypes defined in the enum:
[dt.value for dt in DtypeMap]  # [torch.float32, torch.float16, torch.bfloat16]
# `.to()` now works without if/elif/else branches cluttering methods and without mutating `config` state
inputs_embeds = inputs_embeds.to(DtypeMap.get_dtype(self.config.torch_dtype))
danielhanchen commented 1 month ago

I can add your mapping idea!

llllvvuu commented 1 month ago

I hit this same issue just now, but on this line:

https://github.com/unslothai/unsloth/blob/f1951c0f6d3e1f184af93e5d8f5eff6e7834e4b5/unsloth/models/llama.py#L961C9-L961C52

I don't know if this is the best place to fix it, but I changed it to logits = logits.to(__DTYPE_MAP[self.config.torch_dtype]) and it worked.