unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen 2.5 & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
18.57k stars 1.3k forks source link

Error saving PEFT adapter, re-loading model & adapter, and continuing to train #1211

Closed laura-burdick-sil closed 3 weeks ago

laura-burdick-sil commented 4 weeks ago

In this code, I am loading a Lora adapter onto Llama 3.2 (3 billion), saving the adapter only, and then re-loading it to continue training. However, when I try to continue training, it errors out.

from unsloth import FastLanguageModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftConfig, PeftModel
import json
from datasets import Dataset, DatasetDict
from trl import SFTTrainer
from transformers import TrainingArguments

# Load Llama 3.2 model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-3B-bnb-4bit",
    device_map={"":0}
)

# Add Lora adapter
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

# Save Lora adapter only
model.save_pretrained("/root/test_adapter2", save_adapter=True)
tokenizer.save_pretrained("/root/test_adapter2")

# Load base model and adapter
base_model = AutoModelForCausalLM.from_pretrained("unsloth/Llama-3.2-3B-bnb-4bit")
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-3B-bnb-4bit")

config = PeftConfig.from_pretrained("/root/test_adapter")
model = PeftModel.from_pretrained(base_model, "/root/test_adapter2", config=config, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained("/root/test_adapter2")

# Prepare dataset for training
EOS_TOKEN = tokenizer.eos_token

prompt = """### Instruction:
    {}

    ### Input:
    {}

    ### Response:
    {}"""

# Format input to LLM
def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs       = examples["input"]
    outputs      = examples["output"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
      # Must add EOS_TOKEN, otherwise your generation will go on forever!
      text = prompt.format(instruction, input, output) + EOS_TOKEN
      texts.append(text)
    return { "text" : texts, }

# Initialize a dictionary to hold the lists for each field
dataset_dict = {'input': [], 'output': [], 'instruction': []}

# Open the file and read line by line
with open("/root/all_llm_data/waima_languages_2bibles.jsonl", 'r', encoding='utf-8') as file:
    for line in file:
        # Each line is a complete JSON object
        json_object = json.loads(line)
        # Append each field to the appropriate list
        dataset_dict['input'].append(json_object.get('input', ''))
        dataset_dict['output'].append(json_object.get('output', ''))
        dataset_dict['instruction'].append(json_object.get('instruction', ''))

# Convert the dictionary of lists into a `Dataset`
dataset = Dataset.from_dict(dataset_dict)

dataset1 = dataset.map(formatting_prompts_func, batched = True,)

# Train on data
training_arguments = TrainingArguments(
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 4,
    warmup_steps = 5,
    num_train_epochs = 5,
    learning_rate = 2e-4,
    logging_steps = 1,
    weight_decay = 0.01,
    lr_scheduler_type = "linear",
    seed = 3407,
    output_dir = "/root/checkpoints", # Directory to save checkpoints.
    save_steps=500,              # Save a checkpoint every 500 steps.
    save_total_limit=1,          # Keep only the 1 most recent checkpoints.
    save_strategy="steps",       # Save checkpoints based on steps.
)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset1,
    dataset_text_field = "text",
    max_seq_length = 1024,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = training_arguments,
)

trainer.train()

# %%

Here's the error that I'm getting:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[13], line 12
      1 trainer3 = SFTTrainer(
      2     model = model,
      3     tokenizer = tokenizer,
   (...)
      9     args = training_arguments,
     10 )
---> 12 trainer_stats = trainer3.train()

File :156, in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)

File :377, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File :31, in _unsloth_training_step(self, model, inputs, num_items_in_batch)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/unsloth/models/_utils.py:1184](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/unsloth/models/_utils.py:1184), in _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs)
   1182     pass
   1183 pass
-> 1184 return self._old_compute_loss(model, inputs, *args, **kwargs)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/transformers/trainer.py:3654](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/transformers/trainer.py:3654), in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3652         loss_kwargs["num_items_in_batch"] = num_items_in_batch
   3653     inputs = {**inputs, **loss_kwargs}
-> 3654 outputs = model(**inputs)
   3655 # Save past state if it exists
   3656 # TODO: this needs to be fixed and made cleaner later.
   3657 if self.args.past_index >= 0:

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/accelerate/utils/operations.py:820](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/accelerate/utils/operations.py:820), in convert_outputs_to_fp32..forward(*args, **kwargs)
    819 def forward(*args, **kwargs):
--> 820     return model_forward(*args, **kwargs)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/accelerate/utils/operations.py:808](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/accelerate/utils/operations.py:808), in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    807 def __call__(self, *args, **kwargs):
--> 808     return convert_to_fp32(self.model_forward(*args, **kwargs))

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/amp/autocast_mode.py:44](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/amp/autocast_mode.py:44), in autocast_decorator..decorate_autocast(*args, **kwargs)
     41 @functools.wraps(func)
     42 def decorate_autocast(*args, **kwargs):
     43     with autocast_instance:
---> 44         return func(*args, **kwargs)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/_compile.py:32](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/_compile.py:32), in _disable_dynamo..inner(*args, **kwargs)
     29     disable_fn = torch._dynamo.disable(fn, recursive)
     30     fn.__dynamo_disable = disable_fn
---> 32 return disable_fn(*args, **kwargs)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632), in DisableContext.__call__.._fn(*args, **kwargs)
    630 prior = _maybe_set_eval_frame(callback)
    631 try:
--> 632     return fn(*args, **kwargs)
    633 finally:
    634     _maybe_set_eval_frame(prior)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/unsloth/models/llama.py:1048](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/unsloth/models/llama.py:1048), in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, num_logits_to_keep, **kwargs)
   1033 @torch._disable_dynamo
   1034 def PeftModelForCausalLM_fast_forward(
   1035     self,
   (...)
   1046     **kwargs,
   1047 ):
-> 1048     return self.base_model(
   1049         input_ids=input_ids,
   1050         causal_mask=causal_mask,
   1051         attention_mask=attention_mask,
   1052         inputs_embeds=inputs_embeds,
   1053         labels=labels,
   1054         output_attentions=output_attentions,
   1055         output_hidden_states=output_hidden_states,
   1056         return_dict=return_dict,
   1057         num_logits_to_keep=num_logits_to_keep,
   1058         **kwargs,
   1059     )

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:197](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:197), in BaseTuner.forward(self, *args, **kwargs)
    196 def forward(self, *args: Any, **kwargs: Any):
--> 197     return self.model.forward(*args, **kwargs)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/accelerate/hooks.py:170](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/accelerate/hooks.py:170), in add_hook_to_module..new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/unsloth/models/llama.py:946](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/unsloth/models/llama.py:946), in CausalLM_fast_forward.._CausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep, *args, **kwargs)
    944     # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
    945     self.model._has_no_labels = labels is None
--> 946     outputs = self.model(
    947         input_ids=input_ids,
    948         causal_mask=causal_mask,
    949         attention_mask=attention_mask,
    950         position_ids=position_ids,
    951         past_key_values=past_key_values,
    952         inputs_embeds=inputs_embeds,
    953         use_cache=use_cache,
    954         output_attentions=output_attentions,
    955         output_hidden_states=output_hidden_states,
    956         return_dict=return_dict,
    957     )
    958 pass
    959 hidden_states = outputs[0]

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/accelerate/hooks.py:170](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/accelerate/hooks.py:170), in add_hook_to_module..new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/unsloth/models/llama.py:810](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/unsloth/models/llama.py:810), in LlamaModel_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
    807     hidden_states = layer_outputs[0]
    809 else:
--> 810     layer_outputs = decoder_layer(
    811         hidden_states,
    812         causal_mask=mask,
    813         attention_mask=attention_mask,
    814         position_ids=position_ids,
    815         past_key_value=past_key_value,
    816         output_attentions=output_attentions,
    817         use_cache=use_cache,
    818         padding_mask=padding_mask,
    819     )
    820     hidden_states = layer_outputs[0]
    821 pass

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/accelerate/hooks.py:170](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/accelerate/hooks.py:170), in add_hook_to_module..new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/unsloth/models/llama.py:495](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/unsloth/models/llama.py:495), in LlamaDecoderLayer_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs)
    493 residual = hidden_states
    494 hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
--> 495 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    496     hidden_states=hidden_states,
    497     causal_mask=causal_mask,
    498     attention_mask=attention_mask,
    499     position_ids=position_ids,
    500     past_key_value=past_key_value,
    501     output_attentions=output_attentions,
    502     use_cache=use_cache,
    503     padding_mask=padding_mask,
    504 )
    505 hidden_states = residual + hidden_states
    507 # Fully Connected

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/accelerate/hooks.py:170](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/accelerate/hooks.py:170), in add_hook_to_module..new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/unsloth/models/llama.py:364](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/unsloth/models/llama.py:364), in LlamaAttention_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs)
    361 head_dim   = self.head_dim
    362 assert(n_kv_heads * n_groups == n_heads)
--> 364 Q, K, V = self.apply_qkv(self, hidden_states)
    365 Q = Q.view(bsz, q_len, n_heads,    head_dim).transpose(1, 2)
    366 K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1931](https://vscode-remote+localhost-003a8899.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1931), in Module.__getattr__(self, name)
   1929     if name in modules:
   1930         return modules[name]
-> 1931 raise AttributeError(
   1932     f"'{type(self).__name__}' object has no attribute '{name}'"
   1933 )

AttributeError: 'LlamaSdpaAttention' object has no attribute 'apply_qkv'

Any ideas? Thank you!

danielhanchen commented 4 weeks ago

@laura-burdick-sil Apologies on the issue - you can use Unsloth directly to load the adapter! Ie:

FastLanguageModel.from_pretrained("/root/test_adapter2")

We auto detect LoRAs. The reason for this error is because Unsloth dynamically patches stuff, so using normal HF will break in the same script

laura-burdick-sil commented 3 weeks ago

I'm not sure that I fully understand. When I replace this code:

# Load base model and adapter
base_model = AutoModelForCausalLM.from_pretrained("unsloth/Llama-3.2-3B-bnb-4bit")
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-3B-bnb-4bit")

config = PeftConfig.from_pretrained("/root/test_adapter2")
model = PeftModel.from_pretrained(base_model, "/root/test_adapter2", config=config, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained("/root/test_adapter2")

with this code:

model = FastLanguageModel.from_pretrained("/root/test_adapter2", device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained("/root/test_adapter2")

and try to train the model (using the same code as above to load the dataset and train), I get the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[5], line 18
      1 # Train on data
      2 training_arguments = TrainingArguments(
      3     per_device_train_batch_size = 2,
      4     gradient_accumulation_steps = 4,
   (...)
     15     save_strategy="steps",       # Save checkpoints based on steps.
     16 )
---> 18 trainer = SFTTrainer(
     19     model = model,
     20     tokenizer = tokenizer,
     21     train_dataset = dataset1,
     22     dataset_text_field = "text",
     23     max_seq_length = 1024,
     24     dataset_num_proc = 2,
     25     packing = False, # Can make training 5x faster for short sequences.
     26     args = training_arguments,
     27 )
     29 torch.cuda.empty_cache()
     30 trainer.train()

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:101](https://vscode-remote+localhost-003a8898.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:101), in _deprecate_arguments.._inner_deprecate_positional_args..inner_f(*args, **kwargs)
     99         message += "\n\n" + custom_message
    100     warnings.warn(message, FutureWarning)
--> 101 return f(*args, **kwargs)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:401](https://vscode-remote+localhost-003a8898.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:401), in SFTTrainer.__init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics, peft_config, dataset_text_field, packing, formatting_func, max_seq_length, infinite, num_of_sequences, chars_per_token, dataset_num_proc, dataset_batch_size, neftune_noise_alpha, model_init_kwargs, dataset_kwargs, eval_packing)
    395 if tokenizer.padding_side is not None and tokenizer.padding_side != "right":
    396     warnings.warn(
    397         "You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to "
    398         "overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code."
    399     )
--> 401 super().__init__(
    402     model=model,
    403     args=args,
    404     data_collator=data_collator,
    405     train_dataset=train_dataset,
    406     eval_dataset=eval_dataset,
    407     tokenizer=tokenizer,
    408     model_init=model_init,
    409     compute_metrics=compute_metrics,
    410     callbacks=callbacks,
    411     optimizers=optimizers,
    412     preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    413 )
    415 # Add tags for models that have been loaded with the correct transformers version
    416 if hasattr(self.model, "add_model_tags"):

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/transformers/utils/deprecation.py:165](https://vscode-remote+localhost-003a8898.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/transformers/utils/deprecation.py:165), in deprecate_kwarg..wrapper..wrapped_func(*args, **kwargs)
    161 elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS):
    162     # DeprecationWarning is ignored by default, so we use FutureWarning instead
    163     warnings.warn(message, FutureWarning, stacklevel=2)
--> 165 return func(*args, **kwargs)

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/transformers/trainer.py:587](https://vscode-remote+localhost-003a8898.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/transformers/trainer.py:587), in Trainer.__init__(self, model, args, data_collator, train_dataset, eval_dataset, processing_class, model_init, compute_loss_func, compute_metrics, callbacks, optimizers, optimizer_cls_and_kwargs, preprocess_logits_for_metrics)
    582 # Bnb Quantized models doesn't support `.to` operation.
    583 if (
    584     self.place_model_on_device
    585     and not getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
    586 ):
--> 587     self._move_model_to_device(model, args.device)
    589 # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
    590 if self.is_model_parallel:

File [~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/transformers/trainer.py:860](https://vscode-remote+localhost-003a8898.vscode-resource.vscode-cdn.net/root/~/.clearml/venvs-builds/3.10/lib/python3.10/site-packages/transformers/trainer.py:860), in Trainer._move_model_to_device(self, model, device)
    859 def _move_model_to_device(self, model, device):
--> 860     model = model.to(device)
    861     # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
    862     if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):

AttributeError: 'tuple' object has no attribute 'to'

Am I missing something here? Thank you!

Erland366 commented 3 weeks ago

Ohhh, FastLanguageModel returning both model and tokenizer at the same time. So you need to unpack it :

model, tokenizer = FastLanguageModel.from_pretrained("/root/test_adapter2", max_seq_length=2048, load_in_4bit=True, dtype=None)
laura-burdick-sil commented 3 weeks ago

Thank you! That worked!