pytorch / executorch

On-device AI across mobile, embedded and edge for PyTorch
https://pytorch.org/executorch/
Other
2.18k stars 358 forks source link

How can I convert llama3 safetensors to the pth file needed to use with executorch? #3303

Closed l3utterfly closed 6 months ago

l3utterfly commented 6 months ago

Fine-tunes of Llama3 usually only have safetensors uploaded. In order to compile a Llama3 model following the tutorial, I need the original pth checkpoint file.

Is there a way to convert the safetensors to the checkpoint file?

mergennachin commented 6 months ago

@l3utterfly

Take a look at some example util functions in torchtune. Let us know if it works

https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_checkpointing/_checkpointer_utils.py#L66C18-L69

l3utterfly commented 6 months ago

So I need to use pytorch to save the state dict file?

I tried that with a Llama3 fine tune and then tried to compile it for XNNPACK, I got this error:

Could not import fairseq2 modules.
[INFO 2024-04-25 08:51:48,994 builder.py:84] Loading model with checkpoint=/home/layla/src/text-generation-webui/models/Aura_Uncensored_l3_8B.pth, params=/home/layla/src/text-generation-webui/models/Meta-Llama-3-8B-Instruct/original/params.json, use_kv_cache=True, weight_type=WeightType.LLAMA
[INFO 2024-04-25 08:51:49,226 builder.py:105] Loaded model with dtype=torch.bfloat16
[INFO 2024-04-25 08:51:49,240 config.py:58] PyTorch version 2.4.0.dev20240422+cpu available.
linear: layers.0.attention.wq, in=4096, out=4096
Traceback (most recent call last):
  File "/home/layla/miniconda3/envs/executorch/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/layla/miniconda3/envs/executorch/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/layla/src/executorch/examples/models/llama2/export_llama.py", line 30, in <module>
    main()  # pragma: no cover
  File "/home/layla/src/executorch/examples/models/llama2/export_llama.py", line 26, in main
    export_llama(modelname, args)
  File "/home/layla/src/executorch/examples/models/llama2/export_llama_lib.py", line 302, in export_llama
    return _export_llama(modelname, args)
  File "/home/layla/src/executorch/examples/models/llama2/export_llama_lib.py", line 380, in _export_llama
    builder_exported_to_edge = _prepare_for_llama_export(
  File "/home/layla/src/executorch/examples/models/llama2/export_llama_lib.py", line 365, in _prepare_for_llama_export
    .source_transform(transforms)
  File "/home/layla/src/executorch/examples/models/llama2/builder.py", line 203, in source_transform
    self.model = transform(self.model)
  File "/home/layla/src/executorch/examples/models/llama2/source_transformation/quantize.py", line 80, in quantize
    ).quantize(model)
  File "/home/layla/miniconda3/envs/executorch/lib/python3.10/site-packages/torchao/quantization/GPTQ.py", line 1256, in quantize
    state_dict = self._create_quantized_state_dict(model)
  File "/home/layla/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/layla/miniconda3/envs/executorch/lib/python3.10/site-packages/torchao/quantization/GPTQ.py", line 1214, in _create_quantized_state_dict
    ) = group_quantize_tensor_symmetric(
  File "/home/layla/miniconda3/envs/executorch/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 509, in group_quantize_tensor_symmetric
    scales, zeros = get_group_qparams_symmetric(w, n_bit, group_size, precision)
  File "/home/layla/miniconda3/envs/executorch/lib/python3.10/site-packages/torchao/quantization/quant_primitives.py", line 482, in get_group_qparams_symmetric
    assert torch.isnan(to_quant).sum() == 0
NotImplementedError: aten::_local_scalar_dense: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add a fake impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit

This is the code I used to save the pth file:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

print('loading model...')

model_id = "/home/layla/src/text-generation-webui/models/Aura_Uncensored_l3_8B"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)

print('saving model...')

# Save only the state dictionary
torch.save(model.state_dict(), "/home/layla/src/text-generation-webui/models/Aura_Uncensored_l3_8B.pth")

What am I doing wrong here?

iseeyuan commented 6 months ago

@l3utterfly Could you share the command you used in torchtune, as well as the export_llama?

l3utterfly commented 6 months ago

@iseeyuan sorry, I'm a little new to torch tune, following the documentation here: https://pytorch.org/torchtune/stable/deep_dives/checkpointer.html#understand-checkpointer

  1. I am loading my safetensor file first:

    AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
    )
  2. Then, I save it to the torch state dictionary:

    torch.save(model.state_dict(), "/home/layla/src/text-generation-webui/models/Aura_Uncensored_l3_8B.pth")

This is the command I'm using to convert: python -m examples.models.llama2.export_llama --checkpoint /models/Aura_Uncensored_l3_8B.pth -p /models/Meta-Llama-3-8B-Instruct/original/params.json -d=fp32 -X -qmode 8da4w -kv --use_sdpa_with_kv_cache --output_name="llama3_aura_uncensored_kv_sdpa_xnn_qe_4_32_ctx4096.pte" --group_size 128 --metadata '{"get_bos_id":128000, "get_eos_id":128001}' --embedding-quantize 4,32 --max_seq_len 4096

iseeyuan commented 6 months ago

@l3utterfly I can see two options,

  1. Have you tried directly safe_open your .safetensors file, and use the codes @mergennachin suggested? Then you could use torch.save to save the resulting state_dict. Please let us know if it works.
  2. Fine tune the model using TorchTune directly. The output of the tuned model would be directly saved as pytorch checkpoint. It has been verified that it's compatible with ExecuTorch.
l3utterfly commented 6 months ago

@iseeyuan this is the error I'm getting afer converting the safetensor file with torchtune util function you suggested. This error happens after running the compile pt to executorch command:

NotImplementedError: aten::_local_scalar_dense: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add a fake impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit

I tried to fine tune with torchtune, but it appears Torchtune does not support finetuning on top of another finetune? It is still looking for the original/checkpoint.pth file, but fine tunes do not have that.

Any way forward from here?

iseeyuan commented 6 months ago

@l3utterfly Let me try to convert the safetensor files and let you know if there's a way to workaround.

I tried to fine tune with torchtune, but it appears Torchtune does not support finetuning on top of another finetune? It is still looking for the original/checkpoint.pth file, but fine tunes do not have that.

@kartikayk, are you aware of this?

kartikayk commented 6 months ago

@l3utterfly torchtune doesn't really care about how the checkpoint is produced i.e. whether it's a finetuned model or a pre-trained model. All it cares about is that the formats should match up with what the checkpointer expects.

When the model changes, you need to update the config to point to the right checkpoint etc. Can you share the exact torchtune config you're using and the command you used to launch training?

l3utterfly commented 6 months ago

@kartikayk I am trying to convert this model: https://huggingface.co/ResplendentAI/Aura_Uncensored_l3_8B

It doesn't contain any original pytorch checkpoint files that torchtune supports, so trying to finetune with torch gets me back to square one: how can I convert the safetensor files to a pytorch checkpoint.

kartikayk commented 6 months ago

This should work OOTB if you update the checkpoint files in the config to point to the safetensors. I tried loading the checkpoint into the llama3 8B model in torchtune and the keys loaded successfull:

image

Note that you need to update the checkpointer to point to the HFCheckpointer since the safetensor files are in the HF format. The deep dive you pointed to above has a lot more information about checkpointer formats, but let me know if you have questions.

I'm not sure what your config looks like, but just update the Llama3 config to the following:

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: /data/users/kartikayk/cpts/Aura_Uncensored_l3_8B
  checkpoint_files: [
   model-00001-of-00002.safetensors,
   model-00002-of-00002.safetensors
  ]
  output_dir: /data/users/kartikayk/cpts/Aura_Uncensored_l3_8B
  model_type: LLAMA3
resume_from_checkpoint: False

For examples of how to use safetensors take a look at the Gemma configs

l3utterfly commented 6 months ago

Thanks for the help, I got the finetuning to run now with this config.

First, I tried to load and save the dict right away:

from torchtune.utils import FullModelHFCheckpointer
import torch

checkpointer = FullModelHFCheckpointer(
    checkpoint_dir='/home/layla/src/text-generation-webui/models/Aura_Uncensored_l3_8B',
    checkpoint_files=['model-00001-of-00002.safetensors', 'model-00002-of-00002.safetensors'],
    output_dir='/home/layla/src/text-generation-webui/models/Aura_Uncensored_l3_8B' ,
    model_type=' LLAMA3 '
)

print("loading checkpoint")
sd = checkpointer.load_checkpoint()

print("saving checkpoint")
torch.save(sd, "/home/layla/src/text-generation-webui/models/Aura_Uncensored_l3_8B/checkpoint.pth")

This still gives the same error as before when trying to compile down to executorch.

NotImplementedError: aten::_local_scalar_dense: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add a fake impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit

I am now in the process of doing a "finetune" with learning_rate=0 to get the final checkpoint. Is there any way to save checkpoints every X steps instead of at the end?

kartikayk commented 6 months ago

Glad this worked!

I am now in the process of doing a "finetune" with learning_rate=0 to get the final checkpoint. Is there any way to save checkpoints every X steps instead of at the end?

Yes, if you want to just get dummy checkpoints you can set max_steps_per_epoch=1 and gradient_accumulation_steps=1. For checkpointing more frequently during training, unfortunately we dont support mid-epoch checkpointing because we dont have a good way to checkpoint the dataloader. This is WIP and once its supported, we can do mid epoch checkpointing. Let me know if this makes sense?

This still gives the same error as before when trying to compile down to executorch.

I don't think executorch supports the HF format though. @iseeyuan can confirm.

iseeyuan commented 6 months ago

@l3utterfly I took a deeper look into the state dict (sd) in your code. There are two issues.

  1. The difference is that, the checkpoint returned by checkpointer.load_checkpoint() is a dictionary with one element. The key is 'model' and the value is the actual state dict we are looking for.

Try exactly like what @kartikayk suggested:

sd = checkpointer.load_checkpoint()

print("saving checkpoint")
torch.save(sd['model'], "/home/layla/src/text-generation-webui/models/Aura_Uncensored_l3_8B/checkpoint.pth")
  1. I printed out the state dic from the original llama3's checkpoint, and the checkpoint I got from step 1. I see some differences: a) the tensor names are different; b) the names are different. This can cause the mismatch between the model and checkpoint. For example, the original layer 0 looks like:
    layers.0.attention.wk.weight torch.Size([1024, 4096])
    layers.0.attention.wo.weight torch.Size([4096, 4096])
    layers.0.attention.wq.weight torch.Size([4096, 4096])
    layers.0.attention.wv.weight torch.Size([1024, 4096])
    layers.0.attention_norm.weight torch.Size([4096])
    layers.0.feed_forward.w1.weight torch.Size([14336, 4096])
    layers.0.feed_forward.w2.weight torch.Size([4096, 14336])
    layers.0.feed_forward.w3.weight torch.Size([14336, 4096])
    layers.0.ffn_norm.weight torch.Size([4096])

    But the converted layer 0 looks like,

    layers.0.attn.k_proj.weight torch.Size([1024, 4096])
    layers.0.attn.output_proj.weight torch.Size([4096, 4096])
    layers.0.attn.q_proj.weight torch.Size([4096, 4096])
    layers.0.attn.v_proj.weight torch.Size([1024, 4096])
    layers.0.mlp.w1.weight torch.Size([14336, 4096])
    layers.0.mlp.w2.weight torch.Size([4096, 14336])
    layers.0.mlp.w3.weight torch.Size([14336, 4096])
    layers.0.mlp_norm.scale torch.Size([4096])
    layers.0.sa_norm.scale torch.Size([4096])

    I don't know what causes this name differences. Since our flow is from the original llama3 checkpoint, I'd suggest you use the same checkpoint and iterate based on that.

iseeyuan commented 6 months ago

Update: After chatting with @kartikayk , we need another convert from torchtune to meta's llama3 format. So the code below should work,

from torchtune.utils import FullModelHFCheckpointer
from torchtune.models import convert_weights
import torch

checkpointer = FullModelHFCheckpointer(
    checkpoint_dir='/Users/myuan/.cache/huggingface/hub/models--ResplendentAI--Aura_Uncensored_l3_8B/snapshots/e7720d40e4d8d3c0fa07a8a579fda4d0644aa731',
    checkpoint_files=['model-00001-of-00002.safetensors', 'model-00002-of-00002.safetensors'],
    output_dir='/Users/myuan/data/Aura_Uncensored_l3_8B' ,
    model_type='LLAMA3'
)

print("loading checkpoint")
sd = checkpointer.load_checkpoint()
sd = convert_weights.tune_to_meta(sd['model'])

print("saving checkpoint")
torch.save(sd, "/Users/myuan/data/Aura_Uncensored_l3_8B/checkpoint.pth")

It works well from my side to successfully lower the checkpoint to ExecuTorch. @l3utterfly could you try above conversion and let us know if it works?

l3utterfly commented 6 months ago

Yes! This works, thank you so much for helping me!

I think it may be helpful to put/link this script in the Executorch Llama3 docs? I think this is very beneficial to accelerate adoption of Executorch by the wider community. Now people can load Llama3 finetunes with exeuctorch instead of only working with the base model. It will really encourage the local AI community to build infrastructure around executorch!

iseeyuan commented 6 months ago

@l3utterfly It's a great idea. Let me put up a PR for this with documentations.

puja93 commented 6 months ago

Hey @kartikayk, do you know if torchtune can also convert the quantized version of llama3 safetensors back to .pth ?

image

As you see, every mlp or attention, expanded to have 5 more layers each (bias, g_idx, qweight, qzeros, scales). I've tried naively the code you posted @iseeyuan, of course it doesn't match model_type = llama3.

Any idea perhaps ?

Thanks

guotong1988 commented 5 months ago
from torchtune.utils import FullModelHFCheckpointer
from torchtune.models import convert_weights
import torch

checkpointer = FullModelHFCheckpointer(
    checkpoint_dir="pythonProject/llama3_main/meta-llama-3-8b-instruct/",
    checkpoint_files=["model-00001-of-00004.safetensors", "model-00002-of-00004.safetensors",
                      "model-00003-of-00004.safetensors", "model-00004-of-00004.safetensors"],
    output_dir="./tmp",
    model_type='LLAMA3'
)

print("loading checkpoint")
sd = checkpointer.load_checkpoint()
sd = convert_weights.tune_to_meta(sd['model'])
print("saving checkpoint")
torch.save(sd, "./tmp/checkpoint.pth")

BUT

Traceback (most recent call last):
  File "pythonProject/convert.py", line 1, in <module>
    from torchtune.utils import FullModelHFCheckpointer
  File "python3.8/site-packages/torchtune/__init__.py", line 9, in <module>
    from torchtune import datasets, models, modules, utils
  File "python3.8/site-packages/torchtune/datasets/__init__.py", line 7, in <module>
    from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset
  File "python3.8/site-packages/torchtune/datasets/_alpaca.py", line 10, in <module>
    from torchtune.datasets._instruct import InstructDataset
  File "python3.8/site-packages/torchtune/datasets/_instruct.py", line 12, in <module>
    from torchtune.config._utils import _get_instruct_template
  File "python3.8/site-packages/torchtune/config/__init__.py", line 7, in <module>
    from ._instantiate import instantiate
  File "python3.8/site-packages/torchtune/config/_instantiate.py", line 12, in <module>
    from torchtune.config._utils import _get_component_from_path, _has_component
  File "python3.8/site-packages/torchtune/config/_utils.py", line 16, in <module>
    from torchtune.utils import get_logger, get_world_size_and_rank
  File "python3.8/site-packages/torchtune/utils/__init__.py", line 7, in <module>
    from ._checkpointing import (  # noqa
  File "python3.8/site-packages/torchtune/utils/_checkpointing/__init__.py", line 7, in <module>
    from ._checkpointer import (  # noqa
  File "python3.8/site-packages/torchtune/utils/_checkpointing/_checkpointer.py", line 17, in <module>
    from torchtune.models import convert_weights
  File "python3.8/site-packages/torchtune/models/__init__.py", line 7, in <module>
    from torchtune.models import convert_weights, gemma, llama2, mistral  # noqa
  File "python3.8/site-packages/torchtune/models/gemma/__init__.py", line 7, in <module>
    from ._component_builders import gemma  # noqa
  File "python3.8/site-packages/torchtune/models/gemma/_component_builders.py", line 9, in <module>
    from torchtune.modules import (
  File "python3.8/site-packages/torchtune/modules/__init__.py", line 8, in <module>
    from .common_utils import reparametrize_as_dtype_state_dict_post_hook
  File "python3.8/site-packages/torchtune/modules/common_utils.py", line 12, in <module>
    from torchao.dtypes.nf4tensor import NF4Tensor
  File "python3.8/site-packages/torchao/__init__.py", line 2, in <module>
    from .quantization.quant_api import apply_dynamic_quant
  File "python3.8/site-packages/torchao/quantization/__init__.py", line 7, in <module>
    from .smoothquant import *  # noqa: F403
  File "python3.8/site-packages/torchao/quantization/smoothquant.py", line 18, in <module>
    import torchao.quantization.quant_api as quant_api
  File "python3.8/site-packages/torchao/quantization/quant_api.py", line 22, in <module>
    from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
  File "python3.8/site-packages/torchao/quantization/dynamic_quant.py", line 10, in <module>
    from .quant_primitives import (
  File "python3.8/site-packages/torchao/quantization/quant_primitives.py", line 9, in <module>
    from torch._higher_order_ops.out_dtype import out_dtype
ModuleNotFoundError: No module named 'torch._higher_order_ops'

WITH

torch==2.0.0    
torchtune==0.1.1    
transformers==4.41.1    
safetensors==0.4.3

Thank you very much @iseeyuan