Closed PawKanarek closed 7 months ago
I modified the code a little bit to make some sanity checks.
def train():
gemma2it = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it") # sanity check model
tokenizer = AutoTokenizer.from_pretrained("NousResearch/gemma-2b-it-tokenizer")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
dataset = load_dataset("pawkanarek/poke_test", split="train")
lora_config = LoraConfig(r=8, target_modules=["k_proj", "v_proj"], task_type="CAUSAL_LM")
fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True}
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
tokenizer = tokenizer,
args=TrainingArguments(
per_device_train_batch_size=64,
num_train_epochs=4,
output_dir="output/trained_model",
optim="adafactor",
dataloader_drop_last = True, # Required for SPMD.
fsdp="full_shard",
fsdp_config=fsdp_config,
),
peft_config=lora_config,
max_seq_length=2048,
)
# 1
trainer.train()
print("comparing gemma2it with trainer.model")
compare_weights(gemma2it, trainer.model) # different GemmaForCausalLM:2506172416 params vs SpmdFullyShardedDataParallel:3031123968 params
# 2
merged_model = trainer.model.merge_and_unload()
print("comparing gemma2it with merged_model")
compare_weights(gemma2it, merged_model) # different GemmaForCausalLM:2506172416 params vs GemmaForCausalLM:3030460416 params
# 3
print("saving merged_model")
merged_model.to("cpu")
merged_model.save_pretrained("output/merged_model")
compare_weights(gemma2it, merged_model) # different GemmaForCausalLM:2506172416 params vs GemmaForCausalLM:3030460416 params
# 4
print("comparing loaded merged_model from disk with in-memory merged_model")
loaded_merged_model = AutoModelForCausalLM.from_pretrained("output/merged_model")
compare_weights(merged_model, loaded_merged_model) # different GemmaForCausalLM:3030460416 params vs GemmaForCausalLM:2506172416 params
# 5
print("comparing gemma2it with loaded merged_model from disk")
compare_weights(gemma2it, loaded_merged_model) # models GemmaForCausalLM and GemmaForCausalLM are the same
I added some sanity checks with base, untouched gemma2it
model, and some mid-step comparison:
trainer.model
, differs from the base gemma2it
: yes, the are different in number of parameters - that implies that training was succesfullmerged_model
, differs from the base gemma2it
: yes, the are different in numer of parameters - that implies that merging was succesfullmerged_model
difffers from the base gemma2it
: yes, the are different in number of parameters - that implies that saving does nothing to parametersloaded_merged_model
and check if it differs from the merged_model
before saving - YES THEY ARE DIFFERENT :( - that implies that there is something wrong with loading the model (or saving)
4.1. This warning popped when loading model from the disk:
Some weights of the model checkpoint at output/merged_model were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', (...) 'model.layers.9._orig_module.self_attn.v_proj.weight']
- This IS expected if you are initializing GemmaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GemmaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GemmaForCausalLM were not initialized from the model checkpoint at output/merged_model and are newly initialized: ['model.layers.0.input_layernorm.weight', (...) 'model.layers.9.self_attn.v_proj.weight']
(...)You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
loaded_merged_model
differs from the base gemma2it
: no, they are the same... - that implies that all my training was worthless... Looks like there is something fishy with my code when saving / loading model from the disk... I'll update if i notice what's wrong. I will check why my weights are saved to something called _orig_module
.
Hi @PawKanarek
Please reference #29388 , by the way do you have testing the LoRA fine tune performance on TPU XLA? I have some explore for some LoRA but it has no any effective for the base model and the generate message just very same as base model.
Hi @zorrofox, and thanks for insight! Looks like my transformers fork didn't included change from that PR. What kind of fine-tune performance are you talking about? You want to know how long does it take to train model with LoRA, or how well model is behaving after fine-tuning?
I used the trainer.save_pretrained
function mentioned in PR https://github.com/huggingface/transformers/pull/29388 but it didn't change anything - trained model after saving is still excactly the same as before training.
I think that i fixed it, but i won't recommend this fix to anyone, so I'm not even thinking about making PR.
It's a patch rather than fix, but i think it works - To check if it really works I will train gemma-2-it until it overfit on training dataset and then i will take a look on interference output.
To apply my patch you would have to add new parameter to save_pretrained https://github.com/huggingface/transformers/blob/f02aea27378dd57c2ced4b28ff9e58ec3876340a/src/transformers/modeling_utils.py#L2190C1-L2203C7
formatting_weights_func = None,
Also add this code before sharding https://github.com/huggingface/transformers/blob/03847ef45189d328a51f428b0a61a6b891e69f88/src/transformers/modeling_utils.py#L2429C1-L2437C111
# apply formatting to the weights before saving
if formatting_weights_func is not None:
for old_key in list(state_dict.keys()):
new_key = formatting_weights_func(old_key)
logger.debug(f"changed {old_key=} to {new_key=}")
state_dict[new_key] = state_dict.pop(old_key)
With this changes I can finally spot a difference between a trained model loaded from disk and a base model that was trained on, and the warning also is gone
Some weights of the model checkpoint at output/merged_model were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', (...) 'model.layers.9._orig_module.self_attn.v_proj.weight']
def compare_weights(model1, model2):
name1, name2 = model1.__class__.__name__, model2.__class__.__name__
params1, params2 = model1.parameters(), model2.parameters()
sum1, sum2 = sum(p.numel() for p in params1), sum(p.numel() for p in params2)
if (sum1 != sum2):
print(f"!!! different in {name1}:{sum1} params vs {name2}:{sum2} params")
for (n1, p1), (n2, p2) in zip(model1.named_parameters(), model2.named_parameters()):
if n1 != n2:
print(f"!!! Parameter names differ: {n1} != {n2}")
return False
if not torch.equal(p1.data, p2.data):
print(f"!!! Parameter values differ: {n1}, {p1.data}, {p2.data}")
return False
def formmating_func(old_key):
return old_key.replace('._orig_module', '')
def train():
# the same training config as before
trainer.train()
trainer_model = trainer.model.to('cpu')
merged_model = trainer_model.merge_and_unload()
merged_model.save_pretrained("output/merged_model", formatting_weights_func = formmating_func)
loaded_merged_model = AutoModelForCausalLM.from_pretrained("output/merged_model")
gemma2it = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it")
print("!!! comparing gemma2it with loaded merged_model from disk")
compare_weights(gemma2it, loaded_merged_model) # !!! FINALLY !!! Parameter values differ: model.layers.0.self_attn.k_proj.weight, tensor([[-3.2043e-04, 8.1177e-03, 3.0365e-03, ..., -5.3101e-03,
I'm not closing this issue, because I didn't fixed it, and true issue is still hidden somewhere. That's only workaround
@PawKanarek Thanks a lot for your advice, I also have the same issue as you. I think you have the root causes that why the trained model not changed.
cc @pacman100 @muellerzr @shub-kris
@PawKanarek just to isolate the error, what happens if you run the same code on a GPU instead of TPU?
@PawKanarek can you also provide the training logs please and run with logging_steps=1
?
Also use save_strategy=epoch
@PawKanarek also after training can you try saving with trainer.save_model('output_dir')
@PawKanarek one last thing that I would like to see is: does the generation differs when using this: model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id, device_map="auto", torch_dtype=torch.float16)
for generation on a GPU
@shub-kris thanks,
@PawKanarek just to isolate the error, what happens if you run the same code on a GPU instead of TPU?
I don't have GPU capable of training Gemma-2b-it model. I have only my local macbook with mps backend and Google TPU clouds (thanks to https://sites.research.google/trc/about/)
@PawKanarek can you also provide the training logs please and run with logging_steps=1? Also use save_strategy=epoch
I will try to give you logs tomorrow. Today the machine is busy with training :)
@PawKanarek also after training can you try saving with trainer.save_model('output_dir')
I tried it many times, no success.
@PawKanarek also with your https://github.com/huggingface/transformers/issues/29659#issuecomment-1999954329 did it work?
Yes. It works.
@PawKanarek one last thing that I would like to see is: does the generation differs when using this: model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id, device_map="auto", torch_dtype=torch.float16) for generation on a GPU
Sadly, I don't have nvidia GPU.
@PawKanarek thanks for your answers. I am having a look and will post here once I get to the root of the issue
I tried the following script on a GPU
import torch
import peft
import trl
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer
print(f"{torch.__version__=}")
print(f"{peft.__version__=}")
print(f"{trl.__version__=}")
def check_model_weights_equality(model1, model2):
params1, params2 = model1.parameters(), model2.parameters()
sum1 = sum(p.numel() for p in params1)
sum2 = sum(p.numel() for p in params2)
if (sum1 != sum2):
print(f"Number of parameters are different in {model1.__class__}:{sum1} and {model2.__class__}:{sum2} are different")
return False
for p1, p2 in zip(params1, params2):
if not torch.equal(p1, p2):
print(f"weights of {model1.__class__} and {model2.__class__} are different")
return False
print(f"models {model1.__class__} and {model2.__class__} are the same")
return True
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
def train():
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
dataset = load_dataset("pawkanarek/poke_test", split="train")
lora_config = LoraConfig(r=16, target_modules=["k_proj", "v_proj"], task_type="CAUSAL_LM", lora_alpha=16, lora_dropout=0.05,)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
tokenizer = tokenizer,
args=TrainingArguments(
per_device_train_batch_size=2,
max_steps=40, # small epochs for brevity, but the same is also with larger epochs
output_dir="output/trained_model",
optim="adafactor",
logging_steps=1,
learning_rate=3e-4,
),
peft_config=lora_config,
max_seq_length=512,
)
trainer.train()
trainer.save_model()
merged_model = trainer.model.merge_and_unload() # merge LORA with base model
merged_model.to("cpu")
print(type(merged_model), count_parameters(merged_model))
merged_model.save_pretrained("adapters_merged")
### VERIFICATION, ENSURE THAT MODEL WAS TRAINED
trained_model = AutoModelForCausalLM.from_pretrained("adapters_merged")
print(type(trained_model), count_parameters(trained_model))
original_model = AutoModelForCausalLM.from_pretrained(model_id)
print(type(original_model), count_parameters(original_model))
check_model_weights_equality(trained_model, original_model)
if __name__ == "__main__":
train()
And here was the output:
[2024-03-18 20:15:02,900] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
torch.__version__='2.1.0a0+32f93b1'
peft.__version__='0.8.2'
trl.__version__='0.7.10'
Loading checkpoint shards: 100%|███████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.38it/s]
/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:290: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
warnings.warn(
{'loss': 2.8281, 'grad_norm': 0.66015625, 'learning_rate': 0.00029249999999999995, 'epoch': 0.0}
{'loss': 2.7031, 'grad_norm': 0.6171875, 'learning_rate': 0.000285, 'epoch': 0.01}
{'loss': 2.7344, 'grad_norm': 0.765625, 'learning_rate': 0.00027749999999999997, 'epoch': 0.01}
{'loss': 2.5469, 'grad_norm': 0.8125, 'learning_rate': 0.00027, 'epoch': 0.02}
{'loss': 2.4688, 'grad_norm': 0.96484375, 'learning_rate': 0.0002625, 'epoch': 0.02}
{'loss': 2.3906, 'grad_norm': 0.90625, 'learning_rate': 0.00025499999999999996, 'epoch': 0.03}
{'loss': 2.4219, 'grad_norm': 1.1015625, 'learning_rate': 0.00024749999999999994, 'epoch': 0.03}
{'loss': 2.2344, 'grad_norm': 0.9296875, 'learning_rate': 0.00023999999999999998, 'epoch': 0.03}
{'loss': 2.2031, 'grad_norm': 1.015625, 'learning_rate': 0.00023249999999999999, 'epoch': 0.04}
{'loss': 2.0312, 'grad_norm': 0.96484375, 'learning_rate': 0.000225, 'epoch': 0.04}
{'loss': 2.0938, 'grad_norm': 1.1015625, 'learning_rate': 0.00021749999999999997, 'epoch': 0.05}
{'loss': 2.0, 'grad_norm': 1.296875, 'learning_rate': 0.00020999999999999998, 'epoch': 0.05}
{'loss': 1.8281, 'grad_norm': 3.078125, 'learning_rate': 0.0002025, 'epoch': 0.06}
{'loss': 1.7656, 'grad_norm': 1.9609375, 'learning_rate': 0.000195, 'epoch': 0.06}
{'loss': 1.7031, 'grad_norm': 3.859375, 'learning_rate': 0.00018749999999999998, 'epoch': 0.07}
{'loss': 1.6484, 'grad_norm': 2.171875, 'learning_rate': 0.00017999999999999998, 'epoch': 0.07}
{'loss': 1.5859, 'grad_norm': 2.453125, 'learning_rate': 0.00017249999999999996, 'epoch': 0.07}
{'loss': 1.5312, 'grad_norm': 1.96875, 'learning_rate': 0.000165, 'epoch': 0.08}
{'loss': 1.5391, 'grad_norm': 1.8671875, 'learning_rate': 0.00015749999999999998, 'epoch': 0.08}
{'loss': 1.3828, 'grad_norm': 2.109375, 'learning_rate': 0.00015, 'epoch': 0.09}
{'loss': 1.3438, 'grad_norm': 3.609375, 'learning_rate': 0.0001425, 'epoch': 0.09}
{'loss': 1.2969, 'grad_norm': 2.671875, 'learning_rate': 0.000135, 'epoch': 0.1}
{'loss': 1.2344, 'grad_norm': 3.328125, 'learning_rate': 0.00012749999999999998, 'epoch': 0.1}
{'loss': 1.2891, 'grad_norm': 2.9375, 'learning_rate': 0.00011999999999999999, 'epoch': 0.1}
{'loss': 1.2656, 'grad_norm': 2.109375, 'learning_rate': 0.0001125, 'epoch': 0.11}
{'loss': 1.0938, 'grad_norm': 2.890625, 'learning_rate': 0.00010499999999999999, 'epoch': 0.11}
{'loss': 1.0391, 'grad_norm': 2.46875, 'learning_rate': 9.75e-05, 'epoch': 0.12}
{'loss': 1.1016, 'grad_norm': 2.859375, 'learning_rate': 8.999999999999999e-05, 'epoch': 0.12}
{'loss': 1.0625, 'grad_norm': 2.421875, 'learning_rate': 8.25e-05, 'epoch': 0.13}
{'loss': 0.957, 'grad_norm': 2.4375, 'learning_rate': 7.5e-05, 'epoch': 0.13}
{'loss': 0.9219, 'grad_norm': 1.703125, 'learning_rate': 6.75e-05, 'epoch': 0.13}
{'loss': 0.8906, 'grad_norm': 1.7734375, 'learning_rate': 5.9999999999999995e-05, 'epoch': 0.14}
{'loss': 0.9609, 'grad_norm': 4.40625, 'learning_rate': 5.2499999999999995e-05, 'epoch': 0.14}
{'loss': 0.875, 'grad_norm': 2.109375, 'learning_rate': 4.4999999999999996e-05, 'epoch': 0.15}
{'loss': 0.9219, 'grad_norm': 2.8125, 'learning_rate': 3.75e-05, 'epoch': 0.15}
{'loss': 0.9102, 'grad_norm': 2.125, 'learning_rate': 2.9999999999999997e-05, 'epoch': 0.16}
{'loss': 0.9258, 'grad_norm': 1.515625, 'learning_rate': 2.2499999999999998e-05, 'epoch': 0.16}
{'loss': 0.8164, 'grad_norm': 1.8515625, 'learning_rate': 1.4999999999999999e-05, 'epoch': 0.17}
{'loss': 0.8164, 'grad_norm': 2.0, 'learning_rate': 7.499999999999999e-06, 'epoch': 0.17}
{'loss': 0.8086, 'grad_norm': 1.6484375, 'learning_rate': 0.0, 'epoch': 0.17}
{'train_runtime': 5.337, 'train_samples_per_second': 14.99, 'train_steps_per_second': 7.495, 'train_loss': 1.554296875, 'epoch': 0.17}
100%|████████████████████████████████████████████████████████████████████████████| 40/40 [00:05<00:00, 7.50it/s]
<class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> 2506172416
Loading checkpoint shards: 100%|███████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.78it/s]
<class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> 2506172416
Loading checkpoint shards: 100%|███████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.81it/s]
<class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> 2506172416
models <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> and <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> are the same
so, the issue has nothing to do with TPU for sure cc @amyeroberts @pacman100 @muellerzr
However, one thing I would like to verify is if your way of checking if the model weights are equal or not. So, will get back to you on that.
trainer.save_model(new_model_id)
# Reload model in FP16 and merge it with LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.bfloat16,
)
newmodel = PeftModel.from_pretrained(base_model, new_model_id)
newmodel = newmodel.merge_and_unload()
print(check_model_weights_equality(model, newmodel))
Logs:
Number of parameters are different in <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'>:9327324160 and <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'>:8537680896 are different
False
@moficodes I think you did misunderstand my intentions. I want to save a standalone model, not just the LoRA adapter. You saved only the LoRA adapter (with trainer.save_model()
), but I there is problem with loading/saving the merged model after merge_and_unload()
Please Take a look at this updated script. I changed a comparing function to be more descriptive, and I added more logging as @shub-kris asked.
# Make sure to run the script with the following envs:
# PJRT_DEVICE=TPU XLA_USE_SPMD=1
import torch
import torch_xla
import peft
import trl
import torch_xla.core.xla_model as xm
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft.peft_model import PeftModel
from trl import SFTTrainer
from transformers import logging, IntervalStrategy
device = xm.xla_device() # Set up TPU device.
def models_equal(model1, model2):
name1, name2 = model1.__class__.__name__, model2.__class__.__name__
params1, params2 = model1.parameters(), model2.parameters()
sum1, sum2 = sum(p.numel() for p in params1), sum(p.numel() for p in params2)
if (sum1 != sum2):
print(f"!!! numer of params are different in {name1}:{sum1} params vs {name2}:{sum2} params")
for (n1, p1), (n2, p2) in zip(model1.named_parameters(), model2.named_parameters()):
if n1 != n2:
print(f"!!! Parameter names differ: {n1} != {n2}")
return False
if not torch.equal(p1.data, p2.data):
print(f"!!! Parameter values differ: {n1}, {p1.data}, {p2.data}")
return False
print(f"!!! models {name1} and {name2} are the same")
return True
def train():
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
dataset = load_dataset("pawkanarek/poke_test", split="train")
lora_config = LoraConfig(r=8, target_modules=["k_proj", "v_proj"], task_type="CAUSAL_LM")
fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True}
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
tokenizer = tokenizer,
args=TrainingArguments(
logging_steps=1,
save_strategy=IntervalStrategy.EPOCH,
per_device_train_batch_size=64,
num_train_epochs=1,
output_dir="output/trained_model",
optim="adafactor",
dataloader_drop_last = True, # Required for SPMD.
fsdp="full_shard",
fsdp_config=fsdp_config,
),
peft_config=lora_config,
max_seq_length=2048,
)
trainer.train()
trainer.save_model()
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", return_dict=True, torch_dtype=torch.bfloat16)
new_model = PeftModel.from_pretrained(base_model, "output/trained_model")
new_model = new_model.merge_and_unload()
new_model.save_pretrained("output/new_model")
new_model_from_disk = AutoModelForCausalLM.from_pretrained("output/new_model", torch_dtype=torch.bfloat16)
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
print(f"are equal after load from disk? {models_equal(base_model, new_model_from_disk)}") # they equal after loading from disk
print(1)
if __name__ == "__main__":
logging.set_verbosity(logging.DEBUG)
train()
As you can see at the end i again see information that the base model and loaded model from disk are the same
!!! models GemmaForCausalLM and GemmaForCausalLM are the same
are equal after load from disk? True
I'm open to investigate further.
Hi @PawKanarek I tried a new script which is very similar to your script, and I tried inference before and after training the models and the results are different, which verifies that the model was trained and also saved perfectly.
# PJRT_DEVICE=TPU XLA_USE_SPMD=1
import torch
import torch_xla
import peft
import trl
import torch_xla.core.xla_model as xm
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer
print(f"{torch.__version__=}")
print(f"{torch_xla.__version__=}")
print(f"{peft.__version__=}")
print(f"{trl.__version__=}")
device = xm.xla_device() # Set up TPU device.
def inference(model, tokenizer):
text = "Quote: Imagination is more"
device = "cpu"
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=20) #generate only supported on GPU and CPU
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
def train():
model_id = "google/gemma-2b"
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# tokenizer.pad_token = tokenizer.eos_token
#Load and process dataset
raw_dataset = load_dataset("Abirate/english_quotes", split="train")
lora_config = LoraConfig(r=8, target_modules="all-linear", task_type="CAUSAL_LM", lora_alpha=16, lora_dropout=0.05,)
fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True}
trainer = SFTTrainer(
model=model,
# train_dataset=format_dataset,
train_dataset=raw_dataset,
tokenizer = tokenizer,
args=TrainingArguments(
per_device_train_batch_size=32,
num_train_epochs=10,
output_dir="output",
optim="adafactor",
logging_steps=1,
learning_rate=3e-4,
save_strategy="no",
dataloader_drop_last = True, # Required for SPMD.
fsdp="full_shard",
fsdp_config=fsdp_config,
),
peft_config=lora_config,
max_seq_length=1024,
packing=True,
dataset_text_field="quote",
)
trainer.train()
trainer.save_model()
merged_model = trainer.model.merge_and_unload() # merge LORA with base model
merged_model.to("cpu")
merged_model.save_pretrained("adapters_merged")
### VERIFICATION, ENSURE THAT MODEL WAS TRAINED
trained_model = AutoModelForCausalLM.from_pretrained("adapters_merged")
original_model = AutoModelForCausalLM.from_pretrained(model_id)
print("Inference with base model: \n\n")
inference(original_model, tokenizer)
print("Inference with trained model: \n\n")
inference(trained_model, tokenizer)
if __name__ == "__main__":
train()
torch.__version__='2.3.0'
torch_xla.__version__='2.3.0+gite385c2f'
peft.__version__='0.8.2'
trl.__version__='0.7.12.dev0'
{'loss': 5.0312, 'grad_norm': 3.109375, 'learning_rate': 0.00029, 'epoch': 0.33}
{'loss': 4.7812, 'grad_norm': 2.921875, 'learning_rate': 0.00028, 'epoch': 0.67}
{'loss': 4.5312, 'grad_norm': 4.15625, 'learning_rate': 0.00027, 'epoch': 1.0}
{'loss': 4.1875, 'grad_norm': 3.90625, 'learning_rate': 0.00026, 'epoch': 1.33}
{'loss': 3.9062, 'grad_norm': 4.46875, 'learning_rate': 0.00025, 'epoch': 1.67}
{'loss': 3.75, 'grad_norm': 4.15625, 'learning_rate': 0.00023999999999999998, 'epoch': 2.0}
{'loss': 3.4688, 'grad_norm': 4.46875, 'learning_rate': 0.00023, 'epoch': 2.33}
{'loss': 3.3438, 'grad_norm': 3.71875, 'learning_rate': 0.00021999999999999995, 'epoch': 2.67}
{'loss': 3.2656, 'grad_norm': 3.5, 'learning_rate': 0.00020999999999999998, 'epoch': 3.0}
{'loss': 3.0781, 'grad_norm': 2.734375, 'learning_rate': 0.00019999999999999998, 'epoch': 3.33}
{'loss': 3.0, 'grad_norm': 2.328125, 'learning_rate': 0.00018999999999999998, 'epoch': 3.67}
{'loss': 2.9531, 'grad_norm': 1.796875, 'learning_rate': 0.00017999999999999998, 'epoch': 4.0}
{'loss': 2.875, 'grad_norm': 2.5, 'learning_rate': 0.00016999999999999999, 'epoch': 4.33}
{'loss': 2.8281, 'grad_norm': 3.15625, 'learning_rate': 0.00015999999999999999, 'epoch': 4.67}
{'loss': 2.7969, 'grad_norm': 3.546875, 'learning_rate': 0.00015, 'epoch': 5.0}
{'loss': 2.7188, 'grad_norm': 1.4375, 'learning_rate': 0.00014, 'epoch': 5.33}
{'loss': 2.7188, 'grad_norm': 2.21875, 'learning_rate': 0.00013, 'epoch': 5.67}
{'loss': 2.7656, 'grad_norm': 3.40625, 'learning_rate': 0.00011999999999999999, 'epoch': 6.0}
{'loss': 2.6875, 'grad_norm': 4.6875, 'learning_rate': 0.00010999999999999998, 'epoch': 6.33}
{'loss': 2.625, 'grad_norm': 1.6015625, 'learning_rate': 9.999999999999999e-05, 'epoch': 6.67}
{'loss': 2.6562, 'grad_norm': 1.546875, 'learning_rate': 8.999999999999999e-05, 'epoch': 7.0}
{'loss': 2.6562, 'grad_norm': 1.703125, 'learning_rate': 7.999999999999999e-05, 'epoch': 7.33}
{'loss': 2.5938, 'grad_norm': 1.40625, 'learning_rate': 7e-05, 'epoch': 7.67}
{'loss': 2.625, 'grad_norm': 1.1796875, 'learning_rate': 5.9999999999999995e-05, 'epoch': 8.0}
{'loss': 2.6562, 'grad_norm': 1.5078125, 'learning_rate': 4.9999999999999996e-05, 'epoch': 8.33}
{'loss': 2.5, 'grad_norm': 1.0234375, 'learning_rate': 3.9999999999999996e-05, 'epoch': 8.67}
{'loss': 2.5156, 'grad_norm': 1.359375, 'learning_rate': 2.9999999999999997e-05, 'epoch': 9.0}
{'loss': 2.5, 'grad_norm': 1.03125, 'learning_rate': 1.9999999999999998e-05, 'epoch': 9.33}
{'loss': 2.5938, 'grad_norm': 1.125, 'learning_rate': 9.999999999999999e-06, 'epoch': 9.67}
{'loss': 2.5, 'grad_norm': 0.97265625, 'learning_rate': 0.0, 'epoch': 10.0}
{'train_runtime': 386.8015, 'train_samples_per_second': 2.482, 'train_steps_per_second': 0.078, 'train_loss': 3.103645833333333, 'epoch': 10.0}
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world. - Albert Einstein
I am
2. With finetuned-model
Quote: Imagination is more increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa
@amyeroberts we can close this issue #29659 and also the issue #29608
I also tried without FSDP
as it is easier to finetune:
With finetuned model I got this result:
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world.
Author: Albert Einstein
Thank you @shub-kris ! I will run this script on my local machine and then I will share the results. I have one question regarding to your code, why do you set?
tokenizer.pad_token = tokenizer.eos_token
?
It configures the tokenizer's padding token to be the same as its end-of-sequence (EOS) token. But you don't need it for this use-case as the tokenizr already has pad_token defined here
@shub-kris ,
I got the inference result from TPU is not like you
logs:
torch.__version__='2.3.0.dev20240312+cu121'
torch_xla.__version__='2.3.0+git97acc14'
peft.__version__='0.9.0'
trl.__version__='0.7.11'
config.json: 100%|███████████████████████████████████████████████████████████████████████████| 627/627 [00:00<00:00, 3.63MB/s]
model.safetensors.index.json: 100%|██████████████████████████████████████████████████████| 13.5k/13.5k [00:00<00:00, 49.6MB/s]
model-00001-of-00002.safetensors: 100%|███████████████████████████████████████████████████| 4.95G/4.95G [00:22<00:00, 218MB/s]
model-00002-of-00002.safetensors: 100%|███████████████████████████████████████████████████| 67.1M/67.1M [00:00<00:00, 222MB/s]
Downloading shards: 100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:23<00:00, 11.71s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.05it/s]
generation_config.json: 100%|█████████████████████████████████████████████████████████████████| 137/137 [00:00<00:00, 832kB/s]
tokenizer_config.json: 100%|█████████████████████████████████████████████████████████████| 1.11k/1.11k [00:00<00:00, 7.57MB/s]
tokenizer.model: 100%|████████████████████████████████████████████████████████████████████| 4.24M/4.24M [00:00<00:00, 207MB/s]
tokenizer.json: 100%|█████████████████████████████████████████████████████████████████████| 17.5M/17.5M [00:00<00:00, 223MB/s]
special_tokens_map.json: 100%|███████████████████████████████████████████████████████████████| 555/555 [00:00<00:00, 1.27MB/s]
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1710857864.890503 6669 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1710857864.890575 6669 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1710857864.890586 6669 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
/home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py:104: UserWarning: `devkind` argument is deprecated and will be removed in a future release.
warnings.warn("`devkind` argument is deprecated and will be removed in a "
Generating train split: 102 examples [00:00, 159.49 examples/s]
/home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:294: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
warnings.warn(
/home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead:
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
warnings.warn(
0%| | 0/30 [00:00<?, ?it/s]/home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1597: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
warnings.warn("For backward hooks to be called,"
/home/admin_greghuang_altostrat_com/.local/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at ../torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
{'loss': 4.273, 'grad_norm': 3.953125, 'learning_rate': 0.00029, 'epoch': 0.33}
{'loss': 4.1232, 'grad_norm': 4.25, 'learning_rate': 0.00028, 'epoch': 0.67}
{'loss': 3.7796, 'grad_norm': 5.3125, 'learning_rate': 0.00027, 'epoch': 1.0}
{'loss': 3.4005, 'grad_norm': 4.59375, 'learning_rate': 0.00026, 'epoch': 1.33}
{'loss': 3.2413, 'grad_norm': 3.046875, 'learning_rate': 0.00025, 'epoch': 1.67}
{'loss': 2.9242, 'grad_norm': 1.9765625, 'learning_rate': 0.00023999999999999998, 'epoch': 2.0}
{'loss': 2.7689, 'grad_norm': 2.5625, 'learning_rate': 0.00023, 'epoch': 2.33}
{'loss': 2.7829, 'grad_norm': 2.046875, 'learning_rate': 0.00021999999999999995, 'epoch': 2.67}
{'loss': 2.6584, 'grad_norm': 3.984375, 'learning_rate': 0.00020999999999999998, 'epoch': 3.0}
{'loss': 2.6561, 'grad_norm': 1.6171875, 'learning_rate': 0.00019999999999999998, 'epoch': 3.33}
{'loss': 2.5347, 'grad_norm': 5.34375, 'learning_rate': 0.00018999999999999998, 'epoch': 3.67}
{'loss': 2.4281, 'grad_norm': 2.328125, 'learning_rate': 0.00017999999999999998, 'epoch': 4.0}
{'loss': 2.4578, 'grad_norm': 3.015625, 'learning_rate': 0.00016999999999999999, 'epoch': 4.33}
{'loss': 2.5122, 'grad_norm': 1.4765625, 'learning_rate': 0.00015999999999999999, 'epoch': 4.67}
{'loss': 2.3117, 'grad_norm': 2.125, 'learning_rate': 0.00015, 'epoch': 5.0}
{'loss': 2.3832, 'grad_norm': 2.109375, 'learning_rate': 0.00014, 'epoch': 5.33}
{'loss': 2.3193, 'grad_norm': 1.609375, 'learning_rate': 0.00013, 'epoch': 5.67}
{'loss': 2.2856, 'grad_norm': 2.109375, 'learning_rate': 0.00011999999999999999, 'epoch': 6.0}
{'loss': 2.2524, 'grad_norm': 1.7421875, 'learning_rate': 0.00010999999999999998, 'epoch': 6.33}
{'loss': 2.2826, 'grad_norm': 1.328125, 'learning_rate': 9.999999999999999e-05, 'epoch': 6.67}
{'loss': 2.1978, 'grad_norm': 1.109375, 'learning_rate': 8.999999999999999e-05, 'epoch': 7.0}
{'loss': 2.2295, 'grad_norm': 1.078125, 'learning_rate': 7.999999999999999e-05, 'epoch': 7.33}
{'loss': 2.1379, 'grad_norm': 1.21875, 'learning_rate': 7e-05, 'epoch': 7.67}
{'loss': 2.2398, 'grad_norm': 1.6171875, 'learning_rate': 5.9999999999999995e-05, 'epoch': 8.0}
{'loss': 2.1681, 'grad_norm': 0.890625, 'learning_rate': 4.9999999999999996e-05, 'epoch': 8.33}
{'loss': 2.176, 'grad_norm': 5.96875, 'learning_rate': 3.9999999999999996e-05, 'epoch': 8.67}
{'loss': 2.1323, 'grad_norm': 0.89453125, 'learning_rate': 2.9999999999999997e-05, 'epoch': 9.0}
{'loss': 2.1921, 'grad_norm': 0.87109375, 'learning_rate': 1.9999999999999998e-05, 'epoch': 9.33}
{'loss': 2.0294, 'grad_norm': 5.625, 'learning_rate': 9.999999999999999e-06, 'epoch': 9.67}
{'loss': 2.1877, 'grad_norm': 0.73046875, 'learning_rate': 0.0, 'epoch': 10.0}
{'train_runtime': 567.8682, 'train_samples_per_second': 1.691, 'train_steps_per_second': 0.053, 'train_loss': 2.6022108157475787, 'epoch': 10.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████| 30/30 [09:27<00:00, 18.93s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.31it/s]
Some weights of the model checkpoint at adapters_merged were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', 'model.layers.0._orig_module.mlp.down_proj.weight', 'model.layers.0._orig_module.mlp.gate_proj.weight', 'model.layers.0._orig_module.mlp.up_proj.weight', 'model.layers.0._orig_module.post_attention_layernorm.weight', 'model.layers.0._orig_module.self_attn.k_proj.weight', 'model.layers.0._orig_module.self_attn.o_proj.weight', 'model.layers.0._orig_module.self_attn.q_proj.weight', 'model.layers.0._orig_module.self_attn.v_proj.weight', 'model.layers.1._orig_module.input_layernorm.weight', 'model.layers.1._orig_module.mlp.down_proj.weight', 'model.layers.1._orig_module.mlp.gate_proj.weight', 'model.layers.1._orig_module.mlp.up_proj.weight', 'model.layers.1._orig_module.post_attention_layernorm.weight', 'model.layers.1._orig_module.self_attn.k_proj.weight', 'model.layers.1._orig_module.self_attn.o_proj.weight', 'model.layers.1._orig_module.self_attn.q_proj.weight', 'model.layers.1._orig_module.self_attn.v_proj.weight', 'model.layers.10._orig_module.input_layernorm.weight', 'model.layers.10._orig_module.mlp.down_proj.weight', 'model.layers.10._orig_module.mlp.gate_proj.weight', 'model.layers.10._orig_module.mlp.up_proj.weight', 'model.layers.10._orig_module.post_attention_layernorm.weight', 'model.layers.10._orig_module.self_attn.k_proj.weight', 'model.layers.10._orig_module.self_attn.o_proj.weight', 'model.layers.10._orig_module.self_attn.q_proj.weight', 'model.layers.10._orig_module.self_attn.v_proj.weight', 'model.layers.11._orig_module.input_layernorm.weight', 'model.layers.11._orig_module.mlp.down_proj.weight', 'model.layers.11._orig_module.mlp.gate_proj.weight', 'model.layers.11._orig_module.mlp.up_proj.weight', 'model.layers.11._orig_module.post_attention_layernorm.weight', 'model.layers.11._orig_module.self_attn.k_proj.weight', 'model.layers.11._orig_module.self_attn.o_proj.weight', 'model.layers.11._orig_module.self_attn.q_proj.weight', 'model.layers.11._orig_module.self_attn.v_proj.weight', 'model.layers.12._orig_module.input_layernorm.weight', 'model.layers.12._orig_module.mlp.down_proj.weight', 'model.layers.12._orig_module.mlp.gate_proj.weight', 'model.layers.12._orig_module.mlp.up_proj.weight', 'model.layers.12._orig_module.post_attention_layernorm.weight', 'model.layers.12._orig_module.self_attn.k_proj.weight', 'model.layers.12._orig_module.self_attn.o_proj.weight', 'model.layers.12._orig_module.self_attn.q_proj.weight', 'model.layers.12._orig_module.self_attn.v_proj.weight', 'model.layers.13._orig_module.input_layernorm.weight', 'model.layers.13._orig_module.mlp.down_proj.weight', 'model.layers.13._orig_module.mlp.gate_proj.weight', 'model.layers.13._orig_module.mlp.up_proj.weight', 'model.layers.13._orig_module.post_attention_layernorm.weight', 'model.layers.13._orig_module.self_attn.k_proj.weight', 'model.layers.13._orig_module.self_attn.o_proj.weight', 'model.layers.13._orig_module.self_attn.q_proj.weight', 'model.layers.13._orig_module.self_attn.v_proj.weight', 'model.layers.14._orig_module.input_layernorm.weight', 'model.layers.14._orig_module.mlp.down_proj.weight', 'model.layers.14._orig_module.mlp.gate_proj.weight', 'model.layers.14._orig_module.mlp.up_proj.weight', 'model.layers.14._orig_module.post_attention_layernorm.weight', 'model.layers.14._orig_module.self_attn.k_proj.weight', 'model.layers.14._orig_module.self_attn.o_proj.weight', 'model.layers.14._orig_module.self_attn.q_proj.weight', 'model.layers.14._orig_module.self_attn.v_proj.weight', 'model.layers.15._orig_module.input_layernorm.weight', 'model.layers.15._orig_module.mlp.down_proj.weight', 'model.layers.15._orig_module.mlp.gate_proj.weight', 'model.layers.15._orig_module.mlp.up_proj.weight', 'model.layers.15._orig_module.post_attention_layernorm.weight', 'model.layers.15._orig_module.self_attn.k_proj.weight', 'model.layers.15._orig_module.self_attn.o_proj.weight', 'model.layers.15._orig_module.self_attn.q_proj.weight', 'model.layers.15._orig_module.self_attn.v_proj.weight', 'model.layers.16._orig_module.input_layernorm.weight', 'model.layers.16._orig_module.mlp.down_proj.weight', 'model.layers.16._orig_module.mlp.gate_proj.weight', 'model.layers.16._orig_module.mlp.up_proj.weight', 'model.layers.16._orig_module.post_attention_layernorm.weight', 'model.layers.16._orig_module.self_attn.k_proj.weight', 'model.layers.16._orig_module.self_attn.o_proj.weight', 'model.layers.16._orig_module.self_attn.q_proj.weight', 'model.layers.16._orig_module.self_attn.v_proj.weight', 'model.layers.17._orig_module.input_layernorm.weight', 'model.layers.17._orig_module.mlp.down_proj.weight', 'model.layers.17._orig_module.mlp.gate_proj.weight', 'model.layers.17._orig_module.mlp.up_proj.weight', 'model.layers.17._orig_module.post_attention_layernorm.weight', 'model.layers.17._orig_module.self_attn.k_proj.weight', 'model.layers.17._orig_module.self_attn.o_proj.weight', 'model.layers.17._orig_module.self_attn.q_proj.weight', 'model.layers.17._orig_module.self_attn.v_proj.weight', 'model.layers.2._orig_module.input_layernorm.weight', 'model.layers.2._orig_module.mlp.down_proj.weight', 'model.layers.2._orig_module.mlp.gate_proj.weight', 'model.layers.2._orig_module.mlp.up_proj.weight', 'model.layers.2._orig_module.post_attention_layernorm.weight', 'model.layers.2._orig_module.self_attn.k_proj.weight', 'model.layers.2._orig_module.self_attn.o_proj.weight', 'model.layers.2._orig_module.self_attn.q_proj.weight', 'model.layers.2._orig_module.self_attn.v_proj.weight', 'model.layers.3._orig_module.input_layernorm.weight', 'model.layers.3._orig_module.mlp.down_proj.weight', 'model.layers.3._orig_module.mlp.gate_proj.weight', 'model.layers.3._orig_module.mlp.up_proj.weight', 'model.layers.3._orig_module.post_attention_layernorm.weight', 'model.layers.3._orig_module.self_attn.k_proj.weight', 'model.layers.3._orig_module.self_attn.o_proj.weight', 'model.layers.3._orig_module.self_attn.q_proj.weight', 'model.layers.3._orig_module.self_attn.v_proj.weight', 'model.layers.4._orig_module.input_layernorm.weight', 'model.layers.4._orig_module.mlp.down_proj.weight', 'model.layers.4._orig_module.mlp.gate_proj.weight', 'model.layers.4._orig_module.mlp.up_proj.weight', 'model.layers.4._orig_module.post_attention_layernorm.weight', 'model.layers.4._orig_module.self_attn.k_proj.weight', 'model.layers.4._orig_module.self_attn.o_proj.weight', 'model.layers.4._orig_module.self_attn.q_proj.weight', 'model.layers.4._orig_module.self_attn.v_proj.weight', 'model.layers.5._orig_module.input_layernorm.weight', 'model.layers.5._orig_module.mlp.down_proj.weight', 'model.layers.5._orig_module.mlp.gate_proj.weight', 'model.layers.5._orig_module.mlp.up_proj.weight', 'model.layers.5._orig_module.post_attention_layernorm.weight', 'model.layers.5._orig_module.self_attn.k_proj.weight', 'model.layers.5._orig_module.self_attn.o_proj.weight', 'model.layers.5._orig_module.self_attn.q_proj.weight', 'model.layers.5._orig_module.self_attn.v_proj.weight', 'model.layers.6._orig_module.input_layernorm.weight', 'model.layers.6._orig_module.mlp.down_proj.weight', 'model.layers.6._orig_module.mlp.gate_proj.weight', 'model.layers.6._orig_module.mlp.up_proj.weight', 'model.layers.6._orig_module.post_attention_layernorm.weight', 'model.layers.6._orig_module.self_attn.k_proj.weight', 'model.layers.6._orig_module.self_attn.o_proj.weight', 'model.layers.6._orig_module.self_attn.q_proj.weight', 'model.layers.6._orig_module.self_attn.v_proj.weight', 'model.layers.7._orig_module.input_layernorm.weight', 'model.layers.7._orig_module.mlp.down_proj.weight', 'model.layers.7._orig_module.mlp.gate_proj.weight', 'model.layers.7._orig_module.mlp.up_proj.weight', 'model.layers.7._orig_module.post_attention_layernorm.weight', 'model.layers.7._orig_module.self_attn.k_proj.weight', 'model.layers.7._orig_module.self_attn.o_proj.weight', 'model.layers.7._orig_module.self_attn.q_proj.weight', 'model.layers.7._orig_module.self_attn.v_proj.weight', 'model.layers.8._orig_module.input_layernorm.weight', 'model.layers.8._orig_module.mlp.down_proj.weight', 'model.layers.8._orig_module.mlp.gate_proj.weight', 'model.layers.8._orig_module.mlp.up_proj.weight', 'model.layers.8._orig_module.post_attention_layernorm.weight', 'model.layers.8._orig_module.self_attn.k_proj.weight', 'model.layers.8._orig_module.self_attn.o_proj.weight', 'model.layers.8._orig_module.self_attn.q_proj.weight', 'model.layers.8._orig_module.self_attn.v_proj.weight', 'model.layers.9._orig_module.input_layernorm.weight', 'model.layers.9._orig_module.mlp.down_proj.weight', 'model.layers.9._orig_module.mlp.gate_proj.weight', 'model.layers.9._orig_module.mlp.up_proj.weight', 'model.layers.9._orig_module.post_attention_layernorm.weight', 'model.layers.9._orig_module.self_attn.k_proj.weight', 'model.layers.9._orig_module.self_attn.o_proj.weight', 'model.layers.9._orig_module.self_attn.q_proj.weight', 'model.layers.9._orig_module.self_attn.v_proj.weight']
- This IS expected if you are initializing GemmaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GemmaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GemmaForCausalLM were not initialized from the model checkpoint at adapters_merged and are newly initialized: ['model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.v_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.33it/s]
Inference with base model:
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world. - Albert Einstein
I am
Inference with trained model:
Quote: Imagination is more increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa
@zorrofox it's like me: https://github.com/huggingface/transformers/issues/29659#issuecomment-2006524634
@zorrofox it's like me: #29659 (comment)
But the inference result is very diffrent.
@zorrofox nothing is different. Please go through the comment once again, and if it's different what is different?
Are you referring to this comment: https://github.com/huggingface/transformers/issues/29659#issuecomment-2006743324 then here i tried without FSDP.
I think that my original method for comparing weights was broken. When I accessed the parameters with the params1 = model1.parameters()
Then the method returns iterator function, and it will only iterate once. And in my original comparing function I accessed it twice, so the my original function for model comparing was buggy...:( Look at this sample code
params1 = model1.parameters()
print(len(list(params1))) # prints 164
print(len(list(params1))) # prints 0
I tried your code @shub-kris and I have exactly the same result from merged model:
Quote: Imagination is more increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa
That looks kinda broken, and i still experience this warning when loading merged model
Some weights of the model checkpoint at adapters_merged were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', 'model.layers.0._orig_module.mlp.down_proj.weight', 'model.layers.0._orig_module.mlp.gate_proj.weight', 'model.layers.0._orig_module.mlp.up_proj.weight', 'model.layers.0._orig_module.post_attention_layernorm.weight', 'model.layers.0._orig_module.self_attn.k_proj.weight', 'model.layers.0._orig_module.self_attn.o_proj.weight', 'model.layers.0._orig_module.self_attn.q_proj.weight', 'model.layers.0._orig_module.self_attn.v_proj.weight', 'model.layers.1._orig_module.input_layernorm.weight', 'model.layers.1._orig_module.mlp.down_proj.weight', 'model.layers.1._orig_module.mlp.gate_proj.weight', 'model.layers.1._orig_module.mlp.up_proj.weight', 'model.layers.1._orig_module.post_attention_layernorm.weight', 'model.layers.1._orig_module.self_attn.k_proj.weight', 'model.layers.1._orig_module.self_attn.o_proj.weight', 'model.layers.1._orig_module.self_attn.q_proj.weight', 'model.layers.1._orig_module.self_attn.v_proj.weight', 'model.layers.10._orig_module.input_layernorm.weight', 'model.layers.10._orig_module.mlp.down_proj.weight', 'model.layers.10._orig_module.mlp.gate_proj.weight', 'model.layers.10._orig_module.mlp.up_proj.weight', 'model.layers.10._orig_module.post_attention_layernorm.weight', 'model.layers.10._orig_module.self_attn.k_proj.weight', 'model.layers.10._orig_module.self_attn.o_proj.weight', 'model.layers.10._orig_module.self_attn.q_proj.weight', 'model.layers.10._orig_module.self_attn.v_proj.weight', 'model.layers.11._orig_module.input_layernorm.weight', 'model.layers.11._orig_module.mlp.down_proj.weight', 'model.layers.11._orig_module.mlp.gate_proj.weight', 'model.layers.11._orig_module.mlp.up_proj.weight', 'model.layers.11._orig_module.post_attention_layernorm.weight', 'model.layers.11._orig_module.self_attn.k_proj.weight', 'model.layers.11._orig_module.self_attn.o_proj.weight', 'model.layers.11._orig_module.self_attn.q_proj.weight', 'model.layers.11._orig_module.self_attn.v_proj.weight', 'model.layers.12._orig_module.input_layernorm.weight', 'model.layers.12._orig_module.mlp.down_proj.weight', 'model.layers.12._orig_module.mlp.gate_proj.weight', 'model.layers.12._orig_module.mlp.up_proj.weight', 'model.layers.12._orig_module.post_attention_layernorm.weight', 'model.layers.12._orig_module.self_attn.k_proj.weight', 'model.layers.12._orig_module.self_attn.o_proj.weight', 'model.layers.12._orig_module.self_attn.q_proj.weight', 'model.layers.12._orig_module.self_attn.v_proj.weight', 'model.layers.13._orig_module.input_layernorm.weight', 'model.layers.13._orig_module.mlp.down_proj.weight', 'model.layers.13._orig_module.mlp.gate_proj.weight', 'model.layers.13._orig_module.mlp.up_proj.weight', 'model.layers.13._orig_module.post_attention_layernorm.weight', 'model.layers.13._orig_module.self_attn.k_proj.weight', 'model.layers.13._orig_module.self_attn.o_proj.weight', 'model.layers.13._orig_module.self_attn.q_proj.weight', 'model.layers.13._orig_module.self_attn.v_proj.weight', 'model.layers.14._orig_module.input_layernorm.weight', 'model.layers.14._orig_module.mlp.down_proj.weight', 'model.layers.14._orig_module.mlp.gate_proj.weight', 'model.layers.14._orig_module.mlp.up_proj.weight', 'model.layers.14._orig_module.post_attention_layernorm.weight', 'model.layers.14._orig_module.self_attn.k_proj.weight', 'model.layers.14._orig_module.self_attn.o_proj.weight', 'model.layers.14._orig_module.self_attn.q_proj.weight', 'model.layers.14._orig_module.self_attn.v_proj.weight', 'model.layers.15._orig_module.input_layernorm.weight', 'model.layers.15._orig_module.mlp.down_proj.weight', 'model.layers.15._orig_module.mlp.gate_proj.weight', 'model.layers.15._orig_module.mlp.up_proj.weight', 'model.layers.15._orig_module.post_attention_layernorm.weight', 'model.layers.15._orig_module.self_attn.k_proj.weight', 'model.layers.15._orig_module.self_attn.o_proj.weight', 'model.layers.15._orig_module.self_attn.q_proj.weight', 'model.layers.15._orig_module.self_attn.v_proj.weight', 'model.layers.16._orig_module.input_layernorm.weight', 'model.layers.16._orig_module.mlp.down_proj.weight', 'model.layers.16._orig_module.mlp.gate_proj.weight', 'model.layers.16._orig_module.mlp.up_proj.weight', 'model.layers.16._orig_module.post_attention_layernorm.weight', 'model.layers.16._orig_module.self_attn.k_proj.weight', 'model.layers.16._orig_module.self_attn.o_proj.weight', 'model.layers.16._orig_module.self_attn.q_proj.weight', 'model.layers.16._orig_module.self_attn.v_proj.weight', 'model.layers.17._orig_module.input_layernorm.weight', 'model.layers.17._orig_module.mlp.down_proj.weight', 'model.layers.17._orig_module.mlp.gate_proj.weight', 'model.layers.17._orig_module.mlp.up_proj.weight', 'model.layers.17._orig_module.post_attention_layernorm.weight', 'model.layers.17._orig_module.self_attn.k_proj.weight', 'model.layers.17._orig_module.self_attn.o_proj.weight', 'model.layers.17._orig_module.self_attn.q_proj.weight', 'model.layers.17._orig_module.self_attn.v_proj.weight', 'model.layers.2._orig_module.input_layernorm.weight', 'model.layers.2._orig_module.mlp.down_proj.weight', 'model.layers.2._orig_module.mlp.gate_proj.weight', 'model.layers.2._orig_module.mlp.up_proj.weight', 'model.layers.2._orig_module.post_attention_layernorm.weight', 'model.layers.2._orig_module.self_attn.k_proj.weight', 'model.layers.2._orig_module.self_attn.o_proj.weight', 'model.layers.2._orig_module.self_attn.q_proj.weight', 'model.layers.2._orig_module.self_attn.v_proj.weight', 'model.layers.3._orig_module.input_layernorm.weight', 'model.layers.3._orig_module.mlp.down_proj.weight', 'model.layers.3._orig_module.mlp.gate_proj.weight', 'model.layers.3._orig_module.mlp.up_proj.weight', 'model.layers.3._orig_module.post_attention_layernorm.weight', 'model.layers.3._orig_module.self_attn.k_proj.weight', 'model.layers.3._orig_module.self_attn.o_proj.weight', 'model.layers.3._orig_module.self_attn.q_proj.weight', 'model.layers.3._orig_module.self_attn.v_proj.weight', 'model.layers.4._orig_module.input_layernorm.weight', 'model.layers.4._orig_module.mlp.down_proj.weight', 'model.layers.4._orig_module.mlp.gate_proj.weight', 'model.layers.4._orig_module.mlp.up_proj.weight', 'model.layers.4._orig_module.post_attention_layernorm.weight', 'model.layers.4._orig_module.self_attn.k_proj.weight', 'model.layers.4._orig_module.self_attn.o_proj.weight', 'model.layers.4._orig_module.self_attn.q_proj.weight', 'model.layers.4._orig_module.self_attn.v_proj.weight', 'model.layers.5._orig_module.input_layernorm.weight', 'model.layers.5._orig_module.mlp.down_proj.weight', 'model.layers.5._orig_module.mlp.gate_proj.weight', 'model.layers.5._orig_module.mlp.up_proj.weight', 'model.layers.5._orig_module.post_attention_layernorm.weight', 'model.layers.5._orig_module.self_attn.k_proj.weight', 'model.layers.5._orig_module.self_attn.o_proj.weight', 'model.layers.5._orig_module.self_attn.q_proj.weight', 'model.layers.5._orig_module.self_attn.v_proj.weight', 'model.layers.6._orig_module.input_layernorm.weight', 'model.layers.6._orig_module.mlp.down_proj.weight', 'model.layers.6._orig_module.mlp.gate_proj.weight', 'model.layers.6._orig_module.mlp.up_proj.weight', 'model.layers.6._orig_module.post_attention_layernorm.weight', 'model.layers.6._orig_module.self_attn.k_proj.weight', 'model.layers.6._orig_module.self_attn.o_proj.weight', 'model.layers.6._orig_module.self_attn.q_proj.weight', 'model.layers.6._orig_module.self_attn.v_proj.weight', 'model.layers.7._orig_module.input_layernorm.weight', 'model.layers.7._orig_module.mlp.down_proj.weight', 'model.layers.7._orig_module.mlp.gate_proj.weight', 'model.layers.7._orig_module.mlp.up_proj.weight', 'model.layers.7._orig_module.post_attention_layernorm.weight', 'model.layers.7._orig_module.self_attn.k_proj.weight', 'model.layers.7._orig_module.self_attn.o_proj.weight', 'model.layers.7._orig_module.self_attn.q_proj.weight', 'model.layers.7._orig_module.self_attn.v_proj.weight', 'model.layers.8._orig_module.input_layernorm.weight', 'model.layers.8._orig_module.mlp.down_proj.weight', 'model.layers.8._orig_module.mlp.gate_proj.weight', 'model.layers.8._orig_module.mlp.up_proj.weight', 'model.layers.8._orig_module.post_attention_layernorm.weight', 'model.layers.8._orig_module.self_attn.k_proj.weight', 'model.layers.8._orig_module.self_attn.o_proj.weight', 'model.layers.8._orig_module.self_attn.q_proj.weight', 'model.layers.8._orig_module.self_attn.v_proj.weight', 'model.layers.9._orig_module.input_layernorm.weight', 'model.layers.9._orig_module.mlp.down_proj.weight', 'model.layers.9._orig_module.mlp.gate_proj.weight', 'model.layers.9._orig_module.mlp.up_proj.weight', 'model.layers.9._orig_module.post_attention_layernorm.weight', 'model.layers.9._orig_module.self_attn.k_proj.weight', 'model.layers.9._orig_module.self_attn.o_proj.weight', 'model.layers.9._orig_module.self_attn.q_proj.weight', 'model.layers.9._orig_module.self_attn.v_proj.weight']
- This IS expected if you are initializing GemmaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GemmaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GemmaForCausalLM were not initialized from the model checkpoint at adapters_merged and are newly initialized: ['model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.v_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
But maybe this should be addressed in another issue. Thanks once more for investigating and debugging.
@PawKanarek
That looks kinda broken, and i still experience this warning when loading merged model
Is this happening when you're loading a saved model?
Is this happening when you're loading a saved model?
@amyeroberts No, I copied that warning message from comment of @zorrofox https://github.com/huggingface/transformers/issues/29659#issuecomment-2007343622, but I remember that i also experienced this warning.
To be 100% certain, I once again launched code from this comment of @shub-kris https://github.com/huggingface/transformers/issues/29659#issuecomment-2006524634 and thats is my output
torch.__version__='2.3.0.dev20240307'
torch_xla.__version__='2.3.0+git46e2230'
peft.__version__='0.9.0'
trl.__version__='0.7.12.dev0'
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.67it/s]
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1710873891.917161 1297506 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1710873891.917242 1297506 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1710873891.917250 1297506 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/cextension.py:31: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
warn("The installed version of bitsandbytes was compiled without GPU support. "
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
/home/raix/trl/trl/trainer/sft_trainer.py:316: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
warnings.warn(
0%| | 0/30 [00:00<?, ?it/s]/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/nn/modules/module.py:1597: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
warnings.warn("For backward hooks to be called,"
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1709797140173/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
{'loss': 4.254, 'grad_norm': 3.453125, 'learning_rate': 0.00029, 'epoch': 0.33}
{'loss': 4.1319, 'grad_norm': 3.78125, 'learning_rate': 0.00028, 'epoch': 0.67}
{'loss': 3.8043, 'grad_norm': 4.5, 'learning_rate': 0.00027, 'epoch': 1.0}
{'loss': 3.4729, 'grad_norm': 3.859375, 'learning_rate': 0.00026, 'epoch': 1.33}
{'loss': 3.1394, 'grad_norm': 3.375, 'learning_rate': 0.00025, 'epoch': 1.67}
{'loss': 2.9524, 'grad_norm': 2.015625, 'learning_rate': 0.00023999999999999998, 'epoch': 2.0}
{'loss': 2.8268, 'grad_norm': 1.703125, 'learning_rate': 0.00023, 'epoch': 2.33}
{'loss': 2.6656, 'grad_norm': 1.4609375, 'learning_rate': 0.00021999999999999995, 'epoch': 2.67}
{'loss': 2.7338, 'grad_norm': 4.21875, 'learning_rate': 0.00020999999999999998, 'epoch': 3.0}
{'loss': 2.6369, 'grad_norm': 2.40625, 'learning_rate': 0.00019999999999999998, 'epoch': 3.33}
{'loss': 2.5441, 'grad_norm': 2.21875, 'learning_rate': 0.00018999999999999998, 'epoch': 3.67}
{'loss': 2.4651, 'grad_norm': 2.6875, 'learning_rate': 0.00017999999999999998, 'epoch': 4.0}
{'loss': 2.3907, 'grad_norm': 11.375, 'learning_rate': 0.00016999999999999999, 'epoch': 4.33}
{'loss': 2.3174, 'grad_norm': 3.875, 'learning_rate': 0.00015999999999999999, 'epoch': 4.67}
{'loss': 2.489, 'grad_norm': 1.609375, 'learning_rate': 0.00015, 'epoch': 5.0}
{'loss': 2.2825, 'grad_norm': 1.4921875, 'learning_rate': 0.00014, 'epoch': 5.33}
{'loss': 2.3592, 'grad_norm': 2.3125, 'learning_rate': 0.00013, 'epoch': 5.67}
{'loss': 2.4066, 'grad_norm': 1.859375, 'learning_rate': 0.00011999999999999999, 'epoch': 6.0}
{'loss': 2.2769, 'grad_norm': 2.515625, 'learning_rate': 0.00010999999999999998, 'epoch': 6.33}
{'loss': 2.2699, 'grad_norm': 1.65625, 'learning_rate': 9.999999999999999e-05, 'epoch': 6.67}
{'loss': 2.267, 'grad_norm': 1.4765625, 'learning_rate': 8.999999999999999e-05, 'epoch': 7.0}
{'loss': 2.0841, 'grad_norm': 1.21875, 'learning_rate': 7.999999999999999e-05, 'epoch': 7.33}
{'loss': 2.3272, 'grad_norm': 2.0625, 'learning_rate': 7e-05, 'epoch': 7.67}
{'loss': 2.2218, 'grad_norm': 2.6875, 'learning_rate': 5.9999999999999995e-05, 'epoch': 8.0}
{'loss': 2.1625, 'grad_norm': 0.74609375, 'learning_rate': 4.9999999999999996e-05, 'epoch': 8.33}
{'loss': 2.1687, 'grad_norm': 1.203125, 'learning_rate': 3.9999999999999996e-05, 'epoch': 8.67}
{'loss': 2.153, 'grad_norm': 7.65625, 'learning_rate': 2.9999999999999997e-05, 'epoch': 9.0}
{'loss': 2.1273, 'grad_norm': 1.359375, 'learning_rate': 1.9999999999999998e-05, 'epoch': 9.33}
{'loss': 2.1455, 'grad_norm': 3.015625, 'learning_rate': 9.999999999999999e-06, 'epoch': 9.67}
{'loss': 2.2011, 'grad_norm': 1.0078125, 'learning_rate': 0.0, 'epoch': 10.0}
{'train_runtime': 250.5815, 'train_samples_per_second': 3.831, 'train_steps_per_second': 0.12, 'train_loss': 2.6092514197031655, 'epoch': 10.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [04:10<00:00, 8.35s/it]
tcmalloc: large alloc 2097152000 bytes == 0x52ee08000 @ 0x7f0701396680 0x7f07013b7824 0x7f07013b7b8a 0x7f06e82d38e4 0x7f06e8298d03 0x7f06e9836af9 0x7f06e9830754 0x7f06e983079f 0x7f06e98307e5 0x7f06e9fcaf90 0x7f06eac15c91 0x7f06eac15ceb 0x7f06ea832d67 0x7f06eabdc25f 0x7f06ea87ad80 0x7f06f4dfaf12 0x4fc697 0x5089a9 0x4f2a14 0x4f561d 0x505be8 0x4f619b 0x4f40b0 0x4f561d 0x505be8 0x4f619b 0x4f434a 0x4f561d 0x505be8 0x4f64b6 0x5089a9
tcmalloc: large alloc 2097152000 bytes == 0x6a8630000 @ 0x7f0701396680 0x7f07013b7824 0x7f07013b7b8a 0x7f06e82d38e4 0x7f06e8298d03 0x7f06e9836af9 0x7f06e9830754 0x7f06e983079f 0x7f06e98307e5 0x7f06e9fcaf90 0x7f06eac15c91 0x7f06eac15ceb 0x7f06ea832d67 0x7f06eabdc25f 0x7f06ea87ad80 0x7f06f4dfaf12 0x4fc697 0x5089a9 0x4f2a14 0x4fcadf 0x4f56cd 0x505be8 0x4f619b 0x4f3851 0x4f561d 0x505be8 0x4f64b6 0x5089a9 0x4efb19 0x507eae 0x508858
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.93it/s]
Some weights of the model checkpoint at adapters_merged were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', 'model.layers.0._orig_module.mlp.down_proj.weight', 'model.layers.0._orig_module.mlp.gate_proj.weight', 'model.layers.0._orig_module.mlp.up_proj.weight', 'model.layers.0._orig_module.post_attention_layernorm.weight', 'model.layers.0._orig_module.self_attn.k_proj.weight', 'model.layers.0._orig_module.self_attn.o_proj.weight', 'model.layers.0._orig_module.self_attn.q_proj.weight', 'model.layers.0._orig_module.self_attn.v_proj.weight', 'model.layers.1._orig_module.input_layernorm.weight', 'model.layers.1._orig_module.mlp.down_proj.weight', 'model.layers.1._orig_module.mlp.gate_proj.weight', 'model.layers.1._orig_module.mlp.up_proj.weight', 'model.layers.1._orig_module.post_attention_layernorm.weight', 'model.layers.1._orig_module.self_attn.k_proj.weight', 'model.layers.1._orig_module.self_attn.o_proj.weight', 'model.layers.1._orig_module.self_attn.q_proj.weight', 'model.layers.1._orig_module.self_attn.v_proj.weight', 'model.layers.10._orig_module.input_layernorm.weight', 'model.layers.10._orig_module.mlp.down_proj.weight', 'model.layers.10._orig_module.mlp.gate_proj.weight', 'model.layers.10._orig_module.mlp.up_proj.weight', 'model.layers.10._orig_module.post_attention_layernorm.weight', 'model.layers.10._orig_module.self_attn.k_proj.weight', 'model.layers.10._orig_module.self_attn.o_proj.weight', 'model.layers.10._orig_module.self_attn.q_proj.weight', 'model.layers.10._orig_module.self_attn.v_proj.weight', 'model.layers.11._orig_module.input_layernorm.weight', 'model.layers.11._orig_module.mlp.down_proj.weight', 'model.layers.11._orig_module.mlp.gate_proj.weight', 'model.layers.11._orig_module.mlp.up_proj.weight', 'model.layers.11._orig_module.post_attention_layernorm.weight', 'model.layers.11._orig_module.self_attn.k_proj.weight', 'model.layers.11._orig_module.self_attn.o_proj.weight', 'model.layers.11._orig_module.self_attn.q_proj.weight', 'model.layers.11._orig_module.self_attn.v_proj.weight', 'model.layers.12._orig_module.input_layernorm.weight', 'model.layers.12._orig_module.mlp.down_proj.weight', 'model.layers.12._orig_module.mlp.gate_proj.weight', 'model.layers.12._orig_module.mlp.up_proj.weight', 'model.layers.12._orig_module.post_attention_layernorm.weight', 'model.layers.12._orig_module.self_attn.k_proj.weight', 'model.layers.12._orig_module.self_attn.o_proj.weight', 'model.layers.12._orig_module.self_attn.q_proj.weight', 'model.layers.12._orig_module.self_attn.v_proj.weight', 'model.layers.13._orig_module.input_layernorm.weight', 'model.layers.13._orig_module.mlp.down_proj.weight', 'model.layers.13._orig_module.mlp.gate_proj.weight', 'model.layers.13._orig_module.mlp.up_proj.weight', 'model.layers.13._orig_module.post_attention_layernorm.weight', 'model.layers.13._orig_module.self_attn.k_proj.weight', 'model.layers.13._orig_module.self_attn.o_proj.weight', 'model.layers.13._orig_module.self_attn.q_proj.weight', 'model.layers.13._orig_module.self_attn.v_proj.weight', 'model.layers.14._orig_module.input_layernorm.weight', 'model.layers.14._orig_module.mlp.down_proj.weight', 'model.layers.14._orig_module.mlp.gate_proj.weight', 'model.layers.14._orig_module.mlp.up_proj.weight', 'model.layers.14._orig_module.post_attention_layernorm.weight', 'model.layers.14._orig_module.self_attn.k_proj.weight', 'model.layers.14._orig_module.self_attn.o_proj.weight', 'model.layers.14._orig_module.self_attn.q_proj.weight', 'model.layers.14._orig_module.self_attn.v_proj.weight', 'model.layers.15._orig_module.input_layernorm.weight', 'model.layers.15._orig_module.mlp.down_proj.weight', 'model.layers.15._orig_module.mlp.gate_proj.weight', 'model.layers.15._orig_module.mlp.up_proj.weight', 'model.layers.15._orig_module.post_attention_layernorm.weight', 'model.layers.15._orig_module.self_attn.k_proj.weight', 'model.layers.15._orig_module.self_attn.o_proj.weight', 'model.layers.15._orig_module.self_attn.q_proj.weight', 'model.layers.15._orig_module.self_attn.v_proj.weight', 'model.layers.16._orig_module.input_layernorm.weight', 'model.layers.16._orig_module.mlp.down_proj.weight', 'model.layers.16._orig_module.mlp.gate_proj.weight', 'model.layers.16._orig_module.mlp.up_proj.weight', 'model.layers.16._orig_module.post_attention_layernorm.weight', 'model.layers.16._orig_module.self_attn.k_proj.weight', 'model.layers.16._orig_module.self_attn.o_proj.weight', 'model.layers.16._orig_module.self_attn.q_proj.weight', 'model.layers.16._orig_module.self_attn.v_proj.weight', 'model.layers.17._orig_module.input_layernorm.weight', 'model.layers.17._orig_module.mlp.down_proj.weight', 'model.layers.17._orig_module.mlp.gate_proj.weight', 'model.layers.17._orig_module.mlp.up_proj.weight', 'model.layers.17._orig_module.post_attention_layernorm.weight', 'model.layers.17._orig_module.self_attn.k_proj.weight', 'model.layers.17._orig_module.self_attn.o_proj.weight', 'model.layers.17._orig_module.self_attn.q_proj.weight', 'model.layers.17._orig_module.self_attn.v_proj.weight', 'model.layers.2._orig_module.input_layernorm.weight', 'model.layers.2._orig_module.mlp.down_proj.weight', 'model.layers.2._orig_module.mlp.gate_proj.weight', 'model.layers.2._orig_module.mlp.up_proj.weight', 'model.layers.2._orig_module.post_attention_layernorm.weight', 'model.layers.2._orig_module.self_attn.k_proj.weight', 'model.layers.2._orig_module.self_attn.o_proj.weight', 'model.layers.2._orig_module.self_attn.q_proj.weight', 'model.layers.2._orig_module.self_attn.v_proj.weight', 'model.layers.3._orig_module.input_layernorm.weight', 'model.layers.3._orig_module.mlp.down_proj.weight', 'model.layers.3._orig_module.mlp.gate_proj.weight', 'model.layers.3._orig_module.mlp.up_proj.weight', 'model.layers.3._orig_module.post_attention_layernorm.weight', 'model.layers.3._orig_module.self_attn.k_proj.weight', 'model.layers.3._orig_module.self_attn.o_proj.weight', 'model.layers.3._orig_module.self_attn.q_proj.weight', 'model.layers.3._orig_module.self_attn.v_proj.weight', 'model.layers.4._orig_module.input_layernorm.weight', 'model.layers.4._orig_module.mlp.down_proj.weight', 'model.layers.4._orig_module.mlp.gate_proj.weight', 'model.layers.4._orig_module.mlp.up_proj.weight', 'model.layers.4._orig_module.post_attention_layernorm.weight', 'model.layers.4._orig_module.self_attn.k_proj.weight', 'model.layers.4._orig_module.self_attn.o_proj.weight', 'model.layers.4._orig_module.self_attn.q_proj.weight', 'model.layers.4._orig_module.self_attn.v_proj.weight', 'model.layers.5._orig_module.input_layernorm.weight', 'model.layers.5._orig_module.mlp.down_proj.weight', 'model.layers.5._orig_module.mlp.gate_proj.weight', 'model.layers.5._orig_module.mlp.up_proj.weight', 'model.layers.5._orig_module.post_attention_layernorm.weight', 'model.layers.5._orig_module.self_attn.k_proj.weight', 'model.layers.5._orig_module.self_attn.o_proj.weight', 'model.layers.5._orig_module.self_attn.q_proj.weight', 'model.layers.5._orig_module.self_attn.v_proj.weight', 'model.layers.6._orig_module.input_layernorm.weight', 'model.layers.6._orig_module.mlp.down_proj.weight', 'model.layers.6._orig_module.mlp.gate_proj.weight', 'model.layers.6._orig_module.mlp.up_proj.weight', 'model.layers.6._orig_module.post_attention_layernorm.weight', 'model.layers.6._orig_module.self_attn.k_proj.weight', 'model.layers.6._orig_module.self_attn.o_proj.weight', 'model.layers.6._orig_module.self_attn.q_proj.weight', 'model.layers.6._orig_module.self_attn.v_proj.weight', 'model.layers.7._orig_module.input_layernorm.weight', 'model.layers.7._orig_module.mlp.down_proj.weight', 'model.layers.7._orig_module.mlp.gate_proj.weight', 'model.layers.7._orig_module.mlp.up_proj.weight', 'model.layers.7._orig_module.post_attention_layernorm.weight', 'model.layers.7._orig_module.self_attn.k_proj.weight', 'model.layers.7._orig_module.self_attn.o_proj.weight', 'model.layers.7._orig_module.self_attn.q_proj.weight', 'model.layers.7._orig_module.self_attn.v_proj.weight', 'model.layers.8._orig_module.input_layernorm.weight', 'model.layers.8._orig_module.mlp.down_proj.weight', 'model.layers.8._orig_module.mlp.gate_proj.weight', 'model.layers.8._orig_module.mlp.up_proj.weight', 'model.layers.8._orig_module.post_attention_layernorm.weight', 'model.layers.8._orig_module.self_attn.k_proj.weight', 'model.layers.8._orig_module.self_attn.o_proj.weight', 'model.layers.8._orig_module.self_attn.q_proj.weight', 'model.layers.8._orig_module.self_attn.v_proj.weight', 'model.layers.9._orig_module.input_layernorm.weight', 'model.layers.9._orig_module.mlp.down_proj.weight', 'model.layers.9._orig_module.mlp.gate_proj.weight', 'model.layers.9._orig_module.mlp.up_proj.weight', 'model.layers.9._orig_module.post_attention_layernorm.weight', 'model.layers.9._orig_module.self_attn.k_proj.weight', 'model.layers.9._orig_module.self_attn.o_proj.weight', 'model.layers.9._orig_module.self_attn.q_proj.weight', 'model.layers.9._orig_module.self_attn.v_proj.weight']
- This IS expected if you are initializing GemmaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GemmaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GemmaForCausalLM were not initialized from the model checkpoint at adapters_merged and are newly initialized: ['model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.v_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tcmalloc: large alloc 2097152000 bytes == 0x6a8630000 @ 0x7f0701396680 0x7f07013b7824 0x7f07013b7b8a 0x7f06e82d38e4 0x7f06e8298d03 0x7f06e9836af9 0x7f06e9830754 0x7f06e983079f 0x7f06e98307e5 0x7f06e9fcaf90 0x7f06eac15c91 0x7f06eac15ceb 0x7f06ea832d67 0x7f06eabdc25f 0x7f06ea87ad80 0x7f06f4dfaf12 0x4fc697 0x5089a9 0x4f2a14 0x4f561d 0x505be8 0x4f619b 0x4f40b0 0x4f561d 0x505be8 0x4f619b 0x4f434a 0x4f561d 0x505be8 0x4f64b6 0x5089a9
tcmalloc: large alloc 2097152000 bytes == 0x8fde30000 @ 0x7f0701396680 0x7f07013b7824 0x7f07013b7b8a 0x7f06e82d38e4 0x7f06e8298d03 0x7f06e9836af9 0x7f06e9830754 0x7f06e983079f 0x7f06e98307e5 0x7f06e9fcaf90 0x7f06eac15c91 0x7f06eac15ceb 0x7f06ea832d67 0x7f06eabdc25f 0x7f06ea87ad80 0x7f06f4dfaf12 0x4fc697 0x5089a9 0x4f2a14 0x4fcadf 0x4f56cd 0x505be8 0x4f619b 0x4f3851 0x4f561d 0x505be8 0x4f64b6 0x5089a9 0x4efb19 0x507eae 0x508858
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.79it/s]
Inference with base model:
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world. - Albert Einstein
I am
Inference with trained model:
Quote: Imagination is more increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa
As you can see I also experience that kind of warning when loading merged model.
@PawKanarek I am now able to replicate the error/warning you get, earlier I didn't get.
When I try debugging, I encountered this error when running with fsdp
only. I am trying to look into what is not working, if it's the saving or something else.
Can you please re-run the script, with these commented
#dataloader_drop_last = True, # Required for SPMD.
#fsdp="full_shard",
#fsdp_config=fsdp_config,
and reduce the batch size according to your TPU and post the results here again .
cc @amyeroberts
@shub-kris with commented-out FSDP and reduced batch_size=1
i could finally spot a really fine-tuned model without a warnings.
@PawKanarek thank you for the confirmation. I need to now look into what's going wrong when we use FSDP
if it's the saving or something else?
cc @amyeroberts
@alanwaketan can you please take a look into it.
I think the issue is for the FSDP wrapped model, we need to unwrap the model before saving it. I have given instructions to @shub-kris for fixing the unwrap logic in HF.
If things don't work out in HF, I will provide a utility in torch-xla to unwrap the model.
@PawKanarek @zorrofox can you now try with the PR: #29780
For me everything works perfectly now.
@shub-kris This time the merged model loading warning is disappeared but the inference result is not very good.
@zorrofox try training longer as your losses are still high. I don't remember the exact hyperparameters I tried but I was able to get decent results.
Thanks for confirming that the issue is resolved regarding saving and reloading the weights.
That's great news @shub-kris! Thank you for the quick fix and hard work! I will post update when I'm done with my current trainings (because my workaorund still works and i don't want to break my pipeline). Could you provide the minimal pseudo-code with correct pattern for unloading and merging LoRA adapter as standalone model? WIll this be correct?
trainer = SFTTrainer(...)
trainer.train()
merged_model = trainer.model.merge_and_unload() # merge LORA with base model
merged_model.to("cpu")
merged_model.save_pretrained("adapters_merged")
Is this OK? Or do i need also make trainer.save_model()
after training?
@PawKanarek in the training script: I will recommend to do the training and save the model.
trainer.train()
# saving final model
trainer.save_model()
Merging can be done in a separate script to avoid any kind of TPU or FSDP wrapper issues. I follow as mentioned here: https://huggingface.co/docs/trl/en/use_model#use-adapters-peft
import torch
import peft
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
base_model_name = "google/gemma-2b"
model = AutoModelForCausalLM.from_pretrained(base_model_name)
adapter_model_name = "fsdp_output"
print(f"Adapter model is {adapter_model_name}")
# Load trained peft model
trained_peft_model = PeftModel.from_pretrained(model, adapter_model_name)
merged_model = trained_peft_model.merge_and_unload() # merge LORA with base model
merged_model.save_pretrained("merged_model")
Fix from and saving with given pattern works flawlessly. Thank you @shub-kris 👨💻
System Info
torch.version='2.3.0.dev20240307' torch_xla.version='2.3.0+git46e2230' peft.version='0.9.0' trl.version='0.7.12.dev0' Python 3.10.13
Who can help?
@ArthurZucker , @younesbelkada, @muellerzr, @pacman100
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Hello, I have a problem with training the
gemma-2b-it
model on Google TPU v3-8. My goal is to train it with the peft lora adapter, and then save it as a standalone model.For merging base model with lora adapter I was following the guide: https://huggingface.co/docs/trl/main/en/use_model Training code is based on this blog post: https://huggingface.co/blog/gemma-peft
The problem is that the training takes a while (for 300k rows in a data loader it might take even 8 hours) but after training the model seems… untrained. The interference output looks almost identical to the output of the base model.
Furthermore, when I check for the weights of the trained and original models then they appear to be identical.
I also consistently encounter the following error message, while loading saved model:
Below is the minimal working code that trains and saves the model.
And this is the output
I'm stuck so, I'm asking for help. I tried many combinations of the
PeftModel.merge_and_unload()
,saving_pretrained()
, andtrainer.save_model()
and nothing seems to work. Every idea to push this issue forward will be appreciated. Thanks.Expected behavior
Training trains the model.