mbzuai-oryx / LLaVA-pp

🔥🔥 LLaVA++: Extending LLaVA with Phi-3 and LLaMA-3 (LLaVA LLaMA-3, LLaVA Phi-3)
798 stars 58 forks source link

Training Issue #19

Open DevonPeroutky opened 5 months ago

DevonPeroutky commented 5 months ago

Environment

Issue

I'm seeing random and sudden loss spikes during training, if there is a simpler way of debugging this, I'm open to a new approach. However, I attempted to reproduce the training loop in pytorch such that I could log out abnormal gradients during the training process to detect any erroneous examples in my training data.

However, I'm always getting AttributeError: 'NoneType' object has no attribute 'device' in the forward pass (Full stacktrace below).

I built the model exactly how its done in train.py and my training loop looks like

# Define a threshold for outlier detection
gradient_threshold = 10.0

# Create a DataLoader for iterating through the dataset
train_dataloader = torch.utils.data.DataLoader(data_module['train_dataset'], batch_size=1, shuffle=True)

for batch_idx, batch in enumerate(train_dataloader):
    input_ids = batch["input_ids"]             # torch.Size([1, 200])
    labels = batch["labels"]                         # torch.Size([1, 200])
    image_tensor = batch["image"].half() # torch.Size([1, 3, 336, 336])

    # Zero the gradient
    optimizer.zero_grad()

    # Always errors out here
    output = model.forward(input_ids=input_ids, images=image_tensor)
    ....

The model.forward always fails with the below stacktrace. I've tried the forward pass with and without labels, similar results. After prepare_inputs_labels_for_multimodal call, the inputs look like the following:

Input IDs:  None
position_ids:  None
Attention Mask:  None
past_key_values:  None
labels: None
inputs_embeds:  torch.Size([1, 512, 4096])

Below is the full stacktrace and the model layers. What am I missing?

Model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlavaLlamaForCausalLM(
      (model): LlavaLlamaModel(
        (embed_tokens): Embedding(128257, 4096, padding_idx=128256)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaFlashAttention2(
              (q_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=128, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=128, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (v_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=128, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (o_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=128, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (rotary_emb): LlamaRotaryEmbedding()
            )
            (mlp): LlamaMLP(
              (gate_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=14336, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=128, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=14336, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (up_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=14336, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=128, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=14336, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (down_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=14336, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=14336, out_features=128, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (act_fn): SiLU()
            )
            (input_layernorm): LlamaRMSNorm()
            (post_attention_layernorm): LlamaRMSNorm()
          )
        )
        (norm): LlamaRMSNorm()
        (vision_tower): CLIPVisionTower(
          (vision_tower): CLIPVisionModel(
            (vision_model): CLIPVisionTransformer(
              (embeddings): CLIPVisionEmbeddings(
                (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
                (position_embedding): Embedding(577, 1024)
              )
              (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (encoder): CLIPEncoder(
                (layers): ModuleList(
                  (0-23): 24 x CLIPEncoderLayer(
                    (self_attn): CLIPAttention(
                      (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                      (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                      (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                      (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
                    )
                    (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                    (mlp): CLIPMLP(
                      (activation_fn): QuickGELUActivation()
                      (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                      (fc2): Linear(in_features=4096, out_features=1024, bias=True)
                    )
                    (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                  )
                )
              )
              (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
          )
        )
        (mm_projector): Sequential(
          (0): Linear(in_features=1024, out_features=4096, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=4096, out_features=4096, bias=True)
        )
      )
      (lm_head): Linear8bitLt(in_features=4096, out_features=128257, bias=False)
    )
  )
)

Full StackTrace

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[36], line 46
     28 # (_input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels_embeds) = model.prepare_inputs_labels_for_multimodal(input_ids=input_ids, position_ids=None, attention_mask=None, past_key_values=None, labels=labels, images=image_tensor)
   (...)
     44 
     45 # 4
---> 46 output = model.forward(input_ids=input_ids, images=image_tensor, labels=labels)
     47 loss = compute_loss(output.logits, labels)
     48 print("LOSS: ", loss.item())

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/peft_model.py:1129](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/peft_model.py#line=1128), in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1127     with self._enable_peft_forward_hooks(**kwargs):
   1128         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1129         return self.base_model(
   1130             input_ids=input_ids,
   1131             attention_mask=attention_mask,
   1132             inputs_embeds=inputs_embeds,
   1133             labels=labels,
   1134             output_attentions=output_attentions,
   1135             output_hidden_states=output_hidden_states,
   1136             return_dict=return_dict,
   1137             **kwargs,
   1138         )
   1140 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1141 if attention_mask is not None:
   1142     # concat prompt attention mask

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:161](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/tuners/tuners_utils.py#line=160), in BaseTuner.forward(self, *args, **kwargs)
    160 def forward(self, *args: Any, **kwargs: Any):
--> 161     return self.model.forward(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py:165](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py#line=164), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File [~/LLaVA-pp/LLaVA/llava/model/language_model/llava_llama.py:103](http://34.146.99.81:8888/lab/tree/LLaVA/llava/train/LLaVA/llava/model/language_model/llava_llama.py#line=102), in LlavaLlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, images, image_sizes, return_dict, cache_position)
    101     print("inputs_embeds: ", inputs_embeds.shape)
    102 print("labels: ", labels)
--> 103 return super().forward(
    104     input_ids=input_ids,
    105     attention_mask=attention_mask,
    106     position_ids=position_ids,
    107     past_key_values=past_key_values,
    108     inputs_embeds=inputs_embeds,
    109     labels=labels,
    110     use_cache=use_cache,
    111     output_attentions=output_attentions,
    112     output_hidden_states=output_hidden_states,
    113     return_dict=return_dict
    114 )

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1183](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py#line=1182), in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1180 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1182 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1183 outputs = self.model(
   1184     input_ids=input_ids,
   1185     attention_mask=attention_mask,
   1186     position_ids=position_ids,
   1187     past_key_values=past_key_values,
   1188     inputs_embeds=inputs_embeds,
   1189     use_cache=use_cache,
   1190     output_attentions=output_attentions,
   1191     output_hidden_states=output_hidden_states,
   1192     return_dict=return_dict,
   1193 )
   1195 hidden_states = outputs[0]
   1196 if self.config.pretraining_tp > 1:

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py:165](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py#line=164), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1070](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py#line=1069), in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
   1060     layer_outputs = self._gradient_checkpointing_func(
   1061         decoder_layer.__call__,
   1062         hidden_states,
   (...)
   1067         use_cache,
   1068     )
   1069 else:
-> 1070     layer_outputs = decoder_layer(
   1071         hidden_states,
   1072         attention_mask=attention_mask,
   1073         position_ids=position_ids,
   1074         past_key_value=past_key_values,
   1075         output_attentions=output_attentions,
   1076         use_cache=use_cache,
   1077     )
   1079 hidden_states = layer_outputs[0]
   1081 if use_cache:

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py:165](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py#line=164), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:798](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py#line=797), in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    795 hidden_states = self.input_layernorm(hidden_states)
    797 # Self Attention
--> 798 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    799     hidden_states=hidden_states,
    800     attention_mask=attention_mask,
    801     position_ids=position_ids,
    802     past_key_value=past_key_value,
    803     output_attentions=output_attentions,
    804     use_cache=use_cache,
    805     **kwargs,
    806 )
    807 hidden_states = residual + hidden_states
    809 # Fully Connected

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py:165](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py#line=164), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:494](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py#line=493), in LlamaFlashAttention2.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    490 output_attentions = False
    492 bsz, q_len, _ = hidden_states.size()
--> 494 query_states = self.q_proj(hidden_states)
    495 key_states = self.k_proj(hidden_states)
    496 value_states = self.v_proj(hidden_states)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/tuners/lora/bnb.py:217](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/tuners/lora/bnb.py#line=216), in Linear8bitLt.forward(self, x, *args, **kwargs)
    215     result = self.base_layer(x, *args, **kwargs)
    216 else:
--> 217     result = self.base_layer(x, *args, **kwargs)
    218     for active_adapter in self.active_adapters:
    219         if active_adapter not in self.lora_A.keys():

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1518](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1517), in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py:1527](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1526), in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py:165](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/hooks.py#line=164), in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:797](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/nn/modules.py#line=796), in Linear8bitLt.forward(self, x)
    794 if self.bias is not None and self.bias.dtype != x.dtype:
    795     self.bias.data = self.bias.data.to(x.dtype)
--> 797 out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
    799 if not self.state.has_fp16_weights:
    800     if self.state.CB is not None and self.state.CxB is not None:
    801         # we converted 8-bit row major to turing[/ampere](http://34.146.99.81:8888/ampere) format in the first inference pass
    802         # we no longer need the row-major weight

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:556](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py#line=555), in matmul(A, B, out, state, threshold, bias)
    554 if threshold > 0.0:
    555     state.threshold = threshold
--> 556 return MatMul8bitLt.apply(A, B, out, bias, state)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/autograd/function.py:539](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/autograd/function.py#line=538), in Function.apply(cls, *args, **kwargs)
    536 if not torch._C._are_functorch_transforms_active():
    537     # See NOTE: [functorch vjp and autograd interaction]
    538     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 539     return super().apply(*args, **kwargs)  # type: ignore[misc]
    541 if cls.setup_context == _SingleLevelFunction.setup_context:
    542     raise RuntimeError(
    543         "In order to use an autograd.Function with functorch transforms "
    544         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    545         "staticmethod. For more details, please see "
    546         "https://pytorch.org/docs/master/notes/extending.func.html"
    547     )

File /opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:335, in MatMul8bitLt.forward(ctx, A, B, out, bias, state)
    331     else:
    332         if state.CxB is None and using_igemmlt:
    333             # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
    334             # we also need to convert it to the turing[/ampere](http://34.146.99.81:8888/ampere) format
--> 335             state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
    336 else:
    337     if not state.has_fp16_weights and state.CxB is None and using_igemmlt:

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/functional.py:2597](http://34.146.99.81:8888/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/functional.py#line=2596), in transform(A, to_order, from_order, out, transpose, state, ld)
   2596 def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
-> 2597     prev_device = pre_call(A.device)
   2598     if state is None:
   2599         state = (A.shape, from_order)

AttributeError: 'NoneType' object has no attribute 'device'
mmaaz60 commented 4 months ago

Hi @DevonPeroutky,

Thank you for your interest in our work. Did you try to upgrade the transformers to the latest version? Please note that LLaMA-3 based trainings are only supported with "transformers==4.41+" which you can install as follows,

pip install git+https://github.com/huggingface/transformers@a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3

Let me know if it helps. Good Luck!