huggingface / transformers

πŸ€— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.51k stars 26.9k forks source link

RuntimeError: Caught RuntimeError in replica 0 on device 0 #26203

Closed ArnaudHureaux closed 1 year ago

ArnaudHureaux commented 1 year ago

System Info

transformers version -> 4.33 python version -> 3.10.6

I try to finetune this huggingface model : NousResearch/Llama-2-70b-chat-hf With this huggingface dataset : mlabonne/guanaco-llama2-1k

None of those previous answers helped me : https://github.com/huggingface/transformers/issues/23754 -> i didn't understood the error https://github.com/huggingface/transformers/issues/6855 -> I reduce the batch size by 1 and i used 4 A100 GPU, no result

Who can help?

Who can help -> text models: @ArthurZucker and @younesbelkada

Information

Tasks

Reproduction

1.Deploy a server RunPod with 4 A100 GPU (7.96$ per hour) with the pytorch image "RunPod Pytorch 2.0.1"

  1. Install those libraries :

    !pip install transformers[sentencepiece]
    !pip install yolk3k
    !yolk -V trl
    !pip install -q accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.31.0 trl==0.7.1
    !pip install scipy tensorboardX
    !pip install sentencepiece 
  2. Run this code :

    
    import os
    import torch
    from datasets import load_dataset
    from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
    )
    from peft import LoraConfig, PeftModel
    from trl import SFTTrainer

model_name = "NousResearch/Llama-2-70b-chat-hf" dataset_name = "mlabonne/guanaco-llama2-1k" new_model = "Llama-2-70b-chat-hf-miniguanaco" lora_r = 64 lora_alpha = 16 lora_dropout = 0.1 use_4bit = True bnb_4bit_compute_dtype = "float16" bnb_4bit_quant_type = "nf4" use_nested_quant = False output_dir = "./results" num_train_epochs = 1 fp16 = False bf16 = True per_device_train_batch_size = 1 per_device_eval_batch_size = 2 gradient_accumulation_steps = 1 gradient_checkpointing = True max_grad_norm = 0.3 learning_rate = 2e-4 weight_decay = 0.001 optim = "paged_adamw_32bit" lr_scheduler_type = "constant" max_steps = -1 warmup_ratio = 0.03 group_by_length = True save_steps = 25 logging_steps = 25 max_seq_length = None packing = False device_map = {"": 0}

dataset = load_dataset(dataset_name, split="train")

compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig( load_in_4bit=use_4bit, bnb_4bit_quant_type=bnb_4bit_quant_type, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=use_nested_quant, )

if compute_dtype == torch.float16 and use4bit: major, = torch.cuda.get_device_capability() if major >= 8: print("=" 80) print("Your GPU supports bfloat16: accelerate training with bf16=True") print("=" 80)

Load base model

model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map=device_map # Pass in the device map )

model.config.use_cache = False model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

peft_config = LoraConfig( lora_alpha=lora_alpha, lora_dropout=lora_dropout, r=lora_r, bias="none", task_type="CAUSAL_LM", )

training_arguments = TrainingArguments( output_dir=output_dir, num_train_epochs=num_train_epochs, per_device_train_batch_size=per_device_train_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, optim=optim, save_steps=save_steps, logging_steps=logging_steps, learning_rate=learning_rate, weight_decay=weight_decay, fp16=fp16, bf16=bf16, max_grad_norm=max_grad_norm, max_steps=max_steps, warmup_ratio=warmup_ratio, group_by_length=group_by_length, lr_scheduler_type=lr_scheduler_type, report_to="tensorboard" )

trainer = SFTTrainer( model=model, train_dataset=dataset, peft_config=peft_config, dataset_text_field="text", max_seq_length=max_seq_length, tokenizer=tokenizer, args=training_arguments, packing=packing, )

trainer.train() trainer.model.save_pretrained(new_model)`


### Expected behavior

To get a finetunned model, this code worked with the 7B version model
ArthurZucker commented 1 year ago

Hey, could you provide the full traceback? πŸ€— next time would be great if you can format the code to make it easier to read! cc @younesbelkada if you know what is going on here.

younesbelkada commented 1 year ago

Hi @ArnaudHureaux Thanks for the issue, in order to help you, can you share the full traceback please?

ArnaudHureaux commented 1 year ago

Hey, could you provide the full traceback? πŸ€— next time would be great if you can format the code to make it easier to read! cc @younesbelkada if you know what is going on here.

Oh yeah my bad, i edited my comment sorry, i forgotten the presence of "#####" in my code πŸ€—

@younesbelkada thanks a lot for your help, please find below the full traceback :



You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[10], line 44
     13 training_arguments = TrainingArguments(
     14     output_dir=output_dir,
     15     num_train_epochs=num_train_epochs,
   (...)
     30     report_to="tensorboard"
     31 )
     33 trainer = SFTTrainer(
     34     model=model,
     35     train_dataset=dataset,
   (...)
     41     packing=packing,
     42 )
---> 44 trainer.train()
     45 trainer.model.save_pretrained(new_model)

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1539, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1534     self.model_wrapped = self.model
   1536 inner_training_loop = find_executable_batch_size(
   1537     self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
   1538 )
-> 1539 return inner_training_loop(
   1540     args=args,
   1541     resume_from_checkpoint=resume_from_checkpoint,
   1542     trial=trial,
   1543     ignore_keys_for_eval=ignore_keys_for_eval,
   1544 )

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1809, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1806     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   1808 with self.accelerator.accumulate(model):
-> 1809     tr_loss_step = self.training_step(model, inputs)
   1811 if (
   1812     args.logging_nan_inf_filter
   1813     and not is_torch_tpu_available()
   1814     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   1815 ):
   1816     # if loss is nan or inf simply add the average of previous logged losses
   1817     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2654, in Trainer.training_step(self, model, inputs)
   2651     return loss_mb.reduce_mean().detach().to(self.args.device)
   2653 with self.compute_loss_context_manager():
-> 2654     loss = self.compute_loss(model, inputs)
   2656 if self.args.n_gpu > 1:
   2657     loss = loss.mean()  # mean() to average on multi-gpu parallel training

File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2679, in Trainer.compute_loss(self, model, inputs, return_outputs)
   2677 else:
   2678     labels = None
-> 2679 outputs = model(**inputs)
   2680 # Save past state if it exists
   2681 # TODO: this needs to be fixed and made cleaner later.
   2682 if self.args.past_index >= 0:

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py:171, in DataParallel.forward(self, *inputs, **kwargs)
    169     return self.module(*inputs[0], **kwargs[0])
    170 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 171 outputs = self.parallel_apply(replicas, inputs, kwargs)
    172 return self.gather(outputs, self.output_device)

File /usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py:181, in DataParallel.parallel_apply(self, replicas, inputs, kwargs)
    180 def parallel_apply(self, replicas, inputs, kwargs):
--> 181     return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

File /usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py:89, in parallel_apply(modules, inputs, kwargs_tup, devices)
     87     output = results[i]
     88     if isinstance(output, ExceptionWrapper):
---> 89         output.reraise()
     90     outputs.append(output)
     91 return outputs

File /usr/local/lib/python3.10/dist-packages/torch/_utils.py:644, in ExceptionWrapper.reraise(self)
    640 except TypeError:
    641     # If the exception takes multiple arguments, don't try to
    642     # instantiate since we don't know how to
    643     raise RuntimeError(msg) from None
--> 644 raise exception

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 922, in forward
    return self.base_model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 806, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 685, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 249, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 107, in forward
    outputs = run_function(*args)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 681, in custom_forward
    return module(*inputs, output_attentions, None)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 295, in forward
    query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 295, in <listcomp>
    query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x8192 and 1x1024)```
younesbelkada commented 1 year ago

Hmm it seems it uses pretraining_tp=1 despite you are forcing it on the script, and this is the culprit as it is not supported in PEFT - @ArthurZucker I thought the pretraining_tp was always forced to be 1. Can you add revision="refs/pr/1 on from_pretrained ?

ArnaudHureaux commented 1 year ago

Thanks @younesbelkada but which "from_pretrained" ?

python
model = AutoModelForCausalLM.from_pretrained(

or

python
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
ArnaudHureaux commented 1 year ago

?

ArthurZucker commented 1 year ago

Hey, I invite you to read the documentation about the model you are using: Llama. Here, the pertaining_tp argument is explained. πŸ€— \ As this is an argument for the modelling code, it should be added to the from_pretrained used to initialise the model.

younesbelkada commented 1 year ago

hi @ArnaudHureaux can you try this snippet:

import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer

model_name = "NousResearch/Llama-2-70b-chat-hf"
dataset_name = "mlabonne/guanaco-llama2-1k"
new_model = "Llama-2-70b-chat-hf-miniguanaco"
lora_r = 64
lora_alpha = 16
lora_dropout = 0.1
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False
output_dir = "./results"
num_train_epochs = 1
fp16 = False
bf16 = True
per_device_train_batch_size = 1
per_device_eval_batch_size = 2
gradient_accumulation_steps = 1
gradient_checkpointing = True
max_grad_norm = 0.3
learning_rate = 2e-4
weight_decay = 0.001
optim = "paged_adamw_32bit"
lr_scheduler_type = "constant"
max_steps = -1
warmup_ratio = 0.03
group_by_length = True
save_steps = 25
logging_steps = 25
max_seq_length = None
packing = False
device_map = {"": 0}

dataset = load_dataset(dataset_name, split="train")

compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

if compute_dtype == torch.float16 and use_4bit:
    major, _ = torch.cuda.get_device_capability()
    if major >= 8:
        print("=" * 80)
        print("Your GPU supports bfloat16: accelerate training with bf16=True")
        print("=" * 80)

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map  # Pass in the device map,
    revision="refs/pr/1"
)

model.config.use_cache = False
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
)

training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="tensorboard"
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=packing,
)

trainer.train()
trainer.model.save_pretrained(new_model)

However this is very surprising as you have correctly set: model.config.pretraining_tp = 1 , can you try with the latest transformers: pip install -U transformers ?

ArnaudHureaux commented 1 year ago

Hi @younesbelkada,

Thanks a lot for your answer, the training worked perfectly with your code !

But now i want to push the model on the hugging face hub,,so i stopped and reset my kernel & server and i ran this code (which worked for the llama-7B example) :

# Reload model in FP16 and merge it with LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map=device_map,
)
model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()

# Reload tokenizer to save it
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model.push_to_hub(new_model, use_temp_dir=False)
tokenizer.push_to_hub(new_model, use_temp_dir=False)

But i don't know why i have a cuda memory error with my 2 GPUs A80:

---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
Cell In[8], line 2
      1 # Reload model in FP16 and merge it with LoRA weights
----> 2 base_model = AutoModelForCausalLM.from_pretrained(
      3     model_name,
      4     low_cpu_mem_usage=True,
      5     return_dict=True,
      6     torch_dtype=torch.float16,
      7     device_map=device_map,
      8 )
      9 model = PeftModel.from_pretrained(base_model, new_model)
     10 model = model.merge_and_unload()

File /usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py:493, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    491 elif type(config) in cls._model_mapping.keys():
    492     model_class = _get_model_class(config, cls._model_mapping)
--> 493     return model_class.from_pretrained(
    494         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    495     )
    496 raise ValueError(
    497     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    498     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    499 )

File /usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:2903, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   2893     if dtype_orig is not None:
   2894         torch.set_default_dtype(dtype_orig)
   2896     (
   2897         model,
   2898         missing_keys,
   2899         unexpected_keys,
   2900         mismatched_keys,
   2901         offload_index,
   2902         error_msgs,
-> 2903     ) = cls._load_pretrained_model(
   2904         model,
   2905         state_dict,
   2906         loaded_state_dict_keys,  # XXX: rename?
   2907         resolved_archive_file,
   2908         pretrained_model_name_or_path,
   2909         ignore_mismatched_sizes=ignore_mismatched_sizes,
   2910         sharded_metadata=sharded_metadata,
   2911         _fast_init=_fast_init,
   2912         low_cpu_mem_usage=low_cpu_mem_usage,
   2913         device_map=device_map,
   2914         offload_folder=offload_folder,
   2915         offload_state_dict=offload_state_dict,
   2916         dtype=torch_dtype,
   2917         is_quantized=(load_in_8bit or load_in_4bit),
   2918         keep_in_fp32_modules=keep_in_fp32_modules,
   2919     )
   2921 model.is_loaded_in_4bit = load_in_4bit
   2922 model.is_loaded_in_8bit = load_in_8bit

File /usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:3260, in PreTrainedModel._load_pretrained_model(cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, is_quantized, keep_in_fp32_modules)
   3250 mismatched_keys += _find_mismatched_keys(
   3251     state_dict,
   3252     model_state_dict,
   (...)
   3256     ignore_mismatched_sizes,
   3257 )
   3259 if low_cpu_mem_usage:
-> 3260     new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
   3261         model_to_load,
   3262         state_dict,
   3263         loaded_keys,
   3264         start_prefix,
   3265         expected_keys,
   3266         device_map=device_map,
   3267         offload_folder=offload_folder,
   3268         offload_index=offload_index,
   3269         state_dict_folder=state_dict_folder,
   3270         state_dict_index=state_dict_index,
   3271         dtype=dtype,
   3272         is_quantized=is_quantized,
   3273         is_safetensors=is_safetensors,
   3274         keep_in_fp32_modules=keep_in_fp32_modules,
   3275     )
   3276     error_msgs += new_error_msgs
   3277 else:

File /usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:717, in _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix, expected_keys, device_map, offload_folder, offload_index, state_dict_folder, state_dict_index, dtype, is_quantized, is_safetensors, keep_in_fp32_modules)
    714     state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
    715 elif not is_quantized:
    716     # For backward compatibility with older versions of `accelerate`
--> 717     set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
    718 else:
    719     if param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():

File /usr/local/lib/python3.10/dist-packages/accelerate/utils/modeling.py:298, in set_module_tensor_to_device(module, tensor_name, device, value, dtype, fp16_statistics)
    296             module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
    297 elif isinstance(value, torch.Tensor):
--> 298     new_value = value.to(device)
    299 else:
    300     new_value = torch.tensor(value, device=device)

OutOfMemoryError: CUDA out of memory. Tried to allocate 448.00 MiB (GPU 0; 79.15 GiB total capacity; 78.58 GiB already allocated; 153.25 MiB free; 78.58 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I ran

torch.cuda.memory_summary(device=None, abbreviated=False)

And i get this output :

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 1            |        cudaMalloc retries: 1         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  80469 MiB |  80469 MiB |  80469 MiB |      0 B   |
|       from large pool |  80468 MiB |  80468 MiB |  80468 MiB |      0 B   |
|       from small pool |      1 MiB |      1 MiB |      1 MiB |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |  80469 MiB |  80469 MiB |  80469 MiB |      0 B   |
|       from large pool |  80468 MiB |  80468 MiB |  80468 MiB |      0 B   |
|       from small pool |      1 MiB |      1 MiB |      1 MiB |      0 B   |
|---------------------------------------------------------------------------|
| Requested memory      |  80469 MiB |  80469 MiB |  80469 MiB |      0 B   |
|       from large pool |  80468 MiB |  80468 MiB |  80468 MiB |      0 B   |
|       from small pool |      1 MiB |      1 MiB |      1 MiB |      0 B   |
|---------------------------------------------------------------------------|
| GPU reserved memory   |  80470 MiB |  80470 MiB |  80470 MiB |      0 B   |
|       from large pool |  80468 MiB |  80468 MiB |  80468 MiB |      0 B   |
|       from small pool |      2 MiB |      2 MiB |      2 MiB |      0 B   |
|---------------------------------------------------------------------------|
| Non-releasable memory | 450048 B   |   2032 KiB |   2032 KiB |   1592 KiB |
|       from large pool |      0 B   |      0 KiB |      0 KiB |      0 KiB |
|       from small pool | 450048 B   |   2032 KiB |   2032 KiB |   1592 KiB |
|---------------------------------------------------------------------------|
| Allocations           |     492    |     492    |     492    |       0    |
|       from large pool |     344    |     344    |     344    |       0    |
|       from small pool |     148    |     148    |     148    |       0    |
|---------------------------------------------------------------------------|
| Active allocs         |     492    |     492    |     492    |       0    |
|       from large pool |     344    |     344    |     344    |       0    |
|       from small pool |     148    |     148    |     148    |       0    |
|---------------------------------------------------------------------------|
| GPU reserved segments |     345    |     345    |     345    |       0    |
|       from large pool |     344    |     344    |     344    |       0    |
|       from small pool |       1    |       1    |       1    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       1    |       1    |       1    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       1    |       1    |       1    |       0    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

image

I think the problem is that i'am using only 1 GPU whereas i have 2 GPUs on my server Any idea of what i should do ? How can i use my 2 GPUs in this code ?

ArnaudHureaux commented 1 year ago

Ok my bad, i just replaced device_map=device_map by device_map='auto' and it worked !

base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map = 'auto'
    #device_map=device_map
)
younesbelkada commented 1 year ago

Awesome ! Great that the training worked now ! Feel free to close the issue if you thinks your concerns have been solved ! Thanks again !

ArnaudHureaux commented 1 year ago

Yeah it's 100% solved ! I have my own model on the huggingfacehub, i will now try to API-se this model with text-generation-inference

image

Thanks again @younesbelkada

younesbelkada commented 1 year ago

Very nice @ArnaudHureaux ! πŸš€

ArnaudHureaux commented 1 year ago

Oh, last question @younesbelkada

With this code I sucessfully created a finetuned model : https://huggingface.co/ArnaudHureaux/Llama-2-70b-chat-hf-miniguanaco/tree/main

But there is no .safetensors files in this model instead of the original model https://huggingface.co/meta-llama/Llama-2-70b-chat-hf/tree/main and i need them to deploy this model on text-generation-inference API

How can i convert the .bin into .safetensors ?

(i know that the question is not related to my issue so i didn't reopen it and i created a topic on the generalist forum https://discuss.huggingface.co/t/how-convert-the-bin-files-into-safetensors-files/56721

Thanks by advance for your answer ;)

ArnaudHureaux commented 1 year ago

Oh forget about, i was finally able to deploy the model without the safetensors ;)

younesbelkada commented 1 year ago

Perfect! Could you share the fix so that the community could also take some inspiration from it in case they face the same issue ? πŸ™

ArnaudHureaux commented 1 year ago

Yeah, it was just an error from my side, the "there is no .safetensors" was just a warning, and the error was indeed only related to a bad settings of my docker run command (with no links with the current issue)

viissshhal commented 2 months ago

please help me this code is not working

import torch
from datasets import load_dataset
from transformers import AutoTokenizer,AutoModelForSequenceClassification,TrainingArguments,Trainer
import numpy as np
import evaluate
​
device = 'cuda' if torch.cuda.is_available() else 'cpu'
add Codeadd Markdown
arrow_upwardarrow_downwarddelete
encodings
dataset = load_dataset("csv",data_files="/kaggle/input/chatbot/bitext-retail-banking-llm-chatbot-training-dataset.csv" , usecols=['instruction','response'])
dataset["train"]
tokenizer=AutoTokenizer.from_pretrained("/kaggle/input/bert/transformers/default/1", clean_up_tokenization_spaces=True)
​
​
def token_function(example):
    encodings=tokenizer(text=example['instruction'], text_target=example['response'], padding="max_length", truncation=True, max_length=128)
    encodings['labels'] = encodings['input_ids']
    return encodings
​
token_data=dataset.map(token_function, batched=True)  #features: ['instruction', 'response', 'input_ids', 'token_type_ids', 'attention_mask', 'labels']
print(token_data['train'][0])
def preprocess_labels(example):
    example['labels'] = example['input_ids']  # For demonstration; adjust as needed
    return example
​
​
token_data = token_data.map(preprocess_labels, batched=True)
small_train_dataset = token_data["train"].shuffle(seed=42).select(range(500))
​
model=AutoModelForSequenceClassification.from_pretrained("/kaggle/input/bert/transformers/default/1", num_labels=5)
​
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="steps",
    save_steps=10_000,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_dir="./logs",
    logging_steps=200,
    num_train_epochs=3
)
metric = evaluate.load("accuracy")
​
​
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)
​
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    compute_metrics=compute_metrics,
)
trainer.train()
​
print("sussessfuly Done")
Map: 100%
 25545/25545 [00:10<00:00, 2544.57 examples/s]
{'instruction': 'I would like to acivate a card, can you help me?', 'response': 'I\'m here to assist you with that! Activating your card is an important step to starting and enjoying its benefits. Here\'s how you can activate your card:\n\n1. Locate the activation instructions: Depending on the card issuer, you may find the activation instructions on a sticker attached to the card itself, in the welcome package, or on the issuer\'s website.\n\n2. Visit the card issuer\'s activation website: Using your computer or mobile device, open a web browser and navigate to the card issuer\'s website. Look for the activation page or section.\n\n3. Enter the required information: Follow the prompts on the activation page and provide the necessary information. This may include your card number, personal details, and security code.\n\n4. Set up your card: Once you\'ve entered the required information, you may have the option to set up a PIN, create an online account, or choose additional security features. Follow the instructions provided.\n\n5. Confirm activation: After entering all the necessary details and setting up any additional features, review the information you\'ve provided and click on the "Activate" or "Confirm" button. \n\n6. Await confirmation: In most cases, you\'ll receive a confirmation message either on the website or through email or SMS indicating that your card has been successfully activated.\n\nIf you encounter any issues during the activation process or have any questions, please don\'t hesitate to reach out. I\'m here to assist you every step of the way!', 'input_ids': [101, 146, 1156, 1176, 1106, 170, 6617, 21106, 170, 3621, 117, 1169, 1128, 1494, 1143, 136, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [101, 146, 1156, 1176, 1106, 170, 6617, 21106, 170, 3621, 117, 1169, 1128, 1494, 1143, 136, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
Map: 100%
 25545/25545 [00:03<00:00, 7900.92 examples/s]
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /kaggle/input/bert/transformers/default/1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/opt/conda/lib/python3.10/site-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of πŸ€— Transformers. Use `eval_strategy` instead
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[9], line 47
     39     return metric.compute(predictions=predictions, references=labels)
     41 trainer = Trainer(
     42     model=model,
     43     args=training_args,
     44     train_dataset=small_train_dataset,
     45     compute_metrics=compute_metrics,
     46 )
---> 47 trainer.train()
     49 print("sussessfuly Done")

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1948, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1946         hf_hub_utils.enable_progress_bars()
   1947 else:
-> 1948     return inner_training_loop(
   1949         args=args,
   1950         resume_from_checkpoint=resume_from_checkpoint,
   1951         trial=trial,
   1952         ignore_keys_for_eval=ignore_keys_for_eval,
   1953     )

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2289, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2286     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   2288 with self.accelerator.accumulate(model):
-> 2289     tr_loss_step = self.training_step(model, inputs)
   2291 if (
   2292     args.logging_nan_inf_filter
   2293     and not is_torch_xla_available()
   2294     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   2295 ):
   2296     # if loss is nan or inf simply add the average of previous logged losses
   2297     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:3328, in Trainer.training_step(self, model, inputs)
   3325     return loss_mb.reduce_mean().detach().to(self.args.device)
   3327 with self.compute_loss_context_manager():
-> 3328     loss = self.compute_loss(model, inputs)
   3330 del inputs
   3331 if (
   3332     self.args.torch_empty_cache_steps is not None
   3333     and self.state.global_step % self.args.torch_empty_cache_steps == 0
   3334 ):

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:3373, in Trainer.compute_loss(self, model, inputs, return_outputs)
   3371 else:
   3372     labels = None
-> 3373 outputs = model(**inputs)
   3374 # Save past state if it exists
   3375 # TODO: this needs to be fixed and made cleaner later.
   3376 if self.args.past_index >= 0:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:186, in DataParallel.forward(self, *inputs, **kwargs)
    184     return self.module(*inputs[0], **module_kwargs[0])
    185 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 186 outputs = self.parallel_apply(replicas, inputs, module_kwargs)
    187 return self.gather(outputs, self.output_device)

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:201, in DataParallel.parallel_apply(self, replicas, inputs, kwargs)
    200 def parallel_apply(self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any) -> List[Any]:
--> 201     return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py:108, in parallel_apply(modules, inputs, kwargs_tup, devices)
    106     output = results[i]
    107     if isinstance(output, ExceptionWrapper):
--> 108         output.reraise()
    109     outputs.append(output)
    110 return outputs

File /opt/conda/lib/python3.10/site-packages/torch/_utils.py:706, in ExceptionWrapper.reraise(self)
    702 except TypeError:
    703     # If the exception takes multiple arguments, don't try to
    704     # instantiate since we don't know how to
    705     raise RuntimeError(msg) from None
--> 706 raise exception

ValueError: Caught ValueError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 1730, in forward
    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 1188, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 3104, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (8) to match target batch_size (1024).